332 lines
11 KiB
C#
332 lines
11 KiB
C#
// Licensed under BUSL-1.1. Copyright (C) 2026 StellaOps Contributors.
|
|
|
|
|
|
using StellaOps.BinaryIndex.Analysis;
|
|
using System.Collections.Immutable;
|
|
|
|
namespace StellaOps.BinaryIndex.Diff;
|
|
|
|
/// <summary>
|
|
/// Compares function fingerprints between pre and post binaries.
|
|
/// </summary>
|
|
internal sealed class FunctionDiffer : IFunctionDiffer
|
|
{
|
|
private readonly IEdgeComparator _edgeComparator;
|
|
|
|
public FunctionDiffer(IEdgeComparator edgeComparator)
|
|
{
|
|
_edgeComparator = edgeComparator;
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public FunctionDiffResult Compare(
|
|
string functionName,
|
|
FunctionFingerprint? preFingerprint,
|
|
FunctionFingerprint? postFingerprint,
|
|
FunctionSignature signature,
|
|
DiffOptions options)
|
|
{
|
|
// Handle missing functions
|
|
if (preFingerprint is null && postFingerprint is null)
|
|
{
|
|
return FunctionDiffResult.NotFound(functionName);
|
|
}
|
|
|
|
if (preFingerprint is not null && postFingerprint is null)
|
|
{
|
|
return FunctionDiffResult.FunctionRemoved(functionName);
|
|
}
|
|
|
|
// Determine function status
|
|
var preStatus = preFingerprint is not null ? FunctionStatus.Present : FunctionStatus.Absent;
|
|
var postStatus = postFingerprint is not null ? FunctionStatus.Present : FunctionStatus.Absent;
|
|
|
|
// Build CFG diff if both present
|
|
CfgDiffResult? cfgDiff = null;
|
|
if (preFingerprint is not null && postFingerprint is not null)
|
|
{
|
|
cfgDiff = BuildCfgDiff(preFingerprint, postFingerprint);
|
|
}
|
|
|
|
// Build block diffs
|
|
var blockDiffs = BuildBlockDiffs(preFingerprint, postFingerprint, signature);
|
|
|
|
// Compare vulnerable edges (EdgePatterns in signature)
|
|
var preEdges = ExtractEdges(preFingerprint);
|
|
var postEdges = ExtractEdges(postFingerprint);
|
|
var edgeDiff = _edgeComparator.Compare(signature.EdgePatterns, preEdges, postEdges);
|
|
|
|
// Compute sink reachability diff (simplified without full reachability analysis)
|
|
var reachabilityDiff = ComputeSimplifiedReachability(signature, preFingerprint, postFingerprint);
|
|
|
|
// Compute semantic similarity
|
|
decimal? semanticSimilarity = null;
|
|
if (options.IncludeSemanticAnalysis && preFingerprint is not null && postFingerprint is not null)
|
|
{
|
|
semanticSimilarity = ComputeSemanticSimilarity(preFingerprint, postFingerprint);
|
|
}
|
|
|
|
// Determine function-level verdict
|
|
var verdict = DetermineVerdict(edgeDiff, reachabilityDiff, cfgDiff, preStatus, postStatus);
|
|
|
|
return new FunctionDiffResult
|
|
{
|
|
FunctionName = functionName,
|
|
PreStatus = preStatus,
|
|
PostStatus = postStatus,
|
|
CfgDiff = cfgDiff,
|
|
BlockDiffs = blockDiffs,
|
|
EdgeDiff = edgeDiff,
|
|
ReachabilityDiff = reachabilityDiff,
|
|
SemanticSimilarity = semanticSimilarity,
|
|
Verdict = verdict
|
|
};
|
|
}
|
|
|
|
private static CfgDiffResult BuildCfgDiff(
|
|
FunctionFingerprint pre,
|
|
FunctionFingerprint post)
|
|
{
|
|
// Count edges from successors in basic blocks
|
|
var preEdgeCount = pre.BasicBlockHashes.Sum(b => b.Successors.Length);
|
|
var postEdgeCount = post.BasicBlockHashes.Sum(b => b.Successors.Length);
|
|
|
|
return new CfgDiffResult
|
|
{
|
|
PreCfgHash = pre.CfgHash,
|
|
PostCfgHash = post.CfgHash,
|
|
PreBlockCount = pre.BasicBlockHashes.Length,
|
|
PostBlockCount = post.BasicBlockHashes.Length,
|
|
PreEdgeCount = preEdgeCount,
|
|
PostEdgeCount = postEdgeCount
|
|
};
|
|
}
|
|
|
|
private static ImmutableArray<BlockDiffResult> BuildBlockDiffs(
|
|
FunctionFingerprint? pre,
|
|
FunctionFingerprint? post,
|
|
FunctionSignature signature)
|
|
{
|
|
if (pre is null && post is null)
|
|
{
|
|
return [];
|
|
}
|
|
|
|
var preBlocks = pre?.BasicBlockHashes.ToDictionary(
|
|
b => b.BlockId,
|
|
b => b.OpcodeHash,
|
|
StringComparer.Ordinal) ?? [];
|
|
|
|
var postBlocks = post?.BasicBlockHashes.ToDictionary(
|
|
b => b.BlockId,
|
|
b => b.OpcodeHash,
|
|
StringComparer.Ordinal) ?? [];
|
|
|
|
var allBlockIds = preBlocks.Keys
|
|
.Union(postBlocks.Keys, StringComparer.Ordinal)
|
|
.ToList();
|
|
|
|
// Extract vulnerable block IDs from edge patterns (blocks referenced in edges)
|
|
var vulnerableBlocks = new HashSet<string>(StringComparer.Ordinal);
|
|
foreach (var edge in signature.EdgePatterns)
|
|
{
|
|
var parts = edge.Split("->", StringSplitOptions.TrimEntries);
|
|
if (parts.Length == 2)
|
|
{
|
|
vulnerableBlocks.Add(parts[0]);
|
|
vulnerableBlocks.Add(parts[1]);
|
|
}
|
|
}
|
|
|
|
var results = new List<BlockDiffResult>();
|
|
foreach (var blockId in allBlockIds)
|
|
{
|
|
var existsInPre = preBlocks.TryGetValue(blockId, out var preHash);
|
|
var existsInPost = postBlocks.TryGetValue(blockId, out var postHash);
|
|
|
|
results.Add(new BlockDiffResult
|
|
{
|
|
BlockId = blockId,
|
|
ExistsInPre = existsInPre,
|
|
ExistsInPost = existsInPost,
|
|
IsVulnerablePath = vulnerableBlocks.Contains(blockId),
|
|
HashChanged = existsInPre && existsInPost && !string.Equals(preHash, postHash, StringComparison.Ordinal),
|
|
PreHash = preHash,
|
|
PostHash = postHash
|
|
});
|
|
}
|
|
|
|
return [.. results.OrderBy(b => b.BlockId, StringComparer.Ordinal)];
|
|
}
|
|
|
|
private static ImmutableArray<string> ExtractEdges(FunctionFingerprint? fingerprint)
|
|
{
|
|
if (fingerprint is null)
|
|
{
|
|
return [];
|
|
}
|
|
|
|
// Build edge patterns from BasicBlockHash successors
|
|
var edges = new List<string>();
|
|
foreach (var block in fingerprint.BasicBlockHashes)
|
|
{
|
|
foreach (var succ in block.Successors)
|
|
{
|
|
edges.Add($"{block.BlockId}->{succ}");
|
|
}
|
|
}
|
|
return [.. edges];
|
|
}
|
|
|
|
private static SinkReachabilityDiff ComputeSimplifiedReachability(
|
|
FunctionSignature signature,
|
|
FunctionFingerprint? pre,
|
|
FunctionFingerprint? post)
|
|
{
|
|
// Simplified reachability based on presence of vulnerable blocks
|
|
// Full reachability analysis requires ReachGraph integration
|
|
|
|
if (signature.Sinks.IsEmpty)
|
|
{
|
|
return SinkReachabilityDiff.Empty;
|
|
}
|
|
|
|
var preReachable = new List<string>();
|
|
var postReachable = new List<string>();
|
|
|
|
// Extract vulnerable block IDs from edge patterns
|
|
var vulnerableBlocks = new HashSet<string>(StringComparer.Ordinal);
|
|
foreach (var edge in signature.EdgePatterns)
|
|
{
|
|
var parts = edge.Split("->", StringSplitOptions.TrimEntries);
|
|
if (parts.Length == 2)
|
|
{
|
|
vulnerableBlocks.Add(parts[0]);
|
|
vulnerableBlocks.Add(parts[1]);
|
|
}
|
|
}
|
|
|
|
// Check if vulnerable blocks are present (simplified check)
|
|
var preHasVulnerableBlocks = pre?.BasicBlockHashes
|
|
.Any(b => vulnerableBlocks.Contains(b.BlockId)) ?? false;
|
|
|
|
var postHasVulnerableBlocks = post?.BasicBlockHashes
|
|
.Any(b => vulnerableBlocks.Contains(b.BlockId)) ?? false;
|
|
|
|
// If vulnerable blocks are present, assume sinks are reachable
|
|
if (preHasVulnerableBlocks)
|
|
{
|
|
preReachable.AddRange(signature.Sinks);
|
|
}
|
|
|
|
if (postHasVulnerableBlocks)
|
|
{
|
|
postReachable.AddRange(signature.Sinks);
|
|
}
|
|
|
|
return SinkReachabilityDiff.Compute(
|
|
[.. preReachable],
|
|
[.. postReachable]);
|
|
}
|
|
|
|
private static decimal ComputeSemanticSimilarity(
|
|
FunctionFingerprint pre,
|
|
FunctionFingerprint post)
|
|
{
|
|
// Simple similarity based on basic block hash overlap
|
|
// Full semantic analysis would use embeddings
|
|
|
|
if (pre.BasicBlockHashes.IsEmpty || post.BasicBlockHashes.IsEmpty)
|
|
{
|
|
return 0m;
|
|
}
|
|
|
|
var preHashes = pre.BasicBlockHashes.Select(b => b.OpcodeHash).ToHashSet(StringComparer.Ordinal);
|
|
var postHashes = post.BasicBlockHashes.Select(b => b.OpcodeHash).ToHashSet(StringComparer.Ordinal);
|
|
|
|
var intersection = preHashes.Intersect(postHashes, StringComparer.Ordinal).Count();
|
|
var union = preHashes.Union(postHashes, StringComparer.Ordinal).Count();
|
|
|
|
if (union == 0)
|
|
{
|
|
return 0m;
|
|
}
|
|
|
|
return (decimal)intersection / union;
|
|
}
|
|
|
|
private static FunctionPatchVerdict DetermineVerdict(
|
|
VulnerableEdgeDiff edgeDiff,
|
|
SinkReachabilityDiff reachabilityDiff,
|
|
CfgDiffResult? cfgDiff,
|
|
FunctionStatus preStatus,
|
|
FunctionStatus postStatus)
|
|
{
|
|
// Function removed
|
|
if (preStatus == FunctionStatus.Present && postStatus == FunctionStatus.Absent)
|
|
{
|
|
return FunctionPatchVerdict.FunctionRemoved;
|
|
}
|
|
|
|
// Function not found in either
|
|
if (preStatus == FunctionStatus.Absent && postStatus == FunctionStatus.Absent)
|
|
{
|
|
return FunctionPatchVerdict.Inconclusive;
|
|
}
|
|
|
|
// All vulnerable edges removed
|
|
if (edgeDiff.AllVulnerableEdgesRemoved)
|
|
{
|
|
return FunctionPatchVerdict.Fixed;
|
|
}
|
|
|
|
// All sinks made unreachable
|
|
if (reachabilityDiff.AllSinksUnreachable)
|
|
{
|
|
return FunctionPatchVerdict.Fixed;
|
|
}
|
|
|
|
// Some edges removed or some sinks unreachable
|
|
if (edgeDiff.SomeVulnerableEdgesRemoved || reachabilityDiff.SomeSinksUnreachable)
|
|
{
|
|
return FunctionPatchVerdict.PartialFix;
|
|
}
|
|
|
|
// CFG structure changed significantly
|
|
if (cfgDiff?.StructureChanged == true &&
|
|
Math.Abs(cfgDiff.BlockCountDelta) > 2)
|
|
{
|
|
return FunctionPatchVerdict.PartialFix;
|
|
}
|
|
|
|
// No significant change detected
|
|
if (edgeDiff.NoChange && cfgDiff?.StructureChanged != true)
|
|
{
|
|
return FunctionPatchVerdict.StillVulnerable;
|
|
}
|
|
|
|
return FunctionPatchVerdict.Inconclusive;
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compares vulnerable edges between binaries.
|
|
/// </summary>
|
|
internal sealed class EdgeComparator : IEdgeComparator
|
|
{
|
|
/// <inheritdoc />
|
|
public VulnerableEdgeDiff Compare(
|
|
ImmutableArray<string> goldenSetEdges,
|
|
ImmutableArray<string> preEdges,
|
|
ImmutableArray<string> postEdges)
|
|
{
|
|
// Find which golden set edges are present in each binary
|
|
var goldenSet = goldenSetEdges.ToHashSet(StringComparer.Ordinal);
|
|
|
|
var vulnerableInPre = preEdges.Where(e => goldenSet.Contains(e)).ToImmutableArray();
|
|
var vulnerableInPost = postEdges.Where(e => goldenSet.Contains(e)).ToImmutableArray();
|
|
|
|
return VulnerableEdgeDiff.Compute(vulnerableInPre, vulnerableInPost);
|
|
}
|
|
}
|