up
Some checks failed
AOC Guard CI / aoc-guard (push) Has been cancelled
AOC Guard CI / aoc-verify (push) Has been cancelled
Docs CI / lint-and-preview (push) Has been cancelled
Notify Smoke Test / Notify Unit Tests (push) Has been cancelled
Notify Smoke Test / Notifier Service Tests (push) Has been cancelled
Notify Smoke Test / Notification Smoke Test (push) Has been cancelled
Policy Lint & Smoke / policy-lint (push) Has been cancelled
Scanner Analyzers / Discover Analyzers (push) Has been cancelled
Scanner Analyzers / Build Analyzers (push) Has been cancelled
Scanner Analyzers / Test Language Analyzers (push) Has been cancelled
Scanner Analyzers / Validate Test Fixtures (push) Has been cancelled
Scanner Analyzers / Verify Deterministic Output (push) Has been cancelled
Signals CI & Image / signals-ci (push) Has been cancelled
Signals Reachability Scoring & Events / reachability-smoke (push) Has been cancelled
Signals Reachability Scoring & Events / sign-and-upload (push) Has been cancelled
Manifest Integrity / Validate Schema Integrity (push) Has been cancelled
Manifest Integrity / Validate Contract Documents (push) Has been cancelled
Manifest Integrity / Validate Pack Fixtures (push) Has been cancelled
Manifest Integrity / Audit SHA256SUMS Files (push) Has been cancelled
Manifest Integrity / Verify Merkle Roots (push) Has been cancelled
devportal-offline / build-offline (push) Has been cancelled
Mirror Thin Bundle Sign & Verify / mirror-sign (push) Has been cancelled

This commit is contained in:
StellaOps Bot
2025-12-13 18:08:55 +02:00
parent 6e45066e37
commit f1a39c4ce3
234 changed files with 24038 additions and 6910 deletions

View File

@@ -1,40 +0,0 @@
using StellaOps.Gateway.WebService.Middleware;
using StellaOps.Gateway.WebService.OpenApi;
namespace StellaOps.Gateway.WebService;
/// <summary>
/// Extension methods for configuring the gateway middleware pipeline.
/// </summary>
public static class ApplicationBuilderExtensions
{
/// <summary>
/// Adds the gateway router middleware pipeline.
/// </summary>
/// <param name="app">The application builder.</param>
/// <returns>The application builder for chaining.</returns>
public static IApplicationBuilder UseGatewayRouter(this IApplicationBuilder app)
{
// Resolve endpoints from routing state
app.UseMiddleware<EndpointResolutionMiddleware>();
// Make routing decisions (select instance)
app.UseMiddleware<RoutingDecisionMiddleware>();
// Dispatch to transport and return response
app.UseMiddleware<TransportDispatchMiddleware>();
return app;
}
/// <summary>
/// Maps OpenAPI endpoints to the application.
/// Should be called before UseGatewayRouter so OpenAPI requests are handled first.
/// </summary>
/// <param name="endpoints">The endpoint route builder.</param>
/// <returns>The endpoint route builder for chaining.</returns>
public static IEndpointRouteBuilder MapGatewayOpenApi(this IEndpointRouteBuilder endpoints)
{
return endpoints.MapGatewayOpenApiEndpoints();
}
}

View File

@@ -1,140 +0,0 @@
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
namespace StellaOps.Gateway.WebService.Authorization;
/// <summary>
/// Background service that periodically refreshes claims from Authority.
/// </summary>
internal sealed class AuthorityClaimsRefreshService : BackgroundService
{
private readonly IAuthorityClaimsProvider _claimsProvider;
private readonly IEffectiveClaimsStore _claimsStore;
private readonly AuthorityConnectionOptions _options;
private readonly ILogger<AuthorityClaimsRefreshService> _logger;
/// <summary>
/// Initializes a new instance of the <see cref="AuthorityClaimsRefreshService"/> class.
/// </summary>
public AuthorityClaimsRefreshService(
IAuthorityClaimsProvider claimsProvider,
IEffectiveClaimsStore claimsStore,
IOptions<AuthorityConnectionOptions> options,
ILogger<AuthorityClaimsRefreshService> logger)
{
_claimsProvider = claimsProvider;
_claimsStore = claimsStore;
_options = options.Value;
_logger = logger;
}
/// <inheritdoc />
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
if (!_options.Enabled)
{
_logger.LogInformation("Authority integration is disabled");
return;
}
if (string.IsNullOrWhiteSpace(_options.AuthorityUrl))
{
_logger.LogWarning("Authority URL not configured, skipping claims refresh");
return;
}
// Subscribe to push notifications if enabled
if (_options.UseAuthorityPushNotifications)
{
_claimsProvider.OverridesChanged += OnOverridesChanged;
}
// Initial fetch with optional wait
await FetchWithRetryAsync(stoppingToken);
// Periodic refresh
while (!stoppingToken.IsCancellationRequested)
{
try
{
await Task.Delay(_options.RefreshInterval, stoppingToken);
await RefreshClaimsAsync(stoppingToken);
}
catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested)
{
break;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error during claims refresh");
}
}
}
private async Task FetchWithRetryAsync(CancellationToken stoppingToken)
{
if (!_options.WaitForAuthorityOnStartup)
{
await RefreshClaimsAsync(stoppingToken);
return;
}
var deadline = DateTime.UtcNow.Add(_options.StartupTimeout);
var retryDelay = TimeSpan.FromSeconds(1);
var attempt = 0;
while (DateTime.UtcNow < deadline && !stoppingToken.IsCancellationRequested)
{
attempt++;
_logger.LogDebug("Fetching claims from Authority (attempt {Attempt})", attempt);
await RefreshClaimsAsync(stoppingToken);
if (_claimsProvider.IsAvailable)
{
_logger.LogInformation(
"Successfully connected to Authority after {Attempts} attempts",
attempt);
return;
}
await Task.Delay(retryDelay, stoppingToken);
retryDelay = TimeSpan.FromSeconds(Math.Min(retryDelay.TotalSeconds * 2, 10));
}
_logger.LogWarning(
"Could not connect to Authority within {Timeout}. Proceeding without Authority claims.",
_options.StartupTimeout);
}
private async Task RefreshClaimsAsync(CancellationToken cancellationToken)
{
try
{
var overrides = await _claimsProvider.GetOverridesAsync(cancellationToken);
_claimsStore.UpdateFromAuthority(overrides);
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to refresh claims from Authority");
}
}
private void OnOverridesChanged(object? sender, ClaimsOverrideChangedEventArgs e)
{
_logger.LogInformation("Received claims override update from Authority");
_claimsStore.UpdateFromAuthority(e.Overrides);
}
/// <inheritdoc />
public override void Dispose()
{
if (_options.UseAuthorityPushNotifications)
{
_claimsProvider.OverridesChanged -= OnOverridesChanged;
}
base.Dispose();
}
}

View File

@@ -1,44 +0,0 @@
namespace StellaOps.Gateway.WebService.Authorization;
/// <summary>
/// Configuration options for connecting to the Authority service.
/// </summary>
public sealed class AuthorityConnectionOptions
{
/// <summary>
/// Configuration section name.
/// </summary>
public const string SectionName = "Authority";
/// <summary>
/// Gets or sets the Authority service URL.
/// </summary>
public string AuthorityUrl { get; set; } = string.Empty;
/// <summary>
/// Gets or sets whether to wait for Authority on startup.
/// If true, the gateway will delay handling traffic until Authority is available.
/// </summary>
public bool WaitForAuthorityOnStartup { get; set; } = true;
/// <summary>
/// Gets or sets the startup timeout when waiting for Authority.
/// </summary>
public TimeSpan StartupTimeout { get; set; } = TimeSpan.FromSeconds(30);
/// <summary>
/// Gets or sets the interval at which to refresh claims from Authority.
/// </summary>
public TimeSpan RefreshInterval { get; set; } = TimeSpan.FromMinutes(5);
/// <summary>
/// Gets or sets whether to use push notifications from Authority.
/// If false, the gateway will poll at the refresh interval.
/// </summary>
public bool UseAuthorityPushNotifications { get; set; }
/// <summary>
/// Gets or sets whether Authority integration is enabled.
/// </summary>
public bool Enabled { get; set; } = true;
}

View File

@@ -1,103 +0,0 @@
using Microsoft.Extensions.Logging;
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService.Authorization;
/// <summary>
/// Middleware that enforces claims requirements for endpoints.
/// </summary>
public sealed class AuthorizationMiddleware
{
private readonly RequestDelegate _next;
private readonly IEffectiveClaimsStore _claimsStore;
private readonly ILogger<AuthorizationMiddleware> _logger;
/// <summary>
/// Initializes a new instance of the <see cref="AuthorizationMiddleware"/> class.
/// </summary>
public AuthorizationMiddleware(
RequestDelegate next,
IEffectiveClaimsStore claimsStore,
ILogger<AuthorizationMiddleware> logger)
{
_next = next;
_claimsStore = claimsStore;
_logger = logger;
}
/// <summary>
/// Invokes the middleware.
/// </summary>
public async Task InvokeAsync(HttpContext context)
{
// Get resolved endpoint from earlier middleware
if (!context.Items.TryGetValue(RouterHttpContextKeys.EndpointDescriptor, out var endpointObj) ||
endpointObj is not EndpointDescriptor endpoint)
{
// No endpoint resolved, let next middleware handle
await _next(context);
return;
}
// Get effective claims for this endpoint
var effectiveClaims = _claimsStore.GetEffectiveClaims(
endpoint.ServiceName,
endpoint.Method,
endpoint.Path);
if (effectiveClaims.Count == 0)
{
// No claims required
await _next(context);
return;
}
// Check each required claim
foreach (var required in effectiveClaims)
{
var userClaims = context.User.Claims;
bool hasClaim = required.Value == null
? userClaims.Any(c => c.Type == required.Type)
: userClaims.Any(c => c.Type == required.Type && c.Value == required.Value);
if (!hasClaim)
{
_logger.LogWarning(
"Authorization failed for {Method} {Path}: user lacks claim {ClaimType}={ClaimValue}",
endpoint.Method,
endpoint.Path,
required.Type,
required.Value ?? "(any)");
context.Response.StatusCode = StatusCodes.Status403Forbidden;
context.Response.ContentType = "application/json";
await context.Response.WriteAsJsonAsync(new
{
error = "Forbidden",
message = "Authorization failed: missing required claim",
requiredClaim = new { type = required.Type, value = required.Value }
});
return;
}
}
await _next(context);
}
}
/// <summary>
/// Extension methods for registering the authorization middleware.
/// </summary>
public static class AuthorizationMiddlewareExtensions
{
/// <summary>
/// Adds the claims authorization middleware to the pipeline.
/// </summary>
/// <param name="app">The application builder.</param>
/// <returns>The application builder for chaining.</returns>
public static IApplicationBuilder UseClaimsAuthorization(this IApplicationBuilder app)
{
return app.UseMiddleware<AuthorizationMiddleware>();
}
}

View File

@@ -1,110 +0,0 @@
using System.Collections.Concurrent;
using Microsoft.Extensions.Logging;
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService.Authorization;
/// <summary>
/// In-memory store for effective claims.
/// Merges microservice defaults with Authority overrides.
/// </summary>
internal sealed class EffectiveClaimsStore : IEffectiveClaimsStore
{
private readonly ConcurrentDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> _microserviceClaims = new();
private readonly ConcurrentDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> _authorityClaims = new();
private readonly ILogger<EffectiveClaimsStore> _logger;
/// <summary>
/// Initializes a new instance of the <see cref="EffectiveClaimsStore"/> class.
/// </summary>
public EffectiveClaimsStore(ILogger<EffectiveClaimsStore> logger)
{
_logger = logger;
}
/// <inheritdoc />
public IReadOnlyList<ClaimRequirement> GetEffectiveClaims(string serviceName, string method, string path)
{
var key = EndpointKey.Create(serviceName, method, path);
// Authority takes precedence
if (_authorityClaims.TryGetValue(key, out var authorityClaims))
{
_logger.LogDebug(
"Using Authority claims for {Endpoint}: {ClaimCount} claims",
key,
authorityClaims.Count);
return authorityClaims;
}
// Fall back to microservice defaults
if (_microserviceClaims.TryGetValue(key, out var msClaims))
{
return msClaims;
}
return [];
}
/// <inheritdoc />
public void UpdateFromMicroservice(string serviceName, IReadOnlyList<EndpointDescriptor> endpoints)
{
foreach (var endpoint in endpoints)
{
var key = EndpointKey.Create(serviceName, endpoint.Method, endpoint.Path);
var claims = endpoint.RequiringClaims ?? [];
if (claims.Count > 0)
{
_microserviceClaims[key] = claims;
_logger.LogDebug(
"Registered {ClaimCount} claims from microservice for {Endpoint}",
claims.Count,
key);
}
else
{
_microserviceClaims.TryRemove(key, out _);
}
}
}
/// <inheritdoc />
public void UpdateFromAuthority(IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> overrides)
{
// Clear previous Authority claims
_authorityClaims.Clear();
// Add new Authority claims
foreach (var (key, claims) in overrides)
{
if (claims.Count > 0)
{
_authorityClaims[key] = claims;
}
}
_logger.LogInformation(
"Updated Authority claims: {EndpointCount} endpoints with overrides",
overrides.Count);
}
/// <inheritdoc />
public void RemoveService(string serviceName)
{
var normalizedServiceName = serviceName.ToLowerInvariant();
var keysToRemove = _microserviceClaims.Keys
.Where(k => k.ServiceName == normalizedServiceName)
.ToList();
foreach (var key in keysToRemove)
{
_microserviceClaims.TryRemove(key, out _);
}
_logger.LogDebug(
"Removed {Count} endpoint claims for service {ServiceName}",
keysToRemove.Count,
serviceName);
}
}

View File

@@ -1,24 +0,0 @@
namespace StellaOps.Gateway.WebService.Authorization;
/// <summary>
/// Key for identifying an endpoint by service name, method, and path.
/// </summary>
/// <param name="ServiceName">The name of the service.</param>
/// <param name="Method">The HTTP method.</param>
/// <param name="Path">The path template.</param>
public readonly record struct EndpointKey(string ServiceName, string Method, string Path)
{
/// <summary>
/// Creates an endpoint key with normalized values.
/// </summary>
public static EndpointKey Create(string serviceName, string method, string path)
{
return new EndpointKey(
serviceName.ToLowerInvariant(),
method.ToUpperInvariant(),
path.ToLowerInvariant());
}
/// <inheritdoc />
public override string ToString() => $"{ServiceName}:{Method} {Path}";
}

View File

@@ -1,133 +0,0 @@
using System.Net.Http.Json;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService.Authorization;
/// <summary>
/// Fetches claims overrides from the Authority service via HTTP.
/// </summary>
internal sealed class HttpAuthorityClaimsProvider : IAuthorityClaimsProvider
{
private readonly HttpClient _httpClient;
private readonly AuthorityConnectionOptions _options;
private readonly ILogger<HttpAuthorityClaimsProvider> _logger;
private volatile bool _isAvailable;
private static readonly JsonSerializerOptions JsonOptions = new()
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
};
/// <summary>
/// Initializes a new instance of the <see cref="HttpAuthorityClaimsProvider"/> class.
/// </summary>
public HttpAuthorityClaimsProvider(
HttpClient httpClient,
IOptions<AuthorityConnectionOptions> options,
ILogger<HttpAuthorityClaimsProvider> logger)
{
_httpClient = httpClient;
_options = options.Value;
_logger = logger;
}
/// <inheritdoc />
public bool IsAvailable => _isAvailable;
/// <inheritdoc />
public event EventHandler<ClaimsOverrideChangedEventArgs>? OverridesChanged;
/// <inheritdoc />
public async Task<IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>> GetOverridesAsync(
CancellationToken cancellationToken)
{
if (string.IsNullOrWhiteSpace(_options.AuthorityUrl))
{
_logger.LogDebug("Authority URL not configured, returning empty overrides");
_isAvailable = false;
return new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>();
}
try
{
var url = $"{_options.AuthorityUrl.TrimEnd('/')}/api/v1/claims/overrides";
_logger.LogDebug("Fetching claims overrides from {Url}", url);
var response = await _httpClient.GetAsync(url, cancellationToken);
response.EnsureSuccessStatusCode();
var overrideResponse = await response.Content.ReadFromJsonAsync<ClaimsOverrideResponse>(
JsonOptions,
cancellationToken);
if (overrideResponse?.Overrides == null)
{
_isAvailable = true;
return new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>();
}
var result = new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>();
foreach (var entry in overrideResponse.Overrides)
{
var key = EndpointKey.Create(entry.ServiceName, entry.Method, entry.Path);
var claims = entry.RequiringClaims
.Select(c => new ClaimRequirement { Type = c.Type, Value = c.Value })
.ToList();
result[key] = claims;
}
_isAvailable = true;
_logger.LogInformation(
"Fetched {Count} claims overrides from Authority",
result.Count);
return result;
}
catch (Exception ex) when (ex is HttpRequestException or TaskCanceledException)
{
_isAvailable = false;
_logger.LogWarning(ex, "Failed to fetch claims overrides from Authority");
return new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>();
}
}
/// <summary>
/// Raises the <see cref="OverridesChanged"/> event.
/// </summary>
internal void RaiseOverridesChanged(IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> overrides)
{
OverridesChanged?.Invoke(this, new ClaimsOverrideChangedEventArgs { Overrides = overrides });
}
/// <summary>
/// DTO for claims override response from Authority.
/// </summary>
private sealed class ClaimsOverrideResponse
{
public List<ClaimsOverrideEntry> Overrides { get; set; } = [];
}
/// <summary>
/// DTO for a single claims override entry.
/// </summary>
private sealed class ClaimsOverrideEntry
{
public string ServiceName { get; set; } = string.Empty;
public string Method { get; set; } = string.Empty;
public string Path { get; set; } = string.Empty;
public List<ClaimRequirementDto> RequiringClaims { get; set; } = [];
}
/// <summary>
/// DTO for a claim requirement.
/// </summary>
private sealed class ClaimRequirementDto
{
public string Type { get; set; } = string.Empty;
public string? Value { get; set; }
}
}

View File

@@ -1,39 +0,0 @@
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService.Authorization;
/// <summary>
/// Provides claims overrides from the central Authority service.
/// </summary>
public interface IAuthorityClaimsProvider
{
/// <summary>
/// Gets all claims overrides from Authority.
/// </summary>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A dictionary of endpoint keys to claim requirements.</returns>
Task<IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>> GetOverridesAsync(
CancellationToken cancellationToken);
/// <summary>
/// Gets a value indicating whether the Authority is currently available.
/// </summary>
bool IsAvailable { get; }
/// <summary>
/// Occurs when claims overrides change.
/// </summary>
event EventHandler<ClaimsOverrideChangedEventArgs>? OverridesChanged;
}
/// <summary>
/// Event arguments for claims override changes.
/// </summary>
public sealed class ClaimsOverrideChangedEventArgs : EventArgs
{
/// <summary>
/// Gets the updated claims overrides.
/// </summary>
public IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> Overrides { get; init; }
= new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>();
}

View File

@@ -1,40 +0,0 @@
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService.Authorization;
/// <summary>
/// Stores and retrieves effective claims for endpoints.
/// Handles merging of microservice defaults with Authority overrides.
/// </summary>
public interface IEffectiveClaimsStore
{
/// <summary>
/// Gets the effective claims for an endpoint.
/// Authority overrides take precedence over microservice defaults.
/// </summary>
/// <param name="serviceName">The service name.</param>
/// <param name="method">The HTTP method.</param>
/// <param name="path">The path template.</param>
/// <returns>The effective claims for the endpoint.</returns>
IReadOnlyList<ClaimRequirement> GetEffectiveClaims(string serviceName, string method, string path);
/// <summary>
/// Updates claims from a microservice's HELLO message.
/// </summary>
/// <param name="serviceName">The service name.</param>
/// <param name="endpoints">The endpoint descriptors with claims.</param>
void UpdateFromMicroservice(string serviceName, IReadOnlyList<EndpointDescriptor> endpoints);
/// <summary>
/// Updates claims from Authority overrides.
/// </summary>
/// <param name="overrides">The Authority claims overrides.</param>
void UpdateFromAuthority(IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> overrides);
/// <summary>
/// Removes all claims for a service.
/// Called when a microservice disconnects.
/// </summary>
/// <param name="serviceName">The service name.</param>
void RemoveService(string serviceName);
}

View File

@@ -1,107 +0,0 @@
namespace StellaOps.Gateway.WebService.Authorization;
/// <summary>
/// Extension methods for registering Authority integration services.
/// </summary>
public static class AuthorizationServiceCollectionExtensions
{
/// <summary>
/// Adds Authority integration services to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configuration">The configuration.</param>
/// <returns>The service collection for chaining.</returns>
public static IServiceCollection AddAuthorityIntegration(
this IServiceCollection services,
IConfiguration configuration)
{
// Bind options
services.Configure<AuthorityConnectionOptions>(
configuration.GetSection(AuthorityConnectionOptions.SectionName));
// Register effective claims store
services.AddSingleton<IEffectiveClaimsStore, EffectiveClaimsStore>();
// Register HTTP client for Authority
services.AddHttpClient<IAuthorityClaimsProvider, HttpAuthorityClaimsProvider>(client =>
{
client.Timeout = TimeSpan.FromSeconds(30);
});
// Register background service for claims refresh
services.AddHostedService<AuthorityClaimsRefreshService>();
return services;
}
/// <summary>
/// Adds Authority integration services with custom options.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configure">Action to configure Authority options.</param>
/// <returns>The service collection for chaining.</returns>
public static IServiceCollection AddAuthorityIntegration(
this IServiceCollection services,
Action<AuthorityConnectionOptions>? configure = null)
{
// Register options
if (configure != null)
{
services.Configure(configure);
}
else
{
services.AddOptions<AuthorityConnectionOptions>();
}
// Register effective claims store
services.AddSingleton<IEffectiveClaimsStore, EffectiveClaimsStore>();
// Register HTTP client for Authority
services.AddHttpClient<IAuthorityClaimsProvider, HttpAuthorityClaimsProvider>(client =>
{
client.Timeout = TimeSpan.FromSeconds(30);
});
// Register background service for claims refresh
services.AddHostedService<AuthorityClaimsRefreshService>();
return services;
}
/// <summary>
/// Adds a no-op Authority integration (no external Authority).
/// Claims are only from microservices.
/// </summary>
/// <param name="services">The service collection.</param>
/// <returns>The service collection for chaining.</returns>
public static IServiceCollection AddNoOpAuthorityIntegration(this IServiceCollection services)
{
services.Configure<AuthorityConnectionOptions>(options => options.Enabled = false);
services.AddSingleton<IEffectiveClaimsStore, EffectiveClaimsStore>();
services.AddSingleton<IAuthorityClaimsProvider, NoOpAuthorityClaimsProvider>();
return services;
}
}
/// <summary>
/// A no-op Authority claims provider that returns empty overrides.
/// </summary>
internal sealed class NoOpAuthorityClaimsProvider : IAuthorityClaimsProvider
{
/// <inheritdoc />
public bool IsAvailable => true;
/// <inheritdoc />
#pragma warning disable CS0067 // Event is never used (expected for no-op implementation)
public event EventHandler<ClaimsOverrideChangedEventArgs>? OverridesChanged;
#pragma warning restore CS0067
/// <inheritdoc />
public Task<IReadOnlyDictionary<EndpointKey, IReadOnlyList<StellaOps.Router.Common.Models.ClaimRequirement>>> GetOverridesAsync(
CancellationToken cancellationToken)
{
return Task.FromResult<IReadOnlyDictionary<EndpointKey, IReadOnlyList<StellaOps.Router.Common.Models.ClaimRequirement>>>(
new Dictionary<EndpointKey, IReadOnlyList<StellaOps.Router.Common.Models.ClaimRequirement>>());
}
}

View File

@@ -1,110 +0,0 @@
using Microsoft.Extensions.Logging;
using StellaOps.Gateway.WebService.OpenApi;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using StellaOps.Router.Transport.InMemory;
namespace StellaOps.Gateway.WebService;
/// <summary>
/// Manages microservice connections and updates routing state.
/// </summary>
internal sealed class ConnectionManager : IHostedService
{
private readonly InMemoryTransportServer _transportServer;
private readonly InMemoryConnectionRegistry _connectionRegistry;
private readonly IGlobalRoutingState _routingState;
private readonly IGatewayOpenApiDocumentCache? _openApiCache;
private readonly ILogger<ConnectionManager> _logger;
public ConnectionManager(
InMemoryTransportServer transportServer,
InMemoryConnectionRegistry connectionRegistry,
IGlobalRoutingState routingState,
ILogger<ConnectionManager> logger,
IGatewayOpenApiDocumentCache? openApiCache = null)
{
_transportServer = transportServer;
_connectionRegistry = connectionRegistry;
_routingState = routingState;
_openApiCache = openApiCache;
_logger = logger;
}
public async Task StartAsync(CancellationToken cancellationToken)
{
// Subscribe to transport server events
_transportServer.OnHelloReceived += HandleHelloReceivedAsync;
_transportServer.OnHeartbeatReceived += HandleHeartbeatReceivedAsync;
_transportServer.OnConnectionClosed += HandleConnectionClosedAsync;
// Start the transport server
await _transportServer.StartAsync(cancellationToken);
_logger.LogInformation("Connection manager started");
}
public async Task StopAsync(CancellationToken cancellationToken)
{
await _transportServer.StopAsync(cancellationToken);
_transportServer.OnHelloReceived -= HandleHelloReceivedAsync;
_transportServer.OnHeartbeatReceived -= HandleHeartbeatReceivedAsync;
_transportServer.OnConnectionClosed -= HandleConnectionClosedAsync;
_logger.LogInformation("Connection manager stopped");
}
private Task HandleHelloReceivedAsync(ConnectionState connectionState, HelloPayload payload)
{
_logger.LogInformation(
"Connection registered: {ConnectionId} from {ServiceName}/{Version} with {EndpointCount} endpoints, {SchemaCount} schemas",
connectionState.ConnectionId,
connectionState.Instance.ServiceName,
connectionState.Instance.Version,
connectionState.Endpoints.Count,
connectionState.Schemas.Count);
// Add the connection to the routing state
_routingState.AddConnection(connectionState);
// Start listening to this connection for frames
_transportServer.StartListeningToConnection(connectionState.ConnectionId);
// Invalidate OpenAPI cache when connections change
_openApiCache?.Invalidate();
return Task.CompletedTask;
}
private Task HandleHeartbeatReceivedAsync(ConnectionState connectionState, HeartbeatPayload payload)
{
_logger.LogDebug(
"Heartbeat received from {ConnectionId}: status={Status}",
connectionState.ConnectionId,
payload.Status);
// Update connection state
_routingState.UpdateConnection(connectionState.ConnectionId, conn =>
{
conn.Status = payload.Status;
conn.LastHeartbeatUtc = DateTime.UtcNow;
});
return Task.CompletedTask;
}
private Task HandleConnectionClosedAsync(string connectionId)
{
_logger.LogInformation("Connection closed: {ConnectionId}", connectionId);
// Remove from routing state
_routingState.RemoveConnection(connectionId);
// Invalidate OpenAPI cache when connections change
_openApiCache?.Invalidate();
return Task.CompletedTask;
}
}

View File

@@ -1,256 +0,0 @@
using System.Collections.Concurrent;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService;
/// <summary>
/// Default implementation of routing plugin that provides health-aware, region-aware routing.
/// </summary>
/// <remarks>
/// Routing algorithm:
/// 1. Filter by ServiceName (exact match from endpoint)
/// 2. Filter by Version (strict semver equality when RequestedVersion specified)
/// 3. Filter by Health (Healthy preferred, Degraded as fallback)
/// 4. Group by Region Tier:
/// - Tier 0: Same region as gateway
/// - Tier 1: Configured neighbor regions
/// - Tier 2: All other regions
/// 5. Within each tier, sort by:
/// - Primary: Lower AveragePingMs
/// - Secondary: More recent LastHeartbeatUtc
/// - Tie-breaker: Random or RoundRobin
/// 6. Return first candidate from best available tier
/// 7. If none remain, return null (503 Service Unavailable)
/// </remarks>
internal sealed class DefaultRoutingPlugin : IRoutingPlugin
{
private readonly RoutingOptions _options;
private readonly GatewayNodeConfig _gatewayConfig;
private readonly ConcurrentDictionary<string, int> _roundRobinCounters = new();
/// <summary>
/// Initializes a new instance of the <see cref="DefaultRoutingPlugin"/> class.
/// </summary>
public DefaultRoutingPlugin(
IOptions<RoutingOptions> options,
IOptions<GatewayNodeConfig> gatewayConfig)
{
_options = options.Value;
_gatewayConfig = gatewayConfig.Value;
}
/// <inheritdoc />
public Task<RoutingDecision?> ChooseInstanceAsync(
RoutingContext context,
CancellationToken cancellationToken)
{
if (context.AvailableConnections.Count == 0)
{
return Task.FromResult<RoutingDecision?>(null);
}
var endpoint = context.Endpoint;
if (endpoint is null)
{
return Task.FromResult<RoutingDecision?>(null);
}
// Start with all available connections
var candidates = context.AvailableConnections.ToList();
// Filter by version if requested
candidates = FilterByVersion(candidates, context.RequestedVersion);
if (candidates.Count == 0)
{
return Task.FromResult<RoutingDecision?>(null);
}
// Filter by health status - prefer healthy, fall back to degraded
candidates = FilterByHealth(candidates);
if (candidates.Count == 0)
{
return Task.FromResult<RoutingDecision?>(null);
}
// Group by region tier and select from best available tier
var selected = SelectByRegionTier(candidates, context.GatewayRegion, endpoint.ServiceName);
if (selected is null)
{
return Task.FromResult<RoutingDecision?>(null);
}
var decision = new RoutingDecision
{
Endpoint = endpoint,
Connection = selected,
TransportType = selected.TransportType,
EffectiveTimeout = TimeSpan.FromMilliseconds(_options.RoutingTimeoutMs)
};
return Task.FromResult<RoutingDecision?>(decision);
}
private List<ConnectionState> FilterByVersion(
List<ConnectionState> candidates,
string? requestedVersion)
{
// Determine effective version to match
var versionToMatch = requestedVersion ?? _options.DefaultVersion;
// If no version specified and no default, return all candidates
if (string.IsNullOrEmpty(versionToMatch))
{
return candidates;
}
if (_options.StrictVersionMatching)
{
// Strict match: exact version equality
return candidates
.Where(c => string.Equals(c.Instance.Version, versionToMatch, StringComparison.Ordinal))
.ToList();
}
// Non-strict: allow compatible versions (for now, just exact match)
// Future: implement semver compatibility checking
return candidates
.Where(c => string.Equals(c.Instance.Version, versionToMatch, StringComparison.Ordinal))
.ToList();
}
private List<ConnectionState> FilterByHealth(List<ConnectionState> candidates)
{
// Filter to only healthy instances first
var healthy = candidates
.Where(c => c.Status == InstanceHealthStatus.Healthy)
.ToList();
if (healthy.Count > 0)
{
return healthy;
}
// If no healthy instances and degraded allowed, include degraded
if (_options.AllowDegradedInstances)
{
var degraded = candidates
.Where(c => c.Status == InstanceHealthStatus.Degraded)
.ToList();
if (degraded.Count > 0)
{
return degraded;
}
}
// No suitable instances
return [];
}
private ConnectionState? SelectByRegionTier(
List<ConnectionState> candidates,
string gatewayRegion,
string serviceName)
{
if (!_options.PreferLocalRegion || string.IsNullOrEmpty(gatewayRegion))
{
// No region preference, select from all candidates
return SelectFromTier(candidates, serviceName);
}
// Tier 0: Same region as gateway
var tier0 = candidates
.Where(c => string.Equals(c.Instance.Region, gatewayRegion, StringComparison.OrdinalIgnoreCase))
.ToList();
var selected = SelectFromTier(tier0, serviceName);
if (selected is not null)
{
return selected;
}
// Tier 1: Configured neighbor regions
var neighborRegions = _gatewayConfig.NeighborRegions;
if (neighborRegions.Count > 0)
{
var tier1 = candidates
.Where(c => neighborRegions.Contains(c.Instance.Region, StringComparer.OrdinalIgnoreCase))
.ToList();
selected = SelectFromTier(tier1, serviceName);
if (selected is not null)
{
return selected;
}
}
// Tier 2: All other regions (remaining candidates not in tier0 or tier1)
var tier2 = candidates
.Where(c => !string.Equals(c.Instance.Region, gatewayRegion, StringComparison.OrdinalIgnoreCase))
.Where(c => !neighborRegions.Contains(c.Instance.Region, StringComparer.OrdinalIgnoreCase))
.ToList();
return SelectFromTier(tier2, serviceName);
}
private ConnectionState? SelectFromTier(List<ConnectionState> tier, string serviceName)
{
if (tier.Count == 0)
{
return null;
}
if (tier.Count == 1)
{
return tier[0];
}
// Sort by ping (ascending), then by heartbeat (descending = more recent first)
var sorted = tier
.OrderBy(c => c.AveragePingMs)
.ThenByDescending(c => c.LastHeartbeatUtc)
.ToList();
var best = sorted[0];
// Find all instances "tied" with the best one
var tied = sorted
.TakeWhile(c =>
Math.Abs(c.AveragePingMs - best.AveragePingMs) <= _options.PingToleranceMs &&
c.LastHeartbeatUtc == best.LastHeartbeatUtc)
.ToList();
if (tied.Count == 1)
{
return tied[0];
}
// Apply tie-breaker
return _options.TieBreaker switch
{
TieBreakerMode.RoundRobin => SelectRoundRobin(tied, serviceName),
_ => SelectRandom(tied)
};
}
private ConnectionState SelectRandom(List<ConnectionState> candidates)
{
var index = Random.Shared.Next(candidates.Count);
return candidates[index];
}
private ConnectionState SelectRoundRobin(List<ConnectionState> candidates, string serviceName)
{
// Get or create counter for this service
var counter = _roundRobinCounters.AddOrUpdate(
serviceName,
_ => 0,
(_, current) => current + 1);
var index = counter % candidates.Count;
return candidates[index];
}
}

