using System.Diagnostics; using System.Net; using System.Net.Http.Json; using System.Text.Json; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using StellaOps.Policy.Engine.Telemetry; namespace StellaOps.Policy.Engine.ReachabilityFacts; /// /// HTTP client for fetching reachability facts from Signals service. /// public sealed class ReachabilityFactsSignalsClient : IReachabilityFactsSignalsClient { private static readonly JsonSerializerOptions SerializerOptions = new(JsonSerializerDefaults.Web) { PropertyNameCaseInsensitive = true, PropertyNamingPolicy = JsonNamingPolicy.CamelCase, }; private readonly HttpClient _httpClient; private readonly ReachabilityFactsSignalsClientOptions _options; private readonly ILogger _logger; public ReachabilityFactsSignalsClient( HttpClient httpClient, IOptions options, ILogger logger) { _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); ArgumentNullException.ThrowIfNull(options); _options = options.Value; if (_httpClient.BaseAddress is null && _options.BaseUri is not null) { _httpClient.BaseAddress = _options.BaseUri; } _httpClient.DefaultRequestHeaders.Accept.Clear(); _httpClient.DefaultRequestHeaders.Accept.ParseAdd("application/json"); } /// public async Task GetBySubjectAsync( string subjectKey, CancellationToken cancellationToken = default) { ArgumentException.ThrowIfNullOrWhiteSpace(subjectKey); using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity( "signals_client.get_fact", ActivityKind.Client); activity?.SetTag("signals.subject_key", subjectKey); var path = $"signals/facts/{Uri.EscapeDataString(subjectKey)}"; try { var response = await _httpClient.GetAsync(path, cancellationToken).ConfigureAwait(false); if (response.StatusCode == HttpStatusCode.NotFound) { _logger.LogDebug("Reachability fact not found for subject {SubjectKey}", subjectKey); return null; } response.EnsureSuccessStatusCode(); var fact = await response.Content .ReadFromJsonAsync(SerializerOptions, cancellationToken) .ConfigureAwait(false); _logger.LogDebug( "Retrieved reachability fact for subject {SubjectKey}: score={Score}, states={StateCount}", subjectKey, fact?.Score, fact?.States?.Count ?? 0); return fact; } catch (HttpRequestException ex) when (ex.StatusCode == HttpStatusCode.NotFound) { return null; } catch (Exception ex) { _logger.LogError(ex, "Failed to get reachability fact for subject {SubjectKey}", subjectKey); throw; } } /// public async Task> GetBatchBySubjectsAsync( IReadOnlyList subjectKeys, CancellationToken cancellationToken = default) { ArgumentNullException.ThrowIfNull(subjectKeys); if (subjectKeys.Count == 0) { return new Dictionary(); } using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity( "signals_client.get_facts_batch", ActivityKind.Client); activity?.SetTag("signals.batch_size", subjectKeys.Count); var result = new Dictionary(StringComparer.Ordinal); // Signals doesn't expose a batch endpoint, so we fetch in parallel with concurrency limit var semaphore = new SemaphoreSlim(_options.MaxConcurrentRequests); var tasks = subjectKeys.Select(async key => { await semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); try { var fact = await GetBySubjectAsync(key, cancellationToken).ConfigureAwait(false); return (Key: key, Fact: fact); } finally { semaphore.Release(); } }); var results = await Task.WhenAll(tasks).ConfigureAwait(false); foreach (var (key, fact) in results) { if (fact is not null) { result[key] = fact; } } activity?.SetTag("signals.found_count", result.Count); _logger.LogDebug( "Batch retrieved {FoundCount}/{TotalCount} reachability facts", result.Count, subjectKeys.Count); return result; } /// public async Task TriggerRecomputeAsync( SignalsRecomputeRequest request, CancellationToken cancellationToken = default) { ArgumentNullException.ThrowIfNull(request); using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity( "signals_client.trigger_recompute", ActivityKind.Client); activity?.SetTag("signals.subject_key", request.SubjectKey); activity?.SetTag("signals.tenant_id", request.TenantId); try { var response = await _httpClient.PostAsJsonAsync( "signals/reachability/recompute", new { subjectKey = request.SubjectKey, tenantId = request.TenantId }, SerializerOptions, cancellationToken).ConfigureAwait(false); if (response.IsSuccessStatusCode) { _logger.LogInformation( "Triggered reachability recompute for subject {SubjectKey}", request.SubjectKey); return true; } _logger.LogWarning( "Failed to trigger reachability recompute for subject {SubjectKey}: {StatusCode}", request.SubjectKey, response.StatusCode); return false; } catch (Exception ex) { _logger.LogError( ex, "Error triggering reachability recompute for subject {SubjectKey}", request.SubjectKey); return false; } } /// public async Task GetWithSubgraphAsync( string subjectKey, string? cveId = null, CancellationToken cancellationToken = default) { ArgumentException.ThrowIfNullOrWhiteSpace(subjectKey); using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity( "signals_client.get_fact_with_subgraph", ActivityKind.Client); activity?.SetTag("signals.subject_key", subjectKey); if (cveId is not null) { activity?.SetTag("signals.cve_id", cveId); } // Get base reachability fact from Signals var fact = await GetBySubjectAsync(subjectKey, cancellationToken).ConfigureAwait(false); if (fact is null) { _logger.LogDebug("No reachability fact found for subject {SubjectKey}", subjectKey); return null; } if (string.IsNullOrEmpty(fact.CallgraphId)) { _logger.LogDebug( "Reachability fact for subject {SubjectKey} has no callgraph ID", subjectKey); return new ReachabilityFactWithSubgraph(fact, null); } // Fetch subgraph slice from ReachGraph Store var sliceQuery = cveId is not null ? $"?cve={Uri.EscapeDataString(cveId)}" : ""; try { var slicePath = _options.ReachGraphStoreBaseUri is not null ? $"v1/reachgraphs/{Uri.EscapeDataString(fact.CallgraphId)}/slice{sliceQuery}" : $"reachgraph/v1/reachgraphs/{Uri.EscapeDataString(fact.CallgraphId)}/slice{sliceQuery}"; var response = await _httpClient.GetAsync(slicePath, cancellationToken).ConfigureAwait(false); if (!response.IsSuccessStatusCode) { _logger.LogWarning( "Failed to fetch subgraph slice for callgraph {CallgraphId}: {StatusCode}", fact.CallgraphId, response.StatusCode); return new ReachabilityFactWithSubgraph(fact, null); } var slice = await response.Content .ReadFromJsonAsync(SerializerOptions, cancellationToken) .ConfigureAwait(false); _logger.LogDebug( "Fetched subgraph slice for callgraph {CallgraphId}: {NodeCount} nodes, {PathCount} paths", fact.CallgraphId, slice?.NodeCount ?? 0, slice?.Paths?.Count ?? 0); return new ReachabilityFactWithSubgraph(fact, slice); } catch (HttpRequestException ex) { _logger.LogWarning( ex, "Error fetching subgraph slice for callgraph {CallgraphId}", fact.CallgraphId); return new ReachabilityFactWithSubgraph(fact, null); } } } /// /// Configuration options for the Signals reachability client. /// public sealed class ReachabilityFactsSignalsClientOptions { /// /// Configuration section name. /// public const string SectionName = "ReachabilitySignals"; /// /// Base URI for the Signals service. /// public Uri? BaseUri { get; set; } /// /// Base URI for the ReachGraph Store service. /// If null, uses the same base URI as Signals. /// public Uri? ReachGraphStoreBaseUri { get; set; } /// /// Maximum concurrent requests for batch operations. /// Default: 10. /// public int MaxConcurrentRequests { get; set; } = 10; /// /// Request timeout. /// Default: 30 seconds. /// public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(30); /// /// Retry count for transient failures. /// Default: 3. /// public int RetryCount { get; set; } = 3; }