126 lines
4.8 KiB
C#
126 lines
4.8 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.Collections.Immutable;
|
|
using System.Composition;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
using Microsoft.CodeAnalysis;
|
|
using Microsoft.CodeAnalysis.CodeActions;
|
|
using Microsoft.CodeAnalysis.CodeFixes;
|
|
using Microsoft.CodeAnalysis.CSharp;
|
|
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
|
|
|
namespace StellaOps.AirGap.Policy.Analyzers;
|
|
|
|
/// <summary>
|
|
/// Offers a remediation template that routes HttpClient creation through the shared EgressPolicy factory.
|
|
/// </summary>
|
|
[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(HttpClientUsageCodeFixProvider))]
|
|
[Shared]
|
|
public sealed class HttpClientUsageCodeFixProvider : CodeFixProvider
|
|
{
|
|
private const string Title = "Use EgressHttpClientFactory.Create(...)";
|
|
|
|
/// <inheritdoc/>
|
|
public override ImmutableArray<string> FixableDiagnosticIds
|
|
=> ImmutableArray.Create(HttpClientUsageAnalyzer.DiagnosticId);
|
|
|
|
/// <inheritdoc/>
|
|
public override FixAllProvider GetFixAllProvider()
|
|
=> WellKnownFixAllProviders.BatchFixer;
|
|
|
|
/// <inheritdoc/>
|
|
public override async Task RegisterCodeFixesAsync(CodeFixContext context)
|
|
{
|
|
if (context.Document is null)
|
|
{
|
|
return;
|
|
}
|
|
|
|
var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);
|
|
if (root is null)
|
|
{
|
|
return;
|
|
}
|
|
|
|
var diagnostic = context.Diagnostics[0];
|
|
var node = root.FindNode(diagnostic.Location.SourceSpan);
|
|
if (node is not ObjectCreationExpressionSyntax objectCreation)
|
|
{
|
|
return;
|
|
}
|
|
|
|
context.RegisterCodeFix(
|
|
CodeAction.Create(
|
|
Title,
|
|
cancellationToken => ReplaceWithFactoryCallAsync(context.Document, objectCreation, cancellationToken),
|
|
equivalenceKey: Title),
|
|
diagnostic);
|
|
}
|
|
|
|
private static async Task<Document> ReplaceWithFactoryCallAsync(Document document, ObjectCreationExpressionSyntax creation, CancellationToken cancellationToken)
|
|
{
|
|
var replacementExpression = BuildReplacementExpression(creation);
|
|
|
|
var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
|
|
if (root is null)
|
|
{
|
|
return document;
|
|
}
|
|
|
|
var updatedRoot = root.ReplaceNode(creation, replacementExpression.WithTriviaFrom(creation));
|
|
return document.WithSyntaxRoot(updatedRoot);
|
|
}
|
|
|
|
private static ExpressionSyntax BuildReplacementExpression(ObjectCreationExpressionSyntax creation)
|
|
{
|
|
var requestExpression = SyntaxFactory.ParseExpression(
|
|
"new global::StellaOps.AirGap.Policy.EgressRequest(" +
|
|
"component: \"REPLACE_COMPONENT\", " +
|
|
"destination: new global::System.Uri(\"https://replace-with-endpoint\"), " +
|
|
"intent: \"REPLACE_INTENT\")");
|
|
|
|
var egressPolicyExpression = SyntaxFactory.ParseExpression(
|
|
"default(global::StellaOps.AirGap.Policy.IEgressPolicy)");
|
|
|
|
var arguments = new List<ArgumentSyntax>
|
|
{
|
|
SyntaxFactory.Argument(egressPolicyExpression)
|
|
.WithNameColon(SyntaxFactory.NameColon("egressPolicy"))
|
|
.WithTrailingTrivia(
|
|
SyntaxFactory.Space,
|
|
SyntaxFactory.Comment("/* TODO: provide IEgressPolicy instance */")),
|
|
SyntaxFactory.Argument(requestExpression)
|
|
.WithNameColon(SyntaxFactory.NameColon("request"))
|
|
};
|
|
|
|
if (ShouldUseClientFactory(creation))
|
|
{
|
|
var clientFactoryLambda = SyntaxFactory.ParenthesizedLambdaExpression(
|
|
SyntaxFactory.ParameterList(),
|
|
CreateHttpClientExpression(creation));
|
|
|
|
arguments.Add(
|
|
SyntaxFactory.Argument(clientFactoryLambda)
|
|
.WithNameColon(SyntaxFactory.NameColon("clientFactory")));
|
|
}
|
|
|
|
return SyntaxFactory.InvocationExpression(
|
|
SyntaxFactory.ParseExpression("global::StellaOps.AirGap.Policy.EgressHttpClientFactory.Create"))
|
|
.WithArgumentList(SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(arguments)));
|
|
}
|
|
|
|
private static bool ShouldUseClientFactory(ObjectCreationExpressionSyntax creation)
|
|
=> (creation.ArgumentList?.Arguments.Count ?? 0) > 0 || creation.Initializer is not null;
|
|
|
|
private static ObjectCreationExpressionSyntax CreateHttpClientExpression(ObjectCreationExpressionSyntax creation)
|
|
{
|
|
var httpClientType = SyntaxFactory.ParseTypeName("global::System.Net.Http.HttpClient");
|
|
var arguments = creation.ArgumentList ?? SyntaxFactory.ArgumentList();
|
|
|
|
return SyntaxFactory.ObjectCreationExpression(httpClientType)
|
|
.WithArgumentList(arguments)
|
|
.WithInitializer(creation.Initializer);
|
|
}
|
|
}
|