Add unit tests for Router configuration and transport layers
Some checks failed
Docs CI / lint-and-preview (push) Has been cancelled
Policy Lint & Smoke / policy-lint (push) Has been cancelled

- Implemented tests for RouterConfig, RoutingOptions, StaticInstanceConfig, and RouterConfigOptions to ensure default values are set correctly.
- Added tests for RouterConfigProvider to validate configurations and ensure defaults are returned when no file is specified.
- Created tests for ConfigValidationResult to check success and error scenarios.
- Developed tests for ServiceCollectionExtensions to verify service registration for RouterConfig.
- Introduced UdpTransportTests to validate serialization, connection, request-response, and error handling in UDP transport.
- Added scripts for signing authority gaps and hashing DevPortal SDK snippets.
This commit is contained in:
StellaOps Bot
2025-12-05 08:01:47 +02:00
parent 635c70e828
commit 6a299d231f
294 changed files with 28434 additions and 1329 deletions

View File

@@ -0,0 +1,55 @@
using Microsoft.CodeAnalysis;
namespace StellaOps.Microservice.SourceGen;
/// <summary>
/// Diagnostic descriptors for the source generator.
/// </summary>
internal static class DiagnosticDescriptors
{
private const string Category = "StellaOps.Microservice";
/// <summary>
/// Class with [StellaEndpoint] must implement IStellaEndpoint or IRawStellaEndpoint.
/// </summary>
public static readonly DiagnosticDescriptor MissingHandlerInterface = new(
id: "STELLA001",
title: "Missing handler interface",
messageFormat: "Class '{0}' with [StellaEndpoint] must implement IStellaEndpoint<TRequest, TResponse> or IRawStellaEndpoint",
category: Category,
defaultSeverity: DiagnosticSeverity.Error,
isEnabledByDefault: true);
/// <summary>
/// Duplicate endpoint detected.
/// </summary>
public static readonly DiagnosticDescriptor DuplicateEndpoint = new(
id: "STELLA002",
title: "Duplicate endpoint",
messageFormat: "Duplicate endpoint: {0} {1} is defined in both '{2}' and '{3}'",
category: Category,
defaultSeverity: DiagnosticSeverity.Warning,
isEnabledByDefault: true);
/// <summary>
/// [StellaEndpoint] on abstract class is ignored.
/// </summary>
public static readonly DiagnosticDescriptor AbstractClassIgnored = new(
id: "STELLA003",
title: "Abstract class ignored",
messageFormat: "[StellaEndpoint] on abstract class '{0}' is ignored",
category: Category,
defaultSeverity: DiagnosticSeverity.Warning,
isEnabledByDefault: true);
/// <summary>
/// Informational: endpoints generated.
/// </summary>
public static readonly DiagnosticDescriptor EndpointsGenerated = new(
id: "STELLA004",
title: "Endpoints generated",
messageFormat: "Generated {0} endpoint descriptors",
category: Category,
defaultSeverity: DiagnosticSeverity.Info,
isEnabledByDefault: false);
}

View File

@@ -0,0 +1,17 @@
namespace StellaOps.Microservice.SourceGen;
/// <summary>
/// Holds extracted endpoint information from a [StellaEndpoint] decorated class.
/// </summary>
internal sealed record EndpointInfo(
string Namespace,
string ClassName,
string FullyQualifiedName,
string Method,
string Path,
int TimeoutSeconds,
bool SupportsStreaming,
string[] RequiredClaims,
string? RequestTypeName,
string? ResponseTypeName,
bool IsRaw);

View File

@@ -1,13 +0,0 @@
namespace StellaOps.Microservice.SourceGen;
/// <summary>
/// Placeholder type for the source generator project.
/// This will be replaced with actual source generator implementation in a later sprint.
/// </summary>
public static class Placeholder
{
/// <summary>
/// Indicates the source generator is not yet implemented.
/// </summary>
public const string Status = "NotImplemented";
}

View File

@@ -0,0 +1,10 @@
// Polyfills for netstandard2.0 compatibility
// ReSharper disable once CheckNamespace
namespace System.Runtime.CompilerServices
{
/// <summary>
/// Allows use of init accessors in netstandard2.0.
/// </summary>
internal static class IsExternalInit { }
}

View File

@@ -0,0 +1,399 @@
using System.Collections.Immutable;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
namespace StellaOps.Microservice.SourceGen;
/// <summary>
/// Incremental source generator for [StellaEndpoint] decorated classes.
/// Generates endpoint descriptors and DI registration at compile time.
/// </summary>
[Generator]
public sealed class StellaEndpointGenerator : IIncrementalGenerator
{
private const string StellaEndpointAttributeName = "StellaOps.Microservice.StellaEndpointAttribute";
private const string IStellaEndpointName = "StellaOps.Microservice.IStellaEndpoint";
private const string IRawStellaEndpointName = "StellaOps.Microservice.IRawStellaEndpoint";
/// <inheritdoc />
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Find all class declarations with attributes
var classDeclarations = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => IsSyntaxTargetForGeneration(s),
transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx))
.Where(static m => m is not null);
// Combine all endpoints and generate
var compilationAndClasses = context.CompilationProvider.Combine(classDeclarations.Collect());
context.RegisterSourceOutput(
compilationAndClasses,
static (spc, source) => Execute(source.Left, source.Right!, spc));
}
private static bool IsSyntaxTargetForGeneration(SyntaxNode node)
{
return node is ClassDeclarationSyntax { AttributeLists.Count: > 0 } classDecl
&& !classDecl.Modifiers.Any(SyntaxKind.AbstractKeyword);
}
private static ClassDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
{
var classDeclaration = (ClassDeclarationSyntax)context.Node;
foreach (var attributeList in classDeclaration.AttributeLists)
{
foreach (var attribute in attributeList.Attributes)
{
var symbolInfo = context.SemanticModel.GetSymbolInfo(attribute);
var symbol = symbolInfo.Symbol;
if (symbol is not IMethodSymbol attributeSymbol)
continue;
var attributeContainingType = attributeSymbol.ContainingType;
var fullName = attributeContainingType.ToDisplayString();
if (fullName == StellaEndpointAttributeName)
{
return classDeclaration;
}
}
}
return null;
}
private static void Execute(
Compilation compilation,
ImmutableArray<ClassDeclarationSyntax?> classes,
SourceProductionContext context)
{
if (classes.IsDefaultOrEmpty)
return;
var distinctClasses = classes.Where(c => c is not null).Distinct().Cast<ClassDeclarationSyntax>();
var endpoints = new List<EndpointInfo>();
foreach (var classDeclaration in distinctClasses)
{
var semanticModel = compilation.GetSemanticModel(classDeclaration.SyntaxTree);
var classSymbol = semanticModel.GetDeclaredSymbol(classDeclaration);
if (classSymbol is null)
continue;
var endpoint = ExtractEndpointInfo(classSymbol, context);
if (endpoint is not null)
{
endpoints.Add(endpoint);
}
}
if (endpoints.Count == 0)
return;
// Check for duplicates
var seen = new Dictionary<(string Method, string Path), EndpointInfo>();
foreach (var endpoint in endpoints)
{
var key = (endpoint.Method, endpoint.Path);
if (seen.TryGetValue(key, out var existing))
{
context.ReportDiagnostic(Diagnostic.Create(
DiagnosticDescriptors.DuplicateEndpoint,
Location.None,
endpoint.Method,
endpoint.Path,
existing.ClassName,
endpoint.ClassName));
}
else
{
seen[key] = endpoint;
}
}
// Generate the source
var source = GenerateEndpointsClass(endpoints);
context.AddSource("StellaEndpoints.g.cs", SourceText.From(source, Encoding.UTF8));
// Generate the provider class
var providerSource = GenerateProviderClass();
context.AddSource("GeneratedEndpointProvider.g.cs", SourceText.From(providerSource, Encoding.UTF8));
}
private static EndpointInfo? ExtractEndpointInfo(INamedTypeSymbol classSymbol, SourceProductionContext context)
{
// Find StellaEndpoint attribute
AttributeData? stellaAttribute = null;
foreach (var attr in classSymbol.GetAttributes())
{
if (attr.AttributeClass?.ToDisplayString() == StellaEndpointAttributeName)
{
stellaAttribute = attr;
break;
}
}
if (stellaAttribute is null)
return null;
// Check for abstract class
if (classSymbol.IsAbstract)
{
context.ReportDiagnostic(Diagnostic.Create(
DiagnosticDescriptors.AbstractClassIgnored,
Location.None,
classSymbol.Name));
return null;
}
// Extract constructor arguments: method and path
if (stellaAttribute.ConstructorArguments.Length < 2)
return null;
var method = stellaAttribute.ConstructorArguments[0].Value as string ?? "GET";
var path = stellaAttribute.ConstructorArguments[1].Value as string ?? "/";
// Extract named arguments
var timeoutSeconds = 30;
var supportsStreaming = false;
var requiredClaims = Array.Empty<string>();
foreach (var namedArg in stellaAttribute.NamedArguments)
{
switch (namedArg.Key)
{
case "TimeoutSeconds":
timeoutSeconds = (int)(namedArg.Value.Value ?? 30);
break;
case "SupportsStreaming":
supportsStreaming = (bool)(namedArg.Value.Value ?? false);
break;
case "RequiredClaims":
if (!namedArg.Value.IsNull && namedArg.Value.Values.Length > 0)
{
requiredClaims = namedArg.Value.Values
.Select(v => v.Value as string)
.Where(s => s is not null)
.Cast<string>()
.ToArray();
}
break;
}
}
// Find handler interface implementation
string? requestTypeName = null;
string? responseTypeName = null;
bool isRaw = false;
foreach (var iface in classSymbol.AllInterfaces)
{
var fullName = iface.OriginalDefinition.ToDisplayString();
if (fullName.StartsWith(IStellaEndpointName) && iface.TypeArguments.Length == 2)
{
requestTypeName = iface.TypeArguments[0].ToDisplayString();
responseTypeName = iface.TypeArguments[1].ToDisplayString();
isRaw = false;
break;
}
if (fullName == IRawStellaEndpointName)
{
isRaw = true;
break;
}
}
// If no handler interface found, report error
if (!isRaw && requestTypeName is null)
{
context.ReportDiagnostic(Diagnostic.Create(
DiagnosticDescriptors.MissingHandlerInterface,
Location.None,
classSymbol.Name));
return null;
}
var ns = classSymbol.ContainingNamespace.IsGlobalNamespace
? string.Empty
: classSymbol.ContainingNamespace.ToDisplayString();
return new EndpointInfo(
Namespace: ns,
ClassName: classSymbol.Name,
FullyQualifiedName: classSymbol.ToDisplayString(),
Method: method.ToUpperInvariant(),
Path: path,
TimeoutSeconds: timeoutSeconds,
SupportsStreaming: supportsStreaming,
RequiredClaims: requiredClaims,
RequestTypeName: requestTypeName,
ResponseTypeName: responseTypeName,
IsRaw: isRaw);
}
private static string GenerateEndpointsClass(List<EndpointInfo> endpoints)
{
var sb = new StringBuilder();
sb.AppendLine("// <auto-generated/>");
sb.AppendLine("#nullable enable");
sb.AppendLine();
sb.AppendLine("namespace StellaOps.Microservice.Generated");
sb.AppendLine("{");
sb.AppendLine(" /// <summary>");
sb.AppendLine(" /// Auto-generated endpoint metadata and registration.");
sb.AppendLine(" /// </summary>");
sb.AppendLine(" [global::System.CodeDom.Compiler.GeneratedCode(\"StellaOps.Microservice.SourceGen\", \"1.0.0\")]");
sb.AppendLine(" internal static class StellaEndpoints");
sb.AppendLine(" {");
// GetEndpoints method
sb.AppendLine(" /// <summary>");
sb.AppendLine(" /// Gets all discovered endpoint descriptors.");
sb.AppendLine(" /// </summary>");
sb.AppendLine(" public static global::System.Collections.Generic.IReadOnlyList<global::StellaOps.Router.Common.Models.EndpointDescriptor> GetEndpoints()");
sb.AppendLine(" {");
sb.AppendLine(" return new global::StellaOps.Router.Common.Models.EndpointDescriptor[]");
sb.AppendLine(" {");
for (int i = 0; i < endpoints.Count; i++)
{
var ep = endpoints[i];
sb.AppendLine(" new global::StellaOps.Router.Common.Models.EndpointDescriptor");
sb.AppendLine(" {");
sb.AppendLine(" ServiceName = string.Empty, // Set by SDK at registration");
sb.AppendLine(" Version = string.Empty, // Set by SDK at registration");
sb.AppendLine($" Method = \"{EscapeString(ep.Method)}\",");
sb.AppendLine($" Path = \"{EscapeString(ep.Path)}\",");
sb.AppendLine($" DefaultTimeout = global::System.TimeSpan.FromSeconds({ep.TimeoutSeconds}),");
sb.AppendLine($" SupportsStreaming = {(ep.SupportsStreaming ? "true" : "false")},");
sb.Append(" RequiringClaims = ");
if (ep.RequiredClaims.Length == 0)
{
sb.AppendLine("new global::System.Collections.Generic.List<global::StellaOps.Router.Common.Models.ClaimRequirement>(),");
}
else
{
sb.AppendLine("new global::System.Collections.Generic.List<global::StellaOps.Router.Common.Models.ClaimRequirement>");
sb.AppendLine(" {");
foreach (var claim in ep.RequiredClaims)
{
sb.AppendLine($" new global::StellaOps.Router.Common.Models.ClaimRequirement {{ Type = \"{EscapeString(claim)}\", Value = null }},");
}
sb.AppendLine(" },");
}
sb.AppendLine($" HandlerType = typeof(global::{ep.FullyQualifiedName})");
sb.Append(" }");
if (i < endpoints.Count - 1)
{
sb.AppendLine(",");
}
else
{
sb.AppendLine();
}
}
sb.AppendLine(" };");
sb.AppendLine(" }");
sb.AppendLine();
// RegisterHandlers method
sb.AppendLine(" /// <summary>");
sb.AppendLine(" /// Registers all endpoint handlers with the service collection.");
sb.AppendLine(" /// </summary>");
sb.AppendLine(" public static void RegisterHandlers(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services)");
sb.AppendLine(" {");
foreach (var ep in endpoints)
{
sb.AppendLine($" global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient<global::{ep.FullyQualifiedName}>(services);");
}
sb.AppendLine(" }");
// GetHandlerTypes method
sb.AppendLine();
sb.AppendLine(" /// <summary>");
sb.AppendLine(" /// Gets all handler types for endpoint discovery.");
sb.AppendLine(" /// </summary>");
sb.AppendLine(" public static global::System.Collections.Generic.IReadOnlyList<global::System.Type> GetHandlerTypes()");
sb.AppendLine(" {");
sb.AppendLine(" return new global::System.Type[]");
sb.AppendLine(" {");
for (int i = 0; i < endpoints.Count; i++)
{
var ep = endpoints[i];
sb.Append($" typeof(global::{ep.FullyQualifiedName})");
if (i < endpoints.Count - 1)
{
sb.AppendLine(",");
}
else
{
sb.AppendLine();
}
}
sb.AppendLine(" };");
sb.AppendLine(" }");
sb.AppendLine(" }");
sb.AppendLine("}");
return sb.ToString();
}
private static string GenerateProviderClass()
{
var sb = new StringBuilder();
sb.AppendLine("// <auto-generated/>");
sb.AppendLine("#nullable enable");
sb.AppendLine();
sb.AppendLine("namespace StellaOps.Microservice.Generated");
sb.AppendLine("{");
sb.AppendLine(" /// <summary>");
sb.AppendLine(" /// Generated implementation of IGeneratedEndpointProvider.");
sb.AppendLine(" /// </summary>");
sb.AppendLine(" [global::System.CodeDom.Compiler.GeneratedCode(\"StellaOps.Microservice.SourceGen\", \"1.0.0\")]");
sb.AppendLine(" internal sealed class GeneratedEndpointProvider : global::StellaOps.Microservice.IGeneratedEndpointProvider");
sb.AppendLine(" {");
sb.AppendLine(" /// <inheritdoc />");
sb.AppendLine(" public global::System.Collections.Generic.IReadOnlyList<global::StellaOps.Router.Common.Models.EndpointDescriptor> GetEndpoints()");
sb.AppendLine(" => StellaEndpoints.GetEndpoints();");
sb.AppendLine();
sb.AppendLine(" /// <inheritdoc />");
sb.AppendLine(" public void RegisterHandlers(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services)");
sb.AppendLine(" => StellaEndpoints.RegisterHandlers(services);");
sb.AppendLine();
sb.AppendLine(" /// <inheritdoc />");
sb.AppendLine(" public global::System.Collections.Generic.IReadOnlyList<global::System.Type> GetHandlerTypes()");
sb.AppendLine(" => StellaEndpoints.GetHandlerTypes();");
sb.AppendLine(" }");
sb.AppendLine("}");
return sb.ToString();
}
private static string EscapeString(string value)
{
return value
.Replace("\\", "\\\\")
.Replace("\"", "\\\"")
.Replace("\n", "\\n")
.Replace("\r", "\\r")
.Replace("\t", "\\t");
}
}

View File

@@ -1,9 +1,32 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<LangVersion>preview</LangVersion>
<!-- Source generators must target netstandard2.0 for Roslyn compatibility -->
<TargetFramework>netstandard2.0</TargetFramework>
<LangVersion>12.0</LangVersion>
<Nullable>enable</Nullable>
<ImplicitUsings>enable</ImplicitUsings>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<!-- Source generator specific settings -->
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules>
<IsRoslynComponent>true</IsRoslynComponent>
<IncludeBuildOutput>false</IncludeBuildOutput>
<!-- Package settings for distribution -->
<PackageId>StellaOps.Microservice.SourceGen</PackageId>
<Description>Source generator for Stella microservice endpoints</Description>
<DevelopmentDependency>true</DevelopmentDependency>
<IsPackable>true</IsPackable>
<NoWarn>$(NoWarn);NU5128;RS2008</NoWarn>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.4" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.0.1" PrivateAssets="all" />
</ItemGroup>
<!-- Pack the analyzer as an analyzer -->
<ItemGroup>
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,71 @@
using Microsoft.Extensions.Logging;
using StellaOps.Router.Common.Models;
namespace StellaOps.Microservice;
/// <summary>
/// Interface for discovering endpoints with YAML configuration support.
/// </summary>
public interface IEndpointDiscoveryService
{
/// <summary>
/// Discovers all endpoints, applying any YAML configuration overrides.
/// </summary>
/// <returns>The discovered endpoints with overrides applied.</returns>
IReadOnlyList<EndpointDescriptor> DiscoverEndpoints();
}
/// <summary>
/// Service that discovers endpoints and applies YAML configuration overrides.
/// </summary>
public sealed class EndpointDiscoveryService : IEndpointDiscoveryService
{
private readonly IEndpointDiscoveryProvider _discoveryProvider;
private readonly IMicroserviceYamlLoader _yamlLoader;
private readonly IEndpointOverrideMerger _merger;
private readonly ILogger<EndpointDiscoveryService> _logger;
/// <summary>
/// Initializes a new instance of the <see cref="EndpointDiscoveryService"/> class.
/// </summary>
public EndpointDiscoveryService(
IEndpointDiscoveryProvider discoveryProvider,
IMicroserviceYamlLoader yamlLoader,
IEndpointOverrideMerger merger,
ILogger<EndpointDiscoveryService> logger)
{
_discoveryProvider = discoveryProvider;
_yamlLoader = yamlLoader;
_merger = merger;
_logger = logger;
}
/// <inheritdoc />
public IReadOnlyList<EndpointDescriptor> DiscoverEndpoints()
{
// 1. Discover endpoints from code (via reflection or source gen)
var codeEndpoints = _discoveryProvider.DiscoverEndpoints();
_logger.LogDebug("Discovered {Count} endpoints from code", codeEndpoints.Count);
// 2. Load YAML overrides
MicroserviceYamlConfig? yamlConfig = null;
try
{
yamlConfig = _yamlLoader.Load();
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to load YAML configuration, using code defaults only");
}
// 3. Merge code endpoints with YAML overrides
var mergedEndpoints = _merger.Merge(codeEndpoints, yamlConfig);
_logger.LogInformation(
"Endpoint discovery complete: {Count} endpoints (YAML overrides: {HasYaml})",
mergedEndpoints.Count,
yamlConfig != null);
return mergedEndpoints;
}
}

View File

@@ -0,0 +1,115 @@
using Microsoft.Extensions.Logging;
using StellaOps.Router.Common.Models;
namespace StellaOps.Microservice;
/// <summary>
/// Interface for merging endpoint overrides from YAML configuration.
/// </summary>
public interface IEndpointOverrideMerger
{
/// <summary>
/// Merges YAML overrides with code-defined endpoints.
/// </summary>
/// <param name="codeEndpoints">The endpoints discovered from code.</param>
/// <param name="yamlConfig">The YAML configuration, if any.</param>
/// <returns>The merged endpoints.</returns>
IReadOnlyList<EndpointDescriptor> Merge(
IReadOnlyList<EndpointDescriptor> codeEndpoints,
MicroserviceYamlConfig? yamlConfig);
}
/// <summary>
/// Merges endpoint overrides from YAML configuration with code defaults.
/// </summary>
public sealed class EndpointOverrideMerger : IEndpointOverrideMerger
{
private readonly ILogger<EndpointOverrideMerger> _logger;
/// <summary>
/// Initializes a new instance of the <see cref="EndpointOverrideMerger"/> class.
/// </summary>
public EndpointOverrideMerger(ILogger<EndpointOverrideMerger> logger)
{
_logger = logger;
}
/// <inheritdoc />
public IReadOnlyList<EndpointDescriptor> Merge(
IReadOnlyList<EndpointDescriptor> codeEndpoints,
MicroserviceYamlConfig? yamlConfig)
{
if (yamlConfig == null || yamlConfig.Endpoints.Count == 0)
{
return codeEndpoints;
}
WarnUnmatchedOverrides(codeEndpoints, yamlConfig);
return codeEndpoints.Select(ep =>
{
var yamlOverride = FindMatchingOverride(ep, yamlConfig);
return yamlOverride == null ? ep : MergeEndpoint(ep, yamlOverride);
}).ToList();
}
private static EndpointOverrideConfig? FindMatchingOverride(
EndpointDescriptor endpoint,
MicroserviceYamlConfig yamlConfig)
{
return yamlConfig.Endpoints.FirstOrDefault(y =>
string.Equals(y.Method, endpoint.Method, StringComparison.OrdinalIgnoreCase) &&
string.Equals(y.Path, endpoint.Path, StringComparison.OrdinalIgnoreCase));
}
private EndpointDescriptor MergeEndpoint(
EndpointDescriptor codeDefault,
EndpointOverrideConfig yamlOverride)
{
var merged = codeDefault with
{
DefaultTimeout = yamlOverride.GetDefaultTimeoutAsTimeSpan() ?? codeDefault.DefaultTimeout,
SupportsStreaming = yamlOverride.SupportsStreaming ?? codeDefault.SupportsStreaming,
RequiringClaims = yamlOverride.RequiringClaims?.Count > 0
? yamlOverride.RequiringClaims.Select(c => c.ToClaimRequirement()).ToList()
: codeDefault.RequiringClaims
};
if (yamlOverride.GetDefaultTimeoutAsTimeSpan().HasValue ||
yamlOverride.SupportsStreaming.HasValue ||
yamlOverride.RequiringClaims?.Count > 0)
{
_logger.LogDebug(
"Applied YAML overrides to endpoint {Method} {Path}: Timeout={Timeout}, Streaming={Streaming}, Claims={Claims}",
merged.Method,
merged.Path,
merged.DefaultTimeout,
merged.SupportsStreaming,
merged.RequiringClaims?.Count ?? 0);
}
return merged;
}
private void WarnUnmatchedOverrides(
IReadOnlyList<EndpointDescriptor> codeEndpoints,
MicroserviceYamlConfig yamlConfig)
{
var codeKeys = codeEndpoints
.Select(e => (Method: e.Method.ToUpperInvariant(), Path: e.Path.ToLowerInvariant()))
.ToHashSet();
foreach (var yamlEntry in yamlConfig.Endpoints)
{
var key = (Method: yamlEntry.Method.ToUpperInvariant(), Path: yamlEntry.Path.ToLowerInvariant());
if (!codeKeys.Contains(key))
{
_logger.LogWarning(
"YAML override for {Method} {Path} does not match any code endpoint. " +
"YAML cannot create endpoints, only modify existing ones.",
yamlEntry.Method,
yamlEntry.Path);
}
}
}
}

View File

@@ -1,3 +1,5 @@
using StellaOps.Router.Common.Models;
namespace StellaOps.Microservice;
/// <summary>

View File

@@ -0,0 +1,110 @@
using System.Reflection;
using Microsoft.Extensions.Logging;
using StellaOps.Router.Common.Models;
namespace StellaOps.Microservice;
/// <summary>
/// Discovers endpoints using source-generated provider, falling back to reflection.
/// </summary>
public sealed class GeneratedEndpointDiscoveryProvider : IEndpointDiscoveryProvider
{
private readonly StellaMicroserviceOptions _options;
private readonly ILogger<GeneratedEndpointDiscoveryProvider> _logger;
private readonly ReflectionEndpointDiscoveryProvider _reflectionFallback;
private const string GeneratedProviderTypeName = "StellaOps.Microservice.Generated.GeneratedEndpointProvider";
/// <summary>
/// Initializes a new instance of the <see cref="GeneratedEndpointDiscoveryProvider"/> class.
/// </summary>
public GeneratedEndpointDiscoveryProvider(
StellaMicroserviceOptions options,
ILogger<GeneratedEndpointDiscoveryProvider> logger)
{
_options = options;
_logger = logger;
_reflectionFallback = new ReflectionEndpointDiscoveryProvider(options);
}
/// <inheritdoc />
public IReadOnlyList<EndpointDescriptor> DiscoverEndpoints()
{
// Try to find the generated provider
var generatedProvider = TryGetGeneratedProvider();
if (generatedProvider != null)
{
_logger.LogDebug("Using source-generated endpoint discovery");
var endpoints = generatedProvider.GetEndpoints();
// Apply service name and version from options
var result = new List<EndpointDescriptor>();
foreach (var endpoint in endpoints)
{
result.Add(endpoint with
{
ServiceName = _options.ServiceName,
Version = _options.Version
});
}
_logger.LogInformation(
"Discovered {Count} endpoints via source generation",
result.Count);
return result;
}
// Fall back to reflection
_logger.LogDebug("Source-generated provider not found, falling back to reflection");
return _reflectionFallback.DiscoverEndpoints();
}
private IGeneratedEndpointProvider? TryGetGeneratedProvider()
{
try
{
// Look in the entry assembly first
var entryAssembly = Assembly.GetEntryAssembly();
var providerType = entryAssembly?.GetType(GeneratedProviderTypeName);
if (providerType != null)
{
return (IGeneratedEndpointProvider)Activator.CreateInstance(providerType)!;
}
// Also check the calling assembly
var callingAssembly = Assembly.GetCallingAssembly();
providerType = callingAssembly.GetType(GeneratedProviderTypeName);
if (providerType != null)
{
return (IGeneratedEndpointProvider)Activator.CreateInstance(providerType)!;
}
// Check all loaded assemblies
foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
{
try
{
providerType = assembly.GetType(GeneratedProviderTypeName);
if (providerType != null)
{
return (IGeneratedEndpointProvider)Activator.CreateInstance(providerType)!;
}
}
catch
{
// Ignore assembly loading errors
}
}
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Failed to load generated endpoint provider");
}
return null;
}
}

View File

@@ -1,3 +1,5 @@
using StellaOps.Router.Common.Models;
namespace StellaOps.Microservice;
/// <summary>

View File

@@ -0,0 +1,25 @@
using Microsoft.Extensions.DependencyInjection;
using StellaOps.Router.Common.Models;
namespace StellaOps.Microservice;
/// <summary>
/// Interface implemented by the source-generated endpoint provider.
/// </summary>
public interface IGeneratedEndpointProvider
{
/// <summary>
/// Gets all discovered endpoint descriptors.
/// </summary>
IReadOnlyList<EndpointDescriptor> GetEndpoints();
/// <summary>
/// Registers all endpoint handlers with the service collection.
/// </summary>
void RegisterHandlers(IServiceCollection services);
/// <summary>
/// Gets all handler types for endpoint discovery.
/// </summary>
IReadOnlyList<Type> GetHandlerTypes();
}

View File

