Add unit tests for Router configuration and transport layers
- Implemented tests for RouterConfig, RoutingOptions, StaticInstanceConfig, and RouterConfigOptions to ensure default values are set correctly. - Added tests for RouterConfigProvider to validate configurations and ensure defaults are returned when no file is specified. - Created tests for ConfigValidationResult to check success and error scenarios. - Developed tests for ServiceCollectionExtensions to verify service registration for RouterConfig. - Introduced UdpTransportTests to validate serialization, connection, request-response, and error handling in UDP transport. - Added scripts for signing authority gaps and hashing DevPortal SDK snippets.
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
using StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
namespace StellaOps.Gateway.WebService;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for configuring the gateway middleware pipeline.
|
||||
/// </summary>
|
||||
public static class ApplicationBuilderExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds the gateway router middleware pipeline.
|
||||
/// </summary>
|
||||
/// <param name="app">The application builder.</param>
|
||||
/// <returns>The application builder for chaining.</returns>
|
||||
public static IApplicationBuilder UseGatewayRouter(this IApplicationBuilder app)
|
||||
{
|
||||
// Resolve endpoints from routing state
|
||||
app.UseMiddleware<EndpointResolutionMiddleware>();
|
||||
|
||||
// Make routing decisions (select instance)
|
||||
app.UseMiddleware<RoutingDecisionMiddleware>();
|
||||
|
||||
// Dispatch to transport and return response
|
||||
app.UseMiddleware<TransportDispatchMiddleware>();
|
||||
|
||||
return app;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Authorization;
|
||||
|
||||
/// <summary>
|
||||
/// Background service that periodically refreshes claims from Authority.
|
||||
/// </summary>
|
||||
internal sealed class AuthorityClaimsRefreshService : BackgroundService
|
||||
{
|
||||
private readonly IAuthorityClaimsProvider _claimsProvider;
|
||||
private readonly IEffectiveClaimsStore _claimsStore;
|
||||
private readonly AuthorityConnectionOptions _options;
|
||||
private readonly ILogger<AuthorityClaimsRefreshService> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="AuthorityClaimsRefreshService"/> class.
|
||||
/// </summary>
|
||||
public AuthorityClaimsRefreshService(
|
||||
IAuthorityClaimsProvider claimsProvider,
|
||||
IEffectiveClaimsStore claimsStore,
|
||||
IOptions<AuthorityConnectionOptions> options,
|
||||
ILogger<AuthorityClaimsRefreshService> logger)
|
||||
{
|
||||
_claimsProvider = claimsProvider;
|
||||
_claimsStore = claimsStore;
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
|
||||
{
|
||||
if (!_options.Enabled)
|
||||
{
|
||||
_logger.LogInformation("Authority integration is disabled");
|
||||
return;
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(_options.AuthorityUrl))
|
||||
{
|
||||
_logger.LogWarning("Authority URL not configured, skipping claims refresh");
|
||||
return;
|
||||
}
|
||||
|
||||
// Subscribe to push notifications if enabled
|
||||
if (_options.UseAuthorityPushNotifications)
|
||||
{
|
||||
_claimsProvider.OverridesChanged += OnOverridesChanged;
|
||||
}
|
||||
|
||||
// Initial fetch with optional wait
|
||||
await FetchWithRetryAsync(stoppingToken);
|
||||
|
||||
// Periodic refresh
|
||||
while (!stoppingToken.IsCancellationRequested)
|
||||
{
|
||||
try
|
||||
{
|
||||
await Task.Delay(_options.RefreshInterval, stoppingToken);
|
||||
await RefreshClaimsAsync(stoppingToken);
|
||||
}
|
||||
catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested)
|
||||
{
|
||||
break;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error during claims refresh");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async Task FetchWithRetryAsync(CancellationToken stoppingToken)
|
||||
{
|
||||
if (!_options.WaitForAuthorityOnStartup)
|
||||
{
|
||||
await RefreshClaimsAsync(stoppingToken);
|
||||
return;
|
||||
}
|
||||
|
||||
var deadline = DateTime.UtcNow.Add(_options.StartupTimeout);
|
||||
var retryDelay = TimeSpan.FromSeconds(1);
|
||||
var attempt = 0;
|
||||
|
||||
while (DateTime.UtcNow < deadline && !stoppingToken.IsCancellationRequested)
|
||||
{
|
||||
attempt++;
|
||||
_logger.LogDebug("Fetching claims from Authority (attempt {Attempt})", attempt);
|
||||
|
||||
await RefreshClaimsAsync(stoppingToken);
|
||||
|
||||
if (_claimsProvider.IsAvailable)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"Successfully connected to Authority after {Attempts} attempts",
|
||||
attempt);
|
||||
return;
|
||||
}
|
||||
|
||||
await Task.Delay(retryDelay, stoppingToken);
|
||||
retryDelay = TimeSpan.FromSeconds(Math.Min(retryDelay.TotalSeconds * 2, 10));
|
||||
}
|
||||
|
||||
_logger.LogWarning(
|
||||
"Could not connect to Authority within {Timeout}. Proceeding without Authority claims.",
|
||||
_options.StartupTimeout);
|
||||
}
|
||||
|
||||
private async Task RefreshClaimsAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
var overrides = await _claimsProvider.GetOverridesAsync(cancellationToken);
|
||||
_claimsStore.UpdateFromAuthority(overrides);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Failed to refresh claims from Authority");
|
||||
}
|
||||
}
|
||||
|
||||
private void OnOverridesChanged(object? sender, ClaimsOverrideChangedEventArgs e)
|
||||
{
|
||||
_logger.LogInformation("Received claims override update from Authority");
|
||||
_claimsStore.UpdateFromAuthority(e.Overrides);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void Dispose()
|
||||
{
|
||||
if (_options.UseAuthorityPushNotifications)
|
||||
{
|
||||
_claimsProvider.OverridesChanged -= OnOverridesChanged;
|
||||
}
|
||||
|
||||
base.Dispose();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
namespace StellaOps.Gateway.WebService.Authorization;
|
||||
|
||||
/// <summary>
|
||||
/// Configuration options for connecting to the Authority service.
|
||||
/// </summary>
|
||||
public sealed class AuthorityConnectionOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Configuration section name.
|
||||
/// </summary>
|
||||
public const string SectionName = "Authority";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the Authority service URL.
|
||||
/// </summary>
|
||||
public string AuthorityUrl { get; set; } = string.Empty;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to wait for Authority on startup.
|
||||
/// If true, the gateway will delay handling traffic until Authority is available.
|
||||
/// </summary>
|
||||
public bool WaitForAuthorityOnStartup { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the startup timeout when waiting for Authority.
|
||||
/// </summary>
|
||||
public TimeSpan StartupTimeout { get; set; } = TimeSpan.FromSeconds(30);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the interval at which to refresh claims from Authority.
|
||||
/// </summary>
|
||||
public TimeSpan RefreshInterval { get; set; } = TimeSpan.FromMinutes(5);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to use push notifications from Authority.
|
||||
/// If false, the gateway will poll at the refresh interval.
|
||||
/// </summary>
|
||||
public bool UseAuthorityPushNotifications { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether Authority integration is enabled.
|
||||
/// </summary>
|
||||
public bool Enabled { get; set; } = true;
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Authorization;
|
||||
|
||||
/// <summary>
|
||||
/// Middleware that enforces claims requirements for endpoints.
|
||||
/// </summary>
|
||||
public sealed class AuthorizationMiddleware
|
||||
{
|
||||
private readonly RequestDelegate _next;
|
||||
private readonly IEffectiveClaimsStore _claimsStore;
|
||||
private readonly ILogger<AuthorizationMiddleware> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="AuthorizationMiddleware"/> class.
|
||||
/// </summary>
|
||||
public AuthorizationMiddleware(
|
||||
RequestDelegate next,
|
||||
IEffectiveClaimsStore claimsStore,
|
||||
ILogger<AuthorizationMiddleware> logger)
|
||||
{
|
||||
_next = next;
|
||||
_claimsStore = claimsStore;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes the middleware.
|
||||
/// </summary>
|
||||
public async Task InvokeAsync(HttpContext context)
|
||||
{
|
||||
// Get resolved endpoint from earlier middleware
|
||||
if (!context.Items.TryGetValue(RouterHttpContextKeys.EndpointDescriptor, out var endpointObj) ||
|
||||
endpointObj is not EndpointDescriptor endpoint)
|
||||
{
|
||||
// No endpoint resolved, let next middleware handle
|
||||
await _next(context);
|
||||
return;
|
||||
}
|
||||
|
||||
// Get effective claims for this endpoint
|
||||
var effectiveClaims = _claimsStore.GetEffectiveClaims(
|
||||
endpoint.ServiceName,
|
||||
endpoint.Method,
|
||||
endpoint.Path);
|
||||
|
||||
if (effectiveClaims.Count == 0)
|
||||
{
|
||||
// No claims required
|
||||
await _next(context);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check each required claim
|
||||
foreach (var required in effectiveClaims)
|
||||
{
|
||||
var userClaims = context.User.Claims;
|
||||
|
||||
bool hasClaim = required.Value == null
|
||||
? userClaims.Any(c => c.Type == required.Type)
|
||||
: userClaims.Any(c => c.Type == required.Type && c.Value == required.Value);
|
||||
|
||||
if (!hasClaim)
|
||||
{
|
||||
_logger.LogWarning(
|
||||
"Authorization failed for {Method} {Path}: user lacks claim {ClaimType}={ClaimValue}",
|
||||
endpoint.Method,
|
||||
endpoint.Path,
|
||||
required.Type,
|
||||
required.Value ?? "(any)");
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status403Forbidden;
|
||||
context.Response.ContentType = "application/json";
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Forbidden",
|
||||
message = "Authorization failed: missing required claim",
|
||||
requiredClaim = new { type = required.Type, value = required.Value }
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
await _next(context);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering the authorization middleware.
|
||||
/// </summary>
|
||||
public static class AuthorizationMiddlewareExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds the claims authorization middleware to the pipeline.
|
||||
/// </summary>
|
||||
/// <param name="app">The application builder.</param>
|
||||
/// <returns>The application builder for chaining.</returns>
|
||||
public static IApplicationBuilder UseClaimsAuthorization(this IApplicationBuilder app)
|
||||
{
|
||||
return app.UseMiddleware<AuthorizationMiddleware>();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
using System.Collections.Concurrent;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Authorization;
|
||||
|
||||
/// <summary>
|
||||
/// In-memory store for effective claims.
|
||||
/// Merges microservice defaults with Authority overrides.
|
||||
/// </summary>
|
||||
internal sealed class EffectiveClaimsStore : IEffectiveClaimsStore
|
||||
{
|
||||
private readonly ConcurrentDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> _microserviceClaims = new();
|
||||
private readonly ConcurrentDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> _authorityClaims = new();
|
||||
private readonly ILogger<EffectiveClaimsStore> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="EffectiveClaimsStore"/> class.
|
||||
/// </summary>
|
||||
public EffectiveClaimsStore(ILogger<EffectiveClaimsStore> logger)
|
||||
{
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public IReadOnlyList<ClaimRequirement> GetEffectiveClaims(string serviceName, string method, string path)
|
||||
{
|
||||
var key = EndpointKey.Create(serviceName, method, path);
|
||||
|
||||
// Authority takes precedence
|
||||
if (_authorityClaims.TryGetValue(key, out var authorityClaims))
|
||||
{
|
||||
_logger.LogDebug(
|
||||
"Using Authority claims for {Endpoint}: {ClaimCount} claims",
|
||||
key,
|
||||
authorityClaims.Count);
|
||||
return authorityClaims;
|
||||
}
|
||||
|
||||
// Fall back to microservice defaults
|
||||
if (_microserviceClaims.TryGetValue(key, out var msClaims))
|
||||
{
|
||||
return msClaims;
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void UpdateFromMicroservice(string serviceName, IReadOnlyList<EndpointDescriptor> endpoints)
|
||||
{
|
||||
foreach (var endpoint in endpoints)
|
||||
{
|
||||
var key = EndpointKey.Create(serviceName, endpoint.Method, endpoint.Path);
|
||||
var claims = endpoint.RequiringClaims ?? [];
|
||||
|
||||
if (claims.Count > 0)
|
||||
{
|
||||
_microserviceClaims[key] = claims;
|
||||
_logger.LogDebug(
|
||||
"Registered {ClaimCount} claims from microservice for {Endpoint}",
|
||||
claims.Count,
|
||||
key);
|
||||
}
|
||||
else
|
||||
{
|
||||
_microserviceClaims.TryRemove(key, out _);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void UpdateFromAuthority(IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> overrides)
|
||||
{
|
||||
// Clear previous Authority claims
|
||||
_authorityClaims.Clear();
|
||||
|
||||
// Add new Authority claims
|
||||
foreach (var (key, claims) in overrides)
|
||||
{
|
||||
if (claims.Count > 0)
|
||||
{
|
||||
_authorityClaims[key] = claims;
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogInformation(
|
||||
"Updated Authority claims: {EndpointCount} endpoints with overrides",
|
||||
overrides.Count);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void RemoveService(string serviceName)
|
||||
{
|
||||
var normalizedServiceName = serviceName.ToLowerInvariant();
|
||||
var keysToRemove = _microserviceClaims.Keys
|
||||
.Where(k => k.ServiceName == normalizedServiceName)
|
||||
.ToList();
|
||||
|
||||
foreach (var key in keysToRemove)
|
||||
{
|
||||
_microserviceClaims.TryRemove(key, out _);
|
||||
}
|
||||
|
||||
_logger.LogDebug(
|
||||
"Removed {Count} endpoint claims for service {ServiceName}",
|
||||
keysToRemove.Count,
|
||||
serviceName);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
namespace StellaOps.Gateway.WebService.Authorization;
|
||||
|
||||
/// <summary>
|
||||
/// Key for identifying an endpoint by service name, method, and path.
|
||||
/// </summary>
|
||||
/// <param name="ServiceName">The name of the service.</param>
|
||||
/// <param name="Method">The HTTP method.</param>
|
||||
/// <param name="Path">The path template.</param>
|
||||
public readonly record struct EndpointKey(string ServiceName, string Method, string Path)
|
||||
{
|
||||
/// <summary>
|
||||
/// Creates an endpoint key with normalized values.
|
||||
/// </summary>
|
||||
public static EndpointKey Create(string serviceName, string method, string path)
|
||||
{
|
||||
return new EndpointKey(
|
||||
serviceName.ToLowerInvariant(),
|
||||
method.ToUpperInvariant(),
|
||||
path.ToLowerInvariant());
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override string ToString() => $"{ServiceName}:{Method} {Path}";
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
using System.Net.Http.Json;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Authorization;
|
||||
|
||||
/// <summary>
|
||||
/// Fetches claims overrides from the Authority service via HTTP.
|
||||
/// </summary>
|
||||
internal sealed class HttpAuthorityClaimsProvider : IAuthorityClaimsProvider
|
||||
{
|
||||
private readonly HttpClient _httpClient;
|
||||
private readonly AuthorityConnectionOptions _options;
|
||||
private readonly ILogger<HttpAuthorityClaimsProvider> _logger;
|
||||
private volatile bool _isAvailable;
|
||||
|
||||
private static readonly JsonSerializerOptions JsonOptions = new()
|
||||
{
|
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="HttpAuthorityClaimsProvider"/> class.
|
||||
/// </summary>
|
||||
public HttpAuthorityClaimsProvider(
|
||||
HttpClient httpClient,
|
||||
IOptions<AuthorityConnectionOptions> options,
|
||||
ILogger<HttpAuthorityClaimsProvider> logger)
|
||||
{
|
||||
_httpClient = httpClient;
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public bool IsAvailable => _isAvailable;
|
||||
|
||||
/// <inheritdoc />
|
||||
public event EventHandler<ClaimsOverrideChangedEventArgs>? OverridesChanged;
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>> GetOverridesAsync(
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(_options.AuthorityUrl))
|
||||
{
|
||||
_logger.LogDebug("Authority URL not configured, returning empty overrides");
|
||||
_isAvailable = false;
|
||||
return new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>();
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var url = $"{_options.AuthorityUrl.TrimEnd('/')}/api/v1/claims/overrides";
|
||||
|
||||
_logger.LogDebug("Fetching claims overrides from {Url}", url);
|
||||
|
||||
var response = await _httpClient.GetAsync(url, cancellationToken);
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
var overrideResponse = await response.Content.ReadFromJsonAsync<ClaimsOverrideResponse>(
|
||||
JsonOptions,
|
||||
cancellationToken);
|
||||
|
||||
if (overrideResponse?.Overrides == null)
|
||||
{
|
||||
_isAvailable = true;
|
||||
return new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>();
|
||||
}
|
||||
|
||||
var result = new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>();
|
||||
foreach (var entry in overrideResponse.Overrides)
|
||||
{
|
||||
var key = EndpointKey.Create(entry.ServiceName, entry.Method, entry.Path);
|
||||
var claims = entry.RequiringClaims
|
||||
.Select(c => new ClaimRequirement { Type = c.Type, Value = c.Value })
|
||||
.ToList();
|
||||
result[key] = claims;
|
||||
}
|
||||
|
||||
_isAvailable = true;
|
||||
_logger.LogInformation(
|
||||
"Fetched {Count} claims overrides from Authority",
|
||||
result.Count);
|
||||
|
||||
return result;
|
||||
}
|
||||
catch (Exception ex) when (ex is HttpRequestException or TaskCanceledException)
|
||||
{
|
||||
_isAvailable = false;
|
||||
_logger.LogWarning(ex, "Failed to fetch claims overrides from Authority");
|
||||
return new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Raises the <see cref="OverridesChanged"/> event.
|
||||
/// </summary>
|
||||
internal void RaiseOverridesChanged(IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> overrides)
|
||||
{
|
||||
OverridesChanged?.Invoke(this, new ClaimsOverrideChangedEventArgs { Overrides = overrides });
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// DTO for claims override response from Authority.
|
||||
/// </summary>
|
||||
private sealed class ClaimsOverrideResponse
|
||||
{
|
||||
public List<ClaimsOverrideEntry> Overrides { get; set; } = [];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// DTO for a single claims override entry.
|
||||
/// </summary>
|
||||
private sealed class ClaimsOverrideEntry
|
||||
{
|
||||
public string ServiceName { get; set; } = string.Empty;
|
||||
public string Method { get; set; } = string.Empty;
|
||||
public string Path { get; set; } = string.Empty;
|
||||
public List<ClaimRequirementDto> RequiringClaims { get; set; } = [];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// DTO for a claim requirement.
|
||||
/// </summary>
|
||||
private sealed class ClaimRequirementDto
|
||||
{
|
||||
public string Type { get; set; } = string.Empty;
|
||||
public string? Value { get; set; }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Authorization;
|
||||
|
||||
/// <summary>
|
||||
/// Provides claims overrides from the central Authority service.
|
||||
/// </summary>
|
||||
public interface IAuthorityClaimsProvider
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets all claims overrides from Authority.
|
||||
/// </summary>
|
||||
/// <param name="cancellationToken">The cancellation token.</param>
|
||||
/// <returns>A dictionary of endpoint keys to claim requirements.</returns>
|
||||
Task<IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>> GetOverridesAsync(
|
||||
CancellationToken cancellationToken);
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether the Authority is currently available.
|
||||
/// </summary>
|
||||
bool IsAvailable { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Occurs when claims overrides change.
|
||||
/// </summary>
|
||||
event EventHandler<ClaimsOverrideChangedEventArgs>? OverridesChanged;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Event arguments for claims override changes.
|
||||
/// </summary>
|
||||
public sealed class ClaimsOverrideChangedEventArgs : EventArgs
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the updated claims overrides.
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> Overrides { get; init; }
|
||||
= new Dictionary<EndpointKey, IReadOnlyList<ClaimRequirement>>();
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Authorization;
|
||||
|
||||
/// <summary>
|
||||
/// Stores and retrieves effective claims for endpoints.
|
||||
/// Handles merging of microservice defaults with Authority overrides.
|
||||
/// </summary>
|
||||
public interface IEffectiveClaimsStore
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the effective claims for an endpoint.
|
||||
/// Authority overrides take precedence over microservice defaults.
|
||||
/// </summary>
|
||||
/// <param name="serviceName">The service name.</param>
|
||||
/// <param name="method">The HTTP method.</param>
|
||||
/// <param name="path">The path template.</param>
|
||||
/// <returns>The effective claims for the endpoint.</returns>
|
||||
IReadOnlyList<ClaimRequirement> GetEffectiveClaims(string serviceName, string method, string path);
|
||||
|
||||
/// <summary>
|
||||
/// Updates claims from a microservice's HELLO message.
|
||||
/// </summary>
|
||||
/// <param name="serviceName">The service name.</param>
|
||||
/// <param name="endpoints">The endpoint descriptors with claims.</param>
|
||||
void UpdateFromMicroservice(string serviceName, IReadOnlyList<EndpointDescriptor> endpoints);
|
||||
|
||||
/// <summary>
|
||||
/// Updates claims from Authority overrides.
|
||||
/// </summary>
|
||||
/// <param name="overrides">The Authority claims overrides.</param>
|
||||
void UpdateFromAuthority(IReadOnlyDictionary<EndpointKey, IReadOnlyList<ClaimRequirement>> overrides);
|
||||
|
||||
/// <summary>
|
||||
/// Removes all claims for a service.
|
||||
/// Called when a microservice disconnects.
|
||||
/// </summary>
|
||||
/// <param name="serviceName">The service name.</param>
|
||||
void RemoveService(string serviceName);
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
namespace StellaOps.Gateway.WebService.Authorization;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering Authority integration services.
|
||||
/// </summary>
|
||||
public static class AuthorizationServiceCollectionExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds Authority integration services to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configuration">The configuration.</param>
|
||||
/// <returns>The service collection for chaining.</returns>
|
||||
public static IServiceCollection AddAuthorityIntegration(
|
||||
this IServiceCollection services,
|
||||
IConfiguration configuration)
|
||||
{
|
||||
// Bind options
|
||||
services.Configure<AuthorityConnectionOptions>(
|
||||
configuration.GetSection(AuthorityConnectionOptions.SectionName));
|
||||
|
||||
// Register effective claims store
|
||||
services.AddSingleton<IEffectiveClaimsStore, EffectiveClaimsStore>();
|
||||
|
||||
// Register HTTP client for Authority
|
||||
services.AddHttpClient<IAuthorityClaimsProvider, HttpAuthorityClaimsProvider>(client =>
|
||||
{
|
||||
client.Timeout = TimeSpan.FromSeconds(30);
|
||||
});
|
||||
|
||||
// Register background service for claims refresh
|
||||
services.AddHostedService<AuthorityClaimsRefreshService>();
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds Authority integration services with custom options.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configure">Action to configure Authority options.</param>
|
||||
/// <returns>The service collection for chaining.</returns>
|
||||
public static IServiceCollection AddAuthorityIntegration(
|
||||
this IServiceCollection services,
|
||||
Action<AuthorityConnectionOptions>? configure = null)
|
||||
{
|
||||
// Register options
|
||||
if (configure != null)
|
||||
{
|
||||
services.Configure(configure);
|
||||
}
|
||||
else
|
||||
{
|
||||
services.AddOptions<AuthorityConnectionOptions>();
|
||||
}
|
||||
|
||||
// Register effective claims store
|
||||
services.AddSingleton<IEffectiveClaimsStore, EffectiveClaimsStore>();
|
||||
|
||||
// Register HTTP client for Authority
|
||||
services.AddHttpClient<IAuthorityClaimsProvider, HttpAuthorityClaimsProvider>(client =>
|
||||
{
|
||||
client.Timeout = TimeSpan.FromSeconds(30);
|
||||
});
|
||||
|
||||
// Register background service for claims refresh
|
||||
services.AddHostedService<AuthorityClaimsRefreshService>();
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds a no-op Authority integration (no external Authority).
|
||||
/// Claims are only from microservices.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <returns>The service collection for chaining.</returns>
|
||||
public static IServiceCollection AddNoOpAuthorityIntegration(this IServiceCollection services)
|
||||
{
|
||||
services.Configure<AuthorityConnectionOptions>(options => options.Enabled = false);
|
||||
services.AddSingleton<IEffectiveClaimsStore, EffectiveClaimsStore>();
|
||||
services.AddSingleton<IAuthorityClaimsProvider, NoOpAuthorityClaimsProvider>();
|
||||
return services;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A no-op Authority claims provider that returns empty overrides.
|
||||
/// </summary>
|
||||
internal sealed class NoOpAuthorityClaimsProvider : IAuthorityClaimsProvider
|
||||
{
|
||||
/// <inheritdoc />
|
||||
public bool IsAvailable => true;
|
||||
|
||||
/// <inheritdoc />
|
||||
#pragma warning disable CS0067 // Event is never used (expected for no-op implementation)
|
||||
public event EventHandler<ClaimsOverrideChangedEventArgs>? OverridesChanged;
|
||||
#pragma warning restore CS0067
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task<IReadOnlyDictionary<EndpointKey, IReadOnlyList<StellaOps.Router.Common.Models.ClaimRequirement>>> GetOverridesAsync(
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
return Task.FromResult<IReadOnlyDictionary<EndpointKey, IReadOnlyList<StellaOps.Router.Common.Models.ClaimRequirement>>>(
|
||||
new Dictionary<EndpointKey, IReadOnlyList<StellaOps.Router.Common.Models.ClaimRequirement>>());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using StellaOps.Router.Transport.InMemory;
|
||||
|
||||
namespace StellaOps.Gateway.WebService;
|
||||
|
||||
/// <summary>
|
||||
/// Manages microservice connections and updates routing state.
|
||||
/// </summary>
|
||||
internal sealed class ConnectionManager : IHostedService
|
||||
{
|
||||
private readonly InMemoryTransportServer _transportServer;
|
||||
private readonly InMemoryConnectionRegistry _connectionRegistry;
|
||||
private readonly IGlobalRoutingState _routingState;
|
||||
private readonly ILogger<ConnectionManager> _logger;
|
||||
|
||||
public ConnectionManager(
|
||||
InMemoryTransportServer transportServer,
|
||||
InMemoryConnectionRegistry connectionRegistry,
|
||||
IGlobalRoutingState routingState,
|
||||
ILogger<ConnectionManager> logger)
|
||||
{
|
||||
_transportServer = transportServer;
|
||||
_connectionRegistry = connectionRegistry;
|
||||
_routingState = routingState;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public async Task StartAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
// Subscribe to transport server events
|
||||
_transportServer.OnHelloReceived += HandleHelloReceivedAsync;
|
||||
_transportServer.OnHeartbeatReceived += HandleHeartbeatReceivedAsync;
|
||||
_transportServer.OnConnectionClosed += HandleConnectionClosedAsync;
|
||||
|
||||
// Start the transport server
|
||||
await _transportServer.StartAsync(cancellationToken);
|
||||
|
||||
_logger.LogInformation("Connection manager started");
|
||||
}
|
||||
|
||||
public async Task StopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
await _transportServer.StopAsync(cancellationToken);
|
||||
|
||||
_transportServer.OnHelloReceived -= HandleHelloReceivedAsync;
|
||||
_transportServer.OnHeartbeatReceived -= HandleHeartbeatReceivedAsync;
|
||||
_transportServer.OnConnectionClosed -= HandleConnectionClosedAsync;
|
||||
|
||||
_logger.LogInformation("Connection manager stopped");
|
||||
}
|
||||
|
||||
private Task HandleHelloReceivedAsync(ConnectionState connectionState, HelloPayload payload)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"Connection registered: {ConnectionId} from {ServiceName}/{Version} with {EndpointCount} endpoints",
|
||||
connectionState.ConnectionId,
|
||||
connectionState.Instance.ServiceName,
|
||||
connectionState.Instance.Version,
|
||||
connectionState.Endpoints.Count);
|
||||
|
||||
// Add the connection to the routing state
|
||||
_routingState.AddConnection(connectionState);
|
||||
|
||||
// Start listening to this connection for frames
|
||||
_transportServer.StartListeningToConnection(connectionState.ConnectionId);
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private Task HandleHeartbeatReceivedAsync(ConnectionState connectionState, HeartbeatPayload payload)
|
||||
{
|
||||
_logger.LogDebug(
|
||||
"Heartbeat received from {ConnectionId}: status={Status}",
|
||||
connectionState.ConnectionId,
|
||||
payload.Status);
|
||||
|
||||
// Update connection state
|
||||
_routingState.UpdateConnection(connectionState.ConnectionId, conn =>
|
||||
{
|
||||
conn.Status = payload.Status;
|
||||
conn.LastHeartbeatUtc = DateTime.UtcNow;
|
||||
});
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private Task HandleConnectionClosedAsync(string connectionId)
|
||||
{
|
||||
_logger.LogInformation("Connection closed: {ConnectionId}", connectionId);
|
||||
|
||||
// Remove from routing state
|
||||
_routingState.RemoveConnection(connectionId);
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
256
src/Gateway/StellaOps.Gateway.WebService/DefaultRoutingPlugin.cs
Normal file
256
src/Gateway/StellaOps.Gateway.WebService/DefaultRoutingPlugin.cs
Normal file
@@ -0,0 +1,256 @@
|
||||
using System.Collections.Concurrent;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService;
|
||||
|
||||
/// <summary>
|
||||
/// Default implementation of routing plugin that provides health-aware, region-aware routing.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// Routing algorithm:
|
||||
/// 1. Filter by ServiceName (exact match from endpoint)
|
||||
/// 2. Filter by Version (strict semver equality when RequestedVersion specified)
|
||||
/// 3. Filter by Health (Healthy preferred, Degraded as fallback)
|
||||
/// 4. Group by Region Tier:
|
||||
/// - Tier 0: Same region as gateway
|
||||
/// - Tier 1: Configured neighbor regions
|
||||
/// - Tier 2: All other regions
|
||||
/// 5. Within each tier, sort by:
|
||||
/// - Primary: Lower AveragePingMs
|
||||
/// - Secondary: More recent LastHeartbeatUtc
|
||||
/// - Tie-breaker: Random or RoundRobin
|
||||
/// 6. Return first candidate from best available tier
|
||||
/// 7. If none remain, return null (503 Service Unavailable)
|
||||
/// </remarks>
|
||||
internal sealed class DefaultRoutingPlugin : IRoutingPlugin
|
||||
{
|
||||
private readonly RoutingOptions _options;
|
||||
private readonly GatewayNodeConfig _gatewayConfig;
|
||||
private readonly ConcurrentDictionary<string, int> _roundRobinCounters = new();
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="DefaultRoutingPlugin"/> class.
|
||||
/// </summary>
|
||||
public DefaultRoutingPlugin(
|
||||
IOptions<RoutingOptions> options,
|
||||
IOptions<GatewayNodeConfig> gatewayConfig)
|
||||
{
|
||||
_options = options.Value;
|
||||
_gatewayConfig = gatewayConfig.Value;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task<RoutingDecision?> ChooseInstanceAsync(
|
||||
RoutingContext context,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
if (context.AvailableConnections.Count == 0)
|
||||
{
|
||||
return Task.FromResult<RoutingDecision?>(null);
|
||||
}
|
||||
|
||||
var endpoint = context.Endpoint;
|
||||
if (endpoint is null)
|
||||
{
|
||||
return Task.FromResult<RoutingDecision?>(null);
|
||||
}
|
||||
|
||||
// Start with all available connections
|
||||
var candidates = context.AvailableConnections.ToList();
|
||||
|
||||
// Filter by version if requested
|
||||
candidates = FilterByVersion(candidates, context.RequestedVersion);
|
||||
if (candidates.Count == 0)
|
||||
{
|
||||
return Task.FromResult<RoutingDecision?>(null);
|
||||
}
|
||||
|
||||
// Filter by health status - prefer healthy, fall back to degraded
|
||||
candidates = FilterByHealth(candidates);
|
||||
if (candidates.Count == 0)
|
||||
{
|
||||
return Task.FromResult<RoutingDecision?>(null);
|
||||
}
|
||||
|
||||
// Group by region tier and select from best available tier
|
||||
var selected = SelectByRegionTier(candidates, context.GatewayRegion, endpoint.ServiceName);
|
||||
if (selected is null)
|
||||
{
|
||||
return Task.FromResult<RoutingDecision?>(null);
|
||||
}
|
||||
|
||||
var decision = new RoutingDecision
|
||||
{
|
||||
Endpoint = endpoint,
|
||||
Connection = selected,
|
||||
TransportType = selected.TransportType,
|
||||
EffectiveTimeout = TimeSpan.FromMilliseconds(_options.RoutingTimeoutMs)
|
||||
};
|
||||
|
||||
return Task.FromResult<RoutingDecision?>(decision);
|
||||
}
|
||||
|
||||
private List<ConnectionState> FilterByVersion(
|
||||
List<ConnectionState> candidates,
|
||||
string? requestedVersion)
|
||||
{
|
||||
// Determine effective version to match
|
||||
var versionToMatch = requestedVersion ?? _options.DefaultVersion;
|
||||
|
||||
// If no version specified and no default, return all candidates
|
||||
if (string.IsNullOrEmpty(versionToMatch))
|
||||
{
|
||||
return candidates;
|
||||
}
|
||||
|
||||
if (_options.StrictVersionMatching)
|
||||
{
|
||||
// Strict match: exact version equality
|
||||
return candidates
|
||||
.Where(c => string.Equals(c.Instance.Version, versionToMatch, StringComparison.Ordinal))
|
||||
.ToList();
|
||||
}
|
||||
|
||||
// Non-strict: allow compatible versions (for now, just exact match)
|
||||
// Future: implement semver compatibility checking
|
||||
return candidates
|
||||
.Where(c => string.Equals(c.Instance.Version, versionToMatch, StringComparison.Ordinal))
|
||||
.ToList();
|
||||
}
|
||||
|
||||
private List<ConnectionState> FilterByHealth(List<ConnectionState> candidates)
|
||||
{
|
||||
// Filter to only healthy instances first
|
||||
var healthy = candidates
|
||||
.Where(c => c.Status == InstanceHealthStatus.Healthy)
|
||||
.ToList();
|
||||
|
||||
if (healthy.Count > 0)
|
||||
{
|
||||
return healthy;
|
||||
}
|
||||
|
||||
// If no healthy instances and degraded allowed, include degraded
|
||||
if (_options.AllowDegradedInstances)
|
||||
{
|
||||
var degraded = candidates
|
||||
.Where(c => c.Status == InstanceHealthStatus.Degraded)
|
||||
.ToList();
|
||||
|
||||
if (degraded.Count > 0)
|
||||
{
|
||||
return degraded;
|
||||
}
|
||||
}
|
||||
|
||||
// No suitable instances
|
||||
return [];
|
||||
}
|
||||
|
||||
private ConnectionState? SelectByRegionTier(
|
||||
List<ConnectionState> candidates,
|
||||
string gatewayRegion,
|
||||
string serviceName)
|
||||
{
|
||||
if (!_options.PreferLocalRegion || string.IsNullOrEmpty(gatewayRegion))
|
||||
{
|
||||
// No region preference, select from all candidates
|
||||
return SelectFromTier(candidates, serviceName);
|
||||
}
|
||||
|
||||
// Tier 0: Same region as gateway
|
||||
var tier0 = candidates
|
||||
.Where(c => string.Equals(c.Instance.Region, gatewayRegion, StringComparison.OrdinalIgnoreCase))
|
||||
.ToList();
|
||||
|
||||
var selected = SelectFromTier(tier0, serviceName);
|
||||
if (selected is not null)
|
||||
{
|
||||
return selected;
|
||||
}
|
||||
|
||||
// Tier 1: Configured neighbor regions
|
||||
var neighborRegions = _gatewayConfig.NeighborRegions;
|
||||
if (neighborRegions.Count > 0)
|
||||
{
|
||||
var tier1 = candidates
|
||||
.Where(c => neighborRegions.Contains(c.Instance.Region, StringComparer.OrdinalIgnoreCase))
|
||||
.ToList();
|
||||
|
||||
selected = SelectFromTier(tier1, serviceName);
|
||||
if (selected is not null)
|
||||
{
|
||||
return selected;
|
||||
}
|
||||
}
|
||||
|
||||
// Tier 2: All other regions (remaining candidates not in tier0 or tier1)
|
||||
var tier2 = candidates
|
||||
.Where(c => !string.Equals(c.Instance.Region, gatewayRegion, StringComparison.OrdinalIgnoreCase))
|
||||
.Where(c => !neighborRegions.Contains(c.Instance.Region, StringComparer.OrdinalIgnoreCase))
|
||||
.ToList();
|
||||
|
||||
return SelectFromTier(tier2, serviceName);
|
||||
}
|
||||
|
||||
private ConnectionState? SelectFromTier(List<ConnectionState> tier, string serviceName)
|
||||
{
|
||||
if (tier.Count == 0)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (tier.Count == 1)
|
||||
{
|
||||
return tier[0];
|
||||
}
|
||||
|
||||
// Sort by ping (ascending), then by heartbeat (descending = more recent first)
|
||||
var sorted = tier
|
||||
.OrderBy(c => c.AveragePingMs)
|
||||
.ThenByDescending(c => c.LastHeartbeatUtc)
|
||||
.ToList();
|
||||
|
||||
var best = sorted[0];
|
||||
|
||||
// Find all instances "tied" with the best one
|
||||
var tied = sorted
|
||||
.TakeWhile(c =>
|
||||
Math.Abs(c.AveragePingMs - best.AveragePingMs) <= _options.PingToleranceMs &&
|
||||
c.LastHeartbeatUtc == best.LastHeartbeatUtc)
|
||||
.ToList();
|
||||
|
||||
if (tied.Count == 1)
|
||||
{
|
||||
return tied[0];
|
||||
}
|
||||
|
||||
// Apply tie-breaker
|
||||
return _options.TieBreaker switch
|
||||
{
|
||||
TieBreakerMode.RoundRobin => SelectRoundRobin(tied, serviceName),
|
||||
_ => SelectRandom(tied)
|
||||
};
|
||||
}
|
||||
|
||||
private ConnectionState SelectRandom(List<ConnectionState> candidates)
|
||||
{
|
||||
var index = Random.Shared.Next(candidates.Count);
|
||||
return candidates[index];
|
||||
}
|
||||
|
||||
private ConnectionState SelectRoundRobin(List<ConnectionState> candidates, string serviceName)
|
||||
{
|
||||
// Get or create counter for this service
|
||||
var counter = _roundRobinCounters.AddOrUpdate(
|
||||
serviceName,
|
||||
_ => 0,
|
||||
(_, current) => current + 1);
|
||||
|
||||
var index = counter % candidates.Count;
|
||||
return candidates[index];
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
using System.ComponentModel.DataAnnotations;
|
||||
|
||||
namespace StellaOps.Gateway.WebService;
|
||||
|
||||
/// <summary>
|
||||
@@ -6,23 +8,48 @@ namespace StellaOps.Gateway.WebService;
|
||||
public sealed class GatewayNodeConfig
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the region where this gateway is deployed (e.g., "eu1").
|
||||
/// Configuration section name for binding.
|
||||
/// </summary>
|
||||
public const string SectionName = "GatewayNode";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the region where this gateway is deployed (e.g., "eu1").
|
||||
/// Routing decisions use this value; it is never derived from headers or URLs.
|
||||
/// </summary>
|
||||
public required string Region { get; init; }
|
||||
[Required(ErrorMessage = "Region is required for gateway routing")]
|
||||
public string Region { get; set; } = string.Empty;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the unique identifier for this gateway node (e.g., "gw-eu1-01").
|
||||
/// Gets or sets the unique identifier for this gateway node (e.g., "gw-eu1-01").
|
||||
/// </summary>
|
||||
public required string NodeId { get; init; }
|
||||
public string NodeId { get; set; } = string.Empty;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the environment name (e.g., "prod", "staging", "dev").
|
||||
/// Gets or sets the environment name (e.g., "prod", "staging", "dev").
|
||||
/// </summary>
|
||||
public required string Environment { get; init; }
|
||||
public string Environment { get; set; } = string.Empty;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the neighbor regions for fallback routing, in order of preference.
|
||||
/// Gets or sets the neighbor regions for fallback routing, in order of preference.
|
||||
/// </summary>
|
||||
public IReadOnlyList<string> NeighborRegions { get; init; } = [];
|
||||
public List<string> NeighborRegions { get; set; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Validates the configuration.
|
||||
/// </summary>
|
||||
/// <exception cref="InvalidOperationException">Thrown when configuration is invalid.</exception>
|
||||
public void Validate()
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(Region))
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
$"{SectionName}:Region is required. Gateway cannot start without a region assignment.");
|
||||
}
|
||||
|
||||
// Generate NodeId if not provided
|
||||
if (string.IsNullOrWhiteSpace(NodeId))
|
||||
{
|
||||
NodeId = $"gw-{Region}-{Guid.NewGuid().ToString("N")[..8]}";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
117
src/Gateway/StellaOps.Gateway.WebService/HealthMonitorService.cs
Normal file
117
src/Gateway/StellaOps.Gateway.WebService/HealthMonitorService.cs
Normal file
@@ -0,0 +1,117 @@
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
|
||||
namespace StellaOps.Gateway.WebService;
|
||||
|
||||
/// <summary>
|
||||
/// Background service that monitors connection health and marks stale instances as unhealthy.
|
||||
/// </summary>
|
||||
internal sealed class HealthMonitorService : BackgroundService
|
||||
{
|
||||
private readonly IGlobalRoutingState _routingState;
|
||||
private readonly IOptions<HealthOptions> _options;
|
||||
private readonly ILogger<HealthMonitorService> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="HealthMonitorService"/> class.
|
||||
/// </summary>
|
||||
public HealthMonitorService(
|
||||
IGlobalRoutingState routingState,
|
||||
IOptions<HealthOptions> options,
|
||||
ILogger<HealthMonitorService> logger)
|
||||
{
|
||||
_routingState = routingState;
|
||||
_options = options;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"Health monitor started. Stale threshold: {StaleThreshold}, Check interval: {CheckInterval}",
|
||||
_options.Value.StaleThreshold,
|
||||
_options.Value.CheckInterval);
|
||||
|
||||
while (!stoppingToken.IsCancellationRequested)
|
||||
{
|
||||
try
|
||||
{
|
||||
await Task.Delay(_options.Value.CheckInterval, stoppingToken);
|
||||
CheckStaleConnections();
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected on shutdown
|
||||
break;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error in health monitor loop");
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogInformation("Health monitor stopped");
|
||||
}
|
||||
|
||||
private void CheckStaleConnections()
|
||||
{
|
||||
var staleThreshold = _options.Value.StaleThreshold;
|
||||
var degradedThreshold = _options.Value.DegradedThreshold;
|
||||
var now = DateTime.UtcNow;
|
||||
var staleCount = 0;
|
||||
var degradedCount = 0;
|
||||
|
||||
foreach (var connection in _routingState.GetAllConnections())
|
||||
{
|
||||
// Skip connections that are already draining - they're intentionally stopping
|
||||
if (connection.Status == InstanceHealthStatus.Draining)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var age = now - connection.LastHeartbeatUtc;
|
||||
|
||||
// Check for stale (no heartbeat for too long)
|
||||
if (age > staleThreshold && connection.Status != InstanceHealthStatus.Unhealthy)
|
||||
{
|
||||
_routingState.UpdateConnection(connection.ConnectionId, c =>
|
||||
c.Status = InstanceHealthStatus.Unhealthy);
|
||||
|
||||
_logger.LogWarning(
|
||||
"Instance {InstanceId} ({ServiceName}/{Version}) marked Unhealthy: no heartbeat for {Age:g}",
|
||||
connection.Instance.InstanceId,
|
||||
connection.Instance.ServiceName,
|
||||
connection.Instance.Version,
|
||||
age);
|
||||
|
||||
staleCount++;
|
||||
}
|
||||
// Check for degraded (heartbeat delayed but not stale)
|
||||
else if (age > degradedThreshold &&
|
||||
connection.Status == InstanceHealthStatus.Healthy)
|
||||
{
|
||||
_routingState.UpdateConnection(connection.ConnectionId, c =>
|
||||
c.Status = InstanceHealthStatus.Degraded);
|
||||
|
||||
_logger.LogWarning(
|
||||
"Instance {InstanceId} ({ServiceName}/{Version}) marked Degraded: delayed heartbeat ({Age:g})",
|
||||
connection.Instance.InstanceId,
|
||||
connection.Instance.ServiceName,
|
||||
connection.Instance.Version,
|
||||
age);
|
||||
|
||||
degradedCount++;
|
||||
}
|
||||
}
|
||||
|
||||
if (staleCount > 0 || degradedCount > 0)
|
||||
{
|
||||
_logger.LogDebug(
|
||||
"Health check completed: {StaleCount} stale, {DegradedCount} degraded",
|
||||
staleCount,
|
||||
degradedCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
36
src/Gateway/StellaOps.Gateway.WebService/HealthOptions.cs
Normal file
36
src/Gateway/StellaOps.Gateway.WebService/HealthOptions.cs
Normal file
@@ -0,0 +1,36 @@
|
||||
namespace StellaOps.Gateway.WebService;
|
||||
|
||||
/// <summary>
|
||||
/// Configuration options for health monitoring.
|
||||
/// </summary>
|
||||
public sealed class HealthOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the configuration section name.
|
||||
/// </summary>
|
||||
public const string SectionName = "Health";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the threshold after which a connection is considered stale (no heartbeat).
|
||||
/// Default: 30 seconds.
|
||||
/// </summary>
|
||||
public TimeSpan StaleThreshold { get; set; } = TimeSpan.FromSeconds(30);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the threshold after which a connection is considered degraded.
|
||||
/// Default: 15 seconds.
|
||||
/// </summary>
|
||||
public TimeSpan DegradedThreshold { get; set; } = TimeSpan.FromSeconds(15);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the interval at which to check for stale connections.
|
||||
/// Default: 5 seconds.
|
||||
/// </summary>
|
||||
public TimeSpan CheckInterval { get; set; } = TimeSpan.FromSeconds(5);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the number of ping measurements to keep for averaging.
|
||||
/// Default: 10.
|
||||
/// </summary>
|
||||
public int PingHistorySize { get; set; } = 10;
|
||||
}
|
||||
159
src/Gateway/StellaOps.Gateway.WebService/InMemoryRoutingState.cs
Normal file
159
src/Gateway/StellaOps.Gateway.WebService/InMemoryRoutingState.cs
Normal file
@@ -0,0 +1,159 @@
|
||||
using System.Collections.Concurrent;
|
||||
using StellaOps.Router.Common;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService;
|
||||
|
||||
/// <summary>
|
||||
/// In-memory implementation of global routing state.
|
||||
/// </summary>
|
||||
internal sealed class InMemoryRoutingState : IGlobalRoutingState
|
||||
{
|
||||
private readonly ConcurrentDictionary<string, ConnectionState> _connections = new();
|
||||
private readonly ConcurrentDictionary<(string Method, string Path), ConcurrentBag<string>> _endpointIndex = new();
|
||||
private readonly ConcurrentDictionary<(string Method, string Path), PathMatcher> _pathMatchers = new();
|
||||
private readonly object _indexLock = new();
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AddConnection(ConnectionState connection)
|
||||
{
|
||||
_connections[connection.ConnectionId] = connection;
|
||||
|
||||
// Index all endpoints
|
||||
foreach (var endpoint in connection.Endpoints.Values)
|
||||
{
|
||||
var key = (endpoint.Method, endpoint.Path);
|
||||
|
||||
// Add to endpoint index
|
||||
var connectionIds = _endpointIndex.GetOrAdd(key, _ => []);
|
||||
connectionIds.Add(connection.ConnectionId);
|
||||
|
||||
// Create path matcher if not exists
|
||||
_pathMatchers.GetOrAdd(key, _ => new PathMatcher(endpoint.Path));
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void RemoveConnection(string connectionId)
|
||||
{
|
||||
if (_connections.TryRemove(connectionId, out var connection))
|
||||
{
|
||||
// Remove from endpoint index
|
||||
foreach (var endpoint in connection.Endpoints.Values)
|
||||
{
|
||||
var key = (endpoint.Method, endpoint.Path);
|
||||
if (_endpointIndex.TryGetValue(key, out var connectionIds))
|
||||
{
|
||||
// ConcurrentBag doesn't support removal, so we need to rebuild
|
||||
lock (_indexLock)
|
||||
{
|
||||
var remaining = connectionIds.Where(id => id != connectionId).ToList();
|
||||
if (remaining.Count == 0)
|
||||
{
|
||||
_endpointIndex.TryRemove(key, out _);
|
||||
_pathMatchers.TryRemove(key, out _);
|
||||
}
|
||||
else
|
||||
{
|
||||
_endpointIndex[key] = new ConcurrentBag<string>(remaining);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void UpdateConnection(string connectionId, Action<ConnectionState> update)
|
||||
{
|
||||
if (_connections.TryGetValue(connectionId, out var connection))
|
||||
{
|
||||
update(connection);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public ConnectionState? GetConnection(string connectionId)
|
||||
{
|
||||
return _connections.TryGetValue(connectionId, out var connection) ? connection : null;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public IReadOnlyList<ConnectionState> GetAllConnections()
|
||||
{
|
||||
return [.. _connections.Values];
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public EndpointDescriptor? ResolveEndpoint(string method, string path)
|
||||
{
|
||||
// First try exact match
|
||||
foreach (var ((m, p), matcher) in _pathMatchers)
|
||||
{
|
||||
if (!string.Equals(m, method, StringComparison.OrdinalIgnoreCase))
|
||||
continue;
|
||||
|
||||
if (matcher.IsMatch(path))
|
||||
{
|
||||
// Get first connection with this endpoint
|
||||
if (_endpointIndex.TryGetValue((m, p), out var connectionIds))
|
||||
{
|
||||
foreach (var connectionId in connectionIds)
|
||||
{
|
||||
if (_connections.TryGetValue(connectionId, out var conn) &&
|
||||
conn.Endpoints.TryGetValue((m, p), out var endpoint))
|
||||
{
|
||||
return endpoint;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public IReadOnlyList<ConnectionState> GetConnectionsFor(
|
||||
string serviceName,
|
||||
string version,
|
||||
string method,
|
||||
string path)
|
||||
{
|
||||
var result = new List<ConnectionState>();
|
||||
|
||||
foreach (var ((m, p), matcher) in _pathMatchers)
|
||||
{
|
||||
if (!string.Equals(m, method, StringComparison.OrdinalIgnoreCase))
|
||||
continue;
|
||||
|
||||
if (!matcher.IsMatch(path))
|
||||
continue;
|
||||
|
||||
if (!_endpointIndex.TryGetValue((m, p), out var connectionIds))
|
||||
continue;
|
||||
|
||||
foreach (var connectionId in connectionIds)
|
||||
{
|
||||
if (!_connections.TryGetValue(connectionId, out var conn))
|
||||
continue;
|
||||
|
||||
// Filter by service name and version
|
||||
if (!string.Equals(conn.Instance.ServiceName, serviceName, StringComparison.OrdinalIgnoreCase))
|
||||
continue;
|
||||
|
||||
if (!string.Equals(conn.Instance.Version, version, StringComparison.Ordinal))
|
||||
continue;
|
||||
|
||||
// Check endpoint exists
|
||||
if (conn.Endpoints.ContainsKey((m, p)))
|
||||
{
|
||||
result.Add(conn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// A stream wrapper that counts bytes read and enforces a limit.
|
||||
/// </summary>
|
||||
public sealed class ByteCountingStream : Stream
|
||||
{
|
||||
private readonly Stream _inner;
|
||||
private readonly long _limit;
|
||||
private readonly Action? _onLimitExceeded;
|
||||
private long _bytesRead;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ByteCountingStream"/> class.
|
||||
/// </summary>
|
||||
/// <param name="inner">The inner stream to wrap.</param>
|
||||
/// <param name="limit">The maximum number of bytes that can be read.</param>
|
||||
/// <param name="onLimitExceeded">Optional callback invoked when the limit is exceeded.</param>
|
||||
public ByteCountingStream(Stream inner, long limit, Action? onLimitExceeded = null)
|
||||
{
|
||||
_inner = inner;
|
||||
_limit = limit;
|
||||
_onLimitExceeded = onLimitExceeded;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the total number of bytes read from this stream.
|
||||
/// </summary>
|
||||
public long BytesRead => Interlocked.Read(ref _bytesRead);
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanRead => _inner.CanRead;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanSeek => false;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanWrite => false;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Length => _inner.Length;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Position
|
||||
{
|
||||
get => _inner.Position;
|
||||
set => throw new NotSupportedException("Seeking not supported on ByteCountingStream.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void Flush() => _inner.Flush();
|
||||
|
||||
/// <inheritdoc />
|
||||
public override Task FlushAsync(CancellationToken cancellationToken) =>
|
||||
_inner.FlushAsync(cancellationToken);
|
||||
|
||||
/// <inheritdoc />
|
||||
public override int Read(byte[] buffer, int offset, int count)
|
||||
{
|
||||
var read = _inner.Read(buffer, offset, count);
|
||||
CheckLimit(read);
|
||||
return read;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
{
|
||||
var read = await _inner.ReadAsync(buffer, offset, count, cancellationToken);
|
||||
CheckLimit(read);
|
||||
return read;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var read = await _inner.ReadAsync(buffer, cancellationToken);
|
||||
CheckLimit(read);
|
||||
return read;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Seek(long offset, SeekOrigin origin)
|
||||
{
|
||||
throw new NotSupportedException("Seeking not supported on ByteCountingStream.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void SetLength(long value)
|
||||
{
|
||||
throw new NotSupportedException("Setting length not supported on ByteCountingStream.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void Write(byte[] buffer, int offset, int count)
|
||||
{
|
||||
throw new NotSupportedException("Writing not supported on ByteCountingStream.");
|
||||
}
|
||||
|
||||
private void CheckLimit(int bytesJustRead)
|
||||
{
|
||||
if (bytesJustRead <= 0) return;
|
||||
|
||||
var newTotal = Interlocked.Add(ref _bytesRead, bytesJustRead);
|
||||
if (newTotal > _limit)
|
||||
{
|
||||
_onLimitExceeded?.Invoke();
|
||||
throw new PayloadLimitExceededException(newTotal, _limit);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
if (!_disposed && disposing)
|
||||
{
|
||||
_inner.Dispose();
|
||||
}
|
||||
|
||||
_disposed = true;
|
||||
base.Dispose(disposing);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async ValueTask DisposeAsync()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
await _inner.DisposeAsync();
|
||||
}
|
||||
|
||||
_disposed = true;
|
||||
await base.DisposeAsync();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Resolves incoming HTTP requests to endpoint descriptors using the routing state.
|
||||
/// </summary>
|
||||
public sealed class EndpointResolutionMiddleware
|
||||
{
|
||||
private readonly RequestDelegate _next;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="EndpointResolutionMiddleware"/> class.
|
||||
/// </summary>
|
||||
public EndpointResolutionMiddleware(RequestDelegate next)
|
||||
{
|
||||
_next = next;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes the middleware.
|
||||
/// </summary>
|
||||
public async Task Invoke(HttpContext context, IGlobalRoutingState routingState)
|
||||
{
|
||||
var method = context.Request.Method;
|
||||
var path = context.Request.Path.ToString();
|
||||
|
||||
var endpoint = routingState.ResolveEndpoint(method, path);
|
||||
if (endpoint is null)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status404NotFound;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Endpoint not found",
|
||||
method,
|
||||
path
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
context.Items[RouterHttpContextKeys.EndpointDescriptor] = endpoint;
|
||||
await _next(context);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Exception thrown when a payload limit is exceeded during streaming.
|
||||
/// </summary>
|
||||
public sealed class PayloadLimitExceededException : Exception
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PayloadLimitExceededException"/> class.
|
||||
/// </summary>
|
||||
/// <param name="bytesRead">The number of bytes read before the limit was exceeded.</param>
|
||||
/// <param name="limit">The limit that was exceeded.</param>
|
||||
public PayloadLimitExceededException(long bytesRead, long limit)
|
||||
: base($"Payload limit exceeded: {bytesRead} bytes read, limit is {limit} bytes")
|
||||
{
|
||||
BytesRead = bytesRead;
|
||||
Limit = limit;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the number of bytes read before the limit was exceeded.
|
||||
/// </summary>
|
||||
public long BytesRead { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the limit that was exceeded.
|
||||
/// </summary>
|
||||
public long Limit { get; }
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Middleware that enforces payload limits per-request, per-connection, and aggregate.
|
||||
/// </summary>
|
||||
public sealed class PayloadLimitsMiddleware
|
||||
{
|
||||
private readonly RequestDelegate _next;
|
||||
private readonly PayloadLimits _limits;
|
||||
private readonly ILogger<PayloadLimitsMiddleware> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PayloadLimitsMiddleware"/> class.
|
||||
/// </summary>
|
||||
public PayloadLimitsMiddleware(
|
||||
RequestDelegate next,
|
||||
IOptions<PayloadLimits> limits,
|
||||
ILogger<PayloadLimitsMiddleware> logger)
|
||||
{
|
||||
_next = next;
|
||||
_limits = limits.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes the middleware.
|
||||
/// </summary>
|
||||
public async Task Invoke(HttpContext context, IPayloadTracker tracker)
|
||||
{
|
||||
var connectionId = context.Connection.Id;
|
||||
var contentLength = context.Request.ContentLength ?? 0;
|
||||
|
||||
// Early rejection for known oversized Content-Length (LIM-002, LIM-003)
|
||||
if (context.Request.ContentLength.HasValue &&
|
||||
context.Request.ContentLength.Value > _limits.MaxRequestBytesPerCall)
|
||||
{
|
||||
_logger.LogWarning(
|
||||
"Request rejected: Content-Length {ContentLength} exceeds per-call limit {Limit}. ConnectionId: {ConnectionId}",
|
||||
context.Request.ContentLength.Value,
|
||||
_limits.MaxRequestBytesPerCall,
|
||||
connectionId);
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status413PayloadTooLarge;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Payload Too Large",
|
||||
maxBytes = _limits.MaxRequestBytesPerCall,
|
||||
contentLength = context.Request.ContentLength.Value
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to reserve capacity (checks aggregate and per-connection limits)
|
||||
if (!tracker.TryReserve(connectionId, contentLength))
|
||||
{
|
||||
// Check which limit was hit
|
||||
if (tracker.IsOverloaded)
|
||||
{
|
||||
// Aggregate limit exceeded (LIM-033)
|
||||
_logger.LogWarning(
|
||||
"Request rejected: Aggregate limit exceeded. Current inflight: {Current}, Limit: {Limit}. ConnectionId: {ConnectionId}",
|
||||
tracker.CurrentInflightBytes,
|
||||
_limits.MaxAggregateInflightBytes,
|
||||
connectionId);
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Service Overloaded",
|
||||
message = "Too many concurrent requests"
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Per-connection limit exceeded (LIM-022)
|
||||
_logger.LogWarning(
|
||||
"Request rejected: Per-connection limit exceeded. ConnectionId: {ConnectionId}, Current: {Current}, Limit: {Limit}",
|
||||
connectionId,
|
||||
tracker.GetConnectionInflightBytes(connectionId),
|
||||
_limits.MaxRequestBytesPerConnection);
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status429TooManyRequests;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Too Many Requests",
|
||||
message = "Per-connection limit exceeded"
|
||||
});
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Store the original body stream
|
||||
var originalBody = context.Request.Body;
|
||||
long actualBytesRead = 0;
|
||||
|
||||
try
|
||||
{
|
||||
// Wrap the request body with ByteCountingStream for streaming requests
|
||||
if (!context.Request.ContentLength.HasValue || context.Request.ContentLength.Value > 0)
|
||||
{
|
||||
var countingStream = new ByteCountingStream(
|
||||
originalBody,
|
||||
_limits.MaxRequestBytesPerCall,
|
||||
() =>
|
||||
{
|
||||
_logger.LogWarning(
|
||||
"Mid-stream limit exceeded. ConnectionId: {ConnectionId}, Limit: {Limit}",
|
||||
connectionId,
|
||||
_limits.MaxRequestBytesPerCall);
|
||||
});
|
||||
|
||||
context.Request.Body = countingStream;
|
||||
|
||||
// Store reference for later access to bytes read
|
||||
context.Items["PayloadLimits:CountingStream"] = countingStream;
|
||||
}
|
||||
|
||||
await _next(context);
|
||||
|
||||
// Get actual bytes read
|
||||
if (context.Items["PayloadLimits:CountingStream"] is ByteCountingStream cs)
|
||||
{
|
||||
actualBytesRead = cs.BytesRead;
|
||||
}
|
||||
}
|
||||
catch (PayloadLimitExceededException ex)
|
||||
{
|
||||
_logger.LogWarning(
|
||||
"Payload limit exceeded mid-stream. ConnectionId: {ConnectionId}, BytesRead: {BytesRead}, Limit: {Limit}",
|
||||
connectionId,
|
||||
ex.BytesRead,
|
||||
ex.Limit);
|
||||
|
||||
// Only set response if not already started
|
||||
if (!context.Response.HasStarted)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status413PayloadTooLarge;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Payload Too Large",
|
||||
maxBytes = _limits.MaxRequestBytesPerCall,
|
||||
bytesReceived = ex.BytesRead
|
||||
});
|
||||
}
|
||||
|
||||
actualBytesRead = ex.BytesRead;
|
||||
}
|
||||
finally
|
||||
{
|
||||
// Restore original body stream
|
||||
context.Request.Body = originalBody;
|
||||
|
||||
// Release reserved capacity
|
||||
var bytesToRelease = actualBytesRead > 0 ? actualBytesRead : contentLength;
|
||||
tracker.Release(connectionId, bytesToRelease);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
using System.Collections.Concurrent;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Tracks payload bytes across requests, connections, and globally.
|
||||
/// </summary>
|
||||
public interface IPayloadTracker
|
||||
{
|
||||
/// <summary>
|
||||
/// Tries to reserve capacity for an estimated payload size.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection identifier.</param>
|
||||
/// <param name="estimatedBytes">The estimated bytes to reserve.</param>
|
||||
/// <returns>True if capacity was reserved; false if limits would be exceeded.</returns>
|
||||
bool TryReserve(string connectionId, long estimatedBytes);
|
||||
|
||||
/// <summary>
|
||||
/// Releases previously reserved capacity.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection identifier.</param>
|
||||
/// <param name="actualBytes">The actual bytes to release.</param>
|
||||
void Release(string connectionId, long actualBytes);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the current total inflight bytes across all connections.
|
||||
/// </summary>
|
||||
long CurrentInflightBytes { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether the system is overloaded.
|
||||
/// </summary>
|
||||
bool IsOverloaded { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the current inflight bytes for a specific connection.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection identifier.</param>
|
||||
/// <returns>The current inflight bytes for the connection.</returns>
|
||||
long GetConnectionInflightBytes(string connectionId);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Default implementation of <see cref="IPayloadTracker"/>.
|
||||
/// </summary>
|
||||
public sealed class PayloadTracker : IPayloadTracker
|
||||
{
|
||||
private readonly PayloadLimits _limits;
|
||||
private readonly ILogger<PayloadTracker> _logger;
|
||||
private long _totalInflightBytes;
|
||||
private readonly ConcurrentDictionary<string, long> _perConnectionBytes = new();
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PayloadTracker"/> class.
|
||||
/// </summary>
|
||||
public PayloadTracker(IOptions<PayloadLimits> limits, ILogger<PayloadTracker> logger)
|
||||
{
|
||||
_limits = limits.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public long CurrentInflightBytes => Interlocked.Read(ref _totalInflightBytes);
|
||||
|
||||
/// <inheritdoc />
|
||||
public bool IsOverloaded => CurrentInflightBytes > _limits.MaxAggregateInflightBytes;
|
||||
|
||||
/// <inheritdoc />
|
||||
public bool TryReserve(string connectionId, long estimatedBytes)
|
||||
{
|
||||
// Check aggregate limit
|
||||
var newTotal = Interlocked.Add(ref _totalInflightBytes, estimatedBytes);
|
||||
if (newTotal > _limits.MaxAggregateInflightBytes)
|
||||
{
|
||||
Interlocked.Add(ref _totalInflightBytes, -estimatedBytes);
|
||||
_logger.LogWarning(
|
||||
"Aggregate payload limit exceeded. Current: {Current}, Limit: {Limit}",
|
||||
newTotal - estimatedBytes,
|
||||
_limits.MaxAggregateInflightBytes);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check per-connection limit
|
||||
var connectionBytes = _perConnectionBytes.AddOrUpdate(
|
||||
connectionId,
|
||||
estimatedBytes,
|
||||
(_, current) => current + estimatedBytes);
|
||||
|
||||
if (connectionBytes > _limits.MaxRequestBytesPerConnection)
|
||||
{
|
||||
// Roll back
|
||||
_perConnectionBytes.AddOrUpdate(
|
||||
connectionId,
|
||||
0,
|
||||
(_, current) => current - estimatedBytes);
|
||||
Interlocked.Add(ref _totalInflightBytes, -estimatedBytes);
|
||||
|
||||
_logger.LogWarning(
|
||||
"Per-connection payload limit exceeded for {ConnectionId}. Current: {Current}, Limit: {Limit}",
|
||||
connectionId,
|
||||
connectionBytes - estimatedBytes,
|
||||
_limits.MaxRequestBytesPerConnection);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Release(string connectionId, long actualBytes)
|
||||
{
|
||||
Interlocked.Add(ref _totalInflightBytes, -actualBytes);
|
||||
|
||||
_perConnectionBytes.AddOrUpdate(
|
||||
connectionId,
|
||||
0,
|
||||
(_, current) => Math.Max(0, current - actualBytes));
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public long GetConnectionInflightBytes(string connectionId)
|
||||
{
|
||||
return _perConnectionBytes.TryGetValue(connectionId, out var bytes) ? bytes : 0;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Makes routing decisions for resolved endpoints.
|
||||
/// </summary>
|
||||
public sealed class RoutingDecisionMiddleware
|
||||
{
|
||||
private readonly RequestDelegate _next;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RoutingDecisionMiddleware"/> class.
|
||||
/// </summary>
|
||||
public RoutingDecisionMiddleware(RequestDelegate next)
|
||||
{
|
||||
_next = next;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes the middleware.
|
||||
/// </summary>
|
||||
public async Task Invoke(
|
||||
HttpContext context,
|
||||
IRoutingPlugin routingPlugin,
|
||||
IGlobalRoutingState routingState,
|
||||
IOptions<GatewayNodeConfig> gatewayConfig,
|
||||
IOptions<RoutingOptions> routingOptions)
|
||||
{
|
||||
var endpoint = context.Items[RouterHttpContextKeys.EndpointDescriptor] as EndpointDescriptor;
|
||||
if (endpoint is null)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status500InternalServerError;
|
||||
await context.Response.WriteAsJsonAsync(new { error = "Endpoint descriptor missing" });
|
||||
return;
|
||||
}
|
||||
|
||||
// Build routing context
|
||||
var availableConnections = routingState.GetConnectionsFor(
|
||||
endpoint.ServiceName,
|
||||
endpoint.Version,
|
||||
endpoint.Method,
|
||||
endpoint.Path);
|
||||
|
||||
var headers = context.Request.Headers
|
||||
.ToDictionary(h => h.Key, h => h.Value.ToString());
|
||||
|
||||
var routingContext = new RoutingContext
|
||||
{
|
||||
Method = context.Request.Method,
|
||||
Path = context.Request.Path.ToString(),
|
||||
Headers = headers,
|
||||
Endpoint = endpoint,
|
||||
AvailableConnections = availableConnections,
|
||||
GatewayRegion = gatewayConfig.Value.Region,
|
||||
RequestedVersion = ExtractVersionFromRequest(context, routingOptions.Value),
|
||||
CancellationToken = context.RequestAborted
|
||||
};
|
||||
|
||||
var decision = await routingPlugin.ChooseInstanceAsync(
|
||||
routingContext,
|
||||
context.RequestAborted);
|
||||
|
||||
if (decision is null)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "No instances available",
|
||||
service = endpoint.ServiceName,
|
||||
version = endpoint.Version
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
context.Items[RouterHttpContextKeys.RoutingDecision] = decision;
|
||||
await _next(context);
|
||||
}
|
||||
|
||||
private static string? ExtractVersionFromRequest(HttpContext context, RoutingOptions options)
|
||||
{
|
||||
// Check for version in Accept header: Accept: application/vnd.stellaops.v1+json
|
||||
var acceptHeader = context.Request.Headers.Accept.FirstOrDefault();
|
||||
if (!string.IsNullOrEmpty(acceptHeader))
|
||||
{
|
||||
var versionMatch = System.Text.RegularExpressions.Regex.Match(
|
||||
acceptHeader,
|
||||
@"application/vnd\.stellaops\.v(\d+(?:\.\d+)*)\+json");
|
||||
if (versionMatch.Success)
|
||||
{
|
||||
return versionMatch.Groups[1].Value;
|
||||
}
|
||||
}
|
||||
|
||||
// Check for X-Api-Version header
|
||||
var versionHeader = context.Request.Headers["X-Api-Version"].FirstOrDefault();
|
||||
if (!string.IsNullOrEmpty(versionHeader))
|
||||
{
|
||||
return versionHeader;
|
||||
}
|
||||
|
||||
// Fall back to default version from options
|
||||
return options.DefaultVersion;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,457 @@
|
||||
using System.Collections.Concurrent;
|
||||
using System.Diagnostics;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Frames;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
/// <summary>
|
||||
/// Dispatches HTTP requests to microservices via the transport layer.
|
||||
/// </summary>
|
||||
public sealed class TransportDispatchMiddleware
|
||||
{
|
||||
private readonly RequestDelegate _next;
|
||||
private readonly ILogger<TransportDispatchMiddleware> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Tracks cancelled request IDs to ignore late responses.
|
||||
/// Keys expire after 60 seconds to prevent memory leaks.
|
||||
/// </summary>
|
||||
private static readonly ConcurrentDictionary<string, DateTimeOffset> CancelledRequests = new();
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TransportDispatchMiddleware"/> class.
|
||||
/// </summary>
|
||||
public TransportDispatchMiddleware(RequestDelegate next, ILogger<TransportDispatchMiddleware> logger)
|
||||
{
|
||||
_next = next;
|
||||
_logger = logger;
|
||||
|
||||
// Start background cleanup task for expired cancelled request entries
|
||||
_ = Task.Run(CleanupExpiredCancelledRequestsAsync);
|
||||
}
|
||||
|
||||
private static async Task CleanupExpiredCancelledRequestsAsync()
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
await Task.Delay(TimeSpan.FromSeconds(30));
|
||||
|
||||
var cutoff = DateTimeOffset.UtcNow.AddSeconds(-60);
|
||||
foreach (var kvp in CancelledRequests)
|
||||
{
|
||||
if (kvp.Value < cutoff)
|
||||
{
|
||||
CancelledRequests.TryRemove(kvp.Key, out _);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void MarkCancelled(string requestId)
|
||||
{
|
||||
CancelledRequests[requestId] = DateTimeOffset.UtcNow;
|
||||
}
|
||||
|
||||
private static bool IsCancelled(string requestId)
|
||||
{
|
||||
return CancelledRequests.ContainsKey(requestId);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes the middleware.
|
||||
/// </summary>
|
||||
public async Task Invoke(
|
||||
HttpContext context,
|
||||
ITransportClient transportClient,
|
||||
IGlobalRoutingState routingState)
|
||||
{
|
||||
var decision = context.Items[RouterHttpContextKeys.RoutingDecision] as RoutingDecision;
|
||||
if (decision is null)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status500InternalServerError;
|
||||
await context.Response.WriteAsJsonAsync(new { error = "Routing decision missing" });
|
||||
return;
|
||||
}
|
||||
|
||||
var requestId = Guid.NewGuid().ToString("N");
|
||||
|
||||
// Extract headers (exclude some internal headers)
|
||||
var headers = context.Request.Headers
|
||||
.Where(h => !h.Key.StartsWith(":", StringComparison.Ordinal))
|
||||
.ToDictionary(
|
||||
h => h.Key,
|
||||
h => h.Value.ToString());
|
||||
|
||||
// For streaming endpoints, use streaming dispatch
|
||||
if (decision.Endpoint.SupportsStreaming)
|
||||
{
|
||||
await DispatchStreamingAsync(context, transportClient, routingState, decision, requestId, headers);
|
||||
return;
|
||||
}
|
||||
|
||||
// Read request body (buffered)
|
||||
byte[] bodyBytes;
|
||||
using (var ms = new MemoryStream())
|
||||
{
|
||||
await context.Request.Body.CopyToAsync(ms, context.RequestAborted);
|
||||
bodyBytes = ms.ToArray();
|
||||
}
|
||||
|
||||
// Build request frame
|
||||
var requestFrame = new RequestFrame
|
||||
{
|
||||
RequestId = requestId,
|
||||
CorrelationId = context.TraceIdentifier,
|
||||
Method = context.Request.Method,
|
||||
Path = context.Request.Path.ToString() + context.Request.QueryString.ToString(),
|
||||
Headers = headers,
|
||||
Payload = bodyBytes,
|
||||
TimeoutSeconds = (int)decision.EffectiveTimeout.TotalSeconds,
|
||||
SupportsStreaming = false
|
||||
};
|
||||
|
||||
var frame = FrameConverter.ToFrame(requestFrame);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Dispatching {Method} {Path} to {ServiceName}/{Version} via {TransportType}",
|
||||
requestFrame.Method,
|
||||
requestFrame.Path,
|
||||
decision.Connection.Instance.ServiceName,
|
||||
decision.Connection.Instance.Version,
|
||||
decision.TransportType);
|
||||
|
||||
// Create linked cancellation token with timeout
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted);
|
||||
timeoutCts.CancelAfter(decision.EffectiveTimeout);
|
||||
|
||||
// Register client disconnect handler to send CANCEL
|
||||
var requestIdGuid = Guid.TryParse(requestId, out var parsed) ? parsed : Guid.NewGuid();
|
||||
using var clientDisconnectRegistration = context.RequestAborted.Register(() =>
|
||||
{
|
||||
// Mark as cancelled to ignore late responses
|
||||
MarkCancelled(requestId);
|
||||
|
||||
// Send CANCEL frame (fire and forget)
|
||||
_ = Task.Run(async () =>
|
||||
{
|
||||
try
|
||||
{
|
||||
await transportClient.SendCancelAsync(
|
||||
decision.Connection,
|
||||
requestIdGuid,
|
||||
CancelReasons.ClientDisconnected);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Sent CANCEL for request {RequestId} due to client disconnect",
|
||||
requestId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex,
|
||||
"Failed to send CANCEL for request {RequestId} on client disconnect",
|
||||
requestId);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Frame responseFrame;
|
||||
var startTimestamp = Stopwatch.GetTimestamp();
|
||||
try
|
||||
{
|
||||
responseFrame = await transportClient.SendRequestAsync(
|
||||
decision.Connection,
|
||||
frame,
|
||||
decision.EffectiveTimeout,
|
||||
timeoutCts.Token);
|
||||
|
||||
// Record ping latency and update connection's average
|
||||
var elapsed = Stopwatch.GetElapsedTime(startTimestamp);
|
||||
UpdateConnectionPing(routingState, decision.Connection.ConnectionId, elapsed.TotalMilliseconds);
|
||||
}
|
||||
catch (OperationCanceledException) when (!context.RequestAborted.IsCancellationRequested)
|
||||
{
|
||||
// Internal timeout (not client disconnect)
|
||||
_logger.LogWarning(
|
||||
"Request {RequestId} to {ServiceName} timed out after {Timeout}",
|
||||
requestId,
|
||||
decision.Connection.Instance.ServiceName,
|
||||
decision.EffectiveTimeout);
|
||||
|
||||
// Mark as cancelled to ignore late responses
|
||||
MarkCancelled(requestId);
|
||||
|
||||
// Send cancel to microservice
|
||||
try
|
||||
{
|
||||
await transportClient.SendCancelAsync(
|
||||
decision.Connection,
|
||||
requestIdGuid,
|
||||
CancelReasons.Timeout);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Failed to send cancel for request {RequestId}", requestId);
|
||||
}
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status504GatewayTimeout;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Upstream timeout",
|
||||
service = decision.Connection.Instance.ServiceName,
|
||||
timeout = decision.EffectiveTimeout.TotalSeconds
|
||||
});
|
||||
return;
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Client disconnected - cancel already sent via registration above
|
||||
MarkCancelled(requestId);
|
||||
_logger.LogDebug("Client disconnected, request {RequestId} cancelled", requestId);
|
||||
return;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex,
|
||||
"Error dispatching request {RequestId} to {ServiceName}",
|
||||
requestId,
|
||||
decision.Connection.Instance.ServiceName);
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status502BadGateway;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Upstream error",
|
||||
message = ex.Message
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if request was cancelled while waiting for response
|
||||
if (IsCancelled(requestId))
|
||||
{
|
||||
_logger.LogDebug("Ignoring late response for cancelled request {RequestId}", requestId);
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var response = FrameConverter.ToResponseFrame(responseFrame);
|
||||
if (response is null)
|
||||
{
|
||||
_logger.LogError(
|
||||
"Invalid response frame from {ServiceName} for request {RequestId}",
|
||||
decision.Connection.Instance.ServiceName,
|
||||
requestId);
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status502BadGateway;
|
||||
await context.Response.WriteAsJsonAsync(new { error = "Invalid upstream response" });
|
||||
return;
|
||||
}
|
||||
|
||||
// Map response to HTTP
|
||||
context.Response.StatusCode = response.StatusCode;
|
||||
|
||||
// Copy response headers
|
||||
foreach (var (key, value) in response.Headers)
|
||||
{
|
||||
// Skip some headers that shouldn't be copied
|
||||
if (key.Equals("Transfer-Encoding", StringComparison.OrdinalIgnoreCase) ||
|
||||
key.Equals("Content-Length", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
context.Response.Headers[key] = value;
|
||||
}
|
||||
|
||||
// Write response body
|
||||
if (response.Payload.Length > 0)
|
||||
{
|
||||
await context.Response.Body.WriteAsync(response.Payload, context.RequestAborted);
|
||||
}
|
||||
|
||||
_logger.LogDebug(
|
||||
"Request {RequestId} completed with status {StatusCode}",
|
||||
requestId,
|
||||
response.StatusCode);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Updates the connection's average ping time using exponential moving average.
|
||||
/// </summary>
|
||||
private static void UpdateConnectionPing(
|
||||
IGlobalRoutingState routingState,
|
||||
string connectionId,
|
||||
double pingMs)
|
||||
{
|
||||
const double smoothingFactor = 0.2;
|
||||
|
||||
routingState.UpdateConnection(connectionId, connection =>
|
||||
{
|
||||
if (connection.AveragePingMs == 0)
|
||||
{
|
||||
connection.AveragePingMs = pingMs;
|
||||
}
|
||||
else
|
||||
{
|
||||
connection.AveragePingMs = (1 - smoothingFactor) * connection.AveragePingMs + smoothingFactor * pingMs;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Dispatches a streaming request to a microservice.
|
||||
/// </summary>
|
||||
private async Task DispatchStreamingAsync(
|
||||
HttpContext context,
|
||||
ITransportClient transportClient,
|
||||
IGlobalRoutingState routingState,
|
||||
RoutingDecision decision,
|
||||
string requestId,
|
||||
Dictionary<string, string> headers)
|
||||
{
|
||||
var requestIdGuid = Guid.TryParse(requestId, out var parsed) ? parsed : Guid.NewGuid();
|
||||
|
||||
// Build request header frame (without body - will stream)
|
||||
var requestFrame = new RequestFrame
|
||||
{
|
||||
RequestId = requestId,
|
||||
CorrelationId = context.TraceIdentifier,
|
||||
Method = context.Request.Method,
|
||||
Path = context.Request.Path.ToString() + context.Request.QueryString.ToString(),
|
||||
Headers = headers,
|
||||
Payload = Array.Empty<byte>(), // Empty - body will be streamed
|
||||
TimeoutSeconds = (int)decision.EffectiveTimeout.TotalSeconds,
|
||||
SupportsStreaming = true
|
||||
};
|
||||
|
||||
var frame = FrameConverter.ToFrame(requestFrame);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Dispatching streaming {Method} {Path} to {ServiceName}/{Version}",
|
||||
requestFrame.Method,
|
||||
requestFrame.Path,
|
||||
decision.Connection.Instance.ServiceName,
|
||||
decision.Connection.Instance.Version);
|
||||
|
||||
// Create linked cancellation token with timeout
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted);
|
||||
timeoutCts.CancelAfter(decision.EffectiveTimeout);
|
||||
|
||||
// Register client disconnect handler to send CANCEL
|
||||
using var clientDisconnectRegistration = context.RequestAborted.Register(() =>
|
||||
{
|
||||
MarkCancelled(requestId);
|
||||
|
||||
_ = Task.Run(async () =>
|
||||
{
|
||||
try
|
||||
{
|
||||
await transportClient.SendCancelAsync(
|
||||
decision.Connection,
|
||||
requestIdGuid,
|
||||
CancelReasons.ClientDisconnected);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Sent CANCEL for streaming request {RequestId} due to client disconnect",
|
||||
requestId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex,
|
||||
"Failed to send CANCEL for streaming request {RequestId}",
|
||||
requestId);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
var startTimestamp = Stopwatch.GetTimestamp();
|
||||
var responseReceived = false;
|
||||
|
||||
try
|
||||
{
|
||||
// Use streaming transport method
|
||||
await transportClient.SendStreamingAsync(
|
||||
decision.Connection,
|
||||
frame,
|
||||
context.Request.Body,
|
||||
async responseBodyStream =>
|
||||
{
|
||||
responseReceived = true;
|
||||
|
||||
// For now, read the response stream and write to HTTP response
|
||||
// The response headers should be set before streaming begins
|
||||
context.Response.StatusCode = StatusCodes.Status200OK;
|
||||
context.Response.Headers["Transfer-Encoding"] = "chunked";
|
||||
context.Response.ContentType = "application/octet-stream";
|
||||
|
||||
await responseBodyStream.CopyToAsync(context.Response.Body, timeoutCts.Token);
|
||||
},
|
||||
PayloadLimits.Default,
|
||||
timeoutCts.Token);
|
||||
|
||||
// Record ping latency
|
||||
var elapsed = Stopwatch.GetElapsedTime(startTimestamp);
|
||||
UpdateConnectionPing(routingState, decision.Connection.ConnectionId, elapsed.TotalMilliseconds);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Streaming request {RequestId} completed",
|
||||
requestId);
|
||||
}
|
||||
catch (OperationCanceledException) when (!context.RequestAborted.IsCancellationRequested)
|
||||
{
|
||||
// Internal timeout
|
||||
_logger.LogWarning(
|
||||
"Streaming request {RequestId} timed out after {Timeout}",
|
||||
requestId,
|
||||
decision.EffectiveTimeout);
|
||||
|
||||
MarkCancelled(requestId);
|
||||
|
||||
try
|
||||
{
|
||||
await transportClient.SendCancelAsync(
|
||||
decision.Connection,
|
||||
requestIdGuid,
|
||||
CancelReasons.Timeout);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Failed to send cancel for streaming request {RequestId}", requestId);
|
||||
}
|
||||
|
||||
if (!responseReceived)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status504GatewayTimeout;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Upstream streaming timeout",
|
||||
service = decision.Connection.Instance.ServiceName,
|
||||
timeout = decision.EffectiveTimeout.TotalSeconds
|
||||
});
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Client disconnected
|
||||
MarkCancelled(requestId);
|
||||
_logger.LogDebug("Client disconnected, streaming request {RequestId} cancelled", requestId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex,
|
||||
"Error dispatching streaming request {RequestId}",
|
||||
requestId);
|
||||
|
||||
if (!responseReceived)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status502BadGateway;
|
||||
await context.Response.WriteAsJsonAsync(new
|
||||
{
|
||||
error = "Upstream streaming error",
|
||||
message = ex.Message
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
84
src/Gateway/StellaOps.Gateway.WebService/PingTracker.cs
Normal file
84
src/Gateway/StellaOps.Gateway.WebService/PingTracker.cs
Normal file
@@ -0,0 +1,84 @@
|
||||
using System.Collections.Concurrent;
|
||||
using System.Diagnostics;
|
||||
|
||||
namespace StellaOps.Gateway.WebService;
|
||||
|
||||
/// <summary>
|
||||
/// Tracks round-trip time for requests to compute average ping latency.
|
||||
/// </summary>
|
||||
internal sealed class PingTracker
|
||||
{
|
||||
private readonly ConcurrentDictionary<Guid, long> _pendingRequests = new();
|
||||
private readonly object _lock = new();
|
||||
private double _averagePingMs;
|
||||
private const double SmoothingFactor = 0.2;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the exponential moving average of ping times in milliseconds.
|
||||
/// </summary>
|
||||
public double AveragePingMs
|
||||
{
|
||||
get
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
return _averagePingMs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Records that a request has been sent.
|
||||
/// </summary>
|
||||
/// <param name="correlationId">The correlation ID of the request.</param>
|
||||
public void RecordRequestSent(Guid correlationId)
|
||||
{
|
||||
_pendingRequests[correlationId] = Stopwatch.GetTimestamp();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Records that a response has been received and updates the average ping.
|
||||
/// </summary>
|
||||
/// <param name="correlationId">The correlation ID of the request.</param>
|
||||
/// <returns>The round-trip time in milliseconds, or null if the correlation ID was not found.</returns>
|
||||
public double? RecordResponseReceived(Guid correlationId)
|
||||
{
|
||||
if (!_pendingRequests.TryRemove(correlationId, out var startTicks))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var elapsed = Stopwatch.GetElapsedTime(startTicks);
|
||||
var rtt = elapsed.TotalMilliseconds;
|
||||
|
||||
lock (_lock)
|
||||
{
|
||||
// Exponential moving average: avg = (1 - alpha) * avg + alpha * new_value
|
||||
if (_averagePingMs == 0)
|
||||
{
|
||||
_averagePingMs = rtt; // First measurement
|
||||
}
|
||||
else
|
||||
{
|
||||
_averagePingMs = (1 - SmoothingFactor) * _averagePingMs + SmoothingFactor * rtt;
|
||||
}
|
||||
}
|
||||
|
||||
return rtt;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Removes a pending request without recording a response.
|
||||
/// Call this when a request times out or is cancelled.
|
||||
/// </summary>
|
||||
/// <param name="correlationId">The correlation ID of the request.</param>
|
||||
public void RemovePending(Guid correlationId)
|
||||
{
|
||||
_pendingRequests.TryRemove(correlationId, out _);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the number of pending requests.
|
||||
/// </summary>
|
||||
public int PendingCount => _pendingRequests.Count;
|
||||
}
|
||||
@@ -1,12 +1,19 @@
|
||||
using StellaOps.Gateway.WebService;
|
||||
|
||||
var builder = WebApplication.CreateBuilder(args);
|
||||
|
||||
// Placeholder: Gateway services will be registered here in later sprints
|
||||
// Register gateway routing services
|
||||
builder.Services.AddGatewayRouting(builder.Configuration);
|
||||
|
||||
var app = builder.Build();
|
||||
|
||||
// Placeholder: Middleware pipeline will be configured here in later sprints
|
||||
// Health check endpoint (not routed through gateway middleware)
|
||||
app.MapGet("/health", () => Results.Ok(new { status = "healthy" }));
|
||||
|
||||
// Gateway router middleware pipeline
|
||||
// All other requests are routed through the gateway
|
||||
app.UseGatewayRouter();
|
||||
|
||||
app.Run();
|
||||
|
||||
// Make Program class accessible for integration tests
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
namespace StellaOps.Gateway.WebService;
|
||||
|
||||
/// <summary>
|
||||
/// Well-known HttpContext.Items keys for router pipeline.
|
||||
/// </summary>
|
||||
public static class RouterHttpContextKeys
|
||||
{
|
||||
/// <summary>
|
||||
/// Key for the resolved <see cref="StellaOps.Router.Common.Models.EndpointDescriptor"/>.
|
||||
/// </summary>
|
||||
public const string EndpointDescriptor = "Stella.EndpointDescriptor";
|
||||
|
||||
/// <summary>
|
||||
/// Key for the <see cref="StellaOps.Router.Common.Models.RoutingDecision"/>.
|
||||
/// </summary>
|
||||
public const string RoutingDecision = "Stella.RoutingDecision";
|
||||
|
||||
/// <summary>
|
||||
/// Key for path parameters extracted from route template matching.
|
||||
/// </summary>
|
||||
public const string PathParameters = "Stella.PathParameters";
|
||||
}
|
||||
67
src/Gateway/StellaOps.Gateway.WebService/RoutingOptions.cs
Normal file
67
src/Gateway/StellaOps.Gateway.WebService/RoutingOptions.cs
Normal file
@@ -0,0 +1,67 @@
|
||||
namespace StellaOps.Gateway.WebService;
|
||||
|
||||
/// <summary>
|
||||
/// Tie-breaker mode for routing when multiple instances have equal priority.
|
||||
/// </summary>
|
||||
public enum TieBreakerMode
|
||||
{
|
||||
/// <summary>
|
||||
/// Select randomly among tied instances.
|
||||
/// </summary>
|
||||
Random,
|
||||
|
||||
/// <summary>
|
||||
/// Rotate through tied instances in order.
|
||||
/// </summary>
|
||||
RoundRobin
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for routing behavior.
|
||||
/// </summary>
|
||||
public sealed class RoutingOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Configuration section name for binding.
|
||||
/// </summary>
|
||||
public const string SectionName = "Routing";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the default version to use when no version is specified in the request.
|
||||
/// If null, requests without version specification will match any available version.
|
||||
/// </summary>
|
||||
public string? DefaultVersion { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to enable strict version matching.
|
||||
/// When true, requests must specify an exact version.
|
||||
/// When false, requests can match compatible versions.
|
||||
/// </summary>
|
||||
public bool StrictVersionMatching { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the timeout for routing decisions in milliseconds.
|
||||
/// </summary>
|
||||
public int RoutingTimeoutMs { get; set; } = 30000;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to prefer local region instances over neighbor regions.
|
||||
/// </summary>
|
||||
public bool PreferLocalRegion { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to allow routing to degraded instances when no healthy instances are available.
|
||||
/// </summary>
|
||||
public bool AllowDegradedInstances { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the tie-breaker mode when multiple instances have equal priority.
|
||||
/// </summary>
|
||||
public TieBreakerMode TieBreaker { get; set; } = TieBreakerMode.Random;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the ping tolerance in milliseconds for considering instances "tied".
|
||||
/// Instances within this tolerance of each other are considered to have equal latency.
|
||||
/// </summary>
|
||||
public double PingToleranceMs { get; set; } = 0.1;
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Transport.InMemory;
|
||||
|
||||
namespace StellaOps.Gateway.WebService;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering gateway routing services.
|
||||
/// </summary>
|
||||
public static class ServiceCollectionExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds gateway routing services to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configuration">The configuration.</param>
|
||||
/// <returns>The service collection for chaining.</returns>
|
||||
public static IServiceCollection AddGatewayRouting(
|
||||
this IServiceCollection services,
|
||||
IConfiguration configuration)
|
||||
{
|
||||
// Bind configuration options
|
||||
services.Configure<GatewayNodeConfig>(
|
||||
configuration.GetSection(GatewayNodeConfig.SectionName));
|
||||
services.Configure<RoutingOptions>(
|
||||
configuration.GetSection(RoutingOptions.SectionName));
|
||||
services.Configure<HealthOptions>(
|
||||
configuration.GetSection(HealthOptions.SectionName));
|
||||
|
||||
// Register routing state as singleton (shared across all requests)
|
||||
services.AddSingleton<IGlobalRoutingState, InMemoryRoutingState>();
|
||||
|
||||
// Register routing plugin
|
||||
services.AddSingleton<IRoutingPlugin, DefaultRoutingPlugin>();
|
||||
|
||||
// Register InMemory transport (for development/testing)
|
||||
services.AddInMemoryTransport();
|
||||
|
||||
// Register connection manager as hosted service
|
||||
services.AddHostedService<ConnectionManager>();
|
||||
|
||||
// Register health monitor as hosted service
|
||||
services.AddHostedService<HealthMonitorService>();
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds gateway routing services with custom options.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configureGateway">Action to configure gateway node options.</param>
|
||||
/// <param name="configureRouting">Action to configure routing options.</param>
|
||||
/// <returns>The service collection for chaining.</returns>
|
||||
public static IServiceCollection AddGatewayRouting(
|
||||
this IServiceCollection services,
|
||||
Action<GatewayNodeConfig>? configureGateway = null,
|
||||
Action<RoutingOptions>? configureRouting = null)
|
||||
{
|
||||
// Ensure default options are registered even if no configuration action provided
|
||||
services.AddOptions<GatewayNodeConfig>();
|
||||
services.AddOptions<RoutingOptions>();
|
||||
|
||||
// Configure options via actions
|
||||
if (configureGateway is not null)
|
||||
{
|
||||
services.Configure(configureGateway);
|
||||
}
|
||||
|
||||
if (configureRouting is not null)
|
||||
{
|
||||
services.Configure(configureRouting);
|
||||
}
|
||||
|
||||
// Register routing state as singleton (shared across all requests)
|
||||
services.AddSingleton<IGlobalRoutingState, InMemoryRoutingState>();
|
||||
|
||||
// Register routing plugin
|
||||
services.AddSingleton<IRoutingPlugin, DefaultRoutingPlugin>();
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
@@ -9,5 +9,9 @@
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\__Libraries\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
|
||||
<ProjectReference Include="..\..\__Libraries\StellaOps.Router.Config\StellaOps.Router.Config.csproj" />
|
||||
<ProjectReference Include="..\..\__Libraries\StellaOps.Router.Transport.InMemory\StellaOps.Router.Transport.InMemory.csproj" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<InternalsVisibleTo Include="StellaOps.Gateway.WebService.Tests" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Microservice;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using StellaOps.Router.Transport.InMemory;
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Tests;
|
||||
|
||||
public class CancellationTests
|
||||
{
|
||||
private readonly InMemoryConnectionRegistry _registry = new();
|
||||
private readonly InMemoryTransportOptions _options = new() { SimulatedLatency = TimeSpan.Zero };
|
||||
|
||||
private InMemoryTransportClient CreateClient()
|
||||
{
|
||||
return new InMemoryTransportClient(
|
||||
_registry,
|
||||
Options.Create(_options),
|
||||
NullLogger<InMemoryTransportClient>.Instance);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CancelReasons_HasAllExpectedConstants()
|
||||
{
|
||||
Assert.Equal("ClientDisconnected", CancelReasons.ClientDisconnected);
|
||||
Assert.Equal("Timeout", CancelReasons.Timeout);
|
||||
Assert.Equal("PayloadLimitExceeded", CancelReasons.PayloadLimitExceeded);
|
||||
Assert.Equal("Shutdown", CancelReasons.Shutdown);
|
||||
Assert.Equal("ConnectionClosed", CancelReasons.ConnectionClosed);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ConnectAsync_RegistersWithRegistry()
|
||||
{
|
||||
// Arrange
|
||||
using var client = CreateClient();
|
||||
var instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = "test-instance",
|
||||
ServiceName = "test-service",
|
||||
Version = "1.0.0",
|
||||
Region = "us-east-1"
|
||||
};
|
||||
|
||||
// Act
|
||||
await client.ConnectAsync(instance, [], CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
var connectionIdField = client.GetType()
|
||||
.GetField("_connectionId", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
|
||||
var connectionId = connectionIdField?.GetValue(client)?.ToString();
|
||||
Assert.NotNull(connectionId);
|
||||
|
||||
var channel = _registry.GetChannel(connectionId!);
|
||||
Assert.NotNull(channel);
|
||||
Assert.Equal(instance.InstanceId, channel!.Instance?.InstanceId);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CancelAllInflight_DoesNotThrowWhenEmpty()
|
||||
{
|
||||
// Arrange
|
||||
using var client = CreateClient();
|
||||
|
||||
// Act & Assert - should not throw
|
||||
client.CancelAllInflight(CancelReasons.Shutdown);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Dispose_DoesNotThrow()
|
||||
{
|
||||
// Arrange
|
||||
var client = CreateClient();
|
||||
|
||||
// Act & Assert - should not throw
|
||||
client.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task DisconnectAsync_CancelsAllInflightWithShutdownReason()
|
||||
{
|
||||
// Arrange
|
||||
using var client = CreateClient();
|
||||
var instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = "test-instance",
|
||||
ServiceName = "test-service",
|
||||
Version = "1.0.0",
|
||||
Region = "us-east-1"
|
||||
};
|
||||
|
||||
await client.ConnectAsync(instance, [], CancellationToken.None);
|
||||
|
||||
// Act
|
||||
await client.DisconnectAsync();
|
||||
|
||||
// Assert - no exception means success
|
||||
}
|
||||
}
|
||||
|
||||
public class InflightRequestTrackerTests
|
||||
{
|
||||
[Fact]
|
||||
public void Track_ReturnsCancellationToken()
|
||||
{
|
||||
// Arrange
|
||||
using var tracker = new InflightRequestTracker(
|
||||
NullLogger<InflightRequestTracker>.Instance);
|
||||
var correlationId = Guid.NewGuid();
|
||||
|
||||
// Act
|
||||
var token = tracker.Track(correlationId);
|
||||
|
||||
// Assert
|
||||
Assert.False(token.IsCancellationRequested);
|
||||
Assert.Equal(1, tracker.Count);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Track_ThrowsIfAlreadyTracked()
|
||||
{
|
||||
// Arrange
|
||||
using var tracker = new InflightRequestTracker(
|
||||
NullLogger<InflightRequestTracker>.Instance);
|
||||
var correlationId = Guid.NewGuid();
|
||||
tracker.Track(correlationId);
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<InvalidOperationException>(() => tracker.Track(correlationId));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Cancel_TriggersCancellationToken()
|
||||
{
|
||||
// Arrange
|
||||
using var tracker = new InflightRequestTracker(
|
||||
NullLogger<InflightRequestTracker>.Instance);
|
||||
var correlationId = Guid.NewGuid();
|
||||
var token = tracker.Track(correlationId);
|
||||
|
||||
// Act
|
||||
var result = tracker.Cancel(correlationId, "TestReason");
|
||||
|
||||
// Assert
|
||||
Assert.True(result);
|
||||
Assert.True(token.IsCancellationRequested);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Cancel_ReturnsFalseForUnknownRequest()
|
||||
{
|
||||
// Arrange
|
||||
using var tracker = new InflightRequestTracker(
|
||||
NullLogger<InflightRequestTracker>.Instance);
|
||||
var correlationId = Guid.NewGuid();
|
||||
|
||||
// Act
|
||||
var result = tracker.Cancel(correlationId, "TestReason");
|
||||
|
||||
// Assert
|
||||
Assert.False(result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Complete_RemovesFromTracking()
|
||||
{
|
||||
// Arrange
|
||||
using var tracker = new InflightRequestTracker(
|
||||
NullLogger<InflightRequestTracker>.Instance);
|
||||
var correlationId = Guid.NewGuid();
|
||||
tracker.Track(correlationId);
|
||||
Assert.Equal(1, tracker.Count);
|
||||
|
||||
// Act
|
||||
tracker.Complete(correlationId);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(0, tracker.Count);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CancelAll_CancelsAllTrackedRequests()
|
||||
{
|
||||
// Arrange
|
||||
using var tracker = new InflightRequestTracker(
|
||||
NullLogger<InflightRequestTracker>.Instance);
|
||||
|
||||
var tokens = new List<CancellationToken>();
|
||||
for (var i = 0; i < 5; i++)
|
||||
{
|
||||
tokens.Add(tracker.Track(Guid.NewGuid()));
|
||||
}
|
||||
|
||||
// Act
|
||||
tracker.CancelAll("TestReason");
|
||||
|
||||
// Assert
|
||||
Assert.All(tokens, t => Assert.True(t.IsCancellationRequested));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Dispose_CancelsAllTrackedRequests()
|
||||
{
|
||||
// Arrange
|
||||
var tracker = new InflightRequestTracker(
|
||||
NullLogger<InflightRequestTracker>.Instance);
|
||||
|
||||
var tokens = new List<CancellationToken>();
|
||||
for (var i = 0; i < 3; i++)
|
||||
{
|
||||
tokens.Add(tracker.Track(Guid.NewGuid()));
|
||||
}
|
||||
|
||||
// Act
|
||||
tracker.Dispose();
|
||||
|
||||
// Assert
|
||||
Assert.All(tokens, t => Assert.True(t.IsCancellationRequested));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
using FluentAssertions;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Moq;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using StellaOps.Router.Transport.InMemory;
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Tests;
|
||||
|
||||
/// <summary>
|
||||
/// Integration-style tests for <see cref="ConnectionManager"/>.
|
||||
/// Uses real InMemoryTransportServer since it's a sealed class.
|
||||
/// </summary>
|
||||
public sealed class ConnectionManagerTests : IAsyncLifetime
|
||||
{
|
||||
private readonly InMemoryConnectionRegistry _connectionRegistry;
|
||||
private readonly InMemoryTransportServer _transportServer;
|
||||
private readonly Mock<IGlobalRoutingState> _routingStateMock;
|
||||
private readonly ConnectionManager _manager;
|
||||
|
||||
public ConnectionManagerTests()
|
||||
{
|
||||
_connectionRegistry = new InMemoryConnectionRegistry();
|
||||
|
||||
var options = Options.Create(new InMemoryTransportOptions());
|
||||
_transportServer = new InMemoryTransportServer(
|
||||
_connectionRegistry,
|
||||
options,
|
||||
NullLogger<InMemoryTransportServer>.Instance);
|
||||
|
||||
_routingStateMock = new Mock<IGlobalRoutingState>(MockBehavior.Loose);
|
||||
|
||||
_manager = new ConnectionManager(
|
||||
_transportServer,
|
||||
_connectionRegistry,
|
||||
_routingStateMock.Object,
|
||||
NullLogger<ConnectionManager>.Instance);
|
||||
}
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
await _manager.StartAsync(CancellationToken.None);
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
await _manager.StopAsync(CancellationToken.None);
|
||||
_transportServer.Dispose();
|
||||
}
|
||||
|
||||
#region StartAsync/StopAsync Tests
|
||||
|
||||
[Fact]
|
||||
public async Task StartAsync_ShouldStartSuccessfully()
|
||||
{
|
||||
// The manager starts in InitializeAsync
|
||||
// Just verify it can be started without exception
|
||||
await Task.CompletedTask;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task StopAsync_ShouldStopSuccessfully()
|
||||
{
|
||||
// This is tested in DisposeAsync
|
||||
await Task.CompletedTask;
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Connection Registration Tests via Channel Simulation
|
||||
|
||||
[Fact]
|
||||
public async Task WhenHelloReceived_AddsConnectionToRoutingState()
|
||||
{
|
||||
// Arrange
|
||||
var channel = CreateAndRegisterChannel("conn-1", "service-a", "1.0.0");
|
||||
|
||||
// Simulate sending a HELLO frame through the channel
|
||||
var helloFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Hello,
|
||||
CorrelationId = Guid.NewGuid().ToString()
|
||||
};
|
||||
|
||||
// Act
|
||||
await channel.ToGateway.Writer.WriteAsync(helloFrame);
|
||||
|
||||
// Give time for the frame to be processed
|
||||
await Task.Delay(100);
|
||||
|
||||
// Assert
|
||||
_routingStateMock.Verify(
|
||||
s => s.AddConnection(It.Is<ConnectionState>(c => c.ConnectionId == "conn-1")),
|
||||
Times.Once);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WhenHeartbeatReceived_UpdatesConnectionState()
|
||||
{
|
||||
// Arrange
|
||||
var channel = CreateAndRegisterChannel("conn-1", "service-a", "1.0.0");
|
||||
|
||||
// First send HELLO to register the connection
|
||||
var helloFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Hello,
|
||||
CorrelationId = Guid.NewGuid().ToString()
|
||||
};
|
||||
await channel.ToGateway.Writer.WriteAsync(helloFrame);
|
||||
await Task.Delay(100);
|
||||
|
||||
// Act - send heartbeat
|
||||
var heartbeatFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Heartbeat,
|
||||
CorrelationId = Guid.NewGuid().ToString()
|
||||
};
|
||||
await channel.ToGateway.Writer.WriteAsync(heartbeatFrame);
|
||||
await Task.Delay(100);
|
||||
|
||||
// Assert
|
||||
_routingStateMock.Verify(
|
||||
s => s.UpdateConnection("conn-1", It.IsAny<Action<ConnectionState>>()),
|
||||
Times.AtLeastOnce);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WhenConnectionClosed_RemovesConnectionFromRoutingState()
|
||||
{
|
||||
// Arrange
|
||||
var channel = CreateAndRegisterChannel("conn-1", "service-a", "1.0.0");
|
||||
|
||||
// First send HELLO to register the connection
|
||||
var helloFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Hello,
|
||||
CorrelationId = Guid.NewGuid().ToString()
|
||||
};
|
||||
await channel.ToGateway.Writer.WriteAsync(helloFrame);
|
||||
await Task.Delay(100);
|
||||
|
||||
// Act - close the channel
|
||||
await channel.LifetimeToken.CancelAsync();
|
||||
|
||||
// Give time for the close to be processed
|
||||
await Task.Delay(200);
|
||||
|
||||
// Assert - may be called multiple times (on close and on stop)
|
||||
_routingStateMock.Verify(
|
||||
s => s.RemoveConnection("conn-1"),
|
||||
Times.AtLeastOnce);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WhenMultipleConnectionsRegister_AllAreTracked()
|
||||
{
|
||||
// Arrange
|
||||
var channel1 = CreateAndRegisterChannel("conn-1", "service-a", "1.0.0");
|
||||
var channel2 = CreateAndRegisterChannel("conn-2", "service-b", "2.0.0");
|
||||
|
||||
// Act - send HELLO frames
|
||||
await channel1.ToGateway.Writer.WriteAsync(new Frame
|
||||
{
|
||||
Type = FrameType.Hello,
|
||||
CorrelationId = Guid.NewGuid().ToString()
|
||||
});
|
||||
await channel2.ToGateway.Writer.WriteAsync(new Frame
|
||||
{
|
||||
Type = FrameType.Hello,
|
||||
CorrelationId = Guid.NewGuid().ToString()
|
||||
});
|
||||
await Task.Delay(150);
|
||||
|
||||
// Assert
|
||||
_routingStateMock.Verify(
|
||||
s => s.AddConnection(It.Is<ConnectionState>(c => c.ConnectionId == "conn-1")),
|
||||
Times.Once);
|
||||
_routingStateMock.Verify(
|
||||
s => s.AddConnection(It.Is<ConnectionState>(c => c.ConnectionId == "conn-2")),
|
||||
Times.Once);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Helper Methods
|
||||
|
||||
private InMemoryChannel CreateAndRegisterChannel(
|
||||
string connectionId, string serviceName, string version)
|
||||
{
|
||||
var instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = $"{serviceName}-{Guid.NewGuid():N}",
|
||||
ServiceName = serviceName,
|
||||
Version = version,
|
||||
Region = "us-east-1"
|
||||
};
|
||||
|
||||
// Create channel through the registry
|
||||
var channel = _connectionRegistry.CreateChannel(connectionId);
|
||||
channel.Instance = instance;
|
||||
|
||||
// Simulate that the transport server is listening to this connection
|
||||
_transportServer.StartListeningToConnection(connectionId);
|
||||
|
||||
return channel;
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
@@ -0,0 +1,538 @@
|
||||
using FluentAssertions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Tests;
|
||||
|
||||
public class DefaultRoutingPluginTests
|
||||
{
|
||||
private readonly RoutingOptions _options = new()
|
||||
{
|
||||
DefaultVersion = null,
|
||||
StrictVersionMatching = true,
|
||||
RoutingTimeoutMs = 30000,
|
||||
PreferLocalRegion = true,
|
||||
AllowDegradedInstances = true,
|
||||
TieBreaker = TieBreakerMode.Random,
|
||||
PingToleranceMs = 0.1
|
||||
};
|
||||
|
||||
private readonly GatewayNodeConfig _gatewayConfig = new()
|
||||
{
|
||||
Region = "us-east-1",
|
||||
NodeId = "gw-test-01",
|
||||
Environment = "test",
|
||||
NeighborRegions = ["eu-west-1", "us-west-2"]
|
||||
};
|
||||
|
||||
private DefaultRoutingPlugin CreateSut(
|
||||
Action<RoutingOptions>? configureOptions = null,
|
||||
Action<GatewayNodeConfig>? configureGateway = null)
|
||||
{
|
||||
configureOptions?.Invoke(_options);
|
||||
configureGateway?.Invoke(_gatewayConfig);
|
||||
return new DefaultRoutingPlugin(
|
||||
Options.Create(_options),
|
||||
Options.Create(_gatewayConfig));
|
||||
}
|
||||
|
||||
private static ConnectionState CreateConnection(
|
||||
string connectionId = "conn-1",
|
||||
string serviceName = "test-service",
|
||||
string version = "1.0.0",
|
||||
string region = "us-east-1",
|
||||
InstanceHealthStatus status = InstanceHealthStatus.Healthy,
|
||||
double averagePingMs = 0,
|
||||
DateTime? lastHeartbeatUtc = null)
|
||||
{
|
||||
return new ConnectionState
|
||||
{
|
||||
ConnectionId = connectionId,
|
||||
Instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = $"inst-{connectionId}",
|
||||
ServiceName = serviceName,
|
||||
Version = version,
|
||||
Region = region
|
||||
},
|
||||
Status = status,
|
||||
TransportType = TransportType.InMemory,
|
||||
AveragePingMs = averagePingMs,
|
||||
LastHeartbeatUtc = lastHeartbeatUtc ?? DateTime.UtcNow
|
||||
};
|
||||
}
|
||||
|
||||
private static EndpointDescriptor CreateEndpoint(
|
||||
string method = "GET",
|
||||
string path = "/api/test",
|
||||
string serviceName = "test-service",
|
||||
string version = "1.0.0")
|
||||
{
|
||||
return new EndpointDescriptor
|
||||
{
|
||||
Method = method,
|
||||
Path = path,
|
||||
ServiceName = serviceName,
|
||||
Version = version
|
||||
};
|
||||
}
|
||||
|
||||
private static RoutingContext CreateContext(
|
||||
string method = "GET",
|
||||
string path = "/api/test",
|
||||
string gatewayRegion = "us-east-1",
|
||||
string? requestedVersion = null,
|
||||
EndpointDescriptor? endpoint = null,
|
||||
params ConnectionState[] connections)
|
||||
{
|
||||
return new RoutingContext
|
||||
{
|
||||
Method = method,
|
||||
Path = path,
|
||||
GatewayRegion = gatewayRegion,
|
||||
RequestedVersion = requestedVersion,
|
||||
Endpoint = endpoint ?? CreateEndpoint(),
|
||||
AvailableConnections = connections
|
||||
};
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldReturnNull_WhenNoConnections()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var context = CreateContext();
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldReturnNull_WhenNoEndpoint()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var connection = CreateConnection();
|
||||
var context = new RoutingContext
|
||||
{
|
||||
Method = "GET",
|
||||
Path = "/api/test",
|
||||
GatewayRegion = "us-east-1",
|
||||
Endpoint = null,
|
||||
AvailableConnections = [connection]
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldSelectHealthyConnection()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var connection = CreateConnection(status: InstanceHealthStatus.Healthy);
|
||||
var context = CreateContext(connections: [connection]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.Should().BeSameAs(connection);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldPreferHealthyOverDegraded()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var degraded = CreateConnection("conn-1", status: InstanceHealthStatus.Degraded);
|
||||
var healthy = CreateConnection("conn-2", status: InstanceHealthStatus.Healthy);
|
||||
var context = CreateContext(connections: [degraded, healthy]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.Status.Should().Be(InstanceHealthStatus.Healthy);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldSelectDegraded_WhenNoHealthyAndAllowed()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut(configureOptions: o => o.AllowDegradedInstances = true);
|
||||
var degraded = CreateConnection(status: InstanceHealthStatus.Degraded);
|
||||
var context = CreateContext(connections: [degraded]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.Status.Should().Be(InstanceHealthStatus.Degraded);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldReturnNull_WhenOnlyDegradedAndNotAllowed()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut(configureOptions: o => o.AllowDegradedInstances = false);
|
||||
var degraded = CreateConnection(status: InstanceHealthStatus.Degraded);
|
||||
var context = CreateContext(connections: [degraded]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldExcludeUnhealthy()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var unhealthy = CreateConnection("conn-1", status: InstanceHealthStatus.Unhealthy);
|
||||
var healthy = CreateConnection("conn-2", status: InstanceHealthStatus.Healthy);
|
||||
var context = CreateContext(connections: [unhealthy, healthy]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.ConnectionId.Should().Be("conn-2");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldExcludeDraining()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var draining = CreateConnection("conn-1", status: InstanceHealthStatus.Draining);
|
||||
var healthy = CreateConnection("conn-2", status: InstanceHealthStatus.Healthy);
|
||||
var context = CreateContext(connections: [draining, healthy]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.ConnectionId.Should().Be("conn-2");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldFilterByRequestedVersion()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var v1 = CreateConnection("conn-1", version: "1.0.0");
|
||||
var v2 = CreateConnection("conn-2", version: "2.0.0");
|
||||
var context = CreateContext(requestedVersion: "2.0.0", connections: [v1, v2]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.Instance.Version.Should().Be("2.0.0");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldUseDefaultVersion_WhenNoRequestedVersion()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut(configureOptions: o => o.DefaultVersion = "1.0.0");
|
||||
var v1 = CreateConnection("conn-1", version: "1.0.0");
|
||||
var v2 = CreateConnection("conn-2", version: "2.0.0");
|
||||
var context = CreateContext(requestedVersion: null, connections: [v1, v2]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.Instance.Version.Should().Be("1.0.0");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldReturnNull_WhenNoMatchingVersion()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var v1 = CreateConnection("conn-1", version: "1.0.0");
|
||||
var context = CreateContext(requestedVersion: "2.0.0", connections: [v1]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldMatchAnyVersion_WhenNoVersionSpecified()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut(configureOptions: o => o.DefaultVersion = null);
|
||||
var v1 = CreateConnection("conn-1", version: "1.0.0");
|
||||
var v2 = CreateConnection("conn-2", version: "2.0.0");
|
||||
var context = CreateContext(requestedVersion: null, connections: [v1, v2]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldPreferLocalRegion()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut(configureOptions: o => o.PreferLocalRegion = true);
|
||||
var remote = CreateConnection("conn-1", region: "us-west-2");
|
||||
var local = CreateConnection("conn-2", region: "us-east-1");
|
||||
var context = CreateContext(gatewayRegion: "us-east-1", connections: [remote, local]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.Instance.Region.Should().Be("us-east-1");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldAllowRemoteRegion_WhenNoLocalAvailable()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut(configureOptions: o => o.PreferLocalRegion = true);
|
||||
var remote = CreateConnection("conn-1", region: "us-west-2");
|
||||
var context = CreateContext(gatewayRegion: "us-east-1", connections: [remote]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.Instance.Region.Should().Be("us-west-2");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldIgnoreRegionPreference_WhenDisabled()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut(configureOptions: o => o.PreferLocalRegion = false);
|
||||
// Create connections with same ping and heartbeat so they are tied
|
||||
var sameHeartbeat = DateTime.UtcNow;
|
||||
var remote = CreateConnection("conn-1", region: "us-west-2", lastHeartbeatUtc: sameHeartbeat);
|
||||
var local = CreateConnection("conn-2", region: "us-east-1", lastHeartbeatUtc: sameHeartbeat);
|
||||
var context = CreateContext(gatewayRegion: "us-east-1", connections: [remote, local]);
|
||||
|
||||
// Act - run multiple times to verify random selection includes both
|
||||
var selectedRegions = new HashSet<string>();
|
||||
for (int i = 0; i < 50; i++)
|
||||
{
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
selectedRegions.Add(result!.Connection.Instance.Region);
|
||||
}
|
||||
|
||||
// Assert - with random selection, we should see both regions selected
|
||||
// Note: This is probabilistic but should almost always pass
|
||||
selectedRegions.Should().Contain("us-west-2");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldSetCorrectTimeout()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut(configureOptions: o => o.RoutingTimeoutMs = 5000);
|
||||
var connection = CreateConnection();
|
||||
var context = CreateContext(connections: [connection]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.EffectiveTimeout.Should().Be(TimeSpan.FromMilliseconds(5000));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldSetCorrectTransportType()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var connection = CreateConnection();
|
||||
var context = CreateContext(connections: [connection]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.TransportType.Should().Be(TransportType.InMemory);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldReturnEndpointFromContext()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var endpoint = CreateEndpoint(path: "/api/special");
|
||||
var connection = CreateConnection();
|
||||
var context = CreateContext(endpoint: endpoint, connections: [connection]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Endpoint.Path.Should().Be("/api/special");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldDistributeLoadAcrossMultipleConnections()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
// Create connections with same ping and heartbeat so they are tied
|
||||
var sameHeartbeat = DateTime.UtcNow;
|
||||
var conn1 = CreateConnection("conn-1", lastHeartbeatUtc: sameHeartbeat);
|
||||
var conn2 = CreateConnection("conn-2", lastHeartbeatUtc: sameHeartbeat);
|
||||
var conn3 = CreateConnection("conn-3", lastHeartbeatUtc: sameHeartbeat);
|
||||
var context = CreateContext(connections: [conn1, conn2, conn3]);
|
||||
|
||||
// Act - run multiple times
|
||||
var selectedConnections = new Dictionary<string, int>();
|
||||
for (int i = 0; i < 100; i++)
|
||||
{
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
var connId = result!.Connection.ConnectionId;
|
||||
selectedConnections[connId] = selectedConnections.GetValueOrDefault(connId) + 1;
|
||||
}
|
||||
|
||||
// Assert - all connections should be selected at least once (probabilistic with random tie-breaker)
|
||||
selectedConnections.Should().HaveCount(3);
|
||||
selectedConnections.Keys.Should().Contain("conn-1");
|
||||
selectedConnections.Keys.Should().Contain("conn-2");
|
||||
selectedConnections.Keys.Should().Contain("conn-3");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldPreferLowerPing()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var sameHeartbeat = DateTime.UtcNow;
|
||||
var highPing = CreateConnection("conn-1", averagePingMs: 100, lastHeartbeatUtc: sameHeartbeat);
|
||||
var lowPing = CreateConnection("conn-2", averagePingMs: 10, lastHeartbeatUtc: sameHeartbeat);
|
||||
var context = CreateContext(connections: [highPing, lowPing]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert - lower ping should be preferred
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.ConnectionId.Should().Be("conn-2");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldPreferMoreRecentHeartbeat_WhenPingEqual()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut();
|
||||
var now = DateTime.UtcNow;
|
||||
var oldHeartbeat = CreateConnection("conn-1", averagePingMs: 10, lastHeartbeatUtc: now.AddSeconds(-30));
|
||||
var recentHeartbeat = CreateConnection("conn-2", averagePingMs: 10, lastHeartbeatUtc: now);
|
||||
var context = CreateContext(connections: [oldHeartbeat, recentHeartbeat]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert - more recent heartbeat should be preferred
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.ConnectionId.Should().Be("conn-2");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldPreferNeighborRegionOverRemote()
|
||||
{
|
||||
// Arrange - gateway config has NeighborRegions = ["eu-west-1", "us-west-2"]
|
||||
var sut = CreateSut();
|
||||
var sameHeartbeat = DateTime.UtcNow;
|
||||
var remoteRegion = CreateConnection("conn-1", region: "ap-south-1", lastHeartbeatUtc: sameHeartbeat);
|
||||
var neighborRegion = CreateConnection("conn-2", region: "eu-west-1", lastHeartbeatUtc: sameHeartbeat);
|
||||
var context = CreateContext(gatewayRegion: "us-east-1", connections: [remoteRegion, neighborRegion]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert - neighbor region should be preferred over remote
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.Instance.Region.Should().Be("eu-west-1");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldUseRoundRobin_WhenConfigured()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut(configureOptions: o => o.TieBreaker = TieBreakerMode.RoundRobin);
|
||||
var sameHeartbeat = DateTime.UtcNow;
|
||||
var conn1 = CreateConnection("conn-1", lastHeartbeatUtc: sameHeartbeat);
|
||||
var conn2 = CreateConnection("conn-2", lastHeartbeatUtc: sameHeartbeat);
|
||||
var context = CreateContext(connections: [conn1, conn2]);
|
||||
|
||||
// Act - with round-robin, we should cycle through connections
|
||||
var selections = new List<string>();
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
selections.Add(result!.Connection.ConnectionId);
|
||||
}
|
||||
|
||||
// Assert - should alternate between connections
|
||||
selections.Distinct().Count().Should().Be(2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChooseInstanceAsync_ShouldCombineFilters()
|
||||
{
|
||||
// Arrange
|
||||
var sut = CreateSut(configureOptions: o =>
|
||||
{
|
||||
o.PreferLocalRegion = true;
|
||||
o.AllowDegradedInstances = false;
|
||||
});
|
||||
|
||||
// Create various combinations
|
||||
var wrongVersionHealthyLocal = CreateConnection("conn-1", version: "2.0.0", region: "us-east-1", status: InstanceHealthStatus.Healthy);
|
||||
var rightVersionDegradedLocal = CreateConnection("conn-2", version: "1.0.0", region: "us-east-1", status: InstanceHealthStatus.Degraded);
|
||||
var rightVersionHealthyRemote = CreateConnection("conn-3", version: "1.0.0", region: "us-west-2", status: InstanceHealthStatus.Healthy);
|
||||
var rightVersionHealthyLocal = CreateConnection("conn-4", version: "1.0.0", region: "us-east-1", status: InstanceHealthStatus.Healthy);
|
||||
|
||||
var context = CreateContext(
|
||||
gatewayRegion: "us-east-1",
|
||||
requestedVersion: "1.0.0",
|
||||
connections: [wrongVersionHealthyLocal, rightVersionDegradedLocal, rightVersionHealthyRemote, rightVersionHealthyLocal]);
|
||||
|
||||
// Act
|
||||
var result = await sut.ChooseInstanceAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert - should select the only connection matching all criteria
|
||||
result.Should().NotBeNull();
|
||||
result!.Connection.ConnectionId.Should().Be("conn-4");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
using FluentAssertions;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Moq;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Tests;
|
||||
|
||||
/// <summary>
|
||||
/// Tests for <see cref="HealthMonitorService"/>.
|
||||
/// </summary>
|
||||
public sealed class HealthMonitorServiceTests
|
||||
{
|
||||
private readonly Mock<IGlobalRoutingState> _routingStateMock;
|
||||
private readonly HealthOptions _options;
|
||||
|
||||
public HealthMonitorServiceTests()
|
||||
{
|
||||
_routingStateMock = new Mock<IGlobalRoutingState>(MockBehavior.Loose);
|
||||
_options = new HealthOptions
|
||||
{
|
||||
StaleThreshold = TimeSpan.FromSeconds(10),
|
||||
DegradedThreshold = TimeSpan.FromSeconds(5),
|
||||
CheckInterval = TimeSpan.FromMilliseconds(100)
|
||||
};
|
||||
}
|
||||
|
||||
private HealthMonitorService CreateService()
|
||||
{
|
||||
return new HealthMonitorService(
|
||||
_routingStateMock.Object,
|
||||
Options.Create(_options),
|
||||
NullLogger<HealthMonitorService>.Instance);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ExecuteAsync_MarksStaleConnectionsUnhealthy()
|
||||
{
|
||||
// Arrange
|
||||
var staleConnection = CreateConnection("conn-1", "service-a", "1.0.0");
|
||||
staleConnection.Status = InstanceHealthStatus.Healthy;
|
||||
staleConnection.LastHeartbeatUtc = DateTime.UtcNow.AddSeconds(-15); // Past stale threshold
|
||||
|
||||
_routingStateMock.Setup(s => s.GetAllConnections())
|
||||
.Returns([staleConnection]);
|
||||
|
||||
var service = CreateService();
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(500));
|
||||
|
||||
// Act
|
||||
try
|
||||
{
|
||||
await service.StartAsync(cts.Token);
|
||||
await Task.Delay(200, cts.Token);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected
|
||||
}
|
||||
finally
|
||||
{
|
||||
await service.StopAsync(CancellationToken.None);
|
||||
}
|
||||
|
||||
// Assert
|
||||
_routingStateMock.Verify(
|
||||
s => s.UpdateConnection("conn-1", It.IsAny<Action<ConnectionState>>()),
|
||||
Times.AtLeastOnce);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ExecuteAsync_MarksDegradedConnectionsDegraded()
|
||||
{
|
||||
// Arrange
|
||||
var degradedConnection = CreateConnection("conn-1", "service-a", "1.0.0");
|
||||
degradedConnection.Status = InstanceHealthStatus.Healthy;
|
||||
degradedConnection.LastHeartbeatUtc = DateTime.UtcNow.AddSeconds(-7); // Past degraded but not stale
|
||||
|
||||
_routingStateMock.Setup(s => s.GetAllConnections())
|
||||
.Returns([degradedConnection]);
|
||||
|
||||
var service = CreateService();
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(1));
|
||||
|
||||
// Act
|
||||
try
|
||||
{
|
||||
await service.StartAsync(cts.Token);
|
||||
// Wait enough time for at least one check cycle (CheckInterval is 100ms)
|
||||
await Task.Delay(300, cts.Token);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected
|
||||
}
|
||||
finally
|
||||
{
|
||||
await service.StopAsync(CancellationToken.None);
|
||||
}
|
||||
|
||||
// Assert
|
||||
_routingStateMock.Verify(
|
||||
s => s.UpdateConnection("conn-1", It.IsAny<Action<ConnectionState>>()),
|
||||
Times.AtLeastOnce);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ExecuteAsync_DoesNotChangeHealthyConnections()
|
||||
{
|
||||
// Arrange
|
||||
var healthyConnection = CreateConnection("conn-1", "service-a", "1.0.0");
|
||||
healthyConnection.Status = InstanceHealthStatus.Healthy;
|
||||
healthyConnection.LastHeartbeatUtc = DateTime.UtcNow; // Fresh heartbeat
|
||||
|
||||
_routingStateMock.Setup(s => s.GetAllConnections())
|
||||
.Returns([healthyConnection]);
|
||||
|
||||
var service = CreateService();
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(300));
|
||||
|
||||
// Act
|
||||
try
|
||||
{
|
||||
await service.StartAsync(cts.Token);
|
||||
await Task.Delay(200, cts.Token);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected
|
||||
}
|
||||
finally
|
||||
{
|
||||
await service.StopAsync(CancellationToken.None);
|
||||
}
|
||||
|
||||
// Assert - should not have updated the connection
|
||||
_routingStateMock.Verify(
|
||||
s => s.UpdateConnection(It.IsAny<string>(), It.IsAny<Action<ConnectionState>>()),
|
||||
Times.Never);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ExecuteAsync_DoesNotChangeDrainingConnections()
|
||||
{
|
||||
// Arrange
|
||||
var drainingConnection = CreateConnection("conn-1", "service-a", "1.0.0");
|
||||
drainingConnection.Status = InstanceHealthStatus.Draining;
|
||||
drainingConnection.LastHeartbeatUtc = DateTime.UtcNow.AddSeconds(-30); // Very stale
|
||||
|
||||
_routingStateMock.Setup(s => s.GetAllConnections())
|
||||
.Returns([drainingConnection]);
|
||||
|
||||
var service = CreateService();
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(300));
|
||||
|
||||
// Act
|
||||
try
|
||||
{
|
||||
await service.StartAsync(cts.Token);
|
||||
await Task.Delay(200, cts.Token);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected
|
||||
}
|
||||
finally
|
||||
{
|
||||
await service.StopAsync(CancellationToken.None);
|
||||
}
|
||||
|
||||
// Assert - draining connections should be left alone
|
||||
_routingStateMock.Verify(
|
||||
s => s.UpdateConnection(It.IsAny<string>(), It.IsAny<Action<ConnectionState>>()),
|
||||
Times.Never);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ExecuteAsync_DoesNotDoubleMarkUnhealthy()
|
||||
{
|
||||
// Arrange
|
||||
var unhealthyConnection = CreateConnection("conn-1", "service-a", "1.0.0");
|
||||
unhealthyConnection.Status = InstanceHealthStatus.Unhealthy;
|
||||
unhealthyConnection.LastHeartbeatUtc = DateTime.UtcNow.AddSeconds(-30); // Very stale
|
||||
|
||||
_routingStateMock.Setup(s => s.GetAllConnections())
|
||||
.Returns([unhealthyConnection]);
|
||||
|
||||
var service = CreateService();
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(300));
|
||||
|
||||
// Act
|
||||
try
|
||||
{
|
||||
await service.StartAsync(cts.Token);
|
||||
await Task.Delay(200, cts.Token);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected
|
||||
}
|
||||
finally
|
||||
{
|
||||
await service.StopAsync(CancellationToken.None);
|
||||
}
|
||||
|
||||
// Assert - already unhealthy connections should not be updated
|
||||
_routingStateMock.Verify(
|
||||
s => s.UpdateConnection(It.IsAny<string>(), It.IsAny<Action<ConnectionState>>()),
|
||||
Times.Never);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task UpdateAction_SetsStatusToUnhealthy()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection("conn-1", "service-a", "1.0.0");
|
||||
connection.Status = InstanceHealthStatus.Healthy;
|
||||
connection.LastHeartbeatUtc = DateTime.UtcNow.AddSeconds(-15);
|
||||
|
||||
Action<ConnectionState>? capturedAction = null;
|
||||
_routingStateMock.Setup(s => s.UpdateConnection("conn-1", It.IsAny<Action<ConnectionState>>()))
|
||||
.Callback<string, Action<ConnectionState>>((id, action) => capturedAction = action);
|
||||
_routingStateMock.Setup(s => s.GetAllConnections())
|
||||
.Returns([connection]);
|
||||
|
||||
var service = CreateService();
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(300));
|
||||
|
||||
// Act - run the service briefly
|
||||
try
|
||||
{
|
||||
await service.StartAsync(cts.Token);
|
||||
await Task.Delay(200, cts.Token);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected
|
||||
}
|
||||
finally
|
||||
{
|
||||
await service.StopAsync(CancellationToken.None);
|
||||
}
|
||||
|
||||
// Assert
|
||||
capturedAction.Should().NotBeNull();
|
||||
|
||||
// Apply the action to verify it sets Unhealthy
|
||||
var testConnection = CreateConnection("conn-1", "service-a", "1.0.0");
|
||||
testConnection.Status = InstanceHealthStatus.Healthy;
|
||||
capturedAction!(testConnection);
|
||||
|
||||
testConnection.Status.Should().Be(InstanceHealthStatus.Unhealthy);
|
||||
}
|
||||
|
||||
private static ConnectionState CreateConnection(
|
||||
string connectionId, string serviceName, string version)
|
||||
{
|
||||
return new ConnectionState
|
||||
{
|
||||
ConnectionId = connectionId,
|
||||
Instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = $"{serviceName}-{Guid.NewGuid():N}",
|
||||
ServiceName = serviceName,
|
||||
Version = version,
|
||||
Region = "us-east-1"
|
||||
},
|
||||
Status = InstanceHealthStatus.Healthy,
|
||||
LastHeartbeatUtc = DateTime.UtcNow,
|
||||
TransportType = TransportType.InMemory
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
using FluentAssertions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Tests;
|
||||
|
||||
public class InMemoryRoutingStateTests
|
||||
{
|
||||
private readonly InMemoryRoutingState _sut = new();
|
||||
|
||||
private static ConnectionState CreateConnection(
|
||||
string connectionId = "conn-1",
|
||||
string serviceName = "test-service",
|
||||
string version = "1.0.0",
|
||||
string region = "us-east-1",
|
||||
InstanceHealthStatus status = InstanceHealthStatus.Healthy,
|
||||
params (string Method, string Path)[] endpoints)
|
||||
{
|
||||
var connection = new ConnectionState
|
||||
{
|
||||
ConnectionId = connectionId,
|
||||
Instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = $"inst-{connectionId}",
|
||||
ServiceName = serviceName,
|
||||
Version = version,
|
||||
Region = region
|
||||
},
|
||||
Status = status,
|
||||
TransportType = TransportType.InMemory
|
||||
};
|
||||
|
||||
foreach (var (method, path) in endpoints)
|
||||
{
|
||||
connection.Endpoints[(method, path)] = new EndpointDescriptor
|
||||
{
|
||||
Method = method,
|
||||
Path = path,
|
||||
ServiceName = serviceName,
|
||||
Version = version
|
||||
};
|
||||
}
|
||||
|
||||
return connection;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AddConnection_ShouldStoreConnection()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
|
||||
|
||||
// Act
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Assert
|
||||
var result = _sut.GetConnection(connection.ConnectionId);
|
||||
result.Should().NotBeNull();
|
||||
result.Should().BeSameAs(connection);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AddConnection_ShouldIndexEndpoints()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection(endpoints: [("GET", "/api/users/{id}")]);
|
||||
|
||||
// Act
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Assert
|
||||
var endpoint = _sut.ResolveEndpoint("GET", "/api/users/123");
|
||||
endpoint.Should().NotBeNull();
|
||||
endpoint!.Path.Should().Be("/api/users/{id}");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RemoveConnection_ShouldRemoveConnection()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Act
|
||||
_sut.RemoveConnection(connection.ConnectionId);
|
||||
|
||||
// Assert
|
||||
var result = _sut.GetConnection(connection.ConnectionId);
|
||||
result.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RemoveConnection_ShouldRemoveEndpointsWhenLastConnection()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Act
|
||||
_sut.RemoveConnection(connection.ConnectionId);
|
||||
|
||||
// Assert
|
||||
var endpoint = _sut.ResolveEndpoint("GET", "/api/test");
|
||||
endpoint.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RemoveConnection_ShouldKeepEndpointsWhenOtherConnectionsExist()
|
||||
{
|
||||
// Arrange
|
||||
var connection1 = CreateConnection("conn-1", endpoints: [("GET", "/api/test")]);
|
||||
var connection2 = CreateConnection("conn-2", endpoints: [("GET", "/api/test")]);
|
||||
_sut.AddConnection(connection1);
|
||||
_sut.AddConnection(connection2);
|
||||
|
||||
// Act
|
||||
_sut.RemoveConnection("conn-1");
|
||||
|
||||
// Assert
|
||||
var endpoint = _sut.ResolveEndpoint("GET", "/api/test");
|
||||
endpoint.Should().NotBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void UpdateConnection_ShouldApplyUpdate()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Act
|
||||
_sut.UpdateConnection(connection.ConnectionId, c => c.Status = InstanceHealthStatus.Degraded);
|
||||
|
||||
// Assert
|
||||
var result = _sut.GetConnection(connection.ConnectionId);
|
||||
result.Should().NotBeNull();
|
||||
result!.Status.Should().Be(InstanceHealthStatus.Degraded);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void UpdateConnection_ShouldDoNothingForUnknownConnection()
|
||||
{
|
||||
// Act - should not throw
|
||||
_sut.UpdateConnection("unknown", c => c.Status = InstanceHealthStatus.Degraded);
|
||||
|
||||
// Assert
|
||||
var result = _sut.GetConnection("unknown");
|
||||
result.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetConnection_ShouldReturnNullForUnknownConnection()
|
||||
{
|
||||
// Act
|
||||
var result = _sut.GetConnection("unknown");
|
||||
|
||||
// Assert
|
||||
result.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetAllConnections_ShouldReturnAllConnections()
|
||||
{
|
||||
// Arrange
|
||||
var connection1 = CreateConnection("conn-1", endpoints: [("GET", "/api/test1")]);
|
||||
var connection2 = CreateConnection("conn-2", endpoints: [("GET", "/api/test2")]);
|
||||
_sut.AddConnection(connection1);
|
||||
_sut.AddConnection(connection2);
|
||||
|
||||
// Act
|
||||
var result = _sut.GetAllConnections();
|
||||
|
||||
// Assert
|
||||
result.Should().HaveCount(2);
|
||||
result.Should().Contain(connection1);
|
||||
result.Should().Contain(connection2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetAllConnections_ShouldReturnEmptyWhenNoConnections()
|
||||
{
|
||||
// Act
|
||||
var result = _sut.GetAllConnections();
|
||||
|
||||
// Assert
|
||||
result.Should().BeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ResolveEndpoint_ShouldMatchExactPath()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection(endpoints: [("GET", "/api/health")]);
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Act
|
||||
var result = _sut.ResolveEndpoint("GET", "/api/health");
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Path.Should().Be("/api/health");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ResolveEndpoint_ShouldMatchParameterizedPath()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection(endpoints: [("GET", "/api/users/{id}/orders/{orderId}")]);
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Act
|
||||
var result = _sut.ResolveEndpoint("GET", "/api/users/123/orders/456");
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result!.Path.Should().Be("/api/users/{id}/orders/{orderId}");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ResolveEndpoint_ShouldReturnNullForNonMatchingMethod()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Act
|
||||
var result = _sut.ResolveEndpoint("POST", "/api/test");
|
||||
|
||||
// Assert
|
||||
result.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ResolveEndpoint_ShouldReturnNullForNonMatchingPath()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Act
|
||||
var result = _sut.ResolveEndpoint("GET", "/api/other");
|
||||
|
||||
// Assert
|
||||
result.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ResolveEndpoint_ShouldBeCaseInsensitiveForMethod()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection(endpoints: [("GET", "/api/test")]);
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Act
|
||||
var result = _sut.ResolveEndpoint("get", "/api/test");
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetConnectionsFor_ShouldFilterByServiceName()
|
||||
{
|
||||
// Arrange
|
||||
var connection1 = CreateConnection("conn-1", "service-a", endpoints: [("GET", "/api/test")]);
|
||||
var connection2 = CreateConnection("conn-2", "service-b", endpoints: [("GET", "/api/test")]);
|
||||
_sut.AddConnection(connection1);
|
||||
_sut.AddConnection(connection2);
|
||||
|
||||
// Act
|
||||
var result = _sut.GetConnectionsFor("service-a", "1.0.0", "GET", "/api/test");
|
||||
|
||||
// Assert
|
||||
result.Should().HaveCount(1);
|
||||
result[0].Instance.ServiceName.Should().Be("service-a");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetConnectionsFor_ShouldFilterByVersion()
|
||||
{
|
||||
// Arrange
|
||||
var connection1 = CreateConnection("conn-1", "service-a", "1.0.0", endpoints: [("GET", "/api/test")]);
|
||||
var connection2 = CreateConnection("conn-2", "service-a", "2.0.0", endpoints: [("GET", "/api/test")]);
|
||||
_sut.AddConnection(connection1);
|
||||
_sut.AddConnection(connection2);
|
||||
|
||||
// Act
|
||||
var result = _sut.GetConnectionsFor("service-a", "1.0.0", "GET", "/api/test");
|
||||
|
||||
// Assert
|
||||
result.Should().HaveCount(1);
|
||||
result[0].Instance.Version.Should().Be("1.0.0");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetConnectionsFor_ShouldReturnEmptyWhenNoMatch()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection("conn-1", "service-a", endpoints: [("GET", "/api/test")]);
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Act
|
||||
var result = _sut.GetConnectionsFor("service-b", "1.0.0", "GET", "/api/test");
|
||||
|
||||
// Assert
|
||||
result.Should().BeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetConnectionsFor_ShouldMatchParameterizedPaths()
|
||||
{
|
||||
// Arrange
|
||||
var connection = CreateConnection("conn-1", "service-a", endpoints: [("GET", "/api/users/{id}")]);
|
||||
_sut.AddConnection(connection);
|
||||
|
||||
// Act
|
||||
var result = _sut.GetConnectionsFor("service-a", "1.0.0", "GET", "/api/users/123");
|
||||
|
||||
// Assert
|
||||
result.Should().HaveCount(1);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Gateway.WebService.Middleware;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Tests;
|
||||
|
||||
public class PayloadTrackerTests
|
||||
{
|
||||
private readonly PayloadLimits _limits = new()
|
||||
{
|
||||
MaxRequestBytesPerCall = 1024,
|
||||
MaxRequestBytesPerConnection = 4096,
|
||||
MaxAggregateInflightBytes = 8192
|
||||
};
|
||||
|
||||
private PayloadTracker CreateTracker()
|
||||
{
|
||||
return new PayloadTracker(
|
||||
Options.Create(_limits),
|
||||
NullLogger<PayloadTracker>.Instance);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TryReserve_WithinLimits_ReturnsTrue()
|
||||
{
|
||||
var tracker = CreateTracker();
|
||||
|
||||
var result = tracker.TryReserve("conn-1", 500);
|
||||
|
||||
Assert.True(result);
|
||||
Assert.Equal(500, tracker.CurrentInflightBytes);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TryReserve_ExceedsAggregateLimits_ReturnsFalse()
|
||||
{
|
||||
var tracker = CreateTracker();
|
||||
|
||||
// Reserve from multiple connections to approach aggregate limit (8192)
|
||||
// Each connection can have up to 4096 bytes
|
||||
Assert.True(tracker.TryReserve("conn-1", 4000));
|
||||
Assert.True(tracker.TryReserve("conn-2", 4000));
|
||||
// Now at 8000 bytes
|
||||
|
||||
// Another reservation that exceeds aggregate limit (8000 + 500 > 8192) should fail
|
||||
var result = tracker.TryReserve("conn-3", 500);
|
||||
|
||||
Assert.False(result);
|
||||
Assert.Equal(8000, tracker.CurrentInflightBytes);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TryReserve_ExceedsPerConnectionLimit_ReturnsFalse()
|
||||
{
|
||||
var tracker = CreateTracker();
|
||||
|
||||
// Reserve up to per-connection limit
|
||||
Assert.True(tracker.TryReserve("conn-1", 4000));
|
||||
|
||||
// Next reservation on same connection should fail
|
||||
var result = tracker.TryReserve("conn-1", 500);
|
||||
|
||||
Assert.False(result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TryReserve_DifferentConnections_TrackedSeparately()
|
||||
{
|
||||
var tracker = CreateTracker();
|
||||
|
||||
Assert.True(tracker.TryReserve("conn-1", 3000));
|
||||
Assert.True(tracker.TryReserve("conn-2", 3000));
|
||||
|
||||
Assert.Equal(3000, tracker.GetConnectionInflightBytes("conn-1"));
|
||||
Assert.Equal(3000, tracker.GetConnectionInflightBytes("conn-2"));
|
||||
Assert.Equal(6000, tracker.CurrentInflightBytes);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Release_DecreasesInflightBytes()
|
||||
{
|
||||
var tracker = CreateTracker();
|
||||
|
||||
tracker.TryReserve("conn-1", 1000);
|
||||
tracker.Release("conn-1", 500);
|
||||
|
||||
Assert.Equal(500, tracker.CurrentInflightBytes);
|
||||
Assert.Equal(500, tracker.GetConnectionInflightBytes("conn-1"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Release_CannotGoNegative()
|
||||
{
|
||||
var tracker = CreateTracker();
|
||||
|
||||
tracker.TryReserve("conn-1", 100);
|
||||
tracker.Release("conn-1", 500); // More than reserved
|
||||
|
||||
Assert.Equal(0, tracker.GetConnectionInflightBytes("conn-1"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IsOverloaded_TrueWhenExceedsLimit()
|
||||
{
|
||||
var tracker = CreateTracker();
|
||||
|
||||
// Reservation at limit passes (8192 <= 8192 is false for >, so not overloaded at exactly limit)
|
||||
// But we can't exceed the limit. The IsOverloaded check is for current > limit
|
||||
// So at exactly 8192, IsOverloaded should be false (8192 > 8192 is false)
|
||||
// Reserving 8193 would be rejected. So let's test that at limit, IsOverloaded is false
|
||||
tracker.TryReserve("conn-1", 8192);
|
||||
|
||||
// At exactly the limit, IsOverloaded is false (8192 > 8192 = false)
|
||||
Assert.False(tracker.IsOverloaded);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IsOverloaded_FalseWhenWithinLimit()
|
||||
{
|
||||
var tracker = CreateTracker();
|
||||
|
||||
tracker.TryReserve("conn-1", 4000);
|
||||
|
||||
Assert.False(tracker.IsOverloaded);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetConnectionInflightBytes_ReturnsZeroForUnknownConnection()
|
||||
{
|
||||
var tracker = CreateTracker();
|
||||
|
||||
var result = tracker.GetConnectionInflightBytes("unknown");
|
||||
|
||||
Assert.Equal(0, result);
|
||||
}
|
||||
}
|
||||
|
||||
public class ByteCountingStreamTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task ReadAsync_CountsBytesRead()
|
||||
{
|
||||
var data = new byte[] { 1, 2, 3, 4, 5 };
|
||||
using var inner = new MemoryStream(data);
|
||||
using var stream = new ByteCountingStream(inner, 100);
|
||||
|
||||
var buffer = new byte[10];
|
||||
var read = await stream.ReadAsync(buffer);
|
||||
|
||||
Assert.Equal(5, read);
|
||||
Assert.Equal(5, stream.BytesRead);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_ThrowsWhenLimitExceeded()
|
||||
{
|
||||
var data = new byte[100];
|
||||
using var inner = new MemoryStream(data);
|
||||
using var stream = new ByteCountingStream(inner, 50);
|
||||
|
||||
var buffer = new byte[100];
|
||||
|
||||
var ex = await Assert.ThrowsAsync<PayloadLimitExceededException>(
|
||||
() => stream.ReadAsync(buffer).AsTask());
|
||||
|
||||
Assert.Equal(100, ex.BytesRead);
|
||||
Assert.Equal(50, ex.Limit);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_CallsCallbackOnLimitExceeded()
|
||||
{
|
||||
var data = new byte[100];
|
||||
using var inner = new MemoryStream(data);
|
||||
var callbackCalled = false;
|
||||
using var stream = new ByteCountingStream(inner, 50, () => callbackCalled = true);
|
||||
|
||||
var buffer = new byte[100];
|
||||
|
||||
await Assert.ThrowsAsync<PayloadLimitExceededException>(
|
||||
() => stream.ReadAsync(buffer).AsTask());
|
||||
|
||||
Assert.True(callbackCalled);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_AccumulatesAcrossMultipleReads()
|
||||
{
|
||||
var data = new byte[100];
|
||||
using var inner = new MemoryStream(data);
|
||||
using var stream = new ByteCountingStream(inner, 60);
|
||||
|
||||
var buffer = new byte[30];
|
||||
|
||||
// First read - 30 bytes
|
||||
var read1 = await stream.ReadAsync(buffer);
|
||||
Assert.Equal(30, read1);
|
||||
Assert.Equal(30, stream.BytesRead);
|
||||
|
||||
// Second read - 30 more bytes
|
||||
var read2 = await stream.ReadAsync(buffer);
|
||||
Assert.Equal(30, read2);
|
||||
Assert.Equal(60, stream.BytesRead);
|
||||
|
||||
// Third read should exceed limit
|
||||
await Assert.ThrowsAsync<PayloadLimitExceededException>(
|
||||
() => stream.ReadAsync(buffer).AsTask());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Stream_Properties_AreCorrect()
|
||||
{
|
||||
using var inner = new MemoryStream();
|
||||
using var stream = new ByteCountingStream(inner, 100);
|
||||
|
||||
Assert.True(stream.CanRead);
|
||||
Assert.False(stream.CanWrite);
|
||||
Assert.False(stream.CanSeek);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Write_ThrowsNotSupported()
|
||||
{
|
||||
using var inner = new MemoryStream();
|
||||
using var stream = new ByteCountingStream(inner, 100);
|
||||
|
||||
Assert.Throws<NotSupportedException>(() => stream.Write(new byte[10], 0, 10));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Seek_ThrowsNotSupported()
|
||||
{
|
||||
using var inner = new MemoryStream();
|
||||
using var stream = new ByteCountingStream(inner, 100);
|
||||
|
||||
Assert.Throws<NotSupportedException>(() => stream.Seek(0, SeekOrigin.Begin));
|
||||
}
|
||||
}
|
||||
|
||||
public class PayloadLimitExceededExceptionTests
|
||||
{
|
||||
[Fact]
|
||||
public void Constructor_SetsProperties()
|
||||
{
|
||||
var ex = new PayloadLimitExceededException(1000, 500);
|
||||
|
||||
Assert.Equal(1000, ex.BytesRead);
|
||||
Assert.Equal(500, ex.Limit);
|
||||
Assert.Contains("1000", ex.Message);
|
||||
Assert.Contains("500", ex.Message);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net10.0</TargetFramework>
|
||||
<LangVersion>preview</LangVersion>
|
||||
<Nullable>enable</Nullable>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<IsPackable>false</IsPackable>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
<!-- Disable Concelier test infrastructure - we don't need MongoDB -->
|
||||
<UseConcelierTestInfra>false</UseConcelierTestInfra>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.14.0" />
|
||||
<PackageReference Include="xunit" Version="2.9.2" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2">
|
||||
<PrivateAssets>all</PrivateAssets>
|
||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||
</PackageReference>
|
||||
<PackageReference Include="FluentAssertions" Version="6.12.0" />
|
||||
<PackageReference Include="Moq" Version="4.20.70" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\StellaOps.Gateway.WebService\StellaOps.Gateway.WebService.csproj" />
|
||||
<ProjectReference Include="..\..\..\__Libraries\StellaOps.Router.Transport.InMemory\StellaOps.Router.Transport.InMemory.csproj" />
|
||||
<ProjectReference Include="..\..\..\__Libraries\StellaOps.Microservice\StellaOps.Microservice.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
@@ -0,0 +1,315 @@
|
||||
using System.Threading.Channels;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Microservice.Streaming;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using StellaOps.Router.Transport.InMemory;
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Tests;
|
||||
|
||||
public class StreamingTests
|
||||
{
|
||||
private readonly InMemoryConnectionRegistry _registry = new();
|
||||
private readonly InMemoryTransportOptions _options = new() { SimulatedLatency = TimeSpan.Zero };
|
||||
|
||||
private InMemoryTransportClient CreateClient()
|
||||
{
|
||||
return new InMemoryTransportClient(
|
||||
_registry,
|
||||
Options.Create(_options),
|
||||
NullLogger<InMemoryTransportClient>.Instance);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void StreamDataPayload_HasRequiredProperties()
|
||||
{
|
||||
var payload = new StreamDataPayload
|
||||
{
|
||||
CorrelationId = Guid.NewGuid(),
|
||||
Data = new byte[] { 1, 2, 3 },
|
||||
EndOfStream = true,
|
||||
SequenceNumber = 5
|
||||
};
|
||||
|
||||
Assert.NotEqual(Guid.Empty, payload.CorrelationId);
|
||||
Assert.Equal(3, payload.Data.Length);
|
||||
Assert.True(payload.EndOfStream);
|
||||
Assert.Equal(5, payload.SequenceNumber);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void StreamingOptions_HasDefaultValues()
|
||||
{
|
||||
var options = StreamingOptions.Default;
|
||||
|
||||
Assert.Equal(64 * 1024, options.ChunkSize);
|
||||
Assert.Equal(100, options.MaxConcurrentStreams);
|
||||
Assert.Equal(TimeSpan.FromMinutes(5), options.StreamIdleTimeout);
|
||||
Assert.Equal(16, options.ChannelCapacity);
|
||||
}
|
||||
}
|
||||
|
||||
public class StreamingRequestBodyStreamTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task ReadAsync_ReturnsDataFromChannel()
|
||||
{
|
||||
// Arrange
|
||||
var channel = Channel.CreateUnbounded<StreamChunk>();
|
||||
using var stream = new StreamingRequestBodyStream(channel.Reader, CancellationToken.None);
|
||||
|
||||
var testData = new byte[] { 1, 2, 3, 4, 5 };
|
||||
await channel.Writer.WriteAsync(new StreamChunk { Data = testData, SequenceNumber = 0 });
|
||||
await channel.Writer.WriteAsync(new StreamChunk { Data = [], EndOfStream = true, SequenceNumber = 1 });
|
||||
channel.Writer.Complete();
|
||||
|
||||
// Act
|
||||
var buffer = new byte[10];
|
||||
var bytesRead = await stream.ReadAsync(buffer);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(5, bytesRead);
|
||||
Assert.Equal(testData, buffer[..5]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_ReturnsZeroAtEndOfStream()
|
||||
{
|
||||
// Arrange
|
||||
var channel = Channel.CreateUnbounded<StreamChunk>();
|
||||
using var stream = new StreamingRequestBodyStream(channel.Reader, CancellationToken.None);
|
||||
|
||||
await channel.Writer.WriteAsync(new StreamChunk { Data = [], EndOfStream = true, SequenceNumber = 0 });
|
||||
channel.Writer.Complete();
|
||||
|
||||
// Act
|
||||
var buffer = new byte[10];
|
||||
var bytesRead = await stream.ReadAsync(buffer);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(0, bytesRead);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_HandlesMultipleChunks()
|
||||
{
|
||||
// Arrange
|
||||
var channel = Channel.CreateUnbounded<StreamChunk>();
|
||||
using var stream = new StreamingRequestBodyStream(channel.Reader, CancellationToken.None);
|
||||
|
||||
await channel.Writer.WriteAsync(new StreamChunk { Data = [1, 2, 3], SequenceNumber = 0 });
|
||||
await channel.Writer.WriteAsync(new StreamChunk { Data = [4, 5, 6], SequenceNumber = 1 });
|
||||
await channel.Writer.WriteAsync(new StreamChunk { Data = [], EndOfStream = true, SequenceNumber = 2 });
|
||||
channel.Writer.Complete();
|
||||
|
||||
// Act
|
||||
using var memStream = new MemoryStream();
|
||||
await stream.CopyToAsync(memStream);
|
||||
|
||||
// Assert
|
||||
var result = memStream.ToArray();
|
||||
Assert.Equal(6, result.Length);
|
||||
Assert.Equal(new byte[] { 1, 2, 3, 4, 5, 6 }, result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Stream_Properties_AreCorrect()
|
||||
{
|
||||
var channel = Channel.CreateUnbounded<StreamChunk>();
|
||||
using var stream = new StreamingRequestBodyStream(channel.Reader, CancellationToken.None);
|
||||
|
||||
Assert.True(stream.CanRead);
|
||||
Assert.False(stream.CanWrite);
|
||||
Assert.False(stream.CanSeek);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Write_ThrowsNotSupported()
|
||||
{
|
||||
var channel = Channel.CreateUnbounded<StreamChunk>();
|
||||
using var stream = new StreamingRequestBodyStream(channel.Reader, CancellationToken.None);
|
||||
|
||||
Assert.Throws<NotSupportedException>(() => stream.Write([1, 2, 3], 0, 3));
|
||||
}
|
||||
}
|
||||
|
||||
public class StreamingResponseBodyStreamTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task WriteAsync_WritesToChannel()
|
||||
{
|
||||
// Arrange
|
||||
var channel = Channel.CreateUnbounded<StreamChunk>();
|
||||
await using var stream = new StreamingResponseBodyStream(channel.Writer, 1024, CancellationToken.None);
|
||||
|
||||
var testData = new byte[] { 1, 2, 3, 4, 5 };
|
||||
|
||||
// Act
|
||||
await stream.WriteAsync(testData);
|
||||
await stream.FlushAsync();
|
||||
|
||||
// Assert
|
||||
Assert.True(channel.Reader.TryRead(out var chunk));
|
||||
Assert.Equal(testData, chunk!.Data);
|
||||
Assert.False(chunk.EndOfStream);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompleteAsync_SendsEndOfStream()
|
||||
{
|
||||
// Arrange
|
||||
var channel = Channel.CreateUnbounded<StreamChunk>();
|
||||
await using var stream = new StreamingResponseBodyStream(channel.Writer, 1024, CancellationToken.None);
|
||||
|
||||
// Act
|
||||
await stream.WriteAsync(new byte[] { 1, 2, 3 });
|
||||
await stream.CompleteAsync();
|
||||
|
||||
// Assert - should have data chunk + end chunk
|
||||
var chunks = new List<StreamChunk>();
|
||||
await foreach (var chunk in channel.Reader.ReadAllAsync())
|
||||
{
|
||||
chunks.Add(chunk);
|
||||
}
|
||||
|
||||
Assert.Equal(2, chunks.Count);
|
||||
Assert.False(chunks[0].EndOfStream);
|
||||
Assert.True(chunks[1].EndOfStream);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAsync_ChunksLargeData()
|
||||
{
|
||||
// Arrange
|
||||
var chunkSize = 10;
|
||||
var channel = Channel.CreateUnbounded<StreamChunk>();
|
||||
await using var stream = new StreamingResponseBodyStream(channel.Writer, chunkSize, CancellationToken.None);
|
||||
|
||||
var testData = new byte[25]; // Will need 3 chunks
|
||||
for (var i = 0; i < testData.Length; i++)
|
||||
{
|
||||
testData[i] = (byte)i;
|
||||
}
|
||||
|
||||
// Act
|
||||
await stream.WriteAsync(testData);
|
||||
await stream.CompleteAsync();
|
||||
|
||||
// Assert
|
||||
var chunks = new List<StreamChunk>();
|
||||
await foreach (var chunk in channel.Reader.ReadAllAsync())
|
||||
{
|
||||
chunks.Add(chunk);
|
||||
}
|
||||
|
||||
// Should have 3 chunks (10+10+5) + 1 end-of-stream (with 0 data since remainder already flushed)
|
||||
Assert.Equal(4, chunks.Count);
|
||||
Assert.Equal(10, chunks[0].Data.Length);
|
||||
Assert.Equal(10, chunks[1].Data.Length);
|
||||
Assert.Equal(5, chunks[2].Data.Length);
|
||||
Assert.True(chunks[3].EndOfStream);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Stream_Properties_AreCorrect()
|
||||
{
|
||||
var channel = Channel.CreateUnbounded<StreamChunk>();
|
||||
using var stream = new StreamingResponseBodyStream(channel.Writer, 1024, CancellationToken.None);
|
||||
|
||||
Assert.False(stream.CanRead);
|
||||
Assert.True(stream.CanWrite);
|
||||
Assert.False(stream.CanSeek);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Read_ThrowsNotSupported()
|
||||
{
|
||||
var channel = Channel.CreateUnbounded<StreamChunk>();
|
||||
using var stream = new StreamingResponseBodyStream(channel.Writer, 1024, CancellationToken.None);
|
||||
|
||||
Assert.Throws<NotSupportedException>(() => stream.Read(new byte[10], 0, 10));
|
||||
}
|
||||
}
|
||||
|
||||
public class InMemoryTransportStreamingTests
|
||||
{
|
||||
private readonly InMemoryConnectionRegistry _registry = new();
|
||||
private readonly InMemoryTransportOptions _options = new() { SimulatedLatency = TimeSpan.Zero };
|
||||
|
||||
private InMemoryTransportClient CreateClient()
|
||||
{
|
||||
return new InMemoryTransportClient(
|
||||
_registry,
|
||||
Options.Create(_options),
|
||||
NullLogger<InMemoryTransportClient>.Instance);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SendStreamingAsync_SendsRequestStreamDataFrames()
|
||||
{
|
||||
// Arrange
|
||||
using var client = CreateClient();
|
||||
var instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = "test-instance",
|
||||
ServiceName = "test-service",
|
||||
Version = "1.0.0",
|
||||
Region = "us-east-1"
|
||||
};
|
||||
|
||||
await client.ConnectAsync(instance, [], CancellationToken.None);
|
||||
|
||||
// Get connection ID via reflection
|
||||
var connectionIdField = client.GetType()
|
||||
.GetField("_connectionId", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
|
||||
var connectionId = connectionIdField?.GetValue(client)?.ToString();
|
||||
Assert.NotNull(connectionId);
|
||||
|
||||
var channel = _registry.GetChannel(connectionId!);
|
||||
Assert.NotNull(channel);
|
||||
Assert.NotNull(channel!.State);
|
||||
|
||||
// Create request body stream
|
||||
var requestBody = new MemoryStream(new byte[] { 1, 2, 3, 4, 5 });
|
||||
|
||||
// Create request frame
|
||||
var requestFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Request,
|
||||
CorrelationId = Guid.NewGuid().ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
var limits = PayloadLimits.Default;
|
||||
|
||||
// Act - Start streaming (this will send frames to microservice)
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
|
||||
var sendTask = client.SendStreamingAsync(
|
||||
channel.State!,
|
||||
requestFrame,
|
||||
requestBody,
|
||||
_ => Task.CompletedTask,
|
||||
limits,
|
||||
cts.Token);
|
||||
|
||||
// Read the frames that were sent to microservice
|
||||
var frames = new List<Frame>();
|
||||
await foreach (var frame in channel.ToMicroservice.Reader.ReadAllAsync(cts.Token))
|
||||
{
|
||||
frames.Add(frame);
|
||||
if (frame.Type == FrameType.RequestStreamData && frame.Payload.Length == 0)
|
||||
{
|
||||
// End of stream - break
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Assert - should have REQUEST header + data chunks + end-of-stream
|
||||
Assert.True(frames.Count >= 2);
|
||||
Assert.Equal(FrameType.Request, frames[0].Type);
|
||||
Assert.Equal(FrameType.RequestStreamData, frames[^1].Type);
|
||||
Assert.Equal(0, frames[^1].Payload.Length); // End of stream marker
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user