save progress

This commit is contained in:
StellaOps Bot
2026-01-06 09:42:02 +02:00
parent 94d68bee8b
commit 37e11918e0
443 changed files with 85863 additions and 897 deletions

View File

@@ -0,0 +1,269 @@
// Copyright (c) StellaOps. All rights reserved.
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
using System.Collections.Immutable;
using System.Text.RegularExpressions;
namespace StellaOps.BinaryIndex.ML;
/// <summary>
/// Tokenizer for binary/decompiled code using byte-pair encoding style tokenization.
/// </summary>
public sealed partial class BinaryCodeTokenizer : ITokenizer
{
private readonly ImmutableDictionary<string, long> _vocabulary;
private readonly long _padToken;
private readonly long _unkToken;
private readonly long _clsToken;
private readonly long _sepToken;
// Special token IDs (matching CodeBERT conventions)
private const long DefaultPadToken = 0;
private const long DefaultUnkToken = 1;
private const long DefaultClsToken = 2;
private const long DefaultSepToken = 3;
public BinaryCodeTokenizer(string? vocabularyPath = null)
{
if (!string.IsNullOrEmpty(vocabularyPath) && File.Exists(vocabularyPath))
{
_vocabulary = LoadVocabulary(vocabularyPath);
_padToken = _vocabulary.GetValueOrDefault("<pad>", DefaultPadToken);
_unkToken = _vocabulary.GetValueOrDefault("<unk>", DefaultUnkToken);
_clsToken = _vocabulary.GetValueOrDefault("<cls>", DefaultClsToken);
_sepToken = _vocabulary.GetValueOrDefault("<sep>", DefaultSepToken);
}
else
{
// Use default vocabulary for testing
_vocabulary = CreateDefaultVocabulary();
_padToken = DefaultPadToken;
_unkToken = DefaultUnkToken;
_clsToken = DefaultClsToken;
_sepToken = DefaultSepToken;
}
}
/// <inheritdoc />
public long[] Tokenize(string text, int maxLength = 512)
{
var (inputIds, _) = TokenizeWithMask(text, maxLength);
return inputIds;
}
/// <inheritdoc />
public (long[] InputIds, long[] AttentionMask) TokenizeWithMask(string text, int maxLength = 512)
{
ArgumentException.ThrowIfNullOrEmpty(text);
var tokens = TokenizeText(text);
var inputIds = new long[maxLength];
var attentionMask = new long[maxLength];
// Add [CLS] token
inputIds[0] = _clsToken;
attentionMask[0] = 1;
var position = 1;
foreach (var token in tokens)
{
if (position >= maxLength - 1)
{
break;
}
inputIds[position] = _vocabulary.GetValueOrDefault(token.ToLowerInvariant(), _unkToken);
attentionMask[position] = 1;
position++;
}
// Add [SEP] token
if (position < maxLength)
{
inputIds[position] = _sepToken;
attentionMask[position] = 1;
position++;
}
// Pad remaining positions
for (var i = position; i < maxLength; i++)
{
inputIds[i] = _padToken;
attentionMask[i] = 0;
}
return (inputIds, attentionMask);
}
/// <inheritdoc />
public string Decode(long[] tokenIds)
{
ArgumentNullException.ThrowIfNull(tokenIds);
var reverseVocab = _vocabulary.ToImmutableDictionary(kv => kv.Value, kv => kv.Key);
var tokens = new List<string>();
foreach (var id in tokenIds)
{
if (id == _padToken || id == _clsToken || id == _sepToken)
{
continue;
}
tokens.Add(reverseVocab.GetValueOrDefault(id, "<unk>"));
}
return string.Join(" ", tokens);
}
private IEnumerable<string> TokenizeText(string text)
{
// Normalize whitespace
text = WhitespaceRegex().Replace(text, " ");
// Split on operators and punctuation, keeping them as tokens
var tokens = new List<string>();
var matches = TokenRegex().Matches(text);
foreach (Match match in matches)
{
var token = match.Value.Trim();
if (!string.IsNullOrEmpty(token))
{
tokens.Add(token);
}
}
return tokens;
}
private static ImmutableDictionary<string, long> LoadVocabulary(string path)
{
var vocabulary = new Dictionary<string, long>();
var lines = File.ReadAllLines(path);
for (var i = 0; i < lines.Length; i++)
{
var token = lines[i].Trim();
if (!string.IsNullOrEmpty(token))
{
vocabulary[token] = i;
}
}
return vocabulary.ToImmutableDictionary();
}
private static ImmutableDictionary<string, long> CreateDefaultVocabulary()
{
// Basic vocabulary for testing without model
var vocab = new Dictionary<string, long>
{
// Special tokens
["<pad>"] = 0,
["<unk>"] = 1,
["<cls>"] = 2,
["<sep>"] = 3,
// Keywords
["void"] = 10,
["int"] = 11,
["char"] = 12,
["short"] = 13,
["long"] = 14,
["float"] = 15,
["double"] = 16,
["unsigned"] = 17,
["signed"] = 18,
["const"] = 19,
["static"] = 20,
["extern"] = 21,
["return"] = 22,
["if"] = 23,
["else"] = 24,
["while"] = 25,
["for"] = 26,
["do"] = 27,
["switch"] = 28,
["case"] = 29,
["default"] = 30,
["break"] = 31,
["continue"] = 32,
["goto"] = 33,
["sizeof"] = 34,
["struct"] = 35,
["union"] = 36,
["enum"] = 37,
["typedef"] = 38,
// Operators
["+"] = 50,
["-"] = 51,
["*"] = 52,
["/"] = 53,
["%"] = 54,
["="] = 55,
["=="] = 56,
["!="] = 57,
["<"] = 58,
[">"] = 59,
["<="] = 60,
[">="] = 61,
["&&"] = 62,
["||"] = 63,
["!"] = 64,
["&"] = 65,
["|"] = 66,
["^"] = 67,
["~"] = 68,
["<<"] = 69,
[">>"] = 70,
["++"] = 71,
["--"] = 72,
["->"] = 73,
["."] = 74,
// Punctuation
["("] = 80,
[")"] = 81,
["{"] = 82,
["}"] = 83,
["["] = 84,
["]"] = 85,
[";"] = 86,
[","] = 87,
[":"] = 88,
// Common Ghidra types
["undefined"] = 100,
["undefined1"] = 101,
["undefined2"] = 102,
["undefined4"] = 103,
["undefined8"] = 104,
["byte"] = 105,
["word"] = 106,
["dword"] = 107,
["qword"] = 108,
["bool"] = 109,
// Common functions
["malloc"] = 200,
["free"] = 201,
["memcpy"] = 202,
["memset"] = 203,
["strlen"] = 204,
["strcpy"] = 205,
["strcmp"] = 206,
["printf"] = 207,
["sprintf"] = 208
};
return vocab.ToImmutableDictionary();
}
[GeneratedRegex(@"\s+")]
private static partial Regex WhitespaceRegex();
[GeneratedRegex(@"([a-zA-Z_][a-zA-Z0-9_]*|0[xX][0-9a-fA-F]+|\d+|""[^""]*""|'[^']*'|[+\-*/%=<>!&|^~]+|[(){}\[\];,.:])")]
private static partial Regex TokenRegex();
}

