Files
git.stella-ops.org/tools/stella-callgraph-python/ast_analyzer.py
master 8779e9226f feat: add stella-callgraph-node for JavaScript/TypeScript call graph extraction
- Implemented a new tool `stella-callgraph-node` that extracts call graphs from JavaScript/TypeScript projects using Babel AST.
- Added command-line interface with options for JSON output and help.
- Included functionality to analyze project structure, detect functions, and build call graphs.
- Created a package.json file for dependency management.

feat: introduce stella-callgraph-python for Python call graph extraction

- Developed `stella-callgraph-python` to extract call graphs from Python projects using AST analysis.
- Implemented command-line interface with options for JSON output and verbose logging.
- Added framework detection to identify popular web frameworks and their entry points.
- Created an AST analyzer to traverse Python code and extract function definitions and calls.
- Included requirements.txt for project dependencies.

chore: add framework detection for Python projects

- Implemented framework detection logic to identify frameworks like Flask, FastAPI, Django, and others based on project files and import patterns.
- Enhanced the AST analyzer to recognize entry points based on decorators and function definitions.
2025-12-19 18:11:59 +02:00

323 lines
11 KiB
Python

"""
AST analyzer for Python call graph extraction.
"""
import ast
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
@dataclass
class FunctionNode:
"""Represents a function in the call graph."""
id: str
package: str
name: str
qualified_name: str
file: str
line: int
visibility: str
annotations: list[str] = field(default_factory=list)
is_entrypoint: bool = False
entrypoint_type: Optional[str] = None
@dataclass
class CallEdge:
"""Represents a call between functions."""
from_id: str
to_id: str
kind: str
file: str
line: int
@dataclass
class Entrypoint:
"""Represents a detected entrypoint."""
id: str
type: str
route: Optional[str] = None
method: Optional[str] = None
class PythonASTAnalyzer:
"""Analyzes Python AST to extract call graph information."""
def __init__(self, package_name: str, root: Path, frameworks: list[str]):
self.package_name = package_name
self.root = root
self.frameworks = frameworks
self.nodes: dict[str, FunctionNode] = {}
self.edges: list[CallEdge] = []
self.entrypoints: list[Entrypoint] = []
self.current_function: Optional[str] = None
self.current_file: str = ""
self.current_class: Optional[str] = None
def analyze_file(self, tree: ast.AST, relative_path: str) -> None:
"""Analyze a single Python file."""
self.current_file = relative_path
self.current_function = None
self.current_class = None
visitor = FunctionVisitor(self)
visitor.visit(tree)
def get_result(self) -> dict[str, Any]:
"""Get the analysis result as a dictionary."""
return {
"module": self.package_name,
"nodes": [self._node_to_dict(n) for n in self.nodes.values()],
"edges": [self._edge_to_dict(e) for e in self._dedupe_edges()],
"entrypoints": [self._entrypoint_to_dict(e) for e in self.entrypoints]
}
def _node_to_dict(self, node: FunctionNode) -> dict[str, Any]:
return {
"id": node.id,
"package": node.package,
"name": node.name,
"signature": node.qualified_name,
"position": {
"file": node.file,
"line": node.line,
"column": 0
},
"visibility": node.visibility,
"annotations": node.annotations
}
def _edge_to_dict(self, edge: CallEdge) -> dict[str, Any]:
return {
"from": edge.from_id,
"to": edge.to_id,
"kind": edge.kind,
"site": {
"file": edge.file,
"line": edge.line
}
}
def _entrypoint_to_dict(self, ep: Entrypoint) -> dict[str, Any]:
result: dict[str, Any] = {
"id": ep.id,
"type": ep.type
}
if ep.route:
result["route"] = ep.route
if ep.method:
result["method"] = ep.method
return result
def _dedupe_edges(self) -> list[CallEdge]:
seen: set[str] = set()
result: list[CallEdge] = []
for edge in self.edges:
key = f"{edge.from_id}|{edge.to_id}"
if key not in seen:
seen.add(key)
result.append(edge)
return result
def make_symbol_id(self, name: str, class_name: Optional[str] = None) -> str:
"""Create a symbol ID for a function or method."""
module_base = self.current_file.replace('.py', '').replace('/', '.').replace('\\', '.')
if class_name:
return f"py:{self.package_name}/{module_base}.{class_name}.{name}"
return f"py:{self.package_name}/{module_base}.{name}"
def add_function(
self,
name: str,
line: int,
decorators: list[str],
class_name: Optional[str] = None,
is_private: bool = False
) -> str:
"""Add a function node to the graph."""
symbol_id = self.make_symbol_id(name, class_name)
qualified_name = f"{class_name}.{name}" if class_name else name
visibility = "private" if is_private or name.startswith('_') else "public"
node = FunctionNode(
id=symbol_id,
package=self.package_name,
name=name,
qualified_name=qualified_name,
file=self.current_file,
line=line,
visibility=visibility,
annotations=decorators
)
self.nodes[symbol_id] = node
# Detect entrypoints
entrypoint = self._detect_entrypoint(name, decorators, class_name)
if entrypoint:
node.is_entrypoint = True
node.entrypoint_type = entrypoint.type
self.entrypoints.append(entrypoint)
return symbol_id
def add_call(self, target_name: str, line: int) -> None:
"""Add a call edge from the current function."""
if not self.current_function:
return
# Try to resolve the target
target_id = self._resolve_target(target_name)
self.edges.append(CallEdge(
from_id=self.current_function,
to_id=target_id,
kind="direct",
file=self.current_file,
line=line
))
def _resolve_target(self, name: str) -> str:
"""Resolve a call target to a symbol ID."""
# Check if it's a known local function
for node_id, node in self.nodes.items():
if node.name == name or node.qualified_name == name:
return node_id
# External or unresolved
return f"py:external/{name}"
def _detect_entrypoint(
self,
name: str,
decorators: list[str],
class_name: Optional[str]
) -> Optional[Entrypoint]:
"""Detect if a function is an entrypoint based on frameworks and decorators."""
symbol_id = self.make_symbol_id(name, class_name)
for decorator in decorators:
# Flask routes
if 'route' in decorator.lower() or decorator.lower() in ['get', 'post', 'put', 'delete', 'patch']:
route = self._extract_route_from_decorator(decorator)
method = self._extract_method_from_decorator(decorator)
return Entrypoint(id=symbol_id, type="http_handler", route=route, method=method)
# FastAPI routes
if decorator.lower() in ['get', 'post', 'put', 'delete', 'patch', 'api_route']:
route = self._extract_route_from_decorator(decorator)
return Entrypoint(id=symbol_id, type="http_handler", route=route, method=decorator.upper())
# Celery tasks
if 'task' in decorator.lower() or 'shared_task' in decorator.lower():
return Entrypoint(id=symbol_id, type="background_job")
# Click commands
if 'command' in decorator.lower() or 'group' in decorator.lower():
return Entrypoint(id=symbol_id, type="cli_command")
# Django views (class-based)
if class_name and class_name.endswith('View'):
if name in ['get', 'post', 'put', 'delete', 'patch']:
return Entrypoint(id=symbol_id, type="http_handler", method=name.upper())
# main() function
if name == 'main' and not class_name:
return Entrypoint(id=symbol_id, type="cli_command")
return None
def _extract_route_from_decorator(self, decorator: str) -> Optional[str]:
"""Extract route path from decorator string."""
import re
match = re.search(r"['\"]([/\w{}<>:.-]+)['\"]", decorator)
return match.group(1) if match else None
def _extract_method_from_decorator(self, decorator: str) -> Optional[str]:
"""Extract HTTP method from decorator string."""
import re
methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS']
for method in methods:
if method.lower() in decorator.lower():
return method
match = re.search(r"methods\s*=\s*\[([^\]]+)\]", decorator)
if match:
return match.group(1).strip("'\"").upper()
return None
class FunctionVisitor(ast.NodeVisitor):
"""AST visitor that extracts function definitions and calls."""
def __init__(self, analyzer: PythonASTAnalyzer):
self.analyzer = analyzer
def visit_ClassDef(self, node: ast.ClassDef) -> None:
"""Visit class definitions."""
old_class = self.analyzer.current_class
self.analyzer.current_class = node.name
self.generic_visit(node)
self.analyzer.current_class = old_class
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
"""Visit function definitions."""
self._visit_function(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
"""Visit async function definitions."""
self._visit_function(node)
def _visit_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
"""Common logic for function and async function definitions."""
decorators = [ast.unparse(d) for d in node.decorator_list]
is_private = node.name.startswith('_') and not node.name.startswith('__')
symbol_id = self.analyzer.add_function(
name=node.name,
line=node.lineno,
decorators=decorators,
class_name=self.analyzer.current_class,
is_private=is_private
)
# Visit function body for calls
old_function = self.analyzer.current_function
self.analyzer.current_function = symbol_id
for child in ast.walk(node):
if isinstance(child, ast.Call):
target_name = self._get_call_target(child)
if target_name:
self.analyzer.add_call(target_name, child.lineno)
self.analyzer.current_function = old_function
def _get_call_target(self, node: ast.Call) -> Optional[str]:
"""Extract the target name from a Call node."""
if isinstance(node.func, ast.Name):
return node.func.id
elif isinstance(node.func, ast.Attribute):
parts = self._get_attribute_parts(node.func)
return '.'.join(parts)
return None
def _get_attribute_parts(self, node: ast.Attribute) -> list[str]:
"""Get all parts of an attribute chain."""
parts: list[str] = []
current: ast.expr = node
while isinstance(current, ast.Attribute):
parts.insert(0, current.attr)
current = current.value
if isinstance(current, ast.Name):
parts.insert(0, current.id)
return parts