@@ -0,0 +1,145 @@
using System.Collections.Concurrent;
using Microsoft.Extensions.Logging;
namespace StellaOps.Microservice;
/// <summary>
/// Tracks in-flight requests and manages their cancellation tokens.
/// </summary>
public sealed class InflightRequestTracker : IDisposable
{
private readonly ConcurrentDictionary<Guid, InflightRequest> _inflight = new();
private readonly ILogger<InflightRequestTracker> _logger;
private bool _disposed;
/// <summary>
/// Initializes a new instance of the <see cref="InflightRequestTracker"/> class.
/// </summary>
public InflightRequestTracker(ILogger<InflightRequestTracker> logger)
{
_logger = logger;
}
/// <summary>
/// Gets the count of in-flight requests.
/// </summary>
public int Count => _inflight.Count;
/// <summary>
/// Starts tracking a request and returns a cancellation token for it.
/// </summary>
/// <param name="correlationId">The correlation ID of the request.</param>
/// <returns>A cancellation token that will be triggered if the request is cancelled.</returns>
public CancellationToken Track(Guid correlationId)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var cts = new CancellationTokenSource();
var request = new InflightRequest(cts);
if (!_inflight.TryAdd(correlationId, request))
{
cts.Dispose();
throw new InvalidOperationException($"Request {correlationId} is already being tracked");
}
_logger.LogDebug("Started tracking request {CorrelationId}", correlationId);
return cts.Token;
}
/// <summary>
/// Cancels a specific request.
/// </summary>
/// <param name="correlationId">The correlation ID of the request to cancel.</param>
/// <param name="reason">The reason for cancellation.</param>
/// <returns>True if the request was found and cancelled; otherwise false.</returns>
public bool Cancel(Guid correlationId, string? reason)
{
if (_inflight.TryGetValue(correlationId, out var request))
{
try
{
request.Cts.Cancel();
_logger.LogInformation(
"Cancelled request {CorrelationId}: {Reason}",
correlationId,
reason ?? "Unknown");
return true;
}
catch (ObjectDisposedException)
{
// CTS was already disposed, request completed
return false;
}
}
_logger.LogDebug(
"Cannot cancel request {CorrelationId}: not found (may have already completed)",
correlationId);
return false;
}
/// <summary>
/// Marks a request as completed and removes it from tracking.
/// </summary>
/// <param name="correlationId">The correlation ID of the completed request.</param>
public void Complete(Guid correlationId)
{
if (_inflight.TryRemove(correlationId, out var request))
{
request.Cts.Dispose();
_logger.LogDebug("Completed request {CorrelationId}", correlationId);
}
}
/// <summary>
/// Cancels all in-flight requests.
/// </summary>
/// <param name="reason">The reason for cancellation.</param>
public void CancelAll(string reason)
{
var count = 0;
foreach (var kvp in _inflight)
{
try
{
kvp.Value.Cts.Cancel();
count++;
}
catch (ObjectDisposedException)
{
// Already disposed
}
}
_logger.LogInformation("Cancelled {Count} in-flight requests: {Reason}", count, reason);
// Clear and dispose all
foreach (var kvp in _inflight)
{
if (_inflight.TryRemove(kvp.Key, out var request))
{
request.Cts.Dispose();
}
}
}
/// <inheritdoc />
public void Dispose()
{
if (_disposed) return;
_disposed = true;
CancelAll("Disposing tracker");
}
private sealed class InflightRequest
{
public CancellationTokenSource Cts { get; }
public InflightRequest(CancellationTokenSource cts)
{
Cts = cts;
}
}
}

View File

@@ -0,0 +1,113 @@
using StellaOps.Router.Common.Models;
using YamlDotNet.Serialization;
namespace StellaOps.Microservice;
/// <summary>
/// Root configuration for microservice endpoint overrides loaded from YAML.
/// </summary>
public sealed class MicroserviceYamlConfig
{
/// <summary>
/// Gets or sets the endpoint override configurations.
/// </summary>
[YamlMember(Alias = "endpoints")]
public List<EndpointOverrideConfig> Endpoints { get; set; } = [];
}
/// <summary>
/// Configuration for overriding an endpoint's properties.
/// </summary>
public sealed class EndpointOverrideConfig
{
/// <summary>
/// Gets or sets the HTTP method to match.
/// </summary>
[YamlMember(Alias = "method")]
public string Method { get; set; } = string.Empty;
/// <summary>
/// Gets or sets the path to match.
/// </summary>
[YamlMember(Alias = "path")]
public string Path { get; set; } = string.Empty;
/// <summary>
/// Gets or sets the default timeout override.
/// </summary>
[YamlMember(Alias = "defaultTimeout")]
public string? DefaultTimeout { get; set; }
/// <summary>
/// Gets or sets whether streaming is supported.
/// </summary>
[YamlMember(Alias = "supportsStreaming")]
public bool? SupportsStreaming { get; set; }
/// <summary>
/// Gets or sets the claim requirements.
/// </summary>
[YamlMember(Alias = "requiringClaims")]
public List<ClaimRequirementConfig>? RequiringClaims { get; set; }
/// <summary>
/// Parses the DefaultTimeout string to a TimeSpan.
/// </summary>
public TimeSpan? GetDefaultTimeoutAsTimeSpan()
{
if (string.IsNullOrWhiteSpace(DefaultTimeout))
return null;
// Handle formats like "30s", "5m", "1h", or "00:00:30"
var value = DefaultTimeout.Trim();
if (value.EndsWith("s", StringComparison.OrdinalIgnoreCase))
{
if (int.TryParse(value[..^1], out var seconds))
return TimeSpan.FromSeconds(seconds);
}
else if (value.EndsWith("m", StringComparison.OrdinalIgnoreCase))
{
if (int.TryParse(value[..^1], out var minutes))
return TimeSpan.FromMinutes(minutes);
}
else if (value.EndsWith("h", StringComparison.OrdinalIgnoreCase))
{
if (int.TryParse(value[..^1], out var hours))
return TimeSpan.FromHours(hours);
}
else if (TimeSpan.TryParse(value, out var timespan))
{
return timespan;
}
return null;
}
}
/// <summary>
/// Configuration for a claim requirement.
/// </summary>
public sealed class ClaimRequirementConfig
{
/// <summary>
/// Gets or sets the claim type.
/// </summary>
[YamlMember(Alias = "type")]
public string Type { get; set; } = string.Empty;
/// <summary>
/// Gets or sets the claim value.
/// </summary>
[YamlMember(Alias = "value")]
public string? Value { get; set; }
/// <summary>
/// Converts to a ClaimRequirement model.
/// </summary>
public ClaimRequirement ToClaimRequirement() => new()
{
Type = Type,
Value = Value
};
}

View File

@@ -0,0 +1,78 @@
using Microsoft.Extensions.Logging;
using YamlDotNet.Serialization;
using YamlDotNet.Serialization.NamingConventions;
namespace StellaOps.Microservice;
/// <summary>
/// Interface for loading microservice YAML configuration.
/// </summary>
public interface IMicroserviceYamlLoader
{
/// <summary>
/// Loads the microservice configuration from YAML.
/// </summary>
/// <returns>The configuration, or null if no file is configured or file doesn't exist.</returns>
MicroserviceYamlConfig? Load();
}
/// <summary>
/// Loads microservice configuration from a YAML file.
/// </summary>
public sealed class MicroserviceYamlLoader : IMicroserviceYamlLoader
{
private readonly StellaMicroserviceOptions _options;
private readonly ILogger<MicroserviceYamlLoader> _logger;
private readonly IDeserializer _deserializer;
/// <summary>
/// Initializes a new instance of the <see cref="MicroserviceYamlLoader"/> class.
/// </summary>
public MicroserviceYamlLoader(
StellaMicroserviceOptions options,
ILogger<MicroserviceYamlLoader> logger)
{
_options = options;
_logger = logger;
_deserializer = new DeserializerBuilder()
.WithNamingConvention(CamelCaseNamingConvention.Instance)
.IgnoreUnmatchedProperties()
.Build();
}
/// <inheritdoc />
public MicroserviceYamlConfig? Load()
{
if (string.IsNullOrWhiteSpace(_options.ConfigFilePath))
{
_logger.LogDebug("No ConfigFilePath specified, skipping YAML configuration");
return null;
}
var fullPath = Path.GetFullPath(_options.ConfigFilePath);
if (!File.Exists(fullPath))
{
_logger.LogDebug("Configuration file {Path} does not exist, skipping", fullPath);
return null;
}
try
{
var yaml = File.ReadAllText(fullPath);
var config = _deserializer.Deserialize<MicroserviceYamlConfig>(yaml);
_logger.LogInformation(
"Loaded microservice configuration from {Path} with {Count} endpoint overrides",
fullPath,
config?.Endpoints?.Count ?? 0);
return config;
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to load microservice configuration from {Path}", fullPath);
throw;
}
}
}

View File

@@ -1,85 +1,2 @@
using System.Text.RegularExpressions;
namespace StellaOps.Microservice;
/// <summary>
/// Matches request paths against route templates.
/// </summary>
public sealed partial class PathMatcher
{
private readonly string _template;
private readonly Regex _regex;
private readonly string[] _parameterNames;
private readonly bool _caseInsensitive;
/// <summary>
/// Gets the route template.
/// </summary>
public string Template => _template;
/// <summary>
/// Initializes a new instance of the <see cref="PathMatcher"/> class.
/// </summary>
/// <param name="template">The route template (e.g., "/api/users/{id}").</param>
/// <param name="caseInsensitive">Whether matching should be case-insensitive.</param>
public PathMatcher(string template, bool caseInsensitive = true)
{
_template = template;
_caseInsensitive = caseInsensitive;
// Extract parameter names and build regex
var paramNames = new List<string>();
var pattern = "^" + ParameterRegex().Replace(template, match =>
{
paramNames.Add(match.Groups[1].Value);
return "([^/]+)";
}) + "/?$";
var options = caseInsensitive ? RegexOptions.IgnoreCase : RegexOptions.None;
_regex = new Regex(pattern, options | RegexOptions.Compiled);
_parameterNames = [.. paramNames];
}
/// <summary>
/// Tries to match a path against the template.
/// </summary>
/// <param name="path">The request path.</param>
/// <param name="parameters">The extracted path parameters if matched.</param>
/// <returns>True if the path matches.</returns>
public bool TryMatch(string path, out Dictionary<string, string> parameters)
{
parameters = [];
// Normalize path
path = path.TrimEnd('/');
if (!path.StartsWith('/'))
path = "/" + path;
var match = _regex.Match(path);
if (!match.Success)
return false;
for (int i = 0; i < _parameterNames.Length; i++)
{
parameters[_parameterNames[i]] = match.Groups[i + 1].Value;
}
return true;
}
/// <summary>
/// Checks if a path matches the template.
/// </summary>
/// <param name="path">The request path.</param>
/// <returns>True if the path matches.</returns>
public bool IsMatch(string path)
{
path = path.TrimEnd('/');
if (!path.StartsWith('/'))
path = "/" + path;
return _regex.IsMatch(path);
}
[GeneratedRegex(@"\{([^}:]+)(?::[^}]+)?\}")]
private static partial Regex ParameterRegex();
}
// Re-export PathMatcher from Router.Common for backwards compatibility
global using PathMatcher = StellaOps.Router.Common.PathMatcher;

View File

@@ -2,6 +2,7 @@ using System.Text.Json;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using StellaOps.Router.Common.Frames;
using StellaOps.Router.Common.Models;
namespace StellaOps.Microservice;
@@ -116,6 +117,13 @@ public sealed class RequestDispatcher
RawRequestContext context,
CancellationToken cancellationToken)
{
// Ensure handler type is set
if (endpoint.HandlerType is null)
{
_logger.LogError("Endpoint {Method} {Path} has no handler type", endpoint.Method, endpoint.Path);
return RawResponse.InternalError("No handler configured");
}
// Get handler instance from DI
var handler = scopedProvider.GetService(endpoint.HandlerType);
if (handler is null)

View File

@@ -14,13 +14,16 @@ public sealed class RouterConnectionManager : IRouterConnectionManager, IDisposa
{
private readonly StellaMicroserviceOptions _options;
private readonly IEndpointDiscoveryProvider _endpointDiscovery;
private readonly ITransportClient _transportClient;
private readonly IMicroserviceTransport? _microserviceTransport;
private readonly ILogger<RouterConnectionManager> _logger;
private readonly ConcurrentDictionary<string, ConnectionState> _connections = new();
private readonly CancellationTokenSource _cts = new();
private IReadOnlyList<EndpointDescriptor>? _endpoints;
private Task? _heartbeatTask;
private bool _disposed;
private volatile InstanceHealthStatus _currentStatus = InstanceHealthStatus.Healthy;
private int _inFlightRequestCount;
private double _errorRate;
/// <inheritdoc />
public IReadOnlyList<ConnectionState> Connections => [.. _connections.Values];
@@ -31,15 +34,42 @@ public sealed class RouterConnectionManager : IRouterConnectionManager, IDisposa
public RouterConnectionManager(
IOptions<StellaMicroserviceOptions> options,
IEndpointDiscoveryProvider endpointDiscovery,
ITransportClient transportClient,
IMicroserviceTransport? microserviceTransport,
ILogger<RouterConnectionManager> logger)
{
_options = options.Value;
_endpointDiscovery = endpointDiscovery;
_transportClient = transportClient;
_microserviceTransport = microserviceTransport;
_logger = logger;
}
/// <summary>
/// Gets or sets the current health status reported by this instance.
/// </summary>
public InstanceHealthStatus CurrentStatus
{
get => _currentStatus;
set => _currentStatus = value;
}
/// <summary>
/// Gets or sets the count of in-flight requests.
/// </summary>
public int InFlightRequestCount
{
get => _inFlightRequestCount;
set => _inFlightRequestCount = value;
}
/// <summary>
/// Gets or sets the error rate (0.0 to 1.0).
/// </summary>
public double ErrorRate
{
get => _errorRate;
set => _errorRate = value;
}
/// <inheritdoc />
public async Task StartAsync(CancellationToken cancellationToken)
{
@@ -168,32 +198,40 @@ public sealed class RouterConnectionManager : IRouterConnectionManager, IDisposa
{
await Task.Delay(_options.HeartbeatInterval, cancellationToken);
foreach (var connection in _connections.Values)
// Build heartbeat payload with current status and metrics
var heartbeat = new HeartbeatPayload
{
InstanceId = _options.InstanceId,
Status = _currentStatus,
InFlightRequestCount = _inFlightRequestCount,
ErrorRate = _errorRate,
TimestampUtc = DateTime.UtcNow
};
// Send heartbeat via transport
if (_microserviceTransport is not null)
{
try
{
// Build heartbeat payload
var heartbeat = new HeartbeatPayload
{
InstanceId = _options.InstanceId,
Status = connection.Status,
TimestampUtc = DateTime.UtcNow
};
// Update last heartbeat time
connection.LastHeartbeatUtc = DateTime.UtcNow;
await _microserviceTransport.SendHeartbeatAsync(heartbeat, cancellationToken);
_logger.LogDebug(
"Sent heartbeat for connection {ConnectionId}",
connection.ConnectionId);
"Sent heartbeat: status={Status}, inflight={InFlight}, errorRate={ErrorRate:P1}",
heartbeat.Status,
heartbeat.InFlightRequestCount,
heartbeat.ErrorRate);
}
catch (Exception ex)
{
_logger.LogWarning(ex,
"Failed to send heartbeat for connection {ConnectionId}",
connection.ConnectionId);
_logger.LogWarning(ex, "Failed to send heartbeat");
}
}
// Update connection state local heartbeat times
foreach (var connection in _connections.Values)
{
connection.LastHeartbeatUtc = DateTime.UtcNow;
}
}
catch (OperationCanceledException)
{

View File

@@ -22,17 +22,34 @@ public static class ServiceCollectionExtensions
ArgumentNullException.ThrowIfNull(services);
ArgumentNullException.ThrowIfNull(configure);
// Configure options
// Configure and register options as singleton
var options = new StellaMicroserviceOptions { ServiceName = "", Version = "1.0.0", Region = "" };
configure(options);
services.AddSingleton(options);
services.Configure(configure);
// Register endpoint discovery
services.TryAddSingleton<IEndpointDiscoveryProvider>(sp =>
// Register YAML loader and merger
services.TryAddSingleton<IMicroserviceYamlLoader, MicroserviceYamlLoader>();
services.TryAddSingleton<IEndpointOverrideMerger, EndpointOverrideMerger>();
// Register endpoint discovery provider (prefers generated over reflection)
services.TryAddSingleton<IEndpointDiscoveryProvider, GeneratedEndpointDiscoveryProvider>();
// Register endpoint discovery service (with YAML integration)
services.TryAddSingleton<IEndpointDiscoveryService, EndpointDiscoveryService>();
// Register endpoint registry (using discovery service)
services.TryAddSingleton<IEndpointRegistry>(sp =>
{
var options = new StellaMicroserviceOptions { ServiceName = "", Version = "1.0.0", Region = "" };
configure(options);
return new ReflectionEndpointDiscoveryProvider(options);
var discoveryService = sp.GetRequiredService<IEndpointDiscoveryService>();
var registry = new EndpointRegistry();
registry.RegisterAll(discoveryService.DiscoverEndpoints());
return registry;
});
// Register request dispatcher
services.TryAddSingleton<RequestDispatcher>();
// Register connection manager
services.TryAddSingleton<IRouterConnectionManager, RouterConnectionManager>();
@@ -57,12 +74,34 @@ public static class ServiceCollectionExtensions
ArgumentNullException.ThrowIfNull(services);
ArgumentNullException.ThrowIfNull(configure);
// Configure options
// Configure and register options as singleton
var options = new StellaMicroserviceOptions { ServiceName = "", Version = "1.0.0", Region = "" };
configure(options);
services.AddSingleton(options);
services.Configure(configure);
// Register YAML loader and merger
services.TryAddSingleton<IMicroserviceYamlLoader, MicroserviceYamlLoader>();
services.TryAddSingleton<IEndpointOverrideMerger, EndpointOverrideMerger>();
// Register custom endpoint discovery
services.TryAddSingleton<IEndpointDiscoveryProvider, TDiscovery>();
// Register endpoint discovery service (with YAML integration)
services.TryAddSingleton<IEndpointDiscoveryService, EndpointDiscoveryService>();
// Register endpoint registry (using discovery service)
services.TryAddSingleton<IEndpointRegistry>(sp =>
{
var discoveryService = sp.GetRequiredService<IEndpointDiscoveryService>();
var registry = new EndpointRegistry();
registry.RegisterAll(discoveryService.DiscoverEndpoints());
return registry;
});
// Register request dispatcher
services.TryAddSingleton<RequestDispatcher>();
// Register connection manager
services.TryAddSingleton<IRouterConnectionManager, RouterConnectionManager>();
@@ -71,4 +110,17 @@ public static class ServiceCollectionExtensions
return services;
}
/// <summary>
/// Registers an endpoint handler type for dependency injection.
/// </summary>
/// <typeparam name="THandler">The endpoint handler type.</typeparam>
/// <param name="services">The service collection.</param>
/// <returns>The service collection for chaining.</returns>
public static IServiceCollection AddStellaEndpoint<THandler>(this IServiceCollection services)
where THandler : class, IStellaEndpoint
{
services.AddScoped<THandler>();
return services;
}
}

View File

@@ -11,6 +11,7 @@
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Options" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="YamlDotNet" Version="13.7.1" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />

View File

@@ -0,0 +1,164 @@
using System.Threading.Channels;
namespace StellaOps.Microservice.Streaming;
/// <summary>
/// A read-only stream that reads from a channel of data chunks.
/// Used to expose streaming request body to handlers.
/// </summary>
public sealed class StreamingRequestBodyStream : Stream
{
private readonly ChannelReader<StreamChunk> _reader;
private readonly CancellationToken _cancellationToken;
private byte[] _currentBuffer = [];
private int _currentBufferPosition;
private bool _endOfStream;
private bool _disposed;
/// <summary>
/// Initializes a new instance of the <see cref="StreamingRequestBodyStream"/> class.
/// </summary>
/// <param name="reader">The channel reader for incoming chunks.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public StreamingRequestBodyStream(
ChannelReader<StreamChunk> reader,
CancellationToken cancellationToken)
{
_reader = reader;
_cancellationToken = cancellationToken;
}
/// <inheritdoc />
public override bool CanRead => true;
/// <inheritdoc />
public override bool CanSeek => false;
/// <inheritdoc />
public override bool CanWrite => false;
/// <inheritdoc />
public override long Length => throw new NotSupportedException("Streaming body length unknown.");
/// <inheritdoc />
public override long Position
{
get => throw new NotSupportedException("Streaming body position not supported.");
set => throw new NotSupportedException("Streaming body position not supported.");
}
/// <inheritdoc />
public override void Flush() { }
/// <inheritdoc />
public override int Read(byte[] buffer, int offset, int count)
{
return ReadAsync(buffer, offset, count, CancellationToken.None)
.GetAwaiter().GetResult();
}
/// <inheritdoc />
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return await ReadAsync(buffer.AsMemory(offset, count), cancellationToken);
}
/// <inheritdoc />
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
ObjectDisposedException.ThrowIf(_disposed, this);
if (_endOfStream)
{
return 0;
}
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cancellationToken);
// Try to use remaining data from current buffer first
if (_currentBufferPosition < _currentBuffer.Length)
{
var bytesToCopy = Math.Min(buffer.Length, _currentBuffer.Length - _currentBufferPosition);
_currentBuffer.AsSpan(_currentBufferPosition, bytesToCopy).CopyTo(buffer.Span);
_currentBufferPosition += bytesToCopy;
return bytesToCopy;
}
// Need to read next chunk from channel
if (!await _reader.WaitToReadAsync(linkedCts.Token))
{
_endOfStream = true;
return 0;
}
if (!_reader.TryRead(out var chunk))
{
_endOfStream = true;
return 0;
}
if (chunk.EndOfStream)
{
_endOfStream = true;
// Still process any data in the final chunk
if (chunk.Data.Length == 0)
{
return 0;
}
}
_currentBuffer = chunk.Data;
_currentBufferPosition = 0;
var bytesToReturn = Math.Min(buffer.Length, _currentBuffer.Length);
_currentBuffer.AsSpan(0, bytesToReturn).CopyTo(buffer.Span);
_currentBufferPosition = bytesToReturn;
return bytesToReturn;
}
/// <inheritdoc />
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException("Seeking not supported on streaming body.");
}
/// <inheritdoc />
public override void SetLength(long value)
{
throw new NotSupportedException("Setting length not supported on streaming body.");
}
/// <inheritdoc />
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException("Write not supported on streaming body.");
}
/// <inheritdoc />
protected override void Dispose(bool disposing)
{
_disposed = true;
base.Dispose(disposing);
}
}
/// <summary>
/// Represents a chunk of streaming data.
/// </summary>
public sealed record StreamChunk
{
/// <summary>
/// Gets the chunk data.
/// </summary>
public byte[] Data { get; init; } = [];
/// <summary>
/// Gets a value indicating whether this is the final chunk.
/// </summary>
public bool EndOfStream { get; init; }
/// <summary>
/// Gets the sequence number.
/// </summary>
public int SequenceNumber { get; init; }
}

View File

@@ -0,0 +1,191 @@
using System.Threading.Channels;
namespace StellaOps.Microservice.Streaming;
/// <summary>
/// A write-only stream that writes chunks to a channel.
/// Used to enable streaming response body from handlers.
/// </summary>
public sealed class StreamingResponseBodyStream : Stream
{
private readonly ChannelWriter<StreamChunk> _writer;
private readonly int _chunkSize;
private readonly CancellationToken _cancellationToken;
private byte[] _buffer;
private int _bufferPosition;
private int _sequenceNumber;
private bool _disposed;
/// <summary>
/// Initializes a new instance of the <see cref="StreamingResponseBodyStream"/> class.
/// </summary>
/// <param name="writer">The channel writer for outgoing chunks.</param>
/// <param name="chunkSize">The chunk size for buffered writes.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public StreamingResponseBodyStream(
ChannelWriter<StreamChunk> writer,
int chunkSize,
CancellationToken cancellationToken)
{
_writer = writer;
_chunkSize = chunkSize;
_cancellationToken = cancellationToken;
_buffer = new byte[chunkSize];
}
/// <inheritdoc />
public override bool CanRead => false;
/// <inheritdoc />
public override bool CanSeek => false;
/// <inheritdoc />
public override bool CanWrite => true;
/// <inheritdoc />
public override long Length => throw new NotSupportedException();
/// <inheritdoc />
public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}
/// <inheritdoc />
public override void Flush()
{
FlushAsync(CancellationToken.None).GetAwaiter().GetResult();
}
/// <inheritdoc />
public override async Task FlushAsync(CancellationToken cancellationToken)
{
if (_bufferPosition > 0)
{
var chunk = new StreamChunk
{
Data = _buffer[.._bufferPosition],
SequenceNumber = _sequenceNumber++,
EndOfStream = false
};
await _writer.WriteAsync(chunk, cancellationToken);
_buffer = new byte[_chunkSize];
_bufferPosition = 0;
}
}
/// <inheritdoc />
public override int Read(byte[] buffer, int offset, int count)
{
throw new NotSupportedException("Read not supported on streaming response body.");
}
/// <inheritdoc />
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException("Seeking not supported on streaming response body.");
}
/// <inheritdoc />
public override void SetLength(long value)
{
throw new NotSupportedException("Setting length not supported on streaming response body.");
}
/// <inheritdoc />
public override void Write(byte[] buffer, int offset, int count)
{
WriteAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult();
}
/// <inheritdoc />
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await WriteAsync(buffer.AsMemory(offset, count), cancellationToken);
}
/// <inheritdoc />
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
ObjectDisposedException.ThrowIf(_disposed, this);
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cancellationToken);
var bytesWritten = 0;
while (bytesWritten < buffer.Length)
{
var spaceInBuffer = _chunkSize - _bufferPosition;
var bytesToWrite = Math.Min(spaceInBuffer, buffer.Length - bytesWritten);
buffer.Slice(bytesWritten, bytesToWrite).Span.CopyTo(_buffer.AsSpan(_bufferPosition));
_bufferPosition += bytesToWrite;
bytesWritten += bytesToWrite;
if (_bufferPosition >= _chunkSize)
{
await FlushAsync(linkedCts.Token);
}
}
}
/// <summary>
/// Completes the stream by flushing remaining data and sending end-of-stream signal.
/// </summary>
public async Task CompleteAsync(CancellationToken cancellationToken = default)
{
// Flush any remaining buffered data
await FlushAsync(cancellationToken);
// Send end-of-stream marker
var endChunk = new StreamChunk
{
Data = [],
SequenceNumber = _sequenceNumber++,
EndOfStream = true
};
await _writer.WriteAsync(endChunk, cancellationToken);
_writer.Complete();
}
/// <inheritdoc />
protected override void Dispose(bool disposing)
{
if (!_disposed && disposing)
{
// Try to complete the stream if not already completed
try
{
_writer.TryComplete();
}
catch
{
// Ignore errors during disposal
}
}
_disposed = true;
base.Dispose(disposing);
}
/// <inheritdoc />
public override async ValueTask DisposeAsync()
{
if (!_disposed)
{
try
{
await CompleteAsync(CancellationToken.None);
}
catch
{
// Ignore errors during disposal
}
}
_disposed = true;
await base.DisposeAsync();
}
}

View File

