// Copyright (c) StellaOps. All rights reserved.
// Licensed under BUSL-1.1. See LICENSE in the project root.
using StellaOps.BinaryIndex.Semantic;
using System.Collections.Immutable;
namespace StellaOps.BinaryIndex.ML;
///
/// Input for generating function embeddings.
///
/// Decompiled C-like code if available.
/// Semantic graph from IR analysis if available.
/// Raw instruction bytes if available.
/// Which input type to prefer for embedding generation.
public sealed record EmbeddingInput(
string? DecompiledCode,
KeySemanticsGraph? SemanticGraph,
byte[]? InstructionBytes,
EmbeddingInputType PreferredInput);
///
/// Type of input for embedding generation.
///
public enum EmbeddingInputType
{
/// Use decompiled C-like code.
DecompiledCode,
/// Use semantic graph from IR analysis.
SemanticGraph,
/// Use raw instruction bytes.
Instructions
}
///
/// A function embedding vector.
///
/// Identifier for the function.
/// Name of the function.
/// Embedding vector (typically 768 dimensions).
/// Model used to generate the embedding.
/// Type of input used.
/// When the embedding was generated.
public sealed record FunctionEmbedding(
string FunctionId,
string FunctionName,
float[] Vector,
EmbeddingModel Model,
EmbeddingInputType InputType,
DateTimeOffset GeneratedAt);
///
/// Available embedding models.
///
public enum EmbeddingModel
{
/// Fine-tuned CodeBERT for binary code analysis.
CodeBertBinary,
/// Graph neural network for CFG/call graph analysis.
GraphSageFunction,
/// Contrastive learning model for function similarity.
ContrastiveFunction
}
///
/// Similarity metrics for comparing embeddings.
///
public enum SimilarityMetric
{
/// Cosine similarity (angle between vectors).
Cosine,
/// Euclidean distance (inverted to similarity).
Euclidean,
/// Manhattan distance (inverted to similarity).
Manhattan,
/// Learned metric from model.
LearnedMetric
}
///
/// A match from embedding similarity search.
///
/// Matched function identifier.
/// Matched function name.
/// Similarity score (0.0 to 1.0).
/// Library containing the function.
/// Version of the library.
public sealed record EmbeddingMatch(
string FunctionId,
string FunctionName,
decimal Similarity,
string? LibraryName,
string? LibraryVersion);
///
/// Options for embedding generation.
///
public sealed record EmbeddingOptions
{
/// Maximum sequence length for tokenization.
public int MaxSequenceLength { get; init; } = 512;
/// Whether to normalize the embedding vector.
public bool NormalizeVector { get; init; } = true;
/// Batch size for batch inference.
public int BatchSize { get; init; } = 32;
}
///
/// Training pair for model training.
///
/// First function input.
/// Second function input.
/// Ground truth: are these the same function?
/// Optional fine-grained similarity score.
public sealed record TrainingPair(
EmbeddingInput FunctionA,
EmbeddingInput FunctionB,
bool IsSimilar,
decimal? SimilarityScore);
///
/// Options for model training.
///
public sealed record TrainingOptions
{
/// Model architecture to train.
public EmbeddingModel Model { get; init; } = EmbeddingModel.CodeBertBinary;
/// Embedding vector dimension.
public int EmbeddingDimension { get; init; } = 768;
/// Training batch size.
public int BatchSize { get; init; } = 32;
/// Number of training epochs.
public int Epochs { get; init; } = 10;
/// Learning rate.
public double LearningRate { get; init; } = 1e-5;
/// Margin for contrastive loss.
public double MarginLoss { get; init; } = 0.5;
/// Path to pretrained model weights.
public string? PretrainedModelPath { get; init; }
/// Path to save checkpoints.
public string? CheckpointPath { get; init; }
}
///
/// Progress update during training.
///
/// Current epoch.
/// Total epochs.
/// Current batch.
/// Total batches.
/// Current loss value.
/// Current accuracy.
public sealed record TrainingProgress(
int Epoch,
int TotalEpochs,
int Batch,
int TotalBatches,
double Loss,
double Accuracy);
///
/// Result of model training.
///
/// Path to saved model.
/// Number of training pairs used.
/// Number of epochs completed.
/// Final loss value.
/// Validation accuracy.
/// Total training time.
public sealed record TrainingResult(
string ModelPath,
int TotalPairs,
int Epochs,
double FinalLoss,
double ValidationAccuracy,
TimeSpan TrainingTime);
///
/// Result of model evaluation.
///
/// Overall accuracy.
/// Precision (true positives / predicted positives).
/// Recall (true positives / actual positives).
/// F1 score (harmonic mean of precision and recall).
/// Area under ROC curve.
/// Confusion matrix entries.
public sealed record EvaluationResult(
double Accuracy,
double Precision,
double Recall,
double F1Score,
double AucRoc,
ImmutableArray ConfusionMatrix);
///
/// Entry in confusion matrix.
///
/// Predicted label.
/// Actual label.
/// Number of occurrences.
public sealed record ConfusionEntry(
string Predicted,
string Actual,
int Count);
///
/// Model export formats.
///
public enum ModelExportFormat
{
/// ONNX format for cross-platform inference.
Onnx,
/// PyTorch format.
PyTorch,
/// TensorFlow SavedModel format.
TensorFlow
}
///
/// Options for ML service.
///
public sealed record MlOptions
{
/// Path to ONNX model file.
public string? ModelPath { get; init; }
/// Path to tokenizer vocabulary.
public string? VocabularyPath { get; init; }
/// Device to use for inference (cpu, cuda).
public string Device { get; init; } = "cpu";
/// Number of threads for inference.
public int NumThreads { get; init; } = 4;
/// Whether to use GPU if available.
public bool UseGpu { get; init; } = false;
/// Maximum batch size for inference.
public int MaxBatchSize { get; init; } = 32;
}