using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.Tcp;
///
/// TCP transport server implementation for the gateway.
///
public sealed class TcpTransportServer : ITransportServer, IAsyncDisposable
{
private readonly TcpTransportOptions _options;
private readonly ILogger _logger;
private readonly ConcurrentDictionary _connections = new();
private TcpListener? _listener;
private CancellationTokenSource? _serverCts;
private Task? _acceptTask;
private bool _disposed;
///
/// Event raised when a connection is established.
///
public event Action? OnConnection;
///
/// Event raised when a connection is lost.
///
public event Action? OnDisconnection;
///
/// Event raised when a frame is received.
///
public event Action? OnFrame;
///
/// Initializes a new instance of the class.
///
public TcpTransportServer(
IOptions options,
ILogger logger)
{
_options = options.Value;
_logger = logger;
}
///
public Task StartAsync(CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
_serverCts = new CancellationTokenSource();
_listener = new TcpListener(_options.BindAddress, _options.Port);
_listener.Start();
_logger.LogInformation(
"TCP transport server listening on {Address}:{Port}",
_options.BindAddress,
_options.Port);
_acceptTask = AcceptLoopAsync(_serverCts.Token);
return Task.CompletedTask;
}
private async Task AcceptLoopAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
try
{
var client = await _listener!.AcceptTcpClientAsync(cancellationToken);
var connectionId = GenerateConnectionId(client);
_logger.LogInformation(
"Accepted connection {ConnectionId} from {RemoteEndpoint}",
connectionId,
client.Client.RemoteEndPoint);
var connection = new TcpConnection(connectionId, client, _options, _logger);
_connections[connectionId] = connection;
connection.OnFrameReceived += HandleFrame;
connection.OnDisconnected += HandleDisconnect;
// Start read loop (non-blocking)
_ = Task.Run(() => connection.ReadLoopAsync(cancellationToken), CancellationToken.None);
}
catch (OperationCanceledException)
{
// Expected on shutdown
break;
}
catch (ObjectDisposedException)
{
// Listener disposed
break;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error accepting connection");
}
}
}
private void HandleFrame(TcpConnection connection, Frame frame)
{
// If this is a HELLO frame, create the ConnectionState
if (frame.Type == FrameType.Hello && connection.State is null)
{
var state = new ConnectionState
{
ConnectionId = connection.ConnectionId,
Instance = new InstanceDescriptor
{
InstanceId = connection.ConnectionId,
ServiceName = "unknown", // Will be updated from HELLO payload
Version = "1.0.0",
Region = "default"
},
Status = InstanceHealthStatus.Healthy,
LastHeartbeatUtc = DateTime.UtcNow,
TransportType = TransportType.Tcp
};
connection.State = state;
OnConnection?.Invoke(connection.ConnectionId, state);
}
OnFrame?.Invoke(connection.ConnectionId, frame);
}
private void HandleDisconnect(TcpConnection connection, Exception? ex)
{
_logger.LogInformation(
"Connection {ConnectionId} disconnected{Reason}",
connection.ConnectionId,
ex is not null ? $": {ex.Message}" : string.Empty);
_connections.TryRemove(connection.ConnectionId, out _);
OnDisconnection?.Invoke(connection.ConnectionId);
// Clean up connection
_ = connection.DisposeAsync();
}
///
/// Sends a frame to a connection.
///
/// The connection ID.
/// The frame to send.
/// Cancellation token.
public async Task SendFrameAsync(
string connectionId,
Frame frame,
CancellationToken cancellationToken = default)
{
if (_connections.TryGetValue(connectionId, out var connection))
{
await connection.WriteFrameAsync(frame, cancellationToken);
}
else
{
throw new InvalidOperationException($"Connection {connectionId} not found");
}
}
///
/// Gets a connection by ID.
///
/// The connection ID.
/// The connection, or null if not found.
public TcpConnection? GetConnection(string connectionId)
{
return _connections.TryGetValue(connectionId, out var conn) ? conn : null;
}
///
/// Gets all active connections.
///
public IEnumerable GetConnections() => _connections.Values;
///
/// Gets the number of active connections.
///
public int ConnectionCount => _connections.Count;
private static string GenerateConnectionId(TcpClient client)
{
var endpoint = client.Client.RemoteEndPoint as IPEndPoint;
if (endpoint is not null)
{
return $"tcp-{endpoint.Address}-{endpoint.Port}-{Guid.NewGuid():N}".Substring(0, 32);
}
return $"tcp-{Guid.NewGuid():N}";
}
///
public async Task StopAsync(CancellationToken cancellationToken)
{
_logger.LogInformation("Stopping TCP transport server");
if (_serverCts is not null)
{
await _serverCts.CancelAsync();
}
_listener?.Stop();
if (_acceptTask is not null)
{
await _acceptTask;
}
// Close all connections
foreach (var connection in _connections.Values)
{
connection.Close();
await connection.DisposeAsync();
}
_connections.Clear();
_logger.LogInformation("TCP transport server stopped");
}
///
public async ValueTask DisposeAsync()
{
if (_disposed) return;
_disposed = true;
await StopAsync(CancellationToken.None);
_listener?.Dispose();
_serverCts?.Dispose();
}
}