Files
git.stella-ops.org/src/BinaryIndex/__Libraries/StellaOps.BinaryIndex.ML/Models.cs
2026-02-01 21:37:40 +02:00

261 lines
8.1 KiB
C#

// Copyright (c) StellaOps. All rights reserved.
// Licensed under BUSL-1.1. See LICENSE in the project root.
using StellaOps.BinaryIndex.Semantic;
using System.Collections.Immutable;
namespace StellaOps.BinaryIndex.ML;
/// <summary>
/// Input for generating function embeddings.
/// </summary>
/// <param name="DecompiledCode">Decompiled C-like code if available.</param>
/// <param name="SemanticGraph">Semantic graph from IR analysis if available.</param>
/// <param name="InstructionBytes">Raw instruction bytes if available.</param>
/// <param name="PreferredInput">Which input type to prefer for embedding generation.</param>
public sealed record EmbeddingInput(
string? DecompiledCode,
KeySemanticsGraph? SemanticGraph,
byte[]? InstructionBytes,
EmbeddingInputType PreferredInput);
/// <summary>
/// Type of input for embedding generation.
/// </summary>
public enum EmbeddingInputType
{
/// <summary>Use decompiled C-like code.</summary>
DecompiledCode,
/// <summary>Use semantic graph from IR analysis.</summary>
SemanticGraph,
/// <summary>Use raw instruction bytes.</summary>
Instructions
}
/// <summary>
/// A function embedding vector.
/// </summary>
/// <param name="FunctionId">Identifier for the function.</param>
/// <param name="FunctionName">Name of the function.</param>
/// <param name="Vector">Embedding vector (typically 768 dimensions).</param>
/// <param name="Model">Model used to generate the embedding.</param>
/// <param name="InputType">Type of input used.</param>
/// <param name="GeneratedAt">When the embedding was generated.</param>
public sealed record FunctionEmbedding(
string FunctionId,
string FunctionName,
float[] Vector,
EmbeddingModel Model,
EmbeddingInputType InputType,
DateTimeOffset GeneratedAt);
/// <summary>
/// Available embedding models.
/// </summary>
public enum EmbeddingModel
{
/// <summary>Fine-tuned CodeBERT for binary code analysis.</summary>
CodeBertBinary,
/// <summary>Graph neural network for CFG/call graph analysis.</summary>
GraphSageFunction,
/// <summary>Contrastive learning model for function similarity.</summary>
ContrastiveFunction
}
/// <summary>
/// Similarity metrics for comparing embeddings.
/// </summary>
public enum SimilarityMetric
{
/// <summary>Cosine similarity (angle between vectors).</summary>
Cosine,
/// <summary>Euclidean distance (inverted to similarity).</summary>
Euclidean,
/// <summary>Manhattan distance (inverted to similarity).</summary>
Manhattan,
/// <summary>Learned metric from model.</summary>
LearnedMetric
}
/// <summary>
/// A match from embedding similarity search.
/// </summary>
/// <param name="FunctionId">Matched function identifier.</param>
/// <param name="FunctionName">Matched function name.</param>
/// <param name="Similarity">Similarity score (0.0 to 1.0).</param>
/// <param name="LibraryName">Library containing the function.</param>
/// <param name="LibraryVersion">Version of the library.</param>
public sealed record EmbeddingMatch(
string FunctionId,
string FunctionName,
decimal Similarity,
string? LibraryName,
string? LibraryVersion);
/// <summary>
/// Options for embedding generation.
/// </summary>
public sealed record EmbeddingOptions
{
/// <summary>Maximum sequence length for tokenization.</summary>
public int MaxSequenceLength { get; init; } = 512;
/// <summary>Whether to normalize the embedding vector.</summary>
public bool NormalizeVector { get; init; } = true;
/// <summary>Batch size for batch inference.</summary>
public int BatchSize { get; init; } = 32;
}
/// <summary>
/// Training pair for model training.
/// </summary>
/// <param name="FunctionA">First function input.</param>
/// <param name="FunctionB">Second function input.</param>
/// <param name="IsSimilar">Ground truth: are these the same function?</param>
/// <param name="SimilarityScore">Optional fine-grained similarity score.</param>
public sealed record TrainingPair(
EmbeddingInput FunctionA,
EmbeddingInput FunctionB,
bool IsSimilar,
decimal? SimilarityScore);
/// <summary>
/// Options for model training.
/// </summary>
public sealed record TrainingOptions
{
/// <summary>Model architecture to train.</summary>
public EmbeddingModel Model { get; init; } = EmbeddingModel.CodeBertBinary;
/// <summary>Embedding vector dimension.</summary>
public int EmbeddingDimension { get; init; } = 768;
/// <summary>Training batch size.</summary>
public int BatchSize { get; init; } = 32;
/// <summary>Number of training epochs.</summary>
public int Epochs { get; init; } = 10;
/// <summary>Learning rate.</summary>
public double LearningRate { get; init; } = 1e-5;
/// <summary>Margin for contrastive loss.</summary>
public double MarginLoss { get; init; } = 0.5;
/// <summary>Path to pretrained model weights.</summary>
public string? PretrainedModelPath { get; init; }
/// <summary>Path to save checkpoints.</summary>
public string? CheckpointPath { get; init; }
}
/// <summary>
/// Progress update during training.
/// </summary>
/// <param name="Epoch">Current epoch.</param>
/// <param name="TotalEpochs">Total epochs.</param>
/// <param name="Batch">Current batch.</param>
/// <param name="TotalBatches">Total batches.</param>
/// <param name="Loss">Current loss value.</param>
/// <param name="Accuracy">Current accuracy.</param>
public sealed record TrainingProgress(
int Epoch,
int TotalEpochs,
int Batch,
int TotalBatches,
double Loss,
double Accuracy);
/// <summary>
/// Result of model training.
/// </summary>
/// <param name="ModelPath">Path to saved model.</param>
/// <param name="TotalPairs">Number of training pairs used.</param>
/// <param name="Epochs">Number of epochs completed.</param>
/// <param name="FinalLoss">Final loss value.</param>
/// <param name="ValidationAccuracy">Validation accuracy.</param>
/// <param name="TrainingTime">Total training time.</param>
public sealed record TrainingResult(
string ModelPath,
int TotalPairs,
int Epochs,
double FinalLoss,
double ValidationAccuracy,
TimeSpan TrainingTime);
/// <summary>
/// Result of model evaluation.
/// </summary>
/// <param name="Accuracy">Overall accuracy.</param>
/// <param name="Precision">Precision (true positives / predicted positives).</param>
/// <param name="Recall">Recall (true positives / actual positives).</param>
/// <param name="F1Score">F1 score (harmonic mean of precision and recall).</param>
/// <param name="AucRoc">Area under ROC curve.</param>
/// <param name="ConfusionMatrix">Confusion matrix entries.</param>
public sealed record EvaluationResult(
double Accuracy,
double Precision,
double Recall,
double F1Score,
double AucRoc,
ImmutableArray<ConfusionEntry> ConfusionMatrix);
/// <summary>
/// Entry in confusion matrix.
/// </summary>
/// <param name="Predicted">Predicted label.</param>
/// <param name="Actual">Actual label.</param>
/// <param name="Count">Number of occurrences.</param>
public sealed record ConfusionEntry(
string Predicted,
string Actual,
int Count);
/// <summary>
/// Model export formats.
/// </summary>
public enum ModelExportFormat
{
/// <summary>ONNX format for cross-platform inference.</summary>
Onnx,
/// <summary>PyTorch format.</summary>
PyTorch,
/// <summary>TensorFlow SavedModel format.</summary>
TensorFlow
}
/// <summary>
/// Options for ML service.
/// </summary>
public sealed record MlOptions
{
/// <summary>Path to ONNX model file.</summary>
public string? ModelPath { get; init; }
/// <summary>Path to tokenizer vocabulary.</summary>
public string? VocabularyPath { get; init; }
/// <summary>Device to use for inference (cpu, cuda).</summary>
public string Device { get; init; } = "cpu";
/// <summary>Number of threads for inference.</summary>
public int NumThreads { get; init; } = 4;
/// <summary>Whether to use GPU if available.</summary>
public bool UseGpu { get; init; } = false;
/// <summary>Maximum batch size for inference.</summary>
public int MaxBatchSize { get; init; } = 32;
}