""" AST analyzer for Python call graph extraction. """ import ast from dataclasses import dataclass, field from pathlib import Path from typing import Any, Optional @dataclass class FunctionNode: """Represents a function in the call graph.""" id: str package: str name: str qualified_name: str file: str line: int visibility: str annotations: list[str] = field(default_factory=list) is_entrypoint: bool = False entrypoint_type: Optional[str] = None @dataclass class CallEdge: """Represents a call between functions.""" from_id: str to_id: str kind: str file: str line: int @dataclass class Entrypoint: """Represents a detected entrypoint.""" id: str type: str route: Optional[str] = None method: Optional[str] = None class PythonASTAnalyzer: """Analyzes Python AST to extract call graph information.""" def __init__(self, package_name: str, root: Path, frameworks: list[str]): self.package_name = package_name self.root = root self.frameworks = frameworks self.nodes: dict[str, FunctionNode] = {} self.edges: list[CallEdge] = [] self.entrypoints: list[Entrypoint] = [] self.current_function: Optional[str] = None self.current_file: str = "" self.current_class: Optional[str] = None def analyze_file(self, tree: ast.AST, relative_path: str) -> None: """Analyze a single Python file.""" self.current_file = relative_path self.current_function = None self.current_class = None visitor = FunctionVisitor(self) visitor.visit(tree) def get_result(self) -> dict[str, Any]: """Get the analysis result as a dictionary.""" return { "module": self.package_name, "nodes": [self._node_to_dict(n) for n in self.nodes.values()], "edges": [self._edge_to_dict(e) for e in self._dedupe_edges()], "entrypoints": [self._entrypoint_to_dict(e) for e in self.entrypoints] } def _node_to_dict(self, node: FunctionNode) -> dict[str, Any]: return { "id": node.id, "package": node.package, "name": node.name, "signature": node.qualified_name, "position": { "file": node.file, "line": node.line, "column": 0 }, "visibility": node.visibility, "annotations": node.annotations } def _edge_to_dict(self, edge: CallEdge) -> dict[str, Any]: return { "from": edge.from_id, "to": edge.to_id, "kind": edge.kind, "site": { "file": edge.file, "line": edge.line } } def _entrypoint_to_dict(self, ep: Entrypoint) -> dict[str, Any]: result: dict[str, Any] = { "id": ep.id, "type": ep.type } if ep.route: result["route"] = ep.route if ep.method: result["method"] = ep.method return result def _dedupe_edges(self) -> list[CallEdge]: seen: set[str] = set() result: list[CallEdge] = [] for edge in self.edges: key = f"{edge.from_id}|{edge.to_id}" if key not in seen: seen.add(key) result.append(edge) return result def make_symbol_id(self, name: str, class_name: Optional[str] = None) -> str: """Create a symbol ID for a function or method.""" module_base = self.current_file.replace('.py', '').replace('/', '.').replace('\\', '.') if class_name: return f"py:{self.package_name}/{module_base}.{class_name}.{name}" return f"py:{self.package_name}/{module_base}.{name}" def add_function( self, name: str, line: int, decorators: list[str], class_name: Optional[str] = None, is_private: bool = False ) -> str: """Add a function node to the graph.""" symbol_id = self.make_symbol_id(name, class_name) qualified_name = f"{class_name}.{name}" if class_name else name visibility = "private" if is_private or name.startswith('_') else "public" node = FunctionNode( id=symbol_id, package=self.package_name, name=name, qualified_name=qualified_name, file=self.current_file, line=line, visibility=visibility, annotations=decorators ) self.nodes[symbol_id] = node # Detect entrypoints entrypoint = self._detect_entrypoint(name, decorators, class_name) if entrypoint: node.is_entrypoint = True node.entrypoint_type = entrypoint.type self.entrypoints.append(entrypoint) return symbol_id def add_call(self, target_name: str, line: int) -> None: """Add a call edge from the current function.""" if not self.current_function: return # Try to resolve the target target_id = self._resolve_target(target_name) self.edges.append(CallEdge( from_id=self.current_function, to_id=target_id, kind="direct", file=self.current_file, line=line )) def _resolve_target(self, name: str) -> str: """Resolve a call target to a symbol ID.""" # Check if it's a known local function for node_id, node in self.nodes.items(): if node.name == name or node.qualified_name == name: return node_id # External or unresolved return f"py:external/{name}" def _detect_entrypoint( self, name: str, decorators: list[str], class_name: Optional[str] ) -> Optional[Entrypoint]: """Detect if a function is an entrypoint based on frameworks and decorators.""" symbol_id = self.make_symbol_id(name, class_name) for decorator in decorators: # Flask routes if 'route' in decorator.lower() or decorator.lower() in ['get', 'post', 'put', 'delete', 'patch']: route = self._extract_route_from_decorator(decorator) method = self._extract_method_from_decorator(decorator) return Entrypoint(id=symbol_id, type="http_handler", route=route, method=method) # FastAPI routes if decorator.lower() in ['get', 'post', 'put', 'delete', 'patch', 'api_route']: route = self._extract_route_from_decorator(decorator) return Entrypoint(id=symbol_id, type="http_handler", route=route, method=decorator.upper()) # Celery tasks if 'task' in decorator.lower() or 'shared_task' in decorator.lower(): return Entrypoint(id=symbol_id, type="background_job") # Click commands if 'command' in decorator.lower() or 'group' in decorator.lower(): return Entrypoint(id=symbol_id, type="cli_command") # Django views (class-based) if class_name and class_name.endswith('View'): if name in ['get', 'post', 'put', 'delete', 'patch']: return Entrypoint(id=symbol_id, type="http_handler", method=name.upper()) # main() function if name == 'main' and not class_name: return Entrypoint(id=symbol_id, type="cli_command") return None def _extract_route_from_decorator(self, decorator: str) -> Optional[str]: """Extract route path from decorator string.""" import re match = re.search(r"['\"]([/\w{}<>:.-]+)['\"]", decorator) return match.group(1) if match else None def _extract_method_from_decorator(self, decorator: str) -> Optional[str]: """Extract HTTP method from decorator string.""" import re methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS'] for method in methods: if method.lower() in decorator.lower(): return method match = re.search(r"methods\s*=\s*\[([^\]]+)\]", decorator) if match: return match.group(1).strip("'\"").upper() return None class FunctionVisitor(ast.NodeVisitor): """AST visitor that extracts function definitions and calls.""" def __init__(self, analyzer: PythonASTAnalyzer): self.analyzer = analyzer def visit_ClassDef(self, node: ast.ClassDef) -> None: """Visit class definitions.""" old_class = self.analyzer.current_class self.analyzer.current_class = node.name self.generic_visit(node) self.analyzer.current_class = old_class def visit_FunctionDef(self, node: ast.FunctionDef) -> None: """Visit function definitions.""" self._visit_function(node) def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: """Visit async function definitions.""" self._visit_function(node) def _visit_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: """Common logic for function and async function definitions.""" decorators = [ast.unparse(d) for d in node.decorator_list] is_private = node.name.startswith('_') and not node.name.startswith('__') symbol_id = self.analyzer.add_function( name=node.name, line=node.lineno, decorators=decorators, class_name=self.analyzer.current_class, is_private=is_private ) # Visit function body for calls old_function = self.analyzer.current_function self.analyzer.current_function = symbol_id for child in ast.walk(node): if isinstance(child, ast.Call): target_name = self._get_call_target(child) if target_name: self.analyzer.add_call(target_name, child.lineno) self.analyzer.current_function = old_function def _get_call_target(self, node: ast.Call) -> Optional[str]: """Extract the target name from a Call node.""" if isinstance(node.func, ast.Name): return node.func.id elif isinstance(node.func, ast.Attribute): parts = self._get_attribute_parts(node.func) return '.'.join(parts) return None def _get_attribute_parts(self, node: ast.Attribute) -> list[str]: """Get all parts of an attribute chain.""" parts: list[str] = [] current: ast.expr = node while isinstance(current, ast.Attribute): parts.insert(0, current.attr) current = current.value if isinstance(current, ast.Name): parts.insert(0, current.id) return parts