using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Models;
using StellaOps.Router.Gateway.OpenApi;
using StellaOps.Router.Transport.InMemory;
namespace StellaOps.Router.Gateway.Services;
///
/// Manages microservice connections and updates routing state.
///
internal sealed class ConnectionManager : IHostedService
{
private readonly InMemoryTransportServer _transportServer;
private readonly InMemoryConnectionRegistry _connectionRegistry;
private readonly IGlobalRoutingState _routingState;
private readonly IRouterOpenApiDocumentCache? _openApiCache;
private readonly ILogger _logger;
public ConnectionManager(
InMemoryTransportServer transportServer,
InMemoryConnectionRegistry connectionRegistry,
IGlobalRoutingState routingState,
ILogger logger,
IRouterOpenApiDocumentCache? openApiCache = null)
{
_transportServer = transportServer;
_connectionRegistry = connectionRegistry;
_routingState = routingState;
_openApiCache = openApiCache;
_logger = logger;
}
public async Task StartAsync(CancellationToken cancellationToken)
{
// Subscribe to transport server events
_transportServer.OnHelloReceived += HandleHelloReceivedAsync;
_transportServer.OnHeartbeatReceived += HandleHeartbeatReceivedAsync;
_transportServer.OnConnectionClosed += HandleConnectionClosedAsync;
// Start the transport server
await _transportServer.StartAsync(cancellationToken);
_logger.LogInformation("Connection manager started");
}
public async Task StopAsync(CancellationToken cancellationToken)
{
await _transportServer.StopAsync(cancellationToken);
_transportServer.OnHelloReceived -= HandleHelloReceivedAsync;
_transportServer.OnHeartbeatReceived -= HandleHeartbeatReceivedAsync;
_transportServer.OnConnectionClosed -= HandleConnectionClosedAsync;
_logger.LogInformation("Connection manager stopped");
}
private Task HandleHelloReceivedAsync(ConnectionState connectionState, HelloPayload payload)
{
if (!TryValidateHelloPayload(payload, out var validationError))
{
_logger.LogWarning(
"Rejecting HELLO for connection {ConnectionId}: {Error}",
connectionState.ConnectionId,
validationError);
_connectionRegistry.RemoveChannel(connectionState.ConnectionId);
return Task.CompletedTask;
}
_logger.LogInformation(
"Connection registered: {ConnectionId} from {ServiceName}/{Version} with {EndpointCount} endpoints, {SchemaCount} schemas",
connectionState.ConnectionId,
connectionState.Instance.ServiceName,
connectionState.Instance.Version,
connectionState.Endpoints.Count,
connectionState.Schemas.Count);
// Add the connection to the routing state
_routingState.AddConnection(connectionState);
// Start listening to this connection for frames
_transportServer.StartListeningToConnection(connectionState.ConnectionId);
// Invalidate OpenAPI cache when connections change
_openApiCache?.Invalidate();
return Task.CompletedTask;
}
private static bool TryValidateHelloPayload(HelloPayload payload, out string error)
{
if (string.IsNullOrWhiteSpace(payload.Instance.ServiceName))
{
error = "Instance.ServiceName is required";
return false;
}
if (string.IsNullOrWhiteSpace(payload.Instance.Version))
{
error = "Instance.Version is required";
return false;
}
if (string.IsNullOrWhiteSpace(payload.Instance.Region))
{
error = "Instance.Region is required";
return false;
}
if (string.IsNullOrWhiteSpace(payload.Instance.InstanceId))
{
error = "Instance.InstanceId is required";
return false;
}
var seen = new HashSet<(string Method, string Path)>(new EndpointKeyComparer());
foreach (var endpoint in payload.Endpoints)
{
if (string.IsNullOrWhiteSpace(endpoint.Method))
{
error = "Endpoint.Method is required";
return false;
}
if (string.IsNullOrWhiteSpace(endpoint.Path) || !endpoint.Path.StartsWith('/'))
{
error = "Endpoint.Path must start with '/'";
return false;
}
if (!string.Equals(endpoint.ServiceName, payload.Instance.ServiceName, StringComparison.OrdinalIgnoreCase) ||
!string.Equals(endpoint.Version, payload.Instance.Version, StringComparison.Ordinal))
{
error = "Endpoint.ServiceName/Version must match HelloPayload.Instance";
return false;
}
if (!seen.Add((endpoint.Method, endpoint.Path)))
{
error = $"Duplicate endpoint registration for {endpoint.Method} {endpoint.Path}";
return false;
}
if (endpoint.SchemaInfo is not null)
{
if (endpoint.SchemaInfo.RequestSchemaId is not null &&
!payload.Schemas.ContainsKey(endpoint.SchemaInfo.RequestSchemaId))
{
error = $"Endpoint schema reference missing: requestSchemaId='{endpoint.SchemaInfo.RequestSchemaId}'";
return false;
}
if (endpoint.SchemaInfo.ResponseSchemaId is not null &&
!payload.Schemas.ContainsKey(endpoint.SchemaInfo.ResponseSchemaId))
{
error = $"Endpoint schema reference missing: responseSchemaId='{endpoint.SchemaInfo.ResponseSchemaId}'";
return false;
}
}
}
error = string.Empty;
return true;
}
private sealed class EndpointKeyComparer : IEqualityComparer<(string Method, string Path)>
{
public bool Equals((string Method, string Path) x, (string Method, string Path) y)
{
return string.Equals(x.Method, y.Method, StringComparison.OrdinalIgnoreCase) &&
string.Equals(x.Path, y.Path, StringComparison.OrdinalIgnoreCase);
}
public int GetHashCode((string Method, string Path) obj)
{
return HashCode.Combine(
StringComparer.OrdinalIgnoreCase.GetHashCode(obj.Method),
StringComparer.OrdinalIgnoreCase.GetHashCode(obj.Path));
}
}
private Task HandleHeartbeatReceivedAsync(ConnectionState connectionState, HeartbeatPayload payload)
{
_logger.LogDebug(
"Heartbeat received from {ConnectionId}: status={Status}",
connectionState.ConnectionId,
payload.Status);
// Update connection state
_routingState.UpdateConnection(connectionState.ConnectionId, conn =>
{
conn.Status = payload.Status;
conn.LastHeartbeatUtc = DateTime.UtcNow;
});
return Task.CompletedTask;
}
private Task HandleConnectionClosedAsync(string connectionId)
{
_logger.LogInformation("Connection closed: {ConnectionId}", connectionId);
// Remove from routing state
_routingState.RemoveConnection(connectionId);
// Invalidate OpenAPI cache when connections change
_openApiCache?.Invalidate();
return Task.CompletedTask;
}
}