283 lines
9.6 KiB
C#
283 lines
9.6 KiB
C#
using System.Security.Cryptography;
|
|
using System.Text.Json;
|
|
using Microsoft.Extensions.Logging;
|
|
using StellaOps.Messaging.Abstractions;
|
|
using StackExchange.Redis;
|
|
|
|
namespace StellaOps.Messaging.Transport.Valkey;
|
|
|
|
/// <summary>
|
|
/// Valkey/Redis implementation of <see cref="IAtomicTokenStore{TPayload}"/>.
|
|
/// Uses Lua scripts for atomic compare-and-delete operations.
|
|
/// </summary>
|
|
public sealed class ValkeyAtomicTokenStore<TPayload> : IAtomicTokenStore<TPayload>
|
|
{
|
|
private readonly ValkeyConnectionFactory _connectionFactory;
|
|
private readonly string _name;
|
|
private readonly ILogger<ValkeyAtomicTokenStore<TPayload>>? _logger;
|
|
private readonly JsonSerializerOptions _jsonOptions;
|
|
private readonly TimeProvider _timeProvider;
|
|
|
|
// Lua script for atomic consume: GET, compare, DELETE if matches
|
|
private const string ConsumeScript = @"
|
|
local value = redis.call('GET', KEYS[1])
|
|
if not value then
|
|
return {0, nil}
|
|
end
|
|
local data = cjson.decode(value)
|
|
if data.token ~= ARGV[1] then
|
|
return {2, value}
|
|
end
|
|
redis.call('DEL', KEYS[1])
|
|
return {1, value}
|
|
";
|
|
|
|
public ValkeyAtomicTokenStore(
|
|
ValkeyConnectionFactory connectionFactory,
|
|
string name,
|
|
ILogger<ValkeyAtomicTokenStore<TPayload>>? logger = null,
|
|
JsonSerializerOptions? jsonOptions = null,
|
|
TimeProvider? timeProvider = null)
|
|
{
|
|
_connectionFactory = connectionFactory ?? throw new ArgumentNullException(nameof(connectionFactory));
|
|
_name = name ?? throw new ArgumentNullException(nameof(name));
|
|
_logger = logger;
|
|
_jsonOptions = jsonOptions ?? new JsonSerializerOptions
|
|
{
|
|
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
|
WriteIndented = false
|
|
};
|
|
_timeProvider = timeProvider ?? TimeProvider.System;
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public string ProviderName => "valkey";
|
|
|
|
/// <inheritdoc />
|
|
public async ValueTask<TokenIssueResult> IssueAsync(
|
|
string key,
|
|
TPayload payload,
|
|
TimeSpan ttl,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(key);
|
|
|
|
var redisKey = BuildKey(key);
|
|
var now = _timeProvider.GetUtcNow();
|
|
var expiresAt = now.Add(ttl);
|
|
|
|
// Generate secure random token
|
|
var tokenBytes = new byte[32];
|
|
RandomNumberGenerator.Fill(tokenBytes);
|
|
var token = Convert.ToBase64String(tokenBytes);
|
|
|
|
var entry = new TokenData<TPayload>
|
|
{
|
|
Token = token,
|
|
Payload = payload,
|
|
IssuedAt = now,
|
|
ExpiresAt = expiresAt
|
|
};
|
|
|
|
var serialized = JsonSerializer.Serialize(entry, _jsonOptions);
|
|
var db = await _connectionFactory.GetDatabaseAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
await db.StringSetAsync(redisKey, serialized, ttl).ConfigureAwait(false);
|
|
|
|
return TokenIssueResult.Succeeded(token, expiresAt);
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async ValueTask<TokenIssueResult> StoreAsync(
|
|
string key,
|
|
string token,
|
|
TPayload payload,
|
|
TimeSpan ttl,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(key);
|
|
ArgumentNullException.ThrowIfNull(token);
|
|
|
|
var redisKey = BuildKey(key);
|
|
var now = _timeProvider.GetUtcNow();
|
|
var expiresAt = now.Add(ttl);
|
|
|
|
var entry = new TokenData<TPayload>
|
|
{
|
|
Token = token,
|
|
Payload = payload,
|
|
IssuedAt = now,
|
|
ExpiresAt = expiresAt
|
|
};
|
|
|
|
var serialized = JsonSerializer.Serialize(entry, _jsonOptions);
|
|
var db = await _connectionFactory.GetDatabaseAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
await db.StringSetAsync(redisKey, serialized, ttl).ConfigureAwait(false);
|
|
|
|
return TokenIssueResult.Succeeded(token, expiresAt);
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async ValueTask<TokenConsumeResult<TPayload>> TryConsumeAsync(
|
|
string key,
|
|
string expectedToken,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(key);
|
|
ArgumentNullException.ThrowIfNull(expectedToken);
|
|
|
|
var redisKey = BuildKey(key);
|
|
var db = await _connectionFactory.GetDatabaseAsync(cancellationToken).ConfigureAwait(false);
|
|
var now = _timeProvider.GetUtcNow();
|
|
|
|
try
|
|
{
|
|
var result = await db.ScriptEvaluateAsync(
|
|
ConsumeScript,
|
|
new RedisKey[] { redisKey },
|
|
new RedisValue[] { expectedToken }).ConfigureAwait(false);
|
|
|
|
var results = (RedisResult[])result!;
|
|
var status = (int)results[0];
|
|
|
|
switch (status)
|
|
{
|
|
case 0: // Not found
|
|
return TokenConsumeResult<TPayload>.NotFound();
|
|
|
|
case 1: // Success
|
|
var data = JsonSerializer.Deserialize<TokenData<TPayload>>((string)results[1]!, _jsonOptions);
|
|
if (data is null)
|
|
{
|
|
return TokenConsumeResult<TPayload>.NotFound();
|
|
}
|
|
|
|
if (data.ExpiresAt < now)
|
|
{
|
|
return TokenConsumeResult<TPayload>.Expired(data.IssuedAt, data.ExpiresAt);
|
|
}
|
|
|
|
return TokenConsumeResult<TPayload>.Success(data.Payload!, data.IssuedAt, data.ExpiresAt);
|
|
|
|
case 2: // Mismatch
|
|
return TokenConsumeResult<TPayload>.Mismatch();
|
|
|
|
default:
|
|
return TokenConsumeResult<TPayload>.NotFound();
|
|
}
|
|
}
|
|
catch (RedisServerException ex) when (ex.Message.Contains("NOSCRIPT"))
|
|
{
|
|
// Fallback: non-atomic approach (less safe but works without Lua)
|
|
return await TryConsumeNonAtomicAsync(db, redisKey, expectedToken, now).ConfigureAwait(false);
|
|
}
|
|
}
|
|
|
|
private async ValueTask<TokenConsumeResult<TPayload>> TryConsumeNonAtomicAsync(
|
|
IDatabase db,
|
|
string redisKey,
|
|
string expectedToken,
|
|
DateTimeOffset now)
|
|
{
|
|
var value = await db.StringGetAsync(redisKey).ConfigureAwait(false);
|
|
if (value.IsNullOrEmpty)
|
|
{
|
|
return TokenConsumeResult<TPayload>.NotFound();
|
|
}
|
|
|
|
var data = JsonSerializer.Deserialize<TokenData<TPayload>>((string)value!, _jsonOptions);
|
|
if (data is null)
|
|
{
|
|
return TokenConsumeResult<TPayload>.NotFound();
|
|
}
|
|
|
|
if (data.ExpiresAt < now)
|
|
{
|
|
await db.KeyDeleteAsync(redisKey).ConfigureAwait(false);
|
|
return TokenConsumeResult<TPayload>.Expired(data.IssuedAt, data.ExpiresAt);
|
|
}
|
|
|
|
if (!string.Equals(data.Token, expectedToken, StringComparison.Ordinal))
|
|
{
|
|
return TokenConsumeResult<TPayload>.Mismatch();
|
|
}
|
|
|
|
// Try to delete - if someone else deleted it first, we lost the race
|
|
if (await db.KeyDeleteAsync(redisKey).ConfigureAwait(false))
|
|
{
|
|
return TokenConsumeResult<TPayload>.Success(data.Payload!, data.IssuedAt, data.ExpiresAt);
|
|
}
|
|
|
|
return TokenConsumeResult<TPayload>.NotFound();
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async ValueTask<bool> ExistsAsync(string key, CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(key);
|
|
|
|
var redisKey = BuildKey(key);
|
|
var db = await _connectionFactory.GetDatabaseAsync(cancellationToken).ConfigureAwait(false);
|
|
return await db.KeyExistsAsync(redisKey).ConfigureAwait(false);
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async ValueTask<bool> RevokeAsync(string key, CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(key);
|
|
|
|
var redisKey = BuildKey(key);
|
|
var db = await _connectionFactory.GetDatabaseAsync(cancellationToken).ConfigureAwait(false);
|
|
return await db.KeyDeleteAsync(redisKey).ConfigureAwait(false);
|
|
}
|
|
|
|
private string BuildKey(string key) => $"token:{_name}:{key}";
|
|
|
|
private sealed class TokenData<T>
|
|
{
|
|
public required string Token { get; init; }
|
|
public T? Payload { get; init; }
|
|
public DateTimeOffset IssuedAt { get; init; }
|
|
public DateTimeOffset ExpiresAt { get; init; }
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Factory for creating Valkey atomic token store instances.
|
|
/// </summary>
|
|
public sealed class ValkeyAtomicTokenStoreFactory : IAtomicTokenStoreFactory
|
|
{
|
|
private readonly ValkeyConnectionFactory _connectionFactory;
|
|
private readonly ILoggerFactory? _loggerFactory;
|
|
private readonly JsonSerializerOptions? _jsonOptions;
|
|
private readonly TimeProvider _timeProvider;
|
|
|
|
public ValkeyAtomicTokenStoreFactory(
|
|
ValkeyConnectionFactory connectionFactory,
|
|
ILoggerFactory? loggerFactory = null,
|
|
JsonSerializerOptions? jsonOptions = null,
|
|
TimeProvider? timeProvider = null)
|
|
{
|
|
_connectionFactory = connectionFactory ?? throw new ArgumentNullException(nameof(connectionFactory));
|
|
_loggerFactory = loggerFactory;
|
|
_jsonOptions = jsonOptions;
|
|
_timeProvider = timeProvider ?? TimeProvider.System;
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public string ProviderName => "valkey";
|
|
|
|
/// <inheritdoc />
|
|
public IAtomicTokenStore<TPayload> Create<TPayload>(string name)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(name);
|
|
return new ValkeyAtomicTokenStore<TPayload>(
|
|
_connectionFactory,
|
|
name,
|
|
_loggerFactory?.CreateLogger<ValkeyAtomicTokenStore<TPayload>>(),
|
|
_jsonOptions,
|
|
_timeProvider);
|
|
}
|
|
}
|