Files
git.stella-ops.org/src/Gateway/StellaOps.Gateway.WebService/Services/GatewayHostedService.cs

459 lines
15 KiB
C#

using System.Linq;
using System.Text.Json;
using Microsoft.Extensions.Options;
using StellaOps.Gateway.WebService.Authorization;
using StellaOps.Gateway.WebService.Configuration;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using StellaOps.Router.Gateway.OpenApi;
using StellaOps.Router.Transport.Tcp;
using StellaOps.Router.Transport.Tls;
namespace StellaOps.Gateway.WebService.Services;
public sealed class GatewayHostedService : IHostedService
{
private readonly TcpTransportServer _tcpServer;
private readonly TlsTransportServer _tlsServer;
private readonly IGlobalRoutingState _routingState;
private readonly GatewayTransportClient _transportClient;
private readonly IEffectiveClaimsStore _claimsStore;
private readonly IRouterOpenApiDocumentCache? _openApiCache;
private readonly IOptions<GatewayOptions> _options;
private readonly GatewayServiceStatus _status;
private readonly ILogger<GatewayHostedService> _logger;
private readonly JsonSerializerOptions _jsonOptions;
private bool _tcpEnabled;
private bool _tlsEnabled;
public GatewayHostedService(
TcpTransportServer tcpServer,
TlsTransportServer tlsServer,
IGlobalRoutingState routingState,
GatewayTransportClient transportClient,
IEffectiveClaimsStore claimsStore,
IOptions<GatewayOptions> options,
GatewayServiceStatus status,
ILogger<GatewayHostedService> logger,
IRouterOpenApiDocumentCache? openApiCache = null)
{
_tcpServer = tcpServer;
_tlsServer = tlsServer;
_routingState = routingState;
_transportClient = transportClient;
_claimsStore = claimsStore;
_options = options;
_status = status;
_logger = logger;
_openApiCache = openApiCache;
_jsonOptions = new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
WriteIndented = false
};
}
public async Task StartAsync(CancellationToken cancellationToken)
{
var options = _options.Value;
_tcpEnabled = options.Transports.Tcp.Enabled;
_tlsEnabled = options.Transports.Tls.Enabled;
if (!_tcpEnabled && !_tlsEnabled)
{
_logger.LogWarning("No transports enabled; gateway will not accept microservice connections.");
_status.MarkStarted();
_status.MarkReady();
return;
}
if (_tcpEnabled)
{
_tcpServer.OnFrame += HandleTcpFrame;
_tcpServer.OnDisconnection += HandleTcpDisconnection;
await _tcpServer.StartAsync(cancellationToken);
_logger.LogInformation("TCP transport started on port {Port}", options.Transports.Tcp.Port);
}
if (_tlsEnabled)
{
_tlsServer.OnFrame += HandleTlsFrame;
_tlsServer.OnDisconnection += HandleTlsDisconnection;
await _tlsServer.StartAsync(cancellationToken);
_logger.LogInformation("TLS transport started on port {Port}", options.Transports.Tls.Port);
}
_status.MarkStarted();
_status.MarkReady();
}
public async Task StopAsync(CancellationToken cancellationToken)
{
_status.MarkNotReady();
foreach (var connection in _routingState.GetAllConnections())
{
_routingState.UpdateConnection(connection.ConnectionId, c => c.Status = InstanceHealthStatus.Draining);
}
if (_tcpEnabled)
{
await _tcpServer.StopAsync(cancellationToken);
_tcpServer.OnFrame -= HandleTcpFrame;
_tcpServer.OnDisconnection -= HandleTcpDisconnection;
}
if (_tlsEnabled)
{
await _tlsServer.StopAsync(cancellationToken);
_tlsServer.OnFrame -= HandleTlsFrame;
_tlsServer.OnDisconnection -= HandleTlsDisconnection;
}
}
private void HandleTcpFrame(string connectionId, Frame frame)
{
_ = HandleFrameAsync(TransportType.Tcp, connectionId, frame);
}
private void HandleTlsFrame(string connectionId, Frame frame)
{
_ = HandleFrameAsync(TransportType.Tls, connectionId, frame);
}
private void HandleTcpDisconnection(string connectionId)
{
HandleDisconnect(connectionId);
}
private void HandleTlsDisconnection(string connectionId)
{
HandleDisconnect(connectionId);
}
private async Task HandleFrameAsync(TransportType transportType, string connectionId, Frame frame)
{
try
{
switch (frame.Type)
{
case FrameType.Hello:
await HandleHelloAsync(transportType, connectionId, frame);
break;
case FrameType.Heartbeat:
await HandleHeartbeatAsync(connectionId, frame);
break;
case FrameType.Response:
case FrameType.ResponseStreamData:
_transportClient.HandleResponseFrame(frame);
break;
case FrameType.Cancel:
_logger.LogDebug("Received CANCEL for {ConnectionId} correlation {CorrelationId}", connectionId, frame.CorrelationId);
break;
default:
_logger.LogDebug("Ignoring frame type {FrameType} from {ConnectionId}", frame.Type, connectionId);
break;
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error handling frame {FrameType} from {ConnectionId}", frame.Type, connectionId);
}
}
private async Task HandleHelloAsync(TransportType transportType, string connectionId, Frame frame)
{
if (!TryParseHelloPayload(frame, out var payload, out var parseError))
{
_logger.LogWarning("Invalid HELLO payload for {ConnectionId}: {Error}", connectionId, parseError);
CloseConnection(transportType, connectionId);
return;
}
if (payload is not null && !TryValidateHelloPayload(payload, out var validationError))
{
_logger.LogWarning("HELLO validation failed for {ConnectionId}: {Error}", connectionId, validationError);
CloseConnection(transportType, connectionId);
return;
}
var state = payload is null
? BuildFallbackState(transportType, connectionId)
: BuildConnectionState(transportType, connectionId, payload);
_routingState.AddConnection(state);
if (payload is not null)
{
_claimsStore.UpdateFromMicroservice(payload.Instance.ServiceName, payload.Endpoints);
}
_openApiCache?.Invalidate();
_logger.LogInformation(
"Connection registered: {ConnectionId} service={ServiceName} version={Version}",
connectionId,
state.Instance.ServiceName,
state.Instance.Version);
await Task.CompletedTask;
}
private async Task HandleHeartbeatAsync(string connectionId, Frame frame)
{
if (!_routingState.GetAllConnections().Any(c => c.ConnectionId == connectionId))
{
_logger.LogDebug("Heartbeat received for unknown connection {ConnectionId}", connectionId);
return;
}
if (TryParseHeartbeatPayload(frame, out var payload))
{
_routingState.UpdateConnection(connectionId, conn =>
{
conn.LastHeartbeatUtc = DateTime.UtcNow;
conn.Status = payload.Status;
});
}
else
{
_routingState.UpdateConnection(connectionId, conn =>
{
conn.LastHeartbeatUtc = DateTime.UtcNow;
});
}
await Task.CompletedTask;
}
private void HandleDisconnect(string connectionId)
{
var connection = _routingState.GetConnection(connectionId);
if (connection is null)
{
return;
}
_routingState.RemoveConnection(connectionId);
_openApiCache?.Invalidate();
var serviceName = connection.Instance.ServiceName;
if (!string.IsNullOrWhiteSpace(serviceName))
{
var remaining = _routingState.GetAllConnections()
.Any(c => string.Equals(c.Instance.ServiceName, serviceName, StringComparison.OrdinalIgnoreCase));
if (!remaining)
{
_claimsStore.RemoveService(serviceName);
}
}
}
private bool TryParseHelloPayload(Frame frame, out HelloPayload? payload, out string? error)
{
payload = null;
error = null;
if (frame.Payload.IsEmpty)
{
return true;
}
try
{
payload = JsonSerializer.Deserialize<HelloPayload>(frame.Payload.Span, _jsonOptions);
if (payload is null)
{
error = "HELLO payload missing";
return false;
}
return true;
}
catch (JsonException ex)
{
error = ex.Message;
return false;
}
}
private bool TryParseHeartbeatPayload(Frame frame, out HeartbeatPayload payload)
{
payload = new HeartbeatPayload
{
InstanceId = string.Empty,
Status = InstanceHealthStatus.Healthy,
TimestampUtc = DateTime.UtcNow
};
if (frame.Payload.IsEmpty)
{
return false;
}
try
{
var parsed = JsonSerializer.Deserialize<HeartbeatPayload>(frame.Payload.Span, _jsonOptions);
if (parsed is null)
{
return false;
}
payload = parsed;
return true;
}
catch (JsonException)
{
return false;
}
}
private static bool TryValidateHelloPayload(HelloPayload payload, out string error)
{
if (string.IsNullOrWhiteSpace(payload.Instance.ServiceName))
{
error = "Instance.ServiceName is required";
return false;
}
if (string.IsNullOrWhiteSpace(payload.Instance.Version))
{
error = "Instance.Version is required";
return false;
}
if (string.IsNullOrWhiteSpace(payload.Instance.Region))
{
error = "Instance.Region is required";
return false;
}
if (string.IsNullOrWhiteSpace(payload.Instance.InstanceId))
{
error = "Instance.InstanceId is required";
return false;
}
var seen = new HashSet<(string Method, string Path)>(new EndpointKeyComparer());
foreach (var endpoint in payload.Endpoints)
{
if (string.IsNullOrWhiteSpace(endpoint.Method))
{
error = "Endpoint.Method is required";
return false;
}
if (string.IsNullOrWhiteSpace(endpoint.Path) || !endpoint.Path.StartsWith('/'))
{
error = "Endpoint.Path must start with '/'";
return false;
}
if (!string.Equals(endpoint.ServiceName, payload.Instance.ServiceName, StringComparison.OrdinalIgnoreCase) ||
!string.Equals(endpoint.Version, payload.Instance.Version, StringComparison.Ordinal))
{
error = "Endpoint.ServiceName/Version must match HelloPayload.Instance";
return false;
}
if (!seen.Add((endpoint.Method, endpoint.Path)))
{
error = $"Duplicate endpoint registration for {endpoint.Method} {endpoint.Path}";
return false;
}
if (endpoint.SchemaInfo is not null)
{
if (endpoint.SchemaInfo.RequestSchemaId is not null &&
!payload.Schemas.ContainsKey(endpoint.SchemaInfo.RequestSchemaId))
{
error = $"Endpoint schema reference missing: requestSchemaId='{endpoint.SchemaInfo.RequestSchemaId}'";
return false;
}
if (endpoint.SchemaInfo.ResponseSchemaId is not null &&
!payload.Schemas.ContainsKey(endpoint.SchemaInfo.ResponseSchemaId))
{
error = $"Endpoint schema reference missing: responseSchemaId='{endpoint.SchemaInfo.ResponseSchemaId}'";
return false;
}
}
}
error = string.Empty;
return true;
}
private static ConnectionState BuildFallbackState(TransportType transportType, string connectionId)
{
return new ConnectionState
{
ConnectionId = connectionId,
Instance = new InstanceDescriptor
{
InstanceId = connectionId,
ServiceName = "unknown",
Version = "unknown",
Region = "unknown"
},
Status = InstanceHealthStatus.Healthy,
LastHeartbeatUtc = DateTime.UtcNow,
TransportType = transportType
};
}
private static ConnectionState BuildConnectionState(TransportType transportType, string connectionId, HelloPayload payload)
{
var state = new ConnectionState
{
ConnectionId = connectionId,
Instance = payload.Instance,
Status = InstanceHealthStatus.Healthy,
LastHeartbeatUtc = DateTime.UtcNow,
TransportType = transportType,
Schemas = payload.Schemas,
OpenApiInfo = payload.OpenApiInfo
};
foreach (var endpoint in payload.Endpoints)
{
state.Endpoints[(endpoint.Method, endpoint.Path)] = endpoint;
}
return state;
}
private void CloseConnection(TransportType transportType, string connectionId)
{
if (transportType == TransportType.Tcp)
{
_tcpServer.GetConnection(connectionId)?.Close();
return;
}
if (transportType == TransportType.Tls)
{
_tlsServer.GetConnection(connectionId)?.Close();
}
}
private sealed class EndpointKeyComparer : IEqualityComparer<(string Method, string Path)>
{
public bool Equals((string Method, string Path) x, (string Method, string Path) y)
{
return string.Equals(x.Method, y.Method, StringComparison.OrdinalIgnoreCase) &&
string.Equals(x.Path, y.Path, StringComparison.OrdinalIgnoreCase);
}
public int GetHashCode((string Method, string Path) obj)
{
return HashCode.Combine(
StringComparer.OrdinalIgnoreCase.GetHashCode(obj.Method),
StringComparer.OrdinalIgnoreCase.GetHashCode(obj.Path));
}
}
}