@@ -7,6 +7,38 @@ namespace StellaOps.Router.Common.Abstractions;
/// </summary>
public interface IGlobalRoutingState
{
/// <summary>
/// Adds a connection to the routing state.
/// </summary>
/// <param name="connection">The connection state to add.</param>
void AddConnection(ConnectionState connection);
/// <summary>
/// Removes a connection from the routing state.
/// </summary>
/// <param name="connectionId">The connection ID to remove.</param>
void RemoveConnection(string connectionId);
/// <summary>
/// Updates an existing connection's state.
/// </summary>
/// <param name="connectionId">The connection ID to update.</param>
/// <param name="update">The update action to apply.</param>
void UpdateConnection(string connectionId, Action<ConnectionState> update);
/// <summary>
/// Gets a connection by its ID.
/// </summary>
/// <param name="connectionId">The connection ID.</param>
/// <returns>The connection state, or null if not found.</returns>
ConnectionState? GetConnection(string connectionId);
/// <summary>
/// Gets all active connections.
/// </summary>
/// <returns>All active connections.</returns>
IReadOnlyList<ConnectionState> GetAllConnections();
/// <summary>
/// Resolves an HTTP request to an endpoint descriptor.
/// </summary>

View File

@@ -0,0 +1,43 @@
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Common.Abstractions;
/// <summary>
/// Represents a transport connection from a microservice to the gateway.
/// This interface is used by the Microservice SDK to communicate with the router.
/// </summary>
public interface IMicroserviceTransport
{
/// <summary>
/// Connects to the router and registers the microservice.
/// </summary>
/// <param name="instance">The instance descriptor.</param>
/// <param name="endpoints">The endpoints to register.</param>
/// <param name="cancellationToken">Cancellation token.</param>
Task ConnectAsync(
InstanceDescriptor instance,
IReadOnlyList<EndpointDescriptor> endpoints,
CancellationToken cancellationToken);
/// <summary>
/// Disconnects from the router.
/// </summary>
Task DisconnectAsync();
/// <summary>
/// Sends a heartbeat to the router.
/// </summary>
/// <param name="heartbeat">The heartbeat payload.</param>
/// <param name="cancellationToken">Cancellation token.</param>
Task SendHeartbeatAsync(HeartbeatPayload heartbeat, CancellationToken cancellationToken);
/// <summary>
/// Event raised when a REQUEST frame is received from the gateway.
/// </summary>
event Func<Frame, CancellationToken, Task<Frame>>? OnRequestReceived;
/// <summary>
/// Event raised when a CANCEL frame is received from the gateway.
/// </summary>
event Func<Guid, string?, Task>? OnCancelReceived;
}

View File

@@ -0,0 +1,148 @@
using System.Text.Json;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Common.Frames;
/// <summary>
/// Converts between generic Frame and typed frame records.
/// </summary>
public static class FrameConverter
{
private static readonly JsonSerializerOptions JsonOptions = new()
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
PropertyNameCaseInsensitive = true
};
/// <summary>
/// Converts a RequestFrame to a generic Frame for transport.
/// </summary>
public static Frame ToFrame(RequestFrame request)
{
var envelope = new RequestEnvelope
{
RequestId = request.RequestId,
Method = request.Method,
Path = request.Path,
Headers = request.Headers,
TimeoutSeconds = request.TimeoutSeconds,
SupportsStreaming = request.SupportsStreaming,
Payload = request.Payload.ToArray()
};
var envelopeBytes = JsonSerializer.SerializeToUtf8Bytes(envelope, JsonOptions);
return new Frame
{
Type = FrameType.Request,
CorrelationId = request.CorrelationId ?? request.RequestId,
Payload = envelopeBytes
};
}
/// <summary>
/// Converts a generic Frame to a RequestFrame.
/// </summary>
public static RequestFrame? ToRequestFrame(Frame frame)
{
if (frame.Type != FrameType.Request)
return null;
try
{
var envelope = JsonSerializer.Deserialize<RequestEnvelope>(frame.Payload.Span, JsonOptions);
if (envelope is null)
return null;
return new RequestFrame
{
RequestId = envelope.RequestId,
CorrelationId = frame.CorrelationId,
Method = envelope.Method,
Path = envelope.Path,
Headers = envelope.Headers ?? new Dictionary<string, string>(),
TimeoutSeconds = envelope.TimeoutSeconds,
SupportsStreaming = envelope.SupportsStreaming,
Payload = envelope.Payload ?? []
};
}
catch (JsonException)
{
return null;
}
}
/// <summary>
/// Converts a ResponseFrame to a generic Frame for transport.
/// </summary>
public static Frame ToFrame(ResponseFrame response)
{
var envelope = new ResponseEnvelope
{
RequestId = response.RequestId,
StatusCode = response.StatusCode,
Headers = response.Headers,
HasMoreChunks = response.HasMoreChunks,
Payload = response.Payload.ToArray()
};
var envelopeBytes = JsonSerializer.SerializeToUtf8Bytes(envelope, JsonOptions);
return new Frame
{
Type = FrameType.Response,
CorrelationId = response.RequestId,
Payload = envelopeBytes
};
}
/// <summary>
/// Converts a generic Frame to a ResponseFrame.
/// </summary>
public static ResponseFrame? ToResponseFrame(Frame frame)
{
if (frame.Type != FrameType.Response)
return null;
try
{
var envelope = JsonSerializer.Deserialize<ResponseEnvelope>(frame.Payload.Span, JsonOptions);
if (envelope is null)
return null;
return new ResponseFrame
{
RequestId = envelope.RequestId,
StatusCode = envelope.StatusCode,
Headers = envelope.Headers ?? new Dictionary<string, string>(),
HasMoreChunks = envelope.HasMoreChunks,
Payload = envelope.Payload ?? []
};
}
catch (JsonException)
{
return null;
}
}
private sealed class RequestEnvelope
{
public required string RequestId { get; set; }
public required string Method { get; set; }
public required string Path { get; set; }
public IReadOnlyDictionary<string, string>? Headers { get; set; }
public int TimeoutSeconds { get; set; } = 30;
public bool SupportsStreaming { get; set; }
public byte[]? Payload { get; set; }
}
private sealed class ResponseEnvelope
{
public required string RequestId { get; set; }
public int StatusCode { get; set; } = 200;
public IReadOnlyDictionary<string, string>? Headers { get; set; }
public bool HasMoreChunks { get; set; }
public byte[]? Payload { get; set; }
}
}

View File

@@ -0,0 +1,47 @@
namespace StellaOps.Router.Common.Frames;
/// <summary>
/// Represents a REQUEST frame sent from gateway to microservice.
/// </summary>
public sealed record RequestFrame
{
/// <summary>
/// Gets the unique request ID for this request.
/// </summary>
public required string RequestId { get; init; }
/// <summary>
/// Gets the correlation ID for distributed tracing.
/// </summary>
public string? CorrelationId { get; init; }
/// <summary>
/// Gets the HTTP method (GET, POST, PUT, DELETE, etc.).
/// </summary>
public required string Method { get; init; }
/// <summary>
/// Gets the request path.
/// </summary>
public required string Path { get; init; }
/// <summary>
/// Gets the request headers.
/// </summary>
public IReadOnlyDictionary<string, string> Headers { get; init; } = new Dictionary<string, string>();
/// <summary>
/// Gets the request payload (body).
/// </summary>
public ReadOnlyMemory<byte> Payload { get; init; }
/// <summary>
/// Gets the timeout in seconds for this request.
/// </summary>
public int TimeoutSeconds { get; init; } = 30;
/// <summary>
/// Gets whether this request supports streaming response.
/// </summary>
public bool SupportsStreaming { get; init; }
}

View File

@@ -0,0 +1,32 @@
namespace StellaOps.Router.Common.Frames;
/// <summary>
/// Represents a RESPONSE frame sent from microservice to gateway.
/// </summary>
public sealed record ResponseFrame
{
/// <summary>
/// Gets the request ID this response is for.
/// </summary>
public required string RequestId { get; init; }
/// <summary>
/// Gets the HTTP status code.
/// </summary>
public int StatusCode { get; init; } = 200;
/// <summary>
/// Gets the response headers.
/// </summary>
public IReadOnlyDictionary<string, string> Headers { get; init; } = new Dictionary<string, string>();
/// <summary>
/// Gets the response payload (body).
/// </summary>
public ReadOnlyMemory<byte> Payload { get; init; }
/// <summary>
/// Gets whether there are more streaming chunks to follow.
/// </summary>
public bool HasMoreChunks { get; init; }
}

View File

@@ -10,3 +10,34 @@ public sealed record CancelPayload
/// </summary>
public string? Reason { get; init; }
}
/// <summary>
/// Standard reasons for request cancellation.
/// </summary>
public static class CancelReasons
{
/// <summary>
/// The HTTP client disconnected before the request completed.
/// </summary>
public const string ClientDisconnected = "ClientDisconnected";
/// <summary>
/// The request exceeded its timeout.
/// </summary>
public const string Timeout = "Timeout";
/// <summary>
/// The request or response payload exceeded configured limits.
/// </summary>
public const string PayloadLimitExceeded = "PayloadLimitExceeded";
/// <summary>
/// The gateway or microservice is shutting down.
/// </summary>
public const string Shutdown = "Shutdown";
/// <summary>
/// The transport connection was closed unexpectedly.
/// </summary>
public const string ConnectionClosed = "ConnectionClosed";
}

View File

@@ -39,4 +39,10 @@ public sealed record EndpointDescriptor
/// Gets a value indicating whether this endpoint supports streaming.
/// </summary>
public bool SupportsStreaming { get; init; }
/// <summary>
/// Gets the handler type that processes requests for this endpoint.
/// This is used by the Microservice SDK for handler resolution.
/// </summary>
public Type? HandlerType { get; init; }
}

View File

@@ -0,0 +1,27 @@
namespace StellaOps.Router.Common.Models;
/// <summary>
/// Payload for streaming data frames (REQUEST_STREAM_DATA/RESPONSE_STREAM_DATA).
/// </summary>
public sealed record StreamDataPayload
{
/// <summary>
/// Gets the correlation ID linking stream data to the original request.
/// </summary>
public required Guid CorrelationId { get; init; }
/// <summary>
/// Gets the stream data chunk.
/// </summary>
public byte[] Data { get; init; } = [];
/// <summary>
/// Gets a value indicating whether this is the final chunk.
/// </summary>
public bool EndOfStream { get; init; }
/// <summary>
/// Gets the sequence number for ordering.
/// </summary>
public int SequenceNumber { get; init; }
}

View File

@@ -0,0 +1,36 @@
namespace StellaOps.Router.Common.Models;
/// <summary>
/// Configuration options for streaming operations.
/// </summary>
public sealed record StreamingOptions
{
/// <summary>
/// Gets the default streaming options.
/// </summary>
public static readonly StreamingOptions Default = new();
/// <summary>
/// Gets the size of each chunk when streaming data.
/// Default: 64 KB.
/// </summary>
public int ChunkSize { get; init; } = 64 * 1024;
/// <summary>
/// Gets the maximum number of concurrent streams per connection.
/// Default: 100.
/// </summary>
public int MaxConcurrentStreams { get; init; } = 100;
/// <summary>
/// Gets the timeout for idle streams (no data flowing).
/// Default: 5 minutes.
/// </summary>
public TimeSpan StreamIdleTimeout { get; init; } = TimeSpan.FromMinutes(5);
/// <summary>
/// Gets the channel capacity for buffered stream data.
/// Default: 16 chunks.
/// </summary>
public int ChannelCapacity { get; init; } = 16;
}

View File

@@ -0,0 +1,85 @@
using System.Text.RegularExpressions;
namespace StellaOps.Router.Common;
/// <summary>
/// Matches request paths against route templates.
/// </summary>
public sealed partial class PathMatcher
{
private readonly string _template;
private readonly Regex _regex;
private readonly string[] _parameterNames;
private readonly bool _caseInsensitive;
/// <summary>
/// Gets the route template.
/// </summary>
public string Template => _template;
/// <summary>
/// Initializes a new instance of the <see cref="PathMatcher"/> class.
/// </summary>
/// <param name="template">The route template (e.g., "/api/users/{id}").</param>
/// <param name="caseInsensitive">Whether matching should be case-insensitive.</param>
public PathMatcher(string template, bool caseInsensitive = true)
{
_template = template;
_caseInsensitive = caseInsensitive;
// Extract parameter names and build regex
var paramNames = new List<string>();
var pattern = "^" + ParameterRegex().Replace(template, match =>
{
paramNames.Add(match.Groups[1].Value);
return "([^/]+)";
}) + "/?$";
var options = caseInsensitive ? RegexOptions.IgnoreCase : RegexOptions.None;
_regex = new Regex(pattern, options | RegexOptions.Compiled);
_parameterNames = [.. paramNames];
}
/// <summary>
/// Tries to match a path against the template.
/// </summary>
/// <param name="path">The request path.</param>
/// <param name="parameters">The extracted path parameters if matched.</param>
/// <returns>True if the path matches.</returns>
public bool TryMatch(string path, out Dictionary<string, string> parameters)
{
parameters = [];
// Normalize path
path = path.TrimEnd('/');
if (!path.StartsWith('/'))
path = "/" + path;
var match = _regex.Match(path);
if (!match.Success)
return false;
for (int i = 0; i < _parameterNames.Length; i++)
{
parameters[_parameterNames[i]] = match.Groups[i + 1].Value;
}
return true;
}
/// <summary>
/// Checks if a path matches the template.
/// </summary>
/// <param name="path">The request path.</param>
/// <returns>True if the path matches.</returns>
public bool IsMatch(string path)
{
path = path.TrimEnd('/');
if (!path.StartsWith('/'))
path = "/" + path;
return _regex.IsMatch(path);
}
[GeneratedRegex(@"\{([^}:]+)(?::[^}]+)?\}")]
private static partial Regex ParameterRegex();
}

View File

@@ -0,0 +1,94 @@
namespace StellaOps.Router.Config;
/// <summary>
/// Provides access to router configuration with hot-reload support.
/// </summary>
public interface IRouterConfigProvider
{
/// <summary>
/// Gets the current router configuration.
/// </summary>
RouterConfig Current { get; }
/// <summary>
/// Gets the current router configuration options.
/// </summary>
RouterConfigOptions Options { get; }
/// <summary>
/// Raised when the configuration is reloaded.
/// </summary>
event EventHandler<ConfigChangedEventArgs>? ConfigurationChanged;
/// <summary>
/// Reloads the configuration from the source.
/// </summary>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>A task representing the reload operation.</returns>
Task ReloadAsync(CancellationToken cancellationToken = default);
/// <summary>
/// Validates the current configuration.
/// </summary>
/// <returns>Validation result.</returns>
ConfigValidationResult Validate();
}
/// <summary>
/// Event arguments for configuration changes.
/// </summary>
public sealed class ConfigChangedEventArgs : EventArgs
{
/// <summary>
/// Initializes a new instance of the <see cref="ConfigChangedEventArgs"/> class.
/// </summary>
/// <param name="previous">The previous configuration.</param>
/// <param name="current">The current configuration.</param>
public ConfigChangedEventArgs(RouterConfig previous, RouterConfig current)
{
Previous = previous;
Current = current;
ChangedAt = DateTime.UtcNow;
}
/// <summary>
/// Gets the previous configuration.
/// </summary>
public RouterConfig Previous { get; }
/// <summary>
/// Gets the current configuration.
/// </summary>
public RouterConfig Current { get; }
/// <summary>
/// Gets the time the configuration was changed.
/// </summary>
public DateTime ChangedAt { get; }
}
/// <summary>
/// Result of configuration validation.
/// </summary>
public sealed class ConfigValidationResult
{
/// <summary>
/// Gets whether the configuration is valid.
/// </summary>
public bool IsValid => Errors.Count == 0;
/// <summary>
/// Gets the validation errors.
/// </summary>
public List<string> Errors { get; init; } = [];
/// <summary>
/// Gets the validation warnings.
/// </summary>
public List<string> Warnings { get; init; } = [];
/// <summary>
/// A successful validation result.
/// </summary>
public static ConfigValidationResult Success => new();
}

View File

@@ -12,8 +12,18 @@ public sealed class RouterConfig
/// </summary>
public PayloadLimits PayloadLimits { get; set; } = new();
/// <summary>
/// Gets or sets the routing options.
/// </summary>
public RoutingOptions Routing { get; set; } = new();
/// <summary>
/// Gets or sets the service configurations.
/// </summary>
public List<ServiceConfig> Services { get; set; } = [];
/// <summary>
/// Gets or sets the static instance configurations.
/// </summary>
public List<StaticInstanceConfig> StaticInstances { get; set; } = [];
}

View File

@@ -0,0 +1,39 @@
namespace StellaOps.Router.Config;
/// <summary>
/// Options for the router configuration provider.
/// </summary>
public sealed class RouterConfigOptions
{
/// <summary>
/// Gets or sets the path to the router configuration file (YAML or JSON).
/// </summary>
public string? ConfigPath { get; set; }
/// <summary>
/// Gets or sets the environment variable prefix for overrides.
/// Default: "STELLAOPS_ROUTER_".
/// </summary>
public string EnvironmentVariablePrefix { get; set; } = "STELLAOPS_ROUTER_";
/// <summary>
/// Gets or sets whether to enable hot-reload of configuration.
/// </summary>
public bool EnableHotReload { get; set; } = true;
/// <summary>
/// Gets or sets the debounce interval for file change notifications.
/// </summary>
public TimeSpan DebounceInterval { get; set; } = TimeSpan.FromMilliseconds(500);
/// <summary>
/// Gets or sets whether to throw on configuration validation errors.
/// If false, keeps the previous valid configuration.
/// </summary>
public bool ThrowOnValidationError { get; set; } = false;
/// <summary>
/// Gets or sets the configuration section name in appsettings.json.
/// </summary>
public string ConfigurationSection { get; set; } = "Router";
}

View File

@@ -0,0 +1,321 @@
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
namespace StellaOps.Router.Config;
/// <summary>
/// Provides router configuration with hot-reload support.
/// </summary>
public sealed class RouterConfigProvider : IRouterConfigProvider, IDisposable
{
private readonly RouterConfigOptions _options;
private readonly ILogger<RouterConfigProvider> _logger;
private readonly FileSystemWatcher? _watcher;
private readonly SemaphoreSlim _reloadLock = new(1, 1);
private readonly Timer? _debounceTimer;
private RouterConfig _current;
private bool _disposed;
/// <inheritdoc />
public event EventHandler<ConfigChangedEventArgs>? ConfigurationChanged;
/// <summary>
/// Initializes a new instance of the <see cref="RouterConfigProvider"/> class.
/// </summary>
public RouterConfigProvider(
IOptions<RouterConfigOptions> options,
ILogger<RouterConfigProvider> logger)
{
_options = options.Value;
_logger = logger;
_current = LoadConfiguration();
if (_options.EnableHotReload && !string.IsNullOrEmpty(_options.ConfigPath) && File.Exists(_options.ConfigPath))
{
var directory = Path.GetDirectoryName(Path.GetFullPath(_options.ConfigPath))!;
var fileName = Path.GetFileName(_options.ConfigPath);
_watcher = new FileSystemWatcher(directory)
{
Filter = fileName,
NotifyFilter = NotifyFilters.LastWrite | NotifyFilters.Size
};
_debounceTimer = new Timer(OnDebounceElapsed, null, Timeout.Infinite, Timeout.Infinite);
_watcher.Changed += OnFileChanged;
_watcher.EnableRaisingEvents = true;
_logger.LogInformation("Hot-reload enabled for configuration file: {Path}", _options.ConfigPath);
}
}
/// <inheritdoc />
public RouterConfig Current => _current;
/// <inheritdoc />
public RouterConfigOptions Options => _options;
private void OnFileChanged(object sender, FileSystemEventArgs e)
{
// Debounce rapid file changes (e.g., from editors saving multiple times)
_debounceTimer?.Change(_options.DebounceInterval, Timeout.InfiniteTimeSpan);
}
private void OnDebounceElapsed(object? state)
{
_ = ReloadAsyncInternal();
}
private async Task ReloadAsyncInternal()
{
if (!await _reloadLock.WaitAsync(TimeSpan.Zero))
{
// Another reload is in progress
return;
}
try
{
var previous = _current;
var newConfig = LoadConfiguration();
var validation = ValidateConfig(newConfig);
if (!validation.IsValid)
{
if (_options.ThrowOnValidationError)
{
throw new ConfigurationException(
$"Configuration validation failed: {string.Join("; ", validation.Errors)}");
}
_logger.LogError(
"Configuration validation failed, keeping previous: {Errors}",
string.Join("; ", validation.Errors));
return;
}
foreach (var warning in validation.Warnings)
{
_logger.LogWarning("Configuration warning: {Warning}", warning);
}
_current = newConfig;
_logger.LogInformation("Router configuration reloaded successfully");
ConfigurationChanged?.Invoke(this, new ConfigChangedEventArgs(previous, newConfig));
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to reload configuration, keeping previous");
if (_options.ThrowOnValidationError)
{
throw;
}
}
finally
{
_reloadLock.Release();
}
}
/// <inheritdoc />
public async Task ReloadAsync(CancellationToken cancellationToken = default)
{
await _reloadLock.WaitAsync(cancellationToken);
try
{
var previous = _current;
var newConfig = LoadConfiguration();
var validation = ValidateConfig(newConfig);
if (!validation.IsValid)
{
throw new ConfigurationException(
$"Configuration validation failed: {string.Join("; ", validation.Errors)}");
}
_current = newConfig;
_logger.LogInformation("Router configuration reloaded successfully");
ConfigurationChanged?.Invoke(this, new ConfigChangedEventArgs(previous, newConfig));
}
finally
{
_reloadLock.Release();
}
}
/// <inheritdoc />
public ConfigValidationResult Validate() => ValidateConfig(_current);
private RouterConfig LoadConfiguration()
{
var builder = new ConfigurationBuilder();
// Load from YAML file if specified
if (!string.IsNullOrEmpty(_options.ConfigPath))
{
var extension = Path.GetExtension(_options.ConfigPath).ToLowerInvariant();
var fullPath = Path.GetFullPath(_options.ConfigPath);
if (File.Exists(fullPath))
{
switch (extension)
{
case ".yaml":
case ".yml":
builder.AddYamlFile(fullPath, optional: true, reloadOnChange: false);
break;
case ".json":
builder.AddJsonFile(fullPath, optional: true, reloadOnChange: false);
break;
default:
_logger.LogWarning("Unknown configuration file extension: {Extension}", extension);
break;
}
}
else
{
_logger.LogWarning("Configuration file not found: {Path}", fullPath);
}
}
// Add environment variable overrides
builder.AddEnvironmentVariables(prefix: _options.EnvironmentVariablePrefix);
var configuration = builder.Build();
var config = new RouterConfig();
configuration.Bind(config);
return config;
}
private static ConfigValidationResult ValidateConfig(RouterConfig config)
{
var result = new ConfigValidationResult();
// Validate payload limits
if (config.PayloadLimits.MaxRequestBytesPerCall <= 0)
{
result.Errors.Add("PayloadLimits.MaxRequestBytesPerCall must be positive");
}
if (config.PayloadLimits.MaxRequestBytesPerConnection <= 0)
{
result.Errors.Add("PayloadLimits.MaxRequestBytesPerConnection must be positive");
}
if (config.PayloadLimits.MaxAggregateInflightBytes <= 0)
{
result.Errors.Add("PayloadLimits.MaxAggregateInflightBytes must be positive");
}
if (config.PayloadLimits.MaxRequestBytesPerCall > config.PayloadLimits.MaxRequestBytesPerConnection)
{
result.Warnings.Add("MaxRequestBytesPerCall is larger than MaxRequestBytesPerConnection");
}
// Validate routing options
if (config.Routing.DefaultTimeout <= TimeSpan.Zero)
{
result.Errors.Add("Routing.DefaultTimeout must be positive");
}
// Validate services
var serviceNames = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
foreach (var service in config.Services)
{
if (string.IsNullOrWhiteSpace(service.ServiceName))
{
result.Errors.Add("Service name cannot be empty");
continue;
}
if (!serviceNames.Add(service.ServiceName))
{
result.Errors.Add($"Duplicate service name: {service.ServiceName}");
}
foreach (var endpoint in service.Endpoints)
{
if (string.IsNullOrWhiteSpace(endpoint.Method))
{
result.Errors.Add($"Service {service.ServiceName}: endpoint method cannot be empty");
}
if (string.IsNullOrWhiteSpace(endpoint.Path))
{
result.Errors.Add($"Service {service.ServiceName}: endpoint path cannot be empty");
}
if (endpoint.DefaultTimeout.HasValue && endpoint.DefaultTimeout.Value <= TimeSpan.Zero)
{
result.Warnings.Add(
$"Service {service.ServiceName}: endpoint {endpoint.Method} {endpoint.Path} has non-positive timeout");
}
}
}
// Validate static instances
foreach (var instance in config.StaticInstances)
{
if (string.IsNullOrWhiteSpace(instance.ServiceName))
{
result.Errors.Add("Static instance service name cannot be empty");
}
if (string.IsNullOrWhiteSpace(instance.Host))
{
result.Errors.Add($"Static instance {instance.ServiceName}: host cannot be empty");
}
if (instance.Port <= 0 || instance.Port > 65535)
{
result.Errors.Add($"Static instance {instance.ServiceName}: port must be between 1 and 65535");
}
if (instance.Weight <= 0)
{
result.Warnings.Add($"Static instance {instance.ServiceName}: weight should be positive");
}
}
return result;
}
/// <inheritdoc />
public void Dispose()
{
if (_disposed) return;
_disposed = true;
_watcher?.Dispose();
_debounceTimer?.Dispose();
_reloadLock.Dispose();
}
}
/// <summary>
/// Exception thrown when configuration is invalid.
/// </summary>
public sealed class ConfigurationException : Exception
{
/// <summary>
/// Initializes a new instance of the <see cref="ConfigurationException"/> class.
/// </summary>
public ConfigurationException(string message) : base(message)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="ConfigurationException"/> class.
/// </summary>
public ConfigurationException(string message, Exception innerException) : base(message, innerException)
{
}
}

View File

@@ -0,0 +1,58 @@
namespace StellaOps.Router.Config;
/// <summary>
/// Routing behavior options.
/// </summary>
public sealed class RoutingOptions
{
/// <summary>
/// Gets or sets the local region for routing preferences.
/// </summary>
public string LocalRegion { get; set; } = "default";
/// <summary>
/// Gets or sets the neighbor regions for fallback routing.
/// </summary>
public List<string> NeighborRegions { get; set; } = [];
/// <summary>
/// Gets or sets the tie-breaker strategy for equal-weight instances.
/// </summary>
public TieBreakerStrategy TieBreaker { get; set; } = TieBreakerStrategy.RoundRobin;
/// <summary>
/// Gets or sets whether to prefer local region instances.
/// </summary>
public bool PreferLocalRegion { get; set; } = true;
/// <summary>
/// Gets or sets the default request timeout.
/// </summary>
public TimeSpan DefaultTimeout { get; set; } = TimeSpan.FromSeconds(30);
}
/// <summary>
/// Tie-breaker strategy for routing decisions.
/// </summary>
public enum TieBreakerStrategy
{
/// <summary>
/// Round-robin between equal-weight instances.
/// </summary>
RoundRobin,
/// <summary>
/// Random selection between equal-weight instances.
/// </summary>
Random,
/// <summary>
/// Select the least-loaded instance.
/// </summary>
LeastLoaded,
/// <summary>
/// Consistent hashing based on request attributes.
/// </summary>
ConsistentHash
}

View File

@@ -0,0 +1,108 @@
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
namespace StellaOps.Router.Config;
/// <summary>
/// Extension methods for registering router configuration services.
/// </summary>
public static class ServiceCollectionExtensions
{
/// <summary>
/// Adds router configuration services to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configPath">Optional path to the configuration file.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddRouterConfig(
this IServiceCollection services,
string? configPath = null)
{
return services.AddRouterConfig(options =>
{
if (!string.IsNullOrEmpty(configPath))
{
options.ConfigPath = configPath;
}
});
}
/// <summary>
/// Adds router configuration services to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configure">Configuration action.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddRouterConfig(
this IServiceCollection services,
Action<RouterConfigOptions> configure)
{
services.Configure(configure);
services.AddSingleton<IRouterConfigProvider, RouterConfigProvider>();
return services;
}
/// <summary>
/// Adds router configuration services to the service collection, binding from IConfiguration.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configuration">The configuration.</param>
/// <param name="sectionName">The configuration section name.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddRouterConfig(
this IServiceCollection services,
IConfiguration configuration,
string sectionName = "Router")
{
var section = configuration.GetSection(sectionName);
services.Configure<RouterConfigOptions>(options =>
{
options.ConfigurationSection = sectionName;
});
services.Configure<RouterConfig>(section);
services.AddSingleton<IRouterConfigProvider, RouterConfigProvider>();
return services;
}
/// <summary>
/// Adds router configuration from a YAML file.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="yamlPath">Path to the YAML configuration file.</param>
/// <param name="enableHotReload">Whether to enable hot-reload.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddRouterConfigFromYaml(
this IServiceCollection services,
string yamlPath,
bool enableHotReload = true)
{
return services.AddRouterConfig(options =>
{
options.ConfigPath = yamlPath;
options.EnableHotReload = enableHotReload;
});
}
/// <summary>
/// Adds router configuration from a JSON file.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="jsonPath">Path to the JSON configuration file.</param>
/// <param name="enableHotReload">Whether to enable hot-reload.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddRouterConfigFromJson(
this IServiceCollection services,
string jsonPath,
bool enableHotReload = true)
{
return services.AddRouterConfig(options =>
{
options.ConfigPath = jsonPath;
options.EnableHotReload = enableHotReload;
});
}
}

View File

