- Implemented a new tool `stella-callgraph-node` that extracts call graphs from JavaScript/TypeScript projects using Babel AST. - Added command-line interface with options for JSON output and help. - Included functionality to analyze project structure, detect functions, and build call graphs. - Created a package.json file for dependency management. feat: introduce stella-callgraph-python for Python call graph extraction - Developed `stella-callgraph-python` to extract call graphs from Python projects using AST analysis. - Implemented command-line interface with options for JSON output and verbose logging. - Added framework detection to identify popular web frameworks and their entry points. - Created an AST analyzer to traverse Python code and extract function definitions and calls. - Included requirements.txt for project dependencies. chore: add framework detection for Python projects - Implemented framework detection logic to identify frameworks like Flask, FastAPI, Django, and others based on project files and import patterns. - Enhanced the AST analyzer to recognize entry points based on decorators and function definitions.
323 lines
11 KiB
Python
323 lines
11 KiB
Python
"""
|
|
AST analyzer for Python call graph extraction.
|
|
"""
|
|
|
|
import ast
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
|
|
@dataclass
|
|
class FunctionNode:
|
|
"""Represents a function in the call graph."""
|
|
id: str
|
|
package: str
|
|
name: str
|
|
qualified_name: str
|
|
file: str
|
|
line: int
|
|
visibility: str
|
|
annotations: list[str] = field(default_factory=list)
|
|
is_entrypoint: bool = False
|
|
entrypoint_type: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class CallEdge:
|
|
"""Represents a call between functions."""
|
|
from_id: str
|
|
to_id: str
|
|
kind: str
|
|
file: str
|
|
line: int
|
|
|
|
|
|
@dataclass
|
|
class Entrypoint:
|
|
"""Represents a detected entrypoint."""
|
|
id: str
|
|
type: str
|
|
route: Optional[str] = None
|
|
method: Optional[str] = None
|
|
|
|
|
|
class PythonASTAnalyzer:
|
|
"""Analyzes Python AST to extract call graph information."""
|
|
|
|
def __init__(self, package_name: str, root: Path, frameworks: list[str]):
|
|
self.package_name = package_name
|
|
self.root = root
|
|
self.frameworks = frameworks
|
|
self.nodes: dict[str, FunctionNode] = {}
|
|
self.edges: list[CallEdge] = []
|
|
self.entrypoints: list[Entrypoint] = []
|
|
self.current_function: Optional[str] = None
|
|
self.current_file: str = ""
|
|
self.current_class: Optional[str] = None
|
|
|
|
def analyze_file(self, tree: ast.AST, relative_path: str) -> None:
|
|
"""Analyze a single Python file."""
|
|
self.current_file = relative_path
|
|
self.current_function = None
|
|
self.current_class = None
|
|
|
|
visitor = FunctionVisitor(self)
|
|
visitor.visit(tree)
|
|
|
|
def get_result(self) -> dict[str, Any]:
|
|
"""Get the analysis result as a dictionary."""
|
|
return {
|
|
"module": self.package_name,
|
|
"nodes": [self._node_to_dict(n) for n in self.nodes.values()],
|
|
"edges": [self._edge_to_dict(e) for e in self._dedupe_edges()],
|
|
"entrypoints": [self._entrypoint_to_dict(e) for e in self.entrypoints]
|
|
}
|
|
|
|
def _node_to_dict(self, node: FunctionNode) -> dict[str, Any]:
|
|
return {
|
|
"id": node.id,
|
|
"package": node.package,
|
|
"name": node.name,
|
|
"signature": node.qualified_name,
|
|
"position": {
|
|
"file": node.file,
|
|
"line": node.line,
|
|
"column": 0
|
|
},
|
|
"visibility": node.visibility,
|
|
"annotations": node.annotations
|
|
}
|
|
|
|
def _edge_to_dict(self, edge: CallEdge) -> dict[str, Any]:
|
|
return {
|
|
"from": edge.from_id,
|
|
"to": edge.to_id,
|
|
"kind": edge.kind,
|
|
"site": {
|
|
"file": edge.file,
|
|
"line": edge.line
|
|
}
|
|
}
|
|
|
|
def _entrypoint_to_dict(self, ep: Entrypoint) -> dict[str, Any]:
|
|
result: dict[str, Any] = {
|
|
"id": ep.id,
|
|
"type": ep.type
|
|
}
|
|
if ep.route:
|
|
result["route"] = ep.route
|
|
if ep.method:
|
|
result["method"] = ep.method
|
|
return result
|
|
|
|
def _dedupe_edges(self) -> list[CallEdge]:
|
|
seen: set[str] = set()
|
|
result: list[CallEdge] = []
|
|
for edge in self.edges:
|
|
key = f"{edge.from_id}|{edge.to_id}"
|
|
if key not in seen:
|
|
seen.add(key)
|
|
result.append(edge)
|
|
return result
|
|
|
|
def make_symbol_id(self, name: str, class_name: Optional[str] = None) -> str:
|
|
"""Create a symbol ID for a function or method."""
|
|
module_base = self.current_file.replace('.py', '').replace('/', '.').replace('\\', '.')
|
|
|
|
if class_name:
|
|
return f"py:{self.package_name}/{module_base}.{class_name}.{name}"
|
|
return f"py:{self.package_name}/{module_base}.{name}"
|
|
|
|
def add_function(
|
|
self,
|
|
name: str,
|
|
line: int,
|
|
decorators: list[str],
|
|
class_name: Optional[str] = None,
|
|
is_private: bool = False
|
|
) -> str:
|
|
"""Add a function node to the graph."""
|
|
symbol_id = self.make_symbol_id(name, class_name)
|
|
|
|
qualified_name = f"{class_name}.{name}" if class_name else name
|
|
visibility = "private" if is_private or name.startswith('_') else "public"
|
|
|
|
node = FunctionNode(
|
|
id=symbol_id,
|
|
package=self.package_name,
|
|
name=name,
|
|
qualified_name=qualified_name,
|
|
file=self.current_file,
|
|
line=line,
|
|
visibility=visibility,
|
|
annotations=decorators
|
|
)
|
|
|
|
self.nodes[symbol_id] = node
|
|
|
|
# Detect entrypoints
|
|
entrypoint = self._detect_entrypoint(name, decorators, class_name)
|
|
if entrypoint:
|
|
node.is_entrypoint = True
|
|
node.entrypoint_type = entrypoint.type
|
|
self.entrypoints.append(entrypoint)
|
|
|
|
return symbol_id
|
|
|
|
def add_call(self, target_name: str, line: int) -> None:
|
|
"""Add a call edge from the current function."""
|
|
if not self.current_function:
|
|
return
|
|
|
|
# Try to resolve the target
|
|
target_id = self._resolve_target(target_name)
|
|
|
|
self.edges.append(CallEdge(
|
|
from_id=self.current_function,
|
|
to_id=target_id,
|
|
kind="direct",
|
|
file=self.current_file,
|
|
line=line
|
|
))
|
|
|
|
def _resolve_target(self, name: str) -> str:
|
|
"""Resolve a call target to a symbol ID."""
|
|
# Check if it's a known local function
|
|
for node_id, node in self.nodes.items():
|
|
if node.name == name or node.qualified_name == name:
|
|
return node_id
|
|
|
|
# External or unresolved
|
|
return f"py:external/{name}"
|
|
|
|
def _detect_entrypoint(
|
|
self,
|
|
name: str,
|
|
decorators: list[str],
|
|
class_name: Optional[str]
|
|
) -> Optional[Entrypoint]:
|
|
"""Detect if a function is an entrypoint based on frameworks and decorators."""
|
|
symbol_id = self.make_symbol_id(name, class_name)
|
|
|
|
for decorator in decorators:
|
|
# Flask routes
|
|
if 'route' in decorator.lower() or decorator.lower() in ['get', 'post', 'put', 'delete', 'patch']:
|
|
route = self._extract_route_from_decorator(decorator)
|
|
method = self._extract_method_from_decorator(decorator)
|
|
return Entrypoint(id=symbol_id, type="http_handler", route=route, method=method)
|
|
|
|
# FastAPI routes
|
|
if decorator.lower() in ['get', 'post', 'put', 'delete', 'patch', 'api_route']:
|
|
route = self._extract_route_from_decorator(decorator)
|
|
return Entrypoint(id=symbol_id, type="http_handler", route=route, method=decorator.upper())
|
|
|
|
# Celery tasks
|
|
if 'task' in decorator.lower() or 'shared_task' in decorator.lower():
|
|
return Entrypoint(id=symbol_id, type="background_job")
|
|
|
|
# Click commands
|
|
if 'command' in decorator.lower() or 'group' in decorator.lower():
|
|
return Entrypoint(id=symbol_id, type="cli_command")
|
|
|
|
# Django views (class-based)
|
|
if class_name and class_name.endswith('View'):
|
|
if name in ['get', 'post', 'put', 'delete', 'patch']:
|
|
return Entrypoint(id=symbol_id, type="http_handler", method=name.upper())
|
|
|
|
# main() function
|
|
if name == 'main' and not class_name:
|
|
return Entrypoint(id=symbol_id, type="cli_command")
|
|
|
|
return None
|
|
|
|
def _extract_route_from_decorator(self, decorator: str) -> Optional[str]:
|
|
"""Extract route path from decorator string."""
|
|
import re
|
|
match = re.search(r"['\"]([/\w{}<>:.-]+)['\"]", decorator)
|
|
return match.group(1) if match else None
|
|
|
|
def _extract_method_from_decorator(self, decorator: str) -> Optional[str]:
|
|
"""Extract HTTP method from decorator string."""
|
|
import re
|
|
methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS']
|
|
for method in methods:
|
|
if method.lower() in decorator.lower():
|
|
return method
|
|
match = re.search(r"methods\s*=\s*\[([^\]]+)\]", decorator)
|
|
if match:
|
|
return match.group(1).strip("'\"").upper()
|
|
return None
|
|
|
|
|
|
class FunctionVisitor(ast.NodeVisitor):
|
|
"""AST visitor that extracts function definitions and calls."""
|
|
|
|
def __init__(self, analyzer: PythonASTAnalyzer):
|
|
self.analyzer = analyzer
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
|
"""Visit class definitions."""
|
|
old_class = self.analyzer.current_class
|
|
self.analyzer.current_class = node.name
|
|
|
|
self.generic_visit(node)
|
|
|
|
self.analyzer.current_class = old_class
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
|
"""Visit function definitions."""
|
|
self._visit_function(node)
|
|
|
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
|
"""Visit async function definitions."""
|
|
self._visit_function(node)
|
|
|
|
def _visit_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
|
|
"""Common logic for function and async function definitions."""
|
|
decorators = [ast.unparse(d) for d in node.decorator_list]
|
|
is_private = node.name.startswith('_') and not node.name.startswith('__')
|
|
|
|
symbol_id = self.analyzer.add_function(
|
|
name=node.name,
|
|
line=node.lineno,
|
|
decorators=decorators,
|
|
class_name=self.analyzer.current_class,
|
|
is_private=is_private
|
|
)
|
|
|
|
# Visit function body for calls
|
|
old_function = self.analyzer.current_function
|
|
self.analyzer.current_function = symbol_id
|
|
|
|
for child in ast.walk(node):
|
|
if isinstance(child, ast.Call):
|
|
target_name = self._get_call_target(child)
|
|
if target_name:
|
|
self.analyzer.add_call(target_name, child.lineno)
|
|
|
|
self.analyzer.current_function = old_function
|
|
|
|
def _get_call_target(self, node: ast.Call) -> Optional[str]:
|
|
"""Extract the target name from a Call node."""
|
|
if isinstance(node.func, ast.Name):
|
|
return node.func.id
|
|
elif isinstance(node.func, ast.Attribute):
|
|
parts = self._get_attribute_parts(node.func)
|
|
return '.'.join(parts)
|
|
return None
|
|
|
|
def _get_attribute_parts(self, node: ast.Attribute) -> list[str]:
|
|
"""Get all parts of an attribute chain."""
|
|
parts: list[str] = []
|
|
current: ast.expr = node
|
|
|
|
while isinstance(current, ast.Attribute):
|
|
parts.insert(0, current.attr)
|
|
current = current.value
|
|
|
|
if isinstance(current, ast.Name):
|
|
parts.insert(0, current.id)
|
|
|
|
return parts
|