311 lines
11 KiB
C#
311 lines
11 KiB
C#
using System.Diagnostics;
|
|
using System.Net;
|
|
using System.Net.Http.Json;
|
|
using System.Text.Json;
|
|
using Microsoft.Extensions.Logging;
|
|
using Microsoft.Extensions.Options;
|
|
using StellaOps.Policy.Engine.Telemetry;
|
|
|
|
namespace StellaOps.Policy.Engine.ReachabilityFacts;
|
|
|
|
/// <summary>
|
|
/// HTTP client for fetching reachability facts from Signals service.
|
|
/// </summary>
|
|
public sealed class ReachabilityFactsSignalsClient : IReachabilityFactsSignalsClient
|
|
{
|
|
private static readonly JsonSerializerOptions SerializerOptions = new(JsonSerializerDefaults.Web)
|
|
{
|
|
PropertyNameCaseInsensitive = true,
|
|
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
|
};
|
|
|
|
private readonly HttpClient _httpClient;
|
|
private readonly ReachabilityFactsSignalsClientOptions _options;
|
|
private readonly ILogger<ReachabilityFactsSignalsClient> _logger;
|
|
|
|
public ReachabilityFactsSignalsClient(
|
|
HttpClient httpClient,
|
|
IOptions<ReachabilityFactsSignalsClientOptions> options,
|
|
ILogger<ReachabilityFactsSignalsClient> logger)
|
|
{
|
|
_httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient));
|
|
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
|
|
|
|
ArgumentNullException.ThrowIfNull(options);
|
|
_options = options.Value;
|
|
|
|
if (_httpClient.BaseAddress is null && _options.BaseUri is not null)
|
|
{
|
|
_httpClient.BaseAddress = _options.BaseUri;
|
|
}
|
|
|
|
_httpClient.DefaultRequestHeaders.Accept.Clear();
|
|
_httpClient.DefaultRequestHeaders.Accept.ParseAdd("application/json");
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<SignalsReachabilityFactResponse?> GetBySubjectAsync(
|
|
string subjectKey,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentException.ThrowIfNullOrWhiteSpace(subjectKey);
|
|
|
|
using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity(
|
|
"signals_client.get_fact",
|
|
ActivityKind.Client);
|
|
activity?.SetTag("signals.subject_key", subjectKey);
|
|
|
|
var path = $"signals/facts/{Uri.EscapeDataString(subjectKey)}";
|
|
|
|
try
|
|
{
|
|
var response = await _httpClient.GetAsync(path, cancellationToken).ConfigureAwait(false);
|
|
|
|
if (response.StatusCode == HttpStatusCode.NotFound)
|
|
{
|
|
_logger.LogDebug("Reachability fact not found for subject {SubjectKey}", subjectKey);
|
|
return null;
|
|
}
|
|
|
|
response.EnsureSuccessStatusCode();
|
|
|
|
var fact = await response.Content
|
|
.ReadFromJsonAsync<SignalsReachabilityFactResponse>(SerializerOptions, cancellationToken)
|
|
.ConfigureAwait(false);
|
|
|
|
_logger.LogDebug(
|
|
"Retrieved reachability fact for subject {SubjectKey}: score={Score}, states={StateCount}",
|
|
subjectKey,
|
|
fact?.Score,
|
|
fact?.States?.Count ?? 0);
|
|
|
|
return fact;
|
|
}
|
|
catch (HttpRequestException ex) when (ex.StatusCode == HttpStatusCode.NotFound)
|
|
{
|
|
return null;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "Failed to get reachability fact for subject {SubjectKey}", subjectKey);
|
|
throw;
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<IReadOnlyDictionary<string, SignalsReachabilityFactResponse>> GetBatchBySubjectsAsync(
|
|
IReadOnlyList<string> subjectKeys,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(subjectKeys);
|
|
|
|
if (subjectKeys.Count == 0)
|
|
{
|
|
return new Dictionary<string, SignalsReachabilityFactResponse>();
|
|
}
|
|
|
|
using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity(
|
|
"signals_client.get_facts_batch",
|
|
ActivityKind.Client);
|
|
activity?.SetTag("signals.batch_size", subjectKeys.Count);
|
|
|
|
var result = new Dictionary<string, SignalsReachabilityFactResponse>(StringComparer.Ordinal);
|
|
|
|
// Signals doesn't expose a batch endpoint, so we fetch in parallel with concurrency limit
|
|
var semaphore = new SemaphoreSlim(_options.MaxConcurrentRequests);
|
|
var tasks = subjectKeys.Select(async key =>
|
|
{
|
|
await semaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
|
|
try
|
|
{
|
|
var fact = await GetBySubjectAsync(key, cancellationToken).ConfigureAwait(false);
|
|
return (Key: key, Fact: fact);
|
|
}
|
|
finally
|
|
{
|
|
semaphore.Release();
|
|
}
|
|
});
|
|
|
|
var results = await Task.WhenAll(tasks).ConfigureAwait(false);
|
|
|
|
foreach (var (key, fact) in results)
|
|
{
|
|
if (fact is not null)
|
|
{
|
|
result[key] = fact;
|
|
}
|
|
}
|
|
|
|
activity?.SetTag("signals.found_count", result.Count);
|
|
_logger.LogDebug(
|
|
"Batch retrieved {FoundCount}/{TotalCount} reachability facts",
|
|
result.Count,
|
|
subjectKeys.Count);
|
|
|
|
return result;
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<bool> TriggerRecomputeAsync(
|
|
SignalsRecomputeRequest request,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(request);
|
|
|
|
using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity(
|
|
"signals_client.trigger_recompute",
|
|
ActivityKind.Client);
|
|
activity?.SetTag("signals.subject_key", request.SubjectKey);
|
|
activity?.SetTag("signals.tenant_id", request.TenantId);
|
|
|
|
try
|
|
{
|
|
var response = await _httpClient.PostAsJsonAsync(
|
|
"signals/reachability/recompute",
|
|
new { subjectKey = request.SubjectKey, tenantId = request.TenantId },
|
|
SerializerOptions,
|
|
cancellationToken).ConfigureAwait(false);
|
|
|
|
if (response.IsSuccessStatusCode)
|
|
{
|
|
_logger.LogInformation(
|
|
"Triggered reachability recompute for subject {SubjectKey}",
|
|
request.SubjectKey);
|
|
return true;
|
|
}
|
|
|
|
_logger.LogWarning(
|
|
"Failed to trigger reachability recompute for subject {SubjectKey}: {StatusCode}",
|
|
request.SubjectKey,
|
|
response.StatusCode);
|
|
return false;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(
|
|
ex,
|
|
"Error triggering reachability recompute for subject {SubjectKey}",
|
|
request.SubjectKey);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async Task<ReachabilityFactWithSubgraph?> GetWithSubgraphAsync(
|
|
string subjectKey,
|
|
string? cveId = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentException.ThrowIfNullOrWhiteSpace(subjectKey);
|
|
|
|
using var activity = PolicyEngineTelemetry.ActivitySource.StartActivity(
|
|
"signals_client.get_fact_with_subgraph",
|
|
ActivityKind.Client);
|
|
activity?.SetTag("signals.subject_key", subjectKey);
|
|
if (cveId is not null)
|
|
{
|
|
activity?.SetTag("signals.cve_id", cveId);
|
|
}
|
|
|
|
// Get base reachability fact from Signals
|
|
var fact = await GetBySubjectAsync(subjectKey, cancellationToken).ConfigureAwait(false);
|
|
if (fact is null)
|
|
{
|
|
_logger.LogDebug("No reachability fact found for subject {SubjectKey}", subjectKey);
|
|
return null;
|
|
}
|
|
|
|
if (string.IsNullOrEmpty(fact.CallgraphId))
|
|
{
|
|
_logger.LogDebug(
|
|
"Reachability fact for subject {SubjectKey} has no callgraph ID",
|
|
subjectKey);
|
|
return new ReachabilityFactWithSubgraph(fact, null);
|
|
}
|
|
|
|
// Fetch subgraph slice from ReachGraph Store
|
|
var sliceQuery = cveId is not null
|
|
? $"?cve={Uri.EscapeDataString(cveId)}"
|
|
: "";
|
|
|
|
try
|
|
{
|
|
var slicePath = _options.ReachGraphStoreBaseUri is not null
|
|
? $"v1/reachgraphs/{Uri.EscapeDataString(fact.CallgraphId)}/slice{sliceQuery}"
|
|
: $"reachgraph/v1/reachgraphs/{Uri.EscapeDataString(fact.CallgraphId)}/slice{sliceQuery}";
|
|
|
|
var response = await _httpClient.GetAsync(slicePath, cancellationToken).ConfigureAwait(false);
|
|
|
|
if (!response.IsSuccessStatusCode)
|
|
{
|
|
_logger.LogWarning(
|
|
"Failed to fetch subgraph slice for callgraph {CallgraphId}: {StatusCode}",
|
|
fact.CallgraphId,
|
|
response.StatusCode);
|
|
return new ReachabilityFactWithSubgraph(fact, null);
|
|
}
|
|
|
|
var slice = await response.Content
|
|
.ReadFromJsonAsync<ReachGraphSlice>(SerializerOptions, cancellationToken)
|
|
.ConfigureAwait(false);
|
|
|
|
_logger.LogDebug(
|
|
"Fetched subgraph slice for callgraph {CallgraphId}: {NodeCount} nodes, {PathCount} paths",
|
|
fact.CallgraphId,
|
|
slice?.NodeCount ?? 0,
|
|
slice?.Paths?.Count ?? 0);
|
|
|
|
return new ReachabilityFactWithSubgraph(fact, slice);
|
|
}
|
|
catch (HttpRequestException ex)
|
|
{
|
|
_logger.LogWarning(
|
|
ex,
|
|
"Error fetching subgraph slice for callgraph {CallgraphId}",
|
|
fact.CallgraphId);
|
|
return new ReachabilityFactWithSubgraph(fact, null);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Configuration options for the Signals reachability client.
|
|
/// </summary>
|
|
public sealed class ReachabilityFactsSignalsClientOptions
|
|
{
|
|
/// <summary>
|
|
/// Configuration section name.
|
|
/// </summary>
|
|
public const string SectionName = "ReachabilitySignals";
|
|
|
|
/// <summary>
|
|
/// Base URI for the Signals service.
|
|
/// </summary>
|
|
public Uri? BaseUri { get; set; }
|
|
|
|
/// <summary>
|
|
/// Base URI for the ReachGraph Store service.
|
|
/// If null, uses the same base URI as Signals.
|
|
/// </summary>
|
|
public Uri? ReachGraphStoreBaseUri { get; set; }
|
|
|
|
/// <summary>
|
|
/// Maximum concurrent requests for batch operations.
|
|
/// Default: 10.
|
|
/// </summary>
|
|
public int MaxConcurrentRequests { get; set; } = 10;
|
|
|
|
/// <summary>
|
|
/// Request timeout.
|
|
/// Default: 30 seconds.
|
|
/// </summary>
|
|
public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(30);
|
|
|
|
/// <summary>
|
|
/// Retry count for transient failures.
|
|
/// Default: 3.
|
|
/// </summary>
|
|
public int RetryCount { get; set; } = 3;
|
|
}
|