@@ -0,0 +1,49 @@
using StellaOps.Router.Common.Enums;
namespace StellaOps.Router.Config;
/// <summary>
/// Configuration for a statically-defined microservice instance.
/// </summary>
public sealed class StaticInstanceConfig
{
/// <summary>
/// Gets or sets the service name.
/// </summary>
public required string ServiceName { get; set; }
/// <summary>
/// Gets or sets the service version.
/// </summary>
public required string Version { get; set; }
/// <summary>
/// Gets or sets the region.
/// </summary>
public string Region { get; set; } = "default";
/// <summary>
/// Gets or sets the host name or IP address.
/// </summary>
public required string Host { get; set; }
/// <summary>
/// Gets or sets the port.
/// </summary>
public required int Port { get; set; }
/// <summary>
/// Gets or sets the transport type.
/// </summary>
public TransportType Transport { get; set; } = TransportType.Tcp;
/// <summary>
/// Gets or sets the instance weight for load balancing.
/// </summary>
public int Weight { get; set; } = 100;
/// <summary>
/// Gets or sets the instance metadata.
/// </summary>
public Dictionary<string, string> Metadata { get; set; } = [];
}

View File

@@ -5,8 +5,23 @@
<Nullable>enable</Nullable>
<ImplicitUsings>enable</ImplicitUsings>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<RootNamespace>StellaOps.Router.Config</RootNamespace>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Configuration" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Configuration.Binder" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Configuration.EnvironmentVariables" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Options" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Options.ConfigurationExtensions" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="NetEscapades.Configuration.Yaml" Version="2.1.0" />
<PackageReference Include="YamlDotNet" Version="13.7.1" />
</ItemGroup>
</Project>

View File

@@ -5,6 +5,7 @@ using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using static StellaOps.Router.Common.Models.CancelReasons;
namespace StellaOps.Router.Transport.InMemory;
@@ -12,12 +13,13 @@ namespace StellaOps.Router.Transport.InMemory;
/// In-memory transport client implementation for testing and development.
/// Used by the Microservice SDK to send frames to the Gateway.
/// </summary>
public sealed class InMemoryTransportClient : ITransportClient, IDisposable
public sealed class InMemoryTransportClient : ITransportClient, IMicroserviceTransport, IDisposable
{
private readonly InMemoryConnectionRegistry _registry;
private readonly InMemoryTransportOptions _options;
private readonly ILogger<InMemoryTransportClient> _logger;
private readonly ConcurrentDictionary<string, TaskCompletionSource<Frame>> _pendingRequests = new();
private readonly ConcurrentDictionary<string, CancellationTokenSource> _inflightHandlers = new();
private readonly CancellationTokenSource _clientCts = new();
private bool _disposed;
private string? _connectionId;
@@ -172,29 +174,54 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
return;
}
var correlationId = frame.CorrelationId ?? Guid.NewGuid().ToString("N");
// Create a linked CancellationTokenSource for this handler
// This allows cancellation via CANCEL frame or transport shutdown
using var handlerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_inflightHandlers[correlationId] = handlerCts;
try
{
var response = await OnRequestReceived(frame, cancellationToken);
var response = await OnRequestReceived(frame, handlerCts.Token);
// Ensure response has same correlation ID
var responseFrame = response with { CorrelationId = frame.CorrelationId };
await channel.ToGateway.Writer.WriteAsync(responseFrame, cancellationToken);
var responseFrame = response with { CorrelationId = correlationId };
// Only send response if not cancelled
if (!handlerCts.Token.IsCancellationRequested)
{
await channel.ToGateway.Writer.WriteAsync(responseFrame, cancellationToken);
}
else
{
_logger.LogDebug("Not sending response for cancelled request {CorrelationId}", correlationId);
}
}
catch (OperationCanceledException)
{
_logger.LogDebug("Request {CorrelationId} was cancelled", frame.CorrelationId);
_logger.LogDebug("Request {CorrelationId} was cancelled", correlationId);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error handling request {CorrelationId}", frame.CorrelationId);
// Send error response
var errorFrame = new Frame
_logger.LogError(ex, "Error handling request {CorrelationId}", correlationId);
// Only send error response if not cancelled
if (!handlerCts.Token.IsCancellationRequested)
{
Type = FrameType.Response,
CorrelationId = frame.CorrelationId,
Payload = ReadOnlyMemory<byte>.Empty
};
await channel.ToGateway.Writer.WriteAsync(errorFrame, cancellationToken);
var errorFrame = new Frame
{
Type = FrameType.Response,
CorrelationId = correlationId,
Payload = ReadOnlyMemory<byte>.Empty
};
await channel.ToGateway.Writer.WriteAsync(errorFrame, cancellationToken);
}
}
finally
{
// Remove from inflight tracking
_inflightHandlers.TryRemove(correlationId, out _);
}
}
@@ -204,13 +231,27 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
_logger.LogDebug("Received CANCEL for correlation {CorrelationId}", frame.CorrelationId);
// Cancel the inflight handler via its CancellationTokenSource
if (_inflightHandlers.TryGetValue(frame.CorrelationId, out var handlerCts))
{
try
{
handlerCts.Cancel();
_logger.LogInformation("Cancelled handler for request {CorrelationId}", frame.CorrelationId);
}
catch (ObjectDisposedException)
{
// Handler already completed
}
}
// Complete any pending request with cancellation
if (_pendingRequests.TryRemove(frame.CorrelationId, out var tcs))
{
tcs.TrySetCanceled();
}
// Notify handler
// Notify external handler (for custom cancellation logic)
if (OnCancelReceived is not null && Guid.TryParse(frame.CorrelationId, out var correlationGuid))
{
_ = OnCancelReceived(correlationGuid, null);
@@ -381,6 +422,33 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
await channel.ToGateway.Writer.WriteAsync(frame, cancellationToken);
}
/// <summary>
/// Cancels all in-flight handler requests.
/// Called when connection is closed or transport is shutting down.
/// </summary>
/// <param name="reason">The reason for cancellation.</param>
public void CancelAllInflight(string reason)
{
var count = 0;
foreach (var kvp in _inflightHandlers)
{
try
{
kvp.Value.Cancel();
count++;
}
catch (ObjectDisposedException)
{
// Already completed/disposed
}
}
if (count > 0)
{
_logger.LogInformation("Cancelled {Count} in-flight handlers: {Reason}", count, reason);
}
}
/// <summary>
/// Disconnects from the transport.
/// </summary>
@@ -388,6 +456,9 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
{
if (_connectionId is null) return;
// Cancel all inflight handlers before disconnecting
CancelAllInflight(CancelReasons.Shutdown);
await _clientCts.CancelAsync();
if (_receiveTask is not null)
@@ -407,6 +478,9 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
if (_disposed) return;
_disposed = true;
// Cancel all inflight handlers
CancelAllInflight(Shutdown);
_clientCts.Cancel();
foreach (var tcs in _pendingRequests.Values)
@@ -414,6 +488,7 @@ public sealed class InMemoryTransportClient : ITransportClient, IDisposable
tcs.TrySetCanceled();
}
_pendingRequests.Clear();
_inflightHandlers.Clear();
if (_connectionId is not null)
{

View File

@@ -35,6 +35,7 @@ public static class ServiceCollectionExtensions
// Register interfaces
services.TryAddSingleton<ITransportServer>(sp => sp.GetRequiredService<InMemoryTransportServer>());
services.TryAddSingleton<ITransportClient>(sp => sp.GetRequiredService<InMemoryTransportClient>());
services.TryAddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<InMemoryTransportClient>());
return services;
}
@@ -81,6 +82,7 @@ public static class ServiceCollectionExtensions
services.TryAddSingleton<InMemoryConnectionRegistry>();
services.TryAddSingleton<InMemoryTransportClient>();
services.TryAddSingleton<ITransportClient>(sp => sp.GetRequiredService<InMemoryTransportClient>());
services.TryAddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<InMemoryTransportClient>());
return services;
}

View File

@@ -0,0 +1,111 @@
using System.Text;
using RabbitMQ.Client;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.RabbitMq;
/// <summary>
/// Handles serialization and deserialization of frames for RabbitMQ transport.
/// </summary>
public static class RabbitMqFrameProtocol
{
/// <summary>
/// Parses a frame from a RabbitMQ message.
/// </summary>
/// <param name="body">The message body.</param>
/// <param name="properties">The message properties.</param>
/// <returns>The parsed frame.</returns>
public static Frame ParseFrame(ReadOnlyMemory<byte> body, IReadOnlyBasicProperties properties)
{
var frameType = ParseFrameType(properties.Type);
var correlationId = properties.CorrelationId;
return new Frame
{
Type = frameType,
CorrelationId = correlationId,
Payload = body
};
}
/// <summary>
/// Creates BasicProperties for a frame.
/// </summary>
/// <param name="frame">The frame to serialize.</param>
/// <param name="replyTo">The reply queue name.</param>
/// <param name="timeout">Optional timeout for the message.</param>
/// <returns>The basic properties.</returns>
public static BasicProperties CreateProperties(Frame frame, string? replyTo, TimeSpan? timeout = null)
{
var props = new BasicProperties
{
Type = frame.Type.ToString(),
Timestamp = new AmqpTimestamp(DateTimeOffset.UtcNow.ToUnixTimeSeconds()),
DeliveryMode = DeliveryModes.Transient // Non-persistent (1)
};
if (!string.IsNullOrEmpty(frame.CorrelationId))
{
props.CorrelationId = frame.CorrelationId;
}
if (!string.IsNullOrEmpty(replyTo))
{
props.ReplyTo = replyTo;
}
if (timeout.HasValue)
{
props.Expiration = ((int)timeout.Value.TotalMilliseconds).ToString();
}
return props;
}
/// <summary>
/// Parses a FrameType from the message Type property.
/// </summary>
private static FrameType ParseFrameType(string? type)
{
if (string.IsNullOrEmpty(type))
{
return FrameType.Request;
}
if (Enum.TryParse<FrameType>(type, ignoreCase: true, out var result))
{
return result;
}
return FrameType.Request;
}
/// <summary>
/// Extracts the connection ID from message properties.
/// </summary>
/// <param name="properties">The message properties.</param>
/// <returns>The connection ID.</returns>
public static string ExtractConnectionId(IReadOnlyBasicProperties properties)
{
// Use ReplyTo as the basis for connection ID (identifies the instance)
if (!string.IsNullOrEmpty(properties.ReplyTo))
{
// Extract instance ID from queue name like "stella.svc.{instanceId}"
var parts = properties.ReplyTo.Split('.');
if (parts.Length >= 3)
{
return $"rmq-{parts[^1]}";
}
return $"rmq-{properties.ReplyTo}";
}
// Fallback to correlation ID
if (!string.IsNullOrEmpty(properties.CorrelationId))
{
return $"rmq-{properties.CorrelationId[..Math.Min(16, properties.CorrelationId.Length)]}";
}
return $"rmq-{Guid.NewGuid():N}"[..32];
}
}

View File

@@ -0,0 +1,449 @@
using System.Collections.Concurrent;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using RabbitMQ.Client;
using RabbitMQ.Client.Events;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.RabbitMq;
/// <summary>
/// RabbitMQ transport client implementation for microservices.
/// </summary>
public sealed class RabbitMqTransportClient : ITransportClient, IMicroserviceTransport, IAsyncDisposable
{
private readonly RabbitMqTransportOptions _options;
private readonly ILogger<RabbitMqTransportClient> _logger;
private readonly ConcurrentDictionary<Guid, TaskCompletionSource<Frame>> _pendingRequests = new();
private readonly ConcurrentDictionary<string, CancellationTokenSource> _inflightHandlers = new();
private readonly CancellationTokenSource _clientCts = new();
private IConnection? _connection;
private IChannel? _channel;
private string? _responseQueueName;
private string? _instanceId;
private string? _gatewayNodeId;
private bool _disposed;
/// <summary>
/// Event raised when a REQUEST frame is received.
/// </summary>
public event Func<Frame, CancellationToken, Task<Frame>>? OnRequestReceived;
/// <summary>
/// Event raised when a CANCEL frame is received.
/// </summary>
public event Func<Guid, string?, Task>? OnCancelReceived;
/// <summary>
/// Initializes a new instance of the <see cref="RabbitMqTransportClient"/> class.
/// </summary>
public RabbitMqTransportClient(
IOptions<RabbitMqTransportOptions> options,
ILogger<RabbitMqTransportClient> logger)
{
_options = options.Value;
_logger = logger;
}
/// <summary>
/// Connects to the gateway via RabbitMQ.
/// </summary>
/// <param name="instance">The instance descriptor.</param>
/// <param name="endpoints">The endpoints to register.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task ConnectAsync(
InstanceDescriptor instance,
IReadOnlyList<EndpointDescriptor> endpoints,
CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
_instanceId = _options.InstanceId ?? instance.InstanceId;
_gatewayNodeId = _options.NodeId ?? "default";
var factory = new ConnectionFactory
{
HostName = _options.HostName,
Port = _options.Port,
VirtualHost = _options.VirtualHost,
UserName = _options.UserName,
Password = _options.Password,
AutomaticRecoveryEnabled = _options.AutomaticRecoveryEnabled,
NetworkRecoveryInterval = _options.NetworkRecoveryInterval
};
if (_options.UseSsl)
{
factory.Ssl = new SslOption
{
Enabled = true,
ServerName = _options.HostName,
CertPath = _options.SslCertPath
};
}
_connection = await factory.CreateConnectionAsync(cancellationToken);
_channel = await _connection.CreateChannelAsync(cancellationToken: cancellationToken);
// Set QoS
await _channel.BasicQosAsync(
prefetchSize: 0,
prefetchCount: _options.PrefetchCount,
global: false,
cancellationToken: cancellationToken);
// Declare exchanges (should already exist from server, but declare for safety)
await _channel.ExchangeDeclareAsync(
exchange: _options.RequestExchange,
type: ExchangeType.Direct,
durable: true,
autoDelete: false,
cancellationToken: cancellationToken);
await _channel.ExchangeDeclareAsync(
exchange: _options.ResponseExchange,
type: ExchangeType.Topic,
durable: true,
autoDelete: false,
cancellationToken: cancellationToken);
// Declare response queue for this instance
_responseQueueName = $"{_options.QueuePrefix}.svc.{_instanceId}";
await _channel.QueueDeclareAsync(
queue: _responseQueueName,
durable: _options.DurableQueues,
exclusive: false,
autoDelete: _options.AutoDeleteQueues,
cancellationToken: cancellationToken);
// Bind to response exchange with instance ID as routing key
await _channel.QueueBindAsync(
queue: _responseQueueName,
exchange: _options.ResponseExchange,
routingKey: _instanceId,
cancellationToken: cancellationToken);
// Start consuming responses
var consumer = new AsyncEventingBasicConsumer(_channel);
consumer.ReceivedAsync += OnMessageReceivedAsync;
await _channel.BasicConsumeAsync(
queue: _responseQueueName,
autoAck: true,
consumer: consumer,
cancellationToken: cancellationToken);
// Send HELLO frame
var helloFrame = new Frame
{
Type = FrameType.Hello,
CorrelationId = Guid.NewGuid().ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
await SendToGatewayAsync(helloFrame, cancellationToken);
_logger.LogInformation(
"Connected to RabbitMQ gateway at {Host}:{Port} as {ServiceName}/{Version}",
_options.HostName,
_options.Port,
instance.ServiceName,
instance.Version);
}
private async Task OnMessageReceivedAsync(object sender, BasicDeliverEventArgs e)
{
try
{
var frame = RabbitMqFrameProtocol.ParseFrame(e.Body, e.BasicProperties);
switch (frame.Type)
{
case FrameType.Request:
await HandleRequestFrameAsync(frame, _clientCts.Token);
break;
case FrameType.Cancel:
HandleCancelFrame(frame);
break;
case FrameType.Response:
if (frame.CorrelationId is not null &&
Guid.TryParse(frame.CorrelationId, out var correlationId))
{
if (_pendingRequests.TryRemove(correlationId, out var tcs))
{
tcs.TrySetResult(frame);
}
}
break;
default:
_logger.LogWarning("Unexpected frame type {FrameType}", frame.Type);
break;
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error processing RabbitMQ message");
}
await Task.CompletedTask;
}
private async Task HandleRequestFrameAsync(Frame frame, CancellationToken cancellationToken)
{
if (OnRequestReceived is null)
{
_logger.LogWarning("No request handler registered");
return;
}
var correlationId = frame.CorrelationId ?? Guid.NewGuid().ToString("N");
using var handlerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_inflightHandlers[correlationId] = handlerCts;
try
{
var response = await OnRequestReceived(frame, handlerCts.Token);
var responseFrame = response with { CorrelationId = correlationId };
if (!handlerCts.Token.IsCancellationRequested)
{
await SendToGatewayAsync(responseFrame, cancellationToken);
}
}
catch (OperationCanceledException)
{
_logger.LogDebug("Request {CorrelationId} was cancelled", correlationId);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error handling request {CorrelationId}", correlationId);
}
finally
{
_inflightHandlers.TryRemove(correlationId, out _);
}
}
private void HandleCancelFrame(Frame frame)
{
if (frame.CorrelationId is null) return;
_logger.LogDebug("Received CANCEL for {CorrelationId}", frame.CorrelationId);
if (_inflightHandlers.TryGetValue(frame.CorrelationId, out var cts))
{
try
{
cts.Cancel();
}
catch (ObjectDisposedException)
{
// Already completed
}
}
if (Guid.TryParse(frame.CorrelationId, out var guid))
{
if (_pendingRequests.TryRemove(guid, out var tcs))
{
tcs.TrySetCanceled();
}
OnCancelReceived?.Invoke(guid, null);
}
}
private async Task SendToGatewayAsync(Frame frame, CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var properties = RabbitMqFrameProtocol.CreateProperties(
frame,
_responseQueueName,
_options.DefaultTimeout);
await _channel!.BasicPublishAsync(
exchange: _options.RequestExchange,
routingKey: _gatewayNodeId!,
mandatory: false,
basicProperties: properties,
body: frame.Payload,
cancellationToken: cancellationToken);
}
/// <inheritdoc />
public async Task<Frame> SendRequestAsync(
ConnectionState connection,
Frame requestFrame,
TimeSpan timeout,
CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var correlationId = requestFrame.CorrelationId is not null &&
Guid.TryParse(requestFrame.CorrelationId, out var parsed)
? parsed
: Guid.NewGuid();
var framedRequest = requestFrame with { CorrelationId = correlationId.ToString("N") };
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCts.CancelAfter(timeout);
var tcs = new TaskCompletionSource<Frame>(TaskCreationOptions.RunContinuationsAsynchronously);
var registration = timeoutCts.Token.Register(() =>
{
if (_pendingRequests.TryRemove(correlationId, out var pendingTcs))
{
pendingTcs.TrySetCanceled(timeoutCts.Token);
}
});
_pendingRequests[correlationId] = tcs;
try
{
await SendToGatewayAsync(framedRequest, timeoutCts.Token);
return await tcs.Task;
}
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
{
throw new TimeoutException($"Request {correlationId} timed out after {timeout}");
}
finally
{
await registration.DisposeAsync();
_pendingRequests.TryRemove(correlationId, out _);
}
}
/// <inheritdoc />
public async Task SendCancelAsync(
ConnectionState connection,
Guid correlationId,
string? reason = null)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var cancelFrame = new Frame
{
Type = FrameType.Cancel,
CorrelationId = correlationId.ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
await SendToGatewayAsync(cancelFrame, CancellationToken.None);
_logger.LogDebug("Sent CANCEL for {CorrelationId}", correlationId);
}
/// <inheritdoc />
public async Task SendStreamingAsync(
ConnectionState connection,
Frame requestHeader,
Stream requestBody,
Func<Stream, Task> readResponseBody,
PayloadLimits limits,
CancellationToken cancellationToken)
{
// Streaming could be implemented by chunking messages, but for now we don't support it
// This keeps RabbitMQ transport simple
throw new NotSupportedException(
"RabbitMQ transport does not currently support streaming. Use TCP or TLS transport for streaming.");
}
/// <summary>
/// Sends a heartbeat.
/// </summary>
public async Task SendHeartbeatAsync(HeartbeatPayload heartbeat, CancellationToken cancellationToken)
{
var frame = new Frame
{
Type = FrameType.Heartbeat,
CorrelationId = null,
Payload = ReadOnlyMemory<byte>.Empty
};
await SendToGatewayAsync(frame, cancellationToken);
}
/// <summary>
/// Cancels all in-flight handlers.
/// </summary>
public void CancelAllInflight(string reason)
{
var count = 0;
foreach (var cts in _inflightHandlers.Values)
{
try
{
cts.Cancel();
count++;
}
catch (ObjectDisposedException)
{
// Already completed
}
}
if (count > 0)
{
_logger.LogInformation("Cancelled {Count} in-flight handlers: {Reason}", count, reason);
}
}
/// <summary>
/// Disconnects from the gateway.
/// </summary>
public async Task DisconnectAsync()
{
CancelAllInflight("Shutdown");
// Cancel all pending requests
foreach (var kvp in _pendingRequests)
{
if (_pendingRequests.TryRemove(kvp.Key, out var tcs))
{
tcs.TrySetCanceled();
}
}
await _clientCts.CancelAsync();
if (_channel is not null)
{
await _channel.CloseAsync();
}
if (_connection is not null)
{
await _connection.CloseAsync();
}
_logger.LogInformation("Disconnected from RabbitMQ gateway");
}
/// <inheritdoc />
public async ValueTask DisposeAsync()
{
if (_disposed) return;
_disposed = true;
await DisconnectAsync();
if (_channel is not null)
{
await _channel.DisposeAsync();
}
if (_connection is not null)
{
await _connection.DisposeAsync();
}
_clientCts.Dispose();
}
}

View File

@@ -0,0 +1,102 @@
namespace StellaOps.Router.Transport.RabbitMq;
/// <summary>
/// Options for RabbitMQ transport configuration.
/// </summary>
public sealed class RabbitMqTransportOptions
{
/// <summary>
/// Gets or sets the RabbitMQ host name.
/// </summary>
public string HostName { get; set; } = "localhost";
/// <summary>
/// Gets or sets the RabbitMQ port.
/// </summary>
public int Port { get; set; } = 5672;
/// <summary>
/// Gets or sets the RabbitMQ virtual host.
/// </summary>
public string VirtualHost { get; set; } = "/";
/// <summary>
/// Gets or sets the RabbitMQ username.
/// </summary>
public string UserName { get; set; } = "guest";
/// <summary>
/// Gets or sets the RabbitMQ password.
/// </summary>
public string Password { get; set; } = "guest";
/// <summary>
/// Gets or sets whether to use SSL/TLS.
/// </summary>
public bool UseSsl { get; set; } = false;
/// <summary>
/// Gets or sets the SSL certificate path.
/// </summary>
public string? SslCertPath { get; set; }
/// <summary>
/// Gets or sets whether queues should be durable.
/// </summary>
public bool DurableQueues { get; set; } = false;
/// <summary>
/// Gets or sets whether queues should auto-delete on disconnect.
/// </summary>
public bool AutoDeleteQueues { get; set; } = true;
/// <summary>
/// Gets or sets the prefetch count (concurrent messages).
/// </summary>
public ushort PrefetchCount { get; set; } = 10;
/// <summary>
/// Gets or sets the exchange prefix.
/// </summary>
public string ExchangePrefix { get; set; } = "stella.router";
/// <summary>
/// Gets or sets the queue prefix.
/// </summary>
public string QueuePrefix { get; set; } = "stella";
/// <summary>
/// Gets or sets the request exchange name.
/// </summary>
public string RequestExchange => $"{ExchangePrefix}.requests";
/// <summary>
/// Gets or sets the response exchange name.
/// </summary>
public string ResponseExchange => $"{ExchangePrefix}.responses";
/// <summary>
/// Gets or sets the node ID for this gateway instance.
/// </summary>
public string? NodeId { get; set; }
/// <summary>
/// Gets or sets the instance ID for this microservice instance.
/// </summary>
public string? InstanceId { get; set; }
/// <summary>
/// Gets or sets whether to use automatic recovery.
/// </summary>
public bool AutomaticRecoveryEnabled { get; set; } = true;
/// <summary>
/// Gets or sets the network recovery interval.
/// </summary>
public TimeSpan NetworkRecoveryInterval { get; set; } = TimeSpan.FromSeconds(5);
/// <summary>
/// Gets or sets the default request timeout.
/// </summary>
public TimeSpan DefaultTimeout { get; set; } = TimeSpan.FromSeconds(30);
}

View File

@@ -0,0 +1,289 @@
using System.Collections.Concurrent;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using RabbitMQ.Client;
using RabbitMQ.Client.Events;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.RabbitMq;
/// <summary>
/// RabbitMQ transport server implementation for the gateway.
/// </summary>
public sealed class RabbitMqTransportServer : ITransportServer, IAsyncDisposable
{
private readonly RabbitMqTransportOptions _options;
private readonly ILogger<RabbitMqTransportServer> _logger;
private readonly ConcurrentDictionary<string, (string ReplyTo, ConnectionState State)> _connections = new();
private readonly string _nodeId;
private IConnection? _connection;
private IChannel? _channel;
private string? _requestQueueName;
private bool _disposed;
/// <summary>
/// Event raised when a connection is established (on first HELLO).
/// </summary>
public event Action<string, ConnectionState>? OnConnection;
/// <summary>
/// Event raised when a connection is lost.
/// </summary>
public event Action<string>? OnDisconnection;
/// <summary>
/// Event raised when a frame is received.
/// </summary>
public event Action<string, Frame>? OnFrame;
/// <summary>
/// Initializes a new instance of the <see cref="RabbitMqTransportServer"/> class.
/// </summary>
public RabbitMqTransportServer(
IOptions<RabbitMqTransportOptions> options,
ILogger<RabbitMqTransportServer> logger)
{
_options = options.Value;
_logger = logger;
_nodeId = _options.NodeId ?? Guid.NewGuid().ToString("N")[..8];
}
/// <inheritdoc />
public async Task StartAsync(CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var factory = new ConnectionFactory
{
HostName = _options.HostName,
Port = _options.Port,
VirtualHost = _options.VirtualHost,
UserName = _options.UserName,
Password = _options.Password,
AutomaticRecoveryEnabled = _options.AutomaticRecoveryEnabled,
NetworkRecoveryInterval = _options.NetworkRecoveryInterval
};
if (_options.UseSsl)
{
factory.Ssl = new SslOption
{
Enabled = true,
ServerName = _options.HostName,
CertPath = _options.SslCertPath
};
}
_connection = await factory.CreateConnectionAsync(cancellationToken);
_channel = await _connection.CreateChannelAsync(cancellationToken: cancellationToken);
// Set QoS (prefetch count)
await _channel.BasicQosAsync(
prefetchSize: 0,
prefetchCount: _options.PrefetchCount,
global: false,
cancellationToken: cancellationToken);
// Declare exchanges
await _channel.ExchangeDeclareAsync(
exchange: _options.RequestExchange,
type: ExchangeType.Direct,
durable: true,
autoDelete: false,
cancellationToken: cancellationToken);
await _channel.ExchangeDeclareAsync(
exchange: _options.ResponseExchange,
type: ExchangeType.Topic,
durable: true,
autoDelete: false,
cancellationToken: cancellationToken);
// Declare and bind request queue
_requestQueueName = $"{_options.QueuePrefix}.gw.{_nodeId}.in";
await _channel.QueueDeclareAsync(
queue: _requestQueueName,
durable: _options.DurableQueues,
exclusive: false,
autoDelete: _options.AutoDeleteQueues,
cancellationToken: cancellationToken);
await _channel.QueueBindAsync(
queue: _requestQueueName,
exchange: _options.RequestExchange,
routingKey: _nodeId,
cancellationToken: cancellationToken);
// Start consuming
var consumer = new AsyncEventingBasicConsumer(_channel);
consumer.ReceivedAsync += OnMessageReceivedAsync;
await _channel.BasicConsumeAsync(
queue: _requestQueueName,
autoAck: true, // At-most-once delivery
consumer: consumer,
cancellationToken: cancellationToken);
_logger.LogInformation(
"RabbitMQ transport server started, consuming from {Queue}",
_requestQueueName);
}
private async Task OnMessageReceivedAsync(object sender, BasicDeliverEventArgs e)
{
try
{
var frame = RabbitMqFrameProtocol.ParseFrame(e.Body, e.BasicProperties);
var connectionId = RabbitMqFrameProtocol.ExtractConnectionId(e.BasicProperties);
var replyTo = e.BasicProperties.ReplyTo ?? string.Empty;
// Handle HELLO specially to register connection
if (frame.Type == FrameType.Hello && !_connections.ContainsKey(connectionId))
{
var state = new ConnectionState
{
ConnectionId = connectionId,
Instance = new InstanceDescriptor
{
InstanceId = connectionId,
ServiceName = "unknown",
Version = "1.0.0",
Region = "default"
},
Status = InstanceHealthStatus.Healthy,
LastHeartbeatUtc = DateTime.UtcNow,
TransportType = TransportType.RabbitMq
};
_connections[connectionId] = (replyTo, state);
_logger.LogInformation(
"RabbitMQ connection established: {ConnectionId} with replyTo {ReplyTo}",
connectionId,
replyTo);
OnConnection?.Invoke(connectionId, state);
}
// Update heartbeat timestamp on HEARTBEAT frames
if (frame.Type == FrameType.Heartbeat &&
_connections.TryGetValue(connectionId, out var conn))
{
conn.State.LastHeartbeatUtc = DateTime.UtcNow;
}
OnFrame?.Invoke(connectionId, frame);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error processing RabbitMQ message");
}
await Task.CompletedTask;
}
/// <summary>
/// Sends a frame to a connection.
/// </summary>
/// <param name="connectionId">The connection ID.</param>
/// <param name="frame">The frame to send.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task SendFrameAsync(
string connectionId,
Frame frame,
CancellationToken cancellationToken = default)
{
ObjectDisposedException.ThrowIf(_disposed, this);
if (!_connections.TryGetValue(connectionId, out var conn))
{
throw new InvalidOperationException($"Connection {connectionId} not found");
}
var properties = RabbitMqFrameProtocol.CreateProperties(frame, null, _options.DefaultTimeout);
// Send to response exchange with instance ID as routing key
var routingKey = conn.ReplyTo.Split('.')[^1]; // Extract instance ID from queue name
await _channel!.BasicPublishAsync(
exchange: _options.ResponseExchange,
routingKey: routingKey,
mandatory: false,
basicProperties: properties,
body: frame.Payload,
cancellationToken: cancellationToken);
}
/// <summary>
/// Gets the connection state by ID.
/// </summary>
/// <param name="connectionId">The connection ID.</param>
/// <returns>The connection state, or null if not found.</returns>
public ConnectionState? GetConnectionState(string connectionId)
{
return _connections.TryGetValue(connectionId, out var conn) ? conn.State : null;
}
/// <summary>
/// Gets all active connections.
/// </summary>
public IEnumerable<ConnectionState> GetConnections() =>
_connections.Values.Select(c => c.State);
/// <summary>
/// Gets the number of active connections.
/// </summary>
public int ConnectionCount => _connections.Count;
/// <summary>
/// Removes a connection.
/// </summary>
/// <param name="connectionId">The connection ID.</param>
public void RemoveConnection(string connectionId)
{
if (_connections.TryRemove(connectionId, out _))
{
_logger.LogInformation("RabbitMQ connection removed: {ConnectionId}", connectionId);
OnDisconnection?.Invoke(connectionId);
}
}
/// <inheritdoc />
public async Task StopAsync(CancellationToken cancellationToken)
{
_logger.LogInformation("Stopping RabbitMQ transport server");
if (_channel is not null)
{
await _channel.CloseAsync(cancellationToken);
}
if (_connection is not null)
{
await _connection.CloseAsync(cancellationToken);
}
_connections.Clear();
_logger.LogInformation("RabbitMQ transport server stopped");
}
/// <inheritdoc />
public async ValueTask DisposeAsync()
{
if (_disposed) return;
_disposed = true;
await StopAsync(CancellationToken.None);
if (_channel is not null)
{
await _channel.DisposeAsync();
}
if (_connection is not null)
{
await _connection.DisposeAsync();
}
}
}

