Files
git.stella-ops.org/docs/modules/binary-index/ml-model-training.md
StellaOps Bot 37e11918e0 save progress
2026-01-06 09:42:20 +02:00

9.6 KiB

BinaryIndex ML Model Training Guide

This document describes how to train, export, and deploy ML models for the BinaryIndex binary similarity detection system.

Overview

The BinaryIndex ML pipeline uses transformer-based models to generate function embeddings that capture semantic similarity. The primary model is CodeBERT-Binary, a fine-tuned variant of CodeBERT optimized for decompiled binary code comparison.

Architecture

┌─────────────────────────────────────────────────────────────────────┐
│                    Model Training Pipeline                          │
│                                                                     │
│  ┌───────────────┐    ┌────────────────┐    ┌──────────────────┐  │
│  │ Training Data │ -> │ Fine-tuning    │ -> │ Model Export     │  │
│  │ (Function     │    │ (Contrastive   │    │ (ONNX format)    │  │
│  │ Pairs)        │    │ Learning)      │    │                  │  │
│  └───────────────┘    └────────────────┘    └──────────────────┘  │
│                                                                     │
│  ┌───────────────────────────────────────────────────────────────┐ │
│  │                    Inference Pipeline                         │ │
│  │                                                               │ │
│  │  Code -> Tokenizer -> ONNX Runtime -> Embedding (768-dim)    │ │
│  │                                                               │ │
│  └───────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘

Training Data Requirements

Positive Pairs (Similar Functions)

Source Description Estimated Count
Same function, different optimization O0 vs O2 vs O3 compilations ~50,000
Same function, different compiler GCC vs Clang vs MSVC ~30,000
Same function, different version From corpus snapshots ~100,000
Vulnerability patches Vulnerable vs fixed versions ~20,000

Negative Pairs (Dissimilar Functions)

Source Description Estimated Count
Random function pairs Random sampling from corpus ~100,000
Similar-named different functions Hard negatives for robustness ~50,000
Same library, different functions Medium-difficulty negatives ~50,000

Total training data: ~400,000 labeled pairs

Data Format

Training data is stored as JSON Lines (JSONL) format:

{"function_a": "int sum(int* a, int n) { int s = 0; for (int i = 0; i < n; i++) s += a[i]; return s; }", "function_b": "int total(int* arr, int len) { int t = 0; for (int j = 0; j < len; j++) t += arr[j]; return t; }", "is_similar": true, "similarity_score": 0.95}
{"function_a": "int sum(int* a, int n) { ... }", "function_b": "void print(char* s) { ... }", "is_similar": false, "similarity_score": 0.1}

Training Process

Prerequisites

  • Python 3.10+
  • PyTorch 2.0+
  • Transformers 4.30+
  • CUDA 11.8+ (for GPU training)
  • 64GB RAM, 32GB VRAM (V100 or A100 recommended)

Installation

cd tools/ml
pip install -r requirements.txt

Configuration

Create a training configuration file config/training.yaml:

model:
  base_model: microsoft/codebert-base
  embedding_dim: 768
  max_sequence_length: 512

training:
  batch_size: 32
  epochs: 10
  learning_rate: 1e-5
  warmup_steps: 1000
  weight_decay: 0.01

contrastive:
  margin: 0.5
  temperature: 0.07

data:
  train_path: data/train.jsonl
  val_path: data/val.jsonl
  test_path: data/test.jsonl

output:
  model_dir: models/codebert-binary
  checkpoint_interval: 1000

Running Training

python train_codebert_binary.py --config config/training.yaml

Training logs are written to logs/ and checkpoints to models/.

Training Script Overview

# tools/ml/train_codebert_binary.py

class CodeBertBinaryModel(torch.nn.Module):
    """CodeBERT fine-tuned for binary code similarity."""

    def __init__(self, pretrained_model="microsoft/codebert-base"):
        super().__init__()
        self.encoder = RobertaModel.from_pretrained(pretrained_model)
        self.projection = torch.nn.Linear(768, 768)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0, :]  # [CLS] token
        projected = self.projection(pooled)
        return torch.nn.functional.normalize(projected, p=2, dim=1)


