sprints work.

This commit is contained in:
master
2026-01-20 00:45:38 +02:00
parent b34bde89fa
commit 4903395618
275 changed files with 52785 additions and 79 deletions

View File

@@ -0,0 +1,244 @@
// -----------------------------------------------------------------------------
// B2R2IrTokenizer.cs
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
// Task: MLEM-003 - IR Token Extraction
// Description: B2R2-based IR tokenizer implementation.
// -----------------------------------------------------------------------------
using System.Text.RegularExpressions;
using Microsoft.Extensions.Logging;
namespace StellaOps.BinaryIndex.ML.Training;
/// <summary>
/// B2R2-based IR tokenizer for ML training input.
/// </summary>
public sealed partial class B2R2IrTokenizer : IIrTokenizer
{
private readonly ILogger<B2R2IrTokenizer> _logger;
// Token vocabulary for common IR elements
private static readonly HashSet<string> ControlFlowTokens =
["[JMP]", "[JE]", "[JNE]", "[JL]", "[JG]", "[JLE]", "[JGE]", "[CALL]", "[RET]", "[LOOP]"];
private static readonly HashSet<string> DataFlowTokens =
["[MOV]", "[LEA]", "[PUSH]", "[POP]", "[XCHG]", "[LOAD]", "[STORE]"];
private static readonly HashSet<string> ArithmeticTokens =
["[ADD]", "[SUB]", "[MUL]", "[DIV]", "[INC]", "[DEC]", "[NEG]", "[SHL]", "[SHR]", "[AND]", "[OR]", "[XOR]", "[NOT]"];
/// <summary>
/// Initializes a new instance of the <see cref="B2R2IrTokenizer"/> class.
/// </summary>
public B2R2IrTokenizer(ILogger<B2R2IrTokenizer> logger)
{
_logger = logger;
}
/// <inheritdoc />
public Task<IReadOnlyList<string>> TokenizeAsync(
string libraryName,
string version,
string functionName,
CancellationToken cancellationToken = default)
{
// This would integrate with B2R2 to lift the function to IR
// For now, return placeholder tokens
_logger.LogDebug("Tokenizing function {Function} from {Library}:{Version}",
functionName, libraryName, version);
var tokens = new List<string>
{
"[FUNC_START]",
$"[NAME:{NormalizeName(functionName)}]",
// IR tokens would be added here from B2R2 analysis
"[FUNC_END]"
};
return Task.FromResult<IReadOnlyList<string>>(tokens);
}
/// <inheritdoc />
public Task<IReadOnlyList<string>> TokenizeInstructionsAsync(
ReadOnlyMemory<byte> instructions,
string architecture,
TokenizationOptions? options = null,
CancellationToken cancellationToken = default)
{
options ??= TokenizationOptions.Default;
var tokens = new List<string>();
// Add architecture token
tokens.Add($"[ARCH:{architecture.ToUpperInvariant()}]");
tokens.Add("[FUNC_START]");
// Disassemble and tokenize
// This would use B2R2 for actual disassembly
var disassembly = DisassembleToIr(instructions, architecture);
var varCounter = 0;
var varMap = new Dictionary<string, string>();
foreach (var insn in disassembly)
{
// Add opcode token
var opcodeToken = MapOpcodeToToken(insn.Opcode);
tokens.Add(opcodeToken);
// Add operand tokens
foreach (var operand in insn.Operands)
{
var operandToken = options.NormalizeVariables
? NormalizeOperand(operand, varMap, ref varCounter)
: operand;
if (options.IncludeOperandTypes)
{
var typeToken = InferOperandType(operand);
tokens.Add($"{typeToken}:{operandToken}");
}
else
{
tokens.Add(operandToken);
}
}
// Add control flow marker if applicable
if (options.IncludeControlFlow && IsControlFlowInstruction(insn.Opcode))
{
tokens.Add("[CF]");
}
}
tokens.Add("[FUNC_END]");
// Truncate or pad to max length
if (tokens.Count > options.MaxLength)
{
tokens = tokens.Take(options.MaxLength - 1).Append("[TRUNCATED]").ToList();
}
return Task.FromResult<IReadOnlyList<string>>(tokens);
}
private static IReadOnlyList<DisassembledInstruction> DisassembleToIr(
ReadOnlyMemory<byte> instructions,
string architecture)
{
// Placeholder - would use B2R2 for actual disassembly
// Return sample instructions for demonstration
return new List<DisassembledInstruction>
{
new("push", ["rbp"]),
new("mov", ["rbp", "rsp"]),
new("sub", ["rsp", "0x20"]),
new("mov", ["[rbp-0x8]", "rdi"]),
new("call", ["helper_func"]),
new("leave", []),
new("ret", [])
};
}
private static string MapOpcodeToToken(string opcode)
{
var upper = opcode.ToUpperInvariant();
// Map to canonical token
return upper switch
{
"JMP" or "JE" or "JNE" or "JZ" or "JNZ" or "JL" or "JG" or "JLE" or "JGE" or "JA" or "JB" =>
$"[{upper}]",
"CALL" => "[CALL]",
"RET" or "RETN" => "[RET]",
"MOV" or "MOVZX" or "MOVSX" => "[MOV]",
"LEA" => "[LEA]",
"PUSH" => "[PUSH]",
"POP" => "[POP]",
"ADD" => "[ADD]",
"SUB" => "[SUB]",
"MUL" or "IMUL" => "[MUL]",
"DIV" or "IDIV" => "[DIV]",
"AND" => "[AND]",
"OR" => "[OR]",
"XOR" => "[XOR]",
"SHL" or "SAL" => "[SHL]",
"SHR" or "SAR" => "[SHR]",
"CMP" => "[CMP]",
"TEST" => "[TEST]",
"NOP" => "[NOP]",
_ => $"[{upper}]"
};
}
private static string NormalizeOperand(
string operand,
Dictionary<string, string> varMap,
ref int varCounter)
{
// Normalize registers to generic names
if (IsRegister(operand))
{
if (!varMap.TryGetValue(operand, out var normalized))
{
normalized = $"v{varCounter++}";
varMap[operand] = normalized;
}
return normalized;
}
// Normalize immediates
if (IsImmediate(operand))
{
return "[IMM]";
}
// Normalize memory references
if (operand.Contains('['))
{
return "[MEM]";
}
return operand;
}
private static string InferOperandType(string operand)
{
if (IsRegister(operand)) return "[REG]";
if (IsImmediate(operand)) return "[IMM]";
if (operand.Contains('[')) return "[MEM]";
if (operand.Contains("func") || operand.Contains("_")) return "[SYM]";
return "[UNK]";
}
private static bool IsRegister(string operand)
{
var lower = operand.ToLowerInvariant();
return lower.StartsWith("r") || lower.StartsWith("e") ||
lower is "rax" or "rbx" or "rcx" or "rdx" or "rsi" or "rdi" or "rsp" or "rbp" or
"eax" or "ebx" or "ecx" or "edx" or "esi" or "edi" or "esp" or "ebp" or
"ax" or "bx" or "cx" or "dx" or "si" or "di" or "sp" or "bp";
}
private static bool IsImmediate(string operand)
{
return operand.StartsWith("0x") || operand.All(char.IsDigit);
}
private static bool IsControlFlowInstruction(string opcode)
{
var upper = opcode.ToUpperInvariant();
return upper.StartsWith('J') || upper is "CALL" or "RET" or "RETN" or "LOOP";
}
private static string NormalizeName(string name)
{
// Remove version-specific suffixes, normalize casing
var normalized = NameNormalizationRegex().Replace(name, "");
return normalized.ToLowerInvariant();
}
[GeneratedRegex(@"@\d+|\.\d+|_v\d+")]
private static partial Regex NameNormalizationRegex();
private sealed record DisassembledInstruction(string Opcode, IReadOnlyList<string> Operands);
}

View File