View File

@@ -1,55 +0,0 @@
using System.ComponentModel.DataAnnotations;
namespace StellaOps.Gateway.WebService;
/// <summary>
/// Static configuration for a gateway node.
/// </summary>
public sealed class GatewayNodeConfig
{
/// <summary>
/// Configuration section name for binding.
/// </summary>
public const string SectionName = "GatewayNode";
/// <summary>
/// Gets or sets the region where this gateway is deployed (e.g., "eu1").
/// Routing decisions use this value; it is never derived from headers or URLs.
/// </summary>
[Required(ErrorMessage = "Region is required for gateway routing")]
public string Region { get; set; } = string.Empty;
/// <summary>
/// Gets or sets the unique identifier for this gateway node (e.g., "gw-eu1-01").
/// </summary>
public string NodeId { get; set; } = string.Empty;
/// <summary>
/// Gets or sets the environment name (e.g., "prod", "staging", "dev").
/// </summary>
public string Environment { get; set; } = string.Empty;
/// <summary>
/// Gets or sets the neighbor regions for fallback routing, in order of preference.
/// </summary>
public List<string> NeighborRegions { get; set; } = [];
/// <summary>
/// Validates the configuration.
/// </summary>
/// <exception cref="InvalidOperationException">Thrown when configuration is invalid.</exception>
public void Validate()
{
if (string.IsNullOrWhiteSpace(Region))
{
throw new InvalidOperationException(
$"{SectionName}:Region is required. Gateway cannot start without a region assignment.");
}
// Generate NodeId if not provided
if (string.IsNullOrWhiteSpace(NodeId))
{
NodeId = $"gw-{Region}-{Guid.NewGuid().ToString("N")[..8]}";
}
}
}

View File

@@ -1,117 +0,0 @@
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
namespace StellaOps.Gateway.WebService;
/// <summary>
/// Background service that monitors connection health and marks stale instances as unhealthy.
/// </summary>
internal sealed class HealthMonitorService : BackgroundService
{
private readonly IGlobalRoutingState _routingState;
private readonly IOptions<HealthOptions> _options;
private readonly ILogger<HealthMonitorService> _logger;
/// <summary>
/// Initializes a new instance of the <see cref="HealthMonitorService"/> class.
/// </summary>
public HealthMonitorService(
IGlobalRoutingState routingState,
IOptions<HealthOptions> options,
ILogger<HealthMonitorService> logger)
{
_routingState = routingState;
_options = options;
_logger = logger;
}
/// <inheritdoc />
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
_logger.LogInformation(
"Health monitor started. Stale threshold: {StaleThreshold}, Check interval: {CheckInterval}",
_options.Value.StaleThreshold,
_options.Value.CheckInterval);
while (!stoppingToken.IsCancellationRequested)
{
try
{
await Task.Delay(_options.Value.CheckInterval, stoppingToken);
CheckStaleConnections();
}
catch (OperationCanceledException)
{
// Expected on shutdown
break;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error in health monitor loop");
}
}
_logger.LogInformation("Health monitor stopped");
}
private void CheckStaleConnections()
{
var staleThreshold = _options.Value.StaleThreshold;
var degradedThreshold = _options.Value.DegradedThreshold;
var now = DateTime.UtcNow;
var staleCount = 0;
var degradedCount = 0;
foreach (var connection in _routingState.GetAllConnections())
{
// Skip connections that are already draining - they're intentionally stopping
if (connection.Status == InstanceHealthStatus.Draining)
{
continue;
}
var age = now - connection.LastHeartbeatUtc;
// Check for stale (no heartbeat for too long)
if (age > staleThreshold && connection.Status != InstanceHealthStatus.Unhealthy)
{
_routingState.UpdateConnection(connection.ConnectionId, c =>
c.Status = InstanceHealthStatus.Unhealthy);
_logger.LogWarning(
"Instance {InstanceId} ({ServiceName}/{Version}) marked Unhealthy: no heartbeat for {Age:g}",
connection.Instance.InstanceId,
connection.Instance.ServiceName,
connection.Instance.Version,
age);
staleCount++;
}
// Check for degraded (heartbeat delayed but not stale)
else if (age > degradedThreshold &&
connection.Status == InstanceHealthStatus.Healthy)
{
_routingState.UpdateConnection(connection.ConnectionId, c =>
c.Status = InstanceHealthStatus.Degraded);
_logger.LogWarning(
"Instance {InstanceId} ({ServiceName}/{Version}) marked Degraded: delayed heartbeat ({Age:g})",
connection.Instance.InstanceId,
connection.Instance.ServiceName,
connection.Instance.Version,
age);
degradedCount++;
}
}
if (staleCount > 0 || degradedCount > 0)
{
_logger.LogDebug(
"Health check completed: {StaleCount} stale, {DegradedCount} degraded",
staleCount,
degradedCount);
}
}
}

View File

@@ -1,36 +0,0 @@
namespace StellaOps.Gateway.WebService;
/// <summary>
/// Configuration options for health monitoring.
/// </summary>
public sealed class HealthOptions
{
/// <summary>
/// Gets the configuration section name.
/// </summary>
public const string SectionName = "Health";
/// <summary>
/// Gets or sets the threshold after which a connection is considered stale (no heartbeat).
/// Default: 30 seconds.
/// </summary>
public TimeSpan StaleThreshold { get; set; } = TimeSpan.FromSeconds(30);
/// <summary>
/// Gets or sets the threshold after which a connection is considered degraded.
/// Default: 15 seconds.
/// </summary>
public TimeSpan DegradedThreshold { get; set; } = TimeSpan.FromSeconds(15);
/// <summary>
/// Gets or sets the interval at which to check for stale connections.
/// Default: 5 seconds.
/// </summary>
public TimeSpan CheckInterval { get; set; } = TimeSpan.FromSeconds(5);
/// <summary>
/// Gets or sets the number of ping measurements to keep for averaging.
/// Default: 10.
/// </summary>
public int PingHistorySize { get; set; } = 10;
}

View File

@@ -1,159 +0,0 @@
using System.Collections.Concurrent;
using StellaOps.Router.Common;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService;
/// <summary>
/// In-memory implementation of global routing state.
/// </summary>
internal sealed class InMemoryRoutingState : IGlobalRoutingState
{
private readonly ConcurrentDictionary<string, ConnectionState> _connections = new();
private readonly ConcurrentDictionary<(string Method, string Path), ConcurrentBag<string>> _endpointIndex = new();
private readonly ConcurrentDictionary<(string Method, string Path), PathMatcher> _pathMatchers = new();
private readonly object _indexLock = new();
/// <inheritdoc />
public void AddConnection(ConnectionState connection)
{
_connections[connection.ConnectionId] = connection;
// Index all endpoints
foreach (var endpoint in connection.Endpoints.Values)
{
var key = (endpoint.Method, endpoint.Path);
// Add to endpoint index
var connectionIds = _endpointIndex.GetOrAdd(key, _ => []);
connectionIds.Add(connection.ConnectionId);
// Create path matcher if not exists
_pathMatchers.GetOrAdd(key, _ => new PathMatcher(endpoint.Path));
}
}
/// <inheritdoc />
public void RemoveConnection(string connectionId)
{
if (_connections.TryRemove(connectionId, out var connection))
{
// Remove from endpoint index
foreach (var endpoint in connection.Endpoints.Values)
{
var key = (endpoint.Method, endpoint.Path);
if (_endpointIndex.TryGetValue(key, out var connectionIds))
{
// ConcurrentBag doesn't support removal, so we need to rebuild
lock (_indexLock)
{
var remaining = connectionIds.Where(id => id != connectionId).ToList();
if (remaining.Count == 0)
{
_endpointIndex.TryRemove(key, out _);
_pathMatchers.TryRemove(key, out _);
}
else
{
_endpointIndex[key] = new ConcurrentBag<string>(remaining);
}
}
}
}
}
}
/// <inheritdoc />
public void UpdateConnection(string connectionId, Action<ConnectionState> update)
{
if (_connections.TryGetValue(connectionId, out var connection))
{
update(connection);
}
}
/// <inheritdoc />
public ConnectionState? GetConnection(string connectionId)
{
return _connections.TryGetValue(connectionId, out var connection) ? connection : null;
}
/// <inheritdoc />
public IReadOnlyList<ConnectionState> GetAllConnections()
{
return [.. _connections.Values];
}
/// <inheritdoc />
public EndpointDescriptor? ResolveEndpoint(string method, string path)
{
// First try exact match
foreach (var ((m, p), matcher) in _pathMatchers)
{
if (!string.Equals(m, method, StringComparison.OrdinalIgnoreCase))
continue;
if (matcher.IsMatch(path))
{
// Get first connection with this endpoint
if (_endpointIndex.TryGetValue((m, p), out var connectionIds))
{
foreach (var connectionId in connectionIds)
{
if (_connections.TryGetValue(connectionId, out var conn) &&
conn.Endpoints.TryGetValue((m, p), out var endpoint))
{
return endpoint;
}
}
}
}
}
return null;
}
/// <inheritdoc />
public IReadOnlyList<ConnectionState> GetConnectionsFor(
string serviceName,
string version,
string method,
string path)
{
var result = new List<ConnectionState>();
foreach (var ((m, p), matcher) in _pathMatchers)
{
if (!string.Equals(m, method, StringComparison.OrdinalIgnoreCase))
continue;
if (!matcher.IsMatch(path))
continue;
if (!_endpointIndex.TryGetValue((m, p), out var connectionIds))
continue;
foreach (var connectionId in connectionIds)
{
if (!_connections.TryGetValue(connectionId, out var conn))
continue;
// Filter by service name and version
if (!string.Equals(conn.Instance.ServiceName, serviceName, StringComparison.OrdinalIgnoreCase))
continue;
if (!string.Equals(conn.Instance.Version, version, StringComparison.Ordinal))
continue;
// Check endpoint exists
if (conn.Endpoints.ContainsKey((m, p)))
{
result.Add(conn);
}
}
}
return result;
}
}

View File

@@ -1,135 +0,0 @@
namespace StellaOps.Gateway.WebService.Middleware;
/// <summary>
/// A stream wrapper that counts bytes read and enforces a limit.
/// </summary>
public sealed class ByteCountingStream : Stream
{
private readonly Stream _inner;
private readonly long _limit;
private readonly Action? _onLimitExceeded;
private long _bytesRead;
private bool _disposed;
/// <summary>
/// Initializes a new instance of the <see cref="ByteCountingStream"/> class.
/// </summary>
/// <param name="inner">The inner stream to wrap.</param>
/// <param name="limit">The maximum number of bytes that can be read.</param>
/// <param name="onLimitExceeded">Optional callback invoked when the limit is exceeded.</param>
public ByteCountingStream(Stream inner, long limit, Action? onLimitExceeded = null)
{
_inner = inner;
_limit = limit;
_onLimitExceeded = onLimitExceeded;
}
/// <summary>
/// Gets the total number of bytes read from this stream.
/// </summary>
public long BytesRead => Interlocked.Read(ref _bytesRead);
/// <inheritdoc />
public override bool CanRead => _inner.CanRead;
/// <inheritdoc />
public override bool CanSeek => false;
/// <inheritdoc />
public override bool CanWrite => false;
/// <inheritdoc />
public override long Length => _inner.Length;
/// <inheritdoc />
public override long Position
{
get => _inner.Position;
set => throw new NotSupportedException("Seeking not supported on ByteCountingStream.");
}
/// <inheritdoc />
public override void Flush() => _inner.Flush();
/// <inheritdoc />
public override Task FlushAsync(CancellationToken cancellationToken) =>
_inner.FlushAsync(cancellationToken);
/// <inheritdoc />
public override int Read(byte[] buffer, int offset, int count)
{
var read = _inner.Read(buffer, offset, count);
CheckLimit(read);
return read;
}
/// <inheritdoc />
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
var read = await _inner.ReadAsync(buffer, offset, count, cancellationToken);
CheckLimit(read);
return read;
}
/// <inheritdoc />
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
var read = await _inner.ReadAsync(buffer, cancellationToken);
CheckLimit(read);
return read;
}
/// <inheritdoc />
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException("Seeking not supported on ByteCountingStream.");
}
/// <inheritdoc />
public override void SetLength(long value)
{
throw new NotSupportedException("Setting length not supported on ByteCountingStream.");
}
/// <inheritdoc />
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException("Writing not supported on ByteCountingStream.");
}
private void CheckLimit(int bytesJustRead)
{
if (bytesJustRead <= 0) return;
var newTotal = Interlocked.Add(ref _bytesRead, bytesJustRead);
if (newTotal > _limit)
{
_onLimitExceeded?.Invoke();
throw new PayloadLimitExceededException(newTotal, _limit);
}
}
/// <inheritdoc />
protected override void Dispose(bool disposing)
{
if (!_disposed && disposing)
{
_inner.Dispose();
}
_disposed = true;
base.Dispose(disposing);
}
/// <inheritdoc />
public override async ValueTask DisposeAsync()
{
if (!_disposed)
{
await _inner.DisposeAsync();
}
_disposed = true;
await base.DisposeAsync();
}
}

View File

@@ -1,44 +0,0 @@
using StellaOps.Router.Common.Abstractions;
namespace StellaOps.Gateway.WebService.Middleware;
/// <summary>
/// Resolves incoming HTTP requests to endpoint descriptors using the routing state.
/// </summary>
public sealed class EndpointResolutionMiddleware
{
private readonly RequestDelegate _next;
/// <summary>
/// Initializes a new instance of the <see cref="EndpointResolutionMiddleware"/> class.
/// </summary>
public EndpointResolutionMiddleware(RequestDelegate next)
{
_next = next;
}
/// <summary>
/// Invokes the middleware.
/// </summary>
public async Task Invoke(HttpContext context, IGlobalRoutingState routingState)
{
var method = context.Request.Method;
var path = context.Request.Path.ToString();
var endpoint = routingState.ResolveEndpoint(method, path);
if (endpoint is null)
{
context.Response.StatusCode = StatusCodes.Status404NotFound;
await context.Response.WriteAsJsonAsync(new
{
error = "Endpoint not found",
method,
path
});
return;
}
context.Items[RouterHttpContextKeys.EndpointDescriptor] = endpoint;
await _next(context);
}
}

View File

@@ -1,29 +0,0 @@
namespace StellaOps.Gateway.WebService.Middleware;
/// <summary>
/// Exception thrown when a payload limit is exceeded during streaming.
/// </summary>
public sealed class PayloadLimitExceededException : Exception
{
/// <summary>
/// Initializes a new instance of the <see cref="PayloadLimitExceededException"/> class.
/// </summary>
/// <param name="bytesRead">The number of bytes read before the limit was exceeded.</param>
/// <param name="limit">The limit that was exceeded.</param>
public PayloadLimitExceededException(long bytesRead, long limit)
: base($"Payload limit exceeded: {bytesRead} bytes read, limit is {limit} bytes")
{
BytesRead = bytesRead;
Limit = limit;
}
/// <summary>
/// Gets the number of bytes read before the limit was exceeded.
/// </summary>
public long BytesRead { get; }
/// <summary>
/// Gets the limit that was exceeded.
/// </summary>
public long Limit { get; }
}

View File

@@ -1,162 +0,0 @@
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService.Middleware;
/// <summary>
/// Middleware that enforces payload limits per-request, per-connection, and aggregate.
/// </summary>
public sealed class PayloadLimitsMiddleware
{
private readonly RequestDelegate _next;
private readonly PayloadLimits _limits;
private readonly ILogger<PayloadLimitsMiddleware> _logger;
/// <summary>
/// Initializes a new instance of the <see cref="PayloadLimitsMiddleware"/> class.
/// </summary>
public PayloadLimitsMiddleware(
RequestDelegate next,
IOptions<PayloadLimits> limits,
ILogger<PayloadLimitsMiddleware> logger)
{
_next = next;
_limits = limits.Value;
_logger = logger;
}
/// <summary>
/// Invokes the middleware.
/// </summary>
public async Task Invoke(HttpContext context, IPayloadTracker tracker)
{
var connectionId = context.Connection.Id;
var contentLength = context.Request.ContentLength ?? 0;
// Early rejection for known oversized Content-Length (LIM-002, LIM-003)
if (context.Request.ContentLength.HasValue &&
context.Request.ContentLength.Value > _limits.MaxRequestBytesPerCall)
{
_logger.LogWarning(
"Request rejected: Content-Length {ContentLength} exceeds per-call limit {Limit}. ConnectionId: {ConnectionId}",
context.Request.ContentLength.Value,
_limits.MaxRequestBytesPerCall,
connectionId);
context.Response.StatusCode = StatusCodes.Status413PayloadTooLarge;
await context.Response.WriteAsJsonAsync(new
{
error = "Payload Too Large",
maxBytes = _limits.MaxRequestBytesPerCall,
contentLength = context.Request.ContentLength.Value
});
return;
}
// Try to reserve capacity (checks aggregate and per-connection limits)
if (!tracker.TryReserve(connectionId, contentLength))
{
// Check which limit was hit
if (tracker.IsOverloaded)
{
// Aggregate limit exceeded (LIM-033)
_logger.LogWarning(
"Request rejected: Aggregate limit exceeded. Current inflight: {Current}, Limit: {Limit}. ConnectionId: {ConnectionId}",
tracker.CurrentInflightBytes,
_limits.MaxAggregateInflightBytes,
connectionId);
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
await context.Response.WriteAsJsonAsync(new
{
error = "Service Overloaded",
message = "Too many concurrent requests"
});
}
else
{
// Per-connection limit exceeded (LIM-022)
_logger.LogWarning(
"Request rejected: Per-connection limit exceeded. ConnectionId: {ConnectionId}, Current: {Current}, Limit: {Limit}",
connectionId,
tracker.GetConnectionInflightBytes(connectionId),
_limits.MaxRequestBytesPerConnection);
context.Response.StatusCode = StatusCodes.Status429TooManyRequests;
await context.Response.WriteAsJsonAsync(new
{
error = "Too Many Requests",
message = "Per-connection limit exceeded"
});
}
return;
}
// Store the original body stream
var originalBody = context.Request.Body;
long actualBytesRead = 0;
try
{
// Wrap the request body with ByteCountingStream for streaming requests
if (!context.Request.ContentLength.HasValue || context.Request.ContentLength.Value > 0)
{
var countingStream = new ByteCountingStream(
originalBody,
_limits.MaxRequestBytesPerCall,
() =>
{
_logger.LogWarning(
"Mid-stream limit exceeded. ConnectionId: {ConnectionId}, Limit: {Limit}",
connectionId,
_limits.MaxRequestBytesPerCall);
});
context.Request.Body = countingStream;
// Store reference for later access to bytes read
context.Items["PayloadLimits:CountingStream"] = countingStream;
}
await _next(context);
// Get actual bytes read
if (context.Items["PayloadLimits:CountingStream"] is ByteCountingStream cs)
{
actualBytesRead = cs.BytesRead;
}
}
catch (PayloadLimitExceededException ex)
{
_logger.LogWarning(
"Payload limit exceeded mid-stream. ConnectionId: {ConnectionId}, BytesRead: {BytesRead}, Limit: {Limit}",
connectionId,
ex.BytesRead,
ex.Limit);
// Only set response if not already started
if (!context.Response.HasStarted)
{
context.Response.StatusCode = StatusCodes.Status413PayloadTooLarge;
await context.Response.WriteAsJsonAsync(new
{
error = "Payload Too Large",
maxBytes = _limits.MaxRequestBytesPerCall,
bytesReceived = ex.BytesRead
});
}
actualBytesRead = ex.BytesRead;
}
finally
{
// Restore original body stream
context.Request.Body = originalBody;
// Release reserved capacity
var bytesToRelease = actualBytesRead > 0 ? actualBytesRead : contentLength;
tracker.Release(connectionId, bytesToRelease);
}
}
}

View File

@@ -1,127 +0,0 @@
using System.Collections.Concurrent;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService.Middleware;
/// <summary>
/// Tracks payload bytes across requests, connections, and globally.
/// </summary>
public interface IPayloadTracker
{
/// <summary>
/// Tries to reserve capacity for an estimated payload size.
/// </summary>
/// <param name="connectionId">The connection identifier.</param>
/// <param name="estimatedBytes">The estimated bytes to reserve.</param>
/// <returns>True if capacity was reserved; false if limits would be exceeded.</returns>
bool TryReserve(string connectionId, long estimatedBytes);
/// <summary>
/// Releases previously reserved capacity.
/// </summary>
/// <param name="connectionId">The connection identifier.</param>
/// <param name="actualBytes">The actual bytes to release.</param>
void Release(string connectionId, long actualBytes);
/// <summary>
/// Gets the current total inflight bytes across all connections.
/// </summary>
long CurrentInflightBytes { get; }
/// <summary>
/// Gets a value indicating whether the system is overloaded.
/// </summary>
bool IsOverloaded { get; }
/// <summary>
/// Gets the current inflight bytes for a specific connection.
/// </summary>
/// <param name="connectionId">The connection identifier.</param>
/// <returns>The current inflight bytes for the connection.</returns>
long GetConnectionInflightBytes(string connectionId);
}
/// <summary>
/// Default implementation of <see cref="IPayloadTracker"/>.
/// </summary>
public sealed class PayloadTracker : IPayloadTracker
{
private readonly PayloadLimits _limits;
private readonly ILogger<PayloadTracker> _logger;
private long _totalInflightBytes;
private readonly ConcurrentDictionary<string, long> _perConnectionBytes = new();
/// <summary>
/// Initializes a new instance of the <see cref="PayloadTracker"/> class.
/// </summary>
public PayloadTracker(IOptions<PayloadLimits> limits, ILogger<PayloadTracker> logger)
{
_limits = limits.Value;
_logger = logger;
}
/// <inheritdoc />
public long CurrentInflightBytes => Interlocked.Read(ref _totalInflightBytes);
/// <inheritdoc />
public bool IsOverloaded => CurrentInflightBytes > _limits.MaxAggregateInflightBytes;
/// <inheritdoc />
public bool TryReserve(string connectionId, long estimatedBytes)
{
// Check aggregate limit
var newTotal = Interlocked.Add(ref _totalInflightBytes, estimatedBytes);
if (newTotal > _limits.MaxAggregateInflightBytes)
{
Interlocked.Add(ref _totalInflightBytes, -estimatedBytes);
_logger.LogWarning(
"Aggregate payload limit exceeded. Current: {Current}, Limit: {Limit}",
newTotal - estimatedBytes,
_limits.MaxAggregateInflightBytes);
return false;
}
// Check per-connection limit
var connectionBytes = _perConnectionBytes.AddOrUpdate(
connectionId,
estimatedBytes,
(_, current) => current + estimatedBytes);
if (connectionBytes > _limits.MaxRequestBytesPerConnection)
{
// Roll back
_perConnectionBytes.AddOrUpdate(
connectionId,
0,
(_, current) => current - estimatedBytes);
Interlocked.Add(ref _totalInflightBytes, -estimatedBytes);
_logger.LogWarning(
"Per-connection payload limit exceeded for {ConnectionId}. Current: {Current}, Limit: {Limit}",
connectionId,
connectionBytes - estimatedBytes,
_limits.MaxRequestBytesPerConnection);
return false;
}
return true;
}
/// <inheritdoc />
public void Release(string connectionId, long actualBytes)
{
Interlocked.Add(ref _totalInflightBytes, -actualBytes);
_perConnectionBytes.AddOrUpdate(
connectionId,
0,
(_, current) => Math.Max(0, current - actualBytes));
}
/// <inheritdoc />
public long GetConnectionInflightBytes(string connectionId)
{
return _perConnectionBytes.TryGetValue(connectionId, out var bytes) ? bytes : 0;
}
}

View File

@@ -1,107 +0,0 @@
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService.Middleware;
/// <summary>
/// Makes routing decisions for resolved endpoints.
/// </summary>
public sealed class RoutingDecisionMiddleware
{
private readonly RequestDelegate _next;
/// <summary>
/// Initializes a new instance of the <see cref="RoutingDecisionMiddleware"/> class.
/// </summary>
public RoutingDecisionMiddleware(RequestDelegate next)
{
_next = next;
}
/// <summary>
/// Invokes the middleware.
/// </summary>
public async Task Invoke(
HttpContext context,
IRoutingPlugin routingPlugin,
IGlobalRoutingState routingState,
IOptions<GatewayNodeConfig> gatewayConfig,
IOptions<RoutingOptions> routingOptions)
{
var endpoint = context.Items[RouterHttpContextKeys.EndpointDescriptor] as EndpointDescriptor;
if (endpoint is null)
{
context.Response.StatusCode = StatusCodes.Status500InternalServerError;
await context.Response.WriteAsJsonAsync(new { error = "Endpoint descriptor missing" });
return;
}
// Build routing context
var availableConnections = routingState.GetConnectionsFor(
endpoint.ServiceName,
endpoint.Version,
endpoint.Method,
endpoint.Path);
var headers = context.Request.Headers
.ToDictionary(h => h.Key, h => h.Value.ToString());
var routingContext = new RoutingContext
{
Method = context.Request.Method,
Path = context.Request.Path.ToString(),
Headers = headers,
Endpoint = endpoint,
AvailableConnections = availableConnections,
GatewayRegion = gatewayConfig.Value.Region,
RequestedVersion = ExtractVersionFromRequest(context, routingOptions.Value),
CancellationToken = context.RequestAborted
};
var decision = await routingPlugin.ChooseInstanceAsync(
routingContext,
context.RequestAborted);
if (decision is null)
{
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
await context.Response.WriteAsJsonAsync(new
{
error = "No instances available",
service = endpoint.ServiceName,
version = endpoint.Version
});
return;
}
context.Items[RouterHttpContextKeys.RoutingDecision] = decision;
await _next(context);
}
private static string? ExtractVersionFromRequest(HttpContext context, RoutingOptions options)
{
// Check for version in Accept header: Accept: application/vnd.stellaops.v1+json
var acceptHeader = context.Request.Headers.Accept.FirstOrDefault();
if (!string.IsNullOrEmpty(acceptHeader))
{
var versionMatch = System.Text.RegularExpressions.Regex.Match(
acceptHeader,
@"application/vnd\.stellaops\.v(\d+(?:\.\d+)*)\+json");
if (versionMatch.Success)
{
return versionMatch.Groups[1].Value;
}
}
// Check for X-Api-Version header
var versionHeader = context.Request.Headers["X-Api-Version"].FirstOrDefault();
if (!string.IsNullOrEmpty(versionHeader))
{
return versionHeader;
}
// Fall back to default version from options
return options.DefaultVersion;
}
}

View File

