Files
git.stella-ops.org/src/BinaryIndex/__Libraries/StellaOps.BinaryIndex.ML/OnnxInferenceEngine.cs

382 lines
12 KiB
C#

// Copyright (c) StellaOps. All rights reserved.
// Licensed under BUSL-1.1. See LICENSE in the project root.
using System.Collections.Immutable;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
namespace StellaOps.BinaryIndex.ML;
/// <summary>
/// ONNX Runtime-based embedding inference engine.
/// </summary>
public sealed class OnnxInferenceEngine : IEmbeddingService, IAsyncDisposable
{
private readonly InferenceSession? _session;
private readonly ITokenizer _tokenizer;
private readonly IEmbeddingIndex? _index;
private readonly MlOptions _options;
private readonly ILogger<OnnxInferenceEngine> _logger;
private readonly TimeProvider _timeProvider;
private bool _disposed;
public OnnxInferenceEngine(
ITokenizer tokenizer,
IOptions<MlOptions> options,
ILogger<OnnxInferenceEngine> logger,
TimeProvider timeProvider,
IEmbeddingIndex? index = null)
{
_tokenizer = tokenizer;
_options = options.Value;
_logger = logger;
_timeProvider = timeProvider;
_index = index;
if (!string.IsNullOrEmpty(_options.ModelPath) && File.Exists(_options.ModelPath))
{
var sessionOptions = new SessionOptions
{
GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL,
ExecutionMode = ExecutionMode.ORT_PARALLEL,
InterOpNumThreads = _options.NumThreads,
IntraOpNumThreads = _options.NumThreads
};
_session = new InferenceSession(_options.ModelPath, sessionOptions);
_logger.LogInformation(
"Loaded ONNX model from {Path}",
_options.ModelPath);
}
else
{
_logger.LogWarning(
"No ONNX model found at {Path}, using fallback embedding",
_options.ModelPath);
}
}
/// <inheritdoc />
public async Task<FunctionEmbedding> GenerateEmbeddingAsync(
EmbeddingInput input,
EmbeddingOptions? options = null,
CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(input);
ct.ThrowIfCancellationRequested();
options ??= new EmbeddingOptions();
var text = GetInputText(input);
var functionId = ComputeFunctionId(text);
float[] vector;
if (_session is not null)
{
vector = await RunInferenceAsync(text, options, ct);
}
else
{
// Fallback: generate hash-based pseudo-embedding for testing
vector = GenerateFallbackEmbedding(text, 768);
}
if (options.NormalizeVector)
{
NormalizeVector(vector);
}
return new FunctionEmbedding(
functionId,
ExtractFunctionName(text),
vector,
EmbeddingModel.CodeBertBinary,
input.PreferredInput,
_timeProvider.GetUtcNow());
}
/// <inheritdoc />
public async Task<ImmutableArray<FunctionEmbedding>> GenerateBatchAsync(
IEnumerable<EmbeddingInput> inputs,
EmbeddingOptions? options = null,
CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(inputs);
options ??= new EmbeddingOptions();
var results = new List<FunctionEmbedding>();
// Process in batches
var batch = new List<EmbeddingInput>();
foreach (var input in inputs)
{
ct.ThrowIfCancellationRequested();
batch.Add(input);
if (batch.Count >= options.BatchSize)
{
var batchResults = await ProcessBatchAsync(batch, options, ct);
results.AddRange(batchResults);
batch.Clear();
}
}
// Process remaining
if (batch.Count > 0)
{
var batchResults = await ProcessBatchAsync(batch, options, ct);
results.AddRange(batchResults);
}
return [.. results];
}
/// <inheritdoc />
public decimal ComputeSimilarity(
FunctionEmbedding a,
FunctionEmbedding b,
SimilarityMetric metric = SimilarityMetric.Cosine)
{
ArgumentNullException.ThrowIfNull(a);
ArgumentNullException.ThrowIfNull(b);
if (a.Vector.Length != b.Vector.Length)
{
throw new ArgumentException("Embedding vectors must have same dimension");
}
return metric switch
{
SimilarityMetric.Cosine => CosineSimilarity(a.Vector, b.Vector),
SimilarityMetric.Euclidean => EuclideanSimilarity(a.Vector, b.Vector),
SimilarityMetric.Manhattan => ManhattanSimilarity(a.Vector, b.Vector),
SimilarityMetric.LearnedMetric => CosineSimilarity(a.Vector, b.Vector), // Fallback
_ => throw new ArgumentOutOfRangeException(nameof(metric))
};
}
/// <inheritdoc />
public async Task<ImmutableArray<EmbeddingMatch>> FindSimilarAsync(
FunctionEmbedding query,
int topK = 10,
decimal minSimilarity = 0.7m,
CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(query);
if (_index is null)
{
_logger.LogWarning("No embedding index configured, cannot search");
return [];
}
var results = await _index.SearchAsync(query.Vector, topK, ct);
return results
.Where(r => r.Similarity >= minSimilarity)
.Select(r => new EmbeddingMatch(
r.Embedding.FunctionId,
r.Embedding.FunctionName,
r.Similarity,
null, // Library info would come from metadata
null))
.ToImmutableArray();
}
private async Task<float[]> RunInferenceAsync(
string text,
EmbeddingOptions options,
CancellationToken ct)
{
if (_session is null)
{
throw new InvalidOperationException("ONNX session not initialized");
}
var (inputIds, attentionMask) = _tokenizer.TokenizeWithMask(text, options.MaxSequenceLength);
var inputIdsTensor = new DenseTensor<long>(inputIds, [1, inputIds.Length]);
var attentionMaskTensor = new DenseTensor<long>(attentionMask, [1, attentionMask.Length]);
var inputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("input_ids", inputIdsTensor),
NamedOnnxValue.CreateFromTensor("attention_mask", attentionMaskTensor)
};
using var results = await Task.Run(() => _session.Run(inputs), ct);
var outputTensor = results.First().AsTensor<float>();
return outputTensor.ToArray();
}
private async Task<IEnumerable<FunctionEmbedding>> ProcessBatchAsync(
List<EmbeddingInput> batch,
EmbeddingOptions options,
CancellationToken ct)
{
// For now, process sequentially
// TODO: Implement true batch inference with batched tensors
var results = new List<FunctionEmbedding>();
foreach (var input in batch)
{
var embedding = await GenerateEmbeddingAsync(input, options, ct);
results.Add(embedding);
}
return results;
}
private static string GetInputText(EmbeddingInput input)
{
return input.PreferredInput switch
{
EmbeddingInputType.DecompiledCode => input.DecompiledCode
?? throw new ArgumentException("DecompiledCode required"),
EmbeddingInputType.SemanticGraph => SerializeGraph(input.SemanticGraph
?? throw new ArgumentException("SemanticGraph required")),
EmbeddingInputType.Instructions => SerializeInstructions(input.InstructionBytes
?? throw new ArgumentException("InstructionBytes required")),
_ => throw new ArgumentOutOfRangeException()
};
}
private static string SerializeGraph(Semantic.KeySemanticsGraph graph)
{
// Convert graph to textual representation for tokenization
var sb = new System.Text.StringBuilder();
sb.AppendLine($"// Graph: {graph.Nodes.Length} nodes");
foreach (var node in graph.Nodes)
{
sb.AppendLine($"node {node.Id}: {node.Operation}");
}
foreach (var edge in graph.Edges)
{
sb.AppendLine($"edge {edge.SourceId} -> {edge.TargetId}");
}
return sb.ToString();
}
private static string SerializeInstructions(byte[] bytes)
{
// Convert instruction bytes to hex representation
return Convert.ToHexString(bytes);
}
private static string ComputeFunctionId(string text)
{
var hash = System.Security.Cryptography.SHA256.HashData(
System.Text.Encoding.UTF8.GetBytes(text));
return Convert.ToHexString(hash)[..16];
}
private static string ExtractFunctionName(string text)
{
// Try to extract function name from code
var match = System.Text.RegularExpressions.Regex.Match(
text,
@"\b(\w+)\s*\(");
return match.Success ? match.Groups[1].Value : "unknown";
}
private static float[] GenerateFallbackEmbedding(string text, int dimension)
{
// Generate a deterministic pseudo-embedding based on text hash
// This is only for testing when no model is available
var hash = System.Security.Cryptography.SHA256.HashData(
System.Text.Encoding.UTF8.GetBytes(text));
var random = new Random(BitConverter.ToInt32(hash, 0));
var vector = new float[dimension];
for (var i = 0; i < dimension; i++)
{
vector[i] = (float)(random.NextDouble() * 2 - 1);
}
return vector;
}
private static void NormalizeVector(float[] vector)
{
var norm = 0.0;
for (var i = 0; i < vector.Length; i++)
{
norm += vector[i] * vector[i];
}
norm = Math.Sqrt(norm);
if (norm > 0)
{
for (var i = 0; i < vector.Length; i++)
{
vector[i] /= (float)norm;
}
}
}
private static decimal CosineSimilarity(float[] a, float[] b)
{
var dotProduct = 0.0;
var normA = 0.0;
var normB = 0.0;
for (var i = 0; i < a.Length; i++)
{
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
if (normA == 0 || normB == 0)
{
return 0;
}
var similarity = dotProduct / (Math.Sqrt(normA) * Math.Sqrt(normB));
return (decimal)Math.Clamp(similarity, -1.0, 1.0);
}
private static decimal EuclideanSimilarity(float[] a, float[] b)
{
var sumSquares = 0.0;
for (var i = 0; i < a.Length; i++)
{
var diff = a[i] - b[i];
sumSquares += diff * diff;
}
var distance = Math.Sqrt(sumSquares);
// Convert distance to similarity (0 = identical, larger = more different)
return (decimal)(1.0 / (1.0 + distance));
}
private static decimal ManhattanSimilarity(float[] a, float[] b)
{
var sum = 0.0;
for (var i = 0; i < a.Length; i++)
{
sum += Math.Abs(a[i] - b[i]);
}
// Convert distance to similarity
return (decimal)(1.0 / (1.0 + sum));
}
public async ValueTask DisposeAsync()
{
if (!_disposed)
{
_session?.Dispose();
_disposed = true;
}
await Task.CompletedTask;
}
}