@@ -0,0 +1,249 @@
// -----------------------------------------------------------------------------
// GhidraDecompilerAdapter.cs
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
// Task: MLEM-004 - Decompiled Code Extraction
// Description: Ghidra-based decompiler adapter implementation.
// -----------------------------------------------------------------------------
using System.Diagnostics;
using System.Text;
using System.Text.RegularExpressions;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
namespace StellaOps.BinaryIndex.ML.Training;
/// <summary>
/// Ghidra-based decompiler adapter.
/// </summary>
public sealed partial class GhidraDecompilerAdapter : IDecompilerAdapter
{
private readonly GhidraAdapterOptions _options;
private readonly ILogger<GhidraDecompilerAdapter> _logger;
/// <summary>
/// Initializes a new instance of the <see cref="GhidraDecompilerAdapter"/> class.
/// </summary>
public GhidraDecompilerAdapter(
IOptions<GhidraAdapterOptions> options,
ILogger<GhidraDecompilerAdapter> logger)
{
_options = options.Value;
_logger = logger;
}
/// <inheritdoc />
public async Task<string?> DecompileAsync(
string libraryName,
string version,
string functionName,
CancellationToken cancellationToken = default)
{
_logger.LogDebug("Decompiling {Function} from {Library}:{Version}",
functionName, libraryName, version);
// This would call Ghidra headless analyzer
// For now, return placeholder
return await Task.FromResult<string?>($"int {functionName}(void *param_1) {{\n int result;\n // Decompiled code placeholder\n result = 0;\n return result;\n}}");
}
/// <inheritdoc />
public async Task<string?> DecompileBytesAsync(
ReadOnlyMemory<byte> bytes,
string architecture,
DecompilationOptions? options = null,
CancellationToken cancellationToken = default)
{
options ??= DecompilationOptions.Default;
if (string.IsNullOrEmpty(_options.GhidraPath))
{
_logger.LogWarning("Ghidra path not configured");
return null;
}
try
{
// Create temp file with bytes
var tempInput = Path.GetTempFileName();
await File.WriteAllBytesAsync(tempInput, bytes.ToArray(), cancellationToken);
var tempOutput = Path.GetTempFileName();
try
{
// Run Ghidra headless
var script = _options.DecompileScriptPath ?? "DecompileFunction.java";
var args = $"-import {tempInput} -postScript {script} {tempOutput} -deleteProject -noanalysis";
var result = await RunGhidraAsync(args, options.Timeout, cancellationToken);
if (!result.Success)
{
_logger.LogWarning("Ghidra decompilation failed: {Error}", result.Error);
return null;
}
if (File.Exists(tempOutput))
{
var decompiled = await File.ReadAllTextAsync(tempOutput, cancellationToken);
return options.Simplify ? Normalize(decompiled) : decompiled;
}
return null;
}
finally
{
if (File.Exists(tempInput)) File.Delete(tempInput);
if (File.Exists(tempOutput)) File.Delete(tempOutput);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Decompilation failed");
return null;
}
}
/// <inheritdoc />
public string Normalize(string code, NormalizationOptions? options = null)
{
options ??= NormalizationOptions.Default;
var result = code;
// Strip comments
if (options.StripComments)
{
result = StripCommentsRegex().Replace(result, "");
result = LineCommentRegex().Replace(result, "");
}
// Normalize whitespace
if (options.NormalizeWhitespace)
{
result = MultipleSpacesRegex().Replace(result, " ");
result = EmptyLinesRegex().Replace(result, "\n");
result = result.Trim();
}
// Normalize variable names
if (options.NormalizeVariables)
{
var varCounter = 0;
var varMap = new Dictionary<string, string>();
result = VariableNameRegex().Replace(result, match =>
{
var name = match.Value;
if (!varMap.TryGetValue(name, out var normalized))
{
normalized = $"var_{varCounter++}";
varMap[name] = normalized;
}
return normalized;
});
}
// Remove type casts
if (options.RemoveTypeCasts)
{
result = TypeCastRegex().Replace(result, "");
}
// Truncate if too long
if (result.Length > options.MaxLength)
{
result = result[..options.MaxLength] + "\n/* truncated */";
}
return result;
}
private async Task<(bool Success, string? Error)> RunGhidraAsync(
string args,
TimeSpan timeout,
CancellationToken ct)
{
var analyzeHeadless = Path.Combine(_options.GhidraPath!, "support", "analyzeHeadless");
var psi = new ProcessStartInfo
{
FileName = analyzeHeadless,
Arguments = args,
RedirectStandardOutput = true,
RedirectStandardError = true,
UseShellExecute = false,
CreateNoWindow = true
};
using var process = new Process { StartInfo = psi };
var output = new StringBuilder();
var error = new StringBuilder();
process.OutputDataReceived += (_, e) =>
{
if (e.Data is not null) output.AppendLine(e.Data);
};
process.ErrorDataReceived += (_, e) =>
{
if (e.Data is not null) error.AppendLine(e.Data);
};
process.Start();
process.BeginOutputReadLine();
process.BeginErrorReadLine();
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(timeout);
try
{
await process.WaitForExitAsync(cts.Token);
return (process.ExitCode == 0, error.Length > 0 ? error.ToString() : null);
}
catch (OperationCanceledException)
{
process.Kill(true);
return (false, "Timeout");
}
}
[GeneratedRegex(@"/\*.*?\*/", RegexOptions.Singleline)]
private static partial Regex StripCommentsRegex();
[GeneratedRegex(@"//.*$", RegexOptions.Multiline)]
private static partial Regex LineCommentRegex();
[GeneratedRegex(@"\s+")]
private static partial Regex MultipleSpacesRegex();
[GeneratedRegex(@"\n\s*\n")]
private static partial Regex EmptyLinesRegex();
[GeneratedRegex(@"\b(local_|param_|DAT_|FUN_)[a-zA-Z0-9_]+")]
private static partial Regex VariableNameRegex();
[GeneratedRegex(@"\(\s*[a-zA-Z_][a-zA-Z0-9_]*\s*\*?\s*\)")]
private static partial Regex TypeCastRegex();
}
/// <summary>
/// Options for Ghidra adapter.
/// </summary>
public sealed record GhidraAdapterOptions
{
/// <summary>
/// Gets the path to Ghidra installation.
/// </summary>
public string? GhidraPath { get; init; }
/// <summary>
/// Gets the path to decompile script.
/// </summary>
public string? DecompileScriptPath { get; init; }
/// <summary>
/// Gets the project directory for temp projects.
/// </summary>
public string? ProjectDirectory { get; init; }
}

View File

@@ -0,0 +1,355 @@
// -----------------------------------------------------------------------------
// GroundTruthCorpusBuilder.cs
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
// Task: MLEM-002 - Corpus Builder from Ground-Truth
// Description: Implementation of corpus builder using ground-truth data.
// -----------------------------------------------------------------------------
using System.Text.Json;
using Microsoft.Extensions.Logging;
namespace StellaOps.BinaryIndex.ML.Training;
/// <summary>
/// Builds training corpus from ground-truth security pairs.
/// </summary>
public sealed class GroundTruthCorpusBuilder : ICorpusBuilder
{
private readonly IIrTokenizer _tokenizer;
private readonly IDecompilerAdapter _decompiler;
private readonly ILogger<GroundTruthCorpusBuilder> _logger;
private readonly List<TrainingFunctionPair> _positivePairs = [];
private readonly List<TrainingFunctionPair> _negativePairs = [];
private readonly Dictionary<string, FunctionRepresentation> _functionCache = [];
private readonly Random _random;
private static readonly JsonSerializerOptions JsonOptions = new()
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
WriteIndented = false
};
/// <summary>
/// Initializes a new instance of the <see cref="GroundTruthCorpusBuilder"/> class.
/// </summary>
public GroundTruthCorpusBuilder(
IIrTokenizer tokenizer,
IDecompilerAdapter decompiler,
ILogger<GroundTruthCorpusBuilder> logger,
int? randomSeed = null)
{
_tokenizer = tokenizer;
_decompiler = decompiler;
_logger = logger;
_random = randomSeed.HasValue ? new Random(randomSeed.Value) : new Random();
}
/// <inheritdoc />
public async Task<TrainingCorpus> BuildCorpusAsync(
CorpusBuildOptions options,
CancellationToken cancellationToken = default)
{
_logger.LogInformation("Building training corpus with target {Positive} positive, {Negative} negative pairs",
options.TargetPositivePairs, options.TargetNegativePairs);
// Load security pairs
if (options.SecurityPairPaths is { Count: > 0 })
{
foreach (var path in options.SecurityPairPaths)
{
await AddSecurityPairsAsync(path, cancellationToken);
}
}
// Generate negative pairs if needed
var neededNegatives = options.TargetNegativePairs - _negativePairs.Count;
if (neededNegatives > 0)
{
await GenerateNegativePairsAsync(neededNegatives, cancellationToken);
}
// Combine and shuffle
var allPairs = _positivePairs.Concat(_negativePairs).ToList();
Shuffle(allPairs);
// Split into train/val/test
var splitConfig = options.SplitConfig;
var trainCount = (int)(allPairs.Count * splitConfig.TrainRatio);
var valCount = (int)(allPairs.Count * splitConfig.ValidationRatio);
var trainPairs = allPairs.Take(trainCount).ToList();
var valPairs = allPairs.Skip(trainCount).Take(valCount).ToList();
var testPairs = allPairs.Skip(trainCount + valCount).ToList();
_logger.LogInformation(
"Corpus built: {Train} train, {Val} validation, {Test} test pairs",
trainPairs.Count, valPairs.Count, testPairs.Count);
return new TrainingCorpus
{
Version = "1.0",
CreatedAt = DateTimeOffset.UtcNow,
Description = "Ground-truth security pairs corpus",
TrainingPairs = trainPairs,
ValidationPairs = valPairs,
TestPairs = testPairs,
Statistics = GetStatistics()
};
}
/// <inheritdoc />
public async Task<int> AddSecurityPairsAsync(
string securityPairPath,
CancellationToken cancellationToken = default)
{
if (!File.Exists(securityPairPath))
{
_logger.LogWarning("Security pair file not found: {Path}", securityPairPath);
return 0;
}
var added = 0;
await foreach (var line in File.ReadLinesAsync(securityPairPath, cancellationToken))
{
if (string.IsNullOrWhiteSpace(line)) continue;
try
{
var pairData = JsonSerializer.Deserialize<SecurityPairData>(line, JsonOptions);
if (pairData is null) continue;
// Extract function pairs from security pair
var pairs = await ExtractFunctionPairsAsync(pairData, cancellationToken);
_positivePairs.AddRange(pairs);
added += pairs.Count;
}
catch (JsonException ex)
{
_logger.LogWarning(ex, "Failed to parse security pair line");
}
}
_logger.LogDebug("Added {Count} pairs from {Path}", added, securityPairPath);
return added;
}
/// <inheritdoc />
public async Task<int> GenerateNegativePairsAsync(
int count,
CancellationToken cancellationToken = default)
{
var functions = _functionCache.Values.ToList();
if (functions.Count < 2)
{
_logger.LogWarning("Not enough functions in cache to generate negative pairs");
return 0;
}
var generated = 0;
for (var i = 0; i < count && !cancellationToken.IsCancellationRequested; i++)
{
// Pick two random functions that are different
var idx1 = _random.Next(functions.Count);
var idx2 = _random.Next(functions.Count);
if (idx1 == idx2) idx2 = (idx2 + 1) % functions.Count;
var func1 = functions[idx1];
var func2 = functions[idx2];
// Skip if same function (by name) from different versions
if (func1.FunctionName == func2.FunctionName &&
func1.LibraryName == func2.LibraryName)
{
continue;
}
_negativePairs.Add(new TrainingFunctionPair
{
PairId = $"neg_{Guid.NewGuid():N}",
Function1 = func1,
Function2 = func2,
Label = EquivalenceLabel.Different,
Confidence = 1.0,
Source = "generated:negative_sampling"
});
generated++;
}
_logger.LogDebug("Generated {Count} negative pairs", generated);
return generated;
}
/// <inheritdoc />
public async Task ExportAsync(
string outputPath,
CorpusExportFormat format = CorpusExportFormat.JsonLines,
CancellationToken cancellationToken = default)
{
var allPairs = _positivePairs.Concat(_negativePairs);
var directory = Path.GetDirectoryName(outputPath);
if (!string.IsNullOrEmpty(directory))
{
Directory.CreateDirectory(directory);
}
switch (format)
{
case CorpusExportFormat.JsonLines:
await using (var writer = new StreamWriter(outputPath))
{
foreach (var pair in allPairs)
{
var json = JsonSerializer.Serialize(pair, JsonOptions);
await writer.WriteLineAsync(json);
}
}
break;
case CorpusExportFormat.Json:
var corpus = new TrainingCorpus
{
Version = "1.0",
CreatedAt = DateTimeOffset.UtcNow,
TrainingPairs = allPairs.ToList(),
Statistics = GetStatistics()
};
var corpusJson = JsonSerializer.Serialize(corpus, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
WriteIndented = true
});
await File.WriteAllTextAsync(outputPath, corpusJson, cancellationToken);
break;
default:
throw new NotSupportedException($"Export format {format} not yet supported");
}
_logger.LogInformation("Exported corpus to {Path}", outputPath);
}
/// <inheritdoc />
public CorpusStatistics GetStatistics()
{
var allPairs = _positivePairs.Concat(_negativePairs).ToList();
var allFunctions = allPairs
.SelectMany(p => new[] { p.Function1, p.Function2 })
.ToList();
return new CorpusStatistics
{
TotalPairs = allPairs.Count,
EquivalentPairs = allPairs.Count(p => p.Label == EquivalenceLabel.Equivalent),
DifferentPairs = allPairs.Count(p => p.Label == EquivalenceLabel.Different),
UnknownPairs = allPairs.Count(p => p.Label == EquivalenceLabel.Unknown),
UniqueLibraries = allFunctions.Select(f => f.LibraryName).Distinct().Count(),
UniqueFunctions = allFunctions.Select(f => f.FunctionName).Distinct().Count(),
Architectures = allFunctions.Select(f => f.Architecture).Distinct().ToList()
};
}
private async Task<List<TrainingFunctionPair>> ExtractFunctionPairsAsync(
SecurityPairData pairData,
CancellationToken ct)
{
var pairs = new List<TrainingFunctionPair>();
// For each affected function, create a positive pair
foreach (var funcName in pairData.AffectedFunctions ?? [])
{
var func1 = await GetFunctionRepresentationAsync(
pairData.LibraryName,
pairData.VersionBefore,
funcName,
pairData.Architecture ?? "x86_64",
ct);
var func2 = await GetFunctionRepresentationAsync(
pairData.LibraryName,
pairData.VersionAfter,
funcName,
pairData.Architecture ?? "x86_64",
ct);
if (func1 is not null && func2 is not null)
{
pairs.Add(new TrainingFunctionPair
{
PairId = $"pos_{pairData.CveId}_{funcName}_{Guid.NewGuid():N}",
Function1 = func1,
Function2 = func2,
Label = EquivalenceLabel.Equivalent,
Confidence = 1.0,
Source = $"groundtruth:security_pair:{pairData.CveId}",
Metadata = new TrainingPairMetadata
{
CveId = pairData.CveId,
IsPatched = true,
Distribution = pairData.Distribution
}
});
// Cache functions for negative pair generation
_functionCache[$"{func1.LibraryName}:{func1.LibraryVersion}:{func1.FunctionName}"] = func1;
_functionCache[$"{func2.LibraryName}:{func2.LibraryVersion}:{func2.FunctionName}"] = func2;
}
}
return pairs;
}
private async Task<FunctionRepresentation?> GetFunctionRepresentationAsync(
string libraryName,
string version,
string functionName,
string architecture,
CancellationToken ct)
{
// Extract IR tokens
var irTokens = await _tokenizer.TokenizeAsync(libraryName, version, functionName, ct);
// Get decompiled code
var decompiled = await _decompiler.DecompileAsync(libraryName, version, functionName, ct);
return new FunctionRepresentation
{
LibraryName = libraryName,
LibraryVersion = version,
FunctionName = functionName,
Architecture = architecture,
IrTokens = irTokens,
DecompiledCode = decompiled
};
}
private void Shuffle<T>(List<T> list)
{
var n = list.Count;
while (n > 1)
{
n--;
var k = _random.Next(n + 1);
(list[k], list[n]) = (list[n], list[k]);
}
}
}
/// <summary>
/// Security pair data from ground-truth.
/// </summary>
internal sealed record SecurityPairData
{
public string? CveId { get; init; }
public string LibraryName { get; init; } = "";
public string VersionBefore { get; init; } = "";
public string VersionAfter { get; init; } = "";
public IReadOnlyList<string>? AffectedFunctions { get; init; }
public string? Architecture { get; init; }
public string? Distribution { get; init; }
}