@@ -1,457 +0,0 @@
using System.Collections.Concurrent;
using System.Diagnostics;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Frames;
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService.Middleware;
/// <summary>
/// Dispatches HTTP requests to microservices via the transport layer.
/// </summary>
public sealed class TransportDispatchMiddleware
{
private readonly RequestDelegate _next;
private readonly ILogger<TransportDispatchMiddleware> _logger;
/// <summary>
/// Tracks cancelled request IDs to ignore late responses.
/// Keys expire after 60 seconds to prevent memory leaks.
/// </summary>
private static readonly ConcurrentDictionary<string, DateTimeOffset> CancelledRequests = new();
/// <summary>
/// Initializes a new instance of the <see cref="TransportDispatchMiddleware"/> class.
/// </summary>
public TransportDispatchMiddleware(RequestDelegate next, ILogger<TransportDispatchMiddleware> logger)
{
_next = next;
_logger = logger;
// Start background cleanup task for expired cancelled request entries
_ = Task.Run(CleanupExpiredCancelledRequestsAsync);
}
private static async Task CleanupExpiredCancelledRequestsAsync()
{
while (true)
{
await Task.Delay(TimeSpan.FromSeconds(30));
var cutoff = DateTimeOffset.UtcNow.AddSeconds(-60);
foreach (var kvp in CancelledRequests)
{
if (kvp.Value < cutoff)
{
CancelledRequests.TryRemove(kvp.Key, out _);
}
}
}
}
private static void MarkCancelled(string requestId)
{
CancelledRequests[requestId] = DateTimeOffset.UtcNow;
}
private static bool IsCancelled(string requestId)
{
return CancelledRequests.ContainsKey(requestId);
}
/// <summary>
/// Invokes the middleware.
/// </summary>
public async Task Invoke(
HttpContext context,
ITransportClient transportClient,
IGlobalRoutingState routingState)
{
var decision = context.Items[RouterHttpContextKeys.RoutingDecision] as RoutingDecision;
if (decision is null)
{
context.Response.StatusCode = StatusCodes.Status500InternalServerError;
await context.Response.WriteAsJsonAsync(new { error = "Routing decision missing" });
return;
}
var requestId = Guid.NewGuid().ToString("N");
// Extract headers (exclude some internal headers)
var headers = context.Request.Headers
.Where(h => !h.Key.StartsWith(":", StringComparison.Ordinal))
.ToDictionary(
h => h.Key,
h => h.Value.ToString());
// For streaming endpoints, use streaming dispatch
if (decision.Endpoint.SupportsStreaming)
{
await DispatchStreamingAsync(context, transportClient, routingState, decision, requestId, headers);
return;
}
// Read request body (buffered)
byte[] bodyBytes;
using (var ms = new MemoryStream())
{
await context.Request.Body.CopyToAsync(ms, context.RequestAborted);
bodyBytes = ms.ToArray();
}
// Build request frame
var requestFrame = new RequestFrame
{
RequestId = requestId,
CorrelationId = context.TraceIdentifier,
Method = context.Request.Method,
Path = context.Request.Path.ToString() + context.Request.QueryString.ToString(),
Headers = headers,
Payload = bodyBytes,
TimeoutSeconds = (int)decision.EffectiveTimeout.TotalSeconds,
SupportsStreaming = false
};
var frame = FrameConverter.ToFrame(requestFrame);
_logger.LogDebug(
"Dispatching {Method} {Path} to {ServiceName}/{Version} via {TransportType}",
requestFrame.Method,
requestFrame.Path,
decision.Connection.Instance.ServiceName,
decision.Connection.Instance.Version,
decision.TransportType);
// Create linked cancellation token with timeout
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted);
timeoutCts.CancelAfter(decision.EffectiveTimeout);
// Register client disconnect handler to send CANCEL
var requestIdGuid = Guid.TryParse(requestId, out var parsed) ? parsed : Guid.NewGuid();
using var clientDisconnectRegistration = context.RequestAborted.Register(() =>
{
// Mark as cancelled to ignore late responses
MarkCancelled(requestId);
// Send CANCEL frame (fire and forget)
_ = Task.Run(async () =>
{
try
{
await transportClient.SendCancelAsync(
decision.Connection,
requestIdGuid,
CancelReasons.ClientDisconnected);
_logger.LogDebug(
"Sent CANCEL for request {RequestId} due to client disconnect",
requestId);
}
catch (Exception ex)
{
_logger.LogWarning(ex,
"Failed to send CANCEL for request {RequestId} on client disconnect",
requestId);
}
});
});
Frame responseFrame;
var startTimestamp = Stopwatch.GetTimestamp();
try
{
responseFrame = await transportClient.SendRequestAsync(
decision.Connection,
frame,
decision.EffectiveTimeout,
timeoutCts.Token);
// Record ping latency and update connection's average
var elapsed = Stopwatch.GetElapsedTime(startTimestamp);
UpdateConnectionPing(routingState, decision.Connection.ConnectionId, elapsed.TotalMilliseconds);
}
catch (OperationCanceledException) when (!context.RequestAborted.IsCancellationRequested)
{
// Internal timeout (not client disconnect)
_logger.LogWarning(
"Request {RequestId} to {ServiceName} timed out after {Timeout}",
requestId,
decision.Connection.Instance.ServiceName,
decision.EffectiveTimeout);
// Mark as cancelled to ignore late responses
MarkCancelled(requestId);
// Send cancel to microservice
try
{
await transportClient.SendCancelAsync(
decision.Connection,
requestIdGuid,
CancelReasons.Timeout);
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to send cancel for request {RequestId}", requestId);
}
context.Response.StatusCode = StatusCodes.Status504GatewayTimeout;
await context.Response.WriteAsJsonAsync(new
{
error = "Upstream timeout",
service = decision.Connection.Instance.ServiceName,
timeout = decision.EffectiveTimeout.TotalSeconds
});
return;
}
catch (OperationCanceledException)
{
// Client disconnected - cancel already sent via registration above
MarkCancelled(requestId);
_logger.LogDebug("Client disconnected, request {RequestId} cancelled", requestId);
return;
}
catch (Exception ex)
{
_logger.LogError(ex,
"Error dispatching request {RequestId} to {ServiceName}",
requestId,
decision.Connection.Instance.ServiceName);
context.Response.StatusCode = StatusCodes.Status502BadGateway;
await context.Response.WriteAsJsonAsync(new
{
error = "Upstream error",
message = ex.Message
});
return;
}
// Check if request was cancelled while waiting for response
if (IsCancelled(requestId))
{
_logger.LogDebug("Ignoring late response for cancelled request {RequestId}", requestId);
return;
}
// Parse response
var response = FrameConverter.ToResponseFrame(responseFrame);
if (response is null)
{
_logger.LogError(
"Invalid response frame from {ServiceName} for request {RequestId}",
decision.Connection.Instance.ServiceName,
requestId);
context.Response.StatusCode = StatusCodes.Status502BadGateway;
await context.Response.WriteAsJsonAsync(new { error = "Invalid upstream response" });
return;
}
// Map response to HTTP
context.Response.StatusCode = response.StatusCode;
// Copy response headers
foreach (var (key, value) in response.Headers)
{
// Skip some headers that shouldn't be copied
if (key.Equals("Transfer-Encoding", StringComparison.OrdinalIgnoreCase) ||
key.Equals("Content-Length", StringComparison.OrdinalIgnoreCase))
{
continue;
}
context.Response.Headers[key] = value;
}
// Write response body
if (response.Payload.Length > 0)
{
await context.Response.Body.WriteAsync(response.Payload, context.RequestAborted);
}
_logger.LogDebug(
"Request {RequestId} completed with status {StatusCode}",
requestId,
response.StatusCode);
}
/// <summary>
/// Updates the connection's average ping time using exponential moving average.
/// </summary>
private static void UpdateConnectionPing(
IGlobalRoutingState routingState,
string connectionId,
double pingMs)
{
const double smoothingFactor = 0.2;
routingState.UpdateConnection(connectionId, connection =>
{
if (connection.AveragePingMs == 0)
{
connection.AveragePingMs = pingMs;
}
else
{
connection.AveragePingMs = (1 - smoothingFactor) * connection.AveragePingMs + smoothingFactor * pingMs;
}
});
}
/// <summary>
/// Dispatches a streaming request to a microservice.
/// </summary>
private async Task DispatchStreamingAsync(
HttpContext context,
ITransportClient transportClient,
IGlobalRoutingState routingState,
RoutingDecision decision,
string requestId,
Dictionary<string, string> headers)
{
var requestIdGuid = Guid.TryParse(requestId, out var parsed) ? parsed : Guid.NewGuid();
// Build request header frame (without body - will stream)
var requestFrame = new RequestFrame
{
RequestId = requestId,
CorrelationId = context.TraceIdentifier,
Method = context.Request.Method,
Path = context.Request.Path.ToString() + context.Request.QueryString.ToString(),
Headers = headers,
Payload = Array.Empty<byte>(), // Empty - body will be streamed
TimeoutSeconds = (int)decision.EffectiveTimeout.TotalSeconds,
SupportsStreaming = true
};
var frame = FrameConverter.ToFrame(requestFrame);
_logger.LogDebug(
"Dispatching streaming {Method} {Path} to {ServiceName}/{Version}",
requestFrame.Method,
requestFrame.Path,
decision.Connection.Instance.ServiceName,
decision.Connection.Instance.Version);
// Create linked cancellation token with timeout
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted);
timeoutCts.CancelAfter(decision.EffectiveTimeout);
// Register client disconnect handler to send CANCEL
using var clientDisconnectRegistration = context.RequestAborted.Register(() =>
{
MarkCancelled(requestId);
_ = Task.Run(async () =>
{
try
{
await transportClient.SendCancelAsync(
decision.Connection,
requestIdGuid,
CancelReasons.ClientDisconnected);
_logger.LogDebug(
"Sent CANCEL for streaming request {RequestId} due to client disconnect",
requestId);
}
catch (Exception ex)
{
_logger.LogWarning(ex,
"Failed to send CANCEL for streaming request {RequestId}",
requestId);
}
});
});
var startTimestamp = Stopwatch.GetTimestamp();
var responseReceived = false;
try
{
// Use streaming transport method
await transportClient.SendStreamingAsync(
decision.Connection,
frame,
context.Request.Body,
async responseBodyStream =>
{
responseReceived = true;
// For now, read the response stream and write to HTTP response
// The response headers should be set before streaming begins
context.Response.StatusCode = StatusCodes.Status200OK;
context.Response.Headers["Transfer-Encoding"] = "chunked";
context.Response.ContentType = "application/octet-stream";
await responseBodyStream.CopyToAsync(context.Response.Body, timeoutCts.Token);
},
PayloadLimits.Default,
timeoutCts.Token);
// Record ping latency
var elapsed = Stopwatch.GetElapsedTime(startTimestamp);
UpdateConnectionPing(routingState, decision.Connection.ConnectionId, elapsed.TotalMilliseconds);
_logger.LogDebug(
"Streaming request {RequestId} completed",
requestId);
}
catch (OperationCanceledException) when (!context.RequestAborted.IsCancellationRequested)
{
// Internal timeout
_logger.LogWarning(
"Streaming request {RequestId} timed out after {Timeout}",
requestId,
decision.EffectiveTimeout);
MarkCancelled(requestId);
try
{
await transportClient.SendCancelAsync(
decision.Connection,
requestIdGuid,
CancelReasons.Timeout);
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to send cancel for streaming request {RequestId}", requestId);
}
if (!responseReceived)
{
context.Response.StatusCode = StatusCodes.Status504GatewayTimeout;
await context.Response.WriteAsJsonAsync(new
{
error = "Upstream streaming timeout",
service = decision.Connection.Instance.ServiceName,
timeout = decision.EffectiveTimeout.TotalSeconds
});
}
}
catch (OperationCanceledException)
{
// Client disconnected
MarkCancelled(requestId);
_logger.LogDebug("Client disconnected, streaming request {RequestId} cancelled", requestId);
}
catch (Exception ex)
{
_logger.LogError(ex,
"Error dispatching streaming request {RequestId}",
requestId);
if (!responseReceived)
{
context.Response.StatusCode = StatusCodes.Status502BadGateway;
await context.Response.WriteAsJsonAsync(new
{
error = "Upstream streaming error",
message = ex.Message
});
}
}
}
}

View File

@@ -1,106 +0,0 @@
using System.Text.Json.Nodes;
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService.OpenApi;
/// <summary>
/// Maps claim requirements to OpenAPI security schemes.
/// </summary>
internal static class ClaimSecurityMapper
{
/// <summary>
/// Generates security schemes from claim requirements.
/// </summary>
/// <param name="endpoints">All endpoint descriptors.</param>
/// <param name="tokenUrl">The OAuth2 token URL.</param>
/// <returns>Security schemes JSON object.</returns>
public static JsonObject GenerateSecuritySchemes(
IEnumerable<EndpointDescriptor> endpoints,
string tokenUrl)
{
var schemes = new JsonObject();
// Always add BearerAuth scheme
schemes["BearerAuth"] = new JsonObject
{
["type"] = "http",
["scheme"] = "bearer",
["bearerFormat"] = "JWT",
["description"] = "JWT Bearer token authentication"
};
// Collect all unique scopes from claims
var scopes = new Dictionary<string, string>();
foreach (var endpoint in endpoints)
{
foreach (var claim in endpoint.RequiringClaims)
{
var scope = claim.Type;
if (!scopes.ContainsKey(scope))
{
scopes[scope] = $"Access scope: {scope}";
}
}
}
// Add OAuth2 scheme if there are any scopes
if (scopes.Count > 0)
{
var scopesObject = new JsonObject();
foreach (var (scope, description) in scopes)
{
scopesObject[scope] = description;
}
schemes["OAuth2"] = new JsonObject
{
["type"] = "oauth2",
["flows"] = new JsonObject
{
["clientCredentials"] = new JsonObject
{
["tokenUrl"] = tokenUrl,
["scopes"] = scopesObject
}
}
};
}
return schemes;
}
/// <summary>
/// Generates security requirement for an endpoint.
/// </summary>
/// <param name="endpoint">The endpoint descriptor.</param>
/// <returns>Security requirement JSON array.</returns>
public static JsonArray GenerateSecurityRequirement(EndpointDescriptor endpoint)
{
var requirements = new JsonArray();
if (endpoint.RequiringClaims.Count == 0)
{
return requirements;
}
var requirement = new JsonObject();
// Always require BearerAuth
requirement["BearerAuth"] = new JsonArray();
// Add OAuth2 scopes
var scopes = new JsonArray();
foreach (var claim in endpoint.RequiringClaims)
{
scopes.Add(claim.Type);
}
if (scopes.Count > 0)
{
requirement["OAuth2"] = scopes;
}
requirements.Add(requirement);
return requirements;
}
}

View File

@@ -1,69 +0,0 @@
using System.Security.Cryptography;
using System.Text;
using Microsoft.Extensions.Options;
namespace StellaOps.Gateway.WebService.OpenApi;
/// <summary>
/// Caches the generated OpenAPI document with TTL-based expiration.
/// </summary>
internal sealed class GatewayOpenApiDocumentCache : IGatewayOpenApiDocumentCache
{
private readonly IOpenApiDocumentGenerator _generator;
private readonly OpenApiAggregationOptions _options;
private readonly object _lock = new();
private string? _cachedDocument;
private string? _cachedETag;
private DateTime _generatedAt;
private bool _invalidated = true;
public GatewayOpenApiDocumentCache(
IOpenApiDocumentGenerator generator,
IOptions<OpenApiAggregationOptions> options)
{
_generator = generator;
_options = options.Value;
}
/// <inheritdoc />
public (string DocumentJson, string ETag, DateTime GeneratedAt) GetDocument()
{
lock (_lock)
{
var now = DateTime.UtcNow;
var ttl = TimeSpan.FromSeconds(_options.CacheTtlSeconds);
// Check if we need to regenerate
if (_invalidated || _cachedDocument is null || now - _generatedAt > ttl)
{
Regenerate();
}
return (_cachedDocument!, _cachedETag!, _generatedAt);
}
}
/// <inheritdoc />
public void Invalidate()
{
lock (_lock)
{
_invalidated = true;
}
}
private void Regenerate()
{
_cachedDocument = _generator.GenerateDocument();
_cachedETag = ComputeETag(_cachedDocument);
_generatedAt = DateTime.UtcNow;
_invalidated = false;
}
private static string ComputeETag(string content)
{
var hash = SHA256.HashData(Encoding.UTF8.GetBytes(content));
return $"\"{Convert.ToHexString(hash)[..16]}\"";
}
}

View File

@@ -1,18 +0,0 @@
namespace StellaOps.Gateway.WebService.OpenApi;
/// <summary>
/// Caches the generated OpenAPI document with TTL-based expiration.
/// </summary>
public interface IGatewayOpenApiDocumentCache
{
/// <summary>
/// Gets the cached document or regenerates if expired.
/// </summary>
/// <returns>A tuple containing the document JSON, ETag, and generation timestamp.</returns>
(string DocumentJson, string ETag, DateTime GeneratedAt) GetDocument();
/// <summary>
/// Invalidates the cache, forcing regeneration on next access.
/// </summary>
void Invalidate();
}

View File

@@ -1,13 +0,0 @@
namespace StellaOps.Gateway.WebService.OpenApi;
/// <summary>
/// Generates OpenAPI documents from aggregated microservice schemas.
/// </summary>
public interface IOpenApiDocumentGenerator
{
/// <summary>
/// Generates the OpenAPI 3.1.0 document as JSON.
/// </summary>
/// <returns>The OpenAPI document as a JSON string.</returns>
string GenerateDocument();
}

View File

@@ -1,62 +0,0 @@
namespace StellaOps.Gateway.WebService.OpenApi;
/// <summary>
/// Configuration options for OpenAPI document aggregation.
/// </summary>
public sealed class OpenApiAggregationOptions
{
/// <summary>
/// The configuration section name.
/// </summary>
public const string SectionName = "OpenApi";
/// <summary>
/// Gets or sets the API title.
/// </summary>
public string Title { get; set; } = "StellaOps Gateway API";
/// <summary>
/// Gets or sets the API description.
/// </summary>
public string Description { get; set; } = "Unified API aggregating all connected microservices.";
/// <summary>
/// Gets or sets the API version.
/// </summary>
public string Version { get; set; } = "1.0.0";
/// <summary>
/// Gets or sets the server URL.
/// </summary>
public string ServerUrl { get; set; } = "/";
/// <summary>
/// Gets or sets the cache TTL in seconds.
/// </summary>
public int CacheTtlSeconds { get; set; } = 60;
/// <summary>
/// Gets or sets whether OpenAPI aggregation is enabled.
/// </summary>
public bool Enabled { get; set; } = true;
/// <summary>
/// Gets or sets the license name.
/// </summary>
public string LicenseName { get; set; } = "AGPL-3.0-or-later";
/// <summary>
/// Gets or sets the contact name.
/// </summary>
public string? ContactName { get; set; }
/// <summary>
/// Gets or sets the contact email.
/// </summary>
public string? ContactEmail { get; set; }
/// <summary>
/// Gets or sets the OAuth2 token URL for security schemes.
/// </summary>
public string TokenUrl { get; set; } = "/auth/token";
}

View File

@@ -1,285 +0,0 @@
using System.Text.Json;
using System.Text.Json.Nodes;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Models;
namespace StellaOps.Gateway.WebService.OpenApi;
/// <summary>
/// Generates OpenAPI 3.1.0 documents from aggregated microservice schemas.
/// </summary>
internal sealed class OpenApiDocumentGenerator : IOpenApiDocumentGenerator
{
private readonly IGlobalRoutingState _routingState;
private readonly OpenApiAggregationOptions _options;
private static readonly JsonSerializerOptions JsonOptions = new()
{
WriteIndented = true
};
public OpenApiDocumentGenerator(
IGlobalRoutingState routingState,
IOptions<OpenApiAggregationOptions> options)
{
_routingState = routingState;
_options = options.Value;
}
/// <inheritdoc />
public string GenerateDocument()
{
var connections = _routingState.GetAllConnections();
var doc = new JsonObject
{
["openapi"] = "3.1.0",
["info"] = GenerateInfo(),
["servers"] = GenerateServers(),
["paths"] = GeneratePaths(connections),
["components"] = GenerateComponents(connections),
["tags"] = GenerateTags(connections)
};
return doc.ToJsonString(JsonOptions);
}
private JsonObject GenerateInfo()
{
var info = new JsonObject
{
["title"] = _options.Title,
["version"] = _options.Version,
["description"] = _options.Description,
["license"] = new JsonObject
{
["name"] = _options.LicenseName
}
};
if (_options.ContactName is not null || _options.ContactEmail is not null)
{
var contact = new JsonObject();
if (_options.ContactName is not null)
contact["name"] = _options.ContactName;
if (_options.ContactEmail is not null)
contact["email"] = _options.ContactEmail;
info["contact"] = contact;
}
return info;
}
private JsonArray GenerateServers()
{
return new JsonArray
{
new JsonObject
{
["url"] = _options.ServerUrl
}
};
}
private JsonObject GeneratePaths(IReadOnlyList<ConnectionState> connections)
{
var paths = new JsonObject();
// Group endpoints by path
var pathGroups = new Dictionary<string, List<(ConnectionState Conn, EndpointDescriptor Endpoint)>>();
foreach (var conn in connections)
{
foreach (var endpoint in conn.Endpoints.Values)
{
if (!pathGroups.TryGetValue(endpoint.Path, out var list))
{
list = [];
pathGroups[endpoint.Path] = list;
}
list.Add((conn, endpoint));
}
}
// Generate path items
foreach (var (path, endpoints) in pathGroups.OrderBy(p => p.Key))
{
var pathItem = new JsonObject();
foreach (var (conn, endpoint) in endpoints)
{
var operation = GenerateOperation(conn, endpoint);
var method = endpoint.Method.ToLowerInvariant();
pathItem[method] = operation;
}
paths[path] = pathItem;
}
return paths;
}
private JsonObject GenerateOperation(ConnectionState conn, EndpointDescriptor endpoint)
{
var operation = new JsonObject
{
["operationId"] = $"{conn.Instance.ServiceName}_{endpoint.Path.Replace("/", "_").Trim('_')}_{endpoint.Method}",
["tags"] = new JsonArray { conn.Instance.ServiceName }
};
// Add documentation from SchemaInfo
if (endpoint.SchemaInfo is not null)
{
if (endpoint.SchemaInfo.Summary is not null)
operation["summary"] = endpoint.SchemaInfo.Summary;
if (endpoint.SchemaInfo.Description is not null)
operation["description"] = endpoint.SchemaInfo.Description;
if (endpoint.SchemaInfo.Deprecated)
operation["deprecated"] = true;
// Override tags if specified
if (endpoint.SchemaInfo.Tags.Count > 0)
{
var tags = new JsonArray();
foreach (var tag in endpoint.SchemaInfo.Tags)
{
tags.Add(tag);
}
operation["tags"] = tags;
}
}
// Add security requirements
var security = ClaimSecurityMapper.GenerateSecurityRequirement(endpoint);
if (security.Count > 0)
{
operation["security"] = security;
}
// Add request body if schema exists
if (endpoint.SchemaInfo?.RequestSchemaId is not null)
{
var schemaRef = $"#/components/schemas/{conn.Instance.ServiceName}_{endpoint.SchemaInfo.RequestSchemaId}";
operation["requestBody"] = new JsonObject
{
["required"] = true,
["content"] = new JsonObject
{
["application/json"] = new JsonObject
{
["schema"] = new JsonObject
{
["$ref"] = schemaRef
}
}
}
};
}
// Add responses
var responses = new JsonObject();
// Success response
var successResponse = new JsonObject
{
["description"] = "Success"
};
if (endpoint.SchemaInfo?.ResponseSchemaId is not null)
{
var schemaRef = $"#/components/schemas/{conn.Instance.ServiceName}_{endpoint.SchemaInfo.ResponseSchemaId}";
successResponse["content"] = new JsonObject
{
["application/json"] = new JsonObject
{
["schema"] = new JsonObject
{
["$ref"] = schemaRef
}
}
};
}
responses["200"] = successResponse;
// Error responses
responses["400"] = new JsonObject { ["description"] = "Bad Request" };
responses["401"] = new JsonObject { ["description"] = "Unauthorized" };
responses["404"] = new JsonObject { ["description"] = "Not Found" };
responses["422"] = new JsonObject { ["description"] = "Validation Error" };
responses["500"] = new JsonObject { ["description"] = "Internal Server Error" };
operation["responses"] = responses;
return operation;
}
private JsonObject GenerateComponents(IReadOnlyList<ConnectionState> connections)
{
var components = new JsonObject();
// Generate schemas with service prefix
var schemas = new JsonObject();
foreach (var conn in connections)
{
foreach (var (schemaId, schemaDef) in conn.Schemas)
{
var prefixedId = $"{conn.Instance.ServiceName}_{schemaId}";
try
{
var schemaNode = JsonNode.Parse(schemaDef.SchemaJson);
if (schemaNode is not null)
{
schemas[prefixedId] = schemaNode;
}
}
catch (JsonException)
{
// Skip invalid schemas
}
}
}
if (schemas.Count > 0)
{
components["schemas"] = schemas;
}
// Generate security schemes
var allEndpoints = connections.SelectMany(c => c.Endpoints.Values);
var securitySchemes = ClaimSecurityMapper.GenerateSecuritySchemes(allEndpoints, _options.TokenUrl);
if (securitySchemes.Count > 0)
{
components["securitySchemes"] = securitySchemes;
}
return components;
}
private JsonArray GenerateTags(IReadOnlyList<ConnectionState> connections)
{
var tags = new JsonArray();
var seen = new HashSet<string>();
foreach (var conn in connections)
{
var serviceName = conn.Instance.ServiceName;
if (seen.Add(serviceName))
{
var tag = new JsonObject
{
["name"] = serviceName,
["description"] = $"{serviceName} microservice (v{conn.Instance.Version})"
};
if (conn.OpenApiInfo?.Description is not null)
{
tag["description"] = conn.OpenApiInfo.Description;
}
tags.Add(tag);
}
}
return tags;
}
}

View File

@@ -1,124 +0,0 @@
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using Microsoft.AspNetCore.Mvc;
using YamlDotNet.Serialization;
using YamlDotNet.Serialization.NamingConventions;
namespace StellaOps.Gateway.WebService.OpenApi;
/// <summary>
/// Endpoints for serving OpenAPI documentation.
/// </summary>
public static class OpenApiEndpoints
{
private static readonly ISerializer YamlSerializer = new SerializerBuilder()
.WithNamingConvention(CamelCaseNamingConvention.Instance)
.Build();
/// <summary>
/// Maps OpenAPI endpoints to the application.
/// </summary>
public static IEndpointRouteBuilder MapGatewayOpenApiEndpoints(this IEndpointRouteBuilder endpoints)
{
endpoints.MapGet("/.well-known/openapi", GetOpenApiDiscovery)
.ExcludeFromDescription();
endpoints.MapGet("/openapi.json", GetOpenApiJson)
.ExcludeFromDescription();
endpoints.MapGet("/openapi.yaml", GetOpenApiYaml)
.ExcludeFromDescription();
return endpoints;
}
private static IResult GetOpenApiDiscovery(
[FromServices] IGatewayOpenApiDocumentCache cache,
HttpContext context)
{
var (_, etag, generatedAt) = cache.GetDocument();
var discovery = new
{
openapi_json = "/openapi.json",
openapi_yaml = "/openapi.yaml",
etag,
generated_at = generatedAt.ToString("O")
};
context.Response.Headers.CacheControl = "public, max-age=60";
return Results.Ok(discovery);
}
private static IResult GetOpenApiJson(
[FromServices] IGatewayOpenApiDocumentCache cache,
HttpContext context)
{
var (documentJson, etag, _) = cache.GetDocument();
// Check If-None-Match header
if (context.Request.Headers.TryGetValue("If-None-Match", out var ifNoneMatch))
{
if (ifNoneMatch == etag)
{
context.Response.Headers.ETag = etag;
context.Response.Headers.CacheControl = "public, max-age=60";
return Results.StatusCode(304);
}
}
context.Response.Headers.ETag = etag;
context.Response.Headers.CacheControl = "public, max-age=60";
return Results.Content(documentJson, "application/json; charset=utf-8");
}
private static IResult GetOpenApiYaml(
[FromServices] IGatewayOpenApiDocumentCache cache,
HttpContext context)
{
var (documentJson, etag, _) = cache.GetDocument();
// Check If-None-Match header
if (context.Request.Headers.TryGetValue("If-None-Match", out var ifNoneMatch))
{
if (ifNoneMatch == etag)
{
context.Response.Headers.ETag = etag;
context.Response.Headers.CacheControl = "public, max-age=60";
return Results.StatusCode(304);
}
}
// Convert JSON to YAML
var jsonNode = JsonNode.Parse(documentJson);
var yamlContent = ConvertToYaml(jsonNode);
context.Response.Headers.ETag = etag;
context.Response.Headers.CacheControl = "public, max-age=60";
return Results.Content(yamlContent, "application/yaml; charset=utf-8");
}
private static string ConvertToYaml(JsonNode? node)
{
if (node is null)
return string.Empty;
var obj = ConvertJsonNodeToObject(node);
return YamlSerializer.Serialize(obj);
}
private static object? ConvertJsonNodeToObject(JsonNode? node)
{
return node switch
{
null => null,
JsonObject obj => obj.ToDictionary(
kvp => kvp.Key,
kvp => ConvertJsonNodeToObject(kvp.Value)),
JsonArray arr => arr.Select(ConvertJsonNodeToObject).ToList(),
JsonValue val => val.GetValue<object>(),
_ => null
};
}
}

View File

@@ -1,84 +0,0 @@
using System.Collections.Concurrent;
using System.Diagnostics;
namespace StellaOps.Gateway.WebService;
/// <summary>
/// Tracks round-trip time for requests to compute average ping latency.
/// </summary>
internal sealed class PingTracker
{
private readonly ConcurrentDictionary<Guid, long> _pendingRequests = new();
private readonly object _lock = new();
private double _averagePingMs;
private const double SmoothingFactor = 0.2;
/// <summary>
/// Gets the exponential moving average of ping times in milliseconds.
/// </summary>
public double AveragePingMs
{
get
{
lock (_lock)
{
return _averagePingMs;
}
}
}
/// <summary>
/// Records that a request has been sent.
/// </summary>
/// <param name="correlationId">The correlation ID of the request.</param>
public void RecordRequestSent(Guid correlationId)
{
_pendingRequests[correlationId] = Stopwatch.GetTimestamp();
}
/// <summary>
/// Records that a response has been received and updates the average ping.
/// </summary>
/// <param name="correlationId">The correlation ID of the request.</param>
/// <returns>The round-trip time in milliseconds, or null if the correlation ID was not found.</returns>
public double? RecordResponseReceived(Guid correlationId)
{
if (!_pendingRequests.TryRemove(correlationId, out var startTicks))
{
return null;
}
var elapsed = Stopwatch.GetElapsedTime(startTicks);
var rtt = elapsed.TotalMilliseconds;
lock (_lock)
{
// Exponential moving average: avg = (1 - alpha) * avg + alpha * new_value
if (_averagePingMs == 0)
{
_averagePingMs = rtt; // First measurement
}
else
{
_averagePingMs = (1 - SmoothingFactor) * _averagePingMs + SmoothingFactor * rtt;
}
}
return rtt;
}
/// <summary>
/// Removes a pending request without recording a response.
/// Call this when a request times out or is cancelled.
/// </summary>
/// <param name="correlationId">The correlation ID of the request.</param>
public void RemovePending(Guid correlationId)
{
_pendingRequests.TryRemove(correlationId, out _);
}
/// <summary>
/// Gets the number of pending requests.
/// </summary>
public int PendingCount => _pendingRequests.Count;
}

View File

@@ -1,20 +0,0 @@
using StellaOps.Gateway.WebService;
var builder = WebApplication.CreateBuilder(args);
// Register gateway routing services
builder.Services.AddGatewayRouting(builder.Configuration);
var app = builder.Build();
// Health check endpoint (not routed through gateway middleware)
app.MapGet("/health", () => Results.Ok(new { status = "healthy" }));
// Gateway router middleware pipeline
// All other requests are routed through the gateway
app.UseGatewayRouter();
app.Run();
// Make Program class accessible for integration tests
public partial class Program { }

View File

@@ -1,22 +0,0 @@
namespace StellaOps.Gateway.WebService;
/// <summary>
/// Well-known HttpContext.Items keys for router pipeline.
/// </summary>
public static class RouterHttpContextKeys
{
/// <summary>
/// Key for the resolved <see cref="StellaOps.Router.Common.Models.EndpointDescriptor"/>.
/// </summary>
public const string EndpointDescriptor = "Stella.EndpointDescriptor";
/// <summary>
/// Key for the <see cref="StellaOps.Router.Common.Models.RoutingDecision"/>.
/// </summary>
public const string RoutingDecision = "Stella.RoutingDecision";
/// <summary>
/// Key for path parameters extracted from route template matching.
/// </summary>
public const string PathParameters = "Stella.PathParameters";
}