View File

@@ -0,0 +1,53 @@
using Microsoft.Extensions.DependencyInjection;
using StellaOps.Router.Common.Abstractions;
namespace StellaOps.Router.Transport.RabbitMq;
/// <summary>
/// Extension methods for registering RabbitMQ transport services.
/// </summary>
public static class ServiceCollectionExtensions
{
/// <summary>
/// Adds RabbitMQ transport server services to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configure">Optional configuration action.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddRabbitMqTransportServer(
this IServiceCollection services,
Action<RabbitMqTransportOptions>? configure = null)
{
if (configure is not null)
{
services.Configure(configure);
}
services.AddSingleton<RabbitMqTransportServer>();
services.AddSingleton<ITransportServer>(sp => sp.GetRequiredService<RabbitMqTransportServer>());
return services;
}
/// <summary>
/// Adds RabbitMQ transport client services to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configure">Optional configuration action.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddRabbitMqTransportClient(
this IServiceCollection services,
Action<RabbitMqTransportOptions>? configure = null)
{
if (configure is not null)
{
services.Configure(configure);
}
services.AddSingleton<RabbitMqTransportClient>();
services.AddSingleton<ITransportClient>(sp => sp.GetRequiredService<RabbitMqTransportClient>());
services.AddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<RabbitMqTransportClient>());
return services;
}
}

View File

@@ -0,0 +1,23 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>preview</LangVersion>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<RootNamespace>StellaOps.Router.Transport.RabbitMq</RootNamespace>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Options" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="RabbitMQ.Client" Version="7.0.0" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,144 @@
using System.Buffers.Binary;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.Tcp;
/// <summary>
/// Handles reading and writing length-prefixed frames over a stream.
/// Frame format: [4-byte big-endian length][payload]
/// Payload format: [1-byte frame type][16-byte correlation GUID][remaining data]
/// </summary>
public static class FrameProtocol
{
private const int LengthPrefixSize = 4;
private const int FrameTypeSize = 1;
private const int CorrelationIdSize = 16;
private const int HeaderSize = FrameTypeSize + CorrelationIdSize;
/// <summary>
/// Reads a complete frame from the stream.
/// </summary>
/// <param name="stream">The stream to read from.</param>
/// <param name="maxFrameSize">The maximum frame size allowed.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>The frame read, or null if the stream is closed.</returns>
public static async Task<Frame?> ReadFrameAsync(
Stream stream,
int maxFrameSize,
CancellationToken cancellationToken)
{
// Read length prefix (4 bytes, big-endian)
var lengthBuffer = new byte[LengthPrefixSize];
var bytesRead = await ReadExactAsync(stream, lengthBuffer, cancellationToken);
if (bytesRead == 0)
{
return null; // Connection closed
}
if (bytesRead < LengthPrefixSize)
{
throw new InvalidOperationException("Incomplete length prefix received");
}
var payloadLength = BinaryPrimitives.ReadInt32BigEndian(lengthBuffer);
if (payloadLength < HeaderSize)
{
throw new InvalidOperationException($"Invalid payload length: {payloadLength}");
}
if (payloadLength > maxFrameSize)
{
throw new InvalidOperationException(
$"Frame size {payloadLength} exceeds maximum {maxFrameSize}");
}
// Read payload
var payload = new byte[payloadLength];
bytesRead = await ReadExactAsync(stream, payload, cancellationToken);
if (bytesRead < payloadLength)
{
throw new InvalidOperationException(
$"Incomplete payload: expected {payloadLength}, got {bytesRead}");
}
// Parse frame
var frameType = (FrameType)payload[0];
var correlationId = new Guid(payload.AsSpan(FrameTypeSize, CorrelationIdSize));
var data = payload.AsMemory(HeaderSize);
return new Frame
{
Type = frameType,
CorrelationId = correlationId.ToString("N"),
Payload = data
};
}
/// <summary>
/// Writes a frame to the stream.
/// </summary>
/// <param name="stream">The stream to write to.</param>
/// <param name="frame">The frame to write.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public static async Task WriteFrameAsync(
Stream stream,
Frame frame,
CancellationToken cancellationToken)
{
// Parse or generate correlation ID
var correlationGuid = frame.CorrelationId is not null &&
Guid.TryParse(frame.CorrelationId, out var parsed)
? parsed
: Guid.NewGuid();
var dataLength = frame.Payload.Length;
var payloadLength = HeaderSize + dataLength;
// Create buffer for the complete message
var buffer = new byte[LengthPrefixSize + payloadLength];
// Write length prefix (big-endian)
BinaryPrimitives.WriteInt32BigEndian(buffer.AsSpan(0, LengthPrefixSize), payloadLength);
// Write frame type
buffer[LengthPrefixSize] = (byte)frame.Type;
// Write correlation ID
correlationGuid.TryWriteBytes(buffer.AsSpan(LengthPrefixSize + FrameTypeSize, CorrelationIdSize));
// Write data
if (dataLength > 0)
{
frame.Payload.Span.CopyTo(buffer.AsSpan(LengthPrefixSize + HeaderSize));
}
await stream.WriteAsync(buffer, cancellationToken);
}
/// <summary>
/// Reads exactly the specified number of bytes from the stream.
/// </summary>
private static async Task<int> ReadExactAsync(
Stream stream,
Memory<byte> buffer,
CancellationToken cancellationToken)
{
var totalRead = 0;
while (totalRead < buffer.Length)
{
var read = await stream.ReadAsync(
buffer[totalRead..],
cancellationToken);
if (read == 0)
{
return totalRead; // EOF
}
totalRead += read;
}
return totalRead;
}
}

View File

@@ -0,0 +1,125 @@
using System.Collections.Concurrent;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.Tcp;
/// <summary>
/// Tracks pending requests waiting for responses.
/// Enables multiplexing multiple concurrent requests on a single connection.
/// </summary>
public sealed class PendingRequestTracker : IDisposable
{
private readonly ConcurrentDictionary<Guid, TaskCompletionSource<Frame>> _pending = new();
private bool _disposed;
/// <summary>
/// Tracks a request and returns a task that completes when the response arrives.
/// </summary>
/// <param name="correlationId">The correlation ID of the request.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>A task that completes with the response frame.</returns>
public Task<Frame> TrackRequest(Guid correlationId, CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var tcs = new TaskCompletionSource<Frame>(TaskCreationOptions.RunContinuationsAsynchronously);
// Register cancellation callback
var registration = cancellationToken.Register(() =>
{
if (_pending.TryRemove(correlationId, out var pendingTcs))
{
pendingTcs.TrySetCanceled(cancellationToken);
}
});
// Store registration in state to dispose when completed
tcs.Task.ContinueWith(_ => registration.Dispose(), TaskScheduler.Default);
_pending[correlationId] = tcs;
return tcs.Task;
}
/// <summary>
/// Completes a pending request with the response.
/// </summary>
/// <param name="correlationId">The correlation ID.</param>
/// <param name="response">The response frame.</param>
/// <returns>True if the request was found and completed; false otherwise.</returns>
public bool CompleteRequest(Guid correlationId, Frame response)
{
if (_pending.TryRemove(correlationId, out var tcs))
{
return tcs.TrySetResult(response);
}
return false;
}
/// <summary>
/// Fails a pending request with an exception.
/// </summary>
/// <param name="correlationId">The correlation ID.</param>
/// <param name="exception">The exception.</param>
/// <returns>True if the request was found and failed; false otherwise.</returns>
public bool FailRequest(Guid correlationId, Exception exception)
{
if (_pending.TryRemove(correlationId, out var tcs))
{
return tcs.TrySetException(exception);
}
return false;
}
/// <summary>
/// Cancels a pending request.
/// </summary>
/// <param name="correlationId">The correlation ID.</param>
/// <returns>True if the request was found and cancelled; false otherwise.</returns>
public bool CancelRequest(Guid correlationId)
{
if (_pending.TryRemove(correlationId, out var tcs))
{
return tcs.TrySetCanceled();
}
return false;
}
/// <summary>
/// Gets the number of pending requests.
/// </summary>
public int Count => _pending.Count;
/// <summary>
/// Cancels all pending requests.
/// </summary>
/// <param name="exception">Optional exception to set.</param>
public void CancelAll(Exception? exception = null)
{
foreach (var kvp in _pending)
{
if (_pending.TryRemove(kvp.Key, out var tcs))
{
if (exception is not null)
{
tcs.TrySetException(exception);
}
else
{
tcs.TrySetCanceled();
}
}
}
}
/// <inheritdoc />
public void Dispose()
{
if (_disposed) return;
_disposed = true;
CancelAll(new ObjectDisposedException(nameof(PendingRequestTracker)));
}
}

View File

@@ -0,0 +1,53 @@
using Microsoft.Extensions.DependencyInjection;
using StellaOps.Router.Common.Abstractions;
namespace StellaOps.Router.Transport.Tcp;
/// <summary>
/// Extension methods for registering TCP transport services.
/// </summary>
public static class ServiceCollectionExtensions
{
/// <summary>
/// Adds TCP transport server services to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configure">Optional configuration action.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddTcpTransportServer(
this IServiceCollection services,
Action<TcpTransportOptions>? configure = null)
{
if (configure is not null)
{
services.Configure(configure);
}
services.AddSingleton<TcpTransportServer>();
services.AddSingleton<ITransportServer>(sp => sp.GetRequiredService<TcpTransportServer>());
return services;
}
/// <summary>
/// Adds TCP transport client services to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configure">Optional configuration action.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddTcpTransportClient(
this IServiceCollection services,
Action<TcpTransportOptions>? configure = null)
{
if (configure is not null)
{
services.Configure(configure);
}
services.AddSingleton<TcpTransportClient>();
services.AddSingleton<ITransportClient>(sp => sp.GetRequiredService<TcpTransportClient>());
services.AddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<TcpTransportClient>());
return services;
}
}

View File

@@ -0,0 +1,22 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>preview</LangVersion>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<RootNamespace>StellaOps.Router.Transport.Tcp</RootNamespace>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Options" Version="10.0.0-rc.2.25502.107" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,182 @@
using System.Net.Sockets;
using Microsoft.Extensions.Logging;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.Tcp;
/// <summary>
/// Represents a TCP connection to a microservice.
/// </summary>
public sealed class TcpConnection : IAsyncDisposable
{
private readonly TcpClient _client;
private readonly NetworkStream _stream;
private readonly SemaphoreSlim _writeLock = new(1, 1);
private readonly TcpTransportOptions _options;
private readonly ILogger _logger;
private readonly CancellationTokenSource _connectionCts = new();
private bool _disposed;
/// <summary>
/// Gets the connection ID.
/// </summary>
public string ConnectionId { get; }
/// <summary>
/// Gets the remote endpoint as a string.
/// </summary>
public string RemoteEndpoint { get; }
/// <summary>
/// Gets a value indicating whether the connection is active.
/// </summary>
public bool IsConnected => _client.Connected && !_disposed;
/// <summary>
/// Gets the connection state.
/// </summary>
public ConnectionState? State { get; set; }
/// <summary>
/// Gets the cancellation token for this connection.
/// </summary>
public CancellationToken ConnectionToken => _connectionCts.Token;
/// <summary>
/// Event raised when a frame is received.
/// </summary>
public event Action<TcpConnection, Frame>? OnFrameReceived;
/// <summary>
/// Event raised when the connection is closed.
/// </summary>
public event Action<TcpConnection, Exception?>? OnDisconnected;
/// <summary>
/// Initializes a new instance of the <see cref="TcpConnection"/> class.
/// </summary>
public TcpConnection(
string connectionId,
TcpClient client,
TcpTransportOptions options,
ILogger logger)
{
ConnectionId = connectionId;
_client = client;
_stream = client.GetStream();
_options = options;
_logger = logger;
RemoteEndpoint = client.Client.RemoteEndPoint?.ToString() ?? "unknown";
// Configure socket options
client.ReceiveBufferSize = options.ReceiveBufferSize;
client.SendBufferSize = options.SendBufferSize;
client.NoDelay = true;
}
/// <summary>
/// Starts the read loop to receive frames.
/// </summary>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task ReadLoopAsync(CancellationToken cancellationToken)
{
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(
cancellationToken, _connectionCts.Token);
Exception? disconnectException = null;
try
{
while (!linkedCts.Token.IsCancellationRequested)
{
var frame = await FrameProtocol.ReadFrameAsync(
_stream,
_options.MaxFrameSize,
linkedCts.Token);
if (frame is null)
{
_logger.LogDebug("Connection {ConnectionId} closed by remote", ConnectionId);
break;
}
OnFrameReceived?.Invoke(this, frame);
}
}
catch (OperationCanceledException)
{
// Expected on shutdown
}
catch (IOException ex) when (ex.InnerException is SocketException)
{
disconnectException = ex;
_logger.LogDebug(ex, "Connection {ConnectionId} socket error", ConnectionId);
}
catch (Exception ex)
{
disconnectException = ex;
_logger.LogWarning(ex, "Connection {ConnectionId} read error", ConnectionId);
}
OnDisconnected?.Invoke(this, disconnectException);
}
/// <summary>
/// Writes a frame to the connection.
/// </summary>
/// <param name="frame">The frame to write.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task WriteFrameAsync(Frame frame, CancellationToken cancellationToken = default)
{
ObjectDisposedException.ThrowIf(_disposed, this);
await _writeLock.WaitAsync(cancellationToken);
try
{
await FrameProtocol.WriteFrameAsync(_stream, frame, cancellationToken);
await _stream.FlushAsync(cancellationToken);
}
finally
{
_writeLock.Release();
}
}
/// <summary>
/// Closes the connection.
/// </summary>
public void Close()
{
if (_disposed) return;
try
{
_connectionCts.Cancel();
_client.Close();
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Error closing connection {ConnectionId}", ConnectionId);
}
}
/// <inheritdoc />
public async ValueTask DisposeAsync()
{
if (_disposed) return;
_disposed = true;
try
{
await _connectionCts.CancelAsync();
}
catch
{
// Ignore
}
_client.Dispose();
_writeLock.Dispose();
_connectionCts.Dispose();
}
}

View File

@@ -0,0 +1,486 @@
using System.Buffers;
using System.Collections.Concurrent;
using System.Net.Sockets;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.Tcp;
/// <summary>
/// TCP transport client implementation for microservices.
/// </summary>
public sealed class TcpTransportClient : ITransportClient, IMicroserviceTransport, IAsyncDisposable
{
private readonly TcpTransportOptions _options;
private readonly ILogger<TcpTransportClient> _logger;
private readonly PendingRequestTracker _pendingRequests = new();
private readonly ConcurrentDictionary<string, CancellationTokenSource> _inflightHandlers = new();
private readonly CancellationTokenSource _clientCts = new();
private TcpClient? _client;
private NetworkStream? _stream;
private readonly SemaphoreSlim _writeLock = new(1, 1);
private Task? _receiveTask;
private bool _disposed;
private string? _connectionId;
private int _reconnectAttempts;
/// <summary>
/// Event raised when a REQUEST frame is received.
/// </summary>
public event Func<Frame, CancellationToken, Task<Frame>>? OnRequestReceived;
/// <summary>
/// Event raised when a CANCEL frame is received.
/// </summary>
public event Func<Guid, string?, Task>? OnCancelReceived;
/// <summary>
/// Initializes a new instance of the <see cref="TcpTransportClient"/> class.
/// </summary>
public TcpTransportClient(
IOptions<TcpTransportOptions> options,
ILogger<TcpTransportClient> logger)
{
_options = options.Value;
_logger = logger;
}
/// <summary>
/// Connects to the gateway.
/// </summary>
/// <param name="instance">The instance descriptor.</param>
/// <param name="endpoints">The endpoints to register.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task ConnectAsync(
InstanceDescriptor instance,
IReadOnlyList<EndpointDescriptor> endpoints,
CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
if (string.IsNullOrEmpty(_options.Host))
{
throw new InvalidOperationException("Host is not configured");
}
await ConnectInternalAsync(cancellationToken);
_connectionId = Guid.NewGuid().ToString("N");
// Send HELLO frame
var helloFrame = new Frame
{
Type = FrameType.Hello,
CorrelationId = Guid.NewGuid().ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
await WriteFrameAsync(helloFrame, cancellationToken);
_logger.LogInformation(
"Connected to TCP gateway at {Host}:{Port} as {ServiceName}/{Version}",
_options.Host,
_options.Port,
instance.ServiceName,
instance.Version);
// Start receiving frames
_receiveTask = Task.Run(() => ReceiveLoopAsync(_clientCts.Token), CancellationToken.None);
}
private async Task ConnectInternalAsync(CancellationToken cancellationToken)
{
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCts.CancelAfter(_options.ConnectTimeout);
_client = new TcpClient
{
ReceiveBufferSize = _options.ReceiveBufferSize,
SendBufferSize = _options.SendBufferSize,
NoDelay = true
};
await _client.ConnectAsync(_options.Host!, _options.Port, timeoutCts.Token);
_stream = _client.GetStream();
_reconnectAttempts = 0;
}
private async Task ReconnectAsync()
{
if (_disposed) return;
while (_reconnectAttempts < _options.MaxReconnectAttempts && !_clientCts.Token.IsCancellationRequested)
{
_reconnectAttempts++;
var backoff = TimeSpan.FromMilliseconds(
Math.Min(
Math.Pow(2, _reconnectAttempts) * 100,
_options.MaxReconnectBackoff.TotalMilliseconds));
_logger.LogInformation(
"Reconnection attempt {Attempt} of {Max} in {Delay}ms",
_reconnectAttempts,
_options.MaxReconnectAttempts,
backoff.TotalMilliseconds);
await Task.Delay(backoff, _clientCts.Token);
try
{
_client?.Dispose();
await ConnectInternalAsync(_clientCts.Token);
_logger.LogInformation("Reconnected to gateway");
return;
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Reconnection attempt {Attempt} failed", _reconnectAttempts);
}
}
_logger.LogError("Max reconnection attempts reached, giving up");
}
private async Task ReceiveLoopAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
try
{
var frame = await FrameProtocol.ReadFrameAsync(
_stream!,
_options.MaxFrameSize,
cancellationToken);
if (frame is null)
{
_logger.LogDebug("Connection closed by server");
await ReconnectAsync();
continue;
}
await ProcessFrameAsync(frame, cancellationToken);
}
catch (OperationCanceledException)
{
break;
}
catch (IOException ex) when (ex.InnerException is SocketException)
{
_logger.LogDebug(ex, "Socket error, attempting reconnection");
await ReconnectAsync();
}
catch (Exception ex)
{
_logger.LogError(ex, "Error in receive loop");
await Task.Delay(1000, cancellationToken);
}
}
}
private async Task ProcessFrameAsync(Frame frame, CancellationToken cancellationToken)
{
switch (frame.Type)
{
case FrameType.Request:
case FrameType.RequestStreamData:
await HandleRequestFrameAsync(frame, cancellationToken);
break;
case FrameType.Cancel:
HandleCancelFrame(frame);
break;
case FrameType.Response:
case FrameType.ResponseStreamData:
if (frame.CorrelationId is not null &&
Guid.TryParse(frame.CorrelationId, out var correlationId))
{
_pendingRequests.CompleteRequest(correlationId, frame);
}
break;
default:
_logger.LogWarning("Unexpected frame type {FrameType}", frame.Type);
break;
}
}
private async Task HandleRequestFrameAsync(Frame frame, CancellationToken cancellationToken)
{
if (OnRequestReceived is null)
{
_logger.LogWarning("No request handler registered");
return;
}
var correlationId = frame.CorrelationId ?? Guid.NewGuid().ToString("N");
using var handlerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_inflightHandlers[correlationId] = handlerCts;
try
{
var response = await OnRequestReceived(frame, handlerCts.Token);
var responseFrame = response with { CorrelationId = correlationId };
if (!handlerCts.Token.IsCancellationRequested)
{
await WriteFrameAsync(responseFrame, cancellationToken);
}
}
catch (OperationCanceledException)
{
_logger.LogDebug("Request {CorrelationId} was cancelled", correlationId);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error handling request {CorrelationId}", correlationId);
}
finally
{
_inflightHandlers.TryRemove(correlationId, out _);
}
}
private void HandleCancelFrame(Frame frame)
{
if (frame.CorrelationId is null) return;
_logger.LogDebug("Received CANCEL for {CorrelationId}", frame.CorrelationId);
if (_inflightHandlers.TryGetValue(frame.CorrelationId, out var cts))
{
try
{
cts.Cancel();
}
catch (ObjectDisposedException)
{
// Already completed
}
}
if (Guid.TryParse(frame.CorrelationId, out var guid))
{
_pendingRequests.CancelRequest(guid);
OnCancelReceived?.Invoke(guid, null);
}
}
private async Task WriteFrameAsync(Frame frame, CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
await _writeLock.WaitAsync(cancellationToken);
try
{
await FrameProtocol.WriteFrameAsync(_stream!, frame, cancellationToken);
await _stream!.FlushAsync(cancellationToken);
}
finally
{
_writeLock.Release();
}
}
/// <inheritdoc />
public async Task<Frame> SendRequestAsync(
ConnectionState connection,
Frame requestFrame,
TimeSpan timeout,
CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var correlationId = requestFrame.CorrelationId is not null &&
Guid.TryParse(requestFrame.CorrelationId, out var parsed)
? parsed
: Guid.NewGuid();
var framedRequest = requestFrame with { CorrelationId = correlationId.ToString("N") };
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCts.CancelAfter(timeout);
var responseTask = _pendingRequests.TrackRequest(correlationId, timeoutCts.Token);
await WriteFrameAsync(framedRequest, timeoutCts.Token);
try
{
return await responseTask;
}
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
{
throw new TimeoutException($"Request {correlationId} timed out after {timeout}");
}
}
/// <inheritdoc />
public async Task SendCancelAsync(
ConnectionState connection,
Guid correlationId,
string? reason = null)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var cancelFrame = new Frame
{
Type = FrameType.Cancel,
CorrelationId = correlationId.ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
await WriteFrameAsync(cancelFrame, CancellationToken.None);
_logger.LogDebug("Sent CANCEL for {CorrelationId}", correlationId);
}
/// <inheritdoc />
public async Task SendStreamingAsync(
ConnectionState connection,
Frame requestHeader,
Stream requestBody,
Func<Stream, Task> readResponseBody,
PayloadLimits limits,
CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var correlationId = requestHeader.CorrelationId is not null &&
Guid.TryParse(requestHeader.CorrelationId, out var parsed)
? parsed
: Guid.NewGuid();
var headerFrame = requestHeader with
{
Type = FrameType.Request,
CorrelationId = correlationId.ToString("N")
};
await WriteFrameAsync(headerFrame, cancellationToken);
// Stream request body
var buffer = ArrayPool<byte>.Shared.Rent(8192);
try
{
long totalBytesRead = 0;
int bytesRead;
while ((bytesRead = await requestBody.ReadAsync(buffer, cancellationToken)) > 0)
{
totalBytesRead += bytesRead;
if (totalBytesRead > limits.MaxRequestBytesPerCall)
{
throw new InvalidOperationException(
$"Request body exceeds limit of {limits.MaxRequestBytesPerCall} bytes");
}
var dataFrame = new Frame
{
Type = FrameType.RequestStreamData,
CorrelationId = correlationId.ToString("N"),
Payload = new ReadOnlyMemory<byte>(buffer, 0, bytesRead)
};
await WriteFrameAsync(dataFrame, cancellationToken);
}
// End of stream marker
var endFrame = new Frame
{
Type = FrameType.RequestStreamData,
CorrelationId = correlationId.ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
await WriteFrameAsync(endFrame, cancellationToken);
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}
// Read streaming response
using var responseStream = new MemoryStream();
await readResponseBody(responseStream);
}
/// <summary>
/// Sends a heartbeat.
/// </summary>
public async Task SendHeartbeatAsync(HeartbeatPayload heartbeat, CancellationToken cancellationToken)
{
var frame = new Frame
{
Type = FrameType.Heartbeat,
CorrelationId = null,
Payload = ReadOnlyMemory<byte>.Empty
};
await WriteFrameAsync(frame, cancellationToken);
}
/// <summary>
/// Cancels all in-flight handlers.
/// </summary>
public void CancelAllInflight(string reason)
{
var count = 0;
foreach (var cts in _inflightHandlers.Values)
{
try
{
cts.Cancel();
count++;
}
catch (ObjectDisposedException)
{
// Already completed
}
}
if (count > 0)
{
_logger.LogInformation("Cancelled {Count} in-flight handlers: {Reason}", count, reason);
}
}
/// <summary>
/// Disconnects from the gateway.
/// </summary>
public async Task DisconnectAsync()
{
CancelAllInflight("Shutdown");
await _clientCts.CancelAsync();
if (_receiveTask is not null)
{
try
{
await _receiveTask;
}
catch
{
// Ignore
}
}
_client?.Dispose();
_logger.LogInformation("Disconnected from TCP gateway");
}
/// <inheritdoc />
public async ValueTask DisposeAsync()
{
if (_disposed) return;
_disposed = true;
await DisconnectAsync();
_pendingRequests.Dispose();
_writeLock.Dispose();
_clientCts.Dispose();
}
}

View File

