Add unit tests for Router configuration and transport layers
- Implemented tests for RouterConfig, RoutingOptions, StaticInstanceConfig, and RouterConfigOptions to ensure default values are set correctly. - Added tests for RouterConfigProvider to validate configurations and ensure defaults are returned when no file is specified. - Created tests for ConfigValidationResult to check success and error scenarios. - Developed tests for ServiceCollectionExtensions to verify service registration for RouterConfig. - Introduced UdpTransportTests to validate serialization, connection, request-response, and error handling in UDP transport. - Added scripts for signing authority gaps and hashing DevPortal SDK snippets.
This commit is contained in:
@@ -0,0 +1,135 @@
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// A stream wrapper that counts bytes read and enforces a limit.
|
||||
/// </summary>
|
||||
public sealed class ByteCountingStream : Stream
|
||||
{
|
||||
private readonly Stream _inner;
|
||||
private readonly long _limit;
|
||||
private readonly Action? _onLimitExceeded;
|
||||
private long _bytesRead;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ByteCountingStream"/> class.
|
||||
/// </summary>
|
||||
/// <param name="inner">The inner stream to wrap.</param>
|
||||
/// <param name="limit">The maximum number of bytes that can be read.</param>
|
||||
/// <param name="onLimitExceeded">Optional callback invoked when the limit is exceeded.</param>
|
||||
public ByteCountingStream(Stream inner, long limit, Action? onLimitExceeded = null)
|
||||
{
|
||||
_inner = inner;
|
||||
_limit = limit;
|
||||
_onLimitExceeded = onLimitExceeded;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the total number of bytes read from this stream.
|
||||
/// </summary>
|
||||
public long BytesRead => Interlocked.Read(ref _bytesRead);
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanRead => _inner.CanRead;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanSeek => false;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanWrite => false;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Length => _inner.Length;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Position
|
||||
{
|
||||
get => _inner.Position;
|
||||
set => throw new NotSupportedException("Seeking not supported on ByteCountingStream.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void Flush() => _inner.Flush();
|
||||
|
||||
/// <inheritdoc />
|
||||
public override Task FlushAsync(CancellationToken cancellationToken) =>
|
||||
_inner.FlushAsync(cancellationToken);
|
||||
|
||||
/// <inheritdoc />
|
||||
public override int Read(byte[] buffer, int offset, int count)
|
||||
{
|
||||
var read = _inner.Read(buffer, offset, count);
|
||||
CheckLimit(read);
|
||||
return read;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
{
|
||||
var read = await _inner.ReadAsync(buffer, offset, count, cancellationToken);
|
||||
CheckLimit(read);
|
||||
return read;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var read = await _inner.ReadAsync(buffer, cancellationToken);
|
||||
CheckLimit(read);
|
||||
return read;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Seek(long offset, SeekOrigin origin)
|
||||
{
|
||||
throw new NotSupportedException("Seeking not supported on ByteCountingStream.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void SetLength(long value)
|
||||
{
|
||||
throw new NotSupportedException("Setting length not supported on ByteCountingStream.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void Write(byte[] buffer, int offset, int count)
|
||||
{
|
||||
throw new NotSupportedException("Writing not supported on ByteCountingStream.");
|
||||
}
|
||||
|
||||
private void CheckLimit(int bytesJustRead)
|
||||
{
|
||||
if (bytesJustRead <= 0) return;
|
||||
|
||||
var newTotal = Interlocked.Add(ref _bytesRead, bytesJustRead);
|
||||
if (newTotal > _limit)
|
||||
{
|
||||
_onLimitExceeded?.Invoke();
|
||||
throw new PayloadLimitExceededException(newTotal, _limit);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
if (!_disposed && disposing)
|
||||
{
|
||||
_inner.Dispose();
|
||||
}
|
||||
|
||||
_disposed = true;
|
||||
base.Dispose(disposing);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async ValueTask DisposeAsync()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
await _inner.DisposeAsync();
|
||||
}
|
||||
|
||||
_disposed = true;
|
||||
await base.DisposeAsync();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Resolves incoming HTTP requests to endpoint descriptors using the routing state.
|
||||
/// </summary>
|
||||
public sealed class EndpointResolutionMiddleware
|
||||
{
|
||||
private readonly RequestDelegate _next;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="EndpointResolutionMiddleware"/> class.
|
||||
/// </summary>
|
||||
public EndpointResolutionMiddleware(RequestDelegate next)
|
||||
{
|
||||
_next = next;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes the middleware.
|
||||
/// </summary>
|
||||
public async Task Invoke(HttpContext context, IGlobalRoutingState routingState)
|
||||
{
|
||||
var method = context.Request.Method;
|
||||
var path = context.Request.Path.ToString();
|
||||
|
||||
var endpoint = routingState.ResolveEndpoint(method, path);
|
||||
if (endpoint is null)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status404NotFound;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Endpoint not found",
|
||||
method,
|
||||
path
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
context.Items[RouterHttpContextKeys.EndpointDescriptor] = endpoint;
|
||||
await _next(context);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Exception thrown when a payload limit is exceeded during streaming.
|
||||
/// </summary>
|
||||
public sealed class PayloadLimitExceededException : Exception
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PayloadLimitExceededException"/> class.
|
||||
/// </summary>
|
||||
/// <param name="bytesRead">The number of bytes read before the limit was exceeded.</param>
|
||||
/// <param name="limit">The limit that was exceeded.</param>
|
||||
public PayloadLimitExceededException(long bytesRead, long limit)
|
||||
: base($"Payload limit exceeded: {bytesRead} bytes read, limit is {limit} bytes")
|
||||
{
|
||||
BytesRead = bytesRead;
|
||||
Limit = limit;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the number of bytes read before the limit was exceeded.
|
||||
/// </summary>
|
||||
public long BytesRead { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the limit that was exceeded.
|
||||
/// </summary>
|
||||
public long Limit { get; }
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Middleware that enforces payload limits per-request, per-connection, and aggregate.
|
||||
/// </summary>
|
||||
public sealed class PayloadLimitsMiddleware
|
||||
{
|
||||
private readonly RequestDelegate _next;
|
||||
private readonly PayloadLimits _limits;
|
||||
private readonly ILogger<PayloadLimitsMiddleware> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PayloadLimitsMiddleware"/> class.
|
||||
/// </summary>
|
||||
public PayloadLimitsMiddleware(
|
||||
RequestDelegate next,
|
||||
IOptions<PayloadLimits> limits,
|
||||
ILogger<PayloadLimitsMiddleware> logger)
|
||||
{
|
||||
_next = next;
|
||||
_limits = limits.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes the middleware.
|
||||
/// </summary>
|
||||
public async Task Invoke(HttpContext context, IPayloadTracker tracker)
|
||||
{
|
||||
var connectionId = context.Connection.Id;
|
||||
var contentLength = context.Request.ContentLength ?? 0;
|
||||
|
||||
// Early rejection for known oversized Content-Length (LIM-002, LIM-003)
|
||||
if (context.Request.ContentLength.HasValue &&
|
||||
context.Request.ContentLength.Value > _limits.MaxRequestBytesPerCall)
|
||||
{
|
||||
_logger.LogWarning(
|
||||
"Request rejected: Content-Length {ContentLength} exceeds per-call limit {Limit}. ConnectionId: {ConnectionId}",
|
||||
context.Request.ContentLength.Value,
|
||||
_limits.MaxRequestBytesPerCall,
|
||||
connectionId);
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status413PayloadTooLarge;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Payload Too Large",
|
||||
maxBytes = _limits.MaxRequestBytesPerCall,
|
||||
contentLength = context.Request.ContentLength.Value
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to reserve capacity (checks aggregate and per-connection limits)
|
||||
if (!tracker.TryReserve(connectionId, contentLength))
|
||||
{
|
||||
// Check which limit was hit
|
||||
if (tracker.IsOverloaded)
|
||||
{
|
||||
// Aggregate limit exceeded (LIM-033)
|
||||
_logger.LogWarning(
|
||||
"Request rejected: Aggregate limit exceeded. Current inflight: {Current}, Limit: {Limit}. ConnectionId: {ConnectionId}",
|
||||
tracker.CurrentInflightBytes,
|
||||
_limits.MaxAggregateInflightBytes,
|
||||
connectionId);
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Service Overloaded",
|
||||
message = "Too many concurrent requests"
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Per-connection limit exceeded (LIM-022)
|
||||
_logger.LogWarning(
|
||||
"Request rejected: Per-connection limit exceeded. ConnectionId: {ConnectionId}, Current: {Current}, Limit: {Limit}",
|
||||
connectionId,
|
||||
tracker.GetConnectionInflightBytes(connectionId),
|
||||
_limits.MaxRequestBytesPerConnection);
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status429TooManyRequests;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Too Many Requests",
|
||||
message = "Per-connection limit exceeded"
|
||||
});
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Store the original body stream
|
||||
var originalBody = context.Request.Body;
|
||||
long actualBytesRead = 0;
|
||||
|
||||
try
|
||||
{
|
||||
// Wrap the request body with ByteCountingStream for streaming requests
|
||||
if (!context.Request.ContentLength.HasValue || context.Request.ContentLength.Value > 0)
|
||||
{
|
||||
var countingStream = new ByteCountingStream(
|
||||
originalBody,
|
||||
_limits.MaxRequestBytesPerCall,
|
||||
() =>
|
||||
{
|
||||
_logger.LogWarning(
|
||||
"Mid-stream limit exceeded. ConnectionId: {ConnectionId}, Limit: {Limit}",
|
||||
connectionId,
|
||||
_limits.MaxRequestBytesPerCall);
|
||||
});
|
||||
|
||||
context.Request.Body = countingStream;
|
||||
|
||||
// Store reference for later access to bytes read
|
||||
context.Items["PayloadLimits:CountingStream"] = countingStream;
|
||||
}
|
||||
|
||||
await _next(context);
|
||||
|
||||
// Get actual bytes read
|
||||
if (context.Items["PayloadLimits:CountingStream"] is ByteCountingStream cs)
|
||||
{
|
||||
actualBytesRead = cs.BytesRead;
|
||||
}
|
||||
}
|
||||
catch (PayloadLimitExceededException ex)
|
||||
{
|
||||
_logger.LogWarning(
|
||||
"Payload limit exceeded mid-stream. ConnectionId: {ConnectionId}, BytesRead: {BytesRead}, Limit: {Limit}",
|
||||
connectionId,
|
||||
ex.BytesRead,
|
||||
ex.Limit);
|
||||
|
||||
// Only set response if not already started
|
||||
if (!context.Response.HasStarted)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status413PayloadTooLarge;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Payload Too Large",
|
||||
maxBytes = _limits.MaxRequestBytesPerCall,
|
||||
bytesReceived = ex.BytesRead
|
||||
});
|
||||
}
|
||||
|
||||
actualBytesRead = ex.BytesRead;
|
||||
}
|
||||
finally
|
||||
{
|
||||
// Restore original body stream
|
||||
context.Request.Body = originalBody;
|
||||
|
||||
// Release reserved capacity
|
||||
var bytesToRelease = actualBytesRead > 0 ? actualBytesRead : contentLength;
|
||||
tracker.Release(connectionId, bytesToRelease);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
using System.Collections.Concurrent;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Tracks payload bytes across requests, connections, and globally.
|
||||
/// </summary>
|
||||
public interface IPayloadTracker
|
||||
{
|
||||
/// <summary>
|
||||
/// Tries to reserve capacity for an estimated payload size.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection identifier.</param>
|
||||
/// <param name="estimatedBytes">The estimated bytes to reserve.</param>
|
||||
/// <returns>True if capacity was reserved; false if limits would be exceeded.</returns>
|
||||
bool TryReserve(string connectionId, long estimatedBytes);
|
||||
|
||||
/// <summary>
|
||||
/// Releases previously reserved capacity.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection identifier.</param>
|
||||
/// <param name="actualBytes">The actual bytes to release.</param>
|
||||
void Release(string connectionId, long actualBytes);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the current total inflight bytes across all connections.
|
||||
/// </summary>
|
||||
long CurrentInflightBytes { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether the system is overloaded.
|
||||
/// </summary>
|
||||
bool IsOverloaded { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the current inflight bytes for a specific connection.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection identifier.</param>
|
||||
/// <returns>The current inflight bytes for the connection.</returns>
|
||||
long GetConnectionInflightBytes(string connectionId);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Default implementation of <see cref="IPayloadTracker"/>.
|
||||
/// </summary>
|
||||
public sealed class PayloadTracker : IPayloadTracker
|
||||
{
|
||||
private readonly PayloadLimits _limits;
|
||||
private readonly ILogger<PayloadTracker> _logger;
|
||||
private long _totalInflightBytes;
|
||||
private readonly ConcurrentDictionary<string, long> _perConnectionBytes = new();
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PayloadTracker"/> class.
|
||||
/// </summary>
|
||||
public PayloadTracker(IOptions<PayloadLimits> limits, ILogger<PayloadTracker> logger)
|
||||
{
|
||||
_limits = limits.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public long CurrentInflightBytes => Interlocked.Read(ref _totalInflightBytes);
|
||||
|
||||
/// <inheritdoc />
|
||||
public bool IsOverloaded => CurrentInflightBytes > _limits.MaxAggregateInflightBytes;
|
||||
|
||||
/// <inheritdoc />
|
||||
public bool TryReserve(string connectionId, long estimatedBytes)
|
||||
{
|
||||
// Check aggregate limit
|
||||
var newTotal = Interlocked.Add(ref _totalInflightBytes, estimatedBytes);
|
||||
if (newTotal > _limits.MaxAggregateInflightBytes)
|
||||
{
|
||||
Interlocked.Add(ref _totalInflightBytes, -estimatedBytes);
|
||||
_logger.LogWarning(
|
||||
"Aggregate payload limit exceeded. Current: {Current}, Limit: {Limit}",
|
||||
newTotal - estimatedBytes,
|
||||
_limits.MaxAggregateInflightBytes);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check per-connection limit
|
||||
var connectionBytes = _perConnectionBytes.AddOrUpdate(
|
||||
connectionId,
|
||||
estimatedBytes,
|
||||
(_, current) => current + estimatedBytes);
|
||||
|
||||
if (connectionBytes > _limits.MaxRequestBytesPerConnection)
|
||||
{
|
||||
// Roll back
|
||||
_perConnectionBytes.AddOrUpdate(
|
||||
connectionId,
|
||||
0,
|
||||
(_, current) => current - estimatedBytes);
|
||||
Interlocked.Add(ref _totalInflightBytes, -estimatedBytes);
|
||||
|
||||
_logger.LogWarning(
|
||||
"Per-connection payload limit exceeded for {ConnectionId}. Current: {Current}, Limit: {Limit}",
|
||||
connectionId,
|
||||
connectionBytes - estimatedBytes,
|
||||
_limits.MaxRequestBytesPerConnection);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Release(string connectionId, long actualBytes)
|
||||
{
|
||||
Interlocked.Add(ref _totalInflightBytes, -actualBytes);
|
||||
|
||||
_perConnectionBytes.AddOrUpdate(
|
||||
connectionId,
|
||||
0,
|
||||
(_, current) => Math.Max(0, current - actualBytes));
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public long GetConnectionInflightBytes(string connectionId)
|
||||
{
|
||||
return _perConnectionBytes.TryGetValue(connectionId, out var bytes) ? bytes : 0;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Makes routing decisions for resolved endpoints.
|
||||
/// </summary>
|
||||
public sealed class RoutingDecisionMiddleware
|
||||
{
|
||||
private readonly RequestDelegate _next;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RoutingDecisionMiddleware"/> class.
|
||||
/// </summary>
|
||||
public RoutingDecisionMiddleware(RequestDelegate next)
|
||||
{
|
||||
_next = next;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes the middleware.
|
||||
/// </summary>
|
||||
public async Task Invoke(
|
||||
HttpContext context,
|
||||
IRoutingPlugin routingPlugin,
|
||||
IGlobalRoutingState routingState,
|
||||
IOptions<GatewayNodeConfig> gatewayConfig,
|
||||
IOptions<RoutingOptions> routingOptions)
|
||||
{
|
||||
var endpoint = context.Items[RouterHttpContextKeys.EndpointDescriptor] as EndpointDescriptor;
|
||||
if (endpoint is null)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status500InternalServerError;
|
||||
await context.Response.WriteAsJsonAsync(new { error = "Endpoint descriptor missing" });
|
||||
return;
|
||||
}
|
||||
|
||||
// Build routing context
|
||||
var availableConnections = routingState.GetConnectionsFor(
|
||||
endpoint.ServiceName,
|
||||
endpoint.Version,
|
||||
endpoint.Method,
|
||||
endpoint.Path);
|
||||
|
||||
var headers = context.Request.Headers
|
||||
.ToDictionary(h => h.Key, h => h.Value.ToString());
|
||||
|
||||
var routingContext = new RoutingContext
|
||||
{
|
||||
Method = context.Request.Method,
|
||||
Path = context.Request.Path.ToString(),
|
||||
Headers = headers,
|
||||
Endpoint = endpoint,
|
||||
AvailableConnections = availableConnections,
|
||||
GatewayRegion = gatewayConfig.Value.Region,
|
||||
RequestedVersion = ExtractVersionFromRequest(context, routingOptions.Value),
|
||||
CancellationToken = context.RequestAborted
|
||||
};
|
||||
|
||||
var decision = await routingPlugin.ChooseInstanceAsync(
|
||||
routingContext,
|
||||
context.RequestAborted);
|
||||
|
||||
if (decision is null)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "No instances available",
|
||||
service = endpoint.ServiceName,
|
||||
version = endpoint.Version
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
context.Items[RouterHttpContextKeys.RoutingDecision] = decision;
|
||||
await _next(context);
|
||||
}
|
||||
|
||||
private static string? ExtractVersionFromRequest(HttpContext context, RoutingOptions options)
|
||||
{
|
||||
// Check for version in Accept header: Accept: application/vnd.stellaops.v1+json
|
||||
var acceptHeader = context.Request.Headers.Accept.FirstOrDefault();
|
||||
if (!string.IsNullOrEmpty(acceptHeader))
|
||||
{
|
||||
var versionMatch = System.Text.RegularExpressions.Regex.Match(
|
||||
acceptHeader,
|
||||
@"application/vnd\.stellaops\.v(\d+(?:\.\d+)*)\+json");
|
||||
if (versionMatch.Success)
|
||||
{
|
||||
return versionMatch.Groups[1].Value;
|
||||
}
|
||||
}
|
||||
|
||||
// Check for X-Api-Version header
|
||||
var versionHeader = context.Request.Headers["X-Api-Version"].FirstOrDefault();
|
||||
if (!string.IsNullOrEmpty(versionHeader))
|
||||
{
|
||||
return versionHeader;
|
||||
}
|
||||
|
||||
// Fall back to default version from options
|
||||
return options.DefaultVersion;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,457 @@
|
||||
using System.Collections.Concurrent;
|
||||
using System.Diagnostics;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Frames;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Dispatches HTTP requests to microservices via the transport layer.
|
||||
/// </summary>
|
||||
public sealed class TransportDispatchMiddleware
|
||||
{
|
||||
private readonly RequestDelegate _next;
|
||||
private readonly ILogger<TransportDispatchMiddleware> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Tracks cancelled request IDs to ignore late responses.
|
||||
/// Keys expire after 60 seconds to prevent memory leaks.
|
||||
/// </summary>
|
||||
private static readonly ConcurrentDictionary<string, DateTimeOffset> CancelledRequests = new();
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TransportDispatchMiddleware"/> class.
|
||||
/// </summary>
|
||||
public TransportDispatchMiddleware(RequestDelegate next, ILogger<TransportDispatchMiddleware> logger)
|
||||
{
|
||||
_next = next;
|
||||
_logger = logger;
|
||||
|
||||
// Start background cleanup task for expired cancelled request entries
|
||||
_ = Task.Run(CleanupExpiredCancelledRequestsAsync);
|
||||
}
|
||||
|
||||
private static async Task CleanupExpiredCancelledRequestsAsync()
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
await Task.Delay(TimeSpan.FromSeconds(30));
|
||||
|
||||
var cutoff = DateTimeOffset.UtcNow.AddSeconds(-60);
|
||||
foreach (var kvp in CancelledRequests)
|
||||
{
|
||||
if (kvp.Value < cutoff)
|
||||
{
|
||||
CancelledRequests.TryRemove(kvp.Key, out _);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void MarkCancelled(string requestId)
|
||||
{
|
||||
CancelledRequests[requestId] = DateTimeOffset.UtcNow;
|
||||
}
|
||||
|
||||
private static bool IsCancelled(string requestId)
|
||||
{
|
||||
return CancelledRequests.ContainsKey(requestId);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes the middleware.
|
||||
/// </summary>
|
||||
public async Task Invoke(
|
||||
HttpContext context,
|
||||
ITransportClient transportClient,
|
||||
IGlobalRoutingState routingState)
|
||||
{
|
||||
var decision = context.Items[RouterHttpContextKeys.RoutingDecision] as RoutingDecision;
|
||||
if (decision is null)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status500InternalServerError;
|
||||
await context.Response.WriteAsJsonAsync(new { error = "Routing decision missing" });
|
||||
return;
|
||||
}
|
||||
|
||||
var requestId = Guid.NewGuid().ToString("N");
|
||||
|
||||
// Extract headers (exclude some internal headers)
|
||||
var headers = context.Request.Headers
|
||||
.Where(h => !h.Key.StartsWith(":", StringComparison.Ordinal))
|
||||
.ToDictionary(
|
||||
h => h.Key,
|
||||
h => h.Value.ToString());
|
||||
|
||||
// For streaming endpoints, use streaming dispatch
|
||||
if (decision.Endpoint.SupportsStreaming)
|
||||
{
|
||||
await DispatchStreamingAsync(context, transportClient, routingState, decision, requestId, headers);
|
||||
return;
|
||||
}
|
||||
|
||||
// Read request body (buffered)
|
||||
byte[] bodyBytes;
|
||||
using (var ms = new MemoryStream())
|
||||
{
|
||||
await context.Request.Body.CopyToAsync(ms, context.RequestAborted);
|
||||
bodyBytes = ms.ToArray();
|
||||
}
|
||||
|
||||
// Build request frame
|
||||
var requestFrame = new RequestFrame
|
||||
{
|
||||
RequestId = requestId,
|
||||
CorrelationId = context.TraceIdentifier,
|
||||
Method = context.Request.Method,
|
||||
Path = context.Request.Path.ToString() + context.Request.QueryString.ToString(),
|
||||
Headers = headers,
|
||||
Payload = bodyBytes,
|
||||
TimeoutSeconds = (int)decision.EffectiveTimeout.TotalSeconds,
|
||||
SupportsStreaming = false
|
||||
};
|
||||
|
||||
var frame = FrameConverter.ToFrame(requestFrame);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Dispatching {Method} {Path} to {ServiceName}/{Version} via {TransportType}",
|
||||
requestFrame.Method,
|
||||
requestFrame.Path,
|
||||
decision.Connection.Instance.ServiceName,
|
||||
decision.Connection.Instance.Version,
|
||||
decision.TransportType);
|
||||
|
||||
// Create linked cancellation token with timeout
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted);
|
||||
timeoutCts.CancelAfter(decision.EffectiveTimeout);
|
||||
|
||||
// Register client disconnect handler to send CANCEL
|
||||
var requestIdGuid = Guid.TryParse(requestId, out var parsed) ? parsed : Guid.NewGuid();
|
||||
using var clientDisconnectRegistration = context.RequestAborted.Register(() =>
|
||||
{
|
||||
// Mark as cancelled to ignore late responses
|
||||
MarkCancelled(requestId);
|
||||
|
||||
// Send CANCEL frame (fire and forget)
|
||||
_ = Task.Run(async () =>
|
||||
{
|
||||
try
|
||||
{
|
||||
await transportClient.SendCancelAsync(
|
||||
decision.Connection,
|
||||
requestIdGuid,
|
||||
CancelReasons.ClientDisconnected);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Sent CANCEL for request {RequestId} due to client disconnect",
|
||||
requestId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex,
|
||||
"Failed to send CANCEL for request {RequestId} on client disconnect",
|
||||
requestId);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Frame responseFrame;
|
||||
var startTimestamp = Stopwatch.GetTimestamp();
|
||||
try
|
||||
{
|
||||
responseFrame = await transportClient.SendRequestAsync(
|
||||
decision.Connection,
|
||||
frame,
|
||||
decision.EffectiveTimeout,
|
||||
timeoutCts.Token);
|
||||
|
||||
// Record ping latency and update connection's average
|
||||
var elapsed = Stopwatch.GetElapsedTime(startTimestamp);
|
||||
UpdateConnectionPing(routingState, decision.Connection.ConnectionId, elapsed.TotalMilliseconds);
|
||||
}
|
||||
catch (OperationCanceledException) when (!context.RequestAborted.IsCancellationRequested)
|
||||
{
|
||||
// Internal timeout (not client disconnect)
|
||||
_logger.LogWarning(
|
||||
"Request {RequestId} to {ServiceName} timed out after {Timeout}",
|
||||
requestId,
|
||||
decision.Connection.Instance.ServiceName,
|
||||
decision.EffectiveTimeout);
|
||||
|
||||
// Mark as cancelled to ignore late responses
|
||||
MarkCancelled(requestId);
|
||||
|
||||
// Send cancel to microservice
|
||||
try
|
||||
{
|
||||
await transportClient.SendCancelAsync(
|
||||
decision.Connection,
|
||||
requestIdGuid,
|
||||
CancelReasons.Timeout);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Failed to send cancel for request {RequestId}", requestId);
|
||||
}
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status504GatewayTimeout;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Upstream timeout",
|
||||
service = decision.Connection.Instance.ServiceName,
|
||||
timeout = decision.EffectiveTimeout.TotalSeconds
|
||||
});
|
||||
return;
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Client disconnected - cancel already sent via registration above
|
||||
MarkCancelled(requestId);
|
||||
_logger.LogDebug("Client disconnected, request {RequestId} cancelled", requestId);
|
||||
return;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex,
|
||||
"Error dispatching request {RequestId} to {ServiceName}",
|
||||
requestId,
|
||||
decision.Connection.Instance.ServiceName);
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status502BadGateway;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Upstream error",
|
||||
message = ex.Message
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if request was cancelled while waiting for response
|
||||
if (IsCancelled(requestId))
|
||||
{
|
||||
_logger.LogDebug("Ignoring late response for cancelled request {RequestId}", requestId);
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var response = FrameConverter.ToResponseFrame(responseFrame);
|
||||
if (response is null)
|
||||
{
|
||||
_logger.LogError(
|
||||
"Invalid response frame from {ServiceName} for request {RequestId}",
|
||||
decision.Connection.Instance.ServiceName,
|
||||
requestId);
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status502BadGateway;
|
||||
await context.Response.WriteAsJsonAsync(new { error = "Invalid upstream response" });
|
||||
return;
|
||||
}
|
||||
|
||||
// Map response to HTTP
|
||||
context.Response.StatusCode = response.StatusCode;
|
||||
|
||||
// Copy response headers
|
||||
foreach (var (key, value) in response.Headers)
|
||||
{
|
||||
// Skip some headers that shouldn't be copied
|
||||
if (key.Equals("Transfer-Encoding", StringComparison.OrdinalIgnoreCase) ||
|
||||
key.Equals("Content-Length", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
context.Response.Headers[key] = value;
|
||||
}
|
||||
|
||||
// Write response body
|
||||
if (response.Payload.Length > 0)
|
||||
{
|
||||
await context.Response.Body.WriteAsync(response.Payload, context.RequestAborted);
|
||||
}
|
||||
|
||||
_logger.LogDebug(
|
||||
"Request {RequestId} completed with status {StatusCode}",
|
||||
requestId,
|
||||
response.StatusCode);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Updates the connection's average ping time using exponential moving average.
|
||||
/// </summary>
|
||||
private static void UpdateConnectionPing(
|
||||
IGlobalRoutingState routingState,
|
||||
string connectionId,
|
||||
double pingMs)
|
||||
{
|
||||
const double smoothingFactor = 0.2;
|
||||
|
||||
routingState.UpdateConnection(connectionId, connection =>
|
||||
{
|
||||
if (connection.AveragePingMs == 0)
|
||||
{
|
||||
connection.AveragePingMs = pingMs;
|
||||
}
|
||||
else
|
||||
{
|
||||
connection.AveragePingMs = (1 - smoothingFactor) * connection.AveragePingMs + smoothingFactor * pingMs;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Dispatches a streaming request to a microservice.
|
||||
/// </summary>
|
||||
private async Task DispatchStreamingAsync(
|
||||
HttpContext context,
|
||||
ITransportClient transportClient,
|
||||
IGlobalRoutingState routingState,
|
||||
RoutingDecision decision,
|
||||
string requestId,
|
||||
Dictionary<string, string> headers)
|
||||
{
|
||||
var requestIdGuid = Guid.TryParse(requestId, out var parsed) ? parsed : Guid.NewGuid();
|
||||
|
||||
// Build request header frame (without body - will stream)
|
||||
var requestFrame = new RequestFrame
|
||||
{
|
||||
RequestId = requestId,
|
||||
CorrelationId = context.TraceIdentifier,
|
||||
Method = context.Request.Method,
|
||||
Path = context.Request.Path.ToString() + context.Request.QueryString.ToString(),
|
||||
Headers = headers,
|
||||
Payload = Array.Empty<byte>(), // Empty - body will be streamed
|
||||
TimeoutSeconds = (int)decision.EffectiveTimeout.TotalSeconds,
|
||||
SupportsStreaming = true
|
||||
};
|
||||
|
||||
var frame = FrameConverter.ToFrame(requestFrame);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Dispatching streaming {Method} {Path} to {ServiceName}/{Version}",
|
||||
requestFrame.Method,
|
||||
requestFrame.Path,
|
||||
decision.Connection.Instance.ServiceName,
|
||||
decision.Connection.Instance.Version);
|
||||
|
||||
// Create linked cancellation token with timeout
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted);
|
||||
timeoutCts.CancelAfter(decision.EffectiveTimeout);
|
||||
|
||||
// Register client disconnect handler to send CANCEL
|
||||
using var clientDisconnectRegistration = context.RequestAborted.Register(() =>
|
||||
{
|
||||
MarkCancelled(requestId);
|
||||
|
||||
_ = Task.Run(async () =>
|
||||
{
|
||||
try
|
||||
{
|
||||
await transportClient.SendCancelAsync(
|
||||
decision.Connection,
|
||||
requestIdGuid,
|
||||
CancelReasons.ClientDisconnected);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Sent CANCEL for streaming request {RequestId} due to client disconnect",
|
||||
requestId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex,
|
||||
"Failed to send CANCEL for streaming request {RequestId}",
|
||||
requestId);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
var startTimestamp = Stopwatch.GetTimestamp();
|
||||
var responseReceived = false;
|
||||
|
||||
try
|
||||
{
|
||||
// Use streaming transport method
|
||||
await transportClient.SendStreamingAsync(
|
||||
decision.Connection,
|
||||
frame,
|
||||
context.Request.Body,
|
||||
async responseBodyStream =>
|
||||
{
|
||||
responseReceived = true;
|
||||
|
||||
// For now, read the response stream and write to HTTP response
|
||||
// The response headers should be set before streaming begins
|
||||
context.Response.StatusCode = StatusCodes.Status200OK;
|
||||
context.Response.Headers["Transfer-Encoding"] = "chunked";
|
||||
context.Response.ContentType = "application/octet-stream";
|
||||
|
||||
await responseBodyStream.CopyToAsync(context.Response.Body, timeoutCts.Token);
|
||||
},
|
||||
PayloadLimits.Default,
|
||||
timeoutCts.Token);
|
||||
|
||||
// Record ping latency
|
||||
var elapsed = Stopwatch.GetElapsedTime(startTimestamp);
|
||||
UpdateConnectionPing(routingState, decision.Connection.ConnectionId, elapsed.TotalMilliseconds);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Streaming request {RequestId} completed",
|
||||
requestId);
|
||||
}
|
||||
catch (OperationCanceledException) when (!context.RequestAborted.IsCancellationRequested)
|
||||
{
|
||||
// Internal timeout
|
||||
_logger.LogWarning(
|
||||
"Streaming request {RequestId} timed out after {Timeout}",
|
||||
requestId,
|
||||
decision.EffectiveTimeout);
|
||||
|
||||
MarkCancelled(requestId);
|
||||
|
||||
try
|
||||
{
|
||||
await transportClient.SendCancelAsync(
|
||||
decision.Connection,
|
||||
requestIdGuid,
|
||||
CancelReasons.Timeout);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Failed to send cancel for streaming request {RequestId}", requestId);
|
||||
}
|
||||
|
||||
if (!responseReceived)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status504GatewayTimeout;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Upstream streaming timeout",
|
||||
service = decision.Connection.Instance.ServiceName,
|
||||
timeout = decision.EffectiveTimeout.TotalSeconds
|
||||
});
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Client disconnected
|
||||
MarkCancelled(requestId);
|
||||
_logger.LogDebug("Client disconnected, streaming request {RequestId} cancelled", requestId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex,
|
||||
"Error dispatching streaming request {RequestId}",
|
||||
requestId);
|
||||
|
||||
if (!responseReceived)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status502BadGateway;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Upstream streaming error",
|
||||
message = ex.Message
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user