save progress
This commit is contained in:
416
tools/slntools/lib/nuget_api.py
Normal file
416
tools/slntools/lib/nuget_api.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
NuGet API v3 client for package version and vulnerability queries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None # type: ignore
|
||||
|
||||
from .version_utils import parse_version, is_stable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NUGET_SERVICE_INDEX = "https://api.nuget.org/v3/index.json"
|
||||
NUGET_VULN_INDEX = "https://api.nuget.org/v3/vulnerabilities/index.json"
|
||||
|
||||
|
||||
class NuGetApiError(Exception):
|
||||
"""Error communicating with NuGet API."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NuGetApiClient:
|
||||
"""
|
||||
Client for NuGet API v3 operations.
|
||||
|
||||
Provides methods for:
|
||||
- Fetching available package versions
|
||||
- Fetching vulnerability data
|
||||
- Finding non-vulnerable versions
|
||||
"""
|
||||
|
||||
def __init__(self, source: str = "https://api.nuget.org/v3"):
|
||||
if requests is None:
|
||||
raise ImportError(
|
||||
"requests library is required for NuGet API access. "
|
||||
"Install with: pip install requests"
|
||||
)
|
||||
|
||||
self.source = source.rstrip("/")
|
||||
self._session = requests.Session()
|
||||
self._session.headers.update(
|
||||
{"User-Agent": "StellaOps-NuGetVulnChecker/1.0"}
|
||||
)
|
||||
self._service_index: dict | None = None
|
||||
self._vuln_cache: dict[str, list[dict]] | None = None
|
||||
self._search_url: str | None = None
|
||||
self._registration_url: str | None = None
|
||||
|
||||
def _get_service_index(self) -> dict:
|
||||
"""Fetch and cache the NuGet service index."""
|
||||
if self._service_index is not None:
|
||||
return self._service_index
|
||||
|
||||
try:
|
||||
response = self._session.get(f"{self.source}/index.json", timeout=30)
|
||||
response.raise_for_status()
|
||||
self._service_index = response.json()
|
||||
return self._service_index
|
||||
except Exception as e:
|
||||
raise NuGetApiError(f"Failed to fetch NuGet service index: {e}")
|
||||
|
||||
def _get_search_url(self) -> str:
|
||||
"""Get the SearchQueryService URL from service index."""
|
||||
if self._search_url:
|
||||
return self._search_url
|
||||
|
||||
index = self._get_service_index()
|
||||
resources = index.get("resources", [])
|
||||
|
||||
# Look for SearchQueryService
|
||||
for resource in resources:
|
||||
resource_type = resource.get("@type", "")
|
||||
if "SearchQueryService" in resource_type:
|
||||
self._search_url = resource.get("@id", "")
|
||||
return self._search_url
|
||||
|
||||
raise NuGetApiError("SearchQueryService not found in service index")
|
||||
|
||||
def _get_registration_url(self) -> str:
|
||||
"""Get the RegistrationsBaseUrl from service index."""
|
||||
if self._registration_url:
|
||||
return self._registration_url
|
||||
|
||||
index = self._get_service_index()
|
||||
resources = index.get("resources", [])
|
||||
|
||||
# Look for RegistrationsBaseUrl
|
||||
for resource in resources:
|
||||
resource_type = resource.get("@type", "")
|
||||
if "RegistrationsBaseUrl" in resource_type:
|
||||
self._registration_url = resource.get("@id", "").rstrip("/")
|
||||
return self._registration_url
|
||||
|
||||
raise NuGetApiError("RegistrationsBaseUrl not found in service index")
|
||||
|
||||
def get_available_versions(self, package_id: str) -> list[str]:
|
||||
"""
|
||||
Fetch all available versions of a package from NuGet.
|
||||
|
||||
Args:
|
||||
package_id: The NuGet package ID
|
||||
|
||||
Returns:
|
||||
List of version strings, sorted newest first
|
||||
"""
|
||||
try:
|
||||
# Use registration API for complete version list
|
||||
reg_url = self._get_registration_url()
|
||||
package_lower = package_id.lower()
|
||||
url = f"{reg_url}/{package_lower}/index.json"
|
||||
|
||||
response = self._session.get(url, timeout=30)
|
||||
if response.status_code == 404:
|
||||
logger.warning(f"Package not found on NuGet: {package_id}")
|
||||
return []
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
versions = []
|
||||
for page in data.get("items", []):
|
||||
# Pages may be inline or require fetching
|
||||
if "items" in page:
|
||||
items = page["items"]
|
||||
else:
|
||||
# Fetch the page
|
||||
page_url = page.get("@id")
|
||||
if page_url:
|
||||
page_response = self._session.get(page_url, timeout=30)
|
||||
page_response.raise_for_status()
|
||||
items = page_response.json().get("items", [])
|
||||
else:
|
||||
items = []
|
||||
|
||||
for item in items:
|
||||
catalog_entry = item.get("catalogEntry", {})
|
||||
version = catalog_entry.get("version")
|
||||
if version:
|
||||
versions.append(version)
|
||||
|
||||
# Sort by parsed version, newest first
|
||||
def sort_key(v: str) -> tuple:
|
||||
parsed = parse_version(v)
|
||||
if parsed is None:
|
||||
return (0, 0, 0, "")
|
||||
return parsed
|
||||
|
||||
versions.sort(key=sort_key, reverse=True)
|
||||
return versions
|
||||
|
||||
except NuGetApiError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch versions for {package_id}: {e}")
|
||||
return []
|
||||
|
||||
def get_vulnerability_data(self) -> dict[str, list[dict]]:
|
||||
"""
|
||||
Fetch vulnerability data from NuGet VulnerabilityInfo API.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping lowercase package ID to list of vulnerability info dicts.
|
||||
Each dict contains: severity, advisory_url, versions (affected range)
|
||||
"""
|
||||
if self._vuln_cache is not None:
|
||||
return self._vuln_cache
|
||||
|
||||
try:
|
||||
# Fetch vulnerability index
|
||||
response = self._session.get(NUGET_VULN_INDEX, timeout=30)
|
||||
response.raise_for_status()
|
||||
index = response.json()
|
||||
|
||||
vuln_map: dict[str, list[dict]] = {}
|
||||
|
||||
# Fetch each vulnerability page
|
||||
for page_info in index:
|
||||
page_url = page_info.get("@id")
|
||||
if not page_url:
|
||||
continue
|
||||
|
||||
try:
|
||||
page_response = self._session.get(page_url, timeout=60)
|
||||
page_response.raise_for_status()
|
||||
page_data = page_response.json()
|
||||
|
||||
# Parse vulnerability entries
|
||||
self._merge_vuln_data(vuln_map, page_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch vulnerability page {page_url}: {e}")
|
||||
continue
|
||||
|
||||
self._vuln_cache = vuln_map
|
||||
logger.info(f"Loaded vulnerability data for {len(vuln_map)} packages")
|
||||
return vuln_map
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch vulnerability data: {e}")
|
||||
return {}
|
||||
|
||||
def _merge_vuln_data(
|
||||
self, vuln_map: dict[str, list[dict]], page_data: Any
|
||||
) -> None:
|
||||
"""Merge vulnerability page data into the vulnerability map."""
|
||||
# The vulnerability data format is a dict mapping package ID (lowercase)
|
||||
# to list of vulnerability objects
|
||||
if not isinstance(page_data, dict):
|
||||
return
|
||||
|
||||
for package_id, vulns in page_data.items():
|
||||
if package_id.startswith("@"):
|
||||
# Skip metadata fields like @context
|
||||
continue
|
||||
|
||||
package_lower = package_id.lower()
|
||||
if package_lower not in vuln_map:
|
||||
vuln_map[package_lower] = []
|
||||
|
||||
if isinstance(vulns, list):
|
||||
vuln_map[package_lower].extend(vulns)
|
||||
|
||||
def is_version_vulnerable(
|
||||
self, package_id: str, version: str, vuln_data: dict[str, list[dict]] | None = None
|
||||
) -> tuple[bool, list[dict]]:
|
||||
"""
|
||||
Check if a specific package version is vulnerable.
|
||||
|
||||
Args:
|
||||
package_id: The package ID
|
||||
version: The version to check
|
||||
vuln_data: Optional pre-fetched vulnerability data
|
||||
|
||||
Returns:
|
||||
Tuple of (is_vulnerable, list of matching vulnerabilities)
|
||||
"""
|
||||
if vuln_data is None:
|
||||
vuln_data = self.get_vulnerability_data()
|
||||
|
||||
package_lower = package_id.lower()
|
||||
vulns = vuln_data.get(package_lower, [])
|
||||
|
||||
if not vulns:
|
||||
return False, []
|
||||
|
||||
matching = []
|
||||
parsed_version = parse_version(version)
|
||||
if parsed_version is None:
|
||||
return False, []
|
||||
|
||||
for vuln in vulns:
|
||||
# Check version range
|
||||
version_range = vuln.get("versions", "")
|
||||
if self._version_in_range(version, parsed_version, version_range):
|
||||
matching.append(vuln)
|
||||
|
||||
return len(matching) > 0, matching
|
||||
|
||||
def _version_in_range(
|
||||
self, version: str, parsed: tuple, range_str: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a version is in a NuGet version range.
|
||||
|
||||
NuGet range formats:
|
||||
- "[1.0.0, 2.0.0)" - >= 1.0.0 and < 2.0.0
|
||||
- "(, 1.0.0)" - < 1.0.0
|
||||
- "[1.0.0,)" - >= 1.0.0
|
||||
- "1.0.0" - exact match
|
||||
"""
|
||||
if not range_str:
|
||||
return False
|
||||
|
||||
range_str = range_str.strip()
|
||||
|
||||
# Handle exact version
|
||||
if not range_str.startswith(("[", "(")):
|
||||
exact_parsed = parse_version(range_str)
|
||||
return exact_parsed == parsed if exact_parsed else False
|
||||
|
||||
# Parse range
|
||||
match = re.match(r"([\[\(])([^,]*),([^)\]]*)([\)\]])", range_str)
|
||||
if not match:
|
||||
return False
|
||||
|
||||
left_bracket, left_ver, right_ver, right_bracket = match.groups()
|
||||
left_ver = left_ver.strip()
|
||||
right_ver = right_ver.strip()
|
||||
|
||||
# Check lower bound
|
||||
if left_ver:
|
||||
left_parsed = parse_version(left_ver)
|
||||
if left_parsed:
|
||||
if left_bracket == "[":
|
||||
if parsed < left_parsed:
|
||||
return False
|
||||
else: # "("
|
||||
if parsed <= left_parsed:
|
||||
return False
|
||||
|
||||
# Check upper bound
|
||||
if right_ver:
|
||||
right_parsed = parse_version(right_ver)
|
||||
if right_parsed:
|
||||
if right_bracket == "]":
|
||||
if parsed > right_parsed:
|
||||
return False
|
||||
else: # ")"
|
||||
if parsed >= right_parsed:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def find_safe_version(
|
||||
self,
|
||||
package_id: str,
|
||||
current_version: str,
|
||||
prefer_upgrade: bool = True,
|
||||
) -> str | None:
|
||||
"""
|
||||
Find the closest non-vulnerable version.
|
||||
|
||||
Strategy:
|
||||
1. Get all available versions
|
||||
2. Filter out versions with known vulnerabilities
|
||||
3. Prefer: patch upgrade > minor upgrade > major upgrade > downgrade
|
||||
|
||||
Args:
|
||||
package_id: The package ID
|
||||
current_version: Current (vulnerable) version
|
||||
prefer_upgrade: If True, prefer upgrades over downgrades
|
||||
|
||||
Returns:
|
||||
Suggested safe version, or None if not found
|
||||
"""
|
||||
available = self.get_available_versions(package_id)
|
||||
if not available:
|
||||
return None
|
||||
|
||||
vuln_data = self.get_vulnerability_data()
|
||||
current_parsed = parse_version(current_version)
|
||||
if current_parsed is None:
|
||||
return None
|
||||
|
||||
# Find safe versions
|
||||
from .version_utils import ParsedVersion
|
||||
safe_versions: list[tuple[str, ParsedVersion]] = []
|
||||
for version in available:
|
||||
# Skip prereleases unless current is prerelease
|
||||
if not is_stable(version) and is_stable(current_version):
|
||||
continue
|
||||
|
||||
parsed = parse_version(version)
|
||||
if parsed is None:
|
||||
continue
|
||||
|
||||
is_vuln, _ = self.is_version_vulnerable(package_id, version, vuln_data)
|
||||
if not is_vuln:
|
||||
safe_versions.append((version, parsed))
|
||||
|
||||
if not safe_versions:
|
||||
return None
|
||||
|
||||
# Sort by preference: closest upgrade first
|
||||
def sort_key(item: tuple[str, ParsedVersion]) -> tuple:
|
||||
version, parsed = item
|
||||
major_diff = parsed.major - current_parsed.major
|
||||
minor_diff = parsed.minor - current_parsed.minor
|
||||
patch_diff = parsed.patch - current_parsed.patch
|
||||
|
||||
# Prefer upgrades (positive diff) over downgrades
|
||||
# Within upgrades, prefer smaller changes
|
||||
if prefer_upgrade:
|
||||
if major_diff > 0 or (major_diff == 0 and minor_diff > 0) or \
|
||||
(major_diff == 0 and minor_diff == 0 and patch_diff > 0):
|
||||
# Upgrade: prefer smaller version jumps
|
||||
return (0, major_diff, minor_diff, patch_diff)
|
||||
elif major_diff == 0 and minor_diff == 0 and patch_diff == 0:
|
||||
# Same version (shouldn't happen if vulnerable)
|
||||
return (1, 0, 0, 0)
|
||||
else:
|
||||
# Downgrade: prefer smaller version drops
|
||||
return (2, -major_diff, -minor_diff, -patch_diff)
|
||||
else:
|
||||
# Just prefer closest version
|
||||
return (abs(major_diff), abs(minor_diff), abs(patch_diff))
|
||||
|
||||
safe_versions.sort(key=sort_key)
|
||||
return safe_versions[0][0] if safe_versions else None
|
||||
|
||||
def get_fix_risk(
|
||||
self, current_version: str, suggested_version: str
|
||||
) -> str:
|
||||
"""
|
||||
Estimate the risk of upgrading to a suggested version.
|
||||
|
||||
Returns: "low", "medium", or "high"
|
||||
"""
|
||||
current = parse_version(current_version)
|
||||
suggested = parse_version(suggested_version)
|
||||
|
||||
if current is None or suggested is None:
|
||||
return "unknown"
|
||||
|
||||
if suggested.major > current.major:
|
||||
return "high" # Major version change
|
||||
elif suggested.minor > current.minor:
|
||||
return "medium" # Minor version change
|
||||
else:
|
||||
return "low" # Patch or no change
|
||||
Reference in New Issue
Block a user