417 lines
14 KiB
Python
417 lines
14 KiB
Python
"""
|
|
NuGet API v3 client for package version and vulnerability queries.
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
from typing import Any
|
|
|
|
try:
|
|
import requests
|
|
except ImportError:
|
|
requests = None # type: ignore
|
|
|
|
from .version_utils import parse_version, is_stable
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
NUGET_SERVICE_INDEX = "https://api.nuget.org/v3/index.json"
|
|
NUGET_VULN_INDEX = "https://api.nuget.org/v3/vulnerabilities/index.json"
|
|
|
|
|
|
class NuGetApiError(Exception):
|
|
"""Error communicating with NuGet API."""
|
|
|
|
pass
|
|
|
|
|
|
class NuGetApiClient:
|
|
"""
|
|
Client for NuGet API v3 operations.
|
|
|
|
Provides methods for:
|
|
- Fetching available package versions
|
|
- Fetching vulnerability data
|
|
- Finding non-vulnerable versions
|
|
"""
|
|
|
|
def __init__(self, source: str = "https://api.nuget.org/v3"):
|
|
if requests is None:
|
|
raise ImportError(
|
|
"requests library is required for NuGet API access. "
|
|
"Install with: pip install requests"
|
|
)
|
|
|
|
self.source = source.rstrip("/")
|
|
self._session = requests.Session()
|
|
self._session.headers.update(
|
|
{"User-Agent": "StellaOps-NuGetVulnChecker/1.0"}
|
|
)
|
|
self._service_index: dict | None = None
|
|
self._vuln_cache: dict[str, list[dict]] | None = None
|
|
self._search_url: str | None = None
|
|
self._registration_url: str | None = None
|
|
|
|
def _get_service_index(self) -> dict:
|
|
"""Fetch and cache the NuGet service index."""
|
|
if self._service_index is not None:
|
|
return self._service_index
|
|
|
|
try:
|
|
response = self._session.get(f"{self.source}/index.json", timeout=30)
|
|
response.raise_for_status()
|
|
self._service_index = response.json()
|
|
return self._service_index
|
|
except Exception as e:
|
|
raise NuGetApiError(f"Failed to fetch NuGet service index: {e}")
|
|
|
|
def _get_search_url(self) -> str:
|
|
"""Get the SearchQueryService URL from service index."""
|
|
if self._search_url:
|
|
return self._search_url
|
|
|
|
index = self._get_service_index()
|
|
resources = index.get("resources", [])
|
|
|
|
# Look for SearchQueryService
|
|
for resource in resources:
|
|
resource_type = resource.get("@type", "")
|
|
if "SearchQueryService" in resource_type:
|
|
self._search_url = resource.get("@id", "")
|
|
return self._search_url
|
|
|
|
raise NuGetApiError("SearchQueryService not found in service index")
|
|
|
|
def _get_registration_url(self) -> str:
|
|
"""Get the RegistrationsBaseUrl from service index."""
|
|
if self._registration_url:
|
|
return self._registration_url
|
|
|
|
index = self._get_service_index()
|
|
resources = index.get("resources", [])
|
|
|
|
# Look for RegistrationsBaseUrl
|
|
for resource in resources:
|
|
resource_type = resource.get("@type", "")
|
|
if "RegistrationsBaseUrl" in resource_type:
|
|
self._registration_url = resource.get("@id", "").rstrip("/")
|
|
return self._registration_url
|
|
|
|
raise NuGetApiError("RegistrationsBaseUrl not found in service index")
|
|
|
|
def get_available_versions(self, package_id: str) -> list[str]:
|
|
"""
|
|
Fetch all available versions of a package from NuGet.
|
|
|
|
Args:
|
|
package_id: The NuGet package ID
|
|
|
|
Returns:
|
|
List of version strings, sorted newest first
|
|
"""
|
|
try:
|
|
# Use registration API for complete version list
|
|
reg_url = self._get_registration_url()
|
|
package_lower = package_id.lower()
|
|
url = f"{reg_url}/{package_lower}/index.json"
|
|
|
|
response = self._session.get(url, timeout=30)
|
|
if response.status_code == 404:
|
|
logger.warning(f"Package not found on NuGet: {package_id}")
|
|
return []
|
|
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
versions = []
|
|
for page in data.get("items", []):
|
|
# Pages may be inline or require fetching
|
|
if "items" in page:
|
|
items = page["items"]
|
|
else:
|
|
# Fetch the page
|
|
page_url = page.get("@id")
|
|
if page_url:
|
|
page_response = self._session.get(page_url, timeout=30)
|
|
page_response.raise_for_status()
|
|
items = page_response.json().get("items", [])
|
|
else:
|
|
items = []
|
|
|
|
for item in items:
|
|
catalog_entry = item.get("catalogEntry", {})
|
|
version = catalog_entry.get("version")
|
|
if version:
|
|
versions.append(version)
|
|
|
|
# Sort by parsed version, newest first
|
|
def sort_key(v: str) -> tuple:
|
|
parsed = parse_version(v)
|
|
if parsed is None:
|
|
return (0, 0, 0, "")
|
|
return parsed
|
|
|
|
versions.sort(key=sort_key, reverse=True)
|
|
return versions
|
|
|
|
except NuGetApiError:
|
|
raise
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch versions for {package_id}: {e}")
|
|
return []
|
|
|
|
def get_vulnerability_data(self) -> dict[str, list[dict]]:
|
|
"""
|
|
Fetch vulnerability data from NuGet VulnerabilityInfo API.
|
|
|
|
Returns:
|
|
Dictionary mapping lowercase package ID to list of vulnerability info dicts.
|
|
Each dict contains: severity, advisory_url, versions (affected range)
|
|
"""
|
|
if self._vuln_cache is not None:
|
|
return self._vuln_cache
|
|
|
|
try:
|
|
# Fetch vulnerability index
|
|
response = self._session.get(NUGET_VULN_INDEX, timeout=30)
|
|
response.raise_for_status()
|
|
index = response.json()
|
|
|
|
vuln_map: dict[str, list[dict]] = {}
|
|
|
|
# Fetch each vulnerability page
|
|
for page_info in index:
|
|
page_url = page_info.get("@id")
|
|
if not page_url:
|
|
continue
|
|
|
|
try:
|
|
page_response = self._session.get(page_url, timeout=60)
|
|
page_response.raise_for_status()
|
|
page_data = page_response.json()
|
|
|
|
# Parse vulnerability entries
|
|
self._merge_vuln_data(vuln_map, page_data)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch vulnerability page {page_url}: {e}")
|
|
continue
|
|
|
|
self._vuln_cache = vuln_map
|
|
logger.info(f"Loaded vulnerability data for {len(vuln_map)} packages")
|
|
return vuln_map
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch vulnerability data: {e}")
|
|
return {}
|
|
|
|
def _merge_vuln_data(
|
|
self, vuln_map: dict[str, list[dict]], page_data: Any
|
|
) -> None:
|
|
"""Merge vulnerability page data into the vulnerability map."""
|
|
# The vulnerability data format is a dict mapping package ID (lowercase)
|
|
# to list of vulnerability objects
|
|
if not isinstance(page_data, dict):
|
|
return
|
|
|
|
for package_id, vulns in page_data.items():
|
|
if package_id.startswith("@"):
|
|
# Skip metadata fields like @context
|
|
continue
|
|
|
|
package_lower = package_id.lower()
|
|
if package_lower not in vuln_map:
|
|
vuln_map[package_lower] = []
|
|
|
|
if isinstance(vulns, list):
|
|
vuln_map[package_lower].extend(vulns)
|
|
|
|
def is_version_vulnerable(
|
|
self, package_id: str, version: str, vuln_data: dict[str, list[dict]] | None = None
|
|
) -> tuple[bool, list[dict]]:
|
|
"""
|
|
Check if a specific package version is vulnerable.
|
|
|
|
Args:
|
|
package_id: The package ID
|
|
version: The version to check
|
|
vuln_data: Optional pre-fetched vulnerability data
|
|
|
|
Returns:
|
|
Tuple of (is_vulnerable, list of matching vulnerabilities)
|
|
"""
|
|
if vuln_data is None:
|
|
vuln_data = self.get_vulnerability_data()
|
|
|
|
package_lower = package_id.lower()
|
|
vulns = vuln_data.get(package_lower, [])
|
|
|
|
if not vulns:
|
|
return False, []
|
|
|
|
matching = []
|
|
parsed_version = parse_version(version)
|
|
if parsed_version is None:
|
|
return False, []
|
|
|
|
for vuln in vulns:
|
|
# Check version range
|
|
version_range = vuln.get("versions", "")
|
|
if self._version_in_range(version, parsed_version, version_range):
|
|
matching.append(vuln)
|
|
|
|
return len(matching) > 0, matching
|
|
|
|
def _version_in_range(
|
|
self, version: str, parsed: tuple, range_str: str
|
|
) -> bool:
|
|
"""
|
|
Check if a version is in a NuGet version range.
|
|
|
|
NuGet range formats:
|
|
- "[1.0.0, 2.0.0)" - >= 1.0.0 and < 2.0.0
|
|
- "(, 1.0.0)" - < 1.0.0
|
|
- "[1.0.0,)" - >= 1.0.0
|
|
- "1.0.0" - exact match
|
|
"""
|
|
if not range_str:
|
|
return False
|
|
|
|
range_str = range_str.strip()
|
|
|
|
# Handle exact version
|
|
if not range_str.startswith(("[", "(")):
|
|
exact_parsed = parse_version(range_str)
|
|
return exact_parsed == parsed if exact_parsed else False
|
|
|
|
# Parse range
|
|
match = re.match(r"([\[\(])([^,]*),([^)\]]*)([\)\]])", range_str)
|
|
if not match:
|
|
return False
|
|
|
|
left_bracket, left_ver, right_ver, right_bracket = match.groups()
|
|
left_ver = left_ver.strip()
|
|
right_ver = right_ver.strip()
|
|
|
|
# Check lower bound
|
|
if left_ver:
|
|
left_parsed = parse_version(left_ver)
|
|
if left_parsed:
|
|
if left_bracket == "[":
|
|
if parsed < left_parsed:
|
|
return False
|
|
else: # "("
|
|
if parsed <= left_parsed:
|
|
return False
|
|
|
|
# Check upper bound
|
|
if right_ver:
|
|
right_parsed = parse_version(right_ver)
|
|
if right_parsed:
|
|
if right_bracket == "]":
|
|
if parsed > right_parsed:
|
|
return False
|
|
else: # ")"
|
|
if parsed >= right_parsed:
|
|
return False
|
|
|
|
return True
|
|
|
|
def find_safe_version(
|
|
self,
|
|
package_id: str,
|
|
current_version: str,
|
|
prefer_upgrade: bool = True,
|
|
) -> str | None:
|
|
"""
|
|
Find the closest non-vulnerable version.
|
|
|
|
Strategy:
|
|
1. Get all available versions
|
|
2. Filter out versions with known vulnerabilities
|
|
3. Prefer: patch upgrade > minor upgrade > major upgrade > downgrade
|
|
|
|
Args:
|
|
package_id: The package ID
|
|
current_version: Current (vulnerable) version
|
|
prefer_upgrade: If True, prefer upgrades over downgrades
|
|
|
|
Returns:
|
|
Suggested safe version, or None if not found
|
|
"""
|
|
available = self.get_available_versions(package_id)
|
|
if not available:
|
|
return None
|
|
|
|
vuln_data = self.get_vulnerability_data()
|
|
current_parsed = parse_version(current_version)
|
|
if current_parsed is None:
|
|
return None
|
|
|
|
# Find safe versions
|
|
from .version_utils import ParsedVersion
|
|
safe_versions: list[tuple[str, ParsedVersion]] = []
|
|
for version in available:
|
|
# Skip prereleases unless current is prerelease
|
|
if not is_stable(version) and is_stable(current_version):
|
|
continue
|
|
|
|
parsed = parse_version(version)
|
|
if parsed is None:
|
|
continue
|
|
|
|
is_vuln, _ = self.is_version_vulnerable(package_id, version, vuln_data)
|
|
if not is_vuln:
|
|
safe_versions.append((version, parsed))
|
|
|
|
if not safe_versions:
|
|
return None
|
|
|
|
# Sort by preference: closest upgrade first
|
|
def sort_key(item: tuple[str, ParsedVersion]) -> tuple:
|
|
version, parsed = item
|
|
major_diff = parsed.major - current_parsed.major
|
|
minor_diff = parsed.minor - current_parsed.minor
|
|
patch_diff = parsed.patch - current_parsed.patch
|
|
|
|
# Prefer upgrades (positive diff) over downgrades
|
|
# Within upgrades, prefer smaller changes
|
|
if prefer_upgrade:
|
|
if major_diff > 0 or (major_diff == 0 and minor_diff > 0) or \
|
|
(major_diff == 0 and minor_diff == 0 and patch_diff > 0):
|
|
# Upgrade: prefer smaller version jumps
|
|
return (0, major_diff, minor_diff, patch_diff)
|
|
elif major_diff == 0 and minor_diff == 0 and patch_diff == 0:
|
|
# Same version (shouldn't happen if vulnerable)
|
|
return (1, 0, 0, 0)
|
|
else:
|
|
# Downgrade: prefer smaller version drops
|
|
return (2, -major_diff, -minor_diff, -patch_diff)
|
|
else:
|
|
# Just prefer closest version
|
|
return (abs(major_diff), abs(minor_diff), abs(patch_diff))
|
|
|
|
safe_versions.sort(key=sort_key)
|
|
return safe_versions[0][0] if safe_versions else None
|
|
|
|
def get_fix_risk(
|
|
self, current_version: str, suggested_version: str
|
|
) -> str:
|
|
"""
|
|
Estimate the risk of upgrading to a suggested version.
|
|
|
|
Returns: "low", "medium", or "high"
|
|
"""
|
|
current = parse_version(current_version)
|
|
suggested = parse_version(suggested_version)
|
|
|
|
if current is None or suggested is None:
|
|
return "unknown"
|
|
|
|
if suggested.major > current.major:
|
|
return "high" # Major version change
|
|
elif suggested.minor > current.minor:
|
|
return "medium" # Minor version change
|
|
else:
|
|
return "low" # Patch or no change
|