254 lines
8.4 KiB
C#
254 lines
8.4 KiB
C#
using System.Buffers;
|
|
using System.Collections.Concurrent;
|
|
using System.Threading.Channels;
|
|
using StellaOps.Router.Common.Abstractions;
|
|
using StellaOps.Router.Common.Enums;
|
|
using StellaOps.Router.Common.Models;
|
|
using StellaOps.Router.Transport.Tcp;
|
|
using StellaOps.Router.Transport.Tls;
|
|
using StellaOps.Router.Transport.Messaging;
|
|
|
|
namespace StellaOps.Gateway.WebService.Services;
|
|
|
|
public sealed class GatewayTransportClient : ITransportClient
|
|
{
|
|
private readonly TcpTransportServer _tcpServer;
|
|
private readonly TlsTransportServer _tlsServer;
|
|
private readonly MessagingTransportServer? _messagingServer;
|
|
private readonly ILogger<GatewayTransportClient> _logger;
|
|
private readonly ConcurrentDictionary<string, TaskCompletionSource<Frame>> _pendingRequests = new();
|
|
private readonly ConcurrentDictionary<string, Channel<Frame>> _streamingResponses = new();
|
|
|
|
public GatewayTransportClient(
|
|
TcpTransportServer tcpServer,
|
|
TlsTransportServer tlsServer,
|
|
ILogger<GatewayTransportClient> logger,
|
|
MessagingTransportServer? messagingServer = null)
|
|
{
|
|
_tcpServer = tcpServer;
|
|
_tlsServer = tlsServer;
|
|
_messagingServer = messagingServer;
|
|
_logger = logger;
|
|
}
|
|
|
|
public async Task<Frame> SendRequestAsync(
|
|
ConnectionState connection,
|
|
Frame requestFrame,
|
|
TimeSpan timeout,
|
|
CancellationToken cancellationToken)
|
|
{
|
|
var correlationId = EnsureCorrelationId(requestFrame);
|
|
var frame = requestFrame with { CorrelationId = correlationId };
|
|
|
|
var tcs = new TaskCompletionSource<Frame>(TaskCreationOptions.RunContinuationsAsynchronously);
|
|
if (!_pendingRequests.TryAdd(correlationId, tcs))
|
|
{
|
|
throw new InvalidOperationException($"Duplicate correlation ID {correlationId}");
|
|
}
|
|
|
|
try
|
|
{
|
|
await SendFrameAsync(connection, frame, cancellationToken);
|
|
|
|
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
|
timeoutCts.CancelAfter(timeout);
|
|
|
|
return await tcs.Task.WaitAsync(timeoutCts.Token);
|
|
}
|
|
finally
|
|
{
|
|
_pendingRequests.TryRemove(correlationId, out _);
|
|
}
|
|
}
|
|
|
|
public async Task SendCancelAsync(ConnectionState connection, Guid correlationId, string? reason = null)
|
|
{
|
|
var frame = new Frame
|
|
{
|
|
Type = FrameType.Cancel,
|
|
CorrelationId = correlationId.ToString("N"),
|
|
Payload = ReadOnlyMemory<byte>.Empty
|
|
};
|
|
|
|
await SendFrameAsync(connection, frame, CancellationToken.None);
|
|
}
|
|
|
|
public async Task SendStreamingAsync(
|
|
ConnectionState connection,
|
|
Frame requestHeader,
|
|
Stream requestBody,
|
|
Func<Stream, Task> readResponseBody,
|
|
PayloadLimits limits,
|
|
CancellationToken cancellationToken)
|
|
{
|
|
var correlationId = EnsureCorrelationId(requestHeader);
|
|
var headerFrame = requestHeader with
|
|
{
|
|
Type = FrameType.Request,
|
|
CorrelationId = correlationId
|
|
};
|
|
|
|
var channel = Channel.CreateUnbounded<Frame>(new UnboundedChannelOptions
|
|
{
|
|
SingleReader = true,
|
|
SingleWriter = false
|
|
});
|
|
|
|
if (!_streamingResponses.TryAdd(correlationId, channel))
|
|
{
|
|
throw new InvalidOperationException($"Duplicate correlation ID {correlationId}");
|
|
}
|
|
|
|
try
|
|
{
|
|
await SendFrameAsync(connection, headerFrame, cancellationToken);
|
|
await StreamRequestBodyAsync(connection, correlationId, requestBody, limits, cancellationToken);
|
|
|
|
using var responseStream = new MemoryStream();
|
|
await ReadStreamingResponseAsync(channel.Reader, responseStream, cancellationToken);
|
|
responseStream.Position = 0;
|
|
await readResponseBody(responseStream);
|
|
}
|
|
finally
|
|
{
|
|
if (_streamingResponses.TryRemove(correlationId, out var removed))
|
|
{
|
|
removed.Writer.TryComplete();
|
|
}
|
|
}
|
|
}
|
|
|
|
public void HandleResponseFrame(Frame frame)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(frame.CorrelationId))
|
|
{
|
|
_logger.LogDebug("Ignoring response frame without correlation ID");
|
|
return;
|
|
}
|
|
|
|
if (_pendingRequests.TryGetValue(frame.CorrelationId, out var pending))
|
|
{
|
|
pending.TrySetResult(frame);
|
|
return;
|
|
}
|
|
|
|
if (_streamingResponses.TryGetValue(frame.CorrelationId, out var channel))
|
|
{
|
|
channel.Writer.TryWrite(frame);
|
|
return;
|
|
}
|
|
|
|
_logger.LogDebug("No pending request for correlation ID {CorrelationId}", frame.CorrelationId);
|
|
}
|
|
|
|
private async Task SendFrameAsync(ConnectionState connection, Frame frame, CancellationToken cancellationToken)
|
|
{
|
|
switch (connection.TransportType)
|
|
{
|
|
case TransportType.Tcp:
|
|
await _tcpServer.SendFrameAsync(connection.ConnectionId, frame, cancellationToken);
|
|
break;
|
|
case TransportType.Certificate:
|
|
await _tlsServer.SendFrameAsync(connection.ConnectionId, frame, cancellationToken);
|
|
break;
|
|
case TransportType.Messaging:
|
|
if (_messagingServer is null)
|
|
{
|
|
throw new InvalidOperationException("Messaging transport is not enabled");
|
|
}
|
|
await _messagingServer.SendToMicroserviceAsync(connection.ConnectionId, frame, cancellationToken);
|
|
break;
|
|
default:
|
|
throw new NotSupportedException($"Transport type {connection.TransportType} is not supported by the gateway.");
|
|
}
|
|
}
|
|
|
|
private static string EnsureCorrelationId(Frame frame)
|
|
{
|
|
if (!string.IsNullOrWhiteSpace(frame.CorrelationId))
|
|
{
|
|
return frame.CorrelationId;
|
|
}
|
|
|
|
return Guid.NewGuid().ToString("N");
|
|
}
|
|
|
|
private async Task StreamRequestBodyAsync(
|
|
ConnectionState connection,
|
|
string correlationId,
|
|
Stream requestBody,
|
|
PayloadLimits limits,
|
|
CancellationToken cancellationToken)
|
|
{
|
|
var buffer = ArrayPool<byte>.Shared.Rent(8192);
|
|
try
|
|
{
|
|
long totalBytesRead = 0;
|
|
int bytesRead;
|
|
|
|
while ((bytesRead = await requestBody.ReadAsync(buffer, cancellationToken)) > 0)
|
|
{
|
|
totalBytesRead += bytesRead;
|
|
|
|
if (totalBytesRead > limits.MaxRequestBytesPerCall)
|
|
{
|
|
throw new InvalidOperationException(
|
|
$"Request body exceeds limit of {limits.MaxRequestBytesPerCall} bytes");
|
|
}
|
|
|
|
var dataFrame = new Frame
|
|
{
|
|
Type = FrameType.RequestStreamData,
|
|
CorrelationId = correlationId,
|
|
Payload = new ReadOnlyMemory<byte>(buffer, 0, bytesRead)
|
|
};
|
|
await SendFrameAsync(connection, dataFrame, cancellationToken);
|
|
}
|
|
|
|
var endFrame = new Frame
|
|
{
|
|
Type = FrameType.RequestStreamData,
|
|
CorrelationId = correlationId,
|
|
Payload = ReadOnlyMemory<byte>.Empty
|
|
};
|
|
await SendFrameAsync(connection, endFrame, cancellationToken);
|
|
}
|
|
finally
|
|
{
|
|
ArrayPool<byte>.Shared.Return(buffer);
|
|
}
|
|
}
|
|
|
|
private static async Task ReadStreamingResponseAsync(
|
|
ChannelReader<Frame> reader,
|
|
Stream responseStream,
|
|
CancellationToken cancellationToken)
|
|
{
|
|
while (await reader.WaitToReadAsync(cancellationToken))
|
|
{
|
|
while (reader.TryRead(out var frame))
|
|
{
|
|
if (frame.Type == FrameType.ResponseStreamData)
|
|
{
|
|
if (frame.Payload.Length == 0)
|
|
{
|
|
return;
|
|
}
|
|
|
|
await responseStream.WriteAsync(frame.Payload, cancellationToken);
|
|
continue;
|
|
}
|
|
|
|
if (frame.Type == FrameType.Response)
|
|
{
|
|
if (frame.Payload.Length > 0)
|
|
{
|
|
await responseStream.WriteAsync(frame.Payload, cancellationToken);
|
|
}
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|