using FluentAssertions; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using Moq; using StellaOps.Gateway.WebService.Middleware; using StellaOps.Router.Common.Models; using Xunit; namespace StellaOps.Gateway.WebService.Tests; /// /// Unit tests for . /// public sealed class PayloadLimitsMiddlewareTests { private readonly Mock _trackerMock; private readonly Mock _nextMock; private readonly PayloadLimits _defaultLimits; private bool _nextCalled; public PayloadLimitsMiddlewareTests() { _trackerMock = new Mock(); _nextMock = new Mock(); _nextMock.Setup(n => n(It.IsAny())) .Callback(() => _nextCalled = true) .Returns(Task.CompletedTask); _defaultLimits = new PayloadLimits { MaxRequestBytesPerCall = 10 * 1024 * 1024, // 10MB MaxRequestBytesPerConnection = 100 * 1024 * 1024, // 100MB MaxAggregateInflightBytes = 1024 * 1024 * 1024 // 1GB }; } private PayloadLimitsMiddleware CreateMiddleware(PayloadLimits? limits = null) { return new PayloadLimitsMiddleware( _nextMock.Object, Options.Create(limits ?? _defaultLimits), NullLogger.Instance); } private static HttpContext CreateHttpContext(long? contentLength = null, string connectionId = "conn-1") { var context = new DefaultHttpContext(); context.Response.Body = new MemoryStream(); context.Request.Body = new MemoryStream(); context.Connection.Id = connectionId; if (contentLength.HasValue) { context.Request.ContentLength = contentLength; } return context; } #region Within Limits Tests [Fact] public async Task Invoke_WithinLimits_CallsNext() { // Arrange var middleware = CreateMiddleware(); var context = CreateHttpContext(contentLength: 1000); _trackerMock.Setup(t => t.TryReserve("conn-1", 1000)) .Returns(true); // Act await middleware.Invoke(context, _trackerMock.Object); // Assert _nextCalled.Should().BeTrue(); } [Fact] public async Task Invoke_WithNoContentLength_CallsNext() { // Arrange var middleware = CreateMiddleware(); var context = CreateHttpContext(contentLength: null); _trackerMock.Setup(t => t.TryReserve("conn-1", 0)) .Returns(true); // Act await middleware.Invoke(context, _trackerMock.Object); // Assert _nextCalled.Should().BeTrue(); } [Fact] public async Task Invoke_WithZeroContentLength_CallsNext() { // Arrange var middleware = CreateMiddleware(); var context = CreateHttpContext(contentLength: 0); _trackerMock.Setup(t => t.TryReserve("conn-1", 0)) .Returns(true); // Act await middleware.Invoke(context, _trackerMock.Object); // Assert _nextCalled.Should().BeTrue(); } #endregion #region Per-Call Limit Tests [Fact] public async Task Invoke_ExceedsPerCallLimit_Returns413() { // Arrange var limits = new PayloadLimits { MaxRequestBytesPerCall = 1000 }; var middleware = CreateMiddleware(limits); var context = CreateHttpContext(contentLength: 2000); // Act await middleware.Invoke(context, _trackerMock.Object); // Assert _nextCalled.Should().BeFalse(); context.Response.StatusCode.Should().Be(StatusCodes.Status413PayloadTooLarge); } [Fact] public async Task Invoke_ExceedsPerCallLimit_WritesErrorResponse() { // Arrange var limits = new PayloadLimits { MaxRequestBytesPerCall = 1000 }; var middleware = CreateMiddleware(limits); var context = CreateHttpContext(contentLength: 2000); // Act await middleware.Invoke(context, _trackerMock.Object); // Assert context.Response.Body.Seek(0, SeekOrigin.Begin); using var reader = new StreamReader(context.Response.Body); var responseBody = await reader.ReadToEndAsync(); responseBody.Should().Contain("Payload Too Large"); responseBody.Should().Contain("1000"); responseBody.Should().Contain("2000"); } [Fact] public async Task Invoke_ExactlyAtPerCallLimit_CallsNext() { // Arrange var limits = new PayloadLimits { MaxRequestBytesPerCall = 1000 }; var middleware = CreateMiddleware(limits); var context = CreateHttpContext(contentLength: 1000); _trackerMock.Setup(t => t.TryReserve("conn-1", 1000)) .Returns(true); // Act await middleware.Invoke(context, _trackerMock.Object); // Assert _nextCalled.Should().BeTrue(); } #endregion #region Aggregate Limit Tests [Fact] public async Task Invoke_ExceedsAggregateLimit_Returns503() { // Arrange var middleware = CreateMiddleware(); var context = CreateHttpContext(contentLength: 1000); _trackerMock.Setup(t => t.TryReserve("conn-1", 1000)) .Returns(false); _trackerMock.Setup(t => t.IsOverloaded) .Returns(true); _trackerMock.Setup(t => t.CurrentInflightBytes) .Returns(1024 * 1024 * 1024); // 1GB // Act await middleware.Invoke(context, _trackerMock.Object); // Assert _nextCalled.Should().BeFalse(); context.Response.StatusCode.Should().Be(StatusCodes.Status503ServiceUnavailable); } [Fact] public async Task Invoke_ExceedsAggregateLimit_WritesOverloadedResponse() { // Arrange var middleware = CreateMiddleware(); var context = CreateHttpContext(contentLength: 1000); _trackerMock.Setup(t => t.TryReserve("conn-1", 1000)) .Returns(false); _trackerMock.Setup(t => t.IsOverloaded) .Returns(true); // Act await middleware.Invoke(context, _trackerMock.Object); // Assert context.Response.Body.Seek(0, SeekOrigin.Begin); using var reader = new StreamReader(context.Response.Body); var responseBody = await reader.ReadToEndAsync(); responseBody.Should().Contain("Overloaded"); } #endregion #region Per-Connection Limit Tests [Fact] public async Task Invoke_ExceedsPerConnectionLimit_Returns429() { // Arrange var middleware = CreateMiddleware(); var context = CreateHttpContext(contentLength: 1000); _trackerMock.Setup(t => t.TryReserve("conn-1", 1000)) .Returns(false); _trackerMock.Setup(t => t.IsOverloaded) .Returns(false); // Not aggregate limit _trackerMock.Setup(t => t.GetConnectionInflightBytes("conn-1")) .Returns(100 * 1024 * 1024); // 100MB // Act await middleware.Invoke(context, _trackerMock.Object); // Assert _nextCalled.Should().BeFalse(); context.Response.StatusCode.Should().Be(StatusCodes.Status429TooManyRequests); } [Fact] public async Task Invoke_ExceedsPerConnectionLimit_WritesErrorResponse() { // Arrange var middleware = CreateMiddleware(); var context = CreateHttpContext(contentLength: 1000); _trackerMock.Setup(t => t.TryReserve("conn-1", 1000)) .Returns(false); _trackerMock.Setup(t => t.IsOverloaded) .Returns(false); // Act await middleware.Invoke(context, _trackerMock.Object); // Assert context.Response.Body.Seek(0, SeekOrigin.Begin); using var reader = new StreamReader(context.Response.Body); var responseBody = await reader.ReadToEndAsync(); responseBody.Should().Contain("Too Many Requests"); } #endregion #region Release Tests [Fact] public async Task Invoke_AfterSuccess_ReleasesReservation() { // Arrange var middleware = CreateMiddleware(); var context = CreateHttpContext(contentLength: 1000); _trackerMock.Setup(t => t.TryReserve("conn-1", 1000)) .Returns(true); // Act await middleware.Invoke(context, _trackerMock.Object); // Assert _trackerMock.Verify(t => t.Release("conn-1", It.IsAny()), Times.Once); } [Fact] public async Task Invoke_AfterNextThrows_StillReleasesReservation() { // Arrange var middleware = CreateMiddleware(); var context = CreateHttpContext(contentLength: 1000); _trackerMock.Setup(t => t.TryReserve("conn-1", 1000)) .Returns(true); _nextMock.Setup(n => n(It.IsAny())) .ThrowsAsync(new InvalidOperationException("Test error")); // Act var act = async () => await middleware.Invoke(context, _trackerMock.Object); // Assert await act.Should().ThrowAsync(); _trackerMock.Verify(t => t.Release("conn-1", It.IsAny()), Times.Once); } #endregion #region Different Connections Tests [Fact] public async Task Invoke_DifferentConnections_TrackedSeparately() { // Arrange var middleware = CreateMiddleware(); var context1 = CreateHttpContext(contentLength: 1000, connectionId: "conn-1"); var context2 = CreateHttpContext(contentLength: 2000, connectionId: "conn-2"); _trackerMock.Setup(t => t.TryReserve(It.IsAny(), It.IsAny())) .Returns(true); // Act await middleware.Invoke(context1, _trackerMock.Object); await middleware.Invoke(context2, _trackerMock.Object); // Assert _trackerMock.Verify(t => t.TryReserve("conn-1", 1000), Times.Once); _trackerMock.Verify(t => t.TryReserve("conn-2", 2000), Times.Once); } #endregion }