save checkpoint
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
using StellaOps.Router.Gateway.Configuration;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
public sealed class ErrorPageFallbackMiddleware
|
||||
{
|
||||
private readonly RequestDelegate _next;
|
||||
private readonly string? _notFoundPagePath;
|
||||
private readonly string? _serverErrorPagePath;
|
||||
private readonly ILogger<ErrorPageFallbackMiddleware> _logger;
|
||||
|
||||
public ErrorPageFallbackMiddleware(
|
||||
RequestDelegate next,
|
||||
IEnumerable<StellaOpsRoute> errorRoutes,
|
||||
ILogger<ErrorPageFallbackMiddleware> logger)
|
||||
{
|
||||
_next = next;
|
||||
_logger = logger;
|
||||
|
||||
foreach (var route in errorRoutes)
|
||||
{
|
||||
switch (route.Type)
|
||||
{
|
||||
case StellaOpsRouteType.NotFoundPage:
|
||||
_notFoundPagePath = route.TranslatesTo;
|
||||
break;
|
||||
case StellaOpsRouteType.ServerErrorPage:
|
||||
_serverErrorPagePath = route.TranslatesTo;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async Task InvokeAsync(HttpContext context)
|
||||
{
|
||||
// Fast path: no error pages configured, skip body wrapping
|
||||
if (_notFoundPagePath is null && _serverErrorPagePath is null)
|
||||
{
|
||||
await _next(context);
|
||||
return;
|
||||
}
|
||||
|
||||
// Capture the original response body to detect status codes
|
||||
var originalBody = context.Response.Body;
|
||||
using var memoryStream = new MemoryStream();
|
||||
context.Response.Body = memoryStream;
|
||||
|
||||
try
|
||||
{
|
||||
await _next(context);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Unhandled exception in pipeline");
|
||||
context.Response.StatusCode = StatusCodes.Status500InternalServerError;
|
||||
}
|
||||
|
||||
// Check if we need to serve a custom error page
|
||||
if (context.Response.StatusCode == 404 && _notFoundPagePath is not null && memoryStream.Length == 0)
|
||||
{
|
||||
context.Response.Body = originalBody;
|
||||
await ServeErrorPage(context, _notFoundPagePath, 404);
|
||||
return;
|
||||
}
|
||||
|
||||
if (context.Response.StatusCode >= 500 && _serverErrorPagePath is not null && memoryStream.Length == 0)
|
||||
{
|
||||
context.Response.Body = originalBody;
|
||||
await ServeErrorPage(context, _serverErrorPagePath, context.Response.StatusCode);
|
||||
return;
|
||||
}
|
||||
|
||||
// No error page override, copy the original response
|
||||
memoryStream.Position = 0;
|
||||
context.Response.Body = originalBody;
|
||||
await memoryStream.CopyToAsync(originalBody, context.RequestAborted);
|
||||
}
|
||||
|
||||
private async Task ServeErrorPage(HttpContext context, string filePath, int statusCode)
|
||||
{
|
||||
if (!File.Exists(filePath))
|
||||
{
|
||||
_logger.LogWarning("Error page file not found: {FilePath}", filePath);
|
||||
context.Response.StatusCode = statusCode;
|
||||
context.Response.ContentType = "application/json; charset=utf-8";
|
||||
await context.Response.WriteAsync(
|
||||
$$"""{"error":"{{(statusCode == 404 ? "not_found" : "internal_server_error")}}","status":{{statusCode}}}""",
|
||||
context.RequestAborted);
|
||||
return;
|
||||
}
|
||||
|
||||
context.Response.StatusCode = statusCode;
|
||||
context.Response.ContentType = "text/html; charset=utf-8";
|
||||
|
||||
await using var stream = File.OpenRead(filePath);
|
||||
await stream.CopyToAsync(context.Response.Body, context.RequestAborted);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,346 @@
|
||||
using System.Net.WebSockets;
|
||||
using Microsoft.AspNetCore.StaticFiles;
|
||||
using Microsoft.Extensions.FileProviders;
|
||||
using StellaOps.Router.Gateway.Configuration;
|
||||
using StellaOps.Gateway.WebService.Routing;
|
||||
|
||||
namespace StellaOps.Gateway.WebService.Middleware;
|
||||
|
||||
public sealed class RouteDispatchMiddleware
|
||||
{
|
||||
private readonly RequestDelegate _next;
|
||||
private readonly StellaOpsRouteResolver _resolver;
|
||||
private readonly IHttpClientFactory _httpClientFactory;
|
||||
private readonly ILogger<RouteDispatchMiddleware> _logger;
|
||||
private readonly FileExtensionContentTypeProvider _contentTypeProvider = new();
|
||||
|
||||
private static readonly HashSet<string> HopByHopHeaders = new(StringComparer.OrdinalIgnoreCase)
|
||||
{
|
||||
"Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization",
|
||||
"TE", "Trailers", "Transfer-Encoding", "Upgrade"
|
||||
};
|
||||
|
||||
public RouteDispatchMiddleware(
|
||||
RequestDelegate next,
|
||||
StellaOpsRouteResolver resolver,
|
||||
IHttpClientFactory httpClientFactory,
|
||||
ILogger<RouteDispatchMiddleware> logger)
|
||||
{
|
||||
_next = next;
|
||||
_resolver = resolver;
|
||||
_httpClientFactory = httpClientFactory;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public async Task InvokeAsync(HttpContext context)
|
||||
{
|
||||
// System paths (health, metrics, openapi) bypass route dispatch
|
||||
if (GatewayRoutes.IsSystemPath(context.Request.Path))
|
||||
{
|
||||
await _next(context);
|
||||
return;
|
||||
}
|
||||
|
||||
var route = _resolver.Resolve(context.Request.Path);
|
||||
if (route is null)
|
||||
{
|
||||
await _next(context);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (route.Type)
|
||||
{
|
||||
case StellaOpsRouteType.StaticFiles:
|
||||
await HandleStaticFiles(context, route);
|
||||
break;
|
||||
case StellaOpsRouteType.StaticFile:
|
||||
await HandleStaticFile(context, route);
|
||||
break;
|
||||
case StellaOpsRouteType.ReverseProxy:
|
||||
await HandleReverseProxy(context, route);
|
||||
break;
|
||||
case StellaOpsRouteType.WebSocket:
|
||||
await HandleWebSocket(context, route);
|
||||
break;
|
||||
case StellaOpsRouteType.Microservice:
|
||||
await _next(context);
|
||||
break;
|
||||
default:
|
||||
await _next(context);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private async Task HandleStaticFiles(HttpContext context, StellaOpsRoute route)
|
||||
{
|
||||
var requestPath = context.Request.Path.Value ?? string.Empty;
|
||||
var relativePath = requestPath;
|
||||
|
||||
if (requestPath.StartsWith(route.Path, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
relativePath = requestPath[route.Path.Length..];
|
||||
if (!relativePath.StartsWith('/'))
|
||||
{
|
||||
relativePath = "/" + relativePath;
|
||||
}
|
||||
}
|
||||
|
||||
var directoryPath = route.TranslatesTo!;
|
||||
if (!Directory.Exists(directoryPath))
|
||||
{
|
||||
_logger.LogWarning("StaticFiles directory not found: {Directory}", directoryPath);
|
||||
context.Response.StatusCode = StatusCodes.Status404NotFound;
|
||||
return;
|
||||
}
|
||||
|
||||
var fileProvider = new PhysicalFileProvider(directoryPath);
|
||||
var fileInfo = fileProvider.GetFileInfo(relativePath);
|
||||
|
||||
if (fileInfo.Exists && !fileInfo.IsDirectory)
|
||||
{
|
||||
await ServeFile(context, fileInfo, relativePath);
|
||||
return;
|
||||
}
|
||||
|
||||
// SPA fallback: serve index.html for paths without extensions
|
||||
var spaFallback = route.Headers.TryGetValue("x-spa-fallback", out var spaValue) &&
|
||||
string.Equals(spaValue, "true", StringComparison.OrdinalIgnoreCase);
|
||||
|
||||
if (spaFallback && !System.IO.Path.HasExtension(relativePath))
|
||||
{
|
||||
var indexFile = fileProvider.GetFileInfo("/index.html");
|
||||
if (indexFile.Exists && !indexFile.IsDirectory)
|
||||
{
|
||||
await ServeFile(context, indexFile, "/index.html");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status404NotFound;
|
||||
}
|
||||
|
||||
private async Task HandleStaticFile(HttpContext context, StellaOpsRoute route)
|
||||
{
|
||||
var requestPath = context.Request.Path.Value ?? string.Empty;
|
||||
|
||||
// StaticFile serves the exact file only at the exact path
|
||||
if (!requestPath.Equals(route.Path, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status404NotFound;
|
||||
return;
|
||||
}
|
||||
|
||||
var filePath = route.TranslatesTo!;
|
||||
if (!File.Exists(filePath))
|
||||
{
|
||||
_logger.LogWarning("StaticFile not found: {File}", filePath);
|
||||
context.Response.StatusCode = StatusCodes.Status404NotFound;
|
||||
return;
|
||||
}
|
||||
|
||||
var fileName = System.IO.Path.GetFileName(filePath);
|
||||
if (!_contentTypeProvider.TryGetContentType(fileName, out var contentType))
|
||||
{
|
||||
contentType = "application/octet-stream";
|
||||
}
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status200OK;
|
||||
context.Response.ContentType = contentType;
|
||||
|
||||
await using var stream = File.OpenRead(filePath);
|
||||
await stream.CopyToAsync(context.Response.Body, context.RequestAborted);
|
||||
}
|
||||
|
||||
private async Task HandleReverseProxy(HttpContext context, StellaOpsRoute route)
|
||||
{
|
||||
var requestPath = context.Request.Path.Value ?? string.Empty;
|
||||
var remainingPath = requestPath;
|
||||
|
||||
if (!route.IsRegex && requestPath.StartsWith(route.Path, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
remainingPath = requestPath[route.Path.Length..];
|
||||
}
|
||||
|
||||
var upstreamBase = route.TranslatesTo!.TrimEnd('/');
|
||||
var upstreamUri = new Uri($"{upstreamBase}{remainingPath}{context.Request.QueryString}");
|
||||
|
||||
var client = _httpClientFactory.CreateClient("RouteDispatch");
|
||||
client.Timeout = TimeSpan.FromSeconds(30);
|
||||
|
||||
var upstreamRequest = new HttpRequestMessage(new HttpMethod(context.Request.Method), upstreamUri);
|
||||
|
||||
// Copy request headers (excluding hop-by-hop)
|
||||
foreach (var header in context.Request.Headers)
|
||||
{
|
||||
if (HopByHopHeaders.Contains(header.Key) ||
|
||||
header.Key.Equals("Host", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
upstreamRequest.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
|
||||
}
|
||||
|
||||
// Inject configured headers
|
||||
foreach (var (key, value) in route.Headers)
|
||||
{
|
||||
upstreamRequest.Headers.TryAddWithoutValidation(key, value);
|
||||
}
|
||||
|
||||
// Copy request body for methods that support it
|
||||
if (context.Request.ContentLength > 0 || context.Request.ContentType is not null)
|
||||
{
|
||||
upstreamRequest.Content = new StreamContent(context.Request.Body);
|
||||
if (context.Request.ContentType is not null)
|
||||
{
|
||||
upstreamRequest.Content.Headers.TryAddWithoutValidation("Content-Type", context.Request.ContentType);
|
||||
}
|
||||
}
|
||||
|
||||
HttpResponseMessage upstreamResponse;
|
||||
try
|
||||
{
|
||||
upstreamResponse = await client.SendAsync(
|
||||
upstreamRequest,
|
||||
HttpCompletionOption.ResponseHeadersRead,
|
||||
context.RequestAborted);
|
||||
}
|
||||
catch (TaskCanceledException) when (!context.RequestAborted.IsCancellationRequested)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status504GatewayTimeout;
|
||||
return;
|
||||
}
|
||||
catch (HttpRequestException ex)
|
||||
{
|
||||
_logger.LogError(ex, "Reverse proxy upstream request failed for {Upstream}", upstreamUri);
|
||||
context.Response.StatusCode = StatusCodes.Status502BadGateway;
|
||||
return;
|
||||
}
|
||||
|
||||
using (upstreamResponse)
|
||||
{
|
||||
context.Response.StatusCode = (int)upstreamResponse.StatusCode;
|
||||
|
||||
// Copy response headers
|
||||
foreach (var header in upstreamResponse.Headers)
|
||||
{
|
||||
if (!HopByHopHeaders.Contains(header.Key))
|
||||
{
|
||||
context.Response.Headers[header.Key] = header.Value.ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var header in upstreamResponse.Content.Headers)
|
||||
{
|
||||
context.Response.Headers[header.Key] = header.Value.ToArray();
|
||||
}
|
||||
|
||||
// Stream response body
|
||||
await using var responseStream = await upstreamResponse.Content.ReadAsStreamAsync(context.RequestAborted);
|
||||
await responseStream.CopyToAsync(context.Response.Body, context.RequestAborted);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task HandleWebSocket(HttpContext context, StellaOpsRoute route)
|
||||
{
|
||||
if (!context.WebSockets.IsWebSocketRequest)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status400BadRequest;
|
||||
return;
|
||||
}
|
||||
|
||||
var requestPath = context.Request.Path.Value ?? string.Empty;
|
||||
var remainingPath = requestPath;
|
||||
|
||||
if (!route.IsRegex && requestPath.StartsWith(route.Path, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
remainingPath = requestPath[route.Path.Length..];
|
||||
}
|
||||
|
||||
var upstreamBase = route.TranslatesTo!.TrimEnd('/');
|
||||
var upstreamUri = new Uri($"{upstreamBase}{remainingPath}");
|
||||
|
||||
using var clientWebSocket = new ClientWebSocket();
|
||||
try
|
||||
{
|
||||
await clientWebSocket.ConnectAsync(upstreamUri, context.RequestAborted);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "WebSocket upstream connection failed for {Upstream}", upstreamUri);
|
||||
context.Response.StatusCode = StatusCodes.Status502BadGateway;
|
||||
return;
|
||||
}
|
||||
|
||||
using var serverWebSocket = await context.WebSockets.AcceptWebSocketAsync();
|
||||
var cts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted);
|
||||
|
||||
var clientToServer = PumpWebSocket(serverWebSocket, clientWebSocket, cts);
|
||||
var serverToClient = PumpWebSocket(clientWebSocket, serverWebSocket, cts);
|
||||
|
||||
await Task.WhenAny(clientToServer, serverToClient);
|
||||
await cts.CancelAsync();
|
||||
}
|
||||
|
||||
private static async Task PumpWebSocket(
|
||||
WebSocket source,
|
||||
WebSocket destination,
|
||||
CancellationTokenSource cts)
|
||||
{
|
||||
var buffer = new byte[4096];
|
||||
try
|
||||
{
|
||||
while (!cts.Token.IsCancellationRequested)
|
||||
{
|
||||
var result = await source.ReceiveAsync(
|
||||
new ArraySegment<byte>(buffer),
|
||||
cts.Token);
|
||||
|
||||
if (result.MessageType == WebSocketMessageType.Close)
|
||||
{
|
||||
if (destination.State == WebSocketState.Open ||
|
||||
destination.State == WebSocketState.CloseReceived)
|
||||
{
|
||||
await destination.CloseAsync(
|
||||
result.CloseStatus ?? WebSocketCloseStatus.NormalClosure,
|
||||
result.CloseStatusDescription,
|
||||
cts.Token);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (destination.State == WebSocketState.Open)
|
||||
{
|
||||
await destination.SendAsync(
|
||||
new ArraySegment<byte>(buffer, 0, result.Count),
|
||||
result.MessageType,
|
||||
result.EndOfMessage,
|
||||
cts.Token);
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Expected during shutdown
|
||||
}
|
||||
catch (WebSocketException)
|
||||
{
|
||||
// Connection closed unexpectedly
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ServeFile(HttpContext context, IFileInfo fileInfo, string fileName)
|
||||
{
|
||||
if (!_contentTypeProvider.TryGetContentType(fileName, out var contentType))
|
||||
{
|
||||
contentType = "application/octet-stream";
|
||||
}
|
||||
|
||||
context.Response.StatusCode = StatusCodes.Status200OK;
|
||||
context.Response.ContentType = contentType;
|
||||
context.Response.ContentLength = fileInfo.Length;
|
||||
|
||||
await using var stream = fileInfo.CreateReadStream();
|
||||
await stream.CopyToAsync(context.Response.Body, context.RequestAborted);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user