View File

@@ -0,0 +1,174 @@
// Copyright (c) StellaOps. All rights reserved.
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
using System.Collections.Immutable;
namespace StellaOps.BinaryIndex.ML;
/// <summary>
/// Service for generating and comparing function embeddings.
/// </summary>
public interface IEmbeddingService
{
/// <summary>
/// Generate embedding vector for a function.
/// </summary>
/// <param name="input">Function input data.</param>
/// <param name="options">Embedding options.</param>
/// <param name="ct">Cancellation token.</param>
/// <returns>Function embedding with vector.</returns>
Task<FunctionEmbedding> GenerateEmbeddingAsync(
EmbeddingInput input,
EmbeddingOptions? options = null,
CancellationToken ct = default);
/// <summary>
/// Generate embeddings for multiple functions in batch.
/// </summary>
/// <param name="inputs">Function inputs.</param>
/// <param name="options">Embedding options.</param>
/// <param name="ct">Cancellation token.</param>
/// <returns>Function embeddings.</returns>
Task<ImmutableArray<FunctionEmbedding>> GenerateBatchAsync(
IEnumerable<EmbeddingInput> inputs,
EmbeddingOptions? options = null,
CancellationToken ct = default);
/// <summary>
/// Compute similarity between two embeddings.
/// </summary>
/// <param name="a">First embedding.</param>
/// <param name="b">Second embedding.</param>
/// <param name="metric">Similarity metric to use.</param>
/// <returns>Similarity score (0.0 to 1.0).</returns>
decimal ComputeSimilarity(
FunctionEmbedding a,
FunctionEmbedding b,
SimilarityMetric metric = SimilarityMetric.Cosine);
/// <summary>
/// Find similar functions in an embedding index.
/// </summary>
/// <param name="query">Query embedding.</param>
/// <param name="topK">Number of results to return.</param>
/// <param name="minSimilarity">Minimum similarity threshold.</param>
/// <param name="ct">Cancellation token.</param>
/// <returns>Matching functions sorted by similarity.</returns>
Task<ImmutableArray<EmbeddingMatch>> FindSimilarAsync(
FunctionEmbedding query,
int topK = 10,
decimal minSimilarity = 0.7m,
CancellationToken ct = default);
}
/// <summary>
/// Service for training ML models.
/// </summary>
public interface IModelTrainingService
{
/// <summary>
/// Train embedding model on function pairs.
/// </summary>
/// <param name="trainingData">Training pairs.</param>
/// <param name="options">Training options.</param>
/// <param name="progress">Optional progress reporter.</param>
/// <param name="ct">Cancellation token.</param>
/// <returns>Training result.</returns>
Task<TrainingResult> TrainAsync(
IAsyncEnumerable<TrainingPair> trainingData,
TrainingOptions options,
IProgress<TrainingProgress>? progress = null,
CancellationToken ct = default);
/// <summary>
/// Evaluate model on test data.
/// </summary>
/// <param name="testData">Test pairs.</param>
/// <param name="ct">Cancellation token.</param>
/// <returns>Evaluation metrics.</returns>
Task<EvaluationResult> EvaluateAsync(
IAsyncEnumerable<TrainingPair> testData,
CancellationToken ct = default);
/// <summary>
/// Export trained model to specified format.
/// </summary>
/// <param name="outputPath">Output path for model.</param>
/// <param name="format">Export format.</param>
/// <param name="ct">Cancellation token.</param>
Task ExportModelAsync(
string outputPath,
ModelExportFormat format = ModelExportFormat.Onnx,
CancellationToken ct = default);
}
/// <summary>
/// Tokenizer for converting code to token sequences.
/// </summary>
public interface ITokenizer
{
/// <summary>
/// Tokenize text into token IDs.
/// </summary>
/// <param name="text">Input text.</param>
/// <param name="maxLength">Maximum sequence length.</param>
/// <returns>Token ID array.</returns>
long[] Tokenize(string text, int maxLength = 512);
/// <summary>
/// Tokenize with attention mask.
/// </summary>
/// <param name="text">Input text.</param>
/// <param name="maxLength">Maximum sequence length.</param>
/// <returns>Token IDs and attention mask.</returns>
(long[] InputIds, long[] AttentionMask) TokenizeWithMask(string text, int maxLength = 512);
/// <summary>
/// Decode token IDs back to text.
/// </summary>
/// <param name="tokenIds">Token IDs.</param>
/// <returns>Decoded text.</returns>
string Decode(long[] tokenIds);
}
/// <summary>
/// Index for efficient embedding similarity search.
/// </summary>
public interface IEmbeddingIndex
{
/// <summary>
/// Add embedding to index.
/// </summary>
/// <param name="embedding">Embedding to add.</param>
/// <param name="ct">Cancellation token.</param>
Task AddAsync(FunctionEmbedding embedding, CancellationToken ct = default);
/// <summary>
/// Add multiple embeddings to index.
/// </summary>
/// <param name="embeddings">Embeddings to add.</param>
/// <param name="ct">Cancellation token.</param>
Task AddBatchAsync(IEnumerable<FunctionEmbedding> embeddings, CancellationToken ct = default);
/// <summary>
/// Search for similar embeddings.
/// </summary>
/// <param name="query">Query vector.</param>
/// <param name="topK">Number of results.</param>
/// <param name="ct">Cancellation token.</param>
/// <returns>Similar embeddings with scores.</returns>
Task<ImmutableArray<(FunctionEmbedding Embedding, decimal Similarity)>> SearchAsync(
float[] query,
int topK,
CancellationToken ct = default);
/// <summary>
/// Get total count of indexed embeddings.
/// </summary>
int Count { get; }
/// <summary>
/// Clear all embeddings from index.
/// </summary>
void Clear();
}