View File

@@ -0,0 +1,147 @@
// -----------------------------------------------------------------------------
// ICorpusBuilder.cs
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
// Task: MLEM-002 - Corpus Builder from Ground-Truth
// Description: Interface for building training corpus from ground-truth data.
// -----------------------------------------------------------------------------
namespace StellaOps.BinaryIndex.ML.Training;
/// <summary>
/// Builder for ML training corpus from ground-truth data.
/// </summary>
public interface ICorpusBuilder
{
/// <summary>
/// Builds a training corpus from security pairs.
/// </summary>
/// <param name="options">Build options.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>The built corpus.</returns>
Task<TrainingCorpus> BuildCorpusAsync(
CorpusBuildOptions options,
CancellationToken cancellationToken = default);
/// <summary>
/// Adds pairs from a security pair source.
/// </summary>
/// <param name="securityPairPath">Path to security pair data.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Number of pairs added.</returns>
Task<int> AddSecurityPairsAsync(
string securityPairPath,
CancellationToken cancellationToken = default);
/// <summary>
/// Generates negative pairs from existing functions.
/// </summary>
/// <param name="count">Number of negative pairs to generate.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Number of pairs generated.</returns>
Task<int> GenerateNegativePairsAsync(
int count,
CancellationToken cancellationToken = default);
/// <summary>
/// Exports the corpus to a file.
/// </summary>
/// <param name="outputPath">Output file path.</param>
/// <param name="format">Export format.</param>
/// <param name="cancellationToken">Cancellation token.</param>
Task ExportAsync(
string outputPath,
CorpusExportFormat format = CorpusExportFormat.JsonLines,
CancellationToken cancellationToken = default);
/// <summary>
/// Gets current build statistics.
/// </summary>
CorpusStatistics GetStatistics();
}
/// <summary>
/// Options for corpus building.
/// </summary>
public sealed record CorpusBuildOptions
{
/// <summary>
/// Gets paths to security pair data.
/// </summary>
public IReadOnlyList<string>? SecurityPairPaths { get; init; }
/// <summary>
/// Gets the target number of positive pairs.
/// </summary>
public int TargetPositivePairs { get; init; } = 15000;
/// <summary>
/// Gets the target number of negative pairs.
/// </summary>
public int TargetNegativePairs { get; init; } = 15000;
/// <summary>
/// Gets the split configuration.
/// </summary>
public CorpusSplitConfig SplitConfig { get; init; } = new();
/// <summary>
/// Gets whether to include IR tokens.
/// </summary>
public bool IncludeIrTokens { get; init; } = true;
/// <summary>
/// Gets whether to include decompiled code.
/// </summary>
public bool IncludeDecompiledCode { get; init; } = true;
/// <summary>
/// Gets whether to include fingerprints.
/// </summary>
public bool IncludeFingerprints { get; init; } = true;
/// <summary>
/// Gets the maximum IR token sequence length.
/// </summary>
public int MaxIrTokenLength { get; init; } = 512;
/// <summary>
/// Gets the maximum decompiled code length.
/// </summary>
public int MaxDecompiledLength { get; init; } = 2048;
/// <summary>
/// Gets libraries to include (null = all).
/// </summary>
public IReadOnlyList<string>? IncludeLibraries { get; init; }
/// <summary>
/// Gets architectures to include (null = all).
/// </summary>
public IReadOnlyList<string>? IncludeArchitectures { get; init; }
}
/// <summary>
/// Export format for corpus.
/// </summary>
public enum CorpusExportFormat
{
/// <summary>
/// JSON Lines format (one pair per line).
/// </summary>
JsonLines,
/// <summary>
/// Single JSON file.
/// </summary>
Json,
/// <summary>
/// Parquet format for large datasets.
/// </summary>
Parquet,
/// <summary>
/// HuggingFace datasets format.
/// </summary>
HuggingFace
}

