// ----------------------------------------------------------------------------- // IdempotencyMiddleware.cs // Sprint: SPRINT_3500_0002_0003_proof_replay_api // Task: T3 - Idempotency Middleware // Description: Middleware for POST endpoint idempotency using Content-Digest header // ----------------------------------------------------------------------------- using System.IO; using System.Security.Cryptography; using System.Text; using System.Text.Json; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using StellaOps.Scanner.Storage.Entities; using StellaOps.Scanner.Storage.Repositories; using StellaOps.Scanner.WebService.Options; namespace StellaOps.Scanner.WebService.Middleware; /// /// Middleware that implements idempotency for POST endpoints using RFC 9530 Content-Digest header. /// public sealed class IdempotencyMiddleware { private readonly RequestDelegate _next; private readonly ILogger _logger; public IdempotencyMiddleware( RequestDelegate next, ILogger logger) { _next = next ?? throw new ArgumentNullException(nameof(next)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); } public async Task InvokeAsync( HttpContext context, IIdempotencyKeyRepository repository, IOptions options) { ArgumentNullException.ThrowIfNull(context); ArgumentNullException.ThrowIfNull(repository); ArgumentNullException.ThrowIfNull(options); var opts = options.Value; // Only apply to POST requests if (!HttpMethods.IsPost(context.Request.Method)) { await _next(context).ConfigureAwait(false); return; } // Check if idempotency is enabled if (!opts.Enabled) { await _next(context).ConfigureAwait(false); return; } // Check if this endpoint is in the list of idempotent endpoints var path = context.Request.Path.Value ?? string.Empty; if (!IsIdempotentEndpoint(path, opts.IdempotentEndpoints)) { await _next(context).ConfigureAwait(false); return; } // Get or compute Content-Digest var contentDigest = await GetOrComputeContentDigestAsync(context.Request).ConfigureAwait(false); if (string.IsNullOrEmpty(contentDigest)) { await _next(context).ConfigureAwait(false); return; } // Get tenant ID from claims or use default var tenantId = GetTenantId(context); // Check for existing idempotency key var existingKey = await repository.TryGetAsync(tenantId, contentDigest, path, context.RequestAborted) .ConfigureAwait(false); if (existingKey is not null) { _logger.LogInformation( "Returning cached response for idempotency key {KeyId}, tenant {TenantId}", existingKey.KeyId, tenantId); await WriteCachedResponseAsync(context, existingKey).ConfigureAwait(false); return; } // Enable response buffering to capture response body var originalBodyStream = context.Response.Body; using var responseBuffer = new MemoryStream(); context.Response.Body = responseBuffer; try { await _next(context).ConfigureAwait(false); // Only cache successful responses (2xx) if (context.Response.StatusCode >= 200 && context.Response.StatusCode < 300) { responseBuffer.Position = 0; var responseBody = await new StreamReader(responseBuffer).ReadToEndAsync(context.RequestAborted) .ConfigureAwait(false); var idempotencyKey = new IdempotencyKeyRow { TenantId = tenantId, ContentDigest = contentDigest, EndpointPath = path, ResponseStatus = context.Response.StatusCode, ResponseBody = responseBody, ResponseHeaders = SerializeHeaders(context.Response.Headers), CreatedAt = DateTimeOffset.UtcNow, ExpiresAt = DateTimeOffset.UtcNow.Add(opts.Window) }; try { await repository.SaveAsync(idempotencyKey, context.RequestAborted).ConfigureAwait(false); _logger.LogDebug( "Cached idempotency key for tenant {TenantId}, digest {ContentDigest}", tenantId, contentDigest); } catch (Exception ex) { // Log but don't fail the request if caching fails _logger.LogWarning(ex, "Failed to cache idempotency key"); } } // Copy buffered response to original stream responseBuffer.Position = 0; await responseBuffer.CopyToAsync(originalBodyStream, context.RequestAborted).ConfigureAwait(false); } finally { context.Response.Body = originalBodyStream; } } private static bool IsIdempotentEndpoint(string path, IReadOnlyList idempotentEndpoints) { foreach (var pattern in idempotentEndpoints) { if (path.StartsWith(pattern, StringComparison.OrdinalIgnoreCase)) { return true; } } return false; } private static async Task GetOrComputeContentDigestAsync(HttpRequest request) { // Check for existing Content-Digest header per RFC 9530 if (request.Headers.TryGetValue("Content-Digest", out var digestHeader) && !string.IsNullOrWhiteSpace(digestHeader)) { return digestHeader.ToString(); } // Compute digest from request body if (request.ContentLength is null or 0) { return null; } request.EnableBuffering(); request.Body.Position = 0; using var sha256 = SHA256.Create(); var hash = await sha256.ComputeHashAsync(request.Body).ConfigureAwait(false); request.Body.Position = 0; var base64Hash = Convert.ToBase64String(hash); return $"sha-256=:{base64Hash}:"; } private static string GetTenantId(HttpContext context) { // Try to get tenant from claims var tenantClaim = context.User?.FindFirst("tenant_id")?.Value; if (!string.IsNullOrEmpty(tenantClaim)) { return tenantClaim; } // Fall back to client IP or default var clientIp = context.Connection.RemoteIpAddress?.ToString(); return !string.IsNullOrEmpty(clientIp) ? $"ip:{clientIp}" : "default"; } private static async Task WriteCachedResponseAsync(HttpContext context, IdempotencyKeyRow key) { context.Response.StatusCode = key.ResponseStatus; context.Response.ContentType = "application/json"; // Add idempotency headers context.Response.Headers["X-Idempotency-Key"] = key.KeyId.ToString(); context.Response.Headers["X-Idempotency-Cached"] = "true"; // Replay cached headers if (!string.IsNullOrEmpty(key.ResponseHeaders)) { try { var headers = JsonSerializer.Deserialize>(key.ResponseHeaders); if (headers is not null) { foreach (var (name, value) in headers) { if (!IsRestrictedHeader(name)) { context.Response.Headers[name] = value; } } } } catch { // Ignore header deserialization errors } } if (!string.IsNullOrEmpty(key.ResponseBody)) { await context.Response.WriteAsync(key.ResponseBody).ConfigureAwait(false); } } private static string? SerializeHeaders(IHeaderDictionary headers) { var selected = new Dictionary(); foreach (var header in headers) { if (ShouldCacheHeader(header.Key)) { selected[header.Key] = header.Value.ToString(); } } return selected.Count > 0 ? JsonSerializer.Serialize(selected) : null; } private static bool ShouldCacheHeader(string name) { // Only cache specific headers return name.StartsWith("X-", StringComparison.OrdinalIgnoreCase) || string.Equals(name, "Location", StringComparison.OrdinalIgnoreCase) || string.Equals(name, "Content-Digest", StringComparison.OrdinalIgnoreCase); } private static bool IsRestrictedHeader(string name) { // Headers that should not be replayed return string.Equals(name, "Content-Length", StringComparison.OrdinalIgnoreCase) || string.Equals(name, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase) || string.Equals(name, "Connection", StringComparison.OrdinalIgnoreCase); } }