// Copyright (c) StellaOps. All rights reserved. // Licensed under AGPL-3.0-or-later. See LICENSE in the project root. using System.Collections.Immutable; using Microsoft.Extensions.Logging.Abstractions; using NSubstitute; using StellaOps.BinaryIndex.Semantic; using Xunit; namespace StellaOps.BinaryIndex.Ensemble.Tests; public class WeightTuningServiceTests { private readonly IEnsembleDecisionEngine _decisionEngine; private readonly WeightTuningService _service; public WeightTuningServiceTests() { _decisionEngine = Substitute.For(); var logger = NullLogger.Instance; _service = new WeightTuningService(_decisionEngine, logger); } [Fact] public async Task TuneWeightsAsync_WithValidPairs_ReturnsBestWeights() { // Arrange var pairs = CreateTrainingPairs(5); _decisionEngine.CompareAsync( Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any()) .Returns(callInfo => { var opts = callInfo.Arg(); return Task.FromResult(new EnsembleResult { SourceFunctionId = "s", TargetFunctionId = "t", EnsembleScore = opts.SyntacticWeight * 0.9m + opts.SemanticWeight * 0.8m + opts.EmbeddingWeight * 0.85m, Contributions = ImmutableArray.Empty, IsMatch = true, Confidence = ConfidenceLevel.High }); }); // Act var result = await _service.TuneWeightsAsync(pairs, gridStep: 0.25m); // Assert Assert.NotNull(result); Assert.True(result.BestWeights.Syntactic >= 0); Assert.True(result.BestWeights.Semantic >= 0); Assert.True(result.BestWeights.Embedding >= 0); Assert.NotEmpty(result.Evaluations); } [Fact] public async Task TuneWeightsAsync_WeightsSumToOne() { // Arrange var pairs = CreateTrainingPairs(3); _decisionEngine.CompareAsync( Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any()) .Returns(Task.FromResult(new EnsembleResult { SourceFunctionId = "s", TargetFunctionId = "t", EnsembleScore = 0.9m, Contributions = ImmutableArray.Empty, IsMatch = true, Confidence = ConfidenceLevel.High })); // Act var result = await _service.TuneWeightsAsync(pairs, gridStep: 0.5m); // Assert var sum = result.BestWeights.Syntactic + result.BestWeights.Semantic + result.BestWeights.Embedding; Assert.True(Math.Abs(sum - 1.0m) < 0.01m); } [Fact] public async Task TuneWeightsAsync_WithInvalidStep_ThrowsException() { // Arrange var pairs = CreateTrainingPairs(1); // Act & Assert await Assert.ThrowsAsync( () => _service.TuneWeightsAsync(pairs, gridStep: 0)); } [Fact] public async Task TuneWeightsAsync_WithNoPairs_ThrowsException() { // Arrange var pairs = Array.Empty(); // Act & Assert await Assert.ThrowsAsync( () => _service.TuneWeightsAsync(pairs)); } [Fact] public async Task EvaluateWeightsAsync_ComputesMetrics() { // Arrange var pairs = new List { new() { Function1 = CreateAnalysis("f1"), Function2 = CreateAnalysis("f2"), IsEquivalent = true }, new() { Function1 = CreateAnalysis("f3"), Function2 = CreateAnalysis("f4"), IsEquivalent = false } }; var weights = new EffectiveWeights(0.33m, 0.33m, 0.34m); // Simulate decision engine returning matching for first pair _decisionEngine.CompareAsync( pairs[0].Function1, pairs[0].Function2, Arg.Any(), Arg.Any()) .Returns(Task.FromResult(new EnsembleResult { SourceFunctionId = "f1", TargetFunctionId = "f2", EnsembleScore = 0.9m, Contributions = ImmutableArray.Empty, IsMatch = true, Confidence = ConfidenceLevel.High })); // Non-matching for second pair _decisionEngine.CompareAsync( pairs[1].Function1, pairs[1].Function2, Arg.Any(), Arg.Any()) .Returns(Task.FromResult(new EnsembleResult { SourceFunctionId = "f3", TargetFunctionId = "f4", EnsembleScore = 0.3m, Contributions = ImmutableArray.Empty, IsMatch = false, Confidence = ConfidenceLevel.Low })); // Act var result = await _service.EvaluateWeightsAsync(weights, pairs); // Assert Assert.Equal(weights, result.Weights); Assert.Equal(1.0m, result.Accuracy); // Both predictions correct Assert.Equal(1.0m, result.Precision); // TP / (TP + FP) = 1 / 1 Assert.Equal(1.0m, result.Recall); // TP / (TP + FN) = 1 / 1 } [Fact] public async Task EvaluateWeightsAsync_WithFalsePositive_LowersPrecision() { // Arrange var pairs = new List { new() { Function1 = CreateAnalysis("f1"), Function2 = CreateAnalysis("f2"), IsEquivalent = false // Ground truth: NOT equivalent } }; var weights = new EffectiveWeights(0.33m, 0.33m, 0.34m); // But engine says it IS a match (false positive) _decisionEngine.CompareAsync( Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any()) .Returns(Task.FromResult(new EnsembleResult { SourceFunctionId = "f1", TargetFunctionId = "f2", EnsembleScore = 0.9m, Contributions = ImmutableArray.Empty, IsMatch = true, // False positive! Confidence = ConfidenceLevel.High })); // Act var result = await _service.EvaluateWeightsAsync(weights, pairs); // Assert Assert.Equal(0m, result.Accuracy); // 0 correct out of 1 Assert.Equal(0m, result.Precision); // 0 true positives } private static List CreateTrainingPairs(int count) { var pairs = new List(); for (var i = 0; i < count; i++) { pairs.Add(new EnsembleTrainingPair { Function1 = CreateAnalysis($"func{i}a"), Function2 = CreateAnalysis($"func{i}b"), IsEquivalent = i % 2 == 0 }); } return pairs; } private static FunctionAnalysis CreateAnalysis(string id) { return new FunctionAnalysis { FunctionId = id, FunctionName = id }; } }