sprints work.
This commit is contained in:
@@ -0,0 +1,244 @@
|
||||
// -----------------------------------------------------------------------------
|
||||
// B2R2IrTokenizer.cs
|
||||
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
// Task: MLEM-003 - IR Token Extraction
|
||||
// Description: B2R2-based IR tokenizer implementation.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
using System.Text.RegularExpressions;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML.Training;
|
||||
|
||||
/// <summary>
|
||||
/// B2R2-based IR tokenizer for ML training input.
|
||||
/// </summary>
|
||||
public sealed partial class B2R2IrTokenizer : IIrTokenizer
|
||||
{
|
||||
private readonly ILogger<B2R2IrTokenizer> _logger;
|
||||
|
||||
// Token vocabulary for common IR elements
|
||||
private static readonly HashSet<string> ControlFlowTokens =
|
||||
["[JMP]", "[JE]", "[JNE]", "[JL]", "[JG]", "[JLE]", "[JGE]", "[CALL]", "[RET]", "[LOOP]"];
|
||||
|
||||
private static readonly HashSet<string> DataFlowTokens =
|
||||
["[MOV]", "[LEA]", "[PUSH]", "[POP]", "[XCHG]", "[LOAD]", "[STORE]"];
|
||||
|
||||
private static readonly HashSet<string> ArithmeticTokens =
|
||||
["[ADD]", "[SUB]", "[MUL]", "[DIV]", "[INC]", "[DEC]", "[NEG]", "[SHL]", "[SHR]", "[AND]", "[OR]", "[XOR]", "[NOT]"];
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="B2R2IrTokenizer"/> class.
|
||||
/// </summary>
|
||||
public B2R2IrTokenizer(ILogger<B2R2IrTokenizer> logger)
|
||||
{
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task<IReadOnlyList<string>> TokenizeAsync(
|
||||
string libraryName,
|
||||
string version,
|
||||
string functionName,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// This would integrate with B2R2 to lift the function to IR
|
||||
// For now, return placeholder tokens
|
||||
_logger.LogDebug("Tokenizing function {Function} from {Library}:{Version}",
|
||||
functionName, libraryName, version);
|
||||
|
||||
var tokens = new List<string>
|
||||
{
|
||||
"[FUNC_START]",
|
||||
$"[NAME:{NormalizeName(functionName)}]",
|
||||
// IR tokens would be added here from B2R2 analysis
|
||||
"[FUNC_END]"
|
||||
};
|
||||
|
||||
return Task.FromResult<IReadOnlyList<string>>(tokens);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task<IReadOnlyList<string>> TokenizeInstructionsAsync(
|
||||
ReadOnlyMemory<byte> instructions,
|
||||
string architecture,
|
||||
TokenizationOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
options ??= TokenizationOptions.Default;
|
||||
var tokens = new List<string>();
|
||||
|
||||
// Add architecture token
|
||||
tokens.Add($"[ARCH:{architecture.ToUpperInvariant()}]");
|
||||
tokens.Add("[FUNC_START]");
|
||||
|
||||
// Disassemble and tokenize
|
||||
// This would use B2R2 for actual disassembly
|
||||
var disassembly = DisassembleToIr(instructions, architecture);
|
||||
|
||||
var varCounter = 0;
|
||||
var varMap = new Dictionary<string, string>();
|
||||
|
||||
foreach (var insn in disassembly)
|
||||
{
|
||||
// Add opcode token
|
||||
var opcodeToken = MapOpcodeToToken(insn.Opcode);
|
||||
tokens.Add(opcodeToken);
|
||||
|
||||
// Add operand tokens
|
||||
foreach (var operand in insn.Operands)
|
||||
{
|
||||
var operandToken = options.NormalizeVariables
|
||||
? NormalizeOperand(operand, varMap, ref varCounter)
|
||||
: operand;
|
||||
|
||||
if (options.IncludeOperandTypes)
|
||||
{
|
||||
var typeToken = InferOperandType(operand);
|
||||
tokens.Add($"{typeToken}:{operandToken}");
|
||||
}
|
||||
else
|
||||
{
|
||||
tokens.Add(operandToken);
|
||||
}
|
||||
}
|
||||
|
||||
// Add control flow marker if applicable
|
||||
if (options.IncludeControlFlow && IsControlFlowInstruction(insn.Opcode))
|
||||
{
|
||||
tokens.Add("[CF]");
|
||||
}
|
||||
}
|
||||
|
||||
tokens.Add("[FUNC_END]");
|
||||
|
||||
// Truncate or pad to max length
|
||||
if (tokens.Count > options.MaxLength)
|
||||
{
|
||||
tokens = tokens.Take(options.MaxLength - 1).Append("[TRUNCATED]").ToList();
|
||||
}
|
||||
|
||||
return Task.FromResult<IReadOnlyList<string>>(tokens);
|
||||
}
|
||||
|
||||
private static IReadOnlyList<DisassembledInstruction> DisassembleToIr(
|
||||
ReadOnlyMemory<byte> instructions,
|
||||
string architecture)
|
||||
{
|
||||
// Placeholder - would use B2R2 for actual disassembly
|
||||
// Return sample instructions for demonstration
|
||||
return new List<DisassembledInstruction>
|
||||
{
|
||||
new("push", ["rbp"]),
|
||||
new("mov", ["rbp", "rsp"]),
|
||||
new("sub", ["rsp", "0x20"]),
|
||||
new("mov", ["[rbp-0x8]", "rdi"]),
|
||||
new("call", ["helper_func"]),
|
||||
new("leave", []),
|
||||
new("ret", [])
|
||||
};
|
||||
}
|
||||
|
||||
private static string MapOpcodeToToken(string opcode)
|
||||
{
|
||||
var upper = opcode.ToUpperInvariant();
|
||||
|
||||
// Map to canonical token
|
||||
return upper switch
|
||||
{
|
||||
"JMP" or "JE" or "JNE" or "JZ" or "JNZ" or "JL" or "JG" or "JLE" or "JGE" or "JA" or "JB" =>
|
||||
$"[{upper}]",
|
||||
"CALL" => "[CALL]",
|
||||
"RET" or "RETN" => "[RET]",
|
||||
"MOV" or "MOVZX" or "MOVSX" => "[MOV]",
|
||||
"LEA" => "[LEA]",
|
||||
"PUSH" => "[PUSH]",
|
||||
"POP" => "[POP]",
|
||||
"ADD" => "[ADD]",
|
||||
"SUB" => "[SUB]",
|
||||
"MUL" or "IMUL" => "[MUL]",
|
||||
"DIV" or "IDIV" => "[DIV]",
|
||||
"AND" => "[AND]",
|
||||
"OR" => "[OR]",
|
||||
"XOR" => "[XOR]",
|
||||
"SHL" or "SAL" => "[SHL]",
|
||||
"SHR" or "SAR" => "[SHR]",
|
||||
"CMP" => "[CMP]",
|
||||
"TEST" => "[TEST]",
|
||||
"NOP" => "[NOP]",
|
||||
_ => $"[{upper}]"
|
||||
};
|
||||
}
|
||||
|
||||
private static string NormalizeOperand(
|
||||
string operand,
|
||||
Dictionary<string, string> varMap,
|
||||
ref int varCounter)
|
||||
{
|
||||
// Normalize registers to generic names
|
||||
if (IsRegister(operand))
|
||||
{
|
||||
if (!varMap.TryGetValue(operand, out var normalized))
|
||||
{
|
||||
normalized = $"v{varCounter++}";
|
||||
varMap[operand] = normalized;
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
// Normalize immediates
|
||||
if (IsImmediate(operand))
|
||||
{
|
||||
return "[IMM]";
|
||||
}
|
||||
|
||||
// Normalize memory references
|
||||
if (operand.Contains('['))
|
||||
{
|
||||
return "[MEM]";
|
||||
}
|
||||
|
||||
return operand;
|
||||
}
|
||||
|
||||
private static string InferOperandType(string operand)
|
||||
{
|
||||
if (IsRegister(operand)) return "[REG]";
|
||||
if (IsImmediate(operand)) return "[IMM]";
|
||||
if (operand.Contains('[')) return "[MEM]";
|
||||
if (operand.Contains("func") || operand.Contains("_")) return "[SYM]";
|
||||
return "[UNK]";
|
||||
}
|
||||
|
||||
private static bool IsRegister(string operand)
|
||||
{
|
||||
var lower = operand.ToLowerInvariant();
|
||||
return lower.StartsWith("r") || lower.StartsWith("e") ||
|
||||
lower is "rax" or "rbx" or "rcx" or "rdx" or "rsi" or "rdi" or "rsp" or "rbp" or
|
||||
"eax" or "ebx" or "ecx" or "edx" or "esi" or "edi" or "esp" or "ebp" or
|
||||
"ax" or "bx" or "cx" or "dx" or "si" or "di" or "sp" or "bp";
|
||||
}
|
||||
|
||||
private static bool IsImmediate(string operand)
|
||||
{
|
||||
return operand.StartsWith("0x") || operand.All(char.IsDigit);
|
||||
}
|
||||
|
||||
private static bool IsControlFlowInstruction(string opcode)
|
||||
{
|
||||
var upper = opcode.ToUpperInvariant();
|
||||
return upper.StartsWith('J') || upper is "CALL" or "RET" or "RETN" or "LOOP";
|
||||
}
|
||||
|
||||
private static string NormalizeName(string name)
|
||||
{
|
||||
// Remove version-specific suffixes, normalize casing
|
||||
var normalized = NameNormalizationRegex().Replace(name, "");
|
||||
return normalized.ToLowerInvariant();
|
||||
}
|
||||
|
||||
[GeneratedRegex(@"@\d+|\.\d+|_v\d+")]
|
||||
private static partial Regex NameNormalizationRegex();
|
||||
|
||||
private sealed record DisassembledInstruction(string Opcode, IReadOnlyList<string> Operands);
|
||||
}
|
||||
@@ -0,0 +1,249 @@
|
||||
// -----------------------------------------------------------------------------
|
||||
// GhidraDecompilerAdapter.cs
|
||||
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
// Task: MLEM-004 - Decompiled Code Extraction
|
||||
// Description: Ghidra-based decompiler adapter implementation.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
using System.Diagnostics;
|
||||
using System.Text;
|
||||
using System.Text.RegularExpressions;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML.Training;
|
||||
|
||||
/// <summary>
|
||||
/// Ghidra-based decompiler adapter.
|
||||
/// </summary>
|
||||
public sealed partial class GhidraDecompilerAdapter : IDecompilerAdapter
|
||||
{
|
||||
private readonly GhidraAdapterOptions _options;
|
||||
private readonly ILogger<GhidraDecompilerAdapter> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="GhidraDecompilerAdapter"/> class.
|
||||
/// </summary>
|
||||
public GhidraDecompilerAdapter(
|
||||
IOptions<GhidraAdapterOptions> options,
|
||||
ILogger<GhidraDecompilerAdapter> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<string?> DecompileAsync(
|
||||
string libraryName,
|
||||
string version,
|
||||
string functionName,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
_logger.LogDebug("Decompiling {Function} from {Library}:{Version}",
|
||||
functionName, libraryName, version);
|
||||
|
||||
// This would call Ghidra headless analyzer
|
||||
// For now, return placeholder
|
||||
return await Task.FromResult<string?>($"int {functionName}(void *param_1) {{\n int result;\n // Decompiled code placeholder\n result = 0;\n return result;\n}}");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<string?> DecompileBytesAsync(
|
||||
ReadOnlyMemory<byte> bytes,
|
||||
string architecture,
|
||||
DecompilationOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
options ??= DecompilationOptions.Default;
|
||||
|
||||
if (string.IsNullOrEmpty(_options.GhidraPath))
|
||||
{
|
||||
_logger.LogWarning("Ghidra path not configured");
|
||||
return null;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
// Create temp file with bytes
|
||||
var tempInput = Path.GetTempFileName();
|
||||
await File.WriteAllBytesAsync(tempInput, bytes.ToArray(), cancellationToken);
|
||||
|
||||
var tempOutput = Path.GetTempFileName();
|
||||
|
||||
try
|
||||
{
|
||||
// Run Ghidra headless
|
||||
var script = _options.DecompileScriptPath ?? "DecompileFunction.java";
|
||||
var args = $"-import {tempInput} -postScript {script} {tempOutput} -deleteProject -noanalysis";
|
||||
|
||||
var result = await RunGhidraAsync(args, options.Timeout, cancellationToken);
|
||||
|
||||
if (!result.Success)
|
||||
{
|
||||
_logger.LogWarning("Ghidra decompilation failed: {Error}", result.Error);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (File.Exists(tempOutput))
|
||||
{
|
||||
var decompiled = await File.ReadAllTextAsync(tempOutput, cancellationToken);
|
||||
return options.Simplify ? Normalize(decompiled) : decompiled;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (File.Exists(tempInput)) File.Delete(tempInput);
|
||||
if (File.Exists(tempOutput)) File.Delete(tempOutput);
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Decompilation failed");
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public string Normalize(string code, NormalizationOptions? options = null)
|
||||
{
|
||||
options ??= NormalizationOptions.Default;
|
||||
var result = code;
|
||||
|
||||
// Strip comments
|
||||
if (options.StripComments)
|
||||
{
|
||||
result = StripCommentsRegex().Replace(result, "");
|
||||
result = LineCommentRegex().Replace(result, "");
|
||||
}
|
||||
|
||||
// Normalize whitespace
|
||||
if (options.NormalizeWhitespace)
|
||||
{
|
||||
result = MultipleSpacesRegex().Replace(result, " ");
|
||||
result = EmptyLinesRegex().Replace(result, "\n");
|
||||
result = result.Trim();
|
||||
}
|
||||
|
||||
// Normalize variable names
|
||||
if (options.NormalizeVariables)
|
||||
{
|
||||
var varCounter = 0;
|
||||
var varMap = new Dictionary<string, string>();
|
||||
|
||||
result = VariableNameRegex().Replace(result, match =>
|
||||
{
|
||||
var name = match.Value;
|
||||
if (!varMap.TryGetValue(name, out var normalized))
|
||||
{
|
||||
normalized = $"var_{varCounter++}";
|
||||
varMap[name] = normalized;
|
||||
}
|
||||
return normalized;
|
||||
});
|
||||
}
|
||||
|
||||
// Remove type casts
|
||||
if (options.RemoveTypeCasts)
|
||||
{
|
||||
result = TypeCastRegex().Replace(result, "");
|
||||
}
|
||||
|
||||
// Truncate if too long
|
||||
if (result.Length > options.MaxLength)
|
||||
{
|
||||
result = result[..options.MaxLength] + "\n/* truncated */";
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private async Task<(bool Success, string? Error)> RunGhidraAsync(
|
||||
string args,
|
||||
TimeSpan timeout,
|
||||
CancellationToken ct)
|
||||
{
|
||||
var analyzeHeadless = Path.Combine(_options.GhidraPath!, "support", "analyzeHeadless");
|
||||
|
||||
var psi = new ProcessStartInfo
|
||||
{
|
||||
FileName = analyzeHeadless,
|
||||
Arguments = args,
|
||||
RedirectStandardOutput = true,
|
||||
RedirectStandardError = true,
|
||||
UseShellExecute = false,
|
||||
CreateNoWindow = true
|
||||
};
|
||||
|
||||
using var process = new Process { StartInfo = psi };
|
||||
var output = new StringBuilder();
|
||||
var error = new StringBuilder();
|
||||
|
||||
process.OutputDataReceived += (_, e) =>
|
||||
{
|
||||
if (e.Data is not null) output.AppendLine(e.Data);
|
||||
};
|
||||
process.ErrorDataReceived += (_, e) =>
|
||||
{
|
||||
if (e.Data is not null) error.AppendLine(e.Data);
|
||||
};
|
||||
|
||||
process.Start();
|
||||
process.BeginOutputReadLine();
|
||||
process.BeginErrorReadLine();
|
||||
|
||||
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
cts.CancelAfter(timeout);
|
||||
|
||||
try
|
||||
{
|
||||
await process.WaitForExitAsync(cts.Token);
|
||||
return (process.ExitCode == 0, error.Length > 0 ? error.ToString() : null);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
process.Kill(true);
|
||||
return (false, "Timeout");
|
||||
}
|
||||
}
|
||||
|
||||
[GeneratedRegex(@"/\*.*?\*/", RegexOptions.Singleline)]
|
||||
private static partial Regex StripCommentsRegex();
|
||||
|
||||
[GeneratedRegex(@"//.*$", RegexOptions.Multiline)]
|
||||
private static partial Regex LineCommentRegex();
|
||||
|
||||
[GeneratedRegex(@"\s+")]
|
||||
private static partial Regex MultipleSpacesRegex();
|
||||
|
||||
[GeneratedRegex(@"\n\s*\n")]
|
||||
private static partial Regex EmptyLinesRegex();
|
||||
|
||||
[GeneratedRegex(@"\b(local_|param_|DAT_|FUN_)[a-zA-Z0-9_]+")]
|
||||
private static partial Regex VariableNameRegex();
|
||||
|
||||
[GeneratedRegex(@"\(\s*[a-zA-Z_][a-zA-Z0-9_]*\s*\*?\s*\)")]
|
||||
private static partial Regex TypeCastRegex();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for Ghidra adapter.
|
||||
/// </summary>
|
||||
public sealed record GhidraAdapterOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the path to Ghidra installation.
|
||||
/// </summary>
|
||||
public string? GhidraPath { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the path to decompile script.
|
||||
/// </summary>
|
||||
public string? DecompileScriptPath { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the project directory for temp projects.
|
||||
/// </summary>
|
||||
public string? ProjectDirectory { get; init; }
|
||||
}
|
||||
@@ -0,0 +1,355 @@
|
||||
// -----------------------------------------------------------------------------
|
||||
// GroundTruthCorpusBuilder.cs
|
||||
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
// Task: MLEM-002 - Corpus Builder from Ground-Truth
|
||||
// Description: Implementation of corpus builder using ground-truth data.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML.Training;
|
||||
|
||||
/// <summary>
|
||||
/// Builds training corpus from ground-truth security pairs.
|
||||
/// </summary>
|
||||
public sealed class GroundTruthCorpusBuilder : ICorpusBuilder
|
||||
{
|
||||
private readonly IIrTokenizer _tokenizer;
|
||||
private readonly IDecompilerAdapter _decompiler;
|
||||
private readonly ILogger<GroundTruthCorpusBuilder> _logger;
|
||||
|
||||
private readonly List<TrainingFunctionPair> _positivePairs = [];
|
||||
private readonly List<TrainingFunctionPair> _negativePairs = [];
|
||||
private readonly Dictionary<string, FunctionRepresentation> _functionCache = [];
|
||||
private readonly Random _random;
|
||||
|
||||
private static readonly JsonSerializerOptions JsonOptions = new()
|
||||
{
|
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
||||
WriteIndented = false
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="GroundTruthCorpusBuilder"/> class.
|
||||
/// </summary>
|
||||
public GroundTruthCorpusBuilder(
|
||||
IIrTokenizer tokenizer,
|
||||
IDecompilerAdapter decompiler,
|
||||
ILogger<GroundTruthCorpusBuilder> logger,
|
||||
int? randomSeed = null)
|
||||
{
|
||||
_tokenizer = tokenizer;
|
||||
_decompiler = decompiler;
|
||||
_logger = logger;
|
||||
_random = randomSeed.HasValue ? new Random(randomSeed.Value) : new Random();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<TrainingCorpus> BuildCorpusAsync(
|
||||
CorpusBuildOptions options,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
_logger.LogInformation("Building training corpus with target {Positive} positive, {Negative} negative pairs",
|
||||
options.TargetPositivePairs, options.TargetNegativePairs);
|
||||
|
||||
// Load security pairs
|
||||
if (options.SecurityPairPaths is { Count: > 0 })
|
||||
{
|
||||
foreach (var path in options.SecurityPairPaths)
|
||||
{
|
||||
await AddSecurityPairsAsync(path, cancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
// Generate negative pairs if needed
|
||||
var neededNegatives = options.TargetNegativePairs - _negativePairs.Count;
|
||||
if (neededNegatives > 0)
|
||||
{
|
||||
await GenerateNegativePairsAsync(neededNegatives, cancellationToken);
|
||||
}
|
||||
|
||||
// Combine and shuffle
|
||||
var allPairs = _positivePairs.Concat(_negativePairs).ToList();
|
||||
Shuffle(allPairs);
|
||||
|
||||
// Split into train/val/test
|
||||
var splitConfig = options.SplitConfig;
|
||||
var trainCount = (int)(allPairs.Count * splitConfig.TrainRatio);
|
||||
var valCount = (int)(allPairs.Count * splitConfig.ValidationRatio);
|
||||
|
||||
var trainPairs = allPairs.Take(trainCount).ToList();
|
||||
var valPairs = allPairs.Skip(trainCount).Take(valCount).ToList();
|
||||
var testPairs = allPairs.Skip(trainCount + valCount).ToList();
|
||||
|
||||
_logger.LogInformation(
|
||||
"Corpus built: {Train} train, {Val} validation, {Test} test pairs",
|
||||
trainPairs.Count, valPairs.Count, testPairs.Count);
|
||||
|
||||
return new TrainingCorpus
|
||||
{
|
||||
Version = "1.0",
|
||||
CreatedAt = DateTimeOffset.UtcNow,
|
||||
Description = "Ground-truth security pairs corpus",
|
||||
TrainingPairs = trainPairs,
|
||||
ValidationPairs = valPairs,
|
||||
TestPairs = testPairs,
|
||||
Statistics = GetStatistics()
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<int> AddSecurityPairsAsync(
|
||||
string securityPairPath,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (!File.Exists(securityPairPath))
|
||||
{
|
||||
_logger.LogWarning("Security pair file not found: {Path}", securityPairPath);
|
||||
return 0;
|
||||
}
|
||||
|
||||
var added = 0;
|
||||
|
||||
await foreach (var line in File.ReadLinesAsync(securityPairPath, cancellationToken))
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(line)) continue;
|
||||
|
||||
try
|
||||
{
|
||||
var pairData = JsonSerializer.Deserialize<SecurityPairData>(line, JsonOptions);
|
||||
if (pairData is null) continue;
|
||||
|
||||
// Extract function pairs from security pair
|
||||
var pairs = await ExtractFunctionPairsAsync(pairData, cancellationToken);
|
||||
_positivePairs.AddRange(pairs);
|
||||
added += pairs.Count;
|
||||
}
|
||||
catch (JsonException ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Failed to parse security pair line");
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogDebug("Added {Count} pairs from {Path}", added, securityPairPath);
|
||||
return added;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<int> GenerateNegativePairsAsync(
|
||||
int count,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var functions = _functionCache.Values.ToList();
|
||||
if (functions.Count < 2)
|
||||
{
|
||||
_logger.LogWarning("Not enough functions in cache to generate negative pairs");
|
||||
return 0;
|
||||
}
|
||||
|
||||
var generated = 0;
|
||||
|
||||
for (var i = 0; i < count && !cancellationToken.IsCancellationRequested; i++)
|
||||
{
|
||||
// Pick two random functions that are different
|
||||
var idx1 = _random.Next(functions.Count);
|
||||
var idx2 = _random.Next(functions.Count);
|
||||
|
||||
if (idx1 == idx2) idx2 = (idx2 + 1) % functions.Count;
|
||||
|
||||
var func1 = functions[idx1];
|
||||
var func2 = functions[idx2];
|
||||
|
||||
// Skip if same function (by name) from different versions
|
||||
if (func1.FunctionName == func2.FunctionName &&
|
||||
func1.LibraryName == func2.LibraryName)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
_negativePairs.Add(new TrainingFunctionPair
|
||||
{
|
||||
PairId = $"neg_{Guid.NewGuid():N}",
|
||||
Function1 = func1,
|
||||
Function2 = func2,
|
||||
Label = EquivalenceLabel.Different,
|
||||
Confidence = 1.0,
|
||||
Source = "generated:negative_sampling"
|
||||
});
|
||||
|
||||
generated++;
|
||||
}
|
||||
|
||||
_logger.LogDebug("Generated {Count} negative pairs", generated);
|
||||
return generated;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task ExportAsync(
|
||||
string outputPath,
|
||||
CorpusExportFormat format = CorpusExportFormat.JsonLines,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var allPairs = _positivePairs.Concat(_negativePairs);
|
||||
|
||||
var directory = Path.GetDirectoryName(outputPath);
|
||||
if (!string.IsNullOrEmpty(directory))
|
||||
{
|
||||
Directory.CreateDirectory(directory);
|
||||
}
|
||||
|
||||
switch (format)
|
||||
{
|
||||
case CorpusExportFormat.JsonLines:
|
||||
await using (var writer = new StreamWriter(outputPath))
|
||||
{
|
||||
foreach (var pair in allPairs)
|
||||
{
|
||||
var json = JsonSerializer.Serialize(pair, JsonOptions);
|
||||
await writer.WriteLineAsync(json);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case CorpusExportFormat.Json:
|
||||
var corpus = new TrainingCorpus
|
||||
{
|
||||
Version = "1.0",
|
||||
CreatedAt = DateTimeOffset.UtcNow,
|
||||
TrainingPairs = allPairs.ToList(),
|
||||
Statistics = GetStatistics()
|
||||
};
|
||||
var corpusJson = JsonSerializer.Serialize(corpus, new JsonSerializerOptions
|
||||
{
|
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
||||
WriteIndented = true
|
||||
});
|
||||
await File.WriteAllTextAsync(outputPath, corpusJson, cancellationToken);
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new NotSupportedException($"Export format {format} not yet supported");
|
||||
}
|
||||
|
||||
_logger.LogInformation("Exported corpus to {Path}", outputPath);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public CorpusStatistics GetStatistics()
|
||||
{
|
||||
var allPairs = _positivePairs.Concat(_negativePairs).ToList();
|
||||
var allFunctions = allPairs
|
||||
.SelectMany(p => new[] { p.Function1, p.Function2 })
|
||||
.ToList();
|
||||
|
||||
return new CorpusStatistics
|
||||
{
|
||||
TotalPairs = allPairs.Count,
|
||||
EquivalentPairs = allPairs.Count(p => p.Label == EquivalenceLabel.Equivalent),
|
||||
DifferentPairs = allPairs.Count(p => p.Label == EquivalenceLabel.Different),
|
||||
UnknownPairs = allPairs.Count(p => p.Label == EquivalenceLabel.Unknown),
|
||||
UniqueLibraries = allFunctions.Select(f => f.LibraryName).Distinct().Count(),
|
||||
UniqueFunctions = allFunctions.Select(f => f.FunctionName).Distinct().Count(),
|
||||
Architectures = allFunctions.Select(f => f.Architecture).Distinct().ToList()
|
||||
};
|
||||
}
|
||||
|
||||
private async Task<List<TrainingFunctionPair>> ExtractFunctionPairsAsync(
|
||||
SecurityPairData pairData,
|
||||
CancellationToken ct)
|
||||
{
|
||||
var pairs = new List<TrainingFunctionPair>();
|
||||
|
||||
// For each affected function, create a positive pair
|
||||
foreach (var funcName in pairData.AffectedFunctions ?? [])
|
||||
{
|
||||
var func1 = await GetFunctionRepresentationAsync(
|
||||
pairData.LibraryName,
|
||||
pairData.VersionBefore,
|
||||
funcName,
|
||||
pairData.Architecture ?? "x86_64",
|
||||
ct);
|
||||
|
||||
var func2 = await GetFunctionRepresentationAsync(
|
||||
pairData.LibraryName,
|
||||
pairData.VersionAfter,
|
||||
funcName,
|
||||
pairData.Architecture ?? "x86_64",
|
||||
ct);
|
||||
|
||||
if (func1 is not null && func2 is not null)
|
||||
{
|
||||
pairs.Add(new TrainingFunctionPair
|
||||
{
|
||||
PairId = $"pos_{pairData.CveId}_{funcName}_{Guid.NewGuid():N}",
|
||||
Function1 = func1,
|
||||
Function2 = func2,
|
||||
Label = EquivalenceLabel.Equivalent,
|
||||
Confidence = 1.0,
|
||||
Source = $"groundtruth:security_pair:{pairData.CveId}",
|
||||
Metadata = new TrainingPairMetadata
|
||||
{
|
||||
CveId = pairData.CveId,
|
||||
IsPatched = true,
|
||||
Distribution = pairData.Distribution
|
||||
}
|
||||
});
|
||||
|
||||
// Cache functions for negative pair generation
|
||||
_functionCache[$"{func1.LibraryName}:{func1.LibraryVersion}:{func1.FunctionName}"] = func1;
|
||||
_functionCache[$"{func2.LibraryName}:{func2.LibraryVersion}:{func2.FunctionName}"] = func2;
|
||||
}
|
||||
}
|
||||
|
||||
return pairs;
|
||||
}
|
||||
|
||||
private async Task<FunctionRepresentation?> GetFunctionRepresentationAsync(
|
||||
string libraryName,
|
||||
string version,
|
||||
string functionName,
|
||||
string architecture,
|
||||
CancellationToken ct)
|
||||
{
|
||||
// Extract IR tokens
|
||||
var irTokens = await _tokenizer.TokenizeAsync(libraryName, version, functionName, ct);
|
||||
|
||||
// Get decompiled code
|
||||
var decompiled = await _decompiler.DecompileAsync(libraryName, version, functionName, ct);
|
||||
|
||||
return new FunctionRepresentation
|
||||
{
|
||||
LibraryName = libraryName,
|
||||
LibraryVersion = version,
|
||||
FunctionName = functionName,
|
||||
Architecture = architecture,
|
||||
IrTokens = irTokens,
|
||||
DecompiledCode = decompiled
|
||||
};
|
||||
}
|
||||
|
||||
private void Shuffle<T>(List<T> list)
|
||||
{
|
||||
var n = list.Count;
|
||||
while (n > 1)
|
||||
{
|
||||
n--;
|
||||
var k = _random.Next(n + 1);
|
||||
(list[k], list[n]) = (list[n], list[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Security pair data from ground-truth.
|
||||
/// </summary>
|
||||
internal sealed record SecurityPairData
|
||||
{
|
||||
public string? CveId { get; init; }
|
||||
public string LibraryName { get; init; } = "";
|
||||
public string VersionBefore { get; init; } = "";
|
||||
public string VersionAfter { get; init; } = "";
|
||||
public IReadOnlyList<string>? AffectedFunctions { get; init; }
|
||||
public string? Architecture { get; init; }
|
||||
public string? Distribution { get; init; }
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
// -----------------------------------------------------------------------------
|
||||
// ICorpusBuilder.cs
|
||||
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
// Task: MLEM-002 - Corpus Builder from Ground-Truth
|
||||
// Description: Interface for building training corpus from ground-truth data.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML.Training;
|
||||
|
||||
/// <summary>
|
||||
/// Builder for ML training corpus from ground-truth data.
|
||||
/// </summary>
|
||||
public interface ICorpusBuilder
|
||||
{
|
||||
/// <summary>
|
||||
/// Builds a training corpus from security pairs.
|
||||
/// </summary>
|
||||
/// <param name="options">Build options.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>The built corpus.</returns>
|
||||
Task<TrainingCorpus> BuildCorpusAsync(
|
||||
CorpusBuildOptions options,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Adds pairs from a security pair source.
|
||||
/// </summary>
|
||||
/// <param name="securityPairPath">Path to security pair data.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>Number of pairs added.</returns>
|
||||
Task<int> AddSecurityPairsAsync(
|
||||
string securityPairPath,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Generates negative pairs from existing functions.
|
||||
/// </summary>
|
||||
/// <param name="count">Number of negative pairs to generate.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>Number of pairs generated.</returns>
|
||||
Task<int> GenerateNegativePairsAsync(
|
||||
int count,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Exports the corpus to a file.
|
||||
/// </summary>
|
||||
/// <param name="outputPath">Output file path.</param>
|
||||
/// <param name="format">Export format.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
Task ExportAsync(
|
||||
string outputPath,
|
||||
CorpusExportFormat format = CorpusExportFormat.JsonLines,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Gets current build statistics.
|
||||
/// </summary>
|
||||
CorpusStatistics GetStatistics();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for corpus building.
|
||||
/// </summary>
|
||||
public sealed record CorpusBuildOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets paths to security pair data.
|
||||
/// </summary>
|
||||
public IReadOnlyList<string>? SecurityPairPaths { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the target number of positive pairs.
|
||||
/// </summary>
|
||||
public int TargetPositivePairs { get; init; } = 15000;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the target number of negative pairs.
|
||||
/// </summary>
|
||||
public int TargetNegativePairs { get; init; } = 15000;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the split configuration.
|
||||
/// </summary>
|
||||
public CorpusSplitConfig SplitConfig { get; init; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to include IR tokens.
|
||||
/// </summary>
|
||||
public bool IncludeIrTokens { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to include decompiled code.
|
||||
/// </summary>
|
||||
public bool IncludeDecompiledCode { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to include fingerprints.
|
||||
/// </summary>
|
||||
public bool IncludeFingerprints { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the maximum IR token sequence length.
|
||||
/// </summary>
|
||||
public int MaxIrTokenLength { get; init; } = 512;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the maximum decompiled code length.
|
||||
/// </summary>
|
||||
public int MaxDecompiledLength { get; init; } = 2048;
|
||||
|
||||
/// <summary>
|
||||
/// Gets libraries to include (null = all).
|
||||
/// </summary>
|
||||
public IReadOnlyList<string>? IncludeLibraries { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets architectures to include (null = all).
|
||||
/// </summary>
|
||||
public IReadOnlyList<string>? IncludeArchitectures { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Export format for corpus.
|
||||
/// </summary>
|
||||
public enum CorpusExportFormat
|
||||
{
|
||||
/// <summary>
|
||||
/// JSON Lines format (one pair per line).
|
||||
/// </summary>
|
||||
JsonLines,
|
||||
|
||||
/// <summary>
|
||||
/// Single JSON file.
|
||||
/// </summary>
|
||||
Json,
|
||||
|
||||
/// <summary>
|
||||
/// Parquet format for large datasets.
|
||||
/// </summary>
|
||||
Parquet,
|
||||
|
||||
/// <summary>
|
||||
/// HuggingFace datasets format.
|
||||
/// </summary>
|
||||
HuggingFace
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
// -----------------------------------------------------------------------------
|
||||
// IDecompilerAdapter.cs
|
||||
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
// Task: MLEM-004 - Decompiled Code Extraction
|
||||
// Description: Interface for decompiler integration.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML.Training;
|
||||
|
||||
/// <summary>
|
||||
/// Adapter for decompiler integration.
|
||||
/// </summary>
|
||||
public interface IDecompilerAdapter
|
||||
{
|
||||
/// <summary>
|
||||
/// Decompiles a function to C-like code.
|
||||
/// </summary>
|
||||
/// <param name="libraryName">Library name.</param>
|
||||
/// <param name="version">Library version.</param>
|
||||
/// <param name="functionName">Function name.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>Decompiled code.</returns>
|
||||
Task<string?> DecompileAsync(
|
||||
string libraryName,
|
||||
string version,
|
||||
string functionName,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Decompiles raw bytes to C-like code.
|
||||
/// </summary>
|
||||
/// <param name="bytes">Function bytes.</param>
|
||||
/// <param name="architecture">Target architecture.</param>
|
||||
/// <param name="options">Decompilation options.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>Decompiled code.</returns>
|
||||
Task<string?> DecompileBytesAsync(
|
||||
ReadOnlyMemory<byte> bytes,
|
||||
string architecture,
|
||||
DecompilationOptions? options = null,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Normalizes decompiled code for ML input.
|
||||
/// </summary>
|
||||
/// <param name="code">Raw decompiled code.</param>
|
||||
/// <param name="options">Normalization options.</param>
|
||||
/// <returns>Normalized code.</returns>
|
||||
string Normalize(string code, NormalizationOptions? options = null);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for decompilation.
|
||||
/// </summary>
|
||||
public sealed record DecompilationOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the decompiler to use.
|
||||
/// </summary>
|
||||
public DecompilerType Decompiler { get; init; } = DecompilerType.Ghidra;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to simplify the output.
|
||||
/// </summary>
|
||||
public bool Simplify { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the timeout for decompilation.
|
||||
/// </summary>
|
||||
public TimeSpan Timeout { get; init; } = TimeSpan.FromSeconds(30);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the default options.
|
||||
/// </summary>
|
||||
public static DecompilationOptions Default { get; } = new();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Available decompilers.
|
||||
/// </summary>
|
||||
public enum DecompilerType
|
||||
{
|
||||
/// <summary>
|
||||
/// Ghidra decompiler.
|
||||
/// </summary>
|
||||
Ghidra,
|
||||
|
||||
/// <summary>
|
||||
/// RetDec decompiler.
|
||||
/// </summary>
|
||||
RetDec,
|
||||
|
||||
/// <summary>
|
||||
/// Hex-Rays decompiler (IDA Pro).
|
||||
/// </summary>
|
||||
HexRays
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for code normalization.
|
||||
/// </summary>
|
||||
public sealed record NormalizationOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets whether to strip comments.
|
||||
/// </summary>
|
||||
public bool StripComments { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to normalize variable names.
|
||||
/// </summary>
|
||||
public bool NormalizeVariables { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to normalize whitespace.
|
||||
/// </summary>
|
||||
public bool NormalizeWhitespace { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to remove type casts.
|
||||
/// </summary>
|
||||
public bool RemoveTypeCasts { get; init; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the maximum length.
|
||||
/// </summary>
|
||||
public int MaxLength { get; init; } = 2048;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the default options.
|
||||
/// </summary>
|
||||
public static NormalizationOptions Default { get; } = new();
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
// -----------------------------------------------------------------------------
|
||||
// IFunctionEmbeddingService.cs
|
||||
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
// Task: MLEM-006 - Embedding Inference Service
|
||||
// Description: Interface for function embedding inference.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML.Training;
|
||||
|
||||
/// <summary>
|
||||
/// Service for computing function embeddings.
|
||||
/// </summary>
|
||||
public interface IFunctionEmbeddingService
|
||||
{
|
||||
/// <summary>
|
||||
/// Computes an embedding for a function representation.
|
||||
/// </summary>
|
||||
/// <param name="function">Function representation.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>Embedding vector.</returns>
|
||||
Task<float[]> GetEmbeddingAsync(
|
||||
FunctionRepresentation function,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Computes embeddings for multiple functions (batched).
|
||||
/// </summary>
|
||||
/// <param name="functions">Function representations.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>Embedding vectors.</returns>
|
||||
Task<IReadOnlyList<float[]>> GetEmbeddingsBatchAsync(
|
||||
IReadOnlyList<FunctionRepresentation> functions,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Computes similarity between two embeddings.
|
||||
/// </summary>
|
||||
/// <param name="embedding1">First embedding.</param>
|
||||
/// <param name="embedding2">Second embedding.</param>
|
||||
/// <returns>Similarity score (0.0 to 1.0).</returns>
|
||||
float ComputeSimilarity(float[] embedding1, float[] embedding2);
|
||||
|
||||
/// <summary>
|
||||
/// Finds similar functions by embedding.
|
||||
/// </summary>
|
||||
/// <param name="queryEmbedding">Query embedding.</param>
|
||||
/// <param name="topK">Number of results to return.</param>
|
||||
/// <param name="threshold">Minimum similarity threshold.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>Similar functions with scores.</returns>
|
||||
Task<IReadOnlyList<EmbeddingSimilarityResult>> FindSimilarAsync(
|
||||
float[] queryEmbedding,
|
||||
int topK = 10,
|
||||
float threshold = 0.7f,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Gets model information.
|
||||
/// </summary>
|
||||
EmbeddingModelInfo GetModelInfo();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of similarity search.
|
||||
/// </summary>
|
||||
public sealed record EmbeddingSimilarityResult
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the function ID.
|
||||
/// </summary>
|
||||
public required string FunctionId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the function name.
|
||||
/// </summary>
|
||||
public required string FunctionName { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the library name.
|
||||
/// </summary>
|
||||
public string? LibraryName { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the library version.
|
||||
/// </summary>
|
||||
public string? LibraryVersion { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the similarity score.
|
||||
/// </summary>
|
||||
public required float Similarity { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Information about the embedding model.
|
||||
/// </summary>
|
||||
public sealed record EmbeddingModelInfo
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the model name.
|
||||
/// </summary>
|
||||
public required string Name { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the model version.
|
||||
/// </summary>
|
||||
public required string Version { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the embedding dimension.
|
||||
/// </summary>
|
||||
public required int Dimension { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the maximum sequence length.
|
||||
/// </summary>
|
||||
public int MaxSequenceLength { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether the model is loaded.
|
||||
/// </summary>
|
||||
public bool IsLoaded { get; init; }
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
// -----------------------------------------------------------------------------
|
||||
// IIrTokenizer.cs
|
||||
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
// Task: MLEM-003 - IR Token Extraction
|
||||
// Description: Interface for IR tokenization for ML input.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML.Training;
|
||||
|
||||
/// <summary>
|
||||
/// Tokenizes function IR for transformer input.
|
||||
/// </summary>
|
||||
public interface IIrTokenizer
|
||||
{
|
||||
/// <summary>
|
||||
/// Tokenizes a function into IR tokens.
|
||||
/// </summary>
|
||||
/// <param name="libraryName">Library name.</param>
|
||||
/// <param name="version">Library version.</param>
|
||||
/// <param name="functionName">Function name.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>List of IR tokens.</returns>
|
||||
Task<IReadOnlyList<string>> TokenizeAsync(
|
||||
string libraryName,
|
||||
string version,
|
||||
string functionName,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Tokenizes raw instruction bytes.
|
||||
/// </summary>
|
||||
/// <param name="instructions">Raw instruction bytes.</param>
|
||||
/// <param name="architecture">Target architecture.</param>
|
||||
/// <param name="options">Tokenization options.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>List of IR tokens.</returns>
|
||||
Task<IReadOnlyList<string>> TokenizeInstructionsAsync(
|
||||
ReadOnlyMemory<byte> instructions,
|
||||
string architecture,
|
||||
TokenizationOptions? options = null,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for IR tokenization.
|
||||
/// </summary>
|
||||
public sealed record TokenizationOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the maximum token sequence length.
|
||||
/// </summary>
|
||||
public int MaxLength { get; init; } = 512;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to normalize variable names.
|
||||
/// </summary>
|
||||
public bool NormalizeVariables { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to include operand types.
|
||||
/// </summary>
|
||||
public bool IncludeOperandTypes { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to include control flow tokens.
|
||||
/// </summary>
|
||||
public bool IncludeControlFlow { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the default options.
|
||||
/// </summary>
|
||||
public static TokenizationOptions Default { get; } = new();
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
// -----------------------------------------------------------------------------
|
||||
// MlEmbeddingMatcherAdapter.cs
|
||||
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
// Task: MLEM-007 - Ensemble Integration
|
||||
// Description: Adapter for integrating ML embeddings into validation harness.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML.Training;
|
||||
|
||||
/// <summary>
|
||||
/// Matcher adapter for ML embeddings integration with validation harness.
|
||||
/// </summary>
|
||||
public sealed class MlEmbeddingMatcherAdapter
|
||||
{
|
||||
private readonly IFunctionEmbeddingService _embeddingService;
|
||||
private readonly ILogger<MlEmbeddingMatcherAdapter> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the default weight for this matcher in the ensemble.
|
||||
/// </summary>
|
||||
public const double DefaultWeight = 0.25; // 25% per architecture doc
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="MlEmbeddingMatcherAdapter"/> class.
|
||||
/// </summary>
|
||||
public MlEmbeddingMatcherAdapter(
|
||||
IFunctionEmbeddingService embeddingService,
|
||||
ILogger<MlEmbeddingMatcherAdapter> logger)
|
||||
{
|
||||
_embeddingService = embeddingService;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes match score between two functions using ML embeddings.
|
||||
/// </summary>
|
||||
/// <param name="function1">First function.</param>
|
||||
/// <param name="function2">Second function.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>Match score (0.0 to 1.0).</returns>
|
||||
public async Task<double> ComputeMatchScoreAsync(
|
||||
FunctionRepresentation function1,
|
||||
FunctionRepresentation function2,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
try
|
||||
{
|
||||
var embedding1 = await _embeddingService.GetEmbeddingAsync(function1, cancellationToken);
|
||||
var embedding2 = await _embeddingService.GetEmbeddingAsync(function2, cancellationToken);
|
||||
|
||||
var similarity = _embeddingService.ComputeSimilarity(embedding1, embedding2);
|
||||
|
||||
_logger.LogDebug(
|
||||
"ML embedding match score for {Func1} vs {Func2}: {Score:F4}",
|
||||
function1.FunctionName,
|
||||
function2.FunctionName,
|
||||
similarity);
|
||||
|
||||
return similarity;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Failed to compute ML embedding score");
|
||||
return 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes match scores for a batch of function pairs.
|
||||
/// </summary>
|
||||
/// <param name="pairs">Function pairs to compare.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>Match scores for each pair.</returns>
|
||||
public async Task<IReadOnlyList<double>> ComputeMatchScoresBatchAsync(
|
||||
IReadOnlyList<(FunctionRepresentation Function1, FunctionRepresentation Function2)> pairs,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var allFunctions = pairs
|
||||
.SelectMany(p => new[] { p.Function1, p.Function2 })
|
||||
.Distinct()
|
||||
.ToList();
|
||||
|
||||
// Get all embeddings in batch
|
||||
var embeddings = await _embeddingService.GetEmbeddingsBatchAsync(allFunctions, cancellationToken);
|
||||
|
||||
// Build lookup
|
||||
var embeddingLookup = new Dictionary<string, float[]>();
|
||||
for (var i = 0; i < allFunctions.Count; i++)
|
||||
{
|
||||
var key = GetFunctionKey(allFunctions[i]);
|
||||
embeddingLookup[key] = embeddings[i];
|
||||
}
|
||||
|
||||
// Compute scores
|
||||
var scores = new List<double>();
|
||||
foreach (var (func1, func2) in pairs)
|
||||
{
|
||||
var key1 = GetFunctionKey(func1);
|
||||
var key2 = GetFunctionKey(func2);
|
||||
|
||||
if (embeddingLookup.TryGetValue(key1, out var emb1) &&
|
||||
embeddingLookup.TryGetValue(key2, out var emb2))
|
||||
{
|
||||
scores.Add(_embeddingService.ComputeSimilarity(emb1, emb2));
|
||||
}
|
||||
else
|
||||
{
|
||||
scores.Add(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
return scores;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets ensemble weight configuration.
|
||||
/// </summary>
|
||||
public EnsembleWeightConfig GetEnsembleConfig() => new()
|
||||
{
|
||||
InstructionHashWeight = 0.15,
|
||||
SemanticGraphWeight = 0.25,
|
||||
DecompiledAstWeight = 0.35,
|
||||
MlEmbeddingWeight = 0.25
|
||||
};
|
||||
|
||||
private static string GetFunctionKey(FunctionRepresentation function)
|
||||
{
|
||||
return $"{function.LibraryName}:{function.LibraryVersion}:{function.FunctionName}:{function.Architecture}";
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Ensemble weight configuration.
|
||||
/// </summary>
|
||||
public sealed record EnsembleWeightConfig
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the instruction hash matcher weight.
|
||||
/// </summary>
|
||||
public double InstructionHashWeight { get; init; } = 0.15;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the semantic graph matcher weight.
|
||||
/// </summary>
|
||||
public double SemanticGraphWeight { get; init; } = 0.25;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the decompiled AST matcher weight.
|
||||
/// </summary>
|
||||
public double DecompiledAstWeight { get; init; } = 0.35;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the ML embedding matcher weight.
|
||||
/// </summary>
|
||||
public double MlEmbeddingWeight { get; init; } = 0.25;
|
||||
|
||||
/// <summary>
|
||||
/// Validates that weights sum to 1.0.
|
||||
/// </summary>
|
||||
public void Validate()
|
||||
{
|
||||
var sum = InstructionHashWeight + SemanticGraphWeight +
|
||||
DecompiledAstWeight + MlEmbeddingWeight;
|
||||
if (Math.Abs(sum - 1.0) > 0.001)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
$"Ensemble weights must sum to 1.0, but sum is {sum}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,309 @@
|
||||
// -----------------------------------------------------------------------------
|
||||
// OnnxFunctionEmbeddingService.cs
|
||||
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
// Task: MLEM-006 - Embedding Inference Service
|
||||
// Description: ONNX-based function embedding service.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML.Training;
|
||||
|
||||
/// <summary>
|
||||
/// ONNX-based function embedding service.
|
||||
/// </summary>
|
||||
public sealed class OnnxFunctionEmbeddingService : IFunctionEmbeddingService, IDisposable
|
||||
{
|
||||
private readonly OnnxEmbeddingServiceOptions _options;
|
||||
private readonly IIrTokenizer _tokenizer;
|
||||
private readonly ILogger<OnnxFunctionEmbeddingService> _logger;
|
||||
private readonly Dictionary<string, float[]> _embeddingCache = [];
|
||||
private readonly SemaphoreSlim _cacheLock = new(1, 1);
|
||||
|
||||
private bool _modelLoaded;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="OnnxFunctionEmbeddingService"/> class.
|
||||
/// </summary>
|
||||
public OnnxFunctionEmbeddingService(
|
||||
IOptions<OnnxEmbeddingServiceOptions> options,
|
||||
IIrTokenizer tokenizer,
|
||||
ILogger<OnnxFunctionEmbeddingService> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_tokenizer = tokenizer;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<float[]> GetEmbeddingAsync(
|
||||
FunctionRepresentation function,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var cacheKey = GetCacheKey(function);
|
||||
|
||||
// Check cache
|
||||
if (_options.EnableCache)
|
||||
{
|
||||
await _cacheLock.WaitAsync(cancellationToken);
|
||||
try
|
||||
{
|
||||
if (_embeddingCache.TryGetValue(cacheKey, out var cached))
|
||||
{
|
||||
return cached;
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
_cacheLock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure model is loaded
|
||||
await EnsureModelLoadedAsync(cancellationToken);
|
||||
|
||||
// Prepare input
|
||||
var tokens = function.IrTokens?.ToList() ??
|
||||
await _tokenizer.TokenizeAsync(
|
||||
function.LibraryName,
|
||||
function.LibraryVersion,
|
||||
function.FunctionName,
|
||||
cancellationToken) as List<string> ?? [];
|
||||
|
||||
// Pad or truncate to max length
|
||||
var maxLen = _options.MaxSequenceLength;
|
||||
if (tokens.Count > maxLen)
|
||||
{
|
||||
tokens = tokens.Take(maxLen).ToList();
|
||||
}
|
||||
else while (tokens.Count < maxLen)
|
||||
{
|
||||
tokens.Add("[PAD]");
|
||||
}
|
||||
|
||||
// Tokenize to IDs (simplified - would use actual vocabulary)
|
||||
var inputIds = tokens.Select(TokenToId).ToArray();
|
||||
|
||||
// Run inference
|
||||
var embedding = await RunInferenceAsync(inputIds, cancellationToken);
|
||||
|
||||
// Cache result
|
||||
if (_options.EnableCache)
|
||||
{
|
||||
await _cacheLock.WaitAsync(cancellationToken);
|
||||
try
|
||||
{
|
||||
_embeddingCache[cacheKey] = embedding;
|
||||
|
||||
// Evict if cache is too large
|
||||
if (_embeddingCache.Count > _options.MaxCacheSize)
|
||||
{
|
||||
var toRemove = _embeddingCache.Keys.First();
|
||||
_embeddingCache.Remove(toRemove);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
_cacheLock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
return embedding;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<IReadOnlyList<float[]>> GetEmbeddingsBatchAsync(
|
||||
IReadOnlyList<FunctionRepresentation> functions,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var results = new List<float[]>();
|
||||
|
||||
// Process in batches
|
||||
var batchSize = _options.BatchSize;
|
||||
for (var i = 0; i < functions.Count; i += batchSize)
|
||||
{
|
||||
var batch = functions.Skip(i).Take(batchSize);
|
||||
var batchResults = await Task.WhenAll(
|
||||
batch.Select(f => GetEmbeddingAsync(f, cancellationToken)));
|
||||
results.AddRange(batchResults);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public float ComputeSimilarity(float[] embedding1, float[] embedding2)
|
||||
{
|
||||
if (embedding1.Length != embedding2.Length)
|
||||
{
|
||||
throw new ArgumentException("Embeddings must have same dimension");
|
||||
}
|
||||
|
||||
// Cosine similarity
|
||||
var dot = Dot(embedding1, embedding2);
|
||||
var norm1 = MathF.Sqrt(Dot(embedding1, embedding1));
|
||||
var norm2 = MathF.Sqrt(Dot(embedding2, embedding2));
|
||||
|
||||
if (norm1 == 0 || norm2 == 0) return 0;
|
||||
|
||||
return dot / (norm1 * norm2);
|
||||
}
|
||||
|
||||
private static float Dot(float[] a, float[] b)
|
||||
{
|
||||
float sum = 0;
|
||||
for (int i = 0; i < a.Length; i++)
|
||||
{
|
||||
sum += a[i] * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<IReadOnlyList<EmbeddingSimilarityResult>> FindSimilarAsync(
|
||||
float[] queryEmbedding,
|
||||
int topK = 10,
|
||||
float threshold = 0.7f,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var results = new List<EmbeddingSimilarityResult>();
|
||||
|
||||
await _cacheLock.WaitAsync(cancellationToken);
|
||||
try
|
||||
{
|
||||
foreach (var (key, embedding) in _embeddingCache)
|
||||
{
|
||||
var similarity = ComputeSimilarity(queryEmbedding, embedding);
|
||||
if (similarity >= threshold)
|
||||
{
|
||||
var parts = key.Split(':');
|
||||
results.Add(new EmbeddingSimilarityResult
|
||||
{
|
||||
FunctionId = key,
|
||||
FunctionName = parts.Length > 2 ? parts[2] : key,
|
||||
LibraryName = parts.Length > 0 ? parts[0] : null,
|
||||
LibraryVersion = parts.Length > 1 ? parts[1] : null,
|
||||
Similarity = similarity
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
_cacheLock.Release();
|
||||
}
|
||||
|
||||
return results
|
||||
.OrderByDescending(r => r.Similarity)
|
||||
.Take(topK)
|
||||
.ToList();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public EmbeddingModelInfo GetModelInfo()
|
||||
{
|
||||
return new EmbeddingModelInfo
|
||||
{
|
||||
Name = _options.ModelName,
|
||||
Version = _options.ModelVersion,
|
||||
Dimension = _options.EmbeddingDimension,
|
||||
MaxSequenceLength = _options.MaxSequenceLength,
|
||||
IsLoaded = _modelLoaded
|
||||
};
|
||||
}
|
||||
|
||||
private Task EnsureModelLoadedAsync(CancellationToken ct)
|
||||
{
|
||||
if (_modelLoaded) return Task.CompletedTask;
|
||||
|
||||
if (string.IsNullOrEmpty(_options.ModelPath))
|
||||
{
|
||||
_logger.LogWarning("ONNX model path not configured, using placeholder embeddings");
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
_logger.LogInformation("Loading ONNX model from {Path}", _options.ModelPath);
|
||||
// Model loading would happen here - for now mark as loaded
|
||||
_modelLoaded = true;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private Task<float[]> RunInferenceAsync(long[] inputIds, CancellationToken ct)
|
||||
{
|
||||
// Return deterministic embedding based on input hash for testing
|
||||
var rng = new Random(inputIds.GetHashCode());
|
||||
var embedding = new float[_options.EmbeddingDimension];
|
||||
for (var i = 0; i < embedding.Length; i++)
|
||||
{
|
||||
embedding[i] = (float)(rng.NextDouble() * 2 - 1);
|
||||
}
|
||||
return Task.FromResult(embedding);
|
||||
}
|
||||
|
||||
private static long TokenToId(string token)
|
||||
{
|
||||
// Simplified tokenization - would use actual vocabulary
|
||||
return token.GetHashCode() & 0x7FFFFFFF;
|
||||
}
|
||||
|
||||
private static string GetCacheKey(FunctionRepresentation function)
|
||||
{
|
||||
return $"{function.LibraryName}:{function.LibraryVersion}:{function.FunctionName}";
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
_cacheLock.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for ONNX embedding service.
|
||||
/// </summary>
|
||||
public sealed record OnnxEmbeddingServiceOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the path to ONNX model.
|
||||
/// </summary>
|
||||
public string? ModelPath { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the model name.
|
||||
/// </summary>
|
||||
public string ModelName { get; init; } = "function-embeddings";
|
||||
|
||||
/// <summary>
|
||||
/// Gets the model version.
|
||||
/// </summary>
|
||||
public string ModelVersion { get; init; } = "1.0";
|
||||
|
||||
/// <summary>
|
||||
/// Gets the embedding dimension.
|
||||
/// </summary>
|
||||
public int EmbeddingDimension { get; init; } = 768;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the maximum sequence length.
|
||||
/// </summary>
|
||||
public int MaxSequenceLength { get; init; } = 512;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the batch size for inference.
|
||||
/// </summary>
|
||||
public int BatchSize { get; init; } = 16;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to enable caching.
|
||||
/// </summary>
|
||||
public bool EnableCache { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the maximum cache size.
|
||||
/// </summary>
|
||||
public int MaxCacheSize { get; init; } = 10000;
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
// -----------------------------------------------------------------------------
|
||||
// TrainingCorpusModels.cs
|
||||
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
// Task: MLEM-001 - Training Corpus Schema
|
||||
// Description: Schema definitions for ML training corpus.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML.Training;
|
||||
|
||||
/// <summary>
|
||||
/// A labeled function pair for ML training.
|
||||
/// </summary>
|
||||
public sealed record TrainingFunctionPair
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the unique pair identifier.
|
||||
/// </summary>
|
||||
public required string PairId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the first function.
|
||||
/// </summary>
|
||||
public required FunctionRepresentation Function1 { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the second function.
|
||||
/// </summary>
|
||||
public required FunctionRepresentation Function2 { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the equivalence label.
|
||||
/// </summary>
|
||||
public required EquivalenceLabel Label { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the confidence in the label (0.0 to 1.0).
|
||||
/// </summary>
|
||||
public double Confidence { get; init; } = 1.0;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the source of the ground-truth label.
|
||||
/// </summary>
|
||||
public required string Source { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets optional metadata about the pair.
|
||||
/// </summary>
|
||||
public TrainingPairMetadata? Metadata { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Representation of a function for training.
|
||||
/// </summary>
|
||||
public sealed record FunctionRepresentation
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the library name.
|
||||
/// </summary>
|
||||
public required string LibraryName { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the library version.
|
||||
/// </summary>
|
||||
public required string LibraryVersion { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the function name.
|
||||
/// </summary>
|
||||
public required string FunctionName { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the target architecture.
|
||||
/// </summary>
|
||||
public required string Architecture { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the IR tokens (for transformer input).
|
||||
/// </summary>
|
||||
public IReadOnlyList<string>? IrTokens { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the decompiled code.
|
||||
/// </summary>
|
||||
public string? DecompiledCode { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets computed fingerprints.
|
||||
/// </summary>
|
||||
public FunctionFingerprints? Fingerprints { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the function size in bytes.
|
||||
/// </summary>
|
||||
public int? SizeBytes { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the number of basic blocks.
|
||||
/// </summary>
|
||||
public int? BasicBlockCount { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the cyclomatic complexity.
|
||||
/// </summary>
|
||||
public int? CyclomaticComplexity { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Function fingerprints for training data.
|
||||
/// </summary>
|
||||
public sealed record FunctionFingerprints
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the instruction hash.
|
||||
/// </summary>
|
||||
public string? InstructionHash { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the CFG hash.
|
||||
/// </summary>
|
||||
public string? CfgHash { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the call graph hash.
|
||||
/// </summary>
|
||||
public string? CallGraphHash { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets mnemonic histogram.
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<string, int>? MnemonicHistogram { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Equivalence label for function pairs.
|
||||
/// </summary>
|
||||
[JsonConverter(typeof(JsonStringEnumConverter))]
|
||||
public enum EquivalenceLabel
|
||||
{
|
||||
/// <summary>
|
||||
/// Functions are equivalent (same semantics).
|
||||
/// </summary>
|
||||
Equivalent,
|
||||
|
||||
/// <summary>
|
||||
/// Functions are different (different semantics).
|
||||
/// </summary>
|
||||
Different,
|
||||
|
||||
/// <summary>
|
||||
/// Equivalence is unknown/uncertain.
|
||||
/// </summary>
|
||||
Unknown
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Metadata about a training pair.
|
||||
/// </summary>
|
||||
public sealed record TrainingPairMetadata
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the CVE ID if from a security pair.
|
||||
/// </summary>
|
||||
public string? CveId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the patch type.
|
||||
/// </summary>
|
||||
public string? PatchType { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether the function is patched.
|
||||
/// </summary>
|
||||
public bool IsPatched { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the distribution.
|
||||
/// </summary>
|
||||
public string? Distribution { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets additional tags.
|
||||
/// </summary>
|
||||
public IReadOnlyList<string>? Tags { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A training corpus containing labeled function pairs.
|
||||
/// </summary>
|
||||
public sealed record TrainingCorpus
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the corpus version.
|
||||
/// </summary>
|
||||
public required string Version { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets when the corpus was created.
|
||||
/// </summary>
|
||||
public required DateTimeOffset CreatedAt { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the corpus description.
|
||||
/// </summary>
|
||||
public string? Description { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the training pairs.
|
||||
/// </summary>
|
||||
public required IReadOnlyList<TrainingFunctionPair> TrainingPairs { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the validation pairs.
|
||||
/// </summary>
|
||||
public IReadOnlyList<TrainingFunctionPair>? ValidationPairs { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the test pairs.
|
||||
/// </summary>
|
||||
public IReadOnlyList<TrainingFunctionPair>? TestPairs { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets corpus statistics.
|
||||
/// </summary>
|
||||
public CorpusStatistics? Statistics { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Statistics about a training corpus.
|
||||
/// </summary>
|
||||
public sealed record CorpusStatistics
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets total pair count.
|
||||
/// </summary>
|
||||
public int TotalPairs { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets equivalent pair count.
|
||||
/// </summary>
|
||||
public int EquivalentPairs { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets different pair count.
|
||||
/// </summary>
|
||||
public int DifferentPairs { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets unknown pair count.
|
||||
/// </summary>
|
||||
public int UnknownPairs { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets unique libraries.
|
||||
/// </summary>
|
||||
public int UniqueLibraries { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets unique functions.
|
||||
/// </summary>
|
||||
public int UniqueFunctions { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets architectures covered.
|
||||
/// </summary>
|
||||
public IReadOnlyList<string>? Architectures { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Configuration for corpus splitting.
|
||||
/// </summary>
|
||||
public sealed record CorpusSplitConfig
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the training set ratio (default 0.8).
|
||||
/// </summary>
|
||||
public double TrainRatio { get; init; } = 0.8;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the validation set ratio (default 0.1).
|
||||
/// </summary>
|
||||
public double ValidationRatio { get; init; } = 0.1;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the test set ratio (default 0.1).
|
||||
/// </summary>
|
||||
public double TestRatio { get; init; } = 0.1;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the random seed for reproducibility.
|
||||
/// </summary>
|
||||
public int? RandomSeed { get; init; } = 42;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether to stratify by library.
|
||||
/// </summary>
|
||||
public bool StratifyByLibrary { get; init; } = true;
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
// -----------------------------------------------------------------------------
|
||||
// TrainingServiceCollectionExtensions.cs
|
||||
// Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
// Task: MLEM-007, MLEM-009 - DI Registration
|
||||
// Description: Dependency injection extensions for ML training services.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace StellaOps.BinaryIndex.ML.Training;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering ML training services.
|
||||
/// </summary>
|
||||
public static class TrainingServiceCollectionExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds ML training corpus services.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configureOptions">Configuration action.</param>
|
||||
/// <returns>The service collection for chaining.</returns>
|
||||
public static IServiceCollection AddMlTrainingCorpus(
|
||||
this IServiceCollection services,
|
||||
Action<MlTrainingOptions>? configureOptions = null)
|
||||
{
|
||||
// Register options
|
||||
services.AddOptions<GhidraAdapterOptions>();
|
||||
services.AddOptions<OnnxEmbeddingServiceOptions>();
|
||||
|
||||
if (configureOptions is not null)
|
||||
{
|
||||
var options = new MlTrainingOptions();
|
||||
configureOptions(options);
|
||||
|
||||
services.Configure<GhidraAdapterOptions>(o =>
|
||||
{
|
||||
o = options.GhidraOptions ?? new GhidraAdapterOptions();
|
||||
});
|
||||
|
||||
services.Configure<OnnxEmbeddingServiceOptions>(o =>
|
||||
{
|
||||
o = options.OnnxOptions ?? new OnnxEmbeddingServiceOptions();
|
||||
});
|
||||
}
|
||||
|
||||
// Register tokenizer and decompiler
|
||||
services.AddSingleton<IIrTokenizer, B2R2IrTokenizer>();
|
||||
services.AddSingleton<IDecompilerAdapter, GhidraDecompilerAdapter>();
|
||||
|
||||
// Register corpus builder
|
||||
services.AddSingleton<ICorpusBuilder, GroundTruthCorpusBuilder>();
|
||||
|
||||
// Register embedding service
|
||||
services.AddSingleton<IFunctionEmbeddingService, OnnxFunctionEmbeddingService>();
|
||||
|
||||
// Register matcher adapter
|
||||
services.AddSingleton<MlEmbeddingMatcherAdapter>();
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for ML training infrastructure.
|
||||
/// </summary>
|
||||
public sealed record MlTrainingOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets Ghidra adapter options.
|
||||
/// </summary>
|
||||
public GhidraAdapterOptions? GhidraOptions { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets ONNX embedding options.
|
||||
/// </summary>
|
||||
public OnnxEmbeddingServiceOptions? OnnxOptions { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets corpus build options.
|
||||
/// </summary>
|
||||
public CorpusBuildOptions? CorpusBuildOptions { get; set; }
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
#!/usr/bin/env python3
|
||||
# -----------------------------------------------------------------------------
|
||||
# train_function_embeddings.py
|
||||
# Sprint: SPRINT_20260119_006 ML Embeddings Corpus
|
||||
# Task: MLEM-005 - Embedding Model Training Pipeline
|
||||
# Description: PyTorch/HuggingFace training script for contrastive learning.
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Function Embedding Training Pipeline
|
||||
|
||||
Uses contrastive learning to train CodeBERT-based function embeddings.
|
||||
Positive pairs: Same function across versions
|
||||
Negative pairs: Different functions
|
||||
|
||||
Usage:
|
||||
python train_function_embeddings.py --corpus datasets/training_corpus.jsonl \
|
||||
--output models/function_embeddings.onnx \
|
||||
--epochs 10 --batch-size 32
|
||||
|
||||
Requirements:
|
||||
pip install torch transformers onnx onnxruntime tensorboard
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
try:
|
||||
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
except ImportError:
|
||||
print("Please install transformers: pip install transformers")
|
||||
raise
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Training configuration."""
|
||||
model_name: str = "microsoft/codebert-base"
|
||||
corpus_path: str = "datasets/training_corpus.jsonl"
|
||||
output_path: str = "models/function_embeddings"
|
||||
|
||||
# Training params
|
||||
epochs: int = 10
|
||||
batch_size: int = 32
|
||||
learning_rate: float = 2e-5
|
||||
warmup_steps: int = 500
|
||||
weight_decay: float = 0.01
|
||||
|
||||
# Contrastive learning params
|
||||
temperature: float = 0.07
|
||||
margin: float = 0.5
|
||||
|
||||
# Model params
|
||||
embedding_dim: int = 768
|
||||
max_seq_length: int = 512
|
||||
|
||||
# Misc
|
||||
seed: int = 42
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
log_dir: str = "runs/function_embeddings"
|
||||
|
||||
|
||||
class FunctionPairDataset(Dataset):
|
||||
"""Dataset for function pair contrastive learning."""
|
||||
|
||||
def __init__(self, corpus_path: str, tokenizer, max_length: int = 512):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.pairs = []
|
||||
|
||||
logger.info(f"Loading corpus from {corpus_path}")
|
||||
with open(corpus_path, 'r') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
pair = json.loads(line)
|
||||
self.pairs.append(pair)
|
||||
|
||||
logger.info(f"Loaded {len(self.pairs)} pairs")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
pair = self.pairs[idx]
|
||||
|
||||
# Get function representations
|
||||
func1 = pair.get("function1", {})
|
||||
func2 = pair.get("function2", {})
|
||||
|
||||
# Prefer decompiled code, fall back to IR tokens
|
||||
text1 = func1.get("decompiledCode") or " ".join(func1.get("irTokens", []))
|
||||
text2 = func2.get("decompiledCode") or " ".join(func2.get("irTokens", []))
|
||||
|
||||
# Tokenize
|
||||
enc1 = self.tokenizer(
|
||||
text1,
|
||||
max_length=self.max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt"
|
||||
)
|
||||
enc2 = self.tokenizer(
|
||||
text2,
|
||||
max_length=self.max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Label: 1 for equivalent, 0 for different
|
||||
label = 1.0 if pair.get("label") == "equivalent" else 0.0
|
||||
|
||||
return {
|
||||
"input_ids_1": enc1["input_ids"].squeeze(0),
|
||||
"attention_mask_1": enc1["attention_mask"].squeeze(0),
|
||||
"input_ids_2": enc2["input_ids"].squeeze(0),
|
||||
"attention_mask_2": enc2["attention_mask"].squeeze(0),
|
||||
"label": torch.tensor(label, dtype=torch.float)
|
||||
}
|
||||
|
||||
|
||||
class FunctionEmbeddingModel(nn.Module):
|
||||
"""CodeBERT-based function embedding model."""
|
||||
|
||||
def __init__(self, model_name: str, embedding_dim: int = 768):
|
||||
super().__init__()
|
||||
self.encoder = AutoModel.from_pretrained(model_name)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
# Projection head for contrastive learning
|
||||
self.projection = nn.Sequential(
|
||||
nn.Linear(self.encoder.config.hidden_size, embedding_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(embedding_dim, embedding_dim)
|
||||
)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute function embedding."""
|
||||
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# Use [CLS] token representation
|
||||
cls_output = outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
# Project to embedding space
|
||||
embedding = self.projection(cls_output)
|
||||
|
||||
# L2 normalize
|
||||
embedding = F.normalize(embedding, p=2, dim=1)
|
||||
|
||||
return embedding
|
||||
|
||||
def get_embedding(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Get embedding without projection (for inference)."""
|
||||
with torch.no_grad():
|
||||
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_output = outputs.last_hidden_state[:, 0, :]
|
||||
embedding = self.projection(cls_output)
|
||||
return F.normalize(embedding, p=2, dim=1)
|
||||
|
||||
|
||||
class ContrastiveLoss(nn.Module):
|
||||
"""Contrastive loss with temperature scaling."""
|
||||
|
||||
def __init__(self, temperature: float = 0.07, margin: float = 0.5):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.margin = margin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embedding1: torch.Tensor,
|
||||
embedding2: torch.Tensor,
|
||||
labels: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute contrastive loss.
|
||||
|
||||
Args:
|
||||
embedding1: First function embeddings [B, D]
|
||||
embedding2: Second function embeddings [B, D]
|
||||
labels: 1 for positive pairs, 0 for negative [B]
|
||||
|
||||
Returns:
|
||||
Contrastive loss value
|
||||
"""
|
||||
# Cosine similarity
|
||||
similarity = F.cosine_similarity(embedding1, embedding2) / self.temperature
|
||||
|
||||
# Contrastive loss
|
||||
# Positive pairs: minimize distance (maximize similarity)
|
||||
# Negative pairs: maximize distance (minimize similarity) up to margin
|
||||
pos_loss = labels * (1 - similarity)
|
||||
neg_loss = (1 - labels) * F.relu(similarity - self.margin)
|
||||
|
||||
loss = (pos_loss + neg_loss).mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def train_epoch(
|
||||
model: FunctionEmbeddingModel,
|
||||
dataloader: DataLoader,
|
||||
criterion: ContrastiveLoss,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
|
||||
device: str,
|
||||
epoch: int,
|
||||
writer: SummaryWriter
|
||||
) -> float:
|
||||
"""Train for one epoch."""
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
# Move to device
|
||||
input_ids_1 = batch["input_ids_1"].to(device)
|
||||
attention_mask_1 = batch["attention_mask_1"].to(device)
|
||||
input_ids_2 = batch["input_ids_2"].to(device)
|
||||
attention_mask_2 = batch["attention_mask_2"].to(device)
|
||||
labels = batch["label"].to(device)
|
||||
|
||||
# Forward pass
|
||||
emb1 = model(input_ids_1, attention_mask_1)
|
||||
emb2 = model(input_ids_2, attention_mask_2)
|
||||
|
||||
# Compute loss
|
||||
loss = criterion(emb1, emb2, labels)
|
||||
|
||||
# Backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# Log to tensorboard
|
||||
global_step = epoch * len(dataloader) + batch_idx
|
||||
writer.add_scalar("train/loss", loss.item(), global_step)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
logger.info(f"Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
|
||||
|
||||
return total_loss / len(dataloader)
|
||||
|
||||
|
||||
def evaluate(
|
||||
model: FunctionEmbeddingModel,
|
||||
dataloader: DataLoader,
|
||||
criterion: ContrastiveLoss,
|
||||
device: str
|
||||
) -> Tuple[float, float]:
|
||||
"""Evaluate model."""
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
input_ids_1 = batch["input_ids_1"].to(device)
|
||||
attention_mask_1 = batch["attention_mask_1"].to(device)
|
||||
input_ids_2 = batch["input_ids_2"].to(device)
|
||||
attention_mask_2 = batch["attention_mask_2"].to(device)
|
||||
labels = batch["label"].to(device)
|
||||
|
||||
emb1 = model(input_ids_1, attention_mask_1)
|
||||
emb2 = model(input_ids_2, attention_mask_2)
|
||||
|
||||
loss = criterion(emb1, emb2, labels)
|
||||
total_loss += loss.item()
|
||||
|
||||
# Accuracy: predict positive if similarity > 0.5
|
||||
similarity = F.cosine_similarity(emb1, emb2)
|
||||
predictions = (similarity > 0.5).float()
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
return avg_loss, accuracy
|
||||
|
||||
|
||||
def export_onnx(
|
||||
model: FunctionEmbeddingModel,
|
||||
output_path: str,
|
||||
max_seq_length: int = 512
|
||||
):
|
||||
"""Export model to ONNX format."""
|
||||
model.eval()
|
||||
|
||||
# Dummy inputs
|
||||
dummy_input_ids = torch.ones(1, max_seq_length, dtype=torch.long)
|
||||
dummy_attention_mask = torch.ones(1, max_seq_length, dtype=torch.long)
|
||||
|
||||
# Export
|
||||
output_file = f"{output_path}.onnx"
|
||||
logger.info(f"Exporting model to {output_file}")
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(dummy_input_ids, dummy_attention_mask),
|
||||
output_file,
|
||||
input_names=["input_ids", "attention_mask"],
|
||||
output_names=["embedding"],
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch_size"},
|
||||
"attention_mask": {0: "batch_size"},
|
||||
"embedding": {0: "batch_size"}
|
||||
},
|
||||
opset_version=14
|
||||
)
|
||||
|
||||
logger.info(f"Model exported to {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train function embedding model")
|
||||
parser.add_argument("--corpus", type=str, default="datasets/training_corpus.jsonl",
|
||||
help="Path to training corpus (JSONL format)")
|
||||
parser.add_argument("--output", type=str, default="models/function_embeddings",
|
||||
help="Output path for model")
|
||||
parser.add_argument("--model-name", type=str, default="microsoft/codebert-base",
|
||||
help="Base model name")
|
||||
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
|
||||
parser.add_argument("--batch-size", type=int, default=32, help="Batch size")
|
||||
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Config
|
||||
config = TrainingConfig(
|
||||
model_name=args.model_name,
|
||||
corpus_path=args.corpus,
|
||||
output_path=args.output,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.lr,
|
||||
seed=args.seed
|
||||
)
|
||||
|
||||
# Set seed
|
||||
random.seed(config.seed)
|
||||
torch.manual_seed(config.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(config.seed)
|
||||
|
||||
logger.info(f"Using device: {config.device}")
|
||||
|
||||
# Load tokenizer
|
||||
logger.info(f"Loading tokenizer: {config.model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
|
||||
# Create dataset
|
||||
dataset = FunctionPairDataset(config.corpus_path, tokenizer, config.max_seq_length)
|
||||
|
||||
# Split into train/val
|
||||
train_size = int(0.9 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
|
||||
|
||||
# Create model
|
||||
logger.info(f"Creating model: {config.model_name}")
|
||||
model = FunctionEmbeddingModel(config.model_name, config.embedding_dim)
|
||||
model.to(config.device)
|
||||
|
||||
# Loss and optimizer
|
||||
criterion = ContrastiveLoss(config.temperature, config.margin)
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
total_steps = len(train_loader) * config.epochs
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=config.warmup_steps,
|
||||
num_training_steps=total_steps
|
||||
)
|
||||
|
||||
# TensorBoard
|
||||
writer = SummaryWriter(config.log_dir)
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
|
||||
for epoch in range(config.epochs):
|
||||
logger.info(f"=== Epoch {epoch + 1}/{config.epochs} ===")
|
||||
|
||||
train_loss = train_epoch(
|
||||
model, train_loader, criterion, optimizer, scheduler,
|
||||
config.device, epoch, writer
|
||||
)
|
||||
|
||||
val_loss, val_accuracy = evaluate(model, val_loader, criterion, config.device)
|
||||
|
||||
logger.info(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
|
||||
|
||||
writer.add_scalar("val/loss", val_loss, epoch)
|
||||
writer.add_scalar("val/accuracy", val_accuracy, epoch)
|
||||
|
||||
# Save best model
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
|
||||
os.makedirs(config.output_path, exist_ok=True)
|
||||
|
||||
# Save PyTorch model
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'val_loss': val_loss,
|
||||
'val_accuracy': val_accuracy
|
||||
}, f"{config.output_path}/best_model.pt")
|
||||
|
||||
logger.info(f"Saved best model with val_loss: {val_loss:.4f}")
|
||||
|
||||
# Export to ONNX
|
||||
export_onnx(model, config.output_path, config.max_seq_length)
|
||||
|
||||
writer.close()
|
||||
logger.info("Training complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user