@@ -0,0 +1,68 @@
using System.Net;
namespace StellaOps.Router.Transport.Tcp;
/// <summary>
/// Configuration options for TCP transport.
/// </summary>
public sealed class TcpTransportOptions
{
/// <summary>
/// Gets or sets the address to bind to.
/// Default: IPAddress.Any (0.0.0.0).
/// </summary>
public IPAddress BindAddress { get; set; } = IPAddress.Any;
/// <summary>
/// Gets or sets the port to listen on.
/// Default: 5100.
/// </summary>
public int Port { get; set; } = 5100;
/// <summary>
/// Gets or sets the receive buffer size.
/// Default: 64 KB.
/// </summary>
public int ReceiveBufferSize { get; set; } = 64 * 1024;
/// <summary>
/// Gets or sets the send buffer size.
/// Default: 64 KB.
/// </summary>
public int SendBufferSize { get; set; } = 64 * 1024;
/// <summary>
/// Gets or sets the keep-alive interval.
/// Default: 30 seconds.
/// </summary>
public TimeSpan KeepAliveInterval { get; set; } = TimeSpan.FromSeconds(30);
/// <summary>
/// Gets or sets the connection timeout.
/// Default: 10 seconds.
/// </summary>
public TimeSpan ConnectTimeout { get; set; } = TimeSpan.FromSeconds(10);
/// <summary>
/// Gets or sets the maximum number of reconnection attempts.
/// Default: 10.
/// </summary>
public int MaxReconnectAttempts { get; set; } = 10;
/// <summary>
/// Gets or sets the maximum reconnection backoff.
/// Default: 1 minute.
/// </summary>
public TimeSpan MaxReconnectBackoff { get; set; } = TimeSpan.FromMinutes(1);
/// <summary>
/// Gets or sets the maximum frame size in bytes.
/// Default: 16 MB.
/// </summary>
public int MaxFrameSize { get; set; } = 16 * 1024 * 1024;
/// <summary>
/// Gets or sets the host for client connections.
/// </summary>
public string? Host { get; set; }
}

View File

@@ -0,0 +1,241 @@
using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.Tcp;
/// <summary>
/// TCP transport server implementation for the gateway.
/// </summary>
public sealed class TcpTransportServer : ITransportServer, IAsyncDisposable
{
private readonly TcpTransportOptions _options;
private readonly ILogger<TcpTransportServer> _logger;
private readonly ConcurrentDictionary<string, TcpConnection> _connections = new();
private TcpListener? _listener;
private CancellationTokenSource? _serverCts;
private Task? _acceptTask;
private bool _disposed;
/// <summary>
/// Event raised when a connection is established.
/// </summary>
public event Action<string, ConnectionState>? OnConnection;
/// <summary>
/// Event raised when a connection is lost.
/// </summary>
public event Action<string>? OnDisconnection;
/// <summary>
/// Event raised when a frame is received.
/// </summary>
public event Action<string, Frame>? OnFrame;
/// <summary>
/// Initializes a new instance of the <see cref="TcpTransportServer"/> class.
/// </summary>
public TcpTransportServer(
IOptions<TcpTransportOptions> options,
ILogger<TcpTransportServer> logger)
{
_options = options.Value;
_logger = logger;
}
/// <inheritdoc />
public Task StartAsync(CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
_serverCts = new CancellationTokenSource();
_listener = new TcpListener(_options.BindAddress, _options.Port);
_listener.Start();
_logger.LogInformation(
"TCP transport server listening on {Address}:{Port}",
_options.BindAddress,
_options.Port);
_acceptTask = AcceptLoopAsync(_serverCts.Token);
return Task.CompletedTask;
}
private async Task AcceptLoopAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
try
{
var client = await _listener!.AcceptTcpClientAsync(cancellationToken);
var connectionId = GenerateConnectionId(client);
_logger.LogInformation(
"Accepted connection {ConnectionId} from {RemoteEndpoint}",
connectionId,
client.Client.RemoteEndPoint);
var connection = new TcpConnection(connectionId, client, _options, _logger);
_connections[connectionId] = connection;
connection.OnFrameReceived += HandleFrame;
connection.OnDisconnected += HandleDisconnect;
// Start read loop (non-blocking)
_ = Task.Run(() => connection.ReadLoopAsync(cancellationToken), CancellationToken.None);
}
catch (OperationCanceledException)
{
// Expected on shutdown
break;
}
catch (ObjectDisposedException)
{
// Listener disposed
break;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error accepting connection");
}
}
}
private void HandleFrame(TcpConnection connection, Frame frame)
{
// If this is a HELLO frame, create the ConnectionState
if (frame.Type == FrameType.Hello && connection.State is null)
{
var state = new ConnectionState
{
ConnectionId = connection.ConnectionId,
Instance = new InstanceDescriptor
{
InstanceId = connection.ConnectionId,
ServiceName = "unknown", // Will be updated from HELLO payload
Version = "1.0.0",
Region = "default"
},
Status = InstanceHealthStatus.Healthy,
LastHeartbeatUtc = DateTime.UtcNow,
TransportType = TransportType.Tcp
};
connection.State = state;
OnConnection?.Invoke(connection.ConnectionId, state);
}
OnFrame?.Invoke(connection.ConnectionId, frame);
}
private void HandleDisconnect(TcpConnection connection, Exception? ex)
{
_logger.LogInformation(
"Connection {ConnectionId} disconnected{Reason}",
connection.ConnectionId,
ex is not null ? $": {ex.Message}" : string.Empty);
_connections.TryRemove(connection.ConnectionId, out _);
OnDisconnection?.Invoke(connection.ConnectionId);
// Clean up connection
_ = connection.DisposeAsync();
}
/// <summary>
/// Sends a frame to a connection.
/// </summary>
/// <param name="connectionId">The connection ID.</param>
/// <param name="frame">The frame to send.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task SendFrameAsync(
string connectionId,
Frame frame,
CancellationToken cancellationToken = default)
{
if (_connections.TryGetValue(connectionId, out var connection))
{
await connection.WriteFrameAsync(frame, cancellationToken);
}
else
{
throw new InvalidOperationException($"Connection {connectionId} not found");
}
}
/// <summary>
/// Gets a connection by ID.
/// </summary>
/// <param name="connectionId">The connection ID.</param>
/// <returns>The connection, or null if not found.</returns>
public TcpConnection? GetConnection(string connectionId)
{
return _connections.TryGetValue(connectionId, out var conn) ? conn : null;
}
/// <summary>
/// Gets all active connections.
/// </summary>
public IEnumerable<TcpConnection> GetConnections() => _connections.Values;
/// <summary>
/// Gets the number of active connections.
/// </summary>
public int ConnectionCount => _connections.Count;
private static string GenerateConnectionId(TcpClient client)
{
var endpoint = client.Client.RemoteEndPoint as IPEndPoint;
if (endpoint is not null)
{
return $"tcp-{endpoint.Address}-{endpoint.Port}-{Guid.NewGuid():N}".Substring(0, 32);
}
return $"tcp-{Guid.NewGuid():N}";
}
/// <inheritdoc />
public async Task StopAsync(CancellationToken cancellationToken)
{
_logger.LogInformation("Stopping TCP transport server");
if (_serverCts is not null)
{
await _serverCts.CancelAsync();
}
_listener?.Stop();
if (_acceptTask is not null)
{
await _acceptTask;
}
// Close all connections
foreach (var connection in _connections.Values)
{
connection.Close();
await connection.DisposeAsync();
}
_connections.Clear();
_logger.LogInformation("TCP transport server stopped");
}
/// <inheritdoc />
public async ValueTask DisposeAsync()
{
if (_disposed) return;
_disposed = true;
await StopAsync(CancellationToken.None);
_listener?.Dispose();
_serverCts?.Dispose();
}
}

View File

@@ -0,0 +1,104 @@
using System.Security.Cryptography.X509Certificates;
namespace StellaOps.Router.Transport.Tls;
/// <summary>
/// Utility class for loading certificates from various sources.
/// </summary>
public static class CertificateLoader
{
/// <summary>
/// Loads a server certificate from the options.
/// </summary>
/// <param name="options">The TLS transport options.</param>
/// <returns>The loaded certificate.</returns>
/// <exception cref="InvalidOperationException">Thrown when no certificate is configured.</exception>
public static X509Certificate2 LoadServerCertificate(TlsTransportOptions options)
{
// Direct certificate object takes precedence
if (options.ServerCertificate is not null)
{
return options.ServerCertificate;
}
// Load from path
if (string.IsNullOrEmpty(options.ServerCertificatePath))
{
throw new InvalidOperationException("Server certificate is not configured");
}
return LoadCertificateFromPath(
options.ServerCertificatePath,
options.ServerCertificateKeyPath,
options.ServerCertificatePassword);
}
/// <summary>
/// Loads a client certificate from the options.
/// </summary>
/// <param name="options">The TLS transport options.</param>
/// <returns>The loaded certificate, or null if not configured.</returns>
public static X509Certificate2? LoadClientCertificate(TlsTransportOptions options)
{
// Direct certificate object takes precedence
if (options.ClientCertificate is not null)
{
return options.ClientCertificate;
}
// Load from path
if (string.IsNullOrEmpty(options.ClientCertificatePath))
{
return null;
}
return LoadCertificateFromPath(
options.ClientCertificatePath,
options.ClientCertificateKeyPath,
options.ClientCertificatePassword);
}
/// <summary>
/// Loads a certificate from a file path.
/// </summary>
/// <param name="certPath">The certificate path (PEM or PFX).</param>
/// <param name="keyPath">The private key path (optional, for PEM).</param>
/// <param name="password">The password (optional, for PFX).</param>
/// <returns>The loaded certificate.</returns>
public static X509Certificate2 LoadCertificateFromPath(
string certPath,
string? keyPath = null,
string? password = null)
{
var extension = Path.GetExtension(certPath).ToLowerInvariant();
return extension switch
{
".pfx" or ".p12" => LoadPfxCertificate(certPath, password),
".pem" or ".crt" or ".cer" => LoadPemCertificate(certPath, keyPath),
_ => throw new InvalidOperationException($"Unsupported certificate format: {extension}")
};
}
private static X509Certificate2 LoadPfxCertificate(string pfxPath, string? password)
{
return X509CertificateLoader.LoadPkcs12FromFile(
pfxPath,
password,
X509KeyStorageFlags.MachineKeySet | X509KeyStorageFlags.PersistKeySet);
}
private static X509Certificate2 LoadPemCertificate(string certPath, string? keyPath)
{
var certPem = File.ReadAllText(certPath);
if (string.IsNullOrEmpty(keyPath))
{
// Assume the key is in the same file
return X509Certificate2.CreateFromPem(certPem);
}
var keyPem = File.ReadAllText(keyPath);
return X509Certificate2.CreateFromPem(certPem, keyPem);
}
}

View File

@@ -0,0 +1,219 @@
using System.Security.Cryptography.X509Certificates;
using Microsoft.Extensions.Logging;
namespace StellaOps.Router.Transport.Tls;
/// <summary>
/// Watches certificate files for changes and triggers hot-reload.
/// </summary>
public sealed class CertificateWatcher : IDisposable
{
private readonly TlsTransportOptions _options;
private readonly ILogger _logger;
private readonly List<FileSystemWatcher> _watchers = new();
private volatile X509Certificate2? _currentServerCert;
private volatile X509Certificate2? _currentClientCert;
private bool _disposed;
/// <summary>
/// Event raised when the server certificate is reloaded.
/// </summary>
public event Action<X509Certificate2>? OnServerCertificateReloaded;
/// <summary>
/// Event raised when the client certificate is reloaded.
/// </summary>
public event Action<X509Certificate2>? OnClientCertificateReloaded;
/// <summary>
/// Gets the current server certificate.
/// </summary>
public X509Certificate2? ServerCertificate => _currentServerCert;
/// <summary>
/// Gets the current client certificate.
/// </summary>
public X509Certificate2? ClientCertificate => _currentClientCert;
/// <summary>
/// Initializes a new instance of the <see cref="CertificateWatcher"/> class.
/// </summary>
public CertificateWatcher(TlsTransportOptions options, ILogger logger)
{
_options = options;
_logger = logger;
// Load initial certificates
LoadCertificates();
// Set up file system watchers if hot-reload is enabled
if (_options.EnableCertificateHotReload)
{
SetupWatchers();
}
}
private void LoadCertificates()
{
try
{
if (!string.IsNullOrEmpty(_options.ServerCertificatePath) ||
_options.ServerCertificate is not null)
{
_currentServerCert = CertificateLoader.LoadServerCertificate(_options);
_logger.LogInformation(
"Loaded server certificate: {Subject}, Expires: {Expiry}",
_currentServerCert.Subject,
_currentServerCert.NotAfter);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to load server certificate");
throw;
}
try
{
_currentClientCert = CertificateLoader.LoadClientCertificate(_options);
if (_currentClientCert is not null)
{
_logger.LogInformation(
"Loaded client certificate: {Subject}, Expires: {Expiry}",
_currentClientCert.Subject,
_currentClientCert.NotAfter);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to load client certificate");
throw;
}
}
private void SetupWatchers()
{
if (!string.IsNullOrEmpty(_options.ServerCertificatePath))
{
var watcher = CreateWatcher(_options.ServerCertificatePath, ReloadServerCertificate);
if (watcher is not null) _watchers.Add(watcher);
}
if (!string.IsNullOrEmpty(_options.ServerCertificateKeyPath))
{
var watcher = CreateWatcher(_options.ServerCertificateKeyPath, ReloadServerCertificate);
if (watcher is not null) _watchers.Add(watcher);
}
if (!string.IsNullOrEmpty(_options.ClientCertificatePath))
{
var watcher = CreateWatcher(_options.ClientCertificatePath, ReloadClientCertificate);
if (watcher is not null) _watchers.Add(watcher);
}
if (!string.IsNullOrEmpty(_options.ClientCertificateKeyPath))
{
var watcher = CreateWatcher(_options.ClientCertificateKeyPath, ReloadClientCertificate);
if (watcher is not null) _watchers.Add(watcher);
}
}
private FileSystemWatcher? CreateWatcher(string filePath, Action reloadAction)
{
var directory = Path.GetDirectoryName(filePath);
if (string.IsNullOrEmpty(directory) || !Directory.Exists(directory))
{
_logger.LogWarning("Cannot watch certificate path: directory not found for {Path}", filePath);
return null;
}
var fileName = Path.GetFileName(filePath);
var watcher = new FileSystemWatcher(directory, fileName)
{
NotifyFilter = NotifyFilters.LastWrite | NotifyFilters.CreationTime
};
// Debounce file changes to avoid multiple reloads
DateTime lastReload = DateTime.MinValue;
watcher.Changed += (sender, args) =>
{
if (DateTime.UtcNow - lastReload < TimeSpan.FromSeconds(5))
return;
lastReload = DateTime.UtcNow;
_logger.LogInformation("Certificate file changed: {Path}", filePath);
// Delay slightly to ensure file write is complete
Task.Delay(500).ContinueWith(_ => reloadAction());
};
watcher.EnableRaisingEvents = true;
_logger.LogInformation("Watching certificate file: {Path}", filePath);
return watcher;
}
private void ReloadServerCertificate()
{
try
{
var oldCert = _currentServerCert;
_currentServerCert = CertificateLoader.LoadServerCertificate(_options);
_logger.LogInformation(
"Reloaded server certificate: {Subject}, Expires: {Expiry}",
_currentServerCert.Subject,
_currentServerCert.NotAfter);
OnServerCertificateReloaded?.Invoke(_currentServerCert);
// Dispose old certificate
oldCert?.Dispose();
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to reload server certificate");
}
}
private void ReloadClientCertificate()
{
try
{
var oldCert = _currentClientCert;
_currentClientCert = CertificateLoader.LoadClientCertificate(_options);
if (_currentClientCert is not null)
{
_logger.LogInformation(
"Reloaded client certificate: {Subject}, Expires: {Expiry}",
_currentClientCert.Subject,
_currentClientCert.NotAfter);
OnClientCertificateReloaded?.Invoke(_currentClientCert);
}
// Dispose old certificate
oldCert?.Dispose();
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to reload client certificate");
}
}
/// <inheritdoc />
public void Dispose()
{
if (_disposed) return;
_disposed = true;
foreach (var watcher in _watchers)
{
watcher.EnableRaisingEvents = false;
watcher.Dispose();
}
_watchers.Clear();
}
}

View File

@@ -0,0 +1,53 @@
using Microsoft.Extensions.DependencyInjection;
using StellaOps.Router.Common.Abstractions;
namespace StellaOps.Router.Transport.Tls;
/// <summary>
/// Extension methods for registering TLS transport services.
/// </summary>
public static class ServiceCollectionExtensions
{
/// <summary>
/// Adds the TLS transport server to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configure">Configuration action.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddTlsTransportServer(
this IServiceCollection services,
Action<TlsTransportOptions>? configure = null)
{
if (configure is not null)
{
services.Configure(configure);
}
services.AddSingleton<TlsTransportServer>();
services.AddSingleton<ITransportServer>(sp => sp.GetRequiredService<TlsTransportServer>());
return services;
}
/// <summary>
/// Adds the TLS transport client to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configure">Configuration action.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddTlsTransportClient(
this IServiceCollection services,
Action<TlsTransportOptions>? configure = null)
{
if (configure is not null)
{
services.Configure(configure);
}
services.AddSingleton<TlsTransportClient>();
services.AddSingleton<ITransportClient>(sp => sp.GetRequiredService<TlsTransportClient>());
services.AddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<TlsTransportClient>());
return services;
}
}

View File

@@ -0,0 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>preview</LangVersion>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
<ProjectReference Include="..\StellaOps.Router.Transport.Tcp\StellaOps.Router.Transport.Tcp.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,220 @@
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Extensions.Logging;
using StellaOps.Router.Common.Models;
using StellaOps.Router.Transport.Tcp;
namespace StellaOps.Router.Transport.Tls;
/// <summary>
/// Represents a TLS-secured connection to a microservice.
/// </summary>
public sealed class TlsConnection : IAsyncDisposable
{
private readonly TcpClient _client;
private readonly SslStream _sslStream;
private readonly SemaphoreSlim _writeLock = new(1, 1);
private readonly TlsTransportOptions _options;
private readonly ILogger _logger;
private readonly CancellationTokenSource _connectionCts = new();
private bool _disposed;
/// <summary>
/// Gets the connection ID.
/// </summary>
public string ConnectionId { get; }
/// <summary>
/// Gets the remote endpoint as a string.
/// </summary>
public string RemoteEndpoint { get; }
/// <summary>
/// Gets a value indicating whether the connection is active.
/// </summary>
public bool IsConnected => _client.Connected && !_disposed;
/// <summary>
/// Gets the connection state.
/// </summary>
public ConnectionState? State { get; set; }
/// <summary>
/// Gets the cancellation token for this connection.
/// </summary>
public CancellationToken ConnectionToken => _connectionCts.Token;
/// <summary>
/// Gets the remote certificate (if mTLS).
/// </summary>
public X509Certificate? RemoteCertificate => _sslStream.RemoteCertificate;
/// <summary>
/// Gets the peer identity extracted from the certificate.
/// </summary>
public string? PeerIdentity { get; }
/// <summary>
/// Event raised when a frame is received.
/// </summary>
public event Action<TlsConnection, Frame>? OnFrameReceived;
/// <summary>
/// Event raised when the connection is closed.
/// </summary>
public event Action<TlsConnection, Exception?>? OnDisconnected;
/// <summary>
/// Initializes a new instance of the <see cref="TlsConnection"/> class.
/// </summary>
public TlsConnection(
string connectionId,
TcpClient client,
SslStream sslStream,
TlsTransportOptions options,
ILogger logger)
{
ConnectionId = connectionId;
_client = client;
_sslStream = sslStream;
_options = options;
_logger = logger;
RemoteEndpoint = client.Client.RemoteEndPoint?.ToString() ?? "unknown";
// Extract peer identity from certificate
if (_sslStream.RemoteCertificate is X509Certificate2 cert)
{
PeerIdentity = ExtractIdentityFromCertificate(cert);
}
// Configure socket options
client.ReceiveBufferSize = options.ReceiveBufferSize;
client.SendBufferSize = options.SendBufferSize;
client.NoDelay = true;
}
/// <summary>
/// Extracts identity from a certificate.
/// </summary>
private static string? ExtractIdentityFromCertificate(X509Certificate2 cert)
{
// Try to get Common Name (CN)
var cn = cert.GetNameInfo(X509NameType.SimpleName, forIssuer: false);
if (!string.IsNullOrEmpty(cn))
{
return cn;
}
// Fallback to subject
return cert.Subject;
}
/// <summary>
/// Starts the read loop to receive frames.
/// </summary>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task ReadLoopAsync(CancellationToken cancellationToken)
{
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(
cancellationToken, _connectionCts.Token);
Exception? disconnectException = null;
try
{
while (!linkedCts.Token.IsCancellationRequested)
{
var frame = await FrameProtocol.ReadFrameAsync(
_sslStream,
_options.MaxFrameSize,
linkedCts.Token);
if (frame is null)
{
_logger.LogDebug("TLS connection {ConnectionId} closed by remote", ConnectionId);
break;
}
OnFrameReceived?.Invoke(this, frame);
}
}
catch (OperationCanceledException)
{
// Expected on shutdown
}
catch (IOException ex) when (ex.InnerException is SocketException)
{
disconnectException = ex;
_logger.LogDebug(ex, "TLS connection {ConnectionId} socket error", ConnectionId);
}
catch (Exception ex)
{
disconnectException = ex;
_logger.LogWarning(ex, "TLS connection {ConnectionId} read error", ConnectionId);
}
OnDisconnected?.Invoke(this, disconnectException);
}
/// <summary>
/// Writes a frame to the connection.
/// </summary>
/// <param name="frame">The frame to write.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task WriteFrameAsync(Frame frame, CancellationToken cancellationToken = default)
{
ObjectDisposedException.ThrowIf(_disposed, this);
await _writeLock.WaitAsync(cancellationToken);
try
{
await FrameProtocol.WriteFrameAsync(_sslStream, frame, cancellationToken);
await _sslStream.FlushAsync(cancellationToken);
}
finally
{
_writeLock.Release();
}
}
/// <summary>
/// Closes the connection.
/// </summary>
public void Close()
{
if (_disposed) return;
try
{
_connectionCts.Cancel();
_sslStream.Close();
_client.Close();
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Error closing TLS connection {ConnectionId}", ConnectionId);
}
}
/// <inheritdoc />
public async ValueTask DisposeAsync()
{
if (_disposed) return;
_disposed = true;
try
{
await _connectionCts.CancelAsync();
}
catch
{
// Ignore
}
await _sslStream.DisposeAsync();
_client.Dispose();
_writeLock.Dispose();
_connectionCts.Dispose();
}
}

View File

