using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StackExchange.Redis;
using StellaOps.BinaryIndex.Contracts.Resolution;
using System.Text.Json;
namespace StellaOps.BinaryIndex.Cache;
///
/// Caching service for binary resolution results.
/// Uses Valkey/Redis for high-performance caching with configurable TTLs.
///
public interface IResolutionCacheService
{
///
/// Get cached resolution status.
///
/// The cache key.
/// Cancellation token.
/// Cached resolution if found, null otherwise.
Task GetAsync(string cacheKey, CancellationToken ct = default);
///
/// Cache resolution result.
///
/// The cache key.
/// The resolution result to cache.
/// Time-to-live for the cache entry.
/// Cancellation token.
Task SetAsync(string cacheKey, CachedResolution result, TimeSpan ttl, CancellationToken ct = default);
///
/// Invalidate cache entries by pattern.
///
/// Redis pattern (e.g., "resolution:*:debian:*").
/// Cancellation token.
Task InvalidateByPatternAsync(string pattern, CancellationToken ct = default);
///
/// Generate cache key from resolution request.
///
/// The resolution request.
/// Deterministic cache key.
string GenerateCacheKey(VulnResolutionRequest request);
}
///
/// Cached resolution entry.
///
public sealed record CachedResolution
{
/// Resolution status.
public required ResolutionStatus Status { get; init; }
/// Fixed version if applicable.
public string? FixedVersion { get; init; }
/// Reference to evidence record.
public string? EvidenceRef { get; init; }
/// When this entry was cached.
public DateTimeOffset CachedAt { get; init; }
/// Version key for invalidation.
public string? VersionKey { get; init; }
/// Confidence score.
public decimal Confidence { get; init; }
/// Match type used.
public string? MatchType { get; init; }
}
///
/// Configuration options for resolution caching.
///
public sealed class ResolutionCacheOptions
{
/// Configuration section name.
public const string SectionName = "ResolutionCache";
/// TTL for fixed (high confidence) results.
public TimeSpan FixedTtl { get; set; } = TimeSpan.FromHours(24);
/// TTL for vulnerable results.
public TimeSpan VulnerableTtl { get; set; } = TimeSpan.FromHours(4);
/// TTL for unknown results.
public TimeSpan UnknownTtl { get; set; } = TimeSpan.FromHours(1);
/// Cache key prefix.
public string KeyPrefix { get; set; } = "resolution";
/// Enable probabilistic early expiry to prevent stampedes.
public bool EnableEarlyExpiry { get; set; } = true;
/// Early expiry factor (0.0-1.0).
public double EarlyExpiryFactor { get; set; } = 0.1;
}
///
/// Valkey/Redis implementation of resolution caching.
///
public sealed class ResolutionCacheService : IResolutionCacheService
{
private readonly IConnectionMultiplexer _redis;
private readonly ResolutionCacheOptions _options;
private readonly ILogger _logger;
private readonly JsonSerializerOptions _jsonOptions;
private readonly IRandomSource _random;
public ResolutionCacheService(
IConnectionMultiplexer redis,
IOptions options,
ILogger logger,
IRandomSource random)
{
_redis = redis ?? throw new ArgumentNullException(nameof(redis));
_options = options?.Value ?? throw new ArgumentNullException(nameof(options));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_random = random ?? throw new ArgumentNullException(nameof(random));
_jsonOptions = new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
WriteIndented = false
};
}
///
public async Task GetAsync(string cacheKey, CancellationToken ct = default)
{
try
{
var db = _redis.GetDatabase();
var value = await db.StringGetAsync(cacheKey).WaitAsync(ct).ConfigureAwait(false);
if (value.IsNullOrEmpty)
{
_logger.LogDebug("Cache miss for key {CacheKey}", cacheKey);
return null;
}
var cached = JsonSerializer.Deserialize(value.ToString(), _jsonOptions);
// Check for probabilistic early expiry
if (_options.EnableEarlyExpiry && cached is not null)
{
var ttl = await db.KeyTimeToLiveAsync(cacheKey).WaitAsync(ct).ConfigureAwait(false);
if (ShouldExpireEarly(ttl))
{
_logger.LogDebug("Early expiry triggered for key {CacheKey}", cacheKey);
return null;
}
}
_logger.LogDebug("Cache hit for key {CacheKey}", cacheKey);
return cached;
}
catch (OperationCanceledException)
{
throw;
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to get cache entry for key {CacheKey}", cacheKey);
return null;
}
}
///
public async Task SetAsync(string cacheKey, CachedResolution result, TimeSpan ttl, CancellationToken ct = default)
{
try
{
var db = _redis.GetDatabase();
var value = JsonSerializer.Serialize(result, _jsonOptions);
await db.StringSetAsync(cacheKey, value, ttl).WaitAsync(ct).ConfigureAwait(false);
_logger.LogDebug("Cached resolution for key {CacheKey} with TTL {Ttl}", cacheKey, ttl);
}
catch (OperationCanceledException)
{
throw;
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to cache resolution for key {CacheKey}", cacheKey);
}
}
///
public async Task InvalidateByPatternAsync(string pattern, CancellationToken ct = default)
{
try
{
var db = _redis.GetDatabase();
var endpoints = _redis.GetEndPoints();
if (endpoints.Length == 0)
{
_logger.LogWarning("No Redis endpoints available for pattern invalidation");
return;
}
const int batchSize = 500;
long totalDeleted = 0;
foreach (var endpoint in endpoints)
{
ct.ThrowIfCancellationRequested();
var server = _redis.GetServer(endpoint);
if (!server.IsConnected)
{
continue;
}
var buffer = new List(batchSize);
foreach (var key in server.Keys(pattern: pattern, pageSize: batchSize))
{
ct.ThrowIfCancellationRequested();
buffer.Add(key);
if (buffer.Count >= batchSize)
{
totalDeleted += await db.KeyDeleteAsync(buffer.ToArray()).WaitAsync(ct).ConfigureAwait(false);
buffer.Clear();
}
}
if (buffer.Count > 0)
{
totalDeleted += await db.KeyDeleteAsync(buffer.ToArray()).WaitAsync(ct).ConfigureAwait(false);
}
}
if (totalDeleted > 0)
{
_logger.LogInformation(
"Invalidated {Count} cache entries matching pattern {Pattern}",
totalDeleted,
pattern);
}
}
catch (OperationCanceledException)
{
throw;
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to invalidate cache entries matching pattern {Pattern}", pattern);
}
}
///
public string GenerateCacheKey(VulnResolutionRequest request)
{
ArgumentNullException.ThrowIfNull(request);
// Build deterministic cache key
// Format: resolution:{algorithm}:{hash}:{cve_id_or_all}
var algorithm = DetermineAlgorithm(request);
var hash = ComputeIdentityHash(request);
var cveId = request.CveId ?? "all";
return $"{_options.KeyPrefix}:{algorithm}:{hash}:{cveId}";
}
///
/// Get appropriate TTL based on resolution status.
///
public TimeSpan GetTtlForStatus(ResolutionStatus status)
{
return status switch
{
ResolutionStatus.Fixed => _options.FixedTtl,
ResolutionStatus.Vulnerable => _options.VulnerableTtl,
ResolutionStatus.NotAffected => _options.FixedTtl,
_ => _options.UnknownTtl
};
}
private static string DetermineAlgorithm(VulnResolutionRequest request)
{
if (!string.IsNullOrEmpty(request.BuildId))
return "build_id";
if (!string.IsNullOrEmpty(request.Fingerprint))
return request.FingerprintAlgorithm ?? "combined";
if (request.Hashes?.TextSha256 != null)
return "text_sha256";
if (request.Hashes?.FileSha256 != null)
return "file_sha256";
return "package";
}
private static string ComputeIdentityHash(VulnResolutionRequest request)
{
// Use the most specific identifier available
if (!string.IsNullOrEmpty(request.BuildId))
return request.BuildId;
if (!string.IsNullOrEmpty(request.Fingerprint))
return ComputeShortHash(request.Fingerprint);
if (request.Hashes?.TextSha256 != null)
return request.Hashes.TextSha256;
if (request.Hashes?.FileSha256 != null)
return request.Hashes.FileSha256;
// Fall back to package + distro
var key = $"{request.Package}:{request.DistroRelease ?? "unknown"}";
return ComputeShortHash(key);
}
private static string ComputeShortHash(string input)
{
var bytes = System.Text.Encoding.UTF8.GetBytes(input);
var hash = System.Security.Cryptography.SHA256.HashData(bytes);
return Convert.ToHexStringLower(hash)[..16];
}
private bool ShouldExpireEarly(TimeSpan? remainingTtl)
{
if (!remainingTtl.HasValue || remainingTtl.Value <= TimeSpan.Zero)
return true;
// Probabilistic early expiry using exponential decay
var random = _random.NextDouble();
var threshold = _options.EarlyExpiryFactor * Math.Exp(-remainingTtl.Value.TotalSeconds / 3600);
return random < threshold;
}
}