View File

@@ -1,67 +0,0 @@
namespace StellaOps.Gateway.WebService;
/// <summary>
/// Tie-breaker mode for routing when multiple instances have equal priority.
/// </summary>
public enum TieBreakerMode
{
/// <summary>
/// Select randomly among tied instances.
/// </summary>
Random,
/// <summary>
/// Rotate through tied instances in order.
/// </summary>
RoundRobin
}
/// <summary>
/// Options for routing behavior.
/// </summary>
public sealed class RoutingOptions
{
/// <summary>
/// Configuration section name for binding.
/// </summary>
public const string SectionName = "Routing";
/// <summary>
/// Gets or sets the default version to use when no version is specified in the request.
/// If null, requests without version specification will match any available version.
/// </summary>
public string? DefaultVersion { get; set; }
/// <summary>
/// Gets or sets whether to enable strict version matching.
/// When true, requests must specify an exact version.
/// When false, requests can match compatible versions.
/// </summary>
public bool StrictVersionMatching { get; set; } = true;
/// <summary>
/// Gets or sets the timeout for routing decisions in milliseconds.
/// </summary>
public int RoutingTimeoutMs { get; set; } = 30000;
/// <summary>
/// Gets or sets whether to prefer local region instances over neighbor regions.
/// </summary>
public bool PreferLocalRegion { get; set; } = true;
/// <summary>
/// Gets or sets whether to allow routing to degraded instances when no healthy instances are available.
/// </summary>
public bool AllowDegradedInstances { get; set; } = true;
/// <summary>
/// Gets or sets the tie-breaker mode when multiple instances have equal priority.
/// </summary>
public TieBreakerMode TieBreaker { get; set; } = TieBreakerMode.Random;
/// <summary>
/// Gets or sets the ping tolerance in milliseconds for considering instances "tied".
/// Instances within this tolerance of each other are considered to have equal latency.
/// </summary>
public double PingToleranceMs { get; set; } = 0.1;
}

View File

@@ -1,89 +0,0 @@
using StellaOps.Gateway.WebService.OpenApi;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Transport.InMemory;
namespace StellaOps.Gateway.WebService;
/// <summary>
/// Extension methods for registering gateway routing services.
/// </summary>
public static class ServiceCollectionExtensions
{
/// <summary>
/// Adds gateway routing services to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configuration">The configuration.</param>
/// <returns>The service collection for chaining.</returns>
public static IServiceCollection AddGatewayRouting(
this IServiceCollection services,
IConfiguration configuration)
{
// Bind configuration options
services.Configure<GatewayNodeConfig>(
configuration.GetSection(GatewayNodeConfig.SectionName));
services.Configure<RoutingOptions>(
configuration.GetSection(RoutingOptions.SectionName));
services.Configure<HealthOptions>(
configuration.GetSection(HealthOptions.SectionName));
// Register routing state as singleton (shared across all requests)
services.AddSingleton<IGlobalRoutingState, InMemoryRoutingState>();
// Register routing plugin
services.AddSingleton<IRoutingPlugin, DefaultRoutingPlugin>();
// Register InMemory transport (for development/testing)
services.AddInMemoryTransport();
// Register connection manager as hosted service
services.AddHostedService<ConnectionManager>();
// Register health monitor as hosted service
services.AddHostedService<HealthMonitorService>();
// Register OpenAPI aggregation services
services.Configure<OpenApiAggregationOptions>(
configuration.GetSection(OpenApiAggregationOptions.SectionName));
services.AddSingleton<IOpenApiDocumentGenerator, OpenApiDocumentGenerator>();
services.AddSingleton<IGatewayOpenApiDocumentCache, GatewayOpenApiDocumentCache>();
return services;
}
/// <summary>
/// Adds gateway routing services with custom options.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configureGateway">Action to configure gateway node options.</param>
/// <param name="configureRouting">Action to configure routing options.</param>
/// <returns>The service collection for chaining.</returns>
public static IServiceCollection AddGatewayRouting(
this IServiceCollection services,
Action<GatewayNodeConfig>? configureGateway = null,
Action<RoutingOptions>? configureRouting = null)
{
// Ensure default options are registered even if no configuration action provided
services.AddOptions<GatewayNodeConfig>();
services.AddOptions<RoutingOptions>();
// Configure options via actions
if (configureGateway is not null)
{
services.Configure(configureGateway);
}
if (configureRouting is not null)
{
services.Configure(configureRouting);
}
// Register routing state as singleton (shared across all requests)
services.AddSingleton<IGlobalRoutingState, InMemoryRoutingState>();
// Register routing plugin
services.AddSingleton<IRoutingPlugin, DefaultRoutingPlugin>();
return services;
}
}

View File

@@ -1,20 +0,0 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<LangVersion>preview</LangVersion>
<Nullable>enable</Nullable>
<ImplicitUsings>enable</ImplicitUsings>
<TreatWarningsAsErrors>false</TreatWarningsAsErrors>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="YamlDotNet" Version="16.2.1" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\__Libraries\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
<ProjectReference Include="..\..\__Libraries\StellaOps.Router.Config\StellaOps.Router.Config.csproj" />
<ProjectReference Include="..\..\__Libraries\StellaOps.Router.Transport.InMemory\StellaOps.Router.Transport.InMemory.csproj" />
</ItemGroup>
<ItemGroup>
<InternalsVisibleTo Include="StellaOps.Gateway.WebService.Tests" />
</ItemGroup>
</Project>

View File

@@ -1,270 +0,0 @@
using FluentAssertions;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Moq;
using StellaOps.Gateway.WebService.Authorization;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
/// <summary>
/// Unit tests for <see cref="AuthorityClaimsRefreshService"/>.
/// </summary>
public sealed class AuthorityClaimsRefreshServiceTests
{
private readonly Mock<IAuthorityClaimsProvider> _claimsProviderMock;
private readonly Mock<IEffectiveClaimsStore> _claimsStoreMock;
private readonly AuthorityConnectionOptions _options;
public AuthorityClaimsRefreshServiceTests()
{
_claimsProviderMock = new Mock<IAuthorityClaimsProvider>();
_claimsStoreMock = new Mock<IEffectiveClaimsStore>();
_options = new AuthorityConnectionOptions
{
AuthorityUrl = "http://authority.local",
Enabled = true,
RefreshInterval = TimeSpan.FromMilliseconds(100),
WaitForAuthorityOnStartup = false,
StartupTimeout = TimeSpan.FromSeconds(1)
};
_claimsProviderMock.Setup(p => p.GetOverridesAsync(It.IsAny<CancellationToken>()))
.ReturnsAsync(new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>());
}
private AuthorityClaimsRefreshService CreateService()
{
return new AuthorityClaimsRefreshService(
_claimsProviderMock.Object,
_claimsStoreMock.Object,
Options.Create(_options),
NullLogger<AuthorityClaimsRefreshService>.Instance);
}
#region ExecuteAsync Tests - Disabled
[Fact]
public async Task ExecuteAsync_WhenDisabled_DoesNotFetchClaims()
{
// Arrange
_options.Enabled = false;
var service = CreateService();
using var cts = new CancellationTokenSource();
// Act
await service.StartAsync(cts.Token);
await Task.Delay(50);
await service.StopAsync(cts.Token);
// Assert
_claimsProviderMock.Verify(
p => p.GetOverridesAsync(It.IsAny<CancellationToken>()),
Times.Never);
}
[Fact]
public async Task ExecuteAsync_WhenNoAuthorityUrl_DoesNotFetchClaims()
{
// Arrange
_options.AuthorityUrl = string.Empty;
var service = CreateService();
using var cts = new CancellationTokenSource();
// Act
await service.StartAsync(cts.Token);
await Task.Delay(50);
await service.StopAsync(cts.Token);
// Assert
_claimsProviderMock.Verify(
p => p.GetOverridesAsync(It.IsAny<CancellationToken>()),
Times.Never);
}
#endregion
#region ExecuteAsync Tests - Enabled
[Fact]
public async Task ExecuteAsync_WhenEnabled_FetchesClaims()
{
// Arrange
var service = CreateService();
using var cts = new CancellationTokenSource();
// Act
await service.StartAsync(cts.Token);
await Task.Delay(50);
await cts.CancelAsync();
await service.StopAsync(CancellationToken.None);
// Assert
_claimsProviderMock.Verify(
p => p.GetOverridesAsync(It.IsAny<CancellationToken>()),
Times.AtLeastOnce);
}
[Fact]
public async Task ExecuteAsync_UpdatesStoreWithOverrides()
{
// Arrange
var key = EndpointKey.Create("service", "GET", "/api/test");
var overrides = new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>
{
[key] = [new ClaimRequirement { Type = "role", Value = "admin" }]
};
_claimsProviderMock.Setup(p => p.GetOverridesAsync(It.IsAny<CancellationToken>()))
.ReturnsAsync(overrides);
var service = CreateService();
using var cts = new CancellationTokenSource();
// Act
await service.StartAsync(cts.Token);
await Task.Delay(50);
await cts.CancelAsync();
await service.StopAsync(CancellationToken.None);
// Assert
_claimsStoreMock.Verify(
s => s.UpdateFromAuthority(It.Is<IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>>(
d => d.ContainsKey(key))),
Times.AtLeastOnce);
}
#endregion
#region ExecuteAsync Tests - Wait for Authority
[Fact]
public async Task ExecuteAsync_WaitForAuthority_FetchesOnStartup()
{
// Arrange
_options.WaitForAuthorityOnStartup = true;
_options.StartupTimeout = TimeSpan.FromMilliseconds(500);
// Authority is immediately available
_claimsProviderMock.Setup(p => p.IsAvailable).Returns(true);
var fetchCalled = false;
_claimsProviderMock.Setup(p => p.GetOverridesAsync(It.IsAny<CancellationToken>()))
.Callback(() => fetchCalled = true)
.ReturnsAsync(new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>());
var service = CreateService();
using var cts = new CancellationTokenSource();
// Act
await service.StartAsync(cts.Token);
await Task.Delay(100);
await cts.CancelAsync();
await service.StopAsync(CancellationToken.None);
// Assert - fetch was called during startup
fetchCalled.Should().BeTrue();
}
[Fact]
public async Task ExecuteAsync_WaitForAuthority_StopsAfterTimeout()
{
// Arrange
_options.WaitForAuthorityOnStartup = true;
_options.StartupTimeout = TimeSpan.FromMilliseconds(100);
_claimsProviderMock.Setup(p => p.IsAvailable).Returns(false);
var service = CreateService();
using var cts = new CancellationTokenSource();
// Act - should not block forever
var startTask = service.StartAsync(cts.Token);
await Task.Delay(300);
await cts.CancelAsync();
await service.StopAsync(CancellationToken.None);
// Assert - should complete even if Authority never becomes available
startTask.IsCompleted.Should().BeTrue();
}
#endregion
#region Push Notification Tests
[Fact]
public async Task ExecuteAsync_WithPushNotifications_SubscribesToEvent()
{
// Arrange
_options.UseAuthorityPushNotifications = true;
var service = CreateService();
using var cts = new CancellationTokenSource();
// Act
await service.StartAsync(cts.Token);
await Task.Delay(50);
await cts.CancelAsync();
await service.StopAsync(CancellationToken.None);
// Assert - verify event subscription by checking it doesn't throw
_claimsProviderMock.VerifyAdd(
p => p.OverridesChanged += It.IsAny<EventHandler<ClaimsOverrideChangedEventArgs>>(),
Times.Once);
}
[Fact]
public async Task Dispose_WithPushNotifications_UnsubscribesFromEvent()
{
// Arrange
_options.UseAuthorityPushNotifications = true;
var service = CreateService();
using var cts = new CancellationTokenSource();
await service.StartAsync(cts.Token);
await Task.Delay(50);
// Act
await cts.CancelAsync();
service.Dispose();
// Assert
_claimsProviderMock.VerifyRemove(
p => p.OverridesChanged -= It.IsAny<EventHandler<ClaimsOverrideChangedEventArgs>>(),
Times.Once);
}
#endregion
#region Error Handling Tests
[Fact]
public async Task ExecuteAsync_ProviderThrows_ContinuesRefreshLoop()
{
// Arrange
var callCount = 0;
_claimsProviderMock.Setup(p => p.GetOverridesAsync(It.IsAny<CancellationToken>()))
.ReturnsAsync(() =>
{
callCount++;
if (callCount == 1)
{
throw new HttpRequestException("Test error");
}
return new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>();
});
var service = CreateService();
using var cts = new CancellationTokenSource();
// Act
await service.StartAsync(cts.Token);
await Task.Delay(250); // Wait for at least 2 refresh cycles
await cts.CancelAsync();
await service.StopAsync(CancellationToken.None);
// Assert - should have continued after error
callCount.Should().BeGreaterThan(1);
}
#endregion
}

View File

@@ -1,336 +0,0 @@
using System.Security.Claims;
using FluentAssertions;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging.Abstractions;
using Moq;
using StellaOps.Gateway.WebService.Authorization;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
/// <summary>
/// Unit tests for <see cref="AuthorizationMiddleware"/>.
/// </summary>
public sealed class AuthorizationMiddlewareTests
{
private readonly Mock<IEffectiveClaimsStore> _claimsStoreMock;
private readonly Mock<RequestDelegate> _nextMock;
private bool _nextCalled;
public AuthorizationMiddlewareTests()
{
_claimsStoreMock = new Mock<IEffectiveClaimsStore>();
_nextMock = new Mock<RequestDelegate>();
_nextMock.Setup(n => n(It.IsAny<HttpContext>()))
.Callback(() => _nextCalled = true)
.Returns(Task.CompletedTask);
}
private AuthorizationMiddleware CreateMiddleware()
{
return new AuthorizationMiddleware(
_nextMock.Object,
_claimsStoreMock.Object,
NullLogger<AuthorizationMiddleware>.Instance);
}
private static HttpContext CreateHttpContext(
EndpointDescriptor? endpoint = null,
ClaimsPrincipal? user = null)
{
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
if (endpoint is not null)
{
context.Items[RouterHttpContextKeys.EndpointDescriptor] = endpoint;
}
if (user is not null)
{
context.User = user;
}
return context;
}
private static EndpointDescriptor CreateEndpoint(
string serviceName = "test-service",
string method = "GET",
string path = "/api/test",
ClaimRequirement[]? claims = null)
{
return new EndpointDescriptor
{
ServiceName = serviceName,
Version = "1.0.0",
Method = method,
Path = path,
RequiringClaims = claims ?? []
};
}
private static ClaimsPrincipal CreateUserWithClaims(params (string Type, string Value)[] claims)
{
var identity = new ClaimsIdentity(
claims.Select(c => new Claim(c.Type, c.Value)),
"TestAuth");
return new ClaimsPrincipal(identity);
}
#region No Endpoint Tests
[Fact]
public async Task InvokeAsync_WithNoEndpoint_CallsNext()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(endpoint: null);
// Act
await middleware.InvokeAsync(context);
// Assert
_nextCalled.Should().BeTrue();
}
#endregion
#region Empty Claims Tests
[Fact]
public async Task InvokeAsync_WithEmptyRequiringClaims_CallsNext()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var context = CreateHttpContext(endpoint: endpoint);
_claimsStoreMock.Setup(s => s.GetEffectiveClaims(
endpoint.ServiceName, endpoint.Method, endpoint.Path))
.Returns(new List<ClaimRequirement>());
// Act
await middleware.InvokeAsync(context);
// Assert
_nextCalled.Should().BeTrue();
context.Response.StatusCode.Should().Be(StatusCodes.Status200OK);
}
#endregion
#region Matching Claims Tests
[Fact]
public async Task InvokeAsync_WithMatchingClaims_CallsNext()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var user = CreateUserWithClaims(("role", "admin"));
var context = CreateHttpContext(endpoint: endpoint, user: user);
_claimsStoreMock.Setup(s => s.GetEffectiveClaims(
endpoint.ServiceName, endpoint.Method, endpoint.Path))
.Returns(new List<ClaimRequirement>
{
new() { Type = "role", Value = "admin" }
});
// Act
await middleware.InvokeAsync(context);
// Assert
_nextCalled.Should().BeTrue();
context.Response.StatusCode.Should().Be(StatusCodes.Status200OK);
}
[Fact]
public async Task InvokeAsync_WithClaimTypeOnly_MatchesAnyValue()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var user = CreateUserWithClaims(("role", "any-value"));
var context = CreateHttpContext(endpoint: endpoint, user: user);
_claimsStoreMock.Setup(s => s.GetEffectiveClaims(
endpoint.ServiceName, endpoint.Method, endpoint.Path))
.Returns(new List<ClaimRequirement>
{
new() { Type = "role", Value = null } // Any value matches
});
// Act
await middleware.InvokeAsync(context);
// Assert
_nextCalled.Should().BeTrue();
}
[Fact]
public async Task InvokeAsync_WithMultipleMatchingClaims_CallsNext()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var user = CreateUserWithClaims(
("role", "admin"),
("department", "engineering"),
("level", "senior"));
var context = CreateHttpContext(endpoint: endpoint, user: user);
_claimsStoreMock.Setup(s => s.GetEffectiveClaims(
endpoint.ServiceName, endpoint.Method, endpoint.Path))
.Returns(new List<ClaimRequirement>
{
new() { Type = "role", Value = "admin" },
new() { Type = "department", Value = "engineering" }
});
// Act
await middleware.InvokeAsync(context);
// Assert
_nextCalled.Should().BeTrue();
}
#endregion
#region Missing Claims Tests
[Fact]
public async Task InvokeAsync_WithMissingClaim_Returns403()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var user = CreateUserWithClaims(("role", "user")); // Has role, but wrong value
var context = CreateHttpContext(endpoint: endpoint, user: user);
_claimsStoreMock.Setup(s => s.GetEffectiveClaims(
endpoint.ServiceName, endpoint.Method, endpoint.Path))
.Returns(new List<ClaimRequirement>
{
new() { Type = "role", Value = "admin" }
});
// Act
await middleware.InvokeAsync(context);
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status403Forbidden);
}
[Fact]
public async Task InvokeAsync_WithMissingClaimType_Returns403()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var user = CreateUserWithClaims(("department", "engineering"));
var context = CreateHttpContext(endpoint: endpoint, user: user);
_claimsStoreMock.Setup(s => s.GetEffectiveClaims(
endpoint.ServiceName, endpoint.Method, endpoint.Path))
.Returns(new List<ClaimRequirement>
{
new() { Type = "role", Value = "admin" }
});
// Act
await middleware.InvokeAsync(context);
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status403Forbidden);
}
[Fact]
public async Task InvokeAsync_WithNoClaims_Returns403()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var user = CreateUserWithClaims(); // No claims at all
var context = CreateHttpContext(endpoint: endpoint, user: user);
_claimsStoreMock.Setup(s => s.GetEffectiveClaims(
endpoint.ServiceName, endpoint.Method, endpoint.Path))
.Returns(new List<ClaimRequirement>
{
new() { Type = "role", Value = "admin" }
});
// Act
await middleware.InvokeAsync(context);
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status403Forbidden);
}
[Fact]
public async Task InvokeAsync_WithPartialMatchingClaims_Returns403()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var user = CreateUserWithClaims(("role", "admin")); // Has one, missing another
var context = CreateHttpContext(endpoint: endpoint, user: user);
_claimsStoreMock.Setup(s => s.GetEffectiveClaims(
endpoint.ServiceName, endpoint.Method, endpoint.Path))
.Returns(new List<ClaimRequirement>
{
new() { Type = "role", Value = "admin" },
new() { Type = "department", Value = "engineering" }
});
// Act
await middleware.InvokeAsync(context);
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status403Forbidden);
}
#endregion
#region Response Body Tests
[Fact]
public async Task InvokeAsync_WithMissingClaim_WritesErrorResponse()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var user = CreateUserWithClaims();
var context = CreateHttpContext(endpoint: endpoint, user: user);
_claimsStoreMock.Setup(s => s.GetEffectiveClaims(
endpoint.ServiceName, endpoint.Method, endpoint.Path))
.Returns(new List<ClaimRequirement>
{
new() { Type = "role", Value = "admin" }
});
// Act
await middleware.InvokeAsync(context);
// Assert
context.Response.ContentType.Should().StartWith("application/json");
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Contain("Forbidden");
responseBody.Should().Contain("role");
}
#endregion
}

View File

