382 lines
12 KiB
C#
382 lines
12 KiB
C#
// Copyright (c) StellaOps. All rights reserved.
|
|
// Licensed under BUSL-1.1. See LICENSE in the project root.
|
|
|
|
using System.Collections.Immutable;
|
|
using Microsoft.Extensions.Logging;
|
|
using Microsoft.Extensions.Options;
|
|
using Microsoft.ML.OnnxRuntime;
|
|
using Microsoft.ML.OnnxRuntime.Tensors;
|
|
|
|
namespace StellaOps.BinaryIndex.ML;
|
|
|
|
/// <summary>
|
|
/// ONNX Runtime-based embedding inference engine.
|
|
/// </summary>
|
|
public sealed class OnnxInferenceEngine : IEmbeddingService, IAsyncDisposable
|
|
{
|
|
private readonly InferenceSession? _session;
|
|
private readonly ITokenizer _tokenizer;
|
|
private readonly IEmbeddingIndex? _index;
|
|
private readonly MlOptions _options;
|
|
private readonly ILogger<OnnxInferenceEngine> _logger;
|
|
private readonly TimeProvider _timeProvider;
|
|
private bool _disposed;
|
|
|
|
public OnnxInferenceEngine(
|
|
ITokenizer tokenizer,
|
|
IOptions<MlOptions> options,
|
|
ILogger<OnnxInferenceEngine> logger,
|
|
TimeProvider timeProvider,
|
|
IEmbeddingIndex? index = null)
|
|
{
|
|
_tokenizer = tokenizer;
|
|
_options = options.Value;
|
|
_logger = logger;
|
|
_timeProvider = timeProvider;
|
|
_index = index;
|
|
|
|
if (!string.IsNullOrEmpty(_options.ModelPath) && File.Exists(_options.ModelPath))
|
|
{
|
|
var sessionOptions = new SessionOptions
|
|
{
|
|
GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL,
|
|
ExecutionMode = ExecutionMode.ORT_PARALLEL,
|
|
InterOpNumThreads = _options.NumThreads,
|
|
IntraOpNumThreads = _options.NumThreads
|
|
};
|
|
|
|
_session = new InferenceSession(_options.ModelPath, sessionOptions);
|
|
_logger.LogInformation(
|
|
"Loaded ONNX model from {Path}",
|
|
_options.ModelPath);
|
|
}
|
|
else
|
|
{
|
|
_logger.LogWarning(
|
|
"No ONNX model found at {Path}, using fallback embedding",
|
|
_options.ModelPath);
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<FunctionEmbedding> GenerateEmbeddingAsync(
|
|
EmbeddingInput input,
|
|
EmbeddingOptions? options = null,
|
|
CancellationToken ct = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(input);
|
|
ct.ThrowIfCancellationRequested();
|
|
|
|
options ??= new EmbeddingOptions();
|
|
|
|
var text = GetInputText(input);
|
|
var functionId = ComputeFunctionId(text);
|
|
|
|
float[] vector;
|
|
|
|
if (_session is not null)
|
|
{
|
|
vector = await RunInferenceAsync(text, options, ct);
|
|
}
|
|
else
|
|
{
|
|
// Fallback: generate hash-based pseudo-embedding for testing
|
|
vector = GenerateFallbackEmbedding(text, 768);
|
|
}
|
|
|
|
if (options.NormalizeVector)
|
|
{
|
|
NormalizeVector(vector);
|
|
}
|
|
|
|
return new FunctionEmbedding(
|
|
functionId,
|
|
ExtractFunctionName(text),
|
|
vector,
|
|
EmbeddingModel.CodeBertBinary,
|
|
input.PreferredInput,
|
|
_timeProvider.GetUtcNow());
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<ImmutableArray<FunctionEmbedding>> GenerateBatchAsync(
|
|
IEnumerable<EmbeddingInput> inputs,
|
|
EmbeddingOptions? options = null,
|
|
CancellationToken ct = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(inputs);
|
|
|
|
options ??= new EmbeddingOptions();
|
|
var results = new List<FunctionEmbedding>();
|
|
|
|
// Process in batches
|
|
var batch = new List<EmbeddingInput>();
|
|
foreach (var input in inputs)
|
|
{
|
|
ct.ThrowIfCancellationRequested();
|
|
batch.Add(input);
|
|
|
|
if (batch.Count >= options.BatchSize)
|
|
{
|
|
var batchResults = await ProcessBatchAsync(batch, options, ct);
|
|
results.AddRange(batchResults);
|
|
batch.Clear();
|
|
}
|
|
}
|
|
|
|
// Process remaining
|
|
if (batch.Count > 0)
|
|
{
|
|
var batchResults = await ProcessBatchAsync(batch, options, ct);
|
|
results.AddRange(batchResults);
|
|
}
|
|
|
|
return [.. results];
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public decimal ComputeSimilarity(
|
|
FunctionEmbedding a,
|
|
FunctionEmbedding b,
|
|
SimilarityMetric metric = SimilarityMetric.Cosine)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(a);
|
|
ArgumentNullException.ThrowIfNull(b);
|
|
|
|
if (a.Vector.Length != b.Vector.Length)
|
|
{
|
|
throw new ArgumentException("Embedding vectors must have same dimension");
|
|
}
|
|
|
|
return metric switch
|
|
{
|
|
SimilarityMetric.Cosine => CosineSimilarity(a.Vector, b.Vector),
|
|
SimilarityMetric.Euclidean => EuclideanSimilarity(a.Vector, b.Vector),
|
|
SimilarityMetric.Manhattan => ManhattanSimilarity(a.Vector, b.Vector),
|
|
SimilarityMetric.LearnedMetric => CosineSimilarity(a.Vector, b.Vector), // Fallback
|
|
_ => throw new ArgumentOutOfRangeException(nameof(metric))
|
|
};
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<ImmutableArray<EmbeddingMatch>> FindSimilarAsync(
|
|
FunctionEmbedding query,
|
|
int topK = 10,
|
|
decimal minSimilarity = 0.7m,
|
|
CancellationToken ct = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(query);
|
|
|
|
if (_index is null)
|
|
{
|
|
_logger.LogWarning("No embedding index configured, cannot search");
|
|
return [];
|
|
}
|
|
|
|
var results = await _index.SearchAsync(query.Vector, topK, ct);
|
|
|
|
return results
|
|
.Where(r => r.Similarity >= minSimilarity)
|
|
.Select(r => new EmbeddingMatch(
|
|
r.Embedding.FunctionId,
|
|
r.Embedding.FunctionName,
|
|
r.Similarity,
|
|
null, // Library info would come from metadata
|
|
null))
|
|
.ToImmutableArray();
|
|
}
|
|
|
|
private async Task<float[]> RunInferenceAsync(
|
|
string text,
|
|
EmbeddingOptions options,
|
|
CancellationToken ct)
|
|
{
|
|
if (_session is null)
|
|
{
|
|
throw new InvalidOperationException("ONNX session not initialized");
|
|
}
|
|
|
|
var (inputIds, attentionMask) = _tokenizer.TokenizeWithMask(text, options.MaxSequenceLength);
|
|
|
|
var inputIdsTensor = new DenseTensor<long>(inputIds, [1, inputIds.Length]);
|
|
var attentionMaskTensor = new DenseTensor<long>(attentionMask, [1, attentionMask.Length]);
|
|
|
|
var inputs = new List<NamedOnnxValue>
|
|
{
|
|
NamedOnnxValue.CreateFromTensor("input_ids", inputIdsTensor),
|
|
NamedOnnxValue.CreateFromTensor("attention_mask", attentionMaskTensor)
|
|
};
|
|
|
|
using var results = await Task.Run(() => _session.Run(inputs), ct);
|
|
|
|
var outputTensor = results.First().AsTensor<float>();
|
|
return outputTensor.ToArray();
|
|
}
|
|
|
|
private async Task<IEnumerable<FunctionEmbedding>> ProcessBatchAsync(
|
|
List<EmbeddingInput> batch,
|
|
EmbeddingOptions options,
|
|
CancellationToken ct)
|
|
{
|
|
// For now, process sequentially
|
|
// TODO: Implement true batch inference with batched tensors
|
|
var results = new List<FunctionEmbedding>();
|
|
foreach (var input in batch)
|
|
{
|
|
var embedding = await GenerateEmbeddingAsync(input, options, ct);
|
|
results.Add(embedding);
|
|
}
|
|
return results;
|
|
}
|
|
|
|
private static string GetInputText(EmbeddingInput input)
|
|
{
|
|
return input.PreferredInput switch
|
|
{
|
|
EmbeddingInputType.DecompiledCode => input.DecompiledCode
|
|
?? throw new ArgumentException("DecompiledCode required"),
|
|
EmbeddingInputType.SemanticGraph => SerializeGraph(input.SemanticGraph
|
|
?? throw new ArgumentException("SemanticGraph required")),
|
|
EmbeddingInputType.Instructions => SerializeInstructions(input.InstructionBytes
|
|
?? throw new ArgumentException("InstructionBytes required")),
|
|
_ => throw new ArgumentOutOfRangeException()
|
|
};
|
|
}
|
|
|
|
private static string SerializeGraph(Semantic.KeySemanticsGraph graph)
|
|
{
|
|
// Convert graph to textual representation for tokenization
|
|
var sb = new System.Text.StringBuilder();
|
|
sb.AppendLine($"// Graph: {graph.Nodes.Length} nodes");
|
|
|
|
foreach (var node in graph.Nodes)
|
|
{
|
|
sb.AppendLine($"node {node.Id}: {node.Operation}");
|
|
}
|
|
|
|
foreach (var edge in graph.Edges)
|
|
{
|
|
sb.AppendLine($"edge {edge.SourceId} -> {edge.TargetId}");
|
|
}
|
|
|
|
return sb.ToString();
|
|
}
|
|
|
|
private static string SerializeInstructions(byte[] bytes)
|
|
{
|
|
// Convert instruction bytes to hex representation
|
|
return Convert.ToHexString(bytes);
|
|
}
|
|
|
|
private static string ComputeFunctionId(string text)
|
|
{
|
|
var hash = System.Security.Cryptography.SHA256.HashData(
|
|
System.Text.Encoding.UTF8.GetBytes(text));
|
|
return Convert.ToHexString(hash)[..16];
|
|
}
|
|
|
|
private static string ExtractFunctionName(string text)
|
|
{
|
|
// Try to extract function name from code
|
|
var match = System.Text.RegularExpressions.Regex.Match(
|
|
text,
|
|
@"\b(\w+)\s*\(");
|
|
return match.Success ? match.Groups[1].Value : "unknown";
|
|
}
|
|
|
|
private static float[] GenerateFallbackEmbedding(string text, int dimension)
|
|
{
|
|
// Generate a deterministic pseudo-embedding based on text hash
|
|
// This is only for testing when no model is available
|
|
var hash = System.Security.Cryptography.SHA256.HashData(
|
|
System.Text.Encoding.UTF8.GetBytes(text));
|
|
|
|
var random = new Random(BitConverter.ToInt32(hash, 0));
|
|
var vector = new float[dimension];
|
|
|
|
for (var i = 0; i < dimension; i++)
|
|
{
|
|
vector[i] = (float)(random.NextDouble() * 2 - 1);
|
|
}
|
|
|
|
return vector;
|
|
}
|
|
|
|
private static void NormalizeVector(float[] vector)
|
|
{
|
|
var norm = 0.0;
|
|
for (var i = 0; i < vector.Length; i++)
|
|
{
|
|
norm += vector[i] * vector[i];
|
|
}
|
|
|
|
norm = Math.Sqrt(norm);
|
|
if (norm > 0)
|
|
{
|
|
for (var i = 0; i < vector.Length; i++)
|
|
{
|
|
vector[i] /= (float)norm;
|
|
}
|
|
}
|
|
}
|
|
|
|
private static decimal CosineSimilarity(float[] a, float[] b)
|
|
{
|
|
var dotProduct = 0.0;
|
|
var normA = 0.0;
|
|
var normB = 0.0;
|
|
|
|
for (var i = 0; i < a.Length; i++)
|
|
{
|
|
dotProduct += a[i] * b[i];
|
|
normA += a[i] * a[i];
|
|
normB += b[i] * b[i];
|
|
}
|
|
|
|
if (normA == 0 || normB == 0)
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
var similarity = dotProduct / (Math.Sqrt(normA) * Math.Sqrt(normB));
|
|
return (decimal)Math.Clamp(similarity, -1.0, 1.0);
|
|
}
|
|
|
|
private static decimal EuclideanSimilarity(float[] a, float[] b)
|
|
{
|
|
var sumSquares = 0.0;
|
|
for (var i = 0; i < a.Length; i++)
|
|
{
|
|
var diff = a[i] - b[i];
|
|
sumSquares += diff * diff;
|
|
}
|
|
|
|
var distance = Math.Sqrt(sumSquares);
|
|
// Convert distance to similarity (0 = identical, larger = more different)
|
|
return (decimal)(1.0 / (1.0 + distance));
|
|
}
|
|
|
|
private static decimal ManhattanSimilarity(float[] a, float[] b)
|
|
{
|
|
var sum = 0.0;
|
|
for (var i = 0; i < a.Length; i++)
|
|
{
|
|
sum += Math.Abs(a[i] - b[i]);
|
|
}
|
|
|
|
// Convert distance to similarity
|
|
return (decimal)(1.0 / (1.0 + sum));
|
|
}
|
|
|
|
public async ValueTask DisposeAsync()
|
|
{
|
|
if (!_disposed)
|
|
{
|
|
_session?.Dispose();
|
|
_disposed = true;
|
|
}
|
|
|
|
await Task.CompletedTask;
|
|
}
|
|
}
|