#pragma warning disable CS0618 // ConfigureHttpMessageHandlerBuilder is obsolete - test validates legacy handler configuration using System; using System.Collections.Generic; using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Http; using Microsoft.Extensions.Options; using Microsoft.Extensions.Time.Testing; using Microsoft.IdentityModel.Tokens; using StellaOps.Auth.Client; using StellaOps.AirGap.Policy; using Xunit; using StellaOps.TestKit; namespace StellaOps.Auth.Client.Tests; public class ServiceCollectionExtensionsTests { [Trait("Category", TestCategories.Unit)] [Fact] public async Task AddStellaOpsAuthClient_ConfiguresRetryPolicy() { var services = new ServiceCollection(); services.AddLogging(); services.AddStellaOpsAuthClient(options => { options.Authority = "https://authority.test"; options.RetryDelays.Clear(); options.RetryDelays.Add(TimeSpan.FromMilliseconds(1)); options.DiscoveryCacheLifetime = TimeSpan.FromMinutes(1); options.JwksCacheLifetime = TimeSpan.FromMinutes(1); options.AllowOfflineCacheFallback = false; }); var recordedHandlers = new List(); var attemptCount = 0; services.AddHttpClient() .ConfigureHttpMessageHandlerBuilder(builder => { recordedHandlers = new List(builder.AdditionalHandlers); var responses = new Queue>(new[] { () => CreateResponse(HttpStatusCode.InternalServerError, "{}"), () => CreateResponse(HttpStatusCode.OK, "{\"token_endpoint\":\"https://authority.test/connect/token\",\"jwks_uri\":\"https://authority.test/jwks\"}") }); builder.PrimaryHandler = new LambdaHttpMessageHandler((_, _) => { attemptCount++; if (responses.Count == 0) { return Task.FromResult(CreateResponse(HttpStatusCode.OK, "{}")); } var factory = responses.Dequeue(); return Task.FromResult(factory()); }); }); using var provider = services.BuildServiceProvider(); var cache = provider.GetRequiredService(); var configuration = await cache.GetAsync(CancellationToken.None); Assert.Equal(new Uri("https://authority.test/connect/token"), configuration.TokenEndpoint); Assert.Equal(2, attemptCount); Assert.NotEmpty(recordedHandlers); Assert.Contains(recordedHandlers, handler => handler.GetType().Name.Contains("ResilienceHandler", StringComparison.Ordinal)); } [Trait("Category", TestCategories.Unit)] [Fact] public void EnsureEgressAllowed_InvokesPolicyWhenAuthorityProvided() { var services = new ServiceCollection(); var recordingPolicy = new RecordingPolicy(); services.AddSingleton(recordingPolicy); using var provider = services.BuildServiceProvider(); var options = new StellaOpsAuthClientOptions { Authority = "https://authority.test", DiscoveryCacheLifetime = TimeSpan.FromMinutes(1), JwksCacheLifetime = TimeSpan.FromMinutes(1), AllowOfflineCacheFallback = false, }; options.Validate(); var method = typeof(ServiceCollectionExtensions) .GetMethod("EnsureEgressAllowed", BindingFlags.NonPublic | BindingFlags.Static); Assert.NotNull(method); method!.Invoke(null, new object?[] { provider, options, "authority-discovery" }); Assert.Single(recordingPolicy.Requests); var request = recordingPolicy.Requests[0]; Assert.Equal("StellaOpsAuthClient", request.Component); Assert.Equal(new Uri("https://authority.test"), request.Destination); Assert.Equal("authority-discovery", request.Intent); } private static HttpResponseMessage CreateResponse(HttpStatusCode statusCode, string jsonContent) { return new HttpResponseMessage(statusCode) { Content = new StringContent(jsonContent) { Headers = { ContentType = new MediaTypeHeaderValue("application/json") } } }; } private sealed class LambdaHttpMessageHandler : HttpMessageHandler { private readonly Func> responder; public LambdaHttpMessageHandler(Func> responder) { this.responder = responder; } protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) => responder(request, cancellationToken); } [Trait("Category", TestCategories.Unit)] [Fact] public async Task AddStellaOpsApiAuthentication_AttachesPatAndTenantHeader() { var services = new ServiceCollection(); services.AddLogging(); services.AddStellaOpsAuthClient(options => { options.Authority = "https://authority.test"; options.DiscoveryCacheLifetime = TimeSpan.FromMinutes(1); options.JwksCacheLifetime = TimeSpan.FromMinutes(1); options.AllowOfflineCacheFallback = false; }); var tokenClient = new ThrowingTokenClient(); services.AddSingleton(tokenClient); var handler = new RecordingHttpMessageHandler(); services.AddHttpClient("notify") .ConfigurePrimaryHttpMessageHandler(() => handler) .AddStellaOpsApiAuthentication(options => { options.Mode = StellaOpsApiAuthMode.PersonalAccessToken; options.PersonalAccessToken = "pat-token"; options.Tenant = "tenant-123"; options.TenantHeader = "X-Custom-Tenant"; }); using var provider = services.BuildServiceProvider(); var client = provider.GetRequiredService().CreateClient("notify"); var response = await client.GetAsync("https://notify.example/api"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); Assert.Single(handler.AuthorizationHistory); var authorization = handler.AuthorizationHistory[0]; Assert.NotNull(authorization); Assert.Equal("Bearer", authorization!.Scheme); Assert.Equal("pat-token", authorization.Parameter); Assert.Single(handler.TenantHeaders); Assert.Equal("tenant-123", handler.TenantHeaders[0]); Assert.Equal(0, tokenClient.RequestCount); } [Trait("Category", TestCategories.Unit)] [Fact] public async Task AddStellaOpsApiAuthentication_UsesClientCredentialsWithCaching() { var services = new ServiceCollection(); services.AddLogging(); services.AddStellaOpsAuthClient(options => { options.Authority = "https://authority.test"; options.DiscoveryCacheLifetime = TimeSpan.FromMinutes(1); options.JwksCacheLifetime = TimeSpan.FromMinutes(1); options.AllowOfflineCacheFallback = false; options.ExpirationSkew = TimeSpan.FromSeconds(10); }); var fakeTime = new FakeTimeProvider(DateTimeOffset.Parse("2025-11-02T00:00:00Z")); services.AddSingleton(fakeTime); var recordingTokenClient = new RecordingTokenClient(fakeTime); services.AddSingleton(recordingTokenClient); var handler = new RecordingHttpMessageHandler(); services.AddHttpClient("notify") .ConfigurePrimaryHttpMessageHandler(() => handler) .AddStellaOpsApiAuthentication(options => { options.Mode = StellaOpsApiAuthMode.ClientCredentials; options.Scope = "notify.read"; options.Tenant = "tenant-oauth"; }); var secondHandler = new RecordingHttpMessageHandler(); services.AddHttpClient("notify2") .ConfigurePrimaryHttpMessageHandler(() => secondHandler) .AddStellaOpsApiAuthentication(options => { options.Mode = StellaOpsApiAuthMode.ClientCredentials; options.Scope = "notify.read"; options.Tenant = "tenant-oauth"; }); using var provider = services.BuildServiceProvider(); var factory = provider.GetRequiredService(); var client = factory.CreateClient("notify"); await client.GetAsync("https://notify.example/api"); await client.GetAsync("https://notify.example/api"); Assert.Equal(2, handler.AuthorizationHistory.Count); Assert.Equal(1, recordingTokenClient.ClientCredentialsCallCount); Assert.Equal(1, recordingTokenClient.GetCachedTokenCallCount); Assert.Equal(1, recordingTokenClient.CacheTokenCallCount); Assert.All(handler.AuthorizationHistory, header => { Assert.NotNull(header); Assert.Equal("Bearer", header!.Scheme); Assert.Equal("token-1", header.Parameter); }); Assert.All(handler.TenantHeaders, value => Assert.Equal("tenant-oauth", value)); var clientTwo = factory.CreateClient("notify2"); await clientTwo.GetAsync("https://notify.example/api"); Assert.Equal(1, recordingTokenClient.ClientCredentialsCallCount); Assert.True(recordingTokenClient.GetCachedTokenCallCount >= 2); Assert.Equal(1, recordingTokenClient.CacheTokenCallCount); Assert.Single(secondHandler.AuthorizationHistory); Assert.Equal("token-1", secondHandler.AuthorizationHistory[0]!.Parameter); // Advance beyond expiry buffer to force refresh. fakeTime.Advance(TimeSpan.FromMinutes(2)); await client.GetAsync("https://notify.example/api"); Assert.Equal(3, handler.AuthorizationHistory.Count); Assert.Equal("token-2", handler.AuthorizationHistory[^1]!.Parameter); Assert.Equal(2, recordingTokenClient.ClientCredentialsCallCount); Assert.True(recordingTokenClient.GetCachedTokenCallCount >= 2); Assert.True(recordingTokenClient.CacheTokenCallCount >= 2); } [Trait("Category", TestCategories.Unit)] [Fact] public async Task AddStellaOpsApiAuthentication_UsesPasswordFlowWithCaching() { var services = new ServiceCollection(); services.AddLogging(); services.AddStellaOpsAuthClient(options => { options.Authority = "https://authority.test"; options.DiscoveryCacheLifetime = TimeSpan.FromMinutes(1); options.JwksCacheLifetime = TimeSpan.FromMinutes(1); options.AllowOfflineCacheFallback = false; }); var fakeTime = new FakeTimeProvider(DateTimeOffset.Parse("2025-11-02T00:00:00Z")); services.AddSingleton(fakeTime); var recordingTokenClient = new RecordingTokenClient(fakeTime); services.AddSingleton(recordingTokenClient); var handler = new RecordingHttpMessageHandler(); services.AddHttpClient("vuln") .ConfigurePrimaryHttpMessageHandler(() => handler) .AddStellaOpsApiAuthentication(options => { options.Mode = StellaOpsApiAuthMode.Password; options.Username = "user1"; options.Password = "pass1"; options.Scope = "vuln.view"; }); using var provider = services.BuildServiceProvider(); var client = provider.GetRequiredService().CreateClient("vuln"); await client.GetAsync("https://vuln.example/api"); await client.GetAsync("https://vuln.example/api"); Assert.Equal(2, handler.AuthorizationHistory.Count); Assert.Equal(1, recordingTokenClient.PasswordCallCount); Assert.Equal(1, recordingTokenClient.GetCachedTokenCallCount); Assert.Equal(1, recordingTokenClient.CacheTokenCallCount); } [Trait("Category", TestCategories.Unit)] [Fact] public async Task AddStellaOpsAuthClient_DisablesRetriesWhenConfigured() { var services = new ServiceCollection(); services.AddLogging(); services.AddStellaOpsAuthClient(options => { options.Authority = "https://authority.test"; options.EnableRetries = false; options.DiscoveryCacheLifetime = TimeSpan.FromMinutes(1); options.JwksCacheLifetime = TimeSpan.FromMinutes(1); options.AllowOfflineCacheFallback = false; }); var attemptCount = 0; services.AddHttpClient() .ConfigureHttpMessageHandlerBuilder(builder => { builder.PrimaryHandler = new LambdaHttpMessageHandler((_, _) => { attemptCount++; return Task.FromResult(CreateResponse(HttpStatusCode.InternalServerError, "{}")); }); }); using var provider = services.BuildServiceProvider(); var cache = provider.GetRequiredService(); await Assert.ThrowsAsync(() => cache.GetAsync(CancellationToken.None)); Assert.Equal(1, attemptCount); } private sealed class RecordingHttpMessageHandler : HttpMessageHandler { public List AuthorizationHistory { get; } = new(); public List TenantHeaders { get; } = new(); protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { AuthorizationHistory.Add(request.Headers.Authorization); if (request.Headers.TryGetValues("X-Custom-Tenant", out var customTenant)) { TenantHeaders.Add(customTenant.Single()); } else if (request.Headers.TryGetValues("X-StellaOps-Tenant", out var defaultTenant)) { TenantHeaders.Add(defaultTenant.Single()); } else { TenantHeaders.Add(null); } return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); } } private sealed class RecordingPolicy : IEgressPolicy { private readonly List requests = new(); public bool IsSealed => true; public EgressPolicyMode Mode => EgressPolicyMode.Sealed; public IReadOnlyList Requests => requests; public EgressDecision Evaluate(EgressRequest request) { requests.Add(request); return EgressDecision.Allowed; } public ValueTask EvaluateAsync(EgressRequest request, CancellationToken cancellationToken = default) => new(Evaluate(request)); public void EnsureAllowed(EgressRequest request) { requests.Add(request); } public ValueTask EnsureAllowedAsync(EgressRequest request, CancellationToken cancellationToken = default) { EnsureAllowed(request); return ValueTask.CompletedTask; } } private sealed class ThrowingTokenClient : IStellaOpsTokenClient { public int RequestCount { get; private set; } public ValueTask CacheTokenAsync(string key, StellaOpsTokenCacheEntry entry, CancellationToken cancellationToken = default) => ValueTask.CompletedTask; public ValueTask ClearCachedTokenAsync(string key, CancellationToken cancellationToken = default) => ValueTask.CompletedTask; public Task GetJsonWebKeySetAsync(CancellationToken cancellationToken = default) => Task.FromResult(new JsonWebKeySet()); public ValueTask GetCachedTokenAsync(string key, CancellationToken cancellationToken = default) => ValueTask.FromResult(null); public Task RequestClientCredentialsTokenAsync(string? scope = null, IReadOnlyDictionary? additionalParameters = null, CancellationToken cancellationToken = default) { RequestCount++; throw new InvalidOperationException("Client credentials flow should not be invoked for PAT mode."); } public Task RequestPasswordTokenAsync(string username, string password, string? scope = null, IReadOnlyDictionary? additionalParameters = null, CancellationToken cancellationToken = default) { RequestCount++; throw new InvalidOperationException("Password flow should not be invoked for PAT mode."); } } private sealed class RecordingTokenClient : IStellaOpsTokenClient { private readonly FakeTimeProvider timeProvider; private int tokenCounter; private StellaOpsTokenCacheEntry? cachedEntry; public RecordingTokenClient(FakeTimeProvider timeProvider) { this.timeProvider = timeProvider; } public int ClientCredentialsCallCount { get; private set; } public int PasswordCallCount { get; private set; } public int GetCachedTokenCallCount { get; private set; } public int CacheTokenCallCount { get; private set; } public Task RequestClientCredentialsTokenAsync(string? scope = null, IReadOnlyDictionary? additionalParameters = null, CancellationToken cancellationToken = default) { ClientCredentialsCallCount++; var tokenId = Interlocked.Increment(ref tokenCounter); var result = new StellaOpsTokenResult( $"token-{tokenId}", "Bearer", timeProvider.GetUtcNow().AddMinutes(1), scope is null ? Array.Empty() : new[] { scope }, null, null, "{}"); return Task.FromResult(result); } public Task RequestPasswordTokenAsync(string username, string password, string? scope = null, IReadOnlyDictionary? additionalParameters = null, CancellationToken cancellationToken = default) { PasswordCallCount++; var tokenId = Interlocked.Increment(ref tokenCounter); var result = new StellaOpsTokenResult( $"token-{tokenId}", "Bearer", timeProvider.GetUtcNow().AddMinutes(1), scope is null ? Array.Empty() : new[] { scope }, null, null, "{}"); return Task.FromResult(result); } public Task GetJsonWebKeySetAsync(CancellationToken cancellationToken = default) => Task.FromResult(new JsonWebKeySet()); public ValueTask GetCachedTokenAsync(string key, CancellationToken cancellationToken = default) { GetCachedTokenCallCount++; return ValueTask.FromResult(cachedEntry); } public ValueTask CacheTokenAsync(string key, StellaOpsTokenCacheEntry entry, CancellationToken cancellationToken = default) { CacheTokenCallCount++; cachedEntry = entry; return ValueTask.CompletedTask; } public ValueTask ClearCachedTokenAsync(string key, CancellationToken cancellationToken = default) { cachedEntry = null; return ValueTask.CompletedTask; } } }