using System.Diagnostics;
using System.Net;
using System.Net.Http.Json;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StellaOps.Policy.Engine.Telemetry;
namespace StellaOps.Policy.Engine.ReachabilityFacts;
///
/// HTTP client for fetching reachability facts from Signals service.
///
public sealed class ReachabilityFactsSignalsClient : IReachabilityFactsSignalsClient
{
private static readonly JsonSerializerOptions SerializerOptions = new(JsonSerializerDefaults.Web)
{
PropertyNameCaseInsensitive = true,
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
};
private readonly HttpClient _httpClient;
private readonly ReachabilityFactsSignalsClientOptions _options;
private readonly ILogger _logger;
public ReachabilityFactsSignalsClient(
HttpClient httpClient,
IOptions options,
ILogger logger)
{
_httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
ArgumentNullException.ThrowIfNull(options);
_options = options.Value;
if (_httpClient.BaseAddress is null && _options.BaseUri is not null)
{
_httpClient.BaseAddress = _options.BaseUri;
}
_httpClient.DefaultRequestHeaders.Accept.Clear();
_httpClient.DefaultRequestHeaders.Accept.ParseAdd("application/json");
}
///
public async Task GetBySubjectAsync(
string subjectKey,
CancellationToken cancellationToken = default)
{
ArgumentException.ThrowIfNullOrWhiteSpace(subjectKey);
using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity(
"signals_client.get_fact",
ActivityKind.Client);
activity?.SetTag("signals.subject_key", subjectKey);
var path = $"signals/facts/{Uri.EscapeDataString(subjectKey)}";
try
{
var response = await _httpClient.GetAsync(path, cancellationToken).ConfigureAwait(false);
if (response.StatusCode == HttpStatusCode.NotFound)
{
_logger.LogDebug("Reachability fact not found for subject {SubjectKey}", subjectKey);
return null;
}
response.EnsureSuccessStatusCode();
var fact = await response.Content
.ReadFromJsonAsync(SerializerOptions, cancellationToken)
.ConfigureAwait(false);
_logger.LogDebug(
"Retrieved reachability fact for subject {SubjectKey}: score={Score}, states={StateCount}",
subjectKey,
fact?.Score,
fact?.States?.Count ?? 0);
return fact;
}
catch (HttpRequestException ex) when (ex.StatusCode == HttpStatusCode.NotFound)
{
return null;
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to get reachability fact for subject {SubjectKey}", subjectKey);
throw;
}
}
///
public async Task> GetBatchBySubjectsAsync(
IReadOnlyList subjectKeys,
CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(subjectKeys);
if (subjectKeys.Count == 0)
{
return new Dictionary();
}
using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity(
"signals_client.get_facts_batch",
ActivityKind.Client);
activity?.SetTag("signals.batch_size", subjectKeys.Count);
var result = new Dictionary(StringComparer.Ordinal);
// Signals doesn't expose a batch endpoint, so we fetch in parallel with concurrency limit
var semaphore = new SemaphoreSlim(_options.MaxConcurrentRequests);
var tasks = subjectKeys.Select(async key =>
{
await semaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
var fact = await GetBySubjectAsync(key, cancellationToken).ConfigureAwait(false);
return (Key: key, Fact: fact);
}
finally
{
semaphore.Release();
}
});
var results = await Task.WhenAll(tasks).ConfigureAwait(false);
foreach (var (key, fact) in results)
{
if (fact is not null)
{
result[key] = fact;
}
}
activity?.SetTag("signals.found_count", result.Count);
_logger.LogDebug(
"Batch retrieved {FoundCount}/{TotalCount} reachability facts",
result.Count,
subjectKeys.Count);
return result;
}
///
public async Task TriggerRecomputeAsync(
SignalsRecomputeRequest request,
CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(request);
using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity(
"signals_client.trigger_recompute",
ActivityKind.Client);
activity?.SetTag("signals.subject_key", request.SubjectKey);
activity?.SetTag("signals.tenant_id", request.TenantId);
try
{
var response = await _httpClient.PostAsJsonAsync(
"signals/reachability/recompute",
new { subjectKey = request.SubjectKey, tenantId = request.TenantId },
SerializerOptions,
cancellationToken).ConfigureAwait(false);
if (response.IsSuccessStatusCode)
{
_logger.LogInformation(
"Triggered reachability recompute for subject {SubjectKey}",
request.SubjectKey);
return true;
}
_logger.LogWarning(
"Failed to trigger reachability recompute for subject {SubjectKey}: {StatusCode}",
request.SubjectKey,
response.StatusCode);
return false;
}
catch (Exception ex)
{
_logger.LogError(
ex,
"Error triggering reachability recompute for subject {SubjectKey}",
request.SubjectKey);
return false;
}
}
///
public async Task GetWithSubgraphAsync(
string subjectKey,
string? cveId = null,
CancellationToken cancellationToken = default)
{
ArgumentException.ThrowIfNullOrWhiteSpace(subjectKey);
using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity(
"signals_client.get_fact_with_subgraph",
ActivityKind.Client);
activity?.SetTag("signals.subject_key", subjectKey);
if (cveId is not null)
{
activity?.SetTag("signals.cve_id", cveId);
}
// Get base reachability fact from Signals
var fact = await GetBySubjectAsync(subjectKey, cancellationToken).ConfigureAwait(false);
if (fact is null)
{
_logger.LogDebug("No reachability fact found for subject {SubjectKey}", subjectKey);
return null;
}
if (string.IsNullOrEmpty(fact.CallgraphId))
{
_logger.LogDebug(
"Reachability fact for subject {SubjectKey} has no callgraph ID",
subjectKey);
return new ReachabilityFactWithSubgraph(fact, null);
}
// Fetch subgraph slice from ReachGraph Store
var sliceQuery = cveId is not null
? $"?cve={Uri.EscapeDataString(cveId)}"
: "";
try
{
var slicePath = _options.ReachGraphStoreBaseUri is not null
? $"v1/reachgraphs/{Uri.EscapeDataString(fact.CallgraphId)}/slice{sliceQuery}"
: $"reachgraph/v1/reachgraphs/{Uri.EscapeDataString(fact.CallgraphId)}/slice{sliceQuery}";
var response = await _httpClient.GetAsync(slicePath, cancellationToken).ConfigureAwait(false);
if (!response.IsSuccessStatusCode)
{
_logger.LogWarning(
"Failed to fetch subgraph slice for callgraph {CallgraphId}: {StatusCode}",
fact.CallgraphId,
response.StatusCode);
return new ReachabilityFactWithSubgraph(fact, null);
}
var slice = await response.Content
.ReadFromJsonAsync(SerializerOptions, cancellationToken)
.ConfigureAwait(false);
_logger.LogDebug(
"Fetched subgraph slice for callgraph {CallgraphId}: {NodeCount} nodes, {PathCount} paths",
fact.CallgraphId,
slice?.NodeCount ?? 0,
slice?.Paths?.Count ?? 0);
return new ReachabilityFactWithSubgraph(fact, slice);
}
catch (HttpRequestException ex)
{
_logger.LogWarning(
ex,
"Error fetching subgraph slice for callgraph {CallgraphId}",
fact.CallgraphId);
return new ReachabilityFactWithSubgraph(fact, null);
}
}
}
///
/// Configuration options for the Signals reachability client.
///
public sealed class ReachabilityFactsSignalsClientOptions
{
///
/// Configuration section name.
///
public const string SectionName = "ReachabilitySignals";
///
/// Base URI for the Signals service.
///
public Uri? BaseUri { get; set; }
///
/// Base URI for the ReachGraph Store service.
/// If null, uses the same base URI as Signals.
///
public Uri? ReachGraphStoreBaseUri { get; set; }
///
/// Maximum concurrent requests for batch operations.
/// Default: 10.
///
public int MaxConcurrentRequests { get; set; } = 10;
///
/// Request timeout.
/// Default: 30 seconds.
///
public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(30);
///
/// Retry count for transient failures.
/// Default: 3.
///
public int RetryCount { get; set; } = 3;
}