View File

@@ -0,0 +1,133 @@
// -----------------------------------------------------------------------------
// IDecompilerAdapter.cs
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
// Task: MLEM-004 - Decompiled Code Extraction
// Description: Interface for decompiler integration.
// -----------------------------------------------------------------------------
namespace StellaOps.BinaryIndex.ML.Training;
/// <summary>
/// Adapter for decompiler integration.
/// </summary>
public interface IDecompilerAdapter
{
/// <summary>
/// Decompiles a function to C-like code.
/// </summary>
/// <param name="libraryName">Library name.</param>
/// <param name="version">Library version.</param>
/// <param name="functionName">Function name.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Decompiled code.</returns>
Task<string?> DecompileAsync(
string libraryName,
string version,
string functionName,
CancellationToken cancellationToken = default);
/// <summary>
/// Decompiles raw bytes to C-like code.
/// </summary>
/// <param name="bytes">Function bytes.</param>
/// <param name="architecture">Target architecture.</param>
/// <param name="options">Decompilation options.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Decompiled code.</returns>
Task<string?> DecompileBytesAsync(
ReadOnlyMemory<byte> bytes,
string architecture,
DecompilationOptions? options = null,
CancellationToken cancellationToken = default);
/// <summary>
/// Normalizes decompiled code for ML input.
/// </summary>
/// <param name="code">Raw decompiled code.</param>
/// <param name="options">Normalization options.</param>
/// <returns>Normalized code.</returns>
string Normalize(string code, NormalizationOptions? options = null);
}
/// <summary>
/// Options for decompilation.
/// </summary>
public sealed record DecompilationOptions
{
/// <summary>
/// Gets the decompiler to use.
/// </summary>
public DecompilerType Decompiler { get; init; } = DecompilerType.Ghidra;
/// <summary>
/// Gets whether to simplify the output.
/// </summary>
public bool Simplify { get; init; } = true;
/// <summary>
/// Gets the timeout for decompilation.
/// </summary>
public TimeSpan Timeout { get; init; } = TimeSpan.FromSeconds(30);
/// <summary>
/// Gets the default options.
/// </summary>
public static DecompilationOptions Default { get; } = new();
}
/// <summary>
/// Available decompilers.
/// </summary>
public enum DecompilerType
{
/// <summary>
/// Ghidra decompiler.
/// </summary>
Ghidra,
/// <summary>
/// RetDec decompiler.
/// </summary>
RetDec,
/// <summary>
/// Hex-Rays decompiler (IDA Pro).
/// </summary>
HexRays
}
/// <summary>
/// Options for code normalization.
/// </summary>
public sealed record NormalizationOptions
{
/// <summary>
/// Gets whether to strip comments.
/// </summary>
public bool StripComments { get; init; } = true;
/// <summary>
/// Gets whether to normalize variable names.
/// </summary>
public bool NormalizeVariables { get; init; } = true;
/// <summary>
/// Gets whether to normalize whitespace.
/// </summary>
public bool NormalizeWhitespace { get; init; } = true;
/// <summary>
/// Gets whether to remove type casts.
/// </summary>
public bool RemoveTypeCasts { get; init; } = false;
/// <summary>
/// Gets the maximum length.
/// </summary>
public int MaxLength { get; init; } = 2048;
/// <summary>
/// Gets the default options.
/// </summary>
public static NormalizationOptions Default { get; } = new();
}

View File

@@ -0,0 +1,123 @@
// -----------------------------------------------------------------------------
// IFunctionEmbeddingService.cs
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
// Task: MLEM-006 - Embedding Inference Service
// Description: Interface for function embedding inference.
// -----------------------------------------------------------------------------
namespace StellaOps.BinaryIndex.ML.Training;
/// <summary>
/// Service for computing function embeddings.
/// </summary>
public interface IFunctionEmbeddingService
{
/// <summary>
/// Computes an embedding for a function representation.
/// </summary>
/// <param name="function">Function representation.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Embedding vector.</returns>
Task<float[]> GetEmbeddingAsync(
FunctionRepresentation function,
CancellationToken cancellationToken = default);
/// <summary>
/// Computes embeddings for multiple functions (batched).
/// </summary>
/// <param name="functions">Function representations.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Embedding vectors.</returns>
Task<IReadOnlyList<float[]>> GetEmbeddingsBatchAsync(
IReadOnlyList<FunctionRepresentation> functions,
CancellationToken cancellationToken = default);
/// <summary>
/// Computes similarity between two embeddings.
/// </summary>
/// <param name="embedding1">First embedding.</param>
/// <param name="embedding2">Second embedding.</param>
/// <returns>Similarity score (0.0 to 1.0).</returns>
float ComputeSimilarity(float[] embedding1, float[] embedding2);
/// <summary>
/// Finds similar functions by embedding.
/// </summary>
/// <param name="queryEmbedding">Query embedding.</param>
/// <param name="topK">Number of results to return.</param>
/// <param name="threshold">Minimum similarity threshold.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Similar functions with scores.</returns>
Task<IReadOnlyList<EmbeddingSimilarityResult>> FindSimilarAsync(
float[] queryEmbedding,
int topK = 10,
float threshold = 0.7f,
CancellationToken cancellationToken = default);
/// <summary>
/// Gets model information.
/// </summary>
EmbeddingModelInfo GetModelInfo();
}
/// <summary>
/// Result of similarity search.
/// </summary>
public sealed record EmbeddingSimilarityResult
{
/// <summary>
/// Gets the function ID.
/// </summary>
public required string FunctionId { get; init; }
/// <summary>
/// Gets the function name.
/// </summary>
public required string FunctionName { get; init; }
/// <summary>
/// Gets the library name.
/// </summary>
public string? LibraryName { get; init; }
/// <summary>
/// Gets the library version.
/// </summary>
public string? LibraryVersion { get; init; }
/// <summary>
/// Gets the similarity score.
/// </summary>
public required float Similarity { get; init; }
}
/// <summary>
/// Information about the embedding model.
/// </summary>
public sealed record EmbeddingModelInfo
{
/// <summary>
/// Gets the model name.
/// </summary>
public required string Name { get; init; }
/// <summary>
/// Gets the model version.
/// </summary>
public required string Version { get; init; }
/// <summary>
/// Gets the embedding dimension.
/// </summary>
public required int Dimension { get; init; }
/// <summary>
/// Gets the maximum sequence length.
/// </summary>
public int MaxSequenceLength { get; init; }
/// <summary>
/// Gets whether the model is loaded.
/// </summary>
public bool IsLoaded { get; init; }
}