@@ -0,0 +1,578 @@
using System.Buffers;
using System.Collections.Concurrent;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using StellaOps.Router.Transport.Tcp;
namespace StellaOps.Router.Transport.Tls;
/// <summary>
/// TLS transport client implementation for microservices.
/// </summary>
public sealed class TlsTransportClient : ITransportClient, IMicroserviceTransport, IAsyncDisposable
{
private readonly TlsTransportOptions _options;
private readonly ILogger<TlsTransportClient> _logger;
private readonly CertificateWatcher _certWatcher;
private readonly PendingRequestTracker _pendingRequests = new();
private readonly ConcurrentDictionary<string, CancellationTokenSource> _inflightHandlers = new();
private readonly CancellationTokenSource _clientCts = new();
private TcpClient? _client;
private SslStream? _sslStream;
private readonly SemaphoreSlim _writeLock = new(1, 1);
private Task? _receiveTask;
private bool _disposed;
private string? _connectionId;
private int _reconnectAttempts;
/// <summary>
/// Event raised when a REQUEST frame is received.
/// </summary>
public event Func<Frame, CancellationToken, Task<Frame>>? OnRequestReceived;
/// <summary>
/// Event raised when a CANCEL frame is received.
/// </summary>
public event Func<Guid, string?, Task>? OnCancelReceived;
/// <summary>
/// Initializes a new instance of the <see cref="TlsTransportClient"/> class.
/// </summary>
public TlsTransportClient(
IOptions<TlsTransportOptions> options,
ILogger<TlsTransportClient> logger)
{
_options = options.Value;
_logger = logger;
_certWatcher = new CertificateWatcher(_options, logger);
}
/// <summary>
/// Connects to the gateway.
/// </summary>
/// <param name="instance">The instance descriptor.</param>
/// <param name="endpoints">The endpoints to register.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task ConnectAsync(
InstanceDescriptor instance,
IReadOnlyList<EndpointDescriptor> endpoints,
CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
if (string.IsNullOrEmpty(_options.Host))
{
throw new InvalidOperationException("Host is not configured");
}
await ConnectInternalAsync(cancellationToken);
_connectionId = Guid.NewGuid().ToString("N");
// Send HELLO frame
var helloFrame = new Frame
{
Type = FrameType.Hello,
CorrelationId = Guid.NewGuid().ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
await WriteFrameAsync(helloFrame, cancellationToken);
_logger.LogInformation(
"Connected to TLS gateway at {Host}:{Port} as {ServiceName}/{Version}",
_options.Host,
_options.Port,
instance.ServiceName,
instance.Version);
// Start receiving frames
_receiveTask = Task.Run(() => ReceiveLoopAsync(_clientCts.Token), CancellationToken.None);
}
private async Task ConnectInternalAsync(CancellationToken cancellationToken)
{
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCts.CancelAfter(_options.ConnectTimeout);
_client = new TcpClient
{
ReceiveBufferSize = _options.ReceiveBufferSize,
SendBufferSize = _options.SendBufferSize,
NoDelay = true
};
await _client.ConnectAsync(_options.Host!, _options.Port, timeoutCts.Token);
_sslStream = new SslStream(
_client.GetStream(),
leaveInnerStreamOpen: false,
userCertificateValidationCallback: ValidateServerCertificate);
var clientCerts = _certWatcher.ClientCertificate is not null
? new X509CertificateCollection { _certWatcher.ClientCertificate }
: null;
await _sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions
{
TargetHost = _options.ExpectedServerHostname ?? _options.Host,
ClientCertificates = clientCerts,
EnabledSslProtocols = _options.EnabledProtocols,
CertificateRevocationCheckMode = _options.CheckCertificateRevocation
? X509RevocationMode.Online
: X509RevocationMode.NoCheck
}, timeoutCts.Token);
_logger.LogInformation(
"TLS handshake completed: Protocol={Protocol}, CipherSuite={CipherSuite}",
_sslStream.SslProtocol,
_sslStream.NegotiatedCipherSuite);
_reconnectAttempts = 0;
}
private bool ValidateServerCertificate(
object sender,
X509Certificate? certificate,
X509Chain? chain,
SslPolicyErrors errors)
{
// Allow self-signed in development
if (_options.AllowSelfSigned)
{
if (errors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors))
{
// Check if the only error is self-signed
if (chain is not null && chain.ChainStatus.All(s =>
s.Status == X509ChainStatusFlags.UntrustedRoot ||
s.Status == X509ChainStatusFlags.PartialChain))
{
_logger.LogDebug("Allowing self-signed server certificate");
return true;
}
}
// Allow if no errors or only name mismatch
if (errors == SslPolicyErrors.None ||
errors == SslPolicyErrors.RemoteCertificateNameMismatch)
{
return true;
}
}
if (errors != SslPolicyErrors.None)
{
_logger.LogWarning("Server certificate validation failed: {Errors}", errors);
return false;
}
// Hostname verification
if (!string.IsNullOrEmpty(_options.ExpectedServerHostname) && certificate is not null)
{
var cert = new X509Certificate2(certificate);
var cn = cert.GetNameInfo(X509NameType.SimpleName, forIssuer: false);
if (!string.Equals(cn, _options.ExpectedServerHostname, StringComparison.OrdinalIgnoreCase))
{
_logger.LogWarning(
"Server certificate hostname mismatch: expected {Expected}, got {Actual}",
_options.ExpectedServerHostname,
cn);
return false;
}
}
return true;
}
private async Task ReconnectAsync()
{
if (_disposed) return;
while (_reconnectAttempts < _options.MaxReconnectAttempts && !_clientCts.Token.IsCancellationRequested)
{
_reconnectAttempts++;
var backoff = TimeSpan.FromMilliseconds(
Math.Min(
Math.Pow(2, _reconnectAttempts) * 100,
_options.MaxReconnectBackoff.TotalMilliseconds));
_logger.LogInformation(
"TLS reconnection attempt {Attempt} of {Max} in {Delay}ms",
_reconnectAttempts,
_options.MaxReconnectAttempts,
backoff.TotalMilliseconds);
await Task.Delay(backoff, _clientCts.Token);
try
{
_sslStream?.Dispose();
_client?.Dispose();
await ConnectInternalAsync(_clientCts.Token);
_logger.LogInformation("Reconnected to TLS gateway");
return;
}
catch (Exception ex)
{
_logger.LogWarning(ex, "TLS reconnection attempt {Attempt} failed", _reconnectAttempts);
}
}
_logger.LogError("Max TLS reconnection attempts reached, giving up");
}
private async Task ReceiveLoopAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
try
{
var frame = await FrameProtocol.ReadFrameAsync(
_sslStream!,
_options.MaxFrameSize,
cancellationToken);
if (frame is null)
{
_logger.LogDebug("TLS connection closed by server");
await ReconnectAsync();
continue;
}
await ProcessFrameAsync(frame, cancellationToken);
}
catch (OperationCanceledException)
{
break;
}
catch (IOException ex) when (ex.InnerException is SocketException)
{
_logger.LogDebug(ex, "TLS socket error, attempting reconnection");
await ReconnectAsync();
}
catch (AuthenticationException ex)
{
_logger.LogError(ex, "TLS authentication error during receive");
break;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error in TLS receive loop");
await Task.Delay(1000, cancellationToken);
}
}
}
private async Task ProcessFrameAsync(Frame frame, CancellationToken cancellationToken)
{
switch (frame.Type)
{
case FrameType.Request:
case FrameType.RequestStreamData:
await HandleRequestFrameAsync(frame, cancellationToken);
break;
case FrameType.Cancel:
HandleCancelFrame(frame);
break;
case FrameType.Response:
case FrameType.ResponseStreamData:
if (frame.CorrelationId is not null &&
Guid.TryParse(frame.CorrelationId, out var correlationId))
{
_pendingRequests.CompleteRequest(correlationId, frame);
}
break;
default:
_logger.LogWarning("Unexpected frame type {FrameType}", frame.Type);
break;
}
}
private async Task HandleRequestFrameAsync(Frame frame, CancellationToken cancellationToken)
{
if (OnRequestReceived is null)
{
_logger.LogWarning("No request handler registered");
return;
}
var correlationId = frame.CorrelationId ?? Guid.NewGuid().ToString("N");
using var handlerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_inflightHandlers[correlationId] = handlerCts;
try
{
var response = await OnRequestReceived(frame, handlerCts.Token);
var responseFrame = response with { CorrelationId = correlationId };
if (!handlerCts.Token.IsCancellationRequested)
{
await WriteFrameAsync(responseFrame, cancellationToken);
}
}
catch (OperationCanceledException)
{
_logger.LogDebug("Request {CorrelationId} was cancelled", correlationId);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error handling request {CorrelationId}", correlationId);
}
finally
{
_inflightHandlers.TryRemove(correlationId, out _);
}
}
private void HandleCancelFrame(Frame frame)
{
if (frame.CorrelationId is null) return;
_logger.LogDebug("Received CANCEL for {CorrelationId}", frame.CorrelationId);
if (_inflightHandlers.TryGetValue(frame.CorrelationId, out var cts))
{
try
{
cts.Cancel();
}
catch (ObjectDisposedException)
{
// Already completed
}
}
if (Guid.TryParse(frame.CorrelationId, out var guid))
{
_pendingRequests.CancelRequest(guid);
OnCancelReceived?.Invoke(guid, null);
}
}
private async Task WriteFrameAsync(Frame frame, CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
await _writeLock.WaitAsync(cancellationToken);
try
{
await FrameProtocol.WriteFrameAsync(_sslStream!, frame, cancellationToken);
await _sslStream!.FlushAsync(cancellationToken);
}
finally
{
_writeLock.Release();
}
}
/// <inheritdoc />
public async Task<Frame> SendRequestAsync(
ConnectionState connection,
Frame requestFrame,
TimeSpan timeout,
CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var correlationId = requestFrame.CorrelationId is not null &&
Guid.TryParse(requestFrame.CorrelationId, out var parsed)
? parsed
: Guid.NewGuid();
var framedRequest = requestFrame with { CorrelationId = correlationId.ToString("N") };
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCts.CancelAfter(timeout);
var responseTask = _pendingRequests.TrackRequest(correlationId, timeoutCts.Token);
await WriteFrameAsync(framedRequest, timeoutCts.Token);
try
{
return await responseTask;
}
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
{
throw new TimeoutException($"Request {correlationId} timed out after {timeout}");
}
}
/// <inheritdoc />
public async Task SendCancelAsync(
ConnectionState connection,
Guid correlationId,
string? reason = null)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var cancelFrame = new Frame
{
Type = FrameType.Cancel,
CorrelationId = correlationId.ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
await WriteFrameAsync(cancelFrame, CancellationToken.None);
_logger.LogDebug("Sent CANCEL for {CorrelationId}", correlationId);
}
/// <inheritdoc />
public async Task SendStreamingAsync(
ConnectionState connection,
Frame requestHeader,
Stream requestBody,
Func<Stream, Task> readResponseBody,
PayloadLimits limits,
CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var correlationId = requestHeader.CorrelationId is not null &&
Guid.TryParse(requestHeader.CorrelationId, out var parsed)
? parsed
: Guid.NewGuid();
var headerFrame = requestHeader with
{
Type = FrameType.Request,
CorrelationId = correlationId.ToString("N")
};
await WriteFrameAsync(headerFrame, cancellationToken);
// Stream request body
var buffer = ArrayPool<byte>.Shared.Rent(8192);
try
{
long totalBytesRead = 0;
int bytesRead;
while ((bytesRead = await requestBody.ReadAsync(buffer, cancellationToken)) > 0)
{
totalBytesRead += bytesRead;
if (totalBytesRead > limits.MaxRequestBytesPerCall)
{
throw new InvalidOperationException(
$"Request body exceeds limit of {limits.MaxRequestBytesPerCall} bytes");
}
var dataFrame = new Frame
{
Type = FrameType.RequestStreamData,
CorrelationId = correlationId.ToString("N"),
Payload = new ReadOnlyMemory<byte>(buffer, 0, bytesRead)
};
await WriteFrameAsync(dataFrame, cancellationToken);
}
// End of stream marker
var endFrame = new Frame
{
Type = FrameType.RequestStreamData,
CorrelationId = correlationId.ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
await WriteFrameAsync(endFrame, cancellationToken);
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}
// Read streaming response
using var responseStream = new MemoryStream();
await readResponseBody(responseStream);
}
/// <summary>
/// Sends a heartbeat.
/// </summary>
public async Task SendHeartbeatAsync(HeartbeatPayload heartbeat, CancellationToken cancellationToken)
{
var frame = new Frame
{
Type = FrameType.Heartbeat,
CorrelationId = null,
Payload = ReadOnlyMemory<byte>.Empty
};
await WriteFrameAsync(frame, cancellationToken);
}
/// <summary>
/// Cancels all in-flight handlers.
/// </summary>
public void CancelAllInflight(string reason)
{
var count = 0;
foreach (var cts in _inflightHandlers.Values)
{
try
{
cts.Cancel();
count++;
}
catch (ObjectDisposedException)
{
// Already completed
}
}
if (count > 0)
{
_logger.LogInformation("Cancelled {Count} in-flight handlers: {Reason}", count, reason);
}
}
/// <summary>
/// Disconnects from the gateway.
/// </summary>
public async Task DisconnectAsync()
{
CancelAllInflight("Shutdown");
await _clientCts.CancelAsync();
if (_receiveTask is not null)
{
try
{
await _receiveTask;
}
catch
{
// Ignore
}
}
_sslStream?.Dispose();
_client?.Dispose();
_logger.LogInformation("Disconnected from TLS gateway");
}
/// <inheritdoc />
public async ValueTask DisposeAsync()
{
if (_disposed) return;
_disposed = true;
await DisconnectAsync();
_certWatcher.Dispose();
_pendingRequests.Dispose();
_writeLock.Dispose();
_clientCts.Dispose();
}
}

View File

@@ -0,0 +1,137 @@
using System.Net;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
namespace StellaOps.Router.Transport.Tls;
/// <summary>
/// Options for TLS transport configuration.
/// </summary>
public sealed class TlsTransportOptions
{
/// <summary>
/// Gets or sets the bind address for the server.
/// </summary>
public IPAddress BindAddress { get; set; } = IPAddress.Any;
/// <summary>
/// Gets or sets the port to listen on.
/// </summary>
public int Port { get; set; } = 5101;
/// <summary>
/// Gets or sets the host to connect to (client only).
/// </summary>
public string? Host { get; set; }
/// <summary>
/// Gets or sets the receive buffer size.
/// </summary>
public int ReceiveBufferSize { get; set; } = 64 * 1024;
/// <summary>
/// Gets or sets the send buffer size.
/// </summary>
public int SendBufferSize { get; set; } = 64 * 1024;
/// <summary>
/// Gets or sets the keep-alive interval.
/// </summary>
public TimeSpan KeepAliveInterval { get; set; } = TimeSpan.FromSeconds(30);
/// <summary>
/// Gets or sets the connection timeout.
/// </summary>
public TimeSpan ConnectTimeout { get; set; } = TimeSpan.FromSeconds(10);
/// <summary>
/// Gets or sets the maximum reconnection attempts.
/// </summary>
public int MaxReconnectAttempts { get; set; } = 10;
/// <summary>
/// Gets or sets the maximum reconnection backoff.
/// </summary>
public TimeSpan MaxReconnectBackoff { get; set; } = TimeSpan.FromMinutes(1);
/// <summary>
/// Gets or sets the maximum frame size.
/// </summary>
public int MaxFrameSize { get; set; } = 16 * 1024 * 1024;
// Server-side certificate (Gateway)
/// <summary>
/// Gets or sets the server certificate object.
/// </summary>
public X509Certificate2? ServerCertificate { get; set; }
/// <summary>
/// Gets or sets the server certificate path (PEM or PFX).
/// </summary>
public string? ServerCertificatePath { get; set; }
/// <summary>
/// Gets or sets the server certificate key path (PEM private key).
/// </summary>
public string? ServerCertificateKeyPath { get; set; }
/// <summary>
/// Gets or sets the server certificate password (for PFX).
/// </summary>
public string? ServerCertificatePassword { get; set; }
// Client-side certificate (Microservice)
/// <summary>
/// Gets or sets the client certificate object.
/// </summary>
public X509Certificate2? ClientCertificate { get; set; }
/// <summary>
/// Gets or sets the client certificate path (PEM or PFX).
/// </summary>
public string? ClientCertificatePath { get; set; }
/// <summary>
/// Gets or sets the client certificate key path (PEM private key).
/// </summary>
public string? ClientCertificateKeyPath { get; set; }
/// <summary>
/// Gets or sets the client certificate password (for PFX).
/// </summary>
public string? ClientCertificatePassword { get; set; }
// Validation options
/// <summary>
/// Gets or sets whether to require client certificates (mTLS).
/// </summary>
public bool RequireClientCertificate { get; set; } = false;
/// <summary>
/// Gets or sets whether to allow self-signed certificates (dev only).
/// </summary>
public bool AllowSelfSigned { get; set; } = false;
/// <summary>
/// Gets or sets whether to check certificate revocation.
/// </summary>
public bool CheckCertificateRevocation { get; set; } = false;
/// <summary>
/// Gets or sets the expected server hostname (for SNI).
/// </summary>
public string? ExpectedServerHostname { get; set; }
/// <summary>
/// Gets or sets the enabled SSL/TLS protocols.
/// </summary>
public SslProtocols EnabledProtocols { get; set; } = SslProtocols.Tls12 | SslProtocols.Tls13;
/// <summary>
/// Gets or sets whether to enable certificate hot-reload.
/// </summary>
public bool EnableCertificateHotReload { get; set; } = false;
}

View File

@@ -0,0 +1,342 @@
using System.Collections.Concurrent;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.Tls;
/// <summary>
/// TLS transport server implementation for the gateway.
/// </summary>
public sealed class TlsTransportServer : ITransportServer, IAsyncDisposable
{
private readonly TlsTransportOptions _options;
private readonly ILogger<TlsTransportServer> _logger;
private readonly ConcurrentDictionary<string, TlsConnection> _connections = new();
private readonly CertificateWatcher _certWatcher;
private TcpListener? _listener;
private CancellationTokenSource? _serverCts;
private Task? _acceptTask;
private bool _disposed;
/// <summary>
/// Event raised when a connection is established.
/// </summary>
public event Action<string, ConnectionState>? OnConnection;
/// <summary>
/// Event raised when a connection is lost.
/// </summary>
public event Action<string>? OnDisconnection;
/// <summary>
/// Event raised when a frame is received.
/// </summary>
public event Action<string, Frame>? OnFrame;
/// <summary>
/// Initializes a new instance of the <see cref="TlsTransportServer"/> class.
/// </summary>
public TlsTransportServer(
IOptions<TlsTransportOptions> options,
ILogger<TlsTransportServer> logger)
{
_options = options.Value;
_logger = logger;
_certWatcher = new CertificateWatcher(_options, logger);
}
/// <inheritdoc />
public Task StartAsync(CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
if (_certWatcher.ServerCertificate is null)
{
throw new InvalidOperationException("Server certificate is not configured");
}
_serverCts = new CancellationTokenSource();
_listener = new TcpListener(_options.BindAddress, _options.Port);
_listener.Start();
_logger.LogInformation(
"TLS transport server listening on {Address}:{Port}",
_options.BindAddress,
_options.Port);
_acceptTask = AcceptLoopAsync(_serverCts.Token);
return Task.CompletedTask;
}
private async Task AcceptLoopAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
TcpClient? client = null;
SslStream? sslStream = null;
try
{
client = await _listener!.AcceptTcpClientAsync(cancellationToken);
_logger.LogDebug(
"Accepting TLS connection from {RemoteEndpoint}",
client.Client.RemoteEndPoint);
sslStream = new SslStream(
client.GetStream(),
leaveInnerStreamOpen: false,
userCertificateValidationCallback: ValidateClientCertificate);
await sslStream.AuthenticateAsServerAsync(new SslServerAuthenticationOptions
{
ServerCertificate = _certWatcher.ServerCertificate,
ClientCertificateRequired = _options.RequireClientCertificate,
EnabledSslProtocols = _options.EnabledProtocols,
CertificateRevocationCheckMode = _options.CheckCertificateRevocation
? X509RevocationMode.Online
: X509RevocationMode.NoCheck
}, cancellationToken);
var connectionId = GenerateConnectionId(client, sslStream.RemoteCertificate);
_logger.LogInformation(
"TLS connection established: {ConnectionId} from {RemoteEndpoint}, Protocol: {Protocol}, CipherSuite: {CipherSuite}",
connectionId,
client.Client.RemoteEndPoint,
sslStream.SslProtocol,
sslStream.NegotiatedCipherSuite);
var connection = new TlsConnection(connectionId, client, sslStream, _options, _logger);
_connections[connectionId] = connection;
connection.OnFrameReceived += HandleFrame;
connection.OnDisconnected += HandleDisconnect;
// Start read loop (non-blocking)
_ = Task.Run(() => connection.ReadLoopAsync(cancellationToken), CancellationToken.None);
}
catch (OperationCanceledException)
{
// Expected on shutdown
break;
}
catch (ObjectDisposedException)
{
// Listener disposed
break;
}
catch (AuthenticationException ex)
{
_logger.LogWarning(ex,
"TLS handshake failed from {RemoteEndpoint}",
client?.Client?.RemoteEndPoint);
sslStream?.Dispose();
client?.Dispose();
}
catch (Exception ex)
{
_logger.LogError(ex, "Error accepting TLS connection");
sslStream?.Dispose();
client?.Dispose();
}
}
}
private bool ValidateClientCertificate(
object sender,
X509Certificate? certificate,
X509Chain? chain,
SslPolicyErrors errors)
{
// If we don't require client certs and none provided, allow
if (!_options.RequireClientCertificate && certificate is null)
{
return true;
}
// If client cert is required but not provided, reject
if (_options.RequireClientCertificate && certificate is null)
{
_logger.LogWarning("Client certificate required but not provided");
return false;
}
// Allow self-signed in development
if (_options.AllowSelfSigned)
{
if (errors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors))
{
// Check if the only error is self-signed
if (chain is not null && chain.ChainStatus.All(s =>
s.Status == X509ChainStatusFlags.UntrustedRoot ||
s.Status == X509ChainStatusFlags.PartialChain))
{
_logger.LogDebug("Allowing self-signed client certificate");
return true;
}
}
// Allow if no errors or only name mismatch
if (errors == SslPolicyErrors.None ||
errors == SslPolicyErrors.RemoteCertificateNameMismatch)
{
return true;
}
}
if (errors != SslPolicyErrors.None)
{
_logger.LogWarning("Client certificate validation failed: {Errors}", errors);
return false;
}
return true;
}
private void HandleFrame(TlsConnection connection, Frame frame)
{
// If this is a HELLO frame, create the ConnectionState
if (frame.Type == FrameType.Hello && connection.State is null)
{
var state = new ConnectionState
{
ConnectionId = connection.ConnectionId,
Instance = new InstanceDescriptor
{
InstanceId = connection.ConnectionId,
ServiceName = connection.PeerIdentity ?? "unknown",
Version = "1.0.0",
Region = "default"
},
Status = InstanceHealthStatus.Healthy,
LastHeartbeatUtc = DateTime.UtcNow,
TransportType = TransportType.Certificate
};
connection.State = state;
OnConnection?.Invoke(connection.ConnectionId, state);
}
OnFrame?.Invoke(connection.ConnectionId, frame);
}
private void HandleDisconnect(TlsConnection connection, Exception? ex)
{
_logger.LogInformation(
"TLS connection {ConnectionId} disconnected{Reason}",
connection.ConnectionId,
ex is not null ? $": {ex.Message}" : string.Empty);
_connections.TryRemove(connection.ConnectionId, out _);
OnDisconnection?.Invoke(connection.ConnectionId);
// Clean up connection
_ = connection.DisposeAsync();
}
/// <summary>
/// Sends a frame to a connection.
/// </summary>
/// <param name="connectionId">The connection ID.</param>
/// <param name="frame">The frame to send.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task SendFrameAsync(
string connectionId,
Frame frame,
CancellationToken cancellationToken = default)
{
if (_connections.TryGetValue(connectionId, out var connection))
{
await connection.WriteFrameAsync(frame, cancellationToken);
}
else
{
throw new InvalidOperationException($"Connection {connectionId} not found");
}
}
/// <summary>
/// Gets a connection by ID.
/// </summary>
/// <param name="connectionId">The connection ID.</param>
/// <returns>The connection, or null if not found.</returns>
public TlsConnection? GetConnection(string connectionId)
{
return _connections.TryGetValue(connectionId, out var conn) ? conn : null;
}
/// <summary>
/// Gets all active connections.
/// </summary>
public IEnumerable<TlsConnection> GetConnections() => _connections.Values;
/// <summary>
/// Gets the number of active connections.
/// </summary>
public int ConnectionCount => _connections.Count;
private static string GenerateConnectionId(TcpClient client, X509Certificate? remoteCert)
{
var endpoint = client.Client.RemoteEndPoint as IPEndPoint;
var certId = remoteCert?.GetSerialNumberString() ?? "nocert";
if (endpoint is not null)
{
return $"tls-{endpoint.Address}-{endpoint.Port}-{certId}".Substring(0, Math.Min(48, 16 + certId.Length));
}
return $"tls-{Guid.NewGuid():N}";
}
/// <inheritdoc />
public async Task StopAsync(CancellationToken cancellationToken)
{
_logger.LogInformation("Stopping TLS transport server");
if (_serverCts is not null)
{
await _serverCts.CancelAsync();
}
_listener?.Stop();
if (_acceptTask is not null)
{
await _acceptTask;
}
// Close all connections
foreach (var connection in _connections.Values)
{
connection.Close();
await connection.DisposeAsync();
}
_connections.Clear();
_logger.LogInformation("TLS transport server stopped");
}
/// <inheritdoc />
public async ValueTask DisposeAsync()
{
if (_disposed) return;
_disposed = true;
await StopAsync(CancellationToken.None);
_certWatcher.Dispose();
_listener?.Dispose();
_serverCts?.Dispose();
}
}

View File

@@ -0,0 +1,27 @@
namespace StellaOps.Router.Transport.Udp;
/// <summary>
/// Exception thrown when a payload exceeds the maximum datagram size.
/// </summary>
public sealed class PayloadTooLargeException : Exception
{
/// <summary>
/// Gets the actual size of the payload.
/// </summary>
public int ActualSize { get; }
/// <summary>
/// Gets the maximum allowed size.
/// </summary>
public int MaxSize { get; }
/// <summary>
/// Initializes a new instance of the <see cref="PayloadTooLargeException"/> class.
/// </summary>
public PayloadTooLargeException(int actualSize, int maxSize)
: base($"Payload size {actualSize} exceeds maximum datagram size of {maxSize} bytes")
{
ActualSize = actualSize;
MaxSize = maxSize;
}
}

View File

@@ -0,0 +1,53 @@
using Microsoft.Extensions.DependencyInjection;
using StellaOps.Router.Common.Abstractions;
namespace StellaOps.Router.Transport.Udp;
/// <summary>
/// Extension methods for registering UDP transport services.
/// </summary>
public static class ServiceCollectionExtensions
{
/// <summary>
/// Adds UDP transport server services to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configure">Optional configuration action.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddUdpTransportServer(
this IServiceCollection services,
Action<UdpTransportOptions>? configure = null)
{
if (configure is not null)
{
services.Configure(configure);
}
services.AddSingleton<UdpTransportServer>();
services.AddSingleton<ITransportServer>(sp => sp.GetRequiredService<UdpTransportServer>());
return services;
}
/// <summary>
/// Adds UDP transport client services to the service collection.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configure">Optional configuration action.</param>
/// <returns>The service collection.</returns>
public static IServiceCollection AddUdpTransportClient(
this IServiceCollection services,
Action<UdpTransportOptions>? configure = null)
{
if (configure is not null)
{
services.Configure(configure);
}
services.AddSingleton<UdpTransportClient>();
services.AddSingleton<ITransportClient>(sp => sp.GetRequiredService<UdpTransportClient>());
services.AddSingleton<IMicroserviceTransport>(sp => sp.GetRequiredService<UdpTransportClient>());
return services;
}
}

View File

@@ -0,0 +1,22 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>preview</LangVersion>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<RootNamespace>StellaOps.Router.Transport.Udp</RootNamespace>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Options" Version="10.0.0-rc.2.25502.107" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,79 @@
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.Udp;
/// <summary>
/// Handles serialization and deserialization of frames for UDP transport.
/// Frame format: [1-byte frame type][16-byte correlation GUID][remaining data]
/// </summary>
public static class UdpFrameProtocol
{
private const int FrameTypeSize = 1;
private const int CorrelationIdSize = 16;
private const int HeaderSize = FrameTypeSize + CorrelationIdSize;
/// <summary>
/// Parses a frame from a datagram.
/// </summary>
/// <param name="data">The datagram data.</param>
/// <returns>The parsed frame.</returns>
/// <exception cref="InvalidOperationException">Thrown when the datagram is too small.</exception>
public static Frame ParseFrame(ReadOnlySpan<byte> data)
{
if (data.Length < HeaderSize)
{
throw new InvalidOperationException(
$"Datagram too small: {data.Length} bytes, minimum is {HeaderSize}");
}
var frameType = (FrameType)data[0];
var correlationId = new Guid(data.Slice(FrameTypeSize, CorrelationIdSize));
var payload = data.Length > HeaderSize
? data[HeaderSize..].ToArray()
: Array.Empty<byte>();
return new Frame
{
Type = frameType,
CorrelationId = correlationId.ToString("N"),
Payload = payload
};
}
/// <summary>
/// Serializes a frame to a datagram.
/// </summary>
/// <param name="frame">The frame to serialize.</param>
/// <returns>The serialized datagram bytes.</returns>
public static byte[] SerializeFrame(Frame frame)
{
// Parse or generate correlation ID
var correlationGuid = frame.CorrelationId is not null &&
Guid.TryParse(frame.CorrelationId, out var parsed)
? parsed
: Guid.NewGuid();
var payloadLength = frame.Payload.Length;
var buffer = new byte[HeaderSize + payloadLength];
// Write frame type
buffer[0] = (byte)frame.Type;
// Write correlation ID
correlationGuid.TryWriteBytes(buffer.AsSpan(FrameTypeSize, CorrelationIdSize));
// Write payload
if (payloadLength > 0)
{
frame.Payload.Span.CopyTo(buffer.AsSpan(HeaderSize));
}
return buffer;
}
/// <summary>
/// Gets the header size for UDP frames.
/// </summary>
public static int GetHeaderSize() => HeaderSize;
}

View File

@@ -0,0 +1,412 @@
using System.Collections.Concurrent;
using System.Net.Sockets;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.Udp;
/// <summary>
/// UDP transport client implementation for microservices.
/// UDP transport does not support streaming.
/// </summary>
public sealed class UdpTransportClient : ITransportClient, IMicroserviceTransport, IAsyncDisposable
{
private readonly UdpTransportOptions _options;
private readonly ILogger<UdpTransportClient> _logger;
private readonly ConcurrentDictionary<Guid, TaskCompletionSource<Frame>> _pendingRequests = new();
private readonly ConcurrentDictionary<string, CancellationTokenSource> _inflightHandlers = new();
private readonly CancellationTokenSource _clientCts = new();
private UdpClient? _client;
private Task? _receiveTask;
private bool _disposed;
private string? _connectionId;
/// <summary>
/// Event raised when a REQUEST frame is received.
/// </summary>
public event Func<Frame, CancellationToken, Task<Frame>>? OnRequestReceived;
/// <summary>
/// Event raised when a CANCEL frame is received.
/// </summary>
public event Func<Guid, string?, Task>? OnCancelReceived;
/// <summary>
/// Initializes a new instance of the <see cref="UdpTransportClient"/> class.
/// </summary>
public UdpTransportClient(
IOptions<UdpTransportOptions> options,
ILogger<UdpTransportClient> logger)
{
_options = options.Value;
_logger = logger;
}
/// <summary>
/// Connects to the gateway.
/// </summary>
/// <param name="instance">The instance descriptor.</param>
/// <param name="endpoints">The endpoints to register.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task ConnectAsync(
InstanceDescriptor instance,
IReadOnlyList<EndpointDescriptor> endpoints,
CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
if (string.IsNullOrEmpty(_options.Host))
{
throw new InvalidOperationException("Host is not configured");
}
_client = new UdpClient
{
EnableBroadcast = _options.AllowBroadcast
};
_client.Client.ReceiveBufferSize = _options.ReceiveBufferSize;
_client.Client.SendBufferSize = _options.SendBufferSize;
_client.Connect(_options.Host, _options.Port);
_connectionId = Guid.NewGuid().ToString("N");
// Send HELLO frame
var helloFrame = new Frame
{
Type = FrameType.Hello,
CorrelationId = Guid.NewGuid().ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
await SendFrameInternalAsync(helloFrame, cancellationToken);
_logger.LogInformation(
"Connected to UDP gateway at {Host}:{Port} as {ServiceName}/{Version}",
_options.Host,
_options.Port,
instance.ServiceName,
instance.Version);
// Start receiving frames
_receiveTask = Task.Run(() => ReceiveLoopAsync(_clientCts.Token), CancellationToken.None);
}
private async Task ReceiveLoopAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
try
{
var result = await _client!.ReceiveAsync(cancellationToken);
var data = result.Buffer;
if (data.Length < UdpFrameProtocol.GetHeaderSize())
{
_logger.LogWarning("Received datagram too small ({Size} bytes)", data.Length);
continue;
}
var frame = UdpFrameProtocol.ParseFrame(data);
await ProcessFrameAsync(frame, cancellationToken);
}
catch (OperationCanceledException)
{
break;
}
catch (ObjectDisposedException)
{
break;
}
catch (SocketException ex)
{
_logger.LogWarning(ex, "UDP socket error in receive loop");
}
catch (Exception ex)
{
_logger.LogError(ex, "Error in receive loop");
}
}
}
private async Task ProcessFrameAsync(Frame frame, CancellationToken cancellationToken)
{
switch (frame.Type)
{
case FrameType.Request:
await HandleRequestFrameAsync(frame, cancellationToken);
break;
case FrameType.Cancel:
HandleCancelFrame(frame);
break;
case FrameType.Response:
if (frame.CorrelationId is not null &&
Guid.TryParse(frame.CorrelationId, out var correlationId))
{
if (_pendingRequests.TryRemove(correlationId, out var tcs))
{
tcs.TrySetResult(frame);
}
}
break;
case FrameType.RequestStreamData:
case FrameType.ResponseStreamData:
_logger.LogWarning(
"UDP transport does not support streaming. Frame type {Type} ignored.",
frame.Type);
break;
default:
_logger.LogWarning("Unexpected frame type {FrameType}", frame.Type);
break;
}
}
private async Task HandleRequestFrameAsync(Frame frame, CancellationToken cancellationToken)
{
if (OnRequestReceived is null)
{
_logger.LogWarning("No request handler registered");
return;
}
var correlationId = frame.CorrelationId ?? Guid.NewGuid().ToString("N");
using var handlerCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_inflightHandlers[correlationId] = handlerCts;
try
{
var response = await OnRequestReceived(frame, handlerCts.Token);
var responseFrame = response with { CorrelationId = correlationId };
if (!handlerCts.Token.IsCancellationRequested)
{
await SendFrameInternalAsync(responseFrame, cancellationToken);
}
}
catch (OperationCanceledException)
{
_logger.LogDebug("Request {CorrelationId} was cancelled", correlationId);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error handling request {CorrelationId}", correlationId);
}
finally
{
_inflightHandlers.TryRemove(correlationId, out _);
}
}
private void HandleCancelFrame(Frame frame)
{
if (frame.CorrelationId is null) return;
_logger.LogDebug("Received CANCEL for {CorrelationId}", frame.CorrelationId);
if (_inflightHandlers.TryGetValue(frame.CorrelationId, out var cts))
{
try
{
cts.Cancel();
}
catch (ObjectDisposedException)
{
// Already completed
}
}
if (Guid.TryParse(frame.CorrelationId, out var guid))
{
if (_pendingRequests.TryRemove(guid, out var tcs))
{
tcs.TrySetCanceled();
}
OnCancelReceived?.Invoke(guid, null);
}
}
private async Task SendFrameInternalAsync(Frame frame, CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var data = UdpFrameProtocol.SerializeFrame(frame);
if (data.Length > _options.MaxDatagramSize)
{
throw new PayloadTooLargeException(data.Length, _options.MaxDatagramSize);
}
await _client!.SendAsync(data, cancellationToken);
}
/// <inheritdoc />
public async Task<Frame> SendRequestAsync(
ConnectionState connection,
Frame requestFrame,
TimeSpan timeout,
CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var correlationId = requestFrame.CorrelationId is not null &&
Guid.TryParse(requestFrame.CorrelationId, out var parsed)
? parsed
: Guid.NewGuid();
var framedRequest = requestFrame with { CorrelationId = correlationId.ToString("N") };
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCts.CancelAfter(timeout);
var tcs = new TaskCompletionSource<Frame>(TaskCreationOptions.RunContinuationsAsynchronously);
var registration = timeoutCts.Token.Register(() =>
{
if (_pendingRequests.TryRemove(correlationId, out var pendingTcs))
{
pendingTcs.TrySetCanceled(timeoutCts.Token);
}
});
_pendingRequests[correlationId] = tcs;
try
{
await SendFrameInternalAsync(framedRequest, timeoutCts.Token);
return await tcs.Task;
}
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
{
throw new TimeoutException($"Request {correlationId} timed out after {timeout}");
}
finally
{
await registration.DisposeAsync();
_pendingRequests.TryRemove(correlationId, out _);
}
}
/// <inheritdoc />
public async Task SendCancelAsync(
ConnectionState connection,
Guid correlationId,
string? reason = null)
{
ObjectDisposedException.ThrowIf(_disposed, this);
var cancelFrame = new Frame
{
Type = FrameType.Cancel,
CorrelationId = correlationId.ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
// Best effort - UDP may not deliver
await SendFrameInternalAsync(cancelFrame, CancellationToken.None);
_logger.LogDebug("Sent CANCEL for {CorrelationId} (best effort)", correlationId);
}
/// <inheritdoc />
public Task SendStreamingAsync(
ConnectionState connection,
Frame requestHeader,
Stream requestBody,
Func<Stream, Task> readResponseBody,
PayloadLimits limits,
CancellationToken cancellationToken)
{
throw new NotSupportedException(
"UDP transport does not support streaming. Use TCP or TLS transport.");
}
/// <summary>
/// Sends a heartbeat.
/// </summary>
public async Task SendHeartbeatAsync(HeartbeatPayload heartbeat, CancellationToken cancellationToken)
{
var frame = new Frame
{
Type = FrameType.Heartbeat,
CorrelationId = null,
Payload = ReadOnlyMemory<byte>.Empty
};
await SendFrameInternalAsync(frame, cancellationToken);
}
/// <summary>
/// Cancels all in-flight handlers.
/// </summary>
public void CancelAllInflight(string reason)
{
var count = 0;
foreach (var cts in _inflightHandlers.Values)
{
try
{
cts.Cancel();
count++;
}
catch (ObjectDisposedException)
{
// Already completed
}
}
if (count > 0)
{
_logger.LogInformation("Cancelled {Count} in-flight handlers: {Reason}", count, reason);
}
}
/// <summary>
/// Disconnects from the gateway.
/// </summary>
public async Task DisconnectAsync()
{
CancelAllInflight("Shutdown");
// Cancel all pending requests
foreach (var kvp in _pendingRequests)
{
if (_pendingRequests.TryRemove(kvp.Key, out var tcs))
{
tcs.TrySetCanceled();
}
}
await _clientCts.CancelAsync();
if (_receiveTask is not null)
{
try
{
await _receiveTask;
}
catch
{
// Ignore
}
}
_client?.Dispose();
_logger.LogInformation("Disconnected from UDP gateway");
}
/// <inheritdoc />
public async ValueTask DisposeAsync()
{
if (_disposed) return;
_disposed = true;
await DisconnectAsync();
_clientCts.Dispose();
}
}

