Files
git.stella-ops.org/src/Authority/StellaOps.Authority/StellaOps.Auth.Client.Tests/ServiceCollectionExtensionsTests.cs
StellaOps Bot 3f197814c5 save progress
2026-01-02 21:06:27 +02:00

507 lines
20 KiB
C#

#pragma warning disable CS0618 // ConfigureHttpMessageHandlerBuilder is obsolete - test validates legacy handler configuration
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Http;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Time.Testing;
using Microsoft.IdentityModel.Tokens;
using StellaOps.Auth.Client;
using StellaOps.AirGap.Policy;
using Xunit;
using StellaOps.TestKit;
namespace StellaOps.Auth.Client.Tests;
public class ServiceCollectionExtensionsTests
{
[Trait("Category", TestCategories.Unit)]
[Fact]
public async Task AddStellaOpsAuthClient_ConfiguresRetryPolicy()
{
var services = new ServiceCollection();
services.AddLogging();
services.AddStellaOpsAuthClient(options =>
{
options.Authority = "https://authority.test";
options.RetryDelays.Clear();
options.RetryDelays.Add(TimeSpan.FromMilliseconds(1));
options.DiscoveryCacheLifetime = TimeSpan.FromMinutes(1);
options.JwksCacheLifetime = TimeSpan.FromMinutes(1);
options.AllowOfflineCacheFallback = false;
});
var recordedHandlers = new List<DelegatingHandler>();
var attemptCount = 0;
services.AddHttpClient<StellaOpsDiscoveryCache>()
.ConfigureHttpMessageHandlerBuilder(builder =>
{
recordedHandlers = new List<DelegatingHandler>(builder.AdditionalHandlers);
var responses = new Queue<Func<HttpResponseMessage>>(new[]
{
() => CreateResponse(HttpStatusCode.InternalServerError, "{}"),
() => CreateResponse(HttpStatusCode.OK, "{\"token_endpoint\":\"https://authority.test/connect/token\",\"jwks_uri\":\"https://authority.test/jwks\"}")
});
builder.PrimaryHandler = new LambdaHttpMessageHandler((_, _) =>
{
attemptCount++;
if (responses.Count == 0)
{
return Task.FromResult(CreateResponse(HttpStatusCode.OK, "{}"));
}
var factory = responses.Dequeue();
return Task.FromResult(factory());
});
});
using var provider = services.BuildServiceProvider();
var cache = provider.GetRequiredService<StellaOpsDiscoveryCache>();
var configuration = await cache.GetAsync(CancellationToken.None);
Assert.Equal(new Uri("https://authority.test/connect/token"), configuration.TokenEndpoint);
Assert.Equal(2, attemptCount);
Assert.NotEmpty(recordedHandlers);
Assert.Contains(recordedHandlers, handler => handler.GetType().Name.Contains("ResilienceHandler", StringComparison.Ordinal));
}
[Trait("Category", TestCategories.Unit)]
[Fact]
public void EnsureEgressAllowed_InvokesPolicyWhenAuthorityProvided()
{
var services = new ServiceCollection();
var recordingPolicy = new RecordingPolicy();
services.AddSingleton<IEgressPolicy>(recordingPolicy);
using var provider = services.BuildServiceProvider();
var options = new StellaOpsAuthClientOptions
{
Authority = "https://authority.test",
DiscoveryCacheLifetime = TimeSpan.FromMinutes(1),
JwksCacheLifetime = TimeSpan.FromMinutes(1),
AllowOfflineCacheFallback = false,
};
options.Validate();
var method = typeof(ServiceCollectionExtensions)
.GetMethod("EnsureEgressAllowed", BindingFlags.NonPublic | BindingFlags.Static);
Assert.NotNull(method);
method!.Invoke(null, new object?[] { provider, options, "authority-discovery" });
Assert.Single(recordingPolicy.Requests);
var request = recordingPolicy.Requests[0];
Assert.Equal("StellaOpsAuthClient", request.Component);
Assert.Equal(new Uri("https://authority.test"), request.Destination);
Assert.Equal("authority-discovery", request.Intent);
}
private static HttpResponseMessage CreateResponse(HttpStatusCode statusCode, string jsonContent)
{
return new HttpResponseMessage(statusCode)
{
Content = new StringContent(jsonContent)
{
Headers = { ContentType = new MediaTypeHeaderValue("application/json") }
}
};
}
private sealed class LambdaHttpMessageHandler : HttpMessageHandler
{
private readonly Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> responder;
public LambdaHttpMessageHandler(Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> responder)
{
this.responder = responder;
}
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
=> responder(request, cancellationToken);
}
[Trait("Category", TestCategories.Unit)]
[Fact]
public async Task AddStellaOpsApiAuthentication_AttachesPatAndTenantHeader()
{
var services = new ServiceCollection();
services.AddLogging();
services.AddStellaOpsAuthClient(options =>
{
options.Authority = "https://authority.test";
options.DiscoveryCacheLifetime = TimeSpan.FromMinutes(1);
options.JwksCacheLifetime = TimeSpan.FromMinutes(1);
options.AllowOfflineCacheFallback = false;
});
var tokenClient = new ThrowingTokenClient();
services.AddSingleton<IStellaOpsTokenClient>(tokenClient);
var handler = new RecordingHttpMessageHandler();
services.AddHttpClient("notify")
.ConfigurePrimaryHttpMessageHandler(() => handler)
.AddStellaOpsApiAuthentication(options =>
{
options.Mode = StellaOpsApiAuthMode.PersonalAccessToken;
options.PersonalAccessToken = "pat-token";
options.Tenant = "tenant-123";
options.TenantHeader = "X-Custom-Tenant";
});
using var provider = services.BuildServiceProvider();
var client = provider.GetRequiredService<IHttpClientFactory>().CreateClient("notify");
var response = await client.GetAsync("https://notify.example/api");
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Single(handler.AuthorizationHistory);
var authorization = handler.AuthorizationHistory[0];
Assert.NotNull(authorization);
Assert.Equal("Bearer", authorization!.Scheme);
Assert.Equal("pat-token", authorization.Parameter);
Assert.Single(handler.TenantHeaders);
Assert.Equal("tenant-123", handler.TenantHeaders[0]);
Assert.Equal(0, tokenClient.RequestCount);
}
[Trait("Category", TestCategories.Unit)]
[Fact]
public async Task AddStellaOpsApiAuthentication_UsesClientCredentialsWithCaching()
{
var services = new ServiceCollection();
services.AddLogging();
services.AddStellaOpsAuthClient(options =>
{
options.Authority = "https://authority.test";
options.DiscoveryCacheLifetime = TimeSpan.FromMinutes(1);
options.JwksCacheLifetime = TimeSpan.FromMinutes(1);
options.AllowOfflineCacheFallback = false;
options.ExpirationSkew = TimeSpan.FromSeconds(10);
});
var fakeTime = new FakeTimeProvider(DateTimeOffset.Parse("2025-11-02T00:00:00Z"));
services.AddSingleton<TimeProvider>(fakeTime);
var recordingTokenClient = new RecordingTokenClient(fakeTime);
services.AddSingleton<IStellaOpsTokenClient>(recordingTokenClient);
var handler = new RecordingHttpMessageHandler();
services.AddHttpClient("notify")
.ConfigurePrimaryHttpMessageHandler(() => handler)
.AddStellaOpsApiAuthentication(options =>
{
options.Mode = StellaOpsApiAuthMode.ClientCredentials;
options.Scope = "notify.read";
options.Tenant = "tenant-oauth";
});
var secondHandler = new RecordingHttpMessageHandler();
services.AddHttpClient("notify2")
.ConfigurePrimaryHttpMessageHandler(() => secondHandler)
.AddStellaOpsApiAuthentication(options =>
{
options.Mode = StellaOpsApiAuthMode.ClientCredentials;
options.Scope = "notify.read";
options.Tenant = "tenant-oauth";
});
using var provider = services.BuildServiceProvider();
var factory = provider.GetRequiredService<IHttpClientFactory>();
var client = factory.CreateClient("notify");
await client.GetAsync("https://notify.example/api");
await client.GetAsync("https://notify.example/api");
Assert.Equal(2, handler.AuthorizationHistory.Count);
Assert.Equal(1, recordingTokenClient.ClientCredentialsCallCount);
Assert.Equal(1, recordingTokenClient.GetCachedTokenCallCount);
Assert.Equal(1, recordingTokenClient.CacheTokenCallCount);
Assert.All(handler.AuthorizationHistory, header =>
{
Assert.NotNull(header);
Assert.Equal("Bearer", header!.Scheme);
Assert.Equal("token-1", header.Parameter);
});
Assert.All(handler.TenantHeaders, value => Assert.Equal("tenant-oauth", value));
var clientTwo = factory.CreateClient("notify2");
await clientTwo.GetAsync("https://notify.example/api");
Assert.Equal(1, recordingTokenClient.ClientCredentialsCallCount);
Assert.True(recordingTokenClient.GetCachedTokenCallCount >= 2);
Assert.Equal(1, recordingTokenClient.CacheTokenCallCount);
Assert.Single(secondHandler.AuthorizationHistory);
Assert.Equal("token-1", secondHandler.AuthorizationHistory[0]!.Parameter);
// Advance beyond expiry buffer to force refresh.
fakeTime.Advance(TimeSpan.FromMinutes(2));
await client.GetAsync("https://notify.example/api");
Assert.Equal(3, handler.AuthorizationHistory.Count);
Assert.Equal("token-2", handler.AuthorizationHistory[^1]!.Parameter);
Assert.Equal(2, recordingTokenClient.ClientCredentialsCallCount);
Assert.True(recordingTokenClient.GetCachedTokenCallCount >= 2);
Assert.True(recordingTokenClient.CacheTokenCallCount >= 2);
}
[Trait("Category", TestCategories.Unit)]
[Fact]
public async Task AddStellaOpsApiAuthentication_UsesPasswordFlowWithCaching()
{
var services = new ServiceCollection();
services.AddLogging();
services.AddStellaOpsAuthClient(options =>
{
options.Authority = "https://authority.test";
options.DiscoveryCacheLifetime = TimeSpan.FromMinutes(1);
options.JwksCacheLifetime = TimeSpan.FromMinutes(1);
options.AllowOfflineCacheFallback = false;
});
var fakeTime = new FakeTimeProvider(DateTimeOffset.Parse("2025-11-02T00:00:00Z"));
services.AddSingleton<TimeProvider>(fakeTime);
var recordingTokenClient = new RecordingTokenClient(fakeTime);
services.AddSingleton<IStellaOpsTokenClient>(recordingTokenClient);
var handler = new RecordingHttpMessageHandler();
services.AddHttpClient("vuln")
.ConfigurePrimaryHttpMessageHandler(() => handler)
.AddStellaOpsApiAuthentication(options =>
{
options.Mode = StellaOpsApiAuthMode.Password;
options.Username = "user1";
options.Password = "pass1";
options.Scope = "vuln.view";
});
using var provider = services.BuildServiceProvider();
var client = provider.GetRequiredService<IHttpClientFactory>().CreateClient("vuln");
await client.GetAsync("https://vuln.example/api");
await client.GetAsync("https://vuln.example/api");
Assert.Equal(2, handler.AuthorizationHistory.Count);
Assert.Equal(1, recordingTokenClient.PasswordCallCount);
Assert.Equal(1, recordingTokenClient.GetCachedTokenCallCount);
Assert.Equal(1, recordingTokenClient.CacheTokenCallCount);
}
[Trait("Category", TestCategories.Unit)]
[Fact]
public async Task AddStellaOpsAuthClient_DisablesRetriesWhenConfigured()
{
var services = new ServiceCollection();
services.AddLogging();
services.AddStellaOpsAuthClient(options =>
{
options.Authority = "https://authority.test";
options.EnableRetries = false;
options.DiscoveryCacheLifetime = TimeSpan.FromMinutes(1);
options.JwksCacheLifetime = TimeSpan.FromMinutes(1);
options.AllowOfflineCacheFallback = false;
});
var attemptCount = 0;
services.AddHttpClient<StellaOpsDiscoveryCache>()
.ConfigureHttpMessageHandlerBuilder(builder =>
{
builder.PrimaryHandler = new LambdaHttpMessageHandler((_, _) =>
{
attemptCount++;
return Task.FromResult(CreateResponse(HttpStatusCode.InternalServerError, "{}"));
});
});
using var provider = services.BuildServiceProvider();
var cache = provider.GetRequiredService<StellaOpsDiscoveryCache>();
await Assert.ThrowsAsync<HttpRequestException>(() => cache.GetAsync(CancellationToken.None));
Assert.Equal(1, attemptCount);
}
private sealed class RecordingHttpMessageHandler : HttpMessageHandler
{
public List<AuthenticationHeaderValue?> AuthorizationHistory { get; } = new();
public List<string?> TenantHeaders { get; } = new();
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
AuthorizationHistory.Add(request.Headers.Authorization);
if (request.Headers.TryGetValues("X-Custom-Tenant", out var customTenant))
{
TenantHeaders.Add(customTenant.Single());
}
else if (request.Headers.TryGetValues("X-StellaOps-Tenant", out var defaultTenant))
{
TenantHeaders.Add(defaultTenant.Single());
}
else
{
TenantHeaders.Add(null);
}
return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK));
}
}
private sealed class RecordingPolicy : IEgressPolicy
{
private readonly List<EgressRequest> requests = new();
public bool IsSealed => true;
public EgressPolicyMode Mode => EgressPolicyMode.Sealed;
public IReadOnlyList<EgressRequest> Requests => requests;
public EgressDecision Evaluate(EgressRequest request)
{
requests.Add(request);
return EgressDecision.Allowed;
}
public ValueTask<EgressDecision> EvaluateAsync(EgressRequest request, CancellationToken cancellationToken = default)
=> new(Evaluate(request));
public void EnsureAllowed(EgressRequest request)
{
requests.Add(request);
}
public ValueTask EnsureAllowedAsync(EgressRequest request, CancellationToken cancellationToken = default)
{
EnsureAllowed(request);
return ValueTask.CompletedTask;
}
}
private sealed class ThrowingTokenClient : IStellaOpsTokenClient
{
public int RequestCount { get; private set; }
public ValueTask CacheTokenAsync(string key, StellaOpsTokenCacheEntry entry, CancellationToken cancellationToken = default)
=> ValueTask.CompletedTask;
public ValueTask ClearCachedTokenAsync(string key, CancellationToken cancellationToken = default)
=> ValueTask.CompletedTask;
public Task<JsonWebKeySet> GetJsonWebKeySetAsync(CancellationToken cancellationToken = default)
=> Task.FromResult(new JsonWebKeySet());
public ValueTask<StellaOpsTokenCacheEntry?> GetCachedTokenAsync(string key, CancellationToken cancellationToken = default)
=> ValueTask.FromResult<StellaOpsTokenCacheEntry?>(null);
public Task<StellaOpsTokenResult> RequestClientCredentialsTokenAsync(string? scope = null, IReadOnlyDictionary<string, string>? additionalParameters = null, CancellationToken cancellationToken = default)
{
RequestCount++;
throw new InvalidOperationException("Client credentials flow should not be invoked for PAT mode.");
}
public Task<StellaOpsTokenResult> RequestPasswordTokenAsync(string username, string password, string? scope = null, IReadOnlyDictionary<string, string>? additionalParameters = null, CancellationToken cancellationToken = default)
{
RequestCount++;
throw new InvalidOperationException("Password flow should not be invoked for PAT mode.");
}
}
private sealed class RecordingTokenClient : IStellaOpsTokenClient
{
private readonly FakeTimeProvider timeProvider;
private int tokenCounter;
private StellaOpsTokenCacheEntry? cachedEntry;
public RecordingTokenClient(FakeTimeProvider timeProvider)
{
this.timeProvider = timeProvider;
}
public int ClientCredentialsCallCount { get; private set; }
public int PasswordCallCount { get; private set; }
public int GetCachedTokenCallCount { get; private set; }
public int CacheTokenCallCount { get; private set; }
public Task<StellaOpsTokenResult> RequestClientCredentialsTokenAsync(string? scope = null, IReadOnlyDictionary<string, string>? additionalParameters = null, CancellationToken cancellationToken = default)
{
ClientCredentialsCallCount++;
var tokenId = Interlocked.Increment(ref tokenCounter);
var result = new StellaOpsTokenResult(
$"token-{tokenId}",
"Bearer",
timeProvider.GetUtcNow().AddMinutes(1),
scope is null ? Array.Empty<string>() : new[] { scope },
null,
null,
"{}");
return Task.FromResult(result);
}
public Task<StellaOpsTokenResult> RequestPasswordTokenAsync(string username, string password, string? scope = null, IReadOnlyDictionary<string, string>? additionalParameters = null, CancellationToken cancellationToken = default)
{
PasswordCallCount++;
var tokenId = Interlocked.Increment(ref tokenCounter);
var result = new StellaOpsTokenResult(
$"token-{tokenId}",
"Bearer",
timeProvider.GetUtcNow().AddMinutes(1),
scope is null ? Array.Empty<string>() : new[] { scope },
null,
null,
"{}");
return Task.FromResult(result);
}
public Task<JsonWebKeySet> GetJsonWebKeySetAsync(CancellationToken cancellationToken = default)
=> Task.FromResult(new JsonWebKeySet());
public ValueTask<StellaOpsTokenCacheEntry?> GetCachedTokenAsync(string key, CancellationToken cancellationToken = default)
{
GetCachedTokenCallCount++;
return ValueTask.FromResult(cachedEntry);
}
public ValueTask CacheTokenAsync(string key, StellaOpsTokenCacheEntry entry, CancellationToken cancellationToken = default)
{
CacheTokenCallCount++;
cachedEntry = entry;
return ValueTask.CompletedTask;
}
public ValueTask ClearCachedTokenAsync(string key, CancellationToken cancellationToken = default)
{
cachedEntry = null;
return ValueTask.CompletedTask;
}
}
}