View File

@@ -0,0 +1,138 @@
// Copyright (c) StellaOps. All rights reserved.
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
using System.Collections.Concurrent;
using System.Collections.Immutable;
namespace StellaOps.BinaryIndex.ML;
/// <summary>
/// In-memory embedding index using brute-force cosine similarity search.
/// For production use, consider using a vector database like Milvus or Pinecone.
/// </summary>
public sealed class InMemoryEmbeddingIndex : IEmbeddingIndex
{
private readonly ConcurrentDictionary<string, FunctionEmbedding> _embeddings = new();
private readonly object _lock = new();
/// <inheritdoc />
public int Count => _embeddings.Count;
/// <inheritdoc />
public Task AddAsync(FunctionEmbedding embedding, CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(embedding);
ct.ThrowIfCancellationRequested();
_embeddings[embedding.FunctionId] = embedding;
return Task.CompletedTask;
}
/// <inheritdoc />
public Task AddBatchAsync(IEnumerable<FunctionEmbedding> embeddings, CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(embeddings);
foreach (var embedding in embeddings)
{
ct.ThrowIfCancellationRequested();
_embeddings[embedding.FunctionId] = embedding;
}
return Task.CompletedTask;
}
/// <inheritdoc />
public Task<ImmutableArray<(FunctionEmbedding Embedding, decimal Similarity)>> SearchAsync(
float[] query,
int topK,
CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(query);
if (topK <= 0)
{
throw new ArgumentOutOfRangeException(nameof(topK), "topK must be positive");
}
ct.ThrowIfCancellationRequested();
// Calculate similarity for all embeddings
var similarities = new List<(FunctionEmbedding Embedding, decimal Similarity)>();
foreach (var embedding in _embeddings.Values)
{
if (embedding.Vector.Length != query.Length)
{
continue; // Skip incompatible dimensions
}
var similarity = CosineSimilarity(query, embedding.Vector);
similarities.Add((embedding, similarity));
}
// Sort by similarity descending and take top K
var results = similarities
.OrderByDescending(s => s.Similarity)
.Take(topK)
.ToImmutableArray();
return Task.FromResult(results);
}
/// <inheritdoc />
public void Clear()
{
_embeddings.Clear();
}
/// <summary>
/// Get an embedding by function ID.
/// </summary>
/// <param name="functionId">Function identifier.</param>
/// <returns>Embedding if found, null otherwise.</returns>
public FunctionEmbedding? Get(string functionId)
{
return _embeddings.TryGetValue(functionId, out var embedding) ? embedding : null;
}
/// <summary>
/// Remove an embedding by function ID.
/// </summary>
/// <param name="functionId">Function identifier.</param>
/// <returns>True if removed, false if not found.</returns>
public bool Remove(string functionId)
{
return _embeddings.TryRemove(functionId, out _);
}
/// <summary>
/// Get all embeddings.
/// </summary>
/// <returns>All stored embeddings.</returns>
public IEnumerable<FunctionEmbedding> GetAll()
{
return _embeddings.Values;
}
private static decimal CosineSimilarity(float[] a, float[] b)
{
var dotProduct = 0.0;
var normA = 0.0;
var normB = 0.0;
for (var i = 0; i < a.Length; i++)
{
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
if (normA == 0 || normB == 0)
{
return 0;
}
var similarity = dotProduct / (Math.Sqrt(normA) * Math.Sqrt(normB));
return (decimal)Math.Clamp(similarity, -1.0, 1.0);
}
}

View File

@@ -0,0 +1,75 @@
// Copyright (c) StellaOps. All rights reserved.
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
using Microsoft.Extensions.DependencyInjection;
namespace StellaOps.BinaryIndex.ML;
/// <summary>
/// Extension methods for registering ML services.
/// </summary>
public static class MlServiceCollectionExtensions
{
/// <summary>
/// Adds ML embedding services to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <returns>The service collection for chaining.</returns>
public static IServiceCollection AddMlServices(this IServiceCollection services)
{
ArgumentNullException.ThrowIfNull(services);
// Register tokenizer
services.AddSingleton<ITokenizer, BinaryCodeTokenizer>();
// Register embedding index
services.AddSingleton<IEmbeddingIndex, InMemoryEmbeddingIndex>();
// Register embedding service
services.AddScoped<IEmbeddingService, OnnxInferenceEngine>();
return services;
}
/// <summary>
/// Adds ML services with custom options.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configureOptions">Action to configure ML options.</param>
/// <returns>The service collection for chaining.</returns>
public static IServiceCollection AddMlServices(
this IServiceCollection services,
Action<MlOptions> configureOptions)
{
ArgumentNullException.ThrowIfNull(services);
ArgumentNullException.ThrowIfNull(configureOptions);
services.Configure(configureOptions);
return services.AddMlServices();
}
/// <summary>
/// Adds ML services with a custom tokenizer.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="vocabularyPath">Path to vocabulary file.</param>
/// <returns>The service collection for chaining.</returns>
public static IServiceCollection AddMlServicesWithVocabulary(
this IServiceCollection services,
string vocabularyPath)
{
ArgumentNullException.ThrowIfNull(services);
ArgumentException.ThrowIfNullOrEmpty(vocabularyPath);
// Register tokenizer with vocabulary
services.AddSingleton<ITokenizer>(sp => new BinaryCodeTokenizer(vocabularyPath));
// Register embedding index
services.AddSingleton<IEmbeddingIndex, InMemoryEmbeddingIndex>();
// Register embedding service
services.AddScoped<IEmbeddingService, OnnxInferenceEngine>();
return services;
}
}

View File

@@ -0,0 +1,259 @@
// Copyright (c) StellaOps. All rights reserved.
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
using System.Collections.Immutable;
using StellaOps.BinaryIndex.Semantic;
namespace StellaOps.BinaryIndex.ML;
/// <summary>
/// Input for generating function embeddings.
/// </summary>
/// <param name="DecompiledCode">Decompiled C-like code if available.</param>
/// <param name="SemanticGraph">Semantic graph from IR analysis if available.</param>
/// <param name="InstructionBytes">Raw instruction bytes if available.</param>
/// <param name="PreferredInput">Which input type to prefer for embedding generation.</param>
public sealed record EmbeddingInput(
string? DecompiledCode,
KeySemanticsGraph? SemanticGraph,
byte[]? InstructionBytes,
EmbeddingInputType PreferredInput);
/// <summary>
/// Type of input for embedding generation.
/// </summary>
public enum EmbeddingInputType
{
/// <summary>Use decompiled C-like code.</summary>
DecompiledCode,
/// <summary>Use semantic graph from IR analysis.</summary>
SemanticGraph,
/// <summary>Use raw instruction bytes.</summary>
Instructions
}
/// <summary>
/// A function embedding vector.
/// </summary>
/// <param name="FunctionId">Identifier for the function.</param>
/// <param name="FunctionName">Name of the function.</param>
/// <param name="Vector">Embedding vector (typically 768 dimensions).</param>
/// <param name="Model">Model used to generate the embedding.</param>
/// <param name="InputType">Type of input used.</param>
/// <param name="GeneratedAt">When the embedding was generated.</param>
public sealed record FunctionEmbedding(
string FunctionId,
string FunctionName,
float[] Vector,
EmbeddingModel Model,
EmbeddingInputType InputType,
DateTimeOffset GeneratedAt);
/// <summary>
/// Available embedding models.
/// </summary>
public enum EmbeddingModel
{
/// <summary>Fine-tuned CodeBERT for binary code analysis.</summary>
CodeBertBinary,
/// <summary>Graph neural network for CFG/call graph analysis.</summary>
GraphSageFunction,
/// <summary>Contrastive learning model for function similarity.</summary>
ContrastiveFunction
}
/// <summary>
/// Similarity metrics for comparing embeddings.
/// </summary>
public enum SimilarityMetric
{
/// <summary>Cosine similarity (angle between vectors).</summary>
Cosine,
/// <summary>Euclidean distance (inverted to similarity).</summary>
Euclidean,
/// <summary>Manhattan distance (inverted to similarity).</summary>
Manhattan,
/// <summary>Learned metric from model.</summary>
LearnedMetric
}
/// <summary>
/// A match from embedding similarity search.
/// </summary>
/// <param name="FunctionId">Matched function identifier.</param>
/// <param name="FunctionName">Matched function name.</param>
/// <param name="Similarity">Similarity score (0.0 to 1.0).</param>
/// <param name="LibraryName">Library containing the function.</param>
/// <param name="LibraryVersion">Version of the library.</param>
public sealed record EmbeddingMatch(
string FunctionId,
string FunctionName,
decimal Similarity,
string? LibraryName,
string? LibraryVersion);
/// <summary>
/// Options for embedding generation.
/// </summary>
public sealed record EmbeddingOptions
{
/// <summary>Maximum sequence length for tokenization.</summary>
public int MaxSequenceLength { get; init; } = 512;
/// <summary>Whether to normalize the embedding vector.</summary>
public bool NormalizeVector { get; init; } = true;
/// <summary>Batch size for batch inference.</summary>
public int BatchSize { get; init; } = 32;
}
/// <summary>
/// Training pair for model training.
/// </summary>
/// <param name="FunctionA">First function input.</param>
/// <param name="FunctionB">Second function input.</param>
/// <param name="IsSimilar">Ground truth: are these the same function?</param>
/// <param name="SimilarityScore">Optional fine-grained similarity score.</param>
public sealed record TrainingPair(
EmbeddingInput FunctionA,
EmbeddingInput FunctionB,
bool IsSimilar,
decimal? SimilarityScore);
/// <summary>
/// Options for model training.
/// </summary>
public sealed record TrainingOptions
{
/// <summary>Model architecture to train.</summary>
public EmbeddingModel Model { get; init; } = EmbeddingModel.CodeBertBinary;
/// <summary>Embedding vector dimension.</summary>
public int EmbeddingDimension { get; init; } = 768;
/// <summary>Training batch size.</summary>
public int BatchSize { get; init; } = 32;
/// <summary>Number of training epochs.</summary>
public int Epochs { get; init; } = 10;
/// <summary>Learning rate.</summary>
public double LearningRate { get; init; } = 1e-5;
/// <summary>Margin for contrastive loss.</summary>
public double MarginLoss { get; init; } = 0.5;
/// <summary>Path to pretrained model weights.</summary>
public string? PretrainedModelPath { get; init; }
/// <summary>Path to save checkpoints.</summary>
public string? CheckpointPath { get; init; }
}
/// <summary>
/// Progress update during training.
/// </summary>
/// <param name="Epoch">Current epoch.</param>
/// <param name="TotalEpochs">Total epochs.</param>
/// <param name="Batch">Current batch.</param>
/// <param name="TotalBatches">Total batches.</param>
/// <param name="Loss">Current loss value.</param>
/// <param name="Accuracy">Current accuracy.</param>
public sealed record TrainingProgress(
int Epoch,
int TotalEpochs,
int Batch,
int TotalBatches,
double Loss,
double Accuracy);
/// <summary>
/// Result of model training.
/// </summary>
/// <param name="ModelPath">Path to saved model.</param>
/// <param name="TotalPairs">Number of training pairs used.</param>
/// <param name="Epochs">Number of epochs completed.</param>
/// <param name="FinalLoss">Final loss value.</param>
/// <param name="ValidationAccuracy">Validation accuracy.</param>
/// <param name="TrainingTime">Total training time.</param>
public sealed record TrainingResult(
string ModelPath,
int TotalPairs,
int Epochs,
double FinalLoss,
double ValidationAccuracy,
TimeSpan TrainingTime);
/// <summary>
/// Result of model evaluation.
/// </summary>
/// <param name="Accuracy">Overall accuracy.</param>
/// <param name="Precision">Precision (true positives / predicted positives).</param>
/// <param name="Recall">Recall (true positives / actual positives).</param>
/// <param name="F1Score">F1 score (harmonic mean of precision and recall).</param>
/// <param name="AucRoc">Area under ROC curve.</param>
/// <param name="ConfusionMatrix">Confusion matrix entries.</param>
public sealed record EvaluationResult(
double Accuracy,
double Precision,
double Recall,
double F1Score,
double AucRoc,
ImmutableArray<ConfusionEntry> ConfusionMatrix);
/// <summary>
/// Entry in confusion matrix.
/// </summary>
/// <param name="Predicted">Predicted label.</param>
/// <param name="Actual">Actual label.</param>
/// <param name="Count">Number of occurrences.</param>
public sealed record ConfusionEntry(
string Predicted,
string Actual,
int Count);
/// <summary>
/// Model export formats.
/// </summary>
public enum ModelExportFormat
{
/// <summary>ONNX format for cross-platform inference.</summary>
Onnx,
/// <summary>PyTorch format.</summary>
PyTorch,
/// <summary>TensorFlow SavedModel format.</summary>
TensorFlow
}
/// <summary>
/// Options for ML service.
/// </summary>
public sealed record MlOptions
{
/// <summary>Path to ONNX model file.</summary>
public string? ModelPath { get; init; }
/// <summary>Path to tokenizer vocabulary.</summary>
public string? VocabularyPath { get; init; }
/// <summary>Device to use for inference (cpu, cuda).</summary>
public string Device { get; init; } = "cpu";
/// <summary>Number of threads for inference.</summary>
public int NumThreads { get; init; } = 4;
/// <summary>Whether to use GPU if available.</summary>
public bool UseGpu { get; init; } = false;
/// <summary>Maximum batch size for inference.</summary>
public int MaxBatchSize { get; init; } = 32;
}

View File

@@ -0,0 +1,381 @@
// Copyright (c) StellaOps. All rights reserved.
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
using System.Collections.Immutable;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
namespace StellaOps.BinaryIndex.ML;
/// <summary>
/// ONNX Runtime-based embedding inference engine.
/// </summary>
public sealed class OnnxInferenceEngine : IEmbeddingService, IAsyncDisposable
{
private readonly InferenceSession? _session;
private readonly ITokenizer _tokenizer;
private readonly IEmbeddingIndex? _index;
private readonly MlOptions _options;
private readonly ILogger<OnnxInferenceEngine> _logger;
private readonly TimeProvider _timeProvider;
private bool _disposed;
public OnnxInferenceEngine(
ITokenizer tokenizer,
IOptions<MlOptions> options,
ILogger<OnnxInferenceEngine> logger,
TimeProvider timeProvider,
IEmbeddingIndex? index = null)
{
_tokenizer = tokenizer;
_options = options.Value;
_logger = logger;
_timeProvider = timeProvider;
_index = index;
if (!string.IsNullOrEmpty(_options.ModelPath) && File.Exists(_options.ModelPath))
{
var sessionOptions = new SessionOptions
{
GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL,
ExecutionMode = ExecutionMode.ORT_PARALLEL,
InterOpNumThreads = _options.NumThreads,
IntraOpNumThreads = _options.NumThreads
};
_session = new InferenceSession(_options.ModelPath, sessionOptions);
_logger.LogInformation(
"Loaded ONNX model from {Path}",
_options.ModelPath);
}
else
{
_logger.LogWarning(
"No ONNX model found at {Path}, using fallback embedding",
_options.ModelPath);
}
}
/// <inheritdoc />
public async Task<FunctionEmbedding> GenerateEmbeddingAsync(
EmbeddingInput input,
EmbeddingOptions? options = null,
CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(input);
ct.ThrowIfCancellationRequested();
options ??= new EmbeddingOptions();
var text = GetInputText(input);
var functionId = ComputeFunctionId(text);
float[] vector;
if (_session is not null)
{
vector = await RunInferenceAsync(text, options, ct);
}
else
{
// Fallback: generate hash-based pseudo-embedding for testing
vector = GenerateFallbackEmbedding(text, 768);
}
if (options.NormalizeVector)
{
NormalizeVector(vector);
}
return new FunctionEmbedding(
functionId,
ExtractFunctionName(text),
vector,
EmbeddingModel.CodeBertBinary,
input.PreferredInput,
_timeProvider.GetUtcNow());
}
/// <inheritdoc />
public async Task<ImmutableArray<FunctionEmbedding>> GenerateBatchAsync(
IEnumerable<EmbeddingInput> inputs,
EmbeddingOptions? options = null,
CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(inputs);
options ??= new EmbeddingOptions();
var results = new List<FunctionEmbedding>();
// Process in batches
var batch = new List<EmbeddingInput>();
foreach (var input in inputs)
{
ct.ThrowIfCancellationRequested();
batch.Add(input);
if (batch.Count >= options.BatchSize)
{
var batchResults = await ProcessBatchAsync(batch, options, ct);
results.AddRange(batchResults);
batch.Clear();
}
}
// Process remaining
if (batch.Count > 0)
{
var batchResults = await ProcessBatchAsync(batch, options, ct);
results.AddRange(batchResults);
}
return [.. results];
}
/// <inheritdoc />
public decimal ComputeSimilarity(
FunctionEmbedding a,
FunctionEmbedding b,
SimilarityMetric metric = SimilarityMetric.Cosine)
{
ArgumentNullException.ThrowIfNull(a);
ArgumentNullException.ThrowIfNull(b);
if (a.Vector.Length != b.Vector.Length)
{
throw new ArgumentException("Embedding vectors must have same dimension");
}
return metric switch
{
SimilarityMetric.Cosine => CosineSimilarity(a.Vector, b.Vector),
SimilarityMetric.Euclidean => EuclideanSimilarity(a.Vector, b.Vector),
SimilarityMetric.Manhattan => ManhattanSimilarity(a.Vector, b.Vector),
SimilarityMetric.LearnedMetric => CosineSimilarity(a.Vector, b.Vector), // Fallback
_ => throw new ArgumentOutOfRangeException(nameof(metric))
};
}
/// <inheritdoc />
public async Task<ImmutableArray<EmbeddingMatch>> FindSimilarAsync(
FunctionEmbedding query,
int topK = 10,
decimal minSimilarity = 0.7m,
CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(query);
if (_index is null)
{
_logger.LogWarning("No embedding index configured, cannot search");
return [];
}
var results = await _index.SearchAsync(query.Vector, topK, ct);
return results
.Where(r => r.Similarity >= minSimilarity)
.Select(r => new EmbeddingMatch(
r.Embedding.FunctionId,
r.Embedding.FunctionName,
r.Similarity,
null, // Library info would come from metadata
null))
.ToImmutableArray();
}
private async Task<float[]> RunInferenceAsync(
string text,
EmbeddingOptions options,
CancellationToken ct)
{
if (_session is null)
{
throw new InvalidOperationException("ONNX session not initialized");
}
var (inputIds, attentionMask) = _tokenizer.TokenizeWithMask(text, options.MaxSequenceLength);
var inputIdsTensor = new DenseTensor<long>(inputIds, [1, inputIds.Length]);
var attentionMaskTensor = new DenseTensor<long>(attentionMask, [1, attentionMask.Length]);
var inputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("input_ids", inputIdsTensor),
NamedOnnxValue.CreateFromTensor("attention_mask", attentionMaskTensor)
};
using var results = await Task.Run(() => _session.Run(inputs), ct);
var outputTensor = results.First().AsTensor<float>();
return outputTensor.ToArray();
}
private async Task<IEnumerable<FunctionEmbedding>> ProcessBatchAsync(
List<EmbeddingInput> batch,
EmbeddingOptions options,
CancellationToken ct)
{
// For now, process sequentially
// TODO: Implement true batch inference with batched tensors
var results = new List<FunctionEmbedding>();
foreach (var input in batch)
{
var embedding = await GenerateEmbeddingAsync(input, options, ct);
results.Add(embedding);
}
return results;
}
private static string GetInputText(EmbeddingInput input)
{
return input.PreferredInput switch
{
EmbeddingInputType.DecompiledCode => input.DecompiledCode
?? throw new ArgumentException("DecompiledCode required"),
EmbeddingInputType.SemanticGraph => SerializeGraph(input.SemanticGraph
?? throw new ArgumentException("SemanticGraph required")),
EmbeddingInputType.Instructions => SerializeInstructions(input.InstructionBytes
?? throw new ArgumentException("InstructionBytes required")),
_ => throw new ArgumentOutOfRangeException()
};
}
private static string SerializeGraph(Semantic.KeySemanticsGraph graph)
{
// Convert graph to textual representation for tokenization
var sb = new System.Text.StringBuilder();
sb.AppendLine($"// Graph: {graph.Nodes.Length} nodes");
foreach (var node in graph.Nodes)
{
sb.AppendLine($"node {node.Id}: {node.Operation}");
}
foreach (var edge in graph.Edges)
{
sb.AppendLine($"edge {edge.SourceId} -> {edge.TargetId}");
}
return sb.ToString();
}
private static string SerializeInstructions(byte[] bytes)
{
// Convert instruction bytes to hex representation
return Convert.ToHexString(bytes);
}
private static string ComputeFunctionId(string text)
{
var hash = System.Security.Cryptography.SHA256.HashData(
System.Text.Encoding.UTF8.GetBytes(text));
return Convert.ToHexString(hash)[..16];
}
private static string ExtractFunctionName(string text)
{
// Try to extract function name from code
var match = System.Text.RegularExpressions.Regex.Match(
text,
@"\b(\w+)\s*\(");
return match.Success ? match.Groups[1].Value : "unknown";
}
private static float[] GenerateFallbackEmbedding(string text, int dimension)
{
// Generate a deterministic pseudo-embedding based on text hash
// This is only for testing when no model is available
var hash = System.Security.Cryptography.SHA256.HashData(
System.Text.Encoding.UTF8.GetBytes(text));
var random = new Random(BitConverter.ToInt32(hash, 0));
var vector = new float[dimension];
for (var i = 0; i < dimension; i++)
{
vector[i] = (float)(random.NextDouble() * 2 - 1);
}
return vector;
}
private static void NormalizeVector(float[] vector)
{
var norm = 0.0;
for (var i = 0; i < vector.Length; i++)
{
norm += vector[i] * vector[i];
}
norm = Math.Sqrt(norm);
if (norm > 0)
{
for (var i = 0; i < vector.Length; i++)
{
vector[i] /= (float)norm;
}
}
}
private static decimal CosineSimilarity(float[] a, float[] b)
{
var dotProduct = 0.0;
var normA = 0.0;
var normB = 0.0;
for (var i = 0; i < a.Length; i++)
{
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
if (normA == 0 || normB == 0)
{
return 0;
}
var similarity = dotProduct / (Math.Sqrt(normA) * Math.Sqrt(normB));
return (decimal)Math.Clamp(similarity, -1.0, 1.0);
}
private static decimal EuclideanSimilarity(float[] a, float[] b)
{
var sumSquares = 0.0;
for (var i = 0; i < a.Length; i++)
{
var diff = a[i] - b[i];
sumSquares += diff * diff;
}
var distance = Math.Sqrt(sumSquares);
// Convert distance to similarity (0 = identical, larger = more different)
return (decimal)(1.0 / (1.0 + distance));
}
private static decimal ManhattanSimilarity(float[] a, float[] b)
{
var sum = 0.0;
for (var i = 0; i < a.Length; i++)
{
sum += Math.Abs(a[i] - b[i]);
}
// Convert distance to similarity
return (decimal)(1.0 / (1.0 + sum));
}
public async ValueTask DisposeAsync()
{
if (!_disposed)
{
_session?.Dispose();
_disposed = true;
}
await Task.CompletedTask;
}
}

View File

@@ -0,0 +1,23 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<Description>Machine learning-based function similarity using embeddings and ONNX inference for BinaryIndex.</Description>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\StellaOps.BinaryIndex.Decompiler\StellaOps.BinaryIndex.Decompiler.csproj" />
<ProjectReference Include="..\StellaOps.BinaryIndex.Semantic\StellaOps.BinaryIndex.Semantic.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<PackageReference Include="Microsoft.Extensions.Options" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" />
</ItemGroup>
</Project>