261 lines
8.1 KiB
C#
261 lines
8.1 KiB
C#
// Copyright (c) StellaOps. All rights reserved.
|
|
// Licensed under BUSL-1.1. See LICENSE in the project root.
|
|
|
|
|
|
using StellaOps.BinaryIndex.Semantic;
|
|
using System.Collections.Immutable;
|
|
|
|
namespace StellaOps.BinaryIndex.ML;
|
|
|
|
/// <summary>
|
|
/// Input for generating function embeddings.
|
|
/// </summary>
|
|
/// <param name="DecompiledCode">Decompiled C-like code if available.</param>
|
|
/// <param name="SemanticGraph">Semantic graph from IR analysis if available.</param>
|
|
/// <param name="InstructionBytes">Raw instruction bytes if available.</param>
|
|
/// <param name="PreferredInput">Which input type to prefer for embedding generation.</param>
|
|
public sealed record EmbeddingInput(
|
|
string? DecompiledCode,
|
|
KeySemanticsGraph? SemanticGraph,
|
|
byte[]? InstructionBytes,
|
|
EmbeddingInputType PreferredInput);
|
|
|
|
/// <summary>
|
|
/// Type of input for embedding generation.
|
|
/// </summary>
|
|
public enum EmbeddingInputType
|
|
{
|
|
/// <summary>Use decompiled C-like code.</summary>
|
|
DecompiledCode,
|
|
|
|
/// <summary>Use semantic graph from IR analysis.</summary>
|
|
SemanticGraph,
|
|
|
|
/// <summary>Use raw instruction bytes.</summary>
|
|
Instructions
|
|
}
|
|
|
|
/// <summary>
|
|
/// A function embedding vector.
|
|
/// </summary>
|
|
/// <param name="FunctionId">Identifier for the function.</param>
|
|
/// <param name="FunctionName">Name of the function.</param>
|
|
/// <param name="Vector">Embedding vector (typically 768 dimensions).</param>
|
|
/// <param name="Model">Model used to generate the embedding.</param>
|
|
/// <param name="InputType">Type of input used.</param>
|
|
/// <param name="GeneratedAt">When the embedding was generated.</param>
|
|
public sealed record FunctionEmbedding(
|
|
string FunctionId,
|
|
string FunctionName,
|
|
float[] Vector,
|
|
EmbeddingModel Model,
|
|
EmbeddingInputType InputType,
|
|
DateTimeOffset GeneratedAt);
|
|
|
|
/// <summary>
|
|
/// Available embedding models.
|
|
/// </summary>
|
|
public enum EmbeddingModel
|
|
{
|
|
/// <summary>Fine-tuned CodeBERT for binary code analysis.</summary>
|
|
CodeBertBinary,
|
|
|
|
/// <summary>Graph neural network for CFG/call graph analysis.</summary>
|
|
GraphSageFunction,
|
|
|
|
/// <summary>Contrastive learning model for function similarity.</summary>
|
|
ContrastiveFunction
|
|
}
|
|
|
|
/// <summary>
|
|
/// Similarity metrics for comparing embeddings.
|
|
/// </summary>
|
|
public enum SimilarityMetric
|
|
{
|
|
/// <summary>Cosine similarity (angle between vectors).</summary>
|
|
Cosine,
|
|
|
|
/// <summary>Euclidean distance (inverted to similarity).</summary>
|
|
Euclidean,
|
|
|
|
/// <summary>Manhattan distance (inverted to similarity).</summary>
|
|
Manhattan,
|
|
|
|
/// <summary>Learned metric from model.</summary>
|
|
LearnedMetric
|
|
}
|
|
|
|
/// <summary>
|
|
/// A match from embedding similarity search.
|
|
/// </summary>
|
|
/// <param name="FunctionId">Matched function identifier.</param>
|
|
/// <param name="FunctionName">Matched function name.</param>
|
|
/// <param name="Similarity">Similarity score (0.0 to 1.0).</param>
|
|
/// <param name="LibraryName">Library containing the function.</param>
|
|
/// <param name="LibraryVersion">Version of the library.</param>
|
|
public sealed record EmbeddingMatch(
|
|
string FunctionId,
|
|
string FunctionName,
|
|
decimal Similarity,
|
|
string? LibraryName,
|
|
string? LibraryVersion);
|
|
|
|
/// <summary>
|
|
/// Options for embedding generation.
|
|
/// </summary>
|
|
public sealed record EmbeddingOptions
|
|
{
|
|
/// <summary>Maximum sequence length for tokenization.</summary>
|
|
public int MaxSequenceLength { get; init; } = 512;
|
|
|
|
/// <summary>Whether to normalize the embedding vector.</summary>
|
|
public bool NormalizeVector { get; init; } = true;
|
|
|
|
/// <summary>Batch size for batch inference.</summary>
|
|
public int BatchSize { get; init; } = 32;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Training pair for model training.
|
|
/// </summary>
|
|
/// <param name="FunctionA">First function input.</param>
|
|
/// <param name="FunctionB">Second function input.</param>
|
|
/// <param name="IsSimilar">Ground truth: are these the same function?</param>
|
|
/// <param name="SimilarityScore">Optional fine-grained similarity score.</param>
|
|
public sealed record TrainingPair(
|
|
EmbeddingInput FunctionA,
|
|
EmbeddingInput FunctionB,
|
|
bool IsSimilar,
|
|
decimal? SimilarityScore);
|
|
|
|
/// <summary>
|
|
/// Options for model training.
|
|
/// </summary>
|
|
public sealed record TrainingOptions
|
|
{
|
|
/// <summary>Model architecture to train.</summary>
|
|
public EmbeddingModel Model { get; init; } = EmbeddingModel.CodeBertBinary;
|
|
|
|
/// <summary>Embedding vector dimension.</summary>
|
|
public int EmbeddingDimension { get; init; } = 768;
|
|
|
|
/// <summary>Training batch size.</summary>
|
|
public int BatchSize { get; init; } = 32;
|
|
|
|
/// <summary>Number of training epochs.</summary>
|
|
public int Epochs { get; init; } = 10;
|
|
|
|
/// <summary>Learning rate.</summary>
|
|
public double LearningRate { get; init; } = 1e-5;
|
|
|
|
/// <summary>Margin for contrastive loss.</summary>
|
|
public double MarginLoss { get; init; } = 0.5;
|
|
|
|
/// <summary>Path to pretrained model weights.</summary>
|
|
public string? PretrainedModelPath { get; init; }
|
|
|
|
/// <summary>Path to save checkpoints.</summary>
|
|
public string? CheckpointPath { get; init; }
|
|
}
|
|
|
|
/// <summary>
|
|
/// Progress update during training.
|
|
/// </summary>
|
|
/// <param name="Epoch">Current epoch.</param>
|
|
/// <param name="TotalEpochs">Total epochs.</param>
|
|
/// <param name="Batch">Current batch.</param>
|
|
/// <param name="TotalBatches">Total batches.</param>
|
|
/// <param name="Loss">Current loss value.</param>
|
|
/// <param name="Accuracy">Current accuracy.</param>
|
|
public sealed record TrainingProgress(
|
|
int Epoch,
|
|
int TotalEpochs,
|
|
int Batch,
|
|
int TotalBatches,
|
|
double Loss,
|
|
double Accuracy);
|
|
|
|
/// <summary>
|
|
/// Result of model training.
|
|
/// </summary>
|
|
/// <param name="ModelPath">Path to saved model.</param>
|
|
/// <param name="TotalPairs">Number of training pairs used.</param>
|
|
/// <param name="Epochs">Number of epochs completed.</param>
|
|
/// <param name="FinalLoss">Final loss value.</param>
|
|
/// <param name="ValidationAccuracy">Validation accuracy.</param>
|
|
/// <param name="TrainingTime">Total training time.</param>
|
|
public sealed record TrainingResult(
|
|
string ModelPath,
|
|
int TotalPairs,
|
|
int Epochs,
|
|
double FinalLoss,
|
|
double ValidationAccuracy,
|
|
TimeSpan TrainingTime);
|
|
|
|
/// <summary>
|
|
/// Result of model evaluation.
|
|
/// </summary>
|
|
/// <param name="Accuracy">Overall accuracy.</param>
|
|
/// <param name="Precision">Precision (true positives / predicted positives).</param>
|
|
/// <param name="Recall">Recall (true positives / actual positives).</param>
|
|
/// <param name="F1Score">F1 score (harmonic mean of precision and recall).</param>
|
|
/// <param name="AucRoc">Area under ROC curve.</param>
|
|
/// <param name="ConfusionMatrix">Confusion matrix entries.</param>
|
|
public sealed record EvaluationResult(
|
|
double Accuracy,
|
|
double Precision,
|
|
double Recall,
|
|
double F1Score,
|
|
double AucRoc,
|
|
ImmutableArray<ConfusionEntry> ConfusionMatrix);
|
|
|
|
/// <summary>
|
|
/// Entry in confusion matrix.
|
|
/// </summary>
|
|
/// <param name="Predicted">Predicted label.</param>
|
|
/// <param name="Actual">Actual label.</param>
|
|
/// <param name="Count">Number of occurrences.</param>
|
|
public sealed record ConfusionEntry(
|
|
string Predicted,
|
|
string Actual,
|
|
int Count);
|
|
|
|
/// <summary>
|
|
/// Model export formats.
|
|
/// </summary>
|
|
public enum ModelExportFormat
|
|
{
|
|
/// <summary>ONNX format for cross-platform inference.</summary>
|
|
Onnx,
|
|
|
|
/// <summary>PyTorch format.</summary>
|
|
PyTorch,
|
|
|
|
/// <summary>TensorFlow SavedModel format.</summary>
|
|
TensorFlow
|
|
}
|
|
|
|
/// <summary>
|
|
/// Options for ML service.
|
|
/// </summary>
|
|
public sealed record MlOptions
|
|
{
|
|
/// <summary>Path to ONNX model file.</summary>
|
|
public string? ModelPath { get; init; }
|
|
|
|
/// <summary>Path to tokenizer vocabulary.</summary>
|
|
public string? VocabularyPath { get; init; }
|
|
|
|
/// <summary>Device to use for inference (cpu, cuda).</summary>
|
|
public string Device { get; init; } = "cpu";
|
|
|
|
/// <summary>Number of threads for inference.</summary>
|
|
public int NumThreads { get; init; } = 4;
|
|
|
|
/// <summary>Whether to use GPU if available.</summary>
|
|
public bool UseGpu { get; init; } = false;
|
|
|
|
/// <summary>Maximum batch size for inference.</summary>
|
|
public int MaxBatchSize { get; init; } = 32;
|
|
}
|