using System.Linq; using System.Text.Json; using Microsoft.Extensions.Options; using StellaOps.Gateway.WebService.Authorization; using StellaOps.Gateway.WebService.Configuration; using StellaOps.Router.Common.Abstractions; using StellaOps.Router.Common.Enums; using StellaOps.Router.Common.Models; using StellaOps.Router.Gateway.OpenApi; using StellaOps.Router.Transport.Tcp; using StellaOps.Router.Transport.Tls; namespace StellaOps.Gateway.WebService.Services; public sealed class GatewayHostedService : IHostedService { private readonly TcpTransportServer _tcpServer; private readonly TlsTransportServer _tlsServer; private readonly IGlobalRoutingState _routingState; private readonly GatewayTransportClient _transportClient; private readonly IEffectiveClaimsStore _claimsStore; private readonly IRouterOpenApiDocumentCache? _openApiCache; private readonly IOptions _options; private readonly GatewayServiceStatus _status; private readonly ILogger _logger; private readonly JsonSerializerOptions _jsonOptions; private bool _tcpEnabled; private bool _tlsEnabled; public GatewayHostedService( TcpTransportServer tcpServer, TlsTransportServer tlsServer, IGlobalRoutingState routingState, GatewayTransportClient transportClient, IEffectiveClaimsStore claimsStore, IOptions options, GatewayServiceStatus status, ILogger logger, IRouterOpenApiDocumentCache? openApiCache = null) { _tcpServer = tcpServer; _tlsServer = tlsServer; _routingState = routingState; _transportClient = transportClient; _claimsStore = claimsStore; _options = options; _status = status; _logger = logger; _openApiCache = openApiCache; _jsonOptions = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase, WriteIndented = false }; } public async Task StartAsync(CancellationToken cancellationToken) { var options = _options.Value; _tcpEnabled = options.Transports.Tcp.Enabled; _tlsEnabled = options.Transports.Tls.Enabled; if (!_tcpEnabled && !_tlsEnabled) { _logger.LogWarning("No transports enabled; gateway will not accept microservice connections."); _status.MarkStarted(); _status.MarkReady(); return; } if (_tcpEnabled) { _tcpServer.OnFrame += HandleTcpFrame; _tcpServer.OnDisconnection += HandleTcpDisconnection; await _tcpServer.StartAsync(cancellationToken); _logger.LogInformation("TCP transport started on port {Port}", options.Transports.Tcp.Port); } if (_tlsEnabled) { _tlsServer.OnFrame += HandleTlsFrame; _tlsServer.OnDisconnection += HandleTlsDisconnection; await _tlsServer.StartAsync(cancellationToken); _logger.LogInformation("TLS transport started on port {Port}", options.Transports.Tls.Port); } _status.MarkStarted(); _status.MarkReady(); } public async Task StopAsync(CancellationToken cancellationToken) { _status.MarkNotReady(); foreach (var connection in _routingState.GetAllConnections()) { _routingState.UpdateConnection(connection.ConnectionId, c => c.Status = InstanceHealthStatus.Draining); } if (_tcpEnabled) { await _tcpServer.StopAsync(cancellationToken); _tcpServer.OnFrame -= HandleTcpFrame; _tcpServer.OnDisconnection -= HandleTcpDisconnection; } if (_tlsEnabled) { await _tlsServer.StopAsync(cancellationToken); _tlsServer.OnFrame -= HandleTlsFrame; _tlsServer.OnDisconnection -= HandleTlsDisconnection; } } private void HandleTcpFrame(string connectionId, Frame frame) { _ = HandleFrameAsync(TransportType.Tcp, connectionId, frame); } private void HandleTlsFrame(string connectionId, Frame frame) { _ = HandleFrameAsync(TransportType.Tls, connectionId, frame); } private void HandleTcpDisconnection(string connectionId) { HandleDisconnect(connectionId); } private void HandleTlsDisconnection(string connectionId) { HandleDisconnect(connectionId); } private async Task HandleFrameAsync(TransportType transportType, string connectionId, Frame frame) { try { switch (frame.Type) { case FrameType.Hello: await HandleHelloAsync(transportType, connectionId, frame); break; case FrameType.Heartbeat: await HandleHeartbeatAsync(connectionId, frame); break; case FrameType.Response: case FrameType.ResponseStreamData: _transportClient.HandleResponseFrame(frame); break; case FrameType.Cancel: _logger.LogDebug("Received CANCEL for {ConnectionId} correlation {CorrelationId}", connectionId, frame.CorrelationId); break; default: _logger.LogDebug("Ignoring frame type {FrameType} from {ConnectionId}", frame.Type, connectionId); break; } } catch (Exception ex) { _logger.LogError(ex, "Error handling frame {FrameType} from {ConnectionId}", frame.Type, connectionId); } } private async Task HandleHelloAsync(TransportType transportType, string connectionId, Frame frame) { if (!TryParseHelloPayload(frame, out var payload, out var parseError)) { _logger.LogWarning("Invalid HELLO payload for {ConnectionId}: {Error}", connectionId, parseError); CloseConnection(transportType, connectionId); return; } if (payload is not null && !TryValidateHelloPayload(payload, out var validationError)) { _logger.LogWarning("HELLO validation failed for {ConnectionId}: {Error}", connectionId, validationError); CloseConnection(transportType, connectionId); return; } var state = payload is null ? BuildFallbackState(transportType, connectionId) : BuildConnectionState(transportType, connectionId, payload); _routingState.AddConnection(state); if (payload is not null) { _claimsStore.UpdateFromMicroservice(payload.Instance.ServiceName, payload.Endpoints); } _openApiCache?.Invalidate(); _logger.LogInformation( "Connection registered: {ConnectionId} service={ServiceName} version={Version}", connectionId, state.Instance.ServiceName, state.Instance.Version); await Task.CompletedTask; } private async Task HandleHeartbeatAsync(string connectionId, Frame frame) { if (!_routingState.GetAllConnections().Any(c => c.ConnectionId == connectionId)) { _logger.LogDebug("Heartbeat received for unknown connection {ConnectionId}", connectionId); return; } if (TryParseHeartbeatPayload(frame, out var payload)) { _routingState.UpdateConnection(connectionId, conn => { conn.LastHeartbeatUtc = DateTime.UtcNow; conn.Status = payload.Status; }); } else { _routingState.UpdateConnection(connectionId, conn => { conn.LastHeartbeatUtc = DateTime.UtcNow; }); } await Task.CompletedTask; } private void HandleDisconnect(string connectionId) { var connection = _routingState.GetConnection(connectionId); if (connection is null) { return; } _routingState.RemoveConnection(connectionId); _openApiCache?.Invalidate(); var serviceName = connection.Instance.ServiceName; if (!string.IsNullOrWhiteSpace(serviceName)) { var remaining = _routingState.GetAllConnections() .Any(c => string.Equals(c.Instance.ServiceName, serviceName, StringComparison.OrdinalIgnoreCase)); if (!remaining) { _claimsStore.RemoveService(serviceName); } } } private bool TryParseHelloPayload(Frame frame, out HelloPayload? payload, out string? error) { payload = null; error = null; if (frame.Payload.IsEmpty) { return true; } try { payload = JsonSerializer.Deserialize(frame.Payload.Span, _jsonOptions); if (payload is null) { error = "HELLO payload missing"; return false; } return true; } catch (JsonException ex) { error = ex.Message; return false; } } private bool TryParseHeartbeatPayload(Frame frame, out HeartbeatPayload payload) { payload = new HeartbeatPayload { InstanceId = string.Empty, Status = InstanceHealthStatus.Healthy, TimestampUtc = DateTime.UtcNow }; if (frame.Payload.IsEmpty) { return false; } try { var parsed = JsonSerializer.Deserialize(frame.Payload.Span, _jsonOptions); if (parsed is null) { return false; } payload = parsed; return true; } catch (JsonException) { return false; } } private static bool TryValidateHelloPayload(HelloPayload payload, out string error) { if (string.IsNullOrWhiteSpace(payload.Instance.ServiceName)) { error = "Instance.ServiceName is required"; return false; } if (string.IsNullOrWhiteSpace(payload.Instance.Version)) { error = "Instance.Version is required"; return false; } if (string.IsNullOrWhiteSpace(payload.Instance.Region)) { error = "Instance.Region is required"; return false; } if (string.IsNullOrWhiteSpace(payload.Instance.InstanceId)) { error = "Instance.InstanceId is required"; return false; } var seen = new HashSet<(string Method, string Path)>(new EndpointKeyComparer()); foreach (var endpoint in payload.Endpoints) { if (string.IsNullOrWhiteSpace(endpoint.Method)) { error = "Endpoint.Method is required"; return false; } if (string.IsNullOrWhiteSpace(endpoint.Path) || !endpoint.Path.StartsWith('/')) { error = "Endpoint.Path must start with '/'"; return false; } if (!string.Equals(endpoint.ServiceName, payload.Instance.ServiceName, StringComparison.OrdinalIgnoreCase) || !string.Equals(endpoint.Version, payload.Instance.Version, StringComparison.Ordinal)) { error = "Endpoint.ServiceName/Version must match HelloPayload.Instance"; return false; } if (!seen.Add((endpoint.Method, endpoint.Path))) { error = $"Duplicate endpoint registration for {endpoint.Method} {endpoint.Path}"; return false; } if (endpoint.SchemaInfo is not null) { if (endpoint.SchemaInfo.RequestSchemaId is not null && !payload.Schemas.ContainsKey(endpoint.SchemaInfo.RequestSchemaId)) { error = $"Endpoint schema reference missing: requestSchemaId='{endpoint.SchemaInfo.RequestSchemaId}'"; return false; } if (endpoint.SchemaInfo.ResponseSchemaId is not null && !payload.Schemas.ContainsKey(endpoint.SchemaInfo.ResponseSchemaId)) { error = $"Endpoint schema reference missing: responseSchemaId='{endpoint.SchemaInfo.ResponseSchemaId}'"; return false; } } } error = string.Empty; return true; } private static ConnectionState BuildFallbackState(TransportType transportType, string connectionId) { return new ConnectionState { ConnectionId = connectionId, Instance = new InstanceDescriptor { InstanceId = connectionId, ServiceName = "unknown", Version = "unknown", Region = "unknown" }, Status = InstanceHealthStatus.Healthy, LastHeartbeatUtc = DateTime.UtcNow, TransportType = transportType }; } private static ConnectionState BuildConnectionState(TransportType transportType, string connectionId, HelloPayload payload) { var state = new ConnectionState { ConnectionId = connectionId, Instance = payload.Instance, Status = InstanceHealthStatus.Healthy, LastHeartbeatUtc = DateTime.UtcNow, TransportType = transportType, Schemas = payload.Schemas, OpenApiInfo = payload.OpenApiInfo }; foreach (var endpoint in payload.Endpoints) { state.Endpoints[(endpoint.Method, endpoint.Path)] = endpoint; } return state; } private void CloseConnection(TransportType transportType, string connectionId) { if (transportType == TransportType.Tcp) { _tcpServer.GetConnection(connectionId)?.Close(); return; } if (transportType == TransportType.Tls) { _tlsServer.GetConnection(connectionId)?.Close(); } } private sealed class EndpointKeyComparer : IEqualityComparer<(string Method, string Path)> { public bool Equals((string Method, string Path) x, (string Method, string Path) y) { return string.Equals(x.Method, y.Method, StringComparison.OrdinalIgnoreCase) && string.Equals(x.Path, y.Path, StringComparison.OrdinalIgnoreCase); } public int GetHashCode((string Method, string Path) obj) { return HashCode.Combine( StringComparer.OrdinalIgnoreCase.GetHashCode(obj.Method), StringComparer.OrdinalIgnoreCase.GetHashCode(obj.Path)); } } }