save progress
This commit is contained in:
@@ -0,0 +1,269 @@
|
||||
// Copyright (c) StellaOps. All rights reserved.
|
||||
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
||||
|
||||
using System.Collections.Immutable;
|
||||
using System.Text.RegularExpressions;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML;
|
||||
|
||||
/// <summary>
|
||||
/// Tokenizer for binary/decompiled code using byte-pair encoding style tokenization.
|
||||
/// </summary>
|
||||
public sealed partial class BinaryCodeTokenizer : ITokenizer
|
||||
{
|
||||
private readonly ImmutableDictionary<string, long> _vocabulary;
|
||||
private readonly long _padToken;
|
||||
private readonly long _unkToken;
|
||||
private readonly long _clsToken;
|
||||
private readonly long _sepToken;
|
||||
|
||||
// Special token IDs (matching CodeBERT conventions)
|
||||
private const long DefaultPadToken = 0;
|
||||
private const long DefaultUnkToken = 1;
|
||||
private const long DefaultClsToken = 2;
|
||||
private const long DefaultSepToken = 3;
|
||||
|
||||
public BinaryCodeTokenizer(string? vocabularyPath = null)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(vocabularyPath) && File.Exists(vocabularyPath))
|
||||
{
|
||||
_vocabulary = LoadVocabulary(vocabularyPath);
|
||||
_padToken = _vocabulary.GetValueOrDefault("<pad>", DefaultPadToken);
|
||||
_unkToken = _vocabulary.GetValueOrDefault("<unk>", DefaultUnkToken);
|
||||
_clsToken = _vocabulary.GetValueOrDefault("<cls>", DefaultClsToken);
|
||||
_sepToken = _vocabulary.GetValueOrDefault("<sep>", DefaultSepToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use default vocabulary for testing
|
||||
_vocabulary = CreateDefaultVocabulary();
|
||||
_padToken = DefaultPadToken;
|
||||
_unkToken = DefaultUnkToken;
|
||||
_clsToken = DefaultClsToken;
|
||||
_sepToken = DefaultSepToken;
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public long[] Tokenize(string text, int maxLength = 512)
|
||||
{
|
||||
var (inputIds, _) = TokenizeWithMask(text, maxLength);
|
||||
return inputIds;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public (long[] InputIds, long[] AttentionMask) TokenizeWithMask(string text, int maxLength = 512)
|
||||
{
|
||||
ArgumentException.ThrowIfNullOrEmpty(text);
|
||||
|
||||
var tokens = TokenizeText(text);
|
||||
var inputIds = new long[maxLength];
|
||||
var attentionMask = new long[maxLength];
|
||||
|
||||
// Add [CLS] token
|
||||
inputIds[0] = _clsToken;
|
||||
attentionMask[0] = 1;
|
||||
|
||||
var position = 1;
|
||||
foreach (var token in tokens)
|
||||
{
|
||||
if (position >= maxLength - 1)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
inputIds[position] = _vocabulary.GetValueOrDefault(token.ToLowerInvariant(), _unkToken);
|
||||
attentionMask[position] = 1;
|
||||
position++;
|
||||
}
|
||||
|
||||
// Add [SEP] token
|
||||
if (position < maxLength)
|
||||
{
|
||||
inputIds[position] = _sepToken;
|
||||
attentionMask[position] = 1;
|
||||
position++;
|
||||
}
|
||||
|
||||
// Pad remaining positions
|
||||
for (var i = position; i < maxLength; i++)
|
||||
{
|
||||
inputIds[i] = _padToken;
|
||||
attentionMask[i] = 0;
|
||||
}
|
||||
|
||||
return (inputIds, attentionMask);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public string Decode(long[] tokenIds)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(tokenIds);
|
||||
|
||||
var reverseVocab = _vocabulary.ToImmutableDictionary(kv => kv.Value, kv => kv.Key);
|
||||
var tokens = new List<string>();
|
||||
|
||||
foreach (var id in tokenIds)
|
||||
{
|
||||
if (id == _padToken || id == _clsToken || id == _sepToken)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
tokens.Add(reverseVocab.GetValueOrDefault(id, "<unk>"));
|
||||
}
|
||||
|
||||
return string.Join(" ", tokens);
|
||||
}
|
||||
|
||||
private IEnumerable<string> TokenizeText(string text)
|
||||
{
|
||||
// Normalize whitespace
|
||||
text = WhitespaceRegex().Replace(text, " ");
|
||||
|
||||
// Split on operators and punctuation, keeping them as tokens
|
||||
var tokens = new List<string>();
|
||||
var matches = TokenRegex().Matches(text);
|
||||
|
||||
foreach (Match match in matches)
|
||||
{
|
||||
var token = match.Value.Trim();
|
||||
if (!string.IsNullOrEmpty(token))
|
||||
{
|
||||
tokens.Add(token);
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
private static ImmutableDictionary<string, long> LoadVocabulary(string path)
|
||||
{
|
||||
var vocabulary = new Dictionary<string, long>();
|
||||
var lines = File.ReadAllLines(path);
|
||||
|
||||
for (var i = 0; i < lines.Length; i++)
|
||||
{
|
||||
var token = lines[i].Trim();
|
||||
if (!string.IsNullOrEmpty(token))
|
||||
{
|
||||
vocabulary[token] = i;
|
||||
}
|
||||
}
|
||||
|
||||
return vocabulary.ToImmutableDictionary();
|
||||
}
|
||||
|
||||
private static ImmutableDictionary<string, long> CreateDefaultVocabulary()
|
||||
{
|
||||
// Basic vocabulary for testing without model
|
||||
var vocab = new Dictionary<string, long>
|
||||
{
|
||||
// Special tokens
|
||||
["<pad>"] = 0,
|
||||
["<unk>"] = 1,
|
||||
["<cls>"] = 2,
|
||||
["<sep>"] = 3,
|
||||
|
||||
// Keywords
|
||||
["void"] = 10,
|
||||
["int"] = 11,
|
||||
["char"] = 12,
|
||||
["short"] = 13,
|
||||
["long"] = 14,
|
||||
["float"] = 15,
|
||||
["double"] = 16,
|
||||
["unsigned"] = 17,
|
||||
["signed"] = 18,
|
||||
["const"] = 19,
|
||||
["static"] = 20,
|
||||
["extern"] = 21,
|
||||
["return"] = 22,
|
||||
["if"] = 23,
|
||||
["else"] = 24,
|
||||
["while"] = 25,
|
||||
["for"] = 26,
|
||||
["do"] = 27,
|
||||
["switch"] = 28,
|
||||
["case"] = 29,
|
||||
["default"] = 30,
|
||||
["break"] = 31,
|
||||
["continue"] = 32,
|
||||
["goto"] = 33,
|
||||
["sizeof"] = 34,
|
||||
["struct"] = 35,
|
||||
["union"] = 36,
|
||||
["enum"] = 37,
|
||||
["typedef"] = 38,
|
||||
|
||||
// Operators
|
||||
["+"] = 50,
|
||||
["-"] = 51,
|
||||
["*"] = 52,
|
||||
["/"] = 53,
|
||||
["%"] = 54,
|
||||
["="] = 55,
|
||||
["=="] = 56,
|
||||
["!="] = 57,
|
||||
["<"] = 58,
|
||||
[">"] = 59,
|
||||
["<="] = 60,
|
||||
[">="] = 61,
|
||||
["&&"] = 62,
|
||||
["||"] = 63,
|
||||
["!"] = 64,
|
||||
["&"] = 65,
|
||||
["|"] = 66,
|
||||
["^"] = 67,
|
||||
["~"] = 68,
|
||||
["<<"] = 69,
|
||||
[">>"] = 70,
|
||||
["++"] = 71,
|
||||
["--"] = 72,
|
||||
["->"] = 73,
|
||||
["."] = 74,
|
||||
|
||||
// Punctuation
|
||||
["("] = 80,
|
||||
[")"] = 81,
|
||||
["{"] = 82,
|
||||
["}"] = 83,
|
||||
["["] = 84,
|
||||
["]"] = 85,
|
||||
[";"] = 86,
|
||||
[","] = 87,
|
||||
[":"] = 88,
|
||||
|
||||
// Common Ghidra types
|
||||
["undefined"] = 100,
|
||||
["undefined1"] = 101,
|
||||
["undefined2"] = 102,
|
||||
["undefined4"] = 103,
|
||||
["undefined8"] = 104,
|
||||
["byte"] = 105,
|
||||
["word"] = 106,
|
||||
["dword"] = 107,
|
||||
["qword"] = 108,
|
||||
["bool"] = 109,
|
||||
|
||||
// Common functions
|
||||
["malloc"] = 200,
|
||||
["free"] = 201,
|
||||
["memcpy"] = 202,
|
||||
["memset"] = 203,
|
||||
["strlen"] = 204,
|
||||
["strcpy"] = 205,
|
||||
["strcmp"] = 206,
|
||||
["printf"] = 207,
|
||||
["sprintf"] = 208
|
||||
};
|
||||
|
||||
return vocab.ToImmutableDictionary();
|
||||
}
|
||||
|
||||
[GeneratedRegex(@"\s+")]
|
||||
private static partial Regex WhitespaceRegex();
|
||||
|
||||
[GeneratedRegex(@"([a-zA-Z_][a-zA-Z0-9_]*|0[xX][0-9a-fA-F]+|\d+|""[^""]*""|'[^']*'|[+\-*/%=<>!&|^~]+|[(){}\[\];,.:])")]
|
||||
private static partial Regex TokenRegex();
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
// Copyright (c) StellaOps. All rights reserved.
|
||||
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
||||
|
||||
using System.Collections.Immutable;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML;
|
||||
|
||||
/// <summary>
|
||||
/// Service for generating and comparing function embeddings.
|
||||
/// </summary>
|
||||
public interface IEmbeddingService
|
||||
{
|
||||
/// <summary>
|
||||
/// Generate embedding vector for a function.
|
||||
/// </summary>
|
||||
/// <param name="input">Function input data.</param>
|
||||
/// <param name="options">Embedding options.</param>
|
||||
/// <param name="ct">Cancellation token.</param>
|
||||
/// <returns>Function embedding with vector.</returns>
|
||||
Task<FunctionEmbedding> GenerateEmbeddingAsync(
|
||||
EmbeddingInput input,
|
||||
EmbeddingOptions? options = null,
|
||||
CancellationToken ct = default);
|
||||
|
||||
/// <summary>
|
||||
/// Generate embeddings for multiple functions in batch.
|
||||
/// </summary>
|
||||
/// <param name="inputs">Function inputs.</param>
|
||||
/// <param name="options">Embedding options.</param>
|
||||
/// <param name="ct">Cancellation token.</param>
|
||||
/// <returns>Function embeddings.</returns>
|
||||
Task<ImmutableArray<FunctionEmbedding>> GenerateBatchAsync(
|
||||
IEnumerable<EmbeddingInput> inputs,
|
||||
EmbeddingOptions? options = null,
|
||||
CancellationToken ct = default);
|
||||
|
||||
/// <summary>
|
||||
/// Compute similarity between two embeddings.
|
||||
/// </summary>
|
||||
/// <param name="a">First embedding.</param>
|
||||
/// <param name="b">Second embedding.</param>
|
||||
/// <param name="metric">Similarity metric to use.</param>
|
||||
/// <returns>Similarity score (0.0 to 1.0).</returns>
|
||||
decimal ComputeSimilarity(
|
||||
FunctionEmbedding a,
|
||||
FunctionEmbedding b,
|
||||
SimilarityMetric metric = SimilarityMetric.Cosine);
|
||||
|
||||
/// <summary>
|
||||
/// Find similar functions in an embedding index.
|
||||
/// </summary>
|
||||
/// <param name="query">Query embedding.</param>
|
||||
/// <param name="topK">Number of results to return.</param>
|
||||
/// <param name="minSimilarity">Minimum similarity threshold.</param>
|
||||
/// <param name="ct">Cancellation token.</param>
|
||||
/// <returns>Matching functions sorted by similarity.</returns>
|
||||
Task<ImmutableArray<EmbeddingMatch>> FindSimilarAsync(
|
||||
FunctionEmbedding query,
|
||||
int topK = 10,
|
||||
decimal minSimilarity = 0.7m,
|
||||
CancellationToken ct = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Service for training ML models.
|
||||
/// </summary>
|
||||
public interface IModelTrainingService
|
||||
{
|
||||
/// <summary>
|
||||
/// Train embedding model on function pairs.
|
||||
/// </summary>
|
||||
/// <param name="trainingData">Training pairs.</param>
|
||||
/// <param name="options">Training options.</param>
|
||||
/// <param name="progress">Optional progress reporter.</param>
|
||||
/// <param name="ct">Cancellation token.</param>
|
||||
/// <returns>Training result.</returns>
|
||||
Task<TrainingResult> TrainAsync(
|
||||
IAsyncEnumerable<TrainingPair> trainingData,
|
||||
TrainingOptions options,
|
||||
IProgress<TrainingProgress>? progress = null,
|
||||
CancellationToken ct = default);
|
||||
|
||||
/// <summary>
|
||||
/// Evaluate model on test data.
|
||||
/// </summary>
|
||||
/// <param name="testData">Test pairs.</param>
|
||||
/// <param name="ct">Cancellation token.</param>
|
||||
/// <returns>Evaluation metrics.</returns>
|
||||
Task<EvaluationResult> EvaluateAsync(
|
||||
IAsyncEnumerable<TrainingPair> testData,
|
||||
CancellationToken ct = default);
|
||||
|
||||
/// <summary>
|
||||
/// Export trained model to specified format.
|
||||
/// </summary>
|
||||
/// <param name="outputPath">Output path for model.</param>
|
||||
/// <param name="format">Export format.</param>
|
||||
/// <param name="ct">Cancellation token.</param>
|
||||
Task ExportModelAsync(
|
||||
string outputPath,
|
||||
ModelExportFormat format = ModelExportFormat.Onnx,
|
||||
CancellationToken ct = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tokenizer for converting code to token sequences.
|
||||
/// </summary>
|
||||
public interface ITokenizer
|
||||
{
|
||||
/// <summary>
|
||||
/// Tokenize text into token IDs.
|
||||
/// </summary>
|
||||
/// <param name="text">Input text.</param>
|
||||
/// <param name="maxLength">Maximum sequence length.</param>
|
||||
/// <returns>Token ID array.</returns>
|
||||
long[] Tokenize(string text, int maxLength = 512);
|
||||
|
||||
/// <summary>
|
||||
/// Tokenize with attention mask.
|
||||
/// </summary>
|
||||
/// <param name="text">Input text.</param>
|
||||
/// <param name="maxLength">Maximum sequence length.</param>
|
||||
/// <returns>Token IDs and attention mask.</returns>
|
||||
(long[] InputIds, long[] AttentionMask) TokenizeWithMask(string text, int maxLength = 512);
|
||||
|
||||
/// <summary>
|
||||
/// Decode token IDs back to text.
|
||||
/// </summary>
|
||||
/// <param name="tokenIds">Token IDs.</param>
|
||||
/// <returns>Decoded text.</returns>
|
||||
string Decode(long[] tokenIds);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Index for efficient embedding similarity search.
|
||||
/// </summary>
|
||||
public interface IEmbeddingIndex
|
||||
{
|
||||
/// <summary>
|
||||
/// Add embedding to index.
|
||||
/// </summary>
|
||||
/// <param name="embedding">Embedding to add.</param>
|
||||
/// <param name="ct">Cancellation token.</param>
|
||||
Task AddAsync(FunctionEmbedding embedding, CancellationToken ct = default);
|
||||
|
||||
/// <summary>
|
||||
/// Add multiple embeddings to index.
|
||||
/// </summary>
|
||||
/// <param name="embeddings">Embeddings to add.</param>
|
||||
/// <param name="ct">Cancellation token.</param>
|
||||
Task AddBatchAsync(IEnumerable<FunctionEmbedding> embeddings, CancellationToken ct = default);
|
||||
|
||||
/// <summary>
|
||||
/// Search for similar embeddings.
|
||||
/// </summary>
|
||||
/// <param name="query">Query vector.</param>
|
||||
/// <param name="topK">Number of results.</param>
|
||||
/// <param name="ct">Cancellation token.</param>
|
||||
/// <returns>Similar embeddings with scores.</returns>
|
||||
Task<ImmutableArray<(FunctionEmbedding Embedding, decimal Similarity)>> SearchAsync(
|
||||
float[] query,
|
||||
int topK,
|
||||
CancellationToken ct = default);
|
||||
|
||||
/// <summary>
|
||||
/// Get total count of indexed embeddings.
|
||||
/// </summary>
|
||||
int Count { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Clear all embeddings from index.
|
||||
/// </summary>
|
||||
void Clear();
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) StellaOps. All rights reserved.
|
||||
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
||||
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Immutable;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML;
|
||||
|
||||
/// <summary>
|
||||
/// In-memory embedding index using brute-force cosine similarity search.
|
||||
/// For production use, consider using a vector database like Milvus or Pinecone.
|
||||
/// </summary>
|
||||
public sealed class InMemoryEmbeddingIndex : IEmbeddingIndex
|
||||
{
|
||||
private readonly ConcurrentDictionary<string, FunctionEmbedding> _embeddings = new();
|
||||
private readonly object _lock = new();
|
||||
|
||||
/// <inheritdoc />
|
||||
public int Count => _embeddings.Count;
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task AddAsync(FunctionEmbedding embedding, CancellationToken ct = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(embedding);
|
||||
ct.ThrowIfCancellationRequested();
|
||||
|
||||
_embeddings[embedding.FunctionId] = embedding;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task AddBatchAsync(IEnumerable<FunctionEmbedding> embeddings, CancellationToken ct = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(embeddings);
|
||||
|
||||
foreach (var embedding in embeddings)
|
||||
{
|
||||
ct.ThrowIfCancellationRequested();
|
||||
_embeddings[embedding.FunctionId] = embedding;
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task<ImmutableArray<(FunctionEmbedding Embedding, decimal Similarity)>> SearchAsync(
|
||||
float[] query,
|
||||
int topK,
|
||||
CancellationToken ct = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(query);
|
||||
if (topK <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(topK), "topK must be positive");
|
||||
}
|
||||
|
||||
ct.ThrowIfCancellationRequested();
|
||||
|
||||
// Calculate similarity for all embeddings
|
||||
var similarities = new List<(FunctionEmbedding Embedding, decimal Similarity)>();
|
||||
|
||||
foreach (var embedding in _embeddings.Values)
|
||||
{
|
||||
if (embedding.Vector.Length != query.Length)
|
||||
{
|
||||
continue; // Skip incompatible dimensions
|
||||
}
|
||||
|
||||
var similarity = CosineSimilarity(query, embedding.Vector);
|
||||
similarities.Add((embedding, similarity));
|
||||
}
|
||||
|
||||
// Sort by similarity descending and take top K
|
||||
var results = similarities
|
||||
.OrderByDescending(s => s.Similarity)
|
||||
.Take(topK)
|
||||
.ToImmutableArray();
|
||||
|
||||
return Task.FromResult(results);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Clear()
|
||||
{
|
||||
_embeddings.Clear();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get an embedding by function ID.
|
||||
/// </summary>
|
||||
/// <param name="functionId">Function identifier.</param>
|
||||
/// <returns>Embedding if found, null otherwise.</returns>
|
||||
public FunctionEmbedding? Get(string functionId)
|
||||
{
|
||||
return _embeddings.TryGetValue(functionId, out var embedding) ? embedding : null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Remove an embedding by function ID.
|
||||
/// </summary>
|
||||
/// <param name="functionId">Function identifier.</param>
|
||||
/// <returns>True if removed, false if not found.</returns>
|
||||
public bool Remove(string functionId)
|
||||
{
|
||||
return _embeddings.TryRemove(functionId, out _);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get all embeddings.
|
||||
/// </summary>
|
||||
/// <returns>All stored embeddings.</returns>
|
||||
public IEnumerable<FunctionEmbedding> GetAll()
|
||||
{
|
||||
return _embeddings.Values;
|
||||
}
|
||||
|
||||
private static decimal CosineSimilarity(float[] a, float[] b)
|
||||
{
|
||||
var dotProduct = 0.0;
|
||||
var normA = 0.0;
|
||||
var normB = 0.0;
|
||||
|
||||
for (var i = 0; i < a.Length; i++)
|
||||
{
|
||||
dotProduct += a[i] * b[i];
|
||||
normA += a[i] * a[i];
|
||||
normB += b[i] * b[i];
|
||||
}
|
||||
|
||||
if (normA == 0 || normB == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
var similarity = dotProduct / (Math.Sqrt(normA) * Math.Sqrt(normB));
|
||||
return (decimal)Math.Clamp(similarity, -1.0, 1.0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
// Copyright (c) StellaOps. All rights reserved.
|
||||
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
||||
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering ML services.
|
||||
/// </summary>
|
||||
public static class MlServiceCollectionExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds ML embedding services to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <returns>The service collection for chaining.</returns>
|
||||
public static IServiceCollection AddMlServices(this IServiceCollection services)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(services);
|
||||
|
||||
// Register tokenizer
|
||||
services.AddSingleton<ITokenizer, BinaryCodeTokenizer>();
|
||||
|
||||
// Register embedding index
|
||||
services.AddSingleton<IEmbeddingIndex, InMemoryEmbeddingIndex>();
|
||||
|
||||
// Register embedding service
|
||||
services.AddScoped<IEmbeddingService, OnnxInferenceEngine>();
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds ML services with custom options.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configureOptions">Action to configure ML options.</param>
|
||||
/// <returns>The service collection for chaining.</returns>
|
||||
public static IServiceCollection AddMlServices(
|
||||
this IServiceCollection services,
|
||||
Action<MlOptions> configureOptions)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(services);
|
||||
ArgumentNullException.ThrowIfNull(configureOptions);
|
||||
|
||||
services.Configure(configureOptions);
|
||||
return services.AddMlServices();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds ML services with a custom tokenizer.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="vocabularyPath">Path to vocabulary file.</param>
|
||||
/// <returns>The service collection for chaining.</returns>
|
||||
public static IServiceCollection AddMlServicesWithVocabulary(
|
||||
this IServiceCollection services,
|
||||
string vocabularyPath)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(services);
|
||||
ArgumentException.ThrowIfNullOrEmpty(vocabularyPath);
|
||||
|
||||
// Register tokenizer with vocabulary
|
||||
services.AddSingleton<ITokenizer>(sp => new BinaryCodeTokenizer(vocabularyPath));
|
||||
|
||||
// Register embedding index
|
||||
services.AddSingleton<IEmbeddingIndex, InMemoryEmbeddingIndex>();
|
||||
|
||||
// Register embedding service
|
||||
services.AddScoped<IEmbeddingService, OnnxInferenceEngine>();
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
259
src/BinaryIndex/__Libraries/StellaOps.BinaryIndex.ML/Models.cs
Normal file
259
src/BinaryIndex/__Libraries/StellaOps.BinaryIndex.ML/Models.cs
Normal file
@@ -0,0 +1,259 @@
|
||||
// Copyright (c) StellaOps. All rights reserved.
|
||||
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
||||
|
||||
using System.Collections.Immutable;
|
||||
using StellaOps.BinaryIndex.Semantic;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML;
|
||||
|
||||
/// <summary>
|
||||
/// Input for generating function embeddings.
|
||||
/// </summary>
|
||||
/// <param name="DecompiledCode">Decompiled C-like code if available.</param>
|
||||
/// <param name="SemanticGraph">Semantic graph from IR analysis if available.</param>
|
||||
/// <param name="InstructionBytes">Raw instruction bytes if available.</param>
|
||||
/// <param name="PreferredInput">Which input type to prefer for embedding generation.</param>
|
||||
public sealed record EmbeddingInput(
|
||||
string? DecompiledCode,
|
||||
KeySemanticsGraph? SemanticGraph,
|
||||
byte[]? InstructionBytes,
|
||||
EmbeddingInputType PreferredInput);
|
||||
|
||||
/// <summary>
|
||||
/// Type of input for embedding generation.
|
||||
/// </summary>
|
||||
public enum EmbeddingInputType
|
||||
{
|
||||
/// <summary>Use decompiled C-like code.</summary>
|
||||
DecompiledCode,
|
||||
|
||||
/// <summary>Use semantic graph from IR analysis.</summary>
|
||||
SemanticGraph,
|
||||
|
||||
/// <summary>Use raw instruction bytes.</summary>
|
||||
Instructions
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A function embedding vector.
|
||||
/// </summary>
|
||||
/// <param name="FunctionId">Identifier for the function.</param>
|
||||
/// <param name="FunctionName">Name of the function.</param>
|
||||
/// <param name="Vector">Embedding vector (typically 768 dimensions).</param>
|
||||
/// <param name="Model">Model used to generate the embedding.</param>
|
||||
/// <param name="InputType">Type of input used.</param>
|
||||
/// <param name="GeneratedAt">When the embedding was generated.</param>
|
||||
public sealed record FunctionEmbedding(
|
||||
string FunctionId,
|
||||
string FunctionName,
|
||||
float[] Vector,
|
||||
EmbeddingModel Model,
|
||||
EmbeddingInputType InputType,
|
||||
DateTimeOffset GeneratedAt);
|
||||
|
||||
/// <summary>
|
||||
/// Available embedding models.
|
||||
/// </summary>
|
||||
public enum EmbeddingModel
|
||||
{
|
||||
/// <summary>Fine-tuned CodeBERT for binary code analysis.</summary>
|
||||
CodeBertBinary,
|
||||
|
||||
/// <summary>Graph neural network for CFG/call graph analysis.</summary>
|
||||
GraphSageFunction,
|
||||
|
||||
/// <summary>Contrastive learning model for function similarity.</summary>
|
||||
ContrastiveFunction
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Similarity metrics for comparing embeddings.
|
||||
/// </summary>
|
||||
public enum SimilarityMetric
|
||||
{
|
||||
/// <summary>Cosine similarity (angle between vectors).</summary>
|
||||
Cosine,
|
||||
|
||||
/// <summary>Euclidean distance (inverted to similarity).</summary>
|
||||
Euclidean,
|
||||
|
||||
/// <summary>Manhattan distance (inverted to similarity).</summary>
|
||||
Manhattan,
|
||||
|
||||
/// <summary>Learned metric from model.</summary>
|
||||
LearnedMetric
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A match from embedding similarity search.
|
||||
/// </summary>
|
||||
/// <param name="FunctionId">Matched function identifier.</param>
|
||||
/// <param name="FunctionName">Matched function name.</param>
|
||||
/// <param name="Similarity">Similarity score (0.0 to 1.0).</param>
|
||||
/// <param name="LibraryName">Library containing the function.</param>
|
||||
/// <param name="LibraryVersion">Version of the library.</param>
|
||||
public sealed record EmbeddingMatch(
|
||||
string FunctionId,
|
||||
string FunctionName,
|
||||
decimal Similarity,
|
||||
string? LibraryName,
|
||||
string? LibraryVersion);
|
||||
|
||||
/// <summary>
|
||||
/// Options for embedding generation.
|
||||
/// </summary>
|
||||
public sealed record EmbeddingOptions
|
||||
{
|
||||
/// <summary>Maximum sequence length for tokenization.</summary>
|
||||
public int MaxSequenceLength { get; init; } = 512;
|
||||
|
||||
/// <summary>Whether to normalize the embedding vector.</summary>
|
||||
public bool NormalizeVector { get; init; } = true;
|
||||
|
||||
/// <summary>Batch size for batch inference.</summary>
|
||||
public int BatchSize { get; init; } = 32;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Training pair for model training.
|
||||
/// </summary>
|
||||
/// <param name="FunctionA">First function input.</param>
|
||||
/// <param name="FunctionB">Second function input.</param>
|
||||
/// <param name="IsSimilar">Ground truth: are these the same function?</param>
|
||||
/// <param name="SimilarityScore">Optional fine-grained similarity score.</param>
|
||||
public sealed record TrainingPair(
|
||||
EmbeddingInput FunctionA,
|
||||
EmbeddingInput FunctionB,
|
||||
bool IsSimilar,
|
||||
decimal? SimilarityScore);
|
||||
|
||||
/// <summary>
|
||||
/// Options for model training.
|
||||
/// </summary>
|
||||
public sealed record TrainingOptions
|
||||
{
|
||||
/// <summary>Model architecture to train.</summary>
|
||||
public EmbeddingModel Model { get; init; } = EmbeddingModel.CodeBertBinary;
|
||||
|
||||
/// <summary>Embedding vector dimension.</summary>
|
||||
public int EmbeddingDimension { get; init; } = 768;
|
||||
|
||||
/// <summary>Training batch size.</summary>
|
||||
public int BatchSize { get; init; } = 32;
|
||||
|
||||
/// <summary>Number of training epochs.</summary>
|
||||
public int Epochs { get; init; } = 10;
|
||||
|
||||
/// <summary>Learning rate.</summary>
|
||||
public double LearningRate { get; init; } = 1e-5;
|
||||
|
||||
/// <summary>Margin for contrastive loss.</summary>
|
||||
public double MarginLoss { get; init; } = 0.5;
|
||||
|
||||
/// <summary>Path to pretrained model weights.</summary>
|
||||
public string? PretrainedModelPath { get; init; }
|
||||
|
||||
/// <summary>Path to save checkpoints.</summary>
|
||||
public string? CheckpointPath { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Progress update during training.
|
||||
/// </summary>
|
||||
/// <param name="Epoch">Current epoch.</param>
|
||||
/// <param name="TotalEpochs">Total epochs.</param>
|
||||
/// <param name="Batch">Current batch.</param>
|
||||
/// <param name="TotalBatches">Total batches.</param>
|
||||
/// <param name="Loss">Current loss value.</param>
|
||||
/// <param name="Accuracy">Current accuracy.</param>
|
||||
public sealed record TrainingProgress(
|
||||
int Epoch,
|
||||
int TotalEpochs,
|
||||
int Batch,
|
||||
int TotalBatches,
|
||||
double Loss,
|
||||
double Accuracy);
|
||||
|
||||
/// <summary>
|
||||
/// Result of model training.
|
||||
/// </summary>
|
||||
/// <param name="ModelPath">Path to saved model.</param>
|
||||
/// <param name="TotalPairs">Number of training pairs used.</param>
|
||||
/// <param name="Epochs">Number of epochs completed.</param>
|
||||
/// <param name="FinalLoss">Final loss value.</param>
|
||||
/// <param name="ValidationAccuracy">Validation accuracy.</param>
|
||||
/// <param name="TrainingTime">Total training time.</param>
|
||||
public sealed record TrainingResult(
|
||||
string ModelPath,
|
||||
int TotalPairs,
|
||||
int Epochs,
|
||||
double FinalLoss,
|
||||
double ValidationAccuracy,
|
||||
TimeSpan TrainingTime);
|
||||
|
||||
/// <summary>
|
||||
/// Result of model evaluation.
|
||||
/// </summary>
|
||||
/// <param name="Accuracy">Overall accuracy.</param>
|
||||
/// <param name="Precision">Precision (true positives / predicted positives).</param>
|
||||
/// <param name="Recall">Recall (true positives / actual positives).</param>
|
||||
/// <param name="F1Score">F1 score (harmonic mean of precision and recall).</param>
|
||||
/// <param name="AucRoc">Area under ROC curve.</param>
|
||||
/// <param name="ConfusionMatrix">Confusion matrix entries.</param>
|
||||
public sealed record EvaluationResult(
|
||||
double Accuracy,
|
||||
double Precision,
|
||||
double Recall,
|
||||
double F1Score,
|
||||
double AucRoc,
|
||||
ImmutableArray<ConfusionEntry> ConfusionMatrix);
|
||||
|
||||
/// <summary>
|
||||
/// Entry in confusion matrix.
|
||||
/// </summary>
|
||||
/// <param name="Predicted">Predicted label.</param>
|
||||
/// <param name="Actual">Actual label.</param>
|
||||
/// <param name="Count">Number of occurrences.</param>
|
||||
public sealed record ConfusionEntry(
|
||||
string Predicted,
|
||||
string Actual,
|
||||
int Count);
|
||||
|
||||
/// <summary>
|
||||
/// Model export formats.
|
||||
/// </summary>
|
||||
public enum ModelExportFormat
|
||||
{
|
||||
/// <summary>ONNX format for cross-platform inference.</summary>
|
||||
Onnx,
|
||||
|
||||
/// <summary>PyTorch format.</summary>
|
||||
PyTorch,
|
||||
|
||||
/// <summary>TensorFlow SavedModel format.</summary>
|
||||
TensorFlow
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for ML service.
|
||||
/// </summary>
|
||||
public sealed record MlOptions
|
||||
{
|
||||
/// <summary>Path to ONNX model file.</summary>
|
||||
public string? ModelPath { get; init; }
|
||||
|
||||
/// <summary>Path to tokenizer vocabulary.</summary>
|
||||
public string? VocabularyPath { get; init; }
|
||||
|
||||
/// <summary>Device to use for inference (cpu, cuda).</summary>
|
||||
public string Device { get; init; } = "cpu";
|
||||
|
||||
/// <summary>Number of threads for inference.</summary>
|
||||
public int NumThreads { get; init; } = 4;
|
||||
|
||||
/// <summary>Whether to use GPU if available.</summary>
|
||||
public bool UseGpu { get; init; } = false;
|
||||
|
||||
/// <summary>Maximum batch size for inference.</summary>
|
||||
public int MaxBatchSize { get; init; } = 32;
|
||||
}
|
||||
@@ -0,0 +1,381 @@
|
||||
// Copyright (c) StellaOps. All rights reserved.
|
||||
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
||||
|
||||
using System.Collections.Immutable;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Microsoft.ML.OnnxRuntime;
|
||||
using Microsoft.ML.OnnxRuntime.Tensors;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML;
|
||||
|
||||
/// <summary>
|
||||
/// ONNX Runtime-based embedding inference engine.
|
||||
/// </summary>
|
||||
public sealed class OnnxInferenceEngine : IEmbeddingService, IAsyncDisposable
|
||||
{
|
||||
private readonly InferenceSession? _session;
|
||||
private readonly ITokenizer _tokenizer;
|
||||
private readonly IEmbeddingIndex? _index;
|
||||
private readonly MlOptions _options;
|
||||
private readonly ILogger<OnnxInferenceEngine> _logger;
|
||||
private readonly TimeProvider _timeProvider;
|
||||
private bool _disposed;
|
||||
|
||||
public OnnxInferenceEngine(
|
||||
ITokenizer tokenizer,
|
||||
IOptions<MlOptions> options,
|
||||
ILogger<OnnxInferenceEngine> logger,
|
||||
TimeProvider timeProvider,
|
||||
IEmbeddingIndex? index = null)
|
||||
{
|
||||
_tokenizer = tokenizer;
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
_timeProvider = timeProvider;
|
||||
_index = index;
|
||||
|
||||
if (!string.IsNullOrEmpty(_options.ModelPath) && File.Exists(_options.ModelPath))
|
||||
{
|
||||
var sessionOptions = new SessionOptions
|
||||
{
|
||||
GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL,
|
||||
ExecutionMode = ExecutionMode.ORT_PARALLEL,
|
||||
InterOpNumThreads = _options.NumThreads,
|
||||
IntraOpNumThreads = _options.NumThreads
|
||||
};
|
||||
|
||||
_session = new InferenceSession(_options.ModelPath, sessionOptions);
|
||||
_logger.LogInformation(
|
||||
"Loaded ONNX model from {Path}",
|
||||
_options.ModelPath);
|
||||
}
|
||||
else
|
||||
{
|
||||
_logger.LogWarning(
|
||||
"No ONNX model found at {Path}, using fallback embedding",
|
||||
_options.ModelPath);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<FunctionEmbedding> GenerateEmbeddingAsync(
|
||||
EmbeddingInput input,
|
||||
EmbeddingOptions? options = null,
|
||||
CancellationToken ct = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(input);
|
||||
ct.ThrowIfCancellationRequested();
|
||||
|
||||
options ??= new EmbeddingOptions();
|
||||
|
||||
var text = GetInputText(input);
|
||||
var functionId = ComputeFunctionId(text);
|
||||
|
||||
float[] vector;
|
||||
|
||||
if (_session is not null)
|
||||
{
|
||||
vector = await RunInferenceAsync(text, options, ct);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Fallback: generate hash-based pseudo-embedding for testing
|
||||
vector = GenerateFallbackEmbedding(text, 768);
|
||||
}
|
||||
|
||||
if (options.NormalizeVector)
|
||||
{
|
||||
NormalizeVector(vector);
|
||||
}
|
||||
|
||||
return new FunctionEmbedding(
|
||||
functionId,
|
||||
ExtractFunctionName(text),
|
||||
vector,
|
||||
EmbeddingModel.CodeBertBinary,
|
||||
input.PreferredInput,
|
||||
_timeProvider.GetUtcNow());
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<ImmutableArray<FunctionEmbedding>> GenerateBatchAsync(
|
||||
IEnumerable<EmbeddingInput> inputs,
|
||||
EmbeddingOptions? options = null,
|
||||
CancellationToken ct = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(inputs);
|
||||
|
||||
options ??= new EmbeddingOptions();
|
||||
var results = new List<FunctionEmbedding>();
|
||||
|
||||
// Process in batches
|
||||
var batch = new List<EmbeddingInput>();
|
||||
foreach (var input in inputs)
|
||||
{
|
||||
ct.ThrowIfCancellationRequested();
|
||||
batch.Add(input);
|
||||
|
||||
if (batch.Count >= options.BatchSize)
|
||||
{
|
||||
var batchResults = await ProcessBatchAsync(batch, options, ct);
|
||||
results.AddRange(batchResults);
|
||||
batch.Clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Process remaining
|
||||
if (batch.Count > 0)
|
||||
{
|
||||
var batchResults = await ProcessBatchAsync(batch, options, ct);
|
||||
results.AddRange(batchResults);
|
||||
}
|
||||
|
||||
return [.. results];
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public decimal ComputeSimilarity(
|
||||
FunctionEmbedding a,
|
||||
FunctionEmbedding b,
|
||||
SimilarityMetric metric = SimilarityMetric.Cosine)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(a);
|
||||
ArgumentNullException.ThrowIfNull(b);
|
||||
|
||||
if (a.Vector.Length != b.Vector.Length)
|
||||
{
|
||||
throw new ArgumentException("Embedding vectors must have same dimension");
|
||||
}
|
||||
|
||||
return metric switch
|
||||
{
|
||||
SimilarityMetric.Cosine => CosineSimilarity(a.Vector, b.Vector),
|
||||
SimilarityMetric.Euclidean => EuclideanSimilarity(a.Vector, b.Vector),
|
||||
SimilarityMetric.Manhattan => ManhattanSimilarity(a.Vector, b.Vector),
|
||||
SimilarityMetric.LearnedMetric => CosineSimilarity(a.Vector, b.Vector), // Fallback
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(metric))
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<ImmutableArray<EmbeddingMatch>> FindSimilarAsync(
|
||||
FunctionEmbedding query,
|
||||
int topK = 10,
|
||||
decimal minSimilarity = 0.7m,
|
||||
CancellationToken ct = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(query);
|
||||
|
||||
if (_index is null)
|
||||
{
|
||||
_logger.LogWarning("No embedding index configured, cannot search");
|
||||
return [];
|
||||
}
|
||||
|
||||
var results = await _index.SearchAsync(query.Vector, topK, ct);
|
||||
|
||||
return results
|
||||
.Where(r => r.Similarity >= minSimilarity)
|
||||
.Select(r => new EmbeddingMatch(
|
||||
r.Embedding.FunctionId,
|
||||
r.Embedding.FunctionName,
|
||||
r.Similarity,
|
||||
null, // Library info would come from metadata
|
||||
null))
|
||||
.ToImmutableArray();
|
||||
}
|
||||
|
||||
private async Task<float[]> RunInferenceAsync(
|
||||
string text,
|
||||
EmbeddingOptions options,
|
||||
CancellationToken ct)
|
||||
{
|
||||
if (_session is null)
|
||||
{
|
||||
throw new InvalidOperationException("ONNX session not initialized");
|
||||
}
|
||||
|
||||
var (inputIds, attentionMask) = _tokenizer.TokenizeWithMask(text, options.MaxSequenceLength);
|
||||
|
||||
var inputIdsTensor = new DenseTensor<long>(inputIds, [1, inputIds.Length]);
|
||||
var attentionMaskTensor = new DenseTensor<long>(attentionMask, [1, attentionMask.Length]);
|
||||
|
||||
var inputs = new List<NamedOnnxValue>
|
||||
{
|
||||
NamedOnnxValue.CreateFromTensor("input_ids", inputIdsTensor),
|
||||
NamedOnnxValue.CreateFromTensor("attention_mask", attentionMaskTensor)
|
||||
};
|
||||
|
||||
using var results = await Task.Run(() => _session.Run(inputs), ct);
|
||||
|
||||
var outputTensor = results.First().AsTensor<float>();
|
||||
return outputTensor.ToArray();
|
||||
}
|
||||
|
||||
private async Task<IEnumerable<FunctionEmbedding>> ProcessBatchAsync(
|
||||
List<EmbeddingInput> batch,
|
||||
EmbeddingOptions options,
|
||||
CancellationToken ct)
|
||||
{
|
||||
// For now, process sequentially
|
||||
// TODO: Implement true batch inference with batched tensors
|
||||
var results = new List<FunctionEmbedding>();
|
||||
foreach (var input in batch)
|
||||
{
|
||||
var embedding = await GenerateEmbeddingAsync(input, options, ct);
|
||||
results.Add(embedding);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
private static string GetInputText(EmbeddingInput input)
|
||||
{
|
||||
return input.PreferredInput switch
|
||||
{
|
||||
EmbeddingInputType.DecompiledCode => input.DecompiledCode
|
||||
?? throw new ArgumentException("DecompiledCode required"),
|
||||
EmbeddingInputType.SemanticGraph => SerializeGraph(input.SemanticGraph
|
||||
?? throw new ArgumentException("SemanticGraph required")),
|
||||
EmbeddingInputType.Instructions => SerializeInstructions(input.InstructionBytes
|
||||
?? throw new ArgumentException("InstructionBytes required")),
|
||||
_ => throw new ArgumentOutOfRangeException()
|
||||
};
|
||||
}
|
||||
|
||||
private static string SerializeGraph(Semantic.KeySemanticsGraph graph)
|
||||
{
|
||||
// Convert graph to textual representation for tokenization
|
||||
var sb = new System.Text.StringBuilder();
|
||||
sb.AppendLine($"// Graph: {graph.Nodes.Length} nodes");
|
||||
|
||||
foreach (var node in graph.Nodes)
|
||||
{
|
||||
sb.AppendLine($"node {node.Id}: {node.Operation}");
|
||||
}
|
||||
|
||||
foreach (var edge in graph.Edges)
|
||||
{
|
||||
sb.AppendLine($"edge {edge.SourceId} -> {edge.TargetId}");
|
||||
}
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private static string SerializeInstructions(byte[] bytes)
|
||||
{
|
||||
// Convert instruction bytes to hex representation
|
||||
return Convert.ToHexString(bytes);
|
||||
}
|
||||
|
||||
private static string ComputeFunctionId(string text)
|
||||
{
|
||||
var hash = System.Security.Cryptography.SHA256.HashData(
|
||||
System.Text.Encoding.UTF8.GetBytes(text));
|
||||
return Convert.ToHexString(hash)[..16];
|
||||
}
|
||||
|
||||
private static string ExtractFunctionName(string text)
|
||||
{
|
||||
// Try to extract function name from code
|
||||
var match = System.Text.RegularExpressions.Regex.Match(
|
||||
text,
|
||||
@"\b(\w+)\s*\(");
|
||||
return match.Success ? match.Groups[1].Value : "unknown";
|
||||
}
|
||||
|
||||
private static float[] GenerateFallbackEmbedding(string text, int dimension)
|
||||
{
|
||||
// Generate a deterministic pseudo-embedding based on text hash
|
||||
// This is only for testing when no model is available
|
||||
var hash = System.Security.Cryptography.SHA256.HashData(
|
||||
System.Text.Encoding.UTF8.GetBytes(text));
|
||||
|
||||
var random = new Random(BitConverter.ToInt32(hash, 0));
|
||||
var vector = new float[dimension];
|
||||
|
||||
for (var i = 0; i < dimension; i++)
|
||||
{
|
||||
vector[i] = (float)(random.NextDouble() * 2 - 1);
|
||||
}
|
||||
|
||||
return vector;
|
||||
}
|
||||
|
||||
private static void NormalizeVector(float[] vector)
|
||||
{
|
||||
var norm = 0.0;
|
||||
for (var i = 0; i < vector.Length; i++)
|
||||
{
|
||||
norm += vector[i] * vector[i];
|
||||
}
|
||||
|
||||
norm = Math.Sqrt(norm);
|
||||
if (norm > 0)
|
||||
{
|
||||
for (var i = 0; i < vector.Length; i++)
|
||||
{
|
||||
vector[i] /= (float)norm;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static decimal CosineSimilarity(float[] a, float[] b)
|
||||
{
|
||||
var dotProduct = 0.0;
|
||||
var normA = 0.0;
|
||||
var normB = 0.0;
|
||||
|
||||
for (var i = 0; i < a.Length; i++)
|
||||
{
|
||||
dotProduct += a[i] * b[i];
|
||||
normA += a[i] * a[i];
|
||||
normB += b[i] * b[i];
|
||||
}
|
||||
|
||||
if (normA == 0 || normB == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
var similarity = dotProduct / (Math.Sqrt(normA) * Math.Sqrt(normB));
|
||||
return (decimal)Math.Clamp(similarity, -1.0, 1.0);
|
||||
}
|
||||
|
||||
private static decimal EuclideanSimilarity(float[] a, float[] b)
|
||||
{
|
||||
var sumSquares = 0.0;
|
||||
for (var i = 0; i < a.Length; i++)
|
||||
{
|
||||
var diff = a[i] - b[i];
|
||||
sumSquares += diff * diff;
|
||||
}
|
||||
|
||||
var distance = Math.Sqrt(sumSquares);
|
||||
// Convert distance to similarity (0 = identical, larger = more different)
|
||||
return (decimal)(1.0 / (1.0 + distance));
|
||||
}
|
||||
|
||||
private static decimal ManhattanSimilarity(float[] a, float[] b)
|
||||
{
|
||||
var sum = 0.0;
|
||||
for (var i = 0; i < a.Length; i++)
|
||||
{
|
||||
sum += Math.Abs(a[i] - b[i]);
|
||||
}
|
||||
|
||||
// Convert distance to similarity
|
||||
return (decimal)(1.0 / (1.0 + sum));
|
||||
}
|
||||
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
_session?.Dispose();
|
||||
_disposed = true;
|
||||
}
|
||||
|
||||
await Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net10.0</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
<Description>Machine learning-based function similarity using embeddings and ONNX inference for BinaryIndex.</Description>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\StellaOps.BinaryIndex.Decompiler\StellaOps.BinaryIndex.Decompiler.csproj" />
|
||||
<ProjectReference Include="..\StellaOps.BinaryIndex.Semantic\StellaOps.BinaryIndex.Semantic.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
|
||||
<PackageReference Include="Microsoft.Extensions.Options" />
|
||||
<PackageReference Include="Microsoft.ML.OnnxRuntime" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
Reference in New Issue
Block a user