// Copyright (c) StellaOps. All rights reserved.
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
using System.Collections.Immutable;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Time.Testing;
using StellaOps.BinaryIndex.Decompiler;
using StellaOps.BinaryIndex.ML;
using StellaOps.BinaryIndex.Semantic;
using Xunit;
#pragma warning disable CS8625 // Suppress nullable warnings for test code
#pragma warning disable CA1707 // Identifiers should not contain underscores
namespace StellaOps.BinaryIndex.Ensemble.Tests.Integration;
///
/// Integration tests for the full semantic diffing pipeline.
/// These tests wire up real implementations to verify end-to-end functionality.
///
[Trait("Category", "Integration")]
public class SemanticDiffingPipelineTests : IAsyncDisposable
{
private readonly ServiceProvider _serviceProvider;
private readonly FakeTimeProvider _timeProvider;
public SemanticDiffingPipelineTests()
{
_timeProvider = new FakeTimeProvider(new DateTimeOffset(2026, 1, 5, 12, 0, 0, TimeSpan.Zero));
var services = new ServiceCollection();
// Add logging
services.AddLogging(builder => builder.AddDebug().SetMinimumLevel(LogLevel.Debug));
// Add time provider
services.AddSingleton(_timeProvider);
// Add all binary similarity services
services.AddBinarySimilarityServices();
_serviceProvider = services.BuildServiceProvider();
}
public async ValueTask DisposeAsync()
{
await _serviceProvider.DisposeAsync();
GC.SuppressFinalize(this);
}
[Fact]
public async Task Pipeline_WithIdenticalCode_ReturnsHighSimilarity()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
var parser = _serviceProvider.GetRequiredService();
var embeddingService = _serviceProvider.GetRequiredService();
var code = """
int calculate_sum(int* arr, int len) {
int sum = 0;
for (int i = 0; i < len; i++) {
sum += arr[i];
}
return sum;
}
""";
var ast = parser.Parse(code);
var emb = await embeddingService.GenerateEmbeddingAsync(
new EmbeddingInput(code, null, null, EmbeddingInputType.DecompiledCode));
var source = new FunctionAnalysis
{
FunctionId = "func1",
FunctionName = "calculate_sum",
DecompiledCode = code,
NormalizedCodeHash = System.Security.Cryptography.SHA256.HashData(
System.Text.Encoding.UTF8.GetBytes(code)),
Ast = ast,
Embedding = emb
};
var target = new FunctionAnalysis
{
FunctionId = "func2",
FunctionName = "calculate_sum",
DecompiledCode = code,
NormalizedCodeHash = System.Security.Cryptography.SHA256.HashData(
System.Text.Encoding.UTF8.GetBytes(code)),
Ast = ast,
Embedding = emb
};
// Act
var result = await engine.CompareAsync(source, target);
// Assert
// With identical AST and embedding, plus exact hash match, should be very high
Assert.True(result.EnsembleScore >= 0.5m,
$"Expected high similarity for identical code with AST/embedding, got {result.EnsembleScore}");
Assert.True(result.ExactHashMatch);
}
[Fact]
public async Task Pipeline_WithSimilarCode_ReturnsModeratelySimilarity()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
var parser = _serviceProvider.GetRequiredService();
var embeddingService = _serviceProvider.GetRequiredService();
var code1 = """
int calculate_sum(int* arr, int len) {
int sum = 0;
for (int i = 0; i < len; i++) {
sum += arr[i];
}
return sum;
}
""";
var code2 = """
int compute_total(int* data, int count) {
int total = 0;
for (int j = 0; j < count; j++) {
total = total + data[j];
}
return total;
}
""";
var ast1 = parser.Parse(code1);
var ast2 = parser.Parse(code2);
var emb1 = await embeddingService.GenerateEmbeddingAsync(
new EmbeddingInput(code1, null, null, EmbeddingInputType.DecompiledCode));
var emb2 = await embeddingService.GenerateEmbeddingAsync(
new EmbeddingInput(code2, null, null, EmbeddingInputType.DecompiledCode));
var source = new FunctionAnalysis
{
FunctionId = "func1",
FunctionName = "calculate_sum",
DecompiledCode = code1,
NormalizedCodeHash = System.Security.Cryptography.SHA256.HashData(
System.Text.Encoding.UTF8.GetBytes(code1)),
Ast = ast1,
Embedding = emb1
};
var target = new FunctionAnalysis
{
FunctionId = "func2",
FunctionName = "compute_total",
DecompiledCode = code2,
NormalizedCodeHash = System.Security.Cryptography.SHA256.HashData(
System.Text.Encoding.UTF8.GetBytes(code2)),
Ast = ast2,
Embedding = emb2
};
// Act
var result = await engine.CompareAsync(source, target);
// Assert
// With different but structurally similar code, should have some signal
Assert.NotEmpty(result.Contributions);
var availableSignals = result.Contributions.Count(c => c.IsAvailable);
Assert.True(availableSignals >= 1, $"Expected at least 1 available signal, got {availableSignals}");
}
[Fact]
public async Task Pipeline_WithDifferentCode_ReturnsLowSimilarity()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
var source = CreateFunctionAnalysis("func1", """
int calculate_sum(int* arr, int len) {
int sum = 0;
for (int i = 0; i < len; i++) {
sum += arr[i];
}
return sum;
}
""");
var target = CreateFunctionAnalysis("func2", """
void print_string(char* str) {
while (*str != '\0') {
putchar(*str);
str++;
}
}
""");
// Act
var result = await engine.CompareAsync(source, target);
// Assert
Assert.True(result.EnsembleScore < 0.7m,
$"Expected low similarity for different code, got {result.EnsembleScore}");
Assert.False(result.IsMatch);
}
[Fact]
public async Task Pipeline_WithExactHashMatch_ReturnsHighScoreImmediately()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
var hash = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };
var source = new FunctionAnalysis
{
FunctionId = "func1",
FunctionName = "test1",
NormalizedCodeHash = hash
};
var target = new FunctionAnalysis
{
FunctionId = "func2",
FunctionName = "test2",
NormalizedCodeHash = hash
};
// Act
var result = await engine.CompareAsync(source, target);
// Assert
Assert.True(result.ExactHashMatch);
Assert.True(result.EnsembleScore >= 0.1m);
}
[Fact]
public async Task Pipeline_BatchComparison_ReturnsStatistics()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
var sources = new[]
{
CreateFunctionAnalysis("s1", "int add(int a, int b) { return a + b; }"),
CreateFunctionAnalysis("s2", "int sub(int a, int b) { return a - b; }")
};
var targets = new[]
{
CreateFunctionAnalysis("t1", "int add(int x, int y) { return x + y; }"),
CreateFunctionAnalysis("t2", "int mul(int a, int b) { return a * b; }"),
CreateFunctionAnalysis("t3", "int div(int a, int b) { return a / b; }")
};
// Act
var result = await engine.CompareBatchAsync(sources, targets);
// Assert
Assert.Equal(6, result.Statistics.TotalComparisons); // 2 x 3 = 6
Assert.NotEmpty(result.Results);
Assert.True(result.Duration > TimeSpan.Zero);
}
[Fact]
public async Task Pipeline_FindMatches_ReturnsOrderedResults()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
var query = CreateFunctionAnalysis("query", """
int square(int x) {
return x * x;
}
""");
var corpus = new[]
{
CreateFunctionAnalysis("f1", "int square(int n) { return n * n; }"), // Similar
CreateFunctionAnalysis("f2", "int cube(int x) { return x * x * x; }"), // Somewhat similar
CreateFunctionAnalysis("f3", "void print(char* s) { puts(s); }") // Different
};
var options = new EnsembleOptions { MaxCandidates = 10, MinimumSignalThreshold = 0m };
// Act
var results = await engine.FindMatchesAsync(query, corpus, options);
// Assert
Assert.NotEmpty(results);
// Results should be ordered by score descending
for (var i = 1; i < results.Length; i++)
{
Assert.True(results[i - 1].EnsembleScore >= results[i].EnsembleScore,
$"Results not ordered: {results[i - 1].EnsembleScore} should be >= {results[i].EnsembleScore}");
}
}
[Fact]
public async Task Pipeline_WithAstOnly_ComputesSyntacticSignal()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
var astEngine = _serviceProvider.GetRequiredService();
var parser = _serviceProvider.GetRequiredService();
var code1 = "int foo(int x) { return x + 1; }";
var code2 = "int bar(int y) { return y + 2; }";
var ast1 = parser.Parse(code1);
var ast2 = parser.Parse(code2);
var source = new FunctionAnalysis
{
FunctionId = "func1",
FunctionName = "foo",
Ast = ast1
};
var target = new FunctionAnalysis
{
FunctionId = "func2",
FunctionName = "bar",
Ast = ast2
};
// Act
var result = await engine.CompareAsync(source, target);
// Assert
var syntacticContrib = result.Contributions.FirstOrDefault(c => c.SignalType == SignalType.Syntactic);
Assert.NotNull(syntacticContrib);
Assert.True(syntacticContrib.IsAvailable);
Assert.True(syntacticContrib.RawScore >= 0m);
}
[Fact]
public async Task Pipeline_WithEmbeddingOnly_ComputesEmbeddingSignal()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
var embeddingService = _serviceProvider.GetRequiredService();
var emb1 = await embeddingService.GenerateEmbeddingAsync(
new EmbeddingInput(
DecompiledCode: "int add(int a, int b) { return a + b; }",
SemanticGraph: null,
InstructionBytes: null,
PreferredInput: EmbeddingInputType.DecompiledCode));
var emb2 = await embeddingService.GenerateEmbeddingAsync(
new EmbeddingInput(
DecompiledCode: "int sum(int x, int y) { return x + y; }",
SemanticGraph: null,
InstructionBytes: null,
PreferredInput: EmbeddingInputType.DecompiledCode));
var source = new FunctionAnalysis
{
FunctionId = "func1",
FunctionName = "add",
Embedding = emb1
};
var target = new FunctionAnalysis
{
FunctionId = "func2",
FunctionName = "sum",
Embedding = emb2
};
// Act
var result = await engine.CompareAsync(source, target);
// Assert
var embeddingContrib = result.Contributions.FirstOrDefault(c => c.SignalType == SignalType.Embedding);
Assert.NotNull(embeddingContrib);
Assert.True(embeddingContrib.IsAvailable);
}
[Fact]
public async Task Pipeline_WithSemanticGraphOnly_ComputesSemanticSignal()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
var graph1 = CreateSemanticGraph("func1", 5, 4);
var graph2 = CreateSemanticGraph("func2", 5, 4);
var source = new FunctionAnalysis
{
FunctionId = "func1",
FunctionName = "test1",
SemanticGraph = graph1
};
var target = new FunctionAnalysis
{
FunctionId = "func2",
FunctionName = "test2",
SemanticGraph = graph2
};
// Act
var result = await engine.CompareAsync(source, target);
// Assert
var semanticContrib = result.Contributions.FirstOrDefault(c => c.SignalType == SignalType.Semantic);
Assert.NotNull(semanticContrib);
Assert.True(semanticContrib.IsAvailable);
}
[Fact]
public async Task Pipeline_WithAllSignals_CombinesWeightedContributions()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
var parser = _serviceProvider.GetRequiredService();
var embeddingService = _serviceProvider.GetRequiredService();
var code1 = "int multiply(int a, int b) { return a * b; }";
var code2 = "int mult(int x, int y) { return x * y; }";
var ast1 = parser.Parse(code1);
var ast2 = parser.Parse(code2);
var emb1 = await embeddingService.GenerateEmbeddingAsync(
new EmbeddingInput(code1, null, null, EmbeddingInputType.DecompiledCode));
var emb2 = await embeddingService.GenerateEmbeddingAsync(
new EmbeddingInput(code2, null, null, EmbeddingInputType.DecompiledCode));
var graph1 = CreateSemanticGraph("multiply", 4, 3);
var graph2 = CreateSemanticGraph("mult", 4, 3);
var source = new FunctionAnalysis
{
FunctionId = "func1",
FunctionName = "multiply",
Ast = ast1,
Embedding = emb1,
SemanticGraph = graph1
};
var target = new FunctionAnalysis
{
FunctionId = "func2",
FunctionName = "mult",
Ast = ast2,
Embedding = emb2,
SemanticGraph = graph2
};
// Act
var result = await engine.CompareAsync(source, target);
// Assert
var availableSignals = result.Contributions.Count(c => c.IsAvailable);
Assert.True(availableSignals >= 2, $"Expected at least 2 available signals, got {availableSignals}");
// Verify weighted contributions sum correctly
var totalWeight = result.Contributions
.Where(c => c.IsAvailable)
.Sum(c => c.Weight);
Assert.True(Math.Abs(totalWeight - 1.0m) < 0.01m || totalWeight == 0m,
$"Weights should sum to 1.0 (or 0 if no signals), got {totalWeight}");
}
[Fact]
public async Task Pipeline_ConfidenceLevel_ReflectsSignalAvailability()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
// Create minimal analysis with only hash
var source = new FunctionAnalysis
{
FunctionId = "func1",
FunctionName = "test1"
};
var target = new FunctionAnalysis
{
FunctionId = "func2",
FunctionName = "test2"
};
// Act
var result = await engine.CompareAsync(source, target);
// Assert - with no signals, confidence should be very low
Assert.Equal(ConfidenceLevel.VeryLow, result.Confidence);
}
[Fact]
public async Task Pipeline_WithCustomOptions_RespectsThreshold()
{
// Arrange
var engine = _serviceProvider.GetRequiredService();
var source = CreateFunctionAnalysis("func1", "int a(int x) { return x; }");
var target = CreateFunctionAnalysis("func2", "int b(int y) { return y; }");
var strictOptions = new EnsembleOptions { MatchThreshold = 0.99m };
var lenientOptions = new EnsembleOptions { MatchThreshold = 0.1m };
// Act
var strictResult = await engine.CompareAsync(source, target, strictOptions);
var lenientResult = await engine.CompareAsync(source, target, lenientOptions);
// Assert - same comparison, different thresholds
Assert.Equal(strictResult.EnsembleScore, lenientResult.EnsembleScore);
// With very strict threshold, unlikely to be a match
// With very lenient threshold, likely to be a match
Assert.True(lenientResult.IsMatch || strictResult.EnsembleScore < 0.1m);
}
private static FunctionAnalysis CreateFunctionAnalysis(string id, string code)
{
return new FunctionAnalysis
{
FunctionId = id,
FunctionName = id,
DecompiledCode = code,
NormalizedCodeHash = System.Security.Cryptography.SHA256.HashData(
System.Text.Encoding.UTF8.GetBytes(code))
};
}
private static KeySemanticsGraph CreateSemanticGraph(string name, int nodeCount, int edgeCount)
{
var nodes = new List();
var edges = new List();
for (var i = 0; i < nodeCount; i++)
{
nodes.Add(new SemanticNode(
Id: i,
Type: SemanticNodeType.Compute,
Operation: $"op_{i}",
Operands: ImmutableArray.Empty,
Attributes: ImmutableDictionary.Empty));
}
for (var i = 0; i < edgeCount && i < nodeCount - 1; i++)
{
edges.Add(new SemanticEdge(
SourceId: i,
TargetId: i + 1,
Type: SemanticEdgeType.DataDependency,
Label: $"edge_{i}"));
}
var props = new GraphProperties(
NodeCount: nodeCount,
EdgeCount: edgeCount,
CyclomaticComplexity: 2,
MaxDepth: 3,
NodeTypeCounts: ImmutableDictionary.Empty,
EdgeTypeCounts: ImmutableDictionary.Empty,
LoopCount: 1,
BranchCount: 1);
return new KeySemanticsGraph(
name,
[.. nodes],
[.. edges],
props);
}
}