@@ -1,222 +0,0 @@
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using StellaOps.Microservice;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using StellaOps.Router.Transport.InMemory;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
public class CancellationTests
{
private readonly InMemoryConnectionRegistry _registry = new();
private readonly InMemoryTransportOptions _options = new() { SimulatedLatency = TimeSpan.Zero };
private InMemoryTransportClient CreateClient()
{
return new InMemoryTransportClient(
_registry,
Options.Create(_options),
NullLogger<InMemoryTransportClient>.Instance);
}
[Fact]
public void CancelReasons_HasAllExpectedConstants()
{
Assert.Equal("ClientDisconnected", CancelReasons.ClientDisconnected);
Assert.Equal("Timeout", CancelReasons.Timeout);
Assert.Equal("PayloadLimitExceeded", CancelReasons.PayloadLimitExceeded);
Assert.Equal("Shutdown", CancelReasons.Shutdown);
Assert.Equal("ConnectionClosed", CancelReasons.ConnectionClosed);
}
[Fact]
public async Task ConnectAsync_RegistersWithRegistry()
{
// Arrange
using var client = CreateClient();
var instance = new InstanceDescriptor
{
InstanceId = "test-instance",
ServiceName = "test-service",
Version = "1.0.0",
Region = "us-east-1"
};
// Act
await client.ConnectAsync(instance, [], CancellationToken.None);
// Assert
var connectionIdField = client.GetType()
.GetField("_connectionId", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
var connectionId = connectionIdField?.GetValue(client)?.ToString();
Assert.NotNull(connectionId);
var channel = _registry.GetChannel(connectionId!);
Assert.NotNull(channel);
Assert.Equal(instance.InstanceId, channel!.Instance?.InstanceId);
}
[Fact]
public void CancelAllInflight_DoesNotThrowWhenEmpty()
{
// Arrange
using var client = CreateClient();
// Act & Assert - should not throw
client.CancelAllInflight(CancelReasons.Shutdown);
}
[Fact]
public void Dispose_DoesNotThrow()
{
// Arrange
var client = CreateClient();
// Act & Assert - should not throw
client.Dispose();
}
[Fact]
public async Task DisconnectAsync_CancelsAllInflightWithShutdownReason()
{
// Arrange
using var client = CreateClient();
var instance = new InstanceDescriptor
{
InstanceId = "test-instance",
ServiceName = "test-service",
Version = "1.0.0",
Region = "us-east-1"
};
await client.ConnectAsync(instance, [], CancellationToken.None);
// Act
await client.DisconnectAsync();
// Assert - no exception means success
}
}
public class InflightRequestTrackerTests
{
[Fact]
public void Track_ReturnsCancellationToken()
{
// Arrange
using var tracker = new InflightRequestTracker(
NullLogger<InflightRequestTracker>.Instance);
var correlationId = Guid.NewGuid();
// Act
var token = tracker.Track(correlationId);
// Assert
Assert.False(token.IsCancellationRequested);
Assert.Equal(1, tracker.Count);
}
[Fact]
public void Track_ThrowsIfAlreadyTracked()
{
// Arrange
using var tracker = new InflightRequestTracker(
NullLogger<InflightRequestTracker>.Instance);
var correlationId = Guid.NewGuid();
tracker.Track(correlationId);
// Act & Assert
Assert.Throws<InvalidOperationException>(() => tracker.Track(correlationId));
}
[Fact]
public void Cancel_TriggersCancellationToken()
{
// Arrange
using var tracker = new InflightRequestTracker(
NullLogger<InflightRequestTracker>.Instance);
var correlationId = Guid.NewGuid();
var token = tracker.Track(correlationId);
// Act
var result = tracker.Cancel(correlationId, "TestReason");
// Assert
Assert.True(result);
Assert.True(token.IsCancellationRequested);
}
[Fact]
public void Cancel_ReturnsFalseForUnknownRequest()
{
// Arrange
using var tracker = new InflightRequestTracker(
NullLogger<InflightRequestTracker>.Instance);
var correlationId = Guid.NewGuid();
// Act
var result = tracker.Cancel(correlationId, "TestReason");
// Assert
Assert.False(result);
}
[Fact]
public void Complete_RemovesFromTracking()
{
// Arrange
using var tracker = new InflightRequestTracker(
NullLogger<InflightRequestTracker>.Instance);
var correlationId = Guid.NewGuid();
tracker.Track(correlationId);
Assert.Equal(1, tracker.Count);
// Act
tracker.Complete(correlationId);
// Assert
Assert.Equal(0, tracker.Count);
}
[Fact]
public void CancelAll_CancelsAllTrackedRequests()
{
// Arrange
using var tracker = new InflightRequestTracker(
NullLogger<InflightRequestTracker>.Instance);
var tokens = new List<CancellationToken>();
for (var i = 0; i < 5; i++)
{
tokens.Add(tracker.Track(Guid.NewGuid()));
}
// Act
tracker.CancelAll("TestReason");
// Assert
Assert.All(tokens, t => Assert.True(t.IsCancellationRequested));
}
[Fact]
public void Dispose_CancelsAllTrackedRequests()
{
// Arrange
var tracker = new InflightRequestTracker(
NullLogger<InflightRequestTracker>.Instance);
var tokens = new List<CancellationToken>();
for (var i = 0; i < 3; i++)
{
tokens.Add(tracker.Track(Guid.NewGuid()));
}
// Act
tracker.Dispose();
// Assert
Assert.All(tokens, t => Assert.True(t.IsCancellationRequested));
}
}

View File

@@ -1,213 +0,0 @@
using FluentAssertions;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Moq;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using StellaOps.Router.Transport.InMemory;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
/// <summary>
/// Integration-style tests for <see cref="ConnectionManager"/>.
/// Uses real InMemoryTransportServer since it's a sealed class.
/// </summary>
public sealed class ConnectionManagerTests : IAsyncLifetime
{
private readonly InMemoryConnectionRegistry _connectionRegistry;
private readonly InMemoryTransportServer _transportServer;
private readonly Mock<IGlobalRoutingState> _routingStateMock;
private readonly ConnectionManager _manager;
public ConnectionManagerTests()
{
_connectionRegistry = new InMemoryConnectionRegistry();
var options = Options.Create(new InMemoryTransportOptions());
_transportServer = new InMemoryTransportServer(
_connectionRegistry,
options,
NullLogger<InMemoryTransportServer>.Instance);
_routingStateMock = new Mock<IGlobalRoutingState>(MockBehavior.Loose);
_manager = new ConnectionManager(
_transportServer,
_connectionRegistry,
_routingStateMock.Object,
NullLogger<ConnectionManager>.Instance);
}
public async Task InitializeAsync()
{
await _manager.StartAsync(CancellationToken.None);
}
public async Task DisposeAsync()
{
await _manager.StopAsync(CancellationToken.None);
_transportServer.Dispose();
}
#region StartAsync/StopAsync Tests
[Fact]
public async Task StartAsync_ShouldStartSuccessfully()
{
// The manager starts in InitializeAsync
// Just verify it can be started without exception
await Task.CompletedTask;
}
[Fact]
public async Task StopAsync_ShouldStopSuccessfully()
{
// This is tested in DisposeAsync
await Task.CompletedTask;
}
#endregion
#region Connection Registration Tests via Channel Simulation
[Fact]
public async Task WhenHelloReceived_AddsConnectionToRoutingState()
{
// Arrange
var channel = CreateAndRegisterChannel("conn-1", "service-a", "1.0.0");
// Simulate sending a HELLO frame through the channel
var helloFrame = new Frame
{
Type = FrameType.Hello,
CorrelationId = Guid.NewGuid().ToString()
};
// Act
await channel.ToGateway.Writer.WriteAsync(helloFrame);
// Give time for the frame to be processed
await Task.Delay(100);
// Assert
_routingStateMock.Verify(
s => s.AddConnection(It.Is<ConnectionState>(c => c.ConnectionId == "conn-1")),
Times.Once);
}
[Fact]
public async Task WhenHeartbeatReceived_UpdatesConnectionState()
{
// Arrange
var channel = CreateAndRegisterChannel("conn-1", "service-a", "1.0.0");
// First send HELLO to register the connection
var helloFrame = new Frame
{
Type = FrameType.Hello,
CorrelationId = Guid.NewGuid().ToString()
};
await channel.ToGateway.Writer.WriteAsync(helloFrame);
await Task.Delay(100);
// Act - send heartbeat
var heartbeatFrame = new Frame
{
Type = FrameType.Heartbeat,
CorrelationId = Guid.NewGuid().ToString()
};
await channel.ToGateway.Writer.WriteAsync(heartbeatFrame);
await Task.Delay(100);
// Assert
_routingStateMock.Verify(
s => s.UpdateConnection("conn-1", It.IsAny<Action<ConnectionState>>()),
Times.AtLeastOnce);
}
[Fact]
public async Task WhenConnectionClosed_RemovesConnectionFromRoutingState()
{
// Arrange
var channel = CreateAndRegisterChannel("conn-1", "service-a", "1.0.0");
// First send HELLO to register the connection
var helloFrame = new Frame
{
Type = FrameType.Hello,
CorrelationId = Guid.NewGuid().ToString()
};
await channel.ToGateway.Writer.WriteAsync(helloFrame);
await Task.Delay(100);
// Act - close the channel
await channel.LifetimeToken.CancelAsync();
// Give time for the close to be processed
await Task.Delay(200);
// Assert - may be called multiple times (on close and on stop)
_routingStateMock.Verify(
s => s.RemoveConnection("conn-1"),
Times.AtLeastOnce);
}
[Fact]
public async Task WhenMultipleConnectionsRegister_AllAreTracked()
{
// Arrange
var channel1 = CreateAndRegisterChannel("conn-1", "service-a", "1.0.0");
var channel2 = CreateAndRegisterChannel("conn-2", "service-b", "2.0.0");
// Act - send HELLO frames
await channel1.ToGateway.Writer.WriteAsync(new Frame
{
Type = FrameType.Hello,
CorrelationId = Guid.NewGuid().ToString()
});
await channel2.ToGateway.Writer.WriteAsync(new Frame
{
Type = FrameType.Hello,
CorrelationId = Guid.NewGuid().ToString()
});
await Task.Delay(150);
// Assert
_routingStateMock.Verify(
s => s.AddConnection(It.Is<ConnectionState>(c => c.ConnectionId == "conn-1")),
Times.Once);
_routingStateMock.Verify(
s => s.AddConnection(It.Is<ConnectionState>(c => c.ConnectionId == "conn-2")),
Times.Once);
}
#endregion
#region Helper Methods
private InMemoryChannel CreateAndRegisterChannel(
string connectionId, string serviceName, string version)
{
var instance = new InstanceDescriptor
{
InstanceId = $"{serviceName}-{Guid.NewGuid():N}",
ServiceName = serviceName,
Version = version,
Region = "us-east-1"
};
// Create channel through the registry
var channel = _connectionRegistry.CreateChannel(connectionId);
channel.Instance = instance;
// Simulate that the transport server is listening to this connection
_transportServer.StartListeningToConnection(connectionId);
return channel;
}
#endregion
}

View File

@@ -1,538 +0,0 @@
using FluentAssertions;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
public class DefaultRoutingPluginTests
{
private readonly RoutingOptions _options = new()
{
DefaultVersion = null,
StrictVersionMatching = true,
RoutingTimeoutMs = 30000,
PreferLocalRegion = true,
AllowDegradedInstances = true,
TieBreaker = TieBreakerMode.Random,
PingToleranceMs = 0.1
};
private readonly GatewayNodeConfig _gatewayConfig = new()
{
Region = "us-east-1",
NodeId = "gw-test-01",
Environment = "test",
NeighborRegions = ["eu-west-1", "us-west-2"]
};
private DefaultRoutingPlugin CreateSut(
Action<RoutingOptions>? configureOptions = null,
Action<GatewayNodeConfig>? configureGateway = null)
{
configureOptions?.Invoke(_options);
configureGateway?.Invoke(_gatewayConfig);
return new DefaultRoutingPlugin(
Options.Create(_options),
Options.Create(_gatewayConfig));
}
private static ConnectionState CreateConnection(
string connectionId = "conn-1",
string serviceName = "test-service",
string version = "1.0.0",
string region = "us-east-1",
InstanceHealthStatus status = InstanceHealthStatus.Healthy,
double averagePingMs = 0,
DateTime? lastHeartbeatUtc = null)
{
return new ConnectionState
{
ConnectionId = connectionId,
Instance = new InstanceDescriptor
{
InstanceId = $"inst-{connectionId}",
ServiceName = serviceName,
Version = version,
Region = region
},
Status = status,
TransportType = TransportType.InMemory,
AveragePingMs = averagePingMs,
LastHeartbeatUtc = lastHeartbeatUtc ?? DateTime.UtcNow
};
}
private static EndpointDescriptor CreateEndpoint(
string method = "GET",
string path = "/api/test",
string serviceName = "test-service",
string version = "1.0.0")
{
return new EndpointDescriptor
{
Method = method,
Path = path,
ServiceName = serviceName,
Version = version
};
}
private static RoutingContext CreateContext(
string method = "GET",
string path = "/api/test",
string gatewayRegion = "us-east-1",
string? requestedVersion = null,
EndpointDescriptor? endpoint = null,
params ConnectionState[] connections)
{
return new RoutingContext
{
Method = method,
Path = path,
GatewayRegion = gatewayRegion,
RequestedVersion = requestedVersion,
Endpoint = endpoint ?? CreateEndpoint(),
AvailableConnections = connections
};
}
[Fact]
public async Task ChooseInstanceAsync_ShouldReturnNull_WhenNoConnections()
{
// Arrange
var sut = CreateSut();
var context = CreateContext();
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().BeNull();
}
[Fact]
public async Task ChooseInstanceAsync_ShouldReturnNull_WhenNoEndpoint()
{
// Arrange
var sut = CreateSut();
var connection = CreateConnection();
var context = new RoutingContext
{
Method = "GET",
Path = "/api/test",
GatewayRegion = "us-east-1",
Endpoint = null,
AvailableConnections = [connection]
};
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().BeNull();
}
[Fact]
public async Task ChooseInstanceAsync_ShouldSelectHealthyConnection()
{
// Arrange
var sut = CreateSut();
var connection = CreateConnection(status: InstanceHealthStatus.Healthy);
var context = CreateContext(connections: [connection]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.Connection.Should().BeSameAs(connection);
}
[Fact]
public async Task ChooseInstanceAsync_ShouldPreferHealthyOverDegraded()
{
// Arrange
var sut = CreateSut();
var degraded = CreateConnection("conn-1", status: InstanceHealthStatus.Degraded);
var healthy = CreateConnection("conn-2", status: InstanceHealthStatus.Healthy);
var context = CreateContext(connections: [degraded, healthy]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.Connection.Status.Should().Be(InstanceHealthStatus.Healthy);
}
[Fact]
public async Task ChooseInstanceAsync_ShouldSelectDegraded_WhenNoHealthyAndAllowed()
{
// Arrange
var sut = CreateSut(configureOptions: o => o.AllowDegradedInstances = true);
var degraded = CreateConnection(status: InstanceHealthStatus.Degraded);
var context = CreateContext(connections: [degraded]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.Connection.Status.Should().Be(InstanceHealthStatus.Degraded);
}
[Fact]
public async Task ChooseInstanceAsync_ShouldReturnNull_WhenOnlyDegradedAndNotAllowed()
{
// Arrange
var sut = CreateSut(configureOptions: o => o.AllowDegradedInstances = false);
var degraded = CreateConnection(status: InstanceHealthStatus.Degraded);
var context = CreateContext(connections: [degraded]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().BeNull();
}
[Fact]
public async Task ChooseInstanceAsync_ShouldExcludeUnhealthy()
{
// Arrange
var sut = CreateSut();
var unhealthy = CreateConnection("conn-1", status: InstanceHealthStatus.Unhealthy);
var healthy = CreateConnection("conn-2", status: InstanceHealthStatus.Healthy);
var context = CreateContext(connections: [unhealthy, healthy]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.Connection.ConnectionId.Should().Be("conn-2");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldExcludeDraining()
{
// Arrange
var sut = CreateSut();
var draining = CreateConnection("conn-1", status: InstanceHealthStatus.Draining);
var healthy = CreateConnection("conn-2", status: InstanceHealthStatus.Healthy);
var context = CreateContext(connections: [draining, healthy]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.Connection.ConnectionId.Should().Be("conn-2");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldFilterByRequestedVersion()
{
// Arrange
var sut = CreateSut();
var v1 = CreateConnection("conn-1", version: "1.0.0");
var v2 = CreateConnection("conn-2", version: "2.0.0");
var context = CreateContext(requestedVersion: "2.0.0", connections: [v1, v2]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.Connection.Instance.Version.Should().Be("2.0.0");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldUseDefaultVersion_WhenNoRequestedVersion()
{
// Arrange
var sut = CreateSut(configureOptions: o => o.DefaultVersion = "1.0.0");
var v1 = CreateConnection("conn-1", version: "1.0.0");
var v2 = CreateConnection("conn-2", version: "2.0.0");
var context = CreateContext(requestedVersion: null, connections: [v1, v2]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.Connection.Instance.Version.Should().Be("1.0.0");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldReturnNull_WhenNoMatchingVersion()
{
// Arrange
var sut = CreateSut();
var v1 = CreateConnection("conn-1", version: "1.0.0");
var context = CreateContext(requestedVersion: "2.0.0", connections: [v1]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().BeNull();
}
[Fact]
public async Task ChooseInstanceAsync_ShouldMatchAnyVersion_WhenNoVersionSpecified()
{
// Arrange
var sut = CreateSut(configureOptions: o => o.DefaultVersion = null);
var v1 = CreateConnection("conn-1", version: "1.0.0");
var v2 = CreateConnection("conn-2", version: "2.0.0");
var context = CreateContext(requestedVersion: null, connections: [v1, v2]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
}
[Fact]
public async Task ChooseInstanceAsync_ShouldPreferLocalRegion()
{
// Arrange
var sut = CreateSut(configureOptions: o => o.PreferLocalRegion = true);
var remote = CreateConnection("conn-1", region: "us-west-2");
var local = CreateConnection("conn-2", region: "us-east-1");
var context = CreateContext(gatewayRegion: "us-east-1", connections: [remote, local]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.Connection.Instance.Region.Should().Be("us-east-1");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldAllowRemoteRegion_WhenNoLocalAvailable()
{
// Arrange
var sut = CreateSut(configureOptions: o => o.PreferLocalRegion = true);
var remote = CreateConnection("conn-1", region: "us-west-2");
var context = CreateContext(gatewayRegion: "us-east-1", connections: [remote]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.Connection.Instance.Region.Should().Be("us-west-2");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldIgnoreRegionPreference_WhenDisabled()
{
// Arrange
var sut = CreateSut(configureOptions: o => o.PreferLocalRegion = false);
// Create connections with same ping and heartbeat so they are tied
var sameHeartbeat = DateTime.UtcNow;
var remote = CreateConnection("conn-1", region: "us-west-2", lastHeartbeatUtc: sameHeartbeat);
var local = CreateConnection("conn-2", region: "us-east-1", lastHeartbeatUtc: sameHeartbeat);
var context = CreateContext(gatewayRegion: "us-east-1", connections: [remote, local]);
// Act - run multiple times to verify random selection includes both
var selectedRegions = new HashSet<string>();
for (int i = 0; i < 50; i++)
{
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
selectedRegions.Add(result!.Connection.Instance.Region);
}
// Assert - with random selection, we should see both regions selected
// Note: This is probabilistic but should almost always pass
selectedRegions.Should().Contain("us-west-2");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldSetCorrectTimeout()
{
// Arrange
var sut = CreateSut(configureOptions: o => o.RoutingTimeoutMs = 5000);
var connection = CreateConnection();
var context = CreateContext(connections: [connection]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.EffectiveTimeout.Should().Be(TimeSpan.FromMilliseconds(5000));
}
[Fact]
public async Task ChooseInstanceAsync_ShouldSetCorrectTransportType()
{
// Arrange
var sut = CreateSut();
var connection = CreateConnection();
var context = CreateContext(connections: [connection]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.TransportType.Should().Be(TransportType.InMemory);
}
[Fact]
public async Task ChooseInstanceAsync_ShouldReturnEndpointFromContext()
{
// Arrange
var sut = CreateSut();
var endpoint = CreateEndpoint(path: "/api/special");
var connection = CreateConnection();
var context = CreateContext(endpoint: endpoint, connections: [connection]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert
result.Should().NotBeNull();
result!.Endpoint.Path.Should().Be("/api/special");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldDistributeLoadAcrossMultipleConnections()
{
// Arrange
var sut = CreateSut();
// Create connections with same ping and heartbeat so they are tied
var sameHeartbeat = DateTime.UtcNow;
var conn1 = CreateConnection("conn-1", lastHeartbeatUtc: sameHeartbeat);
var conn2 = CreateConnection("conn-2", lastHeartbeatUtc: sameHeartbeat);
var conn3 = CreateConnection("conn-3", lastHeartbeatUtc: sameHeartbeat);
var context = CreateContext(connections: [conn1, conn2, conn3]);
// Act - run multiple times
var selectedConnections = new Dictionary<string, int>();
for (int i = 0; i < 100; i++)
{
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
var connId = result!.Connection.ConnectionId;
selectedConnections[connId] = selectedConnections.GetValueOrDefault(connId) + 1;
}
// Assert - all connections should be selected at least once (probabilistic with random tie-breaker)
selectedConnections.Should().HaveCount(3);
selectedConnections.Keys.Should().Contain("conn-1");
selectedConnections.Keys.Should().Contain("conn-2");
selectedConnections.Keys.Should().Contain("conn-3");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldPreferLowerPing()
{
// Arrange
var sut = CreateSut();
var sameHeartbeat = DateTime.UtcNow;
var highPing = CreateConnection("conn-1", averagePingMs: 100, lastHeartbeatUtc: sameHeartbeat);
var lowPing = CreateConnection("conn-2", averagePingMs: 10, lastHeartbeatUtc: sameHeartbeat);
var context = CreateContext(connections: [highPing, lowPing]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert - lower ping should be preferred
result.Should().NotBeNull();
result!.Connection.ConnectionId.Should().Be("conn-2");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldPreferMoreRecentHeartbeat_WhenPingEqual()
{
// Arrange
var sut = CreateSut();
var now = DateTime.UtcNow;
var oldHeartbeat = CreateConnection("conn-1", averagePingMs: 10, lastHeartbeatUtc: now.AddSeconds(-30));
var recentHeartbeat = CreateConnection("conn-2", averagePingMs: 10, lastHeartbeatUtc: now);
var context = CreateContext(connections: [oldHeartbeat, recentHeartbeat]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert - more recent heartbeat should be preferred
result.Should().NotBeNull();
result!.Connection.ConnectionId.Should().Be("conn-2");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldPreferNeighborRegionOverRemote()
{
// Arrange - gateway config has NeighborRegions = ["eu-west-1", "us-west-2"]
var sut = CreateSut();
var sameHeartbeat = DateTime.UtcNow;
var remoteRegion = CreateConnection("conn-1", region: "ap-south-1", lastHeartbeatUtc: sameHeartbeat);
var neighborRegion = CreateConnection("conn-2", region: "eu-west-1", lastHeartbeatUtc: sameHeartbeat);
var context = CreateContext(gatewayRegion: "us-east-1", connections: [remoteRegion, neighborRegion]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert - neighbor region should be preferred over remote
result.Should().NotBeNull();
result!.Connection.Instance.Region.Should().Be("eu-west-1");
}
[Fact]
public async Task ChooseInstanceAsync_ShouldUseRoundRobin_WhenConfigured()
{
// Arrange
var sut = CreateSut(configureOptions: o => o.TieBreaker = TieBreakerMode.RoundRobin);
var sameHeartbeat = DateTime.UtcNow;
var conn1 = CreateConnection("conn-1", lastHeartbeatUtc: sameHeartbeat);
var conn2 = CreateConnection("conn-2", lastHeartbeatUtc: sameHeartbeat);
var context = CreateContext(connections: [conn1, conn2]);
// Act - with round-robin, we should cycle through connections
var selections = new List<string>();
for (int i = 0; i < 4; i++)
{
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
selections.Add(result!.Connection.ConnectionId);
}
// Assert - should alternate between connections
selections.Distinct().Count().Should().Be(2);
}
[Fact]
public async Task ChooseInstanceAsync_ShouldCombineFilters()
{
// Arrange
var sut = CreateSut(configureOptions: o =>
{
o.PreferLocalRegion = true;
o.AllowDegradedInstances = false;
});
// Create various combinations
var wrongVersionHealthyLocal = CreateConnection("conn-1", version: "2.0.0", region: "us-east-1", status: InstanceHealthStatus.Healthy);
var rightVersionDegradedLocal = CreateConnection("conn-2", version: "1.0.0", region: "us-east-1", status: InstanceHealthStatus.Degraded);
var rightVersionHealthyRemote = CreateConnection("conn-3", version: "1.0.0", region: "us-west-2", status: InstanceHealthStatus.Healthy);
var rightVersionHealthyLocal = CreateConnection("conn-4", version: "1.0.0", region: "us-east-1", status: InstanceHealthStatus.Healthy);
var context = CreateContext(
gatewayRegion: "us-east-1",
requestedVersion: "1.0.0",
connections: [wrongVersionHealthyLocal, rightVersionDegradedLocal, rightVersionHealthyRemote, rightVersionHealthyLocal]);
// Act
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
// Assert - should select the only connection matching all criteria
result.Should().NotBeNull();
result!.Connection.ConnectionId.Should().Be("conn-4");
}
}

View File

@@ -1,404 +0,0 @@
using FluentAssertions;
using Microsoft.Extensions.Logging.Abstractions;
using StellaOps.Gateway.WebService.Authorization;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
/// <summary>
/// Unit tests for <see cref="EffectiveClaimsStore"/>.
/// </summary>
public sealed class EffectiveClaimsStoreTests
{
private readonly EffectiveClaimsStore _store;
public EffectiveClaimsStoreTests()
{
_store = new EffectiveClaimsStore(NullLogger<EffectiveClaimsStore>.Instance);
}
#region GetEffectiveClaims Tests
[Fact]
public void GetEffectiveClaims_NoClaimsRegistered_ReturnsEmptyList()
{
// Arrange - fresh store
// Act
var claims = _store.GetEffectiveClaims("service", "GET", "/api/test");
// Assert
claims.Should().BeEmpty();
}
[Fact]
public void GetEffectiveClaims_MicroserviceClaimsOnly_ReturnsMicroserviceClaims()
{
// Arrange
var endpoints = new[]
{
new EndpointDescriptor
{
ServiceName = "test-service",
Version = "1.0.0",
Method = "GET",
Path = "/api/users",
RequiringClaims = [new ClaimRequirement { Type = "role", Value = "admin" }]
}
};
_store.UpdateFromMicroservice("test-service", endpoints);
// Act
var claims = _store.GetEffectiveClaims("test-service", "GET", "/api/users");
// Assert
claims.Should().HaveCount(1);
claims[0].Type.Should().Be("role");
claims[0].Value.Should().Be("admin");
}
[Fact]
public void GetEffectiveClaims_AuthorityOverridesTakePrecedence()
{
// Arrange
var endpoints = new[]
{
new EndpointDescriptor
{
ServiceName = "test-service",
Version = "1.0.0",
Method = "GET",
Path = "/api/users",
RequiringClaims = [new ClaimRequirement { Type = "role", Value = "user" }]
}
};
_store.UpdateFromMicroservice("test-service", endpoints);
var key = EndpointKey.Create("test-service", "GET", "/api/users");
var overrides = new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>
{
[key] = [new ClaimRequirement { Type = "role", Value = "admin" }]
};
_store.UpdateFromAuthority(overrides);
// Act
var claims = _store.GetEffectiveClaims("test-service", "GET", "/api/users");
// Assert
claims.Should().HaveCount(1);
claims[0].Value.Should().Be("admin");
}
[Fact]
public void GetEffectiveClaims_MethodNormalization_MatchesCaseInsensitively()
{
// Arrange
var endpoints = new[]
{
new EndpointDescriptor
{
ServiceName = "test-service",
Version = "1.0.0",
Method = "get",
Path = "/api/users",
RequiringClaims = [new ClaimRequirement { Type = "role", Value = "admin" }]
}
};
_store.UpdateFromMicroservice("test-service", endpoints);
// Act
var claims = _store.GetEffectiveClaims("test-service", "GET", "/api/users");
// Assert
claims.Should().HaveCount(1);
}
[Fact]
public void GetEffectiveClaims_PathNormalization_MatchesCaseInsensitively()
{
// Arrange
var endpoints = new[]
{
new EndpointDescriptor
{
ServiceName = "test-service",
Version = "1.0.0",
Method = "GET",
Path = "/API/USERS",
RequiringClaims = [new ClaimRequirement { Type = "role", Value = "admin" }]
}
};
_store.UpdateFromMicroservice("test-service", endpoints);
// Act
var claims = _store.GetEffectiveClaims("test-service", "GET", "/api/users");
// Assert
claims.Should().HaveCount(1);
}
#endregion
#region UpdateFromMicroservice Tests
[Fact]
public void UpdateFromMicroservice_MultipleEndpoints_RegistersAll()
{
// Arrange
var endpoints = new[]
{
new EndpointDescriptor
{
ServiceName = "test-service",
Version = "1.0.0",
Method = "GET",
Path = "/api/users",
RequiringClaims = [new ClaimRequirement { Type = "role", Value = "reader" }]
},
new EndpointDescriptor
{
ServiceName = "test-service",
Version = "1.0.0",
Method = "POST",
Path = "/api/users",
RequiringClaims = [new ClaimRequirement { Type = "role", Value = "writer" }]
}
};
// Act
_store.UpdateFromMicroservice("test-service", endpoints);
// Assert
_store.GetEffectiveClaims("test-service", "GET", "/api/users")[0].Value.Should().Be("reader");
_store.GetEffectiveClaims("test-service", "POST", "/api/users")[0].Value.Should().Be("writer");
}
[Fact]
public void UpdateFromMicroservice_EmptyClaims_RemovesFromStore()
{
// Arrange - first add some claims
var endpoints1 = new[]
{
new EndpointDescriptor
{
ServiceName = "test-service",
Version = "1.0.0",
Method = "GET",
Path = "/api/users",
RequiringClaims = [new ClaimRequirement { Type = "role", Value = "admin" }]
}
};
_store.UpdateFromMicroservice("test-service", endpoints1);
// Now update with empty claims
var endpoints2 = new[]
{
new EndpointDescriptor
{
ServiceName = "test-service",
Version = "1.0.0",
Method = "GET",
Path = "/api/users",
RequiringClaims = []
}
};
// Act
_store.UpdateFromMicroservice("test-service", endpoints2);
// Assert
_store.GetEffectiveClaims("test-service", "GET", "/api/users").Should().BeEmpty();
}
[Fact]
public void UpdateFromMicroservice_DefaultEmptyClaims_TreatedAsEmpty()
{
// Arrange
var endpoints = new[]
{
new EndpointDescriptor
{
ServiceName = "test-service",
Version = "1.0.0",
Method = "GET",
Path = "/api/users"
// RequiringClaims defaults to []
}
};
// Act
_store.UpdateFromMicroservice("test-service", endpoints);
// Assert
_store.GetEffectiveClaims("test-service", "GET", "/api/users").Should().BeEmpty();
}
#endregion
#region UpdateFromAuthority Tests
[Fact]
public void UpdateFromAuthority_ClearsPreviousOverrides()
{
// Arrange - add initial override
var key1 = EndpointKey.Create("service1", "GET", "/api/test1");
var overrides1 = new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>
{
[key1] = [new ClaimRequirement { Type = "role", Value = "old" }]
};
_store.UpdateFromAuthority(overrides1);
// Update with new overrides (different key)
var key2 = EndpointKey.Create("service2", "POST", "/api/test2");
var overrides2 = new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>
{
[key2] = [new ClaimRequirement { Type = "role", Value = "new" }]
};
// Act
_store.UpdateFromAuthority(overrides2);
// Assert
_store.GetEffectiveClaims("service1", "GET", "/api/test1").Should().BeEmpty();
_store.GetEffectiveClaims("service2", "POST", "/api/test2").Should().HaveCount(1);
}
[Fact]
public void UpdateFromAuthority_EmptyClaimsNotStored()
{
// Arrange
var key = EndpointKey.Create("service", "GET", "/api/test");
var overrides = new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>
{
[key] = []
};
// Act
_store.UpdateFromAuthority(overrides);
// Assert - should fall back to microservice (which is empty)
_store.GetEffectiveClaims("service", "GET", "/api/test").Should().BeEmpty();
}
[Fact]
public void UpdateFromAuthority_MultipleOverrides()
{
// Arrange
var key1 = EndpointKey.Create("service1", "GET", "/api/users");
var key2 = EndpointKey.Create("service1", "POST", "/api/users");
var overrides = new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>
{
[key1] = [new ClaimRequirement { Type = "role", Value = "reader" }],
[key2] = [new ClaimRequirement { Type = "role", Value = "writer" }]
};
// Act
_store.UpdateFromAuthority(overrides);
// Assert
_store.GetEffectiveClaims("service1", "GET", "/api/users")[0].Value.Should().Be("reader");
_store.GetEffectiveClaims("service1", "POST", "/api/users")[0].Value.Should().Be("writer");
}
#endregion
#region RemoveService Tests
[Fact]
public void RemoveService_RemovesMicroserviceClaims()
{
// Arrange
var endpoints = new[]
{
new EndpointDescriptor
{
ServiceName = "test-service",
Version = "1.0.0",
Method = "GET",
Path = "/api/users",
RequiringClaims = [new ClaimRequirement { Type = "role", Value = "admin" }]
}
};
_store.UpdateFromMicroservice("test-service", endpoints);
// Act
_store.RemoveService("test-service");
// Assert
_store.GetEffectiveClaims("test-service", "GET", "/api/users").Should().BeEmpty();
}
[Fact]
public void RemoveService_CaseInsensitive()
{
// Arrange
var endpoints = new[]
{
new EndpointDescriptor
{
ServiceName = "Test-Service",
Version = "1.0.0",
Method = "GET",
Path = "/api/users",
RequiringClaims = [new ClaimRequirement { Type = "role", Value = "admin" }]
}
};
_store.UpdateFromMicroservice("Test-Service", endpoints);
// Act - remove with different case
_store.RemoveService("TEST-SERVICE");
// Assert
_store.GetEffectiveClaims("test-service", "GET", "/api/users").Should().BeEmpty();
}
[Fact]
public void RemoveService_OnlyRemovesTargetService()
{
// Arrange
var endpoints1 = new[]
{
new EndpointDescriptor
{
ServiceName = "service-a",
Version = "1.0.0",
Method = "GET",
Path = "/api/a",
RequiringClaims = [new ClaimRequirement { Type = "role", Value = "a" }]
}
};
var endpoints2 = new[]
{
new EndpointDescriptor
{
ServiceName = "service-b",
Version = "1.0.0",
Method = "GET",
Path = "/api/b",
RequiringClaims = [new ClaimRequirement { Type = "role", Value = "b" }]
}
};
_store.UpdateFromMicroservice("service-a", endpoints1);
_store.UpdateFromMicroservice("service-b", endpoints2);
// Act
_store.RemoveService("service-a");
// Assert
_store.GetEffectiveClaims("service-a", "GET", "/api/a").Should().BeEmpty();
_store.GetEffectiveClaims("service-b", "GET", "/api/b").Should().HaveCount(1);
}
[Fact]
public void RemoveService_UnknownService_DoesNotThrow()
{
// Arrange & Act
var action = () => _store.RemoveService("unknown-service");
// Assert
action.Should().NotThrow();
}
#endregion
}

View File

@@ -1,287 +0,0 @@
using FluentAssertions;
using Microsoft.AspNetCore.Http;
using Moq;
using StellaOps.Gateway.WebService.Middleware;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
/// <summary>
/// Unit tests for <see cref="EndpointResolutionMiddleware"/>.
/// </summary>
public sealed class EndpointResolutionMiddlewareTests
{
private readonly Mock<IGlobalRoutingState> _routingStateMock;
private readonly Mock<RequestDelegate> _nextMock;
private bool _nextCalled;
public EndpointResolutionMiddlewareTests()
{
_routingStateMock = new Mock<IGlobalRoutingState>();
_nextMock = new Mock<RequestDelegate>();
_nextMock.Setup(n => n(It.IsAny<HttpContext>()))
.Callback(() => _nextCalled = true)
.Returns(Task.CompletedTask);
}
private EndpointResolutionMiddleware CreateMiddleware()
{
return new EndpointResolutionMiddleware(_nextMock.Object);
}
private static HttpContext CreateHttpContext(string method = "GET", string path = "/api/test")
{
var context = new DefaultHttpContext();
context.Request.Method = method;
context.Request.Path = path;
context.Response.Body = new MemoryStream();
return context;
}
private static EndpointDescriptor CreateEndpoint(
string serviceName = "test-service",
string method = "GET",
string path = "/api/test")
{
return new EndpointDescriptor
{
ServiceName = serviceName,
Version = "1.0.0",
Method = method,
Path = path
};
}
#region Matching Endpoint Tests
[Fact]
public async Task Invoke_WithMatchingEndpoint_SetsHttpContextItem()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var context = CreateHttpContext();
_routingStateMock.Setup(r => r.ResolveEndpoint("GET", "/api/test"))
.Returns(endpoint);
// Act
await middleware.Invoke(context, _routingStateMock.Object);
// Assert
_nextCalled.Should().BeTrue();
context.Items[RouterHttpContextKeys.EndpointDescriptor].Should().Be(endpoint);
}
[Fact]
public async Task Invoke_WithMatchingEndpoint_CallsNext()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var context = CreateHttpContext();
_routingStateMock.Setup(r => r.ResolveEndpoint("GET", "/api/test"))
.Returns(endpoint);
// Act
await middleware.Invoke(context, _routingStateMock.Object);
// Assert
_nextCalled.Should().BeTrue();
}
#endregion
#region Unknown Path Tests
[Fact]
public async Task Invoke_WithUnknownPath_Returns404()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(path: "/api/unknown");
_routingStateMock.Setup(r => r.ResolveEndpoint("GET", "/api/unknown"))
.Returns((EndpointDescriptor?)null);
// Act
await middleware.Invoke(context, _routingStateMock.Object);
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status404NotFound);
}
[Fact]
public async Task Invoke_WithUnknownPath_WritesErrorResponse()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(path: "/api/unknown");
_routingStateMock.Setup(r => r.ResolveEndpoint("GET", "/api/unknown"))
.Returns((EndpointDescriptor?)null);
// Act
await middleware.Invoke(context, _routingStateMock.Object);
// Assert
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Contain("not found");
responseBody.Should().Contain("/api/unknown");
}
#endregion
#region HTTP Method Tests
[Fact]
public async Task Invoke_WithPostMethod_ResolvesCorrectly()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint(method: "POST");
var context = CreateHttpContext(method: "POST");
_routingStateMock.Setup(r => r.ResolveEndpoint("POST", "/api/test"))
.Returns(endpoint);
// Act
await middleware.Invoke(context, _routingStateMock.Object);
// Assert
_nextCalled.Should().BeTrue();
context.Items[RouterHttpContextKeys.EndpointDescriptor].Should().Be(endpoint);
}
[Fact]
public async Task Invoke_WithDeleteMethod_ResolvesCorrectly()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint(method: "DELETE", path: "/api/users/123");
var context = CreateHttpContext(method: "DELETE", path: "/api/users/123");
_routingStateMock.Setup(r => r.ResolveEndpoint("DELETE", "/api/users/123"))
.Returns(endpoint);
// Act
await middleware.Invoke(context, _routingStateMock.Object);
// Assert
_nextCalled.Should().BeTrue();
}
[Fact]
public async Task Invoke_WithWrongMethod_Returns404()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(method: "DELETE", path: "/api/test");
_routingStateMock.Setup(r => r.ResolveEndpoint("DELETE", "/api/test"))
.Returns((EndpointDescriptor?)null);
// Act
await middleware.Invoke(context, _routingStateMock.Object);
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status404NotFound);
}
#endregion
#region Path Variations Tests
[Fact]
public async Task Invoke_WithParameterizedPath_ResolvesCorrectly()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint(path: "/api/users/{id}");
var context = CreateHttpContext(path: "/api/users/123");
_routingStateMock.Setup(r => r.ResolveEndpoint("GET", "/api/users/123"))
.Returns(endpoint);
// Act
await middleware.Invoke(context, _routingStateMock.Object);
// Assert
_nextCalled.Should().BeTrue();
context.Items[RouterHttpContextKeys.EndpointDescriptor].Should().Be(endpoint);
}
[Fact]
public async Task Invoke_WithRootPath_ResolvesCorrectly()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint(path: "/");
var context = CreateHttpContext(path: "/");
_routingStateMock.Setup(r => r.ResolveEndpoint("GET", "/"))
.Returns(endpoint);
// Act
await middleware.Invoke(context, _routingStateMock.Object);
// Assert
_nextCalled.Should().BeTrue();
}
[Fact]
public async Task Invoke_WithEmptyPath_PassesEmptyStringToRouting()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(path: "");
_routingStateMock.Setup(r => r.ResolveEndpoint("GET", ""))
.Returns((EndpointDescriptor?)null);
// Act
await middleware.Invoke(context, _routingStateMock.Object);
// Assert
_routingStateMock.Verify(r => r.ResolveEndpoint("GET", ""), Times.Once);
}
#endregion
#region Multiple Calls Tests
[Fact]
public async Task Invoke_MultipleCalls_EachResolvesIndependently()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint1 = CreateEndpoint(path: "/api/users");
var endpoint2 = CreateEndpoint(path: "/api/items");
_routingStateMock.Setup(r => r.ResolveEndpoint("GET", "/api/users"))
.Returns(endpoint1);
_routingStateMock.Setup(r => r.ResolveEndpoint("GET", "/api/items"))
.Returns(endpoint2);
var context1 = CreateHttpContext(path: "/api/users");
var context2 = CreateHttpContext(path: "/api/items");
// Act
await middleware.Invoke(context1, _routingStateMock.Object);
await middleware.Invoke(context2, _routingStateMock.Object);
// Assert
context1.Items[RouterHttpContextKeys.EndpointDescriptor].Should().Be(endpoint1);
context2.Items[RouterHttpContextKeys.EndpointDescriptor].Should().Be(endpoint2);
}
#endregion
}

View File

@@ -1,277 +0,0 @@
using FluentAssertions;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Moq;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
/// <summary>
/// Tests for <see cref="HealthMonitorService"/>.
/// </summary>
public sealed class HealthMonitorServiceTests
{
private readonly Mock<IGlobalRoutingState> _routingStateMock;
private readonly HealthOptions _options;
public HealthMonitorServiceTests()
{
_routingStateMock = new Mock<IGlobalRoutingState>(MockBehavior.Loose);
_options = new HealthOptions
{
StaleThreshold = TimeSpan.FromSeconds(10),
DegradedThreshold = TimeSpan.FromSeconds(5),
CheckInterval = TimeSpan.FromMilliseconds(100)
};
}
private HealthMonitorService CreateService()
{
return new HealthMonitorService(
_routingStateMock.Object,
Options.Create(_options),
NullLogger<HealthMonitorService>.Instance);
}
[Fact]
public async Task ExecuteAsync_MarksStaleConnectionsUnhealthy()
{
// Arrange
var staleConnection = CreateConnection("conn-1", "service-a", "1.0.0");
staleConnection.Status = InstanceHealthStatus.Healthy;
staleConnection.LastHeartbeatUtc = DateTime.UtcNow.AddSeconds(-15); // Past stale threshold
_routingStateMock.Setup(s => s.GetAllConnections())
.Returns([staleConnection]);
var service = CreateService();
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(500));
// Act
try
{
await service.StartAsync(cts.Token);
await Task.Delay(200, cts.Token);
}
catch (OperationCanceledException)
{
// Expected
}
finally
{
await service.StopAsync(CancellationToken.None);
}
// Assert
_routingStateMock.Verify(
s => s.UpdateConnection("conn-1", It.IsAny<Action<ConnectionState>>()),
Times.AtLeastOnce);
}
[Fact]
public async Task ExecuteAsync_MarksDegradedConnectionsDegraded()
{
// Arrange
var degradedConnection = CreateConnection("conn-1", "service-a", "1.0.0");
degradedConnection.Status = InstanceHealthStatus.Healthy;
degradedConnection.LastHeartbeatUtc = DateTime.UtcNow.AddSeconds(-7); // Past degraded but not stale
_routingStateMock.Setup(s => s.GetAllConnections())
.Returns([degradedConnection]);
var service = CreateService();
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(1));
// Act
try
{
await service.StartAsync(cts.Token);
// Wait enough time for at least one check cycle (CheckInterval is 100ms)
await Task.Delay(300, cts.Token);
}
catch (OperationCanceledException)
{
// Expected
}
finally
{
await service.StopAsync(CancellationToken.None);
}
// Assert
_routingStateMock.Verify(
s => s.UpdateConnection("conn-1", It.IsAny<Action<ConnectionState>>()),
Times.AtLeastOnce);
}
[Fact]
public async Task ExecuteAsync_DoesNotChangeHealthyConnections()
{
// Arrange
var healthyConnection = CreateConnection("conn-1", "service-a", "1.0.0");
healthyConnection.Status = InstanceHealthStatus.Healthy;
healthyConnection.LastHeartbeatUtc = DateTime.UtcNow; // Fresh heartbeat
_routingStateMock.Setup(s => s.GetAllConnections())
.Returns([healthyConnection]);
var service = CreateService();
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(300));
// Act
try
{
await service.StartAsync(cts.Token);
await Task.Delay(200, cts.Token);
}
catch (OperationCanceledException)
{
// Expected
}
finally
{
await service.StopAsync(CancellationToken.None);
}
// Assert - should not have updated the connection
_routingStateMock.Verify(
s => s.UpdateConnection(It.IsAny<string>(), It.IsAny<Action<ConnectionState>>()),
Times.Never);
}
[Fact]
public async Task ExecuteAsync_DoesNotChangeDrainingConnections()
{
// Arrange
var drainingConnection = CreateConnection("conn-1", "service-a", "1.0.0");
drainingConnection.Status = InstanceHealthStatus.Draining;
drainingConnection.LastHeartbeatUtc = DateTime.UtcNow.AddSeconds(-30); // Very stale
_routingStateMock.Setup(s => s.GetAllConnections())
.Returns([drainingConnection]);
var service = CreateService();
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(300));
// Act
try
{
await service.StartAsync(cts.Token);
await Task.Delay(200, cts.Token);
}
catch (OperationCanceledException)
{
// Expected
}
finally
{
await service.StopAsync(CancellationToken.None);
}
// Assert - draining connections should be left alone
_routingStateMock.Verify(
s => s.UpdateConnection(It.IsAny<string>(), It.IsAny<Action<ConnectionState>>()),
Times.Never);
}
[Fact]
public async Task ExecuteAsync_DoesNotDoubleMarkUnhealthy()
{
// Arrange
var unhealthyConnection = CreateConnection("conn-1", "service-a", "1.0.0");
unhealthyConnection.Status = InstanceHealthStatus.Unhealthy;
unhealthyConnection.LastHeartbeatUtc = DateTime.UtcNow.AddSeconds(-30); // Very stale
_routingStateMock.Setup(s => s.GetAllConnections())
.Returns([unhealthyConnection]);
var service = CreateService();
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(300));
// Act
try
{
await service.StartAsync(cts.Token);
await Task.Delay(200, cts.Token);
}
catch (OperationCanceledException)
{
// Expected
}
finally
{
await service.StopAsync(CancellationToken.None);
}
// Assert - already unhealthy connections should not be updated
_routingStateMock.Verify(
s => s.UpdateConnection(It.IsAny<string>(), It.IsAny<Action<ConnectionState>>()),
Times.Never);
}
[Fact]
public async Task UpdateAction_SetsStatusToUnhealthy()
{
// Arrange
var connection = CreateConnection("conn-1", "service-a", "1.0.0");
connection.Status = InstanceHealthStatus.Healthy;
connection.LastHeartbeatUtc = DateTime.UtcNow.AddSeconds(-15);
Action<ConnectionState>? capturedAction = null;
_routingStateMock.Setup(s => s.UpdateConnection("conn-1", It.IsAny<Action<ConnectionState>>()))
.Callback<string, Action<ConnectionState>>((id, action) => capturedAction = action);
_routingStateMock.Setup(s => s.GetAllConnections())
.Returns([connection]);
var service = CreateService();
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(300));
// Act - run the service briefly
try
{
await service.StartAsync(cts.Token);
await Task.Delay(200, cts.Token);
}
catch (OperationCanceledException)
{
// Expected
}
finally
{
await service.StopAsync(CancellationToken.None);
}
// Assert
capturedAction.Should().NotBeNull();
// Apply the action to verify it sets Unhealthy
var testConnection = CreateConnection("conn-1", "service-a", "1.0.0");
testConnection.Status = InstanceHealthStatus.Healthy;
capturedAction!(testConnection);
testConnection.Status.Should().Be(InstanceHealthStatus.Unhealthy);
}
private static ConnectionState CreateConnection(
string connectionId, string serviceName, string version)
{
return new ConnectionState
{
ConnectionId = connectionId,
Instance = new InstanceDescriptor
{
InstanceId = $"{serviceName}-{Guid.NewGuid():N}",
ServiceName = serviceName,
Version = version,
Region = "us-east-1"
},
Status = InstanceHealthStatus.Healthy,
LastHeartbeatUtc = DateTime.UtcNow,
TransportType = TransportType.InMemory
};
}
}

View File

@@ -1,356 +0,0 @@
using System.Net;
using System.Text.Json;
using FluentAssertions;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Moq;
using Moq.Protected;
using StellaOps.Gateway.WebService.Authorization;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
/// <summary>
/// Unit tests for <see cref="HttpAuthorityClaimsProvider"/>.
/// </summary>
public sealed class HttpAuthorityClaimsProviderTests
{
private readonly Mock<HttpMessageHandler> _httpHandlerMock;
private readonly HttpClient _httpClient;
private readonly AuthorityConnectionOptions _options;
public HttpAuthorityClaimsProviderTests()
{
_httpHandlerMock = new Mock<HttpMessageHandler>();
_httpClient = new HttpClient(_httpHandlerMock.Object);
_options = new AuthorityConnectionOptions
{
AuthorityUrl = "http://authority.local"
};
}
private HttpAuthorityClaimsProvider CreateProvider()
{
return new HttpAuthorityClaimsProvider(
_httpClient,
Options.Create(_options),
NullLogger<HttpAuthorityClaimsProvider>.Instance);
}
#region GetOverridesAsync Tests
[Fact]
public async Task GetOverridesAsync_NoAuthorityUrl_ReturnsEmpty()
{
// Arrange
_options.AuthorityUrl = string.Empty;
var provider = CreateProvider();
// Act
var result = await provider.GetOverridesAsync(CancellationToken.None);
// Assert
result.Should().BeEmpty();
provider.IsAvailable.Should().BeFalse();
}
[Fact]
public async Task GetOverridesAsync_WhitespaceUrl_ReturnsEmpty()
{
// Arrange
_options.AuthorityUrl = " ";
var provider = CreateProvider();
// Act
var result = await provider.GetOverridesAsync(CancellationToken.None);
// Assert
result.Should().BeEmpty();
provider.IsAvailable.Should().BeFalse();
}
[Fact]
public async Task GetOverridesAsync_SuccessfulResponse_ParsesOverrides()
{
// Arrange
var responseBody = JsonSerializer.Serialize(new
{
overrides = new[]
{
new
{
serviceName = "test-service",
method = "GET",
path = "/api/users",
requiringClaims = new[]
{
new { type = "role", value = "admin" }
}
}
}
}, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase });
SetupHttpResponse(HttpStatusCode.OK, responseBody);
var provider = CreateProvider();
// Act
var result = await provider.GetOverridesAsync(CancellationToken.None);
// Assert
result.Should().HaveCount(1);
provider.IsAvailable.Should().BeTrue();
var key = result.Keys.First();
key.ServiceName.Should().Be("test-service");
key.Method.Should().Be("GET");
key.Path.Should().Be("/api/users");
result[key].Should().HaveCount(1);
result[key][0].Type.Should().Be("role");
result[key][0].Value.Should().Be("admin");
}
[Fact]
public async Task GetOverridesAsync_EmptyOverrides_ReturnsEmpty()
{
// Arrange
var responseBody = JsonSerializer.Serialize(new
{
overrides = Array.Empty<object>()
});
SetupHttpResponse(HttpStatusCode.OK, responseBody);
var provider = CreateProvider();
// Act
var result = await provider.GetOverridesAsync(CancellationToken.None);
// Assert
result.Should().BeEmpty();
provider.IsAvailable.Should().BeTrue();
}
[Fact]
public async Task GetOverridesAsync_NullOverrides_ReturnsEmpty()
{
// Arrange
var responseBody = "{}";
SetupHttpResponse(HttpStatusCode.OK, responseBody);
var provider = CreateProvider();
// Act
var result = await provider.GetOverridesAsync(CancellationToken.None);
// Assert
result.Should().BeEmpty();
provider.IsAvailable.Should().BeTrue();
}
[Fact]
public async Task GetOverridesAsync_HttpError_ReturnsEmptyAndSetsUnavailable()
{
// Arrange
SetupHttpResponse(HttpStatusCode.InternalServerError, "Error");
var provider = CreateProvider();
// Act
var result = await provider.GetOverridesAsync(CancellationToken.None);
// Assert
result.Should().BeEmpty();
provider.IsAvailable.Should().BeFalse();
}
[Fact]
public async Task GetOverridesAsync_Timeout_ReturnsEmptyAndSetsUnavailable()
{
// Arrange
_httpHandlerMock.Protected()
.Setup<Task<HttpResponseMessage>>(
"SendAsync",
ItExpr.IsAny<HttpRequestMessage>(),
ItExpr.IsAny<CancellationToken>())
.ThrowsAsync(new TaskCanceledException("Timeout"));
var provider = CreateProvider();
// Act
var result = await provider.GetOverridesAsync(CancellationToken.None);
// Assert
result.Should().BeEmpty();
provider.IsAvailable.Should().BeFalse();
}
[Fact]
public async Task GetOverridesAsync_NetworkError_ReturnsEmptyAndSetsUnavailable()
{
// Arrange
_httpHandlerMock.Protected()
.Setup<Task<HttpResponseMessage>>(
"SendAsync",
ItExpr.IsAny<HttpRequestMessage>(),
ItExpr.IsAny<CancellationToken>())
.ThrowsAsync(new HttpRequestException("Connection refused"));
var provider = CreateProvider();
// Act
var result = await provider.GetOverridesAsync(CancellationToken.None);
// Assert
result.Should().BeEmpty();
provider.IsAvailable.Should().BeFalse();
}
[Fact]
public async Task GetOverridesAsync_TrimsTrailingSlash()
{
// Arrange
_options.AuthorityUrl = "http://authority.local/";
var responseBody = JsonSerializer.Serialize(new { overrides = Array.Empty<object>() });
string? capturedUrl = null;
_httpHandlerMock.Protected()
.Setup<Task<HttpResponseMessage>>(
"SendAsync",
ItExpr.IsAny<HttpRequestMessage>(),
ItExpr.IsAny<CancellationToken>())
.ReturnsAsync((HttpRequestMessage req, CancellationToken _) =>
{
capturedUrl = req.RequestUri?.ToString();
return new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent(responseBody)
};
});
var provider = CreateProvider();
// Act
await provider.GetOverridesAsync(CancellationToken.None);
// Assert
capturedUrl.Should().Be("http://authority.local/api/v1/claims/overrides");
}
[Fact]
public async Task GetOverridesAsync_MultipleOverrides_ParsesAll()
{
// Arrange
var responseBody = JsonSerializer.Serialize(new
{
overrides = new[]
{
new
{
serviceName = "service-a",
method = "GET",
path = "/api/a",
requiringClaims = new[] { new { type = "role", value = "a" } }
},
new
{
serviceName = "service-b",
method = "POST",
path = "/api/b",
requiringClaims = new[]
{
new { type = "role", value = "b1" },
new { type = "department", value = "b2" }
}
}
}
}, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase });
SetupHttpResponse(HttpStatusCode.OK, responseBody);
var provider = CreateProvider();
// Act
var result = await provider.GetOverridesAsync(CancellationToken.None);
// Assert
result.Should().HaveCount(2);
}
#endregion
#region IsAvailable Tests
[Fact]
public void IsAvailable_InitiallyFalse()
{
// Arrange
var provider = CreateProvider();
// Assert
provider.IsAvailable.Should().BeFalse();
}
[Fact]
public async Task IsAvailable_TrueAfterSuccessfulFetch()
{
// Arrange
SetupHttpResponse(HttpStatusCode.OK, "{}");
var provider = CreateProvider();
// Act
await provider.GetOverridesAsync(CancellationToken.None);
// Assert
provider.IsAvailable.Should().BeTrue();
}
[Fact]
public async Task IsAvailable_FalseAfterFailedFetch()
{
// Arrange
SetupHttpResponse(HttpStatusCode.ServiceUnavailable, "");
var provider = CreateProvider();
// Act
await provider.GetOverridesAsync(CancellationToken.None);
// Assert
provider.IsAvailable.Should().BeFalse();
}
#endregion
#region OverridesChanged Event Tests
[Fact]
public void OverridesChanged_CanBeSubscribed()
{
// Arrange
var provider = CreateProvider();
var eventRaised = false;
// Act
provider.OverridesChanged += (_, _) => eventRaised = true;
// Assert - no exception during subscription, event not raised yet
eventRaised.Should().BeFalse();
provider.Should().NotBeNull();
}
#endregion
#region Helper Methods
private void SetupHttpResponse(HttpStatusCode statusCode, string content)
{
_httpHandlerMock.Protected()
.Setup<Task<HttpResponseMessage>>(
"SendAsync",
ItExpr.IsAny<HttpRequestMessage>(),
ItExpr.IsAny<CancellationToken>())
.ReturnsAsync(new HttpResponseMessage(statusCode)
{
Content = new StringContent(content)
});
}
#endregion
}

View File

@@ -1,323 +0,0 @@
using FluentAssertions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
public class InMemoryRoutingStateTests
{
private readonly InMemoryRoutingState _sut = new();
private static ConnectionState CreateConnection(
string connectionId = "conn-1",
string serviceName = "test-service",
string version = "1.0.0",
string region = "us-east-1",
InstanceHealthStatus status = InstanceHealthStatus.Healthy,
params (string Method, string Path)[] endpoints)
{
var connection = new ConnectionState
{
ConnectionId = connectionId,
Instance = new InstanceDescriptor
{
InstanceId = $"inst-{connectionId}",
ServiceName = serviceName,
Version = version,
Region = region
},
Status = status,
TransportType = TransportType.InMemory
};
foreach (var (method, path) in endpoints)
{
connection.Endpoints[(method, path)] = new EndpointDescriptor
{
Method = method,
Path = path,
ServiceName = serviceName,
Version = version
};
}
return connection;
}
[Fact]
public void AddConnection_ShouldStoreConnection()
{
// Arrange
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
// Act
_sut.AddConnection(connection);
// Assert
var result = _sut.GetConnection(connection.ConnectionId);
result.Should().NotBeNull();
result.Should().BeSameAs(connection);
}
[Fact]
public void AddConnection_ShouldIndexEndpoints()
{
// Arrange
var connection = CreateConnection(endpoints: [("GET", "/api/users/{id}")]);
// Act
_sut.AddConnection(connection);
// Assert
var endpoint = _sut.ResolveEndpoint("GET", "/api/users/123");
endpoint.Should().NotBeNull();
endpoint!.Path.Should().Be("/api/users/{id}");
}
[Fact]
public void RemoveConnection_ShouldRemoveConnection()
{
// Arrange
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
_sut.AddConnection(connection);
// Act
_sut.RemoveConnection(connection.ConnectionId);
// Assert
var result = _sut.GetConnection(connection.ConnectionId);
result.Should().BeNull();
}
[Fact]
public void RemoveConnection_ShouldRemoveEndpointsWhenLastConnection()
{
// Arrange
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
_sut.AddConnection(connection);
// Act
_sut.RemoveConnection(connection.ConnectionId);
// Assert
var endpoint = _sut.ResolveEndpoint("GET", "/api/test");
endpoint.Should().BeNull();
}
[Fact]
public void RemoveConnection_ShouldKeepEndpointsWhenOtherConnectionsExist()
{
// Arrange
var connection1 = CreateConnection("conn-1", endpoints: [("GET", "/api/test")]);
var connection2 = CreateConnection("conn-2", endpoints: [("GET", "/api/test")]);
_sut.AddConnection(connection1);
_sut.AddConnection(connection2);
// Act
_sut.RemoveConnection("conn-1");
// Assert
var endpoint = _sut.ResolveEndpoint("GET", "/api/test");
endpoint.Should().NotBeNull();
}
[Fact]
public void UpdateConnection_ShouldApplyUpdate()
{
// Arrange
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
_sut.AddConnection(connection);
// Act
_sut.UpdateConnection(connection.ConnectionId, c => c.Status = InstanceHealthStatus.Degraded);
// Assert
var result = _sut.GetConnection(connection.ConnectionId);
result.Should().NotBeNull();
result!.Status.Should().Be(InstanceHealthStatus.Degraded);
}
[Fact]
public void UpdateConnection_ShouldDoNothingForUnknownConnection()
{
// Act - should not throw
_sut.UpdateConnection("unknown", c => c.Status = InstanceHealthStatus.Degraded);
// Assert
var result = _sut.GetConnection("unknown");
result.Should().BeNull();
}
[Fact]
public void GetConnection_ShouldReturnNullForUnknownConnection()
{
// Act
var result = _sut.GetConnection("unknown");
// Assert
result.Should().BeNull();
}
[Fact]
public void GetAllConnections_ShouldReturnAllConnections()
{
// Arrange
var connection1 = CreateConnection("conn-1", endpoints: [("GET", "/api/test1")]);
var connection2 = CreateConnection("conn-2", endpoints: [("GET", "/api/test2")]);
_sut.AddConnection(connection1);
_sut.AddConnection(connection2);
// Act
var result = _sut.GetAllConnections();
// Assert
result.Should().HaveCount(2);
result.Should().Contain(connection1);
result.Should().Contain(connection2);
}
[Fact]
public void GetAllConnections_ShouldReturnEmptyWhenNoConnections()
{
// Act
var result = _sut.GetAllConnections();
// Assert
result.Should().BeEmpty();
}
[Fact]
public void ResolveEndpoint_ShouldMatchExactPath()
{
// Arrange
var connection = CreateConnection(endpoints: [("GET", "/api/health")]);
_sut.AddConnection(connection);
// Act
var result = _sut.ResolveEndpoint("GET", "/api/health");
// Assert
result.Should().NotBeNull();
result!.Path.Should().Be("/api/health");
}
[Fact]
public void ResolveEndpoint_ShouldMatchParameterizedPath()
{
// Arrange
var connection = CreateConnection(endpoints: [("GET", "/api/users/{id}/orders/{orderId}")]);
_sut.AddConnection(connection);
// Act
var result = _sut.ResolveEndpoint("GET", "/api/users/123/orders/456");
// Assert
result.Should().NotBeNull();
result!.Path.Should().Be("/api/users/{id}/orders/{orderId}");
}
[Fact]
public void ResolveEndpoint_ShouldReturnNullForNonMatchingMethod()
{
// Arrange
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
_sut.AddConnection(connection);
// Act
var result = _sut.ResolveEndpoint("POST", "/api/test");
// Assert
result.Should().BeNull();
}
[Fact]
public void ResolveEndpoint_ShouldReturnNullForNonMatchingPath()
{
// Arrange
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
_sut.AddConnection(connection);
// Act
var result = _sut.ResolveEndpoint("GET", "/api/other");
// Assert
result.Should().BeNull();
}
[Fact]
public void ResolveEndpoint_ShouldBeCaseInsensitiveForMethod()
{
// Arrange
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
_sut.AddConnection(connection);
// Act
var result = _sut.ResolveEndpoint("get", "/api/test");
// Assert
result.Should().NotBeNull();
}
[Fact]
public void GetConnectionsFor_ShouldFilterByServiceName()
{
// Arrange
var connection1 = CreateConnection("conn-1", "service-a", endpoints: [("GET", "/api/test")]);
var connection2 = CreateConnection("conn-2", "service-b", endpoints: [("GET", "/api/test")]);
_sut.AddConnection(connection1);
_sut.AddConnection(connection2);
// Act
var result = _sut.GetConnectionsFor("service-a", "1.0.0", "GET", "/api/test");
// Assert
result.Should().HaveCount(1);
result[0].Instance.ServiceName.Should().Be("service-a");
}
[Fact]
public void GetConnectionsFor_ShouldFilterByVersion()
{
// Arrange
var connection1 = CreateConnection("conn-1", "service-a", "1.0.0", endpoints: [("GET", "/api/test")]);
var connection2 = CreateConnection("conn-2", "service-a", "2.0.0", endpoints: [("GET", "/api/test")]);
_sut.AddConnection(connection1);
_sut.AddConnection(connection2);
// Act
var result = _sut.GetConnectionsFor("service-a", "1.0.0", "GET", "/api/test");
// Assert
result.Should().HaveCount(1);
result[0].Instance.Version.Should().Be("1.0.0");
}
[Fact]
public void GetConnectionsFor_ShouldReturnEmptyWhenNoMatch()
{
// Arrange
var connection = CreateConnection("conn-1", "service-a", endpoints: [("GET", "/api/test")]);
_sut.AddConnection(connection);
// Act
var result = _sut.GetConnectionsFor("service-b", "1.0.0", "GET", "/api/test");
// Assert
result.Should().BeEmpty();
}
[Fact]
public void GetConnectionsFor_ShouldMatchParameterizedPaths()
{
// Arrange
var connection = CreateConnection("conn-1", "service-a", endpoints: [("GET", "/api/users/{id}")]);
_sut.AddConnection(connection);
// Act
var result = _sut.GetConnectionsFor("service-a", "1.0.0", "GET", "/api/users/123");
// Assert
result.Should().HaveCount(1);
}
}

View File

@@ -1,182 +0,0 @@
using FluentAssertions;
using StellaOps.Gateway.WebService.OpenApi;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests.OpenApi;
public class ClaimSecurityMapperTests
{
[Fact]
public void GenerateSecuritySchemes_WithNoEndpoints_ReturnsBearerAuthOnly()
{
// Arrange
var endpoints = Array.Empty<EndpointDescriptor>();
// Act
var schemes = ClaimSecurityMapper.GenerateSecuritySchemes(endpoints, "/auth/token");
// Assert
schemes.Should().ContainKey("BearerAuth");
schemes.Should().NotContainKey("OAuth2");
}
[Fact]
public void GenerateSecuritySchemes_WithClaimRequirements_ReturnsOAuth2()
{
// Arrange
var endpoints = new[]
{
new EndpointDescriptor
{
Method = "POST",
Path = "/test",
ServiceName = "test",
Version = "1.0.0",
RequiringClaims = [new ClaimRequirement { Type = "test:write" }]
}
};
// Act
var schemes = ClaimSecurityMapper.GenerateSecuritySchemes(endpoints, "/auth/token");
// Assert
schemes.Should().ContainKey("BearerAuth");
schemes.Should().ContainKey("OAuth2");
}
[Fact]
public void GenerateSecuritySchemes_CollectsAllUniqueScopes()
{
// Arrange
var endpoints = new[]
{
new EndpointDescriptor
{
Method = "POST",
Path = "/invoices",
ServiceName = "billing",
Version = "1.0.0",
RequiringClaims = [new ClaimRequirement { Type = "billing:write" }]
},
new EndpointDescriptor
{
Method = "GET",
Path = "/invoices",
ServiceName = "billing",
Version = "1.0.0",
RequiringClaims = [new ClaimRequirement { Type = "billing:read" }]
},
new EndpointDescriptor
{
Method = "POST",
Path = "/payments",
ServiceName = "billing",
Version = "1.0.0",
RequiringClaims = [new ClaimRequirement { Type = "billing:write" }] // Duplicate
}
};
// Act
var schemes = ClaimSecurityMapper.GenerateSecuritySchemes(endpoints, "/auth/token");
// Assert
var oauth2 = schemes["OAuth2"];
var scopes = oauth2!["flows"]!["clientCredentials"]!["scopes"]!;
scopes.AsObject().Count.Should().Be(2); // Only unique scopes
scopes["billing:write"].Should().NotBeNull();
scopes["billing:read"].Should().NotBeNull();
}
[Fact]
public void GenerateSecuritySchemes_SetsCorrectTokenUrl()
{
// Arrange
var endpoints = new[]
{
new EndpointDescriptor
{
Method = "POST",
Path = "/test",
ServiceName = "test",
Version = "1.0.0",
RequiringClaims = [new ClaimRequirement { Type = "test:write" }]
}
};
// Act
var schemes = ClaimSecurityMapper.GenerateSecuritySchemes(endpoints, "/custom/token");
// Assert
var tokenUrl = schemes["OAuth2"]!["flows"]!["clientCredentials"]!["tokenUrl"]!.GetValue<string>();
tokenUrl.Should().Be("/custom/token");
}
[Fact]
public void GenerateSecurityRequirement_WithNoClaimRequirements_ReturnsEmptyArray()
{
// Arrange
var endpoint = new EndpointDescriptor
{
Method = "GET",
Path = "/public",
ServiceName = "test",
Version = "1.0.0",
RequiringClaims = []
};
// Act
var requirement = ClaimSecurityMapper.GenerateSecurityRequirement(endpoint);
// Assert
requirement.Count.Should().Be(0);
}
[Fact]
public void GenerateSecurityRequirement_WithClaimRequirements_ReturnsBearerAndOAuth2()
{
// Arrange
var endpoint = new EndpointDescriptor
{
Method = "POST",
Path = "/secure",
ServiceName = "test",
Version = "1.0.0",
RequiringClaims =
[
new ClaimRequirement { Type = "billing:write" },
new ClaimRequirement { Type = "billing:admin" }
]
};
// Act
var requirement = ClaimSecurityMapper.GenerateSecurityRequirement(endpoint);
// Assert
requirement.Count.Should().Be(1);
var req = requirement[0]!.AsObject();
req.Should().ContainKey("BearerAuth");
req.Should().ContainKey("OAuth2");
var scopes = req["OAuth2"]!.AsArray();
scopes.Count.Should().Be(2);
}
[Fact]
public void GenerateSecuritySchemes_BearerAuth_HasCorrectStructure()
{
// Arrange
var endpoints = Array.Empty<EndpointDescriptor>();
// Act
var schemes = ClaimSecurityMapper.GenerateSecuritySchemes(endpoints, "/auth/token");
// Assert
var bearer = schemes["BearerAuth"]!.AsObject();
bearer["type"]!.GetValue<string>().Should().Be("http");
bearer["scheme"]!.GetValue<string>().Should().Be("bearer");
bearer["bearerFormat"]!.GetValue<string>().Should().Be("JWT");
}
}

View File

@@ -1,166 +0,0 @@
using FluentAssertions;
using Microsoft.Extensions.Options;
using Moq;
using StellaOps.Gateway.WebService.OpenApi;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests.OpenApi;
public class GatewayOpenApiDocumentCacheTests
{
private readonly Mock<IOpenApiDocumentGenerator> _generator = new();
private readonly OpenApiAggregationOptions _options = new() { CacheTtlSeconds = 60 };
private readonly GatewayOpenApiDocumentCache _sut;
public GatewayOpenApiDocumentCacheTests()
{
_sut = new GatewayOpenApiDocumentCache(
_generator.Object,
Options.Create(_options));
}
[Fact]
public void GetDocument_FirstCall_GeneratesDocument()
{
// Arrange
var expectedDoc = """{"openapi":"3.1.0"}""";
_generator.Setup(x => x.GenerateDocument()).Returns(expectedDoc);
// Act
var (doc, _, _) = _sut.GetDocument();
// Assert
doc.Should().Be(expectedDoc);
_generator.Verify(x => x.GenerateDocument(), Times.Once);
}
[Fact]
public void GetDocument_SubsequentCalls_ReturnsCachedDocument()
{
// Arrange
var expectedDoc = """{"openapi":"3.1.0"}""";
_generator.Setup(x => x.GenerateDocument()).Returns(expectedDoc);
// Act
var (doc1, _, _) = _sut.GetDocument();
var (doc2, _, _) = _sut.GetDocument();
var (doc3, _, _) = _sut.GetDocument();
// Assert
doc1.Should().Be(expectedDoc);
doc2.Should().Be(expectedDoc);
doc3.Should().Be(expectedDoc);
_generator.Verify(x => x.GenerateDocument(), Times.Once);
}
[Fact]
public void GetDocument_AfterInvalidate_RegeneratesDocument()
{
// Arrange
var doc1 = """{"openapi":"3.1.0","version":"1"}""";
var doc2 = """{"openapi":"3.1.0","version":"2"}""";
_generator.SetupSequence(x => x.GenerateDocument())
.Returns(doc1)
.Returns(doc2);
// Act
var (result1, _, _) = _sut.GetDocument();
_sut.Invalidate();
var (result2, _, _) = _sut.GetDocument();
// Assert
result1.Should().Be(doc1);
result2.Should().Be(doc2);
_generator.Verify(x => x.GenerateDocument(), Times.Exactly(2));
}
[Fact]
public void GetDocument_ReturnsConsistentETag()
{
// Arrange
var expectedDoc = """{"openapi":"3.1.0"}""";
_generator.Setup(x => x.GenerateDocument()).Returns(expectedDoc);
// Act
var (_, etag1, _) = _sut.GetDocument();
var (_, etag2, _) = _sut.GetDocument();
// Assert
etag1.Should().NotBeNullOrEmpty();
etag1.Should().Be(etag2);
etag1.Should().StartWith("\"").And.EndWith("\""); // ETag format
}
[Fact]
public void GetDocument_DifferentContent_DifferentETag()
{
// Arrange
var doc1 = """{"openapi":"3.1.0","version":"1"}""";
var doc2 = """{"openapi":"3.1.0","version":"2"}""";
_generator.SetupSequence(x => x.GenerateDocument())
.Returns(doc1)
.Returns(doc2);
// Act
var (_, etag1, _) = _sut.GetDocument();
_sut.Invalidate();
var (_, etag2, _) = _sut.GetDocument();
// Assert
etag1.Should().NotBe(etag2);
}
[Fact]
public void GetDocument_ReturnsGenerationTimestamp()
{
// Arrange
_generator.Setup(x => x.GenerateDocument()).Returns("{}");
var beforeGeneration = DateTime.UtcNow;
// Act
var (_, _, generatedAt) = _sut.GetDocument();
// Assert
generatedAt.Should().BeOnOrAfter(beforeGeneration);
generatedAt.Should().BeOnOrBefore(DateTime.UtcNow);
}
[Fact]
public void Invalidate_CanBeCalledMultipleTimes()
{
// Arrange
_generator.Setup(x => x.GenerateDocument()).Returns("{}");
_sut.GetDocument();
// Act & Assert - should not throw
_sut.Invalidate();
_sut.Invalidate();
_sut.Invalidate();
}
[Fact]
public void GetDocument_WithZeroTtl_AlwaysRegenerates()
{
// Arrange
var options = new OpenApiAggregationOptions { CacheTtlSeconds = 0 };
var sut = new GatewayOpenApiDocumentCache(
_generator.Object,
Options.Create(options));
var callCount = 0;
_generator.Setup(x => x.GenerateDocument())
.Returns(() => $"{{\"call\":{++callCount}}}");
// Act
sut.GetDocument();
// Wait a tiny bit to ensure TTL is exceeded
Thread.Sleep(10);
sut.GetDocument();
// Assert
// With 0 TTL, each call should regenerate
_generator.Verify(x => x.GenerateDocument(), Times.Exactly(2));
}
}

View File

@@ -1,338 +0,0 @@
using System.Text.Json;
using FluentAssertions;
using Microsoft.Extensions.Options;
using Moq;
using StellaOps.Gateway.WebService.OpenApi;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests.OpenApi;
public class OpenApiDocumentGeneratorTests
{
private readonly Mock<IGlobalRoutingState> _routingState = new();
private readonly OpenApiAggregationOptions _options = new();
private readonly OpenApiDocumentGenerator _sut;
public OpenApiDocumentGeneratorTests()
{
_sut = new OpenApiDocumentGenerator(
_routingState.Object,
Options.Create(_options));
}
private static ConnectionState CreateConnection(
string serviceName = "test-service",
string version = "1.0.0",
params EndpointDescriptor[] endpoints)
{
var connection = new ConnectionState
{
ConnectionId = $"conn-{serviceName}",
Instance = new InstanceDescriptor
{
InstanceId = $"inst-{serviceName}",
ServiceName = serviceName,
Version = version,
Region = "us-east-1"
},
Status = InstanceHealthStatus.Healthy,
TransportType = TransportType.InMemory,
Schemas = new Dictionary<string, SchemaDefinition>(),
OpenApiInfo = new ServiceOpenApiInfo
{
Title = serviceName,
Description = $"Test {serviceName} service"
}
};
foreach (var endpoint in endpoints)
{
connection.Endpoints[(endpoint.Method, endpoint.Path)] = endpoint;
}
return connection;
}
[Fact]
public void GenerateDocument_WithNoConnections_ReturnsValidOpenApiDocument()
{
// Arrange
_routingState.Setup(x => x.GetAllConnections()).Returns([]);
// Act
var document = _sut.GenerateDocument();
// Assert
document.Should().NotBeNullOrEmpty();
var doc = JsonDocument.Parse(document);
doc.RootElement.GetProperty("openapi").GetString().Should().Be("3.1.0");
doc.RootElement.GetProperty("info").GetProperty("title").GetString().Should().Be(_options.Title);
}
[Fact]
public void GenerateDocument_SetsCorrectInfoSection()
{
// Arrange
_options.Title = "My Gateway API";
_options.Description = "My description";
_options.Version = "2.0.0";
_options.LicenseName = "MIT";
_routingState.Setup(x => x.GetAllConnections()).Returns([]);
// Act
var document = _sut.GenerateDocument();
// Assert
var doc = JsonDocument.Parse(document);
var info = doc.RootElement.GetProperty("info");
info.GetProperty("title").GetString().Should().Be("My Gateway API");
info.GetProperty("description").GetString().Should().Be("My description");
info.GetProperty("version").GetString().Should().Be("2.0.0");
info.GetProperty("license").GetProperty("name").GetString().Should().Be("MIT");
}
[Fact]
public void GenerateDocument_WithConnections_GeneratesPaths()
{
// Arrange
var endpoint = new EndpointDescriptor
{
Method = "GET",
Path = "/api/items",
ServiceName = "inventory",
Version = "1.0.0"
};
var connection = CreateConnection("inventory", "1.0.0", endpoint);
_routingState.Setup(x => x.GetAllConnections()).Returns([connection]);
// Act
var document = _sut.GenerateDocument();
// Assert
var doc = JsonDocument.Parse(document);
var paths = doc.RootElement.GetProperty("paths");
paths.TryGetProperty("/api/items", out var pathItem).Should().BeTrue();
pathItem.TryGetProperty("get", out var operation).Should().BeTrue();
}
[Fact]
public void GenerateDocument_WithSchemaInfo_IncludesDocumentation()
{
// Arrange
var endpoint = new EndpointDescriptor
{
Method = "POST",
Path = "/invoices",
ServiceName = "billing",
Version = "1.0.0",
SchemaInfo = new EndpointSchemaInfo
{
Summary = "Create invoice",
Description = "Creates a new invoice",
Tags = ["billing", "invoices"],
Deprecated = false
}
};
var connection = CreateConnection("billing", "1.0.0", endpoint);
_routingState.Setup(x => x.GetAllConnections()).Returns([connection]);
// Act
var document = _sut.GenerateDocument();
// Assert
var doc = JsonDocument.Parse(document);
var operation = doc.RootElement
.GetProperty("paths")
.GetProperty("/invoices")
.GetProperty("post");
operation.GetProperty("summary").GetString().Should().Be("Create invoice");
operation.GetProperty("description").GetString().Should().Be("Creates a new invoice");
}
[Fact]
public void GenerateDocument_WithSchemas_IncludesSchemaReferences()
{
// Arrange
var endpoint = new EndpointDescriptor
{
Method = "POST",
Path = "/invoices",
ServiceName = "billing",
Version = "1.0.0",
SchemaInfo = new EndpointSchemaInfo
{
RequestSchemaId = "CreateInvoiceRequest"
}
};
var connection = CreateConnection("billing", "1.0.0", endpoint);
var connectionWithSchemas = new ConnectionState
{
ConnectionId = connection.ConnectionId,
Instance = connection.Instance,
Status = connection.Status,
TransportType = connection.TransportType,
Schemas = new Dictionary<string, SchemaDefinition>
{
["CreateInvoiceRequest"] = new SchemaDefinition
{
SchemaId = "CreateInvoiceRequest",
SchemaJson = """{"type": "object", "properties": {"amount": {"type": "number"}}}""",
ETag = "\"ABC123\""
}
}
};
connectionWithSchemas.Endpoints[(endpoint.Method, endpoint.Path)] = endpoint;
_routingState.Setup(x => x.GetAllConnections()).Returns([connectionWithSchemas]);
// Act
var document = _sut.GenerateDocument();
// Assert
var doc = JsonDocument.Parse(document);
// Check request body reference
var requestBody = doc.RootElement
.GetProperty("paths")
.GetProperty("/invoices")
.GetProperty("post")
.GetProperty("requestBody")
.GetProperty("content")
.GetProperty("application/json")
.GetProperty("schema")
.GetProperty("$ref")
.GetString();
requestBody.Should().Be("#/components/schemas/billing_CreateInvoiceRequest");
// Check schema exists in components
var schemas = doc.RootElement.GetProperty("components").GetProperty("schemas");
schemas.TryGetProperty("billing_CreateInvoiceRequest", out _).Should().BeTrue();
}
[Fact]
public void GenerateDocument_WithClaimRequirements_IncludesSecurity()
{
// Arrange
var endpoint = new EndpointDescriptor
{
Method = "POST",
Path = "/invoices",
ServiceName = "billing",
Version = "1.0.0",
RequiringClaims = [new ClaimRequirement { Type = "billing:write" }]
};
var connection = CreateConnection("billing", "1.0.0", endpoint);
_routingState.Setup(x => x.GetAllConnections()).Returns([connection]);
// Act
var document = _sut.GenerateDocument();
// Assert
var doc = JsonDocument.Parse(document);
// Check security schemes
var securitySchemes = doc.RootElement
.GetProperty("components")
.GetProperty("securitySchemes");
securitySchemes.TryGetProperty("BearerAuth", out _).Should().BeTrue();
securitySchemes.TryGetProperty("OAuth2", out _).Should().BeTrue();
// Check operation security
var operation = doc.RootElement
.GetProperty("paths")
.GetProperty("/invoices")
.GetProperty("post");
operation.TryGetProperty("security", out _).Should().BeTrue();
}
[Fact]
public void GenerateDocument_WithMultipleServices_GeneratesTags()
{
// Arrange
var billingEndpoint = new EndpointDescriptor
{
Method = "POST",
Path = "/invoices",
ServiceName = "billing",
Version = "1.0.0"
};
var inventoryEndpoint = new EndpointDescriptor
{
Method = "GET",
Path = "/items",
ServiceName = "inventory",
Version = "2.0.0"
};
var billingConn = CreateConnection("billing", "1.0.0", billingEndpoint);
var inventoryConn = CreateConnection("inventory", "2.0.0", inventoryEndpoint);
_routingState.Setup(x => x.GetAllConnections()).Returns([billingConn, inventoryConn]);
// Act
var document = _sut.GenerateDocument();
// Assert
var doc = JsonDocument.Parse(document);
var tags = doc.RootElement.GetProperty("tags");
tags.GetArrayLength().Should().Be(2);
var tagNames = new List<string>();
foreach (var tag in tags.EnumerateArray())
{
tagNames.Add(tag.GetProperty("name").GetString()!);
}
tagNames.Should().Contain("billing");
tagNames.Should().Contain("inventory");
}
[Fact]
public void GenerateDocument_WithDeprecatedEndpoint_SetsDeprecatedFlag()
{
// Arrange
var endpoint = new EndpointDescriptor
{
Method = "GET",
Path = "/legacy",
ServiceName = "test",
Version = "1.0.0",
SchemaInfo = new EndpointSchemaInfo
{
Deprecated = true
}
};
var connection = CreateConnection("test", "1.0.0", endpoint);
_routingState.Setup(x => x.GetAllConnections()).Returns([connection]);
// Act
var document = _sut.GenerateDocument();
// Assert
var doc = JsonDocument.Parse(document);
var operation = doc.RootElement
.GetProperty("paths")
.GetProperty("/legacy")
.GetProperty("get");
operation.GetProperty("deprecated").GetBoolean().Should().BeTrue();
}
}

View File

@@ -1,337 +0,0 @@
using FluentAssertions;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Moq;
using StellaOps.Gateway.WebService.Middleware;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
/// <summary>
/// Unit tests for <see cref="PayloadLimitsMiddleware"/>.
/// </summary>
public sealed class PayloadLimitsMiddlewareTests
{
private readonly Mock<IPayloadTracker> _trackerMock;
private readonly Mock<RequestDelegate> _nextMock;
private readonly PayloadLimits _defaultLimits;
private bool _nextCalled;
public PayloadLimitsMiddlewareTests()
{
_trackerMock = new Mock<IPayloadTracker>();
_nextMock = new Mock<RequestDelegate>();
_nextMock.Setup(n => n(It.IsAny<HttpContext>()))
.Callback(() => _nextCalled = true)
.Returns(Task.CompletedTask);
_defaultLimits = new PayloadLimits
{
MaxRequestBytesPerCall = 10 * 1024 * 1024, // 10MB
MaxRequestBytesPerConnection = 100 * 1024 * 1024, // 100MB
MaxAggregateInflightBytes = 1024 * 1024 * 1024 // 1GB
};
}
private PayloadLimitsMiddleware CreateMiddleware(PayloadLimits? limits = null)
{
return new PayloadLimitsMiddleware(
_nextMock.Object,
Options.Create(limits ?? _defaultLimits),
NullLogger<PayloadLimitsMiddleware>.Instance);
}
private static HttpContext CreateHttpContext(long? contentLength = null, string connectionId = "conn-1")
{
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
context.Request.Body = new MemoryStream();
context.Connection.Id = connectionId;
if (contentLength.HasValue)
{
context.Request.ContentLength = contentLength;
}
return context;
}
#region Within Limits Tests
[Fact]
public async Task Invoke_WithinLimits_CallsNext()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(contentLength: 1000);
_trackerMock.Setup(t => t.TryReserve("conn-1", 1000))
.Returns(true);
// Act
await middleware.Invoke(context, _trackerMock.Object);
// Assert
_nextCalled.Should().BeTrue();
}
[Fact]
public async Task Invoke_WithNoContentLength_CallsNext()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(contentLength: null);
_trackerMock.Setup(t => t.TryReserve("conn-1", 0))
.Returns(true);
// Act
await middleware.Invoke(context, _trackerMock.Object);
// Assert
_nextCalled.Should().BeTrue();
}
[Fact]
public async Task Invoke_WithZeroContentLength_CallsNext()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(contentLength: 0);
_trackerMock.Setup(t => t.TryReserve("conn-1", 0))
.Returns(true);
// Act
await middleware.Invoke(context, _trackerMock.Object);
// Assert
_nextCalled.Should().BeTrue();
}
#endregion
#region Per-Call Limit Tests
[Fact]
public async Task Invoke_ExceedsPerCallLimit_Returns413()
{
// Arrange
var limits = new PayloadLimits { MaxRequestBytesPerCall = 1000 };
var middleware = CreateMiddleware(limits);
var context = CreateHttpContext(contentLength: 2000);
// Act
await middleware.Invoke(context, _trackerMock.Object);
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status413PayloadTooLarge);
}
[Fact]
public async Task Invoke_ExceedsPerCallLimit_WritesErrorResponse()
{
// Arrange
var limits = new PayloadLimits { MaxRequestBytesPerCall = 1000 };
var middleware = CreateMiddleware(limits);
var context = CreateHttpContext(contentLength: 2000);
// Act
await middleware.Invoke(context, _trackerMock.Object);
// Assert
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Contain("Payload Too Large");
responseBody.Should().Contain("1000");
responseBody.Should().Contain("2000");
}
[Fact]
public async Task Invoke_ExactlyAtPerCallLimit_CallsNext()
{
// Arrange
var limits = new PayloadLimits { MaxRequestBytesPerCall = 1000 };
var middleware = CreateMiddleware(limits);
var context = CreateHttpContext(contentLength: 1000);
_trackerMock.Setup(t => t.TryReserve("conn-1", 1000))
.Returns(true);
// Act
await middleware.Invoke(context, _trackerMock.Object);
// Assert
_nextCalled.Should().BeTrue();
}
#endregion
#region Aggregate Limit Tests
[Fact]
public async Task Invoke_ExceedsAggregateLimit_Returns503()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(contentLength: 1000);
_trackerMock.Setup(t => t.TryReserve("conn-1", 1000))
.Returns(false);
_trackerMock.Setup(t => t.IsOverloaded)
.Returns(true);
_trackerMock.Setup(t => t.CurrentInflightBytes)
.Returns(1024 * 1024 * 1024); // 1GB
// Act
await middleware.Invoke(context, _trackerMock.Object);
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status503ServiceUnavailable);
}
[Fact]
public async Task Invoke_ExceedsAggregateLimit_WritesOverloadedResponse()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(contentLength: 1000);
_trackerMock.Setup(t => t.TryReserve("conn-1", 1000))
.Returns(false);
_trackerMock.Setup(t => t.IsOverloaded)
.Returns(true);
// Act
await middleware.Invoke(context, _trackerMock.Object);
// Assert
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Contain("Overloaded");
}
#endregion
#region Per-Connection Limit Tests
[Fact]
public async Task Invoke_ExceedsPerConnectionLimit_Returns429()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(contentLength: 1000);
_trackerMock.Setup(t => t.TryReserve("conn-1", 1000))
.Returns(false);
_trackerMock.Setup(t => t.IsOverloaded)
.Returns(false); // Not aggregate limit
_trackerMock.Setup(t => t.GetConnectionInflightBytes("conn-1"))
.Returns(100 * 1024 * 1024); // 100MB
// Act
await middleware.Invoke(context, _trackerMock.Object);
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status429TooManyRequests);
}
[Fact]
public async Task Invoke_ExceedsPerConnectionLimit_WritesErrorResponse()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(contentLength: 1000);
_trackerMock.Setup(t => t.TryReserve("conn-1", 1000))
.Returns(false);
_trackerMock.Setup(t => t.IsOverloaded)
.Returns(false);
// Act
await middleware.Invoke(context, _trackerMock.Object);
// Assert
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Contain("Too Many Requests");
}
#endregion
#region Release Tests
[Fact]
public async Task Invoke_AfterSuccess_ReleasesReservation()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(contentLength: 1000);
_trackerMock.Setup(t => t.TryReserve("conn-1", 1000))
.Returns(true);
// Act
await middleware.Invoke(context, _trackerMock.Object);
// Assert
_trackerMock.Verify(t => t.Release("conn-1", It.IsAny<long>()), Times.Once);
}
[Fact]
public async Task Invoke_AfterNextThrows_StillReleasesReservation()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(contentLength: 1000);
_trackerMock.Setup(t => t.TryReserve("conn-1", 1000))
.Returns(true);
_nextMock.Setup(n => n(It.IsAny<HttpContext>()))
.ThrowsAsync(new InvalidOperationException("Test error"));
// Act
var act = async () => await middleware.Invoke(context, _trackerMock.Object);
// Assert
await act.Should().ThrowAsync<InvalidOperationException>();
_trackerMock.Verify(t => t.Release("conn-1", It.IsAny<long>()), Times.Once);
}
#endregion
#region Different Connections Tests
[Fact]
public async Task Invoke_DifferentConnections_TrackedSeparately()
{
// Arrange
var middleware = CreateMiddleware();
var context1 = CreateHttpContext(contentLength: 1000, connectionId: "conn-1");
var context2 = CreateHttpContext(contentLength: 2000, connectionId: "conn-2");
_trackerMock.Setup(t => t.TryReserve(It.IsAny<string>(), It.IsAny<long>()))
.Returns(true);
// Act
await middleware.Invoke(context1, _trackerMock.Object);
await middleware.Invoke(context2, _trackerMock.Object);
// Assert
_trackerMock.Verify(t => t.TryReserve("conn-1", 1000), Times.Once);
_trackerMock.Verify(t => t.TryReserve("conn-2", 2000), Times.Once);
}
#endregion
}

