save progress
This commit is contained in:
304
docs/modules/binary-index/ml-model-training.md
Normal file
304
docs/modules/binary-index/ml-model-training.md
Normal file
@@ -0,0 +1,304 @@
|
||||
# BinaryIndex ML Model Training Guide
|
||||
|
||||
This document describes how to train, export, and deploy ML models for the BinaryIndex binary similarity detection system.
|
||||
|
||||
## Overview
|
||||
|
||||
The BinaryIndex ML pipeline uses transformer-based models to generate function embeddings that capture semantic similarity. The primary model is **CodeBERT-Binary**, a fine-tuned variant of CodeBERT optimized for decompiled binary code comparison.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ Model Training Pipeline │
|
||||
│ │
|
||||
│ ┌───────────────┐ ┌────────────────┐ ┌──────────────────┐ │
|
||||
│ │ Training Data │ -> │ Fine-tuning │ -> │ Model Export │ │
|
||||
│ │ (Function │ │ (Contrastive │ │ (ONNX format) │ │
|
||||
│ │ Pairs) │ │ Learning) │ │ │ │
|
||||
│ └───────────────┘ └────────────────┘ └──────────────────┘ │
|
||||
│ │
|
||||
│ ┌───────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Inference Pipeline │ │
|
||||
│ │ │ │
|
||||
│ │ Code -> Tokenizer -> ONNX Runtime -> Embedding (768-dim) │ │
|
||||
│ │ │ │
|
||||
│ └───────────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Training Data Requirements
|
||||
|
||||
### Positive Pairs (Similar Functions)
|
||||
|
||||
| Source | Description | Estimated Count |
|
||||
|--------|-------------|-----------------|
|
||||
| Same function, different optimization | O0 vs O2 vs O3 compilations | ~50,000 |
|
||||
| Same function, different compiler | GCC vs Clang vs MSVC | ~30,000 |
|
||||
| Same function, different version | From corpus snapshots | ~100,000 |
|
||||
| Vulnerability patches | Vulnerable vs fixed versions | ~20,000 |
|
||||
|
||||
### Negative Pairs (Dissimilar Functions)
|
||||
|
||||
| Source | Description | Estimated Count |
|
||||
|--------|-------------|-----------------|
|
||||
| Random function pairs | Random sampling from corpus | ~100,000 |
|
||||
| Similar-named different functions | Hard negatives for robustness | ~50,000 |
|
||||
| Same library, different functions | Medium-difficulty negatives | ~50,000 |
|
||||
|
||||
**Total training data:** ~400,000 labeled pairs
|
||||
|
||||
### Data Format
|
||||
|
||||
Training data is stored as JSON Lines (JSONL) format:
|
||||
|
||||
```json
|
||||
{"function_a": "int sum(int* a, int n) { int s = 0; for (int i = 0; i < n; i++) s += a[i]; return s; }", "function_b": "int total(int* arr, int len) { int t = 0; for (int j = 0; j < len; j++) t += arr[j]; return t; }", "is_similar": true, "similarity_score": 0.95}
|
||||
{"function_a": "int sum(int* a, int n) { ... }", "function_b": "void print(char* s) { ... }", "is_similar": false, "similarity_score": 0.1}
|
||||
```
|
||||
|
||||
## Training Process
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.10+
|
||||
- PyTorch 2.0+
|
||||
- Transformers 4.30+
|
||||
- CUDA 11.8+ (for GPU training)
|
||||
- 64GB RAM, 32GB VRAM (V100 or A100 recommended)
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
cd tools/ml
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Create a training configuration file `config/training.yaml`:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
base_model: microsoft/codebert-base
|
||||
embedding_dim: 768
|
||||
max_sequence_length: 512
|
||||
|
||||
training:
|
||||
batch_size: 32
|
||||
epochs: 10
|
||||
learning_rate: 1e-5
|
||||
warmup_steps: 1000
|
||||
weight_decay: 0.01
|
||||
|
||||
contrastive:
|
||||
margin: 0.5
|
||||
temperature: 0.07
|
||||
|
||||
data:
|
||||
train_path: data/train.jsonl
|
||||
val_path: data/val.jsonl
|
||||
test_path: data/test.jsonl
|
||||
|
||||
output:
|
||||
model_dir: models/codebert-binary
|
||||
checkpoint_interval: 1000
|
||||
```
|
||||
|
||||
### Running Training
|
||||
|
||||
```bash
|
||||
python train_codebert_binary.py --config config/training.yaml
|
||||
```
|
||||
|
||||
Training logs are written to `logs/` and checkpoints to `models/`.
|
||||
|
||||
### Training Script Overview
|
||||
|
||||
```python
|
||||
# tools/ml/train_codebert_binary.py
|
||||
|
||||
class CodeBertBinaryModel(torch.nn.Module):
|
||||
"""CodeBERT fine-tuned for binary code similarity."""
|
||||
|
||||
def __init__(self, pretrained_model="microsoft/codebert-base"):
|
||||
super().__init__()
|
||||
self.encoder = RobertaModel.from_pretrained(pretrained_model)
|
||||
self.projection = torch.nn.Linear(768, 768)
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
outputs = self.encoder(input_ids, attention_mask=attention_mask)
|
||||
pooled = outputs.last_hidden_state[:, 0, :] # [CLS] token
|
||||
projected = self.projection(pooled)
|
||||
return torch.nn.functional.normalize(projected, p=2, dim=1)
|
||||
|
||||
|
||||
class ContrastiveLoss(torch.nn.Module):
|
||||
"""Contrastive loss for learning similarity embeddings."""
|
||||
|
||||
def __init__(self, margin=0.5):
|
||||
super().__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, embedding_a, embedding_b, label):
|
||||
distance = torch.nn.functional.pairwise_distance(embedding_a, embedding_b)
|
||||
# label=1: similar, label=0: dissimilar
|
||||
loss = label * distance.pow(2) + \
|
||||
(1 - label) * torch.clamp(self.margin - distance, min=0).pow(2)
|
||||
return loss.mean()
|
||||
```
|
||||
|
||||
## Model Export
|
||||
|
||||
After training, export the model to ONNX format for inference:
|
||||
|
||||
```bash
|
||||
python export_onnx.py \
|
||||
--model models/codebert-binary/best.pt \
|
||||
--output models/codebert-binary.onnx \
|
||||
--opset 17
|
||||
```
|
||||
|
||||
### Export Script
|
||||
|
||||
```python
|
||||
# tools/ml/export_onnx.py
|
||||
|
||||
def export_to_onnx(model, output_path):
|
||||
model.eval()
|
||||
dummy_input = torch.randint(0, 50000, (1, 512))
|
||||
dummy_mask = torch.ones(1, 512)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(dummy_input, dummy_mask),
|
||||
output_path,
|
||||
input_names=['input_ids', 'attention_mask'],
|
||||
output_names=['embedding'],
|
||||
dynamic_axes={
|
||||
'input_ids': {0: 'batch', 1: 'seq'},
|
||||
'attention_mask': {0: 'batch', 1: 'seq'},
|
||||
'embedding': {0: 'batch'}
|
||||
},
|
||||
opset_version=17
|
||||
)
|
||||
```
|
||||
|
||||
## Deployment
|
||||
|
||||
### Configuration
|
||||
|
||||
Configure the ML service in your application:
|
||||
|
||||
```yaml
|
||||
# etc/binaryindex.yaml
|
||||
ml:
|
||||
enabled: true
|
||||
model_path: /opt/stellaops/models/codebert-binary.onnx
|
||||
vocabulary_path: /opt/stellaops/models/vocab.txt
|
||||
num_threads: 4
|
||||
batch_size: 16
|
||||
```
|
||||
|
||||
### Code Integration
|
||||
|
||||
```csharp
|
||||
// Register ML services
|
||||
services.AddMlServices(options =>
|
||||
{
|
||||
options.ModelPath = config["ml:model_path"];
|
||||
options.VocabularyPath = config["ml:vocabulary_path"];
|
||||
options.NumThreads = config.GetValue<int>("ml:num_threads");
|
||||
});
|
||||
|
||||
// Use embedding service
|
||||
var embedding = await embeddingService.GenerateEmbeddingAsync(
|
||||
new EmbeddingInput(decompiledCode, null, null, EmbeddingInputType.DecompiledCode));
|
||||
|
||||
// Compare embeddings
|
||||
var similarity = embeddingService.ComputeSimilarity(embA, embB, SimilarityMetric.Cosine);
|
||||
```
|
||||
|
||||
### Fallback Mode
|
||||
|
||||
When no ONNX model is available, the system generates hash-based pseudo-embeddings:
|
||||
|
||||
```csharp
|
||||
// In OnnxInferenceEngine.cs
|
||||
if (_session is null)
|
||||
{
|
||||
// Fallback: generate hash-based pseudo-embedding for testing
|
||||
vector = GenerateFallbackEmbedding(text, 768);
|
||||
}
|
||||
```
|
||||
|
||||
This allows the system to operate without a trained model (useful for testing) but with reduced accuracy.
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Metrics
|
||||
|
||||
| Metric | Definition | Target |
|
||||
|--------|------------|--------|
|
||||
| Accuracy | (TP + TN) / Total | > 90% |
|
||||
| Precision | TP / (TP + FP) | > 95% |
|
||||
| Recall | TP / (TP + FN) | > 85% |
|
||||
| F1 Score | 2 * P * R / (P + R) | > 90% |
|
||||
| Latency | Per-function embedding time | < 100ms |
|
||||
|
||||
### Running Evaluation
|
||||
|
||||
```bash
|
||||
python evaluate.py \
|
||||
--model models/codebert-binary.onnx \
|
||||
--test data/test.jsonl \
|
||||
--output results/evaluation.json
|
||||
```
|
||||
|
||||
### Benchmark Results
|
||||
|
||||
From `EnsembleAccuracyBenchmarks`:
|
||||
|
||||
| Approach | Accuracy | Precision | Recall | F1 Score | Latency |
|
||||
|----------|----------|-----------|--------|----------|---------|
|
||||
| Phase 1 (Hash only) | 70% | 100% | 0% | 0% | 1ms |
|
||||
| AST only | 75% | 80% | 70% | 74% | 5ms |
|
||||
| Embedding only | 80% | 85% | 75% | 80% | 50ms |
|
||||
| Ensemble (Phase 4) | 92% | 95% | 88% | 91% | 80ms |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Model not loading:**
|
||||
- Verify ONNX file path is correct
|
||||
- Check ONNX Runtime is installed: `dotnet add package Microsoft.ML.OnnxRuntime`
|
||||
- Ensure model was exported with compatible opset version
|
||||
|
||||
**Low accuracy:**
|
||||
- Verify training data quality and balance
|
||||
- Check for data leakage between train/test splits
|
||||
- Adjust contrastive loss margin
|
||||
|
||||
**High latency:**
|
||||
- Reduce max sequence length (default 512)
|
||||
- Enable batching for bulk operations
|
||||
- Consider GPU acceleration for high-volume deployments
|
||||
|
||||
### Logging
|
||||
|
||||
Enable detailed ML logging:
|
||||
|
||||
```csharp
|
||||
services.AddLogging(builder =>
|
||||
{
|
||||
builder.AddFilter("StellaOps.BinaryIndex.ML", LogLevel.Debug);
|
||||
});
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [CodeBERT Paper](https://arxiv.org/abs/2002.08155)
|
||||
- [Binary Code Similarity Detection](https://arxiv.org/abs/2308.01463)
|
||||
- [ONNX Runtime Documentation](https://onnxruntime.ai/docs/)
|
||||
- [Contrastive Learning for Code](https://arxiv.org/abs/2103.03143)
|
||||
Reference in New Issue
Block a user