Refactor code structure for improved readability and maintainability; optimize performance in key functions.
This commit is contained in:
@@ -0,0 +1,106 @@
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Gateway.Configuration;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Services;
|
||||
|
||||
public sealed class GatewayHealthMonitorService : BackgroundService
|
||||
{
|
||||
private readonly IGlobalRoutingState _routingState;
|
||||
private readonly IOptions<HealthOptions> _options;
|
||||
private readonly ILogger<GatewayHealthMonitorService> _logger;
|
||||
|
||||
public GatewayHealthMonitorService(
|
||||
IGlobalRoutingState routingState,
|
||||
IOptions<HealthOptions> options,
|
||||
ILogger<GatewayHealthMonitorService> logger)
|
||||
{
|
||||
_routingState = routingState;
|
||||
_options = options;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"Health monitor started. Stale threshold: {StaleThreshold}, Check interval: {CheckInterval}",
|
||||
_options.Value.StaleThreshold,
|
||||
_options.Value.CheckInterval);
|
||||
|
||||
while (!stoppingToken.IsCancellationRequested)
|
||||
{
|
||||
try
|
||||
{
|
||||
await Task.Delay(_options.Value.CheckInterval, stoppingToken);
|
||||
CheckStaleConnections();
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
break;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error in health monitor loop");
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogInformation("Health monitor stopped");
|
||||
}
|
||||
|
||||
private void CheckStaleConnections()
|
||||
{
|
||||
var staleThreshold = _options.Value.StaleThreshold;
|
||||
var degradedThreshold = _options.Value.DegradedThreshold;
|
||||
var now = DateTime.UtcNow;
|
||||
var staleCount = 0;
|
||||
var degradedCount = 0;
|
||||
|
||||
foreach (var connection in _routingState.GetAllConnections())
|
||||
{
|
||||
if (connection.Status == InstanceHealthStatus.Draining)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var age = now - connection.LastHeartbeatUtc;
|
||||
|
||||
if (age > staleThreshold && connection.Status != InstanceHealthStatus.Unhealthy)
|
||||
{
|
||||
_routingState.UpdateConnection(connection.ConnectionId, c =>
|
||||
c.Status = InstanceHealthStatus.Unhealthy);
|
||||
|
||||
_logger.LogWarning(
|
||||
"Instance {InstanceId} ({ServiceName}/{Version}) marked Unhealthy: no heartbeat for {Age:g}",
|
||||
connection.Instance.InstanceId,
|
||||
connection.Instance.ServiceName,
|
||||
connection.Instance.Version,
|
||||
age);
|
||||
|
||||
staleCount++;
|
||||
}
|
||||
else if (age > degradedThreshold && connection.Status == InstanceHealthStatus.Healthy)
|
||||
{
|
||||
_routingState.UpdateConnection(connection.ConnectionId, c =>
|
||||
c.Status = InstanceHealthStatus.Degraded);
|
||||
|
||||
_logger.LogWarning(
|
||||
"Instance {InstanceId} ({ServiceName}/{Version}) marked Degraded: delayed heartbeat ({Age:g})",
|
||||
connection.Instance.InstanceId,
|
||||
connection.Instance.ServiceName,
|
||||
connection.Instance.Version,
|
||||
age);
|
||||
|
||||
degradedCount++;
|
||||
}
|
||||
}
|
||||
|
||||
if (staleCount > 0 || degradedCount > 0)
|
||||
{
|
||||
_logger.LogDebug(
|
||||
"Health check completed: {StaleCount} stale, {DegradedCount} degraded",
|
||||
staleCount,
|
||||
degradedCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,458 @@
|
||||
using System.Linq;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Gateway.WebService.Authorization;
|
||||
using StellaOps.Gateway.WebService.Configuration;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using StellaOps.Router.Gateway.OpenApi;
|
||||
using StellaOps.Router.Transport.Tcp;
|
||||
using StellaOps.Router.Transport.Tls;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Services;
|
||||
|
||||
public sealed class GatewayHostedService : IHostedService
|
||||
{
|
||||
private readonly TcpTransportServer _tcpServer;
|
||||
private readonly TlsTransportServer _tlsServer;
|
||||
private readonly IGlobalRoutingState _routingState;
|
||||
private readonly GatewayTransportClient _transportClient;
|
||||
private readonly IEffectiveClaimsStore _claimsStore;
|
||||
private readonly IRouterOpenApiDocumentCache? _openApiCache;
|
||||
private readonly IOptions<GatewayOptions> _options;
|
||||
private readonly GatewayServiceStatus _status;
|
||||
private readonly ILogger<GatewayHostedService> _logger;
|
||||
private readonly JsonSerializerOptions _jsonOptions;
|
||||
private bool _tcpEnabled;
|
||||
private bool _tlsEnabled;
|
||||
|
||||
public GatewayHostedService(
|
||||
TcpTransportServer tcpServer,
|
||||
TlsTransportServer tlsServer,
|
||||
IGlobalRoutingState routingState,
|
||||
GatewayTransportClient transportClient,
|
||||
IEffectiveClaimsStore claimsStore,
|
||||
IOptions<GatewayOptions> options,
|
||||
GatewayServiceStatus status,
|
||||
ILogger<GatewayHostedService> logger,
|
||||
IRouterOpenApiDocumentCache? openApiCache = null)
|
||||
{
|
||||
_tcpServer = tcpServer;
|
||||
_tlsServer = tlsServer;
|
||||
_routingState = routingState;
|
||||
_transportClient = transportClient;
|
||||
_claimsStore = claimsStore;
|
||||
_options = options;
|
||||
_status = status;
|
||||
_logger = logger;
|
||||
_openApiCache = openApiCache;
|
||||
_jsonOptions = new JsonSerializerOptions
|
||||
{
|
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
||||
WriteIndented = false
|
||||
};
|
||||
}
|
||||
|
||||
public async Task StartAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
var options = _options.Value;
|
||||
_tcpEnabled = options.Transports.Tcp.Enabled;
|
||||
_tlsEnabled = options.Transports.Tls.Enabled;
|
||||
|
||||
if (!_tcpEnabled && !_tlsEnabled)
|
||||
{
|
||||
_logger.LogWarning("No transports enabled; gateway will not accept microservice connections.");
|
||||
_status.MarkStarted();
|
||||
_status.MarkReady();
|
||||
return;
|
||||
}
|
||||
|
||||
if (_tcpEnabled)
|
||||
{
|
||||
_tcpServer.OnFrame += HandleTcpFrame;
|
||||
_tcpServer.OnDisconnection += HandleTcpDisconnection;
|
||||
await _tcpServer.StartAsync(cancellationToken);
|
||||
_logger.LogInformation("TCP transport started on port {Port}", options.Transports.Tcp.Port);
|
||||
}
|
||||
|
||||
if (_tlsEnabled)
|
||||
{
|
||||
_tlsServer.OnFrame += HandleTlsFrame;
|
||||
_tlsServer.OnDisconnection += HandleTlsDisconnection;
|
||||
await _tlsServer.StartAsync(cancellationToken);
|
||||
_logger.LogInformation("TLS transport started on port {Port}", options.Transports.Tls.Port);
|
||||
}
|
||||
|
||||
_status.MarkStarted();
|
||||
_status.MarkReady();
|
||||
}
|
||||
|
||||
public async Task StopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
_status.MarkNotReady();
|
||||
|
||||
foreach (var connection in _routingState.GetAllConnections())
|
||||
{
|
||||
_routingState.UpdateConnection(connection.ConnectionId, c => c.Status = InstanceHealthStatus.Draining);
|
||||
}
|
||||
|
||||
if (_tcpEnabled)
|
||||
{
|
||||
await _tcpServer.StopAsync(cancellationToken);
|
||||
_tcpServer.OnFrame -= HandleTcpFrame;
|
||||
_tcpServer.OnDisconnection -= HandleTcpDisconnection;
|
||||
}
|
||||
|
||||
if (_tlsEnabled)
|
||||
{
|
||||
await _tlsServer.StopAsync(cancellationToken);
|
||||
_tlsServer.OnFrame -= HandleTlsFrame;
|
||||
_tlsServer.OnDisconnection -= HandleTlsDisconnection;
|
||||
}
|
||||
}
|
||||
|
||||
private void HandleTcpFrame(string connectionId, Frame frame)
|
||||
{
|
||||
_ = HandleFrameAsync(TransportType.Tcp, connectionId, frame);
|
||||
}
|
||||
|
||||
private void HandleTlsFrame(string connectionId, Frame frame)
|
||||
{
|
||||
_ = HandleFrameAsync(TransportType.Tls, connectionId, frame);
|
||||
}
|
||||
|
||||
private void HandleTcpDisconnection(string connectionId)
|
||||
{
|
||||
HandleDisconnect(connectionId);
|
||||
}
|
||||
|
||||
private void HandleTlsDisconnection(string connectionId)
|
||||
{
|
||||
HandleDisconnect(connectionId);
|
||||
}
|
||||
|
||||
private async Task HandleFrameAsync(TransportType transportType, string connectionId, Frame frame)
|
||||
{
|
||||
try
|
||||
{
|
||||
switch (frame.Type)
|
||||
{
|
||||
case FrameType.Hello:
|
||||
await HandleHelloAsync(transportType, connectionId, frame);
|
||||
break;
|
||||
case FrameType.Heartbeat:
|
||||
await HandleHeartbeatAsync(connectionId, frame);
|
||||
break;
|
||||
case FrameType.Response:
|
||||
case FrameType.ResponseStreamData:
|
||||
_transportClient.HandleResponseFrame(frame);
|
||||
break;
|
||||
case FrameType.Cancel:
|
||||
_logger.LogDebug("Received CANCEL for {ConnectionId} correlation {CorrelationId}", connectionId, frame.CorrelationId);
|
||||
break;
|
||||
default:
|
||||
_logger.LogDebug("Ignoring frame type {FrameType} from {ConnectionId}", frame.Type, connectionId);
|
||||
break;
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error handling frame {FrameType} from {ConnectionId}", frame.Type, connectionId);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task HandleHelloAsync(TransportType transportType, string connectionId, Frame frame)
|
||||
{
|
||||
if (!TryParseHelloPayload(frame, out var payload, out var parseError))
|
||||
{
|
||||
_logger.LogWarning("Invalid HELLO payload for {ConnectionId}: {Error}", connectionId, parseError);
|
||||
CloseConnection(transportType, connectionId);
|
||||
return;
|
||||
}
|
||||
|
||||
if (payload is not null && !TryValidateHelloPayload(payload, out var validationError))
|
||||
{
|
||||
_logger.LogWarning("HELLO validation failed for {ConnectionId}: {Error}", connectionId, validationError);
|
||||
CloseConnection(transportType, connectionId);
|
||||
return;
|
||||
}
|
||||
|
||||
var state = payload is null
|
||||
? BuildFallbackState(transportType, connectionId)
|
||||
: BuildConnectionState(transportType, connectionId, payload);
|
||||
|
||||
_routingState.AddConnection(state);
|
||||
|
||||
if (payload is not null)
|
||||
{
|
||||
_claimsStore.UpdateFromMicroservice(payload.Instance.ServiceName, payload.Endpoints);
|
||||
}
|
||||
|
||||
_openApiCache?.Invalidate();
|
||||
|
||||
_logger.LogInformation(
|
||||
"Connection registered: {ConnectionId} service={ServiceName} version={Version}",
|
||||
connectionId,
|
||||
state.Instance.ServiceName,
|
||||
state.Instance.Version);
|
||||
|
||||
await Task.CompletedTask;
|
||||
}
|
||||
|
||||
private async Task HandleHeartbeatAsync(string connectionId, Frame frame)
|
||||
{
|
||||
if (!_routingState.GetAllConnections().Any(c => c.ConnectionId == connectionId))
|
||||
{
|
||||
_logger.LogDebug("Heartbeat received for unknown connection {ConnectionId}", connectionId);
|
||||
return;
|
||||
}
|
||||
|
||||
if (TryParseHeartbeatPayload(frame, out var payload))
|
||||
{
|
||||
_routingState.UpdateConnection(connectionId, conn =>
|
||||
{
|
||||
conn.LastHeartbeatUtc = DateTime.UtcNow;
|
||||
conn.Status = payload.Status;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
_routingState.UpdateConnection(connectionId, conn =>
|
||||
{
|
||||
conn.LastHeartbeatUtc = DateTime.UtcNow;
|
||||
});
|
||||
}
|
||||
|
||||
await Task.CompletedTask;
|
||||
}
|
||||
|
||||
private void HandleDisconnect(string connectionId)
|
||||
{
|
||||
var connection = _routingState.GetConnection(connectionId);
|
||||
if (connection is null)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
_routingState.RemoveConnection(connectionId);
|
||||
_openApiCache?.Invalidate();
|
||||
|
||||
var serviceName = connection.Instance.ServiceName;
|
||||
if (!string.IsNullOrWhiteSpace(serviceName))
|
||||
{
|
||||
var remaining = _routingState.GetAllConnections()
|
||||
.Any(c => string.Equals(c.Instance.ServiceName, serviceName, StringComparison.OrdinalIgnoreCase));
|
||||
|
||||
if (!remaining)
|
||||
{
|
||||
_claimsStore.RemoveService(serviceName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private bool TryParseHelloPayload(Frame frame, out HelloPayload? payload, out string? error)
|
||||
{
|
||||
payload = null;
|
||||
error = null;
|
||||
|
||||
if (frame.Payload.IsEmpty)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
payload = JsonSerializer.Deserialize<HelloPayload>(frame.Payload.Span, _jsonOptions);
|
||||
if (payload is null)
|
||||
{
|
||||
error = "HELLO payload missing";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
catch (JsonException ex)
|
||||
{
|
||||
error = ex.Message;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private bool TryParseHeartbeatPayload(Frame frame, out HeartbeatPayload payload)
|
||||
{
|
||||
payload = new HeartbeatPayload
|
||||
{
|
||||
InstanceId = string.Empty,
|
||||
Status = InstanceHealthStatus.Healthy,
|
||||
TimestampUtc = DateTime.UtcNow
|
||||
};
|
||||
|
||||
if (frame.Payload.IsEmpty)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var parsed = JsonSerializer.Deserialize<HeartbeatPayload>(frame.Payload.Span, _jsonOptions);
|
||||
if (parsed is null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
payload = parsed;
|
||||
return true;
|
||||
}
|
||||
catch (JsonException)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static bool TryValidateHelloPayload(HelloPayload payload, out string error)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(payload.Instance.ServiceName))
|
||||
{
|
||||
error = "Instance.ServiceName is required";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(payload.Instance.Version))
|
||||
{
|
||||
error = "Instance.Version is required";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(payload.Instance.Region))
|
||||
{
|
||||
error = "Instance.Region is required";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(payload.Instance.InstanceId))
|
||||
{
|
||||
error = "Instance.InstanceId is required";
|
||||
return false;
|
||||
}
|
||||
|
||||
var seen = new HashSet<(string Method, string Path)>(new EndpointKeyComparer());
|
||||
|
||||
foreach (var endpoint in payload.Endpoints)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(endpoint.Method))
|
||||
{
|
||||
error = "Endpoint.Method is required";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(endpoint.Path) || !endpoint.Path.StartsWith('/'))
|
||||
{
|
||||
error = "Endpoint.Path must start with '/'";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!string.Equals(endpoint.ServiceName, payload.Instance.ServiceName, StringComparison.OrdinalIgnoreCase) ||
|
||||
!string.Equals(endpoint.Version, payload.Instance.Version, StringComparison.Ordinal))
|
||||
{
|
||||
error = "Endpoint.ServiceName/Version must match HelloPayload.Instance";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!seen.Add((endpoint.Method, endpoint.Path)))
|
||||
{
|
||||
error = $"Duplicate endpoint registration for {endpoint.Method} {endpoint.Path}";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (endpoint.SchemaInfo is not null)
|
||||
{
|
||||
if (endpoint.SchemaInfo.RequestSchemaId is not null &&
|
||||
!payload.Schemas.ContainsKey(endpoint.SchemaInfo.RequestSchemaId))
|
||||
{
|
||||
error = $"Endpoint schema reference missing: requestSchemaId='{endpoint.SchemaInfo.RequestSchemaId}'";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (endpoint.SchemaInfo.ResponseSchemaId is not null &&
|
||||
!payload.Schemas.ContainsKey(endpoint.SchemaInfo.ResponseSchemaId))
|
||||
{
|
||||
error = $"Endpoint schema reference missing: responseSchemaId='{endpoint.SchemaInfo.ResponseSchemaId}'";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
error = string.Empty;
|
||||
return true;
|
||||
}
|
||||
|
||||
private static ConnectionState BuildFallbackState(TransportType transportType, string connectionId)
|
||||
{
|
||||
return new ConnectionState
|
||||
{
|
||||
ConnectionId = connectionId,
|
||||
Instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = connectionId,
|
||||
ServiceName = "unknown",
|
||||
Version = "unknown",
|
||||
Region = "unknown"
|
||||
},
|
||||
Status = InstanceHealthStatus.Healthy,
|
||||
LastHeartbeatUtc = DateTime.UtcNow,
|
||||
TransportType = transportType
|
||||
};
|
||||
}
|
||||
|
||||
private static ConnectionState BuildConnectionState(TransportType transportType, string connectionId, HelloPayload payload)
|
||||
{
|
||||
var state = new ConnectionState
|
||||
{
|
||||
ConnectionId = connectionId,
|
||||
Instance = payload.Instance,
|
||||
Status = InstanceHealthStatus.Healthy,
|
||||
LastHeartbeatUtc = DateTime.UtcNow,
|
||||
TransportType = transportType,
|
||||
Schemas = payload.Schemas,
|
||||
OpenApiInfo = payload.OpenApiInfo
|
||||
};
|
||||
|
||||
foreach (var endpoint in payload.Endpoints)
|
||||
{
|
||||
state.Endpoints[(endpoint.Method, endpoint.Path)] = endpoint;
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
private void CloseConnection(TransportType transportType, string connectionId)
|
||||
{
|
||||
if (transportType == TransportType.Tcp)
|
||||
{
|
||||
_tcpServer.GetConnection(connectionId)?.Close();
|
||||
return;
|
||||
}
|
||||
|
||||
if (transportType == TransportType.Tls)
|
||||
{
|
||||
_tlsServer.GetConnection(connectionId)?.Close();
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class EndpointKeyComparer : IEqualityComparer<(string Method, string Path)>
|
||||
{
|
||||
public bool Equals((string Method, string Path) x, (string Method, string Path) y)
|
||||
{
|
||||
return string.Equals(x.Method, y.Method, StringComparison.OrdinalIgnoreCase) &&
|
||||
string.Equals(x.Path, y.Path, StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
|
||||
public int GetHashCode((string Method, string Path) obj)
|
||||
{
|
||||
return HashCode.Combine(
|
||||
StringComparer.OrdinalIgnoreCase.GetHashCode(obj.Method),
|
||||
StringComparer.OrdinalIgnoreCase.GetHashCode(obj.Path));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
using System.Diagnostics.Metrics;
|
||||
using System.Linq;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Services;
|
||||
|
||||
public sealed class GatewayMetrics
|
||||
{
|
||||
public const string MeterName = "StellaOps.Gateway.WebService";
|
||||
|
||||
private static readonly Meter Meter = new(MeterName, "1.0.0");
|
||||
private readonly IGlobalRoutingState _routingState;
|
||||
|
||||
public GatewayMetrics(IGlobalRoutingState routingState)
|
||||
{
|
||||
_routingState = routingState;
|
||||
|
||||
Meter.CreateObservableGauge(
|
||||
"gateway_active_connections",
|
||||
() => GetActiveConnections(),
|
||||
description: "Number of active microservice connections.");
|
||||
|
||||
Meter.CreateObservableGauge(
|
||||
"gateway_registered_endpoints",
|
||||
() => GetRegisteredEndpoints(),
|
||||
description: "Number of registered endpoints across all connections.");
|
||||
}
|
||||
|
||||
public long GetActiveConnections()
|
||||
{
|
||||
return _routingState.GetAllConnections().Count;
|
||||
}
|
||||
|
||||
public long GetRegisteredEndpoints()
|
||||
{
|
||||
return _routingState.GetAllConnections().Sum(c => c.Endpoints.Count);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
using System.Threading;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Services;
|
||||
|
||||
public sealed class GatewayServiceStatus
|
||||
{
|
||||
private int _started;
|
||||
private int _ready;
|
||||
|
||||
public bool IsStarted => Volatile.Read(ref _started) == 1;
|
||||
|
||||
public bool IsReady => Volatile.Read(ref _ready) == 1;
|
||||
|
||||
public void MarkStarted()
|
||||
{
|
||||
Volatile.Write(ref _started, 1);
|
||||
}
|
||||
|
||||
public void MarkReady()
|
||||
{
|
||||
Volatile.Write(ref _ready, 1);
|
||||
}
|
||||
|
||||
public void MarkNotReady()
|
||||
{
|
||||
Volatile.Write(ref _ready, 0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
using System.Buffers;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Threading.Channels;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using StellaOps.Router.Transport.Tcp;
|
||||
using StellaOps.Router.Transport.Tls;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Services;
|
||||
|
||||
public sealed class GatewayTransportClient : ITransportClient
|
||||
{
|
||||
private readonly TcpTransportServer _tcpServer;
|
||||
private readonly TlsTransportServer _tlsServer;
|
||||
private readonly ILogger<GatewayTransportClient> _logger;
|
||||
private readonly ConcurrentDictionary<string, TaskCompletionSource<Frame>> _pendingRequests = new();
|
||||
private readonly ConcurrentDictionary<string, Channel<Frame>> _streamingResponses = new();
|
||||
|
||||
public GatewayTransportClient(
|
||||
TcpTransportServer tcpServer,
|
||||
TlsTransportServer tlsServer,
|
||||
ILogger<GatewayTransportClient> logger)
|
||||
{
|
||||
_tcpServer = tcpServer;
|
||||
_tlsServer = tlsServer;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public async Task<Frame> SendRequestAsync(
|
||||
ConnectionState connection,
|
||||
Frame requestFrame,
|
||||
TimeSpan timeout,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var correlationId = EnsureCorrelationId(requestFrame);
|
||||
var frame = requestFrame with { CorrelationId = correlationId };
|
||||
|
||||
var tcs = new TaskCompletionSource<Frame>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
if (!_pendingRequests.TryAdd(correlationId, tcs))
|
||||
{
|
||||
throw new InvalidOperationException($"Duplicate correlation ID {correlationId}");
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
await SendFrameAsync(connection, frame, cancellationToken);
|
||||
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
timeoutCts.CancelAfter(timeout);
|
||||
|
||||
return await tcs.Task.WaitAsync(timeoutCts.Token);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_pendingRequests.TryRemove(correlationId, out _);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task SendCancelAsync(ConnectionState connection, Guid correlationId, string? reason = null)
|
||||
{
|
||||
var frame = new Frame
|
||||
{
|
||||
Type = FrameType.Cancel,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
await SendFrameAsync(connection, frame, CancellationToken.None);
|
||||
}
|
||||
|
||||
public async Task SendStreamingAsync(
|
||||
ConnectionState connection,
|
||||
Frame requestHeader,
|
||||
Stream requestBody,
|
||||
Func<Stream, Task> readResponseBody,
|
||||
PayloadLimits limits,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var correlationId = EnsureCorrelationId(requestHeader);
|
||||
var headerFrame = requestHeader with
|
||||
{
|
||||
Type = FrameType.Request,
|
||||
CorrelationId = correlationId
|
||||
};
|
||||
|
||||
var channel = Channel.CreateUnbounded<Frame>(new UnboundedChannelOptions
|
||||
{
|
||||
SingleReader = true,
|
||||
SingleWriter = false
|
||||
});
|
||||
|
||||
if (!_streamingResponses.TryAdd(correlationId, channel))
|
||||
{
|
||||
throw new InvalidOperationException($"Duplicate correlation ID {correlationId}");
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
await SendFrameAsync(connection, headerFrame, cancellationToken);
|
||||
await StreamRequestBodyAsync(connection, correlationId, requestBody, limits, cancellationToken);
|
||||
|
||||
using var responseStream = new MemoryStream();
|
||||
await ReadStreamingResponseAsync(channel.Reader, responseStream, cancellationToken);
|
||||
responseStream.Position = 0;
|
||||
await readResponseBody(responseStream);
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (_streamingResponses.TryRemove(correlationId, out var removed))
|
||||
{
|
||||
removed.Writer.TryComplete();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void HandleResponseFrame(Frame frame)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(frame.CorrelationId))
|
||||
{
|
||||
_logger.LogDebug("Ignoring response frame without correlation ID");
|
||||
return;
|
||||
}
|
||||
|
||||
if (_pendingRequests.TryGetValue(frame.CorrelationId, out var pending))
|
||||
{
|
||||
pending.TrySetResult(frame);
|
||||
return;
|
||||
}
|
||||
|
||||
if (_streamingResponses.TryGetValue(frame.CorrelationId, out var channel))
|
||||
{
|
||||
channel.Writer.TryWrite(frame);
|
||||
return;
|
||||
}
|
||||
|
||||
_logger.LogDebug("No pending request for correlation ID {CorrelationId}", frame.CorrelationId);
|
||||
}
|
||||
|
||||
private async Task SendFrameAsync(ConnectionState connection, Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
switch (connection.TransportType)
|
||||
{
|
||||
case TransportType.Tcp:
|
||||
await _tcpServer.SendFrameAsync(connection.ConnectionId, frame, cancellationToken);
|
||||
break;
|
||||
case TransportType.Tls:
|
||||
await _tlsServer.SendFrameAsync(connection.ConnectionId, frame, cancellationToken);
|
||||
break;
|
||||
default:
|
||||
throw new NotSupportedException($"Transport type {connection.TransportType} is not supported by the gateway.");
|
||||
}
|
||||
}
|
||||
|
||||
private static string EnsureCorrelationId(Frame frame)
|
||||
{
|
||||
if (!string.IsNullOrWhiteSpace(frame.CorrelationId))
|
||||
{
|
||||
return frame.CorrelationId;
|
||||
}
|
||||
|
||||
return Guid.NewGuid().ToString("N");
|
||||
}
|
||||
|
||||
private async Task StreamRequestBodyAsync(
|
||||
ConnectionState connection,
|
||||
string correlationId,
|
||||
Stream requestBody,
|
||||
PayloadLimits limits,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var buffer = ArrayPool<byte>.Shared.Rent(8192);
|
||||
try
|
||||
{
|
||||
long totalBytesRead = 0;
|
||||
int bytesRead;
|
||||
|
||||
while ((bytesRead = await requestBody.ReadAsync(buffer, cancellationToken)) > 0)
|
||||
{
|
||||
totalBytesRead += bytesRead;
|
||||
|
||||
if (totalBytesRead > limits.MaxRequestBytesPerCall)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
$"Request body exceeds limit of {limits.MaxRequestBytesPerCall} bytes");
|
||||
}
|
||||
|
||||
var dataFrame = new Frame
|
||||
{
|
||||
Type = FrameType.RequestStreamData,
|
||||
CorrelationId = correlationId,
|
||||
Payload = new ReadOnlyMemory<byte>(buffer, 0, bytesRead)
|
||||
};
|
||||
await SendFrameAsync(connection, dataFrame, cancellationToken);
|
||||
}
|
||||
|
||||
var endFrame = new Frame
|
||||
{
|
||||
Type = FrameType.RequestStreamData,
|
||||
CorrelationId = correlationId,
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
await SendFrameAsync(connection, endFrame, cancellationToken);
|
||||
}
|
||||
finally
|
||||
{
|
||||
ArrayPool<byte>.Shared.Return(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task ReadStreamingResponseAsync(
|
||||
ChannelReader<Frame> reader,
|
||||
Stream responseStream,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
while (await reader.WaitToReadAsync(cancellationToken))
|
||||
{
|
||||
while (reader.TryRead(out var frame))
|
||||
{
|
||||
if (frame.Type == FrameType.ResponseStreamData)
|
||||
{
|
||||
if (frame.Payload.Length == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
await responseStream.WriteAsync(frame.Payload, cancellationToken);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (frame.Type == FrameType.Response)
|
||||
{
|
||||
if (frame.Payload.Length > 0)
|
||||
{
|
||||
await responseStream.WriteAsync(frame.Payload, cancellationToken);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user