Add unit tests for Router configuration and transport layers
- Implemented tests for RouterConfig, RoutingOptions, StaticInstanceConfig, and RouterConfigOptions to ensure default values are set correctly. - Added tests for RouterConfigProvider to validate configurations and ensure defaults are returned when no file is specified. - Created tests for ConfigValidationResult to check success and error scenarios. - Developed tests for ServiceCollectionExtensions to verify service registration for RouterConfig. - Introduced UdpTransportTests to validate serialization, connection, request-response, and error handling in UDP transport. - Added scripts for signing authority gaps and hashing DevPortal SDK snippets.
This commit is contained in:
@@ -0,0 +1,55 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
|
||||
namespace StellaOps.Microservice.SourceGen;
|
||||
|
||||
/// <summary>
|
||||
/// Diagnostic descriptors for the source generator.
|
||||
/// </summary>
|
||||
internal static class DiagnosticDescriptors
|
||||
{
|
||||
private const string Category = "StellaOps.Microservice";
|
||||
|
||||
/// <summary>
|
||||
/// Class with [StellaEndpoint] must implement IStellaEndpoint or IRawStellaEndpoint.
|
||||
/// </summary>
|
||||
public static readonly DiagnosticDescriptor MissingHandlerInterface = new(
|
||||
id: "STELLA001",
|
||||
title: "Missing handler interface",
|
||||
messageFormat: "Class '{0}' with [StellaEndpoint] must implement IStellaEndpoint<TRequest, TResponse> or IRawStellaEndpoint",
|
||||
category: Category,
|
||||
defaultSeverity: DiagnosticSeverity.Error,
|
||||
isEnabledByDefault: true);
|
||||
|
||||
/// <summary>
|
||||
/// Duplicate endpoint detected.
|
||||
/// </summary>
|
||||
public static readonly DiagnosticDescriptor DuplicateEndpoint = new(
|
||||
id: "STELLA002",
|
||||
title: "Duplicate endpoint",
|
||||
messageFormat: "Duplicate endpoint: {0} {1} is defined in both '{2}' and '{3}'",
|
||||
category: Category,
|
||||
defaultSeverity: DiagnosticSeverity.Warning,
|
||||
isEnabledByDefault: true);
|
||||
|
||||
/// <summary>
|
||||
/// [StellaEndpoint] on abstract class is ignored.
|
||||
/// </summary>
|
||||
public static readonly DiagnosticDescriptor AbstractClassIgnored = new(
|
||||
id: "STELLA003",
|
||||
title: "Abstract class ignored",
|
||||
messageFormat: "[StellaEndpoint] on abstract class '{0}' is ignored",
|
||||
category: Category,
|
||||
defaultSeverity: DiagnosticSeverity.Warning,
|
||||
isEnabledByDefault: true);
|
||||
|
||||
/// <summary>
|
||||
/// Informational: endpoints generated.
|
||||
/// </summary>
|
||||
public static readonly DiagnosticDescriptor EndpointsGenerated = new(
|
||||
id: "STELLA004",
|
||||
title: "Endpoints generated",
|
||||
messageFormat: "Generated {0} endpoint descriptors",
|
||||
category: Category,
|
||||
defaultSeverity: DiagnosticSeverity.Info,
|
||||
isEnabledByDefault: false);
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
namespace StellaOps.Microservice.SourceGen;
|
||||
|
||||
/// <summary>
|
||||
/// Holds extracted endpoint information from a [StellaEndpoint] decorated class.
|
||||
/// </summary>
|
||||
internal sealed record EndpointInfo(
|
||||
string Namespace,
|
||||
string ClassName,
|
||||
string FullyQualifiedName,
|
||||
string Method,
|
||||
string Path,
|
||||
int TimeoutSeconds,
|
||||
bool SupportsStreaming,
|
||||
string[] RequiredClaims,
|
||||
string? RequestTypeName,
|
||||
string? ResponseTypeName,
|
||||
bool IsRaw);
|
||||
@@ -1,13 +0,0 @@
|
||||
namespace StellaOps.Microservice.SourceGen;
|
||||
|
||||
/// <summary>
|
||||
/// Placeholder type for the source generator project.
|
||||
/// This will be replaced with actual source generator implementation in a later sprint.
|
||||
/// </summary>
|
||||
public static class Placeholder
|
||||
{
|
||||
/// <summary>
|
||||
/// Indicates the source generator is not yet implemented.
|
||||
/// </summary>
|
||||
public const string Status = "NotImplemented";
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// Polyfills for netstandard2.0 compatibility
|
||||
|
||||
// ReSharper disable once CheckNamespace
|
||||
namespace System.Runtime.CompilerServices
|
||||
{
|
||||
/// <summary>
|
||||
/// Allows use of init accessors in netstandard2.0.
|
||||
/// </summary>
|
||||
internal static class IsExternalInit { }
|
||||
}
|
||||
@@ -0,0 +1,399 @@
|
||||
using System.Collections.Immutable;
|
||||
using System.Text;
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using Microsoft.CodeAnalysis.Text;
|
||||
|
||||
namespace StellaOps.Microservice.SourceGen;
|
||||
|
||||
/// <summary>
|
||||
/// Incremental source generator for [StellaEndpoint] decorated classes.
|
||||
/// Generates endpoint descriptors and DI registration at compile time.
|
||||
/// </summary>
|
||||
[Generator]
|
||||
public sealed class StellaEndpointGenerator : IIncrementalGenerator
|
||||
{
|
||||
private const string StellaEndpointAttributeName = "StellaOps.Microservice.StellaEndpointAttribute";
|
||||
private const string IStellaEndpointName = "StellaOps.Microservice.IStellaEndpoint";
|
||||
private const string IRawStellaEndpointName = "StellaOps.Microservice.IRawStellaEndpoint";
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Initialize(IncrementalGeneratorInitializationContext context)
|
||||
{
|
||||
// Find all class declarations with attributes
|
||||
var classDeclarations = context.SyntaxProvider
|
||||
.CreateSyntaxProvider(
|
||||
predicate: static (s, _) => IsSyntaxTargetForGeneration(s),
|
||||
transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx))
|
||||
.Where(static m => m is not null);
|
||||
|
||||
// Combine all endpoints and generate
|
||||
var compilationAndClasses = context.CompilationProvider.Combine(classDeclarations.Collect());
|
||||
|
||||
context.RegisterSourceOutput(
|
||||
compilationAndClasses,
|
||||
static (spc, source) => Execute(source.Left, source.Right!, spc));
|
||||
}
|
||||
|
||||
private static bool IsSyntaxTargetForGeneration(SyntaxNode node)
|
||||
{
|
||||
return node is ClassDeclarationSyntax { AttributeLists.Count: > 0 } classDecl
|
||||
&& !classDecl.Modifiers.Any(SyntaxKind.AbstractKeyword);
|
||||
}
|
||||
|
||||
private static ClassDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
|
||||
{
|
||||
var classDeclaration = (ClassDeclarationSyntax)context.Node;
|
||||
|
||||
foreach (var attributeList in classDeclaration.AttributeLists)
|
||||
{
|
||||
foreach (var attribute in attributeList.Attributes)
|
||||
{
|
||||
var symbolInfo = context.SemanticModel.GetSymbolInfo(attribute);
|
||||
var symbol = symbolInfo.Symbol;
|
||||
|
||||
if (symbol is not IMethodSymbol attributeSymbol)
|
||||
continue;
|
||||
|
||||
var attributeContainingType = attributeSymbol.ContainingType;
|
||||
var fullName = attributeContainingType.ToDisplayString();
|
||||
|
||||
if (fullName == StellaEndpointAttributeName)
|
||||
{
|
||||
return classDeclaration;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private static void Execute(
|
||||
Compilation compilation,
|
||||
ImmutableArray<ClassDeclarationSyntax?> classes,
|
||||
SourceProductionContext context)
|
||||
{
|
||||
if (classes.IsDefaultOrEmpty)
|
||||
return;
|
||||
|
||||
var distinctClasses = classes.Where(c => c is not null).Distinct().Cast<ClassDeclarationSyntax>();
|
||||
var endpoints = new List<EndpointInfo>();
|
||||
|
||||
foreach (var classDeclaration in distinctClasses)
|
||||
{
|
||||
var semanticModel = compilation.GetSemanticModel(classDeclaration.SyntaxTree);
|
||||
var classSymbol = semanticModel.GetDeclaredSymbol(classDeclaration);
|
||||
|
||||
if (classSymbol is null)
|
||||
continue;
|
||||
|
||||
var endpoint = ExtractEndpointInfo(classSymbol, context);
|
||||
if (endpoint is not null)
|
||||
{
|
||||
endpoints.Add(endpoint);
|
||||
}
|
||||
}
|
||||
|
||||
if (endpoints.Count == 0)
|
||||
return;
|
||||
|
||||
// Check for duplicates
|
||||
var seen = new Dictionary<(string Method, string Path), EndpointInfo>();
|
||||
foreach (var endpoint in endpoints)
|
||||
{
|
||||
var key = (endpoint.Method, endpoint.Path);
|
||||
if (seen.TryGetValue(key, out var existing))
|
||||
{
|
||||
context.ReportDiagnostic(Diagnostic.Create(
|
||||
DiagnosticDescriptors.DuplicateEndpoint,
|
||||
Location.None,
|
||||
endpoint.Method,
|
||||
endpoint.Path,
|
||||
existing.ClassName,
|
||||
endpoint.ClassName));
|
||||
}
|
||||
else
|
||||
{
|
||||
seen[key] = endpoint;
|
||||
}
|
||||
}
|
||||
|
||||
// Generate the source
|
||||
var source = GenerateEndpointsClass(endpoints);
|
||||
context.AddSource("StellaEndpoints.g.cs", SourceText.From(source, Encoding.UTF8));
|
||||
|
||||
// Generate the provider class
|
||||
var providerSource = GenerateProviderClass();
|
||||
context.AddSource("GeneratedEndpointProvider.g.cs", SourceText.From(providerSource, Encoding.UTF8));
|
||||
}
|
||||
|
||||
private static EndpointInfo? ExtractEndpointInfo(INamedTypeSymbol classSymbol, SourceProductionContext context)
|
||||
{
|
||||
// Find StellaEndpoint attribute
|
||||
AttributeData? stellaAttribute = null;
|
||||
foreach (var attr in classSymbol.GetAttributes())
|
||||
{
|
||||
if (attr.AttributeClass?.ToDisplayString() == StellaEndpointAttributeName)
|
||||
{
|
||||
stellaAttribute = attr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (stellaAttribute is null)
|
||||
return null;
|
||||
|
||||
// Check for abstract class
|
||||
if (classSymbol.IsAbstract)
|
||||
{
|
||||
context.ReportDiagnostic(Diagnostic.Create(
|
||||
DiagnosticDescriptors.AbstractClassIgnored,
|
||||
Location.None,
|
||||
classSymbol.Name));
|
||||
return null;
|
||||
}
|
||||
|
||||
// Extract constructor arguments: method and path
|
||||
if (stellaAttribute.ConstructorArguments.Length < 2)
|
||||
return null;
|
||||
|
||||
var method = stellaAttribute.ConstructorArguments[0].Value as string ?? "GET";
|
||||
var path = stellaAttribute.ConstructorArguments[1].Value as string ?? "/";
|
||||
|
||||
// Extract named arguments
|
||||
var timeoutSeconds = 30;
|
||||
var supportsStreaming = false;
|
||||
var requiredClaims = Array.Empty<string>();
|
||||
|
||||
foreach (var namedArg in stellaAttribute.NamedArguments)
|
||||
{
|
||||
switch (namedArg.Key)
|
||||
{
|
||||
case "TimeoutSeconds":
|
||||
timeoutSeconds = (int)(namedArg.Value.Value ?? 30);
|
||||
break;
|
||||
case "SupportsStreaming":
|
||||
supportsStreaming = (bool)(namedArg.Value.Value ?? false);
|
||||
break;
|
||||
case "RequiredClaims":
|
||||
if (!namedArg.Value.IsNull && namedArg.Value.Values.Length > 0)
|
||||
{
|
||||
requiredClaims = namedArg.Value.Values
|
||||
.Select(v => v.Value as string)
|
||||
.Where(s => s is not null)
|
||||
.Cast<string>()
|
||||
.ToArray();
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Find handler interface implementation
|
||||
string? requestTypeName = null;
|
||||
string? responseTypeName = null;
|
||||
bool isRaw = false;
|
||||
|
||||
foreach (var iface in classSymbol.AllInterfaces)
|
||||
{
|
||||
var fullName = iface.OriginalDefinition.ToDisplayString();
|
||||
|
||||
if (fullName.StartsWith(IStellaEndpointName) && iface.TypeArguments.Length == 2)
|
||||
{
|
||||
requestTypeName = iface.TypeArguments[0].ToDisplayString();
|
||||
responseTypeName = iface.TypeArguments[1].ToDisplayString();
|
||||
isRaw = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (fullName == IRawStellaEndpointName)
|
||||
{
|
||||
isRaw = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If no handler interface found, report error
|
||||
if (!isRaw && requestTypeName is null)
|
||||
{
|
||||
context.ReportDiagnostic(Diagnostic.Create(
|
||||
DiagnosticDescriptors.MissingHandlerInterface,
|
||||
Location.None,
|
||||
classSymbol.Name));
|
||||
return null;
|
||||
}
|
||||
|
||||
var ns = classSymbol.ContainingNamespace.IsGlobalNamespace
|
||||
? string.Empty
|
||||
: classSymbol.ContainingNamespace.ToDisplayString();
|
||||
|
||||
return new EndpointInfo(
|
||||
Namespace: ns,
|
||||
ClassName: classSymbol.Name,
|
||||
FullyQualifiedName: classSymbol.ToDisplayString(),
|
||||
Method: method.ToUpperInvariant(),
|
||||
Path: path,
|
||||
TimeoutSeconds: timeoutSeconds,
|
||||
SupportsStreaming: supportsStreaming,
|
||||
RequiredClaims: requiredClaims,
|
||||
RequestTypeName: requestTypeName,
|
||||
ResponseTypeName: responseTypeName,
|
||||
IsRaw: isRaw);
|
||||
}
|
||||
|
||||
private static string GenerateEndpointsClass(List<EndpointInfo> endpoints)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
|
||||
sb.AppendLine("// <auto-generated/>");
|
||||
sb.AppendLine("#nullable enable");
|
||||
sb.AppendLine();
|
||||
sb.AppendLine("namespace StellaOps.Microservice.Generated");
|
||||
sb.AppendLine("{");
|
||||
sb.AppendLine(" /// <summary>");
|
||||
sb.AppendLine(" /// Auto-generated endpoint metadata and registration.");
|
||||
sb.AppendLine(" /// </summary>");
|
||||
sb.AppendLine(" [global::System.CodeDom.Compiler.GeneratedCode(\"StellaOps.Microservice.SourceGen\", \"1.0.0\")]");
|
||||
sb.AppendLine(" internal static class StellaEndpoints");
|
||||
sb.AppendLine(" {");
|
||||
|
||||
// GetEndpoints method
|
||||
sb.AppendLine(" /// <summary>");
|
||||
sb.AppendLine(" /// Gets all discovered endpoint descriptors.");
|
||||
sb.AppendLine(" /// </summary>");
|
||||
sb.AppendLine(" public static global::System.Collections.Generic.IReadOnlyList<global::StellaOps.Router.Common.Models.EndpointDescriptor> GetEndpoints()");
|
||||
sb.AppendLine(" {");
|
||||
sb.AppendLine(" return new global::StellaOps.Router.Common.Models.EndpointDescriptor[]");
|
||||
sb.AppendLine(" {");
|
||||
|
||||
for (int i = 0; i < endpoints.Count; i++)
|
||||
{
|
||||
var ep = endpoints[i];
|
||||
sb.AppendLine(" new global::StellaOps.Router.Common.Models.EndpointDescriptor");
|
||||
sb.AppendLine(" {");
|
||||
sb.AppendLine(" ServiceName = string.Empty, // Set by SDK at registration");
|
||||
sb.AppendLine(" Version = string.Empty, // Set by SDK at registration");
|
||||
sb.AppendLine($" Method = \"{EscapeString(ep.Method)}\",");
|
||||
sb.AppendLine($" Path = \"{EscapeString(ep.Path)}\",");
|
||||
sb.AppendLine($" DefaultTimeout = global::System.TimeSpan.FromSeconds({ep.TimeoutSeconds}),");
|
||||
sb.AppendLine($" SupportsStreaming = {(ep.SupportsStreaming ? "true" : "false")},");
|
||||
sb.Append(" RequiringClaims = ");
|
||||
if (ep.RequiredClaims.Length == 0)
|
||||
{
|
||||
sb.AppendLine("new global::System.Collections.Generic.List<global::StellaOps.Router.Common.Models.ClaimRequirement>(),");
|
||||
}
|
||||
else
|
||||
{
|
||||
sb.AppendLine("new global::System.Collections.Generic.List<global::StellaOps.Router.Common.Models.ClaimRequirement>");
|
||||
sb.AppendLine(" {");
|
||||
foreach (var claim in ep.RequiredClaims)
|
||||
{
|
||||
sb.AppendLine($" new global::StellaOps.Router.Common.Models.ClaimRequirement {{ Type = \"{EscapeString(claim)}\", Value = null }},");
|
||||
}
|
||||
sb.AppendLine(" },");
|
||||
}
|
||||
sb.AppendLine($" HandlerType = typeof(global::{ep.FullyQualifiedName})");
|
||||
sb.Append(" }");
|
||||
if (i < endpoints.Count - 1)
|
||||
{
|
||||
sb.AppendLine(",");
|
||||
}
|
||||
else
|
||||
{
|
||||
sb.AppendLine();
|
||||
}
|
||||
}
|
||||
|
||||
sb.AppendLine(" };");
|
||||
sb.AppendLine(" }");
|
||||
sb.AppendLine();
|
||||
|
||||
// RegisterHandlers method
|
||||
sb.AppendLine(" /// <summary>");
|
||||
sb.AppendLine(" /// Registers all endpoint handlers with the service collection.");
|
||||
sb.AppendLine(" /// </summary>");
|
||||
sb.AppendLine(" public static void RegisterHandlers(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services)");
|
||||
sb.AppendLine(" {");
|
||||
|
||||
foreach (var ep in endpoints)
|
||||
{
|
||||
sb.AppendLine($" global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient<global::{ep.FullyQualifiedName}>(services);");
|
||||
}
|
||||
|
||||
sb.AppendLine(" }");
|
||||
|
||||
// GetHandlerTypes method
|
||||
sb.AppendLine();
|
||||
sb.AppendLine(" /// <summary>");
|
||||
sb.AppendLine(" /// Gets all handler types for endpoint discovery.");
|
||||
sb.AppendLine(" /// </summary>");
|
||||
sb.AppendLine(" public static global::System.Collections.Generic.IReadOnlyList<global::System.Type> GetHandlerTypes()");
|
||||
sb.AppendLine(" {");
|
||||
sb.AppendLine(" return new global::System.Type[]");
|
||||
sb.AppendLine(" {");
|
||||
|
||||
for (int i = 0; i < endpoints.Count; i++)
|
||||
{
|
||||
var ep = endpoints[i];
|
||||
sb.Append($" typeof(global::{ep.FullyQualifiedName})");
|
||||
if (i < endpoints.Count - 1)
|
||||
{
|
||||
sb.AppendLine(",");
|
||||
}
|
||||
else
|
||||
{
|
||||
sb.AppendLine();
|
||||
}
|
||||
}
|
||||
|
||||
sb.AppendLine(" };");
|
||||
sb.AppendLine(" }");
|
||||
|
||||
sb.AppendLine(" }");
|
||||
sb.AppendLine("}");
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private static string GenerateProviderClass()
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
|
||||
sb.AppendLine("// <auto-generated/>");
|
||||
sb.AppendLine("#nullable enable");
|
||||
sb.AppendLine();
|
||||
sb.AppendLine("namespace StellaOps.Microservice.Generated");
|
||||
sb.AppendLine("{");
|
||||
sb.AppendLine(" /// <summary>");
|
||||
sb.AppendLine(" /// Generated implementation of IGeneratedEndpointProvider.");
|
||||
sb.AppendLine(" /// </summary>");
|
||||
sb.AppendLine(" [global::System.CodeDom.Compiler.GeneratedCode(\"StellaOps.Microservice.SourceGen\", \"1.0.0\")]");
|
||||
sb.AppendLine(" internal sealed class GeneratedEndpointProvider : global::StellaOps.Microservice.IGeneratedEndpointProvider");
|
||||
sb.AppendLine(" {");
|
||||
sb.AppendLine(" /// <inheritdoc />");
|
||||
sb.AppendLine(" public global::System.Collections.Generic.IReadOnlyList<global::StellaOps.Router.Common.Models.EndpointDescriptor> GetEndpoints()");
|
||||
sb.AppendLine(" => StellaEndpoints.GetEndpoints();");
|
||||
sb.AppendLine();
|
||||
sb.AppendLine(" /// <inheritdoc />");
|
||||
sb.AppendLine(" public void RegisterHandlers(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services)");
|
||||
sb.AppendLine(" => StellaEndpoints.RegisterHandlers(services);");
|
||||
sb.AppendLine();
|
||||
sb.AppendLine(" /// <inheritdoc />");
|
||||
sb.AppendLine(" public global::System.Collections.Generic.IReadOnlyList<global::System.Type> GetHandlerTypes()");
|
||||
sb.AppendLine(" => StellaEndpoints.GetHandlerTypes();");
|
||||
sb.AppendLine(" }");
|
||||
sb.AppendLine("}");
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private static string EscapeString(string value)
|
||||
{
|
||||
return value
|
||||
.Replace("\\", "\\\\")
|
||||
.Replace("\"", "\\\"")
|
||||
.Replace("\n", "\\n")
|
||||
.Replace("\r", "\\r")
|
||||
.Replace("\t", "\\t");
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,32 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net10.0</TargetFramework>
|
||||
<LangVersion>preview</LangVersion>
|
||||
<!-- Source generators must target netstandard2.0 for Roslyn compatibility -->
|
||||
<TargetFramework>netstandard2.0</TargetFramework>
|
||||
<LangVersion>12.0</LangVersion>
|
||||
<Nullable>enable</Nullable>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
|
||||
<!-- Source generator specific settings -->
|
||||
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules>
|
||||
<IsRoslynComponent>true</IsRoslynComponent>
|
||||
<IncludeBuildOutput>false</IncludeBuildOutput>
|
||||
|
||||
<!-- Package settings for distribution -->
|
||||
<PackageId>StellaOps.Microservice.SourceGen</PackageId>
|
||||
<Description>Source generator for Stella microservice endpoints</Description>
|
||||
<DevelopmentDependency>true</DevelopmentDependency>
|
||||
<IsPackable>true</IsPackable>
|
||||
<NoWarn>$(NoWarn);NU5128;RS2008</NoWarn>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.4" PrivateAssets="all" />
|
||||
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.0.1" PrivateAssets="all" />
|
||||
</ItemGroup>
|
||||
|
||||
<!-- Pack the analyzer as an analyzer -->
|
||||
<ItemGroup>
|
||||
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Microservice;
|
||||
|
||||
/// <summary>
|
||||
/// Interface for discovering endpoints with YAML configuration support.
|
||||
/// </summary>
|
||||
public interface IEndpointDiscoveryService
|
||||
{
|
||||
/// <summary>
|
||||
/// Discovers all endpoints, applying any YAML configuration overrides.
|
||||
/// </summary>
|
||||
/// <returns>The discovered endpoints with overrides applied.</returns>
|
||||
IReadOnlyList<EndpointDescriptor> DiscoverEndpoints();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Service that discovers endpoints and applies YAML configuration overrides.
|
||||
/// </summary>
|
||||
public sealed class EndpointDiscoveryService : IEndpointDiscoveryService
|
||||
{
|
||||
private readonly IEndpointDiscoveryProvider _discoveryProvider;
|
||||
private readonly IMicroserviceYamlLoader _yamlLoader;
|
||||
private readonly IEndpointOverrideMerger _merger;
|
||||
private readonly ILogger<EndpointDiscoveryService> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="EndpointDiscoveryService"/> class.
|
||||
/// </summary>
|
||||
public EndpointDiscoveryService(
|
||||
IEndpointDiscoveryProvider discoveryProvider,
|
||||
IMicroserviceYamlLoader yamlLoader,
|
||||
IEndpointOverrideMerger merger,
|
||||
ILogger<EndpointDiscoveryService> logger)
|
||||
{
|
||||
_discoveryProvider = discoveryProvider;
|
||||
_yamlLoader = yamlLoader;
|
||||
_merger = merger;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public IReadOnlyList<EndpointDescriptor> DiscoverEndpoints()
|
||||
{
|
||||
// 1. Discover endpoints from code (via reflection or source gen)
|
||||
var codeEndpoints = _discoveryProvider.DiscoverEndpoints();
|
||||
_logger.LogDebug("Discovered {Count} endpoints from code", codeEndpoints.Count);
|
||||
|
||||
// 2. Load YAML overrides
|
||||
MicroserviceYamlConfig? yamlConfig = null;
|
||||
try
|
||||
{
|
||||
yamlConfig = _yamlLoader.Load();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Failed to load YAML configuration, using code defaults only");
|
||||
}
|
||||
|
||||
// 3. Merge code endpoints with YAML overrides
|
||||
var mergedEndpoints = _merger.Merge(codeEndpoints, yamlConfig);
|
||||
|
||||
_logger.LogInformation(
|
||||
"Endpoint discovery complete: {Count} endpoints (YAML overrides: {HasYaml})",
|
||||
mergedEndpoints.Count,
|
||||
yamlConfig != null);
|
||||
|
||||
return mergedEndpoints;
|
||||
}
|
||||
}
|
||||
115
src/__Libraries/StellaOps.Microservice/EndpointOverrideMerger.cs
Normal file
115
src/__Libraries/StellaOps.Microservice/EndpointOverrideMerger.cs
Normal file
@@ -0,0 +1,115 @@
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Microservice;
|
||||
|
||||
/// <summary>
|
||||
/// Interface for merging endpoint overrides from YAML configuration.
|
||||
/// </summary>
|
||||
public interface IEndpointOverrideMerger
|
||||
{
|
||||
/// <summary>
|
||||
/// Merges YAML overrides with code-defined endpoints.
|
||||
/// </summary>
|
||||
/// <param name="codeEndpoints">The endpoints discovered from code.</param>
|
||||
/// <param name="yamlConfig">The YAML configuration, if any.</param>
|
||||
/// <returns>The merged endpoints.</returns>
|
||||
IReadOnlyList<EndpointDescriptor> Merge(
|
||||
IReadOnlyList<EndpointDescriptor> codeEndpoints,
|
||||
MicroserviceYamlConfig? yamlConfig);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Merges endpoint overrides from YAML configuration with code defaults.
|
||||
/// </summary>
|
||||
public sealed class EndpointOverrideMerger : IEndpointOverrideMerger
|
||||
{
|
||||
private readonly ILogger<EndpointOverrideMerger> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="EndpointOverrideMerger"/> class.
|
||||
/// </summary>
|
||||
public EndpointOverrideMerger(ILogger<EndpointOverrideMerger> logger)
|
||||
{
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public IReadOnlyList<EndpointDescriptor> Merge(
|
||||
IReadOnlyList<EndpointDescriptor> codeEndpoints,
|
||||
MicroserviceYamlConfig? yamlConfig)
|
||||
{
|
||||
if (yamlConfig == null || yamlConfig.Endpoints.Count == 0)
|
||||
{
|
||||
return codeEndpoints;
|
||||
}
|
||||
|
||||
WarnUnmatchedOverrides(codeEndpoints, yamlConfig);
|
||||
|
||||
return codeEndpoints.Select(ep =>
|
||||
{
|
||||
var yamlOverride = FindMatchingOverride(ep, yamlConfig);
|
||||
return yamlOverride == null ? ep : MergeEndpoint(ep, yamlOverride);
|
||||
}).ToList();
|
||||
}
|
||||
|
||||
private static EndpointOverrideConfig? FindMatchingOverride(
|
||||
EndpointDescriptor endpoint,
|
||||
MicroserviceYamlConfig yamlConfig)
|
||||
{
|
||||
return yamlConfig.Endpoints.FirstOrDefault(y =>
|
||||
string.Equals(y.Method, endpoint.Method, StringComparison.OrdinalIgnoreCase) &&
|
||||
string.Equals(y.Path, endpoint.Path, StringComparison.OrdinalIgnoreCase));
|
||||
}
|
||||
|
||||
private EndpointDescriptor MergeEndpoint(
|
||||
EndpointDescriptor codeDefault,
|
||||
EndpointOverrideConfig yamlOverride)
|
||||
{
|
||||
var merged = codeDefault with
|
||||
{
|
||||
DefaultTimeout = yamlOverride.GetDefaultTimeoutAsTimeSpan() ?? codeDefault.DefaultTimeout,
|
||||
SupportsStreaming = yamlOverride.SupportsStreaming ?? codeDefault.SupportsStreaming,
|
||||
RequiringClaims = yamlOverride.RequiringClaims?.Count > 0
|
||||
? yamlOverride.RequiringClaims.Select(c => c.ToClaimRequirement()).ToList()
|
||||
: codeDefault.RequiringClaims
|
||||
};
|
||||
|
||||
if (yamlOverride.GetDefaultTimeoutAsTimeSpan().HasValue ||
|
||||
yamlOverride.SupportsStreaming.HasValue ||
|
||||
yamlOverride.RequiringClaims?.Count > 0)
|
||||
{
|
||||
_logger.LogDebug(
|
||||
"Applied YAML overrides to endpoint {Method} {Path}: Timeout={Timeout}, Streaming={Streaming}, Claims={Claims}",
|
||||
merged.Method,
|
||||
merged.Path,
|
||||
merged.DefaultTimeout,
|
||||
merged.SupportsStreaming,
|
||||
merged.RequiringClaims?.Count ?? 0);
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
private void WarnUnmatchedOverrides(
|
||||
IReadOnlyList<EndpointDescriptor> codeEndpoints,
|
||||
MicroserviceYamlConfig yamlConfig)
|
||||
{
|
||||
var codeKeys = codeEndpoints
|
||||
.Select(e => (Method: e.Method.ToUpperInvariant(), Path: e.Path.ToLowerInvariant()))
|
||||
.ToHashSet();
|
||||
|
||||
foreach (var yamlEntry in yamlConfig.Endpoints)
|
||||
{
|
||||
var key = (Method: yamlEntry.Method.ToUpperInvariant(), Path: yamlEntry.Path.ToLowerInvariant());
|
||||
if (!codeKeys.Contains(key))
|
||||
{
|
||||
_logger.LogWarning(
|
||||
"YAML override for {Method} {Path} does not match any code endpoint. " +
|
||||
"YAML cannot create endpoints, only modify existing ones.",
|
||||
yamlEntry.Method,
|
||||
yamlEntry.Path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Microservice;
|
||||
|
||||
/// <summary>
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
using System.Reflection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Microservice;
|
||||
|
||||
/// <summary>
|
||||
/// Discovers endpoints using source-generated provider, falling back to reflection.
|
||||
/// </summary>
|
||||
public sealed class GeneratedEndpointDiscoveryProvider : IEndpointDiscoveryProvider
|
||||
{
|
||||
private readonly StellaMicroserviceOptions _options;
|
||||
private readonly ILogger<GeneratedEndpointDiscoveryProvider> _logger;
|
||||
private readonly ReflectionEndpointDiscoveryProvider _reflectionFallback;
|
||||
|
||||
private const string GeneratedProviderTypeName = "StellaOps.Microservice.Generated.GeneratedEndpointProvider";
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="GeneratedEndpointDiscoveryProvider"/> class.
|
||||
/// </summary>
|
||||
public GeneratedEndpointDiscoveryProvider(
|
||||
StellaMicroserviceOptions options,
|
||||
ILogger<GeneratedEndpointDiscoveryProvider> logger)
|
||||
{
|
||||
_options = options;
|
||||
_logger = logger;
|
||||
_reflectionFallback = new ReflectionEndpointDiscoveryProvider(options);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public IReadOnlyList<EndpointDescriptor> DiscoverEndpoints()
|
||||
{
|
||||
// Try to find the generated provider
|
||||
var generatedProvider = TryGetGeneratedProvider();
|
||||
|
||||
if (generatedProvider != null)
|
||||
{
|
||||
_logger.LogDebug("Using source-generated endpoint discovery");
|
||||
var endpoints = generatedProvider.GetEndpoints();
|
||||
|
||||
// Apply service name and version from options
|
||||
var result = new List<EndpointDescriptor>();
|
||||
foreach (var endpoint in endpoints)
|
||||
{
|
||||
result.Add(endpoint with
|
||||
{
|
||||
ServiceName = _options.ServiceName,
|
||||
Version = _options.Version
|
||||
});
|
||||
}
|
||||
|
||||
_logger.LogInformation(
|
||||
"Discovered {Count} endpoints via source generation",
|
||||
result.Count);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Fall back to reflection
|
||||
_logger.LogDebug("Source-generated provider not found, falling back to reflection");
|
||||
return _reflectionFallback.DiscoverEndpoints();
|
||||
}
|
||||
|
||||
private IGeneratedEndpointProvider? TryGetGeneratedProvider()
|
||||
{
|
||||
try
|
||||
{
|
||||
// Look in the entry assembly first
|
||||
var entryAssembly = Assembly.GetEntryAssembly();
|
||||
var providerType = entryAssembly?.GetType(GeneratedProviderTypeName);
|
||||
|
||||
if (providerType != null)
|
||||
{
|
||||
return (IGeneratedEndpointProvider)Activator.CreateInstance(providerType)!;
|
||||
}
|
||||
|
||||
// Also check the calling assembly
|
||||
var callingAssembly = Assembly.GetCallingAssembly();
|
||||
providerType = callingAssembly.GetType(GeneratedProviderTypeName);
|
||||
|
||||
if (providerType != null)
|
||||
{
|
||||
return (IGeneratedEndpointProvider)Activator.CreateInstance(providerType)!;
|
||||
}
|
||||
|
||||
// Check all loaded assemblies
|
||||
foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
|
||||
{
|
||||
try
|
||||
{
|
||||
providerType = assembly.GetType(GeneratedProviderTypeName);
|
||||
if (providerType != null)
|
||||
{
|
||||
return (IGeneratedEndpointProvider)Activator.CreateInstance(providerType)!;
|
||||
}
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Ignore assembly loading errors
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogDebug(ex, "Failed to load generated endpoint provider");
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Microservice;
|
||||
|
||||
/// <summary>
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Microservice;
|
||||
|
||||
/// <summary>
|
||||
/// Interface implemented by the source-generated endpoint provider.
|
||||
/// </summary>
|
||||
public interface IGeneratedEndpointProvider
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets all discovered endpoint descriptors.
|
||||
/// </summary>
|
||||
IReadOnlyList<EndpointDescriptor> GetEndpoints();
|
||||
|
||||
/// <summary>
|
||||
/// Registers all endpoint handlers with the service collection.
|
||||
/// </summary>
|
||||
void RegisterHandlers(IServiceCollection services);
|
||||
|
||||
/// <summary>
|
||||
/// Gets all handler types for endpoint discovery.
|
||||
/// </summary>
|
||||
IReadOnlyList<Type> GetHandlerTypes();
|
||||
}
|
||||
145
src/__Libraries/StellaOps.Microservice/InflightRequestTracker.cs
Normal file
145
src/__Libraries/StellaOps.Microservice/InflightRequestTracker.cs
Normal file
@@ -0,0 +1,145 @@
|
||||
using System.Collections.Concurrent;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace StellaOps.Microservice;
|
||||
|
||||
/// <summary>
|
||||
/// Tracks in-flight requests and manages their cancellation tokens.
|
||||
/// </summary>
|
||||
public sealed class InflightRequestTracker : IDisposable
|
||||
{
|
||||
private readonly ConcurrentDictionary<Guid, InflightRequest> _inflight = new();
|
||||
private readonly ILogger<InflightRequestTracker> _logger;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="InflightRequestTracker"/> class.
|
||||
/// </summary>
|
||||
public InflightRequestTracker(ILogger<InflightRequestTracker> logger)
|
||||
{
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the count of in-flight requests.
|
||||
/// </summary>
|
||||
public int Count => _inflight.Count;
|
||||
|
||||
/// <summary>
|
||||
/// Starts tracking a request and returns a cancellation token for it.
|
||||
/// </summary>
|
||||
/// <param name="correlationId">The correlation ID of the request.</param>
|
||||
/// <returns>A cancellation token that will be triggered if the request is cancelled.</returns>
|
||||
public CancellationToken Track(Guid correlationId)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var cts = new CancellationTokenSource();
|
||||
var request = new InflightRequest(cts);
|
||||
|
||||
if (!_inflight.TryAdd(correlationId, request))
|
||||
{
|
||||
cts.Dispose();
|
||||
throw new InvalidOperationException($"Request {correlationId} is already being tracked");
|
||||
}
|
||||
|
||||
_logger.LogDebug("Started tracking request {CorrelationId}", correlationId);
|
||||
return cts.Token;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cancels a specific request.
|
||||
/// </summary>
|
||||
/// <param name="correlationId">The correlation ID of the request to cancel.</param>
|
||||
/// <param name="reason">The reason for cancellation.</param>
|
||||
/// <returns>True if the request was found and cancelled; otherwise false.</returns>
|
||||
public bool Cancel(Guid correlationId, string? reason)
|
||||
{
|
||||
if (_inflight.TryGetValue(correlationId, out var request))
|
||||
{
|
||||
try
|
||||
{
|
||||
request.Cts.Cancel();
|
||||
_logger.LogInformation(
|
||||
"Cancelled request {CorrelationId}: {Reason}",
|
||||
correlationId,
|
||||
reason ?? "Unknown");
|
||||
return true;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// CTS was already disposed, request completed
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogDebug(
|
||||
"Cannot cancel request {CorrelationId}: not found (may have already completed)",
|
||||
correlationId);
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Marks a request as completed and removes it from tracking.
|
||||
/// </summary>
|
||||
/// <param name="correlationId">The correlation ID of the completed request.</param>
|
||||
public void Complete(Guid correlationId)
|
||||
{
|
||||
if (_inflight.TryRemove(correlationId, out var request))
|
||||
{
|
||||
request.Cts.Dispose();
|
||||
_logger.LogDebug("Completed request {CorrelationId}", correlationId);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cancels all in-flight requests.
|
||||
/// </summary>
|
||||
/// <param name="reason">The reason for cancellation.</param>
|
||||
public void CancelAll(string reason)
|
||||
{
|
||||
var count = 0;
|
||||
foreach (var kvp in _inflight)
|
||||
{
|
||||
try
|
||||
{
|
||||
kvp.Value.Cts.Cancel();
|
||||
count++;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Already disposed
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogInformation("Cancelled {Count} in-flight requests: {Reason}", count, reason);
|
||||
|
||||
// Clear and dispose all
|
||||
foreach (var kvp in _inflight)
|
||||
{
|
||||
if (_inflight.TryRemove(kvp.Key, out var request))
|
||||
{
|
||||
request.Cts.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
CancelAll("Disposing tracker");
|
||||
}
|
||||
|
||||
private sealed class InflightRequest
|
||||
{
|
||||
public CancellationTokenSource Cts { get; }
|
||||
|
||||
public InflightRequest(CancellationTokenSource cts)
|
||||
{
|
||||
Cts = cts;
|
||||
}
|
||||
}
|
||||
}
|
||||
113
src/__Libraries/StellaOps.Microservice/MicroserviceYamlConfig.cs
Normal file
113
src/__Libraries/StellaOps.Microservice/MicroserviceYamlConfig.cs
Normal file
@@ -0,0 +1,113 @@
|
||||
using StellaOps.Router.Common.Models;
|
||||
using YamlDotNet.Serialization;
|
||||
|
||||
namespace StellaOps.Microservice;
|
||||
|
||||
/// <summary>
|
||||
/// Root configuration for microservice endpoint overrides loaded from YAML.
|
||||
/// </summary>
|
||||
public sealed class MicroserviceYamlConfig
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the endpoint override configurations.
|
||||
/// </summary>
|
||||
[YamlMember(Alias = "endpoints")]
|
||||
public List<EndpointOverrideConfig> Endpoints { get; set; } = [];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Configuration for overriding an endpoint's properties.
|
||||
/// </summary>
|
||||
public sealed class EndpointOverrideConfig
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the HTTP method to match.
|
||||
/// </summary>
|
||||
[YamlMember(Alias = "method")]
|
||||
public string Method { get; set; } = string.Empty;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the path to match.
|
||||
/// </summary>
|
||||
[YamlMember(Alias = "path")]
|
||||
public string Path { get; set; } = string.Empty;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the default timeout override.
|
||||
/// </summary>
|
||||
[YamlMember(Alias = "defaultTimeout")]
|
||||
public string? DefaultTimeout { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether streaming is supported.
|
||||
/// </summary>
|
||||
[YamlMember(Alias = "supportsStreaming")]
|
||||
public bool? SupportsStreaming { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the claim requirements.
|
||||
/// </summary>
|
||||
[YamlMember(Alias = "requiringClaims")]
|
||||
public List<ClaimRequirementConfig>? RequiringClaims { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Parses the DefaultTimeout string to a TimeSpan.
|
||||
/// </summary>
|
||||
public TimeSpan? GetDefaultTimeoutAsTimeSpan()
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(DefaultTimeout))
|
||||
return null;
|
||||
|
||||
// Handle formats like "30s", "5m", "1h", or "00:00:30"
|
||||
var value = DefaultTimeout.Trim();
|
||||
|
||||
if (value.EndsWith("s", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
if (int.TryParse(value[..^1], out var seconds))
|
||||
return TimeSpan.FromSeconds(seconds);
|
||||
}
|
||||
else if (value.EndsWith("m", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
if (int.TryParse(value[..^1], out var minutes))
|
||||
return TimeSpan.FromMinutes(minutes);
|
||||
}
|
||||
else if (value.EndsWith("h", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
if (int.TryParse(value[..^1], out var hours))
|
||||
return TimeSpan.FromHours(hours);
|
||||
}
|
||||
else if (TimeSpan.TryParse(value, out var timespan))
|
||||
{
|
||||
return timespan;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Configuration for a claim requirement.
|
||||
/// </summary>
|
||||
public sealed class ClaimRequirementConfig
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the claim type.
|
||||
/// </summary>
|
||||
[YamlMember(Alias = "type")]
|
||||
public string Type { get; set; } = string.Empty;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the claim value.
|
||||
/// </summary>
|
||||
[YamlMember(Alias = "value")]
|
||||
public string? Value { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Converts to a ClaimRequirement model.
|
||||
/// </summary>
|
||||
public ClaimRequirement ToClaimRequirement() => new()
|
||||
{
|
||||
Type = Type,
|
||||
Value = Value
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
using Microsoft.Extensions.Logging;
|
||||
using YamlDotNet.Serialization;
|
||||
using YamlDotNet.Serialization.NamingConventions;
|
||||
|
||||
namespace StellaOps.Microservice;
|
||||
|
||||
/// <summary>
|
||||
/// Interface for loading microservice YAML configuration.
|
||||
/// </summary>
|
||||
public interface IMicroserviceYamlLoader
|
||||
{
|
||||
/// <summary>
|
||||
/// Loads the microservice configuration from YAML.
|
||||
/// </summary>
|
||||
/// <returns>The configuration, or null if no file is configured or file doesn't exist.</returns>
|
||||
MicroserviceYamlConfig? Load();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Loads microservice configuration from a YAML file.
|
||||
/// </summary>
|
||||
public sealed class MicroserviceYamlLoader : IMicroserviceYamlLoader
|
||||
{
|
||||
private readonly StellaMicroserviceOptions _options;
|
||||
private readonly ILogger<MicroserviceYamlLoader> _logger;
|
||||
private readonly IDeserializer _deserializer;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="MicroserviceYamlLoader"/> class.
|
||||
/// </summary>
|
||||
public MicroserviceYamlLoader(
|
||||
StellaMicroserviceOptions options,
|
||||
ILogger<MicroserviceYamlLoader> logger)
|
||||
{
|
||||
_options = options;
|
||||
_logger = logger;
|
||||
_deserializer = new DeserializerBuilder()
|
||||
.WithNamingConvention(CamelCaseNamingConvention.Instance)
|
||||
.IgnoreUnmatchedProperties()
|
||||
.Build();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public MicroserviceYamlConfig? Load()
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(_options.ConfigFilePath))
|
||||
{
|
||||
_logger.LogDebug("No ConfigFilePath specified, skipping YAML configuration");
|
||||
return null;
|
||||
}
|
||||
|
||||
var fullPath = Path.GetFullPath(_options.ConfigFilePath);
|
||||
|
||||
if (!File.Exists(fullPath))
|
||||
{
|
||||
_logger.LogDebug("Configuration file {Path} does not exist, skipping", fullPath);
|
||||
return null;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var yaml = File.ReadAllText(fullPath);
|
||||
var config = _deserializer.Deserialize<MicroserviceYamlConfig>(yaml);
|
||||
|
||||
_logger.LogInformation(
|
||||
"Loaded microservice configuration from {Path} with {Count} endpoint overrides",
|
||||
fullPath,
|
||||
config?.Endpoints?.Count ?? 0);
|
||||
|
||||
return config;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Failed to load microservice configuration from {Path}", fullPath);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,85 +1,2 @@
|
||||
using System.Text.RegularExpressions;
|
||||
|
||||
namespace StellaOps.Microservice;
|
||||
|
||||
/// <summary>
|
||||
/// Matches request paths against route templates.
|
||||
/// </summary>
|
||||
public sealed partial class PathMatcher
|
||||
{
|
||||
private readonly string _template;
|
||||
private readonly Regex _regex;
|
||||
private readonly string[] _parameterNames;
|
||||
private readonly bool _caseInsensitive;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the route template.
|
||||
/// </summary>
|
||||
public string Template => _template;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PathMatcher"/> class.
|
||||
/// </summary>
|
||||
/// <param name="template">The route template (e.g., "/api/users/{id}").</param>
|
||||
/// <param name="caseInsensitive">Whether matching should be case-insensitive.</param>
|
||||
public PathMatcher(string template, bool caseInsensitive = true)
|
||||
{
|
||||
_template = template;
|
||||
_caseInsensitive = caseInsensitive;
|
||||
|
||||
// Extract parameter names and build regex
|
||||
var paramNames = new List<string>();
|
||||
var pattern = "^" + ParameterRegex().Replace(template, match =>
|
||||
{
|
||||
paramNames.Add(match.Groups[1].Value);
|
||||
return "([^/]+)";
|
||||
}) + "/?$";
|
||||
|
||||
var options = caseInsensitive ? RegexOptions.IgnoreCase : RegexOptions.None;
|
||||
_regex = new Regex(pattern, options | RegexOptions.Compiled);
|
||||
_parameterNames = [.. paramNames];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tries to match a path against the template.
|
||||
/// </summary>
|
||||
/// <param name="path">The request path.</param>
|
||||
/// <param name="parameters">The extracted path parameters if matched.</param>
|
||||
/// <returns>True if the path matches.</returns>
|
||||
public bool TryMatch(string path, out Dictionary<string, string> parameters)
|
||||
{
|
||||
parameters = [];
|
||||
|
||||
// Normalize path
|
||||
path = path.TrimEnd('/');
|
||||
if (!path.StartsWith('/'))
|
||||
path = "/" + path;
|
||||
|
||||
var match = _regex.Match(path);
|
||||
if (!match.Success)
|
||||
return false;
|
||||
|
||||
for (int i = 0; i < _parameterNames.Length; i++)
|
||||
{
|
||||
parameters[_parameterNames[i]] = match.Groups[i + 1].Value;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks if a path matches the template.
|
||||
/// </summary>
|
||||
/// <param name="path">The request path.</param>
|
||||
/// <returns>True if the path matches.</returns>
|
||||
public bool IsMatch(string path)
|
||||
{
|
||||
path = path.TrimEnd('/');
|
||||
if (!path.StartsWith('/'))
|
||||
path = "/" + path;
|
||||
return _regex.IsMatch(path);
|
||||
}
|
||||
|
||||
[GeneratedRegex(@"\{([^}:]+)(?::[^}]+)?\}")]
|
||||
private static partial Regex ParameterRegex();
|
||||
}
|
||||
// Re-export PathMatcher from Router.Common for backwards compatibility
|
||||
global using PathMatcher = StellaOps.Router.Common.PathMatcher;
|
||||
|
||||
@@ -2,6 +2,7 @@ using System.Text.Json;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Router.Common.Frames;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Microservice;
|
||||
|
||||
@@ -116,6 +117,13 @@ public sealed class RequestDispatcher
|
||||
RawRequestContext context,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
// Ensure handler type is set
|
||||
if (endpoint.HandlerType is null)
|
||||
{
|
||||
_logger.LogError("Endpoint {Method} {Path} has no handler type", endpoint.Method, endpoint.Path);
|
||||
return RawResponse.InternalError("No handler configured");
|
||||
}
|
||||
|
||||
// Get handler instance from DI
|
||||
var handler = scopedProvider.GetService(endpoint.HandlerType);
|
||||
if (handler is null)
|
||||
|
||||
@@ -14,13 +14,16 @@ public sealed class RouterConnectionManager : IRouterConnectionManager, IDisposa
|
||||
{
|
||||
private readonly StellaMicroserviceOptions _options;
|
||||
private readonly IEndpointDiscoveryProvider _endpointDiscovery;
|
||||
private readonly ITransportClient _transportClient;
|
||||
private readonly IMicroserviceTransport? _microserviceTransport;
|
||||
private readonly ILogger<RouterConnectionManager> _logger;
|
||||
private readonly ConcurrentDictionary<string, ConnectionState> _connections = new();
|
||||
private readonly CancellationTokenSource _cts = new();
|
||||
private IReadOnlyList<EndpointDescriptor>? _endpoints;
|
||||
private Task? _heartbeatTask;
|
||||
private bool _disposed;
|
||||
private volatile InstanceHealthStatus _currentStatus = InstanceHealthStatus.Healthy;
|
||||
private int _inFlightRequestCount;
|
||||
private double _errorRate;
|
||||
|
||||
/// <inheritdoc />
|
||||
public IReadOnlyList<ConnectionState> Connections => [.. _connections.Values];
|
||||
@@ -31,15 +34,42 @@ public sealed class RouterConnectionManager : IRouterConnectionManager, IDisposa
|
||||
public RouterConnectionManager(
|
||||
IOptions<StellaMicroserviceOptions> options,
|
||||
IEndpointDiscoveryProvider endpointDiscovery,
|
||||
ITransportClient transportClient,
|
||||
IMicroserviceTransport? microserviceTransport,
|
||||
ILogger<RouterConnectionManager> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_endpointDiscovery = endpointDiscovery;
|
||||
_transportClient = transportClient;
|
||||
_microserviceTransport = microserviceTransport;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the current health status reported by this instance.
|
||||
/// </summary>
|
||||
public InstanceHealthStatus CurrentStatus
|
||||
{
|
||||
get => _currentStatus;
|
||||
set => _currentStatus = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the count of in-flight requests.
|
||||
/// </summary>
|
||||
public int InFlightRequestCount
|
||||
{
|
||||
get => _inFlightRequestCount;
|
||||
set => _inFlightRequestCount = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the error rate (0.0 to 1.0).
|
||||
/// </summary>
|
||||
public double ErrorRate
|
||||
{
|
||||
get => _errorRate;
|
||||
set => _errorRate = value;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task StartAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
@@ -168,32 +198,40 @@ public sealed class RouterConnectionManager : IRouterConnectionManager, IDisposa
|
||||
{
|
||||
await Task.Delay(_options.HeartbeatInterval, cancellationToken);
|
||||
|
||||
foreach (var connection in _connections.Values)
|
||||
// Build heartbeat payload with current status and metrics
|
||||
var heartbeat = new HeartbeatPayload
|
||||
{
|
||||
InstanceId = _options.InstanceId,
|
||||
Status = _currentStatus,
|
||||
InFlightRequestCount = _inFlightRequestCount,
|
||||
ErrorRate = _errorRate,
|
||||
TimestampUtc = DateTime.UtcNow
|
||||
};
|
||||
|
||||
// Send heartbeat via transport
|
||||
if (_microserviceTransport is not null)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Build heartbeat payload
|
||||
var heartbeat = new HeartbeatPayload
|
||||
{
|
||||
InstanceId = _options.InstanceId,
|
||||
Status = connection.Status,
|
||||
TimestampUtc = DateTime.UtcNow
|
||||
};
|
||||
|
||||
// Update last heartbeat time
|
||||
connection.LastHeartbeatUtc = DateTime.UtcNow;
|
||||
await _microserviceTransport.SendHeartbeatAsync(heartbeat, cancellationToken);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Sent heartbeat for connection {ConnectionId}",
|
||||
connection.ConnectionId);
|
||||
"Sent heartbeat: status={Status}, inflight={InFlight}, errorRate={ErrorRate:P1}",
|
||||
heartbeat.Status,
|
||||
heartbeat.InFlightRequestCount,
|
||||
heartbeat.ErrorRate);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex,
|
||||
"Failed to send heartbeat for connection {ConnectionId}",
|
||||
connection.ConnectionId);
|
||||
_logger.LogWarning(ex, "Failed to send heartbeat");
|
||||
}
|
||||
}
|
||||
|
||||
// Update connection state local heartbeat times
|
||||
foreach (var connection in _connections.Values)
|
||||
{
|
||||
connection.LastHeartbeatUtc = DateTime.UtcNow;
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
|
||||
@@ -22,17 +22,34 @@ public static class ServiceCollectionExtensions
|
||||
ArgumentNullException.ThrowIfNull(services);
|
||||
ArgumentNullException.ThrowIfNull(configure);
|
||||
|
||||
// Configure options
|
||||
// Configure and register options as singleton
|
||||
var options = new StellaMicroserviceOptions { ServiceName = "", Version = "1.0.0", Region = "" };
|
||||
configure(options);
|
||||
services.AddSingleton(options);
|
||||
services.Configure(configure);
|
||||
|
||||
// Register endpoint discovery
|
||||
services.TryAddSingleton<IEndpointDiscoveryProvider>(sp =>
|
||||
// Register YAML loader and merger
|
||||
services.TryAddSingleton<IMicroserviceYamlLoader, MicroserviceYamlLoader>();
|
||||
services.TryAddSingleton<IEndpointOverrideMerger, EndpointOverrideMerger>();
|
||||
|
||||
// Register endpoint discovery provider (prefers generated over reflection)
|
||||
services.TryAddSingleton<IEndpointDiscoveryProvider, GeneratedEndpointDiscoveryProvider>();
|
||||
|
||||
// Register endpoint discovery service (with YAML integration)
|
||||
services.TryAddSingleton<IEndpointDiscoveryService, EndpointDiscoveryService>();
|
||||
|
||||
// Register endpoint registry (using discovery service)
|
||||
services.TryAddSingleton<IEndpointRegistry>(sp =>
|
||||
{
|
||||
var options = new StellaMicroserviceOptions { ServiceName = "", Version = "1.0.0", Region = "" };
|
||||
configure(options);
|
||||
return new ReflectionEndpointDiscoveryProvider(options);
|
||||
var discoveryService = sp.GetRequiredService<IEndpointDiscoveryService>();
|
||||
var registry = new EndpointRegistry();
|
||||
registry.RegisterAll(discoveryService.DiscoverEndpoints());
|
||||
return registry;
|
||||
});
|
||||
|
||||
// Register request dispatcher
|
||||
services.TryAddSingleton<RequestDispatcher>();
|
||||
|
||||
// Register connection manager
|
||||
services.TryAddSingleton<IRouterConnectionManager, RouterConnectionManager>();
|
||||
|
||||
@@ -57,12 +74,34 @@ public static class ServiceCollectionExtensions
|
||||
ArgumentNullException.ThrowIfNull(services);
|
||||
ArgumentNullException.ThrowIfNull(configure);
|
||||
|
||||
// Configure options
|
||||
// Configure and register options as singleton
|
||||
var options = new StellaMicroserviceOptions { ServiceName = "", Version = "1.0.0", Region = "" };
|
||||
configure(options);
|
||||
services.AddSingleton(options);
|
||||
services.Configure(configure);
|
||||
|
||||
// Register YAML loader and merger
|
||||
services.TryAddSingleton<IMicroserviceYamlLoader, MicroserviceYamlLoader>();
|
||||
services.TryAddSingleton<IEndpointOverrideMerger, EndpointOverrideMerger>();
|
||||
|
||||
// Register custom endpoint discovery
|
||||
services.TryAddSingleton<IEndpointDiscoveryProvider, TDiscovery>();
|
||||
|
||||
// Register endpoint discovery service (with YAML integration)
|
||||
services.TryAddSingleton<IEndpointDiscoveryService, EndpointDiscoveryService>();
|
||||
|
||||
// Register endpoint registry (using discovery service)
|
||||
services.TryAddSingleton<IEndpointRegistry>(sp =>
|
||||
{
|
||||
var discoveryService = sp.GetRequiredService<IEndpointDiscoveryService>();
|
||||
var registry = new EndpointRegistry();
|
||||
registry.RegisterAll(discoveryService.DiscoverEndpoints());
|
||||
return registry;
|
||||
});
|
||||
|
||||
// Register request dispatcher
|
||||
services.TryAddSingleton<RequestDispatcher>();
|
||||
|
||||
// Register connection manager
|
||||
services.TryAddSingleton<IRouterConnectionManager, RouterConnectionManager>();
|
||||
|
||||
@@ -71,4 +110,17 @@ public static class ServiceCollectionExtensions
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Registers an endpoint handler type for dependency injection.
|
||||
/// </summary>
|
||||
/// <typeparam name="THandler">The endpoint handler type.</typeparam>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <returns>The service collection for chaining.</returns>
|
||||
public static IServiceCollection AddStellaEndpoint<THandler>(this IServiceCollection services)
|
||||
where THandler : class, IStellaEndpoint
|
||||
{
|
||||
services.AddScoped<THandler>();
|
||||
return services;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Options" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="YamlDotNet" Version="13.7.1" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
using System.Threading.Channels;
|
||||
|
||||
namespace StellaOps.Microservice.Streaming;
|
||||
|
||||
/// <summary>
|
||||
/// A read-only stream that reads from a channel of data chunks.
|
||||
/// Used to expose streaming request body to handlers.
|
||||
/// </summary>
|
||||
public sealed class StreamingRequestBodyStream : Stream
|
||||
{
|
||||
private readonly ChannelReader<StreamChunk> _reader;
|
||||
private readonly CancellationToken _cancellationToken;
|
||||
private byte[] _currentBuffer = [];
|
||||
private int _currentBufferPosition;
|
||||
private bool _endOfStream;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="StreamingRequestBodyStream"/> class.
|
||||
/// </summary>
|
||||
/// <param name="reader">The channel reader for incoming chunks.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public StreamingRequestBodyStream(
|
||||
ChannelReader<StreamChunk> reader,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
_reader = reader;
|
||||
_cancellationToken = cancellationToken;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanRead => true;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanSeek => false;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanWrite => false;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Length => throw new NotSupportedException("Streaming body length unknown.");
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Position
|
||||
{
|
||||
get => throw new NotSupportedException("Streaming body position not supported.");
|
||||
set => throw new NotSupportedException("Streaming body position not supported.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void Flush() { }
|
||||
|
||||
/// <inheritdoc />
|
||||
public override int Read(byte[] buffer, int offset, int count)
|
||||
{
|
||||
return ReadAsync(buffer, offset, count, CancellationToken.None)
|
||||
.GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
{
|
||||
return await ReadAsync(buffer.AsMemory(offset, count), cancellationToken);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
if (_endOfStream)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cancellationToken);
|
||||
|
||||
// Try to use remaining data from current buffer first
|
||||
if (_currentBufferPosition < _currentBuffer.Length)
|
||||
{
|
||||
var bytesToCopy = Math.Min(buffer.Length, _currentBuffer.Length - _currentBufferPosition);
|
||||
_currentBuffer.AsSpan(_currentBufferPosition, bytesToCopy).CopyTo(buffer.Span);
|
||||
_currentBufferPosition += bytesToCopy;
|
||||
return bytesToCopy;
|
||||
}
|
||||
|
||||
// Need to read next chunk from channel
|
||||
if (!await _reader.WaitToReadAsync(linkedCts.Token))
|
||||
{
|
||||
_endOfStream = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!_reader.TryRead(out var chunk))
|
||||
{
|
||||
_endOfStream = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (chunk.EndOfStream)
|
||||
{
|
||||
_endOfStream = true;
|
||||
// Still process any data in the final chunk
|
||||
if (chunk.Data.Length == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
_currentBuffer = chunk.Data;
|
||||
_currentBufferPosition = 0;
|
||||
|
||||
var bytesToReturn = Math.Min(buffer.Length, _currentBuffer.Length);
|
||||
_currentBuffer.AsSpan(0, bytesToReturn).CopyTo(buffer.Span);
|
||||
_currentBufferPosition = bytesToReturn;
|
||||
return bytesToReturn;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Seek(long offset, SeekOrigin origin)
|
||||
{
|
||||
throw new NotSupportedException("Seeking not supported on streaming body.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void SetLength(long value)
|
||||
{
|
||||
throw new NotSupportedException("Setting length not supported on streaming body.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void Write(byte[] buffer, int offset, int count)
|
||||
{
|
||||
throw new NotSupportedException("Write not supported on streaming body.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
_disposed = true;
|
||||
base.Dispose(disposing);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Represents a chunk of streaming data.
|
||||
/// </summary>
|
||||
public sealed record StreamChunk
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the chunk data.
|
||||
/// </summary>
|
||||
public byte[] Data { get; init; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether this is the final chunk.
|
||||
/// </summary>
|
||||
public bool EndOfStream { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the sequence number.
|
||||
/// </summary>
|
||||
public int SequenceNumber { get; init; }
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
using System.Threading.Channels;
|
||||
|
||||
namespace StellaOps.Microservice.Streaming;
|
||||
|
||||
/// <summary>
|
||||
/// A write-only stream that writes chunks to a channel.
|
||||
/// Used to enable streaming response body from handlers.
|
||||
/// </summary>
|
||||
public sealed class StreamingResponseBodyStream : Stream
|
||||
{
|
||||
private readonly ChannelWriter<StreamChunk> _writer;
|
||||
private readonly int _chunkSize;
|
||||
private readonly CancellationToken _cancellationToken;
|
||||
private byte[] _buffer;
|
||||
private int _bufferPosition;
|
||||
private int _sequenceNumber;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="StreamingResponseBodyStream"/> class.
|
||||
/// </summary>
|
||||
/// <param name="writer">The channel writer for outgoing chunks.</param>
|
||||
/// <param name="chunkSize">The chunk size for buffered writes.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public StreamingResponseBodyStream(
|
||||
ChannelWriter<StreamChunk> writer,
|
||||
int chunkSize,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
_writer = writer;
|
||||
_chunkSize = chunkSize;
|
||||
_cancellationToken = cancellationToken;
|
||||
_buffer = new byte[chunkSize];
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanRead => false;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanSeek => false;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool CanWrite => true;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Length => throw new NotSupportedException();
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Position
|
||||
{
|
||||
get => throw new NotSupportedException();
|
||||
set => throw new NotSupportedException();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void Flush()
|
||||
{
|
||||
FlushAsync(CancellationToken.None).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async Task FlushAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
if (_bufferPosition > 0)
|
||||
{
|
||||
var chunk = new StreamChunk
|
||||
{
|
||||
Data = _buffer[.._bufferPosition],
|
||||
SequenceNumber = _sequenceNumber++,
|
||||
EndOfStream = false
|
||||
};
|
||||
|
||||
await _writer.WriteAsync(chunk, cancellationToken);
|
||||
_buffer = new byte[_chunkSize];
|
||||
_bufferPosition = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override int Read(byte[] buffer, int offset, int count)
|
||||
{
|
||||
throw new NotSupportedException("Read not supported on streaming response body.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override long Seek(long offset, SeekOrigin origin)
|
||||
{
|
||||
throw new NotSupportedException("Seeking not supported on streaming response body.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void SetLength(long value)
|
||||
{
|
||||
throw new NotSupportedException("Setting length not supported on streaming response body.");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void Write(byte[] buffer, int offset, int count)
|
||||
{
|
||||
WriteAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
{
|
||||
await WriteAsync(buffer.AsMemory(offset, count), cancellationToken);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cancellationToken);
|
||||
|
||||
var bytesWritten = 0;
|
||||
while (bytesWritten < buffer.Length)
|
||||
{
|
||||
var spaceInBuffer = _chunkSize - _bufferPosition;
|
||||
var bytesToWrite = Math.Min(spaceInBuffer, buffer.Length - bytesWritten);
|
||||
|
||||
buffer.Slice(bytesWritten, bytesToWrite).Span.CopyTo(_buffer.AsSpan(_bufferPosition));
|
||||
_bufferPosition += bytesToWrite;
|
||||
bytesWritten += bytesToWrite;
|
||||
|
||||
if (_bufferPosition >= _chunkSize)
|
||||
{
|
||||
await FlushAsync(linkedCts.Token);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Completes the stream by flushing remaining data and sending end-of-stream signal.
|
||||
/// </summary>
|
||||
public async Task CompleteAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Flush any remaining buffered data
|
||||
await FlushAsync(cancellationToken);
|
||||
|
||||
// Send end-of-stream marker
|
||||
var endChunk = new StreamChunk
|
||||
{
|
||||
Data = [],
|
||||
SequenceNumber = _sequenceNumber++,
|
||||
EndOfStream = true
|
||||
};
|
||||
|
||||
await _writer.WriteAsync(endChunk, cancellationToken);
|
||||
_writer.Complete();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
if (!_disposed && disposing)
|
||||
{
|
||||
// Try to complete the stream if not already completed
|
||||
try
|
||||
{
|
||||
_writer.TryComplete();
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Ignore errors during disposal
|
||||
}
|
||||
}
|
||||
|
||||
_disposed = true;
|
||||
base.Dispose(disposing);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override async ValueTask DisposeAsync()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
try
|
||||
{
|
||||
await CompleteAsync(CancellationToken.None);
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Ignore errors during disposal
|
||||
}
|
||||
}
|
||||
|
||||
_disposed = true;
|
||||
await base.DisposeAsync();
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,38 @@ namespace StellaOps.Router.Common.Abstractions;
|
||||
/// </summary>
|
||||
public interface IGlobalRoutingState
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds a connection to the routing state.
|
||||
/// </summary>
|
||||
/// <param name="connection">The connection state to add.</param>
|
||||
void AddConnection(ConnectionState connection);
|
||||
|
||||
/// <summary>
|
||||
/// Removes a connection from the routing state.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID to remove.</param>
|
||||
void RemoveConnection(string connectionId);
|
||||
|
||||
/// <summary>
|
||||
/// Updates an existing connection's state.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID to update.</param>
|
||||
/// <param name="update">The update action to apply.</param>
|
||||
void UpdateConnection(string connectionId, Action<ConnectionState> update);
|
||||
|
||||
/// <summary>
|
||||
/// Gets a connection by its ID.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID.</param>
|
||||
/// <returns>The connection state, or null if not found.</returns>
|
||||
ConnectionState? GetConnection(string connectionId);
|
||||
|
||||
/// <summary>
|
||||
/// Gets all active connections.
|
||||
/// </summary>
|
||||
/// <returns>All active connections.</returns>
|
||||
IReadOnlyList<ConnectionState> GetAllConnections();
|
||||
|
||||
/// <summary>
|
||||
/// Resolves an HTTP request to an endpoint descriptor.
|
||||
/// </summary>
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Common.Abstractions;
|
||||
|
||||
/// <summary>
|
||||
/// Represents a transport connection from a microservice to the gateway.
|
||||
/// This interface is used by the Microservice SDK to communicate with the router.
|
||||
/// </summary>
|
||||
public interface IMicroserviceTransport
|
||||
{
|
||||
/// <summary>
|
||||
/// Connects to the router and registers the microservice.
|
||||
/// </summary>
|
||||
/// <param name="instance">The instance descriptor.</param>
|
||||
/// <param name="endpoints">The endpoints to register.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
Task ConnectAsync(
|
||||
InstanceDescriptor instance,
|
||||
IReadOnlyList<EndpointDescriptor> endpoints,
|
||||
CancellationToken cancellationToken);
|
||||
|
||||
/// <summary>
|
||||
/// Disconnects from the router.
|
||||
/// </summary>
|
||||
Task DisconnectAsync();
|
||||
|
||||
/// <summary>
|
||||
/// Sends a heartbeat to the router.
|
||||
/// </summary>
|
||||
/// <param name="heartbeat">The heartbeat payload.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
Task SendHeartbeatAsync(HeartbeatPayload heartbeat, CancellationToken cancellationToken);
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a REQUEST frame is received from the gateway.
|
||||
/// </summary>
|
||||
event Func<Frame, CancellationToken, Task<Frame>>? OnRequestReceived;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a CANCEL frame is received from the gateway.
|
||||
/// </summary>
|
||||
event Func<Guid, string?, Task>? OnCancelReceived;
|
||||
}
|
||||
148
src/__Libraries/StellaOps.Router.Common/Frames/FrameConverter.cs
Normal file
148
src/__Libraries/StellaOps.Router.Common/Frames/FrameConverter.cs
Normal file
@@ -0,0 +1,148 @@
|
||||
using System.Text.Json;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Common.Frames;
|
||||
|
||||
/// <summary>
|
||||
/// Converts between generic Frame and typed frame records.
|
||||
/// </summary>
|
||||
public static class FrameConverter
|
||||
{
|
||||
private static readonly JsonSerializerOptions JsonOptions = new()
|
||||
{
|
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
||||
PropertyNameCaseInsensitive = true
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Converts a RequestFrame to a generic Frame for transport.
|
||||
/// </summary>
|
||||
public static Frame ToFrame(RequestFrame request)
|
||||
{
|
||||
var envelope = new RequestEnvelope
|
||||
{
|
||||
RequestId = request.RequestId,
|
||||
Method = request.Method,
|
||||
Path = request.Path,
|
||||
Headers = request.Headers,
|
||||
TimeoutSeconds = request.TimeoutSeconds,
|
||||
SupportsStreaming = request.SupportsStreaming,
|
||||
Payload = request.Payload.ToArray()
|
||||
};
|
||||
|
||||
var envelopeBytes = JsonSerializer.SerializeToUtf8Bytes(envelope, JsonOptions);
|
||||
|
||||
return new Frame
|
||||
{
|
||||
Type = FrameType.Request,
|
||||
CorrelationId = request.CorrelationId ?? request.RequestId,
|
||||
Payload = envelopeBytes
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Converts a generic Frame to a RequestFrame.
|
||||
/// </summary>
|
||||
public static RequestFrame? ToRequestFrame(Frame frame)
|
||||
{
|
||||
if (frame.Type != FrameType.Request)
|
||||
return null;
|
||||
|
||||
try
|
||||
{
|
||||
var envelope = JsonSerializer.Deserialize<RequestEnvelope>(frame.Payload.Span, JsonOptions);
|
||||
if (envelope is null)
|
||||
return null;
|
||||
|
||||
return new RequestFrame
|
||||
{
|
||||
RequestId = envelope.RequestId,
|
||||
CorrelationId = frame.CorrelationId,
|
||||
Method = envelope.Method,
|
||||
Path = envelope.Path,
|
||||
Headers = envelope.Headers ?? new Dictionary<string, string>(),
|
||||
TimeoutSeconds = envelope.TimeoutSeconds,
|
||||
SupportsStreaming = envelope.SupportsStreaming,
|
||||
Payload = envelope.Payload ?? []
|
||||
};
|
||||
}
|
||||
catch (JsonException)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Converts a ResponseFrame to a generic Frame for transport.
|
||||
/// </summary>
|
||||
public static Frame ToFrame(ResponseFrame response)
|
||||
{
|
||||
var envelope = new ResponseEnvelope
|
||||
{
|
||||
RequestId = response.RequestId,
|
||||
StatusCode = response.StatusCode,
|
||||
Headers = response.Headers,
|
||||
HasMoreChunks = response.HasMoreChunks,
|
||||
Payload = response.Payload.ToArray()
|
||||
};
|
||||
|
||||
var envelopeBytes = JsonSerializer.SerializeToUtf8Bytes(envelope, JsonOptions);
|
||||
|
||||
return new Frame
|
||||
{
|
||||
Type = FrameType.Response,
|
||||
CorrelationId = response.RequestId,
|
||||
Payload = envelopeBytes
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Converts a generic Frame to a ResponseFrame.
|
||||
/// </summary>
|
||||
public static ResponseFrame? ToResponseFrame(Frame frame)
|
||||
{
|
||||
if (frame.Type != FrameType.Response)
|
||||
return null;
|
||||
|
||||
try
|
||||
{
|
||||
var envelope = JsonSerializer.Deserialize<ResponseEnvelope>(frame.Payload.Span, JsonOptions);
|
||||
if (envelope is null)
|
||||
return null;
|
||||
|
||||
return new ResponseFrame
|
||||
{
|
||||
RequestId = envelope.RequestId,
|
||||
StatusCode = envelope.StatusCode,
|
||||
Headers = envelope.Headers ?? new Dictionary<string, string>(),
|
||||
HasMoreChunks = envelope.HasMoreChunks,
|
||||
Payload = envelope.Payload ?? []
|
||||
};
|
||||
}
|
||||
catch (JsonException)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class RequestEnvelope
|
||||
{
|
||||
public required string RequestId { get; set; }
|
||||
public required string Method { get; set; }
|
||||
public required string Path { get; set; }
|
||||
public IReadOnlyDictionary<string, string>? Headers { get; set; }
|
||||
public int TimeoutSeconds { get; set; } = 30;
|
||||
public bool SupportsStreaming { get; set; }
|
||||
public byte[]? Payload { get; set; }
|
||||
}
|
||||
|
||||
private sealed class ResponseEnvelope
|
||||
{
|
||||
public required string RequestId { get; set; }
|
||||
public int StatusCode { get; set; } = 200;
|
||||
public IReadOnlyDictionary<string, string>? Headers { get; set; }
|
||||
public bool HasMoreChunks { get; set; }
|
||||
public byte[]? Payload { get; set; }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
namespace StellaOps.Router.Common.Frames;
|
||||
|
||||
/// <summary>
|
||||
/// Represents a REQUEST frame sent from gateway to microservice.
|
||||
/// </summary>
|
||||
public sealed record RequestFrame
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the unique request ID for this request.
|
||||
/// </summary>
|
||||
public required string RequestId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the correlation ID for distributed tracing.
|
||||
/// </summary>
|
||||
public string? CorrelationId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the HTTP method (GET, POST, PUT, DELETE, etc.).
|
||||
/// </summary>
|
||||
public required string Method { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the request path.
|
||||
/// </summary>
|
||||
public required string Path { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the request headers.
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<string, string> Headers { get; init; } = new Dictionary<string, string>();
|
||||
|
||||
/// <summary>
|
||||
/// Gets the request payload (body).
|
||||
/// </summary>
|
||||
public ReadOnlyMemory<byte> Payload { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the timeout in seconds for this request.
|
||||
/// </summary>
|
||||
public int TimeoutSeconds { get; init; } = 30;
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether this request supports streaming response.
|
||||
/// </summary>
|
||||
public bool SupportsStreaming { get; init; }
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
namespace StellaOps.Router.Common.Frames;
|
||||
|
||||
/// <summary>
|
||||
/// Represents a RESPONSE frame sent from microservice to gateway.
|
||||
/// </summary>
|
||||
public sealed record ResponseFrame
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the request ID this response is for.
|
||||
/// </summary>
|
||||
public required string RequestId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the HTTP status code.
|
||||
/// </summary>
|
||||
public int StatusCode { get; init; } = 200;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the response headers.
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<string, string> Headers { get; init; } = new Dictionary<string, string>();
|
||||
|
||||
/// <summary>
|
||||
/// Gets the response payload (body).
|
||||
/// </summary>
|
||||
public ReadOnlyMemory<byte> Payload { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets whether there are more streaming chunks to follow.
|
||||
/// </summary>
|
||||
public bool HasMoreChunks { get; init; }
|
||||
}
|
||||
@@ -10,3 +10,34 @@ public sealed record CancelPayload
|
||||
/// </summary>
|
||||
public string? Reason { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Standard reasons for request cancellation.
|
||||
/// </summary>
|
||||
public static class CancelReasons
|
||||
{
|
||||
/// <summary>
|
||||
/// The HTTP client disconnected before the request completed.
|
||||
/// </summary>
|
||||
public const string ClientDisconnected = "ClientDisconnected";
|
||||
|
||||
/// <summary>
|
||||
/// The request exceeded its timeout.
|
||||
/// </summary>
|
||||
public const string Timeout = "Timeout";
|
||||
|
||||
/// <summary>
|
||||
/// The request or response payload exceeded configured limits.
|
||||
/// </summary>
|
||||
public const string PayloadLimitExceeded = "PayloadLimitExceeded";
|
||||
|
||||
/// <summary>
|
||||
/// The gateway or microservice is shutting down.
|
||||
/// </summary>
|
||||
public const string Shutdown = "Shutdown";
|
||||
|
||||
/// <summary>
|
||||
/// The transport connection was closed unexpectedly.
|
||||
/// </summary>
|
||||
public const string ConnectionClosed = "ConnectionClosed";
|
||||
}
|
||||
|
||||
@@ -39,4 +39,10 @@ public sealed record EndpointDescriptor
|
||||
/// Gets a value indicating whether this endpoint supports streaming.
|
||||
/// </summary>
|
||||
public bool SupportsStreaming { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the handler type that processes requests for this endpoint.
|
||||
/// This is used by the Microservice SDK for handler resolution.
|
||||
/// </summary>
|
||||
public Type? HandlerType { get; init; }
|
||||
}
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
namespace StellaOps.Router.Common.Models;
|
||||
|
||||
/// <summary>
|
||||
/// Payload for streaming data frames (REQUEST_STREAM_DATA/RESPONSE_STREAM_DATA).
|
||||
/// </summary>
|
||||
public sealed record StreamDataPayload
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the correlation ID linking stream data to the original request.
|
||||
/// </summary>
|
||||
public required Guid CorrelationId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the stream data chunk.
|
||||
/// </summary>
|
||||
public byte[] Data { get; init; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether this is the final chunk.
|
||||
/// </summary>
|
||||
public bool EndOfStream { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the sequence number for ordering.
|
||||
/// </summary>
|
||||
public int SequenceNumber { get; init; }
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
namespace StellaOps.Router.Common.Models;
|
||||
|
||||
/// <summary>
|
||||
/// Configuration options for streaming operations.
|
||||
/// </summary>
|
||||
public sealed record StreamingOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the default streaming options.
|
||||
/// </summary>
|
||||
public static readonly StreamingOptions Default = new();
|
||||
|
||||
/// <summary>
|
||||
/// Gets the size of each chunk when streaming data.
|
||||
/// Default: 64 KB.
|
||||
/// </summary>
|
||||
public int ChunkSize { get; init; } = 64 * 1024;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the maximum number of concurrent streams per connection.
|
||||
/// Default: 100.
|
||||
/// </summary>
|
||||
public int MaxConcurrentStreams { get; init; } = 100;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the timeout for idle streams (no data flowing).
|
||||
/// Default: 5 minutes.
|
||||
/// </summary>
|
||||
public TimeSpan StreamIdleTimeout { get; init; } = TimeSpan.FromMinutes(5);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the channel capacity for buffered stream data.
|
||||
/// Default: 16 chunks.
|
||||
/// </summary>
|
||||
public int ChannelCapacity { get; init; } = 16;
|
||||
}
|
||||
85
src/__Libraries/StellaOps.Router.Common/PathMatcher.cs
Normal file
85
src/__Libraries/StellaOps.Router.Common/PathMatcher.cs
Normal file
@@ -0,0 +1,85 @@
|
||||
using System.Text.RegularExpressions;
|
||||
|
||||
namespace StellaOps.Router.Common;
|
||||
|
||||
/// <summary>
|
||||
/// Matches request paths against route templates.
|
||||
/// </summary>
|
||||
public sealed partial class PathMatcher
|
||||
{
|
||||
private readonly string _template;
|
||||
private readonly Regex _regex;
|
||||
private readonly string[] _parameterNames;
|
||||
private readonly bool _caseInsensitive;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the route template.
|
||||
/// </summary>
|
||||
public string Template => _template;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PathMatcher"/> class.
|
||||
/// </summary>
|
||||
/// <param name="template">The route template (e.g., "/api/users/{id}").</param>
|
||||
/// <param name="caseInsensitive">Whether matching should be case-insensitive.</param>
|
||||
public PathMatcher(string template, bool caseInsensitive = true)
|
||||
{
|
||||
_template = template;
|
||||
_caseInsensitive = caseInsensitive;
|
||||
|
||||
// Extract parameter names and build regex
|
||||
var paramNames = new List<string>();
|
||||
var pattern = "^" + ParameterRegex().Replace(template, match =>
|
||||
{
|
||||
paramNames.Add(match.Groups[1].Value);
|
||||
return "([^/]+)";
|
||||
}) + "/?$";
|
||||
|
||||
var options = caseInsensitive ? RegexOptions.IgnoreCase : RegexOptions.None;
|
||||
_regex = new Regex(pattern, options | RegexOptions.Compiled);
|
||||
_parameterNames = [.. paramNames];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tries to match a path against the template.
|
||||
/// </summary>
|
||||
/// <param name="path">The request path.</param>
|
||||
/// <param name="parameters">The extracted path parameters if matched.</param>
|
||||
/// <returns>True if the path matches.</returns>
|
||||
public bool TryMatch(string path, out Dictionary<string, string> parameters)
|
||||
{
|
||||
parameters = [];
|
||||
|
||||
// Normalize path
|
||||
path = path.TrimEnd('/');
|
||||
if (!path.StartsWith('/'))
|
||||
path = "/" + path;
|
||||
|
||||
var match = _regex.Match(path);
|
||||
if (!match.Success)
|
||||
return false;
|
||||
|
||||
for (int i = 0; i < _parameterNames.Length; i++)
|
||||
{
|
||||
parameters[_parameterNames[i]] = match.Groups[i + 1].Value;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks if a path matches the template.
|
||||
/// </summary>
|
||||
/// <param name="path">The request path.</param>
|
||||
/// <returns>True if the path matches.</returns>
|
||||
public bool IsMatch(string path)
|
||||
{
|
||||
path = path.TrimEnd('/');
|
||||
if (!path.StartsWith('/'))
|
||||
path = "/" + path;
|
||||
return _regex.IsMatch(path);
|
||||
}
|
||||
|
||||
[GeneratedRegex(@"\{([^}:]+)(?::[^}]+)?\}")]
|
||||
private static partial Regex ParameterRegex();
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
namespace StellaOps.Router.Config;
|
||||
|
||||
/// <summary>
|
||||
/// Provides access to router configuration with hot-reload support.
|
||||
/// </summary>
|
||||
public interface IRouterConfigProvider
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the current router configuration.
|
||||
/// </summary>
|
||||
RouterConfig Current { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the current router configuration options.
|
||||
/// </summary>
|
||||
RouterConfigOptions Options { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Raised when the configuration is reloaded.
|
||||
/// </summary>
|
||||
event EventHandler<ConfigChangedEventArgs>? ConfigurationChanged;
|
||||
|
||||
/// <summary>
|
||||
/// Reloads the configuration from the source.
|
||||
/// </summary>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>A task representing the reload operation.</returns>
|
||||
Task ReloadAsync(CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Validates the current configuration.
|
||||
/// </summary>
|
||||
/// <returns>Validation result.</returns>
|
||||
ConfigValidationResult Validate();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Event arguments for configuration changes.
|
||||
/// </summary>
|
||||
public sealed class ConfigChangedEventArgs : EventArgs
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ConfigChangedEventArgs"/> class.
|
||||
/// </summary>
|
||||
/// <param name="previous">The previous configuration.</param>
|
||||
/// <param name="current">The current configuration.</param>
|
||||
public ConfigChangedEventArgs(RouterConfig previous, RouterConfig current)
|
||||
{
|
||||
Previous = previous;
|
||||
Current = current;
|
||||
ChangedAt = DateTime.UtcNow;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the previous configuration.
|
||||
/// </summary>
|
||||
public RouterConfig Previous { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the current configuration.
|
||||
/// </summary>
|
||||
public RouterConfig Current { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the time the configuration was changed.
|
||||
/// </summary>
|
||||
public DateTime ChangedAt { get; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of configuration validation.
|
||||
/// </summary>
|
||||
public sealed class ConfigValidationResult
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets whether the configuration is valid.
|
||||
/// </summary>
|
||||
public bool IsValid => Errors.Count == 0;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the validation errors.
|
||||
/// </summary>
|
||||
public List<string> Errors { get; init; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Gets the validation warnings.
|
||||
/// </summary>
|
||||
public List<string> Warnings { get; init; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// A successful validation result.
|
||||
/// </summary>
|
||||
public static ConfigValidationResult Success => new();
|
||||
}
|
||||
@@ -12,8 +12,18 @@ public sealed class RouterConfig
|
||||
/// </summary>
|
||||
public PayloadLimits PayloadLimits { get; set; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the routing options.
|
||||
/// </summary>
|
||||
public RoutingOptions Routing { get; set; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the service configurations.
|
||||
/// </summary>
|
||||
public List<ServiceConfig> Services { get; set; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the static instance configurations.
|
||||
/// </summary>
|
||||
public List<StaticInstanceConfig> StaticInstances { get; set; } = [];
|
||||
}
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
namespace StellaOps.Router.Config;
|
||||
|
||||
/// <summary>
|
||||
/// Options for the router configuration provider.
|
||||
/// </summary>
|
||||
public sealed class RouterConfigOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the path to the router configuration file (YAML or JSON).
|
||||
/// </summary>
|
||||
public string? ConfigPath { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the environment variable prefix for overrides.
|
||||
/// Default: "STELLAOPS_ROUTER_".
|
||||
/// </summary>
|
||||
public string EnvironmentVariablePrefix { get; set; } = "STELLAOPS_ROUTER_";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to enable hot-reload of configuration.
|
||||
/// </summary>
|
||||
public bool EnableHotReload { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the debounce interval for file change notifications.
|
||||
/// </summary>
|
||||
public TimeSpan DebounceInterval { get; set; } = TimeSpan.FromMilliseconds(500);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to throw on configuration validation errors.
|
||||
/// If false, keeps the previous valid configuration.
|
||||
/// </summary>
|
||||
public bool ThrowOnValidationError { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the configuration section name in appsettings.json.
|
||||
/// </summary>
|
||||
public string ConfigurationSection { get; set; } = "Router";
|
||||
}
|
||||
321
src/__Libraries/StellaOps.Router.Config/RouterConfigProvider.cs
Normal file
321
src/__Libraries/StellaOps.Router.Config/RouterConfigProvider.cs
Normal file
@@ -0,0 +1,321 @@
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
|
||||
namespace StellaOps.Router.Config;
|
||||
|
||||
/// <summary>
|
||||
/// Provides router configuration with hot-reload support.
|
||||
/// </summary>
|
||||
public sealed class RouterConfigProvider : IRouterConfigProvider, IDisposable
|
||||
{
|
||||
private readonly RouterConfigOptions _options;
|
||||
private readonly ILogger<RouterConfigProvider> _logger;
|
||||
private readonly FileSystemWatcher? _watcher;
|
||||
private readonly SemaphoreSlim _reloadLock = new(1, 1);
|
||||
private readonly Timer? _debounceTimer;
|
||||
private RouterConfig _current;
|
||||
private bool _disposed;
|
||||
|
||||
/// <inheritdoc />
|
||||
public event EventHandler<ConfigChangedEventArgs>? ConfigurationChanged;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RouterConfigProvider"/> class.
|
||||
/// </summary>
|
||||
public RouterConfigProvider(
|
||||
IOptions<RouterConfigOptions> options,
|
||||
ILogger<RouterConfigProvider> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
_current = LoadConfiguration();
|
||||
|
||||
if (_options.EnableHotReload && !string.IsNullOrEmpty(_options.ConfigPath) && File.Exists(_options.ConfigPath))
|
||||
{
|
||||
var directory = Path.GetDirectoryName(Path.GetFullPath(_options.ConfigPath))!;
|
||||
var fileName = Path.GetFileName(_options.ConfigPath);
|
||||
|
||||
_watcher = new FileSystemWatcher(directory)
|
||||
{
|
||||
Filter = fileName,
|
||||
NotifyFilter = NotifyFilters.LastWrite | NotifyFilters.Size
|
||||
};
|
||||
|
||||
_debounceTimer = new Timer(OnDebounceElapsed, null, Timeout.Infinite, Timeout.Infinite);
|
||||
|
||||
_watcher.Changed += OnFileChanged;
|
||||
_watcher.EnableRaisingEvents = true;
|
||||
|
||||
_logger.LogInformation("Hot-reload enabled for configuration file: {Path}", _options.ConfigPath);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public RouterConfig Current => _current;
|
||||
|
||||
/// <inheritdoc />
|
||||
public RouterConfigOptions Options => _options;
|
||||
|
||||
private void OnFileChanged(object sender, FileSystemEventArgs e)
|
||||
{
|
||||
// Debounce rapid file changes (e.g., from editors saving multiple times)
|
||||
_debounceTimer?.Change(_options.DebounceInterval, Timeout.InfiniteTimeSpan);
|
||||
}
|
||||
|
||||
private void OnDebounceElapsed(object? state)
|
||||
{
|
||||
_ = ReloadAsyncInternal();
|
||||
}
|
||||
|
||||
private async Task ReloadAsyncInternal()
|
||||
{
|
||||
if (!await _reloadLock.WaitAsync(TimeSpan.Zero))
|
||||
{
|
||||
// Another reload is in progress
|
||||
return;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var previous = _current;
|
||||
var newConfig = LoadConfiguration();
|
||||
var validation = ValidateConfig(newConfig);
|
||||
|
||||
if (!validation.IsValid)
|
||||
{
|
||||
if (_options.ThrowOnValidationError)
|
||||
{
|
||||
throw new ConfigurationException(
|
||||
$"Configuration validation failed: {string.Join("; ", validation.Errors)}");
|
||||
}
|
||||
|
||||
_logger.LogError(
|
||||
"Configuration validation failed, keeping previous: {Errors}",
|
||||
string.Join("; ", validation.Errors));
|
||||
return;
|
||||
}
|
||||
|
||||
foreach (var warning in validation.Warnings)
|
||||
{
|
||||
_logger.LogWarning("Configuration warning: {Warning}", warning);
|
||||
}
|
||||
|
||||
_current = newConfig;
|
||||
_logger.LogInformation("Router configuration reloaded successfully");
|
||||
|
||||
ConfigurationChanged?.Invoke(this, new ConfigChangedEventArgs(previous, newConfig));
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Failed to reload configuration, keeping previous");
|
||||
|
||||
if (_options.ThrowOnValidationError)
|
||||
{
|
||||
throw;
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
_reloadLock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task ReloadAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
await _reloadLock.WaitAsync(cancellationToken);
|
||||
|
||||
try
|
||||
{
|
||||
var previous = _current;
|
||||
var newConfig = LoadConfiguration();
|
||||
var validation = ValidateConfig(newConfig);
|
||||
|
||||
if (!validation.IsValid)
|
||||
{
|
||||
throw new ConfigurationException(
|
||||
$"Configuration validation failed: {string.Join("; ", validation.Errors)}");
|
||||
}
|
||||
|
||||
_current = newConfig;
|
||||
_logger.LogInformation("Router configuration reloaded successfully");
|
||||
|
||||
ConfigurationChanged?.Invoke(this, new ConfigChangedEventArgs(previous, newConfig));
|
||||
}
|
||||
finally
|
||||
{
|
||||
_reloadLock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public ConfigValidationResult Validate() => ValidateConfig(_current);
|
||||
|
||||
private RouterConfig LoadConfiguration()
|
||||
{
|
||||
var builder = new ConfigurationBuilder();
|
||||
|
||||
// Load from YAML file if specified
|
||||
if (!string.IsNullOrEmpty(_options.ConfigPath))
|
||||
{
|
||||
var extension = Path.GetExtension(_options.ConfigPath).ToLowerInvariant();
|
||||
var fullPath = Path.GetFullPath(_options.ConfigPath);
|
||||
|
||||
if (File.Exists(fullPath))
|
||||
{
|
||||
switch (extension)
|
||||
{
|
||||
case ".yaml":
|
||||
case ".yml":
|
||||
builder.AddYamlFile(fullPath, optional: true, reloadOnChange: false);
|
||||
break;
|
||||
case ".json":
|
||||
builder.AddJsonFile(fullPath, optional: true, reloadOnChange: false);
|
||||
break;
|
||||
default:
|
||||
_logger.LogWarning("Unknown configuration file extension: {Extension}", extension);
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
_logger.LogWarning("Configuration file not found: {Path}", fullPath);
|
||||
}
|
||||
}
|
||||
|
||||
// Add environment variable overrides
|
||||
builder.AddEnvironmentVariables(prefix: _options.EnvironmentVariablePrefix);
|
||||
|
||||
var configuration = builder.Build();
|
||||
|
||||
var config = new RouterConfig();
|
||||
configuration.Bind(config);
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
private static ConfigValidationResult ValidateConfig(RouterConfig config)
|
||||
{
|
||||
var result = new ConfigValidationResult();
|
||||
|
||||
// Validate payload limits
|
||||
if (config.PayloadLimits.MaxRequestBytesPerCall <= 0)
|
||||
{
|
||||
result.Errors.Add("PayloadLimits.MaxRequestBytesPerCall must be positive");
|
||||
}
|
||||
|
||||
if (config.PayloadLimits.MaxRequestBytesPerConnection <= 0)
|
||||
{
|
||||
result.Errors.Add("PayloadLimits.MaxRequestBytesPerConnection must be positive");
|
||||
}
|
||||
|
||||
if (config.PayloadLimits.MaxAggregateInflightBytes <= 0)
|
||||
{
|
||||
result.Errors.Add("PayloadLimits.MaxAggregateInflightBytes must be positive");
|
||||
}
|
||||
|
||||
if (config.PayloadLimits.MaxRequestBytesPerCall > config.PayloadLimits.MaxRequestBytesPerConnection)
|
||||
{
|
||||
result.Warnings.Add("MaxRequestBytesPerCall is larger than MaxRequestBytesPerConnection");
|
||||
}
|
||||
|
||||
// Validate routing options
|
||||
if (config.Routing.DefaultTimeout <= TimeSpan.Zero)
|
||||
{
|
||||
result.Errors.Add("Routing.DefaultTimeout must be positive");
|
||||
}
|
||||
|
||||
// Validate services
|
||||
var serviceNames = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
|
||||
foreach (var service in config.Services)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(service.ServiceName))
|
||||
{
|
||||
result.Errors.Add("Service name cannot be empty");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!serviceNames.Add(service.ServiceName))
|
||||
{
|
||||
result.Errors.Add($"Duplicate service name: {service.ServiceName}");
|
||||
}
|
||||
|
||||
foreach (var endpoint in service.Endpoints)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(endpoint.Method))
|
||||
{
|
||||
result.Errors.Add($"Service {service.ServiceName}: endpoint method cannot be empty");
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(endpoint.Path))
|
||||
{
|
||||
result.Errors.Add($"Service {service.ServiceName}: endpoint path cannot be empty");
|
||||
}
|
||||
|
||||
if (endpoint.DefaultTimeout.HasValue && endpoint.DefaultTimeout.Value <= TimeSpan.Zero)
|
||||
{
|
||||
result.Warnings.Add(
|
||||
$"Service {service.ServiceName}: endpoint {endpoint.Method} {endpoint.Path} has non-positive timeout");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate static instances
|
||||
foreach (var instance in config.StaticInstances)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(instance.ServiceName))
|
||||
{
|
||||
result.Errors.Add("Static instance service name cannot be empty");
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(instance.Host))
|
||||
{
|
||||
result.Errors.Add($"Static instance {instance.ServiceName}: host cannot be empty");
|
||||
}
|
||||
|
||||
if (instance.Port <= 0 || instance.Port > 65535)
|
||||
{
|
||||
result.Errors.Add($"Static instance {instance.ServiceName}: port must be between 1 and 65535");
|
||||
}
|
||||
|
||||
if (instance.Weight <= 0)
|
||||
{
|
||||
result.Warnings.Add($"Static instance {instance.ServiceName}: weight should be positive");
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
_watcher?.Dispose();
|
||||
_debounceTimer?.Dispose();
|
||||
_reloadLock.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Exception thrown when configuration is invalid.
|
||||
/// </summary>
|
||||
public sealed class ConfigurationException : Exception
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ConfigurationException"/> class.
|
||||
/// </summary>
|
||||
public ConfigurationException(string message) : base(message)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ConfigurationException"/> class.
|
||||
/// </summary>
|
||||
public ConfigurationException(string message, Exception innerException) : base(message, innerException)
|
||||
{
|
||||
}
|
||||
}
|
||||
58
src/__Libraries/StellaOps.Router.Config/RoutingOptions.cs
Normal file
58
src/__Libraries/StellaOps.Router.Config/RoutingOptions.cs
Normal file
@@ -0,0 +1,58 @@
|
||||
namespace StellaOps.Router.Config;
|
||||
|
||||
/// <summary>
|
||||
/// Routing behavior options.
|
||||
/// </summary>
|
||||
public sealed class RoutingOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the local region for routing preferences.
|
||||
/// </summary>
|
||||
public string LocalRegion { get; set; } = "default";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the neighbor regions for fallback routing.
|
||||
/// </summary>
|
||||
public List<string> NeighborRegions { get; set; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the tie-breaker strategy for equal-weight instances.
|
||||
/// </summary>
|
||||
public TieBreakerStrategy TieBreaker { get; set; } = TieBreakerStrategy.RoundRobin;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to prefer local region instances.
|
||||
/// </summary>
|
||||
public bool PreferLocalRegion { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the default request timeout.
|
||||
/// </summary>
|
||||
public TimeSpan DefaultTimeout { get; set; } = TimeSpan.FromSeconds(30);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tie-breaker strategy for routing decisions.
|
||||
/// </summary>
|
||||
public enum TieBreakerStrategy
|
||||
{
|
||||
/// <summary>
|
||||
/// Round-robin between equal-weight instances.
|
||||
/// </summary>
|
||||
RoundRobin,
|
||||
|
||||
/// <summary>
|
||||
/// Random selection between equal-weight instances.
|
||||
/// </summary>
|
||||
Random,
|
||||
|
||||
/// <summary>
|
||||
/// Select the least-loaded instance.
|
||||
/// </summary>
|
||||
LeastLoaded,
|
||||
|
||||
/// <summary>
|
||||
/// Consistent hashing based on request attributes.
|
||||
/// </summary>
|
||||
ConsistentHash
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace StellaOps.Router.Config;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering router configuration services.
|
||||
/// </summary>
|
||||
public static class ServiceCollectionExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds router configuration services to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configPath">Optional path to the configuration file.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddRouterConfig(
|
||||
this IServiceCollection services,
|
||||
string? configPath = null)
|
||||
{
|
||||
return services.AddRouterConfig(options =>
|
||||
{
|
||||
if (!string.IsNullOrEmpty(configPath))
|
||||
{
|
||||
options.ConfigPath = configPath;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds router configuration services to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configure">Configuration action.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddRouterConfig(
|
||||
this IServiceCollection services,
|
||||
Action<RouterConfigOptions> configure)
|
||||
{
|
||||
services.Configure(configure);
|
||||
services.AddSingleton<IRouterConfigProvider, RouterConfigProvider>();
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds router configuration services to the service collection, binding from IConfiguration.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configuration">The configuration.</param>
|
||||
/// <param name="sectionName">The configuration section name.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddRouterConfig(
|
||||
this IServiceCollection services,
|
||||
IConfiguration configuration,
|
||||
string sectionName = "Router")
|
||||
{
|
||||
var section = configuration.GetSection(sectionName);
|
||||
|
||||
services.Configure<RouterConfigOptions>(options =>
|
||||
{
|
||||
options.ConfigurationSection = sectionName;
|
||||
});
|
||||
|
||||
services.Configure<RouterConfig>(section);
|
||||
services.AddSingleton<IRouterConfigProvider, RouterConfigProvider>();
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds router configuration from a YAML file.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="yamlPath">Path to the YAML configuration file.</param>
|
||||
/// <param name="enableHotReload">Whether to enable hot-reload.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddRouterConfigFromYaml(
|
||||
this IServiceCollection services,
|
||||
string yamlPath,
|
||||
bool enableHotReload = true)
|
||||
{
|
||||
return services.AddRouterConfig(options =>
|
||||
{
|
||||
options.ConfigPath = yamlPath;
|
||||
options.EnableHotReload = enableHotReload;
|
||||
});
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds router configuration from a JSON file.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="jsonPath">Path to the JSON configuration file.</param>
|
||||
/// <param name="enableHotReload">Whether to enable hot-reload.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddRouterConfigFromJson(
|
||||
this IServiceCollection services,
|
||||
string jsonPath,
|
||||
bool enableHotReload = true)
|
||||
{
|
||||
return services.AddRouterConfig(options =>
|
||||
{
|
||||
options.ConfigPath = jsonPath;
|
||||
options.EnableHotReload = enableHotReload;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
using StellaOps.Router.Common.Enums;
|
||||
|
||||
namespace StellaOps.Router.Config;
|
||||
|
||||
/// <summary>
|
||||
/// Configuration for a statically-defined microservice instance.
|
||||
/// </summary>
|
||||
public sealed class StaticInstanceConfig
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the service name.
|
||||
/// </summary>
|
||||
public required string ServiceName { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the service version.
|
||||
/// </summary>
|
||||
public required string Version { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the region.
|
||||
/// </summary>
|
||||
public string Region { get; set; } = "default";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the host name or IP address.
|
||||
/// </summary>
|
||||
public required string Host { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the port.
|
||||
/// </summary>
|
||||
public required int Port { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the transport type.
|
||||
/// </summary>
|
||||
public TransportType Transport { get; set; } = TransportType.Tcp;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the instance weight for load balancing.
|
||||
/// </summary>
|
||||
public int Weight { get; set; } = 100;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the instance metadata.
|
||||
/// </summary>
|
||||
public Dictionary<string, string> Metadata { get; set; } = [];
|
||||
}
|
||||
@@ -5,8 +5,23 @@
|
||||
<Nullable>enable</Nullable>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
<RootNamespace>StellaOps.Router.Config</RootNamespace>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.Extensions.Configuration" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Configuration.Binder" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Configuration.EnvironmentVariables" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Configuration.Json" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Options" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Options.ConfigurationExtensions" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="NetEscapades.Configuration.Yaml" Version="2.1.0" />
|
||||
<PackageReference Include="YamlDotNet" Version="13.7.1" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
||||
@@ -5,6 +5,7 @@ using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using static StellaOps.Router.Common.Models.CancelReasons;
|
||||
|
||||
namespace StellaOps.Router.Transport.InMemory;
|
||||
|
||||
@@ -12,12 +13,13 @@ namespace StellaOps.Router.Transport.InMemory;
|
||||
/// In-memory transport client implementation for testing and development.
|
||||
/// Used by the Microservice SDK to send frames to the Gateway.
|
||||
/// </summary>
|
||||
public sealed class InMemoryTransportClient : ITransportClient, IDisposable
|
||||
public sealed class InMemoryTransportClient : ITransportClient, IMicroserviceTransport, IDisposable
|
||||
{
|
||||
private readonly InMemoryConnectionRegistry _registry;
|
||||
private readonly InMemoryTransportOptions _options;
|
||||
private readonly ILogger<InMemoryTransportClient> _logger;
|
||||
private readonly ConcurrentDictionary<string, TaskCompletionSource<Frame>> _pendingRequests = new();
|
||||
private readonly ConcurrentDictionary<string, CancellationTokenSource> _inflightHandlers = new();
|
||||
private readonly CancellationTokenSource _clientCts = new();
|
||||
private bool _disposed;
|
||||
private string? _connectionId;
|
||||
@@ -172,29 +174,54 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
|
||||
return;
|
||||
}
|
||||
|
||||
var correlationId = frame.CorrelationId ?? Guid.NewGuid().ToString("N");
|
||||
|
||||
// Create a linked CancellationTokenSource for this handler
|
||||
// This allows cancellation via CANCEL frame or transport shutdown
|
||||
using var handlerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
_inflightHandlers[correlationId] = handlerCts;
|
||||
|
||||
try
|
||||
{
|
||||
var response = await OnRequestReceived(frame, cancellationToken);
|
||||
var response = await OnRequestReceived(frame, handlerCts.Token);
|
||||
|
||||
// Ensure response has same correlation ID
|
||||
var responseFrame = response with { CorrelationId = frame.CorrelationId };
|
||||
await channel.ToGateway.Writer.WriteAsync(responseFrame, cancellationToken);
|
||||
var responseFrame = response with { CorrelationId = correlationId };
|
||||
|
||||
// Only send response if not cancelled
|
||||
if (!handlerCts.Token.IsCancellationRequested)
|
||||
{
|
||||
await channel.ToGateway.Writer.WriteAsync(responseFrame, cancellationToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
_logger.LogDebug("Not sending response for cancelled request {CorrelationId}", correlationId);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
_logger.LogDebug("Request {CorrelationId} was cancelled", frame.CorrelationId);
|
||||
_logger.LogDebug("Request {CorrelationId} was cancelled", correlationId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error handling request {CorrelationId}", frame.CorrelationId);
|
||||
// Send error response
|
||||
var errorFrame = new Frame
|
||||
_logger.LogError(ex, "Error handling request {CorrelationId}", correlationId);
|
||||
|
||||
// Only send error response if not cancelled
|
||||
if (!handlerCts.Token.IsCancellationRequested)
|
||||
{
|
||||
Type = FrameType.Response,
|
||||
CorrelationId = frame.CorrelationId,
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
await channel.ToGateway.Writer.WriteAsync(errorFrame, cancellationToken);
|
||||
var errorFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Response,
|
||||
CorrelationId = correlationId,
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
await channel.ToGateway.Writer.WriteAsync(errorFrame, cancellationToken);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
// Remove from inflight tracking
|
||||
_inflightHandlers.TryRemove(correlationId, out _);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,13 +231,27 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
|
||||
|
||||
_logger.LogDebug("Received CANCEL for correlation {CorrelationId}", frame.CorrelationId);
|
||||
|
||||
// Cancel the inflight handler via its CancellationTokenSource
|
||||
if (_inflightHandlers.TryGetValue(frame.CorrelationId, out var handlerCts))
|
||||
{
|
||||
try
|
||||
{
|
||||
handlerCts.Cancel();
|
||||
_logger.LogInformation("Cancelled handler for request {CorrelationId}", frame.CorrelationId);
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Handler already completed
|
||||
}
|
||||
}
|
||||
|
||||
// Complete any pending request with cancellation
|
||||
if (_pendingRequests.TryRemove(frame.CorrelationId, out var tcs))
|
||||
{
|
||||
tcs.TrySetCanceled();
|
||||
}
|
||||
|
||||
// Notify handler
|
||||
// Notify external handler (for custom cancellation logic)
|
||||
if (OnCancelReceived is not null && Guid.TryParse(frame.CorrelationId, out var correlationGuid))
|
||||
{
|
||||
_ = OnCancelReceived(correlationGuid, null);
|
||||
@@ -381,6 +422,33 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
|
||||
await channel.ToGateway.Writer.WriteAsync(frame, cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cancels all in-flight handler requests.
|
||||
/// Called when connection is closed or transport is shutting down.
|
||||
/// </summary>
|
||||
/// <param name="reason">The reason for cancellation.</param>
|
||||
public void CancelAllInflight(string reason)
|
||||
{
|
||||
var count = 0;
|
||||
foreach (var kvp in _inflightHandlers)
|
||||
{
|
||||
try
|
||||
{
|
||||
kvp.Value.Cancel();
|
||||
count++;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Already completed/disposed
|
||||
}
|
||||
}
|
||||
|
||||
if (count > 0)
|
||||
{
|
||||
_logger.LogInformation("Cancelled {Count} in-flight handlers: {Reason}", count, reason);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disconnects from the transport.
|
||||
/// </summary>
|
||||
@@ -388,6 +456,9 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
|
||||
{
|
||||
if (_connectionId is null) return;
|
||||
|
||||
// Cancel all inflight handlers before disconnecting
|
||||
CancelAllInflight(CancelReasons.Shutdown);
|
||||
|
||||
await _clientCts.CancelAsync();
|
||||
|
||||
if (_receiveTask is not null)
|
||||
@@ -407,6 +478,9 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
// Cancel all inflight handlers
|
||||
CancelAllInflight(Shutdown);
|
||||
|
||||
_clientCts.Cancel();
|
||||
|
||||
foreach (var tcs in _pendingRequests.Values)
|
||||
@@ -414,6 +488,7 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
|
||||
tcs.TrySetCanceled();
|
||||
}
|
||||
_pendingRequests.Clear();
|
||||
_inflightHandlers.Clear();
|
||||
|
||||
if (_connectionId is not null)
|
||||
{
|
||||
|
||||
@@ -35,6 +35,7 @@ public static class ServiceCollectionExtensions
|
||||
// Register interfaces
|
||||
services.TryAddSingleton<ITransportServer>(sp => sp.GetRequiredService<InMemoryTransportServer>());
|
||||
services.TryAddSingleton<ITransportClient>(sp => sp.GetRequiredService<InMemoryTransportClient>());
|
||||
services.TryAddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<InMemoryTransportClient>());
|
||||
|
||||
return services;
|
||||
}
|
||||
@@ -81,6 +82,7 @@ public static class ServiceCollectionExtensions
|
||||
services.TryAddSingleton<InMemoryConnectionRegistry>();
|
||||
services.TryAddSingleton<InMemoryTransportClient>();
|
||||
services.TryAddSingleton<ITransportClient>(sp => sp.GetRequiredService<InMemoryTransportClient>());
|
||||
services.TryAddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<InMemoryTransportClient>());
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
using System.Text;
|
||||
using RabbitMQ.Client;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.RabbitMq;
|
||||
|
||||
/// <summary>
|
||||
/// Handles serialization and deserialization of frames for RabbitMQ transport.
|
||||
/// </summary>
|
||||
public static class RabbitMqFrameProtocol
|
||||
{
|
||||
/// <summary>
|
||||
/// Parses a frame from a RabbitMQ message.
|
||||
/// </summary>
|
||||
/// <param name="body">The message body.</param>
|
||||
/// <param name="properties">The message properties.</param>
|
||||
/// <returns>The parsed frame.</returns>
|
||||
public static Frame ParseFrame(ReadOnlyMemory<byte> body, IReadOnlyBasicProperties properties)
|
||||
{
|
||||
var frameType = ParseFrameType(properties.Type);
|
||||
var correlationId = properties.CorrelationId;
|
||||
|
||||
return new Frame
|
||||
{
|
||||
Type = frameType,
|
||||
CorrelationId = correlationId,
|
||||
Payload = body
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates BasicProperties for a frame.
|
||||
/// </summary>
|
||||
/// <param name="frame">The frame to serialize.</param>
|
||||
/// <param name="replyTo">The reply queue name.</param>
|
||||
/// <param name="timeout">Optional timeout for the message.</param>
|
||||
/// <returns>The basic properties.</returns>
|
||||
public static BasicProperties CreateProperties(Frame frame, string? replyTo, TimeSpan? timeout = null)
|
||||
{
|
||||
var props = new BasicProperties
|
||||
{
|
||||
Type = frame.Type.ToString(),
|
||||
Timestamp = new AmqpTimestamp(DateTimeOffset.UtcNow.ToUnixTimeSeconds()),
|
||||
DeliveryMode = DeliveryModes.Transient // Non-persistent (1)
|
||||
};
|
||||
|
||||
if (!string.IsNullOrEmpty(frame.CorrelationId))
|
||||
{
|
||||
props.CorrelationId = frame.CorrelationId;
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(replyTo))
|
||||
{
|
||||
props.ReplyTo = replyTo;
|
||||
}
|
||||
|
||||
if (timeout.HasValue)
|
||||
{
|
||||
props.Expiration = ((int)timeout.Value.TotalMilliseconds).ToString();
|
||||
}
|
||||
|
||||
return props;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Parses a FrameType from the message Type property.
|
||||
/// </summary>
|
||||
private static FrameType ParseFrameType(string? type)
|
||||
{
|
||||
if (string.IsNullOrEmpty(type))
|
||||
{
|
||||
return FrameType.Request;
|
||||
}
|
||||
|
||||
if (Enum.TryParse<FrameType>(type, ignoreCase: true, out var result))
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
return FrameType.Request;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extracts the connection ID from message properties.
|
||||
/// </summary>
|
||||
/// <param name="properties">The message properties.</param>
|
||||
/// <returns>The connection ID.</returns>
|
||||
public static string ExtractConnectionId(IReadOnlyBasicProperties properties)
|
||||
{
|
||||
// Use ReplyTo as the basis for connection ID (identifies the instance)
|
||||
if (!string.IsNullOrEmpty(properties.ReplyTo))
|
||||
{
|
||||
// Extract instance ID from queue name like "stella.svc.{instanceId}"
|
||||
var parts = properties.ReplyTo.Split('.');
|
||||
if (parts.Length >= 3)
|
||||
{
|
||||
return $"rmq-{parts[^1]}";
|
||||
}
|
||||
return $"rmq-{properties.ReplyTo}";
|
||||
}
|
||||
|
||||
// Fallback to correlation ID
|
||||
if (!string.IsNullOrEmpty(properties.CorrelationId))
|
||||
{
|
||||
return $"rmq-{properties.CorrelationId[..Math.Min(16, properties.CorrelationId.Length)]}";
|
||||
}
|
||||
|
||||
return $"rmq-{Guid.NewGuid():N}"[..32];
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,449 @@
|
||||
using System.Collections.Concurrent;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using RabbitMQ.Client;
|
||||
using RabbitMQ.Client.Events;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.RabbitMq;
|
||||
|
||||
/// <summary>
|
||||
/// RabbitMQ transport client implementation for microservices.
|
||||
/// </summary>
|
||||
public sealed class RabbitMqTransportClient : ITransportClient, IMicroserviceTransport, IAsyncDisposable
|
||||
{
|
||||
private readonly RabbitMqTransportOptions _options;
|
||||
private readonly ILogger<RabbitMqTransportClient> _logger;
|
||||
private readonly ConcurrentDictionary<Guid, TaskCompletionSource<Frame>> _pendingRequests = new();
|
||||
private readonly ConcurrentDictionary<string, CancellationTokenSource> _inflightHandlers = new();
|
||||
private readonly CancellationTokenSource _clientCts = new();
|
||||
private IConnection? _connection;
|
||||
private IChannel? _channel;
|
||||
private string? _responseQueueName;
|
||||
private string? _instanceId;
|
||||
private string? _gatewayNodeId;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a REQUEST frame is received.
|
||||
/// </summary>
|
||||
public event Func<Frame, CancellationToken, Task<Frame>>? OnRequestReceived;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a CANCEL frame is received.
|
||||
/// </summary>
|
||||
public event Func<Guid, string?, Task>? OnCancelReceived;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RabbitMqTransportClient"/> class.
|
||||
/// </summary>
|
||||
public RabbitMqTransportClient(
|
||||
IOptions<RabbitMqTransportOptions> options,
|
||||
ILogger<RabbitMqTransportClient> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Connects to the gateway via RabbitMQ.
|
||||
/// </summary>
|
||||
/// <param name="instance">The instance descriptor.</param>
|
||||
/// <param name="endpoints">The endpoints to register.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task ConnectAsync(
|
||||
InstanceDescriptor instance,
|
||||
IReadOnlyList<EndpointDescriptor> endpoints,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
_instanceId = _options.InstanceId ?? instance.InstanceId;
|
||||
_gatewayNodeId = _options.NodeId ?? "default";
|
||||
|
||||
var factory = new ConnectionFactory
|
||||
{
|
||||
HostName = _options.HostName,
|
||||
Port = _options.Port,
|
||||
VirtualHost = _options.VirtualHost,
|
||||
UserName = _options.UserName,
|
||||
Password = _options.Password,
|
||||
AutomaticRecoveryEnabled = _options.AutomaticRecoveryEnabled,
|
||||
NetworkRecoveryInterval = _options.NetworkRecoveryInterval
|
||||
};
|
||||
|
||||
if (_options.UseSsl)
|
||||
{
|
||||
factory.Ssl = new SslOption
|
||||
{
|
||||
Enabled = true,
|
||||
ServerName = _options.HostName,
|
||||
CertPath = _options.SslCertPath
|
||||
};
|
||||
}
|
||||
|
||||
_connection = await factory.CreateConnectionAsync(cancellationToken);
|
||||
_channel = await _connection.CreateChannelAsync(cancellationToken: cancellationToken);
|
||||
|
||||
// Set QoS
|
||||
await _channel.BasicQosAsync(
|
||||
prefetchSize: 0,
|
||||
prefetchCount: _options.PrefetchCount,
|
||||
global: false,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
// Declare exchanges (should already exist from server, but declare for safety)
|
||||
await _channel.ExchangeDeclareAsync(
|
||||
exchange: _options.RequestExchange,
|
||||
type: ExchangeType.Direct,
|
||||
durable: true,
|
||||
autoDelete: false,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
await _channel.ExchangeDeclareAsync(
|
||||
exchange: _options.ResponseExchange,
|
||||
type: ExchangeType.Topic,
|
||||
durable: true,
|
||||
autoDelete: false,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
// Declare response queue for this instance
|
||||
_responseQueueName = $"{_options.QueuePrefix}.svc.{_instanceId}";
|
||||
await _channel.QueueDeclareAsync(
|
||||
queue: _responseQueueName,
|
||||
durable: _options.DurableQueues,
|
||||
exclusive: false,
|
||||
autoDelete: _options.AutoDeleteQueues,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
// Bind to response exchange with instance ID as routing key
|
||||
await _channel.QueueBindAsync(
|
||||
queue: _responseQueueName,
|
||||
exchange: _options.ResponseExchange,
|
||||
routingKey: _instanceId,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
// Start consuming responses
|
||||
var consumer = new AsyncEventingBasicConsumer(_channel);
|
||||
consumer.ReceivedAsync += OnMessageReceivedAsync;
|
||||
|
||||
await _channel.BasicConsumeAsync(
|
||||
queue: _responseQueueName,
|
||||
autoAck: true,
|
||||
consumer: consumer,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
// Send HELLO frame
|
||||
var helloFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Hello,
|
||||
CorrelationId = Guid.NewGuid().ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
await SendToGatewayAsync(helloFrame, cancellationToken);
|
||||
|
||||
_logger.LogInformation(
|
||||
"Connected to RabbitMQ gateway at {Host}:{Port} as {ServiceName}/{Version}",
|
||||
_options.HostName,
|
||||
_options.Port,
|
||||
instance.ServiceName,
|
||||
instance.Version);
|
||||
}
|
||||
|
||||
private async Task OnMessageReceivedAsync(object sender, BasicDeliverEventArgs e)
|
||||
{
|
||||
try
|
||||
{
|
||||
var frame = RabbitMqFrameProtocol.ParseFrame(e.Body, e.BasicProperties);
|
||||
|
||||
switch (frame.Type)
|
||||
{
|
||||
case FrameType.Request:
|
||||
await HandleRequestFrameAsync(frame, _clientCts.Token);
|
||||
break;
|
||||
|
||||
case FrameType.Cancel:
|
||||
HandleCancelFrame(frame);
|
||||
break;
|
||||
|
||||
case FrameType.Response:
|
||||
if (frame.CorrelationId is not null &&
|
||||
Guid.TryParse(frame.CorrelationId, out var correlationId))
|
||||
{
|
||||
if (_pendingRequests.TryRemove(correlationId, out var tcs))
|
||||
{
|
||||
tcs.TrySetResult(frame);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
_logger.LogWarning("Unexpected frame type {FrameType}", frame.Type);
|
||||
break;
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error processing RabbitMQ message");
|
||||
}
|
||||
|
||||
await Task.CompletedTask;
|
||||
}
|
||||
|
||||
private async Task HandleRequestFrameAsync(Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
if (OnRequestReceived is null)
|
||||
{
|
||||
_logger.LogWarning("No request handler registered");
|
||||
return;
|
||||
}
|
||||
|
||||
var correlationId = frame.CorrelationId ?? Guid.NewGuid().ToString("N");
|
||||
|
||||
using var handlerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
_inflightHandlers[correlationId] = handlerCts;
|
||||
|
||||
try
|
||||
{
|
||||
var response = await OnRequestReceived(frame, handlerCts.Token);
|
||||
var responseFrame = response with { CorrelationId = correlationId };
|
||||
|
||||
if (!handlerCts.Token.IsCancellationRequested)
|
||||
{
|
||||
await SendToGatewayAsync(responseFrame, cancellationToken);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
_logger.LogDebug("Request {CorrelationId} was cancelled", correlationId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error handling request {CorrelationId}", correlationId);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_inflightHandlers.TryRemove(correlationId, out _);
|
||||
}
|
||||
}
|
||||
|
||||
private void HandleCancelFrame(Frame frame)
|
||||
{
|
||||
if (frame.CorrelationId is null) return;
|
||||
|
||||
_logger.LogDebug("Received CANCEL for {CorrelationId}", frame.CorrelationId);
|
||||
|
||||
if (_inflightHandlers.TryGetValue(frame.CorrelationId, out var cts))
|
||||
{
|
||||
try
|
||||
{
|
||||
cts.Cancel();
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Already completed
|
||||
}
|
||||
}
|
||||
|
||||
if (Guid.TryParse(frame.CorrelationId, out var guid))
|
||||
{
|
||||
if (_pendingRequests.TryRemove(guid, out var tcs))
|
||||
{
|
||||
tcs.TrySetCanceled();
|
||||
}
|
||||
OnCancelReceived?.Invoke(guid, null);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task SendToGatewayAsync(Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var properties = RabbitMqFrameProtocol.CreateProperties(
|
||||
frame,
|
||||
_responseQueueName,
|
||||
_options.DefaultTimeout);
|
||||
|
||||
await _channel!.BasicPublishAsync(
|
||||
exchange: _options.RequestExchange,
|
||||
routingKey: _gatewayNodeId!,
|
||||
mandatory: false,
|
||||
basicProperties: properties,
|
||||
body: frame.Payload,
|
||||
cancellationToken: cancellationToken);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<Frame> SendRequestAsync(
|
||||
ConnectionState connection,
|
||||
Frame requestFrame,
|
||||
TimeSpan timeout,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var correlationId = requestFrame.CorrelationId is not null &&
|
||||
Guid.TryParse(requestFrame.CorrelationId, out var parsed)
|
||||
? parsed
|
||||
: Guid.NewGuid();
|
||||
|
||||
var framedRequest = requestFrame with { CorrelationId = correlationId.ToString("N") };
|
||||
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
timeoutCts.CancelAfter(timeout);
|
||||
|
||||
var tcs = new TaskCompletionSource<Frame>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
var registration = timeoutCts.Token.Register(() =>
|
||||
{
|
||||
if (_pendingRequests.TryRemove(correlationId, out var pendingTcs))
|
||||
{
|
||||
pendingTcs.TrySetCanceled(timeoutCts.Token);
|
||||
}
|
||||
});
|
||||
|
||||
_pendingRequests[correlationId] = tcs;
|
||||
|
||||
try
|
||||
{
|
||||
await SendToGatewayAsync(framedRequest, timeoutCts.Token);
|
||||
return await tcs.Task;
|
||||
}
|
||||
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
throw new TimeoutException($"Request {correlationId} timed out after {timeout}");
|
||||
}
|
||||
finally
|
||||
{
|
||||
await registration.DisposeAsync();
|
||||
_pendingRequests.TryRemove(correlationId, out _);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task SendCancelAsync(
|
||||
ConnectionState connection,
|
||||
Guid correlationId,
|
||||
string? reason = null)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var cancelFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Cancel,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
await SendToGatewayAsync(cancelFrame, CancellationToken.None);
|
||||
_logger.LogDebug("Sent CANCEL for {CorrelationId}", correlationId);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task SendStreamingAsync(
|
||||
ConnectionState connection,
|
||||
Frame requestHeader,
|
||||
Stream requestBody,
|
||||
Func<Stream, Task> readResponseBody,
|
||||
PayloadLimits limits,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
// Streaming could be implemented by chunking messages, but for now we don't support it
|
||||
// This keeps RabbitMQ transport simple
|
||||
throw new NotSupportedException(
|
||||
"RabbitMQ transport does not currently support streaming. Use TCP or TLS transport for streaming.");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sends a heartbeat.
|
||||
/// </summary>
|
||||
public async Task SendHeartbeatAsync(HeartbeatPayload heartbeat, CancellationToken cancellationToken)
|
||||
{
|
||||
var frame = new Frame
|
||||
{
|
||||
Type = FrameType.Heartbeat,
|
||||
CorrelationId = null,
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
await SendToGatewayAsync(frame, cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cancels all in-flight handlers.
|
||||
/// </summary>
|
||||
public void CancelAllInflight(string reason)
|
||||
{
|
||||
var count = 0;
|
||||
foreach (var cts in _inflightHandlers.Values)
|
||||
{
|
||||
try
|
||||
{
|
||||
cts.Cancel();
|
||||
count++;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Already completed
|
||||
}
|
||||
}
|
||||
|
||||
if (count > 0)
|
||||
{
|
||||
_logger.LogInformation("Cancelled {Count} in-flight handlers: {Reason}", count, reason);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disconnects from the gateway.
|
||||
/// </summary>
|
||||
public async Task DisconnectAsync()
|
||||
{
|
||||
CancelAllInflight("Shutdown");
|
||||
|
||||
// Cancel all pending requests
|
||||
foreach (var kvp in _pendingRequests)
|
||||
{
|
||||
if (_pendingRequests.TryRemove(kvp.Key, out var tcs))
|
||||
{
|
||||
tcs.TrySetCanceled();
|
||||
}
|
||||
}
|
||||
|
||||
await _clientCts.CancelAsync();
|
||||
|
||||
if (_channel is not null)
|
||||
{
|
||||
await _channel.CloseAsync();
|
||||
}
|
||||
|
||||
if (_connection is not null)
|
||||
{
|
||||
await _connection.CloseAsync();
|
||||
}
|
||||
|
||||
_logger.LogInformation("Disconnected from RabbitMQ gateway");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
await DisconnectAsync();
|
||||
|
||||
if (_channel is not null)
|
||||
{
|
||||
await _channel.DisposeAsync();
|
||||
}
|
||||
|
||||
if (_connection is not null)
|
||||
{
|
||||
await _connection.DisposeAsync();
|
||||
}
|
||||
|
||||
_clientCts.Dispose();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
namespace StellaOps.Router.Transport.RabbitMq;
|
||||
|
||||
/// <summary>
|
||||
/// Options for RabbitMQ transport configuration.
|
||||
/// </summary>
|
||||
public sealed class RabbitMqTransportOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the RabbitMQ host name.
|
||||
/// </summary>
|
||||
public string HostName { get; set; } = "localhost";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the RabbitMQ port.
|
||||
/// </summary>
|
||||
public int Port { get; set; } = 5672;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the RabbitMQ virtual host.
|
||||
/// </summary>
|
||||
public string VirtualHost { get; set; } = "/";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the RabbitMQ username.
|
||||
/// </summary>
|
||||
public string UserName { get; set; } = "guest";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the RabbitMQ password.
|
||||
/// </summary>
|
||||
public string Password { get; set; } = "guest";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to use SSL/TLS.
|
||||
/// </summary>
|
||||
public bool UseSsl { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the SSL certificate path.
|
||||
/// </summary>
|
||||
public string? SslCertPath { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether queues should be durable.
|
||||
/// </summary>
|
||||
public bool DurableQueues { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether queues should auto-delete on disconnect.
|
||||
/// </summary>
|
||||
public bool AutoDeleteQueues { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the prefetch count (concurrent messages).
|
||||
/// </summary>
|
||||
public ushort PrefetchCount { get; set; } = 10;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the exchange prefix.
|
||||
/// </summary>
|
||||
public string ExchangePrefix { get; set; } = "stella.router";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the queue prefix.
|
||||
/// </summary>
|
||||
public string QueuePrefix { get; set; } = "stella";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the request exchange name.
|
||||
/// </summary>
|
||||
public string RequestExchange => $"{ExchangePrefix}.requests";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the response exchange name.
|
||||
/// </summary>
|
||||
public string ResponseExchange => $"{ExchangePrefix}.responses";
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the node ID for this gateway instance.
|
||||
/// </summary>
|
||||
public string? NodeId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the instance ID for this microservice instance.
|
||||
/// </summary>
|
||||
public string? InstanceId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to use automatic recovery.
|
||||
/// </summary>
|
||||
public bool AutomaticRecoveryEnabled { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the network recovery interval.
|
||||
/// </summary>
|
||||
public TimeSpan NetworkRecoveryInterval { get; set; } = TimeSpan.FromSeconds(5);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the default request timeout.
|
||||
/// </summary>
|
||||
public TimeSpan DefaultTimeout { get; set; } = TimeSpan.FromSeconds(30);
|
||||
}
|
||||
@@ -0,0 +1,289 @@
|
||||
using System.Collections.Concurrent;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using RabbitMQ.Client;
|
||||
using RabbitMQ.Client.Events;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.RabbitMq;
|
||||
|
||||
/// <summary>
|
||||
/// RabbitMQ transport server implementation for the gateway.
|
||||
/// </summary>
|
||||
public sealed class RabbitMqTransportServer : ITransportServer, IAsyncDisposable
|
||||
{
|
||||
private readonly RabbitMqTransportOptions _options;
|
||||
private readonly ILogger<RabbitMqTransportServer> _logger;
|
||||
private readonly ConcurrentDictionary<string, (string ReplyTo, ConnectionState State)> _connections = new();
|
||||
private readonly string _nodeId;
|
||||
private IConnection? _connection;
|
||||
private IChannel? _channel;
|
||||
private string? _requestQueueName;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a connection is established (on first HELLO).
|
||||
/// </summary>
|
||||
public event Action<string, ConnectionState>? OnConnection;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a connection is lost.
|
||||
/// </summary>
|
||||
public event Action<string>? OnDisconnection;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a frame is received.
|
||||
/// </summary>
|
||||
public event Action<string, Frame>? OnFrame;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RabbitMqTransportServer"/> class.
|
||||
/// </summary>
|
||||
public RabbitMqTransportServer(
|
||||
IOptions<RabbitMqTransportOptions> options,
|
||||
ILogger<RabbitMqTransportServer> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
_nodeId = _options.NodeId ?? Guid.NewGuid().ToString("N")[..8];
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task StartAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var factory = new ConnectionFactory
|
||||
{
|
||||
HostName = _options.HostName,
|
||||
Port = _options.Port,
|
||||
VirtualHost = _options.VirtualHost,
|
||||
UserName = _options.UserName,
|
||||
Password = _options.Password,
|
||||
AutomaticRecoveryEnabled = _options.AutomaticRecoveryEnabled,
|
||||
NetworkRecoveryInterval = _options.NetworkRecoveryInterval
|
||||
};
|
||||
|
||||
if (_options.UseSsl)
|
||||
{
|
||||
factory.Ssl = new SslOption
|
||||
{
|
||||
Enabled = true,
|
||||
ServerName = _options.HostName,
|
||||
CertPath = _options.SslCertPath
|
||||
};
|
||||
}
|
||||
|
||||
_connection = await factory.CreateConnectionAsync(cancellationToken);
|
||||
_channel = await _connection.CreateChannelAsync(cancellationToken: cancellationToken);
|
||||
|
||||
// Set QoS (prefetch count)
|
||||
await _channel.BasicQosAsync(
|
||||
prefetchSize: 0,
|
||||
prefetchCount: _options.PrefetchCount,
|
||||
global: false,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
// Declare exchanges
|
||||
await _channel.ExchangeDeclareAsync(
|
||||
exchange: _options.RequestExchange,
|
||||
type: ExchangeType.Direct,
|
||||
durable: true,
|
||||
autoDelete: false,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
await _channel.ExchangeDeclareAsync(
|
||||
exchange: _options.ResponseExchange,
|
||||
type: ExchangeType.Topic,
|
||||
durable: true,
|
||||
autoDelete: false,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
// Declare and bind request queue
|
||||
_requestQueueName = $"{_options.QueuePrefix}.gw.{_nodeId}.in";
|
||||
await _channel.QueueDeclareAsync(
|
||||
queue: _requestQueueName,
|
||||
durable: _options.DurableQueues,
|
||||
exclusive: false,
|
||||
autoDelete: _options.AutoDeleteQueues,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
await _channel.QueueBindAsync(
|
||||
queue: _requestQueueName,
|
||||
exchange: _options.RequestExchange,
|
||||
routingKey: _nodeId,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
// Start consuming
|
||||
var consumer = new AsyncEventingBasicConsumer(_channel);
|
||||
consumer.ReceivedAsync += OnMessageReceivedAsync;
|
||||
|
||||
await _channel.BasicConsumeAsync(
|
||||
queue: _requestQueueName,
|
||||
autoAck: true, // At-most-once delivery
|
||||
consumer: consumer,
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
_logger.LogInformation(
|
||||
"RabbitMQ transport server started, consuming from {Queue}",
|
||||
_requestQueueName);
|
||||
}
|
||||
|
||||
private async Task OnMessageReceivedAsync(object sender, BasicDeliverEventArgs e)
|
||||
{
|
||||
try
|
||||
{
|
||||
var frame = RabbitMqFrameProtocol.ParseFrame(e.Body, e.BasicProperties);
|
||||
var connectionId = RabbitMqFrameProtocol.ExtractConnectionId(e.BasicProperties);
|
||||
var replyTo = e.BasicProperties.ReplyTo ?? string.Empty;
|
||||
|
||||
// Handle HELLO specially to register connection
|
||||
if (frame.Type == FrameType.Hello && !_connections.ContainsKey(connectionId))
|
||||
{
|
||||
var state = new ConnectionState
|
||||
{
|
||||
ConnectionId = connectionId,
|
||||
Instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = connectionId,
|
||||
ServiceName = "unknown",
|
||||
Version = "1.0.0",
|
||||
Region = "default"
|
||||
},
|
||||
Status = InstanceHealthStatus.Healthy,
|
||||
LastHeartbeatUtc = DateTime.UtcNow,
|
||||
TransportType = TransportType.RabbitMq
|
||||
};
|
||||
|
||||
_connections[connectionId] = (replyTo, state);
|
||||
_logger.LogInformation(
|
||||
"RabbitMQ connection established: {ConnectionId} with replyTo {ReplyTo}",
|
||||
connectionId,
|
||||
replyTo);
|
||||
OnConnection?.Invoke(connectionId, state);
|
||||
}
|
||||
|
||||
// Update heartbeat timestamp on HEARTBEAT frames
|
||||
if (frame.Type == FrameType.Heartbeat &&
|
||||
_connections.TryGetValue(connectionId, out var conn))
|
||||
{
|
||||
conn.State.LastHeartbeatUtc = DateTime.UtcNow;
|
||||
}
|
||||
|
||||
OnFrame?.Invoke(connectionId, frame);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error processing RabbitMQ message");
|
||||
}
|
||||
|
||||
await Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sends a frame to a connection.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID.</param>
|
||||
/// <param name="frame">The frame to send.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task SendFrameAsync(
|
||||
string connectionId,
|
||||
Frame frame,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
if (!_connections.TryGetValue(connectionId, out var conn))
|
||||
{
|
||||
throw new InvalidOperationException($"Connection {connectionId} not found");
|
||||
}
|
||||
|
||||
var properties = RabbitMqFrameProtocol.CreateProperties(frame, null, _options.DefaultTimeout);
|
||||
|
||||
// Send to response exchange with instance ID as routing key
|
||||
var routingKey = conn.ReplyTo.Split('.')[^1]; // Extract instance ID from queue name
|
||||
|
||||
await _channel!.BasicPublishAsync(
|
||||
exchange: _options.ResponseExchange,
|
||||
routingKey: routingKey,
|
||||
mandatory: false,
|
||||
basicProperties: properties,
|
||||
body: frame.Payload,
|
||||
cancellationToken: cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the connection state by ID.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID.</param>
|
||||
/// <returns>The connection state, or null if not found.</returns>
|
||||
public ConnectionState? GetConnectionState(string connectionId)
|
||||
{
|
||||
return _connections.TryGetValue(connectionId, out var conn) ? conn.State : null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets all active connections.
|
||||
/// </summary>
|
||||
public IEnumerable<ConnectionState> GetConnections() =>
|
||||
_connections.Values.Select(c => c.State);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the number of active connections.
|
||||
/// </summary>
|
||||
public int ConnectionCount => _connections.Count;
|
||||
|
||||
/// <summary>
|
||||
/// Removes a connection.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID.</param>
|
||||
public void RemoveConnection(string connectionId)
|
||||
{
|
||||
if (_connections.TryRemove(connectionId, out _))
|
||||
{
|
||||
_logger.LogInformation("RabbitMQ connection removed: {ConnectionId}", connectionId);
|
||||
OnDisconnection?.Invoke(connectionId);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task StopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
_logger.LogInformation("Stopping RabbitMQ transport server");
|
||||
|
||||
if (_channel is not null)
|
||||
{
|
||||
await _channel.CloseAsync(cancellationToken);
|
||||
}
|
||||
|
||||
if (_connection is not null)
|
||||
{
|
||||
await _connection.CloseAsync(cancellationToken);
|
||||
}
|
||||
|
||||
_connections.Clear();
|
||||
|
||||
_logger.LogInformation("RabbitMQ transport server stopped");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
await StopAsync(CancellationToken.None);
|
||||
|
||||
if (_channel is not null)
|
||||
{
|
||||
await _channel.DisposeAsync();
|
||||
}
|
||||
|
||||
if (_connection is not null)
|
||||
{
|
||||
await _connection.DisposeAsync();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
|
||||
namespace StellaOps.Router.Transport.RabbitMq;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering RabbitMQ transport services.
|
||||
/// </summary>
|
||||
public static class ServiceCollectionExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds RabbitMQ transport server services to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configure">Optional configuration action.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddRabbitMqTransportServer(
|
||||
this IServiceCollection services,
|
||||
Action<RabbitMqTransportOptions>? configure = null)
|
||||
{
|
||||
if (configure is not null)
|
||||
{
|
||||
services.Configure(configure);
|
||||
}
|
||||
|
||||
services.AddSingleton<RabbitMqTransportServer>();
|
||||
services.AddSingleton<ITransportServer>(sp => sp.GetRequiredService<RabbitMqTransportServer>());
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds RabbitMQ transport client services to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configure">Optional configuration action.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddRabbitMqTransportClient(
|
||||
this IServiceCollection services,
|
||||
Action<RabbitMqTransportOptions>? configure = null)
|
||||
{
|
||||
if (configure is not null)
|
||||
{
|
||||
services.Configure(configure);
|
||||
}
|
||||
|
||||
services.AddSingleton<RabbitMqTransportClient>();
|
||||
services.AddSingleton<ITransportClient>(sp => sp.GetRequiredService<RabbitMqTransportClient>());
|
||||
services.AddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<RabbitMqTransportClient>());
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net10.0</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<LangVersion>preview</LangVersion>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
<RootNamespace>StellaOps.Router.Transport.RabbitMq</RootNamespace>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Options" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="RabbitMQ.Client" Version="7.0.0" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
144
src/__Libraries/StellaOps.Router.Transport.Tcp/FrameProtocol.cs
Normal file
144
src/__Libraries/StellaOps.Router.Transport.Tcp/FrameProtocol.cs
Normal file
@@ -0,0 +1,144 @@
|
||||
using System.Buffers.Binary;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tcp;
|
||||
|
||||
/// <summary>
|
||||
/// Handles reading and writing length-prefixed frames over a stream.
|
||||
/// Frame format: [4-byte big-endian length][payload]
|
||||
/// Payload format: [1-byte frame type][16-byte correlation GUID][remaining data]
|
||||
/// </summary>
|
||||
public static class FrameProtocol
|
||||
{
|
||||
private const int LengthPrefixSize = 4;
|
||||
private const int FrameTypeSize = 1;
|
||||
private const int CorrelationIdSize = 16;
|
||||
private const int HeaderSize = FrameTypeSize + CorrelationIdSize;
|
||||
|
||||
/// <summary>
|
||||
/// Reads a complete frame from the stream.
|
||||
/// </summary>
|
||||
/// <param name="stream">The stream to read from.</param>
|
||||
/// <param name="maxFrameSize">The maximum frame size allowed.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>The frame read, or null if the stream is closed.</returns>
|
||||
public static async Task<Frame?> ReadFrameAsync(
|
||||
Stream stream,
|
||||
int maxFrameSize,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
// Read length prefix (4 bytes, big-endian)
|
||||
var lengthBuffer = new byte[LengthPrefixSize];
|
||||
var bytesRead = await ReadExactAsync(stream, lengthBuffer, cancellationToken);
|
||||
if (bytesRead == 0)
|
||||
{
|
||||
return null; // Connection closed
|
||||
}
|
||||
|
||||
if (bytesRead < LengthPrefixSize)
|
||||
{
|
||||
throw new InvalidOperationException("Incomplete length prefix received");
|
||||
}
|
||||
|
||||
var payloadLength = BinaryPrimitives.ReadInt32BigEndian(lengthBuffer);
|
||||
if (payloadLength < HeaderSize)
|
||||
{
|
||||
throw new InvalidOperationException($"Invalid payload length: {payloadLength}");
|
||||
}
|
||||
|
||||
if (payloadLength > maxFrameSize)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
$"Frame size {payloadLength} exceeds maximum {maxFrameSize}");
|
||||
}
|
||||
|
||||
// Read payload
|
||||
var payload = new byte[payloadLength];
|
||||
bytesRead = await ReadExactAsync(stream, payload, cancellationToken);
|
||||
if (bytesRead < payloadLength)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
$"Incomplete payload: expected {payloadLength}, got {bytesRead}");
|
||||
}
|
||||
|
||||
// Parse frame
|
||||
var frameType = (FrameType)payload[0];
|
||||
var correlationId = new Guid(payload.AsSpan(FrameTypeSize, CorrelationIdSize));
|
||||
var data = payload.AsMemory(HeaderSize);
|
||||
|
||||
return new Frame
|
||||
{
|
||||
Type = frameType,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = data
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes a frame to the stream.
|
||||
/// </summary>
|
||||
/// <param name="stream">The stream to write to.</param>
|
||||
/// <param name="frame">The frame to write.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public static async Task WriteFrameAsync(
|
||||
Stream stream,
|
||||
Frame frame,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
// Parse or generate correlation ID
|
||||
var correlationGuid = frame.CorrelationId is not null &&
|
||||
Guid.TryParse(frame.CorrelationId, out var parsed)
|
||||
? parsed
|
||||
: Guid.NewGuid();
|
||||
|
||||
var dataLength = frame.Payload.Length;
|
||||
var payloadLength = HeaderSize + dataLength;
|
||||
|
||||
// Create buffer for the complete message
|
||||
var buffer = new byte[LengthPrefixSize + payloadLength];
|
||||
|
||||
// Write length prefix (big-endian)
|
||||
BinaryPrimitives.WriteInt32BigEndian(buffer.AsSpan(0, LengthPrefixSize), payloadLength);
|
||||
|
||||
// Write frame type
|
||||
buffer[LengthPrefixSize] = (byte)frame.Type;
|
||||
|
||||
// Write correlation ID
|
||||
correlationGuid.TryWriteBytes(buffer.AsSpan(LengthPrefixSize + FrameTypeSize, CorrelationIdSize));
|
||||
|
||||
// Write data
|
||||
if (dataLength > 0)
|
||||
{
|
||||
frame.Payload.Span.CopyTo(buffer.AsSpan(LengthPrefixSize + HeaderSize));
|
||||
}
|
||||
|
||||
await stream.WriteAsync(buffer, cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Reads exactly the specified number of bytes from the stream.
|
||||
/// </summary>
|
||||
private static async Task<int> ReadExactAsync(
|
||||
Stream stream,
|
||||
Memory<byte> buffer,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var totalRead = 0;
|
||||
while (totalRead < buffer.Length)
|
||||
{
|
||||
var read = await stream.ReadAsync(
|
||||
buffer[totalRead..],
|
||||
cancellationToken);
|
||||
|
||||
if (read == 0)
|
||||
{
|
||||
return totalRead; // EOF
|
||||
}
|
||||
|
||||
totalRead += read;
|
||||
}
|
||||
|
||||
return totalRead;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
using System.Collections.Concurrent;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tcp;
|
||||
|
||||
/// <summary>
|
||||
/// Tracks pending requests waiting for responses.
|
||||
/// Enables multiplexing multiple concurrent requests on a single connection.
|
||||
/// </summary>
|
||||
public sealed class PendingRequestTracker : IDisposable
|
||||
{
|
||||
private readonly ConcurrentDictionary<Guid, TaskCompletionSource<Frame>> _pending = new();
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Tracks a request and returns a task that completes when the response arrives.
|
||||
/// </summary>
|
||||
/// <param name="correlationId">The correlation ID of the request.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>A task that completes with the response frame.</returns>
|
||||
public Task<Frame> TrackRequest(Guid correlationId, CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var tcs = new TaskCompletionSource<Frame>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
|
||||
// Register cancellation callback
|
||||
var registration = cancellationToken.Register(() =>
|
||||
{
|
||||
if (_pending.TryRemove(correlationId, out var pendingTcs))
|
||||
{
|
||||
pendingTcs.TrySetCanceled(cancellationToken);
|
||||
}
|
||||
});
|
||||
|
||||
// Store registration in state to dispose when completed
|
||||
tcs.Task.ContinueWith(_ => registration.Dispose(), TaskScheduler.Default);
|
||||
|
||||
_pending[correlationId] = tcs;
|
||||
return tcs.Task;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Completes a pending request with the response.
|
||||
/// </summary>
|
||||
/// <param name="correlationId">The correlation ID.</param>
|
||||
/// <param name="response">The response frame.</param>
|
||||
/// <returns>True if the request was found and completed; false otherwise.</returns>
|
||||
public bool CompleteRequest(Guid correlationId, Frame response)
|
||||
{
|
||||
if (_pending.TryRemove(correlationId, out var tcs))
|
||||
{
|
||||
return tcs.TrySetResult(response);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Fails a pending request with an exception.
|
||||
/// </summary>
|
||||
/// <param name="correlationId">The correlation ID.</param>
|
||||
/// <param name="exception">The exception.</param>
|
||||
/// <returns>True if the request was found and failed; false otherwise.</returns>
|
||||
public bool FailRequest(Guid correlationId, Exception exception)
|
||||
{
|
||||
if (_pending.TryRemove(correlationId, out var tcs))
|
||||
{
|
||||
return tcs.TrySetException(exception);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cancels a pending request.
|
||||
/// </summary>
|
||||
/// <param name="correlationId">The correlation ID.</param>
|
||||
/// <returns>True if the request was found and cancelled; false otherwise.</returns>
|
||||
public bool CancelRequest(Guid correlationId)
|
||||
{
|
||||
if (_pending.TryRemove(correlationId, out var tcs))
|
||||
{
|
||||
return tcs.TrySetCanceled();
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the number of pending requests.
|
||||
/// </summary>
|
||||
public int Count => _pending.Count;
|
||||
|
||||
/// <summary>
|
||||
/// Cancels all pending requests.
|
||||
/// </summary>
|
||||
/// <param name="exception">Optional exception to set.</param>
|
||||
public void CancelAll(Exception? exception = null)
|
||||
{
|
||||
foreach (var kvp in _pending)
|
||||
{
|
||||
if (_pending.TryRemove(kvp.Key, out var tcs))
|
||||
{
|
||||
if (exception is not null)
|
||||
{
|
||||
tcs.TrySetException(exception);
|
||||
}
|
||||
else
|
||||
{
|
||||
tcs.TrySetCanceled();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
CancelAll(new ObjectDisposedException(nameof(PendingRequestTracker)));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tcp;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering TCP transport services.
|
||||
/// </summary>
|
||||
public static class ServiceCollectionExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds TCP transport server services to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configure">Optional configuration action.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddTcpTransportServer(
|
||||
this IServiceCollection services,
|
||||
Action<TcpTransportOptions>? configure = null)
|
||||
{
|
||||
if (configure is not null)
|
||||
{
|
||||
services.Configure(configure);
|
||||
}
|
||||
|
||||
services.AddSingleton<TcpTransportServer>();
|
||||
services.AddSingleton<ITransportServer>(sp => sp.GetRequiredService<TcpTransportServer>());
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds TCP transport client services to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configure">Optional configuration action.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddTcpTransportClient(
|
||||
this IServiceCollection services,
|
||||
Action<TcpTransportOptions>? configure = null)
|
||||
{
|
||||
if (configure is not null)
|
||||
{
|
||||
services.Configure(configure);
|
||||
}
|
||||
|
||||
services.AddSingleton<TcpTransportClient>();
|
||||
services.AddSingleton<ITransportClient>(sp => sp.GetRequiredService<TcpTransportClient>());
|
||||
services.AddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<TcpTransportClient>());
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net10.0</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<LangVersion>preview</LangVersion>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
<RootNamespace>StellaOps.Router.Transport.Tcp</RootNamespace>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Options" Version="10.0.0-rc.2.25502.107" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
182
src/__Libraries/StellaOps.Router.Transport.Tcp/TcpConnection.cs
Normal file
182
src/__Libraries/StellaOps.Router.Transport.Tcp/TcpConnection.cs
Normal file
@@ -0,0 +1,182 @@
|
||||
using System.Net.Sockets;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tcp;
|
||||
|
||||
/// <summary>
|
||||
/// Represents a TCP connection to a microservice.
|
||||
/// </summary>
|
||||
public sealed class TcpConnection : IAsyncDisposable
|
||||
{
|
||||
private readonly TcpClient _client;
|
||||
private readonly NetworkStream _stream;
|
||||
private readonly SemaphoreSlim _writeLock = new(1, 1);
|
||||
private readonly TcpTransportOptions _options;
|
||||
private readonly ILogger _logger;
|
||||
private readonly CancellationTokenSource _connectionCts = new();
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the connection ID.
|
||||
/// </summary>
|
||||
public string ConnectionId { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the remote endpoint as a string.
|
||||
/// </summary>
|
||||
public string RemoteEndpoint { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether the connection is active.
|
||||
/// </summary>
|
||||
public bool IsConnected => _client.Connected && !_disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the connection state.
|
||||
/// </summary>
|
||||
public ConnectionState? State { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the cancellation token for this connection.
|
||||
/// </summary>
|
||||
public CancellationToken ConnectionToken => _connectionCts.Token;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a frame is received.
|
||||
/// </summary>
|
||||
public event Action<TcpConnection, Frame>? OnFrameReceived;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when the connection is closed.
|
||||
/// </summary>
|
||||
public event Action<TcpConnection, Exception?>? OnDisconnected;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TcpConnection"/> class.
|
||||
/// </summary>
|
||||
public TcpConnection(
|
||||
string connectionId,
|
||||
TcpClient client,
|
||||
TcpTransportOptions options,
|
||||
ILogger logger)
|
||||
{
|
||||
ConnectionId = connectionId;
|
||||
_client = client;
|
||||
_stream = client.GetStream();
|
||||
_options = options;
|
||||
_logger = logger;
|
||||
RemoteEndpoint = client.Client.RemoteEndPoint?.ToString() ?? "unknown";
|
||||
|
||||
// Configure socket options
|
||||
client.ReceiveBufferSize = options.ReceiveBufferSize;
|
||||
client.SendBufferSize = options.SendBufferSize;
|
||||
client.NoDelay = true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Starts the read loop to receive frames.
|
||||
/// </summary>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task ReadLoopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(
|
||||
cancellationToken, _connectionCts.Token);
|
||||
|
||||
Exception? disconnectException = null;
|
||||
|
||||
try
|
||||
{
|
||||
while (!linkedCts.Token.IsCancellationRequested)
|
||||
{
|
||||
var frame = await FrameProtocol.ReadFrameAsync(
|
||||
_stream,
|
||||
_options.MaxFrameSize,
|
||||
linkedCts.Token);
|
||||
|
||||
if (frame is null)
|
||||
{
|
||||
_logger.LogDebug("Connection {ConnectionId} closed by remote", ConnectionId);
|
||||
break;
|
||||
}
|
||||
|
||||
OnFrameReceived?.Invoke(this, frame);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected on shutdown
|
||||
}
|
||||
catch (IOException ex) when (ex.InnerException is SocketException)
|
||||
{
|
||||
disconnectException = ex;
|
||||
_logger.LogDebug(ex, "Connection {ConnectionId} socket error", ConnectionId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
disconnectException = ex;
|
||||
_logger.LogWarning(ex, "Connection {ConnectionId} read error", ConnectionId);
|
||||
}
|
||||
|
||||
OnDisconnected?.Invoke(this, disconnectException);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes a frame to the connection.
|
||||
/// </summary>
|
||||
/// <param name="frame">The frame to write.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task WriteFrameAsync(Frame frame, CancellationToken cancellationToken = default)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
await _writeLock.WaitAsync(cancellationToken);
|
||||
try
|
||||
{
|
||||
await FrameProtocol.WriteFrameAsync(_stream, frame, cancellationToken);
|
||||
await _stream.FlushAsync(cancellationToken);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_writeLock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Closes the connection.
|
||||
/// </summary>
|
||||
public void Close()
|
||||
{
|
||||
if (_disposed) return;
|
||||
|
||||
try
|
||||
{
|
||||
_connectionCts.Cancel();
|
||||
_client.Close();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogDebug(ex, "Error closing connection {ConnectionId}", ConnectionId);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
try
|
||||
{
|
||||
await _connectionCts.CancelAsync();
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Ignore
|
||||
}
|
||||
|
||||
_client.Dispose();
|
||||
_writeLock.Dispose();
|
||||
_connectionCts.Dispose();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,486 @@
|
||||
using System.Buffers;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Net.Sockets;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tcp;
|
||||
|
||||
/// <summary>
|
||||
/// TCP transport client implementation for microservices.
|
||||
/// </summary>
|
||||
public sealed class TcpTransportClient : ITransportClient, IMicroserviceTransport, IAsyncDisposable
|
||||
{
|
||||
private readonly TcpTransportOptions _options;
|
||||
private readonly ILogger<TcpTransportClient> _logger;
|
||||
private readonly PendingRequestTracker _pendingRequests = new();
|
||||
private readonly ConcurrentDictionary<string, CancellationTokenSource> _inflightHandlers = new();
|
||||
private readonly CancellationTokenSource _clientCts = new();
|
||||
private TcpClient? _client;
|
||||
private NetworkStream? _stream;
|
||||
private readonly SemaphoreSlim _writeLock = new(1, 1);
|
||||
private Task? _receiveTask;
|
||||
private bool _disposed;
|
||||
private string? _connectionId;
|
||||
private int _reconnectAttempts;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a REQUEST frame is received.
|
||||
/// </summary>
|
||||
public event Func<Frame, CancellationToken, Task<Frame>>? OnRequestReceived;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a CANCEL frame is received.
|
||||
/// </summary>
|
||||
public event Func<Guid, string?, Task>? OnCancelReceived;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TcpTransportClient"/> class.
|
||||
/// </summary>
|
||||
public TcpTransportClient(
|
||||
IOptions<TcpTransportOptions> options,
|
||||
ILogger<TcpTransportClient> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Connects to the gateway.
|
||||
/// </summary>
|
||||
/// <param name="instance">The instance descriptor.</param>
|
||||
/// <param name="endpoints">The endpoints to register.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task ConnectAsync(
|
||||
InstanceDescriptor instance,
|
||||
IReadOnlyList<EndpointDescriptor> endpoints,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
if (string.IsNullOrEmpty(_options.Host))
|
||||
{
|
||||
throw new InvalidOperationException("Host is not configured");
|
||||
}
|
||||
|
||||
await ConnectInternalAsync(cancellationToken);
|
||||
|
||||
_connectionId = Guid.NewGuid().ToString("N");
|
||||
|
||||
// Send HELLO frame
|
||||
var helloFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Hello,
|
||||
CorrelationId = Guid.NewGuid().ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
await WriteFrameAsync(helloFrame, cancellationToken);
|
||||
|
||||
_logger.LogInformation(
|
||||
"Connected to TCP gateway at {Host}:{Port} as {ServiceName}/{Version}",
|
||||
_options.Host,
|
||||
_options.Port,
|
||||
instance.ServiceName,
|
||||
instance.Version);
|
||||
|
||||
// Start receiving frames
|
||||
_receiveTask = Task.Run(() => ReceiveLoopAsync(_clientCts.Token), CancellationToken.None);
|
||||
}
|
||||
|
||||
private async Task ConnectInternalAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
timeoutCts.CancelAfter(_options.ConnectTimeout);
|
||||
|
||||
_client = new TcpClient
|
||||
{
|
||||
ReceiveBufferSize = _options.ReceiveBufferSize,
|
||||
SendBufferSize = _options.SendBufferSize,
|
||||
NoDelay = true
|
||||
};
|
||||
|
||||
await _client.ConnectAsync(_options.Host!, _options.Port, timeoutCts.Token);
|
||||
_stream = _client.GetStream();
|
||||
_reconnectAttempts = 0;
|
||||
}
|
||||
|
||||
private async Task ReconnectAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
|
||||
while (_reconnectAttempts < _options.MaxReconnectAttempts && !_clientCts.Token.IsCancellationRequested)
|
||||
{
|
||||
_reconnectAttempts++;
|
||||
var backoff = TimeSpan.FromMilliseconds(
|
||||
Math.Min(
|
||||
Math.Pow(2, _reconnectAttempts) * 100,
|
||||
_options.MaxReconnectBackoff.TotalMilliseconds));
|
||||
|
||||
_logger.LogInformation(
|
||||
"Reconnection attempt {Attempt} of {Max} in {Delay}ms",
|
||||
_reconnectAttempts,
|
||||
_options.MaxReconnectAttempts,
|
||||
backoff.TotalMilliseconds);
|
||||
|
||||
await Task.Delay(backoff, _clientCts.Token);
|
||||
|
||||
try
|
||||
{
|
||||
_client?.Dispose();
|
||||
await ConnectInternalAsync(_clientCts.Token);
|
||||
_logger.LogInformation("Reconnected to gateway");
|
||||
return;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Reconnection attempt {Attempt} failed", _reconnectAttempts);
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogError("Max reconnection attempts reached, giving up");
|
||||
}
|
||||
|
||||
private async Task ReceiveLoopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
try
|
||||
{
|
||||
var frame = await FrameProtocol.ReadFrameAsync(
|
||||
_stream!,
|
||||
_options.MaxFrameSize,
|
||||
cancellationToken);
|
||||
|
||||
if (frame is null)
|
||||
{
|
||||
_logger.LogDebug("Connection closed by server");
|
||||
await ReconnectAsync();
|
||||
continue;
|
||||
}
|
||||
|
||||
await ProcessFrameAsync(frame, cancellationToken);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
break;
|
||||
}
|
||||
catch (IOException ex) when (ex.InnerException is SocketException)
|
||||
{
|
||||
_logger.LogDebug(ex, "Socket error, attempting reconnection");
|
||||
await ReconnectAsync();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error in receive loop");
|
||||
await Task.Delay(1000, cancellationToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ProcessFrameAsync(Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
switch (frame.Type)
|
||||
{
|
||||
case FrameType.Request:
|
||||
case FrameType.RequestStreamData:
|
||||
await HandleRequestFrameAsync(frame, cancellationToken);
|
||||
break;
|
||||
|
||||
case FrameType.Cancel:
|
||||
HandleCancelFrame(frame);
|
||||
break;
|
||||
|
||||
case FrameType.Response:
|
||||
case FrameType.ResponseStreamData:
|
||||
if (frame.CorrelationId is not null &&
|
||||
Guid.TryParse(frame.CorrelationId, out var correlationId))
|
||||
{
|
||||
_pendingRequests.CompleteRequest(correlationId, frame);
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
_logger.LogWarning("Unexpected frame type {FrameType}", frame.Type);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private async Task HandleRequestFrameAsync(Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
if (OnRequestReceived is null)
|
||||
{
|
||||
_logger.LogWarning("No request handler registered");
|
||||
return;
|
||||
}
|
||||
|
||||
var correlationId = frame.CorrelationId ?? Guid.NewGuid().ToString("N");
|
||||
|
||||
using var handlerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
_inflightHandlers[correlationId] = handlerCts;
|
||||
|
||||
try
|
||||
{
|
||||
var response = await OnRequestReceived(frame, handlerCts.Token);
|
||||
var responseFrame = response with { CorrelationId = correlationId };
|
||||
|
||||
if (!handlerCts.Token.IsCancellationRequested)
|
||||
{
|
||||
await WriteFrameAsync(responseFrame, cancellationToken);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
_logger.LogDebug("Request {CorrelationId} was cancelled", correlationId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error handling request {CorrelationId}", correlationId);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_inflightHandlers.TryRemove(correlationId, out _);
|
||||
}
|
||||
}
|
||||
|
||||
private void HandleCancelFrame(Frame frame)
|
||||
{
|
||||
if (frame.CorrelationId is null) return;
|
||||
|
||||
_logger.LogDebug("Received CANCEL for {CorrelationId}", frame.CorrelationId);
|
||||
|
||||
if (_inflightHandlers.TryGetValue(frame.CorrelationId, out var cts))
|
||||
{
|
||||
try
|
||||
{
|
||||
cts.Cancel();
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Already completed
|
||||
}
|
||||
}
|
||||
|
||||
if (Guid.TryParse(frame.CorrelationId, out var guid))
|
||||
{
|
||||
_pendingRequests.CancelRequest(guid);
|
||||
OnCancelReceived?.Invoke(guid, null);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task WriteFrameAsync(Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
await _writeLock.WaitAsync(cancellationToken);
|
||||
try
|
||||
{
|
||||
await FrameProtocol.WriteFrameAsync(_stream!, frame, cancellationToken);
|
||||
await _stream!.FlushAsync(cancellationToken);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_writeLock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<Frame> SendRequestAsync(
|
||||
ConnectionState connection,
|
||||
Frame requestFrame,
|
||||
TimeSpan timeout,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var correlationId = requestFrame.CorrelationId is not null &&
|
||||
Guid.TryParse(requestFrame.CorrelationId, out var parsed)
|
||||
? parsed
|
||||
: Guid.NewGuid();
|
||||
|
||||
var framedRequest = requestFrame with { CorrelationId = correlationId.ToString("N") };
|
||||
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
timeoutCts.CancelAfter(timeout);
|
||||
|
||||
var responseTask = _pendingRequests.TrackRequest(correlationId, timeoutCts.Token);
|
||||
|
||||
await WriteFrameAsync(framedRequest, timeoutCts.Token);
|
||||
|
||||
try
|
||||
{
|
||||
return await responseTask;
|
||||
}
|
||||
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
throw new TimeoutException($"Request {correlationId} timed out after {timeout}");
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task SendCancelAsync(
|
||||
ConnectionState connection,
|
||||
Guid correlationId,
|
||||
string? reason = null)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var cancelFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Cancel,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
await WriteFrameAsync(cancelFrame, CancellationToken.None);
|
||||
_logger.LogDebug("Sent CANCEL for {CorrelationId}", correlationId);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task SendStreamingAsync(
|
||||
ConnectionState connection,
|
||||
Frame requestHeader,
|
||||
Stream requestBody,
|
||||
Func<Stream, Task> readResponseBody,
|
||||
PayloadLimits limits,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var correlationId = requestHeader.CorrelationId is not null &&
|
||||
Guid.TryParse(requestHeader.CorrelationId, out var parsed)
|
||||
? parsed
|
||||
: Guid.NewGuid();
|
||||
|
||||
var headerFrame = requestHeader with
|
||||
{
|
||||
Type = FrameType.Request,
|
||||
CorrelationId = correlationId.ToString("N")
|
||||
};
|
||||
await WriteFrameAsync(headerFrame, cancellationToken);
|
||||
|
||||
// Stream request body
|
||||
var buffer = ArrayPool<byte>.Shared.Rent(8192);
|
||||
try
|
||||
{
|
||||
long totalBytesRead = 0;
|
||||
int bytesRead;
|
||||
|
||||
while ((bytesRead = await requestBody.ReadAsync(buffer, cancellationToken)) > 0)
|
||||
{
|
||||
totalBytesRead += bytesRead;
|
||||
|
||||
if (totalBytesRead > limits.MaxRequestBytesPerCall)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
$"Request body exceeds limit of {limits.MaxRequestBytesPerCall} bytes");
|
||||
}
|
||||
|
||||
var dataFrame = new Frame
|
||||
{
|
||||
Type = FrameType.RequestStreamData,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = new ReadOnlyMemory<byte>(buffer, 0, bytesRead)
|
||||
};
|
||||
await WriteFrameAsync(dataFrame, cancellationToken);
|
||||
}
|
||||
|
||||
// End of stream marker
|
||||
var endFrame = new Frame
|
||||
{
|
||||
Type = FrameType.RequestStreamData,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
await WriteFrameAsync(endFrame, cancellationToken);
|
||||
}
|
||||
finally
|
||||
{
|
||||
ArrayPool<byte>.Shared.Return(buffer);
|
||||
}
|
||||
|
||||
// Read streaming response
|
||||
using var responseStream = new MemoryStream();
|
||||
await readResponseBody(responseStream);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sends a heartbeat.
|
||||
/// </summary>
|
||||
public async Task SendHeartbeatAsync(HeartbeatPayload heartbeat, CancellationToken cancellationToken)
|
||||
{
|
||||
var frame = new Frame
|
||||
{
|
||||
Type = FrameType.Heartbeat,
|
||||
CorrelationId = null,
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
await WriteFrameAsync(frame, cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cancels all in-flight handlers.
|
||||
/// </summary>
|
||||
public void CancelAllInflight(string reason)
|
||||
{
|
||||
var count = 0;
|
||||
foreach (var cts in _inflightHandlers.Values)
|
||||
{
|
||||
try
|
||||
{
|
||||
cts.Cancel();
|
||||
count++;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Already completed
|
||||
}
|
||||
}
|
||||
|
||||
if (count > 0)
|
||||
{
|
||||
_logger.LogInformation("Cancelled {Count} in-flight handlers: {Reason}", count, reason);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disconnects from the gateway.
|
||||
/// </summary>
|
||||
public async Task DisconnectAsync()
|
||||
{
|
||||
CancelAllInflight("Shutdown");
|
||||
|
||||
await _clientCts.CancelAsync();
|
||||
|
||||
if (_receiveTask is not null)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _receiveTask;
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
|
||||
_client?.Dispose();
|
||||
_logger.LogInformation("Disconnected from TCP gateway");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
await DisconnectAsync();
|
||||
|
||||
_pendingRequests.Dispose();
|
||||
_writeLock.Dispose();
|
||||
_clientCts.Dispose();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
using System.Net;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tcp;
|
||||
|
||||
/// <summary>
|
||||
/// Configuration options for TCP transport.
|
||||
/// </summary>
|
||||
public sealed class TcpTransportOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the address to bind to.
|
||||
/// Default: IPAddress.Any (0.0.0.0).
|
||||
/// </summary>
|
||||
public IPAddress BindAddress { get; set; } = IPAddress.Any;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the port to listen on.
|
||||
/// Default: 5100.
|
||||
/// </summary>
|
||||
public int Port { get; set; } = 5100;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the receive buffer size.
|
||||
/// Default: 64 KB.
|
||||
/// </summary>
|
||||
public int ReceiveBufferSize { get; set; } = 64 * 1024;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the send buffer size.
|
||||
/// Default: 64 KB.
|
||||
/// </summary>
|
||||
public int SendBufferSize { get; set; } = 64 * 1024;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the keep-alive interval.
|
||||
/// Default: 30 seconds.
|
||||
/// </summary>
|
||||
public TimeSpan KeepAliveInterval { get; set; } = TimeSpan.FromSeconds(30);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the connection timeout.
|
||||
/// Default: 10 seconds.
|
||||
/// </summary>
|
||||
public TimeSpan ConnectTimeout { get; set; } = TimeSpan.FromSeconds(10);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the maximum number of reconnection attempts.
|
||||
/// Default: 10.
|
||||
/// </summary>
|
||||
public int MaxReconnectAttempts { get; set; } = 10;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the maximum reconnection backoff.
|
||||
/// Default: 1 minute.
|
||||
/// </summary>
|
||||
public TimeSpan MaxReconnectBackoff { get; set; } = TimeSpan.FromMinutes(1);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the maximum frame size in bytes.
|
||||
/// Default: 16 MB.
|
||||
/// </summary>
|
||||
public int MaxFrameSize { get; set; } = 16 * 1024 * 1024;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the host for client connections.
|
||||
/// </summary>
|
||||
public string? Host { get; set; }
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
using System.Collections.Concurrent;
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tcp;
|
||||
|
||||
/// <summary>
|
||||
/// TCP transport server implementation for the gateway.
|
||||
/// </summary>
|
||||
public sealed class TcpTransportServer : ITransportServer, IAsyncDisposable
|
||||
{
|
||||
private readonly TcpTransportOptions _options;
|
||||
private readonly ILogger<TcpTransportServer> _logger;
|
||||
private readonly ConcurrentDictionary<string, TcpConnection> _connections = new();
|
||||
private TcpListener? _listener;
|
||||
private CancellationTokenSource? _serverCts;
|
||||
private Task? _acceptTask;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a connection is established.
|
||||
/// </summary>
|
||||
public event Action<string, ConnectionState>? OnConnection;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a connection is lost.
|
||||
/// </summary>
|
||||
public event Action<string>? OnDisconnection;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a frame is received.
|
||||
/// </summary>
|
||||
public event Action<string, Frame>? OnFrame;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TcpTransportServer"/> class.
|
||||
/// </summary>
|
||||
public TcpTransportServer(
|
||||
IOptions<TcpTransportOptions> options,
|
||||
ILogger<TcpTransportServer> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task StartAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
_serverCts = new CancellationTokenSource();
|
||||
_listener = new TcpListener(_options.BindAddress, _options.Port);
|
||||
_listener.Start();
|
||||
|
||||
_logger.LogInformation(
|
||||
"TCP transport server listening on {Address}:{Port}",
|
||||
_options.BindAddress,
|
||||
_options.Port);
|
||||
|
||||
_acceptTask = AcceptLoopAsync(_serverCts.Token);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private async Task AcceptLoopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
try
|
||||
{
|
||||
var client = await _listener!.AcceptTcpClientAsync(cancellationToken);
|
||||
var connectionId = GenerateConnectionId(client);
|
||||
|
||||
_logger.LogInformation(
|
||||
"Accepted connection {ConnectionId} from {RemoteEndpoint}",
|
||||
connectionId,
|
||||
client.Client.RemoteEndPoint);
|
||||
|
||||
var connection = new TcpConnection(connectionId, client, _options, _logger);
|
||||
_connections[connectionId] = connection;
|
||||
|
||||
connection.OnFrameReceived += HandleFrame;
|
||||
connection.OnDisconnected += HandleDisconnect;
|
||||
|
||||
// Start read loop (non-blocking)
|
||||
_ = Task.Run(() => connection.ReadLoopAsync(cancellationToken), CancellationToken.None);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected on shutdown
|
||||
break;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Listener disposed
|
||||
break;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error accepting connection");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void HandleFrame(TcpConnection connection, Frame frame)
|
||||
{
|
||||
// If this is a HELLO frame, create the ConnectionState
|
||||
if (frame.Type == FrameType.Hello && connection.State is null)
|
||||
{
|
||||
var state = new ConnectionState
|
||||
{
|
||||
ConnectionId = connection.ConnectionId,
|
||||
Instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = connection.ConnectionId,
|
||||
ServiceName = "unknown", // Will be updated from HELLO payload
|
||||
Version = "1.0.0",
|
||||
Region = "default"
|
||||
},
|
||||
Status = InstanceHealthStatus.Healthy,
|
||||
LastHeartbeatUtc = DateTime.UtcNow,
|
||||
TransportType = TransportType.Tcp
|
||||
};
|
||||
connection.State = state;
|
||||
OnConnection?.Invoke(connection.ConnectionId, state);
|
||||
}
|
||||
|
||||
OnFrame?.Invoke(connection.ConnectionId, frame);
|
||||
}
|
||||
|
||||
private void HandleDisconnect(TcpConnection connection, Exception? ex)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"Connection {ConnectionId} disconnected{Reason}",
|
||||
connection.ConnectionId,
|
||||
ex is not null ? $": {ex.Message}" : string.Empty);
|
||||
|
||||
_connections.TryRemove(connection.ConnectionId, out _);
|
||||
OnDisconnection?.Invoke(connection.ConnectionId);
|
||||
|
||||
// Clean up connection
|
||||
_ = connection.DisposeAsync();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sends a frame to a connection.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID.</param>
|
||||
/// <param name="frame">The frame to send.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task SendFrameAsync(
|
||||
string connectionId,
|
||||
Frame frame,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (_connections.TryGetValue(connectionId, out var connection))
|
||||
{
|
||||
await connection.WriteFrameAsync(frame, cancellationToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException($"Connection {connectionId} not found");
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets a connection by ID.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID.</param>
|
||||
/// <returns>The connection, or null if not found.</returns>
|
||||
public TcpConnection? GetConnection(string connectionId)
|
||||
{
|
||||
return _connections.TryGetValue(connectionId, out var conn) ? conn : null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets all active connections.
|
||||
/// </summary>
|
||||
public IEnumerable<TcpConnection> GetConnections() => _connections.Values;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the number of active connections.
|
||||
/// </summary>
|
||||
public int ConnectionCount => _connections.Count;
|
||||
|
||||
private static string GenerateConnectionId(TcpClient client)
|
||||
{
|
||||
var endpoint = client.Client.RemoteEndPoint as IPEndPoint;
|
||||
if (endpoint is not null)
|
||||
{
|
||||
return $"tcp-{endpoint.Address}-{endpoint.Port}-{Guid.NewGuid():N}".Substring(0, 32);
|
||||
}
|
||||
|
||||
return $"tcp-{Guid.NewGuid():N}";
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task StopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
_logger.LogInformation("Stopping TCP transport server");
|
||||
|
||||
if (_serverCts is not null)
|
||||
{
|
||||
await _serverCts.CancelAsync();
|
||||
}
|
||||
|
||||
_listener?.Stop();
|
||||
|
||||
if (_acceptTask is not null)
|
||||
{
|
||||
await _acceptTask;
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
foreach (var connection in _connections.Values)
|
||||
{
|
||||
connection.Close();
|
||||
await connection.DisposeAsync();
|
||||
}
|
||||
|
||||
_connections.Clear();
|
||||
|
||||
_logger.LogInformation("TCP transport server stopped");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
await StopAsync(CancellationToken.None);
|
||||
|
||||
_listener?.Dispose();
|
||||
_serverCts?.Dispose();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tls;
|
||||
|
||||
/// <summary>
|
||||
/// Utility class for loading certificates from various sources.
|
||||
/// </summary>
|
||||
public static class CertificateLoader
|
||||
{
|
||||
/// <summary>
|
||||
/// Loads a server certificate from the options.
|
||||
/// </summary>
|
||||
/// <param name="options">The TLS transport options.</param>
|
||||
/// <returns>The loaded certificate.</returns>
|
||||
/// <exception cref="InvalidOperationException">Thrown when no certificate is configured.</exception>
|
||||
public static X509Certificate2 LoadServerCertificate(TlsTransportOptions options)
|
||||
{
|
||||
// Direct certificate object takes precedence
|
||||
if (options.ServerCertificate is not null)
|
||||
{
|
||||
return options.ServerCertificate;
|
||||
}
|
||||
|
||||
// Load from path
|
||||
if (string.IsNullOrEmpty(options.ServerCertificatePath))
|
||||
{
|
||||
throw new InvalidOperationException("Server certificate is not configured");
|
||||
}
|
||||
|
||||
return LoadCertificateFromPath(
|
||||
options.ServerCertificatePath,
|
||||
options.ServerCertificateKeyPath,
|
||||
options.ServerCertificatePassword);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Loads a client certificate from the options.
|
||||
/// </summary>
|
||||
/// <param name="options">The TLS transport options.</param>
|
||||
/// <returns>The loaded certificate, or null if not configured.</returns>
|
||||
public static X509Certificate2? LoadClientCertificate(TlsTransportOptions options)
|
||||
{
|
||||
// Direct certificate object takes precedence
|
||||
if (options.ClientCertificate is not null)
|
||||
{
|
||||
return options.ClientCertificate;
|
||||
}
|
||||
|
||||
// Load from path
|
||||
if (string.IsNullOrEmpty(options.ClientCertificatePath))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
return LoadCertificateFromPath(
|
||||
options.ClientCertificatePath,
|
||||
options.ClientCertificateKeyPath,
|
||||
options.ClientCertificatePassword);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Loads a certificate from a file path.
|
||||
/// </summary>
|
||||
/// <param name="certPath">The certificate path (PEM or PFX).</param>
|
||||
/// <param name="keyPath">The private key path (optional, for PEM).</param>
|
||||
/// <param name="password">The password (optional, for PFX).</param>
|
||||
/// <returns>The loaded certificate.</returns>
|
||||
public static X509Certificate2 LoadCertificateFromPath(
|
||||
string certPath,
|
||||
string? keyPath = null,
|
||||
string? password = null)
|
||||
{
|
||||
var extension = Path.GetExtension(certPath).ToLowerInvariant();
|
||||
|
||||
return extension switch
|
||||
{
|
||||
".pfx" or ".p12" => LoadPfxCertificate(certPath, password),
|
||||
".pem" or ".crt" or ".cer" => LoadPemCertificate(certPath, keyPath),
|
||||
_ => throw new InvalidOperationException($"Unsupported certificate format: {extension}")
|
||||
};
|
||||
}
|
||||
|
||||
private static X509Certificate2 LoadPfxCertificate(string pfxPath, string? password)
|
||||
{
|
||||
return X509CertificateLoader.LoadPkcs12FromFile(
|
||||
pfxPath,
|
||||
password,
|
||||
X509KeyStorageFlags.MachineKeySet | X509KeyStorageFlags.PersistKeySet);
|
||||
}
|
||||
|
||||
private static X509Certificate2 LoadPemCertificate(string certPath, string? keyPath)
|
||||
{
|
||||
var certPem = File.ReadAllText(certPath);
|
||||
|
||||
if (string.IsNullOrEmpty(keyPath))
|
||||
{
|
||||
// Assume the key is in the same file
|
||||
return X509Certificate2.CreateFromPem(certPem);
|
||||
}
|
||||
|
||||
var keyPem = File.ReadAllText(keyPath);
|
||||
return X509Certificate2.CreateFromPem(certPem, keyPem);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tls;
|
||||
|
||||
/// <summary>
|
||||
/// Watches certificate files for changes and triggers hot-reload.
|
||||
/// </summary>
|
||||
public sealed class CertificateWatcher : IDisposable
|
||||
{
|
||||
private readonly TlsTransportOptions _options;
|
||||
private readonly ILogger _logger;
|
||||
private readonly List<FileSystemWatcher> _watchers = new();
|
||||
private volatile X509Certificate2? _currentServerCert;
|
||||
private volatile X509Certificate2? _currentClientCert;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when the server certificate is reloaded.
|
||||
/// </summary>
|
||||
public event Action<X509Certificate2>? OnServerCertificateReloaded;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when the client certificate is reloaded.
|
||||
/// </summary>
|
||||
public event Action<X509Certificate2>? OnClientCertificateReloaded;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the current server certificate.
|
||||
/// </summary>
|
||||
public X509Certificate2? ServerCertificate => _currentServerCert;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the current client certificate.
|
||||
/// </summary>
|
||||
public X509Certificate2? ClientCertificate => _currentClientCert;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="CertificateWatcher"/> class.
|
||||
/// </summary>
|
||||
public CertificateWatcher(TlsTransportOptions options, ILogger logger)
|
||||
{
|
||||
_options = options;
|
||||
_logger = logger;
|
||||
|
||||
// Load initial certificates
|
||||
LoadCertificates();
|
||||
|
||||
// Set up file system watchers if hot-reload is enabled
|
||||
if (_options.EnableCertificateHotReload)
|
||||
{
|
||||
SetupWatchers();
|
||||
}
|
||||
}
|
||||
|
||||
private void LoadCertificates()
|
||||
{
|
||||
try
|
||||
{
|
||||
if (!string.IsNullOrEmpty(_options.ServerCertificatePath) ||
|
||||
_options.ServerCertificate is not null)
|
||||
{
|
||||
_currentServerCert = CertificateLoader.LoadServerCertificate(_options);
|
||||
_logger.LogInformation(
|
||||
"Loaded server certificate: {Subject}, Expires: {Expiry}",
|
||||
_currentServerCert.Subject,
|
||||
_currentServerCert.NotAfter);
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Failed to load server certificate");
|
||||
throw;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
_currentClientCert = CertificateLoader.LoadClientCertificate(_options);
|
||||
if (_currentClientCert is not null)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"Loaded client certificate: {Subject}, Expires: {Expiry}",
|
||||
_currentClientCert.Subject,
|
||||
_currentClientCert.NotAfter);
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Failed to load client certificate");
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
private void SetupWatchers()
|
||||
{
|
||||
if (!string.IsNullOrEmpty(_options.ServerCertificatePath))
|
||||
{
|
||||
var watcher = CreateWatcher(_options.ServerCertificatePath, ReloadServerCertificate);
|
||||
if (watcher is not null) _watchers.Add(watcher);
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(_options.ServerCertificateKeyPath))
|
||||
{
|
||||
var watcher = CreateWatcher(_options.ServerCertificateKeyPath, ReloadServerCertificate);
|
||||
if (watcher is not null) _watchers.Add(watcher);
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(_options.ClientCertificatePath))
|
||||
{
|
||||
var watcher = CreateWatcher(_options.ClientCertificatePath, ReloadClientCertificate);
|
||||
if (watcher is not null) _watchers.Add(watcher);
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(_options.ClientCertificateKeyPath))
|
||||
{
|
||||
var watcher = CreateWatcher(_options.ClientCertificateKeyPath, ReloadClientCertificate);
|
||||
if (watcher is not null) _watchers.Add(watcher);
|
||||
}
|
||||
}
|
||||
|
||||
private FileSystemWatcher? CreateWatcher(string filePath, Action reloadAction)
|
||||
{
|
||||
var directory = Path.GetDirectoryName(filePath);
|
||||
if (string.IsNullOrEmpty(directory) || !Directory.Exists(directory))
|
||||
{
|
||||
_logger.LogWarning("Cannot watch certificate path: directory not found for {Path}", filePath);
|
||||
return null;
|
||||
}
|
||||
|
||||
var fileName = Path.GetFileName(filePath);
|
||||
var watcher = new FileSystemWatcher(directory, fileName)
|
||||
{
|
||||
NotifyFilter = NotifyFilters.LastWrite | NotifyFilters.CreationTime
|
||||
};
|
||||
|
||||
// Debounce file changes to avoid multiple reloads
|
||||
DateTime lastReload = DateTime.MinValue;
|
||||
watcher.Changed += (sender, args) =>
|
||||
{
|
||||
if (DateTime.UtcNow - lastReload < TimeSpan.FromSeconds(5))
|
||||
return;
|
||||
|
||||
lastReload = DateTime.UtcNow;
|
||||
_logger.LogInformation("Certificate file changed: {Path}", filePath);
|
||||
|
||||
// Delay slightly to ensure file write is complete
|
||||
Task.Delay(500).ContinueWith(_ => reloadAction());
|
||||
};
|
||||
|
||||
watcher.EnableRaisingEvents = true;
|
||||
_logger.LogInformation("Watching certificate file: {Path}", filePath);
|
||||
|
||||
return watcher;
|
||||
}
|
||||
|
||||
private void ReloadServerCertificate()
|
||||
{
|
||||
try
|
||||
{
|
||||
var oldCert = _currentServerCert;
|
||||
_currentServerCert = CertificateLoader.LoadServerCertificate(_options);
|
||||
|
||||
_logger.LogInformation(
|
||||
"Reloaded server certificate: {Subject}, Expires: {Expiry}",
|
||||
_currentServerCert.Subject,
|
||||
_currentServerCert.NotAfter);
|
||||
|
||||
OnServerCertificateReloaded?.Invoke(_currentServerCert);
|
||||
|
||||
// Dispose old certificate
|
||||
oldCert?.Dispose();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Failed to reload server certificate");
|
||||
}
|
||||
}
|
||||
|
||||
private void ReloadClientCertificate()
|
||||
{
|
||||
try
|
||||
{
|
||||
var oldCert = _currentClientCert;
|
||||
_currentClientCert = CertificateLoader.LoadClientCertificate(_options);
|
||||
|
||||
if (_currentClientCert is not null)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"Reloaded client certificate: {Subject}, Expires: {Expiry}",
|
||||
_currentClientCert.Subject,
|
||||
_currentClientCert.NotAfter);
|
||||
|
||||
OnClientCertificateReloaded?.Invoke(_currentClientCert);
|
||||
}
|
||||
|
||||
// Dispose old certificate
|
||||
oldCert?.Dispose();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Failed to reload client certificate");
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
foreach (var watcher in _watchers)
|
||||
{
|
||||
watcher.EnableRaisingEvents = false;
|
||||
watcher.Dispose();
|
||||
}
|
||||
|
||||
_watchers.Clear();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tls;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering TLS transport services.
|
||||
/// </summary>
|
||||
public static class ServiceCollectionExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds the TLS transport server to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configure">Configuration action.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddTlsTransportServer(
|
||||
this IServiceCollection services,
|
||||
Action<TlsTransportOptions>? configure = null)
|
||||
{
|
||||
if (configure is not null)
|
||||
{
|
||||
services.Configure(configure);
|
||||
}
|
||||
|
||||
services.AddSingleton<TlsTransportServer>();
|
||||
services.AddSingleton<ITransportServer>(sp => sp.GetRequiredService<TlsTransportServer>());
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds the TLS transport client to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configure">Configuration action.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddTlsTransportClient(
|
||||
this IServiceCollection services,
|
||||
Action<TlsTransportOptions>? configure = null)
|
||||
{
|
||||
if (configure is not null)
|
||||
{
|
||||
services.Configure(configure);
|
||||
}
|
||||
|
||||
services.AddSingleton<TlsTransportClient>();
|
||||
services.AddSingleton<ITransportClient>(sp => sp.GetRequiredService<TlsTransportClient>());
|
||||
services.AddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<TlsTransportClient>());
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net10.0</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<LangVersion>preview</LangVersion>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
|
||||
<ProjectReference Include="..\StellaOps.Router.Transport.Tcp\StellaOps.Router.Transport.Tcp.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
220
src/__Libraries/StellaOps.Router.Transport.Tls/TlsConnection.cs
Normal file
220
src/__Libraries/StellaOps.Router.Transport.Tls/TlsConnection.cs
Normal file
@@ -0,0 +1,220 @@
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using StellaOps.Router.Transport.Tcp;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tls;
|
||||
|
||||
/// <summary>
|
||||
/// Represents a TLS-secured connection to a microservice.
|
||||
/// </summary>
|
||||
public sealed class TlsConnection : IAsyncDisposable
|
||||
{
|
||||
private readonly TcpClient _client;
|
||||
private readonly SslStream _sslStream;
|
||||
private readonly SemaphoreSlim _writeLock = new(1, 1);
|
||||
private readonly TlsTransportOptions _options;
|
||||
private readonly ILogger _logger;
|
||||
private readonly CancellationTokenSource _connectionCts = new();
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the connection ID.
|
||||
/// </summary>
|
||||
public string ConnectionId { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the remote endpoint as a string.
|
||||
/// </summary>
|
||||
public string RemoteEndpoint { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether the connection is active.
|
||||
/// </summary>
|
||||
public bool IsConnected => _client.Connected && !_disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the connection state.
|
||||
/// </summary>
|
||||
public ConnectionState? State { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the cancellation token for this connection.
|
||||
/// </summary>
|
||||
public CancellationToken ConnectionToken => _connectionCts.Token;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the remote certificate (if mTLS).
|
||||
/// </summary>
|
||||
public X509Certificate? RemoteCertificate => _sslStream.RemoteCertificate;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the peer identity extracted from the certificate.
|
||||
/// </summary>
|
||||
public string? PeerIdentity { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a frame is received.
|
||||
/// </summary>
|
||||
public event Action<TlsConnection, Frame>? OnFrameReceived;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when the connection is closed.
|
||||
/// </summary>
|
||||
public event Action<TlsConnection, Exception?>? OnDisconnected;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TlsConnection"/> class.
|
||||
/// </summary>
|
||||
public TlsConnection(
|
||||
string connectionId,
|
||||
TcpClient client,
|
||||
SslStream sslStream,
|
||||
TlsTransportOptions options,
|
||||
ILogger logger)
|
||||
{
|
||||
ConnectionId = connectionId;
|
||||
_client = client;
|
||||
_sslStream = sslStream;
|
||||
_options = options;
|
||||
_logger = logger;
|
||||
RemoteEndpoint = client.Client.RemoteEndPoint?.ToString() ?? "unknown";
|
||||
|
||||
// Extract peer identity from certificate
|
||||
if (_sslStream.RemoteCertificate is X509Certificate2 cert)
|
||||
{
|
||||
PeerIdentity = ExtractIdentityFromCertificate(cert);
|
||||
}
|
||||
|
||||
// Configure socket options
|
||||
client.ReceiveBufferSize = options.ReceiveBufferSize;
|
||||
client.SendBufferSize = options.SendBufferSize;
|
||||
client.NoDelay = true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extracts identity from a certificate.
|
||||
/// </summary>
|
||||
private static string? ExtractIdentityFromCertificate(X509Certificate2 cert)
|
||||
{
|
||||
// Try to get Common Name (CN)
|
||||
var cn = cert.GetNameInfo(X509NameType.SimpleName, forIssuer: false);
|
||||
if (!string.IsNullOrEmpty(cn))
|
||||
{
|
||||
return cn;
|
||||
}
|
||||
|
||||
// Fallback to subject
|
||||
return cert.Subject;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Starts the read loop to receive frames.
|
||||
/// </summary>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task ReadLoopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(
|
||||
cancellationToken, _connectionCts.Token);
|
||||
|
||||
Exception? disconnectException = null;
|
||||
|
||||
try
|
||||
{
|
||||
while (!linkedCts.Token.IsCancellationRequested)
|
||||
{
|
||||
var frame = await FrameProtocol.ReadFrameAsync(
|
||||
_sslStream,
|
||||
_options.MaxFrameSize,
|
||||
linkedCts.Token);
|
||||
|
||||
if (frame is null)
|
||||
{
|
||||
_logger.LogDebug("TLS connection {ConnectionId} closed by remote", ConnectionId);
|
||||
break;
|
||||
}
|
||||
|
||||
OnFrameReceived?.Invoke(this, frame);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected on shutdown
|
||||
}
|
||||
catch (IOException ex) when (ex.InnerException is SocketException)
|
||||
{
|
||||
disconnectException = ex;
|
||||
_logger.LogDebug(ex, "TLS connection {ConnectionId} socket error", ConnectionId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
disconnectException = ex;
|
||||
_logger.LogWarning(ex, "TLS connection {ConnectionId} read error", ConnectionId);
|
||||
}
|
||||
|
||||
OnDisconnected?.Invoke(this, disconnectException);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes a frame to the connection.
|
||||
/// </summary>
|
||||
/// <param name="frame">The frame to write.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task WriteFrameAsync(Frame frame, CancellationToken cancellationToken = default)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
await _writeLock.WaitAsync(cancellationToken);
|
||||
try
|
||||
{
|
||||
await FrameProtocol.WriteFrameAsync(_sslStream, frame, cancellationToken);
|
||||
await _sslStream.FlushAsync(cancellationToken);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_writeLock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Closes the connection.
|
||||
/// </summary>
|
||||
public void Close()
|
||||
{
|
||||
if (_disposed) return;
|
||||
|
||||
try
|
||||
{
|
||||
_connectionCts.Cancel();
|
||||
_sslStream.Close();
|
||||
_client.Close();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogDebug(ex, "Error closing TLS connection {ConnectionId}", ConnectionId);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
try
|
||||
{
|
||||
await _connectionCts.CancelAsync();
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Ignore
|
||||
}
|
||||
|
||||
await _sslStream.DisposeAsync();
|
||||
_client.Dispose();
|
||||
_writeLock.Dispose();
|
||||
_connectionCts.Dispose();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,578 @@
|
||||
using System.Buffers;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Authentication;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using StellaOps.Router.Transport.Tcp;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tls;
|
||||
|
||||
/// <summary>
|
||||
/// TLS transport client implementation for microservices.
|
||||
/// </summary>
|
||||
public sealed class TlsTransportClient : ITransportClient, IMicroserviceTransport, IAsyncDisposable
|
||||
{
|
||||
private readonly TlsTransportOptions _options;
|
||||
private readonly ILogger<TlsTransportClient> _logger;
|
||||
private readonly CertificateWatcher _certWatcher;
|
||||
private readonly PendingRequestTracker _pendingRequests = new();
|
||||
private readonly ConcurrentDictionary<string, CancellationTokenSource> _inflightHandlers = new();
|
||||
private readonly CancellationTokenSource _clientCts = new();
|
||||
private TcpClient? _client;
|
||||
private SslStream? _sslStream;
|
||||
private readonly SemaphoreSlim _writeLock = new(1, 1);
|
||||
private Task? _receiveTask;
|
||||
private bool _disposed;
|
||||
private string? _connectionId;
|
||||
private int _reconnectAttempts;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a REQUEST frame is received.
|
||||
/// </summary>
|
||||
public event Func<Frame, CancellationToken, Task<Frame>>? OnRequestReceived;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a CANCEL frame is received.
|
||||
/// </summary>
|
||||
public event Func<Guid, string?, Task>? OnCancelReceived;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TlsTransportClient"/> class.
|
||||
/// </summary>
|
||||
public TlsTransportClient(
|
||||
IOptions<TlsTransportOptions> options,
|
||||
ILogger<TlsTransportClient> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
_certWatcher = new CertificateWatcher(_options, logger);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Connects to the gateway.
|
||||
/// </summary>
|
||||
/// <param name="instance">The instance descriptor.</param>
|
||||
/// <param name="endpoints">The endpoints to register.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task ConnectAsync(
|
||||
InstanceDescriptor instance,
|
||||
IReadOnlyList<EndpointDescriptor> endpoints,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
if (string.IsNullOrEmpty(_options.Host))
|
||||
{
|
||||
throw new InvalidOperationException("Host is not configured");
|
||||
}
|
||||
|
||||
await ConnectInternalAsync(cancellationToken);
|
||||
|
||||
_connectionId = Guid.NewGuid().ToString("N");
|
||||
|
||||
// Send HELLO frame
|
||||
var helloFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Hello,
|
||||
CorrelationId = Guid.NewGuid().ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
await WriteFrameAsync(helloFrame, cancellationToken);
|
||||
|
||||
_logger.LogInformation(
|
||||
"Connected to TLS gateway at {Host}:{Port} as {ServiceName}/{Version}",
|
||||
_options.Host,
|
||||
_options.Port,
|
||||
instance.ServiceName,
|
||||
instance.Version);
|
||||
|
||||
// Start receiving frames
|
||||
_receiveTask = Task.Run(() => ReceiveLoopAsync(_clientCts.Token), CancellationToken.None);
|
||||
}
|
||||
|
||||
private async Task ConnectInternalAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
timeoutCts.CancelAfter(_options.ConnectTimeout);
|
||||
|
||||
_client = new TcpClient
|
||||
{
|
||||
ReceiveBufferSize = _options.ReceiveBufferSize,
|
||||
SendBufferSize = _options.SendBufferSize,
|
||||
NoDelay = true
|
||||
};
|
||||
|
||||
await _client.ConnectAsync(_options.Host!, _options.Port, timeoutCts.Token);
|
||||
|
||||
_sslStream = new SslStream(
|
||||
_client.GetStream(),
|
||||
leaveInnerStreamOpen: false,
|
||||
userCertificateValidationCallback: ValidateServerCertificate);
|
||||
|
||||
var clientCerts = _certWatcher.ClientCertificate is not null
|
||||
? new X509CertificateCollection { _certWatcher.ClientCertificate }
|
||||
: null;
|
||||
|
||||
await _sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions
|
||||
{
|
||||
TargetHost = _options.ExpectedServerHostname ?? _options.Host,
|
||||
ClientCertificates = clientCerts,
|
||||
EnabledSslProtocols = _options.EnabledProtocols,
|
||||
CertificateRevocationCheckMode = _options.CheckCertificateRevocation
|
||||
? X509RevocationMode.Online
|
||||
: X509RevocationMode.NoCheck
|
||||
}, timeoutCts.Token);
|
||||
|
||||
_logger.LogInformation(
|
||||
"TLS handshake completed: Protocol={Protocol}, CipherSuite={CipherSuite}",
|
||||
_sslStream.SslProtocol,
|
||||
_sslStream.NegotiatedCipherSuite);
|
||||
|
||||
_reconnectAttempts = 0;
|
||||
}
|
||||
|
||||
private bool ValidateServerCertificate(
|
||||
object sender,
|
||||
X509Certificate? certificate,
|
||||
X509Chain? chain,
|
||||
SslPolicyErrors errors)
|
||||
{
|
||||
// Allow self-signed in development
|
||||
if (_options.AllowSelfSigned)
|
||||
{
|
||||
if (errors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors))
|
||||
{
|
||||
// Check if the only error is self-signed
|
||||
if (chain is not null && chain.ChainStatus.All(s =>
|
||||
s.Status == X509ChainStatusFlags.UntrustedRoot ||
|
||||
s.Status == X509ChainStatusFlags.PartialChain))
|
||||
{
|
||||
_logger.LogDebug("Allowing self-signed server certificate");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Allow if no errors or only name mismatch
|
||||
if (errors == SslPolicyErrors.None ||
|
||||
errors == SslPolicyErrors.RemoteCertificateNameMismatch)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (errors != SslPolicyErrors.None)
|
||||
{
|
||||
_logger.LogWarning("Server certificate validation failed: {Errors}", errors);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Hostname verification
|
||||
if (!string.IsNullOrEmpty(_options.ExpectedServerHostname) && certificate is not null)
|
||||
{
|
||||
var cert = new X509Certificate2(certificate);
|
||||
var cn = cert.GetNameInfo(X509NameType.SimpleName, forIssuer: false);
|
||||
|
||||
if (!string.Equals(cn, _options.ExpectedServerHostname, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
_logger.LogWarning(
|
||||
"Server certificate hostname mismatch: expected {Expected}, got {Actual}",
|
||||
_options.ExpectedServerHostname,
|
||||
cn);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private async Task ReconnectAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
|
||||
while (_reconnectAttempts < _options.MaxReconnectAttempts && !_clientCts.Token.IsCancellationRequested)
|
||||
{
|
||||
_reconnectAttempts++;
|
||||
var backoff = TimeSpan.FromMilliseconds(
|
||||
Math.Min(
|
||||
Math.Pow(2, _reconnectAttempts) * 100,
|
||||
_options.MaxReconnectBackoff.TotalMilliseconds));
|
||||
|
||||
_logger.LogInformation(
|
||||
"TLS reconnection attempt {Attempt} of {Max} in {Delay}ms",
|
||||
_reconnectAttempts,
|
||||
_options.MaxReconnectAttempts,
|
||||
backoff.TotalMilliseconds);
|
||||
|
||||
await Task.Delay(backoff, _clientCts.Token);
|
||||
|
||||
try
|
||||
{
|
||||
_sslStream?.Dispose();
|
||||
_client?.Dispose();
|
||||
await ConnectInternalAsync(_clientCts.Token);
|
||||
_logger.LogInformation("Reconnected to TLS gateway");
|
||||
return;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "TLS reconnection attempt {Attempt} failed", _reconnectAttempts);
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogError("Max TLS reconnection attempts reached, giving up");
|
||||
}
|
||||
|
||||
private async Task ReceiveLoopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
try
|
||||
{
|
||||
var frame = await FrameProtocol.ReadFrameAsync(
|
||||
_sslStream!,
|
||||
_options.MaxFrameSize,
|
||||
cancellationToken);
|
||||
|
||||
if (frame is null)
|
||||
{
|
||||
_logger.LogDebug("TLS connection closed by server");
|
||||
await ReconnectAsync();
|
||||
continue;
|
||||
}
|
||||
|
||||
await ProcessFrameAsync(frame, cancellationToken);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
break;
|
||||
}
|
||||
catch (IOException ex) when (ex.InnerException is SocketException)
|
||||
{
|
||||
_logger.LogDebug(ex, "TLS socket error, attempting reconnection");
|
||||
await ReconnectAsync();
|
||||
}
|
||||
catch (AuthenticationException ex)
|
||||
{
|
||||
_logger.LogError(ex, "TLS authentication error during receive");
|
||||
break;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error in TLS receive loop");
|
||||
await Task.Delay(1000, cancellationToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ProcessFrameAsync(Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
switch (frame.Type)
|
||||
{
|
||||
case FrameType.Request:
|
||||
case FrameType.RequestStreamData:
|
||||
await HandleRequestFrameAsync(frame, cancellationToken);
|
||||
break;
|
||||
|
||||
case FrameType.Cancel:
|
||||
HandleCancelFrame(frame);
|
||||
break;
|
||||
|
||||
case FrameType.Response:
|
||||
case FrameType.ResponseStreamData:
|
||||
if (frame.CorrelationId is not null &&
|
||||
Guid.TryParse(frame.CorrelationId, out var correlationId))
|
||||
{
|
||||
_pendingRequests.CompleteRequest(correlationId, frame);
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
_logger.LogWarning("Unexpected frame type {FrameType}", frame.Type);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private async Task HandleRequestFrameAsync(Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
if (OnRequestReceived is null)
|
||||
{
|
||||
_logger.LogWarning("No request handler registered");
|
||||
return;
|
||||
}
|
||||
|
||||
var correlationId = frame.CorrelationId ?? Guid.NewGuid().ToString("N");
|
||||
|
||||
using var handlerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
_inflightHandlers[correlationId] = handlerCts;
|
||||
|
||||
try
|
||||
{
|
||||
var response = await OnRequestReceived(frame, handlerCts.Token);
|
||||
var responseFrame = response with { CorrelationId = correlationId };
|
||||
|
||||
if (!handlerCts.Token.IsCancellationRequested)
|
||||
{
|
||||
await WriteFrameAsync(responseFrame, cancellationToken);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
_logger.LogDebug("Request {CorrelationId} was cancelled", correlationId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error handling request {CorrelationId}", correlationId);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_inflightHandlers.TryRemove(correlationId, out _);
|
||||
}
|
||||
}
|
||||
|
||||
private void HandleCancelFrame(Frame frame)
|
||||
{
|
||||
if (frame.CorrelationId is null) return;
|
||||
|
||||
_logger.LogDebug("Received CANCEL for {CorrelationId}", frame.CorrelationId);
|
||||
|
||||
if (_inflightHandlers.TryGetValue(frame.CorrelationId, out var cts))
|
||||
{
|
||||
try
|
||||
{
|
||||
cts.Cancel();
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Already completed
|
||||
}
|
||||
}
|
||||
|
||||
if (Guid.TryParse(frame.CorrelationId, out var guid))
|
||||
{
|
||||
_pendingRequests.CancelRequest(guid);
|
||||
OnCancelReceived?.Invoke(guid, null);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task WriteFrameAsync(Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
await _writeLock.WaitAsync(cancellationToken);
|
||||
try
|
||||
{
|
||||
await FrameProtocol.WriteFrameAsync(_sslStream!, frame, cancellationToken);
|
||||
await _sslStream!.FlushAsync(cancellationToken);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_writeLock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<Frame> SendRequestAsync(
|
||||
ConnectionState connection,
|
||||
Frame requestFrame,
|
||||
TimeSpan timeout,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var correlationId = requestFrame.CorrelationId is not null &&
|
||||
Guid.TryParse(requestFrame.CorrelationId, out var parsed)
|
||||
? parsed
|
||||
: Guid.NewGuid();
|
||||
|
||||
var framedRequest = requestFrame with { CorrelationId = correlationId.ToString("N") };
|
||||
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
timeoutCts.CancelAfter(timeout);
|
||||
|
||||
var responseTask = _pendingRequests.TrackRequest(correlationId, timeoutCts.Token);
|
||||
|
||||
await WriteFrameAsync(framedRequest, timeoutCts.Token);
|
||||
|
||||
try
|
||||
{
|
||||
return await responseTask;
|
||||
}
|
||||
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
throw new TimeoutException($"Request {correlationId} timed out after {timeout}");
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task SendCancelAsync(
|
||||
ConnectionState connection,
|
||||
Guid correlationId,
|
||||
string? reason = null)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var cancelFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Cancel,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
await WriteFrameAsync(cancelFrame, CancellationToken.None);
|
||||
_logger.LogDebug("Sent CANCEL for {CorrelationId}", correlationId);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task SendStreamingAsync(
|
||||
ConnectionState connection,
|
||||
Frame requestHeader,
|
||||
Stream requestBody,
|
||||
Func<Stream, Task> readResponseBody,
|
||||
PayloadLimits limits,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var correlationId = requestHeader.CorrelationId is not null &&
|
||||
Guid.TryParse(requestHeader.CorrelationId, out var parsed)
|
||||
? parsed
|
||||
: Guid.NewGuid();
|
||||
|
||||
var headerFrame = requestHeader with
|
||||
{
|
||||
Type = FrameType.Request,
|
||||
CorrelationId = correlationId.ToString("N")
|
||||
};
|
||||
await WriteFrameAsync(headerFrame, cancellationToken);
|
||||
|
||||
// Stream request body
|
||||
var buffer = ArrayPool<byte>.Shared.Rent(8192);
|
||||
try
|
||||
{
|
||||
long totalBytesRead = 0;
|
||||
int bytesRead;
|
||||
|
||||
while ((bytesRead = await requestBody.ReadAsync(buffer, cancellationToken)) > 0)
|
||||
{
|
||||
totalBytesRead += bytesRead;
|
||||
|
||||
if (totalBytesRead > limits.MaxRequestBytesPerCall)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
$"Request body exceeds limit of {limits.MaxRequestBytesPerCall} bytes");
|
||||
}
|
||||
|
||||
var dataFrame = new Frame
|
||||
{
|
||||
Type = FrameType.RequestStreamData,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = new ReadOnlyMemory<byte>(buffer, 0, bytesRead)
|
||||
};
|
||||
await WriteFrameAsync(dataFrame, cancellationToken);
|
||||
}
|
||||
|
||||
// End of stream marker
|
||||
var endFrame = new Frame
|
||||
{
|
||||
Type = FrameType.RequestStreamData,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
await WriteFrameAsync(endFrame, cancellationToken);
|
||||
}
|
||||
finally
|
||||
{
|
||||
ArrayPool<byte>.Shared.Return(buffer);
|
||||
}
|
||||
|
||||
// Read streaming response
|
||||
using var responseStream = new MemoryStream();
|
||||
await readResponseBody(responseStream);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sends a heartbeat.
|
||||
/// </summary>
|
||||
public async Task SendHeartbeatAsync(HeartbeatPayload heartbeat, CancellationToken cancellationToken)
|
||||
{
|
||||
var frame = new Frame
|
||||
{
|
||||
Type = FrameType.Heartbeat,
|
||||
CorrelationId = null,
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
await WriteFrameAsync(frame, cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cancels all in-flight handlers.
|
||||
/// </summary>
|
||||
public void CancelAllInflight(string reason)
|
||||
{
|
||||
var count = 0;
|
||||
foreach (var cts in _inflightHandlers.Values)
|
||||
{
|
||||
try
|
||||
{
|
||||
cts.Cancel();
|
||||
count++;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Already completed
|
||||
}
|
||||
}
|
||||
|
||||
if (count > 0)
|
||||
{
|
||||
_logger.LogInformation("Cancelled {Count} in-flight handlers: {Reason}", count, reason);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disconnects from the gateway.
|
||||
/// </summary>
|
||||
public async Task DisconnectAsync()
|
||||
{
|
||||
CancelAllInflight("Shutdown");
|
||||
|
||||
await _clientCts.CancelAsync();
|
||||
|
||||
if (_receiveTask is not null)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _receiveTask;
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
|
||||
_sslStream?.Dispose();
|
||||
_client?.Dispose();
|
||||
_logger.LogInformation("Disconnected from TLS gateway");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
await DisconnectAsync();
|
||||
|
||||
_certWatcher.Dispose();
|
||||
_pendingRequests.Dispose();
|
||||
_writeLock.Dispose();
|
||||
_clientCts.Dispose();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
using System.Net;
|
||||
using System.Security.Authentication;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tls;
|
||||
|
||||
/// <summary>
|
||||
/// Options for TLS transport configuration.
|
||||
/// </summary>
|
||||
public sealed class TlsTransportOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the bind address for the server.
|
||||
/// </summary>
|
||||
public IPAddress BindAddress { get; set; } = IPAddress.Any;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the port to listen on.
|
||||
/// </summary>
|
||||
public int Port { get; set; } = 5101;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the host to connect to (client only).
|
||||
/// </summary>
|
||||
public string? Host { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the receive buffer size.
|
||||
/// </summary>
|
||||
public int ReceiveBufferSize { get; set; } = 64 * 1024;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the send buffer size.
|
||||
/// </summary>
|
||||
public int SendBufferSize { get; set; } = 64 * 1024;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the keep-alive interval.
|
||||
/// </summary>
|
||||
public TimeSpan KeepAliveInterval { get; set; } = TimeSpan.FromSeconds(30);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the connection timeout.
|
||||
/// </summary>
|
||||
public TimeSpan ConnectTimeout { get; set; } = TimeSpan.FromSeconds(10);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the maximum reconnection attempts.
|
||||
/// </summary>
|
||||
public int MaxReconnectAttempts { get; set; } = 10;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the maximum reconnection backoff.
|
||||
/// </summary>
|
||||
public TimeSpan MaxReconnectBackoff { get; set; } = TimeSpan.FromMinutes(1);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the maximum frame size.
|
||||
/// </summary>
|
||||
public int MaxFrameSize { get; set; } = 16 * 1024 * 1024;
|
||||
|
||||
// Server-side certificate (Gateway)
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the server certificate object.
|
||||
/// </summary>
|
||||
public X509Certificate2? ServerCertificate { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the server certificate path (PEM or PFX).
|
||||
/// </summary>
|
||||
public string? ServerCertificatePath { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the server certificate key path (PEM private key).
|
||||
/// </summary>
|
||||
public string? ServerCertificateKeyPath { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the server certificate password (for PFX).
|
||||
/// </summary>
|
||||
public string? ServerCertificatePassword { get; set; }
|
||||
|
||||
// Client-side certificate (Microservice)
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the client certificate object.
|
||||
/// </summary>
|
||||
public X509Certificate2? ClientCertificate { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the client certificate path (PEM or PFX).
|
||||
/// </summary>
|
||||
public string? ClientCertificatePath { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the client certificate key path (PEM private key).
|
||||
/// </summary>
|
||||
public string? ClientCertificateKeyPath { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the client certificate password (for PFX).
|
||||
/// </summary>
|
||||
public string? ClientCertificatePassword { get; set; }
|
||||
|
||||
// Validation options
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to require client certificates (mTLS).
|
||||
/// </summary>
|
||||
public bool RequireClientCertificate { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to allow self-signed certificates (dev only).
|
||||
/// </summary>
|
||||
public bool AllowSelfSigned { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to check certificate revocation.
|
||||
/// </summary>
|
||||
public bool CheckCertificateRevocation { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the expected server hostname (for SNI).
|
||||
/// </summary>
|
||||
public string? ExpectedServerHostname { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the enabled SSL/TLS protocols.
|
||||
/// </summary>
|
||||
public SslProtocols EnabledProtocols { get; set; } = SslProtocols.Tls12 | SslProtocols.Tls13;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to enable certificate hot-reload.
|
||||
/// </summary>
|
||||
public bool EnableCertificateHotReload { get; set; } = false;
|
||||
}
|
||||
@@ -0,0 +1,342 @@
|
||||
using System.Collections.Concurrent;
|
||||
using System.Net;
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Authentication;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tls;
|
||||
|
||||
/// <summary>
|
||||
/// TLS transport server implementation for the gateway.
|
||||
/// </summary>
|
||||
public sealed class TlsTransportServer : ITransportServer, IAsyncDisposable
|
||||
{
|
||||
private readonly TlsTransportOptions _options;
|
||||
private readonly ILogger<TlsTransportServer> _logger;
|
||||
private readonly ConcurrentDictionary<string, TlsConnection> _connections = new();
|
||||
private readonly CertificateWatcher _certWatcher;
|
||||
private TcpListener? _listener;
|
||||
private CancellationTokenSource? _serverCts;
|
||||
private Task? _acceptTask;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a connection is established.
|
||||
/// </summary>
|
||||
public event Action<string, ConnectionState>? OnConnection;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a connection is lost.
|
||||
/// </summary>
|
||||
public event Action<string>? OnDisconnection;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a frame is received.
|
||||
/// </summary>
|
||||
public event Action<string, Frame>? OnFrame;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TlsTransportServer"/> class.
|
||||
/// </summary>
|
||||
public TlsTransportServer(
|
||||
IOptions<TlsTransportOptions> options,
|
||||
ILogger<TlsTransportServer> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
_certWatcher = new CertificateWatcher(_options, logger);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task StartAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
if (_certWatcher.ServerCertificate is null)
|
||||
{
|
||||
throw new InvalidOperationException("Server certificate is not configured");
|
||||
}
|
||||
|
||||
_serverCts = new CancellationTokenSource();
|
||||
_listener = new TcpListener(_options.BindAddress, _options.Port);
|
||||
_listener.Start();
|
||||
|
||||
_logger.LogInformation(
|
||||
"TLS transport server listening on {Address}:{Port}",
|
||||
_options.BindAddress,
|
||||
_options.Port);
|
||||
|
||||
_acceptTask = AcceptLoopAsync(_serverCts.Token);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private async Task AcceptLoopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
TcpClient? client = null;
|
||||
SslStream? sslStream = null;
|
||||
|
||||
try
|
||||
{
|
||||
client = await _listener!.AcceptTcpClientAsync(cancellationToken);
|
||||
|
||||
_logger.LogDebug(
|
||||
"Accepting TLS connection from {RemoteEndpoint}",
|
||||
client.Client.RemoteEndPoint);
|
||||
|
||||
sslStream = new SslStream(
|
||||
client.GetStream(),
|
||||
leaveInnerStreamOpen: false,
|
||||
userCertificateValidationCallback: ValidateClientCertificate);
|
||||
|
||||
await sslStream.AuthenticateAsServerAsync(new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = _certWatcher.ServerCertificate,
|
||||
ClientCertificateRequired = _options.RequireClientCertificate,
|
||||
EnabledSslProtocols = _options.EnabledProtocols,
|
||||
CertificateRevocationCheckMode = _options.CheckCertificateRevocation
|
||||
? X509RevocationMode.Online
|
||||
: X509RevocationMode.NoCheck
|
||||
}, cancellationToken);
|
||||
|
||||
var connectionId = GenerateConnectionId(client, sslStream.RemoteCertificate);
|
||||
|
||||
_logger.LogInformation(
|
||||
"TLS connection established: {ConnectionId} from {RemoteEndpoint}, Protocol: {Protocol}, CipherSuite: {CipherSuite}",
|
||||
connectionId,
|
||||
client.Client.RemoteEndPoint,
|
||||
sslStream.SslProtocol,
|
||||
sslStream.NegotiatedCipherSuite);
|
||||
|
||||
var connection = new TlsConnection(connectionId, client, sslStream, _options, _logger);
|
||||
_connections[connectionId] = connection;
|
||||
|
||||
connection.OnFrameReceived += HandleFrame;
|
||||
connection.OnDisconnected += HandleDisconnect;
|
||||
|
||||
// Start read loop (non-blocking)
|
||||
_ = Task.Run(() => connection.ReadLoopAsync(cancellationToken), CancellationToken.None);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected on shutdown
|
||||
break;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Listener disposed
|
||||
break;
|
||||
}
|
||||
catch (AuthenticationException ex)
|
||||
{
|
||||
_logger.LogWarning(ex,
|
||||
"TLS handshake failed from {RemoteEndpoint}",
|
||||
client?.Client?.RemoteEndPoint);
|
||||
|
||||
sslStream?.Dispose();
|
||||
client?.Dispose();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error accepting TLS connection");
|
||||
|
||||
sslStream?.Dispose();
|
||||
client?.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private bool ValidateClientCertificate(
|
||||
object sender,
|
||||
X509Certificate? certificate,
|
||||
X509Chain? chain,
|
||||
SslPolicyErrors errors)
|
||||
{
|
||||
// If we don't require client certs and none provided, allow
|
||||
if (!_options.RequireClientCertificate && certificate is null)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// If client cert is required but not provided, reject
|
||||
if (_options.RequireClientCertificate && certificate is null)
|
||||
{
|
||||
_logger.LogWarning("Client certificate required but not provided");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Allow self-signed in development
|
||||
if (_options.AllowSelfSigned)
|
||||
{
|
||||
if (errors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors))
|
||||
{
|
||||
// Check if the only error is self-signed
|
||||
if (chain is not null && chain.ChainStatus.All(s =>
|
||||
s.Status == X509ChainStatusFlags.UntrustedRoot ||
|
||||
s.Status == X509ChainStatusFlags.PartialChain))
|
||||
{
|
||||
_logger.LogDebug("Allowing self-signed client certificate");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Allow if no errors or only name mismatch
|
||||
if (errors == SslPolicyErrors.None ||
|
||||
errors == SslPolicyErrors.RemoteCertificateNameMismatch)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (errors != SslPolicyErrors.None)
|
||||
{
|
||||
_logger.LogWarning("Client certificate validation failed: {Errors}", errors);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private void HandleFrame(TlsConnection connection, Frame frame)
|
||||
{
|
||||
// If this is a HELLO frame, create the ConnectionState
|
||||
if (frame.Type == FrameType.Hello && connection.State is null)
|
||||
{
|
||||
var state = new ConnectionState
|
||||
{
|
||||
ConnectionId = connection.ConnectionId,
|
||||
Instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = connection.ConnectionId,
|
||||
ServiceName = connection.PeerIdentity ?? "unknown",
|
||||
Version = "1.0.0",
|
||||
Region = "default"
|
||||
},
|
||||
Status = InstanceHealthStatus.Healthy,
|
||||
LastHeartbeatUtc = DateTime.UtcNow,
|
||||
TransportType = TransportType.Certificate
|
||||
};
|
||||
connection.State = state;
|
||||
OnConnection?.Invoke(connection.ConnectionId, state);
|
||||
}
|
||||
|
||||
OnFrame?.Invoke(connection.ConnectionId, frame);
|
||||
}
|
||||
|
||||
private void HandleDisconnect(TlsConnection connection, Exception? ex)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"TLS connection {ConnectionId} disconnected{Reason}",
|
||||
connection.ConnectionId,
|
||||
ex is not null ? $": {ex.Message}" : string.Empty);
|
||||
|
||||
_connections.TryRemove(connection.ConnectionId, out _);
|
||||
OnDisconnection?.Invoke(connection.ConnectionId);
|
||||
|
||||
// Clean up connection
|
||||
_ = connection.DisposeAsync();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sends a frame to a connection.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID.</param>
|
||||
/// <param name="frame">The frame to send.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task SendFrameAsync(
|
||||
string connectionId,
|
||||
Frame frame,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (_connections.TryGetValue(connectionId, out var connection))
|
||||
{
|
||||
await connection.WriteFrameAsync(frame, cancellationToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException($"Connection {connectionId} not found");
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets a connection by ID.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID.</param>
|
||||
/// <returns>The connection, or null if not found.</returns>
|
||||
public TlsConnection? GetConnection(string connectionId)
|
||||
{
|
||||
return _connections.TryGetValue(connectionId, out var conn) ? conn : null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets all active connections.
|
||||
/// </summary>
|
||||
public IEnumerable<TlsConnection> GetConnections() => _connections.Values;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the number of active connections.
|
||||
/// </summary>
|
||||
public int ConnectionCount => _connections.Count;
|
||||
|
||||
private static string GenerateConnectionId(TcpClient client, X509Certificate? remoteCert)
|
||||
{
|
||||
var endpoint = client.Client.RemoteEndPoint as IPEndPoint;
|
||||
var certId = remoteCert?.GetSerialNumberString() ?? "nocert";
|
||||
|
||||
if (endpoint is not null)
|
||||
{
|
||||
return $"tls-{endpoint.Address}-{endpoint.Port}-{certId}".Substring(0, Math.Min(48, 16 + certId.Length));
|
||||
}
|
||||
|
||||
return $"tls-{Guid.NewGuid():N}";
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task StopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
_logger.LogInformation("Stopping TLS transport server");
|
||||
|
||||
if (_serverCts is not null)
|
||||
{
|
||||
await _serverCts.CancelAsync();
|
||||
}
|
||||
|
||||
_listener?.Stop();
|
||||
|
||||
if (_acceptTask is not null)
|
||||
{
|
||||
await _acceptTask;
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
foreach (var connection in _connections.Values)
|
||||
{
|
||||
connection.Close();
|
||||
await connection.DisposeAsync();
|
||||
}
|
||||
|
||||
_connections.Clear();
|
||||
|
||||
_logger.LogInformation("TLS transport server stopped");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
await StopAsync(CancellationToken.None);
|
||||
|
||||
_certWatcher.Dispose();
|
||||
_listener?.Dispose();
|
||||
_serverCts?.Dispose();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
namespace StellaOps.Router.Transport.Udp;
|
||||
|
||||
/// <summary>
|
||||
/// Exception thrown when a payload exceeds the maximum datagram size.
|
||||
/// </summary>
|
||||
public sealed class PayloadTooLargeException : Exception
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the actual size of the payload.
|
||||
/// </summary>
|
||||
public int ActualSize { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the maximum allowed size.
|
||||
/// </summary>
|
||||
public int MaxSize { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PayloadTooLargeException"/> class.
|
||||
/// </summary>
|
||||
public PayloadTooLargeException(int actualSize, int maxSize)
|
||||
: base($"Payload size {actualSize} exceeds maximum datagram size of {maxSize} bytes")
|
||||
{
|
||||
ActualSize = actualSize;
|
||||
MaxSize = maxSize;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
|
||||
namespace StellaOps.Router.Transport.Udp;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering UDP transport services.
|
||||
/// </summary>
|
||||
public static class ServiceCollectionExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds UDP transport server services to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configure">Optional configuration action.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddUdpTransportServer(
|
||||
this IServiceCollection services,
|
||||
Action<UdpTransportOptions>? configure = null)
|
||||
{
|
||||
if (configure is not null)
|
||||
{
|
||||
services.Configure(configure);
|
||||
}
|
||||
|
||||
services.AddSingleton<UdpTransportServer>();
|
||||
services.AddSingleton<ITransportServer>(sp => sp.GetRequiredService<UdpTransportServer>());
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds UDP transport client services to the service collection.
|
||||
/// </summary>
|
||||
/// <param name="services">The service collection.</param>
|
||||
/// <param name="configure">Optional configuration action.</param>
|
||||
/// <returns>The service collection.</returns>
|
||||
public static IServiceCollection AddUdpTransportClient(
|
||||
this IServiceCollection services,
|
||||
Action<UdpTransportOptions>? configure = null)
|
||||
{
|
||||
if (configure is not null)
|
||||
{
|
||||
services.Configure(configure);
|
||||
}
|
||||
|
||||
services.AddSingleton<UdpTransportClient>();
|
||||
services.AddSingleton<ITransportClient>(sp => sp.GetRequiredService<UdpTransportClient>());
|
||||
services.AddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<UdpTransportClient>());
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net10.0</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<LangVersion>preview</LangVersion>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
<RootNamespace>StellaOps.Router.Transport.Udp</RootNamespace>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Options" Version="10.0.0-rc.2.25502.107" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
@@ -0,0 +1,79 @@
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.Udp;
|
||||
|
||||
/// <summary>
|
||||
/// Handles serialization and deserialization of frames for UDP transport.
|
||||
/// Frame format: [1-byte frame type][16-byte correlation GUID][remaining data]
|
||||
/// </summary>
|
||||
public static class UdpFrameProtocol
|
||||
{
|
||||
private const int FrameTypeSize = 1;
|
||||
private const int CorrelationIdSize = 16;
|
||||
private const int HeaderSize = FrameTypeSize + CorrelationIdSize;
|
||||
|
||||
/// <summary>
|
||||
/// Parses a frame from a datagram.
|
||||
/// </summary>
|
||||
/// <param name="data">The datagram data.</param>
|
||||
/// <returns>The parsed frame.</returns>
|
||||
/// <exception cref="InvalidOperationException">Thrown when the datagram is too small.</exception>
|
||||
public static Frame ParseFrame(ReadOnlySpan<byte> data)
|
||||
{
|
||||
if (data.Length < HeaderSize)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
$"Datagram too small: {data.Length} bytes, minimum is {HeaderSize}");
|
||||
}
|
||||
|
||||
var frameType = (FrameType)data[0];
|
||||
var correlationId = new Guid(data.Slice(FrameTypeSize, CorrelationIdSize));
|
||||
var payload = data.Length > HeaderSize
|
||||
? data[HeaderSize..].ToArray()
|
||||
: Array.Empty<byte>();
|
||||
|
||||
return new Frame
|
||||
{
|
||||
Type = frameType,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = payload
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Serializes a frame to a datagram.
|
||||
/// </summary>
|
||||
/// <param name="frame">The frame to serialize.</param>
|
||||
/// <returns>The serialized datagram bytes.</returns>
|
||||
public static byte[] SerializeFrame(Frame frame)
|
||||
{
|
||||
// Parse or generate correlation ID
|
||||
var correlationGuid = frame.CorrelationId is not null &&
|
||||
Guid.TryParse(frame.CorrelationId, out var parsed)
|
||||
? parsed
|
||||
: Guid.NewGuid();
|
||||
|
||||
var payloadLength = frame.Payload.Length;
|
||||
var buffer = new byte[HeaderSize + payloadLength];
|
||||
|
||||
// Write frame type
|
||||
buffer[0] = (byte)frame.Type;
|
||||
|
||||
// Write correlation ID
|
||||
correlationGuid.TryWriteBytes(buffer.AsSpan(FrameTypeSize, CorrelationIdSize));
|
||||
|
||||
// Write payload
|
||||
if (payloadLength > 0)
|
||||
{
|
||||
frame.Payload.Span.CopyTo(buffer.AsSpan(HeaderSize));
|
||||
}
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the header size for UDP frames.
|
||||
/// </summary>
|
||||
public static int GetHeaderSize() => HeaderSize;
|
||||
}
|
||||
@@ -0,0 +1,412 @@
|
||||
using System.Collections.Concurrent;
|
||||
using System.Net.Sockets;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.Udp;
|
||||
|
||||
/// <summary>
|
||||
/// UDP transport client implementation for microservices.
|
||||
/// UDP transport does not support streaming.
|
||||
/// </summary>
|
||||
public sealed class UdpTransportClient : ITransportClient, IMicroserviceTransport, IAsyncDisposable
|
||||
{
|
||||
private readonly UdpTransportOptions _options;
|
||||
private readonly ILogger<UdpTransportClient> _logger;
|
||||
private readonly ConcurrentDictionary<Guid, TaskCompletionSource<Frame>> _pendingRequests = new();
|
||||
private readonly ConcurrentDictionary<string, CancellationTokenSource> _inflightHandlers = new();
|
||||
private readonly CancellationTokenSource _clientCts = new();
|
||||
private UdpClient? _client;
|
||||
private Task? _receiveTask;
|
||||
private bool _disposed;
|
||||
private string? _connectionId;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a REQUEST frame is received.
|
||||
/// </summary>
|
||||
public event Func<Frame, CancellationToken, Task<Frame>>? OnRequestReceived;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a CANCEL frame is received.
|
||||
/// </summary>
|
||||
public event Func<Guid, string?, Task>? OnCancelReceived;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="UdpTransportClient"/> class.
|
||||
/// </summary>
|
||||
public UdpTransportClient(
|
||||
IOptions<UdpTransportOptions> options,
|
||||
ILogger<UdpTransportClient> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Connects to the gateway.
|
||||
/// </summary>
|
||||
/// <param name="instance">The instance descriptor.</param>
|
||||
/// <param name="endpoints">The endpoints to register.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task ConnectAsync(
|
||||
InstanceDescriptor instance,
|
||||
IReadOnlyList<EndpointDescriptor> endpoints,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
if (string.IsNullOrEmpty(_options.Host))
|
||||
{
|
||||
throw new InvalidOperationException("Host is not configured");
|
||||
}
|
||||
|
||||
_client = new UdpClient
|
||||
{
|
||||
EnableBroadcast = _options.AllowBroadcast
|
||||
};
|
||||
_client.Client.ReceiveBufferSize = _options.ReceiveBufferSize;
|
||||
_client.Client.SendBufferSize = _options.SendBufferSize;
|
||||
_client.Connect(_options.Host, _options.Port);
|
||||
|
||||
_connectionId = Guid.NewGuid().ToString("N");
|
||||
|
||||
// Send HELLO frame
|
||||
var helloFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Hello,
|
||||
CorrelationId = Guid.NewGuid().ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
await SendFrameInternalAsync(helloFrame, cancellationToken);
|
||||
|
||||
_logger.LogInformation(
|
||||
"Connected to UDP gateway at {Host}:{Port} as {ServiceName}/{Version}",
|
||||
_options.Host,
|
||||
_options.Port,
|
||||
instance.ServiceName,
|
||||
instance.Version);
|
||||
|
||||
// Start receiving frames
|
||||
_receiveTask = Task.Run(() => ReceiveLoopAsync(_clientCts.Token), CancellationToken.None);
|
||||
}
|
||||
|
||||
private async Task ReceiveLoopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
try
|
||||
{
|
||||
var result = await _client!.ReceiveAsync(cancellationToken);
|
||||
var data = result.Buffer;
|
||||
|
||||
if (data.Length < UdpFrameProtocol.GetHeaderSize())
|
||||
{
|
||||
_logger.LogWarning("Received datagram too small ({Size} bytes)", data.Length);
|
||||
continue;
|
||||
}
|
||||
|
||||
var frame = UdpFrameProtocol.ParseFrame(data);
|
||||
await ProcessFrameAsync(frame, cancellationToken);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
break;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
break;
|
||||
}
|
||||
catch (SocketException ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "UDP socket error in receive loop");
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error in receive loop");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ProcessFrameAsync(Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
switch (frame.Type)
|
||||
{
|
||||
case FrameType.Request:
|
||||
await HandleRequestFrameAsync(frame, cancellationToken);
|
||||
break;
|
||||
|
||||
case FrameType.Cancel:
|
||||
HandleCancelFrame(frame);
|
||||
break;
|
||||
|
||||
case FrameType.Response:
|
||||
if (frame.CorrelationId is not null &&
|
||||
Guid.TryParse(frame.CorrelationId, out var correlationId))
|
||||
{
|
||||
if (_pendingRequests.TryRemove(correlationId, out var tcs))
|
||||
{
|
||||
tcs.TrySetResult(frame);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case FrameType.RequestStreamData:
|
||||
case FrameType.ResponseStreamData:
|
||||
_logger.LogWarning(
|
||||
"UDP transport does not support streaming. Frame type {Type} ignored.",
|
||||
frame.Type);
|
||||
break;
|
||||
|
||||
default:
|
||||
_logger.LogWarning("Unexpected frame type {FrameType}", frame.Type);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private async Task HandleRequestFrameAsync(Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
if (OnRequestReceived is null)
|
||||
{
|
||||
_logger.LogWarning("No request handler registered");
|
||||
return;
|
||||
}
|
||||
|
||||
var correlationId = frame.CorrelationId ?? Guid.NewGuid().ToString("N");
|
||||
|
||||
using var handlerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
_inflightHandlers[correlationId] = handlerCts;
|
||||
|
||||
try
|
||||
{
|
||||
var response = await OnRequestReceived(frame, handlerCts.Token);
|
||||
var responseFrame = response with { CorrelationId = correlationId };
|
||||
|
||||
if (!handlerCts.Token.IsCancellationRequested)
|
||||
{
|
||||
await SendFrameInternalAsync(responseFrame, cancellationToken);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
_logger.LogDebug("Request {CorrelationId} was cancelled", correlationId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error handling request {CorrelationId}", correlationId);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_inflightHandlers.TryRemove(correlationId, out _);
|
||||
}
|
||||
}
|
||||
|
||||
private void HandleCancelFrame(Frame frame)
|
||||
{
|
||||
if (frame.CorrelationId is null) return;
|
||||
|
||||
_logger.LogDebug("Received CANCEL for {CorrelationId}", frame.CorrelationId);
|
||||
|
||||
if (_inflightHandlers.TryGetValue(frame.CorrelationId, out var cts))
|
||||
{
|
||||
try
|
||||
{
|
||||
cts.Cancel();
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Already completed
|
||||
}
|
||||
}
|
||||
|
||||
if (Guid.TryParse(frame.CorrelationId, out var guid))
|
||||
{
|
||||
if (_pendingRequests.TryRemove(guid, out var tcs))
|
||||
{
|
||||
tcs.TrySetCanceled();
|
||||
}
|
||||
OnCancelReceived?.Invoke(guid, null);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task SendFrameInternalAsync(Frame frame, CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var data = UdpFrameProtocol.SerializeFrame(frame);
|
||||
|
||||
if (data.Length > _options.MaxDatagramSize)
|
||||
{
|
||||
throw new PayloadTooLargeException(data.Length, _options.MaxDatagramSize);
|
||||
}
|
||||
|
||||
await _client!.SendAsync(data, cancellationToken);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task<Frame> SendRequestAsync(
|
||||
ConnectionState connection,
|
||||
Frame requestFrame,
|
||||
TimeSpan timeout,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var correlationId = requestFrame.CorrelationId is not null &&
|
||||
Guid.TryParse(requestFrame.CorrelationId, out var parsed)
|
||||
? parsed
|
||||
: Guid.NewGuid();
|
||||
|
||||
var framedRequest = requestFrame with { CorrelationId = correlationId.ToString("N") };
|
||||
|
||||
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
timeoutCts.CancelAfter(timeout);
|
||||
|
||||
var tcs = new TaskCompletionSource<Frame>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
var registration = timeoutCts.Token.Register(() =>
|
||||
{
|
||||
if (_pendingRequests.TryRemove(correlationId, out var pendingTcs))
|
||||
{
|
||||
pendingTcs.TrySetCanceled(timeoutCts.Token);
|
||||
}
|
||||
});
|
||||
|
||||
_pendingRequests[correlationId] = tcs;
|
||||
|
||||
try
|
||||
{
|
||||
await SendFrameInternalAsync(framedRequest, timeoutCts.Token);
|
||||
|
||||
return await tcs.Task;
|
||||
}
|
||||
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
throw new TimeoutException($"Request {correlationId} timed out after {timeout}");
|
||||
}
|
||||
finally
|
||||
{
|
||||
await registration.DisposeAsync();
|
||||
_pendingRequests.TryRemove(correlationId, out _);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task SendCancelAsync(
|
||||
ConnectionState connection,
|
||||
Guid correlationId,
|
||||
string? reason = null)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
var cancelFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Cancel,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
// Best effort - UDP may not deliver
|
||||
await SendFrameInternalAsync(cancelFrame, CancellationToken.None);
|
||||
_logger.LogDebug("Sent CANCEL for {CorrelationId} (best effort)", correlationId);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task SendStreamingAsync(
|
||||
ConnectionState connection,
|
||||
Frame requestHeader,
|
||||
Stream requestBody,
|
||||
Func<Stream, Task> readResponseBody,
|
||||
PayloadLimits limits,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
throw new NotSupportedException(
|
||||
"UDP transport does not support streaming. Use TCP or TLS transport.");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sends a heartbeat.
|
||||
/// </summary>
|
||||
public async Task SendHeartbeatAsync(HeartbeatPayload heartbeat, CancellationToken cancellationToken)
|
||||
{
|
||||
var frame = new Frame
|
||||
{
|
||||
Type = FrameType.Heartbeat,
|
||||
CorrelationId = null,
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
await SendFrameInternalAsync(frame, cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cancels all in-flight handlers.
|
||||
/// </summary>
|
||||
public void CancelAllInflight(string reason)
|
||||
{
|
||||
var count = 0;
|
||||
foreach (var cts in _inflightHandlers.Values)
|
||||
{
|
||||
try
|
||||
{
|
||||
cts.Cancel();
|
||||
count++;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Already completed
|
||||
}
|
||||
}
|
||||
|
||||
if (count > 0)
|
||||
{
|
||||
_logger.LogInformation("Cancelled {Count} in-flight handlers: {Reason}", count, reason);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disconnects from the gateway.
|
||||
/// </summary>
|
||||
public async Task DisconnectAsync()
|
||||
{
|
||||
CancelAllInflight("Shutdown");
|
||||
|
||||
// Cancel all pending requests
|
||||
foreach (var kvp in _pendingRequests)
|
||||
{
|
||||
if (_pendingRequests.TryRemove(kvp.Key, out var tcs))
|
||||
{
|
||||
tcs.TrySetCanceled();
|
||||
}
|
||||
}
|
||||
|
||||
await _clientCts.CancelAsync();
|
||||
|
||||
if (_receiveTask is not null)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _receiveTask;
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
|
||||
_client?.Dispose();
|
||||
_logger.LogInformation("Disconnected from UDP gateway");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
await DisconnectAsync();
|
||||
|
||||
_clientCts.Dispose();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
using System.Net;
|
||||
|
||||
namespace StellaOps.Router.Transport.Udp;
|
||||
|
||||
/// <summary>
|
||||
/// Options for UDP transport configuration.
|
||||
/// </summary>
|
||||
public sealed class UdpTransportOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the bind address for the server.
|
||||
/// </summary>
|
||||
public IPAddress BindAddress { get; set; } = IPAddress.Any;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the port to listen on/connect to.
|
||||
/// </summary>
|
||||
public int Port { get; set; } = 5102;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the host to connect to (client only).
|
||||
/// </summary>
|
||||
public string? Host { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the maximum datagram size in bytes.
|
||||
/// Conservative default well under typical MTU of 1500 bytes.
|
||||
/// </summary>
|
||||
public int MaxDatagramSize { get; set; } = 8192;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the default timeout for requests.
|
||||
/// </summary>
|
||||
public TimeSpan DefaultTimeout { get; set; } = TimeSpan.FromSeconds(5);
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets whether to allow broadcast.
|
||||
/// </summary>
|
||||
public bool AllowBroadcast { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the receive buffer size.
|
||||
/// </summary>
|
||||
public int ReceiveBufferSize { get; set; } = 64 * 1024;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the send buffer size.
|
||||
/// </summary>
|
||||
public int SendBufferSize { get; set; } = 64 * 1024;
|
||||
}
|
||||
@@ -0,0 +1,266 @@
|
||||
using System.Collections.Concurrent;
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Abstractions;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
|
||||
namespace StellaOps.Router.Transport.Udp;
|
||||
|
||||
/// <summary>
|
||||
/// UDP transport server implementation for the gateway.
|
||||
/// UDP transport is stateless - connections are logical based on source endpoint.
|
||||
/// </summary>
|
||||
public sealed class UdpTransportServer : ITransportServer, IAsyncDisposable
|
||||
{
|
||||
private readonly UdpTransportOptions _options;
|
||||
private readonly ILogger<UdpTransportServer> _logger;
|
||||
private readonly ConcurrentDictionary<IPEndPoint, string> _endpointToConnectionId = new();
|
||||
private readonly ConcurrentDictionary<string, (IPEndPoint Endpoint, ConnectionState State)> _connections = new();
|
||||
private UdpClient? _listener;
|
||||
private CancellationTokenSource? _serverCts;
|
||||
private Task? _receiveTask;
|
||||
private bool _disposed;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a connection is established (on first HELLO).
|
||||
/// </summary>
|
||||
public event Action<string, ConnectionState>? OnConnection;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a connection is lost.
|
||||
/// </summary>
|
||||
public event Action<string>? OnDisconnection;
|
||||
|
||||
/// <summary>
|
||||
/// Event raised when a frame is received.
|
||||
/// </summary>
|
||||
public event Action<string, Frame>? OnFrame;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="UdpTransportServer"/> class.
|
||||
/// </summary>
|
||||
public UdpTransportServer(
|
||||
IOptions<UdpTransportOptions> options,
|
||||
ILogger<UdpTransportServer> logger)
|
||||
{
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task StartAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
_serverCts = new CancellationTokenSource();
|
||||
|
||||
var endpoint = new IPEndPoint(_options.BindAddress, _options.Port);
|
||||
_listener = new UdpClient(endpoint)
|
||||
{
|
||||
EnableBroadcast = _options.AllowBroadcast
|
||||
};
|
||||
|
||||
// Configure socket buffers
|
||||
_listener.Client.ReceiveBufferSize = _options.ReceiveBufferSize;
|
||||
_listener.Client.SendBufferSize = _options.SendBufferSize;
|
||||
|
||||
_logger.LogInformation(
|
||||
"UDP transport server listening on {Address}:{Port}",
|
||||
_options.BindAddress,
|
||||
_options.Port);
|
||||
|
||||
_receiveTask = ReceiveLoopAsync(_serverCts.Token);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private async Task ReceiveLoopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
try
|
||||
{
|
||||
var result = await _listener!.ReceiveAsync(cancellationToken);
|
||||
var remoteEndpoint = result.RemoteEndPoint;
|
||||
var data = result.Buffer;
|
||||
|
||||
if (data.Length < UdpFrameProtocol.GetHeaderSize())
|
||||
{
|
||||
_logger.LogWarning(
|
||||
"Received datagram too small ({Size} bytes) from {Endpoint}",
|
||||
data.Length,
|
||||
remoteEndpoint);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse frame
|
||||
var frame = UdpFrameProtocol.ParseFrame(data);
|
||||
|
||||
// Get or create connection ID for this endpoint
|
||||
var connectionId = _endpointToConnectionId.GetOrAdd(
|
||||
remoteEndpoint,
|
||||
_ => $"udp-{remoteEndpoint.Address}-{remoteEndpoint.Port}-{Guid.NewGuid():N}"[..32]);
|
||||
|
||||
// Handle HELLO specially to register connection
|
||||
if (frame.Type == FrameType.Hello && !_connections.ContainsKey(connectionId))
|
||||
{
|
||||
var state = new ConnectionState
|
||||
{
|
||||
ConnectionId = connectionId,
|
||||
Instance = new InstanceDescriptor
|
||||
{
|
||||
InstanceId = connectionId,
|
||||
ServiceName = "unknown",
|
||||
Version = "1.0.0",
|
||||
Region = "default"
|
||||
},
|
||||
Status = InstanceHealthStatus.Healthy,
|
||||
LastHeartbeatUtc = DateTime.UtcNow,
|
||||
TransportType = TransportType.Udp
|
||||
};
|
||||
|
||||
_connections[connectionId] = (remoteEndpoint, state);
|
||||
_logger.LogInformation(
|
||||
"UDP connection established: {ConnectionId} from {Endpoint}",
|
||||
connectionId,
|
||||
remoteEndpoint);
|
||||
OnConnection?.Invoke(connectionId, state);
|
||||
}
|
||||
|
||||
// Update heartbeat timestamp on HEARTBEAT frames
|
||||
if (frame.Type == FrameType.Heartbeat &&
|
||||
_connections.TryGetValue(connectionId, out var conn))
|
||||
{
|
||||
conn.State.LastHeartbeatUtc = DateTime.UtcNow;
|
||||
}
|
||||
|
||||
OnFrame?.Invoke(connectionId, frame);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected on shutdown
|
||||
break;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
// Listener disposed
|
||||
break;
|
||||
}
|
||||
catch (SocketException ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "UDP socket error");
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error receiving UDP datagram");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sends a frame to a connection.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID.</param>
|
||||
/// <param name="frame">The frame to send.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
public async Task SendFrameAsync(
|
||||
string connectionId,
|
||||
Frame frame,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
ObjectDisposedException.ThrowIf(_disposed, this);
|
||||
|
||||
if (!_connections.TryGetValue(connectionId, out var conn))
|
||||
{
|
||||
throw new InvalidOperationException($"Connection {connectionId} not found");
|
||||
}
|
||||
|
||||
var data = UdpFrameProtocol.SerializeFrame(frame);
|
||||
|
||||
if (data.Length > _options.MaxDatagramSize)
|
||||
{
|
||||
throw new PayloadTooLargeException(data.Length, _options.MaxDatagramSize);
|
||||
}
|
||||
|
||||
await _listener!.SendAsync(data, conn.Endpoint, cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the connection state by ID.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID.</param>
|
||||
/// <returns>The connection state, or null if not found.</returns>
|
||||
public ConnectionState? GetConnectionState(string connectionId)
|
||||
{
|
||||
return _connections.TryGetValue(connectionId, out var conn) ? conn.State : null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets all active connections.
|
||||
/// </summary>
|
||||
public IEnumerable<ConnectionState> GetConnections() =>
|
||||
_connections.Values.Select(c => c.State);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the number of active connections.
|
||||
/// </summary>
|
||||
public int ConnectionCount => _connections.Count;
|
||||
|
||||
/// <summary>
|
||||
/// Removes a connection (for cleanup purposes).
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The connection ID.</param>
|
||||
public void RemoveConnection(string connectionId)
|
||||
{
|
||||
if (_connections.TryRemove(connectionId, out var conn))
|
||||
{
|
||||
_endpointToConnectionId.TryRemove(conn.Endpoint, out _);
|
||||
_logger.LogInformation("UDP connection removed: {ConnectionId}", connectionId);
|
||||
OnDisconnection?.Invoke(connectionId);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task StopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
_logger.LogInformation("Stopping UDP transport server");
|
||||
|
||||
if (_serverCts is not null)
|
||||
{
|
||||
await _serverCts.CancelAsync();
|
||||
}
|
||||
|
||||
_listener?.Close();
|
||||
|
||||
if (_receiveTask is not null)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _receiveTask;
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected
|
||||
}
|
||||
}
|
||||
|
||||
_connections.Clear();
|
||||
_endpointToConnectionId.Clear();
|
||||
|
||||
_logger.LogInformation("UDP transport server stopped");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_disposed) return;
|
||||
_disposed = true;
|
||||
|
||||
await StopAsync(CancellationToken.None);
|
||||
|
||||
_listener?.Dispose();
|
||||
_serverCts?.Dispose();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net10.0</TargetFramework>
|
||||
<LangVersion>preview</LangVersion>
|
||||
<Nullable>enable</Nullable>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<IsPackable>false</IsPackable>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
<UseConcelierTestInfra>false</UseConcelierTestInfra>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.14.0" />
|
||||
<PackageReference Include="xunit" Version="2.9.2" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2">
|
||||
<PrivateAssets>all</PrivateAssets>
|
||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||
</PackageReference>
|
||||
<PackageReference Include="FluentAssertions" Version="6.12.0" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\StellaOps.Router.Transport.Tcp\StellaOps.Router.Transport.Tcp.csproj" />
|
||||
<ProjectReference Include="..\..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
@@ -0,0 +1,199 @@
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Common.Enums;
|
||||
using StellaOps.Router.Common.Models;
|
||||
using StellaOps.Router.Transport.Tcp;
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tcp.Tests;
|
||||
|
||||
public class TcpTransportOptionsTests
|
||||
{
|
||||
[Fact]
|
||||
public void DefaultOptions_HaveCorrectValues()
|
||||
{
|
||||
var options = new TcpTransportOptions();
|
||||
|
||||
Assert.Equal(5100, options.Port);
|
||||
Assert.Equal(64 * 1024, options.ReceiveBufferSize);
|
||||
Assert.Equal(64 * 1024, options.SendBufferSize);
|
||||
Assert.Equal(TimeSpan.FromSeconds(30), options.KeepAliveInterval);
|
||||
Assert.Equal(TimeSpan.FromSeconds(10), options.ConnectTimeout);
|
||||
Assert.Equal(10, options.MaxReconnectAttempts);
|
||||
Assert.Equal(TimeSpan.FromMinutes(1), options.MaxReconnectBackoff);
|
||||
Assert.Equal(16 * 1024 * 1024, options.MaxFrameSize);
|
||||
}
|
||||
}
|
||||
|
||||
public class FrameProtocolTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task WriteAndReadFrame_RoundTrip()
|
||||
{
|
||||
// Arrange
|
||||
using var stream = new MemoryStream();
|
||||
var originalFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Request,
|
||||
CorrelationId = Guid.NewGuid().ToString("N"),
|
||||
Payload = new byte[] { 1, 2, 3, 4, 5 }
|
||||
};
|
||||
|
||||
// Act - Write
|
||||
await FrameProtocol.WriteFrameAsync(stream, originalFrame, CancellationToken.None);
|
||||
|
||||
// Act - Read
|
||||
stream.Position = 0;
|
||||
var readFrame = await FrameProtocol.ReadFrameAsync(stream, 1024 * 1024, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(readFrame);
|
||||
Assert.Equal(originalFrame.Type, readFrame.Type);
|
||||
Assert.Equal(originalFrame.CorrelationId, readFrame.CorrelationId);
|
||||
Assert.Equal(originalFrame.Payload.ToArray(), readFrame.Payload.ToArray());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAndReadFrame_EmptyPayload()
|
||||
{
|
||||
using var stream = new MemoryStream();
|
||||
var originalFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Cancel,
|
||||
CorrelationId = Guid.NewGuid().ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
await FrameProtocol.WriteFrameAsync(stream, originalFrame, CancellationToken.None);
|
||||
|
||||
stream.Position = 0;
|
||||
var readFrame = await FrameProtocol.ReadFrameAsync(stream, 1024 * 1024, CancellationToken.None);
|
||||
|
||||
Assert.NotNull(readFrame);
|
||||
Assert.Equal(FrameType.Cancel, readFrame.Type);
|
||||
Assert.Empty(readFrame.Payload.ToArray());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadFrame_ReturnsNullOnEmptyStream()
|
||||
{
|
||||
using var stream = new MemoryStream();
|
||||
|
||||
var result = await FrameProtocol.ReadFrameAsync(stream, 1024 * 1024, CancellationToken.None);
|
||||
|
||||
Assert.Null(result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadFrame_ThrowsOnOversizedFrame()
|
||||
{
|
||||
using var stream = new MemoryStream();
|
||||
var largeFrame = new Frame
|
||||
{
|
||||
Type = FrameType.Request,
|
||||
CorrelationId = Guid.NewGuid().ToString("N"),
|
||||
Payload = new byte[1000]
|
||||
};
|
||||
|
||||
await FrameProtocol.WriteFrameAsync(stream, largeFrame, CancellationToken.None);
|
||||
|
||||
stream.Position = 0;
|
||||
|
||||
// Max frame size is smaller than the written frame
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(
|
||||
() => FrameProtocol.ReadFrameAsync(stream, 100, CancellationToken.None));
|
||||
}
|
||||
}
|
||||
|
||||
public class PendingRequestTrackerTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task TrackRequest_CompletesWithResponse()
|
||||
{
|
||||
using var tracker = new PendingRequestTracker();
|
||||
var correlationId = Guid.NewGuid();
|
||||
var expectedResponse = new Frame
|
||||
{
|
||||
Type = FrameType.Response,
|
||||
CorrelationId = correlationId.ToString("N"),
|
||||
Payload = ReadOnlyMemory<byte>.Empty
|
||||
};
|
||||
|
||||
var responseTask = tracker.TrackRequest(correlationId, CancellationToken.None);
|
||||
Assert.False(responseTask.IsCompleted);
|
||||
|
||||
tracker.CompleteRequest(correlationId, expectedResponse);
|
||||
|
||||
var response = await responseTask;
|
||||
Assert.Equal(expectedResponse.Type, response.Type);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TrackRequest_CancelsOnTokenCancellation()
|
||||
{
|
||||
using var tracker = new PendingRequestTracker();
|
||||
using var cts = new CancellationTokenSource();
|
||||
var correlationId = Guid.NewGuid();
|
||||
|
||||
var responseTask = tracker.TrackRequest(correlationId, cts.Token);
|
||||
|
||||
cts.Cancel();
|
||||
|
||||
await Assert.ThrowsAsync<TaskCanceledException>(() => responseTask);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Count_ReturnsCorrectValue()
|
||||
{
|
||||
using var tracker = new PendingRequestTracker();
|
||||
|
||||
Assert.Equal(0, tracker.Count);
|
||||
|
||||
_ = tracker.TrackRequest(Guid.NewGuid(), CancellationToken.None);
|
||||
_ = tracker.TrackRequest(Guid.NewGuid(), CancellationToken.None);
|
||||
|
||||
Assert.Equal(2, tracker.Count);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CancelAll_CancelsAllPendingRequests()
|
||||
{
|
||||
using var tracker = new PendingRequestTracker();
|
||||
var task1 = tracker.TrackRequest(Guid.NewGuid(), CancellationToken.None);
|
||||
var task2 = tracker.TrackRequest(Guid.NewGuid(), CancellationToken.None);
|
||||
|
||||
tracker.CancelAll();
|
||||
|
||||
Assert.True(task1.IsCanceled || task1.IsFaulted);
|
||||
Assert.True(task2.IsCanceled || task2.IsFaulted);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void FailRequest_SetsException()
|
||||
{
|
||||
using var tracker = new PendingRequestTracker();
|
||||
var correlationId = Guid.NewGuid();
|
||||
var task = tracker.TrackRequest(correlationId, CancellationToken.None);
|
||||
|
||||
tracker.FailRequest(correlationId, new InvalidOperationException("Test error"));
|
||||
|
||||
Assert.True(task.IsFaulted);
|
||||
Assert.IsType<InvalidOperationException>(task.Exception?.InnerException);
|
||||
}
|
||||
}
|
||||
|
||||
public class TcpTransportServerTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task StartAsync_StartsListening()
|
||||
{
|
||||
var options = Options.Create(new TcpTransportOptions { Port = 0 }); // Port 0 = auto-assign
|
||||
await using var server = new TcpTransportServer(options, NullLogger<TcpTransportServer>.Instance);
|
||||
|
||||
await server.StartAsync(CancellationToken.None);
|
||||
|
||||
Assert.Equal(0, server.ConnectionCount);
|
||||
|
||||
await server.StopAsync(CancellationToken.None);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net10.0</TargetFramework>
|
||||
<LangVersion>preview</LangVersion>
|
||||
<Nullable>enable</Nullable>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<IsPackable>false</IsPackable>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
<UseConcelierTestInfra>false</UseConcelierTestInfra>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.14.0" />
|
||||
<PackageReference Include="xunit" Version="2.9.2" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2">
|
||||
<PrivateAssets>all</PrivateAssets>
|
||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||
</PackageReference>
|
||||
<PackageReference Include="FluentAssertions" Version="6.12.0" />
|
||||
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="10.0.0-rc.2.25502.107" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging" Version="10.0.0-rc.2.25502.107" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\StellaOps.Router.Transport.Tls\StellaOps.Router.Transport.Tls.csproj" />
|
||||
<ProjectReference Include="..\..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
@@ -0,0 +1,302 @@
|
||||
using System.Net;
|
||||
using System.Security.Authentication;
|
||||
using System.Security.Cryptography;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.Router.Transport.Tls;
|
||||
using Xunit;
|
||||
|
||||
namespace StellaOps.Router.Transport.Tls.Tests;
|
||||
|
||||
public class TlsTransportOptionsTests
|
||||
{
|
||||
[Fact]
|
||||
public void DefaultOptions_HaveCorrectValues()
|
||||
{
|
||||
var options = new TlsTransportOptions();
|
||||
|
||||
Assert.Equal(5101, options.Port);
|
||||
Assert.Equal(64 * 1024, options.ReceiveBufferSize);
|
||||
Assert.Equal(64 * 1024, options.SendBufferSize);
|
||||
Assert.Equal(TimeSpan.FromSeconds(30), options.KeepAliveInterval);
|
||||
Assert.Equal(TimeSpan.FromSeconds(10), options.ConnectTimeout);
|
||||
Assert.Equal(10, options.MaxReconnectAttempts);
|
||||
Assert.Equal(TimeSpan.FromMinutes(1), options.MaxReconnectBackoff);
|
||||
Assert.Equal(16 * 1024 * 1024, options.MaxFrameSize);
|
||||
Assert.False(options.RequireClientCertificate);
|
||||
Assert.False(options.AllowSelfSigned);
|
||||
Assert.False(options.CheckCertificateRevocation);
|
||||
Assert.Equal(SslProtocols.Tls12 | SslProtocols.Tls13, options.EnabledProtocols);
|
||||
}
|
||||
}
|
||||
|
||||
public class CertificateLoaderTests
|
||||
{
|
||||
[Fact]
|
||||
public void LoadServerCertificate_WithDirectCertificate_ReturnsCertificate()
|
||||
{
|
||||
var cert = CreateSelfSignedCertificate("TestServer");
|
||||
var options = new TlsTransportOptions
|
||||
{
|
||||
ServerCertificate = cert
|
||||
};
|
||||
|
||||
var loaded = CertificateLoader.LoadServerCertificate(options);
|
||||
|
||||
Assert.Same(cert, loaded);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void LoadServerCertificate_WithNoCertificate_ThrowsException()
|
||||
{
|
||||
var options = new TlsTransportOptions();
|
||||
|
||||
Assert.Throws<InvalidOperationException>(() => CertificateLoader.LoadServerCertificate(options));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void LoadClientCertificate_WithNoCertificate_ReturnsNull()
|
||||
{
|
||||
var options = new TlsTransportOptions();
|
||||
|
||||
var result = CertificateLoader.LoadClientCertificate(options);
|
||||
|
||||
Assert.Null(result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void LoadClientCertificate_WithDirectCertificate_ReturnsCertificate()
|
||||
{
|
||||
var cert = CreateSelfSignedCertificate("TestClient");
|
||||
var options = new TlsTransportOptions
|
||||
{
|
||||
ClientCertificate = cert
|
||||
};
|
||||
|
||||
var loaded = CertificateLoader.LoadClientCertificate(options);
|
||||
|
||||
Assert.Same(cert, loaded);
|
||||
}
|
||||
|
||||
private static X509Certificate2 CreateSelfSignedCertificate(string subject)
|
||||
{
|
||||
using var rsa = RSA.Create(2048);
|
||||
var request = new CertificateRequest(
|
||||
$"CN={subject}",
|
||||
rsa,
|
||||
HashAlgorithmName.SHA256,
|
||||
RSASignaturePadding.Pkcs1);
|
||||
|
||||
request.CertificateExtensions.Add(
|
||||
new X509KeyUsageExtension(X509KeyUsageFlags.DigitalSignature, critical: true));
|
||||
|
||||
var certificate = request.CreateSelfSigned(
|
||||
DateTimeOffset.UtcNow.AddMinutes(-5),
|
||||
DateTimeOffset.UtcNow.AddYears(1));
|
||||
|
||||
// Export and re-import to get the private key
|
||||
var pfxBytes = certificate.Export(X509ContentType.Pfx);
|
||||
return X509CertificateLoader.LoadPkcs12(
|
||||
pfxBytes,
|
||||
null,
|
||||
X509KeyStorageFlags.MachineKeySet);
|
||||
}
|
||||
}
|
||||
|
||||
public class TlsTransportServerTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task StartAsync_WithValidCertificate_StartsListening()
|
||||
{
|
||||
var cert = CreateSelfSignedCertificate("TestServer");
|
||||
var options = Options.Create(new TlsTransportOptions
|
||||
{
|
||||
Port = 0,
|
||||
ServerCertificate = cert
|
||||
});
|
||||
|
||||
await using var server = new TlsTransportServer(options, NullLogger<TlsTransportServer>.Instance);
|
||||
|
||||
await server.StartAsync(CancellationToken.None);
|
||||
|
||||
Assert.Equal(0, server.ConnectionCount);
|
||||
|
||||
await server.StopAsync(CancellationToken.None);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task StartAsync_WithNoCertificate_ThrowsException()
|
||||
{
|
||||
var options = Options.Create(new TlsTransportOptions { Port = 0 });
|
||||
await using var server = new TlsTransportServer(options, NullLogger<TlsTransportServer>.Instance);
|
||||
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(() =>
|
||||
server.StartAsync(CancellationToken.None));
|
||||
}
|
||||
|
||||
private static X509Certificate2 CreateSelfSignedCertificate(string subject)
|
||||
{
|
||||
using var rsa = RSA.Create(2048);
|
||||
var request = new CertificateRequest(
|
||||
$"CN={subject}",
|
||||
rsa,
|
||||
HashAlgorithmName.SHA256,
|
||||
RSASignaturePadding.Pkcs1);
|
||||
|
||||
request.CertificateExtensions.Add(
|
||||
new X509KeyUsageExtension(X509KeyUsageFlags.DigitalSignature | X509KeyUsageFlags.KeyEncipherment, critical: true));
|
||||
|
||||
request.CertificateExtensions.Add(
|
||||
new X509EnhancedKeyUsageExtension(
|
||||
new OidCollection { new Oid("1.3.6.1.5.5.7.3.1") },
|
||||
critical: true));
|
||||
|
||||
var certificate = request.CreateSelfSigned(
|
||||
DateTimeOffset.UtcNow.AddMinutes(-5),
|
||||
DateTimeOffset.UtcNow.AddYears(1));
|
||||
|
||||
var pfxBytes = certificate.Export(X509ContentType.Pfx);
|
||||
return X509CertificateLoader.LoadPkcs12(
|
||||
pfxBytes,
|
||||
null,
|
||||
X509KeyStorageFlags.MachineKeySet);
|
||||
}
|
||||
}
|
||||
|
||||
public class TlsConnectionTests
|
||||
{
|
||||
[Fact]
|
||||
public void ConnectionId_IsSet()
|
||||
{
|
||||
// This is more of a documentation test since TlsConnection
|
||||
// requires actual TcpClient and SslStream instances
|
||||
var options = new TlsTransportOptions();
|
||||
|
||||
Assert.NotNull(options);
|
||||
}
|
||||
}
|
||||
|
||||
public class TlsIntegrationTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task ServerAndClient_CanEstablishConnection()
|
||||
{
|
||||
// Create self-signed server certificate
|
||||
var serverCert = CreateSelfSignedServerCertificate("localhost");
|
||||
|
||||
var serverOptions = Options.Create(new TlsTransportOptions
|
||||
{
|
||||
Port = 0, // Auto-assign
|
||||
ServerCertificate = serverCert,
|
||||
RequireClientCertificate = false
|
||||
});
|
||||
|
||||
await using var server = new TlsTransportServer(serverOptions, NullLogger<TlsTransportServer>.Instance);
|
||||
|
||||
await server.StartAsync(CancellationToken.None);
|
||||
|
||||
Assert.Equal(0, server.ConnectionCount);
|
||||
|
||||
await server.StopAsync(CancellationToken.None);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ServerWithMtls_RequiresClientCertificate()
|
||||
{
|
||||
var serverCert = CreateSelfSignedServerCertificate("localhost");
|
||||
|
||||
var serverOptions = Options.Create(new TlsTransportOptions
|
||||
{
|
||||
Port = 0,
|
||||
ServerCertificate = serverCert,
|
||||
RequireClientCertificate = true,
|
||||
AllowSelfSigned = true
|
||||
});
|
||||
|
||||
await using var server = new TlsTransportServer(serverOptions, NullLogger<TlsTransportServer>.Instance);
|
||||
|
||||
await server.StartAsync(CancellationToken.None);
|
||||
|
||||
Assert.True(serverOptions.Value.RequireClientCertificate);
|
||||
|
||||
await server.StopAsync(CancellationToken.None);
|
||||
}
|
||||
|
||||
private static X509Certificate2 CreateSelfSignedServerCertificate(string hostname)
|
||||
{
|
||||
using var rsa = RSA.Create(2048);
|
||||
var request = new CertificateRequest(
|
||||
$"CN={hostname}",
|
||||
rsa,
|
||||
HashAlgorithmName.SHA256,
|
||||
RSASignaturePadding.Pkcs1);
|
||||
|
||||
// Key usage for server auth
|
||||
request.CertificateExtensions.Add(
|
||||
new X509KeyUsageExtension(
|
||||
X509KeyUsageFlags.DigitalSignature | X509KeyUsageFlags.KeyEncipherment,
|
||||
critical: true));
|
||||
|
||||
// Server authentication EKU
|
||||
request.CertificateExtensions.Add(
|
||||
new X509EnhancedKeyUsageExtension(
|
||||
new OidCollection { new Oid("1.3.6.1.5.5.7.3.1") },
|
||||
critical: true));
|
||||
|
||||
// Subject Alternative Name
|
||||
var sanBuilder = new SubjectAlternativeNameBuilder();
|
||||
sanBuilder.AddDnsName(hostname);
|
||||
sanBuilder.AddIpAddress(IPAddress.Loopback);
|
||||
request.CertificateExtensions.Add(sanBuilder.Build());
|
||||
|
||||
var certificate = request.CreateSelfSigned(
|
||||
DateTimeOffset.UtcNow.AddMinutes(-5),
|
||||
DateTimeOffset.UtcNow.AddYears(1));
|
||||
|
||||
var pfxBytes = certificate.Export(X509ContentType.Pfx);
|
||||
return X509CertificateLoader.LoadPkcs12(
|
||||
pfxBytes,
|
||||
null,
|
||||
X509KeyStorageFlags.MachineKeySet);
|
||||
}
|
||||
}
|
||||
|
||||
public class ServiceCollectionExtensionsTests
|
||||
{
|
||||
[Fact]
|
||||
public void AddTlsTransportServer_RegistersServices()
|
||||
{
|
||||
var services = new ServiceCollection();
|
||||
services.AddLogging();
|
||||
|
||||
services.AddTlsTransportServer(options =>
|
||||
{
|
||||
options.Port = 5101;
|
||||
});
|
||||
|
||||
var provider = services.BuildServiceProvider();
|
||||
var server = provider.GetService<TlsTransportServer>();
|
||||
|
||||
Assert.NotNull(server);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AddTlsTransportClient_RegistersServices()
|
||||
{
|
||||
var services = new ServiceCollection();
|
||||
services.AddLogging();
|
||||
|
||||
services.AddTlsTransportClient(options =>
|
||||
{
|
||||
options.Host = "localhost";
|
||||
options.Port = 5101;
|
||||
});
|
||||
|
||||
var provider = services.BuildServiceProvider();
|
||||
var client = provider.GetService<TlsTransportClient>();
|
||||
|
||||
Assert.NotNull(client);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user