View File

@@ -1,254 +0,0 @@
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using StellaOps.Gateway.WebService.Middleware;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
public class PayloadTrackerTests
{
private readonly PayloadLimits _limits = new()
{
MaxRequestBytesPerCall = 1024,
MaxRequestBytesPerConnection = 4096,
MaxAggregateInflightBytes = 8192
};
private PayloadTracker CreateTracker()
{
return new PayloadTracker(
Options.Create(_limits),
NullLogger<PayloadTracker>.Instance);
}
[Fact]
public void TryReserve_WithinLimits_ReturnsTrue()
{
var tracker = CreateTracker();
var result = tracker.TryReserve("conn-1", 500);
Assert.True(result);
Assert.Equal(500, tracker.CurrentInflightBytes);
}
[Fact]
public void TryReserve_ExceedsAggregateLimits_ReturnsFalse()
{
var tracker = CreateTracker();
// Reserve from multiple connections to approach aggregate limit (8192)
// Each connection can have up to 4096 bytes
Assert.True(tracker.TryReserve("conn-1", 4000));
Assert.True(tracker.TryReserve("conn-2", 4000));
// Now at 8000 bytes
// Another reservation that exceeds aggregate limit (8000 + 500 > 8192) should fail
var result = tracker.TryReserve("conn-3", 500);
Assert.False(result);
Assert.Equal(8000, tracker.CurrentInflightBytes);
}
[Fact]
public void TryReserve_ExceedsPerConnectionLimit_ReturnsFalse()
{
var tracker = CreateTracker();
// Reserve up to per-connection limit
Assert.True(tracker.TryReserve("conn-1", 4000));
// Next reservation on same connection should fail
var result = tracker.TryReserve("conn-1", 500);
Assert.False(result);
}
[Fact]
public void TryReserve_DifferentConnections_TrackedSeparately()
{
var tracker = CreateTracker();
Assert.True(tracker.TryReserve("conn-1", 3000));
Assert.True(tracker.TryReserve("conn-2", 3000));
Assert.Equal(3000, tracker.GetConnectionInflightBytes("conn-1"));
Assert.Equal(3000, tracker.GetConnectionInflightBytes("conn-2"));
Assert.Equal(6000, tracker.CurrentInflightBytes);
}
[Fact]
public void Release_DecreasesInflightBytes()
{
var tracker = CreateTracker();
tracker.TryReserve("conn-1", 1000);
tracker.Release("conn-1", 500);
Assert.Equal(500, tracker.CurrentInflightBytes);
Assert.Equal(500, tracker.GetConnectionInflightBytes("conn-1"));
}
[Fact]
public void Release_CannotGoNegative()
{
var tracker = CreateTracker();
tracker.TryReserve("conn-1", 100);
tracker.Release("conn-1", 500); // More than reserved
Assert.Equal(0, tracker.GetConnectionInflightBytes("conn-1"));
}
[Fact]
public void IsOverloaded_TrueWhenExceedsLimit()
{
var tracker = CreateTracker();
// Reservation at limit passes (8192 <= 8192 is false for >, so not overloaded at exactly limit)
// But we can't exceed the limit. The IsOverloaded check is for current > limit
// So at exactly 8192, IsOverloaded should be false (8192 > 8192 is false)
// Reserving 8193 would be rejected. So let's test that at limit, IsOverloaded is false
tracker.TryReserve("conn-1", 8192);
// At exactly the limit, IsOverloaded is false (8192 > 8192 = false)
Assert.False(tracker.IsOverloaded);
}
[Fact]
public void IsOverloaded_FalseWhenWithinLimit()
{
var tracker = CreateTracker();
tracker.TryReserve("conn-1", 4000);
Assert.False(tracker.IsOverloaded);
}
[Fact]
public void GetConnectionInflightBytes_ReturnsZeroForUnknownConnection()
{
var tracker = CreateTracker();
var result = tracker.GetConnectionInflightBytes("unknown");
Assert.Equal(0, result);
}
}
public class ByteCountingStreamTests
{
[Fact]
public async Task ReadAsync_CountsBytesRead()
{
var data = new byte[] { 1, 2, 3, 4, 5 };
using var inner = new MemoryStream(data);
using var stream = new ByteCountingStream(inner, 100);
var buffer = new byte[10];
var read = await stream.ReadAsync(buffer);
Assert.Equal(5, read);
Assert.Equal(5, stream.BytesRead);
}
[Fact]
public async Task ReadAsync_ThrowsWhenLimitExceeded()
{
var data = new byte[100];
using var inner = new MemoryStream(data);
using var stream = new ByteCountingStream(inner, 50);
var buffer = new byte[100];
var ex = await Assert.ThrowsAsync<PayloadLimitExceededException>(
() => stream.ReadAsync(buffer).AsTask());
Assert.Equal(100, ex.BytesRead);
Assert.Equal(50, ex.Limit);
}
[Fact]
public async Task ReadAsync_CallsCallbackOnLimitExceeded()
{
var data = new byte[100];
using var inner = new MemoryStream(data);
var callbackCalled = false;
using var stream = new ByteCountingStream(inner, 50, () => callbackCalled = true);
var buffer = new byte[100];
await Assert.ThrowsAsync<PayloadLimitExceededException>(
() => stream.ReadAsync(buffer).AsTask());
Assert.True(callbackCalled);
}
[Fact]
public async Task ReadAsync_AccumulatesAcrossMultipleReads()
{
var data = new byte[100];
using var inner = new MemoryStream(data);
using var stream = new ByteCountingStream(inner, 60);
var buffer = new byte[30];
// First read - 30 bytes
var read1 = await stream.ReadAsync(buffer);
Assert.Equal(30, read1);
Assert.Equal(30, stream.BytesRead);
// Second read - 30 more bytes
var read2 = await stream.ReadAsync(buffer);
Assert.Equal(30, read2);
Assert.Equal(60, stream.BytesRead);
// Third read should exceed limit
await Assert.ThrowsAsync<PayloadLimitExceededException>(
() => stream.ReadAsync(buffer).AsTask());
}
[Fact]
public void Stream_Properties_AreCorrect()
{
using var inner = new MemoryStream();
using var stream = new ByteCountingStream(inner, 100);
Assert.True(stream.CanRead);
Assert.False(stream.CanWrite);
Assert.False(stream.CanSeek);
}
[Fact]
public void Write_ThrowsNotSupported()
{
using var inner = new MemoryStream();
using var stream = new ByteCountingStream(inner, 100);
Assert.Throws<NotSupportedException>(() => stream.Write(new byte[10], 0, 10));
}
[Fact]
public void Seek_ThrowsNotSupported()
{
using var inner = new MemoryStream();
using var stream = new ByteCountingStream(inner, 100);
Assert.Throws<NotSupportedException>(() => stream.Seek(0, SeekOrigin.Begin));
}
}
public class PayloadLimitExceededExceptionTests
{
[Fact]
public void Constructor_SetsProperties()
{
var ex = new PayloadLimitExceededException(1000, 500);
Assert.Equal(1000, ex.BytesRead);
Assert.Equal(500, ex.Limit);
Assert.Contains("1000", ex.Message);
Assert.Contains("500", ex.Message);
}
}

