// Copyright (c) StellaOps. All rights reserved. // Licensed under BUSL-1.1. See LICENSE in the project root. using System.Collections.Immutable; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; namespace StellaOps.BinaryIndex.ML; /// /// ONNX Runtime-based embedding inference engine. /// public sealed class OnnxInferenceEngine : IEmbeddingService, IAsyncDisposable { private readonly InferenceSession? _session; private readonly ITokenizer _tokenizer; private readonly IEmbeddingIndex? _index; private readonly MlOptions _options; private readonly ILogger _logger; private readonly TimeProvider _timeProvider; private bool _disposed; public OnnxInferenceEngine( ITokenizer tokenizer, IOptions options, ILogger logger, TimeProvider timeProvider, IEmbeddingIndex? index = null) { _tokenizer = tokenizer; _options = options.Value; _logger = logger; _timeProvider = timeProvider; _index = index; if (!string.IsNullOrEmpty(_options.ModelPath) && File.Exists(_options.ModelPath)) { var sessionOptions = new SessionOptions { GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL, ExecutionMode = ExecutionMode.ORT_PARALLEL, InterOpNumThreads = _options.NumThreads, IntraOpNumThreads = _options.NumThreads }; _session = new InferenceSession(_options.ModelPath, sessionOptions); _logger.LogInformation( "Loaded ONNX model from {Path}", _options.ModelPath); } else { _logger.LogWarning( "No ONNX model found at {Path}, using fallback embedding", _options.ModelPath); } } /// public async Task GenerateEmbeddingAsync( EmbeddingInput input, EmbeddingOptions? options = null, CancellationToken ct = default) { ArgumentNullException.ThrowIfNull(input); ct.ThrowIfCancellationRequested(); options ??= new EmbeddingOptions(); var text = GetInputText(input); var functionId = ComputeFunctionId(text); float[] vector; if (_session is not null) { vector = await RunInferenceAsync(text, options, ct); } else { // Fallback: generate hash-based pseudo-embedding for testing vector = GenerateFallbackEmbedding(text, 768); } if (options.NormalizeVector) { NormalizeVector(vector); } return new FunctionEmbedding( functionId, ExtractFunctionName(text), vector, EmbeddingModel.CodeBertBinary, input.PreferredInput, _timeProvider.GetUtcNow()); } /// public async Task> GenerateBatchAsync( IEnumerable inputs, EmbeddingOptions? options = null, CancellationToken ct = default) { ArgumentNullException.ThrowIfNull(inputs); options ??= new EmbeddingOptions(); var results = new List(); // Process in batches var batch = new List(); foreach (var input in inputs) { ct.ThrowIfCancellationRequested(); batch.Add(input); if (batch.Count >= options.BatchSize) { var batchResults = await ProcessBatchAsync(batch, options, ct); results.AddRange(batchResults); batch.Clear(); } } // Process remaining if (batch.Count > 0) { var batchResults = await ProcessBatchAsync(batch, options, ct); results.AddRange(batchResults); } return [.. results]; } /// public decimal ComputeSimilarity( FunctionEmbedding a, FunctionEmbedding b, SimilarityMetric metric = SimilarityMetric.Cosine) { ArgumentNullException.ThrowIfNull(a); ArgumentNullException.ThrowIfNull(b); if (a.Vector.Length != b.Vector.Length) { throw new ArgumentException("Embedding vectors must have same dimension"); } return metric switch { SimilarityMetric.Cosine => CosineSimilarity(a.Vector, b.Vector), SimilarityMetric.Euclidean => EuclideanSimilarity(a.Vector, b.Vector), SimilarityMetric.Manhattan => ManhattanSimilarity(a.Vector, b.Vector), SimilarityMetric.LearnedMetric => CosineSimilarity(a.Vector, b.Vector), // Fallback _ => throw new ArgumentOutOfRangeException(nameof(metric)) }; } /// public async Task> FindSimilarAsync( FunctionEmbedding query, int topK = 10, decimal minSimilarity = 0.7m, CancellationToken ct = default) { ArgumentNullException.ThrowIfNull(query); if (_index is null) { _logger.LogWarning("No embedding index configured, cannot search"); return []; } var results = await _index.SearchAsync(query.Vector, topK, ct); return results .Where(r => r.Similarity >= minSimilarity) .Select(r => new EmbeddingMatch( r.Embedding.FunctionId, r.Embedding.FunctionName, r.Similarity, null, // Library info would come from metadata null)) .ToImmutableArray(); } private async Task RunInferenceAsync( string text, EmbeddingOptions options, CancellationToken ct) { if (_session is null) { throw new InvalidOperationException("ONNX session not initialized"); } var (inputIds, attentionMask) = _tokenizer.TokenizeWithMask(text, options.MaxSequenceLength); var inputIdsTensor = new DenseTensor(inputIds, [1, inputIds.Length]); var attentionMaskTensor = new DenseTensor(attentionMask, [1, attentionMask.Length]); var inputs = new List { NamedOnnxValue.CreateFromTensor("input_ids", inputIdsTensor), NamedOnnxValue.CreateFromTensor("attention_mask", attentionMaskTensor) }; using var results = await Task.Run(() => _session.Run(inputs), ct); var outputTensor = results.First().AsTensor(); return outputTensor.ToArray(); } private async Task> ProcessBatchAsync( List batch, EmbeddingOptions options, CancellationToken ct) { // For now, process sequentially // TODO: Implement true batch inference with batched tensors var results = new List(); foreach (var input in batch) { var embedding = await GenerateEmbeddingAsync(input, options, ct); results.Add(embedding); } return results; } private static string GetInputText(EmbeddingInput input) { return input.PreferredInput switch { EmbeddingInputType.DecompiledCode => input.DecompiledCode ?? throw new ArgumentException("DecompiledCode required"), EmbeddingInputType.SemanticGraph => SerializeGraph(input.SemanticGraph ?? throw new ArgumentException("SemanticGraph required")), EmbeddingInputType.Instructions => SerializeInstructions(input.InstructionBytes ?? throw new ArgumentException("InstructionBytes required")), _ => throw new ArgumentOutOfRangeException() }; } private static string SerializeGraph(Semantic.KeySemanticsGraph graph) { // Convert graph to textual representation for tokenization var sb = new System.Text.StringBuilder(); sb.AppendLine($"// Graph: {graph.Nodes.Length} nodes"); foreach (var node in graph.Nodes) { sb.AppendLine($"node {node.Id}: {node.Operation}"); } foreach (var edge in graph.Edges) { sb.AppendLine($"edge {edge.SourceId} -> {edge.TargetId}"); } return sb.ToString(); } private static string SerializeInstructions(byte[] bytes) { // Convert instruction bytes to hex representation return Convert.ToHexString(bytes); } private static string ComputeFunctionId(string text) { var hash = System.Security.Cryptography.SHA256.HashData( System.Text.Encoding.UTF8.GetBytes(text)); return Convert.ToHexString(hash)[..16]; } private static string ExtractFunctionName(string text) { // Try to extract function name from code var match = System.Text.RegularExpressions.Regex.Match( text, @"\b(\w+)\s*\("); return match.Success ? match.Groups[1].Value : "unknown"; } private static float[] GenerateFallbackEmbedding(string text, int dimension) { // Generate a deterministic pseudo-embedding based on text hash // This is only for testing when no model is available var hash = System.Security.Cryptography.SHA256.HashData( System.Text.Encoding.UTF8.GetBytes(text)); var random = new Random(BitConverter.ToInt32(hash, 0)); var vector = new float[dimension]; for (var i = 0; i < dimension; i++) { vector[i] = (float)(random.NextDouble() * 2 - 1); } return vector; } private static void NormalizeVector(float[] vector) { var norm = 0.0; for (var i = 0; i < vector.Length; i++) { norm += vector[i] * vector[i]; } norm = Math.Sqrt(norm); if (norm > 0) { for (var i = 0; i < vector.Length; i++) { vector[i] /= (float)norm; } } } private static decimal CosineSimilarity(float[] a, float[] b) { var dotProduct = 0.0; var normA = 0.0; var normB = 0.0; for (var i = 0; i < a.Length; i++) { dotProduct += a[i] * b[i]; normA += a[i] * a[i]; normB += b[i] * b[i]; } if (normA == 0 || normB == 0) { return 0; } var similarity = dotProduct / (Math.Sqrt(normA) * Math.Sqrt(normB)); return (decimal)Math.Clamp(similarity, -1.0, 1.0); } private static decimal EuclideanSimilarity(float[] a, float[] b) { var sumSquares = 0.0; for (var i = 0; i < a.Length; i++) { var diff = a[i] - b[i]; sumSquares += diff * diff; } var distance = Math.Sqrt(sumSquares); // Convert distance to similarity (0 = identical, larger = more different) return (decimal)(1.0 / (1.0 + distance)); } private static decimal ManhattanSimilarity(float[] a, float[] b) { var sum = 0.0; for (var i = 0; i < a.Length; i++) { sum += Math.Abs(a[i] - b[i]); } // Convert distance to similarity return (decimal)(1.0 / (1.0 + sum)); } public async ValueTask DisposeAsync() { if (!_disposed) { _session?.Dispose(); _disposed = true; } await Task.CompletedTask; } }