View File

@@ -0,0 +1,73 @@
// -----------------------------------------------------------------------------
// IIrTokenizer.cs
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
// Task: MLEM-003 - IR Token Extraction
// Description: Interface for IR tokenization for ML input.
// -----------------------------------------------------------------------------
namespace StellaOps.BinaryIndex.ML.Training;
/// <summary>
/// Tokenizes function IR for transformer input.
/// </summary>
public interface IIrTokenizer
{
/// <summary>
/// Tokenizes a function into IR tokens.
/// </summary>
/// <param name="libraryName">Library name.</param>
/// <param name="version">Library version.</param>
/// <param name="functionName">Function name.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>List of IR tokens.</returns>
Task<IReadOnlyList<string>> TokenizeAsync(
string libraryName,
string version,
string functionName,
CancellationToken cancellationToken = default);
/// <summary>
/// Tokenizes raw instruction bytes.
/// </summary>
/// <param name="instructions">Raw instruction bytes.</param>
/// <param name="architecture">Target architecture.</param>
/// <param name="options">Tokenization options.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>List of IR tokens.</returns>
Task<IReadOnlyList<string>> TokenizeInstructionsAsync(
ReadOnlyMemory<byte> instructions,
string architecture,
TokenizationOptions? options = null,
CancellationToken cancellationToken = default);
}
/// <summary>
/// Options for IR tokenization.
/// </summary>
public sealed record TokenizationOptions
{
/// <summary>
/// Gets the maximum token sequence length.
/// </summary>
public int MaxLength { get; init; } = 512;
/// <summary>
/// Gets whether to normalize variable names.
/// </summary>
public bool NormalizeVariables { get; init; } = true;
/// <summary>
/// Gets whether to include operand types.
/// </summary>
public bool IncludeOperandTypes { get; init; } = true;
/// <summary>
/// Gets whether to include control flow tokens.
/// </summary>
public bool IncludeControlFlow { get; init; } = true;
/// <summary>
/// Gets the default options.
/// </summary>
public static TokenizationOptions Default { get; } = new();
}

View File

@@ -0,0 +1,172 @@
// -----------------------------------------------------------------------------
// MlEmbeddingMatcherAdapter.cs
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
// Task: MLEM-007 - Ensemble Integration
// Description: Adapter for integrating ML embeddings into validation harness.
// -----------------------------------------------------------------------------
using Microsoft.Extensions.Logging;
namespace StellaOps.BinaryIndex.ML.Training;
/// <summary>
/// Matcher adapter for ML embeddings integration with validation harness.
/// </summary>
public sealed class MlEmbeddingMatcherAdapter
{
private readonly IFunctionEmbeddingService _embeddingService;
private readonly ILogger<MlEmbeddingMatcherAdapter> _logger;
/// <summary>
/// Gets the default weight for this matcher in the ensemble.
/// </summary>
public const double DefaultWeight = 0.25; // 25% per architecture doc
/// <summary>
/// Initializes a new instance of the <see cref="MlEmbeddingMatcherAdapter"/> class.
/// </summary>
public MlEmbeddingMatcherAdapter(
IFunctionEmbeddingService embeddingService,
ILogger<MlEmbeddingMatcherAdapter> logger)
{
_embeddingService = embeddingService;
_logger = logger;
}
/// <summary>
/// Computes match score between two functions using ML embeddings.
/// </summary>
/// <param name="function1">First function.</param>
/// <param name="function2">Second function.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Match score (0.0 to 1.0).</returns>
public async Task<double> ComputeMatchScoreAsync(
FunctionRepresentation function1,
FunctionRepresentation function2,
CancellationToken cancellationToken = default)
{
try
{
var embedding1 = await _embeddingService.GetEmbeddingAsync(function1, cancellationToken);
var embedding2 = await _embeddingService.GetEmbeddingAsync(function2, cancellationToken);
var similarity = _embeddingService.ComputeSimilarity(embedding1, embedding2);
_logger.LogDebug(
"ML embedding match score for {Func1} vs {Func2}: {Score:F4}",
function1.FunctionName,
function2.FunctionName,
similarity);
return similarity;
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to compute ML embedding score");
return 0.0;
}
}
/// <summary>
/// Computes match scores for a batch of function pairs.
/// </summary>
/// <param name="pairs">Function pairs to compare.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Match scores for each pair.</returns>
public async Task<IReadOnlyList<double>> ComputeMatchScoresBatchAsync(
IReadOnlyList<(FunctionRepresentation Function1, FunctionRepresentation Function2)> pairs,
CancellationToken cancellationToken = default)
{
var allFunctions = pairs
.SelectMany(p => new[] { p.Function1, p.Function2 })
.Distinct()
.ToList();
// Get all embeddings in batch
var embeddings = await _embeddingService.GetEmbeddingsBatchAsync(allFunctions, cancellationToken);
// Build lookup
var embeddingLookup = new Dictionary<string, float[]>();
for (var i = 0; i < allFunctions.Count; i++)
{
var key = GetFunctionKey(allFunctions[i]);
embeddingLookup[key] = embeddings[i];
}
// Compute scores
var scores = new List<double>();
foreach (var (func1, func2) in pairs)
{
var key1 = GetFunctionKey(func1);
var key2 = GetFunctionKey(func2);
if (embeddingLookup.TryGetValue(key1, out var emb1) &&
embeddingLookup.TryGetValue(key2, out var emb2))
{
scores.Add(_embeddingService.ComputeSimilarity(emb1, emb2));
}
else
{
scores.Add(0.0);
}
}
return scores;
}
/// <summary>
/// Gets ensemble weight configuration.
/// </summary>
public EnsembleWeightConfig GetEnsembleConfig() => new()
{
InstructionHashWeight = 0.15,
SemanticGraphWeight = 0.25,
DecompiledAstWeight = 0.35,
MlEmbeddingWeight = 0.25
};
private static string GetFunctionKey(FunctionRepresentation function)
{
return $"{function.LibraryName}:{function.LibraryVersion}:{function.FunctionName}:{function.Architecture}";
}
}
/// <summary>
/// Ensemble weight configuration.
/// </summary>
public sealed record EnsembleWeightConfig
{
/// <summary>
/// Gets the instruction hash matcher weight.
/// </summary>
public double InstructionHashWeight { get; init; } = 0.15;
/// <summary>
/// Gets the semantic graph matcher weight.
/// </summary>
public double SemanticGraphWeight { get; init; } = 0.25;
/// <summary>
/// Gets the decompiled AST matcher weight.
/// </summary>
public double DecompiledAstWeight { get; init; } = 0.35;
/// <summary>
/// Gets the ML embedding matcher weight.
/// </summary>
public double MlEmbeddingWeight { get; init; } = 0.25;
/// <summary>
/// Validates that weights sum to 1.0.
/// </summary>
public void Validate()
{
var sum = InstructionHashWeight + SemanticGraphWeight +
DecompiledAstWeight + MlEmbeddingWeight;
if (Math.Abs(sum - 1.0) > 0.001)
{
throw new InvalidOperationException(
$"Ensemble weights must sum to 1.0, but sum is {sum}");
}
}
}

View File

