using System.IO; using System.Net; using System.Net.Http.Headers; using Microsoft.Extensions.Logging; namespace StellaOps.Feedser.Source.Vndr.Cisco.Internal; internal sealed class CiscoOAuthMessageHandler : DelegatingHandler { private readonly CiscoAccessTokenProvider _tokenProvider; private readonly ILogger _logger; public CiscoOAuthMessageHandler( CiscoAccessTokenProvider tokenProvider, ILogger logger) { _tokenProvider = tokenProvider ?? throw new ArgumentNullException(nameof(tokenProvider)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); } protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { ArgumentNullException.ThrowIfNull(request); HttpRequestMessage? retryTemplate = null; try { retryTemplate = await CloneRequestAsync(request, cancellationToken).ConfigureAwait(false); } catch (IOException) { // Unable to buffer content; retry will fail if needed. retryTemplate = null; } request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await _tokenProvider.GetTokenAsync(cancellationToken).ConfigureAwait(false)); var response = await base.SendAsync(request, cancellationToken).ConfigureAwait(false); if (response.StatusCode != HttpStatusCode.Unauthorized) { return response; } response.Dispose(); _logger.LogWarning("Cisco openVuln request returned 401 Unauthorized; refreshing access token."); await _tokenProvider.RefreshAsync(cancellationToken).ConfigureAwait(false); if (retryTemplate is null) { _tokenProvider.Invalidate(); throw new HttpRequestException("Cisco openVuln request returned 401 Unauthorized and could not be retried."); } retryTemplate.Headers.Authorization = new AuthenticationHeaderValue("Bearer", await _tokenProvider.GetTokenAsync(cancellationToken).ConfigureAwait(false)); try { var retryResponse = await base.SendAsync(retryTemplate, cancellationToken).ConfigureAwait(false); if (retryResponse.StatusCode == HttpStatusCode.Unauthorized) { _tokenProvider.Invalidate(); } return retryResponse; } finally { retryTemplate.Dispose(); } } private static async Task CloneRequestAsync(HttpRequestMessage request, CancellationToken cancellationToken) { var clone = new HttpRequestMessage(request.Method, request.RequestUri) { Version = request.Version, VersionPolicy = request.VersionPolicy, }; foreach (var header in request.Headers) { clone.Headers.TryAddWithoutValidation(header.Key, header.Value); } if (request.Content is not null) { using var memory = new MemoryStream(); await request.Content.CopyToAsync(memory, cancellationToken).ConfigureAwait(false); memory.Position = 0; var buffer = memory.ToArray(); var contentClone = new ByteArrayContent(buffer); foreach (var header in request.Content.Headers) { contentClone.Headers.TryAddWithoutValidation(header.Key, header.Value); } clone.Content = contentClone; } return clone; } }