// Copyright (c) StellaOps. All rights reserved. // Licensed under BUSL-1.1. See LICENSE in the project root. using StellaOps.BinaryIndex.Semantic; using System.Collections.Immutable; namespace StellaOps.BinaryIndex.ML; /// /// Input for generating function embeddings. /// /// Decompiled C-like code if available. /// Semantic graph from IR analysis if available. /// Raw instruction bytes if available. /// Which input type to prefer for embedding generation. public sealed record EmbeddingInput( string? DecompiledCode, KeySemanticsGraph? SemanticGraph, byte[]? InstructionBytes, EmbeddingInputType PreferredInput); /// /// Type of input for embedding generation. /// public enum EmbeddingInputType { /// Use decompiled C-like code. DecompiledCode, /// Use semantic graph from IR analysis. SemanticGraph, /// Use raw instruction bytes. Instructions } /// /// A function embedding vector. /// /// Identifier for the function. /// Name of the function. /// Embedding vector (typically 768 dimensions). /// Model used to generate the embedding. /// Type of input used. /// When the embedding was generated. public sealed record FunctionEmbedding( string FunctionId, string FunctionName, float[] Vector, EmbeddingModel Model, EmbeddingInputType InputType, DateTimeOffset GeneratedAt); /// /// Available embedding models. /// public enum EmbeddingModel { /// Fine-tuned CodeBERT for binary code analysis. CodeBertBinary, /// Graph neural network for CFG/call graph analysis. GraphSageFunction, /// Contrastive learning model for function similarity. ContrastiveFunction } /// /// Similarity metrics for comparing embeddings. /// public enum SimilarityMetric { /// Cosine similarity (angle between vectors). Cosine, /// Euclidean distance (inverted to similarity). Euclidean, /// Manhattan distance (inverted to similarity). Manhattan, /// Learned metric from model. LearnedMetric } /// /// A match from embedding similarity search. /// /// Matched function identifier. /// Matched function name. /// Similarity score (0.0 to 1.0). /// Library containing the function. /// Version of the library. public sealed record EmbeddingMatch( string FunctionId, string FunctionName, decimal Similarity, string? LibraryName, string? LibraryVersion); /// /// Options for embedding generation. /// public sealed record EmbeddingOptions { /// Maximum sequence length for tokenization. public int MaxSequenceLength { get; init; } = 512; /// Whether to normalize the embedding vector. public bool NormalizeVector { get; init; } = true; /// Batch size for batch inference. public int BatchSize { get; init; } = 32; } /// /// Training pair for model training. /// /// First function input. /// Second function input. /// Ground truth: are these the same function? /// Optional fine-grained similarity score. public sealed record TrainingPair( EmbeddingInput FunctionA, EmbeddingInput FunctionB, bool IsSimilar, decimal? SimilarityScore); /// /// Options for model training. /// public sealed record TrainingOptions { /// Model architecture to train. public EmbeddingModel Model { get; init; } = EmbeddingModel.CodeBertBinary; /// Embedding vector dimension. public int EmbeddingDimension { get; init; } = 768; /// Training batch size. public int BatchSize { get; init; } = 32; /// Number of training epochs. public int Epochs { get; init; } = 10; /// Learning rate. public double LearningRate { get; init; } = 1e-5; /// Margin for contrastive loss. public double MarginLoss { get; init; } = 0.5; /// Path to pretrained model weights. public string? PretrainedModelPath { get; init; } /// Path to save checkpoints. public string? CheckpointPath { get; init; } } /// /// Progress update during training. /// /// Current epoch. /// Total epochs. /// Current batch. /// Total batches. /// Current loss value. /// Current accuracy. public sealed record TrainingProgress( int Epoch, int TotalEpochs, int Batch, int TotalBatches, double Loss, double Accuracy); /// /// Result of model training. /// /// Path to saved model. /// Number of training pairs used. /// Number of epochs completed. /// Final loss value. /// Validation accuracy. /// Total training time. public sealed record TrainingResult( string ModelPath, int TotalPairs, int Epochs, double FinalLoss, double ValidationAccuracy, TimeSpan TrainingTime); /// /// Result of model evaluation. /// /// Overall accuracy. /// Precision (true positives / predicted positives). /// Recall (true positives / actual positives). /// F1 score (harmonic mean of precision and recall). /// Area under ROC curve. /// Confusion matrix entries. public sealed record EvaluationResult( double Accuracy, double Precision, double Recall, double F1Score, double AucRoc, ImmutableArray ConfusionMatrix); /// /// Entry in confusion matrix. /// /// Predicted label. /// Actual label. /// Number of occurrences. public sealed record ConfusionEntry( string Predicted, string Actual, int Count); /// /// Model export formats. /// public enum ModelExportFormat { /// ONNX format for cross-platform inference. Onnx, /// PyTorch format. PyTorch, /// TensorFlow SavedModel format. TensorFlow } /// /// Options for ML service. /// public sealed record MlOptions { /// Path to ONNX model file. public string? ModelPath { get; init; } /// Path to tokenizer vocabulary. public string? VocabularyPath { get; init; } /// Device to use for inference (cpu, cuda). public string Device { get; init; } = "cpu"; /// Number of threads for inference. public int NumThreads { get; init; } = 4; /// Whether to use GPU if available. public bool UseGpu { get; init; } = false; /// Maximum batch size for inference. public int MaxBatchSize { get; init; } = 32; }