@@ -0,0 +1,309 @@
// -----------------------------------------------------------------------------
// OnnxFunctionEmbeddingService.cs
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
// Task: MLEM-006 - Embedding Inference Service
// Description: ONNX-based function embedding service.
// -----------------------------------------------------------------------------
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
namespace StellaOps.BinaryIndex.ML.Training;
/// <summary>
/// ONNX-based function embedding service.
/// </summary>
public sealed class OnnxFunctionEmbeddingService : IFunctionEmbeddingService, IDisposable
{
private readonly OnnxEmbeddingServiceOptions _options;
private readonly IIrTokenizer _tokenizer;
private readonly ILogger<OnnxFunctionEmbeddingService> _logger;
private readonly Dictionary<string, float[]> _embeddingCache = [];
private readonly SemaphoreSlim _cacheLock = new(1, 1);
private bool _modelLoaded;
private bool _disposed;
/// <summary>
/// Initializes a new instance of the <see cref="OnnxFunctionEmbeddingService"/> class.
/// </summary>
public OnnxFunctionEmbeddingService(
IOptions<OnnxEmbeddingServiceOptions> options,
IIrTokenizer tokenizer,
ILogger<OnnxFunctionEmbeddingService> logger)
{
_options = options.Value;
_tokenizer = tokenizer;
_logger = logger;
}
/// <inheritdoc />
public async Task<float[]> GetEmbeddingAsync(
FunctionRepresentation function,
CancellationToken cancellationToken = default)
{
var cacheKey = GetCacheKey(function);
// Check cache
if (_options.EnableCache)
{
await _cacheLock.WaitAsync(cancellationToken);
try
{
if (_embeddingCache.TryGetValue(cacheKey, out var cached))
{
return cached;
}
}
finally
{
_cacheLock.Release();
}
}
// Ensure model is loaded
await EnsureModelLoadedAsync(cancellationToken);
// Prepare input
var tokens = function.IrTokens?.ToList() ??
await _tokenizer.TokenizeAsync(
function.LibraryName,
function.LibraryVersion,
function.FunctionName,
cancellationToken) as List<string> ?? [];
// Pad or truncate to max length
var maxLen = _options.MaxSequenceLength;
if (tokens.Count > maxLen)
{
tokens = tokens.Take(maxLen).ToList();
}
else while (tokens.Count < maxLen)
{
tokens.Add("[PAD]");
}
// Tokenize to IDs (simplified - would use actual vocabulary)
var inputIds = tokens.Select(TokenToId).ToArray();
// Run inference
var embedding = await RunInferenceAsync(inputIds, cancellationToken);
// Cache result
if (_options.EnableCache)
{
await _cacheLock.WaitAsync(cancellationToken);
try
{
_embeddingCache[cacheKey] = embedding;
// Evict if cache is too large
if (_embeddingCache.Count > _options.MaxCacheSize)
{
var toRemove = _embeddingCache.Keys.First();
_embeddingCache.Remove(toRemove);
}
}
finally
{
_cacheLock.Release();
}
}
return embedding;
}
/// <inheritdoc />
public async Task<IReadOnlyList<float[]>> GetEmbeddingsBatchAsync(
IReadOnlyList<FunctionRepresentation> functions,
CancellationToken cancellationToken = default)
{
var results = new List<float[]>();
// Process in batches
var batchSize = _options.BatchSize;
for (var i = 0; i < functions.Count; i += batchSize)
{
var batch = functions.Skip(i).Take(batchSize);
var batchResults = await Task.WhenAll(
batch.Select(f => GetEmbeddingAsync(f, cancellationToken)));
results.AddRange(batchResults);
}
return results;
}
/// <inheritdoc />
public float ComputeSimilarity(float[] embedding1, float[] embedding2)
{
if (embedding1.Length != embedding2.Length)
{
throw new ArgumentException("Embeddings must have same dimension");
}
// Cosine similarity
var dot = Dot(embedding1, embedding2);
var norm1 = MathF.Sqrt(Dot(embedding1, embedding1));
var norm2 = MathF.Sqrt(Dot(embedding2, embedding2));
if (norm1 == 0 || norm2 == 0) return 0;
return dot / (norm1 * norm2);
}
private static float Dot(float[] a, float[] b)
{
float sum = 0;
for (int i = 0; i < a.Length; i++)
{
sum += a[i] * b[i];
}
return sum;
}
/// <inheritdoc />
public async Task<IReadOnlyList<EmbeddingSimilarityResult>> FindSimilarAsync(
float[] queryEmbedding,
int topK = 10,
float threshold = 0.7f,
CancellationToken cancellationToken = default)
{
var results = new List<EmbeddingSimilarityResult>();
await _cacheLock.WaitAsync(cancellationToken);
try
{
foreach (var (key, embedding) in _embeddingCache)
{
var similarity = ComputeSimilarity(queryEmbedding, embedding);
if (similarity >= threshold)
{
var parts = key.Split(':');
results.Add(new EmbeddingSimilarityResult
{
FunctionId = key,
FunctionName = parts.Length > 2 ? parts[2] : key,
LibraryName = parts.Length > 0 ? parts[0] : null,
LibraryVersion = parts.Length > 1 ? parts[1] : null,
Similarity = similarity
});
}
}
}
finally
{
_cacheLock.Release();
}
return results
.OrderByDescending(r => r.Similarity)
.Take(topK)
.ToList();
}
/// <inheritdoc />
public EmbeddingModelInfo GetModelInfo()
{
return new EmbeddingModelInfo
{
Name = _options.ModelName,
Version = _options.ModelVersion,
Dimension = _options.EmbeddingDimension,
MaxSequenceLength = _options.MaxSequenceLength,
IsLoaded = _modelLoaded
};
}
private Task EnsureModelLoadedAsync(CancellationToken ct)
{
if (_modelLoaded) return Task.CompletedTask;
if (string.IsNullOrEmpty(_options.ModelPath))
{
_logger.LogWarning("ONNX model path not configured, using placeholder embeddings");
return Task.CompletedTask;
}
_logger.LogInformation("Loading ONNX model from {Path}", _options.ModelPath);
// Model loading would happen here - for now mark as loaded
_modelLoaded = true;
return Task.CompletedTask;
}
private Task<float[]> RunInferenceAsync(long[] inputIds, CancellationToken ct)
{
// Return deterministic embedding based on input hash for testing
var rng = new Random(inputIds.GetHashCode());
var embedding = new float[_options.EmbeddingDimension];
for (var i = 0; i < embedding.Length; i++)
{
embedding[i] = (float)(rng.NextDouble() * 2 - 1);
}
return Task.FromResult(embedding);
}
private static long TokenToId(string token)
{
// Simplified tokenization - would use actual vocabulary
return token.GetHashCode() & 0x7FFFFFFF;
}
private static string GetCacheKey(FunctionRepresentation function)
{
return $"{function.LibraryName}:{function.LibraryVersion}:{function.FunctionName}";
}
/// <inheritdoc />
public void Dispose()
{
if (_disposed) return;
_disposed = true;
_cacheLock.Dispose();
}
}
/// <summary>
/// Options for ONNX embedding service.
/// </summary>
public sealed record OnnxEmbeddingServiceOptions
{
/// <summary>
/// Gets the path to ONNX model.
/// </summary>
public string? ModelPath { get; init; }
/// <summary>
/// Gets the model name.
/// </summary>
public string ModelName { get; init; } = "function-embeddings";
/// <summary>
/// Gets the model version.
/// </summary>
public string ModelVersion { get; init; } = "1.0";
/// <summary>
/// Gets the embedding dimension.
/// </summary>
public int EmbeddingDimension { get; init; } = 768;
/// <summary>
/// Gets the maximum sequence length.
/// </summary>
public int MaxSequenceLength { get; init; } = 512;
/// <summary>
/// Gets the batch size for inference.
/// </summary>
public int BatchSize { get; init; } = 16;
/// <summary>
/// Gets whether to enable caching.
/// </summary>
public bool EnableCache { get; init; } = true;
/// <summary>
/// Gets the maximum cache size.
/// </summary>
public int MaxCacheSize { get; init; } = 10000;
}

View File

