Files
git.stella-ops.org/src/BinaryIndex/__Tests/StellaOps.BinaryIndex.Ensemble.Tests/WeightTuningServiceTests.cs
StellaOps Bot 37e11918e0 save progress
2026-01-06 09:42:20 +02:00

239 lines
7.7 KiB
C#

// Copyright (c) StellaOps. All rights reserved.
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
using System.Collections.Immutable;
using Microsoft.Extensions.Logging.Abstractions;
using NSubstitute;
using StellaOps.BinaryIndex.Semantic;
using Xunit;
namespace StellaOps.BinaryIndex.Ensemble.Tests;
public class WeightTuningServiceTests
{
private readonly IEnsembleDecisionEngine _decisionEngine;
private readonly WeightTuningService _service;
public WeightTuningServiceTests()
{
_decisionEngine = Substitute.For<IEnsembleDecisionEngine>();
var logger = NullLogger<WeightTuningService>.Instance;
_service = new WeightTuningService(_decisionEngine, logger);
}
[Fact]
public async Task TuneWeightsAsync_WithValidPairs_ReturnsBestWeights()
{
// Arrange
var pairs = CreateTrainingPairs(5);
_decisionEngine.CompareAsync(
Arg.Any<FunctionAnalysis>(),
Arg.Any<FunctionAnalysis>(),
Arg.Any<EnsembleOptions>(),
Arg.Any<CancellationToken>())
.Returns(callInfo =>
{
var opts = callInfo.Arg<EnsembleOptions>();
return Task.FromResult(new EnsembleResult
{
SourceFunctionId = "s",
TargetFunctionId = "t",
EnsembleScore = opts.SyntacticWeight * 0.9m + opts.SemanticWeight * 0.8m + opts.EmbeddingWeight * 0.85m,
Contributions = ImmutableArray<SignalContribution>.Empty,
IsMatch = true,
Confidence = ConfidenceLevel.High
});
});
// Act
var result = await _service.TuneWeightsAsync(pairs, gridStep: 0.25m);
// Assert
Assert.NotNull(result);
Assert.True(result.BestWeights.Syntactic >= 0);
Assert.True(result.BestWeights.Semantic >= 0);
Assert.True(result.BestWeights.Embedding >= 0);
Assert.NotEmpty(result.Evaluations);
}
[Fact]
public async Task TuneWeightsAsync_WeightsSumToOne()
{
// Arrange
var pairs = CreateTrainingPairs(3);
_decisionEngine.CompareAsync(
Arg.Any<FunctionAnalysis>(),
Arg.Any<FunctionAnalysis>(),
Arg.Any<EnsembleOptions>(),
Arg.Any<CancellationToken>())
.Returns(Task.FromResult(new EnsembleResult
{
SourceFunctionId = "s",
TargetFunctionId = "t",
EnsembleScore = 0.9m,
Contributions = ImmutableArray<SignalContribution>.Empty,
IsMatch = true,
Confidence = ConfidenceLevel.High
}));
// Act
var result = await _service.TuneWeightsAsync(pairs, gridStep: 0.5m);
// Assert
var sum = result.BestWeights.Syntactic + result.BestWeights.Semantic + result.BestWeights.Embedding;
Assert.True(Math.Abs(sum - 1.0m) < 0.01m);
}
[Fact]
public async Task TuneWeightsAsync_WithInvalidStep_ThrowsException()
{
// Arrange
var pairs = CreateTrainingPairs(1);
// Act & Assert
await Assert.ThrowsAsync<ArgumentOutOfRangeException>(
() => _service.TuneWeightsAsync(pairs, gridStep: 0));
}
[Fact]
public async Task TuneWeightsAsync_WithNoPairs_ThrowsException()
{
// Arrange
var pairs = Array.Empty<EnsembleTrainingPair>();
// Act & Assert
await Assert.ThrowsAsync<ArgumentException>(
() => _service.TuneWeightsAsync(pairs));
}
[Fact]
public async Task EvaluateWeightsAsync_ComputesMetrics()
{
// Arrange
var pairs = new List<EnsembleTrainingPair>
{
new()
{
Function1 = CreateAnalysis("f1"),
Function2 = CreateAnalysis("f2"),
IsEquivalent = true
},
new()
{
Function1 = CreateAnalysis("f3"),
Function2 = CreateAnalysis("f4"),
IsEquivalent = false
}
};
var weights = new EffectiveWeights(0.33m, 0.33m, 0.34m);
// Simulate decision engine returning matching for first pair
_decisionEngine.CompareAsync(
pairs[0].Function1,
pairs[0].Function2,
Arg.Any<EnsembleOptions>(),
Arg.Any<CancellationToken>())
.Returns(Task.FromResult(new EnsembleResult
{
SourceFunctionId = "f1",
TargetFunctionId = "f2",
EnsembleScore = 0.9m,
Contributions = ImmutableArray<SignalContribution>.Empty,
IsMatch = true,
Confidence = ConfidenceLevel.High
}));
// Non-matching for second pair
_decisionEngine.CompareAsync(
pairs[1].Function1,
pairs[1].Function2,
Arg.Any<EnsembleOptions>(),
Arg.Any<CancellationToken>())
.Returns(Task.FromResult(new EnsembleResult
{
SourceFunctionId = "f3",
TargetFunctionId = "f4",
EnsembleScore = 0.3m,
Contributions = ImmutableArray<SignalContribution>.Empty,
IsMatch = false,
Confidence = ConfidenceLevel.Low
}));
// Act
var result = await _service.EvaluateWeightsAsync(weights, pairs);
// Assert
Assert.Equal(weights, result.Weights);
Assert.Equal(1.0m, result.Accuracy); // Both predictions correct
Assert.Equal(1.0m, result.Precision); // TP / (TP + FP) = 1 / 1
Assert.Equal(1.0m, result.Recall); // TP / (TP + FN) = 1 / 1
}
[Fact]
public async Task EvaluateWeightsAsync_WithFalsePositive_LowersPrecision()
{
// Arrange
var pairs = new List<EnsembleTrainingPair>
{
new()
{
Function1 = CreateAnalysis("f1"),
Function2 = CreateAnalysis("f2"),
IsEquivalent = false // Ground truth: NOT equivalent
}
};
var weights = new EffectiveWeights(0.33m, 0.33m, 0.34m);
// But engine says it IS a match (false positive)
_decisionEngine.CompareAsync(
Arg.Any<FunctionAnalysis>(),
Arg.Any<FunctionAnalysis>(),
Arg.Any<EnsembleOptions>(),
Arg.Any<CancellationToken>())
.Returns(Task.FromResult(new EnsembleResult
{
SourceFunctionId = "f1",
TargetFunctionId = "f2",
EnsembleScore = 0.9m,
Contributions = ImmutableArray<SignalContribution>.Empty,
IsMatch = true, // False positive!
Confidence = ConfidenceLevel.High
}));
// Act
var result = await _service.EvaluateWeightsAsync(weights, pairs);
// Assert
Assert.Equal(0m, result.Accuracy); // 0 correct out of 1
Assert.Equal(0m, result.Precision); // 0 true positives
}
private static List<EnsembleTrainingPair> CreateTrainingPairs(int count)
{
var pairs = new List<EnsembleTrainingPair>();
for (var i = 0; i < count; i++)
{
pairs.Add(new EnsembleTrainingPair
{
Function1 = CreateAnalysis($"func{i}a"),
Function2 = CreateAnalysis($"func{i}b"),
IsEquivalent = i % 2 == 0
});
}
return pairs;
}
private static FunctionAnalysis CreateAnalysis(string id)
{
return new FunctionAnalysis
{
FunctionId = id,
FunctionName = id
};
}
}