save progress

This commit is contained in:
StellaOps Bot
2025-12-26 22:03:32 +02:00
parent 9a4cd2e0f7
commit e6c47c8f50
3634 changed files with 253222 additions and 56632 deletions

View File

@@ -0,0 +1,9 @@
"""
StellaOps Solution and NuGet Tools.
This package provides CLI tools for:
- sln_generator: Generate consistent .sln solution files
- nuget_normalizer: Normalize NuGet package versions across csproj files
"""
__version__ = "1.0.0"

View File

@@ -0,0 +1,37 @@
"""
StellaOps Solution and NuGet Tools Library.
This package provides shared utilities for:
- Parsing .csproj files
- Generating .sln solution files
- Normalizing NuGet package versions
"""
from .models import CsprojProject, SolutionFolder, PackageUsage
from .version_utils import parse_version, compare_versions, is_stable, select_latest_stable
from .csproj_parser import find_all_csproj, parse_csproj, get_deterministic_guid
from .dependency_graph import build_dependency_graph, get_transitive_dependencies, classify_dependencies
from .sln_writer import generate_solution_content, build_folder_hierarchy
__all__ = [
# Models
"CsprojProject",
"SolutionFolder",
"PackageUsage",
# Version utilities
"parse_version",
"compare_versions",
"is_stable",
"select_latest_stable",
# Csproj parsing
"find_all_csproj",
"parse_csproj",
"get_deterministic_guid",
# Dependency graph
"build_dependency_graph",
"get_transitive_dependencies",
"classify_dependencies",
# Solution writer
"generate_solution_content",
"build_folder_hierarchy",
]

Binary file not shown.

View File

@@ -0,0 +1,276 @@
"""
Csproj file parsing utilities.
Provides functions to:
- Find all .csproj files in a directory tree
- Parse csproj files to extract project references and package references
- Generate deterministic GUIDs for projects
"""
import hashlib
import logging
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Optional
from .models import CsprojProject
logger = logging.getLogger(__name__)
# Default patterns to exclude when scanning for csproj files
DEFAULT_EXCLUDE_DIRS = {
"bin",
"obj",
"node_modules",
".git",
".vs",
".idea",
"third_party",
"packages",
".nuget",
".cache",
"Fixtures", # Test fixture files should not be in solutions
"TestData", # Test data files should not be in solutions
}
# Default file patterns to exclude (test fixtures, samples, etc.)
DEFAULT_EXCLUDE_PATTERNS = {
"*.Tests.Fixtures",
"*.Samples",
}
def get_deterministic_guid(path: Path, base_path: Optional[Path] = None) -> str:
"""
Generate a deterministic GUID from a path.
Uses SHA256 hash of the relative path to ensure consistency across runs.
Args:
path: Path to generate GUID for
base_path: Base path to calculate relative path from (optional)
Returns:
GUID string in uppercase format (e.g., "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX")
"""
if base_path:
try:
rel_path = path.relative_to(base_path)
except ValueError:
rel_path = path
else:
rel_path = path
# Normalize path separators and convert to lowercase for consistency
normalized = str(rel_path).replace("\\", "/").lower()
# Generate SHA256 hash
hash_bytes = hashlib.sha256(normalized.encode("utf-8")).digest()
# Format as GUID (use first 16 bytes)
guid_hex = hash_bytes[:16].hex().upper()
guid = f"{guid_hex[:8]}-{guid_hex[8:12]}-{guid_hex[12:16]}-{guid_hex[16:20]}-{guid_hex[20:32]}"
return guid
def find_all_csproj(
root_dir: Path,
exclude_dirs: Optional[set[str]] = None,
exclude_patterns: Optional[set[str]] = None,
) -> list[Path]:
"""
Find all .csproj files under a directory.
Args:
root_dir: Root directory to search
exclude_dirs: Directory names to exclude (defaults to bin, obj, etc.)
exclude_patterns: File name patterns to exclude
Returns:
List of absolute paths to .csproj files, sorted by path
"""
if exclude_dirs is None:
exclude_dirs = DEFAULT_EXCLUDE_DIRS
if exclude_patterns is None:
exclude_patterns = DEFAULT_EXCLUDE_PATTERNS
csproj_files: list[Path] = []
if not root_dir.exists():
logger.warning(f"Directory does not exist: {root_dir}")
return csproj_files
for item in root_dir.rglob("*.csproj"):
# Check if any parent directory should be excluded
skip = False
for parent in item.parents:
if parent.name in exclude_dirs:
skip = True
break
if skip:
continue
# Check file name patterns
for pattern in exclude_patterns:
if item.match(pattern):
skip = True
break
if skip:
continue
csproj_files.append(item.resolve())
return sorted(csproj_files)
def parse_csproj(
csproj_path: Path,
base_path: Optional[Path] = None,
) -> Optional[CsprojProject]:
"""
Parse a .csproj file and extract project information.
Args:
csproj_path: Path to the .csproj file
base_path: Base path for generating deterministic GUID
Returns:
CsprojProject with parsed information, or None if parsing fails
"""
if not csproj_path.exists():
logger.error(f"Csproj file does not exist: {csproj_path}")
return None
try:
tree = ET.parse(csproj_path)
root = tree.getroot()
except ET.ParseError as e:
logger.error(f"Failed to parse XML in {csproj_path}: {e}")
return None
# Extract project name from file name
name = csproj_path.stem
# Generate deterministic GUID
guid = get_deterministic_guid(csproj_path, base_path)
# Parse project references
project_references = _parse_project_references(root, csproj_path)
# Parse package references
package_references = _parse_package_references(root)
return CsprojProject(
path=csproj_path.resolve(),
name=name,
guid=guid,
project_references=project_references,
package_references=package_references,
)
def _parse_project_references(root: ET.Element, csproj_path: Path) -> list[Path]:
"""
Parse ProjectReference elements from csproj XML.
Args:
root: XML root element
csproj_path: Path to the csproj file (for resolving relative paths)
Returns:
List of resolved absolute paths to referenced projects
"""
references: list[Path] = []
csproj_dir = csproj_path.parent
# Handle both with and without namespace
for ref in root.iter():
if ref.tag.endswith("ProjectReference") or ref.tag == "ProjectReference":
include = ref.get("Include")
if include:
# Normalize path separators
include = include.replace("\\", "/")
# Resolve relative path
try:
ref_path = (csproj_dir / include).resolve()
if ref_path.exists():
references.append(ref_path)
else:
logger.warning(
f"Referenced project does not exist: {include} (from {csproj_path})"
)
except Exception as e:
logger.warning(f"Failed to resolve path {include}: {e}")
return references
def _parse_package_references(root: ET.Element) -> dict[str, str]:
"""
Parse PackageReference elements from csproj XML.
Args:
root: XML root element
Returns:
Dictionary mapping package name to version string
"""
packages: dict[str, str] = {}
for ref in root.iter():
if ref.tag.endswith("PackageReference") or ref.tag == "PackageReference":
include = ref.get("Include")
version = ref.get("Version")
if include and version:
packages[include] = version
elif include:
# Version might be in a child element
for child in ref:
if child.tag.endswith("Version") or child.tag == "Version":
if child.text:
packages[include] = child.text.strip()
break
return packages
def get_project_name_from_path(csproj_path: Path) -> str:
"""
Extract project name from csproj file path.
Args:
csproj_path: Path to csproj file
Returns:
Project name (file name without extension)
"""
return csproj_path.stem
def resolve_project_path(
include_path: str,
from_csproj: Path,
) -> Optional[Path]:
"""
Resolve a ProjectReference Include path to an absolute path.
Args:
include_path: The Include attribute value
from_csproj: The csproj file containing the reference
Returns:
Resolved absolute path, or None if resolution fails
"""
# Normalize path separators
include_path = include_path.replace("\\", "/")
try:
resolved = (from_csproj.parent / include_path).resolve()
return resolved if resolved.exists() else None
except Exception:
return None

View File

@@ -0,0 +1,282 @@
"""
Project dependency graph utilities.
Provides functions to:
- Build a dependency graph from parsed projects
- Get transitive dependencies
- Classify dependencies as internal or external to a module
"""
import logging
from pathlib import Path
from typing import Optional
from .models import CsprojProject
logger = logging.getLogger(__name__)
def build_dependency_graph(
projects: list[CsprojProject],
) -> dict[Path, set[Path]]:
"""
Build a dependency graph from a list of projects.
Args:
projects: List of parsed CsprojProject objects
Returns:
Dictionary mapping project path to set of dependency paths
"""
graph: dict[Path, set[Path]] = {}
for project in projects:
graph[project.path] = set(project.project_references)
return graph
def get_transitive_dependencies(
project_path: Path,
graph: dict[Path, set[Path]],
visited: Optional[set[Path]] = None,
) -> set[Path]:
"""
Get all transitive dependencies for a project.
Handles circular dependencies gracefully by tracking visited nodes.
Args:
project_path: Path to the project
graph: Dependency graph from build_dependency_graph
visited: Set of already visited paths (for cycle detection)
Returns:
Set of all transitive dependency paths
"""
if visited is None:
visited = set()
if project_path in visited:
return set() # Cycle detected
visited.add(project_path)
all_deps: set[Path] = set()
direct_deps = graph.get(project_path, set())
all_deps.update(direct_deps)
for dep in direct_deps:
transitive = get_transitive_dependencies(dep, graph, visited.copy())
all_deps.update(transitive)
return all_deps
def classify_dependencies(
project: CsprojProject,
module_dir: Path,
src_root: Path,
) -> dict[str, list[Path]]:
"""
Classify project dependencies as internal or external.
Args:
project: The project to analyze
module_dir: Root directory of the module
src_root: Root of the src/ directory
Returns:
Dictionary with keys:
- 'internal': Dependencies within module_dir
- '__Libraries': Dependencies from src/__Libraries/
- '<ModuleName>': Dependencies from other modules
"""
result: dict[str, list[Path]] = {"internal": []}
module_dir = module_dir.resolve()
src_root = src_root.resolve()
for ref_path in project.project_references:
ref_path = ref_path.resolve()
# Check if internal to module
try:
ref_path.relative_to(module_dir)
result["internal"].append(ref_path)
continue
except ValueError:
pass
# External - classify by source module
category = _get_external_category(ref_path, src_root)
if category not in result:
result[category] = []
result[category].append(ref_path)
return result
def _get_external_category(ref_path: Path, src_root: Path) -> str:
"""
Determine the category for an external dependency.
Args:
ref_path: Path to the referenced project
src_root: Root of the src/ directory
Returns:
Category name (e.g., '__Libraries', 'Authority', 'Scanner')
"""
try:
rel_path = ref_path.relative_to(src_root)
except ValueError:
# Outside of src/ - use 'External'
return "External"
parts = rel_path.parts
if len(parts) == 0:
return "External"
# First part is the module or __Libraries/__Tests etc.
first_part = parts[0]
if first_part == "__Libraries":
return "__Libraries"
elif first_part == "__Tests":
return "__Tests"
elif first_part == "__Analyzers":
return "__Analyzers"
else:
# It's a module name
return first_part
def collect_all_external_dependencies(
projects: list[CsprojProject],
module_dir: Path,
src_root: Path,
project_map: dict[Path, CsprojProject],
) -> dict[str, list[CsprojProject]]:
"""
Collect all external dependencies for a module's projects.
Includes transitive dependencies.
Args:
projects: List of projects in the module
module_dir: Root directory of the module
src_root: Root of the src/ directory
project_map: Map from path to CsprojProject for all known projects
Returns:
Dictionary mapping category to list of external CsprojProject objects
"""
# Build dependency graph for all known projects
all_projects = list(project_map.values())
graph = build_dependency_graph(all_projects)
module_dir = module_dir.resolve()
src_root = src_root.resolve()
# Collect all external dependencies
external_deps: dict[str, set[Path]] = {}
for project in projects:
# Get all transitive dependencies
all_deps = get_transitive_dependencies(project.path, graph)
for dep_path in all_deps:
dep_path = dep_path.resolve()
# Skip if internal to module
try:
dep_path.relative_to(module_dir)
continue
except ValueError:
pass
# External - classify
category = _get_external_category(dep_path, src_root)
if category not in external_deps:
external_deps[category] = set()
external_deps[category].add(dep_path)
# Convert paths to CsprojProject objects
result: dict[str, list[CsprojProject]] = {}
for category, paths in external_deps.items():
result[category] = []
for path in sorted(paths):
if path in project_map:
result[category].append(project_map[path])
else:
logger.warning(f"External dependency not in project map: {path}")
return result
def get_module_projects(
module_dir: Path,
all_projects: list[CsprojProject],
) -> list[CsprojProject]:
"""
Get all projects that belong to a module.
Args:
module_dir: Root directory of the module
all_projects: List of all projects
Returns:
List of projects within the module directory
"""
module_dir = module_dir.resolve()
result: list[CsprojProject] = []
for project in all_projects:
try:
project.path.relative_to(module_dir)
result.append(project)
except ValueError:
pass
return result
def detect_circular_dependencies(
graph: dict[Path, set[Path]],
) -> list[list[Path]]:
"""
Detect circular dependencies in the project graph.
Args:
graph: Dependency graph
Returns:
List of cycles (each cycle is a list of paths)
"""
cycles: list[list[Path]] = []
visited: set[Path] = set()
rec_stack: set[Path] = set()
def dfs(node: Path, path: list[Path]) -> None:
visited.add(node)
rec_stack.add(node)
path.append(node)
for neighbor in graph.get(node, set()):
if neighbor not in visited:
dfs(neighbor, path.copy())
elif neighbor in rec_stack:
# Found a cycle
cycle_start = path.index(neighbor)
cycle = path[cycle_start:] + [neighbor]
cycles.append(cycle)
rec_stack.remove(node)
for node in graph:
if node not in visited:
dfs(node, [])
return cycles