View File

@@ -0,0 +1,50 @@
using System.Net;
namespace StellaOps.Router.Transport.Udp;
/// <summary>
/// Options for UDP transport configuration.
/// </summary>
public sealed class UdpTransportOptions
{
/// <summary>
/// Gets or sets the bind address for the server.
/// </summary>
public IPAddress BindAddress { get; set; } = IPAddress.Any;
/// <summary>
/// Gets or sets the port to listen on/connect to.
/// </summary>
public int Port { get; set; } = 5102;
/// <summary>
/// Gets or sets the host to connect to (client only).
/// </summary>
public string? Host { get; set; }
/// <summary>
/// Gets or sets the maximum datagram size in bytes.
/// Conservative default well under typical MTU of 1500 bytes.
/// </summary>
public int MaxDatagramSize { get; set; } = 8192;
/// <summary>
/// Gets or sets the default timeout for requests.
/// </summary>
public TimeSpan DefaultTimeout { get; set; } = TimeSpan.FromSeconds(5);
/// <summary>
/// Gets or sets whether to allow broadcast.
/// </summary>
public bool AllowBroadcast { get; set; } = false;
/// <summary>
/// Gets or sets the receive buffer size.
/// </summary>
public int ReceiveBufferSize { get; set; } = 64 * 1024;
/// <summary>
/// Gets or sets the send buffer size.
/// </summary>
public int SendBufferSize { get; set; } = 64 * 1024;
}

View File

@@ -0,0 +1,266 @@
using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Abstractions;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
namespace StellaOps.Router.Transport.Udp;
/// <summary>
/// UDP transport server implementation for the gateway.
/// UDP transport is stateless - connections are logical based on source endpoint.
/// </summary>
public sealed class UdpTransportServer : ITransportServer, IAsyncDisposable
{
private readonly UdpTransportOptions _options;
private readonly ILogger<UdpTransportServer> _logger;
private readonly ConcurrentDictionary<IPEndPoint, string> _endpointToConnectionId = new();
private readonly ConcurrentDictionary<string, (IPEndPoint Endpoint, ConnectionState State)> _connections = new();
private UdpClient? _listener;
private CancellationTokenSource? _serverCts;
private Task? _receiveTask;
private bool _disposed;
/// <summary>
/// Event raised when a connection is established (on first HELLO).
/// </summary>
public event Action<string, ConnectionState>? OnConnection;
/// <summary>
/// Event raised when a connection is lost.
/// </summary>
public event Action<string>? OnDisconnection;
/// <summary>
/// Event raised when a frame is received.
/// </summary>
public event Action<string, Frame>? OnFrame;
/// <summary>
/// Initializes a new instance of the <see cref="UdpTransportServer"/> class.
/// </summary>
public UdpTransportServer(
IOptions<UdpTransportOptions> options,
ILogger<UdpTransportServer> logger)
{
_options = options.Value;
_logger = logger;
}
/// <inheritdoc />
public Task StartAsync(CancellationToken cancellationToken)
{
ObjectDisposedException.ThrowIf(_disposed, this);
_serverCts = new CancellationTokenSource();
var endpoint = new IPEndPoint(_options.BindAddress, _options.Port);
_listener = new UdpClient(endpoint)
{
EnableBroadcast = _options.AllowBroadcast
};
// Configure socket buffers
_listener.Client.ReceiveBufferSize = _options.ReceiveBufferSize;
_listener.Client.SendBufferSize = _options.SendBufferSize;
_logger.LogInformation(
"UDP transport server listening on {Address}:{Port}",
_options.BindAddress,
_options.Port);
_receiveTask = ReceiveLoopAsync(_serverCts.Token);
return Task.CompletedTask;
}
private async Task ReceiveLoopAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
try
{
var result = await _listener!.ReceiveAsync(cancellationToken);
var remoteEndpoint = result.RemoteEndPoint;
var data = result.Buffer;
if (data.Length < UdpFrameProtocol.GetHeaderSize())
{
_logger.LogWarning(
"Received datagram too small ({Size} bytes) from {Endpoint}",
data.Length,
remoteEndpoint);
continue;
}
// Parse frame
var frame = UdpFrameProtocol.ParseFrame(data);
// Get or create connection ID for this endpoint
var connectionId = _endpointToConnectionId.GetOrAdd(
remoteEndpoint,
_ => $"udp-{remoteEndpoint.Address}-{remoteEndpoint.Port}-{Guid.NewGuid():N}"[..32]);
// Handle HELLO specially to register connection
if (frame.Type == FrameType.Hello && !_connections.ContainsKey(connectionId))
{
var state = new ConnectionState
{
ConnectionId = connectionId,
Instance = new InstanceDescriptor
{
InstanceId = connectionId,
ServiceName = "unknown",
Version = "1.0.0",
Region = "default"
},
Status = InstanceHealthStatus.Healthy,
LastHeartbeatUtc = DateTime.UtcNow,
TransportType = TransportType.Udp
};
_connections[connectionId] = (remoteEndpoint, state);
_logger.LogInformation(
"UDP connection established: {ConnectionId} from {Endpoint}",
connectionId,
remoteEndpoint);
OnConnection?.Invoke(connectionId, state);
}
// Update heartbeat timestamp on HEARTBEAT frames
if (frame.Type == FrameType.Heartbeat &&
_connections.TryGetValue(connectionId, out var conn))
{
conn.State.LastHeartbeatUtc = DateTime.UtcNow;
}
OnFrame?.Invoke(connectionId, frame);
}
catch (OperationCanceledException)
{
// Expected on shutdown
break;
}
catch (ObjectDisposedException)
{
// Listener disposed
break;
}
catch (SocketException ex)
{
_logger.LogWarning(ex, "UDP socket error");
}
catch (Exception ex)
{
_logger.LogError(ex, "Error receiving UDP datagram");
}
}
}
/// <summary>
/// Sends a frame to a connection.
/// </summary>
/// <param name="connectionId">The connection ID.</param>
/// <param name="frame">The frame to send.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task SendFrameAsync(
string connectionId,
Frame frame,
CancellationToken cancellationToken = default)
{
ObjectDisposedException.ThrowIf(_disposed, this);
if (!_connections.TryGetValue(connectionId, out var conn))
{
throw new InvalidOperationException($"Connection {connectionId} not found");
}
var data = UdpFrameProtocol.SerializeFrame(frame);
if (data.Length > _options.MaxDatagramSize)
{
throw new PayloadTooLargeException(data.Length, _options.MaxDatagramSize);
}
await _listener!.SendAsync(data, conn.Endpoint, cancellationToken);
}
/// <summary>
/// Gets the connection state by ID.
/// </summary>
/// <param name="connectionId">The connection ID.</param>
/// <returns>The connection state, or null if not found.</returns>
public ConnectionState? GetConnectionState(string connectionId)
{
return _connections.TryGetValue(connectionId, out var conn) ? conn.State : null;
}
/// <summary>
/// Gets all active connections.
/// </summary>
public IEnumerable<ConnectionState> GetConnections() =>
_connections.Values.Select(c => c.State);
/// <summary>
/// Gets the number of active connections.
/// </summary>
public int ConnectionCount => _connections.Count;
/// <summary>
/// Removes a connection (for cleanup purposes).
/// </summary>
/// <param name="connectionId">The connection ID.</param>
public void RemoveConnection(string connectionId)
{
if (_connections.TryRemove(connectionId, out var conn))
{
_endpointToConnectionId.TryRemove(conn.Endpoint, out _);
_logger.LogInformation("UDP connection removed: {ConnectionId}", connectionId);
OnDisconnection?.Invoke(connectionId);
}
}
/// <inheritdoc />
public async Task StopAsync(CancellationToken cancellationToken)
{
_logger.LogInformation("Stopping UDP transport server");
if (_serverCts is not null)
{
await _serverCts.CancelAsync();
}
_listener?.Close();
if (_receiveTask is not null)
{
try
{
await _receiveTask;
}
catch (OperationCanceledException)
{
// Expected
}
}
_connections.Clear();
_endpointToConnectionId.Clear();
_logger.LogInformation("UDP transport server stopped");
}
/// <inheritdoc />
public async ValueTask DisposeAsync()
{
if (_disposed) return;
_disposed = true;
await StopAsync(CancellationToken.None);
_listener?.Dispose();
_serverCts?.Dispose();
}
}

View File

@@ -0,0 +1,24 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<LangVersion>preview</LangVersion>
<Nullable>enable</Nullable>
<ImplicitUsings>enable</ImplicitUsings>
<IsPackable>false</IsPackable>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<UseConcelierTestInfra>false</UseConcelierTestInfra>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.14.0" />
<PackageReference Include="xunit" Version="2.9.2" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="FluentAssertions" Version="6.12.0" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\StellaOps.Router.Transport.Tcp\StellaOps.Router.Transport.Tcp.csproj" />
<ProjectReference Include="..\..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,199 @@
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using StellaOps.Router.Common.Enums;
using StellaOps.Router.Common.Models;
using StellaOps.Router.Transport.Tcp;
using Xunit;
namespace StellaOps.Router.Transport.Tcp.Tests;
public class TcpTransportOptionsTests
{
[Fact]
public void DefaultOptions_HaveCorrectValues()
{
var options = new TcpTransportOptions();
Assert.Equal(5100, options.Port);
Assert.Equal(64 * 1024, options.ReceiveBufferSize);
Assert.Equal(64 * 1024, options.SendBufferSize);
Assert.Equal(TimeSpan.FromSeconds(30), options.KeepAliveInterval);
Assert.Equal(TimeSpan.FromSeconds(10), options.ConnectTimeout);
Assert.Equal(10, options.MaxReconnectAttempts);
Assert.Equal(TimeSpan.FromMinutes(1), options.MaxReconnectBackoff);
Assert.Equal(16 * 1024 * 1024, options.MaxFrameSize);
}
}
public class FrameProtocolTests
{
[Fact]
public async Task WriteAndReadFrame_RoundTrip()
{
// Arrange
using var stream = new MemoryStream();
var originalFrame = new Frame
{
Type = FrameType.Request,
CorrelationId = Guid.NewGuid().ToString("N"),
Payload = new byte[] { 1, 2, 3, 4, 5 }
};
// Act - Write
await FrameProtocol.WriteFrameAsync(stream, originalFrame, CancellationToken.None);
// Act - Read
stream.Position = 0;
var readFrame = await FrameProtocol.ReadFrameAsync(stream, 1024 * 1024, CancellationToken.None);
// Assert
Assert.NotNull(readFrame);
Assert.Equal(originalFrame.Type, readFrame.Type);
Assert.Equal(originalFrame.CorrelationId, readFrame.CorrelationId);
Assert.Equal(originalFrame.Payload.ToArray(), readFrame.Payload.ToArray());
}
[Fact]
public async Task WriteAndReadFrame_EmptyPayload()
{
using var stream = new MemoryStream();
var originalFrame = new Frame
{
Type = FrameType.Cancel,
CorrelationId = Guid.NewGuid().ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
await FrameProtocol.WriteFrameAsync(stream, originalFrame, CancellationToken.None);
stream.Position = 0;
var readFrame = await FrameProtocol.ReadFrameAsync(stream, 1024 * 1024, CancellationToken.None);
Assert.NotNull(readFrame);
Assert.Equal(FrameType.Cancel, readFrame.Type);
Assert.Empty(readFrame.Payload.ToArray());
}
[Fact]
public async Task ReadFrame_ReturnsNullOnEmptyStream()
{
using var stream = new MemoryStream();
var result = await FrameProtocol.ReadFrameAsync(stream, 1024 * 1024, CancellationToken.None);
Assert.Null(result);
}
[Fact]
public async Task ReadFrame_ThrowsOnOversizedFrame()
{
using var stream = new MemoryStream();
var largeFrame = new Frame
{
Type = FrameType.Request,
CorrelationId = Guid.NewGuid().ToString("N"),
Payload = new byte[1000]
};
await FrameProtocol.WriteFrameAsync(stream, largeFrame, CancellationToken.None);
stream.Position = 0;
// Max frame size is smaller than the written frame
await Assert.ThrowsAsync<InvalidOperationException>(
() => FrameProtocol.ReadFrameAsync(stream, 100, CancellationToken.None));
}
}
public class PendingRequestTrackerTests
{
[Fact]
public async Task TrackRequest_CompletesWithResponse()
{
using var tracker = new PendingRequestTracker();
var correlationId = Guid.NewGuid();
var expectedResponse = new Frame
{
Type = FrameType.Response,
CorrelationId = correlationId.ToString("N"),
Payload = ReadOnlyMemory<byte>.Empty
};
var responseTask = tracker.TrackRequest(correlationId, CancellationToken.None);
Assert.False(responseTask.IsCompleted);
tracker.CompleteRequest(correlationId, expectedResponse);
var response = await responseTask;
Assert.Equal(expectedResponse.Type, response.Type);
}
[Fact]
public async Task TrackRequest_CancelsOnTokenCancellation()
{
using var tracker = new PendingRequestTracker();
using var cts = new CancellationTokenSource();
var correlationId = Guid.NewGuid();
var responseTask = tracker.TrackRequest(correlationId, cts.Token);
cts.Cancel();
await Assert.ThrowsAsync<TaskCanceledException>(() => responseTask);
}
[Fact]
public void Count_ReturnsCorrectValue()
{
using var tracker = new PendingRequestTracker();
Assert.Equal(0, tracker.Count);
_ = tracker.TrackRequest(Guid.NewGuid(), CancellationToken.None);
_ = tracker.TrackRequest(Guid.NewGuid(), CancellationToken.None);
Assert.Equal(2, tracker.Count);
}
[Fact]
public void CancelAll_CancelsAllPendingRequests()
{
using var tracker = new PendingRequestTracker();
var task1 = tracker.TrackRequest(Guid.NewGuid(), CancellationToken.None);
var task2 = tracker.TrackRequest(Guid.NewGuid(), CancellationToken.None);
tracker.CancelAll();
Assert.True(task1.IsCanceled || task1.IsFaulted);
Assert.True(task2.IsCanceled || task2.IsFaulted);
}
[Fact]
public void FailRequest_SetsException()
{
using var tracker = new PendingRequestTracker();
var correlationId = Guid.NewGuid();
var task = tracker.TrackRequest(correlationId, CancellationToken.None);
tracker.FailRequest(correlationId, new InvalidOperationException("Test error"));
Assert.True(task.IsFaulted);
Assert.IsType<InvalidOperationException>(task.Exception?.InnerException);
}
}
public class TcpTransportServerTests
{
[Fact]
public async Task StartAsync_StartsListening()
{
var options = Options.Create(new TcpTransportOptions { Port = 0 }); // Port 0 = auto-assign
await using var server = new TcpTransportServer(options, NullLogger<TcpTransportServer>.Instance);
await server.StartAsync(CancellationToken.None);
Assert.Equal(0, server.ConnectionCount);
await server.StopAsync(CancellationToken.None);
}
}

View File

@@ -0,0 +1,26 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<LangVersion>preview</LangVersion>
<Nullable>enable</Nullable>
<ImplicitUsings>enable</ImplicitUsings>
<IsPackable>false</IsPackable>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<UseConcelierTestInfra>false</UseConcelierTestInfra>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.14.0" />
<PackageReference Include="xunit" Version="2.9.2" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="FluentAssertions" Version="6.12.0" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="10.0.0-rc.2.25502.107" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="10.0.0-rc.2.25502.107" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\StellaOps.Router.Transport.Tls\StellaOps.Router.Transport.Tls.csproj" />
<ProjectReference Include="..\..\StellaOps.Router.Common\StellaOps.Router.Common.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,302 @@
using System.Net;
using System.Security.Authentication;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using StellaOps.Router.Transport.Tls;
using Xunit;
namespace StellaOps.Router.Transport.Tls.Tests;
public class TlsTransportOptionsTests
{
[Fact]
public void DefaultOptions_HaveCorrectValues()
{
var options = new TlsTransportOptions();
Assert.Equal(5101, options.Port);
Assert.Equal(64 * 1024, options.ReceiveBufferSize);
Assert.Equal(64 * 1024, options.SendBufferSize);
Assert.Equal(TimeSpan.FromSeconds(30), options.KeepAliveInterval);
Assert.Equal(TimeSpan.FromSeconds(10), options.ConnectTimeout);
Assert.Equal(10, options.MaxReconnectAttempts);
Assert.Equal(TimeSpan.FromMinutes(1), options.MaxReconnectBackoff);
Assert.Equal(16 * 1024 * 1024, options.MaxFrameSize);
Assert.False(options.RequireClientCertificate);
Assert.False(options.AllowSelfSigned);
Assert.False(options.CheckCertificateRevocation);
Assert.Equal(SslProtocols.Tls12 | SslProtocols.Tls13, options.EnabledProtocols);
}
}
public class CertificateLoaderTests
{
[Fact]
public void LoadServerCertificate_WithDirectCertificate_ReturnsCertificate()
{
var cert = CreateSelfSignedCertificate("TestServer");
var options = new TlsTransportOptions
{
ServerCertificate = cert
};
var loaded = CertificateLoader.LoadServerCertificate(options);
Assert.Same(cert, loaded);
}
[Fact]
public void LoadServerCertificate_WithNoCertificate_ThrowsException()
{
var options = new TlsTransportOptions();
Assert.Throws<InvalidOperationException>(() => CertificateLoader.LoadServerCertificate(options));
}
[Fact]
public void LoadClientCertificate_WithNoCertificate_ReturnsNull()
{
var options = new TlsTransportOptions();
var result = CertificateLoader.LoadClientCertificate(options);
Assert.Null(result);
}
[Fact]
public void LoadClientCertificate_WithDirectCertificate_ReturnsCertificate()
{
var cert = CreateSelfSignedCertificate("TestClient");
var options = new TlsTransportOptions
{
ClientCertificate = cert
};
var loaded = CertificateLoader.LoadClientCertificate(options);
Assert.Same(cert, loaded);
}
private static X509Certificate2 CreateSelfSignedCertificate(string subject)
{
using var rsa = RSA.Create(2048);
var request = new CertificateRequest(
$"CN={subject}",
rsa,
HashAlgorithmName.SHA256,
RSASignaturePadding.Pkcs1);
request.CertificateExtensions.Add(
new X509KeyUsageExtension(X509KeyUsageFlags.DigitalSignature, critical: true));
var certificate = request.CreateSelfSigned(
DateTimeOffset.UtcNow.AddMinutes(-5),
DateTimeOffset.UtcNow.AddYears(1));
// Export and re-import to get the private key
var pfxBytes = certificate.Export(X509ContentType.Pfx);
return X509CertificateLoader.LoadPkcs12(
pfxBytes,
null,
X509KeyStorageFlags.MachineKeySet);
}
}
public class TlsTransportServerTests
{
[Fact]
public async Task StartAsync_WithValidCertificate_StartsListening()
{
var cert = CreateSelfSignedCertificate("TestServer");
var options = Options.Create(new TlsTransportOptions
{
Port = 0,
ServerCertificate = cert
});
await using var server = new TlsTransportServer(options, NullLogger<TlsTransportServer>.Instance);
await server.StartAsync(CancellationToken.None);
Assert.Equal(0, server.ConnectionCount);
await server.StopAsync(CancellationToken.None);
}
[Fact]
public async Task StartAsync_WithNoCertificate_ThrowsException()
{
var options = Options.Create(new TlsTransportOptions { Port = 0 });
await using var server = new TlsTransportServer(options, NullLogger<TlsTransportServer>.Instance);
await Assert.ThrowsAsync<InvalidOperationException>(() =>
server.StartAsync(CancellationToken.None));
}
private static X509Certificate2 CreateSelfSignedCertificate(string subject)
{
using var rsa = RSA.Create(2048);
var request = new CertificateRequest(
$"CN={subject}",
rsa,
HashAlgorithmName.SHA256,
RSASignaturePadding.Pkcs1);
request.CertificateExtensions.Add(
new X509KeyUsageExtension(X509KeyUsageFlags.DigitalSignature | X509KeyUsageFlags.KeyEncipherment, critical: true));
request.CertificateExtensions.Add(
new X509EnhancedKeyUsageExtension(
new OidCollection { new Oid("1.3.6.1.5.5.7.3.1") },
critical: true));
var certificate = request.CreateSelfSigned(
DateTimeOffset.UtcNow.AddMinutes(-5),
DateTimeOffset.UtcNow.AddYears(1));
var pfxBytes = certificate.Export(X509ContentType.Pfx);
return X509CertificateLoader.LoadPkcs12(
pfxBytes,
null,
X509KeyStorageFlags.MachineKeySet);
}
}
public class TlsConnectionTests
{
[Fact]
public void ConnectionId_IsSet()
{
// This is more of a documentation test since TlsConnection
// requires actual TcpClient and SslStream instances
var options = new TlsTransportOptions();
Assert.NotNull(options);
}
}
public class TlsIntegrationTests
{
[Fact]
public async Task ServerAndClient_CanEstablishConnection()
{
// Create self-signed server certificate
var serverCert = CreateSelfSignedServerCertificate("localhost");
var serverOptions = Options.Create(new TlsTransportOptions
{
Port = 0, // Auto-assign
ServerCertificate = serverCert,
RequireClientCertificate = false
});
await using var server = new TlsTransportServer(serverOptions, NullLogger<TlsTransportServer>.Instance);
await server.StartAsync(CancellationToken.None);
Assert.Equal(0, server.ConnectionCount);
await server.StopAsync(CancellationToken.None);
}
[Fact]
public async Task ServerWithMtls_RequiresClientCertificate()
{
var serverCert = CreateSelfSignedServerCertificate("localhost");
var serverOptions = Options.Create(new TlsTransportOptions
{
Port = 0,
ServerCertificate = serverCert,
RequireClientCertificate = true,
AllowSelfSigned = true
});
await using var server = new TlsTransportServer(serverOptions, NullLogger<TlsTransportServer>.Instance);
await server.StartAsync(CancellationToken.None);
Assert.True(serverOptions.Value.RequireClientCertificate);
await server.StopAsync(CancellationToken.None);
}
private static X509Certificate2 CreateSelfSignedServerCertificate(string hostname)
{
using var rsa = RSA.Create(2048);
var request = new CertificateRequest(
$"CN={hostname}",
rsa,
HashAlgorithmName.SHA256,
RSASignaturePadding.Pkcs1);
// Key usage for server auth
request.CertificateExtensions.Add(
new X509KeyUsageExtension(
X509KeyUsageFlags.DigitalSignature | X509KeyUsageFlags.KeyEncipherment,
critical: true));
// Server authentication EKU
request.CertificateExtensions.Add(
new X509EnhancedKeyUsageExtension(
new OidCollection { new Oid("1.3.6.1.5.5.7.3.1") },
critical: true));
// Subject Alternative Name
var sanBuilder = new SubjectAlternativeNameBuilder();
sanBuilder.AddDnsName(hostname);
sanBuilder.AddIpAddress(IPAddress.Loopback);
request.CertificateExtensions.Add(sanBuilder.Build());
var certificate = request.CreateSelfSigned(
DateTimeOffset.UtcNow.AddMinutes(-5),
DateTimeOffset.UtcNow.AddYears(1));
var pfxBytes = certificate.Export(X509ContentType.Pfx);
return X509CertificateLoader.LoadPkcs12(
pfxBytes,
null,
X509KeyStorageFlags.MachineKeySet);
}
}
public class ServiceCollectionExtensionsTests
{
[Fact]
public void AddTlsTransportServer_RegistersServices()
{
var services = new ServiceCollection();
services.AddLogging();
services.AddTlsTransportServer(options =>
{
options.Port = 5101;
});
var provider = services.BuildServiceProvider();
var server = provider.GetService<TlsTransportServer>();
Assert.NotNull(server);
}
[Fact]
public void AddTlsTransportClient_RegistersServices()
{
var services = new ServiceCollection();
services.AddLogging();
services.AddTlsTransportClient(options =>
{
options.Host = "localhost";
options.Port = 5101;
});
var provider = services.BuildServiceProvider();
var client = provider.GetService<TlsTransportClient>();
Assert.NotNull(client);
}
}