Files
git.stella-ops.org/src/BinaryIndex/__Libraries/StellaOps.BinaryIndex.Diff/FunctionDiffer.cs
2026-02-01 21:37:40 +02:00

332 lines
11 KiB
C#

// Licensed under BUSL-1.1. Copyright (C) 2026 StellaOps Contributors.
using StellaOps.BinaryIndex.Analysis;
using System.Collections.Immutable;
namespace StellaOps.BinaryIndex.Diff;
/// <summary>
/// Compares function fingerprints between pre and post binaries.
/// </summary>
internal sealed class FunctionDiffer : IFunctionDiffer
{
private readonly IEdgeComparator _edgeComparator;
public FunctionDiffer(IEdgeComparator edgeComparator)
{
_edgeComparator = edgeComparator;
}
/// <inheritdoc />
public FunctionDiffResult Compare(
string functionName,
FunctionFingerprint? preFingerprint,
FunctionFingerprint? postFingerprint,
FunctionSignature signature,
DiffOptions options)
{
// Handle missing functions
if (preFingerprint is null && postFingerprint is null)
{
return FunctionDiffResult.NotFound(functionName);
}
if (preFingerprint is not null && postFingerprint is null)
{
return FunctionDiffResult.FunctionRemoved(functionName);
}
// Determine function status
var preStatus = preFingerprint is not null ? FunctionStatus.Present : FunctionStatus.Absent;
var postStatus = postFingerprint is not null ? FunctionStatus.Present : FunctionStatus.Absent;
// Build CFG diff if both present
CfgDiffResult? cfgDiff = null;
if (preFingerprint is not null && postFingerprint is not null)
{
cfgDiff = BuildCfgDiff(preFingerprint, postFingerprint);
}
// Build block diffs
var blockDiffs = BuildBlockDiffs(preFingerprint, postFingerprint, signature);
// Compare vulnerable edges (EdgePatterns in signature)
var preEdges = ExtractEdges(preFingerprint);
var postEdges = ExtractEdges(postFingerprint);
var edgeDiff = _edgeComparator.Compare(signature.EdgePatterns, preEdges, postEdges);
// Compute sink reachability diff (simplified without full reachability analysis)
var reachabilityDiff = ComputeSimplifiedReachability(signature, preFingerprint, postFingerprint);
// Compute semantic similarity
decimal? semanticSimilarity = null;
if (options.IncludeSemanticAnalysis && preFingerprint is not null && postFingerprint is not null)
{
semanticSimilarity = ComputeSemanticSimilarity(preFingerprint, postFingerprint);
}
// Determine function-level verdict
var verdict = DetermineVerdict(edgeDiff, reachabilityDiff, cfgDiff, preStatus, postStatus);
return new FunctionDiffResult
{
FunctionName = functionName,
PreStatus = preStatus,
PostStatus = postStatus,
CfgDiff = cfgDiff,
BlockDiffs = blockDiffs,
EdgeDiff = edgeDiff,
ReachabilityDiff = reachabilityDiff,
SemanticSimilarity = semanticSimilarity,
Verdict = verdict
};
}
private static CfgDiffResult BuildCfgDiff(
FunctionFingerprint pre,
FunctionFingerprint post)
{
// Count edges from successors in basic blocks
var preEdgeCount = pre.BasicBlockHashes.Sum(b => b.Successors.Length);
var postEdgeCount = post.BasicBlockHashes.Sum(b => b.Successors.Length);
return new CfgDiffResult
{
PreCfgHash = pre.CfgHash,
PostCfgHash = post.CfgHash,
PreBlockCount = pre.BasicBlockHashes.Length,
PostBlockCount = post.BasicBlockHashes.Length,
PreEdgeCount = preEdgeCount,
PostEdgeCount = postEdgeCount
};
}
private static ImmutableArray<BlockDiffResult> BuildBlockDiffs(
FunctionFingerprint? pre,
FunctionFingerprint? post,
FunctionSignature signature)
{
if (pre is null && post is null)
{
return [];
}
var preBlocks = pre?.BasicBlockHashes.ToDictionary(
b => b.BlockId,
b => b.OpcodeHash,
StringComparer.Ordinal) ?? [];
var postBlocks = post?.BasicBlockHashes.ToDictionary(
b => b.BlockId,
b => b.OpcodeHash,
StringComparer.Ordinal) ?? [];
var allBlockIds = preBlocks.Keys
.Union(postBlocks.Keys, StringComparer.Ordinal)
.ToList();
// Extract vulnerable block IDs from edge patterns (blocks referenced in edges)
var vulnerableBlocks = new HashSet<string>(StringComparer.Ordinal);
foreach (var edge in signature.EdgePatterns)
{
var parts = edge.Split("->", StringSplitOptions.TrimEntries);
if (parts.Length == 2)
{
vulnerableBlocks.Add(parts[0]);
vulnerableBlocks.Add(parts[1]);
}
}
var results = new List<BlockDiffResult>();
foreach (var blockId in allBlockIds)
{
var existsInPre = preBlocks.TryGetValue(blockId, out var preHash);
var existsInPost = postBlocks.TryGetValue(blockId, out var postHash);
results.Add(new BlockDiffResult
{
BlockId = blockId,
ExistsInPre = existsInPre,
ExistsInPost = existsInPost,
IsVulnerablePath = vulnerableBlocks.Contains(blockId),
HashChanged = existsInPre && existsInPost && !string.Equals(preHash, postHash, StringComparison.Ordinal),
PreHash = preHash,
PostHash = postHash
});
}
return [.. results.OrderBy(b => b.BlockId, StringComparer.Ordinal)];
}
private static ImmutableArray<string> ExtractEdges(FunctionFingerprint? fingerprint)
{
if (fingerprint is null)
{
return [];
}
// Build edge patterns from BasicBlockHash successors
var edges = new List<string>();
foreach (var block in fingerprint.BasicBlockHashes)
{
foreach (var succ in block.Successors)
{
edges.Add($"{block.BlockId}->{succ}");
}
}
return [.. edges];
}
private static SinkReachabilityDiff ComputeSimplifiedReachability(
FunctionSignature signature,
FunctionFingerprint? pre,
FunctionFingerprint? post)
{
// Simplified reachability based on presence of vulnerable blocks
// Full reachability analysis requires ReachGraph integration
if (signature.Sinks.IsEmpty)
{
return SinkReachabilityDiff.Empty;
}
var preReachable = new List<string>();
var postReachable = new List<string>();
// Extract vulnerable block IDs from edge patterns
var vulnerableBlocks = new HashSet<string>(StringComparer.Ordinal);
foreach (var edge in signature.EdgePatterns)
{
var parts = edge.Split("->", StringSplitOptions.TrimEntries);
if (parts.Length == 2)
{
vulnerableBlocks.Add(parts[0]);
vulnerableBlocks.Add(parts[1]);
}
}
// Check if vulnerable blocks are present (simplified check)
var preHasVulnerableBlocks = pre?.BasicBlockHashes
.Any(b => vulnerableBlocks.Contains(b.BlockId)) ?? false;
var postHasVulnerableBlocks = post?.BasicBlockHashes
.Any(b => vulnerableBlocks.Contains(b.BlockId)) ?? false;
// If vulnerable blocks are present, assume sinks are reachable
if (preHasVulnerableBlocks)
{
preReachable.AddRange(signature.Sinks);
}
if (postHasVulnerableBlocks)
{
postReachable.AddRange(signature.Sinks);
}
return SinkReachabilityDiff.Compute(
[.. preReachable],
[.. postReachable]);
}
private static decimal ComputeSemanticSimilarity(
FunctionFingerprint pre,
FunctionFingerprint post)
{
// Simple similarity based on basic block hash overlap
// Full semantic analysis would use embeddings
if (pre.BasicBlockHashes.IsEmpty || post.BasicBlockHashes.IsEmpty)
{
return 0m;
}
var preHashes = pre.BasicBlockHashes.Select(b => b.OpcodeHash).ToHashSet(StringComparer.Ordinal);
var postHashes = post.BasicBlockHashes.Select(b => b.OpcodeHash).ToHashSet(StringComparer.Ordinal);
var intersection = preHashes.Intersect(postHashes, StringComparer.Ordinal).Count();
var union = preHashes.Union(postHashes, StringComparer.Ordinal).Count();
if (union == 0)
{
return 0m;
}
return (decimal)intersection / union;
}
private static FunctionPatchVerdict DetermineVerdict(
VulnerableEdgeDiff edgeDiff,
SinkReachabilityDiff reachabilityDiff,
CfgDiffResult? cfgDiff,
FunctionStatus preStatus,
FunctionStatus postStatus)
{
// Function removed
if (preStatus == FunctionStatus.Present && postStatus == FunctionStatus.Absent)
{
return FunctionPatchVerdict.FunctionRemoved;
}
// Function not found in either
if (preStatus == FunctionStatus.Absent && postStatus == FunctionStatus.Absent)
{
return FunctionPatchVerdict.Inconclusive;
}
// All vulnerable edges removed
if (edgeDiff.AllVulnerableEdgesRemoved)
{
return FunctionPatchVerdict.Fixed;
}
// All sinks made unreachable
if (reachabilityDiff.AllSinksUnreachable)
{
return FunctionPatchVerdict.Fixed;
}
// Some edges removed or some sinks unreachable
if (edgeDiff.SomeVulnerableEdgesRemoved || reachabilityDiff.SomeSinksUnreachable)
{
return FunctionPatchVerdict.PartialFix;
}
// CFG structure changed significantly
if (cfgDiff?.StructureChanged == true &&
Math.Abs(cfgDiff.BlockCountDelta) > 2)
{
return FunctionPatchVerdict.PartialFix;
}
// No significant change detected
if (edgeDiff.NoChange && cfgDiff?.StructureChanged != true)
{
return FunctionPatchVerdict.StillVulnerable;
}
return FunctionPatchVerdict.Inconclusive;
}
}
/// <summary>
/// Compares vulnerable edges between binaries.
/// </summary>
internal sealed class EdgeComparator : IEdgeComparator
{
/// <inheritdoc />
public VulnerableEdgeDiff Compare(
ImmutableArray<string> goldenSetEdges,
ImmutableArray<string> preEdges,
ImmutableArray<string> postEdges)
{
// Find which golden set edges are present in each binary
var goldenSet = goldenSetEdges.ToHashSet(StringComparer.Ordinal);
var vulnerableInPre = preEdges.Where(e => goldenSet.Contains(e)).ToImmutableArray();
var vulnerableInPost = postEdges.Where(e => goldenSet.Contains(e)).ToImmutableArray();
return VulnerableEdgeDiff.Compute(vulnerableInPre, vulnerableInPost);
}
}