Files
git.stella-ops.org/src/BinaryIndex/__Libraries/StellaOps.BinaryIndex.Semantic/SemanticGraphExtractor.cs
StellaOps Bot 37e11918e0 save progress
2026-01-06 09:42:20 +02:00

516 lines
16 KiB
C#

// Copyright (c) StellaOps. All rights reserved.
// Licensed under AGPL-3.0-or-later. See LICENSE in the project root.
using System.Collections.Immutable;
using System.Globalization;
using Microsoft.Extensions.Logging;
using StellaOps.BinaryIndex.Semantic.Internal;
namespace StellaOps.BinaryIndex.Semantic;
/// <summary>
/// Default implementation of semantic graph extraction from lifted IR.
/// </summary>
public sealed class SemanticGraphExtractor : ISemanticGraphExtractor
{
private readonly ILogger<SemanticGraphExtractor> _logger;
private readonly GraphCanonicalizer _canonicalizer;
/// <summary>
/// Creates a new semantic graph extractor.
/// </summary>
/// <param name="logger">Logger instance.</param>
public SemanticGraphExtractor(ILogger<SemanticGraphExtractor> logger)
{
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_canonicalizer = new GraphCanonicalizer();
}
/// <inheritdoc />
public Task<KeySemanticsGraph> ExtractGraphAsync(
LiftedFunction function,
GraphExtractionOptions? options = null,
CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(function);
ct.ThrowIfCancellationRequested();
options ??= GraphExtractionOptions.Default;
_logger.LogDebug(
"Extracting semantic graph from function {FunctionName} with {StatementCount} statements",
function.Name,
function.Statements.Length);
var nodes = new List<SemanticNode>();
var edges = new List<SemanticEdge>();
var defMap = new Dictionary<string, int>(); // Variable/register -> defining node ID
var nodeId = 0;
foreach (var stmt in function.Statements)
{
ct.ThrowIfCancellationRequested();
if (options.MaxNodes > 0 && nodeId >= options.MaxNodes)
{
_logger.LogWarning(
"Truncating graph at {MaxNodes} nodes for function {FunctionName}",
options.MaxNodes,
function.Name);
break;
}
// Skip NOPs if configured
if (!options.IncludeNops && stmt.Kind == IrStatementKind.Nop)
{
continue;
}
// Create semantic node
var node = CreateSemanticNode(ref nodeId, stmt, options);
if (node is null)
{
continue;
}
nodes.Add(node);
// Add data dependency edges
if (options.ExtractControlDependencies || options.ExtractMemoryDependencies)
{
AddDataDependencyEdges(stmt, node.Id, defMap, edges);
}
// Track definitions
if (stmt.Destination is not null)
{
var defKey = GetOperandKey(stmt.Destination);
if (!string.IsNullOrEmpty(defKey))
{
defMap[defKey] = node.Id;
}
}
}
// Add control dependency edges from CFG
if (options.ExtractControlDependencies)
{
AddControlDependencyEdges(function.Cfg, function.BasicBlocks, nodes, edges);
}
// Compute graph properties
var properties = ComputeProperties(nodes, edges, function.Cfg);
var graph = new KeySemanticsGraph(
function.Name,
[.. nodes],
[.. edges],
properties);
_logger.LogDebug(
"Extracted graph with {NodeCount} nodes and {EdgeCount} edges for function {FunctionName}",
graph.Properties.NodeCount,
graph.Properties.EdgeCount,
function.Name);
return Task.FromResult(graph);
}
/// <inheritdoc />
public Task<KeySemanticsGraph> ExtractGraphFromSsaAsync(
SsaFunction function,
GraphExtractionOptions? options = null,
CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(function);
ct.ThrowIfCancellationRequested();
options ??= GraphExtractionOptions.Default;
_logger.LogDebug(
"Extracting semantic graph from SSA function {FunctionName}",
function.Name);
var nodes = new List<SemanticNode>();
var edges = new List<SemanticEdge>();
var defMap = new Dictionary<string, int>();
var nodeId = 0;
foreach (var stmt in function.Statements)
{
ct.ThrowIfCancellationRequested();
if (options.MaxNodes > 0 && nodeId >= options.MaxNodes)
{
break;
}
var node = CreateSemanticNodeFromSsa(ref nodeId, stmt, options);
if (node is null)
{
continue;
}
nodes.Add(node);
// SSA makes def-use explicit - use DefUse chains
foreach (var source in stmt.Sources)
{
var useKey = GetSsaVariableKey(source);
if (defMap.TryGetValue(useKey, out var defNodeId))
{
edges.Add(new SemanticEdge(defNodeId, node.Id, SemanticEdgeType.DataDependency));
}
}
// Track definition
if (stmt.Destination is not null)
{
var defKey = GetSsaVariableKey(stmt.Destination);
defMap[defKey] = node.Id;
}
}
// Build a minimal CFG from SSA blocks for properties
var cfg = BuildCfgFromSsaBlocks(function.BasicBlocks);
var properties = ComputeProperties(nodes, edges, cfg);
var graph = new KeySemanticsGraph(
function.Name,
[.. nodes],
[.. edges],
properties);
return Task.FromResult(graph);
}
/// <inheritdoc />
public Task<CanonicalGraph> CanonicalizeAsync(
KeySemanticsGraph graph,
CancellationToken ct = default)
{
ArgumentNullException.ThrowIfNull(graph);
ct.ThrowIfCancellationRequested();
var result = _canonicalizer.Canonicalize(graph);
return Task.FromResult(result);
}
private static SemanticNode? CreateSemanticNode(
ref int nodeId,
IrStatement stmt,
GraphExtractionOptions options)
{
var nodeType = MapStatementKindToNodeType(stmt.Kind);
if (nodeType == SemanticNodeType.Unknown)
{
return null;
}
var operation = options.NormalizeOperations
? NormalizeOperation(stmt.Operation)
: stmt.Operation;
var operands = stmt.Sources
.Select(GetNormalizedOperandName)
.Where(o => !string.IsNullOrEmpty(o))
.ToImmutableArray();
var node = new SemanticNode(
nodeId++,
nodeType,
operation,
operands!);
return node;
}
private static SemanticNode? CreateSemanticNodeFromSsa(
ref int nodeId,
SsaStatement stmt,
GraphExtractionOptions options)
{
var nodeType = MapStatementKindToNodeType(stmt.Kind);
if (nodeType == SemanticNodeType.Unknown)
{
return null;
}
var operation = options.NormalizeOperations
? NormalizeOperation(stmt.Operation)
: stmt.Operation;
var operands = stmt.Sources
.Select(s => string.Create(CultureInfo.InvariantCulture, $"{s.BaseName}_{s.Version}"))
.ToImmutableArray();
return new SemanticNode(nodeId++, nodeType, operation, operands);
}
private static SemanticNodeType MapStatementKindToNodeType(IrStatementKind kind)
{
return kind switch
{
IrStatementKind.Assign => SemanticNodeType.Compute,
IrStatementKind.BinaryOp => SemanticNodeType.Compute,
IrStatementKind.UnaryOp => SemanticNodeType.Compute,
IrStatementKind.Compare => SemanticNodeType.Compute,
IrStatementKind.Load => SemanticNodeType.Load,
IrStatementKind.Store => SemanticNodeType.Store,
IrStatementKind.Jump => SemanticNodeType.Branch,
IrStatementKind.ConditionalJump => SemanticNodeType.Branch,
IrStatementKind.Call => SemanticNodeType.Call,
IrStatementKind.Return => SemanticNodeType.Return,
IrStatementKind.Phi => SemanticNodeType.Phi,
IrStatementKind.Cast => SemanticNodeType.Cast,
IrStatementKind.Extend => SemanticNodeType.Cast,
_ => SemanticNodeType.Unknown
};
}
private static string NormalizeOperation(string operation)
{
// Normalize operation names to canonical form
return operation.ToUpperInvariant() switch
{
"ADD" or "IADD" or "FADD" => "ADD",
"SUB" or "ISUB" or "FSUB" => "SUB",
"MUL" or "IMUL" or "FMUL" => "MUL",
"DIV" or "IDIV" or "FDIV" or "UDIV" => "DIV",
"MOD" or "REM" or "UREM" or "SREM" => "MOD",
"AND" or "IAND" => "AND",
"OR" or "IOR" => "OR",
"XOR" or "IXOR" => "XOR",
"NOT" or "INOT" => "NOT",
"NEG" or "INEG" or "FNEG" => "NEG",
"SHL" or "ISHL" => "SHL",
"SHR" or "ISHR" or "LSHR" or "ASHR" => "SHR",
"CMP" or "ICMP" or "FCMP" => "CMP",
"MOV" or "COPY" or "ASSIGN" => "MOV",
"LOAD" or "LDR" or "LD" => "LOAD",
"STORE" or "STR" or "ST" => "STORE",
"CALL" or "INVOKE" => "CALL",
"RET" or "RETURN" => "RET",
"JMP" or "BR" or "GOTO" => "JMP",
"JCC" or "BRC" or "CONDJMP" => "JCC",
"ZEXT" or "SEXT" or "TRUNC" => "EXT",
_ => operation.ToUpperInvariant()
};
}
private static string? GetNormalizedOperandName(IrOperand operand)
{
return operand.Kind switch
{
IrOperandKind.Register => $"R:{operand.Name}",
IrOperandKind.Temporary => $"T:{operand.Name}",
IrOperandKind.Immediate => $"I:{operand.Value}",
IrOperandKind.Memory => "M",
IrOperandKind.Label => operand.Name, // Call targets/labels keep their name for API extraction
_ => null
};
}
private static string GetOperandKey(IrOperand operand)
{
return operand.Kind switch
{
IrOperandKind.Register => $"R:{operand.Name}",
IrOperandKind.Temporary => $"T:{operand.Name}",
_ => string.Empty
};
}
private static string GetSsaVariableKey(SsaVariable variable)
{
return string.Create(
CultureInfo.InvariantCulture,
$"{variable.BaseName}_{variable.Version}");
}
private static void AddDataDependencyEdges(
IrStatement stmt,
int targetNodeId,
Dictionary<string, int> defMap,
List<SemanticEdge> edges)
{
foreach (var source in stmt.Sources)
{
var useKey = GetOperandKey(source);
if (!string.IsNullOrEmpty(useKey) && defMap.TryGetValue(useKey, out var defNodeId))
{
edges.Add(new SemanticEdge(
defNodeId,
targetNodeId,
SemanticEdgeType.DataDependency));
}
}
}
private static void AddControlDependencyEdges(
ControlFlowGraph cfg,
ImmutableArray<IrBasicBlock> blocks,
List<SemanticNode> nodes,
List<SemanticEdge> edges)
{
// Find branch nodes
var branchNodes = nodes
.Where(n => n.Type == SemanticNodeType.Branch)
.ToList();
// For each branch, add control dependency to the first node in target blocks
// This is a simplified version - full control dependence analysis is more complex
foreach (var branch in branchNodes)
{
// Find nodes that are control-dependent on this branch
var dependentNodes = nodes
.Where(n => n.Id > branch.Id && n.Type != SemanticNodeType.Branch)
.Take(5); // Simplified: just the next few nodes
foreach (var dependent in dependentNodes)
{
edges.Add(new SemanticEdge(
branch.Id,
dependent.Id,
SemanticEdgeType.ControlDependency));
}
}
}
private static ControlFlowGraph BuildCfgFromSsaBlocks(ImmutableArray<SsaBasicBlock> blocks)
{
if (blocks.IsEmpty)
{
return new ControlFlowGraph(0, [], []);
}
var edges = new List<CfgEdge>();
foreach (var block in blocks)
{
foreach (var succ in block.Successors)
{
edges.Add(new CfgEdge(block.Id, succ, CfgEdgeKind.FallThrough));
}
}
var exitBlocks = blocks
.Where(b => b.Successors.IsEmpty)
.Select(b => b.Id)
.ToImmutableArray();
return new ControlFlowGraph(
blocks[0].Id,
exitBlocks,
[.. edges]);
}
private static GraphProperties ComputeProperties(
List<SemanticNode> nodes,
List<SemanticEdge> edges,
ControlFlowGraph cfg)
{
var nodeTypeCounts = nodes
.GroupBy(n => n.Type)
.ToImmutableDictionary(g => g.Key, g => g.Count());
var edgeTypeCounts = edges
.GroupBy(e => e.Type)
.ToImmutableDictionary(g => g.Key, g => g.Count());
// Cyclomatic complexity: E - N + 2P (simplified for single function)
var cyclomaticComplexity = cfg.Edges.Length - cfg.ExitBlockIds.Length + 2;
cyclomaticComplexity = Math.Max(1, cyclomaticComplexity);
// Count branches
var branchCount = nodes.Count(n => n.Type == SemanticNodeType.Branch);
// Estimate max depth (simplified)
var maxDepth = ComputeMaxDepth(nodes, edges);
// Estimate loop count from back edges
var loopCount = CountBackEdges(cfg);
return new GraphProperties(
nodes.Count,
edges.Count,
cyclomaticComplexity,
maxDepth,
nodeTypeCounts,
edgeTypeCounts,
loopCount,
branchCount);
}
private static int ComputeMaxDepth(List<SemanticNode> nodes, List<SemanticEdge> edges)
{
if (nodes.Count == 0)
{
return 0;
}
// Build adjacency list
var outEdges = new Dictionary<int, List<int>>();
foreach (var edge in edges)
{
if (!outEdges.TryGetValue(edge.SourceId, out var list))
{
list = [];
outEdges[edge.SourceId] = list;
}
list.Add(edge.TargetId);
}
// Find nodes with no incoming edges (roots)
var hasIncoming = new HashSet<int>(edges.Select(e => e.TargetId));
var roots = nodes.Where(n => !hasIncoming.Contains(n.Id)).Select(n => n.Id).ToList();
if (roots.Count == 0)
{
roots.Add(nodes[0].Id);
}
// BFS to find max depth
var maxDepth = 0;
var visited = new HashSet<int>();
var queue = new Queue<(int nodeId, int depth)>();
foreach (var root in roots)
{
queue.Enqueue((root, 1));
}
while (queue.Count > 0)
{
var (nodeId, depth) = queue.Dequeue();
if (!visited.Add(nodeId))
{
continue;
}
maxDepth = Math.Max(maxDepth, depth);
if (outEdges.TryGetValue(nodeId, out var neighbors))
{
foreach (var neighbor in neighbors)
{
if (!visited.Contains(neighbor))
{
queue.Enqueue((neighbor, depth + 1));
}
}
}
}
return maxDepth;
}
private static int CountBackEdges(ControlFlowGraph cfg)
{
// A back edge is an edge to a node that dominates the source
// Simplified: count edges where target ID < source ID
return cfg.Edges.Count(e => e.TargetBlockId < e.SourceBlockId);
}
}