View File

@@ -1,429 +0,0 @@
using FluentAssertions;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Options;
using Moq;
using StellaOps.Gateway.WebService.Middleware;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
/// <summary>
/// Unit tests for <see cref="RoutingDecisionMiddleware"/>.
/// </summary>
public sealed class RoutingDecisionMiddlewareTests
{
private readonly Mock<IRoutingPlugin> _routingPluginMock;
private readonly Mock<IGlobalRoutingState> _routingStateMock;
private readonly Mock<RequestDelegate> _nextMock;
private readonly GatewayNodeConfig _gatewayConfig;
private readonly RoutingOptions _routingOptions;
private bool _nextCalled;
public RoutingDecisionMiddlewareTests()
{
_routingPluginMock = new Mock<IRoutingPlugin>();
_routingStateMock = new Mock<IGlobalRoutingState>();
_nextMock = new Mock<RequestDelegate>();
_nextMock.Setup(n => n(It.IsAny<HttpContext>()))
.Callback(() => _nextCalled = true)
.Returns(Task.CompletedTask);
_gatewayConfig = new GatewayNodeConfig
{
Region = "us-east-1",
NodeId = "gw-01",
Environment = "test"
};
_routingOptions = new RoutingOptions
{
DefaultVersion = "1.0.0"
};
}
private RoutingDecisionMiddleware CreateMiddleware()
{
return new RoutingDecisionMiddleware(_nextMock.Object);
}
private HttpContext CreateHttpContext(EndpointDescriptor? endpoint = null)
{
var context = new DefaultHttpContext();
context.Request.Method = "GET";
context.Request.Path = "/api/test";
context.Response.Body = new MemoryStream();
if (endpoint is not null)
{
context.Items[RouterHttpContextKeys.EndpointDescriptor] = endpoint;
}
return context;
}
private static EndpointDescriptor CreateEndpoint(
string serviceName = "test-service",
string version = "1.0.0")
{
return new EndpointDescriptor
{
ServiceName = serviceName,
Version = version,
Method = "GET",
Path = "/api/test"
};
}
private static ConnectionState CreateConnection(
string connectionId = "conn-1",
InstanceHealthStatus status = InstanceHealthStatus.Healthy)
{
return new ConnectionState
{
ConnectionId = connectionId,
Instance = new InstanceDescriptor
{
InstanceId = $"inst-{connectionId}",
ServiceName = "test-service",
Version = "1.0.0",
Region = "us-east-1"
},
Status = status,
TransportType = TransportType.InMemory
};
}
private static RoutingDecision CreateDecision(
EndpointDescriptor? endpoint = null,
ConnectionState? connection = null)
{
return new RoutingDecision
{
Endpoint = endpoint ?? CreateEndpoint(),
Connection = connection ?? CreateConnection(),
TransportType = TransportType.InMemory,
EffectiveTimeout = TimeSpan.FromSeconds(30)
};
}
#region Missing Endpoint Tests
[Fact]
public async Task Invoke_WithNoEndpoint_Returns500()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(endpoint: null);
// Act
await middleware.Invoke(
context,
_routingPluginMock.Object,
_routingStateMock.Object,
Options.Create(_gatewayConfig),
Options.Create(_routingOptions));
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status500InternalServerError);
}
[Fact]
public async Task Invoke_WithNoEndpoint_WritesErrorResponse()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(endpoint: null);
// Act
await middleware.Invoke(
context,
_routingPluginMock.Object,
_routingStateMock.Object,
Options.Create(_gatewayConfig),
Options.Create(_routingOptions));
// Assert
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Contain("descriptor missing");
}
#endregion
#region Available Instance Tests
[Fact]
public async Task Invoke_WithAvailableInstance_SetsRoutingDecision()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var connection = CreateConnection();
var decision = CreateDecision(endpoint, connection);
var context = CreateHttpContext(endpoint: endpoint);
_routingStateMock.Setup(r => r.GetConnectionsFor(
endpoint.ServiceName, endpoint.Version, endpoint.Method, endpoint.Path))
.Returns([connection]);
_routingPluginMock.Setup(p => p.ChooseInstanceAsync(
It.IsAny<RoutingContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(decision);
// Act
await middleware.Invoke(
context,
_routingPluginMock.Object,
_routingStateMock.Object,
Options.Create(_gatewayConfig),
Options.Create(_routingOptions));
// Assert
_nextCalled.Should().BeTrue();
context.Items[RouterHttpContextKeys.RoutingDecision].Should().Be(decision);
}
[Fact]
public async Task Invoke_WithAvailableInstance_CallsNext()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var decision = CreateDecision(endpoint);
var context = CreateHttpContext(endpoint: endpoint);
_routingStateMock.Setup(r => r.GetConnectionsFor(
It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>()))
.Returns([CreateConnection()]);
_routingPluginMock.Setup(p => p.ChooseInstanceAsync(
It.IsAny<RoutingContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(decision);
// Act
await middleware.Invoke(
context,
_routingPluginMock.Object,
_routingStateMock.Object,
Options.Create(_gatewayConfig),
Options.Create(_routingOptions));
// Assert
_nextCalled.Should().BeTrue();
}
#endregion
#region No Instances Tests
[Fact]
public async Task Invoke_WithNoInstances_Returns503()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var context = CreateHttpContext(endpoint: endpoint);
_routingStateMock.Setup(r => r.GetConnectionsFor(
It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>()))
.Returns([]);
_routingPluginMock.Setup(p => p.ChooseInstanceAsync(
It.IsAny<RoutingContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync((RoutingDecision?)null);
// Act
await middleware.Invoke(
context,
_routingPluginMock.Object,
_routingStateMock.Object,
Options.Create(_gatewayConfig),
Options.Create(_routingOptions));
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status503ServiceUnavailable);
}
[Fact]
public async Task Invoke_WithNoInstances_WritesErrorResponse()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var context = CreateHttpContext(endpoint: endpoint);
_routingStateMock.Setup(r => r.GetConnectionsFor(
It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>()))
.Returns([]);
_routingPluginMock.Setup(p => p.ChooseInstanceAsync(
It.IsAny<RoutingContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync((RoutingDecision?)null);
// Act
await middleware.Invoke(
context,
_routingPluginMock.Object,
_routingStateMock.Object,
Options.Create(_gatewayConfig),
Options.Create(_routingOptions));
// Assert
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Contain("No instances available");
responseBody.Should().Contain("test-service");
}
#endregion
#region Routing Context Tests
[Fact]
public async Task Invoke_PassesCorrectRoutingContext()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var decision = CreateDecision(endpoint);
var connection = CreateConnection();
var context = CreateHttpContext(endpoint: endpoint);
_routingStateMock.Setup(r => r.GetConnectionsFor(
endpoint.ServiceName, endpoint.Version, endpoint.Method, endpoint.Path))
.Returns([connection]);
RoutingContext? capturedContext = null;
_routingPluginMock.Setup(p => p.ChooseInstanceAsync(
It.IsAny<RoutingContext>(), It.IsAny<CancellationToken>()))
.Callback<RoutingContext, CancellationToken>((ctx, _) => capturedContext = ctx)
.ReturnsAsync(decision);
// Act
await middleware.Invoke(
context,
_routingPluginMock.Object,
_routingStateMock.Object,
Options.Create(_gatewayConfig),
Options.Create(_routingOptions));
// Assert
capturedContext.Should().NotBeNull();
capturedContext!.Method.Should().Be("GET");
capturedContext.Path.Should().Be("/api/test");
capturedContext.GatewayRegion.Should().Be("us-east-1");
capturedContext.Endpoint.Should().Be(endpoint);
capturedContext.AvailableConnections.Should().ContainSingle();
}
[Fact]
public async Task Invoke_PassesRequestHeaders()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var decision = CreateDecision(endpoint);
var context = CreateHttpContext(endpoint: endpoint);
context.Request.Headers["X-Custom-Header"] = "CustomValue";
_routingStateMock.Setup(r => r.GetConnectionsFor(
It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>()))
.Returns([CreateConnection()]);
RoutingContext? capturedContext = null;
_routingPluginMock.Setup(p => p.ChooseInstanceAsync(
It.IsAny<RoutingContext>(), It.IsAny<CancellationToken>()))
.Callback<RoutingContext, CancellationToken>((ctx, _) => capturedContext = ctx)
.ReturnsAsync(decision);
// Act
await middleware.Invoke(
context,
_routingPluginMock.Object,
_routingStateMock.Object,
Options.Create(_gatewayConfig),
Options.Create(_routingOptions));
// Assert
capturedContext!.Headers.Should().ContainKey("X-Custom-Header");
capturedContext.Headers["X-Custom-Header"].Should().Be("CustomValue");
}
#endregion
#region Version Extraction Tests
[Fact]
public async Task Invoke_WithXApiVersionHeader_ExtractsVersion()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var decision = CreateDecision(endpoint);
var context = CreateHttpContext(endpoint: endpoint);
context.Request.Headers["X-Api-Version"] = "2.0.0";
_routingStateMock.Setup(r => r.GetConnectionsFor(
It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>()))
.Returns([CreateConnection()]);
RoutingContext? capturedContext = null;
_routingPluginMock.Setup(p => p.ChooseInstanceAsync(
It.IsAny<RoutingContext>(), It.IsAny<CancellationToken>()))
.Callback<RoutingContext, CancellationToken>((ctx, _) => capturedContext = ctx)
.ReturnsAsync(decision);
// Act
await middleware.Invoke(
context,
_routingPluginMock.Object,
_routingStateMock.Object,
Options.Create(_gatewayConfig),
Options.Create(_routingOptions));
// Assert
capturedContext!.RequestedVersion.Should().Be("2.0.0");
}
[Fact]
public async Task Invoke_WithNoVersionHeader_UsesDefault()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint();
var decision = CreateDecision(endpoint);
var context = CreateHttpContext(endpoint: endpoint);
_routingStateMock.Setup(r => r.GetConnectionsFor(
It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>()))
.Returns([CreateConnection()]);
RoutingContext? capturedContext = null;
_routingPluginMock.Setup(p => p.ChooseInstanceAsync(
It.IsAny<RoutingContext>(), It.IsAny<CancellationToken>()))
.Callback<RoutingContext, CancellationToken>((ctx, _) => capturedContext = ctx)
.ReturnsAsync(decision);
// Act
await middleware.Invoke(
context,
_routingPluginMock.Object,
_routingStateMock.Object,
Options.Create(_gatewayConfig),
Options.Create(_routingOptions));
// Assert
capturedContext!.RequestedVersion.Should().Be("1.0.0"); // From _routingOptions
}
#endregion
}

