save progress
This commit is contained in:
@@ -0,0 +1,238 @@
|
||||
// Copyright (c) StellaOps. All rights reserved.
|
||||
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
||||
|
||||
using System.Collections.Immutable;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using NSubstitute;
|
||||
using StellaOps.BinaryIndex.Semantic;
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.BinaryIndex.Ensemble.Tests;
|
||||
|
||||
public class WeightTuningServiceTests
|
||||
{
|
||||
private readonly IEnsembleDecisionEngine _decisionEngine;
|
||||
private readonly WeightTuningService _service;
|
||||
|
||||
public WeightTuningServiceTests()
|
||||
{
|
||||
_decisionEngine = Substitute.For<IEnsembleDecisionEngine>();
|
||||
var logger = NullLogger<WeightTuningService>.Instance;
|
||||
_service = new WeightTuningService(_decisionEngine, logger);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TuneWeightsAsync_WithValidPairs_ReturnsBestWeights()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = CreateTrainingPairs(5);
|
||||
|
||||
_decisionEngine.CompareAsync(
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<EnsembleOptions>(),
|
||||
Arg.Any<CancellationToken>())
|
||||
.Returns(callInfo =>
|
||||
{
|
||||
var opts = callInfo.Arg<EnsembleOptions>();
|
||||
return Task.FromResult(new EnsembleResult
|
||||
{
|
||||
SourceFunctionId = "s",
|
||||
TargetFunctionId = "t",
|
||||
EnsembleScore = opts.SyntacticWeight * 0.9m + opts.SemanticWeight * 0.8m + opts.EmbeddingWeight * 0.85m,
|
||||
Contributions = ImmutableArray<SignalContribution>.Empty,
|
||||
IsMatch = true,
|
||||
Confidence = ConfidenceLevel.High
|
||||
});
|
||||
});
|
||||
|
||||
// Act
|
||||
var result = await _service.TuneWeightsAsync(pairs, gridStep: 0.25m);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(result);
|
||||
Assert.True(result.BestWeights.Syntactic >= 0);
|
||||
Assert.True(result.BestWeights.Semantic >= 0);
|
||||
Assert.True(result.BestWeights.Embedding >= 0);
|
||||
Assert.NotEmpty(result.Evaluations);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TuneWeightsAsync_WeightsSumToOne()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = CreateTrainingPairs(3);
|
||||
|
||||
_decisionEngine.CompareAsync(
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<EnsembleOptions>(),
|
||||
Arg.Any<CancellationToken>())
|
||||
.Returns(Task.FromResult(new EnsembleResult
|
||||
{
|
||||
SourceFunctionId = "s",
|
||||
TargetFunctionId = "t",
|
||||
EnsembleScore = 0.9m,
|
||||
Contributions = ImmutableArray<SignalContribution>.Empty,
|
||||
IsMatch = true,
|
||||
Confidence = ConfidenceLevel.High
|
||||
}));
|
||||
|
||||
// Act
|
||||
var result = await _service.TuneWeightsAsync(pairs, gridStep: 0.5m);
|
||||
|
||||
// Assert
|
||||
var sum = result.BestWeights.Syntactic + result.BestWeights.Semantic + result.BestWeights.Embedding;
|
||||
Assert.True(Math.Abs(sum - 1.0m) < 0.01m);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TuneWeightsAsync_WithInvalidStep_ThrowsException()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = CreateTrainingPairs(1);
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<ArgumentOutOfRangeException>(
|
||||
() => _service.TuneWeightsAsync(pairs, gridStep: 0));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TuneWeightsAsync_WithNoPairs_ThrowsException()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = Array.Empty<EnsembleTrainingPair>();
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<ArgumentException>(
|
||||
() => _service.TuneWeightsAsync(pairs));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task EvaluateWeightsAsync_ComputesMetrics()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = new List<EnsembleTrainingPair>
|
||||
{
|
||||
new()
|
||||
{
|
||||
Function1 = CreateAnalysis("f1"),
|
||||
Function2 = CreateAnalysis("f2"),
|
||||
IsEquivalent = true
|
||||
},
|
||||
new()
|
||||
{
|
||||
Function1 = CreateAnalysis("f3"),
|
||||
Function2 = CreateAnalysis("f4"),
|
||||
IsEquivalent = false
|
||||
}
|
||||
};
|
||||
|
||||
var weights = new EffectiveWeights(0.33m, 0.33m, 0.34m);
|
||||
|
||||
// Simulate decision engine returning matching for first pair
|
||||
_decisionEngine.CompareAsync(
|
||||
pairs[0].Function1,
|
||||
pairs[0].Function2,
|
||||
Arg.Any<EnsembleOptions>(),
|
||||
Arg.Any<CancellationToken>())
|
||||
.Returns(Task.FromResult(new EnsembleResult
|
||||
{
|
||||
SourceFunctionId = "f1",
|
||||
TargetFunctionId = "f2",
|
||||
EnsembleScore = 0.9m,
|
||||
Contributions = ImmutableArray<SignalContribution>.Empty,
|
||||
IsMatch = true,
|
||||
Confidence = ConfidenceLevel.High
|
||||
}));
|
||||
|
||||
// Non-matching for second pair
|
||||
_decisionEngine.CompareAsync(
|
||||
pairs[1].Function1,
|
||||
pairs[1].Function2,
|
||||
Arg.Any<EnsembleOptions>(),
|
||||
Arg.Any<CancellationToken>())
|
||||
.Returns(Task.FromResult(new EnsembleResult
|
||||
{
|
||||
SourceFunctionId = "f3",
|
||||
TargetFunctionId = "f4",
|
||||
EnsembleScore = 0.3m,
|
||||
Contributions = ImmutableArray<SignalContribution>.Empty,
|
||||
IsMatch = false,
|
||||
Confidence = ConfidenceLevel.Low
|
||||
}));
|
||||
|
||||
// Act
|
||||
var result = await _service.EvaluateWeightsAsync(weights, pairs);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(weights, result.Weights);
|
||||
Assert.Equal(1.0m, result.Accuracy); // Both predictions correct
|
||||
Assert.Equal(1.0m, result.Precision); // TP / (TP + FP) = 1 / 1
|
||||
Assert.Equal(1.0m, result.Recall); // TP / (TP + FN) = 1 / 1
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task EvaluateWeightsAsync_WithFalsePositive_LowersPrecision()
|
||||
{
|
||||
// Arrange
|
||||
var pairs = new List<EnsembleTrainingPair>
|
||||
{
|
||||
new()
|
||||
{
|
||||
Function1 = CreateAnalysis("f1"),
|
||||
Function2 = CreateAnalysis("f2"),
|
||||
IsEquivalent = false // Ground truth: NOT equivalent
|
||||
}
|
||||
};
|
||||
|
||||
var weights = new EffectiveWeights(0.33m, 0.33m, 0.34m);
|
||||
|
||||
// But engine says it IS a match (false positive)
|
||||
_decisionEngine.CompareAsync(
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<FunctionAnalysis>(),
|
||||
Arg.Any<EnsembleOptions>(),
|
||||
Arg.Any<CancellationToken>())
|
||||
.Returns(Task.FromResult(new EnsembleResult
|
||||
{
|
||||
SourceFunctionId = "f1",
|
||||
TargetFunctionId = "f2",
|
||||
EnsembleScore = 0.9m,
|
||||
Contributions = ImmutableArray<SignalContribution>.Empty,
|
||||
IsMatch = true, // False positive!
|
||||
Confidence = ConfidenceLevel.High
|
||||
}));
|
||||
|
||||
// Act
|
||||
var result = await _service.EvaluateWeightsAsync(weights, pairs);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(0m, result.Accuracy); // 0 correct out of 1
|
||||
Assert.Equal(0m, result.Precision); // 0 true positives
|
||||
}
|
||||
|
||||
private static List<EnsembleTrainingPair> CreateTrainingPairs(int count)
|
||||
{
|
||||
var pairs = new List<EnsembleTrainingPair>();
|
||||
for (var i = 0; i < count; i++)
|
||||
{
|
||||
pairs.Add(new EnsembleTrainingPair
|
||||
{
|
||||
Function1 = CreateAnalysis($"func{i}a"),
|
||||
Function2 = CreateAnalysis($"func{i}b"),
|
||||
IsEquivalent = i % 2 == 0
|
||||
});
|
||||
}
|
||||
return pairs;
|
||||
}
|
||||
|
||||
private static FunctionAnalysis CreateAnalysis(string id)
|
||||
{
|
||||
return new FunctionAnalysis
|
||||
{
|
||||
FunctionId = id,
|
||||
FunctionName = id
|
||||
};
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user