Files
git.stella-ops.org/src/Scanner/__Tests/StellaOps.Scanner.WebService.Tests/ReachabilityDriftEndpointsTests.cs

175 lines
6.4 KiB
C#

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Net;
using System.Net.Http.Json;
using Microsoft.Extensions.DependencyInjection;
using StellaOps.Scanner.CallGraph;
using StellaOps.Scanner.Contracts;
using StellaOps.Scanner.Reachability;
using StellaOps.Scanner.ReachabilityDrift;
using StellaOps.Scanner.Storage.Repositories;
using StellaOps.Scanner.WebService.Contracts;
using Xunit;
using StellaOps.TestKit;
namespace StellaOps.Scanner.WebService.Tests;
public sealed class ReachabilityDriftEndpointsTests
{
[Trait("Category", TestCategories.Unit)]
[Fact]
public async Task GetDriftReturnsNotFoundWhenNoResultAndNoBaseScanProvided()
{
using var secrets = new TestSurfaceSecretsScope();
using var factory = new ScannerApplicationFactory().WithOverrides(configuration =>
{
configuration["scanner:authority:enabled"] = "false";
});
using var client = factory.CreateClient();
var scanId = await CreateScanAsync(client);
var response = await client.GetAsync($"/api/v1/scans/{scanId}/drift?language=dotnet");
Assert.Equal(HttpStatusCode.NotFound, response.StatusCode);
}
[Trait("Category", TestCategories.Unit)]
[Fact]
public async Task GetDriftComputesResultAndListsDriftedSinks()
{
using var secrets = new TestSurfaceSecretsScope();
using var factory = new ScannerApplicationFactory().WithOverrides(configuration =>
{
configuration["scanner:authority:enabled"] = "false";
});
using var client = factory.CreateClient();
var baseScanId = await CreateScanAsync(client, "base");
var headScanId = await CreateScanAsync(client, "head");
await SeedCallGraphSnapshotsAsync(factory.Services, baseScanId, headScanId);
var response = await client.GetAsync(
$"/api/v1/scans/{headScanId}/drift?baseScanId={baseScanId}&language=dotnet&includeFullPath=false");
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
var drift = await response.Content.ReadFromJsonAsync<ReachabilityDriftResult>();
Assert.NotNull(drift);
Assert.Equal(baseScanId, drift!.BaseScanId);
Assert.Equal(headScanId, drift.HeadScanId);
Assert.Equal("dotnet", drift.Language);
Assert.Single(drift.NewlyReachable);
Assert.Empty(drift.NewlyUnreachable);
var sink = drift.NewlyReachable[0];
Assert.Equal(DriftDirection.BecameReachable, sink.Direction);
Assert.Equal("sink", sink.SinkNodeId);
Assert.Equal(DriftCauseKind.GuardRemoved, sink.Cause.Kind);
var sinksResponse = await client.GetAsync($"/api/v1/drift/{drift.Id}/sinks?direction=became_reachable&offset=0&limit=10");
Assert.Equal(HttpStatusCode.OK, sinksResponse.StatusCode);
var sinksPayload = await sinksResponse.Content.ReadFromJsonAsync<DriftedSinksResponse>();
Assert.NotNull(sinksPayload);
Assert.Equal(drift.Id, sinksPayload!.DriftId);
Assert.Equal(DriftDirection.BecameReachable, sinksPayload.Direction);
Assert.Equal(0, sinksPayload.Offset);
Assert.Equal(10, sinksPayload.Limit);
Assert.Single(sinksPayload.Sinks);
Assert.Single(sinksPayload.Sinks);
}
private static async Task SeedCallGraphSnapshotsAsync(IServiceProvider services, string baseScanId, string headScanId)
{
using var scope = services.CreateScope();
var repo = scope.ServiceProvider.GetRequiredService<ICallGraphSnapshotRepository>();
var baseSnapshot = CreateSnapshot(
scanId: baseScanId,
edges: ImmutableArray<CallGraphEdge>.Empty);
var headSnapshot = CreateSnapshot(
scanId: headScanId,
edges: ImmutableArray.Create(new CallGraphEdge("entry", "sink", CallKind.Direct, "Demo.cs:1")));
await repo.StoreAsync(baseSnapshot);
await repo.StoreAsync(headSnapshot);
}
private static CallGraphSnapshot CreateSnapshot(string scanId, ImmutableArray<CallGraphEdge> edges)
{
var nodes = ImmutableArray.Create(
new CallGraphNode(
NodeId: "entry",
Symbol: "Demo.Entry",
File: "Demo.cs",
Line: 1,
Package: "pkg:generic/demo@1.0.0",
Visibility: Visibility.Public,
IsEntrypoint: true,
EntrypointType: EntrypointType.HttpHandler,
IsSink: false,
SinkCategory: null),
new CallGraphNode(
NodeId: "sink",
Symbol: "Demo.Sink",
File: "Demo.cs",
Line: 2,
Package: "pkg:generic/demo@1.0.0",
Visibility: Visibility.Public,
IsEntrypoint: false,
EntrypointType: null,
IsSink: true,
SinkCategory: SinkCategory.CmdExec));
var provisional = new CallGraphSnapshot(
ScanId: scanId,
GraphDigest: string.Empty,
Language: "dotnet",
ExtractedAt: DateTimeOffset.UnixEpoch,
Nodes: nodes,
Edges: edges,
EntrypointIds: ImmutableArray.Create("entry"),
SinkIds: ImmutableArray.Create("sink"));
return provisional with { GraphDigest = CallGraphDigests.ComputeGraphDigest(provisional) };
}
private static async Task<string> CreateScanAsync(HttpClient client, string? clientRequestId = null)
{
var response = await client.PostAsJsonAsync("/api/v1/scans", new ScanSubmitRequest
{
Image = new ScanImageDescriptor
{
Reference = "example.com/demo:1.0",
Digest = "sha256:0123456789abcdef"
},
ClientRequestId = clientRequestId,
Metadata = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase)
{
["test.request"] = clientRequestId ?? string.Empty
}
});
Assert.Equal(HttpStatusCode.Accepted, response.StatusCode);
var payload = await response.Content.ReadFromJsonAsync<ScanSubmitResponse>();
Assert.NotNull(payload);
Assert.False(string.IsNullOrWhiteSpace(payload!.ScanId));
return payload.ScanId;
}
private sealed record DriftedSinksResponse(
Guid DriftId,
DriftDirection Direction,
int Offset,
int Limit,
int Count,
DriftedSink[] Sinks);
}