View File

@@ -1,28 +0,0 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<LangVersion>preview</LangVersion>
<Nullable>enable</Nullable>
<ImplicitUsings>enable</ImplicitUsings>
<IsPackable>false</IsPackable>
<TreatWarningsAsErrors>false</TreatWarningsAsErrors>
<!-- Disable Concelier test infrastructure - not needed for Gateway tests -->
<UseConcelierTestInfra>false</UseConcelierTestInfra>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.14.0" />
<PackageReference Include="Microsoft.AspNetCore.Mvc.Testing" Version="10.0.0-preview.7.25380.108" />
<PackageReference Include="xunit" Version="2.9.2" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="FluentAssertions" Version="6.12.0" />
<PackageReference Include="Moq" Version="4.20.70" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\StellaOps.Gateway.WebService\StellaOps.Gateway.WebService.csproj" />
<ProjectReference Include="..\..\..\__Libraries\StellaOps.Router.Transport.InMemory\StellaOps.Router.Transport.InMemory.csproj" />
<ProjectReference Include="..\..\..\__Libraries\StellaOps.Microservice\StellaOps.Microservice.csproj" />
</ItemGroup>
</Project>

View File

@@ -1,315 +0,0 @@
using System.Threading.Channels;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using StellaOps.Microservice.Streaming;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using StellaOps.Router.Transport.InMemory;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
public class StreamingTests
{
private readonly InMemoryConnectionRegistry _registry = new();
private readonly InMemoryTransportOptions _options = new() { SimulatedLatency = TimeSpan.Zero };
private InMemoryTransportClient CreateClient()
{
return new InMemoryTransportClient(
_registry,
Options.Create(_options),
NullLogger<InMemoryTransportClient>.Instance);
}
[Fact]
public void StreamDataPayload_HasRequiredProperties()
{
var payload = new StreamDataPayload
{
CorrelationId = Guid.NewGuid(),
Data = new byte[] { 1, 2, 3 },
EndOfStream = true,
SequenceNumber = 5
};
Assert.NotEqual(Guid.Empty, payload.CorrelationId);
Assert.Equal(3, payload.Data.Length);
Assert.True(payload.EndOfStream);
Assert.Equal(5, payload.SequenceNumber);
}
[Fact]
public void StreamingOptions_HasDefaultValues()
{
var options = StreamingOptions.Default;
Assert.Equal(64 * 1024, options.ChunkSize);
Assert.Equal(100, options.MaxConcurrentStreams);
Assert.Equal(TimeSpan.FromMinutes(5), options.StreamIdleTimeout);
Assert.Equal(16, options.ChannelCapacity);
}
}
public class StreamingRequestBodyStreamTests
{
[Fact]
public async Task ReadAsync_ReturnsDataFromChannel()
{
// Arrange
var channel = Channel.CreateUnbounded<StreamChunk>();
using var stream = new StreamingRequestBodyStream(channel.Reader, CancellationToken.None);
var testData = new byte[] { 1, 2, 3, 4, 5 };
await channel.Writer.WriteAsync(new StreamChunk { Data = testData, SequenceNumber = 0 });
await channel.Writer.WriteAsync(new StreamChunk { Data = [], EndOfStream = true, SequenceNumber = 1 });
channel.Writer.Complete();
// Act
var buffer = new byte[10];
var bytesRead = await stream.ReadAsync(buffer);
// Assert
Assert.Equal(5, bytesRead);
Assert.Equal(testData, buffer[..5]);
}
[Fact]
public async Task ReadAsync_ReturnsZeroAtEndOfStream()
{
// Arrange
var channel = Channel.CreateUnbounded<StreamChunk>();
using var stream = new StreamingRequestBodyStream(channel.Reader, CancellationToken.None);
await channel.Writer.WriteAsync(new StreamChunk { Data = [], EndOfStream = true, SequenceNumber = 0 });
channel.Writer.Complete();
// Act
var buffer = new byte[10];
var bytesRead = await stream.ReadAsync(buffer);
// Assert
Assert.Equal(0, bytesRead);
}
[Fact]
public async Task ReadAsync_HandlesMultipleChunks()
{
// Arrange
var channel = Channel.CreateUnbounded<StreamChunk>();
using var stream = new StreamingRequestBodyStream(channel.Reader, CancellationToken.None);
await channel.Writer.WriteAsync(new StreamChunk { Data = [1, 2, 3], SequenceNumber = 0 });
await channel.Writer.WriteAsync(new StreamChunk { Data = [4, 5, 6], SequenceNumber = 1 });
await channel.Writer.WriteAsync(new StreamChunk { Data = [], EndOfStream = true, SequenceNumber = 2 });
channel.Writer.Complete();
// Act
using var memStream = new MemoryStream();
await stream.CopyToAsync(memStream);
// Assert
var result = memStream.ToArray();
Assert.Equal(6, result.Length);
Assert.Equal(new byte[] { 1, 2, 3, 4, 5, 6 }, result);
}
[Fact]
public void Stream_Properties_AreCorrect()
{
var channel = Channel.CreateUnbounded<StreamChunk>();
using var stream = new StreamingRequestBodyStream(channel.Reader, CancellationToken.None);
Assert.True(stream.CanRead);
Assert.False(stream.CanWrite);
Assert.False(stream.CanSeek);
}
[Fact]
public void Write_ThrowsNotSupported()
{
var channel = Channel.CreateUnbounded<StreamChunk>();
using var stream = new StreamingRequestBodyStream(channel.Reader, CancellationToken.None);
Assert.Throws<NotSupportedException>(() => stream.Write([1, 2, 3], 0, 3));
}
}
public class StreamingResponseBodyStreamTests
{
[Fact]
public async Task WriteAsync_WritesToChannel()
{
// Arrange
var channel = Channel.CreateUnbounded<StreamChunk>();
await using var stream = new StreamingResponseBodyStream(channel.Writer, 1024, CancellationToken.None);
var testData = new byte[] { 1, 2, 3, 4, 5 };
// Act
await stream.WriteAsync(testData);
await stream.FlushAsync();
// Assert
Assert.True(channel.Reader.TryRead(out var chunk));
Assert.Equal(testData, chunk!.Data);
Assert.False(chunk.EndOfStream);
}
[Fact]
public async Task CompleteAsync_SendsEndOfStream()
{
// Arrange
var channel = Channel.CreateUnbounded<StreamChunk>();
await using var stream = new StreamingResponseBodyStream(channel.Writer, 1024, CancellationToken.None);
// Act
await stream.WriteAsync(new byte[] { 1, 2, 3 });
await stream.CompleteAsync();
// Assert - should have data chunk + end chunk
var chunks = new List<StreamChunk>();
await foreach (var chunk in channel.Reader.ReadAllAsync())
{
chunks.Add(chunk);
}
Assert.Equal(2, chunks.Count);
Assert.False(chunks[0].EndOfStream);
Assert.True(chunks[1].EndOfStream);
}
[Fact]
public async Task WriteAsync_ChunksLargeData()
{
// Arrange
var chunkSize = 10;
var channel = Channel.CreateUnbounded<StreamChunk>();
await using var stream = new StreamingResponseBodyStream(channel.Writer, chunkSize, CancellationToken.None);
var testData = new byte[25]; // Will need 3 chunks
for (var i = 0; i < testData.Length; i++)
{
testData[i] = (byte)i;
}
// Act
await stream.WriteAsync(testData);
await stream.CompleteAsync();
// Assert
var chunks = new List<StreamChunk>();
await foreach (var chunk in channel.Reader.ReadAllAsync())
{
chunks.Add(chunk);
}
// Should have 3 chunks (10+10+5) + 1 end-of-stream (with 0 data since remainder already flushed)
Assert.Equal(4, chunks.Count);
Assert.Equal(10, chunks[0].Data.Length);
Assert.Equal(10, chunks[1].Data.Length);
Assert.Equal(5, chunks[2].Data.Length);
Assert.True(chunks[3].EndOfStream);
}
[Fact]
public void Stream_Properties_AreCorrect()
{
var channel = Channel.CreateUnbounded<StreamChunk>();
using var stream = new StreamingResponseBodyStream(channel.Writer, 1024, CancellationToken.None);
Assert.False(stream.CanRead);
Assert.True(stream.CanWrite);
Assert.False(stream.CanSeek);
}
[Fact]
public void Read_ThrowsNotSupported()
{
var channel = Channel.CreateUnbounded<StreamChunk>();
using var stream = new StreamingResponseBodyStream(channel.Writer, 1024, CancellationToken.None);
Assert.Throws<NotSupportedException>(() => stream.Read(new byte[10], 0, 10));
}
}
public class InMemoryTransportStreamingTests
{
private readonly InMemoryConnectionRegistry _registry = new();
private readonly InMemoryTransportOptions _options = new() { SimulatedLatency = TimeSpan.Zero };
private InMemoryTransportClient CreateClient()
{
return new InMemoryTransportClient(
_registry,
Options.Create(_options),
NullLogger<InMemoryTransportClient>.Instance);
}
[Fact]
public async Task SendStreamingAsync_SendsRequestStreamDataFrames()
{
// Arrange
using var client = CreateClient();
var instance = new InstanceDescriptor
{
InstanceId = "test-instance",
ServiceName = "test-service",
Version = "1.0.0",
Region = "us-east-1"
};
await client.ConnectAsync(instance, [], CancellationToken.None);
// Get connection ID via reflection
var connectionIdField = client.GetType()
.GetField("_connectionId", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
var connectionId = connectionIdField?.GetValue(client)?.ToString();
Assert.NotNull(connectionId);
var channel = _registry.GetChannel(connectionId!);
Assert.NotNull(channel);
Assert.NotNull(channel!.State);
// Create request body stream
var requestBody = new MemoryStream(new byte[] { 1, 2, 3, 4, 5 });
// Create request frame
var requestFrame = new Frame
{
Type = FrameType.Request,
CorrelationId = Guid.NewGuid().ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
var limits = PayloadLimits.Default;
// Act - Start streaming (this will send frames to microservice)
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var sendTask = client.SendStreamingAsync(
channel.State!,
requestFrame,
requestBody,
_ => Task.CompletedTask,
limits,
cts.Token);
// Read the frames that were sent to microservice
var frames = new List<Frame>();
await foreach (var frame in channel.ToMicroservice.Reader.ReadAllAsync(cts.Token))
{
frames.Add(frame);
if (frame.Type == FrameType.RequestStreamData && frame.Payload.Length == 0)
{
// End of stream - break
break;
}
}
// Assert - should have REQUEST header + data chunks + end-of-stream
Assert.True(frames.Count >= 2);
Assert.Equal(FrameType.Request, frames[0].Type);
Assert.Equal(FrameType.RequestStreamData, frames[^1].Type);
Assert.Equal(0, frames[^1].Payload.Length); // End of stream marker
}
}

View File

@@ -1,786 +0,0 @@
using System.Text;
using FluentAssertions;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging.Abstractions;
using Moq;
using StellaOps.Gateway.WebService.Middleware;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Frames;
using StellaOps.Router.Common.Models;
using Xunit;
namespace StellaOps.Gateway.WebService.Tests;
/// <summary>
/// Unit tests for <see cref="TransportDispatchMiddleware"/>.
/// </summary>
public sealed class TransportDispatchMiddlewareTests
{
private readonly Mock<ITransportClient> _transportClientMock;
private readonly Mock<IGlobalRoutingState> _routingStateMock;
private readonly Mock<RequestDelegate> _nextMock;
private bool _nextCalled;
public TransportDispatchMiddlewareTests()
{
_transportClientMock = new Mock<ITransportClient>();
_routingStateMock = new Mock<IGlobalRoutingState>();
_nextMock = new Mock<RequestDelegate>();
_nextMock.Setup(n => n(It.IsAny<HttpContext>()))
.Callback(() => _nextCalled = true)
.Returns(Task.CompletedTask);
}
private TransportDispatchMiddleware CreateMiddleware()
{
return new TransportDispatchMiddleware(
_nextMock.Object,
NullLogger<TransportDispatchMiddleware>.Instance);
}
private static HttpContext CreateHttpContext(
RoutingDecision? decision = null,
string method = "GET",
string path = "/api/test",
byte[]? body = null)
{
var context = new DefaultHttpContext();
context.Request.Method = method;
context.Request.Path = path;
context.Response.Body = new MemoryStream();
if (body is not null)
{
context.Request.Body = new MemoryStream(body);
context.Request.ContentLength = body.Length;
}
else
{
context.Request.Body = new MemoryStream();
}
if (decision is not null)
{
context.Items[RouterHttpContextKeys.RoutingDecision] = decision;
}
return context;
}
private static EndpointDescriptor CreateEndpoint(
string serviceName = "test-service",
string version = "1.0.0",
bool supportsStreaming = false)
{
return new EndpointDescriptor
{
ServiceName = serviceName,
Version = version,
Method = "GET",
Path = "/api/test",
SupportsStreaming = supportsStreaming
};
}
private static ConnectionState CreateConnection(
string connectionId = "conn-1",
InstanceHealthStatus status = InstanceHealthStatus.Healthy)
{
return new ConnectionState
{
ConnectionId = connectionId,
Instance = new InstanceDescriptor
{
InstanceId = $"inst-{connectionId}",
ServiceName = "test-service",
Version = "1.0.0",
Region = "us-east-1"
},
Status = status,
TransportType = TransportType.InMemory
};
}
private static RoutingDecision CreateDecision(
EndpointDescriptor? endpoint = null,
ConnectionState? connection = null,
TimeSpan? timeout = null)
{
return new RoutingDecision
{
Endpoint = endpoint ?? CreateEndpoint(),
Connection = connection ?? CreateConnection(),
TransportType = TransportType.InMemory,
EffectiveTimeout = timeout ?? TimeSpan.FromSeconds(30)
};
}
private static Frame CreateResponseFrame(
string requestId = "test-request",
int statusCode = 200,
Dictionary<string, string>? headers = null,
byte[]? payload = null)
{
var response = new ResponseFrame
{
RequestId = requestId,
StatusCode = statusCode,
Headers = headers ?? new Dictionary<string, string>(),
Payload = payload ?? []
};
return FrameConverter.ToFrame(response);
}
#region Missing Routing Decision Tests
[Fact]
public async Task Invoke_WithNoRoutingDecision_Returns500()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(decision: null);
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
_nextCalled.Should().BeFalse();
context.Response.StatusCode.Should().Be(StatusCodes.Status500InternalServerError);
}
[Fact]
public async Task Invoke_WithNoRoutingDecision_WritesErrorResponse()
{
// Arrange
var middleware = CreateMiddleware();
var context = CreateHttpContext(decision: null);
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Contain("Routing decision missing");
}
#endregion
#region Successful Request/Response Tests
[Fact]
public async Task Invoke_WithSuccessfulResponse_ForwardsStatusCode()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var context = CreateHttpContext(decision: decision);
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync((ConnectionState conn, Frame req, TimeSpan timeout, CancellationToken ct) =>
{
var requestFrame = FrameConverter.ToRequestFrame(req);
return CreateResponseFrame(requestId: requestFrame!.RequestId, statusCode: 201);
});
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.StatusCode.Should().Be(201);
}
[Fact]
public async Task Invoke_WithResponsePayload_WritesToResponseBody()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var context = CreateHttpContext(decision: decision);
var responsePayload = Encoding.UTF8.GetBytes("{\"result\":\"success\"}");
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync((ConnectionState conn, Frame req, TimeSpan timeout, CancellationToken ct) =>
{
var requestFrame = FrameConverter.ToRequestFrame(req);
return CreateResponseFrame(requestId: requestFrame!.RequestId, payload: responsePayload);
});
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Be("{\"result\":\"success\"}");
}
[Fact]
public async Task Invoke_WithResponseHeaders_ForwardsHeaders()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var context = CreateHttpContext(decision: decision);
var responseHeaders = new Dictionary<string, string>
{
["X-Custom-Header"] = "CustomValue",
["Content-Type"] = "application/json"
};
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync((ConnectionState conn, Frame req, TimeSpan timeout, CancellationToken ct) =>
{
var requestFrame = FrameConverter.ToRequestFrame(req);
return CreateResponseFrame(requestId: requestFrame!.RequestId, headers: responseHeaders);
});
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.Headers.Should().ContainKey("X-Custom-Header");
context.Response.Headers["X-Custom-Header"].ToString().Should().Be("CustomValue");
context.Response.Headers["Content-Type"].ToString().Should().Be("application/json");
}
[Fact]
public async Task Invoke_WithTransferEncodingHeader_DoesNotForward()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var context = CreateHttpContext(decision: decision);
var responseHeaders = new Dictionary<string, string>
{
["Transfer-Encoding"] = "chunked",
["X-Custom-Header"] = "CustomValue"
};
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync((ConnectionState conn, Frame req, TimeSpan timeout, CancellationToken ct) =>
{
var requestFrame = FrameConverter.ToRequestFrame(req);
return CreateResponseFrame(requestId: requestFrame!.RequestId, headers: responseHeaders);
});
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.Headers.Should().NotContainKey("Transfer-Encoding");
context.Response.Headers.Should().ContainKey("X-Custom-Header");
}
[Fact]
public async Task Invoke_WithRequestBody_SendsBodyInFrame()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var requestBody = Encoding.UTF8.GetBytes("{\"data\":\"test\"}");
var context = CreateHttpContext(decision: decision, body: requestBody);
byte[]? capturedPayload = null;
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.Callback<ConnectionState, Frame, TimeSpan, CancellationToken>((conn, req, timeout, ct) =>
{
var requestFrame = FrameConverter.ToRequestFrame(req);
capturedPayload = requestFrame?.Payload.ToArray();
})
.ReturnsAsync((ConnectionState conn, Frame req, TimeSpan timeout, CancellationToken ct) =>
{
var requestFrame = FrameConverter.ToRequestFrame(req);
return CreateResponseFrame(requestId: requestFrame!.RequestId);
});
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
capturedPayload.Should().BeEquivalentTo(requestBody);
}
[Fact]
public async Task Invoke_WithRequestHeaders_ForwardsHeadersInFrame()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var context = CreateHttpContext(decision: decision);
context.Request.Headers["X-Request-Id"] = "req-123";
context.Request.Headers["Accept"] = "application/json";
IReadOnlyDictionary<string, string>? capturedHeaders = null;
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.Callback<ConnectionState, Frame, TimeSpan, CancellationToken>((conn, req, timeout, ct) =>
{
var requestFrame = FrameConverter.ToRequestFrame(req);
capturedHeaders = requestFrame?.Headers;
})
.ReturnsAsync((ConnectionState conn, Frame req, TimeSpan timeout, CancellationToken ct) =>
{
var requestFrame = FrameConverter.ToRequestFrame(req);
return CreateResponseFrame(requestId: requestFrame!.RequestId);
});
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
capturedHeaders.Should().NotBeNull();
capturedHeaders.Should().ContainKey("X-Request-Id");
capturedHeaders!["X-Request-Id"].Should().Be("req-123");
}
#endregion
#region Timeout Tests
[Fact]
public async Task Invoke_WithTimeout_Returns504()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision(timeout: TimeSpan.FromMilliseconds(50));
var context = CreateHttpContext(decision: decision);
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new OperationCanceledException());
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.StatusCode.Should().Be(StatusCodes.Status504GatewayTimeout);
}
[Fact]
public async Task Invoke_WithTimeout_WritesErrorResponse()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision(timeout: TimeSpan.FromMilliseconds(50));
var context = CreateHttpContext(decision: decision);
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new OperationCanceledException());
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Contain("Upstream timeout");
responseBody.Should().Contain("test-service");
}
[Fact]
public async Task Invoke_WithTimeout_SendsCancelFrame()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision(timeout: TimeSpan.FromMilliseconds(50));
var context = CreateHttpContext(decision: decision);
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new OperationCanceledException());
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
_transportClientMock.Verify(t => t.SendCancelAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Guid>(),
CancelReasons.Timeout), Times.Once);
}
#endregion
#region Upstream Error Tests
[Fact]
public async Task Invoke_WithUpstreamError_Returns502()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var context = CreateHttpContext(decision: decision);
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new InvalidOperationException("Connection failed"));
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.StatusCode.Should().Be(StatusCodes.Status502BadGateway);
}
[Fact]
public async Task Invoke_WithUpstreamError_WritesErrorResponse()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var context = CreateHttpContext(decision: decision);
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new InvalidOperationException("Connection failed"));
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Contain("Upstream error");
responseBody.Should().Contain("Connection failed");
}
#endregion
#region Invalid Response Tests
[Fact]
public async Task Invoke_WithInvalidResponseFrame_Returns502()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var context = CreateHttpContext(decision: decision);
// Return a malformed frame that cannot be parsed as ResponseFrame
var invalidFrame = new Frame
{
Type = FrameType.Heartbeat, // Wrong type
CorrelationId = "test",
Payload = Array.Empty<byte>()
};
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(invalidFrame);
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.StatusCode.Should().Be(StatusCodes.Status502BadGateway);
}
[Fact]
public async Task Invoke_WithInvalidResponseFrame_WritesErrorResponse()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var context = CreateHttpContext(decision: decision);
var invalidFrame = new Frame
{
Type = FrameType.Cancel, // Wrong type
CorrelationId = "test",
Payload = Array.Empty<byte>()
};
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(invalidFrame);
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.Body.Seek(0, SeekOrigin.Begin);
using var reader = new StreamReader(context.Response.Body);
var responseBody = await reader.ReadToEndAsync();
responseBody.Should().Contain("Invalid upstream response");
}
#endregion
#region Connection Ping Update Tests
[Fact]
public async Task Invoke_WithSuccessfulResponse_UpdatesConnectionPing()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var context = CreateHttpContext(decision: decision);
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync((ConnectionState conn, Frame req, TimeSpan timeout, CancellationToken ct) =>
{
var requestFrame = FrameConverter.ToRequestFrame(req);
return CreateResponseFrame(requestId: requestFrame!.RequestId);
});
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
_routingStateMock.Verify(r => r.UpdateConnection(
"conn-1",
It.IsAny<Action<ConnectionState>>()), Times.Once);
}
#endregion
#region Streaming Tests
[Fact]
public async Task Invoke_WithStreamingEndpoint_UsesSendStreamingAsync()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint(supportsStreaming: true);
var decision = CreateDecision(endpoint: endpoint);
var context = CreateHttpContext(decision: decision);
_transportClientMock.Setup(t => t.SendStreamingAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<Stream>(),
It.IsAny<Func<Stream, Task>>(),
It.IsAny<PayloadLimits>(),
It.IsAny<CancellationToken>()))
.Callback<ConnectionState, Frame, Stream, Func<Stream, Task>, PayloadLimits, CancellationToken>(
async (conn, req, requestBody, readResponse, limits, ct) =>
{
// Simulate streaming response
using var responseStream = new MemoryStream(Encoding.UTF8.GetBytes("streamed data"));
await readResponse(responseStream);
})
.Returns(Task.CompletedTask);
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
_transportClientMock.Verify(t => t.SendStreamingAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<Stream>(),
It.IsAny<Func<Stream, Task>>(),
It.IsAny<PayloadLimits>(),
It.IsAny<CancellationToken>()), Times.Once);
}
[Fact]
public async Task Invoke_StreamingWithTimeout_Returns504()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint(supportsStreaming: true);
var decision = CreateDecision(endpoint: endpoint, timeout: TimeSpan.FromMilliseconds(50));
var context = CreateHttpContext(decision: decision);
_transportClientMock.Setup(t => t.SendStreamingAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<Stream>(),
It.IsAny<Func<Stream, Task>>(),
It.IsAny<PayloadLimits>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new OperationCanceledException());
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.StatusCode.Should().Be(StatusCodes.Status504GatewayTimeout);
}
[Fact]
public async Task Invoke_StreamingWithUpstreamError_Returns502()
{
// Arrange
var middleware = CreateMiddleware();
var endpoint = CreateEndpoint(supportsStreaming: true);
var decision = CreateDecision(endpoint: endpoint);
var context = CreateHttpContext(decision: decision);
_transportClientMock.Setup(t => t.SendStreamingAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<Stream>(),
It.IsAny<Func<Stream, Task>>(),
It.IsAny<PayloadLimits>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new InvalidOperationException("Streaming failed"));
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
context.Response.StatusCode.Should().Be(StatusCodes.Status502BadGateway);
}
#endregion
#region Query String Tests
[Fact]
public async Task Invoke_WithQueryString_IncludesInRequestPath()
{
// Arrange
var middleware = CreateMiddleware();
var decision = CreateDecision();
var context = CreateHttpContext(decision: decision, path: "/api/test");
context.Request.QueryString = new QueryString("?key=value&other=123");
string? capturedPath = null;
_transportClientMock.Setup(t => t.SendRequestAsync(
It.IsAny<ConnectionState>(),
It.IsAny<Frame>(),
It.IsAny<TimeSpan>(),
It.IsAny<CancellationToken>()))
.Callback<ConnectionState, Frame, TimeSpan, CancellationToken>((conn, req, timeout, ct) =>
{
var requestFrame = FrameConverter.ToRequestFrame(req);
capturedPath = requestFrame?.Path;
})
.ReturnsAsync((ConnectionState conn, Frame req, TimeSpan timeout, CancellationToken ct) =>
{
var requestFrame = FrameConverter.ToRequestFrame(req);
return CreateResponseFrame(requestId: requestFrame!.RequestId);
});
// Act
await middleware.Invoke(
context,
_transportClientMock.Object,
_routingStateMock.Object);
// Assert
capturedPath.Should().Be("/api/test?key=value&other=123");
}
#endregion
}