291 lines
11 KiB
C#
291 lines
11 KiB
C#
using System.Security.Cryptography;
|
|
using System.Text.Json;
|
|
using Dapper;
|
|
using Microsoft.Extensions.Logging;
|
|
using StellaOps.Messaging.Abstractions;
|
|
|
|
namespace StellaOps.Messaging.Transport.Postgres;
|
|
|
|
/// <summary>
|
|
/// PostgreSQL implementation of <see cref="IAtomicTokenStore{TPayload}"/>.
|
|
/// Uses DELETE ... RETURNING for atomic token consumption.
|
|
/// </summary>
|
|
public sealed class PostgresAtomicTokenStore<TPayload> : IAtomicTokenStore<TPayload>
|
|
{
|
|
private readonly PostgresConnectionFactory _connectionFactory;
|
|
private readonly string _name;
|
|
private readonly ILogger<PostgresAtomicTokenStore<TPayload>>? _logger;
|
|
private readonly JsonSerializerOptions _jsonOptions;
|
|
private readonly TimeProvider _timeProvider;
|
|
private bool _tableInitialized;
|
|
|
|
public PostgresAtomicTokenStore(
|
|
PostgresConnectionFactory connectionFactory,
|
|
string name,
|
|
ILogger<PostgresAtomicTokenStore<TPayload>>? logger = null,
|
|
JsonSerializerOptions? jsonOptions = null,
|
|
TimeProvider? timeProvider = null)
|
|
{
|
|
_connectionFactory = connectionFactory ?? throw new ArgumentNullException(nameof(connectionFactory));
|
|
_name = name ?? throw new ArgumentNullException(nameof(name));
|
|
_logger = logger;
|
|
_jsonOptions = jsonOptions ?? new JsonSerializerOptions
|
|
{
|
|
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
|
WriteIndented = false
|
|
};
|
|
_timeProvider = timeProvider ?? TimeProvider.System;
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public string ProviderName => "postgres";
|
|
|
|
private string TableName => $"{_connectionFactory.Schema}.atomic_token_{_name.ToLowerInvariant().Replace("-", "_")}";
|
|
|
|
/// <inheritdoc />
|
|
public async ValueTask<TokenIssueResult> IssueAsync(
|
|
string key,
|
|
TPayload payload,
|
|
TimeSpan ttl,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(key);
|
|
|
|
await EnsureTableExistsAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
await using var conn = await _connectionFactory.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
var now = _timeProvider.GetUtcNow();
|
|
var expiresAt = now.Add(ttl);
|
|
|
|
var tokenBytes = new byte[32];
|
|
RandomNumberGenerator.Fill(tokenBytes);
|
|
var token = Convert.ToBase64String(tokenBytes);
|
|
|
|
var payloadJson = JsonSerializer.Serialize(payload, _jsonOptions);
|
|
|
|
var sql = $@"
|
|
INSERT INTO {TableName} (key, token, payload, issued_at, expires_at)
|
|
VALUES (@Key, @Token, @Payload::jsonb, @IssuedAt, @ExpiresAt)
|
|
ON CONFLICT (key) DO UPDATE SET
|
|
token = EXCLUDED.token,
|
|
payload = EXCLUDED.payload,
|
|
issued_at = EXCLUDED.issued_at,
|
|
expires_at = EXCLUDED.expires_at";
|
|
|
|
await conn.ExecuteAsync(new CommandDefinition(sql, new
|
|
{
|
|
Key = key,
|
|
Token = token,
|
|
Payload = payloadJson,
|
|
IssuedAt = now.UtcDateTime,
|
|
ExpiresAt = expiresAt.UtcDateTime
|
|
}, cancellationToken: cancellationToken)).ConfigureAwait(false);
|
|
|
|
return TokenIssueResult.Succeeded(token, expiresAt);
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async ValueTask<TokenIssueResult> StoreAsync(
|
|
string key,
|
|
string token,
|
|
TPayload payload,
|
|
TimeSpan ttl,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(key);
|
|
ArgumentNullException.ThrowIfNull(token);
|
|
|
|
await EnsureTableExistsAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
await using var conn = await _connectionFactory.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
var now = _timeProvider.GetUtcNow();
|
|
var expiresAt = now.Add(ttl);
|
|
|
|
var payloadJson = JsonSerializer.Serialize(payload, _jsonOptions);
|
|
|
|
var sql = $@"
|
|
INSERT INTO {TableName} (key, token, payload, issued_at, expires_at)
|
|
VALUES (@Key, @Token, @Payload::jsonb, @IssuedAt, @ExpiresAt)
|
|
ON CONFLICT (key) DO UPDATE SET
|
|
token = EXCLUDED.token,
|
|
payload = EXCLUDED.payload,
|
|
issued_at = EXCLUDED.issued_at,
|
|
expires_at = EXCLUDED.expires_at";
|
|
|
|
await conn.ExecuteAsync(new CommandDefinition(sql, new
|
|
{
|
|
Key = key,
|
|
Token = token,
|
|
Payload = payloadJson,
|
|
IssuedAt = now.UtcDateTime,
|
|
ExpiresAt = expiresAt.UtcDateTime
|
|
}, cancellationToken: cancellationToken)).ConfigureAwait(false);
|
|
|
|
return TokenIssueResult.Succeeded(token, expiresAt);
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async ValueTask<TokenConsumeResult<TPayload>> TryConsumeAsync(
|
|
string key,
|
|
string expectedToken,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(key);
|
|
ArgumentNullException.ThrowIfNull(expectedToken);
|
|
|
|
await EnsureTableExistsAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
await using var conn = await _connectionFactory.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
var now = _timeProvider.GetUtcNow();
|
|
|
|
// First, get the entry to check expiration and mismatch
|
|
var selectSql = $@"SELECT token, payload, issued_at, expires_at FROM {TableName} WHERE key = @Key";
|
|
var entry = await conn.QuerySingleOrDefaultAsync<TokenRow>(
|
|
new CommandDefinition(selectSql, new { Key = key }, cancellationToken: cancellationToken))
|
|
.ConfigureAwait(false);
|
|
|
|
if (entry is null)
|
|
{
|
|
return TokenConsumeResult<TPayload>.NotFound();
|
|
}
|
|
|
|
var issuedAt = new DateTimeOffset(entry.IssuedAt, TimeSpan.Zero);
|
|
var expiresAt = new DateTimeOffset(entry.ExpiresAt, TimeSpan.Zero);
|
|
|
|
if (expiresAt < now)
|
|
{
|
|
// Delete expired entry
|
|
await conn.ExecuteAsync(new CommandDefinition(
|
|
$"DELETE FROM {TableName} WHERE key = @Key", new { Key = key }, cancellationToken: cancellationToken))
|
|
.ConfigureAwait(false);
|
|
return TokenConsumeResult<TPayload>.Expired(issuedAt, expiresAt);
|
|
}
|
|
|
|
if (!string.Equals(entry.Token, expectedToken, StringComparison.Ordinal))
|
|
{
|
|
return TokenConsumeResult<TPayload>.Mismatch();
|
|
}
|
|
|
|
// Atomic delete with condition
|
|
var deleteSql = $@"
|
|
DELETE FROM {TableName}
|
|
WHERE key = @Key AND token = @Token
|
|
RETURNING payload";
|
|
|
|
var deletedPayload = await conn.ExecuteScalarAsync<string>(
|
|
new CommandDefinition(deleteSql, new { Key = key, Token = expectedToken }, cancellationToken: cancellationToken))
|
|
.ConfigureAwait(false);
|
|
|
|
if (deletedPayload is null)
|
|
{
|
|
return TokenConsumeResult<TPayload>.NotFound();
|
|
}
|
|
|
|
var payload = JsonSerializer.Deserialize<TPayload>(deletedPayload, _jsonOptions);
|
|
return TokenConsumeResult<TPayload>.Success(payload!, issuedAt, expiresAt);
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async ValueTask<bool> ExistsAsync(string key, CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(key);
|
|
|
|
await EnsureTableExistsAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
await using var conn = await _connectionFactory.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
var now = _timeProvider.GetUtcNow();
|
|
var sql = $@"SELECT EXISTS(SELECT 1 FROM {TableName} WHERE key = @Key AND expires_at > @Now)";
|
|
|
|
return await conn.ExecuteScalarAsync<bool>(
|
|
new CommandDefinition(sql, new { Key = key, Now = now.UtcDateTime }, cancellationToken: cancellationToken))
|
|
.ConfigureAwait(false);
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public async ValueTask<bool> RevokeAsync(string key, CancellationToken cancellationToken = default)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(key);
|
|
|
|
await EnsureTableExistsAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
await using var conn = await _connectionFactory.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
var sql = $@"DELETE FROM {TableName} WHERE key = @Key";
|
|
var deleted = await conn.ExecuteAsync(
|
|
new CommandDefinition(sql, new { Key = key }, cancellationToken: cancellationToken))
|
|
.ConfigureAwait(false);
|
|
|
|
return deleted > 0;
|
|
}
|
|
|
|
private async ValueTask EnsureTableExistsAsync(CancellationToken cancellationToken)
|
|
{
|
|
if (_tableInitialized) return;
|
|
|
|
await using var conn = await _connectionFactory.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
var sql = $@"
|
|
CREATE TABLE IF NOT EXISTS {TableName} (
|
|
key TEXT PRIMARY KEY,
|
|
token TEXT NOT NULL,
|
|
payload JSONB,
|
|
issued_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
expires_at TIMESTAMPTZ NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_{_name}_expires ON {TableName} (expires_at);";
|
|
|
|
await conn.ExecuteAsync(new CommandDefinition(sql, cancellationToken: cancellationToken)).ConfigureAwait(false);
|
|
_tableInitialized = true;
|
|
}
|
|
|
|
private sealed class TokenRow
|
|
{
|
|
public string Token { get; init; } = null!;
|
|
public string Payload { get; init; } = null!;
|
|
public DateTime IssuedAt { get; init; }
|
|
public DateTime ExpiresAt { get; init; }
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Factory for creating PostgreSQL atomic token store instances.
|
|
/// </summary>
|
|
public sealed class PostgresAtomicTokenStoreFactory : IAtomicTokenStoreFactory
|
|
{
|
|
private readonly PostgresConnectionFactory _connectionFactory;
|
|
private readonly ILoggerFactory? _loggerFactory;
|
|
private readonly JsonSerializerOptions? _jsonOptions;
|
|
private readonly TimeProvider _timeProvider;
|
|
|
|
public PostgresAtomicTokenStoreFactory(
|
|
PostgresConnectionFactory connectionFactory,
|
|
ILoggerFactory? loggerFactory = null,
|
|
JsonSerializerOptions? jsonOptions = null,
|
|
TimeProvider? timeProvider = null)
|
|
{
|
|
_connectionFactory = connectionFactory ?? throw new ArgumentNullException(nameof(connectionFactory));
|
|
_loggerFactory = loggerFactory;
|
|
_jsonOptions = jsonOptions;
|
|
_timeProvider = timeProvider ?? TimeProvider.System;
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public string ProviderName => "postgres";
|
|
|
|
/// <inheritdoc />
|
|
public IAtomicTokenStore<TPayload> Create<TPayload>(string name)
|
|
{
|
|
ArgumentNullException.ThrowIfNull(name);
|
|
return new PostgresAtomicTokenStore<TPayload>(
|
|
_connectionFactory,
|
|
name,
|
|
_loggerFactory?.CreateLogger<PostgresAtomicTokenStore<TPayload>>(),
|
|
_jsonOptions,
|
|
_timeProvider);
|
|
}
|
|
}
|