@@ -0,0 +1,299 @@
// -----------------------------------------------------------------------------
// TrainingCorpusModels.cs
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
// Task: MLEM-001 - Training Corpus Schema
// Description: Schema definitions for ML training corpus.
// -----------------------------------------------------------------------------
using System.Text.Json.Serialization;
namespace StellaOps.BinaryIndex.ML.Training;
/// <summary>
/// A labeled function pair for ML training.
/// </summary>
public sealed record TrainingFunctionPair
{
/// <summary>
/// Gets the unique pair identifier.
/// </summary>
public required string PairId { get; init; }
/// <summary>
/// Gets the first function.
/// </summary>
public required FunctionRepresentation Function1 { get; init; }
/// <summary>
/// Gets the second function.
/// </summary>
public required FunctionRepresentation Function2 { get; init; }
/// <summary>
/// Gets the equivalence label.
/// </summary>
public required EquivalenceLabel Label { get; init; }
/// <summary>
/// Gets the confidence in the label (0.0 to 1.0).
/// </summary>
public double Confidence { get; init; } = 1.0;
/// <summary>
/// Gets the source of the ground-truth label.
/// </summary>
public required string Source { get; init; }
/// <summary>
/// Gets optional metadata about the pair.
/// </summary>
public TrainingPairMetadata? Metadata { get; init; }
}
/// <summary>
/// Representation of a function for training.
/// </summary>
public sealed record FunctionRepresentation
{
/// <summary>
/// Gets the library name.
/// </summary>
public required string LibraryName { get; init; }
/// <summary>
/// Gets the library version.
/// </summary>
public required string LibraryVersion { get; init; }
/// <summary>
/// Gets the function name.
/// </summary>
public required string FunctionName { get; init; }
/// <summary>
/// Gets the target architecture.
/// </summary>
public required string Architecture { get; init; }
/// <summary>
/// Gets the IR tokens (for transformer input).
/// </summary>
public IReadOnlyList<string>? IrTokens { get; init; }
/// <summary>
/// Gets the decompiled code.
/// </summary>
public string? DecompiledCode { get; init; }
/// <summary>
/// Gets computed fingerprints.
/// </summary>
public FunctionFingerprints? Fingerprints { get; init; }
/// <summary>
/// Gets the function size in bytes.
/// </summary>
public int? SizeBytes { get; init; }
/// <summary>
/// Gets the number of basic blocks.
/// </summary>
public int? BasicBlockCount { get; init; }
/// <summary>
/// Gets the cyclomatic complexity.
/// </summary>
public int? CyclomaticComplexity { get; init; }
}
/// <summary>
/// Function fingerprints for training data.
/// </summary>
public sealed record FunctionFingerprints
{
/// <summary>
/// Gets the instruction hash.
/// </summary>
public string? InstructionHash { get; init; }
/// <summary>
/// Gets the CFG hash.
/// </summary>
public string? CfgHash { get; init; }
/// <summary>
/// Gets the call graph hash.
/// </summary>
public string? CallGraphHash { get; init; }
/// <summary>
/// Gets mnemonic histogram.
/// </summary>
public IReadOnlyDictionary<string, int>? MnemonicHistogram { get; init; }
}
/// <summary>
/// Equivalence label for function pairs.
/// </summary>
[JsonConverter(typeof(JsonStringEnumConverter))]
public enum EquivalenceLabel
{
/// <summary>
/// Functions are equivalent (same semantics).
/// </summary>
Equivalent,
/// <summary>
/// Functions are different (different semantics).
/// </summary>
Different,
/// <summary>
/// Equivalence is unknown/uncertain.
/// </summary>
Unknown
}
/// <summary>
/// Metadata about a training pair.
/// </summary>
public sealed record TrainingPairMetadata
{
/// <summary>
/// Gets the CVE ID if from a security pair.
/// </summary>
public string? CveId { get; init; }
/// <summary>
/// Gets the patch type.
/// </summary>
public string? PatchType { get; init; }
/// <summary>
/// Gets whether the function is patched.
/// </summary>
public bool IsPatched { get; init; }
/// <summary>
/// Gets the distribution.
/// </summary>
public string? Distribution { get; init; }
/// <summary>
/// Gets additional tags.
/// </summary>
public IReadOnlyList<string>? Tags { get; init; }
}
/// <summary>
/// A training corpus containing labeled function pairs.
/// </summary>
public sealed record TrainingCorpus
{
/// <summary>
/// Gets the corpus version.
/// </summary>
public required string Version { get; init; }
/// <summary>
/// Gets when the corpus was created.
/// </summary>
public required DateTimeOffset CreatedAt { get; init; }
/// <summary>
/// Gets the corpus description.
/// </summary>
public string? Description { get; init; }
/// <summary>
/// Gets the training pairs.
/// </summary>
public required IReadOnlyList<TrainingFunctionPair> TrainingPairs { get; init; }
/// <summary>
/// Gets the validation pairs.
/// </summary>
public IReadOnlyList<TrainingFunctionPair>? ValidationPairs { get; init; }
/// <summary>
/// Gets the test pairs.
/// </summary>
public IReadOnlyList<TrainingFunctionPair>? TestPairs { get; init; }
/// <summary>
/// Gets corpus statistics.
/// </summary>
public CorpusStatistics? Statistics { get; init; }
}
/// <summary>
/// Statistics about a training corpus.
/// </summary>
public sealed record CorpusStatistics
{
/// <summary>
/// Gets total pair count.
/// </summary>
public int TotalPairs { get; init; }
/// <summary>
/// Gets equivalent pair count.
/// </summary>
public int EquivalentPairs { get; init; }
/// <summary>
/// Gets different pair count.
/// </summary>
public int DifferentPairs { get; init; }
/// <summary>
/// Gets unknown pair count.
/// </summary>
public int UnknownPairs { get; init; }
/// <summary>
/// Gets unique libraries.
/// </summary>
public int UniqueLibraries { get; init; }
/// <summary>
/// Gets unique functions.
/// </summary>
public int UniqueFunctions { get; init; }
/// <summary>
/// Gets architectures covered.
/// </summary>
public IReadOnlyList<string>? Architectures { get; init; }
}
/// <summary>
/// Configuration for corpus splitting.
/// </summary>
public sealed record CorpusSplitConfig
{
/// <summary>
/// Gets the training set ratio (default 0.8).
/// </summary>
public double TrainRatio { get; init; } = 0.8;
/// <summary>
/// Gets the validation set ratio (default 0.1).
/// </summary>
public double ValidationRatio { get; init; } = 0.1;
/// <summary>
/// Gets the test set ratio (default 0.1).
/// </summary>
public double TestRatio { get; init; } = 0.1;
/// <summary>
/// Gets the random seed for reproducibility.
/// </summary>
public int? RandomSeed { get; init; } = 42;
/// <summary>
/// Gets whether to stratify by library.
/// </summary>
public bool StratifyByLibrary { get; init; } = true;
}

View File

@@ -0,0 +1,83 @@
// -----------------------------------------------------------------------------
// TrainingServiceCollectionExtensions.cs
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
// Task: MLEM-007, MLEM-009 - DI Registration
// Description: Dependency injection extensions for ML training services.
// -----------------------------------------------------------------------------
using Microsoft.Extensions.DependencyInjection;
namespace StellaOps.BinaryIndex.ML.Training;
/// <summary>
/// Extension methods for registering ML training services.
/// </summary>
public static class TrainingServiceCollectionExtensions
{
/// <summary>
/// Adds ML training corpus services.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configureOptions">Configuration action.</param>
/// <returns>The service collection for chaining.</returns>
public static IServiceCollection AddMlTrainingCorpus(
this IServiceCollection services,
Action<MlTrainingOptions>? configureOptions = null)
{
// Register options
services.AddOptions<GhidraAdapterOptions>();
services.AddOptions<OnnxEmbeddingServiceOptions>();
if (configureOptions is not null)
{
var options = new MlTrainingOptions();
configureOptions(options);
services.Configure<GhidraAdapterOptions>(o =>
{
o = options.GhidraOptions ?? new GhidraAdapterOptions();
});
services.Configure<OnnxEmbeddingServiceOptions>(o =>
{
o = options.OnnxOptions ?? new OnnxEmbeddingServiceOptions();
});
}
// Register tokenizer and decompiler
services.AddSingleton<IIrTokenizer, B2R2IrTokenizer>();
services.AddSingleton<IDecompilerAdapter, GhidraDecompilerAdapter>();
// Register corpus builder
services.AddSingleton<ICorpusBuilder, GroundTruthCorpusBuilder>();
// Register embedding service
services.AddSingleton<IFunctionEmbeddingService, OnnxFunctionEmbeddingService>();
// Register matcher adapter
services.AddSingleton<MlEmbeddingMatcherAdapter>();
return services;
}
}
/// <summary>
/// Options for ML training infrastructure.
/// </summary>
public sealed record MlTrainingOptions
{
/// <summary>
/// Gets or sets Ghidra adapter options.
/// </summary>
public GhidraAdapterOptions? GhidraOptions { get; set; }
/// <summary>
/// Gets or sets ONNX embedding options.
/// </summary>
public OnnxEmbeddingServiceOptions? OnnxOptions { get; set; }
/// <summary>
/// Gets or sets corpus build options.
/// </summary>
public CorpusBuildOptions? CorpusBuildOptions { get; set; }
}

View File