class ContrastiveLoss(torch.nn.Module):
    """Contrastive loss for learning similarity embeddings."""

    def __init__(self, margin=0.5):
        super().__init__()
        self.margin = margin

    def forward(self, embedding_a, embedding_b, label):
        distance = torch.nn.functional.pairwise_distance(embedding_a, embedding_b)
        # label=1: similar, label=0: dissimilar
        loss = label * distance.pow(2) + \
               (1 - label) * torch.clamp(self.margin - distance, min=0).pow(2)
        return loss.mean()

Model Export

After training, export the model to ONNX format for inference:

python export_onnx.py \
    --model models/codebert-binary/best.pt \
    --output models/codebert-binary.onnx \
    --opset 17

Export Script

# tools/ml/export_onnx.py

def export_to_onnx(model, output_path):
    model.eval()
    dummy_input = torch.randint(0, 50000, (1, 512))
    dummy_mask = torch.ones(1, 512)

    torch.onnx.export(
        model,
        (dummy_input, dummy_mask),
        output_path,
        input_names=['input_ids', 'attention_mask'],
        output_names=['embedding'],
        dynamic_axes={
            'input_ids': {0: 'batch', 1: 'seq'},
            'attention_mask': {0: 'batch', 1: 'seq'},
            'embedding': {0: 'batch'}
        },
        opset_version=17
    )

Deployment

Configuration

Configure the ML service in your application:

# etc/binaryindex.yaml
ml:
  enabled: true
  model_path: /opt/stellaops/models/codebert-binary.onnx
  vocabulary_path: /opt/stellaops/models/vocab.txt
  num_threads: 4
  batch_size: 16

Code Integration

// Register ML services
services.AddMlServices(options =>
{
    options.ModelPath = config["ml:model_path"];
    options.VocabularyPath = config["ml:vocabulary_path"];
    options.NumThreads = config.GetValue<int>("ml:num_threads");
});

// Use embedding service
var embedding = await embeddingService.GenerateEmbeddingAsync(
    new EmbeddingInput(decompiledCode, null, null, EmbeddingInputType.DecompiledCode));

// Compare embeddings
var similarity = embeddingService.ComputeSimilarity(embA, embB, SimilarityMetric.Cosine);

Fallback Mode

When no ONNX model is available, the system generates hash-based pseudo-embeddings:

// In OnnxInferenceEngine.cs
if (_session is null)
{
    // Fallback: generate hash-based pseudo-embedding for testing
    vector = GenerateFallbackEmbedding(text, 768);
}

This allows the system to operate without a trained model (useful for testing) but with reduced accuracy.

Evaluation

Metrics

Metric Definition Target
Accuracy (TP + TN) / Total > 90%
Precision TP / (TP + FP) > 95%
Recall TP / (TP + FN) > 85%
F1 Score 2 * P * R / (P + R) > 90%
Latency Per-function embedding time < 100ms

Running Evaluation

python evaluate.py \
    --model models/codebert-binary.onnx \
    --test data/test.jsonl \
    --output results/evaluation.json

Benchmark Results

From EnsembleAccuracyBenchmarks:

Approach Accuracy Precision Recall F1 Score Latency
Phase 1 (Hash only) 70% 100% 0% 0% 1ms
AST only 75% 80% 70% 74% 5ms
Embedding only 80% 85% 75% 80% 50ms
Ensemble (Phase 4) 92% 95% 88% 91% 80ms

Troubleshooting

Common Issues

Model not loading:

  • Verify ONNX file path is correct
  • Check ONNX Runtime is installed: dotnet add package Microsoft.ML.OnnxRuntime
  • Ensure model was exported with compatible opset version

Low accuracy:

  • Verify training data quality and balance
  • Check for data leakage between train/test splits
  • Adjust contrastive loss margin

High latency:

  • Reduce max sequence length (default 512)
  • Enable batching for bulk operations
  • Consider GPU acceleration for high-volume deployments

Logging

Enable detailed ML logging:

services.AddLogging(builder =>
{
    builder.AddFilter("StellaOps.BinaryIndex.ML", LogLevel.Debug);
});

References