save progress
This commit is contained in:
@@ -0,0 +1,400 @@
|
||||
// Copyright (c) StellaOps. All rights reserved.
|
||||
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
||||
|
||||
using System.Collections.Immutable;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using NSubstitute;
|
||||
using StellaOps.BinaryIndex.Decompiler;
|
||||
using StellaOps.BinaryIndex.ML;
|
||||
using StellaOps.BinaryIndex.Semantic;
|
||||
using Xunit;
|
||||
|
||||
#pragma warning disable CS8625 // Suppress nullable warnings for test code
|
||||
#pragma warning disable CA1707 // Identifiers should not contain underscores
|
||||
|
||||
namespace StellaOps.BinaryIndex.Ensemble.Tests;
|
||||
|
||||
public class EnsembleDecisionEngineTests
|
||||
{
|
||||
private readonly IAstComparisonEngine _astEngine;
|
||||
private readonly ISemanticMatcher _semanticMatcher;
|
||||
private readonly IEmbeddingService _embeddingService;
|
||||
private readonly EnsembleDecisionEngine _engine;
|
||||
|
||||
public EnsembleDecisionEngineTests()
|
||||
{
|
||||
_astEngine = Substitute.For<IAstComparisonEngine>();
|
||||
_semanticMatcher = Substitute.For<ISemanticMatcher>();
|
||||
_embeddingService = Substitute.For<IEmbeddingService>();
|
||||
|
||||
var options = Options.Create(new EnsembleOptions());
|
||||
var logger = NullLogger<EnsembleDecisionEngine>.Instance;
|
||||
|
||||
_engine = new EnsembleDecisionEngine(
|
||||
_astEngine,
|
||||
_semanticMatcher,
|
||||
_embeddingService,
|
||||
options,
|
||||
logger);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompareAsync_WithExactHashMatch_ReturnsHighScore()
|
||||
{
|
||||
// Arrange
|
||||
var hash = new byte[] { 1, 2, 3, 4, 5 };
|
||||
var source = CreateAnalysis("func1", "test", hash);
|
||||
var target = CreateAnalysis("func2", "test", hash);
|
||||
|
||||
// Act
|
||||
var result = await _engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
Assert.True(result.ExactHashMatch);
|
||||
Assert.True(result.EnsembleScore >= 0.1m);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompareAsync_WithDifferentHashes_ComputesSignals()
|
||||
{
|
||||
// Arrange
|
||||
var source = CreateAnalysis("func1", "test1", new byte[] { 1, 2, 3 });
|
||||
var target = CreateAnalysis("func2", "test2", new byte[] { 4, 5, 6 });
|
||||
|
||||
// Act
|
||||
var result = await _engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
Assert.False(result.ExactHashMatch);
|
||||
Assert.NotEmpty(result.Contributions);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompareAsync_WithNoSignals_ReturnsZeroScore()
|
||||
{
|
||||
// Arrange
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "test1"
|
||||
};
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "test2"
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await _engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(0m, result.EnsembleScore);
|
||||
Assert.Equal(ConfidenceLevel.VeryLow, result.Confidence);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompareAsync_WithAstOnly_UsesAstSignal()
|
||||
{
|
||||
// Arrange
|
||||
var ast1 = CreateSimpleAst("func1");
|
||||
var ast2 = CreateSimpleAst("func2");
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "test1",
|
||||
Ast = ast1
|
||||
};
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "test2",
|
||||
Ast = ast2
|
||||
};
|
||||
|
||||
_astEngine.ComputeStructuralSimilarity(ast1, ast2).Returns(0.9m);
|
||||
|
||||
// Act
|
||||
var result = await _engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
var syntacticContrib = result.Contributions.FirstOrDefault(c => c.SignalType == SignalType.Syntactic);
|
||||
Assert.NotNull(syntacticContrib);
|
||||
Assert.True(syntacticContrib.IsAvailable);
|
||||
Assert.Equal(0.9m, syntacticContrib.RawScore);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompareAsync_WithEmbeddingOnly_UsesEmbeddingSignal()
|
||||
{
|
||||
// Arrange
|
||||
var emb1 = CreateEmbedding("func1");
|
||||
var emb2 = CreateEmbedding("func2");
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "test1",
|
||||
Embedding = emb1
|
||||
};
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "test2",
|
||||
Embedding = emb2
|
||||
};
|
||||
|
||||
_embeddingService.ComputeSimilarity(emb1, emb2, SimilarityMetric.Cosine).Returns(0.85m);
|
||||
|
||||
// Act
|
||||
var result = await _engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
var embeddingContrib = result.Contributions.FirstOrDefault(c => c.SignalType == SignalType.Embedding);
|
||||
Assert.NotNull(embeddingContrib);
|
||||
Assert.True(embeddingContrib.IsAvailable);
|
||||
Assert.Equal(0.85m, embeddingContrib.RawScore);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompareAsync_WithSemanticGraphOnly_UsesSemanticSignal()
|
||||
{
|
||||
// Arrange
|
||||
var graph1 = CreateSemanticGraph("func1");
|
||||
var graph2 = CreateSemanticGraph("func2");
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "test1",
|
||||
SemanticGraph = graph1
|
||||
};
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "test2",
|
||||
SemanticGraph = graph2
|
||||
};
|
||||
|
||||
_semanticMatcher.ComputeGraphSimilarityAsync(graph1, graph2, Arg.Any<CancellationToken>())
|
||||
.Returns(Task.FromResult(0.8m));
|
||||
|
||||
// Act
|
||||
var result = await _engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
var semanticContrib = result.Contributions.FirstOrDefault(c => c.SignalType == SignalType.Semantic);
|
||||
Assert.NotNull(semanticContrib);
|
||||
Assert.True(semanticContrib.IsAvailable);
|
||||
Assert.Equal(0.8m, semanticContrib.RawScore);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompareAsync_WithAllSignals_CombinesCorrectly()
|
||||
{
|
||||
// Arrange
|
||||
var ast1 = CreateSimpleAst("func1");
|
||||
var ast2 = CreateSimpleAst("func2");
|
||||
var emb1 = CreateEmbedding("func1");
|
||||
var emb2 = CreateEmbedding("func2");
|
||||
var graph1 = CreateSemanticGraph("func1");
|
||||
var graph2 = CreateSemanticGraph("func2");
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "test1",
|
||||
Ast = ast1,
|
||||
Embedding = emb1,
|
||||
SemanticGraph = graph1
|
||||
};
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "test2",
|
||||
Ast = ast2,
|
||||
Embedding = emb2,
|
||||
SemanticGraph = graph2
|
||||
};
|
||||
|
||||
_astEngine.ComputeStructuralSimilarity(ast1, ast2).Returns(0.9m);
|
||||
_embeddingService.ComputeSimilarity(emb1, emb2, SimilarityMetric.Cosine).Returns(0.85m);
|
||||
_semanticMatcher.ComputeGraphSimilarityAsync(graph1, graph2, Arg.Any<CancellationToken>())
|
||||
.Returns(Task.FromResult(0.8m));
|
||||
|
||||
// Act
|
||||
var result = await _engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(3, result.Contributions.Count(c => c.IsAvailable));
|
||||
Assert.True(result.EnsembleScore > 0.8m);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompareAsync_AboveThreshold_IsMatch()
|
||||
{
|
||||
// Arrange
|
||||
var ast1 = CreateSimpleAst("func1");
|
||||
var ast2 = CreateSimpleAst("func2");
|
||||
var emb1 = CreateEmbedding("func1");
|
||||
var emb2 = CreateEmbedding("func2");
|
||||
var graph1 = CreateSemanticGraph("func1");
|
||||
var graph2 = CreateSemanticGraph("func2");
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "test1",
|
||||
Ast = ast1,
|
||||
Embedding = emb1,
|
||||
SemanticGraph = graph1
|
||||
};
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "test2",
|
||||
Ast = ast2,
|
||||
Embedding = emb2,
|
||||
SemanticGraph = graph2
|
||||
};
|
||||
|
||||
// All high scores
|
||||
_astEngine.ComputeStructuralSimilarity(ast1, ast2).Returns(0.95m);
|
||||
_embeddingService.ComputeSimilarity(emb1, emb2, SimilarityMetric.Cosine).Returns(0.9m);
|
||||
_semanticMatcher.ComputeGraphSimilarityAsync(graph1, graph2, Arg.Any<CancellationToken>())
|
||||
.Returns(Task.FromResult(0.92m));
|
||||
|
||||
// Act
|
||||
var result = await _engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
Assert.True(result.IsMatch);
|
||||
Assert.True(result.Confidence >= ConfidenceLevel.Medium);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompareAsync_BelowThreshold_IsNotMatch()
|
||||
{
|
||||
// Arrange
|
||||
var ast1 = CreateSimpleAst("func1");
|
||||
var ast2 = CreateSimpleAst("func2");
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "test1",
|
||||
Ast = ast1
|
||||
};
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "test2",
|
||||
Ast = ast2
|
||||
};
|
||||
|
||||
_astEngine.ComputeStructuralSimilarity(ast1, ast2).Returns(0.3m);
|
||||
|
||||
// Act
|
||||
var result = await _engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
Assert.False(result.IsMatch);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task FindMatchesAsync_ReturnsOrderedByScore()
|
||||
{
|
||||
// Arrange
|
||||
var query = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "query",
|
||||
FunctionName = "query"
|
||||
};
|
||||
|
||||
var corpus = new[]
|
||||
{
|
||||
CreateAnalysis("func1", "test1", new byte[] { 1 }),
|
||||
CreateAnalysis("func2", "test2", new byte[] { 2 }),
|
||||
CreateAnalysis("func3", "test3", new byte[] { 3 })
|
||||
};
|
||||
|
||||
var options = new EnsembleOptions { MaxCandidates = 10, MinimumSignalThreshold = 0m };
|
||||
|
||||
// Act
|
||||
var results = await _engine.FindMatchesAsync(query, corpus, options);
|
||||
|
||||
// Assert
|
||||
Assert.NotEmpty(results);
|
||||
for (var i = 1; i < results.Length; i++)
|
||||
{
|
||||
Assert.True(results[i - 1].EnsembleScore >= results[i].EnsembleScore);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompareBatchAsync_ReturnsStatistics()
|
||||
{
|
||||
// Arrange
|
||||
var sources = new[] { CreateAnalysis("s1", "source1", new byte[] { 1 }) };
|
||||
var targets = new[]
|
||||
{
|
||||
CreateAnalysis("t1", "target1", new byte[] { 1 }),
|
||||
CreateAnalysis("t2", "target2", new byte[] { 2 })
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await _engine.CompareBatchAsync(sources, targets);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(2, result.Statistics.TotalComparisons);
|
||||
Assert.NotEmpty(result.Results);
|
||||
Assert.True(result.Duration > TimeSpan.Zero);
|
||||
}
|
||||
|
||||
private static FunctionAnalysis CreateAnalysis(string id, string name, byte[] hash)
|
||||
{
|
||||
return new FunctionAnalysis
|
||||
{
|
||||
FunctionId = id,
|
||||
FunctionName = name,
|
||||
NormalizedCodeHash = hash
|
||||
};
|
||||
}
|
||||
|
||||
private static DecompiledAst CreateSimpleAst(string name)
|
||||
{
|
||||
var root = new BlockNode([]);
|
||||
return new DecompiledAst(root, 1, 1, ImmutableArray<AstPattern>.Empty);
|
||||
}
|
||||
|
||||
private static FunctionEmbedding CreateEmbedding(string id)
|
||||
{
|
||||
return new FunctionEmbedding(
|
||||
id,
|
||||
id,
|
||||
new float[768],
|
||||
EmbeddingModel.CodeBertBinary,
|
||||
EmbeddingInputType.DecompiledCode,
|
||||
DateTimeOffset.UtcNow);
|
||||
}
|
||||
|
||||
private static KeySemanticsGraph CreateSemanticGraph(string name)
|
||||
{
|
||||
var props = new GraphProperties(
|
||||
NodeCount: 5,
|
||||
EdgeCount: 4,
|
||||
CyclomaticComplexity: 2,
|
||||
MaxDepth: 3,
|
||||
NodeTypeCounts: ImmutableDictionary<SemanticNodeType, int>.Empty,
|
||||
EdgeTypeCounts: ImmutableDictionary<SemanticEdgeType, int>.Empty,
|
||||
LoopCount: 1,
|
||||
BranchCount: 1);
|
||||
|
||||
return new KeySemanticsGraph(
|
||||
name,
|
||||
ImmutableArray<SemanticNode>.Empty,
|
||||
ImmutableArray<SemanticEdge>.Empty,
|
||||
props);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
// Copyright (c) StellaOps. All rights reserved.
|
||||
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
||||
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.BinaryIndex.Ensemble.Tests;
|
||||
|
||||
public class EnsembleOptionsTests
|
||||
{
|
||||
[Fact]
|
||||
public void AreWeightsValid_WithValidWeights_ReturnsTrue()
|
||||
{
|
||||
// Arrange
|
||||
var options = new EnsembleOptions
|
||||
{
|
||||
SyntacticWeight = 0.25m,
|
||||
SemanticWeight = 0.35m,
|
||||
EmbeddingWeight = 0.40m
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
Assert.True(options.AreWeightsValid());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AreWeightsValid_WithInvalidWeights_ReturnsFalse()
|
||||
{
|
||||
// Arrange
|
||||
var options = new EnsembleOptions
|
||||
{
|
||||
SyntacticWeight = 0.50m,
|
||||
SemanticWeight = 0.50m,
|
||||
EmbeddingWeight = 0.50m
|
||||
};
|
||||
|
||||
// Act & Assert
|
||||
Assert.False(options.AreWeightsValid());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void NormalizeWeights_NormalizesToOne()
|
||||
{
|
||||
// Arrange
|
||||
var options = new EnsembleOptions
|
||||
{
|
||||
SyntacticWeight = 1m,
|
||||
SemanticWeight = 2m,
|
||||
EmbeddingWeight = 2m
|
||||
};
|
||||
|
||||
// Act
|
||||
options.NormalizeWeights();
|
||||
|
||||
// Assert
|
||||
var sum = options.SyntacticWeight + options.SemanticWeight + options.EmbeddingWeight;
|
||||
Assert.True(Math.Abs(sum - 1.0m) < 0.001m);
|
||||
Assert.Equal(0.2m, options.SyntacticWeight);
|
||||
Assert.Equal(0.4m, options.SemanticWeight);
|
||||
Assert.Equal(0.4m, options.EmbeddingWeight);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void NormalizeWeights_WithZeroWeights_HandlesGracefully()
|
||||
{
|
||||
// Arrange
|
||||
var options = new EnsembleOptions
|
||||
{
|
||||
SyntacticWeight = 0m,
|
||||
SemanticWeight = 0m,
|
||||
EmbeddingWeight = 0m
|
||||
};
|
||||
|
||||
// Act
|
||||
options.NormalizeWeights();
|
||||
|
||||
// Assert (should not throw, weights stay at 0)
|
||||
Assert.Equal(0m, options.SyntacticWeight);
|
||||
Assert.Equal(0m, options.SemanticWeight);
|
||||
Assert.Equal(0m, options.EmbeddingWeight);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void DefaultOptions_HaveValidWeights()
|
||||
{
|
||||
// Arrange
|
||||
var options = new EnsembleOptions();
|
||||
|
||||
// Assert
|
||||
Assert.True(options.AreWeightsValid());
|
||||
Assert.Equal(0.25m, options.SyntacticWeight);
|
||||
Assert.Equal(0.35m, options.SemanticWeight);
|
||||
Assert.Equal(0.40m, options.EmbeddingWeight);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void DefaultOptions_HaveReasonableThreshold()
|
||||
{
|
||||
// Arrange
|
||||
var options = new EnsembleOptions();
|
||||
|
||||
// Assert
|
||||
Assert.Equal(0.85m, options.MatchThreshold);
|
||||
Assert.True(options.MatchThreshold > 0.5m);
|
||||
Assert.True(options.MatchThreshold < 1.0m);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void DefaultOptions_UseExactHashMatch()
|
||||
{
|
||||
// Arrange
|
||||
var options = new EnsembleOptions();
|
||||
|
||||
// Assert
|
||||
Assert.True(options.UseExactHashMatch);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void DefaultOptions_UseAdaptiveWeights()
|
||||
{
|
||||
// Arrange
|
||||
var options = new EnsembleOptions();
|
||||
|
||||
// Assert
|
||||
Assert.True(options.AdaptiveWeights);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,570 @@
|
||||
// Copyright (c) StellaOps. All rights reserved.
|
||||
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
||||
|
||||
using System.Collections.Immutable;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Time.Testing;
|
||||
using StellaOps.BinaryIndex.Decompiler;
|
||||
using StellaOps.BinaryIndex.ML;
|
||||
using StellaOps.BinaryIndex.Semantic;
|
||||
using Xunit;
|
||||
|
||||
#pragma warning disable CS8625 // Suppress nullable warnings for test code
|
||||
#pragma warning disable CA1707 // Identifiers should not contain underscores
|
||||
|
||||
namespace StellaOps.BinaryIndex.Ensemble.Tests.Integration;
|
||||
|
||||
/// <summary>
|
||||
/// Integration tests for the full semantic diffing pipeline.
|
||||
/// These tests wire up real implementations to verify end-to-end functionality.
|
||||
/// </summary>
|
||||
[Trait("Category", "Integration")]
|
||||
public class SemanticDiffingPipelineTests : IAsyncDisposable
|
||||
{
|
||||
private readonly ServiceProvider _serviceProvider;
|
||||
private readonly FakeTimeProvider _timeProvider;
|
||||
|
||||
public SemanticDiffingPipelineTests()
|
||||
{
|
||||
_timeProvider = new FakeTimeProvider(new DateTimeOffset(2026, 1, 5, 12, 0, 0, TimeSpan.Zero));
|
||||
|
||||
var services = new ServiceCollection();
|
||||
|
||||
// Add logging
|
||||
services.AddLogging(builder => builder.AddDebug().SetMinimumLevel(LogLevel.Debug));
|
||||
|
||||
// Add time provider
|
||||
services.AddSingleton<TimeProvider>(_timeProvider);
|
||||
|
||||
// Add all binary similarity services
|
||||
services.AddBinarySimilarityServices();
|
||||
|
||||
_serviceProvider = services.BuildServiceProvider();
|
||||
}
|
||||
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
await _serviceProvider.DisposeAsync();
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_WithIdenticalCode_ReturnsHighSimilarity()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
var parser = _serviceProvider.GetRequiredService<IDecompiledCodeParser>();
|
||||
var embeddingService = _serviceProvider.GetRequiredService<IEmbeddingService>();
|
||||
|
||||
var code = """
|
||||
int calculate_sum(int* arr, int len) {
|
||||
int sum = 0;
|
||||
for (int i = 0; i < len; i++) {
|
||||
sum += arr[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
""";
|
||||
|
||||
var ast = parser.Parse(code);
|
||||
var emb = await embeddingService.GenerateEmbeddingAsync(
|
||||
new EmbeddingInput(code, null, null, EmbeddingInputType.DecompiledCode));
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "calculate_sum",
|
||||
DecompiledCode = code,
|
||||
NormalizedCodeHash = System.Security.Cryptography.SHA256.HashData(
|
||||
System.Text.Encoding.UTF8.GetBytes(code)),
|
||||
Ast = ast,
|
||||
Embedding = emb
|
||||
};
|
||||
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "calculate_sum",
|
||||
DecompiledCode = code,
|
||||
NormalizedCodeHash = System.Security.Cryptography.SHA256.HashData(
|
||||
System.Text.Encoding.UTF8.GetBytes(code)),
|
||||
Ast = ast,
|
||||
Embedding = emb
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
// With identical AST and embedding, plus exact hash match, should be very high
|
||||
Assert.True(result.EnsembleScore >= 0.5m,
|
||||
$"Expected high similarity for identical code with AST/embedding, got {result.EnsembleScore}");
|
||||
Assert.True(result.ExactHashMatch);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_WithSimilarCode_ReturnsModeratelySimilarity()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
var parser = _serviceProvider.GetRequiredService<IDecompiledCodeParser>();
|
||||
var embeddingService = _serviceProvider.GetRequiredService<IEmbeddingService>();
|
||||
|
||||
var code1 = """
|
||||
int calculate_sum(int* arr, int len) {
|
||||
int sum = 0;
|
||||
for (int i = 0; i < len; i++) {
|
||||
sum += arr[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
""";
|
||||
|
||||
var code2 = """
|
||||
int compute_total(int* data, int count) {
|
||||
int total = 0;
|
||||
for (int j = 0; j < count; j++) {
|
||||
total = total + data[j];
|
||||
}
|
||||
return total;
|
||||
}
|
||||
""";
|
||||
|
||||
var ast1 = parser.Parse(code1);
|
||||
var ast2 = parser.Parse(code2);
|
||||
var emb1 = await embeddingService.GenerateEmbeddingAsync(
|
||||
new EmbeddingInput(code1, null, null, EmbeddingInputType.DecompiledCode));
|
||||
var emb2 = await embeddingService.GenerateEmbeddingAsync(
|
||||
new EmbeddingInput(code2, null, null, EmbeddingInputType.DecompiledCode));
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "calculate_sum",
|
||||
DecompiledCode = code1,
|
||||
NormalizedCodeHash = System.Security.Cryptography.SHA256.HashData(
|
||||
System.Text.Encoding.UTF8.GetBytes(code1)),
|
||||
Ast = ast1,
|
||||
Embedding = emb1
|
||||
};
|
||||
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "compute_total",
|
||||
DecompiledCode = code2,
|
||||
NormalizedCodeHash = System.Security.Cryptography.SHA256.HashData(
|
||||
System.Text.Encoding.UTF8.GetBytes(code2)),
|
||||
Ast = ast2,
|
||||
Embedding = emb2
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
// With different but structurally similar code, should have some signal
|
||||
Assert.NotEmpty(result.Contributions);
|
||||
var availableSignals = result.Contributions.Count(c => c.IsAvailable);
|
||||
Assert.True(availableSignals >= 1, $"Expected at least 1 available signal, got {availableSignals}");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_WithDifferentCode_ReturnsLowSimilarity()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
|
||||
var source = CreateFunctionAnalysis("func1", """
|
||||
int calculate_sum(int* arr, int len) {
|
||||
int sum = 0;
|
||||
for (int i = 0; i < len; i++) {
|
||||
sum += arr[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
""");
|
||||
|
||||
var target = CreateFunctionAnalysis("func2", """
|
||||
void print_string(char* str) {
|
||||
while (*str != '\0') {
|
||||
putchar(*str);
|
||||
str++;
|
||||
}
|
||||
}
|
||||
""");
|
||||
|
||||
// Act
|
||||
var result = await engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
Assert.True(result.EnsembleScore < 0.7m,
|
||||
$"Expected low similarity for different code, got {result.EnsembleScore}");
|
||||
Assert.False(result.IsMatch);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_WithExactHashMatch_ReturnsHighScoreImmediately()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
var hash = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "test1",
|
||||
NormalizedCodeHash = hash
|
||||
};
|
||||
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "test2",
|
||||
NormalizedCodeHash = hash
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
Assert.True(result.ExactHashMatch);
|
||||
Assert.True(result.EnsembleScore >= 0.1m);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_BatchComparison_ReturnsStatistics()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
|
||||
var sources = new[]
|
||||
{
|
||||
CreateFunctionAnalysis("s1", "int add(int a, int b) { return a + b; }"),
|
||||
CreateFunctionAnalysis("s2", "int sub(int a, int b) { return a - b; }")
|
||||
};
|
||||
|
||||
var targets = new[]
|
||||
{
|
||||
CreateFunctionAnalysis("t1", "int add(int x, int y) { return x + y; }"),
|
||||
CreateFunctionAnalysis("t2", "int mul(int a, int b) { return a * b; }"),
|
||||
CreateFunctionAnalysis("t3", "int div(int a, int b) { return a / b; }")
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await engine.CompareBatchAsync(sources, targets);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(6, result.Statistics.TotalComparisons); // 2 x 3 = 6
|
||||
Assert.NotEmpty(result.Results);
|
||||
Assert.True(result.Duration > TimeSpan.Zero);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_FindMatches_ReturnsOrderedResults()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
|
||||
var query = CreateFunctionAnalysis("query", """
|
||||
int square(int x) {
|
||||
return x * x;
|
||||
}
|
||||
""");
|
||||
|
||||
var corpus = new[]
|
||||
{
|
||||
CreateFunctionAnalysis("f1", "int square(int n) { return n * n; }"), // Similar
|
||||
CreateFunctionAnalysis("f2", "int cube(int x) { return x * x * x; }"), // Somewhat similar
|
||||
CreateFunctionAnalysis("f3", "void print(char* s) { puts(s); }") // Different
|
||||
};
|
||||
|
||||
var options = new EnsembleOptions { MaxCandidates = 10, MinimumSignalThreshold = 0m };
|
||||
|
||||
// Act
|
||||
var results = await engine.FindMatchesAsync(query, corpus, options);
|
||||
|
||||
// Assert
|
||||
Assert.NotEmpty(results);
|
||||
|
||||
// Results should be ordered by score descending
|
||||
for (var i = 1; i < results.Length; i++)
|
||||
{
|
||||
Assert.True(results[i - 1].EnsembleScore >= results[i].EnsembleScore,
|
||||
$"Results not ordered: {results[i - 1].EnsembleScore} should be >= {results[i].EnsembleScore}");
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_WithAstOnly_ComputesSyntacticSignal()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
var astEngine = _serviceProvider.GetRequiredService<IAstComparisonEngine>();
|
||||
var parser = _serviceProvider.GetRequiredService<IDecompiledCodeParser>();
|
||||
|
||||
var code1 = "int foo(int x) { return x + 1; }";
|
||||
var code2 = "int bar(int y) { return y + 2; }";
|
||||
|
||||
var ast1 = parser.Parse(code1);
|
||||
var ast2 = parser.Parse(code2);
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "foo",
|
||||
Ast = ast1
|
||||
};
|
||||
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "bar",
|
||||
Ast = ast2
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
var syntacticContrib = result.Contributions.FirstOrDefault(c => c.SignalType == SignalType.Syntactic);
|
||||
Assert.NotNull(syntacticContrib);
|
||||
Assert.True(syntacticContrib.IsAvailable);
|
||||
Assert.True(syntacticContrib.RawScore >= 0m);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_WithEmbeddingOnly_ComputesEmbeddingSignal()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
var embeddingService = _serviceProvider.GetRequiredService<IEmbeddingService>();
|
||||
|
||||
var emb1 = await embeddingService.GenerateEmbeddingAsync(
|
||||
new EmbeddingInput(
|
||||
DecompiledCode: "int add(int a, int b) { return a + b; }",
|
||||
SemanticGraph: null,
|
||||
InstructionBytes: null,
|
||||
PreferredInput: EmbeddingInputType.DecompiledCode));
|
||||
|
||||
var emb2 = await embeddingService.GenerateEmbeddingAsync(
|
||||
new EmbeddingInput(
|
||||
DecompiledCode: "int sum(int x, int y) { return x + y; }",
|
||||
SemanticGraph: null,
|
||||
InstructionBytes: null,
|
||||
PreferredInput: EmbeddingInputType.DecompiledCode));
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "add",
|
||||
Embedding = emb1
|
||||
};
|
||||
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "sum",
|
||||
Embedding = emb2
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
var embeddingContrib = result.Contributions.FirstOrDefault(c => c.SignalType == SignalType.Embedding);
|
||||
Assert.NotNull(embeddingContrib);
|
||||
Assert.True(embeddingContrib.IsAvailable);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_WithSemanticGraphOnly_ComputesSemanticSignal()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
|
||||
var graph1 = CreateSemanticGraph("func1", 5, 4);
|
||||
var graph2 = CreateSemanticGraph("func2", 5, 4);
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "test1",
|
||||
SemanticGraph = graph1
|
||||
};
|
||||
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "test2",
|
||||
SemanticGraph = graph2
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
var semanticContrib = result.Contributions.FirstOrDefault(c => c.SignalType == SignalType.Semantic);
|
||||
Assert.NotNull(semanticContrib);
|
||||
Assert.True(semanticContrib.IsAvailable);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_WithAllSignals_CombinesWeightedContributions()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
var parser = _serviceProvider.GetRequiredService<IDecompiledCodeParser>();
|
||||
var embeddingService = _serviceProvider.GetRequiredService<IEmbeddingService>();
|
||||
|
||||
var code1 = "int multiply(int a, int b) { return a * b; }";
|
||||
var code2 = "int mult(int x, int y) { return x * y; }";
|
||||
|
||||
var ast1 = parser.Parse(code1);
|
||||
var ast2 = parser.Parse(code2);
|
||||
|
||||
var emb1 = await embeddingService.GenerateEmbeddingAsync(
|
||||
new EmbeddingInput(code1, null, null, EmbeddingInputType.DecompiledCode));
|
||||
var emb2 = await embeddingService.GenerateEmbeddingAsync(
|
||||
new EmbeddingInput(code2, null, null, EmbeddingInputType.DecompiledCode));
|
||||
|
||||
var graph1 = CreateSemanticGraph("multiply", 4, 3);
|
||||
var graph2 = CreateSemanticGraph("mult", 4, 3);
|
||||
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "multiply",
|
||||
Ast = ast1,
|
||||
Embedding = emb1,
|
||||
SemanticGraph = graph1
|
||||
};
|
||||
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "mult",
|
||||
Ast = ast2,
|
||||
Embedding = emb2,
|
||||
SemanticGraph = graph2
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await engine.CompareAsync(source, target);
|
||||
|
||||
// Assert
|
||||
var availableSignals = result.Contributions.Count(c => c.IsAvailable);
|
||||
Assert.True(availableSignals >= 2, $"Expected at least 2 available signals, got {availableSignals}");
|
||||
|
||||
// Verify weighted contributions sum correctly
|
||||
var totalWeight = result.Contributions
|
||||
.Where(c => c.IsAvailable)
|
||||
.Sum(c => c.Weight);
|
||||
Assert.True(Math.Abs(totalWeight - 1.0m) < 0.01m || totalWeight == 0m,
|
||||
$"Weights should sum to 1.0 (or 0 if no signals), got {totalWeight}");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_ConfidenceLevel_ReflectsSignalAvailability()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
|
||||
// Create minimal analysis with only hash
|
||||
var source = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func1",
|
||||
FunctionName = "test1"
|
||||
};
|
||||
|
||||
var target = new FunctionAnalysis
|
||||
{
|
||||
FunctionId = "func2",
|
||||
FunctionName = "test2"
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await engine.CompareAsync(source, target);
|
||||
|
||||
// Assert - with no signals, confidence should be very low
|
||||
Assert.Equal(ConfidenceLevel.VeryLow, result.Confidence);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Pipeline_WithCustomOptions_RespectsThreshold()
|
||||
{
|
||||
// Arrange
|
||||
var engine = _serviceProvider.GetRequiredService<IEnsembleDecisionEngine>();
|
||||
|
||||
var source = CreateFunctionAnalysis("func1", "int a(int x) { return x; }");
|
||||
var target = CreateFunctionAnalysis("func2", "int b(int y) { return y; }");
|
||||
|
||||
var strictOptions = new EnsembleOptions { MatchThreshold = 0.99m };
|
||||
var lenientOptions = new EnsembleOptions { MatchThreshold = 0.1m };
|
||||
|
||||
// Act
|
||||
var strictResult = await engine.CompareAsync(source, target, strictOptions);
|
||||
var lenientResult = await engine.CompareAsync(source, target, lenientOptions);
|
||||
|
||||
// Assert - same comparison, different thresholds
|
||||
Assert.Equal(strictResult.EnsembleScore, lenientResult.EnsembleScore);
|
||||
|
||||
// With very strict threshold, unlikely to be a match
|
||||
// With very lenient threshold, likely to be a match
|
||||
Assert.True(lenientResult.IsMatch || strictResult.EnsembleScore < 0.1m);
|
||||
}
|
||||
|
||||
private static FunctionAnalysis CreateFunctionAnalysis(string id, string code)
|
||||
{
|
||||
return new FunctionAnalysis
|
||||
{
|
||||
FunctionId = id,
|
||||
FunctionName = id,
|
||||
DecompiledCode = code,
|
||||
NormalizedCodeHash = System.Security.Cryptography.SHA256.HashData(
|
||||
System.Text.Encoding.UTF8.GetBytes(code))
|
||||
};
|
||||
}
|
||||
|
||||
private static KeySemanticsGraph CreateSemanticGraph(string name, int nodeCount, int edgeCount)
|
||||
{
|
||||
var nodes = new List<SemanticNode>();
|
||||
var edges = new List<SemanticEdge>();
|
||||
|
||||
for (var i = 0; i < nodeCount; i++)
|
||||
{
|
||||
nodes.Add(new SemanticNode(
|
||||
Id: i,
|
||||
Type: SemanticNodeType.Compute,
|
||||
Operation: $"op_{i}",
|
||||
Operands: ImmutableArray<string>.Empty,
|
||||
Attributes: ImmutableDictionary<string, string>.Empty));
|
||||
}
|
||||
|
||||
for (var i = 0; i < edgeCount && i < nodeCount - 1; i++)
|
||||
{
|
||||
edges.Add(new SemanticEdge(
|
||||
SourceId: i,
|
||||
TargetId: i + 1,
|
||||
Type: SemanticEdgeType.DataDependency,
|
||||
Label: $"edge_{i}"));
|
||||
}
|
||||
|
||||
var props = new GraphProperties(
|
||||
NodeCount: nodeCount,
|
||||
EdgeCount: edgeCount,
|
||||
CyclomaticComplexity: 2,
|
||||
MaxDepth: 3,
|
||||
NodeTypeCounts: ImmutableDictionary<SemanticNodeType, int>.Empty,
|
||||
EdgeTypeCounts: ImmutableDictionary<SemanticEdgeType, int>.Empty,
|
||||
LoopCount: 1,
|
||||
BranchCount: 1);
|
||||
|
||||
return new KeySemanticsGraph(
|
||||
name,
|
||||
[.. nodes],
|
||||
[.. edges],
|
||||
props);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright (c) StellaOps. All rights reserved. -->
|
||||
<!-- Licensed under AGPL-3.0-or-later. See LICENSE in the project root. -->
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net10.0</TargetFramework>
|
||||
<LangVersion>preview</LangVersion>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<IsPackable>false</IsPackable>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
<NoWarn>$(NoWarn);xUnit1051</NoWarn>
|
||||
<RootNamespace>StellaOps.BinaryIndex.Ensemble.Tests</RootNamespace>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="FluentAssertions" />
|
||||
<PackageReference Include="NSubstitute" />
|
||||
<PackageReference Include="Microsoft.Extensions.DependencyInjection" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging" />
|
||||
<PackageReference Include="Microsoft.Extensions.TimeProvider.Testing" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\__Libraries\StellaOps.BinaryIndex.Ensemble\StellaOps.BinaryIndex.Ensemble.csproj" />
|
||||
<ProjectReference Include="..\..\__Libraries\StellaOps.BinaryIndex.Decompiler\StellaOps.BinaryIndex.Decompiler.csproj" />
|
||||
<ProjectReference Include="..\..\__Libraries\StellaOps.BinaryIndex.ML\StellaOps.BinaryIndex.ML.csproj" />
|
||||
<ProjectReference Include="..\..\__Libraries\StellaOps.BinaryIndex.Semantic\StellaOps.BinaryIndex.Semantic.csproj" />
|
||||
<ProjectReference Include="..\..\..\__Libraries\StellaOps.TestKit\StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
@@ -0,0 +1,238 @@
|
||||
// Copyright (c) StellaOps. All rights reserved.
|
||||
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
||||
|
||||
using System.Collections.Immutable;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using NSubstitute;
|
||||
using StellaOps.BinaryIndex.Semantic;
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.BinaryIndex.Ensemble.Tests;
|
||||
|
||||
public class WeightTuningServiceTests
|
||||
{
|
||||
private readonly IEnsembleDecisionEngine _decisionEngine;
|
||||
private readonly WeightTuningService _service;
|
||||
|
||||
public WeightTuningServiceTests()
|
||||
{
|
||||
_decisionEngine = Substitute.For<IEnsembleDecisionEngine>();
|
||||
var logger = NullLogger<WeightTuningService>.Instance;
|
||||
_service = new WeightTuningService(_decisionEngine, logger);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TuneWeightsAsync_WithValidPairs_ReturnsBestWeights()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = CreateTrainingPairs(5);
|
||||
|
||||
_decisionEngine.CompareAsync(
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<EnsembleOptions>(),
|
||||
Arg.Any<CancellationToken>())
|
||||
.Returns(callInfo =>
|
||||
{
|
||||
var opts = callInfo.Arg<EnsembleOptions>();
|
||||
return Task.FromResult(new EnsembleResult
|
||||
{
|
||||
SourceFunctionId = "s",
|
||||
TargetFunctionId = "t",
|
||||
EnsembleScore = opts.SyntacticWeight * 0.9m + opts.SemanticWeight * 0.8m + opts.EmbeddingWeight * 0.85m,
|
||||
Contributions = ImmutableArray<SignalContribution>.Empty,
|
||||
IsMatch = true,
|
||||
Confidence = ConfidenceLevel.High
|
||||
});
|
||||
});
|
||||
|
||||
// Act
|
||||
var result = await _service.TuneWeightsAsync(pairs, gridStep: 0.25m);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(result);
|
||||
Assert.True(result.BestWeights.Syntactic >= 0);
|
||||
Assert.True(result.BestWeights.Semantic >= 0);
|
||||
Assert.True(result.BestWeights.Embedding >= 0);
|
||||
Assert.NotEmpty(result.Evaluations);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TuneWeightsAsync_WeightsSumToOne()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = CreateTrainingPairs(3);
|
||||
|
||||
_decisionEngine.CompareAsync(
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<EnsembleOptions>(),
|
||||
Arg.Any<CancellationToken>())
|
||||
.Returns(Task.FromResult(new EnsembleResult
|
||||
{
|
||||
SourceFunctionId = "s",
|
||||
TargetFunctionId = "t",
|
||||
EnsembleScore = 0.9m,
|
||||
Contributions = ImmutableArray<SignalContribution>.Empty,
|
||||
IsMatch = true,
|
||||
Confidence = ConfidenceLevel.High
|
||||
}));
|
||||
|
||||
// Act
|
||||
var result = await _service.TuneWeightsAsync(pairs, gridStep: 0.5m);
|
||||
|
||||
// Assert
|
||||
var sum = result.BestWeights.Syntactic + result.BestWeights.Semantic + result.BestWeights.Embedding;
|
||||
Assert.True(Math.Abs(sum - 1.0m) < 0.01m);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TuneWeightsAsync_WithInvalidStep_ThrowsException()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = CreateTrainingPairs(1);
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<ArgumentOutOfRangeException>(
|
||||
() => _service.TuneWeightsAsync(pairs, gridStep: 0));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TuneWeightsAsync_WithNoPairs_ThrowsException()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = Array.Empty<EnsembleTrainingPair>();
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<ArgumentException>(
|
||||
() => _service.TuneWeightsAsync(pairs));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task EvaluateWeightsAsync_ComputesMetrics()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = new List<EnsembleTrainingPair>
|
||||
{
|
||||
new()
|
||||
{
|
||||
Function1 = CreateAnalysis("f1"),
|
||||
Function2 = CreateAnalysis("f2"),
|
||||
IsEquivalent = true
|
||||
},
|
||||
new()
|
||||
{
|
||||
Function1 = CreateAnalysis("f3"),
|
||||
Function2 = CreateAnalysis("f4"),
|
||||
IsEquivalent = false
|
||||
}
|
||||
};
|
||||
|
||||
var weights = new EffectiveWeights(0.33m, 0.33m, 0.34m);
|
||||
|
||||
// Simulate decision engine returning matching for first pair
|
||||
_decisionEngine.CompareAsync(
|
||||
pairs[0].Function1,
|
||||
pairs[0].Function2,
|
||||
Arg.Any<EnsembleOptions>(),
|
||||
Arg.Any<CancellationToken>())
|
||||
.Returns(Task.FromResult(new EnsembleResult
|
||||
{
|
||||
SourceFunctionId = "f1",
|
||||
TargetFunctionId = "f2",
|
||||
EnsembleScore = 0.9m,
|
||||
Contributions = ImmutableArray<SignalContribution>.Empty,
|
||||
IsMatch = true,
|
||||
Confidence = ConfidenceLevel.High
|
||||
}));
|
||||
|
||||
// Non-matching for second pair
|
||||
_decisionEngine.CompareAsync(
|
||||
pairs[1].Function1,
|
||||
pairs[1].Function2,
|
||||
Arg.Any<EnsembleOptions>(),
|
||||
Arg.Any<CancellationToken>())
|
||||
.Returns(Task.FromResult(new EnsembleResult
|
||||
{
|
||||
SourceFunctionId = "f3",
|
||||
TargetFunctionId = "f4",
|
||||
EnsembleScore = 0.3m,
|
||||
Contributions = ImmutableArray<SignalContribution>.Empty,
|
||||
IsMatch = false,
|
||||
Confidence = ConfidenceLevel.Low
|
||||
}));
|
||||
|
||||
// Act
|
||||
var result = await _service.EvaluateWeightsAsync(weights, pairs);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(weights, result.Weights);
|
||||
Assert.Equal(1.0m, result.Accuracy); // Both predictions correct
|
||||
Assert.Equal(1.0m, result.Precision); // TP / (TP + FP) = 1 / 1
|
||||
Assert.Equal(1.0m, result.Recall); // TP / (TP + FN) = 1 / 1
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task EvaluateWeightsAsync_WithFalsePositive_LowersPrecision()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = new List<EnsembleTrainingPair>
|
||||
{
|
||||
new()
|
||||
{
|
||||
Function1 = CreateAnalysis("f1"),
|
||||
Function2 = CreateAnalysis("f2"),
|
||||
IsEquivalent = false // Ground truth: NOT equivalent
|
||||
}
|
||||
};
|
||||
|
||||
var weights = new EffectiveWeights(0.33m, 0.33m, 0.34m);
|
||||
|
||||
// But engine says it IS a match (false positive)
|
||||
_decisionEngine.CompareAsync(
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<EnsembleOptions>(),
|
||||
Arg.Any<CancellationToken>())
|
||||
.Returns(Task.FromResult(new EnsembleResult
|
||||
{
|
||||
SourceFunctionId = "f1",
|
||||
TargetFunctionId = "f2",
|
||||
EnsembleScore = 0.9m,
|
||||
Contributions = ImmutableArray<SignalContribution>.Empty,
|
||||
IsMatch = true, // False positive!
|
||||
Confidence = ConfidenceLevel.High
|
||||
}));
|
||||
|
||||
// Act
|
||||
var result = await _service.EvaluateWeightsAsync(weights, pairs);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(0m, result.Accuracy); // 0 correct out of 1
|
||||
Assert.Equal(0m, result.Precision); // 0 true positives
|
||||
}
|
||||
|
||||
private static List<EnsembleTrainingPair> CreateTrainingPairs(int count)
|
||||
{
|
||||
var pairs = new List<EnsembleTrainingPair>();
|
||||
for (var i = 0; i < count; i++)
|
||||
{
|
||||
pairs.Add(new EnsembleTrainingPair
|
||||
{
|
||||
Function1 = CreateAnalysis($"func{i}a"),
|
||||
Function2 = CreateAnalysis($"func{i}b"),
|
||||
IsEquivalent = i % 2 == 0
|
||||
});
|
||||
}
|
||||
return pairs;
|
||||
}
|
||||
|
||||
private static FunctionAnalysis CreateAnalysis(string id)
|
||||
{
|
||||
return new FunctionAnalysis
|
||||
{
|
||||
FunctionId = id,
|
||||
FunctionName = id
|
||||
};
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user