up
This commit is contained in:
@@ -0,0 +1,282 @@
|
||||
using System.Security.Cryptography;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Messaging.Abstractions;
|
||||
using StackExchange.Redis;
|
||||
|
||||
namespace StellaOps.Messaging.Transport.Valkey;
|
||||
|
||||
/// <summary>
|
||||
/// Valkey/Redis implementation of <see cref="IAtomicTokenStore{TPayload}"/>.
|
||||
/// Uses Lua scripts for atomic compare-and-delete operations.
|
||||
/// </summary>
|
||||
public sealed class ValkeyAtomicTokenStore<TPayload> : IAtomicTokenStore<TPayload>
|
||||
{
|
||||
private readonly ValkeyConnectionFactory _connectionFactory;
|
||||
private readonly string _name;
|
||||
private readonly ILogger<ValkeyAtomicTokenStore<TPayload>>? _logger;
|
||||
private readonly JsonSerializerOptions _jsonOptions;
|
||||
private readonly TimeProvider _timeProvider;
|
||||
|
||||
// Lua script for atomic consume: GET, compare, DELETE if matches
|
||||
private const string ConsumeScript = @"
|
||||
local value = redis.call('GET', KEYS[1])
|
||||
if not value then
|
||||
return {0, nil}
|
||||
end
|
||||
local data = cjson.decode(value)
|
||||
if data.token ~= ARGV[1] then
|
||||
return {2, value}
|
||||
end
|
||||
redis.call('DEL', KEYS[1])
|
||||
return {1, value}
|
||||
";
|
||||
|
||||
public ValkeyAtomicTokenStore(
|
||||
ValkeyConnectionFactory connectionFactory,
|
||||
string name,
|
||||
ILogger<ValkeyAtomicTokenStore<TPayload>>? logger = null,
|
||||
JsonSerializerOptions? jsonOptions = null,
|
||||
TimeProvider? timeProvider = null)
|
||||
{
|
||||
_connectionFactory = connectionFactory ?? throw new ArgumentNullException(nameof(connectionFactory));
|
||||
_name = name ?? throw new ArgumentNullException(nameof(name));
|
||||
_logger = logger;
|
||||
_jsonOptions = jsonOptions ?? new JsonSerializerOptions
|
||||
{
|
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
||||
WriteIndented = false
|
||||
};
|
||||
_timeProvider = timeProvider ?? TimeProvider.System;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public string ProviderName => "valkey";
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask<TokenIssueResult> IssueAsync(
|
||||
string key,
|
||||
TPayload payload,
|
||||
TimeSpan ttl,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(key);
|
||||
|
||||
var redisKey = BuildKey(key);
|
||||
var now = _timeProvider.GetUtcNow();
|
||||
var expiresAt = now.Add(ttl);
|
||||
|
||||
// Generate secure random token
|
||||
var tokenBytes = new byte[32];
|
||||
RandomNumberGenerator.Fill(tokenBytes);
|
||||
var token = Convert.ToBase64String(tokenBytes);
|
||||
|
||||
var entry = new TokenData<TPayload>
|
||||
{
|
||||
Token = token,
|
||||
Payload = payload,
|
||||
IssuedAt = now,
|
||||
ExpiresAt = expiresAt
|
||||
};
|
||||
|
||||
var serialized = JsonSerializer.Serialize(entry, _jsonOptions);
|
||||
var db = await _connectionFactory.GetDatabaseAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
await db.StringSetAsync(redisKey, serialized, ttl).ConfigureAwait(false);
|
||||
|
||||
return TokenIssueResult.Succeeded(token, expiresAt);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask<TokenIssueResult> StoreAsync(
|
||||
string key,
|
||||
string token,
|
||||
TPayload payload,
|
||||
TimeSpan ttl,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(key);
|
||||
ArgumentNullException.ThrowIfNull(token);
|
||||
|
||||
var redisKey = BuildKey(key);
|
||||
var now = _timeProvider.GetUtcNow();
|
||||
var expiresAt = now.Add(ttl);
|
||||
|
||||
var entry = new TokenData<TPayload>
|
||||
{
|
||||
Token = token,
|
||||
Payload = payload,
|
||||
IssuedAt = now,
|
||||
ExpiresAt = expiresAt
|
||||
};
|
||||
|
||||
var serialized = JsonSerializer.Serialize(entry, _jsonOptions);
|
||||
var db = await _connectionFactory.GetDatabaseAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
await db.StringSetAsync(redisKey, serialized, ttl).ConfigureAwait(false);
|
||||
|
||||
return TokenIssueResult.Succeeded(token, expiresAt);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask<TokenConsumeResult<TPayload>> TryConsumeAsync(
|
||||
string key,
|
||||
string expectedToken,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(key);
|
||||
ArgumentNullException.ThrowIfNull(expectedToken);
|
||||
|
||||
var redisKey = BuildKey(key);
|
||||
var db = await _connectionFactory.GetDatabaseAsync(cancellationToken).ConfigureAwait(false);
|
||||
var now = _timeProvider.GetUtcNow();
|
||||
|
||||
try
|
||||
{
|
||||
var result = await db.ScriptEvaluateAsync(
|
||||
ConsumeScript,
|
||||
new RedisKey[] { redisKey },
|
||||
new RedisValue[] { expectedToken }).ConfigureAwait(false);
|
||||
|
||||
var results = (RedisResult[])result!;
|
||||
var status = (int)results[0];
|
||||
|
||||
switch (status)
|
||||
{
|
||||
case 0: // Not found
|
||||
return TokenConsumeResult<TPayload>.NotFound();
|
||||
|
||||
case 1: // Success
|
||||
var data = JsonSerializer.Deserialize<TokenData<TPayload>>((string)results[1]!, _jsonOptions);
|
||||
if (data is null)
|
||||
{
|
||||
return TokenConsumeResult<TPayload>.NotFound();
|
||||
}
|
||||
|
||||
if (data.ExpiresAt < now)
|
||||
{
|
||||
return TokenConsumeResult<TPayload>.Expired(data.IssuedAt, data.ExpiresAt);
|
||||
}
|
||||
|
||||
return TokenConsumeResult<TPayload>.Success(data.Payload!, data.IssuedAt, data.ExpiresAt);
|
||||
|
||||
case 2: // Mismatch
|
||||
return TokenConsumeResult<TPayload>.Mismatch();
|
||||
|
||||
default:
|
||||
return TokenConsumeResult<TPayload>.NotFound();
|
||||
}
|
||||
}
|
||||
catch (RedisServerException ex) when (ex.Message.Contains("NOSCRIPT"))
|
||||
{
|
||||
// Fallback: non-atomic approach (less safe but works without Lua)
|
||||
return await TryConsumeNonAtomicAsync(db, redisKey, expectedToken, now).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
private async ValueTask<TokenConsumeResult<TPayload>> TryConsumeNonAtomicAsync(
|
||||
IDatabase db,
|
||||
string redisKey,
|
||||
string expectedToken,
|
||||
DateTimeOffset now)
|
||||
{
|
||||
var value = await db.StringGetAsync(redisKey).ConfigureAwait(false);
|
||||
if (value.IsNullOrEmpty)
|
||||
{
|
||||
return TokenConsumeResult<TPayload>.NotFound();
|
||||
}
|
||||
|
||||
var data = JsonSerializer.Deserialize<TokenData<TPayload>>((string)value!, _jsonOptions);
|
||||
if (data is null)
|
||||
{
|
||||
return TokenConsumeResult<TPayload>.NotFound();
|
||||
}
|
||||
|
||||
if (data.ExpiresAt < now)
|
||||
{
|
||||
await db.KeyDeleteAsync(redisKey).ConfigureAwait(false);
|
||||
return TokenConsumeResult<TPayload>.Expired(data.IssuedAt, data.ExpiresAt);
|
||||
}
|
||||
|
||||
if (!string.Equals(data.Token, expectedToken, StringComparison.Ordinal))
|
||||
{
|
||||
return TokenConsumeResult<TPayload>.Mismatch();
|
||||
}
|
||||
|
||||
// Try to delete - if someone else deleted it first, we lost the race
|
||||
if (await db.KeyDeleteAsync(redisKey).ConfigureAwait(false))
|
||||
{
|
||||
return TokenConsumeResult<TPayload>.Success(data.Payload!, data.IssuedAt, data.ExpiresAt);
|
||||
}
|
||||
|
||||
return TokenConsumeResult<TPayload>.NotFound();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask<bool> ExistsAsync(string key, CancellationToken cancellationToken = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(key);
|
||||
|
||||
var redisKey = BuildKey(key);
|
||||
var db = await _connectionFactory.GetDatabaseAsync(cancellationToken).ConfigureAwait(false);
|
||||
return await db.KeyExistsAsync(redisKey).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async ValueTask<bool> RevokeAsync(string key, CancellationToken cancellationToken = default)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(key);
|
||||
|
||||
var redisKey = BuildKey(key);
|
||||
var db = await _connectionFactory.GetDatabaseAsync(cancellationToken).ConfigureAwait(false);
|
||||
return await db.KeyDeleteAsync(redisKey).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
private string BuildKey(string key) => $"token:{_name}:{key}";
|
||||
|
||||
private sealed class TokenData<T>
|
||||
{
|
||||
public required string Token { get; init; }
|
||||
public T? Payload { get; init; }
|
||||
public DateTimeOffset IssuedAt { get; init; }
|
||||
public DateTimeOffset ExpiresAt { get; init; }
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Factory for creating Valkey atomic token store instances.
|
||||
/// </summary>
|
||||
public sealed class ValkeyAtomicTokenStoreFactory : IAtomicTokenStoreFactory
|
||||
{
|
||||
private readonly ValkeyConnectionFactory _connectionFactory;
|
||||
private readonly ILoggerFactory? _loggerFactory;
|
||||
private readonly JsonSerializerOptions? _jsonOptions;
|
||||
private readonly TimeProvider _timeProvider;
|
||||
|
||||
public ValkeyAtomicTokenStoreFactory(
|
||||
ValkeyConnectionFactory connectionFactory,
|
||||
ILoggerFactory? loggerFactory = null,
|
||||
JsonSerializerOptions? jsonOptions = null,
|
||||
TimeProvider? timeProvider = null)
|
||||
{
|
||||
_connectionFactory = connectionFactory ?? throw new ArgumentNullException(nameof(connectionFactory));
|
||||
_loggerFactory = loggerFactory;
|
||||
_jsonOptions = jsonOptions;
|
||||
_timeProvider = timeProvider ?? TimeProvider.System;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public string ProviderName => "valkey";
|
||||
|
||||
/// <inheritdoc />
|
||||
public IAtomicTokenStore<TPayload> Create<TPayload>(string name)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(name);
|
||||
return new ValkeyAtomicTokenStore<TPayload>(
|
||||
_connectionFactory,
|
||||
name,
|
||||
_loggerFactory?.CreateLogger<ValkeyAtomicTokenStore<TPayload>>(),
|
||||
_jsonOptions,
|
||||
_timeProvider);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user