239 lines
7.7 KiB
C#
239 lines
7.7 KiB
C#
// Copyright (c) StellaOps. All rights reserved.
|
|
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
|
|
|
|
using System.Collections.Immutable;
|
|
using Microsoft.Extensions.Logging.Abstractions;
|
|
using NSubstitute;
|
|
using StellaOps.BinaryIndex.Semantic;
|
|
using Xunit;
|
|
|
|
namespace StellaOps.BinaryIndex.Ensemble.Tests;
|
|
|
|
public class WeightTuningServiceTests
|
|
{
|
|
private readonly IEnsembleDecisionEngine _decisionEngine;
|
|
private readonly WeightTuningService _service;
|
|
|
|
public WeightTuningServiceTests()
|
|
{
|
|
_decisionEngine = Substitute.For<IEnsembleDecisionEngine>();
|
|
var logger = NullLogger<WeightTuningService>.Instance;
|
|
_service = new WeightTuningService(_decisionEngine, logger);
|
|
}
|
|
|
|
[Fact]
|
|
public async Task TuneWeightsAsync_WithValidPairs_ReturnsBestWeights()
|
|
{
|
|
// Arrange
|
|
var pairs = CreateTrainingPairs(5);
|
|
|
|
_decisionEngine.CompareAsync(
|
|
Arg.Any<FunctionAnalysis>(),
|
|
Arg.Any<FunctionAnalysis>(),
|
|
Arg.Any<EnsembleOptions>(),
|
|
Arg.Any<CancellationToken>())
|
|
.Returns(callInfo =>
|
|
{
|
|
var opts = callInfo.Arg<EnsembleOptions>();
|
|
return Task.FromResult(new EnsembleResult
|
|
{
|
|
SourceFunctionId = "s",
|
|
TargetFunctionId = "t",
|
|
EnsembleScore = opts.SyntacticWeight * 0.9m + opts.SemanticWeight * 0.8m + opts.EmbeddingWeight * 0.85m,
|
|
Contributions = ImmutableArray<SignalContribution>.Empty,
|
|
IsMatch = true,
|
|
Confidence = ConfidenceLevel.High
|
|
});
|
|
});
|
|
|
|
// Act
|
|
var result = await _service.TuneWeightsAsync(pairs, gridStep: 0.25m);
|
|
|
|
// Assert
|
|
Assert.NotNull(result);
|
|
Assert.True(result.BestWeights.Syntactic >= 0);
|
|
Assert.True(result.BestWeights.Semantic >= 0);
|
|
Assert.True(result.BestWeights.Embedding >= 0);
|
|
Assert.NotEmpty(result.Evaluations);
|
|
}
|
|
|
|
[Fact]
|
|
public async Task TuneWeightsAsync_WeightsSumToOne()
|
|
{
|
|
// Arrange
|
|
var pairs = CreateTrainingPairs(3);
|
|
|
|
_decisionEngine.CompareAsync(
|
|
Arg.Any<FunctionAnalysis>(),
|
|
Arg.Any<FunctionAnalysis>(),
|
|
Arg.Any<EnsembleOptions>(),
|
|
Arg.Any<CancellationToken>())
|
|
.Returns(Task.FromResult(new EnsembleResult
|
|
{
|
|
SourceFunctionId = "s",
|
|
TargetFunctionId = "t",
|
|
EnsembleScore = 0.9m,
|
|
Contributions = ImmutableArray<SignalContribution>.Empty,
|
|
IsMatch = true,
|
|
Confidence = ConfidenceLevel.High
|
|
}));
|
|
|
|
// Act
|
|
var result = await _service.TuneWeightsAsync(pairs, gridStep: 0.5m);
|
|
|
|
// Assert
|
|
var sum = result.BestWeights.Syntactic + result.BestWeights.Semantic + result.BestWeights.Embedding;
|
|
Assert.True(Math.Abs(sum - 1.0m) < 0.01m);
|
|
}
|
|
|
|
[Fact]
|
|
public async Task TuneWeightsAsync_WithInvalidStep_ThrowsException()
|
|
{
|
|
// Arrange
|
|
var pairs = CreateTrainingPairs(1);
|
|
|
|
// Act & Assert
|
|
await Assert.ThrowsAsync<ArgumentOutOfRangeException>(
|
|
() => _service.TuneWeightsAsync(pairs, gridStep: 0));
|
|
}
|
|
|
|
[Fact]
|
|
public async Task TuneWeightsAsync_WithNoPairs_ThrowsException()
|
|
{
|
|
// Arrange
|
|
var pairs = Array.Empty<EnsembleTrainingPair>();
|
|
|
|
// Act & Assert
|
|
await Assert.ThrowsAsync<ArgumentException>(
|
|
() => _service.TuneWeightsAsync(pairs));
|
|
}
|
|
|
|
[Fact]
|
|
public async Task EvaluateWeightsAsync_ComputesMetrics()
|
|
{
|
|
// Arrange
|
|
var pairs = new List<EnsembleTrainingPair>
|
|
{
|
|
new()
|
|
{
|
|
Function1 = CreateAnalysis("f1"),
|
|
Function2 = CreateAnalysis("f2"),
|
|
IsEquivalent = true
|
|
},
|
|
new()
|
|
{
|
|
Function1 = CreateAnalysis("f3"),
|
|
Function2 = CreateAnalysis("f4"),
|
|
IsEquivalent = false
|
|
}
|
|
};
|
|
|
|
var weights = new EffectiveWeights(0.33m, 0.33m, 0.34m);
|
|
|
|
// Simulate decision engine returning matching for first pair
|
|
_decisionEngine.CompareAsync(
|
|
pairs[0].Function1,
|
|
pairs[0].Function2,
|
|
Arg.Any<EnsembleOptions>(),
|
|
Arg.Any<CancellationToken>())
|
|
.Returns(Task.FromResult(new EnsembleResult
|
|
{
|
|
SourceFunctionId = "f1",
|
|
TargetFunctionId = "f2",
|
|
EnsembleScore = 0.9m,
|
|
Contributions = ImmutableArray<SignalContribution>.Empty,
|
|
IsMatch = true,
|
|
Confidence = ConfidenceLevel.High
|
|
}));
|
|
|
|
// Non-matching for second pair
|
|
_decisionEngine.CompareAsync(
|
|
pairs[1].Function1,
|
|
pairs[1].Function2,
|
|
Arg.Any<EnsembleOptions>(),
|
|
Arg.Any<CancellationToken>())
|
|
.Returns(Task.FromResult(new EnsembleResult
|
|
{
|
|
SourceFunctionId = "f3",
|
|
TargetFunctionId = "f4",
|
|
EnsembleScore = 0.3m,
|
|
Contributions = ImmutableArray<SignalContribution>.Empty,
|
|
IsMatch = false,
|
|
Confidence = ConfidenceLevel.Low
|
|
}));
|
|
|
|
// Act
|
|
var result = await _service.EvaluateWeightsAsync(weights, pairs);
|
|
|
|
// Assert
|
|
Assert.Equal(weights, result.Weights);
|
|
Assert.Equal(1.0m, result.Accuracy); // Both predictions correct
|
|
Assert.Equal(1.0m, result.Precision); // TP / (TP + FP) = 1 / 1
|
|
Assert.Equal(1.0m, result.Recall); // TP / (TP + FN) = 1 / 1
|
|
}
|
|
|
|
[Fact]
|
|
public async Task EvaluateWeightsAsync_WithFalsePositive_LowersPrecision()
|
|
{
|
|
// Arrange
|
|
var pairs = new List<EnsembleTrainingPair>
|
|
{
|
|
new()
|
|
{
|
|
Function1 = CreateAnalysis("f1"),
|
|
Function2 = CreateAnalysis("f2"),
|
|
IsEquivalent = false // Ground truth: NOT equivalent
|
|
}
|
|
};
|
|
|
|
var weights = new EffectiveWeights(0.33m, 0.33m, 0.34m);
|
|
|
|
// But engine says it IS a match (false positive)
|
|
_decisionEngine.CompareAsync(
|
|
Arg.Any<FunctionAnalysis>(),
|
|
Arg.Any<FunctionAnalysis>(),
|
|
Arg.Any<EnsembleOptions>(),
|
|
Arg.Any<CancellationToken>())
|
|
.Returns(Task.FromResult(new EnsembleResult
|
|
{
|
|
SourceFunctionId = "f1",
|
|
TargetFunctionId = "f2",
|
|
EnsembleScore = 0.9m,
|
|
Contributions = ImmutableArray<SignalContribution>.Empty,
|
|
IsMatch = true, // False positive!
|
|
Confidence = ConfidenceLevel.High
|
|
}));
|
|
|
|
// Act
|
|
var result = await _service.EvaluateWeightsAsync(weights, pairs);
|
|
|
|
// Assert
|
|
Assert.Equal(0m, result.Accuracy); // 0 correct out of 1
|
|
Assert.Equal(0m, result.Precision); // 0 true positives
|
|
}
|
|
|
|
private static List<EnsembleTrainingPair> CreateTrainingPairs(int count)
|
|
{
|
|
var pairs = new List<EnsembleTrainingPair>();
|
|
for (var i = 0; i < count; i++)
|
|
{
|
|
pairs.Add(new EnsembleTrainingPair
|
|
{
|
|
Function1 = CreateAnalysis($"func{i}a"),
|
|
Function2 = CreateAnalysis($"func{i}b"),
|
|
IsEquivalent = i % 2 == 0
|
|
});
|
|
}
|
|
return pairs;
|
|
}
|
|
|
|
private static FunctionAnalysis CreateAnalysis(string id)
|
|
{
|
|
return new FunctionAnalysis
|
|
{
|
|
FunctionId = id,
|
|
FunctionName = id
|
|
};
|
|
}
|
|
}
|