// Copyright (c) StellaOps. All rights reserved. // Licensed under AGPL-3.0-or-later. See LICENSE in the project root. using System.Collections.Immutable; using System.Globalization; using Microsoft.Extensions.Logging; using StellaOps.BinaryIndex.Semantic.Internal; namespace StellaOps.BinaryIndex.Semantic; /// /// Default implementation of semantic graph extraction from lifted IR. /// public sealed class SemanticGraphExtractor : ISemanticGraphExtractor { private readonly ILogger _logger; private readonly GraphCanonicalizer _canonicalizer; /// /// Creates a new semantic graph extractor. /// /// Logger instance. public SemanticGraphExtractor(ILogger logger) { _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _canonicalizer = new GraphCanonicalizer(); } /// public Task ExtractGraphAsync( LiftedFunction function, GraphExtractionOptions? options = null, CancellationToken ct = default) { ArgumentNullException.ThrowIfNull(function); ct.ThrowIfCancellationRequested(); options ??= GraphExtractionOptions.Default; _logger.LogDebug( "Extracting semantic graph from function {FunctionName} with {StatementCount} statements", function.Name, function.Statements.Length); var nodes = new List(); var edges = new List(); var defMap = new Dictionary(); // Variable/register -> defining node ID var nodeId = 0; foreach (var stmt in function.Statements) { ct.ThrowIfCancellationRequested(); if (options.MaxNodes > 0 && nodeId >= options.MaxNodes) { _logger.LogWarning( "Truncating graph at {MaxNodes} nodes for function {FunctionName}", options.MaxNodes, function.Name); break; } // Skip NOPs if configured if (!options.IncludeNops && stmt.Kind == IrStatementKind.Nop) { continue; } // Create semantic node var node = CreateSemanticNode(ref nodeId, stmt, options); if (node is null) { continue; } nodes.Add(node); // Add data dependency edges if (options.ExtractControlDependencies || options.ExtractMemoryDependencies) { AddDataDependencyEdges(stmt, node.Id, defMap, edges); } // Track definitions if (stmt.Destination is not null) { var defKey = GetOperandKey(stmt.Destination); if (!string.IsNullOrEmpty(defKey)) { defMap[defKey] = node.Id; } } } // Add control dependency edges from CFG if (options.ExtractControlDependencies) { AddControlDependencyEdges(function.Cfg, function.BasicBlocks, nodes, edges); } // Compute graph properties var properties = ComputeProperties(nodes, edges, function.Cfg); var graph = new KeySemanticsGraph( function.Name, [.. nodes], [.. edges], properties); _logger.LogDebug( "Extracted graph with {NodeCount} nodes and {EdgeCount} edges for function {FunctionName}", graph.Properties.NodeCount, graph.Properties.EdgeCount, function.Name); return Task.FromResult(graph); } /// public Task ExtractGraphFromSsaAsync( SsaFunction function, GraphExtractionOptions? options = null, CancellationToken ct = default) { ArgumentNullException.ThrowIfNull(function); ct.ThrowIfCancellationRequested(); options ??= GraphExtractionOptions.Default; _logger.LogDebug( "Extracting semantic graph from SSA function {FunctionName}", function.Name); var nodes = new List(); var edges = new List(); var defMap = new Dictionary(); var nodeId = 0; foreach (var stmt in function.Statements) { ct.ThrowIfCancellationRequested(); if (options.MaxNodes > 0 && nodeId >= options.MaxNodes) { break; } var node = CreateSemanticNodeFromSsa(ref nodeId, stmt, options); if (node is null) { continue; } nodes.Add(node); // SSA makes def-use explicit - use DefUse chains foreach (var source in stmt.Sources) { var useKey = GetSsaVariableKey(source); if (defMap.TryGetValue(useKey, out var defNodeId)) { edges.Add(new SemanticEdge(defNodeId, node.Id, SemanticEdgeType.DataDependency)); } } // Track definition if (stmt.Destination is not null) { var defKey = GetSsaVariableKey(stmt.Destination); defMap[defKey] = node.Id; } } // Build a minimal CFG from SSA blocks for properties var cfg = BuildCfgFromSsaBlocks(function.BasicBlocks); var properties = ComputeProperties(nodes, edges, cfg); var graph = new KeySemanticsGraph( function.Name, [.. nodes], [.. edges], properties); return Task.FromResult(graph); } /// public Task CanonicalizeAsync( KeySemanticsGraph graph, CancellationToken ct = default) { ArgumentNullException.ThrowIfNull(graph); ct.ThrowIfCancellationRequested(); var result = _canonicalizer.Canonicalize(graph); return Task.FromResult(result); } private static SemanticNode? CreateSemanticNode( ref int nodeId, IrStatement stmt, GraphExtractionOptions options) { var nodeType = MapStatementKindToNodeType(stmt.Kind); if (nodeType == SemanticNodeType.Unknown) { return null; } var operation = options.NormalizeOperations ? NormalizeOperation(stmt.Operation) : stmt.Operation; var operands = stmt.Sources .Select(GetNormalizedOperandName) .Where(o => !string.IsNullOrEmpty(o)) .ToImmutableArray(); var node = new SemanticNode( nodeId++, nodeType, operation, operands!); return node; } private static SemanticNode? CreateSemanticNodeFromSsa( ref int nodeId, SsaStatement stmt, GraphExtractionOptions options) { var nodeType = MapStatementKindToNodeType(stmt.Kind); if (nodeType == SemanticNodeType.Unknown) { return null; } var operation = options.NormalizeOperations ? NormalizeOperation(stmt.Operation) : stmt.Operation; var operands = stmt.Sources .Select(s => string.Create(CultureInfo.InvariantCulture, $"{s.BaseName}_{s.Version}")) .ToImmutableArray(); return new SemanticNode(nodeId++, nodeType, operation, operands); } private static SemanticNodeType MapStatementKindToNodeType(IrStatementKind kind) { return kind switch { IrStatementKind.Assign => SemanticNodeType.Compute, IrStatementKind.BinaryOp => SemanticNodeType.Compute, IrStatementKind.UnaryOp => SemanticNodeType.Compute, IrStatementKind.Compare => SemanticNodeType.Compute, IrStatementKind.Load => SemanticNodeType.Load, IrStatementKind.Store => SemanticNodeType.Store, IrStatementKind.Jump => SemanticNodeType.Branch, IrStatementKind.ConditionalJump => SemanticNodeType.Branch, IrStatementKind.Call => SemanticNodeType.Call, IrStatementKind.Return => SemanticNodeType.Return, IrStatementKind.Phi => SemanticNodeType.Phi, IrStatementKind.Cast => SemanticNodeType.Cast, IrStatementKind.Extend => SemanticNodeType.Cast, _ => SemanticNodeType.Unknown }; } private static string NormalizeOperation(string operation) { // Normalize operation names to canonical form return operation.ToUpperInvariant() switch { "ADD" or "IADD" or "FADD" => "ADD", "SUB" or "ISUB" or "FSUB" => "SUB", "MUL" or "IMUL" or "FMUL" => "MUL", "DIV" or "IDIV" or "FDIV" or "UDIV" => "DIV", "MOD" or "REM" or "UREM" or "SREM" => "MOD", "AND" or "IAND" => "AND", "OR" or "IOR" => "OR", "XOR" or "IXOR" => "XOR", "NOT" or "INOT" => "NOT", "NEG" or "INEG" or "FNEG" => "NEG", "SHL" or "ISHL" => "SHL", "SHR" or "ISHR" or "LSHR" or "ASHR" => "SHR", "CMP" or "ICMP" or "FCMP" => "CMP", "MOV" or "COPY" or "ASSIGN" => "MOV", "LOAD" or "LDR" or "LD" => "LOAD", "STORE" or "STR" or "ST" => "STORE", "CALL" or "INVOKE" => "CALL", "RET" or "RETURN" => "RET", "JMP" or "BR" or "GOTO" => "JMP", "JCC" or "BRC" or "CONDJMP" => "JCC", "ZEXT" or "SEXT" or "TRUNC" => "EXT", _ => operation.ToUpperInvariant() }; } private static string? GetNormalizedOperandName(IrOperand operand) { return operand.Kind switch { IrOperandKind.Register => $"R:{operand.Name}", IrOperandKind.Temporary => $"T:{operand.Name}", IrOperandKind.Immediate => $"I:{operand.Value}", IrOperandKind.Memory => "M", IrOperandKind.Label => operand.Name, // Call targets/labels keep their name for API extraction _ => null }; } private static string GetOperandKey(IrOperand operand) { return operand.Kind switch { IrOperandKind.Register => $"R:{operand.Name}", IrOperandKind.Temporary => $"T:{operand.Name}", _ => string.Empty }; } private static string GetSsaVariableKey(SsaVariable variable) { return string.Create( CultureInfo.InvariantCulture, $"{variable.BaseName}_{variable.Version}"); } private static void AddDataDependencyEdges( IrStatement stmt, int targetNodeId, Dictionary defMap, List edges) { foreach (var source in stmt.Sources) { var useKey = GetOperandKey(source); if (!string.IsNullOrEmpty(useKey) && defMap.TryGetValue(useKey, out var defNodeId)) { edges.Add(new SemanticEdge( defNodeId, targetNodeId, SemanticEdgeType.DataDependency)); } } } private static void AddControlDependencyEdges( ControlFlowGraph cfg, ImmutableArray blocks, List nodes, List edges) { // Find branch nodes var branchNodes = nodes .Where(n => n.Type == SemanticNodeType.Branch) .ToList(); // For each branch, add control dependency to the first node in target blocks // This is a simplified version - full control dependence analysis is more complex foreach (var branch in branchNodes) { // Find nodes that are control-dependent on this branch var dependentNodes = nodes .Where(n => n.Id > branch.Id && n.Type != SemanticNodeType.Branch) .Take(5); // Simplified: just the next few nodes foreach (var dependent in dependentNodes) { edges.Add(new SemanticEdge( branch.Id, dependent.Id, SemanticEdgeType.ControlDependency)); } } } private static ControlFlowGraph BuildCfgFromSsaBlocks(ImmutableArray blocks) { if (blocks.IsEmpty) { return new ControlFlowGraph(0, [], []); } var edges = new List(); foreach (var block in blocks) { foreach (var succ in block.Successors) { edges.Add(new CfgEdge(block.Id, succ, CfgEdgeKind.FallThrough)); } } var exitBlocks = blocks .Where(b => b.Successors.IsEmpty) .Select(b => b.Id) .ToImmutableArray(); return new ControlFlowGraph( blocks[0].Id, exitBlocks, [.. edges]); } private static GraphProperties ComputeProperties( List nodes, List edges, ControlFlowGraph cfg) { var nodeTypeCounts = nodes .GroupBy(n => n.Type) .ToImmutableDictionary(g => g.Key, g => g.Count()); var edgeTypeCounts = edges .GroupBy(e => e.Type) .ToImmutableDictionary(g => g.Key, g => g.Count()); // Cyclomatic complexity: E - N + 2P (simplified for single function) var cyclomaticComplexity = cfg.Edges.Length - cfg.ExitBlockIds.Length + 2; cyclomaticComplexity = Math.Max(1, cyclomaticComplexity); // Count branches var branchCount = nodes.Count(n => n.Type == SemanticNodeType.Branch); // Estimate max depth (simplified) var maxDepth = ComputeMaxDepth(nodes, edges); // Estimate loop count from back edges var loopCount = CountBackEdges(cfg); return new GraphProperties( nodes.Count, edges.Count, cyclomaticComplexity, maxDepth, nodeTypeCounts, edgeTypeCounts, loopCount, branchCount); } private static int ComputeMaxDepth(List nodes, List edges) { if (nodes.Count == 0) { return 0; } // Build adjacency list var outEdges = new Dictionary>(); foreach (var edge in edges) { if (!outEdges.TryGetValue(edge.SourceId, out var list)) { list = []; outEdges[edge.SourceId] = list; } list.Add(edge.TargetId); } // Find nodes with no incoming edges (roots) var hasIncoming = new HashSet(edges.Select(e => e.TargetId)); var roots = nodes.Where(n => !hasIncoming.Contains(n.Id)).Select(n => n.Id).ToList(); if (roots.Count == 0) { roots.Add(nodes[0].Id); } // BFS to find max depth var maxDepth = 0; var visited = new HashSet(); var queue = new Queue<(int nodeId, int depth)>(); foreach (var root in roots) { queue.Enqueue((root, 1)); } while (queue.Count > 0) { var (nodeId, depth) = queue.Dequeue(); if (!visited.Add(nodeId)) { continue; } maxDepth = Math.Max(maxDepth, depth); if (outEdges.TryGetValue(nodeId, out var neighbors)) { foreach (var neighbor in neighbors) { if (!visited.Contains(neighbor)) { queue.Enqueue((neighbor, depth + 1)); } } } } return maxDepth; } private static int CountBackEdges(ControlFlowGraph cfg) { // A back edge is an edge to a node that dominates the source // Simplified: count edges where target ID < source ID return cfg.Edges.Count(e => e.TargetBlockId < e.SourceBlockId); } }