@@ -0,0 +1,450 @@
#!/usr/bin/env python3
# -----------------------------------------------------------------------------
# train_function_embeddings.py
# Sprint: SPRINT_20260119_006 ML Embeddings Corpus
# Task: MLEM-005 - Embedding Model Training Pipeline
# Description: PyTorch/HuggingFace training script for contrastive learning.
# -----------------------------------------------------------------------------
"""
Function Embedding Training Pipeline
Uses contrastive learning to train CodeBERT-based function embeddings.
Positive pairs: Same function across versions
Negative pairs: Different functions
Usage:
python train_function_embeddings.py --corpus datasets/training_corpus.jsonl \
--output models/function_embeddings.onnx \
--epochs 10 --batch-size 32
Requirements:
pip install torch transformers onnx onnxruntime tensorboard
"""
import argparse
import json
import logging
import os
import random
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
try:
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
except ImportError:
print("Please install transformers: pip install transformers")
raise
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class TrainingConfig:
"""Training configuration."""
model_name: str = "microsoft/codebert-base"
corpus_path: str = "datasets/training_corpus.jsonl"
output_path: str = "models/function_embeddings"
# Training params
epochs: int = 10
batch_size: int = 32
learning_rate: float = 2e-5
warmup_steps: int = 500
weight_decay: float = 0.01
# Contrastive learning params
temperature: float = 0.07
margin: float = 0.5
# Model params
embedding_dim: int = 768
max_seq_length: int = 512
# Misc
seed: int = 42
device: str = "cuda" if torch.cuda.is_available() else "cpu"
log_dir: str = "runs/function_embeddings"
class FunctionPairDataset(Dataset):
"""Dataset for function pair contrastive learning."""
def __init__(self, corpus_path: str, tokenizer, max_length: int = 512):
self.tokenizer = tokenizer
self.max_length = max_length
self.pairs = []
logger.info(f"Loading corpus from {corpus_path}")
with open(corpus_path, 'r') as f:
for line in f:
if line.strip():
pair = json.loads(line)
self.pairs.append(pair)
logger.info(f"Loaded {len(self.pairs)} pairs")
def __len__(self) -> int:
return len(self.pairs)
def __getitem__(self, idx: int) -> dict:
pair = self.pairs[idx]
# Get function representations
func1 = pair.get("function1", {})
func2 = pair.get("function2", {})
# Prefer decompiled code, fall back to IR tokens
text1 = func1.get("decompiledCode") or " ".join(func1.get("irTokens", []))
text2 = func2.get("decompiledCode") or " ".join(func2.get("irTokens", []))
# Tokenize
enc1 = self.tokenizer(
text1,
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt"
)
enc2 = self.tokenizer(
text2,
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt"
)
# Label: 1 for equivalent, 0 for different
label = 1.0 if pair.get("label") == "equivalent" else 0.0
return {
"input_ids_1": enc1["input_ids"].squeeze(0),
"attention_mask_1": enc1["attention_mask"].squeeze(0),
"input_ids_2": enc2["input_ids"].squeeze(0),
"attention_mask_2": enc2["attention_mask"].squeeze(0),
"label": torch.tensor(label, dtype=torch.float)
}
class FunctionEmbeddingModel(nn.Module):
"""CodeBERT-based function embedding model."""
def __init__(self, model_name: str, embedding_dim: int = 768):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.embedding_dim = embedding_dim
# Projection head for contrastive learning
self.projection = nn.Sequential(
nn.Linear(self.encoder.config.hidden_size, embedding_dim),
nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim)
)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Compute function embedding."""
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Use [CLS] token representation
cls_output = outputs.last_hidden_state[:, 0, :]
# Project to embedding space
embedding = self.projection(cls_output)
# L2 normalize
embedding = F.normalize(embedding, p=2, dim=1)
return embedding
def get_embedding(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Get embedding without projection (for inference)."""
with torch.no_grad():
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
embedding = self.projection(cls_output)
return F.normalize(embedding, p=2, dim=1)
class ContrastiveLoss(nn.Module):
"""Contrastive loss with temperature scaling."""
def __init__(self, temperature: float = 0.07, margin: float = 0.5):
super().__init__()
self.temperature = temperature
self.margin = margin
def forward(
self,
embedding1: torch.Tensor,
embedding2: torch.Tensor,
labels: torch.Tensor
) -> torch.Tensor:
"""
Compute contrastive loss.
Args:
embedding1: First function embeddings [B, D]
embedding2: Second function embeddings [B, D]
labels: 1 for positive pairs, 0 for negative [B]
Returns:
Contrastive loss value
"""
# Cosine similarity
similarity = F.cosine_similarity(embedding1, embedding2) / self.temperature
# Contrastive loss
# Positive pairs: minimize distance (maximize similarity)
# Negative pairs: maximize distance (minimize similarity) up to margin
pos_loss = labels * (1 - similarity)
neg_loss = (1 - labels) * F.relu(similarity - self.margin)
loss = (pos_loss + neg_loss).mean()
return loss
def train_epoch(
model: FunctionEmbeddingModel,
dataloader: DataLoader,
criterion: ContrastiveLoss,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
device: str,
epoch: int,
writer: SummaryWriter
) -> float:
"""Train for one epoch."""
model.train()
total_loss = 0.0
for batch_idx, batch in enumerate(dataloader):
# Move to device
input_ids_1 = batch["input_ids_1"].to(device)
attention_mask_1 = batch["attention_mask_1"].to(device)
input_ids_2 = batch["input_ids_2"].to(device)
attention_mask_2 = batch["attention_mask_2"].to(device)
labels = batch["label"].to(device)
# Forward pass
emb1 = model(input_ids_1, attention_mask_1)
emb2 = model(input_ids_2, attention_mask_2)
# Compute loss
loss = criterion(emb1, emb2, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
if scheduler is not None:
scheduler.step()
total_loss += loss.item()
# Log to tensorboard
global_step = epoch * len(dataloader) + batch_idx
writer.add_scalar("train/loss", loss.item(), global_step)
if batch_idx % 100 == 0:
logger.info(f"Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
return total_loss / len(dataloader)
def evaluate(
model: FunctionEmbeddingModel,
dataloader: DataLoader,
criterion: ContrastiveLoss,
device: str
) -> Tuple[float, float]:
"""Evaluate model."""
model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch in dataloader:
input_ids_1 = batch["input_ids_1"].to(device)
attention_mask_1 = batch["attention_mask_1"].to(device)
input_ids_2 = batch["input_ids_2"].to(device)
attention_mask_2 = batch["attention_mask_2"].to(device)
labels = batch["label"].to(device)
emb1 = model(input_ids_1, attention_mask_1)
emb2 = model(input_ids_2, attention_mask_2)
loss = criterion(emb1, emb2, labels)
total_loss += loss.item()
# Accuracy: predict positive if similarity > 0.5
similarity = F.cosine_similarity(emb1, emb2)
predictions = (similarity > 0.5).float()
correct += (predictions == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct / total if total > 0 else 0.0
return avg_loss, accuracy
def export_onnx(
model: FunctionEmbeddingModel,
output_path: str,
max_seq_length: int = 512
):
"""Export model to ONNX format."""
model.eval()
# Dummy inputs
dummy_input_ids = torch.ones(1, max_seq_length, dtype=torch.long)
dummy_attention_mask = torch.ones(1, max_seq_length, dtype=torch.long)
# Export
output_file = f"{output_path}.onnx"
logger.info(f"Exporting model to {output_file}")
torch.onnx.export(
model,
(dummy_input_ids, dummy_attention_mask),
output_file,
input_names=["input_ids", "attention_mask"],
output_names=["embedding"],
dynamic_axes={
"input_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size"},
"embedding": {0: "batch_size"}
},
opset_version=14
)
logger.info(f"Model exported to {output_file}")
def main():
parser = argparse.ArgumentParser(description="Train function embedding model")
parser.add_argument("--corpus", type=str, default="datasets/training_corpus.jsonl",
help="Path to training corpus (JSONL format)")
parser.add_argument("--output", type=str, default="models/function_embeddings",
help="Output path for model")
parser.add_argument("--model-name", type=str, default="microsoft/codebert-base",
help="Base model name")
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size")
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()
# Config
config = TrainingConfig(
model_name=args.model_name,
corpus_path=args.corpus,
output_path=args.output,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
seed=args.seed
)
# Set seed
random.seed(config.seed)
torch.manual_seed(config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(config.seed)
logger.info(f"Using device: {config.device}")
# Load tokenizer
logger.info(f"Loading tokenizer: {config.model_name}")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
# Create dataset
dataset = FunctionPairDataset(config.corpus_path, tokenizer, config.max_seq_length)
# Split into train/val
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
# Create model
logger.info(f"Creating model: {config.model_name}")
model = FunctionEmbeddingModel(config.model_name, config.embedding_dim)
model.to(config.device)
# Loss and optimizer
criterion = ContrastiveLoss(config.temperature, config.margin)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
total_steps = len(train_loader) * config.epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=config.warmup_steps,
num_training_steps=total_steps
)
# TensorBoard
writer = SummaryWriter(config.log_dir)
# Training loop
best_val_loss = float('inf')
for epoch in range(config.epochs):
logger.info(f"=== Epoch {epoch + 1}/{config.epochs} ===")
train_loss = train_epoch(
model, train_loader, criterion, optimizer, scheduler,
config.device, epoch, writer
)
val_loss, val_accuracy = evaluate(model, val_loader, criterion, config.device)
logger.info(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
writer.add_scalar("val/loss", val_loss, epoch)
writer.add_scalar("val/accuracy", val_accuracy, epoch)
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
os.makedirs(config.output_path, exist_ok=True)
# Save PyTorch model
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_loss,
'val_accuracy': val_accuracy
}, f"{config.output_path}/best_model.pt")
logger.info(f"Saved best model with val_loss: {val_loss:.4f}")
# Export to ONNX
export_onnx(model, config.output_path, config.max_seq_length)
writer.close()
logger.info("Training complete!")
if __name__ == "__main__":
main()