Files
git.stella-ops.org/src/Router/StellaOps.Gateway.WebService/Middleware/RouteDispatchMiddleware.cs

669 lines
24 KiB
C#

using System.Net.WebSockets;
using System.Text.RegularExpressions;
using Microsoft.AspNetCore.StaticFiles;
using Microsoft.Extensions.FileProviders;
using StellaOps.Gateway.WebService.Configuration;
using StellaOps.Router.Gateway.Configuration;
using StellaOps.Router.Gateway;
using StellaOps.Gateway.WebService.Routing;
namespace StellaOps.Gateway.WebService.Middleware;
public sealed class RouteDispatchMiddleware
{
private readonly RequestDelegate _next;
private readonly StellaOpsRouteResolver _resolver;
private readonly IHttpClientFactory _httpClientFactory;
private readonly ILogger<RouteDispatchMiddleware> _logger;
private readonly FileExtensionContentTypeProvider _contentTypeProvider = new();
private static readonly HashSet<string> HopByHopHeaders = new(StringComparer.OrdinalIgnoreCase)
{
"Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization",
"TE", "Trailers", "Transfer-Encoding", "Upgrade"
};
// ReverseProxy paths that are legitimate browser navigation targets (e.g. OIDC flows)
// and must NOT be redirected to the SPA fallback.
private static readonly string[] BrowserProxyPaths = ["/connect", "/.well-known"];
private static readonly string[] SpaRoutesWithDocumentExtensions = ["/docs", "/docs/"];
public RouteDispatchMiddleware(
RequestDelegate next,
StellaOpsRouteResolver resolver,
IHttpClientFactory httpClientFactory,
ILogger<RouteDispatchMiddleware> logger)
{
_next = next;
_resolver = resolver;
_httpClientFactory = httpClientFactory;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context)
{
// System paths (health, metrics, openapi) bypass route dispatch
if (GatewayRoutes.IsSystemPath(context.Request.Path))
{
await _next(context);
return;
}
var (route, regexMatch) = _resolver.Resolve(context.Request.Path);
if (route is null)
{
await _next(context);
return;
}
// SPA fallback: when a service route (ReverseProxy or Microservice) is matched
// but the request is a browser navigation, serve the SPA index.html instead of
// proxying/dispatching to backend service routes. This prevents collisions
// between UI deep links (for example "/policy") and backend route prefixes.
// Excludes known backend browser-navigation paths (for example OIDC /connect).
if ((route.Type == StellaOpsRouteType.ReverseProxy || route.Type == StellaOpsRouteType.Microservice)
&& IsBrowserNavigation(context.Request))
{
var spaRoute = _resolver.FindSpaFallbackRoute();
if (spaRoute is not null)
{
_logger.LogDebug(
"SPA fallback: serving index.html for browser navigation to {Path} (matched route type: {RouteType})",
context.Request.Path,
route.Type);
await HandleStaticFiles(context, spaRoute);
return;
}
}
switch (route.Type)
{
case StellaOpsRouteType.StaticFiles:
await HandleStaticFiles(context, route);
break;
case StellaOpsRouteType.StaticFile:
await HandleStaticFile(context, route);
break;
case StellaOpsRouteType.ReverseProxy:
await HandleReverseProxy(context, route, regexMatch);
break;
case StellaOpsRouteType.WebSocket:
await HandleWebSocket(context, route);
break;
case StellaOpsRouteType.Microservice:
PrepareMicroserviceRoute(context, route, regexMatch);
await _next(context);
break;
default:
await _next(context);
break;
}
}
private async Task HandleStaticFiles(HttpContext context, StellaOpsRoute route)
{
var requestPath = context.Request.Path.Value ?? string.Empty;
var relativePath = requestPath;
if (requestPath.StartsWith(route.Path, StringComparison.OrdinalIgnoreCase))
{
relativePath = requestPath[route.Path.Length..];
if (!relativePath.StartsWith('/'))
{
relativePath = "/" + relativePath;
}
}
var directoryPath = route.TranslatesTo!;
if (!Directory.Exists(directoryPath))
{
_logger.LogWarning("StaticFiles directory not found: {Directory}", directoryPath);
context.Response.StatusCode = StatusCodes.Status404NotFound;
return;
}
var fileProvider = new PhysicalFileProvider(directoryPath);
var fileInfo = fileProvider.GetFileInfo(relativePath);
if (fileInfo.Exists && !fileInfo.IsDirectory)
{
await ServeFile(context, fileInfo, relativePath);
return;
}
// SPA fallback: serve index.html for paths without extensions
var spaFallback = route.Headers.TryGetValue("x-spa-fallback", out var spaValue) &&
string.Equals(spaValue, "true", StringComparison.OrdinalIgnoreCase);
if (spaFallback && ShouldServeSpaFallback(relativePath))
{
var indexFile = fileProvider.GetFileInfo("/index.html");
if (indexFile.Exists && !indexFile.IsDirectory)
{
await ServeFile(context, indexFile, "/index.html");
return;
}
}
context.Response.StatusCode = StatusCodes.Status404NotFound;
}
private async Task HandleStaticFile(HttpContext context, StellaOpsRoute route)
{
var requestPath = context.Request.Path.Value ?? string.Empty;
// StaticFile serves the exact file only at the exact path
if (!requestPath.Equals(route.Path, StringComparison.OrdinalIgnoreCase))
{
context.Response.StatusCode = StatusCodes.Status404NotFound;
return;
}
var filePath = route.TranslatesTo!;
if (!File.Exists(filePath))
{
_logger.LogWarning("StaticFile not found: {File}", filePath);
context.Response.StatusCode = StatusCodes.Status404NotFound;
return;
}
var fileName = System.IO.Path.GetFileName(filePath);
if (!_contentTypeProvider.TryGetContentType(fileName, out var contentType))
{
contentType = "application/octet-stream";
}
context.Response.StatusCode = StatusCodes.Status200OK;
context.Response.ContentType = contentType;
await using var stream = File.OpenRead(filePath);
await stream.CopyToAsync(context.Response.Body, context.RequestAborted);
}
private async Task HandleReverseProxy(HttpContext context, StellaOpsRoute route, Match? regexMatch)
{
var requestPath = context.Request.Path.Value ?? string.Empty;
var resolvedTranslatesTo = ResolveCaptureGroups(route.TranslatesTo, regexMatch);
var captureGroupsResolved = !string.Equals(resolvedTranslatesTo, route.TranslatesTo, StringComparison.Ordinal);
var remainingPath = requestPath;
if (captureGroupsResolved)
{
// Capture groups resolved: TranslatesTo already contains the full target path.
remainingPath = string.Empty;
}
else if (!route.IsRegex && requestPath.StartsWith(route.Path, StringComparison.OrdinalIgnoreCase))
{
remainingPath = requestPath[route.Path.Length..];
}
var upstreamBase = resolvedTranslatesTo!.TrimEnd('/');
var upstreamUri = new Uri($"{upstreamBase}{remainingPath}{context.Request.QueryString}");
var client = _httpClientFactory.CreateClient("RouteDispatch");
client.Timeout = TimeSpan.FromSeconds(30);
var upstreamRequest = new HttpRequestMessage(new HttpMethod(context.Request.Method), upstreamUri);
// Copy request headers (excluding hop-by-hop)
foreach (var header in context.Request.Headers)
{
if (HopByHopHeaders.Contains(header.Key) ||
header.Key.Equals("Host", StringComparison.OrdinalIgnoreCase))
{
continue;
}
upstreamRequest.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
}
// Inject configured headers
foreach (var (key, value) in route.Headers)
{
upstreamRequest.Headers.TryAddWithoutValidation(key, value);
}
// Copy request body for methods that support it
if (context.Request.ContentLength > 0 || context.Request.ContentType is not null)
{
upstreamRequest.Content = new StreamContent(context.Request.Body);
if (context.Request.ContentType is not null)
{
upstreamRequest.Content.Headers.TryAddWithoutValidation("Content-Type", context.Request.ContentType);
}
}
HttpResponseMessage upstreamResponse;
try
{
upstreamResponse = await client.SendAsync(
upstreamRequest,
HttpCompletionOption.ResponseHeadersRead,
context.RequestAborted);
}
catch (TaskCanceledException) when (!context.RequestAborted.IsCancellationRequested)
{
context.Response.StatusCode = StatusCodes.Status504GatewayTimeout;
return;
}
catch (HttpRequestException ex)
{
_logger.LogError(ex, "Reverse proxy upstream request failed for {Upstream}", upstreamUri);
context.Response.StatusCode = StatusCodes.Status502BadGateway;
return;
}
using (upstreamResponse)
{
context.Response.StatusCode = (int)upstreamResponse.StatusCode;
// Copy response headers (excluding hop-by-hop and content-length which
// we'll set ourselves after reading the body to ensure accuracy)
foreach (var header in upstreamResponse.Headers)
{
if (!HopByHopHeaders.Contains(header.Key))
{
context.Response.Headers[header.Key] = header.Value.ToArray();
}
}
foreach (var header in upstreamResponse.Content.Headers)
{
if (!string.Equals(header.Key, "Content-Length", StringComparison.OrdinalIgnoreCase))
{
context.Response.Headers[header.Key] = header.Value.ToArray();
}
}
// Read the full response body so we can set an accurate Content-Length.
// This is necessary because the upstream may use chunked transfer encoding
// (which we strip as a hop-by-hop header), and without Content-Length or
// Transfer-Encoding the downstream client cannot determine body length.
var body = await upstreamResponse.Content.ReadAsByteArrayAsync(context.RequestAborted);
if (body.Length > 0)
{
context.Response.ContentLength = body.Length;
await context.Response.Body.WriteAsync(body, context.RequestAborted);
}
}
}
private static void PrepareMicroserviceRoute(HttpContext context, StellaOpsRoute route, Match? regexMatch)
{
// If regex route with capture groups, resolve $1/$2/etc. in TranslatesTo
var effectiveRoute = route;
if (regexMatch is not null && !string.IsNullOrWhiteSpace(route.TranslatesTo))
{
var resolvedTranslatesTo = ResolveCaptureGroups(route.TranslatesTo, regexMatch);
if (!string.Equals(resolvedTranslatesTo, route.TranslatesTo, StringComparison.Ordinal))
{
effectiveRoute = new StellaOpsRoute
{
Type = route.Type,
Path = route.Path,
IsRegex = route.IsRegex,
TranslatesTo = resolvedTranslatesTo,
DefaultTimeout = route.DefaultTimeout,
PreserveAuthHeaders = route.PreserveAuthHeaders
};
}
}
var translatedPath = ResolveTranslatedMicroservicePath(context.Request.Path.Value, effectiveRoute);
if (!string.Equals(translatedPath, context.Request.Path.Value, StringComparison.Ordinal))
{
context.Items[RouterHttpContextKeys.TranslatedRequestPath] = translatedPath;
}
var targetMicroservice = ResolveRouteTargetMicroservice(effectiveRoute);
if (!string.IsNullOrWhiteSpace(targetMicroservice))
{
context.Items[RouterHttpContextKeys.RouteTargetMicroservice] = targetMicroservice;
}
if (!string.IsNullOrWhiteSpace(route.DefaultTimeout))
{
var routeTimeout = GatewayValueParser.ParseDuration(route.DefaultTimeout, TimeSpan.FromSeconds(30));
context.Items[RouterHttpContextKeys.RouteDefaultTimeout] = routeTimeout;
}
}
private static string? ResolveCaptureGroups(string? translatesTo, Match? regexMatch)
{
if (regexMatch is null || string.IsNullOrWhiteSpace(translatesTo))
{
return translatesTo;
}
var resolved = translatesTo;
for (var i = regexMatch.Groups.Count - 1; i >= 1; i--)
{
resolved = resolved.Replace($"${i}", regexMatch.Groups[i].Value);
}
return resolved;
}
private static string ResolveTranslatedMicroservicePath(string? requestPathValue, StellaOpsRoute route)
{
var requestPath = string.IsNullOrWhiteSpace(requestPathValue) ? "/" : requestPathValue!;
if (string.IsNullOrWhiteSpace(route.TranslatesTo))
{
return requestPath;
}
var targetPrefix = ResolveTargetPathPrefix(route);
if (string.IsNullOrWhiteSpace(targetPrefix))
{
return requestPath;
}
// For regex routes, the TranslatesTo (after capture group substitution)
// already contains the full target path. Use it directly.
if (route.IsRegex)
{
return NormalizePath(targetPrefix);
}
var normalizedRoutePath = NormalizePath(route.Path);
var normalizedRequestPath = NormalizePath(requestPath);
var remainingPath = normalizedRequestPath;
if (normalizedRequestPath.StartsWith(normalizedRoutePath, StringComparison.OrdinalIgnoreCase))
{
remainingPath = normalizedRequestPath[normalizedRoutePath.Length..];
if (!remainingPath.StartsWith('/'))
{
remainingPath = "/" + remainingPath;
}
}
return targetPrefix == "/"
? NormalizePath(remainingPath)
: NormalizePath($"{targetPrefix.TrimEnd('/')}{remainingPath}");
}
private static string ResolveTargetPathPrefix(StellaOpsRoute route)
{
var rawValue = route.TranslatesTo;
if (string.IsNullOrWhiteSpace(rawValue))
{
return string.Empty;
}
if (Uri.TryCreate(rawValue, UriKind.Absolute, out var absolute))
{
return NormalizePath(absolute.AbsolutePath);
}
if (Uri.TryCreate(rawValue, UriKind.Relative, out _))
{
return NormalizePath(rawValue);
}
return string.Empty;
}
private static string? ResolveRouteTargetMicroservice(StellaOpsRoute route)
{
var hostService = ExtractServiceKeyFromTranslatesTo(route.TranslatesTo);
var pathService = ExtractServiceKeyFromPath(route.Path);
if (IsGenericServiceAlias(hostService) && !IsGenericServiceAlias(pathService))
{
return pathService;
}
return hostService ?? pathService;
}
private static string? ExtractServiceKeyFromTranslatesTo(string? translatesTo)
{
if (string.IsNullOrWhiteSpace(translatesTo))
{
return null;
}
if (!Uri.TryCreate(translatesTo, UriKind.Absolute, out var absolute))
{
return null;
}
return NormalizeServiceKey(absolute.Host);
}
private static string? ExtractServiceKeyFromPath(string? path)
{
var normalizedPath = NormalizePath(path);
var segments = normalizedPath
.Split('/', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
if (segments.Length == 0)
{
return null;
}
if (segments.Length >= 3 &&
string.Equals(segments[0], "api", StringComparison.OrdinalIgnoreCase) &&
string.Equals(segments[1], "v1", StringComparison.OrdinalIgnoreCase))
{
return NormalizeServiceKey(segments[2]);
}
return NormalizeServiceKey(segments[0]);
}
private static string? NormalizeServiceKey(string? value)
{
if (string.IsNullOrWhiteSpace(value))
{
return null;
}
var normalized = value.Trim().ToLowerInvariant();
var portSeparator = normalized.IndexOf(':');
if (portSeparator >= 0)
{
normalized = normalized[..portSeparator];
}
const string localDomain = ".stella-ops.local";
if (normalized.EndsWith(localDomain, StringComparison.Ordinal))
{
normalized = normalized[..^localDomain.Length];
}
return string.IsNullOrWhiteSpace(normalized)
? null
: normalized;
}
private static bool IsGenericServiceAlias(string? value)
{
if (string.IsNullOrWhiteSpace(value))
{
return true;
}
return value.Equals("api", StringComparison.OrdinalIgnoreCase) ||
value.Equals("web", StringComparison.OrdinalIgnoreCase) ||
value.Equals("service", StringComparison.OrdinalIgnoreCase);
}
private static string NormalizePath(string? value)
{
if (string.IsNullOrWhiteSpace(value))
{
return "/";
}
var normalized = value.Trim();
if (!normalized.StartsWith('/'))
{
normalized = "/" + normalized;
}
normalized = normalized.TrimEnd('/');
return string.IsNullOrEmpty(normalized) ? "/" : normalized;
}
private async Task HandleWebSocket(HttpContext context, StellaOpsRoute route)
{
if (!context.WebSockets.IsWebSocketRequest)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
return;
}
var requestPath = context.Request.Path.Value ?? string.Empty;
var remainingPath = requestPath;
if (!route.IsRegex && requestPath.StartsWith(route.Path, StringComparison.OrdinalIgnoreCase))
{
remainingPath = requestPath[route.Path.Length..];
}
var upstreamBase = route.TranslatesTo!.TrimEnd('/');
var upstreamUri = new Uri($"{upstreamBase}{remainingPath}");
using var clientWebSocket = new ClientWebSocket();
try
{
await clientWebSocket.ConnectAsync(upstreamUri, context.RequestAborted);
}
catch (Exception ex)
{
_logger.LogError(ex, "WebSocket upstream connection failed for {Upstream}", upstreamUri);
context.Response.StatusCode = StatusCodes.Status502BadGateway;
return;
}
using var serverWebSocket = await context.WebSockets.AcceptWebSocketAsync();
var cts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted);
var clientToServer = PumpWebSocket(serverWebSocket, clientWebSocket, cts);
var serverToClient = PumpWebSocket(clientWebSocket, serverWebSocket, cts);
await Task.WhenAny(clientToServer, serverToClient);
await cts.CancelAsync();
}
private static async Task PumpWebSocket(
WebSocket source,
WebSocket destination,
CancellationTokenSource cts)
{
var buffer = new byte[4096];
try
{
while (!cts.Token.IsCancellationRequested)
{
var result = await source.ReceiveAsync(
new ArraySegment<byte>(buffer),
cts.Token);
if (result.MessageType == WebSocketMessageType.Close)
{
if (destination.State == WebSocketState.Open ||
destination.State == WebSocketState.CloseReceived)
{
await destination.CloseAsync(
result.CloseStatus ?? WebSocketCloseStatus.NormalClosure,
result.CloseStatusDescription,
cts.Token);
}
break;
}
if (destination.State == WebSocketState.Open)
{
await destination.SendAsync(
new ArraySegment<byte>(buffer, 0, result.Count),
result.MessageType,
result.EndOfMessage,
cts.Token);
}
}
}
catch (OperationCanceledException)
{
// Expected during shutdown
}
catch (WebSocketException)
{
// Connection closed unexpectedly
}
}
private async Task ServeFile(HttpContext context, IFileInfo fileInfo, string fileName)
{
if (!_contentTypeProvider.TryGetContentType(fileName, out var contentType))
{
contentType = "application/octet-stream";
}
context.Response.StatusCode = StatusCodes.Status200OK;
context.Response.ContentType = contentType;
context.Response.ContentLength = fileInfo.Length;
await using var stream = fileInfo.CreateReadStream();
await stream.CopyToAsync(context.Response.Body, context.RequestAborted);
}
/// <summary>
/// Determines if the request is a browser page navigation (as opposed to an XHR/fetch API call).
/// Browser navigations send Accept: text/html and target paths without file extensions.
/// Known backend browser-navigation paths (OIDC endpoints) are excluded.
/// </summary>
private static bool IsBrowserNavigation(HttpRequest request)
{
if (!HttpMethods.IsGet(request.Method))
return false;
var path = request.Path.Value ?? string.Empty;
// Paths with file extensions are static asset requests, not SPA navigation
if (System.IO.Path.HasExtension(path))
return false;
// Exclude known backend paths that legitimately receive browser navigations
foreach (var excluded in BrowserProxyPaths)
{
if (path.StartsWith(excluded, StringComparison.OrdinalIgnoreCase))
return false;
}
// API prefixes should continue to dispatch to backend handlers even when
// entered directly in a browser.
if (path.Equals("/api", StringComparison.OrdinalIgnoreCase) ||
path.StartsWith("/api/", StringComparison.OrdinalIgnoreCase) ||
path.Equals("/v1", StringComparison.OrdinalIgnoreCase) ||
path.StartsWith("/v1/", StringComparison.OrdinalIgnoreCase))
{
return false;
}
var accept = request.Headers.Accept.ToString();
return accept.Contains("text/html", StringComparison.OrdinalIgnoreCase);
}
private static bool ShouldServeSpaFallback(string relativePath)
{
if (!System.IO.Path.HasExtension(relativePath))
{
return true;
}
foreach (var prefix in SpaRoutesWithDocumentExtensions)
{
if (relativePath.StartsWith(prefix, StringComparison.OrdinalIgnoreCase))
{
return true;
}
}
return false;
}
}