View File

@@ -0,0 +1,87 @@
"""
Data models for solution and project management.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
@dataclass
class CsprojProject:
"""Represents a .csproj project file."""
path: Path # Absolute path to .csproj file
name: str # Project name (without extension)
guid: str # Project GUID (generated deterministically from path)
project_references: list[Path] = field(default_factory=list) # Resolved absolute paths
package_references: dict[str, str] = field(default_factory=dict) # Package name -> version
def __hash__(self) -> int:
return hash(self.path)
def __eq__(self, other: object) -> bool:
if not isinstance(other, CsprojProject):
return False
return self.path == other.path
@dataclass
class SolutionFolder:
"""Represents a solution folder in a .sln file."""
name: str # Folder display name
guid: str # Folder GUID
path: str # Full path within solution (e.g., "Module/__Libraries")
parent_guid: Optional[str] = None # Parent folder GUID (None for root folders)
children: list["SolutionFolder"] = field(default_factory=list)
projects: list[CsprojProject] = field(default_factory=list)
def __hash__(self) -> int:
return hash(self.path)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SolutionFolder):
return False
return self.path == other.path
@dataclass
class PackageUsage:
"""Tracks usage of a NuGet package across the codebase."""
package_name: str
usages: dict[Path, str] = field(default_factory=dict) # csproj path -> version string
def get_all_versions(self) -> list[str]:
"""Get list of unique versions used."""
return list(set(self.usages.values()))
def get_usage_count(self) -> int:
"""Get number of projects using this package."""
return len(self.usages)
@dataclass
class NormalizationChange:
"""Represents a version change for a package in a project."""
csproj_path: Path
old_version: str
new_version: str
@dataclass
class NormalizationResult:
"""Result of normalizing a package across the codebase."""
package_name: str
target_version: str
changes: list[NormalizationChange] = field(default_factory=list)
skipped_reason: Optional[str] = None
# Constants for solution file format
CSHARP_PROJECT_TYPE_GUID = "FAE04EC0-301F-11D3-BF4B-00C04F79EFBC"
SOLUTION_FOLDER_TYPE_GUID = "2150E333-8FDC-42A3-9474-1A3956D46DE8"
BYPASS_MARKER = "# STELLAOPS-MANUAL-SOLUTION"

View File

@@ -0,0 +1,416 @@
"""
NuGet API v3 client for package version and vulnerability queries.
"""
import logging
import re
from typing import Any
try:
import requests
except ImportError:
requests = None # type: ignore
from .version_utils import parse_version, is_stable
logger = logging.getLogger(__name__)
NUGET_SERVICE_INDEX = "https://api.nuget.org/v3/index.json"
NUGET_VULN_INDEX = "https://api.nuget.org/v3/vulnerabilities/index.json"
class NuGetApiError(Exception):
"""Error communicating with NuGet API."""
pass
class NuGetApiClient:
"""
Client for NuGet API v3 operations.
Provides methods for:
- Fetching available package versions
- Fetching vulnerability data
- Finding non-vulnerable versions
"""
def __init__(self, source: str = "https://api.nuget.org/v3"):
if requests is None:
raise ImportError(
"requests library is required for NuGet API access. "
"Install with: pip install requests"
)
self.source = source.rstrip("/")
self._session = requests.Session()
self._session.headers.update(
{"User-Agent": "StellaOps-NuGetVulnChecker/1.0"}
)
self._service_index: dict | None = None
self._vuln_cache: dict[str, list[dict]] | None = None
self._search_url: str | None = None
self._registration_url: str | None = None
def _get_service_index(self) -> dict:
"""Fetch and cache the NuGet service index."""
if self._service_index is not None:
return self._service_index
try:
response = self._session.get(f"{self.source}/index.json", timeout=30)
response.raise_for_status()
self._service_index = response.json()
return self._service_index
except Exception as e:
raise NuGetApiError(f"Failed to fetch NuGet service index: {e}")
def _get_search_url(self) -> str:
"""Get the SearchQueryService URL from service index."""
if self._search_url:
return self._search_url
index = self._get_service_index()
resources = index.get("resources", [])
# Look for SearchQueryService
for resource in resources:
resource_type = resource.get("@type", "")
if "SearchQueryService" in resource_type:
self._search_url = resource.get("@id", "")
return self._search_url
raise NuGetApiError("SearchQueryService not found in service index")
def _get_registration_url(self) -> str:
"""Get the RegistrationsBaseUrl from service index."""
if self._registration_url:
return self._registration_url
index = self._get_service_index()
resources = index.get("resources", [])
# Look for RegistrationsBaseUrl
for resource in resources:
resource_type = resource.get("@type", "")
if "RegistrationsBaseUrl" in resource_type:
self._registration_url = resource.get("@id", "").rstrip("/")
return self._registration_url
raise NuGetApiError("RegistrationsBaseUrl not found in service index")
def get_available_versions(self, package_id: str) -> list[str]:
"""
Fetch all available versions of a package from NuGet.
Args:
package_id: The NuGet package ID
Returns:
List of version strings, sorted newest first
"""
try:
# Use registration API for complete version list
reg_url = self._get_registration_url()
package_lower = package_id.lower()
url = f"{reg_url}/{package_lower}/index.json"
response = self._session.get(url, timeout=30)
if response.status_code == 404:
logger.warning(f"Package not found on NuGet: {package_id}")
return []
response.raise_for_status()
data = response.json()
versions = []
for page in data.get("items", []):
# Pages may be inline or require fetching
if "items" in page:
items = page["items"]
else:
# Fetch the page
page_url = page.get("@id")
if page_url:
page_response = self._session.get(page_url, timeout=30)
page_response.raise_for_status()
items = page_response.json().get("items", [])
else:
items = []
for item in items:
catalog_entry = item.get("catalogEntry", {})
version = catalog_entry.get("version")
if version:
versions.append(version)
# Sort by parsed version, newest first
def sort_key(v: str) -> tuple:
parsed = parse_version(v)
if parsed is None:
return (0, 0, 0, "")
return parsed
versions.sort(key=sort_key, reverse=True)
return versions
except NuGetApiError:
raise
except Exception as e:
logger.warning(f"Failed to fetch versions for {package_id}: {e}")
return []
def get_vulnerability_data(self) -> dict[str, list[dict]]:
"""
Fetch vulnerability data from NuGet VulnerabilityInfo API.
Returns:
Dictionary mapping lowercase package ID to list of vulnerability info dicts.
Each dict contains: severity, advisory_url, versions (affected range)
"""
if self._vuln_cache is not None:
return self._vuln_cache
try:
# Fetch vulnerability index
response = self._session.get(NUGET_VULN_INDEX, timeout=30)
response.raise_for_status()
index = response.json()
vuln_map: dict[str, list[dict]] = {}
# Fetch each vulnerability page
for page_info in index:
page_url = page_info.get("@id")
if not page_url:
continue
try:
page_response = self._session.get(page_url, timeout=60)
page_response.raise_for_status()
page_data = page_response.json()
# Parse vulnerability entries
self._merge_vuln_data(vuln_map, page_data)
except Exception as e:
logger.warning(f"Failed to fetch vulnerability page {page_url}: {e}")
continue
self._vuln_cache = vuln_map
logger.info(f"Loaded vulnerability data for {len(vuln_map)} packages")
return vuln_map
except Exception as e:
logger.warning(f"Failed to fetch vulnerability data: {e}")
return {}
def _merge_vuln_data(
self, vuln_map: dict[str, list[dict]], page_data: Any
) -> None:
"""Merge vulnerability page data into the vulnerability map."""
# The vulnerability data format is a dict mapping package ID (lowercase)
# to list of vulnerability objects
if not isinstance(page_data, dict):
return
for package_id, vulns in page_data.items():
if package_id.startswith("@"):
# Skip metadata fields like @context
continue
package_lower = package_id.lower()
if package_lower not in vuln_map:
vuln_map[package_lower] = []
if isinstance(vulns, list):
vuln_map[package_lower].extend(vulns)
def is_version_vulnerable(
self, package_id: str, version: str, vuln_data: dict[str, list[dict]] | None = None
) -> tuple[bool, list[dict]]:
"""
Check if a specific package version is vulnerable.
Args:
package_id: The package ID
version: The version to check
vuln_data: Optional pre-fetched vulnerability data
Returns:
Tuple of (is_vulnerable, list of matching vulnerabilities)
"""
if vuln_data is None:
vuln_data = self.get_vulnerability_data()
package_lower = package_id.lower()
vulns = vuln_data.get(package_lower, [])
if not vulns:
return False, []
matching = []
parsed_version = parse_version(version)
if parsed_version is None:
return False, []
for vuln in vulns:
# Check version range
version_range = vuln.get("versions", "")
if self._version_in_range(version, parsed_version, version_range):
matching.append(vuln)
return len(matching) > 0, matching
def _version_in_range(
self, version: str, parsed: tuple, range_str: str
) -> bool:
"""
Check if a version is in a NuGet version range.
NuGet range formats:
- "[1.0.0, 2.0.0)" - >= 1.0.0 and < 2.0.0
- "(, 1.0.0)" - < 1.0.0
- "[1.0.0,)" - >= 1.0.0
- "1.0.0" - exact match
"""
if not range_str:
return False
range_str = range_str.strip()
# Handle exact version
if not range_str.startswith(("[", "(")):
exact_parsed = parse_version(range_str)
return exact_parsed == parsed if exact_parsed else False
# Parse range
match = re.match(r"([\[\(])([^,]*),([^)\]]*)([\)\]])", range_str)
if not match:
return False
left_bracket, left_ver, right_ver, right_bracket = match.groups()
left_ver = left_ver.strip()
right_ver = right_ver.strip()
# Check lower bound
if left_ver:
left_parsed = parse_version(left_ver)
if left_parsed:
if left_bracket == "[":
if parsed < left_parsed:
return False
else: # "("
if parsed <= left_parsed:
return False
# Check upper bound
if right_ver:
right_parsed = parse_version(right_ver)
if right_parsed:
if right_bracket == "]":
if parsed > right_parsed:
return False
else: # ")"
if parsed >= right_parsed:
return False
return True
def find_safe_version(
self,
package_id: str,
current_version: str,
prefer_upgrade: bool = True,
) -> str | None:
"""
Find the closest non-vulnerable version.
Strategy:
1. Get all available versions
2. Filter out versions with known vulnerabilities
3. Prefer: patch upgrade > minor upgrade > major upgrade > downgrade
Args:
package_id: The package ID
current_version: Current (vulnerable) version
prefer_upgrade: If True, prefer upgrades over downgrades
Returns:
Suggested safe version, or None if not found
"""
available = self.get_available_versions(package_id)
if not available:
return None
vuln_data = self.get_vulnerability_data()
current_parsed = parse_version(current_version)
if current_parsed is None:
return None
# Find safe versions
from .version_utils import ParsedVersion
safe_versions: list[tuple[str, ParsedVersion]] = []
for version in available:
# Skip prereleases unless current is prerelease
if not is_stable(version) and is_stable(current_version):
continue
parsed = parse_version(version)
if parsed is None:
continue
is_vuln, _ = self.is_version_vulnerable(package_id, version, vuln_data)
if not is_vuln:
safe_versions.append((version, parsed))
if not safe_versions:
return None
# Sort by preference: closest upgrade first
def sort_key(item: tuple[str, ParsedVersion]) -> tuple:
version, parsed = item
major_diff = parsed.major - current_parsed.major
minor_diff = parsed.minor - current_parsed.minor
patch_diff = parsed.patch - current_parsed.patch
# Prefer upgrades (positive diff) over downgrades
# Within upgrades, prefer smaller changes
if prefer_upgrade:
if major_diff > 0 or (major_diff == 0 and minor_diff > 0) or \
(major_diff == 0 and minor_diff == 0 and patch_diff > 0):
# Upgrade: prefer smaller version jumps
return (0, major_diff, minor_diff, patch_diff)
elif major_diff == 0 and minor_diff == 0 and patch_diff == 0:
# Same version (shouldn't happen if vulnerable)
return (1, 0, 0, 0)
else:
# Downgrade: prefer smaller version drops
return (2, -major_diff, -minor_diff, -patch_diff)
else:
# Just prefer closest version
return (abs(major_diff), abs(minor_diff), abs(patch_diff))
safe_versions.sort(key=sort_key)
return safe_versions[0][0] if safe_versions else None
def get_fix_risk(
self, current_version: str, suggested_version: str
) -> str:
"""
Estimate the risk of upgrading to a suggested version.
Returns: "low", "medium", or "high"
"""
current = parse_version(current_version)
suggested = parse_version(suggested_version)
if current is None or suggested is None:
return "unknown"
if suggested.major > current.major:
return "high" # Major version change
elif suggested.minor > current.minor:
return "medium" # Minor version change
else:
return "low" # Patch or no change

View File

@@ -0,0 +1,381 @@
"""
Solution file (.sln) generation utilities.
Provides functions to:
- Build solution folder hierarchy from projects
- Generate complete .sln file content
"""
import logging
from pathlib import Path
from typing import Optional
from .csproj_parser import get_deterministic_guid
from .models import (
BYPASS_MARKER,
CSHARP_PROJECT_TYPE_GUID,
SOLUTION_FOLDER_TYPE_GUID,
CsprojProject,
SolutionFolder,
)
logger = logging.getLogger(__name__)
# Solution file header
SOLUTION_HEADER = """\
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 17
VisualStudioVersion = 17.0.31903.59
MinimumVisualStudioVersion = 10.0.40219.1
"""
def build_folder_hierarchy(
projects: list[CsprojProject],
base_dir: Path,
prefix: str = "",
) -> dict[str, SolutionFolder]:
"""
Build solution folder hierarchy from project paths.
Creates nested folders matching the physical directory structure.
Args:
projects: List of projects to organize
base_dir: Base directory for calculating relative paths
prefix: Optional prefix for folder paths (e.g., "__External")
Returns:
Dictionary mapping folder path to SolutionFolder object
"""
folders: dict[str, SolutionFolder] = {}
base_dir = base_dir.resolve()
for project in projects:
try:
rel_path = project.path.parent.relative_to(base_dir)
except ValueError:
# Project outside base_dir - skip folder creation
continue
parts = list(rel_path.parts)
if not parts:
continue
# Add prefix if specified
if prefix:
parts = [prefix] + list(parts)
# Create folders for each level
current_path = ""
parent_guid: Optional[str] = None
for part in parts:
if current_path:
current_path = f"{current_path}/{part}"
else:
current_path = part
if current_path not in folders:
folder_guid = get_deterministic_guid(
Path(current_path), Path("")
)
folders[current_path] = SolutionFolder(
name=part,
guid=folder_guid,
path=current_path,
parent_guid=parent_guid,
)
parent_guid = folders[current_path].guid
# Assign project to its folder
if current_path in folders:
folders[current_path].projects.append(project)
return folders
def build_external_folder_hierarchy(
external_groups: dict[str, list[CsprojProject]],
src_root: Path,
) -> dict[str, SolutionFolder]:
"""
Build folder hierarchy for external dependencies.
Organizes external projects under __External/<Source>/<Path>.
Args:
external_groups: Dictionary mapping source category to projects
src_root: Root of the src/ directory
Returns:
Dictionary mapping folder path to SolutionFolder object
"""
folders: dict[str, SolutionFolder] = {}
src_root = src_root.resolve()
# Create __External root folder
external_root_path = "__External"
external_root_guid = get_deterministic_guid(Path(external_root_path), Path(""))
folders[external_root_path] = SolutionFolder(
name="__External",
guid=external_root_guid,
path=external_root_path,
parent_guid=None,
)
for category, projects in sorted(external_groups.items()):
if not projects:
continue
# Create category folder (e.g., __External/__Libraries, __External/Authority)
category_path = f"{external_root_path}/{category}"
category_guid = get_deterministic_guid(Path(category_path), Path(""))
if category_path not in folders:
folders[category_path] = SolutionFolder(
name=category,
guid=category_guid,
path=category_path,
parent_guid=external_root_guid,
)
# For each project, create intermediate folders based on path within source
for project in projects:
try:
if category == "__Libraries":
# Path relative to src/__Libraries/
lib_root = src_root / "__Libraries"
rel_path = project.path.parent.relative_to(lib_root)
else:
# Path relative to src/<Module>/
module_root = src_root / category
rel_path = project.path.parent.relative_to(module_root)
except ValueError:
# Just put directly in category folder
folders[category_path].projects.append(project)
continue
parts = list(rel_path.parts)
if not parts:
folders[category_path].projects.append(project)
continue
# Create intermediate folders
current_path = category_path
parent_guid = category_guid
for part in parts:
current_path = f"{current_path}/{part}"
if current_path not in folders:
folder_guid = get_deterministic_guid(Path(current_path), Path(""))
folders[current_path] = SolutionFolder(
name=part,
guid=folder_guid,
path=current_path,
parent_guid=parent_guid,
)
parent_guid = folders[current_path].guid
# Assign project to deepest folder
folders[current_path].projects.append(project)
return folders
def generate_solution_content(
sln_path: Path,
projects: list[CsprojProject],
folders: dict[str, SolutionFolder],
external_folders: Optional[dict[str, SolutionFolder]] = None,
add_bypass_marker: bool = False,
) -> str:
"""
Generate complete .sln file content.
Args:
sln_path: Path where the solution will be written (for relative paths)
projects: List of internal projects
folders: Internal folder hierarchy
external_folders: External dependency folders (optional)
add_bypass_marker: Whether to add the bypass marker comment
Returns:
Complete .sln file content as string
"""
lines: list[str] = []
sln_dir = sln_path.parent.resolve()
# Add header
if add_bypass_marker:
lines.append(BYPASS_MARKER)
lines.append("")
lines.append(SOLUTION_HEADER.rstrip())
# Merge folders
all_folders = dict(folders)
if external_folders:
all_folders.update(external_folders)
# Collect all projects (internal + external from folders)
all_projects: list[CsprojProject] = list(projects)
project_to_folder: dict[Path, str] = {}
for folder_path, folder in all_folders.items():
for proj in folder.projects:
if proj not in all_projects:
all_projects.append(proj)
project_to_folder[proj.path] = folder_path
# Write solution folder entries
for folder_path in sorted(all_folders.keys()):
folder = all_folders[folder_path]
lines.append(
f'Project("{{{SOLUTION_FOLDER_TYPE_GUID}}}") = "{folder.name}", "{folder.name}", "{{{folder.guid}}}"'
)
lines.append("EndProject")
# Write project entries
for project in sorted(all_projects, key=lambda p: p.name):
rel_path = _get_relative_path(sln_dir, project.path)
lines.append(
f'Project("{{{CSHARP_PROJECT_TYPE_GUID}}}") = "{project.name}", "{rel_path}", "{{{project.guid}}}"'
)
lines.append("EndProject")
# Write Global section
lines.append("Global")
# SolutionConfigurationPlatforms
lines.append("\tGlobalSection(SolutionConfigurationPlatforms) = preSolution")
lines.append("\t\tDebug|Any CPU = Debug|Any CPU")
lines.append("\t\tRelease|Any CPU = Release|Any CPU")
lines.append("\tEndGlobalSection")
# ProjectConfigurationPlatforms
lines.append("\tGlobalSection(ProjectConfigurationPlatforms) = postSolution")
for project in sorted(all_projects, key=lambda p: p.name):
guid = project.guid
lines.append(f"\t\t{{{guid}}}.Debug|Any CPU.ActiveCfg = Debug|Any CPU")
lines.append(f"\t\t{{{guid}}}.Debug|Any CPU.Build.0 = Debug|Any CPU")
lines.append(f"\t\t{{{guid}}}.Release|Any CPU.ActiveCfg = Release|Any CPU")
lines.append(f"\t\t{{{guid}}}.Release|Any CPU.Build.0 = Release|Any CPU")
lines.append("\tEndGlobalSection")
# SolutionProperties
lines.append("\tGlobalSection(SolutionProperties) = preSolution")
lines.append("\t\tHideSolutionNode = FALSE")
lines.append("\tEndGlobalSection")
# NestedProjects - assign folders and projects to parent folders
lines.append("\tGlobalSection(NestedProjects) = preSolution")
# Nest folders under their parents
for folder_path in sorted(all_folders.keys()):
folder = all_folders[folder_path]
if folder.parent_guid:
lines.append(f"\t\t{{{folder.guid}}} = {{{folder.parent_guid}}}")
# Nest projects under their folders
for project in sorted(all_projects, key=lambda p: p.name):
if project.path in project_to_folder:
folder_path = project_to_folder[project.path]
folder = all_folders[folder_path]
lines.append(f"\t\t{{{project.guid}}} = {{{folder.guid}}}")
lines.append("\tEndGlobalSection")
# ExtensibilityGlobals (required by VS)
lines.append("\tGlobalSection(ExtensibilityGlobals) = postSolution")
# Generate a solution GUID
sln_guid = get_deterministic_guid(sln_path, sln_path.parent.parent)
lines.append(f"\t\tSolutionGuid = {{{sln_guid}}}")
lines.append("\tEndGlobalSection")
lines.append("EndGlobal")
lines.append("") # Trailing newline
return "\r\n".join(lines)
def _get_relative_path(from_dir: Path, to_path: Path) -> str:
"""
Get relative path from directory to file, using backslashes.
Args:
from_dir: Directory to calculate from
to_path: Target path
Returns:
Relative path with backslashes (Windows format for .sln)
"""
try:
rel = to_path.relative_to(from_dir)
return str(rel).replace("/", "\\")
except ValueError:
# Different drive or not relative - use absolute with backslashes
return str(to_path).replace("/", "\\")
def has_bypass_marker(sln_path: Path) -> bool:
"""
Check if a solution file has the bypass marker.
Args:
sln_path: Path to the solution file
Returns:
True if the bypass marker is found in the first 10 lines
"""
if not sln_path.exists():
return False
try:
with open(sln_path, "r", encoding="utf-8-sig") as f:
for i, line in enumerate(f):
if i >= 10:
break
if BYPASS_MARKER in line:
return True
except Exception as e:
logger.warning(f"Failed to read solution file {sln_path}: {e}")
return False
def write_solution_file(
sln_path: Path,
content: str,
dry_run: bool = False,
) -> bool:
"""
Write solution content to file.
Args:
sln_path: Path to write to
content: Solution file content
dry_run: If True, don't actually write
Returns:
True if successful (or would be successful in dry run)
"""
if dry_run:
logger.info(f"Would write solution to: {sln_path}")
return True
try:
# Ensure parent directory exists
sln_path.parent.mkdir(parents=True, exist_ok=True)
# Write with UTF-8 BOM and CRLF line endings
with open(sln_path, "w", encoding="utf-8-sig", newline="\r\n") as f:
f.write(content)
logger.info(f"Wrote solution to: {sln_path}")
return True
except Exception as e:
logger.error(f"Failed to write solution {sln_path}: {e}")
return False

View File

@@ -0,0 +1,237 @@
"""
Version parsing and comparison utilities for NuGet packages.
Handles SemVer versions with prerelease suffixes.
"""
import re
from dataclasses import dataclass
from typing import Optional
@dataclass(frozen=True)
class ParsedVersion:
"""Parsed semantic version."""
major: int
minor: int
patch: int
prerelease: Optional[str] = None
build_metadata: Optional[str] = None
original: str = ""
def is_stable(self) -> bool:
"""Returns True if this is a stable (non-prerelease) version."""
return self.prerelease is None
def __lt__(self, other: "ParsedVersion") -> bool:
"""Compare versions following SemVer rules."""
# Compare major.minor.patch first
self_tuple = (self.major, self.minor, self.patch)
other_tuple = (other.major, other.minor, other.patch)
if self_tuple != other_tuple:
return self_tuple < other_tuple
# If equal, prerelease versions are less than stable
if self.prerelease is None and other.prerelease is None:
return False
if self.prerelease is None:
return False # stable > prerelease
if other.prerelease is None:
return True # prerelease < stable
# Both have prerelease - compare alphanumerically
return self._compare_prerelease(self.prerelease, other.prerelease) < 0
def __le__(self, other: "ParsedVersion") -> bool:
return self == other or self < other
def __gt__(self, other: "ParsedVersion") -> bool:
return other < self
def __ge__(self, other: "ParsedVersion") -> bool:
return self == other or self > other
@staticmethod
def _compare_prerelease(a: str, b: str) -> int:
"""Compare prerelease strings according to SemVer."""
a_parts = a.split(".")
b_parts = b.split(".")
for i in range(max(len(a_parts), len(b_parts))):
if i >= len(a_parts):
return -1
if i >= len(b_parts):
return 1
a_part = a_parts[i]
b_part = b_parts[i]
# Try numeric comparison first
a_is_num = a_part.isdigit()
b_is_num = b_part.isdigit()
if a_is_num and b_is_num:
diff = int(a_part) - int(b_part)
if diff != 0:
return diff
elif a_is_num:
return -1 # Numeric < string
elif b_is_num:
return 1 # String > numeric
else:
# Both strings - compare lexically
if a_part < b_part:
return -1
if a_part > b_part:
return 1
return 0
# Regex for parsing NuGet versions
# Matches: 1.2.3, 1.2.3-beta, 1.2.3-beta.1, 1.2.3-rc.1+build, [1.2.3]
VERSION_PATTERN = re.compile(
r"^\[?" # Optional opening bracket
r"(\d+)" # Major (required)
r"(?:\.(\d+))?" # Minor (optional)
r"(?:\.(\d+))?" # Patch (optional)
r"(?:-([a-zA-Z0-9][a-zA-Z0-9.-]*))?" # Prerelease (optional)
r"(?:\+([a-zA-Z0-9][a-zA-Z0-9.-]*))?" # Build metadata (optional)
r"\]?$" # Optional closing bracket
)
# Pattern for wildcard versions (e.g., 1.0.*)
WILDCARD_PATTERN = re.compile(r"\*")
def parse_version(version_str: str) -> Optional[ParsedVersion]:
"""
Parse a NuGet version string.
Args:
version_str: Version string like "1.2.3", "1.2.3-beta.1", "[1.2.3]"
Returns:
ParsedVersion if valid, None if invalid or wildcard
"""
if not version_str:
return None
version_str = version_str.strip()
# Skip wildcard versions
if WILDCARD_PATTERN.search(version_str):
return None
match = VERSION_PATTERN.match(version_str)
if not match:
return None
major = int(match.group(1))
minor = int(match.group(2)) if match.group(2) else 0
patch = int(match.group(3)) if match.group(3) else 0
prerelease = match.group(4)
build_metadata = match.group(5)
return ParsedVersion(
major=major,
minor=minor,
patch=patch,
prerelease=prerelease,
build_metadata=build_metadata,
original=version_str,
)
def is_stable(version_str: str) -> bool:
"""
Check if a version string represents a stable release.
Args:
version_str: Version string to check
Returns:
True if stable (no prerelease suffix), False otherwise
"""
parsed = parse_version(version_str)
if parsed is None:
return False
return parsed.is_stable()
def compare_versions(v1: str, v2: str) -> int:
"""
Compare two version strings.
Args:
v1: First version string
v2: Second version string
Returns:
-1 if v1 < v2, 0 if equal, 1 if v1 > v2
Returns 0 if either version is unparseable
"""
parsed_v1 = parse_version(v1)
parsed_v2 = parse_version(v2)
if parsed_v1 is None or parsed_v2 is None:
return 0
if parsed_v1 < parsed_v2:
return -1
if parsed_v1 > parsed_v2:
return 1
return 0
def select_latest_stable(versions: list[str]) -> Optional[str]:
"""
Select the latest stable version from a list.
Args:
versions: List of version strings
Returns:
Latest stable version string, or None if no stable versions exist
"""
stable_versions: list[tuple[ParsedVersion, str]] = []
for v in versions:
parsed = parse_version(v)
if parsed is not None and parsed.is_stable():
stable_versions.append((parsed, v))
if not stable_versions:
return None
# Sort by parsed version and return the original string of the max
stable_versions.sort(key=lambda x: x[0], reverse=True)
return stable_versions[0][1]
def normalize_version_string(version_str: str) -> str:
"""
Normalize a version string to a canonical form.
Strips brackets, whitespace, and normalizes format.
Args:
version_str: Version string to normalize
Returns:
Normalized version string
"""
parsed = parse_version(version_str)
if parsed is None:
return version_str.strip()
# Rebuild canonical form
result = f"{parsed.major}.{parsed.minor}.{parsed.patch}"
if parsed.prerelease:
result += f"-{parsed.prerelease}"
if parsed.build_metadata:
result += f"+{parsed.build_metadata}"
return result

View File

@@ -0,0 +1,123 @@
"""
Data models for NuGet vulnerability checking.
"""
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class VulnerabilityDetail:
"""Details about a specific vulnerability."""
severity: str # low, moderate, high, critical
advisory_url: str
@dataclass
class VulnerablePackage:
"""A package with known vulnerabilities."""
package_id: str
resolved_version: str
requested_version: str
vulnerabilities: list[VulnerabilityDetail] = field(default_factory=list)
affected_projects: list[Path] = field(default_factory=list)
suggested_version: str | None = None
fix_risk: str = "unknown" # low, medium, high
@property
def highest_severity(self) -> str:
"""Get the highest severity among all vulnerabilities."""
severity_order = {"low": 1, "moderate": 2, "high": 3, "critical": 4}
if not self.vulnerabilities:
return "unknown"
return max(
self.vulnerabilities,
key=lambda v: severity_order.get(v.severity.lower(), 0),
).severity
@property
def advisory_urls(self) -> list[str]:
"""Get all advisory URLs."""
return [v.advisory_url for v in self.vulnerabilities]
@dataclass
class SuggestedFix:
"""Suggested fix for a vulnerable package."""
version: str
is_major_upgrade: bool
is_minor_upgrade: bool
is_patch_upgrade: bool
breaking_change_risk: str # low, medium, high
@classmethod
def from_versions(
cls, current: str, suggested: str, current_parsed: tuple, suggested_parsed: tuple
) -> "SuggestedFix":
"""Create a SuggestedFix from version tuples."""
is_major = suggested_parsed[0] > current_parsed[0]
is_minor = not is_major and suggested_parsed[1] > current_parsed[1]
is_patch = not is_major and not is_minor and suggested_parsed[2] > current_parsed[2]
# Estimate breaking change risk
if is_major:
risk = "high"
elif is_minor:
risk = "medium"
else:
risk = "low"
return cls(
version=suggested,
is_major_upgrade=is_major,
is_minor_upgrade=is_minor,
is_patch_upgrade=is_patch,
breaking_change_risk=risk,
)
@dataclass
class VulnerabilityReport:
"""Complete vulnerability scan report."""
solution: Path
min_severity: str
total_packages: int
vulnerabilities: list[VulnerablePackage] = field(default_factory=list)
unfixable: list[tuple[str, str]] = field(default_factory=list) # (package, reason)
@property
def vulnerable_count(self) -> int:
"""Count of vulnerable packages."""
return len(self.vulnerabilities)
@property
def fixable_count(self) -> int:
"""Count of packages with suggested fixes."""
return sum(1 for v in self.vulnerabilities if v.suggested_version)
@property
def unfixable_count(self) -> int:
"""Count of packages without fixes."""
return len(self.unfixable) + sum(
1 for v in self.vulnerabilities if not v.suggested_version
)
# Severity level mapping for comparisons
SEVERITY_LEVELS = {
"low": 1,
"moderate": 2,
"high": 3,
"critical": 4,
}
def meets_severity_threshold(vuln_severity: str, min_severity: str) -> bool:
"""Check if vulnerability meets minimum severity threshold."""
vuln_level = SEVERITY_LEVELS.get(vuln_severity.lower(), 0)
min_level = SEVERITY_LEVELS.get(min_severity.lower(), 0)
return vuln_level >= min_level

View File

@@ -0,0 +1,648 @@
#!/usr/bin/env python3
"""
StellaOps NuGet Centralization Tool.
Centralizes NuGet package versions to src/Directory.Build.props for packages
used in multiple projects, and removes version attributes from individual .csproj files.
This is the REVERSE of nuget_normalizer.py:
- nuget_normalizer: keeps versions in csproj files, normalizes to latest stable
- nuget_centralizer: moves shared packages to Directory.Build.props, removes versions from csproj
Usage:
python nuget_centralizer.py [OPTIONS]
Options:
--src-root PATH Root of src/ directory (default: ./src)
--dry-run Report without making changes
--report PATH Write JSON report to file
--exclude PACKAGE Exclude package from centralization (repeatable)
--min-usage N Minimum number of projects using a package to centralize it (default: 2)
--check CI mode: exit 1 if centralization needed
-v, --verbose Verbose output
"""
import argparse
import json
import logging
import re
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Set, Tuple
from lib.csproj_parser import find_all_csproj
from lib.models import PackageUsage
from lib.version_utils import select_latest_stable, parse_version
logger = logging.getLogger(__name__)
def setup_logging(verbose: bool) -> None:
"""Configure logging based on verbosity."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(levelname)s: %(message)s",
)
def scan_all_packages(src_root: Path) -> Dict[str, PackageUsage]:
"""
Scan all .csproj files and collect package references.
Args:
src_root: Root of src/ directory
Returns:
Dictionary mapping package name to PackageUsage
"""
packages: Dict[str, PackageUsage] = {}
csproj_files = find_all_csproj(src_root)
logger.info(f"Scanning {len(csproj_files)} .csproj files for package references")
# Regex for PackageReference with Version
package_ref_pattern = re.compile(
r'<PackageReference\s+[^>]*Include\s*=\s*"([^"]+)"[^>]*Version\s*=\s*"([^"]+)"',
re.IGNORECASE,
)
# Alternative pattern for when Version comes first
package_ref_pattern_alt = re.compile(
r'<PackageReference\s+[^>]*Version\s*=\s*"([^"]+)"[^>]*Include\s*=\s*"([^"]+)"',
re.IGNORECASE,
)
for csproj_path in csproj_files:
try:
content = csproj_path.read_text(encoding="utf-8")
except Exception as e:
logger.warning(f"Failed to read {csproj_path}: {e}")
continue
# Find all PackageReference elements with versions
for match in package_ref_pattern.finditer(content):
package_name = match.group(1)
version = match.group(2)
if package_name not in packages:
packages[package_name] = PackageUsage(package_name=package_name)
packages[package_name].usages[csproj_path] = version
# Also try alternative pattern
for match in package_ref_pattern_alt.finditer(content):
version = match.group(1)
package_name = match.group(2)
if package_name not in packages:
packages[package_name] = PackageUsage(package_name=package_name)
packages[package_name].usages[csproj_path] = version
logger.info(f"Found {len(packages)} unique packages")
return packages
def find_packages_to_centralize(
packages: Dict[str, PackageUsage],
exclude_packages: Set[str],
min_usage: int = 2,
) -> Dict[str, Tuple[str, List[Path]]]:
"""
Find packages that should be centralized.
A package is centralized if:
- It is used in at least min_usage projects
- It is not in the exclude list
- It has at least one parseable stable version
Args:
packages: Package usage data
exclude_packages: Package names to exclude
min_usage: Minimum number of projects using a package
Returns:
Dictionary mapping package name to (target_version, list of csproj paths)
"""
to_centralize: Dict[str, Tuple[str, List[Path]]] = {}
for package_name, usage in sorted(packages.items()):
# Skip excluded packages
if package_name in exclude_packages:
logger.debug(f"Excluding package: {package_name}")
continue
# Check if used in enough projects
if len(usage.usages) < min_usage:
logger.debug(f"Skipping {package_name}: only used in {len(usage.usages)} project(s)")
continue
# Get all versions and find latest stable
versions = usage.get_all_versions()
parseable_versions = [v for v in versions if parse_version(v) is not None]
if not parseable_versions:
logger.warning(f"Skipping {package_name}: no parseable versions")
continue
target_version = select_latest_stable(parseable_versions)
if target_version is None:
# Try to find any version (including prereleases)
parsed = [
(parse_version(v), v)
for v in parseable_versions
if parse_version(v) is not None
]
if parsed:
parsed.sort(key=lambda x: x[0], reverse=True)
target_version = parsed[0][1]
logger.warning(
f"Package {package_name}: using prerelease version {target_version} "
"(no stable version found)"
)
else:
logger.warning(f"Skipping {package_name}: no valid versions found")
continue
# Add to centralization list
csproj_paths = list(usage.usages.keys())
to_centralize[package_name] = (target_version, csproj_paths)
logger.info(
f"Will centralize {package_name} v{target_version} "
f"(used in {len(csproj_paths)} projects)"
)
return to_centralize
def read_directory_build_props(props_path: Path) -> str:
"""
Read Directory.Build.props file.
Args:
props_path: Path to Directory.Build.props
Returns:
File content as string
"""
if props_path.exists():
return props_path.read_text(encoding="utf-8")
else:
# Create minimal Directory.Build.props
return """<Project>
<PropertyGroup>
<!-- Centralize NuGet package cache to prevent directory sprawl -->
<RestorePackagesPath>$(MSBuildThisFileDirectory)../.nuget/packages</RestorePackagesPath>
<DisableImplicitNuGetFallbackFolder>true</DisableImplicitNuGetFallbackFolder>
<!-- Disable NuGet audit to prevent build failures when mirrors are unreachable -->
<NuGetAudit>false</NuGetAudit>
<WarningsNotAsErrors>$(WarningsNotAsErrors);NU1900;NU1901;NU1902;NU1903;NU1904</WarningsNotAsErrors>
</PropertyGroup>
<!-- Centralized NuGet package versions -->
<ItemGroup>
</ItemGroup>
</Project>
"""
def add_package_to_directory_props(
content: str,
package_name: str,
version: str,
) -> str:
"""
Add or update a PackageReference in Directory.Build.props.
Args:
content: Current content of Directory.Build.props
package_name: Package name
version: Package version
Returns:
Updated content
"""
# Check if package already exists
existing_pattern = re.compile(
rf'<PackageReference\s+Update\s*=\s*"{re.escape(package_name)}"[^>]*Version\s*=\s*"[^"]+"[^>]*/?>',
re.IGNORECASE,
)
if existing_pattern.search(content):
# Update existing entry
def replacer(match):
# Preserve the format of the existing entry
return f'<PackageReference Update="{package_name}" Version="{version}" />'
content = existing_pattern.sub(replacer, content)
logger.debug(f"Updated existing entry for {package_name}")
else:
# Find or create the centralized packages ItemGroup
# Look for the comment marker first
itemgroup_pattern = re.compile(
r'(<!-- Centralized NuGet package versions -->\s*<ItemGroup>)',
re.IGNORECASE,
)
if itemgroup_pattern.search(content):
# Add to existing centralized ItemGroup
new_entry = f'\n <PackageReference Update="{package_name}" Version="{version}" />'
content = itemgroup_pattern.sub(rf'\1{new_entry}', content)
logger.debug(f"Added {package_name} to centralized ItemGroup")
else:
# Create new centralized ItemGroup section before </Project>
project_end = re.compile(r'(\s*</Project>)', re.IGNORECASE)
new_section = f'''
<!-- Centralized NuGet package versions -->
<ItemGroup>
<PackageReference Update="{package_name}" Version="{version}" />
</ItemGroup>
'''
content = project_end.sub(rf'{new_section}\1', content)
logger.debug(f"Created centralized ItemGroup with {package_name}")
return content
def update_directory_build_props(
props_path: Path,
packages_to_centralize: Dict[str, Tuple[str, List[Path]]],
dry_run: bool = False,
) -> bool:
"""
Update Directory.Build.props with centralized package versions.
Args:
props_path: Path to Directory.Build.props
packages_to_centralize: Packages to add
dry_run: If True, don't write files
Returns:
True if successful
"""
if not packages_to_centralize:
logger.info("No packages to centralize")
return True
content = read_directory_build_props(props_path)
# Add or update each package
for package_name, (version, _) in sorted(packages_to_centralize.items()):
content = add_package_to_directory_props(content, package_name, version)
# Sort the PackageReference entries alphabetically
content = sort_package_references(content)
if dry_run:
logger.info(f"Would update {props_path}")
return True
try:
props_path.write_text(content, encoding="utf-8")
logger.info(f"Updated {props_path} with {len(packages_to_centralize)} centralized packages")
return True
except Exception as e:
logger.error(f"Failed to write {props_path}: {e}")
return False
def sort_package_references(content: str) -> str:
"""
Sort PackageReference Update entries alphabetically.
Args:
content: XML content
Returns:
Content with sorted PackageReference entries
"""
# Find the centralized packages ItemGroup section
itemgroup_pattern = re.compile(
r'(<!-- Centralized NuGet package versions -->\s*<ItemGroup>)(.*?)(</ItemGroup>)',
re.IGNORECASE | re.DOTALL,
)
match = itemgroup_pattern.search(content)
if not match:
# No centralized section found, return as-is
return content
prefix = match.group(1)
itemgroup_content = match.group(2)
suffix = match.group(3)
# Find all PackageReference Update entries in this section
package_pattern = re.compile(
r'<PackageReference\s+Update\s*=\s*"([^"]+)"[^>]*Version\s*=\s*"([^"]+)"[^>]*/?>',
re.IGNORECASE,
)
packages = package_pattern.findall(itemgroup_content)
if not packages:
# No packages found, return as-is
return content
# Sort by package name
sorted_packages = sorted(packages, key=lambda x: x[0].lower())
# Rebuild the ItemGroup with sorted entries
sorted_entries = '\n'.join(
f' <PackageReference Update="{pkg}" Version="{ver}" />'
for pkg, ver in sorted_packages
)
new_itemgroup = f'{prefix}\n{sorted_entries}\n {suffix}'
# Replace the old ItemGroup with the sorted one
content = itemgroup_pattern.sub(new_itemgroup, content)
return content
def remove_version_from_csproj(
csproj_path: Path,
package_name: str,
dry_run: bool = False,
) -> bool:
"""
Remove Version attribute from a PackageReference in a .csproj file.
Args:
csproj_path: Path to .csproj file
package_name: Package name
dry_run: If True, don't write files
Returns:
True if successful
"""
try:
content = csproj_path.read_text(encoding="utf-8")
except Exception as e:
logger.error(f"Failed to read {csproj_path}: {e}")
return False
# Pattern to match PackageReference with Version attribute and remove it
# This pattern captures the opening tag, Include attribute, and any other attributes
# Then removes the Version="..." attribute while preserving others
# Pattern 1: <PackageReference Include="..." Version="..." ... />
pattern1 = re.compile(
rf'(<PackageReference\s+Include\s*=\s*"{re.escape(package_name)}")\s+Version\s*=\s*"[^"]+"\s*([^/>]*/>)',
re.IGNORECASE,
)
# Pattern 2: <PackageReference Include="..." attr="..." Version="..." ... />
pattern2 = re.compile(
rf'(<PackageReference\s+Include\s*=\s*"{re.escape(package_name)}"[^>]*?)\s+Version\s*=\s*"[^"]+"\s*([^/>]*/>)',
re.IGNORECASE,
)
# Pattern 3: <PackageReference Version="..." Include="..." ... />
pattern3 = re.compile(
rf'(<PackageReference)\s+Version\s*=\s*"[^"]+"\s+(Include\s*=\s*"{re.escape(package_name)}"[^/>]*/>)',
re.IGNORECASE,
)
new_content = content
# Try each pattern
new_content = pattern1.sub(r'\1 \2', new_content)
new_content = pattern2.sub(r'\1 \2', new_content)
new_content = pattern3.sub(r'\1 \2', new_content)
if new_content == content:
logger.debug(f"No changes needed for {package_name} in {csproj_path.name}")
return True
if dry_run:
logger.info(f"Would remove version from {package_name} in {csproj_path.name}")
return True
try:
csproj_path.write_text(new_content, encoding="utf-8")
logger.info(f"Removed version from {package_name} in {csproj_path.name}")
return True
except Exception as e:
logger.error(f"Failed to write {csproj_path}: {e}")
return False
def apply_centralization(
props_path: Path,
packages_to_centralize: Dict[str, Tuple[str, List[Path]]],
dry_run: bool = False,
) -> Tuple[int, int]:
"""
Apply centralization by updating Directory.Build.props and csproj files.
Args:
props_path: Path to Directory.Build.props
packages_to_centralize: Packages to centralize
dry_run: If True, don't write files
Returns:
Tuple of (packages centralized, csproj files modified)
"""
# Update Directory.Build.props
if not update_directory_build_props(props_path, packages_to_centralize, dry_run):
return 0, 0
# Remove versions from csproj files
files_modified: Set[Path] = set()
for package_name, (version, csproj_paths) in packages_to_centralize.items():
for csproj_path in csproj_paths:
if remove_version_from_csproj(csproj_path, package_name, dry_run):
files_modified.add(csproj_path)
return len(packages_to_centralize), len(files_modified)
def generate_report(
packages: Dict[str, PackageUsage],
packages_to_centralize: Dict[str, Tuple[str, List[Path]]],
) -> dict:
"""
Generate a JSON report of the centralization.
Args:
packages: All package usage data
packages_to_centralize: Packages to centralize
Returns:
Report dictionary
"""
csproj_files_affected = set()
for _, csproj_paths in packages_to_centralize.values():
csproj_files_affected.update(csproj_paths)
report = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"summary": {
"packages_scanned": len(packages),
"packages_to_centralize": len(packages_to_centralize),
"csproj_files_affected": len(csproj_files_affected),
},
"centralized_packages": [],
}
for package_name, (version, csproj_paths) in sorted(packages_to_centralize.items()):
report["centralized_packages"].append(
{
"package": package_name,
"version": version,
"usage_count": len(csproj_paths),
"files": [str(p) for p in csproj_paths],
}
)
return report
def print_summary(
packages: Dict[str, PackageUsage],
packages_to_centralize: Dict[str, Tuple[str, List[Path]]],
dry_run: bool,
) -> None:
"""Print a summary of the centralization."""
print("\n" + "=" * 60)
print("NuGet Package Centralization Summary")
print("=" * 60)
csproj_files_affected = set()
for _, csproj_paths in packages_to_centralize.values():
csproj_files_affected.update(csproj_paths)
print(f"\nPackages scanned: {len(packages)}")
print(f"Packages to centralize: {len(packages_to_centralize)}")
print(f"Project files affected: {len(csproj_files_affected)}")
if packages_to_centralize:
print("\nPackages to centralize to Directory.Build.props:")
for package_name, (version, csproj_paths) in sorted(packages_to_centralize.items()):
print(f" {package_name}: v{version} (used in {len(csproj_paths)} projects)")
if dry_run:
print("\n[DRY RUN - No files were modified]")
def main() -> int:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Centralize NuGet package versions to Directory.Build.props",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--src-root",
type=Path,
default=Path("src"),
help="Root of src/ directory (default: ./src)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Report without making changes",
)
parser.add_argument(
"--report",
type=Path,
help="Write JSON report to file",
)
parser.add_argument(
"--exclude",
action="append",
dest="exclude_packages",
default=[],
help="Exclude package from centralization (repeatable)",
)
parser.add_argument(
"--min-usage",
type=int,
default=2,
help="Minimum number of projects using a package to centralize it (default: 2)",
)
parser.add_argument(
"--check",
action="store_true",
help="CI mode: exit 1 if centralization needed",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Verbose output",
)
args = parser.parse_args()
setup_logging(args.verbose)
# Resolve src root
src_root = args.src_root.resolve()
if not src_root.exists():
logger.error(f"Source root does not exist: {src_root}")
return 1
logger.info(f"Source root: {src_root}")
# Path to Directory.Build.props
props_path = src_root / "Directory.Build.props"
# Scan all packages
packages = scan_all_packages(src_root)
if not packages:
logger.info("No packages found")
return 0
# Find packages to centralize
exclude_set = set(args.exclude_packages)
packages_to_centralize = find_packages_to_centralize(
packages, exclude_set, args.min_usage
)
# Generate report
report = generate_report(packages, packages_to_centralize)
# Write report if requested
if args.report:
try:
args.report.write_text(
json.dumps(report, indent=2, default=str),
encoding="utf-8",
)
logger.info(f"Report written to: {args.report}")
except Exception as e:
logger.error(f"Failed to write report: {e}")
# Print summary
print_summary(packages, packages_to_centralize, args.dry_run or args.check)
# Check mode - just report if centralization is needed
if args.check:
if packages_to_centralize:
logger.error("Package centralization needed")
return 1
logger.info("All shared packages are already centralized")
return 0
# Apply centralization
if not args.dry_run:
packages_count, files_count = apply_centralization(
props_path, packages_to_centralize, dry_run=False
)
print(f"\nCentralized {packages_count} packages, modified {files_count} files")
else:
apply_centralization(props_path, packages_to_centralize, dry_run=True)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,626 @@
#!/usr/bin/env python3
"""
StellaOps NuGet Version Normalizer.
Scans all .csproj files and normalizes NuGet package versions to the latest stable.
IMPORTANT: Packages centrally managed in Directory.Build.props (via PackageReference Update)
are automatically excluded from normalization. These packages are reported separately.
Usage:
python nuget_normalizer.py [OPTIONS]
Options:
--src-root PATH Root of src/ directory (default: ./src)
--repo-root PATH Root of repository (default: parent of src-root)
--dry-run Report without making changes
--report PATH Write JSON report to file
--exclude PACKAGE Exclude package from normalization (repeatable)
--check CI mode: exit 1 if normalization needed
-v, --verbose Verbose output
"""
import argparse
import json
import logging
import re
import sys
from datetime import datetime, timezone
from pathlib import Path
from lib.csproj_parser import find_all_csproj
from lib.models import NormalizationChange, NormalizationResult, PackageUsage
from lib.version_utils import is_stable, parse_version, select_latest_stable
logger = logging.getLogger(__name__)
def find_directory_build_props(repo_root: Path) -> list[Path]:
"""
Find all Directory.Build.props files in the repository.
Args:
repo_root: Root of the repository
Returns:
List of paths to Directory.Build.props files
"""
props_files = []
for props_file in repo_root.rglob("Directory.Build.props"):
# Skip common exclusion directories
parts = props_file.parts
if any(p in ("bin", "obj", "node_modules", ".git") for p in parts):
continue
props_files.append(props_file)
return props_files
def scan_centrally_managed_packages(repo_root: Path) -> dict[str, tuple[str, Path]]:
"""
Scan Directory.Build.props files for centrally managed package versions.
These are packages defined with <PackageReference Update="..." Version="..."/>
which override versions in individual csproj files.
Args:
repo_root: Root of the repository
Returns:
Dictionary mapping package name to (version, props_file_path)
"""
centrally_managed: dict[str, tuple[str, Path]] = {}
props_files = find_directory_build_props(repo_root)
logger.info(f"Scanning {len(props_files)} Directory.Build.props files for centrally managed packages")
# Pattern for PackageReference Update (central version management)
# <PackageReference Update="PackageName" Version="1.2.3" />
update_pattern = re.compile(
r'<PackageReference\s+Update\s*=\s*"([^"]+)"[^>]*Version\s*=\s*"([^"]+)"',
re.IGNORECASE,
)
# Alternative pattern when Version comes first
update_pattern_alt = re.compile(
r'<PackageReference\s+[^>]*Version\s*=\s*"([^"]+)"[^>]*Update\s*=\s*"([^"]+)"',
re.IGNORECASE,
)
for props_file in props_files:
try:
content = props_file.read_text(encoding="utf-8")
except Exception as e:
logger.warning(f"Failed to read {props_file}: {e}")
continue
# Find PackageReference Update elements
for match in update_pattern.finditer(content):
package_name = match.group(1)
version = match.group(2)
# Store with the props file path for reporting
if package_name not in centrally_managed:
centrally_managed[package_name] = (version, props_file)
logger.debug(f"Found centrally managed: {package_name} v{version} in {props_file}")
for match in update_pattern_alt.finditer(content):
version = match.group(1)
package_name = match.group(2)
if package_name not in centrally_managed:
centrally_managed[package_name] = (version, props_file)
logger.debug(f"Found centrally managed: {package_name} v{version} in {props_file}")
logger.info(f"Found {len(centrally_managed)} centrally managed packages")
return centrally_managed
def setup_logging(verbose: bool) -> None:
"""Configure logging based on verbosity."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(levelname)s: %(message)s",
)
def scan_all_packages(src_root: Path) -> dict[str, PackageUsage]:
"""
Scan all .csproj files and collect package references.
Args:
src_root: Root of src/ directory
Returns:
Dictionary mapping package name to PackageUsage
"""
packages: dict[str, PackageUsage] = {}
csproj_files = find_all_csproj(src_root)
logger.info(f"Scanning {len(csproj_files)} .csproj files for package references")
# Regex for PackageReference
# Matches: <PackageReference Include="PackageName" Version="1.2.3" />
# Also handles multi-line and various attribute orderings
package_ref_pattern = re.compile(
r'<PackageReference\s+[^>]*Include\s*=\s*"([^"]+)"[^>]*Version\s*=\s*"([^"]+)"',
re.IGNORECASE,
)
# Alternative pattern for when Version comes first
package_ref_pattern_alt = re.compile(
r'<PackageReference\s+[^>]*Version\s*=\s*"([^"]+)"[^>]*Include\s*=\s*"([^"]+)"',
re.IGNORECASE,
)
for csproj_path in csproj_files:
try:
content = csproj_path.read_text(encoding="utf-8")
except Exception as e:
logger.warning(f"Failed to read {csproj_path}: {e}")
continue
# Find all PackageReference elements
for match in package_ref_pattern.finditer(content):
package_name = match.group(1)
version = match.group(2)
if package_name not in packages:
packages[package_name] = PackageUsage(package_name=package_name)
packages[package_name].usages[csproj_path] = version
# Also try alternative pattern
for match in package_ref_pattern_alt.finditer(content):
version = match.group(1)
package_name = match.group(2)
if package_name not in packages:
packages[package_name] = PackageUsage(package_name=package_name)
packages[package_name].usages[csproj_path] = version
logger.info(f"Found {len(packages)} unique packages")
return packages
def calculate_normalizations(
packages: dict[str, PackageUsage],
exclude_packages: set[str],
centrally_managed: dict[str, tuple[str, Path]] | None = None,
) -> tuple[list[NormalizationResult], list[tuple[str, str, Path]]]:
"""
Calculate which packages need version normalization.
Args:
packages: Package usage data
exclude_packages: Package names to exclude
centrally_managed: Packages managed in Directory.Build.props (auto-excluded)
Returns:
Tuple of (normalization results, list of centrally managed packages that were skipped)
"""
results: list[NormalizationResult] = []
centrally_skipped: list[tuple[str, str, Path]] = []
if centrally_managed is None:
centrally_managed = {}
for package_name, usage in sorted(packages.items()):
# Skip centrally managed packages
if package_name in centrally_managed:
version, props_file = centrally_managed[package_name]
centrally_skipped.append((package_name, version, props_file))
logger.debug(f"Skipping centrally managed package: {package_name} (v{version} in {props_file})")
continue
if package_name in exclude_packages:
logger.debug(f"Excluding package: {package_name}")
continue
versions = usage.get_all_versions()
# Skip if only one version
if len(versions) <= 1:
continue
# Check if any versions are wildcards or unparseable
parseable_versions = [v for v in versions if parse_version(v) is not None]
if not parseable_versions:
results.append(
NormalizationResult(
package_name=package_name,
target_version="",
skipped_reason="No parseable versions found",
)
)
continue
# Select latest stable version
target_version = select_latest_stable(parseable_versions)
if target_version is None:
# Try to find any version (including prereleases)
parsed = [
(parse_version(v), v)
for v in parseable_versions
if parse_version(v) is not None
]
if parsed:
parsed.sort(key=lambda x: x[0], reverse=True)
target_version = parsed[0][1]
results.append(
NormalizationResult(
package_name=package_name,
target_version=target_version,
skipped_reason="Only prerelease versions available",
)
)
continue
else:
results.append(
NormalizationResult(
package_name=package_name,
target_version="",
skipped_reason="No stable versions found",
)
)
continue
# Create normalization result with changes
result = NormalizationResult(
package_name=package_name,
target_version=target_version,
)
for csproj_path, current_version in usage.usages.items():
if current_version != target_version:
result.changes.append(
NormalizationChange(
csproj_path=csproj_path,
old_version=current_version,
new_version=target_version,
)
)
if result.changes:
results.append(result)
return results, centrally_skipped
def apply_normalizations(
normalizations: list[NormalizationResult],
dry_run: bool = False,
) -> int:
"""
Apply version normalizations to csproj files.
Args:
normalizations: List of normalization results
dry_run: If True, don't actually modify files
Returns:
Number of files modified
"""
files_modified: set[Path] = set()
for result in normalizations:
if result.skipped_reason:
continue
for change in result.changes:
csproj_path = change.csproj_path
if dry_run:
logger.info(
f"Would update {result.package_name} in {csproj_path.name}: "
f"{change.old_version} -> {change.new_version}"
)
files_modified.add(csproj_path)
continue
try:
content = csproj_path.read_text(encoding="utf-8")
# Replace the specific package version
# Pattern matches the PackageReference for this specific package
pattern = re.compile(
rf'(<PackageReference\s+[^>]*Include\s*=\s*"{re.escape(result.package_name)}"'
rf'[^>]*Version\s*=\s*"){re.escape(change.old_version)}(")',
re.IGNORECASE,
)
new_content, count = pattern.subn(
rf"\g<1>{change.new_version}\g<2>",
content,
)
if count > 0:
csproj_path.write_text(new_content, encoding="utf-8")
files_modified.add(csproj_path)
logger.info(
f"Updated {result.package_name} in {csproj_path.name}: "
f"{change.old_version} -> {change.new_version}"
)
else:
# Try alternative pattern
pattern_alt = re.compile(
rf'(<PackageReference\s+[^>]*Version\s*=\s*"){re.escape(change.old_version)}"'
rf'([^>]*Include\s*=\s*"{re.escape(result.package_name)}")',
re.IGNORECASE,
)
new_content, count = pattern_alt.subn(
rf'\g<1>{change.new_version}"\g<2>',
content,
)
if count > 0:
csproj_path.write_text(new_content, encoding="utf-8")
files_modified.add(csproj_path)
logger.info(
f"Updated {result.package_name} in {csproj_path.name}: "
f"{change.old_version} -> {change.new_version}"
)
else:
logger.warning(
f"Could not find pattern to update {result.package_name} "
f"in {csproj_path}"
)
except Exception as e:
logger.error(f"Failed to update {csproj_path}: {e}")
return len(files_modified)
def generate_report(
packages: dict[str, PackageUsage],
normalizations: list[NormalizationResult],
centrally_skipped: list[tuple[str, str, Path]] | None = None,
) -> dict:
"""
Generate a JSON report of the normalization.
Args:
packages: Package usage data
normalizations: Normalization results
centrally_skipped: Packages skipped due to central management
Returns:
Report dictionary
"""
if centrally_skipped is None:
centrally_skipped = []
# Count changes
packages_normalized = sum(
1 for n in normalizations if n.changes and not n.skipped_reason
)
files_modified = len(
set(
change.csproj_path
for n in normalizations
for change in n.changes
if not n.skipped_reason
)
)
report = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"summary": {
"packages_scanned": len(packages),
"packages_with_inconsistencies": len(normalizations),
"packages_normalized": packages_normalized,
"files_modified": files_modified,
"packages_centrally_managed": len(centrally_skipped),
},
"normalizations": [],
"skipped": [],
"centrally_managed": [],
}
for result in normalizations:
if result.skipped_reason:
report["skipped"].append(
{
"package": result.package_name,
"reason": result.skipped_reason,
"versions": packages[result.package_name].get_all_versions()
if result.package_name in packages
else [],
}
)
elif result.changes:
report["normalizations"].append(
{
"package": result.package_name,
"target_version": result.target_version,
"changes": [
{
"file": str(change.csproj_path),
"old": change.old_version,
"new": change.new_version,
}
for change in result.changes
],
}
)
# Add centrally managed packages
for package_name, version, props_file in centrally_skipped:
report["centrally_managed"].append(
{
"package": package_name,
"version": version,
"managed_in": str(props_file),
}
)
return report
def print_summary(
packages: dict[str, PackageUsage],
normalizations: list[NormalizationResult],
centrally_skipped: list[tuple[str, str, Path]],
dry_run: bool,
) -> None:
"""Print a summary of the normalization."""
print("\n" + "=" * 60)
print("NuGet Version Normalization Summary")
print("=" * 60)
changes_needed = [n for n in normalizations if n.changes and not n.skipped_reason]
skipped = [n for n in normalizations if n.skipped_reason]
print(f"\nPackages scanned: {len(packages)}")
print(f"Packages with version inconsistencies: {len(normalizations)}")
print(f"Packages to normalize: {len(changes_needed)}")
print(f"Packages skipped (other reasons): {len(skipped)}")
print(f"Packages centrally managed (auto-skipped): {len(centrally_skipped)}")
if centrally_skipped:
print("\nCentrally managed packages (in Directory.Build.props):")
for package_name, version, props_file in sorted(centrally_skipped, key=lambda x: x[0]):
rel_path = props_file.name if len(str(props_file)) > 50 else props_file
print(f" {package_name}: v{version} ({rel_path})")
if changes_needed:
print("\nPackages to normalize:")
for result in sorted(changes_needed, key=lambda x: x.package_name):
old_versions = set(c.old_version for c in result.changes)
print(
f" {result.package_name}: {', '.join(sorted(old_versions))} -> {result.target_version}"
)
if skipped and logger.isEnabledFor(logging.DEBUG):
print("\nSkipped packages:")
for result in sorted(skipped, key=lambda x: x.package_name):
print(f" {result.package_name}: {result.skipped_reason}")
if dry_run:
print("\n[DRY RUN - No files were modified]")
def main() -> int:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Normalize NuGet package versions across all csproj files",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--src-root",
type=Path,
default=Path("src"),
help="Root of src/ directory (default: ./src)",
)
parser.add_argument(
"--repo-root",
type=Path,
default=None,
help="Root of repository for Directory.Build.props scanning (default: parent of src-root)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Report without making changes",
)
parser.add_argument(
"--report",
type=Path,
help="Write JSON report to file",
)
parser.add_argument(
"--exclude",
action="append",
dest="exclude_packages",
default=[],
help="Exclude package from normalization (repeatable)",
)
parser.add_argument(
"--check",
action="store_true",
help="CI mode: exit 1 if normalization needed",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Verbose output",
)
args = parser.parse_args()
setup_logging(args.verbose)
# Resolve src root
src_root = args.src_root.resolve()
if not src_root.exists():
logger.error(f"Source root does not exist: {src_root}")
return 1
# Resolve repo root (for Directory.Build.props scanning)
repo_root = args.repo_root.resolve() if args.repo_root else src_root.parent
if not repo_root.exists():
logger.error(f"Repository root does not exist: {repo_root}")
return 1
logger.info(f"Source root: {src_root}")
logger.info(f"Repository root: {repo_root}")
# Scan for centrally managed packages in Directory.Build.props
centrally_managed = scan_centrally_managed_packages(repo_root)
# Scan all packages
packages = scan_all_packages(src_root)
if not packages:
logger.info("No packages found")
return 0
# Calculate normalizations (excluding centrally managed packages)
exclude_set = set(args.exclude_packages)
normalizations, centrally_skipped = calculate_normalizations(
packages, exclude_set, centrally_managed
)
# Generate report
report = generate_report(packages, normalizations, centrally_skipped)
# Write report if requested
if args.report:
try:
args.report.write_text(
json.dumps(report, indent=2, default=str),
encoding="utf-8",
)
logger.info(f"Report written to: {args.report}")
except Exception as e:
logger.error(f"Failed to write report: {e}")
# Print summary
print_summary(packages, normalizations, centrally_skipped, args.dry_run or args.check)
# Check mode - just report if normalization is needed
if args.check:
changes_needed = [n for n in normalizations if n.changes and not n.skipped_reason]
if changes_needed:
logger.error("Version normalization needed")
return 1
logger.info("All package versions are consistent")
return 0
# Apply normalizations
if not args.dry_run:
files_modified = apply_normalizations(normalizations, dry_run=False)
print(f"\nModified {files_modified} files")
else:
apply_normalizations(normalizations, dry_run=True)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,620 @@
#!/usr/bin/env python3
"""
StellaOps NuGet Vulnerability Checker.
Scans NuGet packages for security vulnerabilities and suggests/applies fixes.
Usage:
python nuget_vuln_checker.py [OPTIONS]
Options:
--solution PATH Path to .sln file (default: src/StellaOps.sln)
--min-severity LEVEL Minimum severity: low|moderate|high|critical (default: high)
--fix Auto-fix by updating to non-vulnerable versions
--dry-run Show what would be fixed without modifying files
--report PATH Write JSON report to file
--include-transitive Include transitive dependency vulnerabilities
--exclude PACKAGE Exclude package from checks (repeatable)
-v, --verbose Verbose output
Exit Codes:
0 - No vulnerabilities found (or all below threshold)
1 - Vulnerabilities found above threshold
2 - Error during execution
"""
import argparse
import json
import logging
import re
import shutil
import subprocess
import sys
from datetime import datetime, timezone
from pathlib import Path
from lib.nuget_api import NuGetApiClient, NuGetApiError
from lib.vulnerability_models import (
SEVERITY_LEVELS,
VulnerabilityDetail,
VulnerabilityReport,
VulnerablePackage,
meets_severity_threshold,
)
from lib.version_utils import parse_version
logger = logging.getLogger(__name__)
def setup_logging(verbose: bool) -> None:
"""Configure logging based on verbosity."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(levelname)s: %(message)s",
)
def check_dotnet_available() -> bool:
"""Check if dotnet CLI is available."""
return shutil.which("dotnet") is not None
def run_vulnerability_check(
solution_path: Path, include_transitive: bool
) -> dict | None:
"""
Run dotnet list package --vulnerable and parse JSON output.
Returns parsed JSON or None if command fails.
"""
cmd = [
"dotnet",
"list",
str(solution_path),
"package",
"--vulnerable",
"--format",
"json",
"--output-version",
"1",
]
if include_transitive:
cmd.append("--include-transitive")
logger.info(f"Running: {' '.join(cmd)}")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=600, # 10 minute timeout for large solutions
)
# dotnet always returns 0, even with vulnerabilities
if result.returncode != 0:
logger.error(f"dotnet command failed: {result.stderr}")
return None
# Parse JSON output
if not result.stdout.strip():
logger.warning("Empty output from dotnet list package")
return {"version": 1, "projects": []}
return json.loads(result.stdout)
except subprocess.TimeoutExpired:
logger.error("dotnet command timed out")
return None
except json.JSONDecodeError as e:
logger.error(f"Failed to parse dotnet output as JSON: {e}")
logger.debug(f"Output was: {result.stdout[:500]}...")
return None
except Exception as e:
logger.error(f"Error running dotnet command: {e}")
return None
def parse_vulnerability_output(
data: dict, min_severity: str, exclude_packages: set[str]
) -> list[VulnerablePackage]:
"""
Parse dotnet list package --vulnerable JSON output.
Returns list of VulnerablePackage objects that meet severity threshold.
"""
vulnerable_packages: dict[str, VulnerablePackage] = {}
for project in data.get("projects", []):
project_path = Path(project.get("path", "unknown"))
for framework in project.get("frameworks", []):
# Check both topLevelPackages and transitivePackages
for package_list_key in ["topLevelPackages", "transitivePackages"]:
for package in framework.get(package_list_key, []):
package_id = package.get("id", "")
# Skip excluded packages
if package_id.lower() in {p.lower() for p in exclude_packages}:
logger.debug(f"Skipping excluded package: {package_id}")
continue
vulns = package.get("vulnerabilities", [])
if not vulns:
continue
# Check if any vulnerability meets threshold
matching_vulns = []
for vuln in vulns:
severity = vuln.get("severity", "unknown")
if meets_severity_threshold(severity, min_severity):
matching_vulns.append(
VulnerabilityDetail(
severity=severity,
advisory_url=vuln.get("advisoryurl", ""),
)
)
if not matching_vulns:
continue
# Add or update vulnerable package
key = f"{package_id}@{package.get('resolvedVersion', '')}"
if key not in vulnerable_packages:
vulnerable_packages[key] = VulnerablePackage(
package_id=package_id,
resolved_version=package.get("resolvedVersion", ""),
requested_version=package.get("requestedVersion", ""),
vulnerabilities=matching_vulns,
)
vulnerable_packages[key].affected_projects.append(project_path)
return list(vulnerable_packages.values())
def find_suggested_fixes(
vulnerable_packages: list[VulnerablePackage],
api_client: NuGetApiClient | None,
) -> None:
"""
For each vulnerable package, find a suggested non-vulnerable version.
Modifies packages in-place to add suggested_version and fix_risk.
"""
if api_client is None:
logger.warning("NuGet API client not available, cannot suggest fixes")
return
for pkg in vulnerable_packages:
logger.debug(f"Finding safe version for {pkg.package_id} {pkg.resolved_version}")
try:
safe_version = api_client.find_safe_version(
pkg.package_id, pkg.resolved_version
)
if safe_version:
pkg.suggested_version = safe_version
pkg.fix_risk = api_client.get_fix_risk(
pkg.resolved_version, safe_version
)
logger.info(
f"Found safe version for {pkg.package_id}: "
f"{pkg.resolved_version} -> {safe_version} (risk: {pkg.fix_risk})"
)
else:
logger.warning(
f"No safe version found for {pkg.package_id} {pkg.resolved_version}"
)
except NuGetApiError as e:
logger.warning(f"Failed to query NuGet API for {pkg.package_id}: {e}")
def has_direct_package_reference(content: str, package_id: str) -> bool:
"""Check if the csproj has a direct PackageReference for the package."""
pattern = re.compile(
rf'<PackageReference\s+[^>]*Include\s*=\s*"{re.escape(package_id)}"',
re.IGNORECASE,
)
return pattern.search(content) is not None
def add_package_reference(content: str, package_id: str, version: str) -> str:
"""
Add a new PackageReference to a csproj file.
Inserts into an existing ItemGroup with PackageReferences, or creates a new one.
"""
# Find existing ItemGroup with PackageReferences
itemgroup_pattern = re.compile(
r'(<ItemGroup[^>]*>)(.*?<PackageReference\s)',
re.IGNORECASE | re.DOTALL,
)
match = itemgroup_pattern.search(content)
if match:
# Insert after the opening ItemGroup tag
insert_pos = match.end(1)
new_ref = f'\n <PackageReference Include="{package_id}" Version="{version}" />'
return content[:insert_pos] + new_ref + content[insert_pos:]
# No ItemGroup with PackageReferences found, look for any ItemGroup
any_itemgroup = re.search(r'(<ItemGroup[^>]*>)', content, re.IGNORECASE)
if any_itemgroup:
insert_pos = any_itemgroup.end(1)
new_ref = f'\n <PackageReference Include="{package_id}" Version="{version}" />'
return content[:insert_pos] + new_ref + content[insert_pos:]
# No ItemGroup at all, add before closing </Project>
project_close = content.rfind('</Project>')
if project_close > 0:
new_itemgroup = f'\n <ItemGroup>\n <PackageReference Include="{package_id}" Version="{version}" />\n </ItemGroup>\n'
return content[:project_close] + new_itemgroup + content[project_close:]
# Fallback - shouldn't happen for valid csproj
return content
def apply_fixes(
vulnerable_packages: list[VulnerablePackage],
dry_run: bool = False,
) -> int:
"""
Apply suggested fixes to csproj files.
For direct dependencies: updates the version in place.
For transitive dependencies: adds an explicit PackageReference to override.
Returns number of files modified.
"""
files_modified: set[Path] = set()
for pkg in vulnerable_packages:
if not pkg.suggested_version:
continue
for project_path in pkg.affected_projects:
if not project_path.exists():
logger.warning(f"Project file not found: {project_path}")
continue
try:
content = project_path.read_text(encoding="utf-8")
# Check if this is a direct or transitive dependency
is_direct = has_direct_package_reference(content, pkg.package_id)
if is_direct:
# Direct dependency - update version in place
if dry_run:
logger.info(
f"Would update {pkg.package_id} in {project_path.name}: "
f"{pkg.resolved_version} -> {pkg.suggested_version}"
)
files_modified.add(project_path)
continue
# Pattern to match PackageReference for this package
pattern = re.compile(
rf'(<PackageReference\s+[^>]*Include\s*=\s*"{re.escape(pkg.package_id)}"'
rf'[^>]*Version\s*=\s*"){re.escape(pkg.resolved_version)}(")',
re.IGNORECASE,
)
new_content, count = pattern.subn(
rf"\g<1>{pkg.suggested_version}\g<2>",
content,
)
if count > 0:
project_path.write_text(new_content, encoding="utf-8")
files_modified.add(project_path)
logger.info(
f"Updated {pkg.package_id} in {project_path.name}: "
f"{pkg.resolved_version} -> {pkg.suggested_version}"
)
else:
# Try alternative pattern (Version before Include)
pattern_alt = re.compile(
rf'(<PackageReference\s+[^>]*Version\s*=\s*"){re.escape(pkg.resolved_version)}"'
rf'([^>]*Include\s*=\s*"{re.escape(pkg.package_id)}")',
re.IGNORECASE,
)
new_content, count = pattern_alt.subn(
rf'\g<1>{pkg.suggested_version}"\g<2>',
content,
)
if count > 0:
project_path.write_text(new_content, encoding="utf-8")
files_modified.add(project_path)
logger.info(
f"Updated {pkg.package_id} in {project_path.name}: "
f"{pkg.resolved_version} -> {pkg.suggested_version}"
)
else:
logger.warning(
f"Could not find {pkg.package_id} {pkg.resolved_version} "
f"in {project_path}"
)
else:
# Transitive dependency - add explicit PackageReference to override
if dry_run:
logger.info(
f"Would add explicit PackageReference for transitive dependency "
f"{pkg.package_id} {pkg.suggested_version} in {project_path.name} "
f"(overrides vulnerable {pkg.resolved_version})"
)
files_modified.add(project_path)
continue
new_content = add_package_reference(
content, pkg.package_id, pkg.suggested_version
)
if new_content != content:
project_path.write_text(new_content, encoding="utf-8")
files_modified.add(project_path)
logger.info(
f"Added explicit PackageReference for {pkg.package_id} "
f"{pkg.suggested_version} in {project_path.name} "
f"(overrides vulnerable transitive {pkg.resolved_version})"
)
else:
logger.warning(
f"Failed to add PackageReference for {pkg.package_id} "
f"in {project_path}"
)
except Exception as e:
logger.error(f"Failed to update {project_path}: {e}")
return len(files_modified)
def generate_report(
solution: Path,
min_severity: str,
total_packages: int,
vulnerable_packages: list[VulnerablePackage],
) -> dict:
"""Generate JSON report of vulnerability scan."""
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"solution": str(solution),
"min_severity": min_severity,
"summary": {
"total_packages_scanned": total_packages,
"vulnerable_packages": len(vulnerable_packages),
"fixable_packages": sum(
1 for p in vulnerable_packages if p.suggested_version
),
"unfixable_packages": sum(
1 for p in vulnerable_packages if not p.suggested_version
),
},
"vulnerabilities": [
{
"package": pkg.package_id,
"current_version": pkg.resolved_version,
"severity": pkg.highest_severity,
"advisory_urls": pkg.advisory_urls,
"affected_projects": [str(p) for p in pkg.affected_projects],
"suggested_fix": {
"version": pkg.suggested_version,
"risk": pkg.fix_risk,
}
if pkg.suggested_version
else None,
}
for pkg in vulnerable_packages
],
"unfixable": [
{
"package": pkg.package_id,
"version": pkg.resolved_version,
"reason": "No non-vulnerable version available",
}
for pkg in vulnerable_packages
if not pkg.suggested_version
],
}
def print_summary(
vulnerable_packages: list[VulnerablePackage],
min_severity: str,
dry_run: bool,
fix_mode: bool,
) -> None:
"""Print a human-readable summary of findings."""
print("\n" + "=" * 70)
print("NuGet Vulnerability Scan Results")
print("=" * 70)
if not vulnerable_packages:
print(f"\nNo vulnerabilities found at or above '{min_severity}' severity.")
return
print(f"\nFound {len(vulnerable_packages)} vulnerable package(s):\n")
for pkg in sorted(vulnerable_packages, key=lambda p: (
-SEVERITY_LEVELS.get(p.highest_severity.lower(), 0),
p.package_id,
)):
severity_upper = pkg.highest_severity.upper()
print(f" [{severity_upper}] {pkg.package_id} {pkg.resolved_version}")
for vuln in pkg.vulnerabilities:
print(f" Advisory: {vuln.advisory_url}")
if pkg.suggested_version:
risk_str = f" (risk: {pkg.fix_risk})" if pkg.fix_risk != "unknown" else ""
print(f" Suggested fix: {pkg.suggested_version}{risk_str}")
else:
print(" No fix available")
print(f" Affected projects: {len(pkg.affected_projects)}")
for proj in pkg.affected_projects[:3]: # Show first 3
print(f" - {proj.name}")
if len(pkg.affected_projects) > 3:
print(f" - ... and {len(pkg.affected_projects) - 3} more")
print()
# Summary counts
fixable = sum(1 for p in vulnerable_packages if p.suggested_version)
unfixable = len(vulnerable_packages) - fixable
print("-" * 70)
print(f"Summary: {len(vulnerable_packages)} vulnerable, {fixable} fixable, {unfixable} unfixable")
if dry_run:
print("\n[DRY RUN - No files were modified]")
elif not fix_mode:
print("\nRun with --fix to apply suggested fixes, or --dry-run to preview changes")
def main() -> int:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Check NuGet packages for security vulnerabilities",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--solution",
type=Path,
default=Path("src/StellaOps.sln"),
help="Path to .sln file (default: src/StellaOps.sln)",
)
parser.add_argument(
"--min-severity",
choices=["low", "moderate", "high", "critical"],
default="high",
help="Minimum severity to report (default: high)",
)
parser.add_argument(
"--fix",
action="store_true",
help="Auto-fix by updating to non-vulnerable versions",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be fixed without modifying files",
)
parser.add_argument(
"--report",
type=Path,
help="Write JSON report to file",
)
parser.add_argument(
"--include-transitive",
action="store_true",
help="Include transitive dependency vulnerabilities",
)
parser.add_argument(
"--exclude",
action="append",
dest="exclude_packages",
default=[],
help="Exclude package from checks (repeatable)",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Verbose output",
)
args = parser.parse_args()
setup_logging(args.verbose)
# Validate solution path
solution_path = args.solution.resolve()
if not solution_path.exists():
logger.error(f"Solution file not found: {solution_path}")
return 2
# Check dotnet is available
if not check_dotnet_available():
logger.error("dotnet CLI not found. Please install .NET SDK.")
return 2
logger.info(f"Scanning solution: {solution_path}")
logger.info(f"Minimum severity: {args.min_severity}")
# Run vulnerability check
vuln_data = run_vulnerability_check(solution_path, args.include_transitive)
if vuln_data is None:
logger.error("Failed to run vulnerability check")
return 2
# Count total packages for reporting
total_packages = 0
for project in vuln_data.get("projects", []):
for framework in project.get("frameworks", []):
total_packages += len(framework.get("topLevelPackages", []))
if args.include_transitive:
total_packages += len(framework.get("transitivePackages", []))
# Parse vulnerabilities
exclude_set = set(args.exclude_packages)
vulnerable_packages = parse_vulnerability_output(
vuln_data, args.min_severity, exclude_set
)
logger.info(f"Found {len(vulnerable_packages)} vulnerable package(s)")
# Try to find suggested fixes via NuGet API
api_client = None
try:
api_client = NuGetApiClient()
find_suggested_fixes(vulnerable_packages, api_client)
except ImportError:
logger.warning(
"requests library not available, cannot suggest fixes. "
"Install with: pip install requests"
)
except Exception as e:
logger.warning(f"NuGet API initialization failed: {e}")
# Generate report
report = generate_report(
solution_path, args.min_severity, total_packages, vulnerable_packages
)
# Write report if requested
if args.report:
try:
args.report.write_text(
json.dumps(report, indent=2, default=str),
encoding="utf-8",
)
logger.info(f"Report written to: {args.report}")
except Exception as e:
logger.error(f"Failed to write report: {e}")
# Print summary
print_summary(vulnerable_packages, args.min_severity, args.dry_run, args.fix)
# Apply fixes if requested
if args.fix or args.dry_run:
files_modified = apply_fixes(vulnerable_packages, dry_run=args.dry_run)
if not args.dry_run:
print(f"\nModified {files_modified} file(s)")
# Exit with appropriate code
if vulnerable_packages:
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,395 @@
#!/usr/bin/env python3
"""
StellaOps Solution Generator.
Generates consistent .sln files for:
- Main solution (src/StellaOps.sln) with all projects
- Module solutions (src/<Module>/StellaOps.<Module>.sln) with external deps in __External/
Usage:
python sln_generator.py [OPTIONS]
Options:
--src-root PATH Root of src/ directory (default: ./src)
--main-only Only regenerate main solution
--module NAME Regenerate specific module solution only
--all Regenerate all solutions (default)
--dry-run Show changes without writing
--check CI mode: exit 1 if solutions need updating
-v, --verbose Verbose output
"""
import argparse
import logging
import sys
from pathlib import Path
from lib.csproj_parser import find_all_csproj, parse_csproj
from lib.dependency_graph import (
collect_all_external_dependencies,
get_module_projects,
)
from lib.models import CsprojProject
from lib.sln_writer import (
build_external_folder_hierarchy,
build_folder_hierarchy,
generate_solution_content,
has_bypass_marker,
write_solution_file,
)
logger = logging.getLogger(__name__)
# Directories under src/ that are modules (have their own solutions)
# Excludes special directories like __Libraries, __Tests, __Analyzers, Web, etc.
EXCLUDED_FROM_MODULE_SOLUTIONS = {
"__Libraries",
"__Tests",
"__Analyzers",
".nuget",
".cache",
".vs",
"Web", # Angular project, not .NET
"plugins",
"app",
"Api",
"Sdk",
"DevPortal",
"Mirror",
"Provenance",
"Symbols",
"Unknowns",
}
def setup_logging(verbose: bool) -> None:
"""Configure logging based on verbosity."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(levelname)s: %(message)s",
)
def discover_modules(src_root: Path) -> list[Path]:
"""
Discover all module directories under src/.
A module is a directory that:
- Is a direct child of src/
- Is not in EXCLUDED_FROM_MODULE_SOLUTIONS
- Contains at least one .csproj file
Returns:
List of absolute paths to module directories
"""
modules: list[Path] = []
for item in src_root.iterdir():
if not item.is_dir():
continue
if item.name in EXCLUDED_FROM_MODULE_SOLUTIONS:
continue
if item.name.startswith("."):
continue
# Check if it contains any csproj files
csproj_files = list(item.rglob("*.csproj"))
if csproj_files:
modules.append(item.resolve())
return sorted(modules)
def load_all_projects(src_root: Path) -> tuple[list[CsprojProject], dict[Path, CsprojProject]]:
"""
Load and parse all projects under src/.
Returns:
Tuple of (list of all projects, map from path to project)
"""
csproj_files = find_all_csproj(src_root)
logger.info(f"Found {len(csproj_files)} .csproj files")
projects: list[CsprojProject] = []
project_map: dict[Path, CsprojProject] = {}
for csproj_path in csproj_files:
project = parse_csproj(csproj_path, src_root)
if project:
projects.append(project)
project_map[project.path] = project
else:
logger.warning(f"Failed to parse: {csproj_path}")
logger.info(f"Successfully parsed {len(projects)} projects")
return projects, project_map
def generate_main_solution(
src_root: Path,
projects: list[CsprojProject],
dry_run: bool = False,
) -> bool:
"""
Generate the main StellaOps.sln with all projects.
Args:
src_root: Root of src/ directory
projects: All parsed projects
dry_run: If True, don't write files
Returns:
True if successful
"""
sln_path = src_root / "StellaOps.sln"
# Check for bypass marker
if has_bypass_marker(sln_path):
logger.info(f"Skipping {sln_path} (has bypass marker)")
return True
logger.info(f"Generating main solution: {sln_path}")
# Build folder hierarchy matching physical structure
folders = build_folder_hierarchy(projects, src_root)
# Generate solution content
content = generate_solution_content(
sln_path=sln_path,
projects=[], # Projects are in folders
folders=folders,
external_folders=None,
add_bypass_marker=False,
)
return write_solution_file(sln_path, content, dry_run)
def generate_module_solution(
module_dir: Path,
src_root: Path,
all_projects: list[CsprojProject],
project_map: dict[Path, CsprojProject],
dry_run: bool = False,
) -> bool:
"""
Generate a module-specific solution.
Args:
module_dir: Root directory of the module
src_root: Root of src/ directory
all_projects: All parsed projects
project_map: Map from path to project
dry_run: If True, don't write files
Returns:
True if successful
"""
module_name = module_dir.name
sln_path = module_dir / f"StellaOps.{module_name}.sln"
# Check for bypass marker
if has_bypass_marker(sln_path):
logger.info(f"Skipping {sln_path} (has bypass marker)")
return True
logger.info(f"Generating module solution: {sln_path}")
# Get projects within this module
module_projects = get_module_projects(module_dir, all_projects)
if not module_projects:
logger.warning(f"No projects found in module: {module_name}")
return True
logger.debug(f" Found {len(module_projects)} projects in module")
# Build internal folder hierarchy
internal_folders = build_folder_hierarchy(module_projects, module_dir)
# Collect external dependencies
external_groups = collect_all_external_dependencies(
projects=module_projects,
module_dir=module_dir,
src_root=src_root,
project_map=project_map,
)
# Build external folder hierarchy
external_folders = {}
if external_groups:
external_folders = build_external_folder_hierarchy(external_groups, src_root)
ext_count = sum(len(v) for v in external_groups.values())
logger.debug(f" Found {ext_count} external dependencies")
# Generate solution content
content = generate_solution_content(
sln_path=sln_path,
projects=[], # Projects are in folders
folders=internal_folders,
external_folders=external_folders,
add_bypass_marker=False,
)
return write_solution_file(sln_path, content, dry_run)
def check_solutions_up_to_date(
src_root: Path,
modules: list[Path],
all_projects: list[CsprojProject],
project_map: dict[Path, CsprojProject],
) -> bool:
"""
Check if solutions need updating (for --check mode).
Args:
src_root: Root of src/ directory
modules: List of module directories
all_projects: All parsed projects
project_map: Map from path to project
Returns:
True if all solutions are up to date
"""
# This is a simplified check - in a real implementation,
# you would compare generated content with existing files
logger.info("Checking if solutions are up to date...")
# For now, just check if files exist
main_sln = src_root / "StellaOps.sln"
if not main_sln.exists():
logger.error(f"Main solution missing: {main_sln}")
return False
for module_dir in modules:
module_name = module_dir.name
module_sln = module_dir / f"StellaOps.{module_name}.sln"
if has_bypass_marker(module_sln):
continue
if not module_sln.exists():
logger.error(f"Module solution missing: {module_sln}")
return False
logger.info("All solutions appear to be up to date")
return True
def main() -> int:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Generate StellaOps solution files",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--src-root",
type=Path,
default=Path("src"),
help="Root of src/ directory (default: ./src)",
)
parser.add_argument(
"--main-only",
action="store_true",
help="Only regenerate main solution",
)
parser.add_argument(
"--module",
type=str,
help="Regenerate specific module solution only",
)
parser.add_argument(
"--all",
action="store_true",
dest="regenerate_all",
help="Regenerate all solutions (default)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show changes without writing",
)
parser.add_argument(
"--check",
action="store_true",
help="CI mode: exit 1 if solutions need updating",
)
parser.add_argument(
"-v", "--verbose",
action="store_true",
help="Verbose output",
)
args = parser.parse_args()
setup_logging(args.verbose)
# Resolve src root
src_root = args.src_root.resolve()
if not src_root.exists():
logger.error(f"Source root does not exist: {src_root}")
return 1
logger.info(f"Source root: {src_root}")
# Load all projects
all_projects, project_map = load_all_projects(src_root)
if not all_projects:
logger.error("No projects found")
return 1
# Discover modules
modules = discover_modules(src_root)
logger.info(f"Discovered {len(modules)} modules")
# Check mode
if args.check:
if check_solutions_up_to_date(src_root, modules, all_projects, project_map):
return 0
return 1
# Determine what to generate
success = True
if args.module:
# Specific module only
module_dir = src_root / args.module
if not module_dir.exists():
logger.error(f"Module directory does not exist: {module_dir}")
return 1
success = generate_module_solution(
module_dir, src_root, all_projects, project_map, args.dry_run
)
elif args.main_only:
# Main solution only
success = generate_main_solution(src_root, all_projects, args.dry_run)
else:
# Generate all (default)
# Main solution
if not generate_main_solution(src_root, all_projects, args.dry_run):
success = False
# Module solutions
for module_dir in modules:
if not generate_module_solution(
module_dir, src_root, all_projects, project_map, args.dry_run
):
success = False
if args.dry_run:
logger.info("Dry run complete - no files were modified")
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())