Refactor code structure and optimize performance across multiple modules
This commit is contained in:
308
src/AdvisoryAI/StellaOps.AdvisoryAI/Inference/LlmBenchmark.cs
Normal file
308
src/AdvisoryAI/StellaOps.AdvisoryAI/Inference/LlmBenchmark.cs
Normal file
@@ -0,0 +1,308 @@
|
||||
using System.Diagnostics;
|
||||
using StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Inference;
|
||||
|
||||
/// <summary>
|
||||
/// Benchmarks local LLM inference performance.
|
||||
/// Sprint: SPRINT_20251226_019_AI_offline_inference
|
||||
/// Task: OFFLINE-20
|
||||
/// </summary>
|
||||
public interface ILlmBenchmark
|
||||
{
|
||||
/// <summary>
|
||||
/// Run a benchmark suite against a provider.
|
||||
/// </summary>
|
||||
Task<BenchmarkResult> RunAsync(
|
||||
ILlmProvider provider,
|
||||
BenchmarkOptions options,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for benchmark execution.
|
||||
/// </summary>
|
||||
public sealed record BenchmarkOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Number of warmup iterations.
|
||||
/// </summary>
|
||||
public int WarmupIterations { get; init; } = 2;
|
||||
|
||||
/// <summary>
|
||||
/// Number of benchmark iterations.
|
||||
/// </summary>
|
||||
public int Iterations { get; init; } = 10;
|
||||
|
||||
/// <summary>
|
||||
/// Short prompt for latency testing.
|
||||
/// </summary>
|
||||
public string ShortPrompt { get; init; } = "What is 2+2?";
|
||||
|
||||
/// <summary>
|
||||
/// Long prompt for throughput testing.
|
||||
/// </summary>
|
||||
public string LongPrompt { get; init; } = """
|
||||
Analyze the following vulnerability and provide a detailed assessment:
|
||||
CVE-2024-1234 affects the logging component in versions 1.0-2.5.
|
||||
The vulnerability allows remote code execution through log injection.
|
||||
Provide: severity rating, attack vector, remediation steps.
|
||||
""";
|
||||
|
||||
/// <summary>
|
||||
/// Max tokens for generation.
|
||||
/// </summary>
|
||||
public int MaxTokens { get; init; } = 512;
|
||||
|
||||
/// <summary>
|
||||
/// Report progress during benchmark.
|
||||
/// </summary>
|
||||
public IProgress<BenchmarkProgress>? Progress { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Progress update during benchmark.
|
||||
/// </summary>
|
||||
public sealed record BenchmarkProgress
|
||||
{
|
||||
public required string Phase { get; init; }
|
||||
public required int CurrentIteration { get; init; }
|
||||
public required int TotalIterations { get; init; }
|
||||
public string? Message { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of a benchmark run.
|
||||
/// </summary>
|
||||
public sealed record BenchmarkResult
|
||||
{
|
||||
public required string ProviderId { get; init; }
|
||||
public required string ModelId { get; init; }
|
||||
public required bool Success { get; init; }
|
||||
public required LatencyMetrics Latency { get; init; }
|
||||
public required ThroughputMetrics Throughput { get; init; }
|
||||
public required ResourceMetrics Resources { get; init; }
|
||||
public required DateTime CompletedAt { get; init; }
|
||||
public string? ErrorMessage { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Latency metrics.
|
||||
/// </summary>
|
||||
public sealed record LatencyMetrics
|
||||
{
|
||||
public required double MeanMs { get; init; }
|
||||
public required double MedianMs { get; init; }
|
||||
public required double P95Ms { get; init; }
|
||||
public required double P99Ms { get; init; }
|
||||
public required double MinMs { get; init; }
|
||||
public required double MaxMs { get; init; }
|
||||
public required double StdDevMs { get; init; }
|
||||
public required double TimeToFirstTokenMs { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Throughput metrics.
|
||||
/// </summary>
|
||||
public sealed record ThroughputMetrics
|
||||
{
|
||||
public required double TokensPerSecond { get; init; }
|
||||
public required double RequestsPerMinute { get; init; }
|
||||
public required int TotalTokensGenerated { get; init; }
|
||||
public required double TotalDurationSeconds { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Resource usage metrics.
|
||||
/// </summary>
|
||||
public sealed record ResourceMetrics
|
||||
{
|
||||
public required long PeakMemoryBytes { get; init; }
|
||||
public required double AvgCpuPercent { get; init; }
|
||||
public required bool GpuUsed { get; init; }
|
||||
public long? GpuMemoryBytes { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Default implementation of LLM benchmark.
|
||||
/// </summary>
|
||||
public sealed class LlmBenchmark : ILlmBenchmark
|
||||
{
|
||||
public async Task<BenchmarkResult> RunAsync(
|
||||
ILlmProvider provider,
|
||||
BenchmarkOptions options,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var latencyMeasurements = new List<double>();
|
||||
var ttftMeasurements = new List<double>();
|
||||
var totalTokens = 0;
|
||||
var modelId = "unknown";
|
||||
|
||||
try
|
||||
{
|
||||
// Warmup phase
|
||||
options.Progress?.Report(new BenchmarkProgress
|
||||
{
|
||||
Phase = "warmup",
|
||||
CurrentIteration = 0,
|
||||
TotalIterations = options.WarmupIterations,
|
||||
Message = "Starting warmup..."
|
||||
});
|
||||
|
||||
for (var i = 0; i < options.WarmupIterations; i++)
|
||||
{
|
||||
await RunSingleAsync(provider, options.ShortPrompt, options.MaxTokens, cancellationToken);
|
||||
options.Progress?.Report(new BenchmarkProgress
|
||||
{
|
||||
Phase = "warmup",
|
||||
CurrentIteration = i + 1,
|
||||
TotalIterations = options.WarmupIterations
|
||||
});
|
||||
}
|
||||
|
||||
// Latency benchmark (short prompts)
|
||||
options.Progress?.Report(new BenchmarkProgress
|
||||
{
|
||||
Phase = "latency",
|
||||
CurrentIteration = 0,
|
||||
TotalIterations = options.Iterations,
|
||||
Message = "Measuring latency..."
|
||||
});
|
||||
|
||||
var latencyStopwatch = Stopwatch.StartNew();
|
||||
for (var i = 0; i < options.Iterations; i++)
|
||||
{
|
||||
var sw = Stopwatch.StartNew();
|
||||
var result = await RunSingleAsync(provider, options.ShortPrompt, options.MaxTokens, cancellationToken);
|
||||
sw.Stop();
|
||||
|
||||
latencyMeasurements.Add(sw.Elapsed.TotalMilliseconds);
|
||||
if (result.TimeToFirstTokenMs.HasValue)
|
||||
{
|
||||
ttftMeasurements.Add(result.TimeToFirstTokenMs.Value);
|
||||
}
|
||||
totalTokens += result.OutputTokens ?? 0;
|
||||
modelId = result.ModelId;
|
||||
|
||||
options.Progress?.Report(new BenchmarkProgress
|
||||
{
|
||||
Phase = "latency",
|
||||
CurrentIteration = i + 1,
|
||||
TotalIterations = options.Iterations
|
||||
});
|
||||
}
|
||||
latencyStopwatch.Stop();
|
||||
|
||||
// Throughput benchmark (longer prompts)
|
||||
options.Progress?.Report(new BenchmarkProgress
|
||||
{
|
||||
Phase = "throughput",
|
||||
CurrentIteration = 0,
|
||||
TotalIterations = options.Iterations,
|
||||
Message = "Measuring throughput..."
|
||||
});
|
||||
|
||||
var throughputStopwatch = Stopwatch.StartNew();
|
||||
for (var i = 0; i < options.Iterations; i++)
|
||||
{
|
||||
var result = await RunSingleAsync(provider, options.LongPrompt, options.MaxTokens, cancellationToken);
|
||||
totalTokens += result.OutputTokens ?? 0;
|
||||
|
||||
options.Progress?.Report(new BenchmarkProgress
|
||||
{
|
||||
Phase = "throughput",
|
||||
CurrentIteration = i + 1,
|
||||
TotalIterations = options.Iterations
|
||||
});
|
||||
}
|
||||
throughputStopwatch.Stop();
|
||||
|
||||
// Calculate metrics
|
||||
var sortedLatencies = latencyMeasurements.Order().ToList();
|
||||
var mean = sortedLatencies.Average();
|
||||
var median = sortedLatencies[sortedLatencies.Count / 2];
|
||||
var p95 = sortedLatencies[(int)(sortedLatencies.Count * 0.95)];
|
||||
var p99 = sortedLatencies[(int)(sortedLatencies.Count * 0.99)];
|
||||
var stdDev = Math.Sqrt(sortedLatencies.Average(x => Math.Pow(x - mean, 2)));
|
||||
var avgTtft = ttftMeasurements.Count > 0 ? ttftMeasurements.Average() : 0;
|
||||
|
||||
var totalDuration = throughputStopwatch.Elapsed.TotalSeconds;
|
||||
var tokensPerSecond = totalTokens / totalDuration;
|
||||
var requestsPerMinute = (options.Iterations * 2) / totalDuration * 60;
|
||||
|
||||
return new BenchmarkResult
|
||||
{
|
||||
ProviderId = provider.ProviderId,
|
||||
ModelId = modelId,
|
||||
Success = true,
|
||||
Latency = new LatencyMetrics
|
||||
{
|
||||
MeanMs = mean,
|
||||
MedianMs = median,
|
||||
P95Ms = p95,
|
||||
P99Ms = p99,
|
||||
MinMs = sortedLatencies.Min(),
|
||||
MaxMs = sortedLatencies.Max(),
|
||||
StdDevMs = stdDev,
|
||||
TimeToFirstTokenMs = avgTtft
|
||||
},
|
||||
Throughput = new ThroughputMetrics
|
||||
{
|
||||
TokensPerSecond = tokensPerSecond,
|
||||
RequestsPerMinute = requestsPerMinute,
|
||||
TotalTokensGenerated = totalTokens,
|
||||
TotalDurationSeconds = totalDuration
|
||||
},
|
||||
Resources = new ResourceMetrics
|
||||
{
|
||||
PeakMemoryBytes = GC.GetTotalMemory(false),
|
||||
AvgCpuPercent = 0, // Would need process monitoring
|
||||
GpuUsed = false // Would need GPU monitoring
|
||||
},
|
||||
CompletedAt = DateTime.UtcNow
|
||||
};
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
return new BenchmarkResult
|
||||
{
|
||||
ProviderId = provider.ProviderId,
|
||||
ModelId = modelId,
|
||||
Success = false,
|
||||
Latency = new LatencyMetrics
|
||||
{
|
||||
MeanMs = 0, MedianMs = 0, P95Ms = 0, P99Ms = 0,
|
||||
MinMs = 0, MaxMs = 0, StdDevMs = 0, TimeToFirstTokenMs = 0
|
||||
},
|
||||
Throughput = new ThroughputMetrics
|
||||
{
|
||||
TokensPerSecond = 0, RequestsPerMinute = 0,
|
||||
TotalTokensGenerated = 0, TotalDurationSeconds = 0
|
||||
},
|
||||
Resources = new ResourceMetrics
|
||||
{
|
||||
PeakMemoryBytes = 0, AvgCpuPercent = 0, GpuUsed = false
|
||||
},
|
||||
CompletedAt = DateTime.UtcNow,
|
||||
ErrorMessage = ex.Message
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task<LlmCompletionResult> RunSingleAsync(
|
||||
ILlmProvider provider,
|
||||
string prompt,
|
||||
int maxTokens,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var request = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = prompt,
|
||||
Temperature = 0,
|
||||
Seed = 42,
|
||||
MaxTokens = maxTokens
|
||||
};
|
||||
|
||||
return await provider.CompleteAsync(request, cancellationToken);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,567 @@
|
||||
using System.Net.Http.Json;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
|
||||
/// <summary>
|
||||
/// Claude (Anthropic) provider configuration (maps to claude.yaml).
|
||||
/// </summary>
|
||||
public sealed class ClaudeConfig : LlmProviderConfigBase
|
||||
{
|
||||
/// <summary>
|
||||
/// API key (or use ANTHROPIC_API_KEY env var).
|
||||
/// </summary>
|
||||
public string? ApiKey { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Base URL for API requests.
|
||||
/// </summary>
|
||||
public string BaseUrl { get; set; } = "https://api.anthropic.com";
|
||||
|
||||
/// <summary>
|
||||
/// API version header.
|
||||
/// </summary>
|
||||
public string ApiVersion { get; set; } = "2023-06-01";
|
||||
|
||||
/// <summary>
|
||||
/// Model name.
|
||||
/// </summary>
|
||||
public string Model { get; set; } = "claude-sonnet-4-20250514";
|
||||
|
||||
/// <summary>
|
||||
/// Fallback models.
|
||||
/// </summary>
|
||||
public List<string> FallbackModels { get; set; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Top-p sampling.
|
||||
/// </summary>
|
||||
public double TopP { get; set; } = 1.0;
|
||||
|
||||
/// <summary>
|
||||
/// Top-k sampling (0 = disabled).
|
||||
/// </summary>
|
||||
public int TopK { get; set; } = 0;
|
||||
|
||||
/// <summary>
|
||||
/// Enable extended thinking.
|
||||
/// </summary>
|
||||
public bool ExtendedThinkingEnabled { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Budget tokens for extended thinking.
|
||||
/// </summary>
|
||||
public int ThinkingBudgetTokens { get; set; } = 10000;
|
||||
|
||||
/// <summary>
|
||||
/// Log request/response bodies.
|
||||
/// </summary>
|
||||
public bool LogBodies { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Log token usage.
|
||||
/// </summary>
|
||||
public bool LogUsage { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Bind configuration from IConfiguration.
|
||||
/// </summary>
|
||||
public static ClaudeConfig FromConfiguration(IConfiguration config)
|
||||
{
|
||||
var result = new ClaudeConfig();
|
||||
|
||||
// Provider section
|
||||
result.Enabled = config.GetValue("enabled", true);
|
||||
result.Priority = config.GetValue("priority", 100);
|
||||
|
||||
// API section
|
||||
var api = config.GetSection("api");
|
||||
result.ApiKey = ExpandEnvVar(api.GetValue<string>("apiKey"));
|
||||
result.BaseUrl = api.GetValue("baseUrl", "https://api.anthropic.com")!;
|
||||
result.ApiVersion = api.GetValue("apiVersion", "2023-06-01")!;
|
||||
|
||||
// Model section
|
||||
var model = config.GetSection("model");
|
||||
result.Model = model.GetValue("name", "claude-sonnet-4-20250514")!;
|
||||
result.FallbackModels = model.GetSection("fallbacks").Get<List<string>>() ?? new();
|
||||
|
||||
// Inference section
|
||||
var inference = config.GetSection("inference");
|
||||
result.Temperature = inference.GetValue("temperature", 0.0);
|
||||
result.MaxTokens = inference.GetValue("maxTokens", 4096);
|
||||
result.TopP = inference.GetValue("topP", 1.0);
|
||||
result.TopK = inference.GetValue("topK", 0);
|
||||
|
||||
// Request section
|
||||
var request = config.GetSection("request");
|
||||
result.Timeout = request.GetValue("timeout", TimeSpan.FromSeconds(120));
|
||||
result.MaxRetries = request.GetValue("maxRetries", 3);
|
||||
|
||||
// Thinking section
|
||||
var thinking = config.GetSection("thinking");
|
||||
result.ExtendedThinkingEnabled = thinking.GetValue("enabled", false);
|
||||
result.ThinkingBudgetTokens = thinking.GetValue("budgetTokens", 10000);
|
||||
|
||||
// Logging section
|
||||
var logging = config.GetSection("logging");
|
||||
result.LogBodies = logging.GetValue("logBodies", false);
|
||||
result.LogUsage = logging.GetValue("logUsage", true);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private static string? ExpandEnvVar(string? value)
|
||||
{
|
||||
if (string.IsNullOrEmpty(value))
|
||||
{
|
||||
return value;
|
||||
}
|
||||
|
||||
if (value.StartsWith("${") && value.EndsWith("}"))
|
||||
{
|
||||
var varName = value.Substring(2, value.Length - 3);
|
||||
return Environment.GetEnvironmentVariable(varName);
|
||||
}
|
||||
|
||||
return Environment.ExpandEnvironmentVariables(value);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Claude LLM provider plugin.
|
||||
/// </summary>
|
||||
public sealed class ClaudeLlmProviderPlugin : ILlmProviderPlugin
|
||||
{
|
||||
public string Name => "Claude LLM Provider";
|
||||
public string ProviderId => "claude";
|
||||
public string DisplayName => "Claude";
|
||||
public string Description => "Anthropic Claude models via API";
|
||||
public string DefaultConfigFileName => "claude.yaml";
|
||||
|
||||
public bool IsAvailable(IServiceProvider services)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
public ILlmProvider Create(IServiceProvider services, IConfiguration configuration)
|
||||
{
|
||||
var config = ClaudeConfig.FromConfiguration(configuration);
|
||||
var httpClientFactory = services.GetRequiredService<IHttpClientFactory>();
|
||||
var loggerFactory = services.GetRequiredService<ILoggerFactory>();
|
||||
|
||||
return new ClaudeLlmProvider(
|
||||
httpClientFactory.CreateClient("Claude"),
|
||||
config,
|
||||
loggerFactory.CreateLogger<ClaudeLlmProvider>());
|
||||
}
|
||||
|
||||
public LlmProviderConfigValidation ValidateConfiguration(IConfiguration configuration)
|
||||
{
|
||||
var errors = new List<string>();
|
||||
var warnings = new List<string>();
|
||||
|
||||
var config = ClaudeConfig.FromConfiguration(configuration);
|
||||
|
||||
if (!config.Enabled)
|
||||
{
|
||||
return LlmProviderConfigValidation.WithWarnings("Provider is disabled");
|
||||
}
|
||||
|
||||
var apiKey = config.ApiKey ?? Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY");
|
||||
if (string.IsNullOrEmpty(apiKey))
|
||||
{
|
||||
errors.Add("API key not configured. Set 'api.apiKey' or ANTHROPIC_API_KEY environment variable.");
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(config.BaseUrl))
|
||||
{
|
||||
errors.Add("Base URL is required.");
|
||||
}
|
||||
else if (!Uri.TryCreate(config.BaseUrl, UriKind.Absolute, out _))
|
||||
{
|
||||
errors.Add($"Invalid base URL: {config.BaseUrl}");
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(config.Model))
|
||||
{
|
||||
warnings.Add("No model specified, will use default 'claude-sonnet-4-20250514'.");
|
||||
}
|
||||
|
||||
if (errors.Count > 0)
|
||||
{
|
||||
return new LlmProviderConfigValidation
|
||||
{
|
||||
IsValid = false,
|
||||
Errors = errors,
|
||||
Warnings = warnings
|
||||
};
|
||||
}
|
||||
|
||||
return new LlmProviderConfigValidation
|
||||
{
|
||||
IsValid = true,
|
||||
Warnings = warnings
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Claude LLM provider implementation.
|
||||
/// </summary>
|
||||
public sealed class ClaudeLlmProvider : ILlmProvider
|
||||
{
|
||||
private readonly HttpClient _httpClient;
|
||||
private readonly ClaudeConfig _config;
|
||||
private readonly ILogger<ClaudeLlmProvider> _logger;
|
||||
private bool _disposed;
|
||||
|
||||
public string ProviderId => "claude";
|
||||
|
||||
public ClaudeLlmProvider(
|
||||
HttpClient httpClient,
|
||||
ClaudeConfig config,
|
||||
ILogger<ClaudeLlmProvider> logger)
|
||||
{
|
||||
_httpClient = httpClient;
|
||||
_config = config;
|
||||
_logger = logger;
|
||||
|
||||
ConfigureHttpClient();
|
||||
}
|
||||
|
||||
private void ConfigureHttpClient()
|
||||
{
|
||||
_httpClient.BaseAddress = new Uri(_config.BaseUrl.TrimEnd('/') + "/");
|
||||
_httpClient.Timeout = _config.Timeout;
|
||||
|
||||
var apiKey = _config.ApiKey ?? Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY");
|
||||
if (!string.IsNullOrEmpty(apiKey))
|
||||
{
|
||||
_httpClient.DefaultRequestHeaders.Add("x-api-key", apiKey);
|
||||
}
|
||||
|
||||
_httpClient.DefaultRequestHeaders.Add("anthropic-version", _config.ApiVersion);
|
||||
}
|
||||
|
||||
public async Task<bool> IsAvailableAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (!_config.Enabled)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var apiKey = _config.ApiKey ?? Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY");
|
||||
return !string.IsNullOrEmpty(apiKey);
|
||||
}
|
||||
|
||||
public async Task<LlmCompletionResult> CompleteAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var stopwatch = System.Diagnostics.Stopwatch.StartNew();
|
||||
var model = request.Model ?? _config.Model;
|
||||
var temperature = request.Temperature > 0 ? request.Temperature : _config.Temperature;
|
||||
var maxTokens = request.MaxTokens > 0 ? request.MaxTokens : _config.MaxTokens;
|
||||
|
||||
var claudeRequest = new ClaudeMessageRequest
|
||||
{
|
||||
Model = model,
|
||||
MaxTokens = maxTokens,
|
||||
System = request.SystemPrompt,
|
||||
Messages = new List<ClaudeMessage>
|
||||
{
|
||||
new() { Role = "user", Content = request.UserPrompt }
|
||||
},
|
||||
Temperature = temperature,
|
||||
TopP = _config.TopP,
|
||||
TopK = _config.TopK > 0 ? _config.TopK : null,
|
||||
StopSequences = request.StopSequences?.ToArray()
|
||||
};
|
||||
|
||||
if (_config.LogBodies)
|
||||
{
|
||||
_logger.LogDebug("Claude request: {Request}", JsonSerializer.Serialize(claudeRequest));
|
||||
}
|
||||
|
||||
var response = await _httpClient.PostAsJsonAsync(
|
||||
"v1/messages",
|
||||
claudeRequest,
|
||||
cancellationToken);
|
||||
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
var claudeResponse = await response.Content.ReadFromJsonAsync<ClaudeMessageResponse>(cancellationToken);
|
||||
stopwatch.Stop();
|
||||
|
||||
if (claudeResponse is null)
|
||||
{
|
||||
throw new InvalidOperationException("No response from Claude API");
|
||||
}
|
||||
|
||||
var content = claudeResponse.Content?
|
||||
.Where(c => c.Type == "text")
|
||||
.Select(c => c.Text)
|
||||
.FirstOrDefault() ?? string.Empty;
|
||||
|
||||
if (_config.LogUsage && claudeResponse.Usage is not null)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"Claude usage - Model: {Model}, Input: {InputTokens}, Output: {OutputTokens}",
|
||||
claudeResponse.Model,
|
||||
claudeResponse.Usage.InputTokens,
|
||||
claudeResponse.Usage.OutputTokens);
|
||||
}
|
||||
|
||||
return new LlmCompletionResult
|
||||
{
|
||||
Content = content,
|
||||
ModelId = claudeResponse.Model ?? model,
|
||||
ProviderId = ProviderId,
|
||||
InputTokens = claudeResponse.Usage?.InputTokens,
|
||||
OutputTokens = claudeResponse.Usage?.OutputTokens,
|
||||
TotalTimeMs = stopwatch.ElapsedMilliseconds,
|
||||
FinishReason = claudeResponse.StopReason,
|
||||
Deterministic = temperature == 0,
|
||||
RequestId = request.RequestId ?? claudeResponse.Id
|
||||
};
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<LlmStreamChunk> CompleteStreamAsync(
|
||||
LlmCompletionRequest request,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var model = request.Model ?? _config.Model;
|
||||
var temperature = request.Temperature > 0 ? request.Temperature : _config.Temperature;
|
||||
var maxTokens = request.MaxTokens > 0 ? request.MaxTokens : _config.MaxTokens;
|
||||
|
||||
var claudeRequest = new ClaudeMessageRequest
|
||||
{
|
||||
Model = model,
|
||||
MaxTokens = maxTokens,
|
||||
System = request.SystemPrompt,
|
||||
Messages = new List<ClaudeMessage>
|
||||
{
|
||||
new() { Role = "user", Content = request.UserPrompt }
|
||||
},
|
||||
Temperature = temperature,
|
||||
TopP = _config.TopP,
|
||||
TopK = _config.TopK > 0 ? _config.TopK : null,
|
||||
StopSequences = request.StopSequences?.ToArray(),
|
||||
Stream = true
|
||||
};
|
||||
|
||||
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "v1/messages")
|
||||
{
|
||||
Content = new StringContent(
|
||||
JsonSerializer.Serialize(claudeRequest),
|
||||
Encoding.UTF8,
|
||||
"application/json")
|
||||
};
|
||||
|
||||
var response = await _httpClient.SendAsync(
|
||||
httpRequest,
|
||||
HttpCompletionOption.ResponseHeadersRead,
|
||||
cancellationToken);
|
||||
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
|
||||
using var reader = new StreamReader(stream);
|
||||
|
||||
string? line;
|
||||
while ((line = await reader.ReadLineAsync(cancellationToken)) is not null)
|
||||
{
|
||||
cancellationToken.ThrowIfCancellationRequested();
|
||||
|
||||
if (string.IsNullOrEmpty(line))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!line.StartsWith("data: "))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var data = line.Substring(6);
|
||||
|
||||
ClaudeStreamEvent? evt;
|
||||
try
|
||||
{
|
||||
evt = JsonSerializer.Deserialize<ClaudeStreamEvent>(data);
|
||||
}
|
||||
catch
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (evt is null)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (evt.Type)
|
||||
{
|
||||
case "content_block_delta":
|
||||
if (evt.Delta?.Type == "text_delta")
|
||||
{
|
||||
yield return new LlmStreamChunk
|
||||
{
|
||||
Content = evt.Delta.Text ?? string.Empty,
|
||||
IsFinal = false
|
||||
};
|
||||
}
|
||||
break;
|
||||
|
||||
case "message_stop":
|
||||
yield return new LlmStreamChunk
|
||||
{
|
||||
Content = string.Empty,
|
||||
IsFinal = true,
|
||||
FinishReason = "stop"
|
||||
};
|
||||
yield break;
|
||||
|
||||
case "message_delta":
|
||||
if (evt.Delta?.StopReason != null)
|
||||
{
|
||||
yield return new LlmStreamChunk
|
||||
{
|
||||
Content = string.Empty,
|
||||
IsFinal = true,
|
||||
FinishReason = evt.Delta.StopReason
|
||||
};
|
||||
yield break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
_httpClient.Dispose();
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Claude API models
|
||||
internal sealed class ClaudeMessageRequest
|
||||
{
|
||||
[JsonPropertyName("model")]
|
||||
public required string Model { get; set; }
|
||||
|
||||
[JsonPropertyName("max_tokens")]
|
||||
public int MaxTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("system")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
||||
public string? System { get; set; }
|
||||
|
||||
[JsonPropertyName("messages")]
|
||||
public required List<ClaudeMessage> Messages { get; set; }
|
||||
|
||||
[JsonPropertyName("temperature")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public double Temperature { get; set; }
|
||||
|
||||
[JsonPropertyName("top_p")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public double TopP { get; set; }
|
||||
|
||||
[JsonPropertyName("top_k")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
||||
public int? TopK { get; set; }
|
||||
|
||||
[JsonPropertyName("stop_sequences")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
||||
public string[]? StopSequences { get; set; }
|
||||
|
||||
[JsonPropertyName("stream")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public bool Stream { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class ClaudeMessage
|
||||
{
|
||||
[JsonPropertyName("role")]
|
||||
public required string Role { get; set; }
|
||||
|
||||
[JsonPropertyName("content")]
|
||||
public required string Content { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class ClaudeMessageResponse
|
||||
{
|
||||
[JsonPropertyName("id")]
|
||||
public string? Id { get; set; }
|
||||
|
||||
[JsonPropertyName("type")]
|
||||
public string? Type { get; set; }
|
||||
|
||||
[JsonPropertyName("role")]
|
||||
public string? Role { get; set; }
|
||||
|
||||
[JsonPropertyName("model")]
|
||||
public string? Model { get; set; }
|
||||
|
||||
[JsonPropertyName("content")]
|
||||
public List<ClaudeContentBlock>? Content { get; set; }
|
||||
|
||||
[JsonPropertyName("stop_reason")]
|
||||
public string? StopReason { get; set; }
|
||||
|
||||
[JsonPropertyName("usage")]
|
||||
public ClaudeUsage? Usage { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class ClaudeContentBlock
|
||||
{
|
||||
[JsonPropertyName("type")]
|
||||
public string? Type { get; set; }
|
||||
|
||||
[JsonPropertyName("text")]
|
||||
public string? Text { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class ClaudeUsage
|
||||
{
|
||||
[JsonPropertyName("input_tokens")]
|
||||
public int InputTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("output_tokens")]
|
||||
public int OutputTokens { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class ClaudeStreamEvent
|
||||
{
|
||||
[JsonPropertyName("type")]
|
||||
public string? Type { get; set; }
|
||||
|
||||
[JsonPropertyName("delta")]
|
||||
public ClaudeDelta? Delta { get; set; }
|
||||
|
||||
[JsonPropertyName("index")]
|
||||
public int? Index { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class ClaudeDelta
|
||||
{
|
||||
[JsonPropertyName("type")]
|
||||
public string? Type { get; set; }
|
||||
|
||||
[JsonPropertyName("text")]
|
||||
public string? Text { get; set; }
|
||||
|
||||
[JsonPropertyName("stop_reason")]
|
||||
public string? StopReason { get; set; }
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
using System.Runtime.CompilerServices;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
|
||||
/// <summary>
|
||||
/// Unified LLM provider interface supporting OpenAI, Claude, and local servers.
|
||||
/// This unblocks OFFLINE-07 and enables all AI sprints to use any backend.
|
||||
/// </summary>
|
||||
public interface ILlmProvider : IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Provider identifier (openai, claude, llama-server, ollama).
|
||||
/// </summary>
|
||||
string ProviderId { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Whether the provider is available and configured.
|
||||
/// </summary>
|
||||
Task<bool> IsAvailableAsync(CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Generate a completion from a prompt.
|
||||
/// </summary>
|
||||
Task<LlmCompletionResult> CompleteAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Generate a completion with streaming output.
|
||||
/// </summary>
|
||||
IAsyncEnumerable<LlmStreamChunk> CompleteStreamAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Request for LLM completion.
|
||||
/// </summary>
|
||||
public sealed record LlmCompletionRequest
|
||||
{
|
||||
/// <summary>
|
||||
/// System prompt (instructions).
|
||||
/// </summary>
|
||||
public string? SystemPrompt { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// User prompt (main input).
|
||||
/// </summary>
|
||||
public required string UserPrompt { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Model to use (provider-specific).
|
||||
/// </summary>
|
||||
public string? Model { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Temperature (0 = deterministic).
|
||||
/// </summary>
|
||||
public double Temperature { get; init; } = 0;
|
||||
|
||||
/// <summary>
|
||||
/// Maximum tokens to generate.
|
||||
/// </summary>
|
||||
public int MaxTokens { get; init; } = 4096;
|
||||
|
||||
/// <summary>
|
||||
/// Random seed for reproducibility.
|
||||
/// </summary>
|
||||
public int? Seed { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Stop sequences.
|
||||
/// </summary>
|
||||
public IReadOnlyList<string>? StopSequences { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Request ID for tracing.
|
||||
/// </summary>
|
||||
public string? RequestId { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result from LLM completion.
|
||||
/// </summary>
|
||||
public sealed record LlmCompletionResult
|
||||
{
|
||||
/// <summary>
|
||||
/// Generated content.
|
||||
/// </summary>
|
||||
public required string Content { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Model used.
|
||||
/// </summary>
|
||||
public required string ModelId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Provider ID.
|
||||
/// </summary>
|
||||
public required string ProviderId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Input tokens used.
|
||||
/// </summary>
|
||||
public int? InputTokens { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Output tokens generated.
|
||||
/// </summary>
|
||||
public int? OutputTokens { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Time to first token (ms).
|
||||
/// </summary>
|
||||
public long? TimeToFirstTokenMs { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Total inference time (ms).
|
||||
/// </summary>
|
||||
public long? TotalTimeMs { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Finish reason (stop, length, etc.).
|
||||
/// </summary>
|
||||
public string? FinishReason { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Whether output is deterministic.
|
||||
/// </summary>
|
||||
public bool Deterministic { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Request ID for tracing.
|
||||
/// </summary>
|
||||
public string? RequestId { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Streaming chunk from LLM.
|
||||
/// </summary>
|
||||
public sealed record LlmStreamChunk
|
||||
{
|
||||
/// <summary>
|
||||
/// Content delta.
|
||||
/// </summary>
|
||||
public required string Content { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Whether this is the final chunk.
|
||||
/// </summary>
|
||||
public bool IsFinal { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Finish reason (only on final chunk).
|
||||
/// </summary>
|
||||
public string? FinishReason { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Factory for creating LLM providers.
|
||||
/// </summary>
|
||||
public interface ILlmProviderFactory
|
||||
{
|
||||
/// <summary>
|
||||
/// Get a provider by ID.
|
||||
/// </summary>
|
||||
ILlmProvider GetProvider(string providerId);
|
||||
|
||||
/// <summary>
|
||||
/// Get the default provider based on configuration.
|
||||
/// </summary>
|
||||
ILlmProvider GetDefaultProvider();
|
||||
|
||||
/// <summary>
|
||||
/// List available providers.
|
||||
/// </summary>
|
||||
IReadOnlyList<string> AvailableProviders { get; }
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
using System.Runtime.CompilerServices;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using NetEscapades.Configuration.Yaml;
|
||||
using StellaOps.Plugin;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
|
||||
/// <summary>
|
||||
/// Plugin interface for LLM providers.
|
||||
/// Each provider (OpenAI, Claude, LlamaServer, Ollama) implements this interface
|
||||
/// and is discovered via the plugin catalog.
|
||||
/// </summary>
|
||||
public interface ILlmProviderPlugin : IAvailabilityPlugin
|
||||
{
|
||||
/// <summary>
|
||||
/// Unique provider identifier (e.g., "openai", "claude", "llama-server").
|
||||
/// </summary>
|
||||
string ProviderId { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Display name for the provider.
|
||||
/// </summary>
|
||||
string DisplayName { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Provider description.
|
||||
/// </summary>
|
||||
string Description { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Default configuration file name (e.g., "openai.yaml").
|
||||
/// </summary>
|
||||
string DefaultConfigFileName { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Create an LLM provider instance with the given configuration.
|
||||
/// </summary>
|
||||
ILlmProvider Create(IServiceProvider services, IConfiguration configuration);
|
||||
|
||||
/// <summary>
|
||||
/// Validate the configuration and return any errors.
|
||||
/// </summary>
|
||||
LlmProviderConfigValidation ValidateConfiguration(IConfiguration configuration);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of configuration validation.
|
||||
/// </summary>
|
||||
public sealed record LlmProviderConfigValidation
|
||||
{
|
||||
public bool IsValid { get; init; }
|
||||
public IReadOnlyList<string> Errors { get; init; } = Array.Empty<string>();
|
||||
public IReadOnlyList<string> Warnings { get; init; } = Array.Empty<string>();
|
||||
|
||||
public static LlmProviderConfigValidation Success() => new() { IsValid = true };
|
||||
|
||||
public static LlmProviderConfigValidation Failed(params string[] errors) => new()
|
||||
{
|
||||
IsValid = false,
|
||||
Errors = errors
|
||||
};
|
||||
|
||||
public static LlmProviderConfigValidation WithWarnings(params string[] warnings) => new()
|
||||
{
|
||||
IsValid = true,
|
||||
Warnings = warnings
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Base configuration shared by all LLM providers.
|
||||
/// </summary>
|
||||
public abstract class LlmProviderConfigBase
|
||||
{
|
||||
/// <summary>
|
||||
/// Whether the provider is enabled.
|
||||
/// </summary>
|
||||
public bool Enabled { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Priority for provider selection (lower = higher priority).
|
||||
/// </summary>
|
||||
public int Priority { get; set; } = 100;
|
||||
|
||||
/// <summary>
|
||||
/// Request timeout.
|
||||
/// </summary>
|
||||
public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(120);
|
||||
|
||||
/// <summary>
|
||||
/// Maximum retries on failure.
|
||||
/// </summary>
|
||||
public int MaxRetries { get; set; } = 3;
|
||||
|
||||
/// <summary>
|
||||
/// Temperature for inference (0 = deterministic).
|
||||
/// </summary>
|
||||
public double Temperature { get; set; } = 0;
|
||||
|
||||
/// <summary>
|
||||
/// Maximum tokens to generate.
|
||||
/// </summary>
|
||||
public int MaxTokens { get; set; } = 4096;
|
||||
|
||||
/// <summary>
|
||||
/// Random seed for reproducibility.
|
||||
/// </summary>
|
||||
public int? Seed { get; set; } = 42;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Catalog for LLM provider plugins.
|
||||
/// </summary>
|
||||
public sealed class LlmProviderCatalog
|
||||
{
|
||||
private readonly Dictionary<string, ILlmProviderPlugin> _plugins = new(StringComparer.OrdinalIgnoreCase);
|
||||
private readonly Dictionary<string, IConfiguration> _configurations = new(StringComparer.OrdinalIgnoreCase);
|
||||
|
||||
/// <summary>
|
||||
/// Register a provider plugin.
|
||||
/// </summary>
|
||||
public LlmProviderCatalog RegisterPlugin(ILlmProviderPlugin plugin)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(plugin);
|
||||
_plugins[plugin.ProviderId] = plugin;
|
||||
return this;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Register configuration for a provider.
|
||||
/// </summary>
|
||||
public LlmProviderCatalog RegisterConfiguration(string providerId, IConfiguration configuration)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(configuration);
|
||||
_configurations[providerId] = configuration;
|
||||
return this;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load configurations from a directory.
|
||||
/// </summary>
|
||||
public LlmProviderCatalog LoadConfigurationsFromDirectory(string directory)
|
||||
{
|
||||
if (!Directory.Exists(directory))
|
||||
{
|
||||
return this;
|
||||
}
|
||||
|
||||
foreach (var file in Directory.GetFiles(directory, "*.yaml"))
|
||||
{
|
||||
var providerId = Path.GetFileNameWithoutExtension(file);
|
||||
var config = new ConfigurationBuilder()
|
||||
.AddYamlFile(file, optional: false, reloadOnChange: true)
|
||||
.Build();
|
||||
_configurations[providerId] = config;
|
||||
}
|
||||
|
||||
foreach (var file in Directory.GetFiles(directory, "*.yml"))
|
||||
{
|
||||
var providerId = Path.GetFileNameWithoutExtension(file);
|
||||
var config = new ConfigurationBuilder()
|
||||
.AddYamlFile(file, optional: false, reloadOnChange: true)
|
||||
.Build();
|
||||
_configurations[providerId] = config;
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get all registered plugins.
|
||||
/// </summary>
|
||||
public IReadOnlyList<ILlmProviderPlugin> GetPlugins() => _plugins.Values.ToList();
|
||||
|
||||
/// <summary>
|
||||
/// Get available plugins (those with valid configuration).
|
||||
/// </summary>
|
||||
public IReadOnlyList<ILlmProviderPlugin> GetAvailablePlugins(IServiceProvider services)
|
||||
{
|
||||
var available = new List<ILlmProviderPlugin>();
|
||||
|
||||
foreach (var plugin in _plugins.Values)
|
||||
{
|
||||
if (!_configurations.TryGetValue(plugin.ProviderId, out var config))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!plugin.IsAvailable(services))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var validation = plugin.ValidateConfiguration(config);
|
||||
if (!validation.IsValid)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
available.Add(plugin);
|
||||
}
|
||||
|
||||
return available.OrderBy(p => GetPriority(p.ProviderId)).ToList();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get a specific plugin by ID.
|
||||
/// </summary>
|
||||
public ILlmProviderPlugin? GetPlugin(string providerId)
|
||||
{
|
||||
return _plugins.TryGetValue(providerId, out var plugin) ? plugin : null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get configuration for a provider.
|
||||
/// </summary>
|
||||
public IConfiguration? GetConfiguration(string providerId)
|
||||
{
|
||||
return _configurations.TryGetValue(providerId, out var config) ? config : null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a provider instance.
|
||||
/// </summary>
|
||||
public ILlmProvider? CreateProvider(string providerId, IServiceProvider services)
|
||||
{
|
||||
if (!_plugins.TryGetValue(providerId, out var plugin))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!_configurations.TryGetValue(providerId, out var config))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
return plugin.Create(services, config);
|
||||
}
|
||||
|
||||
private int GetPriority(string providerId)
|
||||
{
|
||||
if (_configurations.TryGetValue(providerId, out var config))
|
||||
{
|
||||
return config.GetValue<int>("Priority", 100);
|
||||
}
|
||||
return 100;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,592 @@
|
||||
using System.Net.Http.Json;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
|
||||
/// <summary>
|
||||
/// Llama.cpp server configuration (maps to llama-server.yaml).
|
||||
/// This is the key provider for OFFLINE/AIRGAP environments.
|
||||
/// </summary>
|
||||
public sealed class LlamaServerConfig : LlmProviderConfigBase
|
||||
{
|
||||
/// <summary>
|
||||
/// Server base URL.
|
||||
/// </summary>
|
||||
public string BaseUrl { get; set; } = "http://localhost:8080";
|
||||
|
||||
/// <summary>
|
||||
/// API key (if server requires auth).
|
||||
/// </summary>
|
||||
public string? ApiKey { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Health check endpoint.
|
||||
/// </summary>
|
||||
public string HealthEndpoint { get; set; } = "/health";
|
||||
|
||||
/// <summary>
|
||||
/// Model name (for logging).
|
||||
/// </summary>
|
||||
public string Model { get; set; } = "local-llama";
|
||||
|
||||
/// <summary>
|
||||
/// Model file path (informational).
|
||||
/// </summary>
|
||||
public string? ModelPath { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Expected model digest (SHA-256).
|
||||
/// </summary>
|
||||
public string? ExpectedDigest { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Top-p sampling.
|
||||
/// </summary>
|
||||
public double TopP { get; set; } = 1.0;
|
||||
|
||||
/// <summary>
|
||||
/// Top-k sampling.
|
||||
/// </summary>
|
||||
public int TopK { get; set; } = 40;
|
||||
|
||||
/// <summary>
|
||||
/// Repeat penalty.
|
||||
/// </summary>
|
||||
public double RepeatPenalty { get; set; } = 1.1;
|
||||
|
||||
/// <summary>
|
||||
/// Context length.
|
||||
/// </summary>
|
||||
public int ContextLength { get; set; } = 4096;
|
||||
|
||||
/// <summary>
|
||||
/// Model bundle path (for airgap).
|
||||
/// </summary>
|
||||
public string? BundlePath { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Verify bundle signature.
|
||||
/// </summary>
|
||||
public bool VerifySignature { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Crypto scheme for verification.
|
||||
/// </summary>
|
||||
public string? CryptoScheme { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Log health checks.
|
||||
/// </summary>
|
||||
public bool LogHealthChecks { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Log token usage.
|
||||
/// </summary>
|
||||
public bool LogUsage { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Bind configuration from IConfiguration.
|
||||
/// </summary>
|
||||
public static LlamaServerConfig FromConfiguration(IConfiguration config)
|
||||
{
|
||||
var result = new LlamaServerConfig();
|
||||
|
||||
// Provider section
|
||||
result.Enabled = config.GetValue("enabled", true);
|
||||
result.Priority = config.GetValue("priority", 10); // Lower = higher priority for offline
|
||||
|
||||
// Server section
|
||||
var server = config.GetSection("server");
|
||||
result.BaseUrl = server.GetValue("baseUrl", "http://localhost:8080")!;
|
||||
result.ApiKey = server.GetValue<string>("apiKey");
|
||||
result.HealthEndpoint = server.GetValue("healthEndpoint", "/health")!;
|
||||
|
||||
// Model section
|
||||
var model = config.GetSection("model");
|
||||
result.Model = model.GetValue("name", "local-llama")!;
|
||||
result.ModelPath = model.GetValue<string>("modelPath");
|
||||
result.ExpectedDigest = model.GetValue<string>("expectedDigest");
|
||||
|
||||
// Inference section
|
||||
var inference = config.GetSection("inference");
|
||||
result.Temperature = inference.GetValue("temperature", 0.0);
|
||||
result.MaxTokens = inference.GetValue("maxTokens", 4096);
|
||||
result.Seed = inference.GetValue<int?>("seed") ?? 42;
|
||||
result.TopP = inference.GetValue("topP", 1.0);
|
||||
result.TopK = inference.GetValue("topK", 40);
|
||||
result.RepeatPenalty = inference.GetValue("repeatPenalty", 1.1);
|
||||
result.ContextLength = inference.GetValue("contextLength", 4096);
|
||||
|
||||
// Request section
|
||||
var request = config.GetSection("request");
|
||||
result.Timeout = request.GetValue("timeout", TimeSpan.FromMinutes(5)); // Longer for local
|
||||
result.MaxRetries = request.GetValue("maxRetries", 2);
|
||||
|
||||
// Bundle section (for airgap)
|
||||
var bundle = config.GetSection("bundle");
|
||||
result.BundlePath = bundle.GetValue<string>("bundlePath");
|
||||
result.VerifySignature = bundle.GetValue("verifySignature", true);
|
||||
result.CryptoScheme = bundle.GetValue<string>("cryptoScheme");
|
||||
|
||||
// Logging section
|
||||
var logging = config.GetSection("logging");
|
||||
result.LogHealthChecks = logging.GetValue("logHealthChecks", false);
|
||||
result.LogUsage = logging.GetValue("logUsage", true);
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Llama.cpp server LLM provider plugin.
|
||||
/// </summary>
|
||||
public sealed class LlamaServerLlmProviderPlugin : ILlmProviderPlugin
|
||||
{
|
||||
public string Name => "Llama.cpp Server LLM Provider";
|
||||
public string ProviderId => "llama-server";
|
||||
public string DisplayName => "llama.cpp Server";
|
||||
public string Description => "Local LLM inference via llama.cpp HTTP server (enables offline operation)";
|
||||
public string DefaultConfigFileName => "llama-server.yaml";
|
||||
|
||||
public bool IsAvailable(IServiceProvider services)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
public ILlmProvider Create(IServiceProvider services, IConfiguration configuration)
|
||||
{
|
||||
var config = LlamaServerConfig.FromConfiguration(configuration);
|
||||
var httpClientFactory = services.GetRequiredService<IHttpClientFactory>();
|
||||
var loggerFactory = services.GetRequiredService<ILoggerFactory>();
|
||||
|
||||
return new LlamaServerLlmProvider(
|
||||
httpClientFactory.CreateClient("LlamaServer"),
|
||||
config,
|
||||
loggerFactory.CreateLogger<LlamaServerLlmProvider>());
|
||||
}
|
||||
|
||||
public LlmProviderConfigValidation ValidateConfiguration(IConfiguration configuration)
|
||||
{
|
||||
var errors = new List<string>();
|
||||
var warnings = new List<string>();
|
||||
|
||||
var config = LlamaServerConfig.FromConfiguration(configuration);
|
||||
|
||||
if (!config.Enabled)
|
||||
{
|
||||
return LlmProviderConfigValidation.WithWarnings("Provider is disabled");
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(config.BaseUrl))
|
||||
{
|
||||
errors.Add("Server base URL is required.");
|
||||
}
|
||||
else if (!Uri.TryCreate(config.BaseUrl, UriKind.Absolute, out _))
|
||||
{
|
||||
errors.Add($"Invalid server URL: {config.BaseUrl}");
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(config.Model))
|
||||
{
|
||||
warnings.Add("No model name specified for logging.");
|
||||
}
|
||||
|
||||
if (errors.Count > 0)
|
||||
{
|
||||
return new LlmProviderConfigValidation
|
||||
{
|
||||
IsValid = false,
|
||||
Errors = errors,
|
||||
Warnings = warnings
|
||||
};
|
||||
}
|
||||
|
||||
return new LlmProviderConfigValidation
|
||||
{
|
||||
IsValid = true,
|
||||
Warnings = warnings
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Llama.cpp server LLM provider implementation.
|
||||
/// Connects to llama.cpp running with --server flag (OpenAI-compatible API).
|
||||
/// This is the key solution for OFFLINE-07: enables local inference without native bindings.
|
||||
/// </summary>
|
||||
public sealed class LlamaServerLlmProvider : ILlmProvider
|
||||
{
|
||||
private readonly HttpClient _httpClient;
|
||||
private readonly LlamaServerConfig _config;
|
||||
private readonly ILogger<LlamaServerLlmProvider> _logger;
|
||||
private bool _disposed;
|
||||
|
||||
public string ProviderId => "llama-server";
|
||||
|
||||
public LlamaServerLlmProvider(
|
||||
HttpClient httpClient,
|
||||
LlamaServerConfig config,
|
||||
ILogger<LlamaServerLlmProvider> logger)
|
||||
{
|
||||
_httpClient = httpClient;
|
||||
_config = config;
|
||||
_logger = logger;
|
||||
|
||||
ConfigureHttpClient();
|
||||
}
|
||||
|
||||
private void ConfigureHttpClient()
|
||||
{
|
||||
_httpClient.BaseAddress = new Uri(_config.BaseUrl.TrimEnd('/') + "/");
|
||||
_httpClient.Timeout = _config.Timeout;
|
||||
|
||||
if (!string.IsNullOrEmpty(_config.ApiKey))
|
||||
{
|
||||
_httpClient.DefaultRequestHeaders.Authorization =
|
||||
new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", _config.ApiKey);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<bool> IsAvailableAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (!_config.Enabled)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
// llama.cpp server exposes /health endpoint
|
||||
var response = await _httpClient.GetAsync(_config.HealthEndpoint.TrimStart('/'), cancellationToken);
|
||||
var available = response.IsSuccessStatusCode;
|
||||
|
||||
if (_config.LogHealthChecks)
|
||||
{
|
||||
_logger.LogDebug("Llama server health check: {Available} at {BaseUrl}",
|
||||
available, _config.BaseUrl);
|
||||
}
|
||||
|
||||
if (available)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Fallback: try /v1/models (OpenAI-compatible)
|
||||
response = await _httpClient.GetAsync("v1/models", cancellationToken);
|
||||
return response.IsSuccessStatusCode;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogDebug(ex, "Llama server availability check failed at {BaseUrl}", _config.BaseUrl);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<LlmCompletionResult> CompleteAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var stopwatch = System.Diagnostics.Stopwatch.StartNew();
|
||||
var model = request.Model ?? _config.Model;
|
||||
var temperature = request.Temperature > 0 ? request.Temperature : _config.Temperature;
|
||||
var maxTokens = request.MaxTokens > 0 ? request.MaxTokens : _config.MaxTokens;
|
||||
var seed = request.Seed ?? _config.Seed ?? 42;
|
||||
|
||||
var llamaRequest = new LlamaServerRequest
|
||||
{
|
||||
Model = model,
|
||||
Messages = BuildMessages(request),
|
||||
Temperature = temperature,
|
||||
MaxTokens = maxTokens,
|
||||
Seed = seed,
|
||||
TopP = _config.TopP,
|
||||
TopK = _config.TopK,
|
||||
RepeatPenalty = _config.RepeatPenalty,
|
||||
Stop = request.StopSequences?.ToArray()
|
||||
};
|
||||
|
||||
var response = await _httpClient.PostAsJsonAsync(
|
||||
"v1/chat/completions",
|
||||
llamaRequest,
|
||||
cancellationToken);
|
||||
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
var llamaResponse = await response.Content.ReadFromJsonAsync<LlamaServerResponse>(cancellationToken);
|
||||
stopwatch.Stop();
|
||||
|
||||
if (llamaResponse?.Choices is null || llamaResponse.Choices.Count == 0)
|
||||
{
|
||||
throw new InvalidOperationException("No completion returned from llama.cpp server");
|
||||
}
|
||||
|
||||
var choice = llamaResponse.Choices[0];
|
||||
|
||||
if (_config.LogUsage && llamaResponse.Usage is not null)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"Llama server usage - Model: {Model}, Input: {InputTokens}, Output: {OutputTokens}, Time: {TimeMs}ms",
|
||||
model,
|
||||
llamaResponse.Usage.PromptTokens,
|
||||
llamaResponse.Usage.CompletionTokens,
|
||||
stopwatch.ElapsedMilliseconds);
|
||||
}
|
||||
|
||||
return new LlmCompletionResult
|
||||
{
|
||||
Content = choice.Message?.Content ?? string.Empty,
|
||||
ModelId = llamaResponse.Model ?? model,
|
||||
ProviderId = ProviderId,
|
||||
InputTokens = llamaResponse.Usage?.PromptTokens,
|
||||
OutputTokens = llamaResponse.Usage?.CompletionTokens,
|
||||
TotalTimeMs = stopwatch.ElapsedMilliseconds,
|
||||
FinishReason = choice.FinishReason,
|
||||
Deterministic = temperature == 0,
|
||||
RequestId = request.RequestId ?? llamaResponse.Id
|
||||
};
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<LlmStreamChunk> CompleteStreamAsync(
|
||||
LlmCompletionRequest request,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var model = request.Model ?? _config.Model;
|
||||
var temperature = request.Temperature > 0 ? request.Temperature : _config.Temperature;
|
||||
var maxTokens = request.MaxTokens > 0 ? request.MaxTokens : _config.MaxTokens;
|
||||
var seed = request.Seed ?? _config.Seed ?? 42;
|
||||
|
||||
var llamaRequest = new LlamaServerRequest
|
||||
{
|
||||
Model = model,
|
||||
Messages = BuildMessages(request),
|
||||
Temperature = temperature,
|
||||
MaxTokens = maxTokens,
|
||||
Seed = seed,
|
||||
TopP = _config.TopP,
|
||||
TopK = _config.TopK,
|
||||
RepeatPenalty = _config.RepeatPenalty,
|
||||
Stop = request.StopSequences?.ToArray(),
|
||||
Stream = true
|
||||
};
|
||||
|
||||
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "v1/chat/completions")
|
||||
{
|
||||
Content = new StringContent(
|
||||
JsonSerializer.Serialize(llamaRequest),
|
||||
Encoding.UTF8,
|
||||
"application/json")
|
||||
};
|
||||
|
||||
var response = await _httpClient.SendAsync(
|
||||
httpRequest,
|
||||
HttpCompletionOption.ResponseHeadersRead,
|
||||
cancellationToken);
|
||||
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
|
||||
using var reader = new StreamReader(stream);
|
||||
|
||||
string? line;
|
||||
while ((line = await reader.ReadLineAsync(cancellationToken)) is not null)
|
||||
{
|
||||
cancellationToken.ThrowIfCancellationRequested();
|
||||
|
||||
if (string.IsNullOrEmpty(line))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!line.StartsWith("data: "))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var data = line.Substring(6);
|
||||
if (data == "[DONE]")
|
||||
{
|
||||
yield return new LlmStreamChunk
|
||||
{
|
||||
Content = string.Empty,
|
||||
IsFinal = true,
|
||||
FinishReason = "stop"
|
||||
};
|
||||
yield break;
|
||||
}
|
||||
|
||||
LlamaServerStreamResponse? chunk;
|
||||
try
|
||||
{
|
||||
chunk = JsonSerializer.Deserialize<LlamaServerStreamResponse>(data);
|
||||
}
|
||||
catch
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (chunk?.Choices is null || chunk.Choices.Count == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var choice = chunk.Choices[0];
|
||||
var content = choice.Delta?.Content ?? string.Empty;
|
||||
var isFinal = choice.FinishReason != null;
|
||||
|
||||
yield return new LlmStreamChunk
|
||||
{
|
||||
Content = content,
|
||||
IsFinal = isFinal,
|
||||
FinishReason = choice.FinishReason
|
||||
};
|
||||
|
||||
if (isFinal)
|
||||
{
|
||||
yield break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static List<LlamaServerMessage> BuildMessages(LlmCompletionRequest request)
|
||||
{
|
||||
var messages = new List<LlamaServerMessage>();
|
||||
|
||||
if (!string.IsNullOrEmpty(request.SystemPrompt))
|
||||
{
|
||||
messages.Add(new LlamaServerMessage { Role = "system", Content = request.SystemPrompt });
|
||||
}
|
||||
|
||||
messages.Add(new LlamaServerMessage { Role = "user", Content = request.UserPrompt });
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
_httpClient.Dispose();
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// llama.cpp server API models (OpenAI-compatible)
|
||||
internal sealed class LlamaServerRequest
|
||||
{
|
||||
[JsonPropertyName("model")]
|
||||
public required string Model { get; set; }
|
||||
|
||||
[JsonPropertyName("messages")]
|
||||
public required List<LlamaServerMessage> Messages { get; set; }
|
||||
|
||||
[JsonPropertyName("temperature")]
|
||||
public double Temperature { get; set; }
|
||||
|
||||
[JsonPropertyName("max_tokens")]
|
||||
public int MaxTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("seed")]
|
||||
public int Seed { get; set; }
|
||||
|
||||
[JsonPropertyName("top_p")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public double TopP { get; set; }
|
||||
|
||||
[JsonPropertyName("top_k")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public int TopK { get; set; }
|
||||
|
||||
[JsonPropertyName("repeat_penalty")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public double RepeatPenalty { get; set; }
|
||||
|
||||
[JsonPropertyName("stop")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
||||
public string[]? Stop { get; set; }
|
||||
|
||||
[JsonPropertyName("stream")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public bool Stream { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class LlamaServerMessage
|
||||
{
|
||||
[JsonPropertyName("role")]
|
||||
public required string Role { get; set; }
|
||||
|
||||
[JsonPropertyName("content")]
|
||||
public required string Content { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class LlamaServerResponse
|
||||
{
|
||||
[JsonPropertyName("id")]
|
||||
public string? Id { get; set; }
|
||||
|
||||
[JsonPropertyName("model")]
|
||||
public string? Model { get; set; }
|
||||
|
||||
[JsonPropertyName("choices")]
|
||||
public List<LlamaServerChoice>? Choices { get; set; }
|
||||
|
||||
[JsonPropertyName("usage")]
|
||||
public LlamaServerUsage? Usage { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class LlamaServerChoice
|
||||
{
|
||||
[JsonPropertyName("index")]
|
||||
public int Index { get; set; }
|
||||
|
||||
[JsonPropertyName("message")]
|
||||
public LlamaServerMessage? Message { get; set; }
|
||||
|
||||
[JsonPropertyName("finish_reason")]
|
||||
public string? FinishReason { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class LlamaServerUsage
|
||||
{
|
||||
[JsonPropertyName("prompt_tokens")]
|
||||
public int PromptTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("completion_tokens")]
|
||||
public int CompletionTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("total_tokens")]
|
||||
public int TotalTokens { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class LlamaServerStreamResponse
|
||||
{
|
||||
[JsonPropertyName("id")]
|
||||
public string? Id { get; set; }
|
||||
|
||||
[JsonPropertyName("choices")]
|
||||
public List<LlamaServerStreamChoice>? Choices { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class LlamaServerStreamChoice
|
||||
{
|
||||
[JsonPropertyName("index")]
|
||||
public int Index { get; set; }
|
||||
|
||||
[JsonPropertyName("delta")]
|
||||
public LlamaServerDelta? Delta { get; set; }
|
||||
|
||||
[JsonPropertyName("finish_reason")]
|
||||
public string? FinishReason { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class LlamaServerDelta
|
||||
{
|
||||
[JsonPropertyName("content")]
|
||||
public string? Content { get; set; }
|
||||
}
|
||||
@@ -0,0 +1,492 @@
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
|
||||
/// <summary>
|
||||
/// Interface for LLM inference caching.
|
||||
/// Caches deterministic (temperature=0) completions for replay and cost reduction.
|
||||
/// Sprint: SPRINT_20251226_019_AI_offline_inference
|
||||
/// Task: OFFLINE-09
|
||||
/// </summary>
|
||||
public interface ILlmInferenceCache
|
||||
{
|
||||
/// <summary>
|
||||
/// Try to get a cached completion.
|
||||
/// </summary>
|
||||
Task<LlmCompletionResult?> TryGetAsync(
|
||||
LlmCompletionRequest request,
|
||||
string providerId,
|
||||
CancellationToken ct = default);
|
||||
|
||||
/// <summary>
|
||||
/// Cache a completion result.
|
||||
/// </summary>
|
||||
Task SetAsync(
|
||||
LlmCompletionRequest request,
|
||||
string providerId,
|
||||
LlmCompletionResult result,
|
||||
CancellationToken ct = default);
|
||||
|
||||
/// <summary>
|
||||
/// Invalidate cached entries by pattern.
|
||||
/// </summary>
|
||||
Task InvalidateAsync(string pattern, CancellationToken ct = default);
|
||||
|
||||
/// <summary>
|
||||
/// Get cache statistics.
|
||||
/// </summary>
|
||||
LlmCacheStatistics GetStatistics();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for LLM inference caching.
|
||||
/// </summary>
|
||||
public sealed class LlmInferenceCacheOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Whether caching is enabled.
|
||||
/// </summary>
|
||||
public bool Enabled { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Whether to only cache deterministic requests (temperature=0).
|
||||
/// </summary>
|
||||
public bool DeterministicOnly { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Default TTL for cache entries.
|
||||
/// </summary>
|
||||
public TimeSpan DefaultTtl { get; set; } = TimeSpan.FromDays(7);
|
||||
|
||||
/// <summary>
|
||||
/// Maximum TTL for cache entries.
|
||||
/// </summary>
|
||||
public TimeSpan MaxTtl { get; set; } = TimeSpan.FromDays(30);
|
||||
|
||||
/// <summary>
|
||||
/// TTL for short-lived entries (non-deterministic).
|
||||
/// </summary>
|
||||
public TimeSpan ShortTtl { get; set; } = TimeSpan.FromHours(1);
|
||||
|
||||
/// <summary>
|
||||
/// Key prefix for cache entries.
|
||||
/// </summary>
|
||||
public string KeyPrefix { get; set; } = "llm:inference:";
|
||||
|
||||
/// <summary>
|
||||
/// Maximum content length to cache.
|
||||
/// </summary>
|
||||
public int MaxContentLength { get; set; } = 100_000;
|
||||
|
||||
/// <summary>
|
||||
/// Whether to use sliding expiration.
|
||||
/// </summary>
|
||||
public bool SlidingExpiration { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Include model in cache key.
|
||||
/// </summary>
|
||||
public bool IncludeModelInKey { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Include seed in cache key.
|
||||
/// </summary>
|
||||
public bool IncludeSeedInKey { get; set; } = true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Statistics for LLM inference cache.
|
||||
/// </summary>
|
||||
public sealed record LlmCacheStatistics
|
||||
{
|
||||
/// <summary>
|
||||
/// Total cache hits.
|
||||
/// </summary>
|
||||
public long Hits { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Total cache misses.
|
||||
/// </summary>
|
||||
public long Misses { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Total cache sets.
|
||||
/// </summary>
|
||||
public long Sets { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Cache hit rate (0.0 - 1.0).
|
||||
/// </summary>
|
||||
public double HitRate => Hits + Misses > 0 ? (double)Hits / (Hits + Misses) : 0;
|
||||
|
||||
/// <summary>
|
||||
/// Estimated tokens saved.
|
||||
/// </summary>
|
||||
public long TokensSaved { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Estimated cost saved (USD).
|
||||
/// </summary>
|
||||
public decimal CostSaved { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// In-memory LLM inference cache implementation.
|
||||
/// For production, use distributed cache (Valkey/Redis).
|
||||
/// </summary>
|
||||
public sealed class InMemoryLlmInferenceCache : ILlmInferenceCache, IDisposable
|
||||
{
|
||||
private static readonly JsonSerializerOptions SerializerOptions = new(JsonSerializerDefaults.Web)
|
||||
{
|
||||
WriteIndented = false,
|
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
||||
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
|
||||
};
|
||||
|
||||
private readonly Dictionary<string, CacheEntry> _cache = new();
|
||||
private readonly LlmInferenceCacheOptions _options;
|
||||
private readonly ILogger<InMemoryLlmInferenceCache> _logger;
|
||||
private readonly TimeProvider _timeProvider;
|
||||
private readonly object _lock = new();
|
||||
private readonly Timer _cleanupTimer;
|
||||
|
||||
private long _hits;
|
||||
private long _misses;
|
||||
private long _sets;
|
||||
private long _tokensSaved;
|
||||
|
||||
public InMemoryLlmInferenceCache(
|
||||
IOptions<LlmInferenceCacheOptions> options,
|
||||
ILogger<InMemoryLlmInferenceCache> logger,
|
||||
TimeProvider? timeProvider = null)
|
||||
{
|
||||
_options = options.Value;
|
||||
_logger = logger;
|
||||
_timeProvider = timeProvider ?? TimeProvider.System;
|
||||
|
||||
// Cleanup expired entries every 5 minutes
|
||||
_cleanupTimer = new Timer(CleanupExpired, null, TimeSpan.FromMinutes(5), TimeSpan.FromMinutes(5));
|
||||
}
|
||||
|
||||
public Task<LlmCompletionResult?> TryGetAsync(
|
||||
LlmCompletionRequest request,
|
||||
string providerId,
|
||||
CancellationToken ct = default)
|
||||
{
|
||||
if (!_options.Enabled)
|
||||
{
|
||||
return Task.FromResult<LlmCompletionResult?>(null);
|
||||
}
|
||||
|
||||
if (_options.DeterministicOnly && request.Temperature > 0)
|
||||
{
|
||||
return Task.FromResult<LlmCompletionResult?>(null);
|
||||
}
|
||||
|
||||
var key = ComputeCacheKey(request, providerId);
|
||||
|
||||
lock (_lock)
|
||||
{
|
||||
if (_cache.TryGetValue(key, out var entry))
|
||||
{
|
||||
if (entry.ExpiresAt > _timeProvider.GetUtcNow())
|
||||
{
|
||||
Interlocked.Increment(ref _hits);
|
||||
Interlocked.Add(ref _tokensSaved, entry.Result.OutputTokens ?? 0);
|
||||
|
||||
// Update access time for sliding expiration
|
||||
if (_options.SlidingExpiration)
|
||||
{
|
||||
entry.AccessedAt = _timeProvider.GetUtcNow();
|
||||
}
|
||||
|
||||
_logger.LogDebug("Cache hit for key {Key}", key);
|
||||
return Task.FromResult<LlmCompletionResult?>(entry.Result);
|
||||
}
|
||||
|
||||
// Expired, remove it
|
||||
_cache.Remove(key);
|
||||
}
|
||||
}
|
||||
|
||||
Interlocked.Increment(ref _misses);
|
||||
_logger.LogDebug("Cache miss for key {Key}", key);
|
||||
return Task.FromResult<LlmCompletionResult?>(null);
|
||||
}
|
||||
|
||||
public Task SetAsync(
|
||||
LlmCompletionRequest request,
|
||||
string providerId,
|
||||
LlmCompletionResult result,
|
||||
CancellationToken ct = default)
|
||||
{
|
||||
if (!_options.Enabled)
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
// Don't cache non-deterministic if option is set
|
||||
if (_options.DeterministicOnly && request.Temperature > 0)
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
// Don't cache if content too large
|
||||
if (result.Content.Length > _options.MaxContentLength)
|
||||
{
|
||||
_logger.LogDebug("Skipping cache for large content ({Length} > {Max})",
|
||||
result.Content.Length, _options.MaxContentLength);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
var key = ComputeCacheKey(request, providerId);
|
||||
var ttl = result.Deterministic ? _options.DefaultTtl : _options.ShortTtl;
|
||||
var now = _timeProvider.GetUtcNow();
|
||||
|
||||
var entry = new CacheEntry
|
||||
{
|
||||
Result = result,
|
||||
CreatedAt = now,
|
||||
AccessedAt = now,
|
||||
ExpiresAt = now.Add(ttl)
|
||||
};
|
||||
|
||||
lock (_lock)
|
||||
{
|
||||
_cache[key] = entry;
|
||||
}
|
||||
|
||||
Interlocked.Increment(ref _sets);
|
||||
_logger.LogDebug("Cached result for key {Key}, TTL {Ttl}", key, ttl);
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public Task InvalidateAsync(string pattern, CancellationToken ct = default)
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
var keysToRemove = _cache.Keys
|
||||
.Where(k => k.Contains(pattern, StringComparison.OrdinalIgnoreCase))
|
||||
.ToList();
|
||||
|
||||
foreach (var key in keysToRemove)
|
||||
{
|
||||
_cache.Remove(key);
|
||||
}
|
||||
|
||||
_logger.LogInformation("Invalidated {Count} cache entries matching '{Pattern}'",
|
||||
keysToRemove.Count, pattern);
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public LlmCacheStatistics GetStatistics()
|
||||
{
|
||||
return new LlmCacheStatistics
|
||||
{
|
||||
Hits = _hits,
|
||||
Misses = _misses,
|
||||
Sets = _sets,
|
||||
TokensSaved = _tokensSaved,
|
||||
// Rough estimate: $0.002 per 1K tokens average
|
||||
CostSaved = _tokensSaved * 0.002m / 1000m
|
||||
};
|
||||
}
|
||||
|
||||
private string ComputeCacheKey(LlmCompletionRequest request, string providerId)
|
||||
{
|
||||
using var sha = SHA256.Create();
|
||||
var sb = new StringBuilder();
|
||||
|
||||
sb.Append(_options.KeyPrefix);
|
||||
sb.Append(providerId);
|
||||
sb.Append(':');
|
||||
|
||||
if (_options.IncludeModelInKey && !string.IsNullOrEmpty(request.Model))
|
||||
{
|
||||
sb.Append(request.Model);
|
||||
sb.Append(':');
|
||||
}
|
||||
|
||||
// Hash the prompts
|
||||
var promptHash = ComputeHash(sha, $"{request.SystemPrompt}||{request.UserPrompt}");
|
||||
sb.Append(promptHash);
|
||||
|
||||
// Include seed if configured
|
||||
if (_options.IncludeSeedInKey && request.Seed.HasValue)
|
||||
{
|
||||
sb.Append(':');
|
||||
sb.Append(request.Seed.Value);
|
||||
}
|
||||
|
||||
// Include temperature and max tokens in key
|
||||
sb.Append(':');
|
||||
sb.Append(request.Temperature.ToString("F2"));
|
||||
sb.Append(':');
|
||||
sb.Append(request.MaxTokens);
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private static string ComputeHash(SHA256 sha, string input)
|
||||
{
|
||||
var bytes = Encoding.UTF8.GetBytes(input);
|
||||
var hash = sha.ComputeHash(bytes);
|
||||
return Convert.ToHexStringLower(hash)[..16]; // First 16 chars
|
||||
}
|
||||
|
||||
private void CleanupExpired(object? state)
|
||||
{
|
||||
var now = _timeProvider.GetUtcNow();
|
||||
var removed = 0;
|
||||
|
||||
lock (_lock)
|
||||
{
|
||||
var keysToRemove = _cache
|
||||
.Where(kvp => kvp.Value.ExpiresAt <= now)
|
||||
.Select(kvp => kvp.Key)
|
||||
.ToList();
|
||||
|
||||
foreach (var key in keysToRemove)
|
||||
{
|
||||
_cache.Remove(key);
|
||||
removed++;
|
||||
}
|
||||
}
|
||||
|
||||
if (removed > 0)
|
||||
{
|
||||
_logger.LogDebug("Cleaned up {Count} expired cache entries", removed);
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_cleanupTimer.Dispose();
|
||||
}
|
||||
|
||||
private sealed class CacheEntry
|
||||
{
|
||||
public required LlmCompletionResult Result { get; init; }
|
||||
public DateTimeOffset CreatedAt { get; init; }
|
||||
public DateTimeOffset AccessedAt { get; set; }
|
||||
public DateTimeOffset ExpiresAt { get; init; }
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Caching wrapper for LLM providers.
|
||||
/// Wraps any ILlmProvider to add caching.
|
||||
/// </summary>
|
||||
public sealed class CachingLlmProvider : ILlmProvider
|
||||
{
|
||||
private readonly ILlmProvider _inner;
|
||||
private readonly ILlmInferenceCache _cache;
|
||||
private readonly ILogger<CachingLlmProvider> _logger;
|
||||
|
||||
public string ProviderId => _inner.ProviderId;
|
||||
|
||||
public CachingLlmProvider(
|
||||
ILlmProvider inner,
|
||||
ILlmInferenceCache cache,
|
||||
ILogger<CachingLlmProvider> logger)
|
||||
{
|
||||
_inner = inner ?? throw new ArgumentNullException(nameof(inner));
|
||||
_cache = cache ?? throw new ArgumentNullException(nameof(cache));
|
||||
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
|
||||
}
|
||||
|
||||
public Task<bool> IsAvailableAsync(CancellationToken cancellationToken = default)
|
||||
=> _inner.IsAvailableAsync(cancellationToken);
|
||||
|
||||
public async Task<LlmCompletionResult> CompleteAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Try cache first
|
||||
var cached = await _cache.TryGetAsync(request, ProviderId, cancellationToken);
|
||||
if (cached is not null)
|
||||
{
|
||||
_logger.LogDebug("Returning cached result for provider {ProviderId}", ProviderId);
|
||||
return cached with { RequestId = request.RequestId };
|
||||
}
|
||||
|
||||
// Get from provider
|
||||
var result = await _inner.CompleteAsync(request, cancellationToken);
|
||||
|
||||
// Cache the result
|
||||
await _cache.SetAsync(request, ProviderId, result, cancellationToken);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<LlmStreamChunk> CompleteStreamAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Streaming is not cached - pass through to inner provider
|
||||
return _inner.CompleteStreamAsync(request, cancellationToken);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_inner.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Factory for creating caching LLM providers.
|
||||
/// </summary>
|
||||
public sealed class CachingLlmProviderFactory : ILlmProviderFactory
|
||||
{
|
||||
private readonly ILlmProviderFactory _inner;
|
||||
private readonly ILlmInferenceCache _cache;
|
||||
private readonly ILoggerFactory _loggerFactory;
|
||||
private readonly Dictionary<string, CachingLlmProvider> _cachedProviders = new();
|
||||
private readonly object _lock = new();
|
||||
|
||||
public CachingLlmProviderFactory(
|
||||
ILlmProviderFactory inner,
|
||||
ILlmInferenceCache cache,
|
||||
ILoggerFactory loggerFactory)
|
||||
{
|
||||
_inner = inner ?? throw new ArgumentNullException(nameof(inner));
|
||||
_cache = cache ?? throw new ArgumentNullException(nameof(cache));
|
||||
_loggerFactory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory));
|
||||
}
|
||||
|
||||
public IReadOnlyList<string> AvailableProviders => _inner.AvailableProviders;
|
||||
|
||||
public ILlmProvider GetProvider(string providerId)
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
if (_cachedProviders.TryGetValue(providerId, out var existing))
|
||||
{
|
||||
return existing;
|
||||
}
|
||||
|
||||
var inner = _inner.GetProvider(providerId);
|
||||
var caching = new CachingLlmProvider(
|
||||
inner,
|
||||
_cache,
|
||||
_loggerFactory.CreateLogger<CachingLlmProvider>());
|
||||
|
||||
_cachedProviders[providerId] = caching;
|
||||
return caching;
|
||||
}
|
||||
}
|
||||
|
||||
public ILlmProvider GetDefaultProvider()
|
||||
{
|
||||
var inner = _inner.GetDefaultProvider();
|
||||
return GetProvider(inner.ProviderId);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,359 @@
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Plugin;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
|
||||
/// <summary>
|
||||
/// Factory for creating and managing LLM providers using the plugin architecture.
|
||||
/// Discovers plugins and loads configurations from YAML files.
|
||||
/// </summary>
|
||||
public sealed class PluginBasedLlmProviderFactory : ILlmProviderFactory, IDisposable
|
||||
{
|
||||
private readonly LlmProviderCatalog _catalog;
|
||||
private readonly IServiceProvider _serviceProvider;
|
||||
private readonly ILogger<PluginBasedLlmProviderFactory> _logger;
|
||||
private readonly Dictionary<string, ILlmProvider> _providers = new(StringComparer.OrdinalIgnoreCase);
|
||||
private readonly object _lock = new();
|
||||
private bool _disposed;
|
||||
|
||||
public PluginBasedLlmProviderFactory(
|
||||
LlmProviderCatalog catalog,
|
||||
IServiceProvider serviceProvider,
|
||||
ILogger<PluginBasedLlmProviderFactory> logger)
|
||||
{
|
||||
_catalog = catalog;
|
||||
_serviceProvider = serviceProvider;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public IReadOnlyList<string> AvailableProviders
|
||||
{
|
||||
get
|
||||
{
|
||||
var plugins = _catalog.GetAvailablePlugins(_serviceProvider);
|
||||
return plugins.Select(p => p.ProviderId).ToList();
|
||||
}
|
||||
}
|
||||
|
||||
public ILlmProvider GetProvider(string providerId)
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
if (_providers.TryGetValue(providerId, out var existing))
|
||||
{
|
||||
return existing;
|
||||
}
|
||||
|
||||
var plugin = _catalog.GetPlugin(providerId);
|
||||
if (plugin is null)
|
||||
{
|
||||
throw new InvalidOperationException($"LLM provider plugin '{providerId}' not found. " +
|
||||
$"Available plugins: {string.Join(", ", _catalog.GetPlugins().Select(p => p.ProviderId))}");
|
||||
}
|
||||
|
||||
var config = _catalog.GetConfiguration(providerId);
|
||||
if (config is null)
|
||||
{
|
||||
throw new InvalidOperationException($"Configuration for LLM provider '{providerId}' not found. " +
|
||||
$"Ensure {plugin.DefaultConfigFileName} exists in the llm-providers directory.");
|
||||
}
|
||||
|
||||
var validation = plugin.ValidateConfiguration(config);
|
||||
if (!validation.IsValid)
|
||||
{
|
||||
throw new InvalidOperationException($"Invalid configuration for LLM provider '{providerId}': " +
|
||||
string.Join("; ", validation.Errors));
|
||||
}
|
||||
|
||||
foreach (var warning in validation.Warnings)
|
||||
{
|
||||
_logger.LogWarning("LLM provider {ProviderId} config warning: {Warning}", providerId, warning);
|
||||
}
|
||||
|
||||
_logger.LogInformation("Creating LLM provider: {ProviderId} ({DisplayName})",
|
||||
providerId, plugin.DisplayName);
|
||||
|
||||
var provider = plugin.Create(_serviceProvider, config);
|
||||
_providers[providerId] = provider;
|
||||
return provider;
|
||||
}
|
||||
}
|
||||
|
||||
public ILlmProvider GetDefaultProvider()
|
||||
{
|
||||
var available = _catalog.GetAvailablePlugins(_serviceProvider);
|
||||
if (available.Count == 0)
|
||||
{
|
||||
throw new InvalidOperationException("No LLM providers are available. " +
|
||||
"Check that at least one provider is configured in the llm-providers directory.");
|
||||
}
|
||||
|
||||
// Return the first available provider (sorted by priority)
|
||||
var defaultPlugin = available[0];
|
||||
_logger.LogInformation("Using default LLM provider: {ProviderId}", defaultPlugin.ProviderId);
|
||||
return GetProvider(defaultPlugin.ProviderId);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
foreach (var provider in _providers.Values)
|
||||
{
|
||||
provider.Dispose();
|
||||
}
|
||||
_providers.Clear();
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering LLM provider services with plugin support.
|
||||
/// </summary>
|
||||
public static class LlmProviderPluginExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds LLM provider plugin services to the service collection.
|
||||
/// </summary>
|
||||
public static IServiceCollection AddLlmProviderPlugins(
|
||||
this IServiceCollection services,
|
||||
string configDirectory = "etc/llm-providers")
|
||||
{
|
||||
services.AddHttpClient();
|
||||
|
||||
// Create and configure the catalog
|
||||
services.AddSingleton(sp =>
|
||||
{
|
||||
var catalog = new LlmProviderCatalog();
|
||||
|
||||
// Register built-in plugins
|
||||
catalog.RegisterPlugin(new OpenAiLlmProviderPlugin());
|
||||
catalog.RegisterPlugin(new ClaudeLlmProviderPlugin());
|
||||
catalog.RegisterPlugin(new LlamaServerLlmProviderPlugin());
|
||||
catalog.RegisterPlugin(new OllamaLlmProviderPlugin());
|
||||
|
||||
// Load configurations from directory
|
||||
var fullPath = Path.GetFullPath(configDirectory);
|
||||
if (Directory.Exists(fullPath))
|
||||
{
|
||||
catalog.LoadConfigurationsFromDirectory(fullPath);
|
||||
}
|
||||
|
||||
return catalog;
|
||||
});
|
||||
|
||||
services.AddSingleton<ILlmProviderFactory, PluginBasedLlmProviderFactory>();
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds LLM provider plugin services with explicit configuration.
|
||||
/// </summary>
|
||||
public static IServiceCollection AddLlmProviderPlugins(
|
||||
this IServiceCollection services,
|
||||
Action<LlmProviderCatalog> configureCatalog)
|
||||
{
|
||||
services.AddHttpClient();
|
||||
|
||||
services.AddSingleton(sp =>
|
||||
{
|
||||
var catalog = new LlmProviderCatalog();
|
||||
|
||||
// Register built-in plugins
|
||||
catalog.RegisterPlugin(new OpenAiLlmProviderPlugin());
|
||||
catalog.RegisterPlugin(new ClaudeLlmProviderPlugin());
|
||||
catalog.RegisterPlugin(new LlamaServerLlmProviderPlugin());
|
||||
catalog.RegisterPlugin(new OllamaLlmProviderPlugin());
|
||||
|
||||
configureCatalog(catalog);
|
||||
|
||||
return catalog;
|
||||
});
|
||||
|
||||
services.AddSingleton<ILlmProviderFactory, PluginBasedLlmProviderFactory>();
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Registers a custom LLM provider plugin.
|
||||
/// </summary>
|
||||
public static LlmProviderCatalog RegisterCustomPlugin<TPlugin>(this LlmProviderCatalog catalog)
|
||||
where TPlugin : ILlmProviderPlugin, new()
|
||||
{
|
||||
return catalog.RegisterPlugin(new TPlugin());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Registers configuration for a provider from an IConfiguration section.
|
||||
/// </summary>
|
||||
public static LlmProviderCatalog RegisterConfiguration(
|
||||
this LlmProviderCatalog catalog,
|
||||
string providerId,
|
||||
IConfigurationSection section)
|
||||
{
|
||||
return catalog.RegisterConfiguration(providerId, section);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Legacy LLM provider factory for backwards compatibility.
|
||||
/// Wraps the plugin-based factory.
|
||||
/// </summary>
|
||||
[Obsolete("Use PluginBasedLlmProviderFactory instead")]
|
||||
public sealed class LlmProviderFactory : ILlmProviderFactory, IDisposable
|
||||
{
|
||||
private readonly PluginBasedLlmProviderFactory _innerFactory;
|
||||
|
||||
public LlmProviderFactory(
|
||||
LlmProviderCatalog catalog,
|
||||
IServiceProvider serviceProvider,
|
||||
ILogger<PluginBasedLlmProviderFactory> logger)
|
||||
{
|
||||
_innerFactory = new PluginBasedLlmProviderFactory(catalog, serviceProvider, logger);
|
||||
}
|
||||
|
||||
public IReadOnlyList<string> AvailableProviders => _innerFactory.AvailableProviders;
|
||||
|
||||
public ILlmProvider GetProvider(string providerId) => _innerFactory.GetProvider(providerId);
|
||||
|
||||
public ILlmProvider GetDefaultProvider() => _innerFactory.GetDefaultProvider();
|
||||
|
||||
public void Dispose() => _innerFactory.Dispose();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// LLM provider with automatic fallback to alternative providers.
|
||||
/// </summary>
|
||||
public sealed class FallbackLlmProvider : ILlmProvider
|
||||
{
|
||||
private readonly ILlmProviderFactory _factory;
|
||||
private readonly IReadOnlyList<string> _providerOrder;
|
||||
private readonly ILogger<FallbackLlmProvider> _logger;
|
||||
|
||||
public string ProviderId => "fallback";
|
||||
|
||||
public FallbackLlmProvider(
|
||||
ILlmProviderFactory factory,
|
||||
IReadOnlyList<string> providerOrder,
|
||||
ILogger<FallbackLlmProvider> logger)
|
||||
{
|
||||
_factory = factory;
|
||||
_providerOrder = providerOrder;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public async Task<bool> IsAvailableAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
foreach (var providerId in _providerOrder)
|
||||
{
|
||||
try
|
||||
{
|
||||
var provider = _factory.GetProvider(providerId);
|
||||
if (await provider.IsAvailableAsync(cancellationToken))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Continue to next provider
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public async Task<LlmCompletionResult> CompleteAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
Exception? lastException = null;
|
||||
|
||||
foreach (var providerId in _providerOrder)
|
||||
{
|
||||
try
|
||||
{
|
||||
var provider = _factory.GetProvider(providerId);
|
||||
|
||||
if (!await provider.IsAvailableAsync(cancellationToken))
|
||||
{
|
||||
_logger.LogDebug("Provider {ProviderId} is not available, trying next", providerId);
|
||||
continue;
|
||||
}
|
||||
|
||||
_logger.LogDebug("Using provider {ProviderId} for completion", providerId);
|
||||
return await provider.CompleteAsync(request, cancellationToken);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Provider {ProviderId} failed, trying next", providerId);
|
||||
lastException = ex;
|
||||
}
|
||||
}
|
||||
|
||||
throw new InvalidOperationException(
|
||||
"All LLM providers failed. Check configuration and provider availability.",
|
||||
lastException);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<LlmStreamChunk> CompleteStreamAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
return CompleteStreamAsyncCore(request, cancellationToken);
|
||||
}
|
||||
|
||||
private async IAsyncEnumerable<LlmStreamChunk> CompleteStreamAsyncCore(
|
||||
LlmCompletionRequest request,
|
||||
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken)
|
||||
{
|
||||
// Find the first available provider
|
||||
ILlmProvider? selectedProvider = null;
|
||||
Exception? lastException = null;
|
||||
|
||||
foreach (var providerId in _providerOrder)
|
||||
{
|
||||
try
|
||||
{
|
||||
var provider = _factory.GetProvider(providerId);
|
||||
|
||||
if (await provider.IsAvailableAsync(cancellationToken))
|
||||
{
|
||||
_logger.LogDebug("Using provider {ProviderId} for streaming completion", providerId);
|
||||
selectedProvider = provider;
|
||||
break;
|
||||
}
|
||||
|
||||
_logger.LogDebug("Provider {ProviderId} is not available for streaming, trying next", providerId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Provider {ProviderId} check failed, trying next", providerId);
|
||||
lastException = ex;
|
||||
}
|
||||
}
|
||||
|
||||
if (selectedProvider is null)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
"No LLM provider available for streaming. Check configuration and provider availability.",
|
||||
lastException);
|
||||
}
|
||||
|
||||
// Stream from the selected provider
|
||||
await foreach (var chunk in selectedProvider.CompleteStreamAsync(request, cancellationToken))
|
||||
{
|
||||
yield return chunk;
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
// Factory manages provider disposal
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
namespace StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
|
||||
/// <summary>
|
||||
/// Configuration for LLM providers.
|
||||
/// </summary>
|
||||
public sealed class LlmProviderOptions
|
||||
{
|
||||
public const string SectionName = "AdvisoryAI:LlmProviders";
|
||||
|
||||
/// <summary>
|
||||
/// Default provider to use (openai, claude, llama-server, ollama).
|
||||
/// </summary>
|
||||
public string DefaultProvider { get; set; } = "openai";
|
||||
|
||||
/// <summary>
|
||||
/// Fallback providers in order of preference.
|
||||
/// </summary>
|
||||
public List<string> FallbackProviders { get; set; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// OpenAI configuration.
|
||||
/// </summary>
|
||||
public OpenAiProviderOptions OpenAi { get; set; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Claude/Anthropic configuration.
|
||||
/// </summary>
|
||||
public ClaudeProviderOptions Claude { get; set; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Llama.cpp server configuration.
|
||||
/// </summary>
|
||||
public LlamaServerProviderOptions LlamaServer { get; set; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Ollama configuration.
|
||||
/// </summary>
|
||||
public OllamaProviderOptions Ollama { get; set; } = new();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// OpenAI provider options.
|
||||
/// </summary>
|
||||
public sealed class OpenAiProviderOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Whether enabled.
|
||||
/// </summary>
|
||||
public bool Enabled { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// API key (or use OPENAI_API_KEY env var).
|
||||
/// </summary>
|
||||
public string? ApiKey { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Base URL (for Azure OpenAI or proxies).
|
||||
/// </summary>
|
||||
public string BaseUrl { get; set; } = "https://api.openai.com/v1";
|
||||
|
||||
/// <summary>
|
||||
/// Default model.
|
||||
/// </summary>
|
||||
public string Model { get; set; } = "gpt-4o";
|
||||
|
||||
/// <summary>
|
||||
/// Organization ID (optional).
|
||||
/// </summary>
|
||||
public string? OrganizationId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Request timeout.
|
||||
/// </summary>
|
||||
public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(120);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Claude/Anthropic provider options.
|
||||
/// </summary>
|
||||
public sealed class ClaudeProviderOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Whether enabled.
|
||||
/// </summary>
|
||||
public bool Enabled { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// API key (or use ANTHROPIC_API_KEY env var).
|
||||
/// </summary>
|
||||
public string? ApiKey { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Base URL.
|
||||
/// </summary>
|
||||
public string BaseUrl { get; set; } = "https://api.anthropic.com";
|
||||
|
||||
/// <summary>
|
||||
/// Default model.
|
||||
/// </summary>
|
||||
public string Model { get; set; } = "claude-sonnet-4-20250514";
|
||||
|
||||
/// <summary>
|
||||
/// API version.
|
||||
/// </summary>
|
||||
public string ApiVersion { get; set; } = "2023-06-01";
|
||||
|
||||
/// <summary>
|
||||
/// Request timeout.
|
||||
/// </summary>
|
||||
public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(120);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Llama.cpp server provider options.
|
||||
/// </summary>
|
||||
public sealed class LlamaServerProviderOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Whether enabled.
|
||||
/// </summary>
|
||||
public bool Enabled { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Server URL (llama.cpp runs OpenAI-compatible endpoint).
|
||||
/// </summary>
|
||||
public string BaseUrl { get; set; } = "http://localhost:8080";
|
||||
|
||||
/// <summary>
|
||||
/// Model name (for logging, actual model is loaded on server).
|
||||
/// </summary>
|
||||
public string Model { get; set; } = "local-llama";
|
||||
|
||||
/// <summary>
|
||||
/// Request timeout.
|
||||
/// </summary>
|
||||
public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(300);
|
||||
|
||||
/// <summary>
|
||||
/// API key if server requires auth.
|
||||
/// </summary>
|
||||
public string? ApiKey { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Ollama provider options.
|
||||
/// </summary>
|
||||
public sealed class OllamaProviderOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Whether enabled.
|
||||
/// </summary>
|
||||
public bool Enabled { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Ollama server URL.
|
||||
/// </summary>
|
||||
public string BaseUrl { get; set; } = "http://localhost:11434";
|
||||
|
||||
/// <summary>
|
||||
/// Default model.
|
||||
/// </summary>
|
||||
public string Model { get; set; } = "llama3:8b";
|
||||
|
||||
/// <summary>
|
||||
/// Request timeout.
|
||||
/// </summary>
|
||||
public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(300);
|
||||
}
|
||||
@@ -0,0 +1,536 @@
|
||||
using System.Net.Http.Json;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
|
||||
/// <summary>
|
||||
/// Ollama provider configuration (maps to ollama.yaml).
|
||||
/// </summary>
|
||||
public sealed class OllamaConfig : LlmProviderConfigBase
|
||||
{
|
||||
/// <summary>
|
||||
/// Server base URL.
|
||||
/// </summary>
|
||||
public string BaseUrl { get; set; } = "http://localhost:11434";
|
||||
|
||||
/// <summary>
|
||||
/// Health check endpoint.
|
||||
/// </summary>
|
||||
public string HealthEndpoint { get; set; } = "/api/tags";
|
||||
|
||||
/// <summary>
|
||||
/// Model name.
|
||||
/// </summary>
|
||||
public string Model { get; set; } = "llama3:8b";
|
||||
|
||||
/// <summary>
|
||||
/// Fallback models.
|
||||
/// </summary>
|
||||
public List<string> FallbackModels { get; set; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Keep model loaded in memory.
|
||||
/// </summary>
|
||||
public string KeepAlive { get; set; } = "5m";
|
||||
|
||||
/// <summary>
|
||||
/// Top-p sampling.
|
||||
/// </summary>
|
||||
public double TopP { get; set; } = 1.0;
|
||||
|
||||
/// <summary>
|
||||
/// Top-k sampling.
|
||||
/// </summary>
|
||||
public int TopK { get; set; } = 40;
|
||||
|
||||
/// <summary>
|
||||
/// Repeat penalty.
|
||||
/// </summary>
|
||||
public double RepeatPenalty { get; set; } = 1.1;
|
||||
|
||||
/// <summary>
|
||||
/// Context length.
|
||||
/// </summary>
|
||||
public int NumCtx { get; set; } = 4096;
|
||||
|
||||
/// <summary>
|
||||
/// Number of tokens to predict.
|
||||
/// </summary>
|
||||
public int NumPredict { get; set; } = -1;
|
||||
|
||||
/// <summary>
|
||||
/// Number of GPU layers.
|
||||
/// </summary>
|
||||
public int NumGpu { get; set; } = 0;
|
||||
|
||||
/// <summary>
|
||||
/// Auto-pull model if not found.
|
||||
/// </summary>
|
||||
public bool AutoPull { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Verify model after pull.
|
||||
/// </summary>
|
||||
public bool VerifyPull { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Log token usage.
|
||||
/// </summary>
|
||||
public bool LogUsage { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Bind configuration from IConfiguration.
|
||||
/// </summary>
|
||||
public static OllamaConfig FromConfiguration(IConfiguration config)
|
||||
{
|
||||
var result = new OllamaConfig();
|
||||
|
||||
// Provider section
|
||||
result.Enabled = config.GetValue("enabled", true);
|
||||
result.Priority = config.GetValue("priority", 20);
|
||||
|
||||
// Server section
|
||||
var server = config.GetSection("server");
|
||||
result.BaseUrl = server.GetValue("baseUrl", "http://localhost:11434")!;
|
||||
result.HealthEndpoint = server.GetValue("healthEndpoint", "/api/tags")!;
|
||||
|
||||
// Model section
|
||||
var model = config.GetSection("model");
|
||||
result.Model = model.GetValue("name", "llama3:8b")!;
|
||||
result.FallbackModels = model.GetSection("fallbacks").Get<List<string>>() ?? new();
|
||||
result.KeepAlive = model.GetValue("keepAlive", "5m")!;
|
||||
|
||||
// Inference section
|
||||
var inference = config.GetSection("inference");
|
||||
result.Temperature = inference.GetValue("temperature", 0.0);
|
||||
result.MaxTokens = inference.GetValue("maxTokens", 4096);
|
||||
result.Seed = inference.GetValue<int?>("seed") ?? 42;
|
||||
result.TopP = inference.GetValue("topP", 1.0);
|
||||
result.TopK = inference.GetValue("topK", 40);
|
||||
result.RepeatPenalty = inference.GetValue("repeatPenalty", 1.1);
|
||||
result.NumCtx = inference.GetValue("numCtx", 4096);
|
||||
result.NumPredict = inference.GetValue("numPredict", -1);
|
||||
|
||||
// Request section
|
||||
var request = config.GetSection("request");
|
||||
result.Timeout = request.GetValue("timeout", TimeSpan.FromMinutes(5));
|
||||
result.MaxRetries = request.GetValue("maxRetries", 2);
|
||||
|
||||
// GPU section
|
||||
var gpu = config.GetSection("gpu");
|
||||
result.NumGpu = gpu.GetValue("numGpu", 0);
|
||||
|
||||
// Management section
|
||||
var management = config.GetSection("management");
|
||||
result.AutoPull = management.GetValue("autoPull", false);
|
||||
result.VerifyPull = management.GetValue("verifyPull", true);
|
||||
|
||||
// Logging section
|
||||
var logging = config.GetSection("logging");
|
||||
result.LogUsage = logging.GetValue("logUsage", true);
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Ollama LLM provider plugin.
|
||||
/// </summary>
|
||||
public sealed class OllamaLlmProviderPlugin : ILlmProviderPlugin
|
||||
{
|
||||
public string Name => "Ollama LLM Provider";
|
||||
public string ProviderId => "ollama";
|
||||
public string DisplayName => "Ollama";
|
||||
public string Description => "Local LLM inference via Ollama";
|
||||
public string DefaultConfigFileName => "ollama.yaml";
|
||||
|
||||
public bool IsAvailable(IServiceProvider services)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
public ILlmProvider Create(IServiceProvider services, IConfiguration configuration)
|
||||
{
|
||||
var config = OllamaConfig.FromConfiguration(configuration);
|
||||
var httpClientFactory = services.GetRequiredService<IHttpClientFactory>();
|
||||
var loggerFactory = services.GetRequiredService<ILoggerFactory>();
|
||||
|
||||
return new OllamaLlmProvider(
|
||||
httpClientFactory.CreateClient("Ollama"),
|
||||
config,
|
||||
loggerFactory.CreateLogger<OllamaLlmProvider>());
|
||||
}
|
||||
|
||||
public LlmProviderConfigValidation ValidateConfiguration(IConfiguration configuration)
|
||||
{
|
||||
var errors = new List<string>();
|
||||
var warnings = new List<string>();
|
||||
|
||||
var config = OllamaConfig.FromConfiguration(configuration);
|
||||
|
||||
if (!config.Enabled)
|
||||
{
|
||||
return LlmProviderConfigValidation.WithWarnings("Provider is disabled");
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(config.BaseUrl))
|
||||
{
|
||||
errors.Add("Server base URL is required.");
|
||||
}
|
||||
else if (!Uri.TryCreate(config.BaseUrl, UriKind.Absolute, out _))
|
||||
{
|
||||
errors.Add($"Invalid server URL: {config.BaseUrl}");
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(config.Model))
|
||||
{
|
||||
warnings.Add("No model specified, will use default 'llama3:8b'.");
|
||||
}
|
||||
|
||||
if (errors.Count > 0)
|
||||
{
|
||||
return new LlmProviderConfigValidation
|
||||
{
|
||||
IsValid = false,
|
||||
Errors = errors,
|
||||
Warnings = warnings
|
||||
};
|
||||
}
|
||||
|
||||
return new LlmProviderConfigValidation
|
||||
{
|
||||
IsValid = true,
|
||||
Warnings = warnings
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Ollama LLM provider implementation.
|
||||
/// </summary>
|
||||
public sealed class OllamaLlmProvider : ILlmProvider
|
||||
{
|
||||
private readonly HttpClient _httpClient;
|
||||
private readonly OllamaConfig _config;
|
||||
private readonly ILogger<OllamaLlmProvider> _logger;
|
||||
private bool _disposed;
|
||||
|
||||
public string ProviderId => "ollama";
|
||||
|
||||
public OllamaLlmProvider(
|
||||
HttpClient httpClient,
|
||||
OllamaConfig config,
|
||||
ILogger<OllamaLlmProvider> logger)
|
||||
{
|
||||
_httpClient = httpClient;
|
||||
_config = config;
|
||||
_logger = logger;
|
||||
|
||||
ConfigureHttpClient();
|
||||
}
|
||||
|
||||
private void ConfigureHttpClient()
|
||||
{
|
||||
_httpClient.BaseAddress = new Uri(_config.BaseUrl.TrimEnd('/') + "/");
|
||||
_httpClient.Timeout = _config.Timeout;
|
||||
}
|
||||
|
||||
public async Task<bool> IsAvailableAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (!_config.Enabled)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var response = await _httpClient.GetAsync(_config.HealthEndpoint.TrimStart('/'), cancellationToken);
|
||||
return response.IsSuccessStatusCode;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogDebug(ex, "Ollama availability check failed at {BaseUrl}", _config.BaseUrl);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<LlmCompletionResult> CompleteAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var stopwatch = System.Diagnostics.Stopwatch.StartNew();
|
||||
var model = request.Model ?? _config.Model;
|
||||
var temperature = request.Temperature > 0 ? request.Temperature : _config.Temperature;
|
||||
var maxTokens = request.MaxTokens > 0 ? request.MaxTokens : _config.MaxTokens;
|
||||
var seed = request.Seed ?? _config.Seed ?? 42;
|
||||
|
||||
var ollamaRequest = new OllamaChatRequest
|
||||
{
|
||||
Model = model,
|
||||
Messages = BuildMessages(request),
|
||||
Stream = false,
|
||||
Options = new OllamaOptions
|
||||
{
|
||||
Temperature = temperature,
|
||||
NumPredict = maxTokens,
|
||||
Seed = seed,
|
||||
TopP = _config.TopP,
|
||||
TopK = _config.TopK,
|
||||
RepeatPenalty = _config.RepeatPenalty,
|
||||
NumCtx = _config.NumCtx,
|
||||
NumGpu = _config.NumGpu,
|
||||
Stop = request.StopSequences?.ToArray()
|
||||
}
|
||||
};
|
||||
|
||||
var response = await _httpClient.PostAsJsonAsync(
|
||||
"api/chat",
|
||||
ollamaRequest,
|
||||
cancellationToken);
|
||||
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
var ollamaResponse = await response.Content.ReadFromJsonAsync<OllamaChatResponse>(cancellationToken);
|
||||
stopwatch.Stop();
|
||||
|
||||
if (ollamaResponse is null)
|
||||
{
|
||||
throw new InvalidOperationException("No response from Ollama");
|
||||
}
|
||||
|
||||
if (_config.LogUsage)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"Ollama usage - Model: {Model}, Input: {InputTokens}, Output: {OutputTokens}, Time: {TimeMs}ms",
|
||||
model,
|
||||
ollamaResponse.PromptEvalCount,
|
||||
ollamaResponse.EvalCount,
|
||||
stopwatch.ElapsedMilliseconds);
|
||||
}
|
||||
|
||||
return new LlmCompletionResult
|
||||
{
|
||||
Content = ollamaResponse.Message?.Content ?? string.Empty,
|
||||
ModelId = ollamaResponse.Model ?? model,
|
||||
ProviderId = ProviderId,
|
||||
InputTokens = ollamaResponse.PromptEvalCount,
|
||||
OutputTokens = ollamaResponse.EvalCount,
|
||||
TotalTimeMs = stopwatch.ElapsedMilliseconds,
|
||||
TimeToFirstTokenMs = ollamaResponse.PromptEvalDuration.HasValue
|
||||
? ollamaResponse.PromptEvalDuration.Value / 1_000_000
|
||||
: null,
|
||||
FinishReason = ollamaResponse.Done == true ? "stop" : null,
|
||||
Deterministic = temperature == 0,
|
||||
RequestId = request.RequestId
|
||||
};
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<LlmStreamChunk> CompleteStreamAsync(
|
||||
LlmCompletionRequest request,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var model = request.Model ?? _config.Model;
|
||||
var temperature = request.Temperature > 0 ? request.Temperature : _config.Temperature;
|
||||
var maxTokens = request.MaxTokens > 0 ? request.MaxTokens : _config.MaxTokens;
|
||||
var seed = request.Seed ?? _config.Seed ?? 42;
|
||||
|
||||
var ollamaRequest = new OllamaChatRequest
|
||||
{
|
||||
Model = model,
|
||||
Messages = BuildMessages(request),
|
||||
Stream = true,
|
||||
Options = new OllamaOptions
|
||||
{
|
||||
Temperature = temperature,
|
||||
NumPredict = maxTokens,
|
||||
Seed = seed,
|
||||
TopP = _config.TopP,
|
||||
TopK = _config.TopK,
|
||||
RepeatPenalty = _config.RepeatPenalty,
|
||||
NumCtx = _config.NumCtx,
|
||||
NumGpu = _config.NumGpu,
|
||||
Stop = request.StopSequences?.ToArray()
|
||||
}
|
||||
};
|
||||
|
||||
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "api/chat")
|
||||
{
|
||||
Content = new StringContent(
|
||||
JsonSerializer.Serialize(ollamaRequest),
|
||||
Encoding.UTF8,
|
||||
"application/json")
|
||||
};
|
||||
|
||||
var response = await _httpClient.SendAsync(
|
||||
httpRequest,
|
||||
HttpCompletionOption.ResponseHeadersRead,
|
||||
cancellationToken);
|
||||
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
|
||||
using var reader = new StreamReader(stream);
|
||||
|
||||
string? line;
|
||||
while ((line = await reader.ReadLineAsync(cancellationToken)) is not null)
|
||||
{
|
||||
cancellationToken.ThrowIfCancellationRequested();
|
||||
|
||||
if (string.IsNullOrEmpty(line))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
OllamaChatResponse? chunk;
|
||||
try
|
||||
{
|
||||
chunk = JsonSerializer.Deserialize<OllamaChatResponse>(line);
|
||||
}
|
||||
catch
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (chunk is null)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var content = chunk.Message?.Content ?? string.Empty;
|
||||
var isFinal = chunk.Done == true;
|
||||
|
||||
yield return new LlmStreamChunk
|
||||
{
|
||||
Content = content,
|
||||
IsFinal = isFinal,
|
||||
FinishReason = isFinal ? "stop" : null
|
||||
};
|
||||
|
||||
if (isFinal)
|
||||
{
|
||||
yield break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static List<OllamaMessage> BuildMessages(LlmCompletionRequest request)
|
||||
{
|
||||
var messages = new List<OllamaMessage>();
|
||||
|
||||
if (!string.IsNullOrEmpty(request.SystemPrompt))
|
||||
{
|
||||
messages.Add(new OllamaMessage { Role = "system", Content = request.SystemPrompt });
|
||||
}
|
||||
|
||||
messages.Add(new OllamaMessage { Role = "user", Content = request.UserPrompt });
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
_httpClient.Dispose();
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ollama API models
|
||||
internal sealed class OllamaChatRequest
|
||||
{
|
||||
[JsonPropertyName("model")]
|
||||
public required string Model { get; set; }
|
||||
|
||||
[JsonPropertyName("messages")]
|
||||
public required List<OllamaMessage> Messages { get; set; }
|
||||
|
||||
[JsonPropertyName("stream")]
|
||||
public bool Stream { get; set; }
|
||||
|
||||
[JsonPropertyName("options")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
||||
public OllamaOptions? Options { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class OllamaMessage
|
||||
{
|
||||
[JsonPropertyName("role")]
|
||||
public required string Role { get; set; }
|
||||
|
||||
[JsonPropertyName("content")]
|
||||
public required string Content { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class OllamaOptions
|
||||
{
|
||||
[JsonPropertyName("temperature")]
|
||||
public double Temperature { get; set; }
|
||||
|
||||
[JsonPropertyName("num_predict")]
|
||||
public int NumPredict { get; set; }
|
||||
|
||||
[JsonPropertyName("seed")]
|
||||
public int Seed { get; set; }
|
||||
|
||||
[JsonPropertyName("top_p")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public double TopP { get; set; }
|
||||
|
||||
[JsonPropertyName("top_k")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public int TopK { get; set; }
|
||||
|
||||
[JsonPropertyName("repeat_penalty")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public double RepeatPenalty { get; set; }
|
||||
|
||||
[JsonPropertyName("num_ctx")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public int NumCtx { get; set; }
|
||||
|
||||
[JsonPropertyName("num_gpu")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public int NumGpu { get; set; }
|
||||
|
||||
[JsonPropertyName("stop")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
||||
public string[]? Stop { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class OllamaChatResponse
|
||||
{
|
||||
[JsonPropertyName("model")]
|
||||
public string? Model { get; set; }
|
||||
|
||||
[JsonPropertyName("message")]
|
||||
public OllamaMessage? Message { get; set; }
|
||||
|
||||
[JsonPropertyName("done")]
|
||||
public bool? Done { get; set; }
|
||||
|
||||
[JsonPropertyName("total_duration")]
|
||||
public long? TotalDuration { get; set; }
|
||||
|
||||
[JsonPropertyName("load_duration")]
|
||||
public long? LoadDuration { get; set; }
|
||||
|
||||
[JsonPropertyName("prompt_eval_count")]
|
||||
public int? PromptEvalCount { get; set; }
|
||||
|
||||
[JsonPropertyName("prompt_eval_duration")]
|
||||
public long? PromptEvalDuration { get; set; }
|
||||
|
||||
[JsonPropertyName("eval_count")]
|
||||
public int? EvalCount { get; set; }
|
||||
|
||||
[JsonPropertyName("eval_duration")]
|
||||
public long? EvalDuration { get; set; }
|
||||
}
|
||||
@@ -0,0 +1,590 @@
|
||||
using System.Net.Http.Json;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
|
||||
/// <summary>
|
||||
/// OpenAI LLM provider configuration (maps to openai.yaml).
|
||||
/// </summary>
|
||||
public sealed class OpenAiConfig : LlmProviderConfigBase
|
||||
{
|
||||
/// <summary>
|
||||
/// API key (or use OPENAI_API_KEY env var).
|
||||
/// </summary>
|
||||
public string? ApiKey { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Base URL for API requests.
|
||||
/// </summary>
|
||||
public string BaseUrl { get; set; } = "https://api.openai.com/v1";
|
||||
|
||||
/// <summary>
|
||||
/// Model name.
|
||||
/// </summary>
|
||||
public string Model { get; set; } = "gpt-4o";
|
||||
|
||||
/// <summary>
|
||||
/// Fallback models.
|
||||
/// </summary>
|
||||
public List<string> FallbackModels { get; set; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Organization ID (optional).
|
||||
/// </summary>
|
||||
public string? OrganizationId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// API version (for Azure OpenAI).
|
||||
/// </summary>
|
||||
public string? ApiVersion { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Top-p sampling.
|
||||
/// </summary>
|
||||
public double TopP { get; set; } = 1.0;
|
||||
|
||||
/// <summary>
|
||||
/// Frequency penalty.
|
||||
/// </summary>
|
||||
public double FrequencyPenalty { get; set; } = 0;
|
||||
|
||||
/// <summary>
|
||||
/// Presence penalty.
|
||||
/// </summary>
|
||||
public double PresencePenalty { get; set; } = 0;
|
||||
|
||||
/// <summary>
|
||||
/// Log request/response bodies.
|
||||
/// </summary>
|
||||
public bool LogBodies { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Log token usage.
|
||||
/// </summary>
|
||||
public bool LogUsage { get; set; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Bind configuration from IConfiguration.
|
||||
/// </summary>
|
||||
public static OpenAiConfig FromConfiguration(IConfiguration config)
|
||||
{
|
||||
var result = new OpenAiConfig();
|
||||
|
||||
// Provider section
|
||||
result.Enabled = config.GetValue("enabled", true);
|
||||
result.Priority = config.GetValue("priority", 100);
|
||||
|
||||
// API section
|
||||
var api = config.GetSection("api");
|
||||
result.ApiKey = ExpandEnvVar(api.GetValue<string>("apiKey"));
|
||||
result.BaseUrl = api.GetValue("baseUrl", "https://api.openai.com/v1")!;
|
||||
result.OrganizationId = api.GetValue<string>("organizationId");
|
||||
result.ApiVersion = api.GetValue<string>("apiVersion");
|
||||
|
||||
// Model section
|
||||
var model = config.GetSection("model");
|
||||
result.Model = model.GetValue("name", "gpt-4o")!;
|
||||
result.FallbackModels = model.GetSection("fallbacks").Get<List<string>>() ?? new();
|
||||
|
||||
// Inference section
|
||||
var inference = config.GetSection("inference");
|
||||
result.Temperature = inference.GetValue("temperature", 0.0);
|
||||
result.MaxTokens = inference.GetValue("maxTokens", 4096);
|
||||
result.Seed = inference.GetValue<int?>("seed");
|
||||
result.TopP = inference.GetValue("topP", 1.0);
|
||||
result.FrequencyPenalty = inference.GetValue("frequencyPenalty", 0.0);
|
||||
result.PresencePenalty = inference.GetValue("presencePenalty", 0.0);
|
||||
|
||||
// Request section
|
||||
var request = config.GetSection("request");
|
||||
result.Timeout = request.GetValue("timeout", TimeSpan.FromSeconds(120));
|
||||
result.MaxRetries = request.GetValue("maxRetries", 3);
|
||||
|
||||
// Logging section
|
||||
var logging = config.GetSection("logging");
|
||||
result.LogBodies = logging.GetValue("logBodies", false);
|
||||
result.LogUsage = logging.GetValue("logUsage", true);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private static string? ExpandEnvVar(string? value)
|
||||
{
|
||||
if (string.IsNullOrEmpty(value))
|
||||
{
|
||||
return value;
|
||||
}
|
||||
|
||||
// Expand ${VAR_NAME} pattern
|
||||
if (value.StartsWith("${") && value.EndsWith("}"))
|
||||
{
|
||||
var varName = value.Substring(2, value.Length - 3);
|
||||
return Environment.GetEnvironmentVariable(varName);
|
||||
}
|
||||
|
||||
return Environment.ExpandEnvironmentVariables(value);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// OpenAI LLM provider plugin.
|
||||
/// </summary>
|
||||
public sealed class OpenAiLlmProviderPlugin : ILlmProviderPlugin
|
||||
{
|
||||
public string Name => "OpenAI LLM Provider";
|
||||
public string ProviderId => "openai";
|
||||
public string DisplayName => "OpenAI";
|
||||
public string Description => "OpenAI GPT models via API (supports Azure OpenAI)";
|
||||
public string DefaultConfigFileName => "openai.yaml";
|
||||
|
||||
public bool IsAvailable(IServiceProvider services)
|
||||
{
|
||||
// Plugin is always available if the assembly is loaded
|
||||
return true;
|
||||
}
|
||||
|
||||
public ILlmProvider Create(IServiceProvider services, IConfiguration configuration)
|
||||
{
|
||||
var config = OpenAiConfig.FromConfiguration(configuration);
|
||||
var httpClientFactory = services.GetRequiredService<IHttpClientFactory>();
|
||||
var loggerFactory = services.GetRequiredService<ILoggerFactory>();
|
||||
|
||||
return new OpenAiLlmProvider(
|
||||
httpClientFactory.CreateClient("OpenAI"),
|
||||
config,
|
||||
loggerFactory.CreateLogger<OpenAiLlmProvider>());
|
||||
}
|
||||
|
||||
public LlmProviderConfigValidation ValidateConfiguration(IConfiguration configuration)
|
||||
{
|
||||
var errors = new List<string>();
|
||||
var warnings = new List<string>();
|
||||
|
||||
var config = OpenAiConfig.FromConfiguration(configuration);
|
||||
|
||||
if (!config.Enabled)
|
||||
{
|
||||
return LlmProviderConfigValidation.WithWarnings("Provider is disabled");
|
||||
}
|
||||
|
||||
// Check API key
|
||||
var apiKey = config.ApiKey ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY");
|
||||
if (string.IsNullOrEmpty(apiKey))
|
||||
{
|
||||
errors.Add("API key not configured. Set 'api.apiKey' or OPENAI_API_KEY environment variable.");
|
||||
}
|
||||
|
||||
// Check base URL
|
||||
if (string.IsNullOrEmpty(config.BaseUrl))
|
||||
{
|
||||
errors.Add("Base URL is required.");
|
||||
}
|
||||
else if (!Uri.TryCreate(config.BaseUrl, UriKind.Absolute, out _))
|
||||
{
|
||||
errors.Add($"Invalid base URL: {config.BaseUrl}");
|
||||
}
|
||||
|
||||
// Check model
|
||||
if (string.IsNullOrEmpty(config.Model))
|
||||
{
|
||||
warnings.Add("No model specified, will use default 'gpt-4o'.");
|
||||
}
|
||||
|
||||
if (errors.Count > 0)
|
||||
{
|
||||
return new LlmProviderConfigValidation
|
||||
{
|
||||
IsValid = false,
|
||||
Errors = errors,
|
||||
Warnings = warnings
|
||||
};
|
||||
}
|
||||
|
||||
return new LlmProviderConfigValidation
|
||||
{
|
||||
IsValid = true,
|
||||
Warnings = warnings
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// OpenAI LLM provider implementation.
|
||||
/// </summary>
|
||||
public sealed class OpenAiLlmProvider : ILlmProvider
|
||||
{
|
||||
private readonly HttpClient _httpClient;
|
||||
private readonly OpenAiConfig _config;
|
||||
private readonly ILogger<OpenAiLlmProvider> _logger;
|
||||
private bool _disposed;
|
||||
|
||||
public string ProviderId => "openai";
|
||||
|
||||
public OpenAiLlmProvider(
|
||||
HttpClient httpClient,
|
||||
OpenAiConfig config,
|
||||
ILogger<OpenAiLlmProvider> logger)
|
||||
{
|
||||
_httpClient = httpClient;
|
||||
_config = config;
|
||||
_logger = logger;
|
||||
|
||||
ConfigureHttpClient();
|
||||
}
|
||||
|
||||
private void ConfigureHttpClient()
|
||||
{
|
||||
_httpClient.BaseAddress = new Uri(_config.BaseUrl.TrimEnd('/') + "/");
|
||||
_httpClient.Timeout = _config.Timeout;
|
||||
|
||||
var apiKey = _config.ApiKey ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY");
|
||||
if (!string.IsNullOrEmpty(apiKey))
|
||||
{
|
||||
_httpClient.DefaultRequestHeaders.Authorization =
|
||||
new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", apiKey);
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(_config.OrganizationId))
|
||||
{
|
||||
_httpClient.DefaultRequestHeaders.Add("OpenAI-Organization", _config.OrganizationId);
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(_config.ApiVersion))
|
||||
{
|
||||
_httpClient.DefaultRequestHeaders.Add("api-version", _config.ApiVersion);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<bool> IsAvailableAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (!_config.Enabled)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var response = await _httpClient.GetAsync("models", cancellationToken);
|
||||
return response.IsSuccessStatusCode;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogDebug(ex, "OpenAI availability check failed");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<LlmCompletionResult> CompleteAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var stopwatch = System.Diagnostics.Stopwatch.StartNew();
|
||||
var model = request.Model ?? _config.Model;
|
||||
var temperature = request.Temperature > 0 ? request.Temperature : _config.Temperature;
|
||||
var maxTokens = request.MaxTokens > 0 ? request.MaxTokens : _config.MaxTokens;
|
||||
var seed = request.Seed ?? _config.Seed;
|
||||
|
||||
var openAiRequest = new OpenAiChatRequest
|
||||
{
|
||||
Model = model,
|
||||
Messages = BuildMessages(request),
|
||||
Temperature = temperature,
|
||||
MaxTokens = maxTokens,
|
||||
Seed = seed,
|
||||
TopP = _config.TopP,
|
||||
FrequencyPenalty = _config.FrequencyPenalty,
|
||||
PresencePenalty = _config.PresencePenalty,
|
||||
Stop = request.StopSequences?.ToArray()
|
||||
};
|
||||
|
||||
if (_config.LogBodies)
|
||||
{
|
||||
_logger.LogDebug("OpenAI request: {Request}", JsonSerializer.Serialize(openAiRequest));
|
||||
}
|
||||
|
||||
var response = await _httpClient.PostAsJsonAsync(
|
||||
"chat/completions",
|
||||
openAiRequest,
|
||||
cancellationToken);
|
||||
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
var openAiResponse = await response.Content.ReadFromJsonAsync<OpenAiChatResponse>(cancellationToken);
|
||||
stopwatch.Stop();
|
||||
|
||||
if (openAiResponse?.Choices is null || openAiResponse.Choices.Count == 0)
|
||||
{
|
||||
throw new InvalidOperationException("No completion returned from OpenAI");
|
||||
}
|
||||
|
||||
var choice = openAiResponse.Choices[0];
|
||||
|
||||
if (_config.LogUsage && openAiResponse.Usage is not null)
|
||||
{
|
||||
_logger.LogInformation(
|
||||
"OpenAI usage - Model: {Model}, Input: {InputTokens}, Output: {OutputTokens}, Total: {TotalTokens}",
|
||||
openAiResponse.Model,
|
||||
openAiResponse.Usage.PromptTokens,
|
||||
openAiResponse.Usage.CompletionTokens,
|
||||
openAiResponse.Usage.TotalTokens);
|
||||
}
|
||||
|
||||
return new LlmCompletionResult
|
||||
{
|
||||
Content = choice.Message?.Content ?? string.Empty,
|
||||
ModelId = openAiResponse.Model ?? model,
|
||||
ProviderId = ProviderId,
|
||||
InputTokens = openAiResponse.Usage?.PromptTokens,
|
||||
OutputTokens = openAiResponse.Usage?.CompletionTokens,
|
||||
TotalTimeMs = stopwatch.ElapsedMilliseconds,
|
||||
FinishReason = choice.FinishReason,
|
||||
Deterministic = temperature == 0 && seed.HasValue,
|
||||
RequestId = request.RequestId ?? openAiResponse.Id
|
||||
};
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<LlmStreamChunk> CompleteStreamAsync(
|
||||
LlmCompletionRequest request,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var model = request.Model ?? _config.Model;
|
||||
var temperature = request.Temperature > 0 ? request.Temperature : _config.Temperature;
|
||||
var maxTokens = request.MaxTokens > 0 ? request.MaxTokens : _config.MaxTokens;
|
||||
var seed = request.Seed ?? _config.Seed;
|
||||
|
||||
var openAiRequest = new OpenAiChatRequest
|
||||
{
|
||||
Model = model,
|
||||
Messages = BuildMessages(request),
|
||||
Temperature = temperature,
|
||||
MaxTokens = maxTokens,
|
||||
Seed = seed,
|
||||
TopP = _config.TopP,
|
||||
FrequencyPenalty = _config.FrequencyPenalty,
|
||||
PresencePenalty = _config.PresencePenalty,
|
||||
Stop = request.StopSequences?.ToArray(),
|
||||
Stream = true
|
||||
};
|
||||
|
||||
var httpRequest = new HttpRequestMessage(HttpMethod.Post, "chat/completions")
|
||||
{
|
||||
Content = new StringContent(
|
||||
JsonSerializer.Serialize(openAiRequest),
|
||||
Encoding.UTF8,
|
||||
"application/json")
|
||||
};
|
||||
|
||||
var response = await _httpClient.SendAsync(
|
||||
httpRequest,
|
||||
HttpCompletionOption.ResponseHeadersRead,
|
||||
cancellationToken);
|
||||
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
|
||||
using var reader = new StreamReader(stream);
|
||||
|
||||
string? line;
|
||||
while ((line = await reader.ReadLineAsync(cancellationToken)) is not null)
|
||||
{
|
||||
cancellationToken.ThrowIfCancellationRequested();
|
||||
|
||||
if (string.IsNullOrEmpty(line))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!line.StartsWith("data: "))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var data = line.Substring(6);
|
||||
if (data == "[DONE]")
|
||||
{
|
||||
yield return new LlmStreamChunk
|
||||
{
|
||||
Content = string.Empty,
|
||||
IsFinal = true,
|
||||
FinishReason = "stop"
|
||||
};
|
||||
yield break;
|
||||
}
|
||||
|
||||
OpenAiStreamResponse? chunk;
|
||||
try
|
||||
{
|
||||
chunk = JsonSerializer.Deserialize<OpenAiStreamResponse>(data);
|
||||
}
|
||||
catch
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (chunk?.Choices is null || chunk.Choices.Count == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var choice = chunk.Choices[0];
|
||||
var content = choice.Delta?.Content ?? string.Empty;
|
||||
var isFinal = choice.FinishReason != null;
|
||||
|
||||
yield return new LlmStreamChunk
|
||||
{
|
||||
Content = content,
|
||||
IsFinal = isFinal,
|
||||
FinishReason = choice.FinishReason
|
||||
};
|
||||
|
||||
if (isFinal)
|
||||
{
|
||||
yield break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static List<OpenAiMessage> BuildMessages(LlmCompletionRequest request)
|
||||
{
|
||||
var messages = new List<OpenAiMessage>();
|
||||
|
||||
if (!string.IsNullOrEmpty(request.SystemPrompt))
|
||||
{
|
||||
messages.Add(new OpenAiMessage { Role = "system", Content = request.SystemPrompt });
|
||||
}
|
||||
|
||||
messages.Add(new OpenAiMessage { Role = "user", Content = request.UserPrompt });
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
_httpClient.Dispose();
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI API models
|
||||
internal sealed class OpenAiChatRequest
|
||||
{
|
||||
[JsonPropertyName("model")]
|
||||
public required string Model { get; set; }
|
||||
|
||||
[JsonPropertyName("messages")]
|
||||
public required List<OpenAiMessage> Messages { get; set; }
|
||||
|
||||
[JsonPropertyName("temperature")]
|
||||
public double Temperature { get; set; }
|
||||
|
||||
[JsonPropertyName("max_tokens")]
|
||||
public int MaxTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("seed")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
||||
public int? Seed { get; set; }
|
||||
|
||||
[JsonPropertyName("top_p")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public double TopP { get; set; }
|
||||
|
||||
[JsonPropertyName("frequency_penalty")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public double FrequencyPenalty { get; set; }
|
||||
|
||||
[JsonPropertyName("presence_penalty")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public double PresencePenalty { get; set; }
|
||||
|
||||
[JsonPropertyName("stop")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
||||
public string[]? Stop { get; set; }
|
||||
|
||||
[JsonPropertyName("stream")]
|
||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
|
||||
public bool Stream { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class OpenAiMessage
|
||||
{
|
||||
[JsonPropertyName("role")]
|
||||
public required string Role { get; set; }
|
||||
|
||||
[JsonPropertyName("content")]
|
||||
public required string Content { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class OpenAiChatResponse
|
||||
{
|
||||
[JsonPropertyName("id")]
|
||||
public string? Id { get; set; }
|
||||
|
||||
[JsonPropertyName("model")]
|
||||
public string? Model { get; set; }
|
||||
|
||||
[JsonPropertyName("choices")]
|
||||
public List<OpenAiChoice>? Choices { get; set; }
|
||||
|
||||
[JsonPropertyName("usage")]
|
||||
public OpenAiUsage? Usage { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class OpenAiChoice
|
||||
{
|
||||
[JsonPropertyName("index")]
|
||||
public int Index { get; set; }
|
||||
|
||||
[JsonPropertyName("message")]
|
||||
public OpenAiMessage? Message { get; set; }
|
||||
|
||||
[JsonPropertyName("finish_reason")]
|
||||
public string? FinishReason { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class OpenAiUsage
|
||||
{
|
||||
[JsonPropertyName("prompt_tokens")]
|
||||
public int PromptTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("completion_tokens")]
|
||||
public int CompletionTokens { get; set; }
|
||||
|
||||
[JsonPropertyName("total_tokens")]
|
||||
public int TotalTokens { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class OpenAiStreamResponse
|
||||
{
|
||||
[JsonPropertyName("id")]
|
||||
public string? Id { get; set; }
|
||||
|
||||
[JsonPropertyName("choices")]
|
||||
public List<OpenAiStreamChoice>? Choices { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class OpenAiStreamChoice
|
||||
{
|
||||
[JsonPropertyName("index")]
|
||||
public int Index { get; set; }
|
||||
|
||||
[JsonPropertyName("delta")]
|
||||
public OpenAiDelta? Delta { get; set; }
|
||||
|
||||
[JsonPropertyName("finish_reason")]
|
||||
public string? FinishReason { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class OpenAiDelta
|
||||
{
|
||||
[JsonPropertyName("content")]
|
||||
public string? Content { get; set; }
|
||||
}
|
||||
@@ -0,0 +1,233 @@
|
||||
using System.Collections.Immutable;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.AdvisoryAI.Guardrails;
|
||||
using StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
using StellaOps.AdvisoryAI.Orchestration;
|
||||
using StellaOps.AdvisoryAI.Prompting;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Inference;
|
||||
|
||||
/// <summary>
|
||||
/// Advisory inference client that uses LLM providers directly.
|
||||
/// Supports OpenAI, Claude, Llama.cpp server, and Ollama.
|
||||
/// This unblocks OFFLINE-07 by enabling local inference via HTTP to llama.cpp server.
|
||||
/// </summary>
|
||||
public sealed class ProviderBasedAdvisoryInferenceClient : IAdvisoryInferenceClient
|
||||
{
|
||||
private readonly ILlmProviderFactory _providerFactory;
|
||||
private readonly IOptions<LlmProviderOptions> _providerOptions;
|
||||
private readonly IOptions<AdvisoryAiInferenceOptions> _inferenceOptions;
|
||||
private readonly ILogger<ProviderBasedAdvisoryInferenceClient> _logger;
|
||||
|
||||
public ProviderBasedAdvisoryInferenceClient(
|
||||
ILlmProviderFactory providerFactory,
|
||||
IOptions<LlmProviderOptions> providerOptions,
|
||||
IOptions<AdvisoryAiInferenceOptions> inferenceOptions,
|
||||
ILogger<ProviderBasedAdvisoryInferenceClient> logger)
|
||||
{
|
||||
_providerFactory = providerFactory;
|
||||
_providerOptions = providerOptions;
|
||||
_inferenceOptions = inferenceOptions;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public async Task<AdvisoryInferenceResult> GenerateAsync(
|
||||
AdvisoryTaskPlan plan,
|
||||
AdvisoryPrompt prompt,
|
||||
AdvisoryGuardrailResult guardrailResult,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(plan);
|
||||
ArgumentNullException.ThrowIfNull(prompt);
|
||||
ArgumentNullException.ThrowIfNull(guardrailResult);
|
||||
|
||||
var sanitized = guardrailResult.SanitizedPrompt ?? prompt.Prompt ?? string.Empty;
|
||||
var systemPrompt = BuildSystemPrompt(plan, prompt);
|
||||
|
||||
// Try providers in order: default, then fallbacks
|
||||
var providerOrder = GetProviderOrder();
|
||||
Exception? lastException = null;
|
||||
|
||||
foreach (var providerId in providerOrder)
|
||||
{
|
||||
try
|
||||
{
|
||||
var provider = _providerFactory.GetProvider(providerId);
|
||||
|
||||
if (!await provider.IsAvailableAsync(cancellationToken))
|
||||
{
|
||||
_logger.LogDebug("Provider {ProviderId} is not available, trying next", providerId);
|
||||
continue;
|
||||
}
|
||||
|
||||
_logger.LogInformation("Using LLM provider {ProviderId} for task {TaskType}",
|
||||
providerId, plan.Request.TaskType);
|
||||
|
||||
var request = new LlmCompletionRequest
|
||||
{
|
||||
SystemPrompt = systemPrompt,
|
||||
UserPrompt = sanitized,
|
||||
Temperature = 0, // Deterministic
|
||||
MaxTokens = 4096,
|
||||
Seed = 42, // Fixed seed for reproducibility
|
||||
RequestId = plan.CacheKey
|
||||
};
|
||||
|
||||
var result = await provider.CompleteAsync(request, cancellationToken);
|
||||
|
||||
return ToAdvisoryResult(result, prompt.Metadata);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Provider {ProviderId} failed, trying next", providerId);
|
||||
lastException = ex;
|
||||
}
|
||||
}
|
||||
|
||||
// All providers failed - return fallback
|
||||
_logger.LogError(lastException, "All LLM providers failed for task {TaskType}. Returning sanitized prompt.",
|
||||
plan.Request.TaskType);
|
||||
|
||||
return AdvisoryInferenceResult.FromFallback(
|
||||
sanitized,
|
||||
"all_providers_failed",
|
||||
lastException?.Message);
|
||||
}
|
||||
|
||||
private IEnumerable<string> GetProviderOrder()
|
||||
{
|
||||
var opts = _providerOptions.Value;
|
||||
|
||||
yield return opts.DefaultProvider;
|
||||
|
||||
foreach (var fallback in opts.FallbackProviders)
|
||||
{
|
||||
if (!string.Equals(fallback, opts.DefaultProvider, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
yield return fallback;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static string BuildSystemPrompt(AdvisoryTaskPlan plan, AdvisoryPrompt prompt)
|
||||
{
|
||||
var taskType = plan.Request.TaskType.ToString();
|
||||
var profile = plan.Request.Profile;
|
||||
|
||||
var builder = new System.Text.StringBuilder();
|
||||
builder.AppendLine("You are a security advisory analyst assistant.");
|
||||
builder.AppendLine($"Task type: {taskType}");
|
||||
builder.AppendLine($"Profile: {profile}");
|
||||
builder.AppendLine();
|
||||
builder.AppendLine("Guidelines:");
|
||||
builder.AppendLine("- Provide accurate, evidence-based analysis");
|
||||
builder.AppendLine("- Use [EVIDENCE:id] format for citations when referencing source documents");
|
||||
builder.AppendLine("- Follow the 3-line doctrine: What, Why, Next Action");
|
||||
builder.AppendLine("- Be concise and actionable");
|
||||
|
||||
if (prompt.Citations.Length > 0)
|
||||
{
|
||||
builder.AppendLine();
|
||||
builder.AppendLine("Available evidence citations:");
|
||||
foreach (var citation in prompt.Citations)
|
||||
{
|
||||
builder.AppendLine($"- [EVIDENCE:{citation.Index}] Document: {citation.DocumentId}, Chunk: {citation.ChunkId}");
|
||||
}
|
||||
}
|
||||
|
||||
return builder.ToString();
|
||||
}
|
||||
|
||||
private static AdvisoryInferenceResult ToAdvisoryResult(
|
||||
LlmCompletionResult result,
|
||||
ImmutableDictionary<string, string> promptMetadata)
|
||||
{
|
||||
var metadataBuilder = ImmutableDictionary.CreateBuilder<string, string>(StringComparer.Ordinal);
|
||||
|
||||
// Copy prompt metadata
|
||||
foreach (var kvp in promptMetadata)
|
||||
{
|
||||
metadataBuilder[kvp.Key] = kvp.Value;
|
||||
}
|
||||
|
||||
// Add inference metadata
|
||||
metadataBuilder["inference.provider"] = result.ProviderId;
|
||||
metadataBuilder["inference.model"] = result.ModelId;
|
||||
metadataBuilder["inference.deterministic"] = result.Deterministic.ToString().ToLowerInvariant();
|
||||
|
||||
if (result.TotalTimeMs.HasValue)
|
||||
{
|
||||
metadataBuilder["inference.total_time_ms"] = result.TotalTimeMs.Value.ToString();
|
||||
}
|
||||
|
||||
if (result.TimeToFirstTokenMs.HasValue)
|
||||
{
|
||||
metadataBuilder["inference.ttft_ms"] = result.TimeToFirstTokenMs.Value.ToString();
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(result.FinishReason))
|
||||
{
|
||||
metadataBuilder["inference.finish_reason"] = result.FinishReason;
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(result.RequestId))
|
||||
{
|
||||
metadataBuilder["inference.request_id"] = result.RequestId;
|
||||
}
|
||||
|
||||
return new AdvisoryInferenceResult(
|
||||
result.Content,
|
||||
result.ModelId,
|
||||
result.InputTokens,
|
||||
result.OutputTokens,
|
||||
metadataBuilder.ToImmutable());
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for registering LLM provider services.
|
||||
/// </summary>
|
||||
public static class LlmProviderServiceExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds LLM provider services to the service collection.
|
||||
/// </summary>
|
||||
public static IServiceCollection AddLlmProviders(
|
||||
this IServiceCollection services,
|
||||
Action<LlmProviderOptions>? configure = null)
|
||||
{
|
||||
services.AddHttpClient();
|
||||
|
||||
if (configure is not null)
|
||||
{
|
||||
services.Configure(configure);
|
||||
}
|
||||
|
||||
services.AddSingleton<ILlmProviderFactory, LlmProviderFactory>();
|
||||
services.AddScoped<IAdvisoryInferenceClient, ProviderBasedAdvisoryInferenceClient>();
|
||||
services.AddScoped<FallbackLlmProvider>();
|
||||
|
||||
return services;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds LLM provider services with configuration from IConfiguration.
|
||||
/// </summary>
|
||||
public static IServiceCollection AddLlmProviders(
|
||||
this IServiceCollection services,
|
||||
IConfiguration configuration)
|
||||
{
|
||||
services.AddHttpClient();
|
||||
services.Configure<LlmProviderOptions>(
|
||||
configuration.GetSection(LlmProviderOptions.SectionName));
|
||||
|
||||
services.AddSingleton<ILlmProviderFactory, LlmProviderFactory>();
|
||||
services.AddScoped<IAdvisoryInferenceClient, ProviderBasedAdvisoryInferenceClient>();
|
||||
services.AddScoped<FallbackLlmProvider>();
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,385 @@
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Inference;
|
||||
|
||||
/// <summary>
|
||||
/// Manages signed model bundles with cryptographic verification.
|
||||
/// Sprint: SPRINT_20251226_019_AI_offline_inference
|
||||
/// Task: OFFLINE-15, OFFLINE-16
|
||||
/// </summary>
|
||||
public interface ISignedModelBundleManager
|
||||
{
|
||||
/// <summary>
|
||||
/// Sign a model bundle using the specified signer.
|
||||
/// </summary>
|
||||
Task<SigningResult> SignBundleAsync(
|
||||
string bundlePath,
|
||||
IModelBundleSigner signer,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Verify a signed model bundle.
|
||||
/// </summary>
|
||||
Task<SignatureVerificationResult> VerifySignatureAsync(
|
||||
string bundlePath,
|
||||
IModelBundleVerifier verifier,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Load a bundle with signature verification at load time.
|
||||
/// </summary>
|
||||
Task<ModelLoadResult> LoadWithVerificationAsync(
|
||||
string bundlePath,
|
||||
IModelBundleVerifier verifier,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Signer interface for model bundles.
|
||||
/// </summary>
|
||||
public interface IModelBundleSigner
|
||||
{
|
||||
/// <summary>
|
||||
/// Key ID of the signer.
|
||||
/// </summary>
|
||||
string KeyId { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Crypto scheme (e.g., "ed25519", "ecdsa-p256", "gost3410").
|
||||
/// </summary>
|
||||
string CryptoScheme { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Sign the manifest digest.
|
||||
/// </summary>
|
||||
Task<byte[]> SignAsync(byte[] data, CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verifier interface for model bundles.
|
||||
/// </summary>
|
||||
public interface IModelBundleVerifier
|
||||
{
|
||||
/// <summary>
|
||||
/// Verify a signature.
|
||||
/// </summary>
|
||||
Task<bool> VerifyAsync(
|
||||
byte[] data,
|
||||
byte[] signature,
|
||||
string keyId,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of signing a bundle.
|
||||
/// </summary>
|
||||
public sealed record SigningResult
|
||||
{
|
||||
public required bool Success { get; init; }
|
||||
public required string SignatureId { get; init; }
|
||||
public required string CryptoScheme { get; init; }
|
||||
public required string ManifestDigest { get; init; }
|
||||
public string? ErrorMessage { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of signature verification.
|
||||
/// </summary>
|
||||
public sealed record SignatureVerificationResult
|
||||
{
|
||||
public required bool Valid { get; init; }
|
||||
public required string SignatureId { get; init; }
|
||||
public required string CryptoScheme { get; init; }
|
||||
public required string KeyId { get; init; }
|
||||
public string? ErrorMessage { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of loading a model.
|
||||
/// </summary>
|
||||
public sealed record ModelLoadResult
|
||||
{
|
||||
public required bool Success { get; init; }
|
||||
public required string BundlePath { get; init; }
|
||||
public required bool SignatureVerified { get; init; }
|
||||
public required ModelBundleManifest? Manifest { get; init; }
|
||||
public string? ErrorMessage { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// DSSE envelope for model bundle signatures.
|
||||
/// </summary>
|
||||
public sealed record ModelBundleSignatureEnvelope
|
||||
{
|
||||
public required string PayloadType { get; init; }
|
||||
public required string Payload { get; init; }
|
||||
public required IReadOnlyList<ModelBundleSignature> Signatures { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A signature in the envelope.
|
||||
/// </summary>
|
||||
public sealed record ModelBundleSignature
|
||||
{
|
||||
public required string KeyId { get; init; }
|
||||
public required string Sig { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Default implementation of signed model bundle manager.
|
||||
/// </summary>
|
||||
public sealed class SignedModelBundleManager : ISignedModelBundleManager
|
||||
{
|
||||
private const string SignatureFileName = "signature.dsse";
|
||||
private const string ManifestFileName = "manifest.json";
|
||||
private const string PayloadType = "application/vnd.stellaops.model-bundle+json";
|
||||
|
||||
private static readonly JsonSerializerOptions JsonOptions = new()
|
||||
{
|
||||
WriteIndented = true,
|
||||
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower
|
||||
};
|
||||
|
||||
public async Task<SigningResult> SignBundleAsync(
|
||||
string bundlePath,
|
||||
IModelBundleSigner signer,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
try
|
||||
{
|
||||
var manifestPath = Path.Combine(bundlePath, ManifestFileName);
|
||||
if (!File.Exists(manifestPath))
|
||||
{
|
||||
return new SigningResult
|
||||
{
|
||||
Success = false,
|
||||
SignatureId = string.Empty,
|
||||
CryptoScheme = signer.CryptoScheme,
|
||||
ManifestDigest = string.Empty,
|
||||
ErrorMessage = "Manifest not found"
|
||||
};
|
||||
}
|
||||
|
||||
// Read and hash the manifest
|
||||
var manifestBytes = await File.ReadAllBytesAsync(manifestPath, cancellationToken);
|
||||
var manifestDigest = ComputeSha256(manifestBytes);
|
||||
|
||||
// Create the payload (manifest digest + metadata)
|
||||
var payload = new
|
||||
{
|
||||
manifest_digest = manifestDigest,
|
||||
signed_at = DateTime.UtcNow.ToString("o"),
|
||||
bundle_path = Path.GetFileName(bundlePath)
|
||||
};
|
||||
var payloadJson = JsonSerializer.Serialize(payload, JsonOptions);
|
||||
var payloadBytes = Encoding.UTF8.GetBytes(payloadJson);
|
||||
var payloadBase64 = Convert.ToBase64String(payloadBytes);
|
||||
|
||||
// Sign the PAE (Pre-Authentication Encoding)
|
||||
var pae = CreatePae(PayloadType, payloadBytes);
|
||||
var signature = await signer.SignAsync(pae, cancellationToken);
|
||||
var signatureBase64 = Convert.ToBase64String(signature);
|
||||
|
||||
var signatureId = $"{signer.CryptoScheme}-{DateTime.UtcNow:yyyyMMddHHmmss}-{manifestDigest[..8]}";
|
||||
|
||||
// Create DSSE envelope
|
||||
var envelope = new ModelBundleSignatureEnvelope
|
||||
{
|
||||
PayloadType = PayloadType,
|
||||
Payload = payloadBase64,
|
||||
Signatures = new[]
|
||||
{
|
||||
new ModelBundleSignature
|
||||
{
|
||||
KeyId = signer.KeyId,
|
||||
Sig = signatureBase64
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Write envelope
|
||||
var envelopePath = Path.Combine(bundlePath, SignatureFileName);
|
||||
var envelopeJson = JsonSerializer.Serialize(envelope, JsonOptions);
|
||||
await File.WriteAllTextAsync(envelopePath, envelopeJson, cancellationToken);
|
||||
|
||||
// Update manifest with signature info
|
||||
var manifest = await File.ReadAllTextAsync(manifestPath, cancellationToken);
|
||||
var manifestObj = JsonSerializer.Deserialize<Dictionary<string, object>>(manifest);
|
||||
if (manifestObj != null)
|
||||
{
|
||||
manifestObj["signature_id"] = signatureId;
|
||||
manifestObj["crypto_scheme"] = signer.CryptoScheme;
|
||||
var updatedManifest = JsonSerializer.Serialize(manifestObj, JsonOptions);
|
||||
await File.WriteAllTextAsync(manifestPath, updatedManifest, cancellationToken);
|
||||
}
|
||||
|
||||
return new SigningResult
|
||||
{
|
||||
Success = true,
|
||||
SignatureId = signatureId,
|
||||
CryptoScheme = signer.CryptoScheme,
|
||||
ManifestDigest = manifestDigest
|
||||
};
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
return new SigningResult
|
||||
{
|
||||
Success = false,
|
||||
SignatureId = string.Empty,
|
||||
CryptoScheme = signer.CryptoScheme,
|
||||
ManifestDigest = string.Empty,
|
||||
ErrorMessage = ex.Message
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<SignatureVerificationResult> VerifySignatureAsync(
|
||||
string bundlePath,
|
||||
IModelBundleVerifier verifier,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var signaturePath = Path.Combine(bundlePath, SignatureFileName);
|
||||
if (!File.Exists(signaturePath))
|
||||
{
|
||||
return new SignatureVerificationResult
|
||||
{
|
||||
Valid = false,
|
||||
SignatureId = string.Empty,
|
||||
CryptoScheme = string.Empty,
|
||||
KeyId = string.Empty,
|
||||
ErrorMessage = "No signature file found"
|
||||
};
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var envelopeJson = await File.ReadAllTextAsync(signaturePath, cancellationToken);
|
||||
var envelope = JsonSerializer.Deserialize<ModelBundleSignatureEnvelope>(envelopeJson);
|
||||
|
||||
if (envelope?.Signatures == null || envelope.Signatures.Count == 0)
|
||||
{
|
||||
return new SignatureVerificationResult
|
||||
{
|
||||
Valid = false,
|
||||
SignatureId = string.Empty,
|
||||
CryptoScheme = string.Empty,
|
||||
KeyId = string.Empty,
|
||||
ErrorMessage = "No signatures in envelope"
|
||||
};
|
||||
}
|
||||
|
||||
var sig = envelope.Signatures[0];
|
||||
var payloadBytes = Convert.FromBase64String(envelope.Payload);
|
||||
var signatureBytes = Convert.FromBase64String(sig.Sig);
|
||||
|
||||
// Recreate PAE and verify
|
||||
var pae = CreatePae(envelope.PayloadType, payloadBytes);
|
||||
var valid = await verifier.VerifyAsync(pae, signatureBytes, sig.KeyId, cancellationToken);
|
||||
|
||||
// Extract signature ID from manifest
|
||||
var manifestPath = Path.Combine(bundlePath, ManifestFileName);
|
||||
var manifest = await File.ReadAllTextAsync(manifestPath, cancellationToken);
|
||||
var manifestObj = JsonSerializer.Deserialize<Dictionary<string, JsonElement>>(manifest);
|
||||
var signatureId = manifestObj?.TryGetValue("signature_id", out var sigId) == true
|
||||
? sigId.GetString() ?? string.Empty
|
||||
: string.Empty;
|
||||
var cryptoScheme = manifestObj?.TryGetValue("crypto_scheme", out var scheme) == true
|
||||
? scheme.GetString() ?? string.Empty
|
||||
: string.Empty;
|
||||
|
||||
return new SignatureVerificationResult
|
||||
{
|
||||
Valid = valid,
|
||||
SignatureId = signatureId,
|
||||
CryptoScheme = cryptoScheme,
|
||||
KeyId = sig.KeyId,
|
||||
ErrorMessage = valid ? null : "Signature verification failed"
|
||||
};
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
return new SignatureVerificationResult
|
||||
{
|
||||
Valid = false,
|
||||
SignatureId = string.Empty,
|
||||
CryptoScheme = string.Empty,
|
||||
KeyId = string.Empty,
|
||||
ErrorMessage = ex.Message
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<ModelLoadResult> LoadWithVerificationAsync(
|
||||
string bundlePath,
|
||||
IModelBundleVerifier verifier,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var manifestPath = Path.Combine(bundlePath, ManifestFileName);
|
||||
if (!File.Exists(manifestPath))
|
||||
{
|
||||
return new ModelLoadResult
|
||||
{
|
||||
Success = false,
|
||||
BundlePath = bundlePath,
|
||||
SignatureVerified = false,
|
||||
Manifest = null,
|
||||
ErrorMessage = "Manifest not found"
|
||||
};
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
// Verify signature first
|
||||
var sigResult = await VerifySignatureAsync(bundlePath, verifier, cancellationToken);
|
||||
|
||||
// Load manifest
|
||||
var manifestJson = await File.ReadAllTextAsync(manifestPath, cancellationToken);
|
||||
var manifest = JsonSerializer.Deserialize<ModelBundleManifest>(manifestJson);
|
||||
|
||||
return new ModelLoadResult
|
||||
{
|
||||
Success = true,
|
||||
BundlePath = bundlePath,
|
||||
SignatureVerified = sigResult.Valid,
|
||||
Manifest = manifest,
|
||||
ErrorMessage = sigResult.Valid ? null : sigResult.ErrorMessage
|
||||
};
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
return new ModelLoadResult
|
||||
{
|
||||
Success = false,
|
||||
BundlePath = bundlePath,
|
||||
SignatureVerified = false,
|
||||
Manifest = null,
|
||||
ErrorMessage = ex.Message
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static string ComputeSha256(byte[] data)
|
||||
{
|
||||
var hash = SHA256.HashData(data);
|
||||
return Convert.ToHexStringLower(hash);
|
||||
}
|
||||
|
||||
private static byte[] CreatePae(string payloadType, byte[] payload)
|
||||
{
|
||||
// Pre-Authentication Encoding per DSSE spec
|
||||
// PAE = "DSSEv1" + SP + LEN(payloadType) + SP + payloadType + SP + LEN(payload) + SP + payload
|
||||
var parts = new List<byte>();
|
||||
parts.AddRange(Encoding.UTF8.GetBytes("DSSEv1 "));
|
||||
parts.AddRange(Encoding.UTF8.GetBytes(payloadType.Length.ToString()));
|
||||
parts.AddRange(Encoding.UTF8.GetBytes(" "));
|
||||
parts.AddRange(Encoding.UTF8.GetBytes(payloadType));
|
||||
parts.AddRange(Encoding.UTF8.GetBytes(" "));
|
||||
parts.AddRange(Encoding.UTF8.GetBytes(payload.Length.ToString()));
|
||||
parts.AddRange(Encoding.UTF8.GetBytes(" "));
|
||||
parts.AddRange(payload);
|
||||
return parts.ToArray();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,769 @@
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using StellaOps.Policy.TrustLattice;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.PolicyStudio;
|
||||
|
||||
/// <summary>
|
||||
/// Interface for compiling AI-generated rules into versioned, signed policy bundles.
|
||||
/// Sprint: SPRINT_20251226_017_AI_policy_copilot
|
||||
/// Task: POLICY-13
|
||||
/// </summary>
|
||||
public interface IPolicyBundleCompiler
|
||||
{
|
||||
/// <summary>
|
||||
/// Compiles lattice rules into a policy bundle.
|
||||
/// </summary>
|
||||
Task<PolicyCompilationResult> CompileAsync(
|
||||
PolicyCompilationRequest request,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Validates a compiled policy bundle.
|
||||
/// </summary>
|
||||
Task<PolicyValidationReport> ValidateAsync(
|
||||
PolicyBundle bundle,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Signs a compiled policy bundle.
|
||||
/// </summary>
|
||||
Task<SignedPolicyBundle> SignAsync(
|
||||
PolicyBundle bundle,
|
||||
PolicySigningOptions options,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Request to compile rules into a policy bundle.
|
||||
/// </summary>
|
||||
public sealed record PolicyCompilationRequest
|
||||
{
|
||||
/// <summary>
|
||||
/// Rules to compile.
|
||||
/// </summary>
|
||||
public required IReadOnlyList<LatticeRule> Rules { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Test cases to include.
|
||||
/// </summary>
|
||||
public IReadOnlyList<PolicyTestCase>? TestCases { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Policy bundle name.
|
||||
/// </summary>
|
||||
public required string Name { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Policy version.
|
||||
/// </summary>
|
||||
public string Version { get; init; } = "1.0.0";
|
||||
|
||||
/// <summary>
|
||||
/// Target policy pack ID (if extending existing).
|
||||
/// </summary>
|
||||
public string? TargetPolicyPack { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Trust roots to include.
|
||||
/// </summary>
|
||||
public IReadOnlyList<TrustRoot>? TrustRoots { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Trust requirements.
|
||||
/// </summary>
|
||||
public TrustRequirements? TrustRequirements { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Whether to validate before compilation.
|
||||
/// </summary>
|
||||
public bool ValidateBeforeCompile { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Whether to run test cases.
|
||||
/// </summary>
|
||||
public bool RunTests { get; init; } = true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of policy compilation.
|
||||
/// </summary>
|
||||
public sealed record PolicyCompilationResult
|
||||
{
|
||||
/// <summary>
|
||||
/// Whether compilation was successful.
|
||||
/// </summary>
|
||||
public required bool Success { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Compiled policy bundle.
|
||||
/// </summary>
|
||||
public PolicyBundle? Bundle { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Compilation errors.
|
||||
/// </summary>
|
||||
public IReadOnlyList<string> Errors { get; init; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Compilation warnings.
|
||||
/// </summary>
|
||||
public IReadOnlyList<string> Warnings { get; init; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Validation report.
|
||||
/// </summary>
|
||||
public PolicyValidationReport? ValidationReport { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Test run report.
|
||||
/// </summary>
|
||||
public PolicyTestReport? TestReport { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Compilation timestamp (UTC ISO-8601).
|
||||
/// </summary>
|
||||
public required string CompiledAt { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Bundle digest.
|
||||
/// </summary>
|
||||
public string? BundleDigest { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Validation report for a policy bundle.
|
||||
/// </summary>
|
||||
public sealed record PolicyValidationReport
|
||||
{
|
||||
/// <summary>
|
||||
/// Whether validation passed.
|
||||
/// </summary>
|
||||
public required bool Valid { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Syntax valid.
|
||||
/// </summary>
|
||||
public bool SyntaxValid { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Semantics valid.
|
||||
/// </summary>
|
||||
public bool SemanticsValid { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Syntax errors.
|
||||
/// </summary>
|
||||
public IReadOnlyList<string> SyntaxErrors { get; init; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Semantic warnings.
|
||||
/// </summary>
|
||||
public IReadOnlyList<string> SemanticWarnings { get; init; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Rule conflicts detected.
|
||||
/// </summary>
|
||||
public IReadOnlyList<RuleConflict> Conflicts { get; init; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Coverage estimate (0.0 - 1.0).
|
||||
/// </summary>
|
||||
public double Coverage { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Test report for a policy bundle.
|
||||
/// </summary>
|
||||
public sealed record PolicyTestReport
|
||||
{
|
||||
/// <summary>
|
||||
/// Total tests run.
|
||||
/// </summary>
|
||||
public int TotalTests { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Tests passed.
|
||||
/// </summary>
|
||||
public int Passed { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Tests failed.
|
||||
/// </summary>
|
||||
public int Failed { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Pass rate (0.0 - 1.0).
|
||||
/// </summary>
|
||||
public double PassRate => TotalTests > 0 ? (double)Passed / TotalTests : 0;
|
||||
|
||||
/// <summary>
|
||||
/// Failure details.
|
||||
/// </summary>
|
||||
public IReadOnlyList<TestFailure> Failures { get; init; } = [];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Test failure detail.
|
||||
/// </summary>
|
||||
public sealed record TestFailure
|
||||
{
|
||||
public required string TestId { get; init; }
|
||||
public required string RuleId { get; init; }
|
||||
public required string Description { get; init; }
|
||||
public required string Expected { get; init; }
|
||||
public required string Actual { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Options for signing a policy bundle.
|
||||
/// </summary>
|
||||
public sealed record PolicySigningOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Key ID to use for signing.
|
||||
/// </summary>
|
||||
public string? KeyId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Crypto scheme (eidas, fips, gost, sm).
|
||||
/// </summary>
|
||||
public string? CryptoScheme { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Signer identity.
|
||||
/// </summary>
|
||||
public string? SignerIdentity { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Include timestamp.
|
||||
/// </summary>
|
||||
public bool IncludeTimestamp { get; init; } = true;
|
||||
|
||||
/// <summary>
|
||||
/// Timestamping authority URL.
|
||||
/// </summary>
|
||||
public string? TimestampAuthority { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Signed policy bundle.
|
||||
/// </summary>
|
||||
public sealed record SignedPolicyBundle
|
||||
{
|
||||
/// <summary>
|
||||
/// The policy bundle.
|
||||
/// </summary>
|
||||
public required PolicyBundle Bundle { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Bundle content hash.
|
||||
/// </summary>
|
||||
public required string ContentDigest { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Signature bytes (base64).
|
||||
/// </summary>
|
||||
public required string Signature { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Signing algorithm used.
|
||||
/// </summary>
|
||||
public required string Algorithm { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Key ID used for signing.
|
||||
/// </summary>
|
||||
public string? KeyId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Signer identity.
|
||||
/// </summary>
|
||||
public string? SignerIdentity { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Signature timestamp (UTC ISO-8601).
|
||||
/// </summary>
|
||||
public string? SignedAt { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Timestamp token (if requested).
|
||||
/// </summary>
|
||||
public string? TimestampToken { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Certificate chain (PEM).
|
||||
/// </summary>
|
||||
public string? CertificateChain { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Compiles AI-generated rules into versioned, signed policy bundles.
|
||||
/// Sprint: SPRINT_20251226_017_AI_policy_copilot
|
||||
/// Task: POLICY-13
|
||||
/// </summary>
|
||||
public sealed class PolicyBundleCompiler : IPolicyBundleCompiler
|
||||
{
|
||||
private static readonly JsonSerializerOptions SerializerOptions = new(JsonSerializerDefaults.Web)
|
||||
{
|
||||
WriteIndented = false,
|
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
||||
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
|
||||
};
|
||||
|
||||
private readonly IPolicyRuleGenerator _ruleGenerator;
|
||||
private readonly IPolicyBundleSigner? _signer;
|
||||
private readonly ILogger<PolicyBundleCompiler> _logger;
|
||||
|
||||
public PolicyBundleCompiler(
|
||||
IPolicyRuleGenerator ruleGenerator,
|
||||
IPolicyBundleSigner? signer,
|
||||
ILogger<PolicyBundleCompiler> logger)
|
||||
{
|
||||
_ruleGenerator = ruleGenerator ?? throw new ArgumentNullException(nameof(ruleGenerator));
|
||||
_signer = signer;
|
||||
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
|
||||
}
|
||||
|
||||
public async Task<PolicyCompilationResult> CompileAsync(
|
||||
PolicyCompilationRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
_logger.LogInformation("Compiling policy bundle '{Name}' with {RuleCount} rules",
|
||||
request.Name, request.Rules.Count);
|
||||
|
||||
var errors = new List<string>();
|
||||
var warnings = new List<string>();
|
||||
PolicyValidationReport? validationReport = null;
|
||||
PolicyTestReport? testReport = null;
|
||||
|
||||
// Step 1: Validate rules if requested
|
||||
if (request.ValidateBeforeCompile)
|
||||
{
|
||||
var validationResult = await _ruleGenerator.ValidateAsync(
|
||||
request.Rules, null, cancellationToken);
|
||||
|
||||
validationReport = new PolicyValidationReport
|
||||
{
|
||||
Valid = validationResult.Valid,
|
||||
SyntaxValid = validationResult.Valid,
|
||||
SemanticsValid = validationResult.Conflicts.Count == 0,
|
||||
Conflicts = validationResult.Conflicts,
|
||||
SemanticWarnings = validationResult.UnreachableConditions.Concat(validationResult.PotentialLoops).ToList(),
|
||||
Coverage = validationResult.Coverage
|
||||
};
|
||||
|
||||
if (!validationResult.Valid)
|
||||
{
|
||||
errors.AddRange(validationResult.Conflicts.Select(c =>
|
||||
$"Rule conflict: {c.Description}"));
|
||||
errors.AddRange(validationResult.UnreachableConditions);
|
||||
errors.AddRange(validationResult.PotentialLoops);
|
||||
}
|
||||
|
||||
warnings.AddRange(validationResult.UnreachableConditions);
|
||||
}
|
||||
|
||||
// Step 2: Run tests if requested
|
||||
if (request.RunTests && request.TestCases?.Count > 0)
|
||||
{
|
||||
testReport = RunTests(request.Rules, request.TestCases);
|
||||
|
||||
if (testReport.Failed > 0)
|
||||
{
|
||||
warnings.Add($"{testReport.Failed} of {testReport.TotalTests} tests failed");
|
||||
}
|
||||
}
|
||||
|
||||
// Check for blocking errors
|
||||
if (errors.Count > 0)
|
||||
{
|
||||
return new PolicyCompilationResult
|
||||
{
|
||||
Success = false,
|
||||
Errors = errors,
|
||||
Warnings = warnings,
|
||||
ValidationReport = validationReport,
|
||||
TestReport = testReport,
|
||||
CompiledAt = DateTime.UtcNow.ToString("O")
|
||||
};
|
||||
}
|
||||
|
||||
// Step 3: Build the policy bundle
|
||||
var bundle = BuildBundle(request);
|
||||
|
||||
// Step 4: Compute bundle digest
|
||||
var bundleDigest = ComputeBundleDigest(bundle);
|
||||
|
||||
_logger.LogInformation("Compiled policy bundle '{Name}' v{Version} with digest {Digest}",
|
||||
bundle.Name, bundle.Version, bundleDigest);
|
||||
|
||||
return new PolicyCompilationResult
|
||||
{
|
||||
Success = true,
|
||||
Bundle = bundle,
|
||||
Errors = errors,
|
||||
Warnings = warnings,
|
||||
ValidationReport = validationReport,
|
||||
TestReport = testReport,
|
||||
CompiledAt = DateTime.UtcNow.ToString("O"),
|
||||
BundleDigest = bundleDigest
|
||||
};
|
||||
}
|
||||
|
||||
public Task<PolicyValidationReport> ValidateAsync(
|
||||
PolicyBundle bundle,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var syntaxErrors = new List<string>();
|
||||
var semanticWarnings = new List<string>();
|
||||
var conflicts = new List<RuleConflict>();
|
||||
|
||||
// Validate trust roots
|
||||
foreach (var root in bundle.TrustRoots)
|
||||
{
|
||||
if (root.ExpiresAt.HasValue && root.ExpiresAt.Value < DateTimeOffset.UtcNow)
|
||||
{
|
||||
semanticWarnings.Add($"Trust root '{root.Principal.Id}' has expired");
|
||||
}
|
||||
}
|
||||
|
||||
// Validate custom rules
|
||||
foreach (var rule in bundle.CustomRules)
|
||||
{
|
||||
if (string.IsNullOrEmpty(rule.Name))
|
||||
{
|
||||
syntaxErrors.Add($"Rule is missing a name");
|
||||
}
|
||||
}
|
||||
|
||||
// Check for rule conflicts
|
||||
var rules = bundle.CustomRules.ToList();
|
||||
for (int i = 0; i < rules.Count; i++)
|
||||
{
|
||||
for (int j = i + 1; j < rules.Count; j++)
|
||||
{
|
||||
// Simple overlap check based on atom patterns
|
||||
if (HasOverlappingAtoms(rules[i], rules[j]))
|
||||
{
|
||||
conflicts.Add(new RuleConflict
|
||||
{
|
||||
RuleId1 = rules[i].Name,
|
||||
RuleId2 = rules[j].Name,
|
||||
Description = "Rules may have overlapping conditions",
|
||||
SuggestedResolution = "Review rule priorities",
|
||||
Severity = "warning"
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Task.FromResult(new PolicyValidationReport
|
||||
{
|
||||
Valid = syntaxErrors.Count == 0,
|
||||
SyntaxValid = syntaxErrors.Count == 0,
|
||||
SemanticsValid = conflicts.Count == 0,
|
||||
SyntaxErrors = syntaxErrors,
|
||||
SemanticWarnings = semanticWarnings,
|
||||
Conflicts = conflicts,
|
||||
Coverage = EstimateCoverage(bundle)
|
||||
});
|
||||
}
|
||||
|
||||
public async Task<SignedPolicyBundle> SignAsync(
|
||||
PolicyBundle bundle,
|
||||
PolicySigningOptions options,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var contentDigest = ComputeBundleDigest(bundle);
|
||||
|
||||
if (_signer is null)
|
||||
{
|
||||
_logger.LogWarning("No signer configured, returning unsigned bundle");
|
||||
return new SignedPolicyBundle
|
||||
{
|
||||
Bundle = bundle,
|
||||
ContentDigest = contentDigest,
|
||||
Signature = string.Empty,
|
||||
Algorithm = "none",
|
||||
SignedAt = DateTime.UtcNow.ToString("O")
|
||||
};
|
||||
}
|
||||
|
||||
var signature = await _signer.SignAsync(contentDigest, options, cancellationToken);
|
||||
|
||||
_logger.LogInformation("Signed policy bundle '{Name}' with key {KeyId}",
|
||||
bundle.Name, options.KeyId);
|
||||
|
||||
return new SignedPolicyBundle
|
||||
{
|
||||
Bundle = bundle,
|
||||
ContentDigest = contentDigest,
|
||||
Signature = signature.SignatureBase64,
|
||||
Algorithm = signature.Algorithm,
|
||||
KeyId = options.KeyId,
|
||||
SignerIdentity = options.SignerIdentity,
|
||||
SignedAt = DateTime.UtcNow.ToString("O"),
|
||||
CertificateChain = signature.CertificateChain
|
||||
};
|
||||
}
|
||||
|
||||
private PolicyBundle BuildBundle(PolicyCompilationRequest request)
|
||||
{
|
||||
// Convert LatticeRules to SelectionRules
|
||||
var customRules = request.Rules.Select(ConvertToSelectionRule).ToList();
|
||||
|
||||
return new PolicyBundle
|
||||
{
|
||||
Id = $"bundle:{ComputeHash(request.Name)[..12]}",
|
||||
Name = request.Name,
|
||||
Version = request.Version,
|
||||
TrustRoots = request.TrustRoots ?? [],
|
||||
TrustRequirements = request.TrustRequirements ?? new TrustRequirements(),
|
||||
CustomRules = customRules,
|
||||
ConflictResolution = ConflictResolution.ReportConflict,
|
||||
AssumeReachableWhenUnknown = true
|
||||
};
|
||||
}
|
||||
|
||||
private static SelectionRule ConvertToSelectionRule(LatticeRule rule)
|
||||
{
|
||||
// Map disposition string to Disposition enum
|
||||
var disposition = rule.Disposition.ToLowerInvariant() switch
|
||||
{
|
||||
"block" or "exploitable" => Disposition.Exploitable,
|
||||
"allow" or "resolved" => Disposition.Resolved,
|
||||
"resolved_with_pedigree" => Disposition.ResolvedWithPedigree,
|
||||
"not_affected" => Disposition.NotAffected,
|
||||
"false_positive" => Disposition.FalsePositive,
|
||||
"warn" or "in_triage" or _ => Disposition.InTriage
|
||||
};
|
||||
|
||||
// Build condition function from lattice expression
|
||||
var condition = BuildConditionFromExpression(rule.LatticeExpression);
|
||||
|
||||
return new SelectionRule
|
||||
{
|
||||
Name = rule.Name,
|
||||
Priority = rule.Priority,
|
||||
Disposition = disposition,
|
||||
ConditionDescription = rule.LatticeExpression,
|
||||
Condition = condition,
|
||||
ExplanationTemplate = rule.Description
|
||||
};
|
||||
}
|
||||
|
||||
private static Func<IReadOnlyDictionary<SecurityAtom, K4Value>, bool> BuildConditionFromExpression(string latticeExpression)
|
||||
{
|
||||
// Parse lattice expression and build condition function
|
||||
// This is a simplified parser - production would use proper expression parsing
|
||||
var expr = latticeExpression.ToUpperInvariant();
|
||||
|
||||
return atoms =>
|
||||
{
|
||||
// Check for negated atoms first
|
||||
if (expr.Contains("¬REACHABLE") || expr.Contains("NOT REACHABLE") || expr.Contains("!REACHABLE"))
|
||||
{
|
||||
if (atoms.TryGetValue(SecurityAtom.Reachable, out var r) && r != K4Value.False)
|
||||
return false;
|
||||
}
|
||||
else if (expr.Contains("REACHABLE"))
|
||||
{
|
||||
if (atoms.TryGetValue(SecurityAtom.Reachable, out var r) && r != K4Value.True)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (expr.Contains("¬PRESENT") || expr.Contains("NOT PRESENT") || expr.Contains("!PRESENT"))
|
||||
{
|
||||
if (atoms.TryGetValue(SecurityAtom.Present, out var p) && p != K4Value.False)
|
||||
return false;
|
||||
}
|
||||
else if (expr.Contains("PRESENT"))
|
||||
{
|
||||
if (atoms.TryGetValue(SecurityAtom.Present, out var p) && p != K4Value.True)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (expr.Contains("¬APPLIES") || expr.Contains("NOT APPLIES") || expr.Contains("!APPLIES"))
|
||||
{
|
||||
if (atoms.TryGetValue(SecurityAtom.Applies, out var a) && a != K4Value.False)
|
||||
return false;
|
||||
}
|
||||
else if (expr.Contains("APPLIES"))
|
||||
{
|
||||
if (atoms.TryGetValue(SecurityAtom.Applies, out var a) && a != K4Value.True)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (expr.Contains("MITIGATED"))
|
||||
{
|
||||
if (atoms.TryGetValue(SecurityAtom.Mitigated, out var m) && m != K4Value.True)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (expr.Contains("FIXED"))
|
||||
{
|
||||
if (atoms.TryGetValue(SecurityAtom.Fixed, out var f) && f != K4Value.True)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (expr.Contains("MISATTRIBUTED"))
|
||||
{
|
||||
if (atoms.TryGetValue(SecurityAtom.Misattributed, out var m) && m != K4Value.True)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extract referenced atoms from a lattice expression for overlap detection.
|
||||
/// </summary>
|
||||
private static HashSet<SecurityAtom> ExtractAtomsFromExpression(string expression)
|
||||
{
|
||||
var atoms = new HashSet<SecurityAtom>();
|
||||
var expr = expression.ToUpperInvariant();
|
||||
|
||||
if (expr.Contains("REACHABLE")) atoms.Add(SecurityAtom.Reachable);
|
||||
if (expr.Contains("PRESENT")) atoms.Add(SecurityAtom.Present);
|
||||
if (expr.Contains("APPLIES")) atoms.Add(SecurityAtom.Applies);
|
||||
if (expr.Contains("MITIGATED")) atoms.Add(SecurityAtom.Mitigated);
|
||||
if (expr.Contains("FIXED")) atoms.Add(SecurityAtom.Fixed);
|
||||
if (expr.Contains("MISATTRIBUTED")) atoms.Add(SecurityAtom.Misattributed);
|
||||
|
||||
return atoms;
|
||||
}
|
||||
|
||||
private PolicyTestReport RunTests(
|
||||
IReadOnlyList<LatticeRule> rules,
|
||||
IReadOnlyList<PolicyTestCase> testCases)
|
||||
{
|
||||
var failures = new List<TestFailure>();
|
||||
var passed = 0;
|
||||
|
||||
foreach (var test in testCases)
|
||||
{
|
||||
// Find all target rules for this test
|
||||
var targetRules = rules.Where(r => test.TargetRuleIds.Contains(r.RuleId)).ToList();
|
||||
if (targetRules.Count == 0)
|
||||
{
|
||||
failures.Add(new TestFailure
|
||||
{
|
||||
TestId = test.TestCaseId,
|
||||
RuleId = string.Join(",", test.TargetRuleIds),
|
||||
Description = "Target rules not found",
|
||||
Expected = test.ExpectedDisposition,
|
||||
Actual = "not_found"
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Evaluate the test against the rules
|
||||
var result = EvaluateTest(targetRules, test);
|
||||
if (result == test.ExpectedDisposition)
|
||||
{
|
||||
passed++;
|
||||
}
|
||||
else
|
||||
{
|
||||
failures.Add(new TestFailure
|
||||
{
|
||||
TestId = test.TestCaseId,
|
||||
RuleId = string.Join(",", test.TargetRuleIds),
|
||||
Description = test.Description,
|
||||
Expected = test.ExpectedDisposition,
|
||||
Actual = result
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return new PolicyTestReport
|
||||
{
|
||||
TotalTests = testCases.Count,
|
||||
Passed = passed,
|
||||
Failed = failures.Count,
|
||||
Failures = failures
|
||||
};
|
||||
}
|
||||
|
||||
private static string EvaluateTest(IReadOnlyList<LatticeRule> rules, PolicyTestCase test)
|
||||
{
|
||||
// Simplified test evaluation - find highest priority matching rule
|
||||
// In production, use proper lattice engine with full atom evaluation
|
||||
var bestMatch = rules.OrderBy(r => r.Priority).FirstOrDefault();
|
||||
return bestMatch?.Disposition ?? "unknown";
|
||||
}
|
||||
|
||||
private static bool HasOverlappingAtoms(SelectionRule rule1, SelectionRule rule2)
|
||||
{
|
||||
// Extract atoms from condition descriptions (which contain the lattice expressions)
|
||||
var atoms1 = ExtractAtomsFromExpression(rule1.ConditionDescription);
|
||||
var atoms2 = ExtractAtomsFromExpression(rule2.ConditionDescription);
|
||||
return atoms1.Overlaps(atoms2);
|
||||
}
|
||||
|
||||
private static double EstimateCoverage(PolicyBundle bundle)
|
||||
{
|
||||
// Count distinct atoms referenced across all rules
|
||||
var atomsCovered = bundle.CustomRules
|
||||
.SelectMany(r => ExtractAtomsFromExpression(r.ConditionDescription))
|
||||
.Distinct()
|
||||
.Count();
|
||||
|
||||
// 6 possible security atoms, estimate coverage as percentage
|
||||
return Math.Min(1.0, (double)atomsCovered / 6.0);
|
||||
}
|
||||
|
||||
private static string ComputeBundleDigest(PolicyBundle bundle)
|
||||
{
|
||||
var json = JsonSerializer.Serialize(bundle, SerializerOptions);
|
||||
var bytes = SHA256.HashData(Encoding.UTF8.GetBytes(json));
|
||||
return $"sha256:{Convert.ToHexStringLower(bytes)}";
|
||||
}
|
||||
|
||||
private static string ComputeHash(string content)
|
||||
{
|
||||
var bytes = SHA256.HashData(Encoding.UTF8.GetBytes(content));
|
||||
return Convert.ToHexStringLower(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Interface for signing policy bundles.
|
||||
/// </summary>
|
||||
public interface IPolicyBundleSigner
|
||||
{
|
||||
/// <summary>
|
||||
/// Signs content and returns signature.
|
||||
/// </summary>
|
||||
Task<PolicySignature> SignAsync(
|
||||
string contentDigest,
|
||||
PolicySigningOptions options,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Policy signature result.
|
||||
/// </summary>
|
||||
public sealed record PolicySignature
|
||||
{
|
||||
/// <summary>
|
||||
/// Signature bytes (base64).
|
||||
/// </summary>
|
||||
public required string SignatureBase64 { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Signing algorithm.
|
||||
/// </summary>
|
||||
public required string Algorithm { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Certificate chain (PEM).
|
||||
/// </summary>
|
||||
public string? CertificateChain { get; init; }
|
||||
}
|
||||
|
||||
@@ -0,0 +1,324 @@
|
||||
using System.Text;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Remediation;
|
||||
|
||||
/// <summary>
|
||||
/// Service for computing and signing SBOM deltas during remediation.
|
||||
/// Sprint: SPRINT_20251226_016_AI_remedy_autopilot
|
||||
/// Task: REMEDY-15, REMEDY-16, REMEDY-17
|
||||
/// </summary>
|
||||
public interface IRemediationDeltaService
|
||||
{
|
||||
/// <summary>
|
||||
/// Compute SBOM delta between before and after remediation.
|
||||
/// </summary>
|
||||
Task<RemediationDelta> ComputeDeltaAsync(
|
||||
RemediationPlan plan,
|
||||
string beforeSbomPath,
|
||||
string afterSbomPath,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Sign the delta verdict with attestation.
|
||||
/// </summary>
|
||||
Task<SignedDeltaVerdict> SignDeltaAsync(
|
||||
RemediationDelta delta,
|
||||
IRemediationDeltaSigner signer,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Generate PR description with delta verdict.
|
||||
/// </summary>
|
||||
Task<string> GeneratePrDescriptionAsync(
|
||||
RemediationPlan plan,
|
||||
SignedDeltaVerdict signedDelta,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Signer interface for delta verdicts.
|
||||
/// </summary>
|
||||
public interface IRemediationDeltaSigner
|
||||
{
|
||||
string KeyId { get; }
|
||||
string Algorithm { get; }
|
||||
Task<byte[]> SignAsync(byte[] data, CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Delta result from remediation.
|
||||
/// </summary>
|
||||
public sealed record RemediationDelta
|
||||
{
|
||||
public required string DeltaId { get; init; }
|
||||
public required string PlanId { get; init; }
|
||||
public required string BeforeSbomDigest { get; init; }
|
||||
public required string AfterSbomDigest { get; init; }
|
||||
public required IReadOnlyList<ComponentChange> ComponentChanges { get; init; }
|
||||
public required IReadOnlyList<VulnerabilityChange> VulnerabilityChanges { get; init; }
|
||||
public required DeltaSummary Summary { get; init; }
|
||||
public required string ComputedAt { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A component change in the delta.
|
||||
/// </summary>
|
||||
public sealed record ComponentChange
|
||||
{
|
||||
public required string ChangeType { get; init; } // added, removed, upgraded
|
||||
public required string Purl { get; init; }
|
||||
public string? OldVersion { get; init; }
|
||||
public string? NewVersion { get; init; }
|
||||
public required IReadOnlyList<string> AffectedVulnerabilities { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A vulnerability change in the delta.
|
||||
/// </summary>
|
||||
public sealed record VulnerabilityChange
|
||||
{
|
||||
public required string ChangeType { get; init; } // fixed, introduced, status_changed
|
||||
public required string VulnerabilityId { get; init; }
|
||||
public required string Severity { get; init; }
|
||||
public string? OldStatus { get; init; }
|
||||
public string? NewStatus { get; init; }
|
||||
public required string ComponentPurl { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Summary of the delta.
|
||||
/// </summary>
|
||||
public sealed record DeltaSummary
|
||||
{
|
||||
public required int ComponentsAdded { get; init; }
|
||||
public required int ComponentsRemoved { get; init; }
|
||||
public required int ComponentsUpgraded { get; init; }
|
||||
public required int VulnerabilitiesFixed { get; init; }
|
||||
public required int VulnerabilitiesIntroduced { get; init; }
|
||||
public required int NetVulnerabilityChange { get; init; }
|
||||
public required bool IsImprovement { get; init; }
|
||||
public required string RiskTrend { get; init; } // improved, degraded, stable
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Signed delta verdict.
|
||||
/// </summary>
|
||||
public sealed record SignedDeltaVerdict
|
||||
{
|
||||
public required RemediationDelta Delta { get; init; }
|
||||
public required string SignatureId { get; init; }
|
||||
public required string KeyId { get; init; }
|
||||
public required string Algorithm { get; init; }
|
||||
public required string Signature { get; init; }
|
||||
public required string SignedAt { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Default implementation of remediation delta service.
|
||||
/// </summary>
|
||||
public sealed class RemediationDeltaService : IRemediationDeltaService
|
||||
{
|
||||
public async Task<RemediationDelta> ComputeDeltaAsync(
|
||||
RemediationPlan plan,
|
||||
string beforeSbomPath,
|
||||
string afterSbomPath,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// In production, this would use the DeltaComputationEngine
|
||||
// For now, create delta from the plan's expected delta
|
||||
|
||||
var componentChanges = new List<ComponentChange>();
|
||||
var vulnChanges = new List<VulnerabilityChange>();
|
||||
|
||||
// Convert expected delta to component changes
|
||||
foreach (var (oldPurl, newPurl) in plan.ExpectedDelta.Upgraded)
|
||||
{
|
||||
componentChanges.Add(new ComponentChange
|
||||
{
|
||||
ChangeType = "upgraded",
|
||||
Purl = oldPurl,
|
||||
OldVersion = ExtractVersion(oldPurl),
|
||||
NewVersion = ExtractVersion(newPurl),
|
||||
AffectedVulnerabilities = new[] { plan.Request.VulnerabilityId }
|
||||
});
|
||||
}
|
||||
|
||||
foreach (var purl in plan.ExpectedDelta.Added)
|
||||
{
|
||||
componentChanges.Add(new ComponentChange
|
||||
{
|
||||
ChangeType = "added",
|
||||
Purl = purl,
|
||||
AffectedVulnerabilities = Array.Empty<string>()
|
||||
});
|
||||
}
|
||||
|
||||
foreach (var purl in plan.ExpectedDelta.Removed)
|
||||
{
|
||||
componentChanges.Add(new ComponentChange
|
||||
{
|
||||
ChangeType = "removed",
|
||||
Purl = purl,
|
||||
AffectedVulnerabilities = Array.Empty<string>()
|
||||
});
|
||||
}
|
||||
|
||||
// Add vulnerability fix
|
||||
vulnChanges.Add(new VulnerabilityChange
|
||||
{
|
||||
ChangeType = "fixed",
|
||||
VulnerabilityId = plan.Request.VulnerabilityId,
|
||||
Severity = "high", // Would come from advisory data
|
||||
OldStatus = "affected",
|
||||
NewStatus = "fixed",
|
||||
ComponentPurl = plan.Request.ComponentPurl
|
||||
});
|
||||
|
||||
var summary = new DeltaSummary
|
||||
{
|
||||
ComponentsAdded = plan.ExpectedDelta.Added.Count,
|
||||
ComponentsRemoved = plan.ExpectedDelta.Removed.Count,
|
||||
ComponentsUpgraded = plan.ExpectedDelta.Upgraded.Count,
|
||||
VulnerabilitiesFixed = Math.Abs(Math.Min(0, plan.ExpectedDelta.NetVulnerabilityChange)),
|
||||
VulnerabilitiesIntroduced = Math.Max(0, plan.ExpectedDelta.NetVulnerabilityChange),
|
||||
NetVulnerabilityChange = plan.ExpectedDelta.NetVulnerabilityChange,
|
||||
IsImprovement = plan.ExpectedDelta.NetVulnerabilityChange < 0,
|
||||
RiskTrend = plan.ExpectedDelta.NetVulnerabilityChange < 0 ? "improved" :
|
||||
plan.ExpectedDelta.NetVulnerabilityChange > 0 ? "degraded" : "stable"
|
||||
};
|
||||
|
||||
var deltaId = $"delta-{plan.PlanId}-{DateTime.UtcNow:yyyyMMddHHmmss}";
|
||||
|
||||
return new RemediationDelta
|
||||
{
|
||||
DeltaId = deltaId,
|
||||
PlanId = plan.PlanId,
|
||||
BeforeSbomDigest = await ComputeFileDigestAsync(beforeSbomPath, cancellationToken),
|
||||
AfterSbomDigest = await ComputeFileDigestAsync(afterSbomPath, cancellationToken),
|
||||
ComponentChanges = componentChanges,
|
||||
VulnerabilityChanges = vulnChanges,
|
||||
Summary = summary,
|
||||
ComputedAt = DateTime.UtcNow.ToString("o")
|
||||
};
|
||||
}
|
||||
|
||||
public async Task<SignedDeltaVerdict> SignDeltaAsync(
|
||||
RemediationDelta delta,
|
||||
IRemediationDeltaSigner signer,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Serialize delta to canonical JSON for signing
|
||||
var deltaJson = System.Text.Json.JsonSerializer.Serialize(delta, new System.Text.Json.JsonSerializerOptions
|
||||
{
|
||||
WriteIndented = false,
|
||||
PropertyNamingPolicy = System.Text.Json.JsonNamingPolicy.SnakeCaseLower
|
||||
});
|
||||
|
||||
var dataToSign = Encoding.UTF8.GetBytes(deltaJson);
|
||||
var signature = await signer.SignAsync(dataToSign, cancellationToken);
|
||||
var signatureBase64 = Convert.ToBase64String(signature);
|
||||
var signatureId = $"sig-{delta.DeltaId}-{signer.KeyId[..8]}";
|
||||
|
||||
return new SignedDeltaVerdict
|
||||
{
|
||||
Delta = delta,
|
||||
SignatureId = signatureId,
|
||||
KeyId = signer.KeyId,
|
||||
Algorithm = signer.Algorithm,
|
||||
Signature = signatureBase64,
|
||||
SignedAt = DateTime.UtcNow.ToString("o")
|
||||
};
|
||||
}
|
||||
|
||||
public Task<string> GeneratePrDescriptionAsync(
|
||||
RemediationPlan plan,
|
||||
SignedDeltaVerdict signedDelta,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
|
||||
sb.AppendLine("## Security Remediation");
|
||||
sb.AppendLine();
|
||||
sb.AppendLine($"This PR remediates **{plan.Request.VulnerabilityId}** affecting `{plan.Request.ComponentPurl}`.");
|
||||
sb.AppendLine();
|
||||
|
||||
// Risk assessment
|
||||
sb.AppendLine("### Risk Assessment");
|
||||
sb.AppendLine();
|
||||
sb.AppendLine($"- **Risk Level**: {plan.RiskAssessment}");
|
||||
sb.AppendLine($"- **Confidence**: {plan.ConfidenceScore:P0}");
|
||||
sb.AppendLine($"- **Authority**: {plan.Authority}");
|
||||
sb.AppendLine();
|
||||
|
||||
// Changes
|
||||
sb.AppendLine("### Changes");
|
||||
sb.AppendLine();
|
||||
foreach (var step in plan.Steps)
|
||||
{
|
||||
sb.AppendLine($"- {step.Description}");
|
||||
if (!string.IsNullOrEmpty(step.PreviousValue) && !string.IsNullOrEmpty(step.NewValue))
|
||||
{
|
||||
sb.AppendLine($" - `{step.PreviousValue}` → `{step.NewValue}`");
|
||||
}
|
||||
}
|
||||
sb.AppendLine();
|
||||
|
||||
// Delta verdict
|
||||
sb.AppendLine("### Delta Verdict");
|
||||
sb.AppendLine();
|
||||
var summary = signedDelta.Delta.Summary;
|
||||
var trendEmoji = summary.RiskTrend switch
|
||||
{
|
||||
"improved" => "✅",
|
||||
"degraded" => "⚠️",
|
||||
_ => "➖"
|
||||
};
|
||||
sb.AppendLine($"{trendEmoji} **{summary.RiskTrend.ToUpperInvariant()}**");
|
||||
sb.AppendLine();
|
||||
sb.AppendLine($"| Metric | Count |");
|
||||
sb.AppendLine($"|--------|-------|");
|
||||
sb.AppendLine($"| Vulnerabilities Fixed | {summary.VulnerabilitiesFixed} |");
|
||||
sb.AppendLine($"| Vulnerabilities Introduced | {summary.VulnerabilitiesIntroduced} |");
|
||||
sb.AppendLine($"| Net Change | {summary.NetVulnerabilityChange} |");
|
||||
sb.AppendLine($"| Components Upgraded | {summary.ComponentsUpgraded} |");
|
||||
sb.AppendLine();
|
||||
|
||||
// Signature verification
|
||||
sb.AppendLine("### Attestation");
|
||||
sb.AppendLine();
|
||||
sb.AppendLine("```");
|
||||
sb.AppendLine($"Delta ID: {signedDelta.Delta.DeltaId}");
|
||||
sb.AppendLine($"Signature ID: {signedDelta.SignatureId}");
|
||||
sb.AppendLine($"Algorithm: {signedDelta.Algorithm}");
|
||||
sb.AppendLine($"Signed At: {signedDelta.SignedAt}");
|
||||
sb.AppendLine("```");
|
||||
sb.AppendLine();
|
||||
|
||||
// Footer
|
||||
sb.AppendLine("---");
|
||||
sb.AppendLine($"*Generated by StellaOps Remedy Autopilot using {plan.ModelId}*");
|
||||
|
||||
return Task.FromResult(sb.ToString());
|
||||
}
|
||||
|
||||
private static string ExtractVersion(string purl)
|
||||
{
|
||||
// Extract version from PURL like pkg:npm/lodash@4.17.21
|
||||
var atIndex = purl.LastIndexOf('@');
|
||||
return atIndex >= 0 ? purl[(atIndex + 1)..] : "unknown";
|
||||
}
|
||||
|
||||
private static async Task<string> ComputeFileDigestAsync(
|
||||
string filePath,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
if (!File.Exists(filePath))
|
||||
{
|
||||
return "file-not-found";
|
||||
}
|
||||
|
||||
await using var stream = File.OpenRead(filePath);
|
||||
var hash = await System.Security.Cryptography.SHA256.HashDataAsync(stream, cancellationToken);
|
||||
return Convert.ToHexStringLower(hash);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,386 @@
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Remediation.ScmConnector;
|
||||
|
||||
/// <summary>
|
||||
/// Azure DevOps SCM connector plugin.
|
||||
/// Supports Azure DevOps Services and Azure DevOps Server.
|
||||
/// </summary>
|
||||
public sealed class AzureDevOpsScmConnectorPlugin : IScmConnectorPlugin
|
||||
{
|
||||
public string ScmType => "azuredevops";
|
||||
public string DisplayName => "Azure DevOps";
|
||||
|
||||
public bool IsAvailable(ScmConnectorOptions options) =>
|
||||
!string.IsNullOrEmpty(options.ApiToken);
|
||||
|
||||
public bool CanHandle(string repositoryUrl) =>
|
||||
repositoryUrl.Contains("dev.azure.com", StringComparison.OrdinalIgnoreCase) ||
|
||||
repositoryUrl.Contains("visualstudio.com", StringComparison.OrdinalIgnoreCase) ||
|
||||
repositoryUrl.Contains("azure.com", StringComparison.OrdinalIgnoreCase);
|
||||
|
||||
public IScmConnector Create(ScmConnectorOptions options, HttpClient httpClient) =>
|
||||
new AzureDevOpsScmConnector(httpClient, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Azure DevOps SCM connector implementation.
|
||||
/// API Reference: https://learn.microsoft.com/en-us/rest/api/azure/devops/
|
||||
/// </summary>
|
||||
public sealed class AzureDevOpsScmConnector : ScmConnectorBase
|
||||
{
|
||||
private readonly string _baseUrl;
|
||||
private const string ApiVersion = "7.1";
|
||||
|
||||
public AzureDevOpsScmConnector(HttpClient httpClient, ScmConnectorOptions options)
|
||||
: base(httpClient, options)
|
||||
{
|
||||
_baseUrl = options.BaseUrl ?? "https://dev.azure.com";
|
||||
}
|
||||
|
||||
public override string ScmType => "azuredevops";
|
||||
|
||||
protected override void ConfigureAuthentication()
|
||||
{
|
||||
// Azure DevOps uses Basic auth with PAT (empty username, token as password)
|
||||
var credentials = Convert.ToBase64String(Encoding.ASCII.GetBytes($":{Options.ApiToken}"));
|
||||
HttpClient.DefaultRequestHeaders.Authorization =
|
||||
new System.Net.Http.Headers.AuthenticationHeaderValue("Basic", credentials);
|
||||
}
|
||||
|
||||
public override async Task<BranchResult> CreateBranchAsync(
|
||||
string owner, string repo, string branchName, string baseBranch,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Get the base branch ref
|
||||
var refsUrl = $"{_baseUrl}/{owner}/{repo}/_apis/git/refs?filter=heads/{baseBranch}&api-version={ApiVersion}";
|
||||
var refs = await GetJsonAsync<JsonElement>(refsUrl, cancellationToken);
|
||||
|
||||
if (refs.ValueKind == JsonValueKind.Undefined ||
|
||||
!refs.TryGetProperty("value", out var refArray) ||
|
||||
refArray.GetArrayLength() == 0)
|
||||
{
|
||||
return new BranchResult
|
||||
{
|
||||
Success = false,
|
||||
BranchName = branchName,
|
||||
ErrorMessage = $"Base branch '{baseBranch}' not found"
|
||||
};
|
||||
}
|
||||
|
||||
var baseSha = refArray[0].GetProperty("objectId").GetString();
|
||||
|
||||
// Create new branch
|
||||
var payload = new[]
|
||||
{
|
||||
new
|
||||
{
|
||||
name = $"refs/heads/{branchName}",
|
||||
oldObjectId = "0000000000000000000000000000000000000000",
|
||||
newObjectId = baseSha
|
||||
}
|
||||
};
|
||||
|
||||
var (success, _) = await PostJsonAsync(
|
||||
$"{_baseUrl}/{owner}/{repo}/_apis/git/refs?api-version={ApiVersion}",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
return new BranchResult
|
||||
{
|
||||
Success = success,
|
||||
BranchName = branchName,
|
||||
CommitSha = baseSha,
|
||||
ErrorMessage = success ? null : "Failed to create branch"
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<FileUpdateResult> UpdateFileAsync(
|
||||
string owner, string repo, string branch, string filePath,
|
||||
string content, string commitMessage,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Get the latest commit on the branch
|
||||
var branchUrl = $"{_baseUrl}/{owner}/{repo}/_apis/git/refs?filter=heads/{branch}&api-version={ApiVersion}";
|
||||
var branchRef = await GetJsonAsync<JsonElement>(branchUrl, cancellationToken);
|
||||
|
||||
if (branchRef.ValueKind == JsonValueKind.Undefined ||
|
||||
!branchRef.TryGetProperty("value", out var refArray) ||
|
||||
refArray.GetArrayLength() == 0)
|
||||
{
|
||||
return new FileUpdateResult
|
||||
{
|
||||
Success = false,
|
||||
FilePath = filePath,
|
||||
ErrorMessage = "Branch not found"
|
||||
};
|
||||
}
|
||||
|
||||
var oldObjectId = refArray[0].GetProperty("objectId").GetString();
|
||||
|
||||
// Create a push with the file change
|
||||
var payload = new
|
||||
{
|
||||
refUpdates = new[]
|
||||
{
|
||||
new
|
||||
{
|
||||
name = $"refs/heads/{branch}",
|
||||
oldObjectId
|
||||
}
|
||||
},
|
||||
commits = new[]
|
||||
{
|
||||
new
|
||||
{
|
||||
comment = commitMessage,
|
||||
changes = new[]
|
||||
{
|
||||
new
|
||||
{
|
||||
changeType = "edit",
|
||||
item = new { path = $"/{filePath}" },
|
||||
newContent = new
|
||||
{
|
||||
content,
|
||||
contentType = "rawtext"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
var (success, result) = await PostJsonAsync(
|
||||
$"{_baseUrl}/{owner}/{repo}/_apis/git/pushes?api-version={ApiVersion}",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
string? commitSha = null;
|
||||
if (success && result.ValueKind != JsonValueKind.Undefined &&
|
||||
result.TryGetProperty("commits", out var commits) &&
|
||||
commits.GetArrayLength() > 0)
|
||||
{
|
||||
commitSha = commits[0].GetProperty("commitId").GetString();
|
||||
}
|
||||
|
||||
return new FileUpdateResult
|
||||
{
|
||||
Success = success,
|
||||
FilePath = filePath,
|
||||
CommitSha = commitSha,
|
||||
ErrorMessage = success ? null : "Failed to update file"
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<PrCreateResult> CreatePullRequestAsync(
|
||||
string owner, string repo, string headBranch, string baseBranch,
|
||||
string title, string body,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var payload = new
|
||||
{
|
||||
sourceRefName = $"refs/heads/{headBranch}",
|
||||
targetRefName = $"refs/heads/{baseBranch}",
|
||||
title,
|
||||
description = body
|
||||
};
|
||||
|
||||
var (success, result) = await PostJsonAsync(
|
||||
$"{_baseUrl}/{owner}/{repo}/_apis/git/pullrequests?api-version={ApiVersion}",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
if (!success || result.ValueKind == JsonValueKind.Undefined)
|
||||
{
|
||||
return new PrCreateResult
|
||||
{
|
||||
Success = false,
|
||||
PrNumber = 0,
|
||||
PrUrl = string.Empty,
|
||||
ErrorMessage = "Failed to create pull request"
|
||||
};
|
||||
}
|
||||
|
||||
var prId = result.GetProperty("pullRequestId").GetInt32();
|
||||
|
||||
return new PrCreateResult
|
||||
{
|
||||
Success = true,
|
||||
PrNumber = prId,
|
||||
PrUrl = $"{_baseUrl}/{owner}/{repo}/_git/{repo}/pullrequest/{prId}"
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<PrStatusResult> GetPullRequestStatusAsync(
|
||||
string owner, string repo, int prNumber,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var pr = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/{owner}/{repo}/_apis/git/pullrequests/{prNumber}?api-version={ApiVersion}",
|
||||
cancellationToken);
|
||||
|
||||
if (pr.ValueKind == JsonValueKind.Undefined)
|
||||
{
|
||||
return new PrStatusResult
|
||||
{
|
||||
Success = false,
|
||||
PrNumber = prNumber,
|
||||
State = PrState.Open,
|
||||
HeadSha = string.Empty,
|
||||
HeadBranch = string.Empty,
|
||||
BaseBranch = string.Empty,
|
||||
Title = string.Empty,
|
||||
Mergeable = false,
|
||||
ErrorMessage = "PR not found"
|
||||
};
|
||||
}
|
||||
|
||||
var status = pr.GetProperty("status").GetString() ?? "active";
|
||||
var prState = status switch
|
||||
{
|
||||
"completed" => PrState.Merged,
|
||||
"abandoned" => PrState.Closed,
|
||||
_ => PrState.Open
|
||||
};
|
||||
|
||||
var sourceRef = pr.GetProperty("sourceRefName").GetString() ?? string.Empty;
|
||||
var targetRef = pr.GetProperty("targetRefName").GetString() ?? string.Empty;
|
||||
|
||||
return new PrStatusResult
|
||||
{
|
||||
Success = true,
|
||||
PrNumber = prNumber,
|
||||
State = prState,
|
||||
HeadSha = pr.GetProperty("lastMergeSourceCommit").GetProperty("commitId").GetString() ?? string.Empty,
|
||||
HeadBranch = sourceRef.Replace("refs/heads/", ""),
|
||||
BaseBranch = targetRef.Replace("refs/heads/", ""),
|
||||
Title = pr.GetProperty("title").GetString() ?? string.Empty,
|
||||
Body = pr.TryGetProperty("description", out var d) ? d.GetString() : null,
|
||||
PrUrl = $"{_baseUrl}/{owner}/{repo}/_git/{repo}/pullrequest/{prNumber}",
|
||||
Mergeable = pr.TryGetProperty("mergeStatus", out var ms) &&
|
||||
ms.GetString() == "succeeded"
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<CiStatusResult> GetCiStatusAsync(
|
||||
string owner, string repo, string commitSha,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Get build status for the commit
|
||||
var builds = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/{owner}/{repo}/_apis/build/builds?sourceVersion={commitSha}&api-version={ApiVersion}",
|
||||
cancellationToken);
|
||||
|
||||
var checks = new List<CiCheck>();
|
||||
|
||||
if (builds.ValueKind != JsonValueKind.Undefined &&
|
||||
builds.TryGetProperty("value", out var buildArray))
|
||||
{
|
||||
foreach (var build in buildArray.EnumerateArray())
|
||||
{
|
||||
var buildStatus = build.GetProperty("status").GetString() ?? "notStarted";
|
||||
var buildResult = build.TryGetProperty("result", out var r) ? r.GetString() : null;
|
||||
|
||||
var state = buildResult != null
|
||||
? MapBuildResultToCiState(buildResult)
|
||||
: MapBuildStatusToCiState(buildStatus);
|
||||
|
||||
checks.Add(new CiCheck
|
||||
{
|
||||
Name = build.GetProperty("definition").GetProperty("name").GetString() ?? "unknown",
|
||||
State = state,
|
||||
Description = build.TryGetProperty("buildNumber", out var bn) ? bn.GetString() : null,
|
||||
TargetUrl = build.TryGetProperty("_links", out var links) &&
|
||||
links.TryGetProperty("web", out var web) &&
|
||||
web.TryGetProperty("href", out var href) ? href.GetString() : null,
|
||||
StartedAt = build.TryGetProperty("startTime", out var st) ? st.GetString() : null,
|
||||
CompletedAt = build.TryGetProperty("finishTime", out var ft) ? ft.GetString() : null
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
var overallState = checks.Count > 0 ? DetermineOverallState(checks) : CiState.Unknown;
|
||||
|
||||
return new CiStatusResult
|
||||
{
|
||||
Success = true,
|
||||
OverallState = overallState,
|
||||
Checks = checks
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<bool> UpdatePullRequestAsync(
|
||||
string owner, string repo, int prNumber, string? title, string? body,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var payload = new Dictionary<string, string>();
|
||||
if (title != null) payload["title"] = title;
|
||||
if (body != null) payload["description"] = body;
|
||||
|
||||
return await PatchJsonAsync(
|
||||
$"{_baseUrl}/{owner}/{repo}/_apis/git/pullrequests/{prNumber}?api-version={ApiVersion}",
|
||||
payload,
|
||||
cancellationToken);
|
||||
}
|
||||
|
||||
public override async Task<bool> AddCommentAsync(
|
||||
string owner, string repo, int prNumber, string comment,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var payload = new
|
||||
{
|
||||
comments = new[]
|
||||
{
|
||||
new { content = comment }
|
||||
},
|
||||
status = "active"
|
||||
};
|
||||
|
||||
var (success, _) = await PostJsonAsync(
|
||||
$"{_baseUrl}/{owner}/{repo}/_apis/git/repositories/{repo}/pullRequests/{prNumber}/threads?api-version={ApiVersion}",
|
||||
payload,
|
||||
cancellationToken);
|
||||
return success;
|
||||
}
|
||||
|
||||
public override async Task<bool> ClosePullRequestAsync(
|
||||
string owner, string repo, int prNumber,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
return await PatchJsonAsync(
|
||||
$"{_baseUrl}/{owner}/{repo}/_apis/git/pullrequests/{prNumber}?api-version={ApiVersion}",
|
||||
new { status = "abandoned" },
|
||||
cancellationToken);
|
||||
}
|
||||
|
||||
private static CiState MapBuildStatusToCiState(string status) => status switch
|
||||
{
|
||||
"notStarted" or "postponed" => CiState.Pending,
|
||||
"inProgress" => CiState.Running,
|
||||
"completed" => CiState.Success,
|
||||
"cancelling" or "none" => CiState.Unknown,
|
||||
_ => CiState.Unknown
|
||||
};
|
||||
|
||||
private static CiState MapBuildResultToCiState(string result) => result switch
|
||||
{
|
||||
"succeeded" => CiState.Success,
|
||||
"partiallySucceeded" => CiState.Success,
|
||||
"failed" => CiState.Failure,
|
||||
"canceled" => CiState.Error,
|
||||
_ => CiState.Unknown
|
||||
};
|
||||
|
||||
private static CiState DetermineOverallState(IReadOnlyList<CiCheck> checks)
|
||||
{
|
||||
if (checks.Count == 0) return CiState.Unknown;
|
||||
if (checks.Any(c => c.State == CiState.Failure)) return CiState.Failure;
|
||||
if (checks.Any(c => c.State == CiState.Error)) return CiState.Error;
|
||||
if (checks.Any(c => c.State == CiState.Running)) return CiState.Running;
|
||||
if (checks.Any(c => c.State == CiState.Pending)) return CiState.Pending;
|
||||
if (checks.All(c => c.State == CiState.Success)) return CiState.Success;
|
||||
return CiState.Unknown;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
using System.Text.Json;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Remediation.ScmConnector;
|
||||
|
||||
/// <summary>
|
||||
/// GitHub SCM connector plugin.
|
||||
/// Supports github.com and GitHub Enterprise Server.
|
||||
/// </summary>
|
||||
public sealed class GitHubScmConnectorPlugin : IScmConnectorPlugin
|
||||
{
|
||||
public string ScmType => "github";
|
||||
public string DisplayName => "GitHub";
|
||||
|
||||
public bool IsAvailable(ScmConnectorOptions options) =>
|
||||
!string.IsNullOrEmpty(options.ApiToken);
|
||||
|
||||
public bool CanHandle(string repositoryUrl) =>
|
||||
repositoryUrl.Contains("github.com", StringComparison.OrdinalIgnoreCase) ||
|
||||
repositoryUrl.Contains("github.", StringComparison.OrdinalIgnoreCase);
|
||||
|
||||
public IScmConnector Create(ScmConnectorOptions options, HttpClient httpClient) =>
|
||||
new GitHubScmConnector(httpClient, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// GitHub SCM connector implementation.
|
||||
/// API Reference: https://docs.github.com/en/rest
|
||||
/// </summary>
|
||||
public sealed class GitHubScmConnector : ScmConnectorBase
|
||||
{
|
||||
private readonly string _baseUrl;
|
||||
|
||||
public GitHubScmConnector(HttpClient httpClient, ScmConnectorOptions options)
|
||||
: base(httpClient, options)
|
||||
{
|
||||
_baseUrl = options.BaseUrl ?? "https://api.github.com";
|
||||
}
|
||||
|
||||
public override string ScmType => "github";
|
||||
|
||||
protected override void ConfigureAuthentication()
|
||||
{
|
||||
HttpClient.DefaultRequestHeaders.Authorization =
|
||||
new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", Options.ApiToken);
|
||||
HttpClient.DefaultRequestHeaders.Accept.ParseAdd("application/vnd.github+json");
|
||||
HttpClient.DefaultRequestHeaders.Add("X-GitHub-Api-Version", "2022-11-28");
|
||||
}
|
||||
|
||||
public override async Task<BranchResult> CreateBranchAsync(
|
||||
string owner, string repo, string branchName, string baseBranch,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Get base branch SHA
|
||||
var refResponse = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/repos/{owner}/{repo}/git/refs/heads/{baseBranch}",
|
||||
cancellationToken);
|
||||
|
||||
if (refResponse.ValueKind == JsonValueKind.Undefined)
|
||||
{
|
||||
return new BranchResult
|
||||
{
|
||||
Success = false,
|
||||
BranchName = branchName,
|
||||
ErrorMessage = $"Base branch '{baseBranch}' not found"
|
||||
};
|
||||
}
|
||||
|
||||
var baseSha = refResponse.GetProperty("object").GetProperty("sha").GetString();
|
||||
|
||||
// Create new branch ref
|
||||
var payload = new { @ref = $"refs/heads/{branchName}", sha = baseSha };
|
||||
var (success, result) = await PostJsonAsync(
|
||||
$"{_baseUrl}/repos/{owner}/{repo}/git/refs",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
return new BranchResult
|
||||
{
|
||||
Success = success,
|
||||
BranchName = branchName,
|
||||
CommitSha = baseSha,
|
||||
ErrorMessage = success ? null : "Failed to create branch"
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<FileUpdateResult> UpdateFileAsync(
|
||||
string owner, string repo, string branch, string filePath,
|
||||
string content, string commitMessage,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Get existing file SHA if it exists
|
||||
string? fileSha = null;
|
||||
var existingFile = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/repos/{owner}/{repo}/contents/{filePath}?ref={branch}",
|
||||
cancellationToken);
|
||||
|
||||
if (existingFile.ValueKind != JsonValueKind.Undefined &&
|
||||
existingFile.TryGetProperty("sha", out var sha))
|
||||
{
|
||||
fileSha = sha.GetString();
|
||||
}
|
||||
|
||||
// Update or create file
|
||||
var payload = new
|
||||
{
|
||||
message = commitMessage,
|
||||
content = Base64Encode(content),
|
||||
branch,
|
||||
sha = fileSha
|
||||
};
|
||||
|
||||
var (success, result) = await PutJsonAsync(
|
||||
$"{_baseUrl}/repos/{owner}/{repo}/contents/{filePath}",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
string? commitSha = null;
|
||||
if (success && result.ValueKind != JsonValueKind.Undefined &&
|
||||
result.TryGetProperty("commit", out var commit) &&
|
||||
commit.TryGetProperty("sha", out var csha))
|
||||
{
|
||||
commitSha = csha.GetString();
|
||||
}
|
||||
|
||||
return new FileUpdateResult
|
||||
{
|
||||
Success = success,
|
||||
FilePath = filePath,
|
||||
CommitSha = commitSha,
|
||||
ErrorMessage = success ? null : "Failed to update file"
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<PrCreateResult> CreatePullRequestAsync(
|
||||
string owner, string repo, string headBranch, string baseBranch,
|
||||
string title, string body,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var payload = new
|
||||
{
|
||||
title,
|
||||
body,
|
||||
head = headBranch,
|
||||
@base = baseBranch
|
||||
};
|
||||
|
||||
var (success, result) = await PostJsonAsync(
|
||||
$"{_baseUrl}/repos/{owner}/{repo}/pulls",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
if (!success || result.ValueKind == JsonValueKind.Undefined)
|
||||
{
|
||||
return new PrCreateResult
|
||||
{
|
||||
Success = false,
|
||||
PrNumber = 0,
|
||||
PrUrl = string.Empty,
|
||||
ErrorMessage = "Failed to create pull request"
|
||||
};
|
||||
}
|
||||
|
||||
return new PrCreateResult
|
||||
{
|
||||
Success = true,
|
||||
PrNumber = result.GetProperty("number").GetInt32(),
|
||||
PrUrl = result.GetProperty("html_url").GetString() ?? string.Empty
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<PrStatusResult> GetPullRequestStatusAsync(
|
||||
string owner, string repo, int prNumber,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var pr = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/repos/{owner}/{repo}/pulls/{prNumber}",
|
||||
cancellationToken);
|
||||
|
||||
if (pr.ValueKind == JsonValueKind.Undefined)
|
||||
{
|
||||
return new PrStatusResult
|
||||
{
|
||||
Success = false,
|
||||
PrNumber = prNumber,
|
||||
State = PrState.Open,
|
||||
HeadSha = string.Empty,
|
||||
HeadBranch = string.Empty,
|
||||
BaseBranch = string.Empty,
|
||||
Title = string.Empty,
|
||||
Mergeable = false,
|
||||
ErrorMessage = "PR not found"
|
||||
};
|
||||
}
|
||||
|
||||
var state = pr.GetProperty("state").GetString() ?? "open";
|
||||
var merged = pr.TryGetProperty("merged", out var m) && m.GetBoolean();
|
||||
|
||||
return new PrStatusResult
|
||||
{
|
||||
Success = true,
|
||||
PrNumber = prNumber,
|
||||
State = merged ? PrState.Merged : state == "closed" ? PrState.Closed : PrState.Open,
|
||||
HeadSha = pr.GetProperty("head").GetProperty("sha").GetString() ?? string.Empty,
|
||||
HeadBranch = pr.GetProperty("head").GetProperty("ref").GetString() ?? string.Empty,
|
||||
BaseBranch = pr.GetProperty("base").GetProperty("ref").GetString() ?? string.Empty,
|
||||
Title = pr.GetProperty("title").GetString() ?? string.Empty,
|
||||
Body = pr.TryGetProperty("body", out var b) ? b.GetString() : null,
|
||||
PrUrl = pr.GetProperty("html_url").GetString(),
|
||||
Mergeable = pr.TryGetProperty("mergeable", out var mg) && mg.ValueKind == JsonValueKind.True
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<CiStatusResult> GetCiStatusAsync(
|
||||
string owner, string repo, string commitSha,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Get combined status
|
||||
var status = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/repos/{owner}/{repo}/commits/{commitSha}/status",
|
||||
cancellationToken);
|
||||
|
||||
// Get check runs (GitHub Actions)
|
||||
var checkRuns = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/repos/{owner}/{repo}/commits/{commitSha}/check-runs",
|
||||
cancellationToken);
|
||||
|
||||
var checks = new List<CiCheck>();
|
||||
|
||||
// Process commit statuses
|
||||
if (status.ValueKind != JsonValueKind.Undefined &&
|
||||
status.TryGetProperty("statuses", out var statuses))
|
||||
{
|
||||
foreach (var s in statuses.EnumerateArray())
|
||||
{
|
||||
checks.Add(new CiCheck
|
||||
{
|
||||
Name = s.GetProperty("context").GetString() ?? "unknown",
|
||||
State = MapToCiState(s.GetProperty("state").GetString() ?? "pending"),
|
||||
Description = s.TryGetProperty("description", out var d) ? d.GetString() : null,
|
||||
TargetUrl = s.TryGetProperty("target_url", out var u) ? u.GetString() : null
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Process check runs
|
||||
if (checkRuns.ValueKind != JsonValueKind.Undefined &&
|
||||
checkRuns.TryGetProperty("check_runs", out var runs))
|
||||
{
|
||||
foreach (var r in runs.EnumerateArray())
|
||||
{
|
||||
var conclusion = r.TryGetProperty("conclusion", out var c) ? c.GetString() : null;
|
||||
var runStatus = r.GetProperty("status").GetString() ?? "queued";
|
||||
|
||||
checks.Add(new CiCheck
|
||||
{
|
||||
Name = r.GetProperty("name").GetString() ?? "unknown",
|
||||
State = conclusion != null ? MapToCiState(conclusion) : MapToCiState(runStatus),
|
||||
Description = r.TryGetProperty("output", out var o) &&
|
||||
o.TryGetProperty("summary", out var sum) ? sum.GetString() : null,
|
||||
TargetUrl = r.TryGetProperty("html_url", out var u) ? u.GetString() : null,
|
||||
StartedAt = r.TryGetProperty("started_at", out var sa) ? sa.GetString() : null,
|
||||
CompletedAt = r.TryGetProperty("completed_at", out var ca) ? ca.GetString() : null
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
var overallState = DetermineOverallState(checks);
|
||||
|
||||
return new CiStatusResult
|
||||
{
|
||||
Success = true,
|
||||
OverallState = overallState,
|
||||
Checks = checks
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<bool> UpdatePullRequestAsync(
|
||||
string owner, string repo, int prNumber, string? title, string? body,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var payload = new Dictionary<string, string>();
|
||||
if (title != null) payload["title"] = title;
|
||||
if (body != null) payload["body"] = body;
|
||||
|
||||
return await PatchJsonAsync(
|
||||
$"{_baseUrl}/repos/{owner}/{repo}/pulls/{prNumber}",
|
||||
payload,
|
||||
cancellationToken);
|
||||
}
|
||||
|
||||
public override async Task<bool> AddCommentAsync(
|
||||
string owner, string repo, int prNumber, string comment,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var payload = new { body = comment };
|
||||
var (success, _) = await PostJsonAsync(
|
||||
$"{_baseUrl}/repos/{owner}/{repo}/issues/{prNumber}/comments",
|
||||
payload,
|
||||
cancellationToken);
|
||||
return success;
|
||||
}
|
||||
|
||||
public override async Task<bool> ClosePullRequestAsync(
|
||||
string owner, string repo, int prNumber,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
return await PatchJsonAsync(
|
||||
$"{_baseUrl}/repos/{owner}/{repo}/pulls/{prNumber}",
|
||||
new { state = "closed" },
|
||||
cancellationToken);
|
||||
}
|
||||
|
||||
private static CiState DetermineOverallState(IReadOnlyList<CiCheck> checks)
|
||||
{
|
||||
if (checks.Count == 0) return CiState.Unknown;
|
||||
if (checks.Any(c => c.State == CiState.Failure)) return CiState.Failure;
|
||||
if (checks.Any(c => c.State == CiState.Error)) return CiState.Error;
|
||||
if (checks.Any(c => c.State == CiState.Running)) return CiState.Running;
|
||||
if (checks.Any(c => c.State == CiState.Pending)) return CiState.Pending;
|
||||
if (checks.All(c => c.State == CiState.Success)) return CiState.Success;
|
||||
return CiState.Unknown;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,335 @@
|
||||
using System.Text.Json;
|
||||
using System.Web;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Remediation.ScmConnector;
|
||||
|
||||
/// <summary>
|
||||
/// GitLab SCM connector plugin.
|
||||
/// Supports gitlab.com and self-hosted GitLab instances.
|
||||
/// </summary>
|
||||
public sealed class GitLabScmConnectorPlugin : IScmConnectorPlugin
|
||||
{
|
||||
public string ScmType => "gitlab";
|
||||
public string DisplayName => "GitLab";
|
||||
|
||||
public bool IsAvailable(ScmConnectorOptions options) =>
|
||||
!string.IsNullOrEmpty(options.ApiToken);
|
||||
|
||||
public bool CanHandle(string repositoryUrl) =>
|
||||
repositoryUrl.Contains("gitlab.com", StringComparison.OrdinalIgnoreCase) ||
|
||||
repositoryUrl.Contains("gitlab.", StringComparison.OrdinalIgnoreCase);
|
||||
|
||||
public IScmConnector Create(ScmConnectorOptions options, HttpClient httpClient) =>
|
||||
new GitLabScmConnector(httpClient, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// GitLab SCM connector implementation.
|
||||
/// API Reference: https://docs.gitlab.com/ee/api/rest/
|
||||
/// </summary>
|
||||
public sealed class GitLabScmConnector : ScmConnectorBase
|
||||
{
|
||||
private readonly string _baseUrl;
|
||||
|
||||
public GitLabScmConnector(HttpClient httpClient, ScmConnectorOptions options)
|
||||
: base(httpClient, options)
|
||||
{
|
||||
_baseUrl = options.BaseUrl ?? "https://gitlab.com/api/v4";
|
||||
}
|
||||
|
||||
public override string ScmType => "gitlab";
|
||||
|
||||
protected override void ConfigureAuthentication()
|
||||
{
|
||||
HttpClient.DefaultRequestHeaders.Add("PRIVATE-TOKEN", Options.ApiToken);
|
||||
}
|
||||
|
||||
private static string EncodeProjectPath(string owner, string repo) =>
|
||||
HttpUtility.UrlEncode($"{owner}/{repo}");
|
||||
|
||||
public override async Task<BranchResult> CreateBranchAsync(
|
||||
string owner, string repo, string branchName, string baseBranch,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var projectPath = EncodeProjectPath(owner, repo);
|
||||
|
||||
var payload = new
|
||||
{
|
||||
branch = branchName,
|
||||
@ref = baseBranch
|
||||
};
|
||||
|
||||
var (success, result) = await PostJsonAsync(
|
||||
$"{_baseUrl}/projects/{projectPath}/repository/branches",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
string? commitSha = null;
|
||||
if (success && result.ValueKind != JsonValueKind.Undefined &&
|
||||
result.TryGetProperty("commit", out var commit) &&
|
||||
commit.TryGetProperty("id", out var id))
|
||||
{
|
||||
commitSha = id.GetString();
|
||||
}
|
||||
|
||||
return new BranchResult
|
||||
{
|
||||
Success = success,
|
||||
BranchName = branchName,
|
||||
CommitSha = commitSha,
|
||||
ErrorMessage = success ? null : "Failed to create branch"
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<FileUpdateResult> UpdateFileAsync(
|
||||
string owner, string repo, string branch, string filePath,
|
||||
string content, string commitMessage,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var projectPath = EncodeProjectPath(owner, repo);
|
||||
var encodedPath = HttpUtility.UrlEncode(filePath);
|
||||
|
||||
// Check if file exists to determine create vs update action
|
||||
var existingFile = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/projects/{projectPath}/repository/files/{encodedPath}?ref={branch}",
|
||||
cancellationToken);
|
||||
|
||||
var action = existingFile.ValueKind != JsonValueKind.Undefined ? "update" : "create";
|
||||
|
||||
// Use commits API for file changes (more reliable for both create and update)
|
||||
var payload = new
|
||||
{
|
||||
branch,
|
||||
commit_message = commitMessage,
|
||||
actions = new[]
|
||||
{
|
||||
new
|
||||
{
|
||||
action,
|
||||
file_path = filePath,
|
||||
content
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
var (success, result) = await PostJsonAsync(
|
||||
$"{_baseUrl}/projects/{projectPath}/repository/commits",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
string? commitSha = null;
|
||||
if (success && result.ValueKind != JsonValueKind.Undefined && result.TryGetProperty("id", out var id))
|
||||
{
|
||||
commitSha = id.GetString();
|
||||
}
|
||||
|
||||
return new FileUpdateResult
|
||||
{
|
||||
Success = success,
|
||||
FilePath = filePath,
|
||||
CommitSha = commitSha,
|
||||
ErrorMessage = success ? null : "Failed to update file"
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<PrCreateResult> CreatePullRequestAsync(
|
||||
string owner, string repo, string headBranch, string baseBranch,
|
||||
string title, string body,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var projectPath = EncodeProjectPath(owner, repo);
|
||||
|
||||
var payload = new
|
||||
{
|
||||
source_branch = headBranch,
|
||||
target_branch = baseBranch,
|
||||
title,
|
||||
description = body
|
||||
};
|
||||
|
||||
var (success, result) = await PostJsonAsync(
|
||||
$"{_baseUrl}/projects/{projectPath}/merge_requests",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
if (!success || result.ValueKind == JsonValueKind.Undefined)
|
||||
{
|
||||
return new PrCreateResult
|
||||
{
|
||||
Success = false,
|
||||
PrNumber = 0,
|
||||
PrUrl = string.Empty,
|
||||
ErrorMessage = "Failed to create merge request"
|
||||
};
|
||||
}
|
||||
|
||||
return new PrCreateResult
|
||||
{
|
||||
Success = true,
|
||||
PrNumber = result.GetProperty("iid").GetInt32(),
|
||||
PrUrl = result.GetProperty("web_url").GetString() ?? string.Empty
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<PrStatusResult> GetPullRequestStatusAsync(
|
||||
string owner, string repo, int prNumber,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var projectPath = EncodeProjectPath(owner, repo);
|
||||
|
||||
var mr = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/projects/{projectPath}/merge_requests/{prNumber}",
|
||||
cancellationToken);
|
||||
|
||||
if (mr.ValueKind == JsonValueKind.Undefined)
|
||||
{
|
||||
return new PrStatusResult
|
||||
{
|
||||
Success = false,
|
||||
PrNumber = prNumber,
|
||||
State = PrState.Open,
|
||||
HeadSha = string.Empty,
|
||||
HeadBranch = string.Empty,
|
||||
BaseBranch = string.Empty,
|
||||
Title = string.Empty,
|
||||
Mergeable = false,
|
||||
ErrorMessage = "MR not found"
|
||||
};
|
||||
}
|
||||
|
||||
var state = mr.GetProperty("state").GetString() ?? "opened";
|
||||
var prState = state switch
|
||||
{
|
||||
"merged" => PrState.Merged,
|
||||
"closed" => PrState.Closed,
|
||||
_ => PrState.Open
|
||||
};
|
||||
|
||||
return new PrStatusResult
|
||||
{
|
||||
Success = true,
|
||||
PrNumber = prNumber,
|
||||
State = prState,
|
||||
HeadSha = mr.GetProperty("sha").GetString() ?? string.Empty,
|
||||
HeadBranch = mr.GetProperty("source_branch").GetString() ?? string.Empty,
|
||||
BaseBranch = mr.GetProperty("target_branch").GetString() ?? string.Empty,
|
||||
Title = mr.GetProperty("title").GetString() ?? string.Empty,
|
||||
Body = mr.TryGetProperty("description", out var d) ? d.GetString() : null,
|
||||
PrUrl = mr.GetProperty("web_url").GetString(),
|
||||
Mergeable = mr.TryGetProperty("merge_status", out var ms) &&
|
||||
ms.GetString() == "can_be_merged"
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<CiStatusResult> GetCiStatusAsync(
|
||||
string owner, string repo, string commitSha,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var projectPath = EncodeProjectPath(owner, repo);
|
||||
|
||||
// Get pipelines for the commit
|
||||
var pipelines = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/projects/{projectPath}/pipelines?sha={commitSha}",
|
||||
cancellationToken);
|
||||
|
||||
var checks = new List<CiCheck>();
|
||||
|
||||
if (pipelines.ValueKind == JsonValueKind.Array)
|
||||
{
|
||||
foreach (var pipeline in pipelines.EnumerateArray().Take(1)) // Most recent pipeline
|
||||
{
|
||||
var pipelineId = pipeline.GetProperty("id").GetInt32();
|
||||
var pipelineStatus = pipeline.GetProperty("status").GetString() ?? "pending";
|
||||
|
||||
// Get jobs for this pipeline
|
||||
var jobs = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/projects/{projectPath}/pipelines/{pipelineId}/jobs",
|
||||
cancellationToken);
|
||||
|
||||
if (jobs.ValueKind == JsonValueKind.Array)
|
||||
{
|
||||
foreach (var job in jobs.EnumerateArray())
|
||||
{
|
||||
checks.Add(new CiCheck
|
||||
{
|
||||
Name = job.GetProperty("name").GetString() ?? "unknown",
|
||||
State = MapToCiState(job.GetProperty("status").GetString() ?? "pending"),
|
||||
Description = job.TryGetProperty("stage", out var s) ? s.GetString() : null,
|
||||
TargetUrl = job.TryGetProperty("web_url", out var u) ? u.GetString() : null,
|
||||
StartedAt = job.TryGetProperty("started_at", out var sa) ? sa.GetString() : null,
|
||||
CompletedAt = job.TryGetProperty("finished_at", out var fa) ? fa.GetString() : null
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var overallState = checks.Count > 0 ? DetermineOverallState(checks) : CiState.Unknown;
|
||||
|
||||
return new CiStatusResult
|
||||
{
|
||||
Success = true,
|
||||
OverallState = overallState,
|
||||
Checks = checks
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<bool> UpdatePullRequestAsync(
|
||||
string owner, string repo, int prNumber, string? title, string? body,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var projectPath = EncodeProjectPath(owner, repo);
|
||||
var payload = new Dictionary<string, string>();
|
||||
if (title != null) payload["title"] = title;
|
||||
if (body != null) payload["description"] = body;
|
||||
|
||||
var request = new HttpRequestMessage(HttpMethod.Put,
|
||||
$"{_baseUrl}/projects/{projectPath}/merge_requests/{prNumber}")
|
||||
{
|
||||
Content = System.Net.Http.Json.JsonContent.Create(payload, options: JsonOptions)
|
||||
};
|
||||
|
||||
var response = await HttpClient.SendAsync(request, cancellationToken);
|
||||
return response.IsSuccessStatusCode;
|
||||
}
|
||||
|
||||
public override async Task<bool> AddCommentAsync(
|
||||
string owner, string repo, int prNumber, string comment,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var projectPath = EncodeProjectPath(owner, repo);
|
||||
var payload = new { body = comment };
|
||||
var (success, _) = await PostJsonAsync(
|
||||
$"{_baseUrl}/projects/{projectPath}/merge_requests/{prNumber}/notes",
|
||||
payload,
|
||||
cancellationToken);
|
||||
return success;
|
||||
}
|
||||
|
||||
public override async Task<bool> ClosePullRequestAsync(
|
||||
string owner, string repo, int prNumber,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var projectPath = EncodeProjectPath(owner, repo);
|
||||
var request = new HttpRequestMessage(HttpMethod.Put,
|
||||
$"{_baseUrl}/projects/{projectPath}/merge_requests/{prNumber}")
|
||||
{
|
||||
Content = System.Net.Http.Json.JsonContent.Create(
|
||||
new { state_event = "close" }, options: JsonOptions)
|
||||
};
|
||||
|
||||
var response = await HttpClient.SendAsync(request, cancellationToken);
|
||||
return response.IsSuccessStatusCode;
|
||||
}
|
||||
|
||||
private static CiState DetermineOverallState(IReadOnlyList<CiCheck> checks)
|
||||
{
|
||||
if (checks.Count == 0) return CiState.Unknown;
|
||||
if (checks.Any(c => c.State == CiState.Failure)) return CiState.Failure;
|
||||
if (checks.Any(c => c.State == CiState.Error)) return CiState.Error;
|
||||
if (checks.Any(c => c.State == CiState.Running)) return CiState.Running;
|
||||
if (checks.Any(c => c.State == CiState.Pending)) return CiState.Pending;
|
||||
if (checks.All(c => c.State == CiState.Success)) return CiState.Success;
|
||||
return CiState.Unknown;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,327 @@
|
||||
using System.Text.Json;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Remediation.ScmConnector;
|
||||
|
||||
/// <summary>
|
||||
/// Gitea SCM connector plugin.
|
||||
/// Supports Gitea and Forgejo instances.
|
||||
/// </summary>
|
||||
public sealed class GiteaScmConnectorPlugin : IScmConnectorPlugin
|
||||
{
|
||||
public string ScmType => "gitea";
|
||||
public string DisplayName => "Gitea";
|
||||
|
||||
public bool IsAvailable(ScmConnectorOptions options) =>
|
||||
!string.IsNullOrEmpty(options.ApiToken) &&
|
||||
!string.IsNullOrEmpty(options.BaseUrl);
|
||||
|
||||
public bool CanHandle(string repositoryUrl) =>
|
||||
// Gitea instances are self-hosted, so we rely on configuration
|
||||
// or explicit URL patterns
|
||||
repositoryUrl.Contains("gitea.", StringComparison.OrdinalIgnoreCase) ||
|
||||
repositoryUrl.Contains("forgejo.", StringComparison.OrdinalIgnoreCase) ||
|
||||
repositoryUrl.Contains("codeberg.org", StringComparison.OrdinalIgnoreCase);
|
||||
|
||||
public IScmConnector Create(ScmConnectorOptions options, HttpClient httpClient) =>
|
||||
new GiteaScmConnector(httpClient, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gitea SCM connector implementation.
|
||||
/// API Reference: https://docs.gitea.io/en-us/api-usage/
|
||||
/// Also compatible with Forgejo and Codeberg.
|
||||
/// </summary>
|
||||
public sealed class GiteaScmConnector : ScmConnectorBase
|
||||
{
|
||||
private readonly string _baseUrl;
|
||||
|
||||
public GiteaScmConnector(HttpClient httpClient, ScmConnectorOptions options)
|
||||
: base(httpClient, options)
|
||||
{
|
||||
_baseUrl = options.BaseUrl?.TrimEnd('/') ?? throw new ArgumentNullException(
|
||||
nameof(options), "BaseUrl is required for Gitea connector");
|
||||
}
|
||||
|
||||
public override string ScmType => "gitea";
|
||||
|
||||
protected override void ConfigureAuthentication()
|
||||
{
|
||||
HttpClient.DefaultRequestHeaders.Authorization =
|
||||
new System.Net.Http.Headers.AuthenticationHeaderValue("token", Options.ApiToken);
|
||||
}
|
||||
|
||||
public override async Task<BranchResult> CreateBranchAsync(
|
||||
string owner, string repo, string branchName, string baseBranch,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Get base branch SHA
|
||||
var branchInfo = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/api/v1/repos/{owner}/{repo}/branches/{baseBranch}",
|
||||
cancellationToken);
|
||||
|
||||
if (branchInfo.ValueKind == JsonValueKind.Undefined)
|
||||
{
|
||||
return new BranchResult
|
||||
{
|
||||
Success = false,
|
||||
BranchName = branchName,
|
||||
ErrorMessage = $"Base branch '{baseBranch}' not found"
|
||||
};
|
||||
}
|
||||
|
||||
var baseSha = branchInfo.GetProperty("commit").GetProperty("id").GetString();
|
||||
|
||||
// Create new branch
|
||||
var payload = new
|
||||
{
|
||||
new_branch_name = branchName,
|
||||
old_ref_name = baseBranch
|
||||
};
|
||||
|
||||
var (success, _) = await PostJsonAsync(
|
||||
$"{_baseUrl}/api/v1/repos/{owner}/{repo}/branches",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
return new BranchResult
|
||||
{
|
||||
Success = success,
|
||||
BranchName = branchName,
|
||||
CommitSha = baseSha,
|
||||
ErrorMessage = success ? null : "Failed to create branch"
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<FileUpdateResult> UpdateFileAsync(
|
||||
string owner, string repo, string branch, string filePath,
|
||||
string content, string commitMessage,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Check if file exists to get SHA
|
||||
var existingFile = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/api/v1/repos/{owner}/{repo}/contents/{filePath}?ref={branch}",
|
||||
cancellationToken);
|
||||
|
||||
string? fileSha = null;
|
||||
if (existingFile.ValueKind != JsonValueKind.Undefined &&
|
||||
existingFile.TryGetProperty("sha", out var sha))
|
||||
{
|
||||
fileSha = sha.GetString();
|
||||
}
|
||||
|
||||
// Update or create file
|
||||
var payload = new
|
||||
{
|
||||
message = commitMessage,
|
||||
content = Base64Encode(content),
|
||||
branch,
|
||||
sha = fileSha
|
||||
};
|
||||
|
||||
var (success, result) = await PutJsonAsync(
|
||||
$"{_baseUrl}/api/v1/repos/{owner}/{repo}/contents/{filePath}",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
string? commitSha = null;
|
||||
if (success && result.ValueKind != JsonValueKind.Undefined &&
|
||||
result.TryGetProperty("commit", out var commit) &&
|
||||
commit.TryGetProperty("sha", out var csha))
|
||||
{
|
||||
commitSha = csha.GetString();
|
||||
}
|
||||
|
||||
return new FileUpdateResult
|
||||
{
|
||||
Success = success,
|
||||
FilePath = filePath,
|
||||
CommitSha = commitSha,
|
||||
ErrorMessage = success ? null : "Failed to update file"
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<PrCreateResult> CreatePullRequestAsync(
|
||||
string owner, string repo, string headBranch, string baseBranch,
|
||||
string title, string body,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var payload = new
|
||||
{
|
||||
title,
|
||||
body,
|
||||
head = headBranch,
|
||||
@base = baseBranch
|
||||
};
|
||||
|
||||
var (success, result) = await PostJsonAsync(
|
||||
$"{_baseUrl}/api/v1/repos/{owner}/{repo}/pulls",
|
||||
payload,
|
||||
cancellationToken);
|
||||
|
||||
if (!success || result.ValueKind == JsonValueKind.Undefined)
|
||||
{
|
||||
return new PrCreateResult
|
||||
{
|
||||
Success = false,
|
||||
PrNumber = 0,
|
||||
PrUrl = string.Empty,
|
||||
ErrorMessage = "Failed to create pull request"
|
||||
};
|
||||
}
|
||||
|
||||
return new PrCreateResult
|
||||
{
|
||||
Success = true,
|
||||
PrNumber = result.GetProperty("number").GetInt32(),
|
||||
PrUrl = result.GetProperty("html_url").GetString() ?? string.Empty
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<PrStatusResult> GetPullRequestStatusAsync(
|
||||
string owner, string repo, int prNumber,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var pr = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/api/v1/repos/{owner}/{repo}/pulls/{prNumber}",
|
||||
cancellationToken);
|
||||
|
||||
if (pr.ValueKind == JsonValueKind.Undefined)
|
||||
{
|
||||
return new PrStatusResult
|
||||
{
|
||||
Success = false,
|
||||
PrNumber = prNumber,
|
||||
State = PrState.Open,
|
||||
HeadSha = string.Empty,
|
||||
HeadBranch = string.Empty,
|
||||
BaseBranch = string.Empty,
|
||||
Title = string.Empty,
|
||||
Mergeable = false,
|
||||
ErrorMessage = "PR not found"
|
||||
};
|
||||
}
|
||||
|
||||
var state = pr.GetProperty("state").GetString() ?? "open";
|
||||
var merged = pr.TryGetProperty("merged", out var m) && m.GetBoolean();
|
||||
|
||||
return new PrStatusResult
|
||||
{
|
||||
Success = true,
|
||||
PrNumber = prNumber,
|
||||
State = merged ? PrState.Merged : state == "closed" ? PrState.Closed : PrState.Open,
|
||||
HeadSha = pr.GetProperty("head").GetProperty("sha").GetString() ?? string.Empty,
|
||||
HeadBranch = pr.GetProperty("head").GetProperty("ref").GetString() ?? string.Empty,
|
||||
BaseBranch = pr.GetProperty("base").GetProperty("ref").GetString() ?? string.Empty,
|
||||
Title = pr.GetProperty("title").GetString() ?? string.Empty,
|
||||
Body = pr.TryGetProperty("body", out var b) ? b.GetString() : null,
|
||||
PrUrl = pr.GetProperty("html_url").GetString(),
|
||||
Mergeable = pr.TryGetProperty("mergeable", out var mg) && mg.GetBoolean()
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<CiStatusResult> GetCiStatusAsync(
|
||||
string owner, string repo, string commitSha,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Get combined commit status (from Gitea Actions and external CI)
|
||||
var status = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/api/v1/repos/{owner}/{repo}/commits/{commitSha}/status",
|
||||
cancellationToken);
|
||||
|
||||
var checks = new List<CiCheck>();
|
||||
|
||||
if (status.ValueKind != JsonValueKind.Undefined &&
|
||||
status.TryGetProperty("statuses", out var statuses))
|
||||
{
|
||||
foreach (var s in statuses.EnumerateArray())
|
||||
{
|
||||
checks.Add(new CiCheck
|
||||
{
|
||||
Name = s.GetProperty("context").GetString() ?? "unknown",
|
||||
State = MapToCiState(s.GetProperty("status").GetString() ?? "pending"),
|
||||
Description = s.TryGetProperty("description", out var d) ? d.GetString() : null,
|
||||
TargetUrl = s.TryGetProperty("target_url", out var u) ? u.GetString() : null
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Also get workflow runs if available (Gitea Actions)
|
||||
var runs = await GetJsonAsync<JsonElement>(
|
||||
$"{_baseUrl}/api/v1/repos/{owner}/{repo}/actions/runs?head_sha={commitSha}",
|
||||
cancellationToken);
|
||||
|
||||
if (runs.ValueKind != JsonValueKind.Undefined &&
|
||||
runs.TryGetProperty("workflow_runs", out var workflowRuns))
|
||||
{
|
||||
foreach (var run in workflowRuns.EnumerateArray())
|
||||
{
|
||||
var conclusion = run.TryGetProperty("conclusion", out var c) ? c.GetString() : null;
|
||||
var runStatus = run.GetProperty("status").GetString() ?? "queued";
|
||||
|
||||
checks.Add(new CiCheck
|
||||
{
|
||||
Name = run.GetProperty("name").GetString() ?? "workflow",
|
||||
State = conclusion != null ? MapToCiState(conclusion) : MapToCiState(runStatus),
|
||||
TargetUrl = run.TryGetProperty("html_url", out var u) ? u.GetString() : null,
|
||||
StartedAt = run.TryGetProperty("run_started_at", out var sa) ? sa.GetString() : null
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
var overallState = checks.Count > 0 ? DetermineOverallState(checks) : CiState.Unknown;
|
||||
|
||||
return new CiStatusResult
|
||||
{
|
||||
Success = true,
|
||||
OverallState = overallState,
|
||||
Checks = checks
|
||||
};
|
||||
}
|
||||
|
||||
public override async Task<bool> UpdatePullRequestAsync(
|
||||
string owner, string repo, int prNumber, string? title, string? body,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var payload = new Dictionary<string, string>();
|
||||
if (title != null) payload["title"] = title;
|
||||
if (body != null) payload["body"] = body;
|
||||
|
||||
return await PatchJsonAsync(
|
||||
$"{_baseUrl}/api/v1/repos/{owner}/{repo}/pulls/{prNumber}",
|
||||
payload,
|
||||
cancellationToken);
|
||||
}
|
||||
|
||||
public override async Task<bool> AddCommentAsync(
|
||||
string owner, string repo, int prNumber, string comment,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var payload = new { body = comment };
|
||||
var (success, _) = await PostJsonAsync(
|
||||
$"{_baseUrl}/api/v1/repos/{owner}/{repo}/issues/{prNumber}/comments",
|
||||
payload,
|
||||
cancellationToken);
|
||||
return success;
|
||||
}
|
||||
|
||||
public override async Task<bool> ClosePullRequestAsync(
|
||||
string owner, string repo, int prNumber,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
return await PatchJsonAsync(
|
||||
$"{_baseUrl}/api/v1/repos/{owner}/{repo}/pulls/{prNumber}",
|
||||
new { state = "closed" },
|
||||
cancellationToken);
|
||||
}
|
||||
|
||||
private static CiState DetermineOverallState(IReadOnlyList<CiCheck> checks)
|
||||
{
|
||||
if (checks.Count == 0) return CiState.Unknown;
|
||||
if (checks.Any(c => c.State == CiState.Failure)) return CiState.Failure;
|
||||
if (checks.Any(c => c.State == CiState.Error)) return CiState.Error;
|
||||
if (checks.Any(c => c.State == CiState.Running)) return CiState.Running;
|
||||
if (checks.Any(c => c.State == CiState.Pending)) return CiState.Pending;
|
||||
if (checks.All(c => c.State == CiState.Success)) return CiState.Success;
|
||||
return CiState.Unknown;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
namespace StellaOps.AdvisoryAI.Remediation.ScmConnector;
|
||||
|
||||
/// <summary>
|
||||
/// SCM connector plugin interface for customer premise integrations.
|
||||
/// Follows the StellaOps plugin pattern (IConnectorPlugin).
|
||||
/// Sprint: SPRINT_20251226_016_AI_remedy_autopilot
|
||||
/// Task: REMEDY-12, REMEDY-13, REMEDY-14
|
||||
/// </summary>
|
||||
public interface IScmConnectorPlugin
|
||||
{
|
||||
/// <summary>
|
||||
/// Unique identifier for this SCM type.
|
||||
/// </summary>
|
||||
string ScmType { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Display name for this SCM.
|
||||
/// </summary>
|
||||
string DisplayName { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Check if this connector is available with current configuration.
|
||||
/// </summary>
|
||||
bool IsAvailable(ScmConnectorOptions options);
|
||||
|
||||
/// <summary>
|
||||
/// Check if this connector can handle the given repository URL.
|
||||
/// </summary>
|
||||
bool CanHandle(string repositoryUrl);
|
||||
|
||||
/// <summary>
|
||||
/// Create a connector instance for the given options.
|
||||
/// </summary>
|
||||
IScmConnector Create(ScmConnectorOptions options, HttpClient httpClient);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Core SCM connector interface for PR operations.
|
||||
/// </summary>
|
||||
public interface IScmConnector
|
||||
{
|
||||
/// <summary>
|
||||
/// SCM type identifier.
|
||||
/// </summary>
|
||||
string ScmType { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a branch from the base branch.
|
||||
/// </summary>
|
||||
Task<BranchResult> CreateBranchAsync(
|
||||
string owner,
|
||||
string repo,
|
||||
string branchName,
|
||||
string baseBranch,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Update or create a file in a branch.
|
||||
/// </summary>
|
||||
Task<FileUpdateResult> UpdateFileAsync(
|
||||
string owner,
|
||||
string repo,
|
||||
string branch,
|
||||
string filePath,
|
||||
string content,
|
||||
string commitMessage,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Create a pull request / merge request.
|
||||
/// </summary>
|
||||
Task<PrCreateResult> CreatePullRequestAsync(
|
||||
string owner,
|
||||
string repo,
|
||||
string headBranch,
|
||||
string baseBranch,
|
||||
string title,
|
||||
string body,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Get pull request details and status.
|
||||
/// </summary>
|
||||
Task<PrStatusResult> GetPullRequestStatusAsync(
|
||||
string owner,
|
||||
string repo,
|
||||
int prNumber,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Get CI/CD pipeline status for a commit.
|
||||
/// </summary>
|
||||
Task<CiStatusResult> GetCiStatusAsync(
|
||||
string owner,
|
||||
string repo,
|
||||
string commitSha,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Update pull request body/description.
|
||||
/// </summary>
|
||||
Task<bool> UpdatePullRequestAsync(
|
||||
string owner,
|
||||
string repo,
|
||||
int prNumber,
|
||||
string? title,
|
||||
string? body,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Add a comment to a pull request.
|
||||
/// </summary>
|
||||
Task<bool> AddCommentAsync(
|
||||
string owner,
|
||||
string repo,
|
||||
int prNumber,
|
||||
string comment,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Close a pull request without merging.
|
||||
/// </summary>
|
||||
Task<bool> ClosePullRequestAsync(
|
||||
string owner,
|
||||
string repo,
|
||||
int prNumber,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Configuration options for SCM connectors.
|
||||
/// </summary>
|
||||
public sealed record ScmConnectorOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// SCM server base URL (for self-hosted instances).
|
||||
/// </summary>
|
||||
public string? BaseUrl { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Authentication token (PAT, OAuth token, etc.).
|
||||
/// </summary>
|
||||
public string? ApiToken { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// OAuth client ID (for OAuth flow).
|
||||
/// </summary>
|
||||
public string? ClientId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// OAuth client secret (for OAuth flow).
|
||||
/// </summary>
|
||||
public string? ClientSecret { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Default base branch for PRs.
|
||||
/// </summary>
|
||||
public string DefaultBaseBranch { get; init; } = "main";
|
||||
|
||||
/// <summary>
|
||||
/// Request timeout in seconds.
|
||||
/// </summary>
|
||||
public int TimeoutSeconds { get; init; } = 30;
|
||||
|
||||
/// <summary>
|
||||
/// User agent string for API requests.
|
||||
/// </summary>
|
||||
public string UserAgent { get; init; } = "StellaOps-Remedy/1.0";
|
||||
}
|
||||
|
||||
#region Result Types
|
||||
|
||||
/// <summary>
|
||||
/// Result of creating a branch.
|
||||
/// </summary>
|
||||
public sealed record BranchResult
|
||||
{
|
||||
public required bool Success { get; init; }
|
||||
public required string BranchName { get; init; }
|
||||
public string? CommitSha { get; init; }
|
||||
public string? ErrorMessage { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of updating a file.
|
||||
/// </summary>
|
||||
public sealed record FileUpdateResult
|
||||
{
|
||||
public required bool Success { get; init; }
|
||||
public required string FilePath { get; init; }
|
||||
public string? CommitSha { get; init; }
|
||||
public string? ErrorMessage { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of creating a PR.
|
||||
/// </summary>
|
||||
public sealed record PrCreateResult
|
||||
{
|
||||
public required bool Success { get; init; }
|
||||
public required int PrNumber { get; init; }
|
||||
public required string PrUrl { get; init; }
|
||||
public string? ErrorMessage { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// PR status result.
|
||||
/// </summary>
|
||||
public sealed record PrStatusResult
|
||||
{
|
||||
public required bool Success { get; init; }
|
||||
public required int PrNumber { get; init; }
|
||||
public required PrState State { get; init; }
|
||||
public required string HeadSha { get; init; }
|
||||
public required string HeadBranch { get; init; }
|
||||
public required string BaseBranch { get; init; }
|
||||
public required string Title { get; init; }
|
||||
public string? Body { get; init; }
|
||||
public string? PrUrl { get; init; }
|
||||
public required bool Mergeable { get; init; }
|
||||
public string? ErrorMessage { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// PR state.
|
||||
/// </summary>
|
||||
public enum PrState
|
||||
{
|
||||
Open,
|
||||
Closed,
|
||||
Merged,
|
||||
Draft
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// CI status result.
|
||||
/// </summary>
|
||||
public sealed record CiStatusResult
|
||||
{
|
||||
public required bool Success { get; init; }
|
||||
public required CiState OverallState { get; init; }
|
||||
public required IReadOnlyList<CiCheck> Checks { get; init; }
|
||||
public string? ErrorMessage { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Overall CI state.
|
||||
/// </summary>
|
||||
public enum CiState
|
||||
{
|
||||
Pending,
|
||||
Running,
|
||||
Success,
|
||||
Failure,
|
||||
Error,
|
||||
Unknown
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Individual CI check.
|
||||
/// </summary>
|
||||
public sealed record CiCheck
|
||||
{
|
||||
public required string Name { get; init; }
|
||||
public required CiState State { get; init; }
|
||||
public string? Description { get; init; }
|
||||
public string? TargetUrl { get; init; }
|
||||
public string? StartedAt { get; init; }
|
||||
public string? CompletedAt { get; init; }
|
||||
}
|
||||
|
||||
#endregion
|
||||
@@ -0,0 +1,159 @@
|
||||
using System.Net.Http.Json;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Remediation.ScmConnector;
|
||||
|
||||
/// <summary>
|
||||
/// Base class for SCM connectors with shared HTTP and JSON handling.
|
||||
/// </summary>
|
||||
public abstract class ScmConnectorBase : IScmConnector
|
||||
{
|
||||
protected readonly HttpClient HttpClient;
|
||||
protected readonly ScmConnectorOptions Options;
|
||||
|
||||
protected static readonly JsonSerializerOptions JsonOptions = new()
|
||||
{
|
||||
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
|
||||
PropertyNameCaseInsensitive = true,
|
||||
WriteIndented = false
|
||||
};
|
||||
|
||||
protected ScmConnectorBase(HttpClient httpClient, ScmConnectorOptions options)
|
||||
{
|
||||
HttpClient = httpClient;
|
||||
Options = options;
|
||||
ConfigureHttpClient();
|
||||
}
|
||||
|
||||
public abstract string ScmType { get; }
|
||||
|
||||
protected virtual void ConfigureHttpClient()
|
||||
{
|
||||
HttpClient.Timeout = TimeSpan.FromSeconds(Options.TimeoutSeconds);
|
||||
HttpClient.DefaultRequestHeaders.UserAgent.ParseAdd(Options.UserAgent);
|
||||
|
||||
if (!string.IsNullOrEmpty(Options.ApiToken))
|
||||
{
|
||||
ConfigureAuthentication();
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract void ConfigureAuthentication();
|
||||
|
||||
public abstract Task<BranchResult> CreateBranchAsync(
|
||||
string owner, string repo, string branchName, string baseBranch,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
public abstract Task<FileUpdateResult> UpdateFileAsync(
|
||||
string owner, string repo, string branch, string filePath,
|
||||
string content, string commitMessage,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
public abstract Task<PrCreateResult> CreatePullRequestAsync(
|
||||
string owner, string repo, string headBranch, string baseBranch,
|
||||
string title, string body,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
public abstract Task<PrStatusResult> GetPullRequestStatusAsync(
|
||||
string owner, string repo, int prNumber,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
public abstract Task<CiStatusResult> GetCiStatusAsync(
|
||||
string owner, string repo, string commitSha,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
public abstract Task<bool> UpdatePullRequestAsync(
|
||||
string owner, string repo, int prNumber, string? title, string? body,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
public abstract Task<bool> AddCommentAsync(
|
||||
string owner, string repo, int prNumber, string comment,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
public abstract Task<bool> ClosePullRequestAsync(
|
||||
string owner, string repo, int prNumber,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
#region Helper Methods
|
||||
|
||||
protected async Task<T?> GetJsonAsync<T>(string url, CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
var response = await HttpClient.GetAsync(url, cancellationToken);
|
||||
if (!response.IsSuccessStatusCode) return default;
|
||||
return await response.Content.ReadFromJsonAsync<T>(JsonOptions, cancellationToken);
|
||||
}
|
||||
catch
|
||||
{
|
||||
return default;
|
||||
}
|
||||
}
|
||||
|
||||
protected async Task<(bool Success, JsonElement Result)> PostJsonAsync(
|
||||
string url, object payload, CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
var response = await HttpClient.PostAsJsonAsync(url, payload, JsonOptions, cancellationToken);
|
||||
if (!response.IsSuccessStatusCode)
|
||||
return (false, default);
|
||||
var result = await response.Content.ReadFromJsonAsync<JsonElement>(JsonOptions, cancellationToken);
|
||||
return (true, result);
|
||||
}
|
||||
catch
|
||||
{
|
||||
return (false, default);
|
||||
}
|
||||
}
|
||||
|
||||
protected async Task<bool> PatchJsonAsync(string url, object payload, CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
var request = new HttpRequestMessage(HttpMethod.Patch, url)
|
||||
{
|
||||
Content = JsonContent.Create(payload, options: JsonOptions)
|
||||
};
|
||||
var response = await HttpClient.SendAsync(request, cancellationToken);
|
||||
return response.IsSuccessStatusCode;
|
||||
}
|
||||
catch
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
protected async Task<(bool Success, JsonElement Result)> PutJsonAsync(
|
||||
string url, object payload, CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
var response = await HttpClient.PutAsJsonAsync(url, payload, JsonOptions, cancellationToken);
|
||||
if (!response.IsSuccessStatusCode)
|
||||
return (false, default);
|
||||
var result = await response.Content.ReadFromJsonAsync<JsonElement>(JsonOptions, cancellationToken);
|
||||
return (true, result);
|
||||
}
|
||||
catch
|
||||
{
|
||||
return (false, default);
|
||||
}
|
||||
}
|
||||
|
||||
protected static string Base64Encode(string content) =>
|
||||
Convert.ToBase64String(Encoding.UTF8.GetBytes(content));
|
||||
|
||||
protected static CiState MapToCiState(string state) => state.ToLowerInvariant() switch
|
||||
{
|
||||
"pending" or "queued" or "waiting" => CiState.Pending,
|
||||
"in_progress" or "running" => CiState.Running,
|
||||
"success" or "succeeded" or "completed" => CiState.Success,
|
||||
"failure" or "failed" => CiState.Failure,
|
||||
"error" or "cancelled" or "canceled" or "timed_out" => CiState.Error,
|
||||
_ => CiState.Unknown
|
||||
};
|
||||
|
||||
#endregion
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Remediation.ScmConnector;
|
||||
|
||||
/// <summary>
|
||||
/// Catalog and factory for SCM connector plugins.
|
||||
/// Discovers and manages available SCM connectors for customer premise integrations.
|
||||
/// </summary>
|
||||
public sealed class ScmConnectorCatalog
|
||||
{
|
||||
private readonly IReadOnlyList<IScmConnectorPlugin> _plugins;
|
||||
private readonly IHttpClientFactory _httpClientFactory;
|
||||
|
||||
/// <summary>
|
||||
/// Create a catalog with default plugins (GitHub, GitLab, AzureDevOps, Gitea).
|
||||
/// </summary>
|
||||
public ScmConnectorCatalog(IHttpClientFactory httpClientFactory)
|
||||
{
|
||||
_httpClientFactory = httpClientFactory;
|
||||
_plugins = new List<IScmConnectorPlugin>
|
||||
{
|
||||
new GitHubScmConnectorPlugin(),
|
||||
new GitLabScmConnectorPlugin(),
|
||||
new AzureDevOpsScmConnectorPlugin(),
|
||||
new GiteaScmConnectorPlugin()
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a catalog with custom plugins.
|
||||
/// </summary>
|
||||
public ScmConnectorCatalog(
|
||||
IHttpClientFactory httpClientFactory,
|
||||
IEnumerable<IScmConnectorPlugin> plugins)
|
||||
{
|
||||
_httpClientFactory = httpClientFactory;
|
||||
_plugins = plugins.ToList();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get all registered plugins.
|
||||
/// </summary>
|
||||
public IReadOnlyList<IScmConnectorPlugin> Plugins => _plugins;
|
||||
|
||||
/// <summary>
|
||||
/// Get available plugins based on provided options.
|
||||
/// </summary>
|
||||
public IEnumerable<IScmConnectorPlugin> GetAvailablePlugins(ScmConnectorOptions options)
|
||||
{
|
||||
return _plugins.Where(p => p.IsAvailable(options));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get a connector by explicit SCM type.
|
||||
/// </summary>
|
||||
public IScmConnector? GetConnector(string scmType, ScmConnectorOptions options)
|
||||
{
|
||||
var plugin = _plugins.FirstOrDefault(p =>
|
||||
p.ScmType.Equals(scmType, StringComparison.OrdinalIgnoreCase));
|
||||
|
||||
if (plugin is null || !plugin.IsAvailable(options))
|
||||
return null;
|
||||
|
||||
var httpClient = CreateHttpClient(scmType, options);
|
||||
return plugin.Create(options, httpClient);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Auto-detect SCM type from repository URL and create connector.
|
||||
/// </summary>
|
||||
public IScmConnector? GetConnectorForRepository(string repositoryUrl, ScmConnectorOptions options)
|
||||
{
|
||||
var plugin = _plugins.FirstOrDefault(p => p.CanHandle(repositoryUrl));
|
||||
|
||||
if (plugin is null || !plugin.IsAvailable(options))
|
||||
return null;
|
||||
|
||||
var httpClient = CreateHttpClient(plugin.ScmType, options);
|
||||
return plugin.Create(options, httpClient);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a connector with explicit options override.
|
||||
/// </summary>
|
||||
public IScmConnector? GetConnector(
|
||||
string scmType,
|
||||
ScmConnectorOptions baseOptions,
|
||||
Action<ScmConnectorOptions>? configure)
|
||||
{
|
||||
var options = baseOptions with { };
|
||||
configure?.Invoke(options);
|
||||
return GetConnector(scmType, options);
|
||||
}
|
||||
|
||||
private HttpClient CreateHttpClient(string scmType, ScmConnectorOptions options)
|
||||
{
|
||||
var httpClient = _httpClientFactory.CreateClient($"ScmConnector_{scmType}");
|
||||
|
||||
if (!string.IsNullOrEmpty(options.BaseUrl))
|
||||
{
|
||||
httpClient.BaseAddress = new Uri(options.BaseUrl);
|
||||
}
|
||||
|
||||
return httpClient;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for dependency injection registration.
|
||||
/// </summary>
|
||||
public static class ScmConnectorServiceExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Add SCM connector services to the service collection.
|
||||
/// </summary>
|
||||
public static IServiceCollection AddScmConnectors(
|
||||
this IServiceCollection services,
|
||||
Action<ScmConnectorRegistration>? configure = null)
|
||||
{
|
||||
var registration = new ScmConnectorRegistration();
|
||||
configure?.Invoke(registration);
|
||||
|
||||
// Register HTTP clients for each SCM type
|
||||
services.AddHttpClient("ScmConnector_github");
|
||||
services.AddHttpClient("ScmConnector_gitlab");
|
||||
services.AddHttpClient("ScmConnector_azuredevops");
|
||||
services.AddHttpClient("ScmConnector_gitea");
|
||||
|
||||
// Register plugins
|
||||
foreach (var plugin in registration.Plugins)
|
||||
{
|
||||
services.AddSingleton(plugin);
|
||||
}
|
||||
|
||||
// Register the catalog
|
||||
services.AddSingleton<ScmConnectorCatalog>(sp =>
|
||||
{
|
||||
var httpClientFactory = sp.GetRequiredService<IHttpClientFactory>();
|
||||
var plugins = sp.GetServices<IScmConnectorPlugin>();
|
||||
return new ScmConnectorCatalog(httpClientFactory, plugins);
|
||||
});
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Registration builder for SCM connectors.
|
||||
/// </summary>
|
||||
public sealed class ScmConnectorRegistration
|
||||
{
|
||||
private readonly List<IScmConnectorPlugin> _plugins = new()
|
||||
{
|
||||
new GitHubScmConnectorPlugin(),
|
||||
new GitLabScmConnectorPlugin(),
|
||||
new AzureDevOpsScmConnectorPlugin(),
|
||||
new GiteaScmConnectorPlugin()
|
||||
};
|
||||
|
||||
public IReadOnlyList<IScmConnectorPlugin> Plugins => _plugins;
|
||||
|
||||
/// <summary>
|
||||
/// Add a custom SCM connector plugin.
|
||||
/// </summary>
|
||||
public ScmConnectorRegistration AddPlugin(IScmConnectorPlugin plugin)
|
||||
{
|
||||
_plugins.Add(plugin);
|
||||
return this;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Remove a built-in plugin by SCM type.
|
||||
/// </summary>
|
||||
public ScmConnectorRegistration RemovePlugin(string scmType)
|
||||
{
|
||||
_plugins.RemoveAll(p => p.ScmType.Equals(scmType, StringComparison.OrdinalIgnoreCase));
|
||||
return this;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Clear all plugins.
|
||||
/// </summary>
|
||||
public ScmConnectorRegistration ClearPlugins()
|
||||
{
|
||||
_plugins.Clear();
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
459
src/AdvisoryAI/StellaOps.AdvisoryAI/Replay/AIArtifactReplayer.cs
Normal file
459
src/AdvisoryAI/StellaOps.AdvisoryAI/Replay/AIArtifactReplayer.cs
Normal file
@@ -0,0 +1,459 @@
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
|
||||
namespace StellaOps.AdvisoryAI.Replay;
|
||||
|
||||
/// <summary>
|
||||
/// Replays AI artifact generation with deterministic verification.
|
||||
/// Sprint: SPRINT_20251226_019_AI_offline_inference
|
||||
/// Task: OFFLINE-18, OFFLINE-19
|
||||
/// </summary>
|
||||
public interface IAIArtifactReplayer
|
||||
{
|
||||
/// <summary>
|
||||
/// Replay an AI artifact generation from its manifest.
|
||||
/// </summary>
|
||||
Task<ReplayResult> ReplayAsync(
|
||||
AIArtifactReplayManifest manifest,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Detect divergence between original and replayed output.
|
||||
/// </summary>
|
||||
Task<DivergenceResult> DetectDivergenceAsync(
|
||||
AIArtifactReplayManifest originalManifest,
|
||||
string replayedOutput,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Verify a replay is identical to original.
|
||||
/// </summary>
|
||||
Task<ReplayVerificationResult> VerifyReplayAsync(
|
||||
AIArtifactReplayManifest manifest,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Manifest for replaying AI artifacts.
|
||||
/// Sprint: SPRINT_20251226_018_AI_attestations
|
||||
/// Task: AIATTEST-18
|
||||
/// </summary>
|
||||
public sealed record AIArtifactReplayManifest
|
||||
{
|
||||
/// <summary>
|
||||
/// Unique artifact ID.
|
||||
/// </summary>
|
||||
public required string ArtifactId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Artifact type (explanation, remediation, vex_draft, policy_draft).
|
||||
/// </summary>
|
||||
public required string ArtifactType { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Model identifier used for generation.
|
||||
/// </summary>
|
||||
public required string ModelId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Weights digest (for local models).
|
||||
/// </summary>
|
||||
public string? WeightsDigest { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Prompt template version.
|
||||
/// </summary>
|
||||
public required string PromptTemplateVersion { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// System prompt used.
|
||||
/// </summary>
|
||||
public required string SystemPrompt { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// User prompt used.
|
||||
/// </summary>
|
||||
public required string UserPrompt { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Temperature (should be 0 for determinism).
|
||||
/// </summary>
|
||||
public required double Temperature { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Random seed for reproducibility.
|
||||
/// </summary>
|
||||
public required int Seed { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Maximum tokens.
|
||||
/// </summary>
|
||||
public required int MaxTokens { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Input hashes for verification.
|
||||
/// </summary>
|
||||
public required IReadOnlyList<string> InputHashes { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Original output hash.
|
||||
/// </summary>
|
||||
public required string OutputHash { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Original output content.
|
||||
/// </summary>
|
||||
public required string OutputContent { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Generation timestamp.
|
||||
/// </summary>
|
||||
public required string GeneratedAt { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of a replay operation.
|
||||
/// </summary>
|
||||
public sealed record ReplayResult
|
||||
{
|
||||
public required bool Success { get; init; }
|
||||
public required string ReplayedOutput { get; init; }
|
||||
public required string ReplayedOutputHash { get; init; }
|
||||
public required bool Identical { get; init; }
|
||||
public required TimeSpan Duration { get; init; }
|
||||
public string? ErrorMessage { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of divergence detection.
|
||||
/// </summary>
|
||||
public sealed record DivergenceResult
|
||||
{
|
||||
public required bool Diverged { get; init; }
|
||||
public required double SimilarityScore { get; init; }
|
||||
public required IReadOnlyList<DivergenceDetail> Details { get; init; }
|
||||
public required string OriginalHash { get; init; }
|
||||
public required string ReplayedHash { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Details of a divergence.
|
||||
/// </summary>
|
||||
public sealed record DivergenceDetail
|
||||
{
|
||||
public required string Type { get; init; }
|
||||
public required string Description { get; init; }
|
||||
public int? Position { get; init; }
|
||||
public string? OriginalSnippet { get; init; }
|
||||
public string? ReplayedSnippet { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of replay verification.
|
||||
/// </summary>
|
||||
public sealed record ReplayVerificationResult
|
||||
{
|
||||
public required bool Verified { get; init; }
|
||||
public required bool OutputIdentical { get; init; }
|
||||
public required bool InputHashesValid { get; init; }
|
||||
public required bool ModelAvailable { get; init; }
|
||||
public IReadOnlyList<string>? ValidationErrors { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Default implementation of AI artifact replayer.
|
||||
/// </summary>
|
||||
public sealed class AIArtifactReplayer : IAIArtifactReplayer
|
||||
{
|
||||
private readonly ILlmProvider _provider;
|
||||
|
||||
public AIArtifactReplayer(ILlmProvider provider)
|
||||
{
|
||||
_provider = provider;
|
||||
}
|
||||
|
||||
public async Task<ReplayResult> ReplayAsync(
|
||||
AIArtifactReplayManifest manifest,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var startTime = DateTime.UtcNow;
|
||||
|
||||
try
|
||||
{
|
||||
// Validate determinism requirements
|
||||
if (manifest.Temperature != 0)
|
||||
{
|
||||
return new ReplayResult
|
||||
{
|
||||
Success = false,
|
||||
ReplayedOutput = string.Empty,
|
||||
ReplayedOutputHash = string.Empty,
|
||||
Identical = false,
|
||||
Duration = DateTime.UtcNow - startTime,
|
||||
ErrorMessage = "Replay requires temperature=0 for determinism"
|
||||
};
|
||||
}
|
||||
|
||||
// Check model availability
|
||||
if (!await _provider.IsAvailableAsync(cancellationToken))
|
||||
{
|
||||
return new ReplayResult
|
||||
{
|
||||
Success = false,
|
||||
ReplayedOutput = string.Empty,
|
||||
ReplayedOutputHash = string.Empty,
|
||||
Identical = false,
|
||||
Duration = DateTime.UtcNow - startTime,
|
||||
ErrorMessage = $"Model {manifest.ModelId} is not available"
|
||||
};
|
||||
}
|
||||
|
||||
// Create request with same parameters
|
||||
var request = new LlmCompletionRequest
|
||||
{
|
||||
SystemPrompt = manifest.SystemPrompt,
|
||||
UserPrompt = manifest.UserPrompt,
|
||||
Model = manifest.ModelId,
|
||||
Temperature = manifest.Temperature,
|
||||
Seed = manifest.Seed,
|
||||
MaxTokens = manifest.MaxTokens,
|
||||
RequestId = $"replay-{manifest.ArtifactId}"
|
||||
};
|
||||
|
||||
// Execute inference
|
||||
var result = await _provider.CompleteAsync(request, cancellationToken);
|
||||
var replayedHash = ComputeHash(result.Content);
|
||||
var identical = string.Equals(replayedHash, manifest.OutputHash, StringComparison.OrdinalIgnoreCase);
|
||||
|
||||
return new ReplayResult
|
||||
{
|
||||
Success = true,
|
||||
ReplayedOutput = result.Content,
|
||||
ReplayedOutputHash = replayedHash,
|
||||
Identical = identical,
|
||||
Duration = DateTime.UtcNow - startTime
|
||||
};
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
return new ReplayResult
|
||||
{
|
||||
Success = false,
|
||||
ReplayedOutput = string.Empty,
|
||||
ReplayedOutputHash = string.Empty,
|
||||
Identical = false,
|
||||
Duration = DateTime.UtcNow - startTime,
|
||||
ErrorMessage = ex.Message
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
public Task<DivergenceResult> DetectDivergenceAsync(
|
||||
AIArtifactReplayManifest originalManifest,
|
||||
string replayedOutput,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var originalHash = originalManifest.OutputHash;
|
||||
var replayedHash = ComputeHash(replayedOutput);
|
||||
var identical = string.Equals(originalHash, replayedHash, StringComparison.OrdinalIgnoreCase);
|
||||
|
||||
if (identical)
|
||||
{
|
||||
return Task.FromResult(new DivergenceResult
|
||||
{
|
||||
Diverged = false,
|
||||
SimilarityScore = 1.0,
|
||||
Details = Array.Empty<DivergenceDetail>(),
|
||||
OriginalHash = originalHash,
|
||||
ReplayedHash = replayedHash
|
||||
});
|
||||
}
|
||||
|
||||
// Analyze divergence
|
||||
var details = new List<DivergenceDetail>();
|
||||
var original = originalManifest.OutputContent;
|
||||
|
||||
// Check length difference
|
||||
if (original.Length != replayedOutput.Length)
|
||||
{
|
||||
details.Add(new DivergenceDetail
|
||||
{
|
||||
Type = "length_mismatch",
|
||||
Description = $"Length differs: original={original.Length}, replayed={replayedOutput.Length}"
|
||||
});
|
||||
}
|
||||
|
||||
// Find first divergence point
|
||||
var minLen = Math.Min(original.Length, replayedOutput.Length);
|
||||
var firstDiff = -1;
|
||||
for (var i = 0; i < minLen; i++)
|
||||
{
|
||||
if (original[i] != replayedOutput[i])
|
||||
{
|
||||
firstDiff = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (firstDiff >= 0)
|
||||
{
|
||||
var snippetLen = Math.Min(50, original.Length - firstDiff);
|
||||
var replayedSnippetLen = Math.Min(50, replayedOutput.Length - firstDiff);
|
||||
|
||||
details.Add(new DivergenceDetail
|
||||
{
|
||||
Type = "content_divergence",
|
||||
Description = "Content differs at position",
|
||||
Position = firstDiff,
|
||||
OriginalSnippet = original.Substring(firstDiff, snippetLen),
|
||||
ReplayedSnippet = replayedOutput.Substring(firstDiff, replayedSnippetLen)
|
||||
});
|
||||
}
|
||||
|
||||
// Calculate similarity score using Levenshtein distance ratio
|
||||
var similarity = CalculateSimilarity(original, replayedOutput);
|
||||
|
||||
return Task.FromResult(new DivergenceResult
|
||||
{
|
||||
Diverged = true,
|
||||
SimilarityScore = similarity,
|
||||
Details = details,
|
||||
OriginalHash = originalHash,
|
||||
ReplayedHash = replayedHash
|
||||
});
|
||||
}
|
||||
|
||||
public async Task<ReplayVerificationResult> VerifyReplayAsync(
|
||||
AIArtifactReplayManifest manifest,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var errors = new List<string>();
|
||||
|
||||
// Verify determinism settings
|
||||
if (manifest.Temperature != 0)
|
||||
{
|
||||
errors.Add("Temperature must be 0 for deterministic replay");
|
||||
}
|
||||
|
||||
// Verify input hashes
|
||||
var inputHashesValid = await VerifyInputHashesAsync(manifest, cancellationToken);
|
||||
if (!inputHashesValid)
|
||||
{
|
||||
errors.Add("Input hashes could not be verified");
|
||||
}
|
||||
|
||||
// Check model availability
|
||||
var modelAvailable = await _provider.IsAvailableAsync(cancellationToken);
|
||||
if (!modelAvailable)
|
||||
{
|
||||
errors.Add($"Model {manifest.ModelId} is not available");
|
||||
}
|
||||
|
||||
// Attempt replay if all prerequisites pass
|
||||
var outputIdentical = false;
|
||||
if (errors.Count == 0)
|
||||
{
|
||||
var replayResult = await ReplayAsync(manifest, cancellationToken);
|
||||
if (replayResult.Success)
|
||||
{
|
||||
outputIdentical = replayResult.Identical;
|
||||
if (!outputIdentical)
|
||||
{
|
||||
errors.Add("Replayed output differs from original");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
errors.Add($"Replay failed: {replayResult.ErrorMessage}");
|
||||
}
|
||||
}
|
||||
|
||||
return new ReplayVerificationResult
|
||||
{
|
||||
Verified = errors.Count == 0 && outputIdentical,
|
||||
OutputIdentical = outputIdentical,
|
||||
InputHashesValid = inputHashesValid,
|
||||
ModelAvailable = modelAvailable,
|
||||
ValidationErrors = errors.Count > 0 ? errors : null
|
||||
};
|
||||
}
|
||||
|
||||
private static Task<bool> VerifyInputHashesAsync(
|
||||
AIArtifactReplayManifest manifest,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
// Verify that input hashes can be reconstructed from the manifest
|
||||
var expectedHashes = new List<string>
|
||||
{
|
||||
ComputeHash(manifest.SystemPrompt),
|
||||
ComputeHash(manifest.UserPrompt)
|
||||
};
|
||||
|
||||
// Check if all expected hashes are present in manifest
|
||||
var allPresent = expectedHashes.All(h =>
|
||||
manifest.InputHashes.Any(ih => ih.Contains(h[..16])));
|
||||
|
||||
return Task.FromResult(allPresent || manifest.InputHashes.Count > 0);
|
||||
}
|
||||
|
||||
private static string ComputeHash(string content)
|
||||
{
|
||||
var bytes = Encoding.UTF8.GetBytes(content);
|
||||
var hash = SHA256.HashData(bytes);
|
||||
return Convert.ToHexStringLower(hash);
|
||||
}
|
||||
|
||||
private static double CalculateSimilarity(string a, string b)
|
||||
{
|
||||
if (string.IsNullOrEmpty(a) && string.IsNullOrEmpty(b))
|
||||
return 1.0;
|
||||
if (string.IsNullOrEmpty(a) || string.IsNullOrEmpty(b))
|
||||
return 0.0;
|
||||
|
||||
// Simple character-level similarity
|
||||
var maxLen = Math.Max(a.Length, b.Length);
|
||||
var minLen = Math.Min(a.Length, b.Length);
|
||||
var matches = 0;
|
||||
|
||||
for (var i = 0; i < minLen; i++)
|
||||
{
|
||||
if (a[i] == b[i])
|
||||
matches++;
|
||||
}
|
||||
|
||||
return (double)matches / maxLen;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Factory for creating AI artifact replayers.
|
||||
/// </summary>
|
||||
public sealed class AIArtifactReplayerFactory
|
||||
{
|
||||
private readonly ILlmProviderFactory _providerFactory;
|
||||
|
||||
public AIArtifactReplayerFactory(ILlmProviderFactory providerFactory)
|
||||
{
|
||||
_providerFactory = providerFactory;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a replayer using the specified provider.
|
||||
/// </summary>
|
||||
public IAIArtifactReplayer Create(string providerId)
|
||||
{
|
||||
var provider = _providerFactory.GetProvider(providerId);
|
||||
return new AIArtifactReplayer(provider);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a replayer using the default provider.
|
||||
/// </summary>
|
||||
public IAIArtifactReplayer CreateDefault()
|
||||
{
|
||||
var provider = _providerFactory.GetDefaultProvider();
|
||||
return new AIArtifactReplayer(provider);
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@
|
||||
<ProjectReference Include="..\..\Concelier\__Libraries\StellaOps.Concelier.Core\StellaOps.Concelier.Core.csproj" />
|
||||
<ProjectReference Include="..\..\Concelier\__Libraries\StellaOps.Concelier.RawModels\StellaOps.Concelier.RawModels.csproj" />
|
||||
<ProjectReference Include="..\..\Excititor\__Libraries\StellaOps.Excititor.Core\StellaOps.Excititor.Core.csproj" />
|
||||
<ProjectReference Include="..\..\__Libraries\StellaOps.Configuration\StellaOps.Configuration.csproj" />
|
||||
<ProjectReference Include="..\..\__Libraries\StellaOps.Cryptography\StellaOps.Cryptography.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
||||
@@ -15,6 +15,8 @@ using StellaOps.AdvisoryAI.Orchestration;
|
||||
using StellaOps.AdvisoryAI.Prompting;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryGuardrailInjectionTests
|
||||
@@ -35,7 +37,8 @@ public sealed class AdvisoryGuardrailInjectionTests
|
||||
public static IEnumerable<object[]> InjectionPayloads =>
|
||||
HarnessCases.Value.Select(testCase => new object[] { testCase });
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[MemberData(nameof(InjectionPayloads))]
|
||||
public async Task EvaluateAsync_CompliesWithGuardrailHarness(InjectionCase testCase)
|
||||
{
|
||||
|
||||
@@ -12,11 +12,14 @@ using StellaOps.AdvisoryAI.Guardrails;
|
||||
using StellaOps.AdvisoryAI.Hosting;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryGuardrailOptionsBindingTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task AddAdvisoryAiCore_ConfiguresGuardrailOptionsFromServiceOptions()
|
||||
{
|
||||
var tempRoot = CreateTempDirectory();
|
||||
@@ -47,7 +50,8 @@ public sealed class AdvisoryGuardrailOptionsBindingTests
|
||||
options.BlockedPhrases.Should().Contain("dump cache");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task AddAdvisoryAiCore_ThrowsWhenPhraseFileMissing()
|
||||
{
|
||||
var tempRoot = CreateTempDirectory();
|
||||
|
||||
@@ -16,6 +16,8 @@ using StellaOps.AdvisoryAI.Orchestration;
|
||||
using StellaOps.AdvisoryAI.Prompting;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryGuardrailPerformanceTests
|
||||
@@ -27,7 +29,8 @@ public sealed class AdvisoryGuardrailPerformanceTests
|
||||
|
||||
public static IEnumerable<object[]> PerfScenarios => LoadPerfScenarios();
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[MemberData(nameof(PerfScenarios))]
|
||||
public async Task EvaluateAsync_CompletesWithinBudget(PerfScenario scenario)
|
||||
{
|
||||
@@ -53,7 +56,8 @@ public sealed class AdvisoryGuardrailPerformanceTests
|
||||
$"{scenario.Name} exceeded the allotted {scenario.MaxDurationMs} ms budget (measured {stopwatch.ElapsedMilliseconds} ms)");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task EvaluateAsync_HonorsSeededBlockedPhrases()
|
||||
{
|
||||
var phrases = LoadSeededBlockedPhrases();
|
||||
|
||||
@@ -7,11 +7,13 @@ using StellaOps.AdvisoryAI.Orchestration;
|
||||
using StellaOps.AdvisoryAI.Prompting;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryGuardrailPipelineTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task EvaluateAsync_BlocksWhenCitationsMissing()
|
||||
{
|
||||
var options = Options.Create(new AdvisoryGuardrailOptions { RequireCitations = true });
|
||||
@@ -31,7 +33,8 @@ public sealed class AdvisoryGuardrailPipelineTests
|
||||
Assert.Contains(result.Violations, violation => violation.Code == "citation_missing");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task EvaluateAsync_RedactsSecrets()
|
||||
{
|
||||
var options = Options.Create(new AdvisoryGuardrailOptions());
|
||||
|
||||
@@ -17,13 +17,16 @@ using StellaOps.AdvisoryAI.Tools;
|
||||
using StellaOps.AdvisoryAI.Inference;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryPipelineExecutorTests : IDisposable
|
||||
{
|
||||
private readonly StubMeterFactory _meterFactory = new();
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ExecuteAsync_SavesOutputAndProvenance()
|
||||
{
|
||||
var plan = BuildMinimalPlan(cacheKey: "CACHE-1");
|
||||
@@ -49,7 +52,8 @@ public sealed class AdvisoryPipelineExecutorTests : IDisposable
|
||||
saved.Guardrail.Metadata.Should().ContainKey("prompt_length");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ExecuteAsync_PersistsGuardrailOutcome()
|
||||
{
|
||||
var plan = BuildMinimalPlan(cacheKey: "CACHE-2");
|
||||
@@ -71,7 +75,8 @@ public sealed class AdvisoryPipelineExecutorTests : IDisposable
|
||||
saved.Prompt.Should().Be("{\"prompt\":\"value\"}");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ExecuteAsync_RecordsTelemetryMeasurements()
|
||||
{
|
||||
using var listener = new MeterListener();
|
||||
@@ -124,7 +129,8 @@ public sealed class AdvisoryPipelineExecutorTests : IDisposable
|
||||
Math.Abs(measurement.Value - 1d) < 0.0001);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ExecuteAsync_ComputesPartialCitationCoverage()
|
||||
{
|
||||
using var listener = new MeterListener();
|
||||
@@ -163,7 +169,8 @@ public sealed class AdvisoryPipelineExecutorTests : IDisposable
|
||||
Math.Abs(measurement.Value - 0.5d) < 0.0001);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ExecuteAsync_RecordsInferenceMetadata()
|
||||
{
|
||||
var plan = BuildMinimalPlan(cacheKey: "CACHE-4");
|
||||
|
||||
@@ -13,11 +13,13 @@ using StellaOps.AdvisoryAI.Orchestration;
|
||||
using StellaOps.AdvisoryAI.Tools;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryPipelineOrchestratorTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CreatePlanAsync_ComposesDeterministicPlan()
|
||||
{
|
||||
var structuredRetriever = new FakeStructuredRetriever();
|
||||
@@ -63,7 +65,8 @@ public sealed class AdvisoryPipelineOrchestratorTests
|
||||
Assert.Equal(plan.CacheKey, secondPlan.CacheKey);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CreatePlanAsync_RemainsDeterministicAcrossMultipleRuns()
|
||||
{
|
||||
var structuredRetriever = new ShufflingStructuredRetriever();
|
||||
@@ -116,7 +119,8 @@ public sealed class AdvisoryPipelineOrchestratorTests
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CreatePlanAsync_PopulatesMetadataCountsFromEvidence()
|
||||
{
|
||||
var structuredRetriever = new FakeStructuredRetriever();
|
||||
@@ -156,7 +160,8 @@ public sealed class AdvisoryPipelineOrchestratorTests
|
||||
metadata["sbom_blast_impacted_workloads"].Should().Be("3");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CreatePlanAsync_WhenArtifactIdMissing_SkipsSbomContext()
|
||||
{
|
||||
var structuredRetriever = new FakeStructuredRetriever();
|
||||
@@ -188,7 +193,8 @@ public sealed class AdvisoryPipelineOrchestratorTests
|
||||
Assert.DoesNotContain("sbom_dependency_path_count", plan.Metadata.Keys);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CreatePlanAsync_RespectsOptionFlagsAndProducesStableCacheKey()
|
||||
{
|
||||
var structuredRetriever = new FakeStructuredRetriever();
|
||||
@@ -227,7 +233,8 @@ public sealed class AdvisoryPipelineOrchestratorTests
|
||||
Assert.DoesNotContain(planOne.Metadata.Keys, key => key.StartsWith("sbom_blast_", StringComparison.Ordinal));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CreatePlanAsync_RemainsDeterministicWhenRetrieverOrderChanges()
|
||||
{
|
||||
var structuredRetriever = new ShufflingStructuredRetriever();
|
||||
|
||||
@@ -10,11 +10,13 @@ using StellaOps.AdvisoryAI.Orchestration;
|
||||
using StellaOps.AdvisoryAI.Tools;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryPipelinePlanResponseTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void FromPlan_ProjectsMetadataAndCounts()
|
||||
{
|
||||
var request = new AdvisoryTaskRequest(AdvisoryTaskType.Summary, "adv-key");
|
||||
|
||||
@@ -15,11 +15,13 @@ using StellaOps.AdvisoryAI.Tools;
|
||||
using StellaOps.AdvisoryAI.Tests.TestUtilities;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryPlanCacheTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SetAndRetrieve_ReturnsCachedPlan()
|
||||
{
|
||||
var timeProvider = new DeterministicTimeProvider(DateTimeOffset.UtcNow);
|
||||
@@ -34,7 +36,8 @@ public sealed class AdvisoryPlanCacheTests
|
||||
retrieved.Metadata.Should().ContainKey("task_type");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ExpiredEntries_AreEvicted()
|
||||
{
|
||||
var start = DateTimeOffset.UtcNow;
|
||||
@@ -49,7 +52,8 @@ public sealed class AdvisoryPlanCacheTests
|
||||
retrieved.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SetAsync_ReplacesPlanAndRefreshesExpiration()
|
||||
{
|
||||
var timeProvider = new DeterministicTimeProvider(DateTimeOffset.UtcNow);
|
||||
@@ -69,7 +73,8 @@ public sealed class AdvisoryPlanCacheTests
|
||||
retrieved!.Request.AdvisoryKey.Should().Be("ADV-999");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SetAsync_WithInterleavedKeysRemainsDeterministic()
|
||||
{
|
||||
var timeProvider = new DeterministicTimeProvider(DateTimeOffset.UtcNow);
|
||||
@@ -108,7 +113,8 @@ public sealed class AdvisoryPlanCacheTests
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData(7)]
|
||||
[InlineData(42)]
|
||||
[InlineData(512)]
|
||||
|
||||
@@ -14,6 +14,8 @@ using StellaOps.AdvisoryAI.Tools;
|
||||
using Xunit;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryPromptAssemblerTests
|
||||
@@ -25,7 +27,8 @@ public sealed class AdvisoryPromptAssemblerTests
|
||||
_output = output;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task AssembleAsync_ProducesDeterministicPrompt()
|
||||
{
|
||||
var plan = BuildPlan();
|
||||
@@ -43,7 +46,8 @@ public sealed class AdvisoryPromptAssemblerTests
|
||||
await AssertPromptMatchesGoldenAsync("summary-prompt.json", prompt.Prompt);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task AssembleAsync_ProducesConflictPromptGolden()
|
||||
{
|
||||
var plan = BuildPlan(AdvisoryTaskType.Conflict);
|
||||
@@ -56,7 +60,8 @@ public sealed class AdvisoryPromptAssemblerTests
|
||||
prompt.Metadata["task_type"].Should().Be(nameof(AdvisoryTaskType.Conflict));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task AssembleAsync_TruncatesVectorPreviewsToMaintainPromptSize()
|
||||
{
|
||||
var longPreview = new string('A', 700);
|
||||
|
||||
@@ -7,11 +7,13 @@ using StellaOps.AdvisoryAI.Documents;
|
||||
using StellaOps.AdvisoryAI.Retrievers;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryStructuredRetrieverTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task RetrieveAsync_ReturnsCsafChunksWithMetadata()
|
||||
{
|
||||
var provider = CreateProvider(
|
||||
@@ -33,7 +35,8 @@ public sealed class AdvisoryStructuredRetrieverTests
|
||||
result.Chunks.Any(c => c.Section == "document.notes").Should().BeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task RetrieveAsync_ReturnsOsvChunksWithAffectedMetadata()
|
||||
{
|
||||
var provider = CreateProvider(
|
||||
@@ -53,7 +56,8 @@ public sealed class AdvisoryStructuredRetrieverTests
|
||||
result.Chunks.First(c => c.Section.StartsWith("affected", StringComparison.OrdinalIgnoreCase)).Metadata.Should().ContainKey("package");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task RetrieveAsync_ReturnsOpenVexChunksWithStatusMetadata()
|
||||
{
|
||||
var provider = CreateProvider(
|
||||
@@ -73,7 +77,8 @@ public sealed class AdvisoryStructuredRetrieverTests
|
||||
result.Chunks.Should().AllSatisfy(chunk => chunk.Section.Should().Be("vex.statements"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task RetrieveAsync_FiltersToPreferredSections()
|
||||
{
|
||||
var provider = CreateProvider(
|
||||
|
||||
@@ -7,11 +7,13 @@ using StellaOps.AdvisoryAI.Orchestration;
|
||||
using StellaOps.AdvisoryAI.Queue;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryTaskQueueTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task EnqueueAndDequeue_ReturnsMessageInOrder()
|
||||
{
|
||||
var options = Options.Create(new AdvisoryTaskQueueOptions { Capacity = 10, DequeueWaitInterval = TimeSpan.FromMilliseconds(50) });
|
||||
|
||||
@@ -7,11 +7,13 @@ using StellaOps.AdvisoryAI.Tests.TestUtilities;
|
||||
using StellaOps.AdvisoryAI.Vectorization;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class AdvisoryVectorRetrieverTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SearchAsync_ReturnsBestMatchingChunk()
|
||||
{
|
||||
var advisoryContent = """
|
||||
|
||||
@@ -8,11 +8,13 @@ using StellaOps.Concelier.Core.Raw;
|
||||
using StellaOps.Concelier.RawModels;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class ConcelierAdvisoryDocumentProviderTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GetDocumentsAsync_ReturnsMappedDocuments()
|
||||
{
|
||||
var rawDocument = RawDocumentFactory.CreateAdvisory(
|
||||
|
||||
@@ -5,11 +5,13 @@ using StellaOps.AdvisoryAI.Context;
|
||||
using StellaOps.AdvisoryAI.Tools;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class DeterministicToolsetTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void AnalyzeDependencies_ComputesRuntimeAndDevelopmentCounts()
|
||||
{
|
||||
var context = SbomContextResult.Create(
|
||||
@@ -52,7 +54,8 @@ public sealed class DeterministicToolsetTests
|
||||
libB.DevelopmentOccurrences.Should().Be(1);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData("semver", "1.2.3", "1.2.4", -1)]
|
||||
[InlineData("semver", "1.2.3", "1.2.3", 0)]
|
||||
[InlineData("semver", "1.2.4", "1.2.3", 1)]
|
||||
@@ -66,7 +69,8 @@ public sealed class DeterministicToolsetTests
|
||||
comparison.Should().Be(expected);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData("semver", "1.2.3", ">=1.0.0 <2.0.0")]
|
||||
[InlineData("semver", "2.0.0", ">=2.0.0")]
|
||||
[InlineData("evr", "0:1.2-3", ">=0:1.0-0 <0:2.0-0")]
|
||||
|
||||
@@ -9,11 +9,13 @@ using StellaOps.Excititor.Core;
|
||||
using StellaOps.Excititor.Core.Observations;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class ExcititorVexDocumentProviderTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GetDocumentsAsync_ReturnsMappedObservation()
|
||||
{
|
||||
const string vulnerabilityId = "CVE-2024-9999";
|
||||
|
||||
@@ -0,0 +1,542 @@
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using FluentAssertions;
|
||||
using StellaOps.AdvisoryAI.Explanation;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
/// <summary>
|
||||
/// Integration tests for explanation generation with mocked LLM and evidence anchoring validation.
|
||||
/// Sprint: SPRINT_20251226_015_AI_zastava_companion
|
||||
/// Task: ZASTAVA-19
|
||||
/// </summary>
|
||||
public sealed class ExplanationGeneratorIntegrationTests
|
||||
{
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GenerateAsync_WithFullEvidence_ProducesEvidenceBackedExplanation()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceService = new StubEvidenceRetrievalService(CreateFullEvidenceContext());
|
||||
var promptService = new StubExplanationPromptService();
|
||||
var inferenceClient = new StubExplanationInferenceClient(
|
||||
content: "This is a test explanation with [citation:ev-001] and [citation:ev-002].",
|
||||
confidence: 0.95);
|
||||
var citationExtractor = new StubCitationExtractor(verifiedRate: 0.9);
|
||||
var store = new InMemoryExplanationStore();
|
||||
|
||||
var generator = new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
|
||||
var request = CreateExplanationRequest(ExplanationType.Full);
|
||||
|
||||
// Act
|
||||
var result = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result.ExplanationId.Should().StartWith("sha256:");
|
||||
result.Authority.Should().Be(ExplanationAuthority.EvidenceBacked);
|
||||
result.CitationRate.Should().BeGreaterOrEqualTo(0.8);
|
||||
result.Citations.Should().NotBeEmpty();
|
||||
result.EvidenceRefs.Should().NotBeEmpty();
|
||||
result.InputHashes.Should().HaveCount(3);
|
||||
result.OutputHash.Should().NotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GenerateAsync_WithLowCitationRate_ProducesSuggestionExplanation()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceService = new StubEvidenceRetrievalService(CreateMinimalEvidenceContext());
|
||||
var promptService = new StubExplanationPromptService();
|
||||
var inferenceClient = new StubExplanationInferenceClient(
|
||||
content: "This is a speculative explanation without proper citations.",
|
||||
confidence: 0.6);
|
||||
var citationExtractor = new StubCitationExtractor(verifiedRate: 0.3);
|
||||
var store = new InMemoryExplanationStore();
|
||||
|
||||
var generator = new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
|
||||
var request = CreateExplanationRequest(ExplanationType.Why);
|
||||
|
||||
// Act
|
||||
var result = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert
|
||||
result.Authority.Should().Be(ExplanationAuthority.Suggestion);
|
||||
result.CitationRate.Should().BeLessThan(0.8);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GenerateAsync_StoresResultForReplay()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceService = new StubEvidenceRetrievalService(CreateFullEvidenceContext());
|
||||
var promptService = new StubExplanationPromptService();
|
||||
var inferenceClient = new StubExplanationInferenceClient(
|
||||
content: "Stored explanation [citation:ev-001].",
|
||||
confidence: 0.9);
|
||||
var citationExtractor = new StubCitationExtractor(verifiedRate: 0.85);
|
||||
var store = new InMemoryExplanationStore();
|
||||
|
||||
var generator = new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
|
||||
var request = CreateExplanationRequest(ExplanationType.What);
|
||||
|
||||
// Act
|
||||
var result = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert
|
||||
var stored = await store.GetAsync(result.ExplanationId, CancellationToken.None);
|
||||
stored.Should().NotBeNull();
|
||||
stored!.ExplanationId.Should().Be(result.ExplanationId);
|
||||
stored.OutputHash.Should().Be(result.OutputHash);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GenerateAsync_ComputesConsistentInputHashes()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceContext = CreateFullEvidenceContext();
|
||||
var evidenceService = new StubEvidenceRetrievalService(evidenceContext);
|
||||
var promptService = new StubExplanationPromptService();
|
||||
var inferenceClient = new StubExplanationInferenceClient(
|
||||
content: "Consistent explanation.",
|
||||
confidence: 0.88);
|
||||
var citationExtractor = new StubCitationExtractor(verifiedRate: 0.85);
|
||||
var store = new InMemoryExplanationStore();
|
||||
|
||||
var generator = new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
|
||||
var request = CreateExplanationRequest(ExplanationType.Evidence);
|
||||
|
||||
// Act
|
||||
var result1 = await generator.GenerateAsync(request);
|
||||
var result2 = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert - same inputs should produce same input hashes
|
||||
result1.InputHashes.Should().BeEquivalentTo(result2.InputHashes);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GenerateAsync_ProducesValidExplanationId()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceService = new StubEvidenceRetrievalService(CreateFullEvidenceContext());
|
||||
var promptService = new StubExplanationPromptService();
|
||||
var inferenceClient = new StubExplanationInferenceClient(
|
||||
content: "Test explanation content.",
|
||||
confidence: 0.9);
|
||||
var citationExtractor = new StubCitationExtractor(verifiedRate: 0.9);
|
||||
var store = new InMemoryExplanationStore();
|
||||
|
||||
var generator = new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
|
||||
var request = CreateExplanationRequest(ExplanationType.Full);
|
||||
|
||||
// Act
|
||||
var result = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert
|
||||
result.ExplanationId.Should().StartWith("sha256:");
|
||||
result.ExplanationId.Length.Should().Be(7 + 64); // "sha256:" + 64 hex chars
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GenerateAsync_IncludesAllEvidenceRefs()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceContext = CreateFullEvidenceContext();
|
||||
var evidenceService = new StubEvidenceRetrievalService(evidenceContext);
|
||||
var promptService = new StubExplanationPromptService();
|
||||
var inferenceClient = new StubExplanationInferenceClient(
|
||||
content: "Explanation with evidence.",
|
||||
confidence: 0.9);
|
||||
var citationExtractor = new StubCitationExtractor(verifiedRate: 0.9);
|
||||
var store = new InMemoryExplanationStore();
|
||||
|
||||
var generator = new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
|
||||
var request = CreateExplanationRequest(ExplanationType.Full);
|
||||
|
||||
// Act
|
||||
var result = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert
|
||||
var allEvidenceIds = evidenceContext.AllEvidence.Select(e => e.Id).ToList();
|
||||
result.EvidenceRefs.Should().BeEquivalentTo(allEvidenceIds);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GenerateAsync_RecordsModelIdAndTemplateVersion()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceService = new StubEvidenceRetrievalService(CreateFullEvidenceContext());
|
||||
var promptService = new StubExplanationPromptService(templateVersion: "explain-v2.1");
|
||||
var inferenceClient = new StubExplanationInferenceClient(
|
||||
content: "Test.",
|
||||
confidence: 0.9,
|
||||
modelId: "claude:claude-3-opus:20240229");
|
||||
var citationExtractor = new StubCitationExtractor(verifiedRate: 0.9);
|
||||
var store = new InMemoryExplanationStore();
|
||||
|
||||
var generator = new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
|
||||
var request = CreateExplanationRequest(ExplanationType.Full);
|
||||
|
||||
// Act
|
||||
var result = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert
|
||||
result.ModelId.Should().Be("claude:claude-3-opus:20240229");
|
||||
result.PromptTemplateVersion.Should().Be("explain-v2.1");
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GenerateAsync_GeneratesValidSummary()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceService = new StubEvidenceRetrievalService(CreateFullEvidenceContext());
|
||||
var promptService = new StubExplanationPromptService();
|
||||
var inferenceClient = new StubExplanationInferenceClient(
|
||||
content: "Detailed explanation content.",
|
||||
confidence: 0.9);
|
||||
var citationExtractor = new StubCitationExtractor(verifiedRate: 0.9);
|
||||
var store = new InMemoryExplanationStore();
|
||||
|
||||
var generator = new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
|
||||
var request = CreateExplanationRequest(ExplanationType.Full);
|
||||
|
||||
// Act
|
||||
var result = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert
|
||||
result.Summary.Should().NotBeNull();
|
||||
result.Summary.Line1.Should().NotBeNullOrEmpty();
|
||||
result.Summary.Line2.Should().NotBeNullOrEmpty();
|
||||
result.Summary.Line3.Should().NotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData(ExplanationType.What)]
|
||||
[InlineData(ExplanationType.Why)]
|
||||
[InlineData(ExplanationType.Evidence)]
|
||||
[InlineData(ExplanationType.Counterfactual)]
|
||||
[InlineData(ExplanationType.Full)]
|
||||
public async Task GenerateAsync_HandlesAllExplanationTypes(ExplanationType type)
|
||||
{
|
||||
// Arrange
|
||||
var evidenceService = new StubEvidenceRetrievalService(CreateFullEvidenceContext());
|
||||
var promptService = new StubExplanationPromptService();
|
||||
var inferenceClient = new StubExplanationInferenceClient(
|
||||
content: $"Explanation for {type}.",
|
||||
confidence: 0.9);
|
||||
var citationExtractor = new StubCitationExtractor(verifiedRate: 0.85);
|
||||
var store = new InMemoryExplanationStore();
|
||||
|
||||
var generator = new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
|
||||
var request = CreateExplanationRequest(type);
|
||||
|
||||
// Act
|
||||
var result = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert
|
||||
result.Should().NotBeNull();
|
||||
result.Content.Should().Contain(type.ToString());
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ValidateAsync_ReturnsTrueForValidEvidence()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceService = new StubEvidenceRetrievalService(CreateFullEvidenceContext(), validateResult: true);
|
||||
var promptService = new StubExplanationPromptService();
|
||||
var inferenceClient = new StubExplanationInferenceClient(
|
||||
content: "Test.",
|
||||
confidence: 0.9);
|
||||
var citationExtractor = new StubCitationExtractor(verifiedRate: 0.9);
|
||||
var store = new InMemoryExplanationStore();
|
||||
|
||||
var generator = new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
|
||||
var request = CreateExplanationRequest(ExplanationType.Full);
|
||||
var result = await generator.GenerateAsync(request);
|
||||
|
||||
// Act
|
||||
var isValid = await generator.ValidateAsync(result);
|
||||
|
||||
// Assert
|
||||
isValid.Should().BeTrue();
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ValidateAsync_ReturnsFalseWhenEvidenceChanged()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceService = new StubEvidenceRetrievalService(CreateFullEvidenceContext(), validateResult: false);
|
||||
var promptService = new StubExplanationPromptService();
|
||||
var inferenceClient = new StubExplanationInferenceClient(
|
||||
content: "Test.",
|
||||
confidence: 0.9);
|
||||
var citationExtractor = new StubCitationExtractor(verifiedRate: 0.9);
|
||||
var store = new InMemoryExplanationStore();
|
||||
|
||||
var generator = new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
|
||||
var request = CreateExplanationRequest(ExplanationType.Full);
|
||||
var result = await generator.GenerateAsync(request);
|
||||
|
||||
// Act
|
||||
var isValid = await generator.ValidateAsync(result);
|
||||
|
||||
// Assert
|
||||
isValid.Should().BeFalse();
|
||||
}
|
||||
|
||||
#region Helper Methods
|
||||
|
||||
private static ExplanationRequest CreateExplanationRequest(ExplanationType type) => new()
|
||||
{
|
||||
FindingId = "finding-001",
|
||||
ArtifactDigest = "sha256:abc123",
|
||||
Scope = "image",
|
||||
ScopeId = "my-image:latest",
|
||||
ExplanationType = type,
|
||||
VulnerabilityId = "CVE-2024-1234",
|
||||
ComponentPurl = "pkg:npm/lodash@4.17.20",
|
||||
PlainLanguage = false,
|
||||
MaxLength = 0,
|
||||
CorrelationId = "corr-001"
|
||||
};
|
||||
|
||||
private static EvidenceContext CreateFullEvidenceContext() => new()
|
||||
{
|
||||
SbomEvidence =
|
||||
[
|
||||
new EvidenceNode
|
||||
{
|
||||
Id = "ev-001",
|
||||
Type = "sbom",
|
||||
Summary = "Component lodash@4.17.20 found in SBOM",
|
||||
Content = "Package: lodash, Version: 4.17.20, License: MIT",
|
||||
Source = "sbom-scan",
|
||||
Confidence = 0.99,
|
||||
CollectedAt = "2024-01-15T10:00:00Z"
|
||||
}
|
||||
],
|
||||
ReachabilityEvidence =
|
||||
[
|
||||
new EvidenceNode
|
||||
{
|
||||
Id = "ev-002",
|
||||
Type = "reachability",
|
||||
Summary = "Vulnerable function is reachable",
|
||||
Content = "Call path: main.js -> utils.js -> lodash.merge()",
|
||||
Source = "static-analysis",
|
||||
Confidence = 0.85,
|
||||
CollectedAt = "2024-01-15T10:05:00Z"
|
||||
}
|
||||
],
|
||||
RuntimeEvidence = [],
|
||||
VexEvidence =
|
||||
[
|
||||
new EvidenceNode
|
||||
{
|
||||
Id = "ev-003",
|
||||
Type = "vex",
|
||||
Summary = "No vendor VEX statement",
|
||||
Content = "No applicable VEX statements found",
|
||||
Source = "vex-lookup",
|
||||
Confidence = 0.5,
|
||||
CollectedAt = "2024-01-15T10:10:00Z"
|
||||
}
|
||||
],
|
||||
PatchEvidence = [],
|
||||
ContextHash = ComputeHash("full-evidence-context")
|
||||
};
|
||||
|
||||
private static EvidenceContext CreateMinimalEvidenceContext() => new()
|
||||
{
|
||||
SbomEvidence =
|
||||
[
|
||||
new EvidenceNode
|
||||
{
|
||||
Id = "ev-min-001",
|
||||
Type = "sbom",
|
||||
Summary = "Component found",
|
||||
Content = "Package exists",
|
||||
Source = "sbom",
|
||||
Confidence = 0.7,
|
||||
CollectedAt = "2024-01-15T10:00:00Z"
|
||||
}
|
||||
],
|
||||
ReachabilityEvidence = [],
|
||||
RuntimeEvidence = [],
|
||||
VexEvidence = [],
|
||||
PatchEvidence = [],
|
||||
ContextHash = ComputeHash("minimal-evidence-context")
|
||||
};
|
||||
|
||||
private static string ComputeHash(string content)
|
||||
{
|
||||
var bytes = SHA256.HashData(Encoding.UTF8.GetBytes(content));
|
||||
return Convert.ToHexStringLower(bytes);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Stub Implementations
|
||||
|
||||
private sealed class StubEvidenceRetrievalService : IEvidenceRetrievalService
|
||||
{
|
||||
private readonly EvidenceContext _context;
|
||||
private readonly bool _validateResult;
|
||||
|
||||
public StubEvidenceRetrievalService(EvidenceContext context, bool validateResult = true)
|
||||
{
|
||||
_context = context;
|
||||
_validateResult = validateResult;
|
||||
}
|
||||
|
||||
public Task<EvidenceContext> RetrieveEvidenceAsync(
|
||||
string findingId, string artifactDigest, string vulnerabilityId,
|
||||
string? componentPurl = null, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(_context);
|
||||
|
||||
public Task<EvidenceNode?> GetEvidenceNodeAsync(string evidenceId, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(_context.AllEvidence.FirstOrDefault(e => e.Id == evidenceId));
|
||||
|
||||
public Task<bool> ValidateEvidenceAsync(IEnumerable<string> evidenceIds, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(_validateResult);
|
||||
}
|
||||
|
||||
private sealed class StubExplanationPromptService : IExplanationPromptService
|
||||
{
|
||||
private readonly string _templateVersion;
|
||||
|
||||
public StubExplanationPromptService(string templateVersion = "explain-v1.0")
|
||||
{
|
||||
_templateVersion = templateVersion;
|
||||
}
|
||||
|
||||
public Task<ExplanationPrompt> BuildPromptAsync(
|
||||
ExplanationRequest request, EvidenceContext evidence, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(new ExplanationPrompt
|
||||
{
|
||||
Content = $"Explain {request.VulnerabilityId} for {request.ExplanationType}",
|
||||
TemplateVersion = _templateVersion
|
||||
});
|
||||
|
||||
public Task<ExplanationSummary> GenerateSummaryAsync(
|
||||
string content, ExplanationType type, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(new ExplanationSummary
|
||||
{
|
||||
Line1 = "What: Vulnerability detected",
|
||||
Line2 = "Why: Reachable code path",
|
||||
Line3 = "Action: Update dependency"
|
||||
});
|
||||
}
|
||||
|
||||
private sealed class StubExplanationInferenceClient : IExplanationInferenceClient
|
||||
{
|
||||
private readonly string _content;
|
||||
private readonly double _confidence;
|
||||
private readonly string _modelId;
|
||||
|
||||
public StubExplanationInferenceClient(string content, double confidence, string modelId = "stub-model:v1")
|
||||
{
|
||||
_content = content;
|
||||
_confidence = confidence;
|
||||
_modelId = modelId;
|
||||
}
|
||||
|
||||
public Task<ExplanationInferenceResult> GenerateAsync(
|
||||
ExplanationPrompt prompt, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(new ExplanationInferenceResult
|
||||
{
|
||||
Content = _content,
|
||||
Confidence = _confidence,
|
||||
ModelId = _modelId
|
||||
});
|
||||
}
|
||||
|
||||
private sealed class StubCitationExtractor : ICitationExtractor
|
||||
{
|
||||
private readonly double _verifiedRate;
|
||||
|
||||
public StubCitationExtractor(double verifiedRate)
|
||||
{
|
||||
_verifiedRate = verifiedRate;
|
||||
}
|
||||
|
||||
public Task<IReadOnlyList<ExplanationCitation>> ExtractCitationsAsync(
|
||||
string content, EvidenceContext evidence, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var citations = new List<ExplanationCitation>();
|
||||
var evidenceList = evidence.AllEvidence.ToList();
|
||||
|
||||
for (int i = 0; i < evidenceList.Count; i++)
|
||||
{
|
||||
var ev = evidenceList[i];
|
||||
citations.Add(new ExplanationCitation
|
||||
{
|
||||
ClaimText = $"Claim about {ev.Type}",
|
||||
EvidenceId = ev.Id,
|
||||
EvidenceType = ev.Type,
|
||||
Verified = i < (int)(evidenceList.Count * _verifiedRate),
|
||||
EvidenceExcerpt = ev.Summary
|
||||
});
|
||||
}
|
||||
|
||||
return Task.FromResult<IReadOnlyList<ExplanationCitation>>(citations);
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class InMemoryExplanationStore : IExplanationStore
|
||||
{
|
||||
private readonly Dictionary<string, ExplanationResult> _results = new();
|
||||
private readonly Dictionary<string, ExplanationRequest> _requests = new();
|
||||
|
||||
public Task StoreAsync(ExplanationResult result, CancellationToken cancellationToken = default)
|
||||
{
|
||||
_results[result.ExplanationId] = result;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public Task<ExplanationResult?> GetAsync(string explanationId, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(_results.GetValueOrDefault(explanationId));
|
||||
|
||||
public Task<ExplanationRequest?> GetRequestAsync(string explanationId, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(_requests.GetValueOrDefault(explanationId));
|
||||
|
||||
public void StoreRequest(string explanationId, ExplanationRequest request)
|
||||
=> _requests[explanationId] = request;
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
@@ -0,0 +1,500 @@
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using FluentAssertions;
|
||||
using StellaOps.AdvisoryAI.Explanation;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
/// <summary>
|
||||
/// Golden tests for deterministic explanation replay.
|
||||
/// Verifies that replaying an explanation with the same inputs produces identical output.
|
||||
/// Sprint: SPRINT_20251226_015_AI_zastava_companion
|
||||
/// Task: ZASTAVA-20
|
||||
/// </summary>
|
||||
public sealed class ExplanationReplayGoldenTests
|
||||
{
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ReplayAsync_WithSameInputs_ProducesIdenticalOutput()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceContext = CreateDeterministicEvidenceContext();
|
||||
var store = new InMemoryExplanationStoreWithRequests();
|
||||
var generator = CreateDeterministicGenerator(evidenceContext, store);
|
||||
|
||||
var request = CreateDeterministicRequest();
|
||||
|
||||
// Act - Generate original
|
||||
var original = await generator.GenerateAsync(request);
|
||||
store.StoreRequest(original.ExplanationId, request);
|
||||
|
||||
// Act - Replay
|
||||
var replayed = await generator.ReplayAsync(original.ExplanationId);
|
||||
|
||||
// Assert - Output should be identical
|
||||
replayed.OutputHash.Should().Be(original.OutputHash);
|
||||
replayed.Content.Should().Be(original.Content);
|
||||
replayed.InputHashes.Should().BeEquivalentTo(original.InputHashes);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ReplayAsync_PreservesExplanationStructure()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceContext = CreateDeterministicEvidenceContext();
|
||||
var store = new InMemoryExplanationStoreWithRequests();
|
||||
var generator = CreateDeterministicGenerator(evidenceContext, store);
|
||||
|
||||
var request = CreateDeterministicRequest();
|
||||
var original = await generator.GenerateAsync(request);
|
||||
store.StoreRequest(original.ExplanationId, request);
|
||||
|
||||
// Act
|
||||
var replayed = await generator.ReplayAsync(original.ExplanationId);
|
||||
|
||||
// Assert
|
||||
replayed.Citations.Count.Should().Be(original.Citations.Count);
|
||||
replayed.EvidenceRefs.Should().BeEquivalentTo(original.EvidenceRefs);
|
||||
replayed.ConfidenceScore.Should().Be(original.ConfidenceScore);
|
||||
replayed.CitationRate.Should().Be(original.CitationRate);
|
||||
replayed.Authority.Should().Be(original.Authority);
|
||||
replayed.ModelId.Should().Be(original.ModelId);
|
||||
replayed.PromptTemplateVersion.Should().Be(original.PromptTemplateVersion);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ReplayAsync_WithChangedEvidence_ThrowsException()
|
||||
{
|
||||
// Arrange
|
||||
var originalContext = CreateDeterministicEvidenceContext();
|
||||
var store = new InMemoryExplanationStoreWithRequests();
|
||||
var generator = CreateGeneratorWithChangingEvidence(store);
|
||||
|
||||
var request = CreateDeterministicRequest();
|
||||
var original = await generator.GenerateAsync(request);
|
||||
store.StoreRequest(original.ExplanationId, request);
|
||||
|
||||
// Mark evidence as changed
|
||||
generator.MarkEvidenceAsChanged();
|
||||
|
||||
// Act & Assert
|
||||
var act = async () => await generator.ReplayAsync(original.ExplanationId);
|
||||
await act.Should().ThrowAsync<InvalidOperationException>()
|
||||
.WithMessage("*evidence has changed*");
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task MultipleReplays_ProduceIdenticalResults()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceContext = CreateDeterministicEvidenceContext();
|
||||
var store = new InMemoryExplanationStoreWithRequests();
|
||||
var generator = CreateDeterministicGenerator(evidenceContext, store);
|
||||
|
||||
var request = CreateDeterministicRequest();
|
||||
var original = await generator.GenerateAsync(request);
|
||||
store.StoreRequest(original.ExplanationId, request);
|
||||
|
||||
// Act - Replay multiple times
|
||||
var replay1 = await generator.ReplayAsync(original.ExplanationId);
|
||||
var replay2 = await generator.ReplayAsync(original.ExplanationId);
|
||||
var replay3 = await generator.ReplayAsync(original.ExplanationId);
|
||||
|
||||
// Assert - All should be identical
|
||||
replay1.OutputHash.Should().Be(original.OutputHash);
|
||||
replay2.OutputHash.Should().Be(original.OutputHash);
|
||||
replay3.OutputHash.Should().Be(original.OutputHash);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task InputHashOrder_IsConsistent()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceContext = CreateDeterministicEvidenceContext();
|
||||
var store = new InMemoryExplanationStoreWithRequests();
|
||||
var generator = CreateDeterministicGenerator(evidenceContext, store);
|
||||
|
||||
var request = CreateDeterministicRequest();
|
||||
|
||||
// Act - Generate twice
|
||||
var result1 = await generator.GenerateAsync(request);
|
||||
var result2 = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert - Input hashes should be in same order
|
||||
result1.InputHashes.Should().HaveCount(3);
|
||||
result2.InputHashes.Should().HaveCount(3);
|
||||
for (int i = 0; i < result1.InputHashes.Count; i++)
|
||||
{
|
||||
result1.InputHashes[i].Should().Be(result2.InputHashes[i],
|
||||
$"Input hash at index {i} should be identical");
|
||||
}
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ExplanationId_IsDeterministicFromInputsAndOutput()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceContext = CreateDeterministicEvidenceContext();
|
||||
var store = new InMemoryExplanationStoreWithRequests();
|
||||
var generator = CreateDeterministicGenerator(evidenceContext, store);
|
||||
|
||||
var request = CreateDeterministicRequest();
|
||||
|
||||
// Act
|
||||
var result1 = await generator.GenerateAsync(request);
|
||||
var result2 = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert - Same inputs + same output = same ID
|
||||
result1.ExplanationId.Should().Be(result2.ExplanationId);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task DifferentInputs_ProduceDifferentIds()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceContext = CreateDeterministicEvidenceContext();
|
||||
var store = new InMemoryExplanationStoreWithRequests();
|
||||
var generator = CreateDeterministicGenerator(evidenceContext, store);
|
||||
|
||||
var request1 = CreateDeterministicRequest() with { VulnerabilityId = "CVE-2024-0001" };
|
||||
var request2 = CreateDeterministicRequest() with { VulnerabilityId = "CVE-2024-0002" };
|
||||
|
||||
// Act
|
||||
var result1 = await generator.GenerateAsync(request1);
|
||||
var result2 = await generator.GenerateAsync(request2);
|
||||
|
||||
// Assert
|
||||
result1.ExplanationId.Should().NotBe(result2.ExplanationId);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GoldenOutput_MatchesExpectedFormat()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceContext = CreateDeterministicEvidenceContext();
|
||||
var store = new InMemoryExplanationStoreWithRequests();
|
||||
var generator = CreateDeterministicGenerator(evidenceContext, store);
|
||||
|
||||
var request = CreateDeterministicRequest();
|
||||
|
||||
// Act
|
||||
var result = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert - Verify golden format
|
||||
result.ExplanationId.Should().StartWith("sha256:");
|
||||
result.OutputHash.Should().HaveLength(64); // SHA-256 hex
|
||||
result.GeneratedAt.Should().MatchRegex(@"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}");
|
||||
result.InputHashes.Should().AllSatisfy(h => h.Length.Should().Be(64));
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CitationVerification_IsDeterministic()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceContext = CreateDeterministicEvidenceContext();
|
||||
var store = new InMemoryExplanationStoreWithRequests();
|
||||
var generator = CreateDeterministicGenerator(evidenceContext, store);
|
||||
|
||||
var request = CreateDeterministicRequest();
|
||||
|
||||
// Act
|
||||
var result1 = await generator.GenerateAsync(request);
|
||||
var result2 = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert - Citations should be identical in order and verification status
|
||||
result1.Citations.Count.Should().Be(result2.Citations.Count);
|
||||
for (int i = 0; i < result1.Citations.Count; i++)
|
||||
{
|
||||
result1.Citations[i].ClaimText.Should().Be(result2.Citations[i].ClaimText);
|
||||
result1.Citations[i].EvidenceId.Should().Be(result2.Citations[i].EvidenceId);
|
||||
result1.Citations[i].Verified.Should().Be(result2.Citations[i].Verified);
|
||||
}
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SummaryGeneration_IsDeterministic()
|
||||
{
|
||||
// Arrange
|
||||
var evidenceContext = CreateDeterministicEvidenceContext();
|
||||
var store = new InMemoryExplanationStoreWithRequests();
|
||||
var generator = CreateDeterministicGenerator(evidenceContext, store);
|
||||
|
||||
var request = CreateDeterministicRequest();
|
||||
|
||||
// Act
|
||||
var result1 = await generator.GenerateAsync(request);
|
||||
var result2 = await generator.GenerateAsync(request);
|
||||
|
||||
// Assert
|
||||
result1.Summary.Line1.Should().Be(result2.Summary.Line1);
|
||||
result1.Summary.Line2.Should().Be(result2.Summary.Line2);
|
||||
result1.Summary.Line3.Should().Be(result2.Summary.Line3);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void OutputHash_MatchesContentHash()
|
||||
{
|
||||
// Arrange
|
||||
var content = "This is deterministic explanation content.";
|
||||
var expectedHash = ComputeHash(content);
|
||||
|
||||
// Act
|
||||
var actualHash = ComputeHash(content);
|
||||
|
||||
// Assert
|
||||
actualHash.Should().Be(expectedHash);
|
||||
actualHash.Should().HaveLength(64);
|
||||
}
|
||||
|
||||
#region Helper Methods
|
||||
|
||||
private static ExplanationRequest CreateDeterministicRequest() => new()
|
||||
{
|
||||
FindingId = "golden-finding-001",
|
||||
ArtifactDigest = "sha256:golden123abc",
|
||||
Scope = "image",
|
||||
ScopeId = "golden-image:v1.0.0",
|
||||
ExplanationType = ExplanationType.Full,
|
||||
VulnerabilityId = "CVE-2024-GOLDEN",
|
||||
ComponentPurl = "pkg:npm/golden-pkg@1.0.0",
|
||||
PlainLanguage = false,
|
||||
MaxLength = 0,
|
||||
CorrelationId = "golden-corr-001"
|
||||
};
|
||||
|
||||
private static EvidenceContext CreateDeterministicEvidenceContext() => new()
|
||||
{
|
||||
SbomEvidence =
|
||||
[
|
||||
new EvidenceNode
|
||||
{
|
||||
Id = "golden-ev-001",
|
||||
Type = "sbom",
|
||||
Summary = "Golden component found",
|
||||
Content = "Package: golden-pkg, Version: 1.0.0",
|
||||
Source = "golden-sbom",
|
||||
Confidence = 0.99,
|
||||
CollectedAt = "2024-01-01T00:00:00Z"
|
||||
}
|
||||
],
|
||||
ReachabilityEvidence =
|
||||
[
|
||||
new EvidenceNode
|
||||
{
|
||||
Id = "golden-ev-002",
|
||||
Type = "reachability",
|
||||
Summary = "Golden function reachable",
|
||||
Content = "Call path: entry -> golden_func()",
|
||||
Source = "golden-analysis",
|
||||
Confidence = 0.95,
|
||||
CollectedAt = "2024-01-01T00:00:01Z"
|
||||
}
|
||||
],
|
||||
RuntimeEvidence = [],
|
||||
VexEvidence = [],
|
||||
PatchEvidence = [],
|
||||
ContextHash = ComputeHash("golden-evidence-context-v1")
|
||||
};
|
||||
|
||||
private static EvidenceAnchoredExplanationGenerator CreateDeterministicGenerator(
|
||||
EvidenceContext evidenceContext,
|
||||
InMemoryExplanationStoreWithRequests store)
|
||||
{
|
||||
var evidenceService = new DeterministicEvidenceService(evidenceContext);
|
||||
var promptService = new DeterministicPromptService();
|
||||
var inferenceClient = new DeterministicInferenceClient();
|
||||
var citationExtractor = new DeterministicCitationExtractor();
|
||||
|
||||
return new EvidenceAnchoredExplanationGenerator(
|
||||
evidenceService, promptService, inferenceClient, citationExtractor, store);
|
||||
}
|
||||
|
||||
private static ChangingEvidenceGenerator CreateGeneratorWithChangingEvidence(
|
||||
InMemoryExplanationStoreWithRequests store)
|
||||
{
|
||||
return new ChangingEvidenceGenerator(store);
|
||||
}
|
||||
|
||||
private static string ComputeHash(string content)
|
||||
{
|
||||
var bytes = SHA256.HashData(Encoding.UTF8.GetBytes(content));
|
||||
return Convert.ToHexStringLower(bytes);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Deterministic Test Doubles
|
||||
|
||||
private sealed class DeterministicEvidenceService : IEvidenceRetrievalService
|
||||
{
|
||||
private readonly EvidenceContext _context;
|
||||
|
||||
public DeterministicEvidenceService(EvidenceContext context) => _context = context;
|
||||
|
||||
public Task<EvidenceContext> RetrieveEvidenceAsync(
|
||||
string findingId, string artifactDigest, string vulnerabilityId,
|
||||
string? componentPurl = null, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(_context);
|
||||
|
||||
public Task<EvidenceNode?> GetEvidenceNodeAsync(string evidenceId, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(_context.AllEvidence.FirstOrDefault(e => e.Id == evidenceId));
|
||||
|
||||
public Task<bool> ValidateEvidenceAsync(IEnumerable<string> evidenceIds, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(true);
|
||||
}
|
||||
|
||||
private sealed class DeterministicPromptService : IExplanationPromptService
|
||||
{
|
||||
public Task<ExplanationPrompt> BuildPromptAsync(
|
||||
ExplanationRequest request, EvidenceContext evidence, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(new ExplanationPrompt
|
||||
{
|
||||
Content = $"GOLDEN_PROMPT:{request.VulnerabilityId}:{evidence.ContextHash}",
|
||||
TemplateVersion = "golden-template-v1.0"
|
||||
});
|
||||
|
||||
public Task<ExplanationSummary> GenerateSummaryAsync(
|
||||
string content, ExplanationType type, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(new ExplanationSummary
|
||||
{
|
||||
Line1 = "Golden: What happened",
|
||||
Line2 = "Golden: Why it matters",
|
||||
Line3 = "Golden: Next steps"
|
||||
});
|
||||
}
|
||||
|
||||
private sealed class DeterministicInferenceClient : IExplanationInferenceClient
|
||||
{
|
||||
public Task<ExplanationInferenceResult> GenerateAsync(
|
||||
ExplanationPrompt prompt, CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Deterministic output based on prompt hash
|
||||
var content = $"GOLDEN_EXPLANATION:hash={ComputeHash(prompt.Content)}";
|
||||
return Task.FromResult(new ExplanationInferenceResult
|
||||
{
|
||||
Content = content,
|
||||
Confidence = 0.95,
|
||||
ModelId = "golden-model:v1.0"
|
||||
});
|
||||
}
|
||||
|
||||
private static string ComputeHash(string content)
|
||||
{
|
||||
var bytes = SHA256.HashData(Encoding.UTF8.GetBytes(content));
|
||||
return Convert.ToHexStringLower(bytes)[..16];
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class DeterministicCitationExtractor : ICitationExtractor
|
||||
{
|
||||
public Task<IReadOnlyList<ExplanationCitation>> ExtractCitationsAsync(
|
||||
string content, EvidenceContext evidence, CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Deterministic citations based on evidence order
|
||||
var citations = evidence.AllEvidence.Select((ev, i) => new ExplanationCitation
|
||||
{
|
||||
ClaimText = $"Golden claim {i + 1}",
|
||||
EvidenceId = ev.Id,
|
||||
EvidenceType = ev.Type,
|
||||
Verified = true,
|
||||
EvidenceExcerpt = ev.Summary
|
||||
}).ToList();
|
||||
|
||||
return Task.FromResult<IReadOnlyList<ExplanationCitation>>(citations);
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class InMemoryExplanationStoreWithRequests : IExplanationStore
|
||||
{
|
||||
private readonly Dictionary<string, ExplanationResult> _results = new();
|
||||
private readonly Dictionary<string, ExplanationRequest> _requests = new();
|
||||
|
||||
public Task StoreAsync(ExplanationResult result, CancellationToken cancellationToken = default)
|
||||
{
|
||||
_results[result.ExplanationId] = result;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public Task<ExplanationResult?> GetAsync(string explanationId, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(_results.GetValueOrDefault(explanationId));
|
||||
|
||||
public Task<ExplanationRequest?> GetRequestAsync(string explanationId, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(_requests.GetValueOrDefault(explanationId));
|
||||
|
||||
public void StoreRequest(string explanationId, ExplanationRequest request)
|
||||
=> _requests[explanationId] = request;
|
||||
}
|
||||
|
||||
private sealed class ChangingEvidenceGenerator : IExplanationGenerator
|
||||
{
|
||||
private readonly InMemoryExplanationStoreWithRequests _store;
|
||||
private bool _evidenceChanged = false;
|
||||
|
||||
public ChangingEvidenceGenerator(InMemoryExplanationStoreWithRequests store)
|
||||
{
|
||||
_store = store;
|
||||
}
|
||||
|
||||
public void MarkEvidenceAsChanged() => _evidenceChanged = true;
|
||||
|
||||
public async Task<ExplanationResult> GenerateAsync(ExplanationRequest request, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var result = new ExplanationResult
|
||||
{
|
||||
ExplanationId = $"sha256:{ComputeHash(JsonSerializer.Serialize(request))}",
|
||||
Content = "Test content",
|
||||
Summary = new ExplanationSummary { Line1 = "L1", Line2 = "L2", Line3 = "L3" },
|
||||
Citations = [],
|
||||
ConfidenceScore = 0.9,
|
||||
CitationRate = 0.9,
|
||||
Authority = ExplanationAuthority.EvidenceBacked,
|
||||
EvidenceRefs = ["ev-001"],
|
||||
ModelId = "test-model",
|
||||
PromptTemplateVersion = "v1",
|
||||
InputHashes = [ComputeHash("input")],
|
||||
GeneratedAt = DateTime.UtcNow.ToString("O"),
|
||||
OutputHash = ComputeHash("Test content")
|
||||
};
|
||||
|
||||
await _store.StoreAsync(result, cancellationToken);
|
||||
return result;
|
||||
}
|
||||
|
||||
public async Task<ExplanationResult> ReplayAsync(string explanationId, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var original = await _store.GetAsync(explanationId, cancellationToken)
|
||||
?? throw new InvalidOperationException($"Explanation {explanationId} not found");
|
||||
|
||||
if (_evidenceChanged)
|
||||
{
|
||||
throw new InvalidOperationException("Input evidence has changed since original explanation");
|
||||
}
|
||||
|
||||
return original;
|
||||
}
|
||||
|
||||
public Task<bool> ValidateAsync(ExplanationResult result, CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(!_evidenceChanged);
|
||||
|
||||
private static string ComputeHash(string content)
|
||||
{
|
||||
var bytes = SHA256.HashData(Encoding.UTF8.GetBytes(content));
|
||||
return Convert.ToHexStringLower(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
@@ -19,13 +19,15 @@ using StellaOps.AdvisoryAI.Tools;
|
||||
using StellaOps.AdvisoryAI.Tests.TestUtilities;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class FileSystemAdvisoryOutputStoreTests : IDisposable
|
||||
{
|
||||
private readonly TempDirectory _temp = TempDirectory.Create();
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SaveAndRetrieve_RoundTripsOutput()
|
||||
{
|
||||
var store = CreateStore();
|
||||
@@ -41,7 +43,8 @@ public sealed class FileSystemAdvisoryOutputStoreTests : IDisposable
|
||||
retrieved.Metadata["inference.model_id"].Should().Be("local.prompt-preview");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task TryGetAsync_ReturnsNullWhenFileMissing()
|
||||
{
|
||||
var store = CreateStore();
|
||||
|
||||
@@ -19,13 +19,15 @@ using StellaOps.AdvisoryAI.Prompting;
|
||||
using StellaOps.AdvisoryAI.Tools;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class FileSystemAdvisoryPersistenceTests : IDisposable
|
||||
{
|
||||
private readonly TempDirectory _tempDir = new();
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task PlanCache_PersistsPlanOnDisk()
|
||||
{
|
||||
var serviceOptions = Options.Create(new AdvisoryAiServiceOptions
|
||||
@@ -54,7 +56,8 @@ public sealed class FileSystemAdvisoryPersistenceTests : IDisposable
|
||||
reloaded.Metadata.Should().ContainKey("advisory_key").WhoseValue.Should().Be("adv-key");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task OutputStore_PersistsOutputOnDisk()
|
||||
{
|
||||
var serviceOptions = Options.Create(new AdvisoryAiServiceOptions
|
||||
|
||||
@@ -16,13 +16,15 @@ using StellaOps.AdvisoryAI.Tools;
|
||||
using StellaOps.AdvisoryAI.Tests.TestUtilities;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class FileSystemAdvisoryPlanCacheTests : IDisposable
|
||||
{
|
||||
private readonly TempDirectory _temp = TempDirectory.Create();
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SetAndRetrieve_RoundTripsPlan()
|
||||
{
|
||||
var cache = CreateCache();
|
||||
@@ -36,7 +38,8 @@ public sealed class FileSystemAdvisoryPlanCacheTests : IDisposable
|
||||
retrieved.Metadata.Should().ContainKey("task_type");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task TryGetAsync_WhenExpired_ReturnsNull()
|
||||
{
|
||||
var clock = new DeterministicTimeProvider(new DateTimeOffset(2025, 11, 9, 0, 0, 0, TimeSpan.Zero));
|
||||
@@ -50,7 +53,8 @@ public sealed class FileSystemAdvisoryPlanCacheTests : IDisposable
|
||||
retrieved.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task BulkSeedAsync_RemainsDeterministicAcrossInstances()
|
||||
{
|
||||
var clock = new DeterministicTimeProvider(new DateTimeOffset(2025, 11, 9, 0, 0, 0, TimeSpan.Zero));
|
||||
|
||||
@@ -8,6 +8,7 @@ using StellaOps.AdvisoryAI.Orchestration;
|
||||
using StellaOps.AdvisoryAI.Queue;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class FileSystemAdvisoryTaskQueueTests : IDisposable
|
||||
@@ -20,7 +21,8 @@ public sealed class FileSystemAdvisoryTaskQueueTests : IDisposable
|
||||
Directory.CreateDirectory(_root);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task EnqueueAndDequeue_RoundTripsMessage()
|
||||
{
|
||||
var options = Options.Create(new AdvisoryAiServiceOptions
|
||||
|
||||
@@ -0,0 +1,795 @@
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using StellaOps.AdvisoryAI.Inference;
|
||||
using StellaOps.AdvisoryAI.Inference.LlmProviders;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
/// <summary>
|
||||
/// Integration tests for offline AI inference infrastructure.
|
||||
/// Sprint: SPRINT_20251226_019_AI_offline_inference
|
||||
/// Task: OFFLINE-25
|
||||
/// </summary>
|
||||
public sealed class OfflineInferenceIntegrationTests : IDisposable
|
||||
{
|
||||
private readonly string _tempPath;
|
||||
private readonly InMemoryLlmInferenceCache _cache;
|
||||
private readonly StubLlmProvider _stubProvider;
|
||||
|
||||
public OfflineInferenceIntegrationTests()
|
||||
{
|
||||
_tempPath = Path.Combine(Path.GetTempPath(), $"stellaops_offline_tests_{Guid.NewGuid():N}");
|
||||
Directory.CreateDirectory(_tempPath);
|
||||
|
||||
var cacheOptions = Options.Create(new LlmInferenceCacheOptions
|
||||
{
|
||||
Enabled = true,
|
||||
DeterministicOnly = true,
|
||||
DefaultTtl = TimeSpan.FromDays(7)
|
||||
});
|
||||
|
||||
_cache = new InMemoryLlmInferenceCache(
|
||||
cacheOptions,
|
||||
NullLogger<InMemoryLlmInferenceCache>.Instance);
|
||||
|
||||
_stubProvider = new StubLlmProvider();
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_cache.Dispose();
|
||||
try
|
||||
{
|
||||
if (Directory.Exists(_tempPath))
|
||||
{
|
||||
Directory.Delete(_tempPath, recursive: true);
|
||||
}
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
}
|
||||
|
||||
#region Local Inference Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CompleteAsync_WithDeterministicSettings_ReturnsDeterministicResult()
|
||||
{
|
||||
// Arrange
|
||||
var request = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = "Analyze CVE-2024-1234 for log4j",
|
||||
SystemPrompt = "You are a security analyst.",
|
||||
Temperature = 0,
|
||||
Seed = 42,
|
||||
MaxTokens = 1024
|
||||
};
|
||||
|
||||
// Act
|
||||
var result1 = await _stubProvider.CompleteAsync(request);
|
||||
var result2 = await _stubProvider.CompleteAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.True(result1.Deterministic);
|
||||
Assert.True(result2.Deterministic);
|
||||
Assert.Equal(result1.Content, result2.Content);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CompleteAsync_WithProviderAvailabilityCheck_ReturnsTrue()
|
||||
{
|
||||
// Act
|
||||
var isAvailable = await _stubProvider.IsAvailableAsync();
|
||||
|
||||
// Assert
|
||||
Assert.True(isAvailable);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CompleteStreamAsync_YieldsChunks()
|
||||
{
|
||||
// Arrange
|
||||
var request = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = "Test streaming",
|
||||
Temperature = 0
|
||||
};
|
||||
|
||||
// Act
|
||||
var chunks = new List<LlmStreamChunk>();
|
||||
await foreach (var chunk in _stubProvider.CompleteStreamAsync(request))
|
||||
{
|
||||
chunks.Add(chunk);
|
||||
}
|
||||
|
||||
// Assert
|
||||
Assert.NotEmpty(chunks);
|
||||
Assert.Contains(chunks, c => c.IsFinal);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Inference Cache Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Cache_DeterministicRequest_CachesResult()
|
||||
{
|
||||
// Arrange
|
||||
var request = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = "Cached prompt",
|
||||
Temperature = 0,
|
||||
Seed = 42
|
||||
};
|
||||
|
||||
var result = new LlmCompletionResult
|
||||
{
|
||||
Content = "Cached response",
|
||||
ModelId = "test-model",
|
||||
ProviderId = "stub",
|
||||
Deterministic = true,
|
||||
OutputTokens = 10
|
||||
};
|
||||
|
||||
// Act
|
||||
await _cache.SetAsync(request, "stub", result);
|
||||
var cached = await _cache.TryGetAsync(request, "stub");
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(cached);
|
||||
Assert.Equal(result.Content, cached.Content);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Cache_NonDeterministicRequest_DoesNotCache()
|
||||
{
|
||||
// Arrange
|
||||
var options = Options.Create(new LlmInferenceCacheOptions
|
||||
{
|
||||
Enabled = true,
|
||||
DeterministicOnly = true
|
||||
});
|
||||
|
||||
using var cache = new InMemoryLlmInferenceCache(
|
||||
options, NullLogger<InMemoryLlmInferenceCache>.Instance);
|
||||
|
||||
var request = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = "Non-deterministic",
|
||||
Temperature = 0.7 // Non-deterministic
|
||||
};
|
||||
|
||||
var result = new LlmCompletionResult
|
||||
{
|
||||
Content = "Response",
|
||||
ModelId = "test-model",
|
||||
ProviderId = "stub",
|
||||
Deterministic = false
|
||||
};
|
||||
|
||||
// Act
|
||||
await cache.SetAsync(request, "stub", result);
|
||||
var cached = await cache.TryGetAsync(request, "stub");
|
||||
|
||||
// Assert
|
||||
Assert.Null(cached);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Cache_SameInputsDifferentSeeds_SeparateCacheEntries()
|
||||
{
|
||||
// Arrange
|
||||
var request1 = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = "Test prompt",
|
||||
Temperature = 0,
|
||||
Seed = 42
|
||||
};
|
||||
|
||||
var request2 = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = "Test prompt",
|
||||
Temperature = 0,
|
||||
Seed = 123
|
||||
};
|
||||
|
||||
var result1 = new LlmCompletionResult
|
||||
{
|
||||
Content = "Response with seed 42",
|
||||
ModelId = "test-model",
|
||||
ProviderId = "stub",
|
||||
Deterministic = true
|
||||
};
|
||||
|
||||
var result2 = new LlmCompletionResult
|
||||
{
|
||||
Content = "Response with seed 123",
|
||||
ModelId = "test-model",
|
||||
ProviderId = "stub",
|
||||
Deterministic = true
|
||||
};
|
||||
|
||||
// Act
|
||||
await _cache.SetAsync(request1, "stub", result1);
|
||||
await _cache.SetAsync(request2, "stub", result2);
|
||||
|
||||
var cached1 = await _cache.TryGetAsync(request1, "stub");
|
||||
var cached2 = await _cache.TryGetAsync(request2, "stub");
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(cached1);
|
||||
Assert.NotNull(cached2);
|
||||
Assert.NotEqual(cached1.Content, cached2.Content);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Cache_Statistics_TracksHitsAndMisses()
|
||||
{
|
||||
// Act
|
||||
var stats = _cache.GetStatistics();
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(stats);
|
||||
Assert.Equal(0, stats.Hits);
|
||||
Assert.True(stats.HitRate >= 0 && stats.HitRate <= 1);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CachingLlmProvider_UsesCache()
|
||||
{
|
||||
// Arrange
|
||||
var countingProvider = new CallCountingLlmProvider();
|
||||
var cachingProvider = new CachingLlmProvider(
|
||||
countingProvider,
|
||||
_cache,
|
||||
NullLogger<CachingLlmProvider>.Instance);
|
||||
|
||||
var request = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = "Test caching",
|
||||
Temperature = 0,
|
||||
Seed = 42
|
||||
};
|
||||
|
||||
// Act - First call hits provider
|
||||
var result1 = await cachingProvider.CompleteAsync(request);
|
||||
Assert.Equal(1, countingProvider.CallCount);
|
||||
|
||||
// Act - Second call should use cache
|
||||
var result2 = await cachingProvider.CompleteAsync(request);
|
||||
Assert.Equal(1, countingProvider.CallCount); // Still 1, used cache
|
||||
|
||||
// Assert
|
||||
Assert.Equal(result1.Content, result2.Content);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Bundle Verification Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task BundleManager_VerifyBundle_ValidBundle_ReturnsValid()
|
||||
{
|
||||
// Arrange
|
||||
var bundlePath = Path.Combine(_tempPath, "valid-bundle");
|
||||
CreateValidBundle(bundlePath);
|
||||
|
||||
var manager = new FileSystemModelBundleManager(_tempPath);
|
||||
|
||||
// Act
|
||||
var result = await manager.VerifyBundleAsync(bundlePath);
|
||||
|
||||
// Assert
|
||||
Assert.True(result.Valid);
|
||||
Assert.Empty(result.FailedFiles);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task BundleManager_VerifyBundle_MissingManifest_ReturnsInvalid()
|
||||
{
|
||||
// Arrange
|
||||
var bundlePath = Path.Combine(_tempPath, "no-manifest");
|
||||
Directory.CreateDirectory(bundlePath);
|
||||
|
||||
var manager = new FileSystemModelBundleManager(_tempPath);
|
||||
|
||||
// Act
|
||||
var result = await manager.VerifyBundleAsync(bundlePath);
|
||||
|
||||
// Assert
|
||||
Assert.False(result.Valid);
|
||||
Assert.NotNull(result.ErrorMessage);
|
||||
Assert.Contains("manifest.json", result.ErrorMessage);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task BundleManager_VerifyBundle_CorruptedFile_ReturnsInvalid()
|
||||
{
|
||||
// Arrange
|
||||
var bundlePath = Path.Combine(_tempPath, "corrupted-bundle");
|
||||
CreateValidBundle(bundlePath);
|
||||
|
||||
// Corrupt a file
|
||||
var modelFile = Path.Combine(bundlePath, "model.gguf");
|
||||
await File.WriteAllTextAsync(modelFile, "corrupted data");
|
||||
|
||||
var manager = new FileSystemModelBundleManager(_tempPath);
|
||||
|
||||
// Act
|
||||
var result = await manager.VerifyBundleAsync(bundlePath);
|
||||
|
||||
// Assert
|
||||
Assert.False(result.Valid);
|
||||
Assert.NotEmpty(result.FailedFiles);
|
||||
Assert.Contains(result.FailedFiles, f => f.Contains("model.gguf"));
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task BundleManager_VerifyBundle_MissingFile_ReturnsInvalid()
|
||||
{
|
||||
// Arrange
|
||||
var bundlePath = Path.Combine(_tempPath, "missing-file-bundle");
|
||||
CreateValidBundle(bundlePath);
|
||||
|
||||
// Delete a file
|
||||
File.Delete(Path.Combine(bundlePath, "tokenizer.json"));
|
||||
|
||||
var manager = new FileSystemModelBundleManager(_tempPath);
|
||||
|
||||
// Act
|
||||
var result = await manager.VerifyBundleAsync(bundlePath);
|
||||
|
||||
// Assert
|
||||
Assert.False(result.Valid);
|
||||
Assert.NotEmpty(result.FailedFiles);
|
||||
Assert.Contains(result.FailedFiles, f => f.Contains("missing"));
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task BundleManager_ListBundles_ReturnsAvailableBundles()
|
||||
{
|
||||
// Arrange
|
||||
CreateValidBundle(Path.Combine(_tempPath, "bundle1"));
|
||||
CreateValidBundle(Path.Combine(_tempPath, "bundle2"));
|
||||
|
||||
var manager = new FileSystemModelBundleManager(_tempPath);
|
||||
|
||||
// Act
|
||||
var bundles = await manager.ListBundlesAsync();
|
||||
|
||||
// Assert
|
||||
Assert.Equal(2, bundles.Count);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task BundleManager_GetManifest_ExistingBundle_ReturnsManifest()
|
||||
{
|
||||
// Arrange
|
||||
CreateValidBundle(Path.Combine(_tempPath, "test-bundle"));
|
||||
var manager = new FileSystemModelBundleManager(_tempPath);
|
||||
|
||||
// Act
|
||||
var manifest = await manager.GetManifestAsync("test-bundle");
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(manifest);
|
||||
Assert.Equal("test-model", manifest.Name);
|
||||
Assert.Equal("Apache-2.0", manifest.License);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task BundleManager_GetManifest_NonExistentBundle_ReturnsNull()
|
||||
{
|
||||
// Arrange
|
||||
var manager = new FileSystemModelBundleManager(_tempPath);
|
||||
|
||||
// Act
|
||||
var manifest = await manager.GetManifestAsync("nonexistent");
|
||||
|
||||
// Assert
|
||||
Assert.Null(manifest);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Offline Replay Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task OfflineReplay_SameInputs_ProducesSameOutput()
|
||||
{
|
||||
// Arrange
|
||||
var request = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = "Analyze vulnerability impact",
|
||||
SystemPrompt = "You are a security expert.",
|
||||
Temperature = 0,
|
||||
Seed = 42,
|
||||
MaxTokens = 1024
|
||||
};
|
||||
|
||||
// Simulate first run
|
||||
var originalResult = await _stubProvider.CompleteAsync(request);
|
||||
await _cache.SetAsync(request, "stub", originalResult);
|
||||
|
||||
// Simulate replay (offline)
|
||||
var replayResult = await _cache.TryGetAsync(request, "stub");
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(replayResult);
|
||||
Assert.Equal(originalResult.Content, replayResult.Content);
|
||||
Assert.True(originalResult.Deterministic);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task OfflineReplay_DifferentInputs_DifferentOutput()
|
||||
{
|
||||
// Arrange
|
||||
var request1 = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = "Input A",
|
||||
Temperature = 0,
|
||||
Seed = 42
|
||||
};
|
||||
|
||||
var request2 = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = "Input B",
|
||||
Temperature = 0,
|
||||
Seed = 42
|
||||
};
|
||||
|
||||
// Act
|
||||
var result1 = await _stubProvider.CompleteAsync(request1);
|
||||
var result2 = await _stubProvider.CompleteAsync(request2);
|
||||
|
||||
// Assert
|
||||
Assert.NotEqual(result1.Content, result2.Content);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Cache_Invalidation_RemovesEntries()
|
||||
{
|
||||
// Arrange
|
||||
var request = new LlmCompletionRequest
|
||||
{
|
||||
UserPrompt = "To be invalidated",
|
||||
Temperature = 0,
|
||||
Seed = 42
|
||||
};
|
||||
|
||||
var result = new LlmCompletionResult
|
||||
{
|
||||
Content = "Cached content",
|
||||
ModelId = "test-model",
|
||||
ProviderId = "stub",
|
||||
Deterministic = true
|
||||
};
|
||||
|
||||
await _cache.SetAsync(request, "stub", result);
|
||||
|
||||
// Verify it's cached
|
||||
var cached = await _cache.TryGetAsync(request, "stub");
|
||||
Assert.NotNull(cached);
|
||||
|
||||
// Act - Invalidate
|
||||
await _cache.InvalidateAsync("stub");
|
||||
|
||||
// Assert
|
||||
var afterInvalidation = await _cache.TryGetAsync(request, "stub");
|
||||
Assert.Null(afterInvalidation);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region LocalLlmConfig Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void LocalLlmConfig_DefaultValues_AreCorrect()
|
||||
{
|
||||
// Arrange & Act
|
||||
var config = new LocalLlmConfig
|
||||
{
|
||||
ModelPath = "/models/test.gguf",
|
||||
WeightsDigest = "abc123"
|
||||
};
|
||||
|
||||
// Assert
|
||||
Assert.Equal(ModelQuantization.Q4_K_M, config.Quantization);
|
||||
Assert.Equal(4096, config.ContextLength);
|
||||
Assert.Equal(InferenceDevice.Auto, config.Device);
|
||||
Assert.Equal(0, config.Temperature);
|
||||
Assert.Equal(42, config.Seed);
|
||||
Assert.True(config.FlashAttention);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void LocalLlmConfig_CustomValues_AreApplied()
|
||||
{
|
||||
// Arrange & Act
|
||||
var config = new LocalLlmConfig
|
||||
{
|
||||
ModelPath = "/models/llama3-8b.gguf",
|
||||
WeightsDigest = "sha256:abc123def456",
|
||||
Quantization = ModelQuantization.FP16,
|
||||
ContextLength = 8192,
|
||||
Device = InferenceDevice.CUDA,
|
||||
GpuLayers = 32,
|
||||
Threads = 8,
|
||||
Temperature = 0,
|
||||
Seed = 12345,
|
||||
FlashAttention = false,
|
||||
MaxTokens = 4096
|
||||
};
|
||||
|
||||
// Assert
|
||||
Assert.Equal("/models/llama3-8b.gguf", config.ModelPath);
|
||||
Assert.Equal(ModelQuantization.FP16, config.Quantization);
|
||||
Assert.Equal(8192, config.ContextLength);
|
||||
Assert.Equal(InferenceDevice.CUDA, config.Device);
|
||||
Assert.Equal(32, config.GpuLayers);
|
||||
Assert.Equal(12345, config.Seed);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Fallback Provider Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task FallbackLlmProvider_FirstAvailable_UsesFirstProvider()
|
||||
{
|
||||
// Arrange
|
||||
var factory = new StubLlmProviderFactory(new Dictionary<string, ILlmProvider>
|
||||
{
|
||||
["primary"] = new StubLlmProvider { IsAvailableResult = true, ProviderIdOverride = "primary" },
|
||||
["fallback"] = new StubLlmProvider { IsAvailableResult = true, ProviderIdOverride = "fallback" }
|
||||
});
|
||||
|
||||
var fallbackProvider = new FallbackLlmProvider(
|
||||
factory,
|
||||
new[] { "primary", "fallback" },
|
||||
NullLogger<FallbackLlmProvider>.Instance);
|
||||
|
||||
var request = new LlmCompletionRequest { UserPrompt = "Test" };
|
||||
|
||||
// Act
|
||||
var result = await fallbackProvider.CompleteAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.Equal("primary", result.ProviderId);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task FallbackLlmProvider_FirstUnavailable_UsesFallback()
|
||||
{
|
||||
// Arrange
|
||||
var factory = new StubLlmProviderFactory(new Dictionary<string, ILlmProvider>
|
||||
{
|
||||
["primary"] = new StubLlmProvider { IsAvailableResult = false, ProviderIdOverride = "primary" },
|
||||
["fallback"] = new StubLlmProvider { IsAvailableResult = true, ProviderIdOverride = "fallback" }
|
||||
});
|
||||
|
||||
var fallbackProvider = new FallbackLlmProvider(
|
||||
factory,
|
||||
new[] { "primary", "fallback" },
|
||||
NullLogger<FallbackLlmProvider>.Instance);
|
||||
|
||||
var request = new LlmCompletionRequest { UserPrompt = "Test" };
|
||||
|
||||
// Act
|
||||
var result = await fallbackProvider.CompleteAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.Equal("fallback", result.ProviderId);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task FallbackLlmProvider_AllUnavailable_ThrowsException()
|
||||
{
|
||||
// Arrange
|
||||
var factory = new StubLlmProviderFactory(new Dictionary<string, ILlmProvider>
|
||||
{
|
||||
["primary"] = new StubLlmProvider { IsAvailableResult = false },
|
||||
["fallback"] = new StubLlmProvider { IsAvailableResult = false }
|
||||
});
|
||||
|
||||
var fallbackProvider = new FallbackLlmProvider(
|
||||
factory,
|
||||
new[] { "primary", "fallback" },
|
||||
NullLogger<FallbackLlmProvider>.Instance);
|
||||
|
||||
var request = new LlmCompletionRequest { UserPrompt = "Test" };
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(() =>
|
||||
fallbackProvider.CompleteAsync(request));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Helper Methods
|
||||
|
||||
private void CreateValidBundle(string bundlePath)
|
||||
{
|
||||
Directory.CreateDirectory(bundlePath);
|
||||
|
||||
// Create model file
|
||||
var modelContent = "fake model weights for testing";
|
||||
var modelPath = Path.Combine(bundlePath, "model.gguf");
|
||||
File.WriteAllText(modelPath, modelContent);
|
||||
|
||||
// Create tokenizer file
|
||||
var tokenizerContent = "{\"vocab_size\": 32000}";
|
||||
var tokenizerPath = Path.Combine(bundlePath, "tokenizer.json");
|
||||
File.WriteAllText(tokenizerPath, tokenizerContent);
|
||||
|
||||
// Compute digests
|
||||
using var sha256 = SHA256.Create();
|
||||
var modelDigest = Convert.ToHexStringLower(sha256.ComputeHash(Encoding.UTF8.GetBytes(modelContent)));
|
||||
var tokenizerDigest = Convert.ToHexStringLower(sha256.ComputeHash(Encoding.UTF8.GetBytes(tokenizerContent)));
|
||||
|
||||
// Create manifest
|
||||
var manifest = new ModelBundleManifest
|
||||
{
|
||||
Name = "test-model",
|
||||
License = "Apache-2.0",
|
||||
SizeCategory = "7B",
|
||||
Quantizations = new[] { "Q4_K_M", "FP16" },
|
||||
CreatedAt = DateTime.UtcNow.ToString("o"),
|
||||
Files = new[]
|
||||
{
|
||||
new BundleFile { Path = "model.gguf", Digest = modelDigest, Size = modelContent.Length, Type = "weights" },
|
||||
new BundleFile { Path = "tokenizer.json", Digest = tokenizerDigest, Size = tokenizerContent.Length, Type = "tokenizer" }
|
||||
}
|
||||
};
|
||||
|
||||
var manifestJson = JsonSerializer.Serialize(manifest, new JsonSerializerOptions { WriteIndented = true });
|
||||
File.WriteAllText(Path.Combine(bundlePath, "manifest.json"), manifestJson);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Stub Implementations
|
||||
|
||||
private sealed class StubLlmProvider : ILlmProvider
|
||||
{
|
||||
public string ProviderId => ProviderIdOverride ?? "stub";
|
||||
public string? ProviderIdOverride { get; set; }
|
||||
public bool IsAvailableResult { get; set; } = true;
|
||||
|
||||
public Task<bool> IsAvailableAsync(CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(IsAvailableResult);
|
||||
|
||||
public Task<LlmCompletionResult> CompleteAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Generate deterministic output based on input hash
|
||||
using var sha = SHA256.Create();
|
||||
var inputHash = Convert.ToHexStringLower(
|
||||
sha.ComputeHash(Encoding.UTF8.GetBytes(
|
||||
$"{request.SystemPrompt}||{request.UserPrompt}||{request.Seed}")));
|
||||
|
||||
var content = $"Deterministic response for input hash: {inputHash[..16]}";
|
||||
|
||||
return Task.FromResult(new LlmCompletionResult
|
||||
{
|
||||
Content = content,
|
||||
ModelId = "stub-model",
|
||||
ProviderId = ProviderId,
|
||||
Deterministic = request.Temperature == 0,
|
||||
InputTokens = request.UserPrompt.Length / 4,
|
||||
OutputTokens = content.Length / 4,
|
||||
FinishReason = "stop",
|
||||
RequestId = request.RequestId
|
||||
});
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<LlmStreamChunk> CompleteStreamAsync(
|
||||
LlmCompletionRequest request,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var words = new[] { "This ", "is ", "a ", "streaming ", "response." };
|
||||
|
||||
foreach (var word in words)
|
||||
{
|
||||
await Task.Delay(10, cancellationToken);
|
||||
yield return new LlmStreamChunk { Content = word, IsFinal = false };
|
||||
}
|
||||
|
||||
yield return new LlmStreamChunk { Content = "", IsFinal = true, FinishReason = "stop" };
|
||||
}
|
||||
|
||||
public void Dispose() { }
|
||||
}
|
||||
|
||||
private sealed class CallCountingLlmProvider : ILlmProvider
|
||||
{
|
||||
public string ProviderId => "counting";
|
||||
public int CallCount { get; private set; }
|
||||
|
||||
public Task<bool> IsAvailableAsync(CancellationToken cancellationToken = default)
|
||||
=> Task.FromResult(true);
|
||||
|
||||
public Task<LlmCompletionResult> CompleteAsync(
|
||||
LlmCompletionRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
CallCount++;
|
||||
return Task.FromResult(new LlmCompletionResult
|
||||
{
|
||||
Content = $"Response #{CallCount}",
|
||||
ModelId = "counting-model",
|
||||
ProviderId = ProviderId,
|
||||
Deterministic = request.Temperature == 0,
|
||||
OutputTokens = 5
|
||||
});
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<LlmStreamChunk> CompleteStreamAsync(
|
||||
LlmCompletionRequest request,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
CallCount++;
|
||||
yield return new LlmStreamChunk { Content = "Response", IsFinal = true };
|
||||
await Task.CompletedTask;
|
||||
}
|
||||
|
||||
public void Dispose() { }
|
||||
}
|
||||
|
||||
private sealed class StubLlmProviderFactory : ILlmProviderFactory
|
||||
{
|
||||
private readonly Dictionary<string, ILlmProvider> _providers;
|
||||
|
||||
public StubLlmProviderFactory(Dictionary<string, ILlmProvider> providers)
|
||||
{
|
||||
_providers = providers;
|
||||
}
|
||||
|
||||
public IReadOnlyList<string> AvailableProviders => _providers.Keys.ToList();
|
||||
|
||||
public ILlmProvider GetProvider(string providerId)
|
||||
{
|
||||
if (_providers.TryGetValue(providerId, out var provider))
|
||||
return provider;
|
||||
|
||||
throw new InvalidOperationException($"Provider '{providerId}' not found");
|
||||
}
|
||||
|
||||
public ILlmProvider GetDefaultProvider() => _providers.Values.First();
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
@@ -0,0 +1,834 @@
|
||||
using FluentAssertions;
|
||||
using StellaOps.AdvisoryAI.PolicyStudio;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
/// <summary>
|
||||
/// Integration tests for Policy Studio NL→rule→test round-trip and conflict detection.
|
||||
/// Sprint: SPRINT_20251226_017_AI_policy_copilot
|
||||
/// Task: POLICY-25
|
||||
/// </summary>
|
||||
public sealed class PolicyStudioIntegrationTests
|
||||
{
|
||||
#region NL → Intent → Rule Round-Trip Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ParseAndGenerate_OverrideRule_ProducesValidLatticeRule()
|
||||
{
|
||||
// Arrange
|
||||
var parser = new StubPolicyIntentParser();
|
||||
var generator = new StubPolicyRuleGenerator();
|
||||
var synthesizer = new StubTestCaseSynthesizer();
|
||||
|
||||
var naturalLanguage = "Block all critical vulnerabilities that are reachable";
|
||||
|
||||
// Act - Parse NL to intent
|
||||
var parseResult = await parser.ParseAsync(naturalLanguage);
|
||||
parseResult.Success.Should().BeTrue();
|
||||
parseResult.Intent.IntentType.Should().Be(PolicyIntentType.OverrideRule);
|
||||
|
||||
// Act - Generate rules from intent
|
||||
var ruleResult = await generator.GenerateAsync(parseResult.Intent);
|
||||
ruleResult.Success.Should().BeTrue();
|
||||
ruleResult.Rules.Should().NotBeEmpty();
|
||||
|
||||
// Assert - Rules have correct structure
|
||||
var rule = ruleResult.Rules[0];
|
||||
rule.LatticeExpression.Should().Contain("REACHABLE");
|
||||
rule.Disposition.Should().Be("block");
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ParseAndGenerate_ExceptionRule_ProducesValidLatticeRule()
|
||||
{
|
||||
// Arrange
|
||||
var parser = new StubPolicyIntentParser();
|
||||
var generator = new StubPolicyRuleGenerator();
|
||||
|
||||
var naturalLanguage = "Allow vulnerabilities with vendor VEX not_affected status";
|
||||
|
||||
// Act
|
||||
var parseResult = await parser.ParseAsync(naturalLanguage, new PolicyParseContext
|
||||
{
|
||||
DefaultScope = "all"
|
||||
});
|
||||
|
||||
var ruleResult = await generator.GenerateAsync(parseResult.Intent);
|
||||
|
||||
// Assert
|
||||
ruleResult.Success.Should().BeTrue();
|
||||
var rule = ruleResult.Rules[0];
|
||||
rule.Disposition.Should().Be("allow");
|
||||
rule.Conditions.Should().Contain(c => c.Field == "vex_status");
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task FullRoundTrip_NLToRuleToTest_ProducesValidTestCases()
|
||||
{
|
||||
// Arrange
|
||||
var parser = new StubPolicyIntentParser();
|
||||
var generator = new StubPolicyRuleGenerator();
|
||||
var synthesizer = new StubTestCaseSynthesizer();
|
||||
|
||||
var naturalLanguage = "Block critical reachable vulnerabilities without VEX";
|
||||
|
||||
// Act - Full round-trip
|
||||
var parseResult = await parser.ParseAsync(naturalLanguage);
|
||||
var ruleResult = await generator.GenerateAsync(parseResult.Intent);
|
||||
var testCases = await synthesizer.SynthesizeAsync(ruleResult.Rules);
|
||||
|
||||
// Assert
|
||||
testCases.Should().NotBeEmpty();
|
||||
testCases.Should().Contain(t => t.Type == TestCaseType.Positive);
|
||||
testCases.Should().Contain(t => t.Type == TestCaseType.Negative);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData("Block all high severity findings", PolicyIntentType.OverrideRule)]
|
||||
[InlineData("Escalate critical vulnerabilities to security team", PolicyIntentType.EscalationRule)]
|
||||
[InlineData("Allow exceptions for internal-only services", PolicyIntentType.ExceptionCondition)]
|
||||
[InlineData("Set severity threshold to 7.0 for blocking", PolicyIntentType.ThresholdRule)]
|
||||
public async Task ParseAsync_RecognizesIntentTypes(string input, PolicyIntentType expectedType)
|
||||
{
|
||||
// Arrange
|
||||
var parser = new StubPolicyIntentParser(expectedType);
|
||||
|
||||
// Act
|
||||
var result = await parser.ParseAsync(input);
|
||||
|
||||
// Assert
|
||||
result.Success.Should().BeTrue();
|
||||
result.Intent.IntentType.Should().Be(expectedType);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Conflict Detection Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ValidateAsync_DetectsConflictingRules()
|
||||
{
|
||||
// Arrange
|
||||
var generator = new StubPolicyRuleGenerator();
|
||||
|
||||
var conflictingRules = new List<LatticeRule>
|
||||
{
|
||||
CreateRule("rule-1", "REACHABLE ∧ PRESENT", "block", priority: 10),
|
||||
CreateRule("rule-2", "REACHABLE ∧ PRESENT", "allow", priority: 20)
|
||||
};
|
||||
|
||||
// Act
|
||||
var validationResult = await generator.ValidateAsync(conflictingRules);
|
||||
|
||||
// Assert
|
||||
validationResult.Valid.Should().BeFalse();
|
||||
validationResult.Conflicts.Should().NotBeEmpty();
|
||||
validationResult.Conflicts[0].RuleId1.Should().Be("rule-1");
|
||||
validationResult.Conflicts[0].RuleId2.Should().Be("rule-2");
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ValidateAsync_NoConflict_WhenDifferentConditions()
|
||||
{
|
||||
// Arrange
|
||||
var generator = new StubPolicyRuleGenerator();
|
||||
|
||||
var nonConflictingRules = new List<LatticeRule>
|
||||
{
|
||||
CreateRule("rule-1", "REACHABLE ∧ PRESENT", "block", priority: 10),
|
||||
CreateRule("rule-2", "¬REACHABLE ∧ PRESENT", "allow", priority: 20)
|
||||
};
|
||||
|
||||
// Act
|
||||
var validationResult = await generator.ValidateAsync(nonConflictingRules);
|
||||
|
||||
// Assert
|
||||
validationResult.Conflicts.Should().BeEmpty();
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ValidateAsync_DetectsUnreachableConditions()
|
||||
{
|
||||
// Arrange
|
||||
var generator = new StubPolicyRuleGenerator();
|
||||
|
||||
var rules = new List<LatticeRule>
|
||||
{
|
||||
CreateRule("rule-1", "REACHABLE ∧ ¬REACHABLE", "block", priority: 10) // Contradiction
|
||||
};
|
||||
|
||||
// Act
|
||||
var validationResult = await generator.ValidateAsync(rules);
|
||||
|
||||
// Assert
|
||||
validationResult.UnreachableConditions.Should().NotBeEmpty();
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ValidateAsync_ReportsCoverageMetric()
|
||||
{
|
||||
// Arrange
|
||||
var generator = new StubPolicyRuleGenerator();
|
||||
|
||||
var rules = new List<LatticeRule>
|
||||
{
|
||||
CreateRule("rule-1", "REACHABLE", "block", priority: 10),
|
||||
CreateRule("rule-2", "PRESENT", "warn", priority: 20),
|
||||
CreateRule("rule-3", "FIXED", "allow", priority: 30)
|
||||
};
|
||||
|
||||
// Act
|
||||
var validationResult = await generator.ValidateAsync(rules);
|
||||
|
||||
// Assert
|
||||
validationResult.Coverage.Should().BeGreaterThan(0);
|
||||
validationResult.Coverage.Should().BeLessThanOrEqualTo(1.0);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Test Case Synthesis Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SynthesizeAsync_GeneratesPositiveTests()
|
||||
{
|
||||
// Arrange
|
||||
var synthesizer = new StubTestCaseSynthesizer();
|
||||
var rules = new List<LatticeRule>
|
||||
{
|
||||
CreateRule("rule-1", "REACHABLE ∧ PRESENT", "block", priority: 10)
|
||||
};
|
||||
|
||||
// Act
|
||||
var testCases = await synthesizer.SynthesizeAsync(rules);
|
||||
|
||||
// Assert
|
||||
var positiveTests = testCases.Where(t => t.Type == TestCaseType.Positive).ToList();
|
||||
positiveTests.Should().NotBeEmpty();
|
||||
positiveTests.Should().AllSatisfy(t =>
|
||||
{
|
||||
t.ExpectedDisposition.Should().Be("block");
|
||||
t.TargetRuleIds.Should().Contain("rule-1");
|
||||
});
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SynthesizeAsync_GeneratesNegativeTests()
|
||||
{
|
||||
// Arrange
|
||||
var synthesizer = new StubTestCaseSynthesizer();
|
||||
var rules = new List<LatticeRule>
|
||||
{
|
||||
CreateRule("rule-1", "REACHABLE ∧ PRESENT", "block", priority: 10)
|
||||
};
|
||||
|
||||
// Act
|
||||
var testCases = await synthesizer.SynthesizeAsync(rules);
|
||||
|
||||
// Assert
|
||||
var negativeTests = testCases.Where(t => t.Type == TestCaseType.Negative).ToList();
|
||||
negativeTests.Should().NotBeEmpty();
|
||||
negativeTests.Should().AllSatisfy(t =>
|
||||
{
|
||||
t.Description.Should().NotBeNullOrEmpty();
|
||||
});
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SynthesizeAsync_GeneratesBoundaryTests()
|
||||
{
|
||||
// Arrange
|
||||
var synthesizer = new StubTestCaseSynthesizer();
|
||||
var rules = new List<LatticeRule>
|
||||
{
|
||||
CreateRule("rule-1", "cvss_score >= 7.0", "block", priority: 10)
|
||||
};
|
||||
|
||||
// Act
|
||||
var testCases = await synthesizer.SynthesizeAsync(rules);
|
||||
|
||||
// Assert
|
||||
var boundaryTests = testCases.Where(t => t.Type == TestCaseType.Boundary).ToList();
|
||||
boundaryTests.Should().NotBeEmpty();
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SynthesizeAsync_GeneratesConflictTests_ForOverlappingRules()
|
||||
{
|
||||
// Arrange
|
||||
var synthesizer = new StubTestCaseSynthesizer();
|
||||
var rules = new List<LatticeRule>
|
||||
{
|
||||
CreateRule("rule-1", "REACHABLE", "block", priority: 10),
|
||||
CreateRule("rule-2", "REACHABLE ∧ HAS_VEX", "allow", priority: 20)
|
||||
};
|
||||
|
||||
// Act
|
||||
var testCases = await synthesizer.SynthesizeAsync(rules);
|
||||
|
||||
// Assert
|
||||
var conflictTests = testCases.Where(t => t.Type == TestCaseType.Conflict).ToList();
|
||||
conflictTests.Should().NotBeEmpty();
|
||||
conflictTests.Should().AllSatisfy(t =>
|
||||
{
|
||||
t.TargetRuleIds.Count.Should().BeGreaterThan(1);
|
||||
});
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task RunTestsAsync_PassesWithMatchingRules()
|
||||
{
|
||||
// Arrange
|
||||
var synthesizer = new StubTestCaseSynthesizer();
|
||||
var rules = new List<LatticeRule>
|
||||
{
|
||||
CreateRule("rule-1", "REACHABLE", "block", priority: 10)
|
||||
};
|
||||
|
||||
var testCases = await synthesizer.SynthesizeAsync(rules);
|
||||
|
||||
// Act
|
||||
var result = await synthesizer.RunTestsAsync(testCases, rules);
|
||||
|
||||
// Assert
|
||||
result.Success.Should().BeTrue();
|
||||
result.Passed.Should().Be(result.Total);
|
||||
result.Failed.Should().Be(0);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Edge Cases
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ParseAsync_WithAmbiguousInput_ReturnsAlternatives()
|
||||
{
|
||||
// Arrange
|
||||
var parser = new StubPolicyIntentParser(ambiguous: true);
|
||||
|
||||
var ambiguousInput = "Block vulnerabilities in production";
|
||||
|
||||
// Act
|
||||
var result = await parser.ParseAsync(ambiguousInput);
|
||||
|
||||
// Assert
|
||||
result.Intent.Alternatives.Should().NotBeNullOrEmpty();
|
||||
result.Intent.ClarifyingQuestions.Should().NotBeNullOrEmpty();
|
||||
result.Intent.Confidence.Should().BeLessThan(0.9);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GenerateAsync_WithEmptyConditions_ReturnsError()
|
||||
{
|
||||
// Arrange
|
||||
var generator = new StubPolicyRuleGenerator();
|
||||
var intent = new PolicyIntent
|
||||
{
|
||||
IntentId = "empty-intent",
|
||||
IntentType = PolicyIntentType.OverrideRule,
|
||||
OriginalInput = "Block everything",
|
||||
Conditions = [],
|
||||
Actions = [],
|
||||
Scope = "all",
|
||||
Priority = 100,
|
||||
Confidence = 0.5
|
||||
};
|
||||
|
||||
// Act
|
||||
var result = await generator.GenerateAsync(intent);
|
||||
|
||||
// Assert
|
||||
result.Success.Should().BeFalse();
|
||||
result.Errors.Should().NotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ClarifyAsync_UpdatesIntentWithClarification()
|
||||
{
|
||||
// Arrange
|
||||
var parser = new StubPolicyIntentParser(ambiguous: true);
|
||||
var initialResult = await parser.ParseAsync("Block vulnerabilities");
|
||||
|
||||
// Act
|
||||
var clarifiedResult = await parser.ClarifyAsync(
|
||||
initialResult.Intent.IntentId,
|
||||
"Only critical severity");
|
||||
|
||||
// Assert
|
||||
clarifiedResult.Intent.Confidence.Should().BeGreaterThan(initialResult.Intent.Confidence);
|
||||
clarifiedResult.Intent.ClarifyingQuestions.Should().BeNullOrEmpty();
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Helper Methods
|
||||
|
||||
private static LatticeRule CreateRule(string ruleId, string expression, string disposition, int priority)
|
||||
=> new()
|
||||
{
|
||||
RuleId = ruleId,
|
||||
Name = $"Test Rule {ruleId}",
|
||||
Description = $"Test rule with expression: {expression}",
|
||||
LatticeExpression = expression,
|
||||
Conditions = ParseConditions(expression),
|
||||
Disposition = disposition,
|
||||
Priority = priority,
|
||||
Scope = "all",
|
||||
Enabled = true
|
||||
};
|
||||
|
||||
private static IReadOnlyList<PolicyCondition> ParseConditions(string expression)
|
||||
{
|
||||
var conditions = new List<PolicyCondition>();
|
||||
|
||||
if (expression.Contains("REACHABLE", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
conditions.Add(new PolicyCondition
|
||||
{
|
||||
Field = "reachable",
|
||||
Operator = "equals",
|
||||
Value = true,
|
||||
Connector = expression.Contains("∧") ? "and" : null
|
||||
});
|
||||
}
|
||||
|
||||
if (expression.Contains("PRESENT", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
conditions.Add(new PolicyCondition
|
||||
{
|
||||
Field = "present",
|
||||
Operator = "equals",
|
||||
Value = true
|
||||
});
|
||||
}
|
||||
|
||||
if (expression.Contains("cvss_score", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
conditions.Add(new PolicyCondition
|
||||
{
|
||||
Field = "cvss_score",
|
||||
Operator = "greater_than_or_equal",
|
||||
Value = 7.0
|
||||
});
|
||||
}
|
||||
|
||||
return conditions;
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Stub Implementations
|
||||
|
||||
private sealed class StubPolicyIntentParser : IPolicyIntentParser
|
||||
{
|
||||
private readonly PolicyIntentType _defaultType;
|
||||
private readonly bool _ambiguous;
|
||||
private readonly Dictionary<string, PolicyIntent> _intents = new();
|
||||
|
||||
public StubPolicyIntentParser(
|
||||
PolicyIntentType defaultType = PolicyIntentType.OverrideRule,
|
||||
bool ambiguous = false)
|
||||
{
|
||||
_defaultType = defaultType;
|
||||
_ambiguous = ambiguous;
|
||||
}
|
||||
|
||||
public Task<PolicyParseResult> ParseAsync(
|
||||
string naturalLanguageInput,
|
||||
PolicyParseContext? context = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var intentId = $"intent-{Guid.NewGuid():N}";
|
||||
var confidence = _ambiguous ? 0.7 : 0.95;
|
||||
|
||||
var conditions = new List<PolicyCondition>();
|
||||
|
||||
if (naturalLanguageInput.Contains("critical", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
conditions.Add(new PolicyCondition
|
||||
{
|
||||
Field = "severity",
|
||||
Operator = "equals",
|
||||
Value = "critical"
|
||||
});
|
||||
}
|
||||
|
||||
if (naturalLanguageInput.Contains("reachable", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
conditions.Add(new PolicyCondition
|
||||
{
|
||||
Field = "reachable",
|
||||
Operator = "equals",
|
||||
Value = true
|
||||
});
|
||||
}
|
||||
|
||||
if (naturalLanguageInput.Contains("VEX", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
conditions.Add(new PolicyCondition
|
||||
{
|
||||
Field = "vex_status",
|
||||
Operator = "equals",
|
||||
Value = "not_affected"
|
||||
});
|
||||
}
|
||||
|
||||
var intent = new PolicyIntent
|
||||
{
|
||||
IntentId = intentId,
|
||||
IntentType = _defaultType,
|
||||
OriginalInput = naturalLanguageInput,
|
||||
Conditions = conditions,
|
||||
Actions = [new PolicyAction
|
||||
{
|
||||
ActionType = "set_verdict",
|
||||
Parameters = new Dictionary<string, object> { ["verdict"] = "block" }
|
||||
}],
|
||||
Scope = context?.DefaultScope ?? "all",
|
||||
Priority = 100,
|
||||
Confidence = confidence,
|
||||
Alternatives = _ambiguous ? [CreateAlternativeIntent(naturalLanguageInput)] : null,
|
||||
ClarifyingQuestions = _ambiguous ? ["What severity levels should be affected?"] : null
|
||||
};
|
||||
|
||||
_intents[intentId] = intent;
|
||||
|
||||
return Task.FromResult(new PolicyParseResult
|
||||
{
|
||||
Intent = intent,
|
||||
Success = true,
|
||||
ModelId = "stub-parser-v1",
|
||||
ParsedAt = DateTime.UtcNow.ToString("O")
|
||||
});
|
||||
}
|
||||
|
||||
public Task<PolicyParseResult> ClarifyAsync(
|
||||
string intentId,
|
||||
string clarification,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var original = _intents.GetValueOrDefault(intentId);
|
||||
if (original is null)
|
||||
{
|
||||
throw new InvalidOperationException($"Intent {intentId} not found");
|
||||
}
|
||||
|
||||
var clarified = original with
|
||||
{
|
||||
Confidence = 0.95,
|
||||
ClarifyingQuestions = null,
|
||||
Alternatives = null
|
||||
};
|
||||
|
||||
return Task.FromResult(new PolicyParseResult
|
||||
{
|
||||
Intent = clarified,
|
||||
Success = true,
|
||||
ModelId = "stub-parser-v1",
|
||||
ParsedAt = DateTime.UtcNow.ToString("O")
|
||||
});
|
||||
}
|
||||
|
||||
private PolicyIntent CreateAlternativeIntent(string input) => new()
|
||||
{
|
||||
IntentId = $"alt-{Guid.NewGuid():N}",
|
||||
IntentType = PolicyIntentType.ExceptionCondition,
|
||||
OriginalInput = input,
|
||||
Conditions = [],
|
||||
Actions = [],
|
||||
Scope = "all",
|
||||
Priority = 50,
|
||||
Confidence = 0.5
|
||||
};
|
||||
}
|
||||
|
||||
private sealed class StubPolicyRuleGenerator : IPolicyRuleGenerator
|
||||
{
|
||||
public Task<RuleGenerationResult> GenerateAsync(
|
||||
PolicyIntent intent,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (intent.Conditions.Count == 0)
|
||||
{
|
||||
return Task.FromResult(new RuleGenerationResult
|
||||
{
|
||||
Rules = [],
|
||||
Success = false,
|
||||
Warnings = [],
|
||||
Errors = ["Intent must have at least one condition"],
|
||||
IntentId = intent.IntentId,
|
||||
GeneratedAt = DateTime.UtcNow.ToString("O")
|
||||
});
|
||||
}
|
||||
|
||||
var expression = BuildLatticeExpression(intent.Conditions);
|
||||
var disposition = intent.Actions.FirstOrDefault()?.Parameters.GetValueOrDefault("verdict")?.ToString() ?? "warn";
|
||||
|
||||
var rule = new LatticeRule
|
||||
{
|
||||
RuleId = $"rule-{Guid.NewGuid():N}",
|
||||
Name = $"Generated from: {intent.OriginalInput[..Math.Min(30, intent.OriginalInput.Length)]}",
|
||||
Description = intent.OriginalInput,
|
||||
LatticeExpression = expression,
|
||||
Conditions = intent.Conditions,
|
||||
Disposition = disposition,
|
||||
Priority = intent.Priority,
|
||||
Scope = intent.Scope,
|
||||
Enabled = true
|
||||
};
|
||||
|
||||
return Task.FromResult(new RuleGenerationResult
|
||||
{
|
||||
Rules = [rule],
|
||||
Success = true,
|
||||
Warnings = [],
|
||||
IntentId = intent.IntentId,
|
||||
GeneratedAt = DateTime.UtcNow.ToString("O")
|
||||
});
|
||||
}
|
||||
|
||||
public Task<RuleValidationResult> ValidateAsync(
|
||||
IReadOnlyList<LatticeRule> rules,
|
||||
IReadOnlyList<string>? existingRuleIds = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var conflicts = new List<RuleConflict>();
|
||||
var unreachable = new List<string>();
|
||||
|
||||
// Check for conflicts
|
||||
for (int i = 0; i < rules.Count; i++)
|
||||
{
|
||||
for (int j = i + 1; j < rules.Count; j++)
|
||||
{
|
||||
if (HasConflict(rules[i], rules[j]))
|
||||
{
|
||||
conflicts.Add(new RuleConflict
|
||||
{
|
||||
RuleId1 = rules[i].RuleId,
|
||||
RuleId2 = rules[j].RuleId,
|
||||
Description = "Rules have overlapping conditions with different dispositions",
|
||||
SuggestedResolution = "Adjust priority or narrow conditions",
|
||||
Severity = "error"
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unreachable conditions (contradictions)
|
||||
if (rules[i].LatticeExpression.Contains("∧ ¬") &&
|
||||
rules[i].LatticeExpression.Split("∧").Any(p =>
|
||||
p.Trim().StartsWith("¬") && rules[i].LatticeExpression.Contains(p.Trim()[1..])))
|
||||
{
|
||||
unreachable.Add($"Rule {rules[i].RuleId} has contradictory conditions");
|
||||
}
|
||||
}
|
||||
|
||||
var coverage = Math.Min(1.0, rules.Count * 0.2);
|
||||
|
||||
return Task.FromResult(new RuleValidationResult
|
||||
{
|
||||
Valid = conflicts.Count == 0 && unreachable.Count == 0,
|
||||
Conflicts = conflicts,
|
||||
UnreachableConditions = unreachable,
|
||||
PotentialLoops = [],
|
||||
Coverage = coverage
|
||||
});
|
||||
}
|
||||
|
||||
private static bool HasConflict(LatticeRule rule1, LatticeRule rule2)
|
||||
{
|
||||
// Simplified conflict detection
|
||||
var sameConditions = rule1.LatticeExpression == rule2.LatticeExpression;
|
||||
var differentDispositions = rule1.Disposition != rule2.Disposition;
|
||||
return sameConditions && differentDispositions;
|
||||
}
|
||||
|
||||
private static string BuildLatticeExpression(IReadOnlyList<PolicyCondition> conditions)
|
||||
{
|
||||
var parts = conditions.Select(c =>
|
||||
{
|
||||
var atom = c.Field.ToUpperInvariant() switch
|
||||
{
|
||||
"REACHABLE" => "REACHABLE",
|
||||
"PRESENT" => "PRESENT",
|
||||
"SEVERITY" => c.Value?.ToString()?.ToUpperInvariant() ?? "CRITICAL",
|
||||
"VEX_STATUS" => "HAS_VEX",
|
||||
_ => c.Field.ToUpperInvariant()
|
||||
};
|
||||
|
||||
return c.Operator == "not_equals" || c.Value?.Equals(false) == true
|
||||
? $"¬{atom}"
|
||||
: atom;
|
||||
});
|
||||
|
||||
return string.Join(" ∧ ", parts);
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class StubTestCaseSynthesizer : ITestCaseSynthesizer
|
||||
{
|
||||
public Task<IReadOnlyList<PolicyTestCase>> SynthesizeAsync(
|
||||
IReadOnlyList<LatticeRule> rules,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var testCases = new List<PolicyTestCase>();
|
||||
var testId = 0;
|
||||
|
||||
foreach (var rule in rules)
|
||||
{
|
||||
// Positive test
|
||||
testCases.Add(new PolicyTestCase
|
||||
{
|
||||
TestCaseId = $"test-{++testId}",
|
||||
Name = $"Positive test for {rule.Name}",
|
||||
Type = TestCaseType.Positive,
|
||||
Input = BuildPositiveInput(rule),
|
||||
ExpectedDisposition = rule.Disposition,
|
||||
TargetRuleIds = [rule.RuleId],
|
||||
Description = $"Verifies rule matches when conditions are met"
|
||||
});
|
||||
|
||||
// Negative test
|
||||
testCases.Add(new PolicyTestCase
|
||||
{
|
||||
TestCaseId = $"test-{++testId}",
|
||||
Name = $"Negative test for {rule.Name}",
|
||||
Type = TestCaseType.Negative,
|
||||
Input = BuildNegativeInput(rule),
|
||||
ExpectedDisposition = "no_match",
|
||||
TargetRuleIds = [rule.RuleId],
|
||||
Description = $"Verifies rule does not match when conditions are not met"
|
||||
});
|
||||
|
||||
// Boundary test for numeric conditions
|
||||
if (rule.LatticeExpression.Contains("cvss_score") || rule.LatticeExpression.Contains(">="))
|
||||
{
|
||||
testCases.Add(new PolicyTestCase
|
||||
{
|
||||
TestCaseId = $"test-{++testId}",
|
||||
Name = $"Boundary test for {rule.Name}",
|
||||
Type = TestCaseType.Boundary,
|
||||
Input = BuildBoundaryInput(rule),
|
||||
ExpectedDisposition = rule.Disposition,
|
||||
TargetRuleIds = [rule.RuleId],
|
||||
Description = $"Verifies rule at boundary values"
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Conflict tests for overlapping rules
|
||||
for (int i = 0; i < rules.Count; i++)
|
||||
{
|
||||
for (int j = i + 1; j < rules.Count; j++)
|
||||
{
|
||||
if (RulesOverlap(rules[i], rules[j]))
|
||||
{
|
||||
testCases.Add(new PolicyTestCase
|
||||
{
|
||||
TestCaseId = $"test-{++testId}",
|
||||
Name = $"Conflict test: {rules[i].Name} vs {rules[j].Name}",
|
||||
Type = TestCaseType.Conflict,
|
||||
Input = BuildOverlapInput(rules[i], rules[j]),
|
||||
ExpectedDisposition = rules[i].Priority > rules[j].Priority
|
||||
? rules[i].Disposition
|
||||
: rules[j].Disposition,
|
||||
TargetRuleIds = [rules[i].RuleId, rules[j].RuleId],
|
||||
Description = $"Verifies priority resolution when both rules match"
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Task.FromResult<IReadOnlyList<PolicyTestCase>>(testCases);
|
||||
}
|
||||
|
||||
public Task<TestRunResult> RunTestsAsync(
|
||||
IReadOnlyList<PolicyTestCase> testCases,
|
||||
IReadOnlyList<LatticeRule> rules,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var results = new List<TestCaseResult>();
|
||||
|
||||
foreach (var testCase in testCases)
|
||||
{
|
||||
var passed = true; // Simplified - stub always passes
|
||||
results.Add(new TestCaseResult
|
||||
{
|
||||
TestCaseId = testCase.TestCaseId,
|
||||
Passed = passed,
|
||||
Expected = testCase.ExpectedDisposition,
|
||||
Actual = testCase.ExpectedDisposition
|
||||
});
|
||||
}
|
||||
|
||||
return Task.FromResult(new TestRunResult
|
||||
{
|
||||
Total = testCases.Count,
|
||||
Passed = results.Count(r => r.Passed),
|
||||
Failed = results.Count(r => !r.Passed),
|
||||
Results = results,
|
||||
RunAt = DateTime.UtcNow.ToString("O")
|
||||
});
|
||||
}
|
||||
|
||||
private static IReadOnlyDictionary<string, object> BuildPositiveInput(LatticeRule rule)
|
||||
{
|
||||
var input = new Dictionary<string, object>();
|
||||
if (rule.LatticeExpression.Contains("REACHABLE")) input["reachable"] = true;
|
||||
if (rule.LatticeExpression.Contains("PRESENT")) input["present"] = true;
|
||||
if (rule.LatticeExpression.Contains("HAS_VEX")) input["has_vex"] = true;
|
||||
return input;
|
||||
}
|
||||
|
||||
private static IReadOnlyDictionary<string, object> BuildNegativeInput(LatticeRule rule)
|
||||
{
|
||||
var input = new Dictionary<string, object>();
|
||||
if (rule.LatticeExpression.Contains("REACHABLE")) input["reachable"] = false;
|
||||
if (rule.LatticeExpression.Contains("PRESENT")) input["present"] = false;
|
||||
return input;
|
||||
}
|
||||
|
||||
private static IReadOnlyDictionary<string, object> BuildBoundaryInput(LatticeRule rule)
|
||||
{
|
||||
return new Dictionary<string, object>
|
||||
{
|
||||
["cvss_score"] = 7.0
|
||||
};
|
||||
}
|
||||
|
||||
private static IReadOnlyDictionary<string, object> BuildOverlapInput(LatticeRule rule1, LatticeRule rule2)
|
||||
{
|
||||
var input = new Dictionary<string, object>();
|
||||
input["reachable"] = true;
|
||||
input["present"] = true;
|
||||
input["has_vex"] = true;
|
||||
return input;
|
||||
}
|
||||
|
||||
private static bool RulesOverlap(LatticeRule rule1, LatticeRule rule2)
|
||||
{
|
||||
// Simplified overlap detection
|
||||
return rule1.LatticeExpression.Contains("REACHABLE") &&
|
||||
rule2.LatticeExpression.Contains("REACHABLE");
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
@@ -0,0 +1,791 @@
|
||||
using StellaOps.AdvisoryAI.Remediation;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
/// <summary>
|
||||
/// Integration tests for remediation plan generation and PR creation.
|
||||
/// Sprint: SPRINT_20251226_016_AI_remedy_autopilot
|
||||
/// Task: REMEDY-25
|
||||
/// </summary>
|
||||
public sealed class RemediationIntegrationTests
|
||||
{
|
||||
#region Plan Generation Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_ValidRequest_ReturnsPlan()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner();
|
||||
var request = CreateTestRequest();
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(plan);
|
||||
Assert.Equal(request.FindingId, plan.Request.FindingId);
|
||||
Assert.NotEmpty(plan.Steps);
|
||||
Assert.NotEmpty(plan.PlanId);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_BumpRemediation_GeneratesBumpSteps()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner();
|
||||
var request = CreateTestRequest() with
|
||||
{
|
||||
RemediationType = RemediationType.Bump,
|
||||
ComponentPurl = "pkg:npm/lodash@4.17.20"
|
||||
};
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.Contains(plan.Steps, s => s.ActionType == "update_package");
|
||||
Assert.True(plan.ExpectedDelta.Upgraded.Count > 0);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_UpgradeRemediation_GeneratesUpgradeSteps()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner();
|
||||
var request = CreateTestRequest() with
|
||||
{
|
||||
RemediationType = RemediationType.Upgrade,
|
||||
ComponentPurl = "pkg:oci/alpine@3.18"
|
||||
};
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.Contains(plan.Steps, s => s.ActionType == "update_base_image");
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_ConfigRemediation_GeneratesConfigSteps()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner();
|
||||
var request = CreateTestRequest() with
|
||||
{
|
||||
RemediationType = RemediationType.Config,
|
||||
VulnerabilityId = "CVE-2021-44228" // Log4Shell
|
||||
};
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.Contains(plan.Steps, s => s.ActionType == "update_config");
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_AssessesRiskCorrectly_PatchVersion()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner(patchVersionBump: true);
|
||||
var request = CreateTestRequest();
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(RemediationRisk.Low, plan.RiskAssessment);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_AssessesRiskCorrectly_MajorVersion()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner(majorVersionBump: true);
|
||||
var request = CreateTestRequest();
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(RemediationRisk.High, plan.RiskAssessment);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_IncludesExpectedSbomDelta()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner();
|
||||
var request = CreateTestRequest();
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(plan.ExpectedDelta);
|
||||
Assert.True(plan.ExpectedDelta.NetVulnerabilityChange < 0); // Should improve
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_IncludesTestRequirements()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner();
|
||||
var request = CreateTestRequest();
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(plan.TestRequirements);
|
||||
Assert.NotEmpty(plan.TestRequirements.TestSuites);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_IncludesInputHashes()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner();
|
||||
var request = CreateTestRequest();
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.NotEmpty(plan.InputHashes);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ValidatePlanAsync_ExistingPlan_ReturnsTrue()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner();
|
||||
var request = CreateTestRequest();
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Act
|
||||
var isValid = await planner.ValidatePlanAsync(plan.PlanId);
|
||||
|
||||
// Assert
|
||||
Assert.True(isValid);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ValidatePlanAsync_NonexistentPlan_ReturnsFalse()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner();
|
||||
|
||||
// Act
|
||||
var isValid = await planner.ValidatePlanAsync("nonexistent-plan");
|
||||
|
||||
// Assert
|
||||
Assert.False(isValid);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GetPlanAsync_ExistingPlan_ReturnsPlan()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner();
|
||||
var request = CreateTestRequest();
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Act
|
||||
var retrieved = await planner.GetPlanAsync(plan.PlanId);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(retrieved);
|
||||
Assert.Equal(plan.PlanId, retrieved.PlanId);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region PR Generation Tests (Mocked SCM)
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CreatePullRequestAsync_ValidPlan_CreatesPR()
|
||||
{
|
||||
// Arrange
|
||||
var prGenerator = new StubPullRequestGenerator();
|
||||
var plan = CreateTestPlan();
|
||||
|
||||
// Act
|
||||
var result = await prGenerator.CreatePullRequestAsync(plan);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(result);
|
||||
Assert.NotEmpty(result.PrId);
|
||||
Assert.True(result.PrNumber > 0);
|
||||
Assert.NotEmpty(result.Url);
|
||||
Assert.NotEmpty(result.BranchName);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CreatePullRequestAsync_SetsBranchNameFromPlan()
|
||||
{
|
||||
// Arrange
|
||||
var prGenerator = new StubPullRequestGenerator();
|
||||
var plan = CreateTestPlan();
|
||||
|
||||
// Act
|
||||
var result = await prGenerator.CreatePullRequestAsync(plan);
|
||||
|
||||
// Assert
|
||||
Assert.Contains("stellaops-fix", result.BranchName);
|
||||
Assert.Contains(plan.Request.VulnerabilityId.ToLowerInvariant(), result.BranchName);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CreatePullRequestAsync_InitialStatus_IsOpen()
|
||||
{
|
||||
// Arrange
|
||||
var prGenerator = new StubPullRequestGenerator();
|
||||
var plan = CreateTestPlan();
|
||||
|
||||
// Act
|
||||
var result = await prGenerator.CreatePullRequestAsync(plan);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(PullRequestStatus.Open, result.Status);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GetStatusAsync_ExistingPR_ReturnsStatus()
|
||||
{
|
||||
// Arrange
|
||||
var prGenerator = new StubPullRequestGenerator();
|
||||
var plan = CreateTestPlan();
|
||||
var pr = await prGenerator.CreatePullRequestAsync(plan);
|
||||
|
||||
// Act
|
||||
var status = await prGenerator.GetStatusAsync(pr.PrId);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(status);
|
||||
Assert.Equal(pr.PrId, status.PrId);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task UpdateWithDeltaVerdictAsync_UpdatesPR()
|
||||
{
|
||||
// Arrange
|
||||
var prGenerator = new StubPullRequestGenerator();
|
||||
var plan = CreateTestPlan();
|
||||
var pr = await prGenerator.CreatePullRequestAsync(plan);
|
||||
|
||||
var deltaVerdict = new DeltaVerdictResult
|
||||
{
|
||||
Improved = true,
|
||||
VulnerabilitiesFixed = 3,
|
||||
VulnerabilitiesIntroduced = 0,
|
||||
VerdictId = "delta-001",
|
||||
ComputedAt = DateTime.UtcNow.ToString("o")
|
||||
};
|
||||
|
||||
// Act
|
||||
await prGenerator.UpdateWithDeltaVerdictAsync(pr.PrId, deltaVerdict);
|
||||
var updated = await prGenerator.GetStatusAsync(pr.PrId);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(updated.DeltaVerdict);
|
||||
Assert.Equal(3, updated.DeltaVerdict.VulnerabilitiesFixed);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ClosePullRequestAsync_ClosesPR()
|
||||
{
|
||||
// Arrange
|
||||
var prGenerator = new StubPullRequestGenerator();
|
||||
var plan = CreateTestPlan();
|
||||
var pr = await prGenerator.CreatePullRequestAsync(plan);
|
||||
|
||||
// Act
|
||||
await prGenerator.ClosePullRequestAsync(pr.PrId, "Superseded by manual fix");
|
||||
var status = await prGenerator.GetStatusAsync(pr.PrId);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(PullRequestStatus.Closed, status.Status);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ScmType_GitHub_ReturnsCorrectType()
|
||||
{
|
||||
// Arrange
|
||||
var prGenerator = new StubPullRequestGenerator { ScmTypeOverride = "github" };
|
||||
|
||||
// Assert
|
||||
Assert.Equal("github", prGenerator.ScmType);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ScmType_GitLab_ReturnsCorrectType()
|
||||
{
|
||||
// Arrange
|
||||
var prGenerator = new StubPullRequestGenerator { ScmTypeOverride = "gitlab" };
|
||||
|
||||
// Assert
|
||||
Assert.Equal("gitlab", prGenerator.ScmType);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ScmType_AzureDevOps_ReturnsCorrectType()
|
||||
{
|
||||
// Arrange
|
||||
var prGenerator = new StubPullRequestGenerator { ScmTypeOverride = "azure-devops" };
|
||||
|
||||
// Assert
|
||||
Assert.Equal("azure-devops", prGenerator.ScmType);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Fallback Handling Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_BuildFails_SetsSuggestionAuthority()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner(buildWillFail: true);
|
||||
var request = CreateTestRequest() with { AutoCreatePr = true };
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(RemediationAuthority.Suggestion, plan.Authority);
|
||||
Assert.False(plan.PrReady);
|
||||
Assert.NotNull(plan.NotReadyReason);
|
||||
Assert.Contains("build", plan.NotReadyReason.ToLower());
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_TestsFail_SetsSuggestionAuthority()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner(testsWillFail: true);
|
||||
var request = CreateTestRequest() with { AutoCreatePr = true };
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(RemediationAuthority.Suggestion, plan.Authority);
|
||||
Assert.False(plan.PrReady);
|
||||
Assert.NotNull(plan.NotReadyReason);
|
||||
Assert.Contains("test", plan.NotReadyReason.ToLower());
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_NoAutoCreatePr_SetsDraftAuthority()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner();
|
||||
var request = CreateTestRequest() with { AutoCreatePr = false };
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(RemediationAuthority.Draft, plan.Authority);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_AllVerificationsPassed_SetsVerifiedAuthority()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner(allVerificationsPassed: true);
|
||||
var request = CreateTestRequest() with { AutoCreatePr = true };
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(RemediationAuthority.Verified, plan.Authority);
|
||||
Assert.True(plan.PrReady);
|
||||
Assert.Null(plan.NotReadyReason);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_BreakingChanges_ReducesConfidence()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner(hasBreakingChanges: true);
|
||||
var request = CreateTestRequest();
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.True(plan.ConfidenceScore < 0.8);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Confidence Score Tests
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_PatchVersion_HighConfidence()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner(patchVersionBump: true);
|
||||
var request = CreateTestRequest();
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.True(plan.ConfidenceScore >= 0.9);
|
||||
}
|
||||
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GeneratePlanAsync_MajorVersion_LowerConfidence()
|
||||
{
|
||||
// Arrange
|
||||
var planner = new StubRemediationPlanner(majorVersionBump: true);
|
||||
var request = CreateTestRequest();
|
||||
|
||||
// Act
|
||||
var plan = await planner.GeneratePlanAsync(request);
|
||||
|
||||
// Assert
|
||||
Assert.True(plan.ConfidenceScore < 0.7);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Helper Methods
|
||||
|
||||
private static RemediationPlanRequest CreateTestRequest()
|
||||
{
|
||||
return new RemediationPlanRequest
|
||||
{
|
||||
FindingId = "finding-001",
|
||||
ArtifactDigest = "sha256:abc123",
|
||||
VulnerabilityId = "CVE-2024-1234",
|
||||
ComponentPurl = "pkg:npm/lodash@4.17.20",
|
||||
RemediationType = RemediationType.Auto,
|
||||
RepositoryUrl = "https://github.com/test/repo",
|
||||
TargetBranch = "main",
|
||||
AutoCreatePr = false,
|
||||
CorrelationId = Guid.NewGuid().ToString()
|
||||
};
|
||||
}
|
||||
|
||||
private static RemediationPlan CreateTestPlan()
|
||||
{
|
||||
return new RemediationPlan
|
||||
{
|
||||
PlanId = $"plan-{Guid.NewGuid():N}",
|
||||
Request = CreateTestRequest(),
|
||||
Steps = new[]
|
||||
{
|
||||
new RemediationStep
|
||||
{
|
||||
Order = 1,
|
||||
ActionType = "update_package",
|
||||
FilePath = "package.json",
|
||||
Description = "Update lodash from 4.17.20 to 4.17.21",
|
||||
PreviousValue = "4.17.20",
|
||||
NewValue = "4.17.21",
|
||||
Risk = RemediationRisk.Low
|
||||
}
|
||||
},
|
||||
ExpectedDelta = new ExpectedSbomDelta
|
||||
{
|
||||
Added = Array.Empty<string>(),
|
||||
Removed = Array.Empty<string>(),
|
||||
Upgraded = new Dictionary<string, string>
|
||||
{
|
||||
["pkg:npm/lodash@4.17.20"] = "pkg:npm/lodash@4.17.21"
|
||||
},
|
||||
NetVulnerabilityChange = -1
|
||||
},
|
||||
RiskAssessment = RemediationRisk.Low,
|
||||
TestRequirements = new RemediationTestRequirements
|
||||
{
|
||||
TestSuites = new[] { "unit", "integration" },
|
||||
MinCoverage = 80,
|
||||
RequireAllPass = true,
|
||||
Timeout = TimeSpan.FromMinutes(15)
|
||||
},
|
||||
Authority = RemediationAuthority.Draft,
|
||||
PrReady = false,
|
||||
ConfidenceScore = 0.92,
|
||||
ModelId = "test-model",
|
||||
GeneratedAt = DateTime.UtcNow.ToString("o"),
|
||||
InputHashes = new[] { "hash1", "hash2" },
|
||||
EvidenceRefs = new[] { "evidence/sbom-001", "evidence/vuln-001" }
|
||||
};
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Stub Implementations
|
||||
|
||||
private sealed class StubRemediationPlanner : IRemediationPlanner
|
||||
{
|
||||
private readonly Dictionary<string, RemediationPlan> _plans = new();
|
||||
private readonly bool _patchVersionBump;
|
||||
private readonly bool _majorVersionBump;
|
||||
private readonly bool _buildWillFail;
|
||||
private readonly bool _testsWillFail;
|
||||
private readonly bool _allVerificationsPassed;
|
||||
private readonly bool _hasBreakingChanges;
|
||||
|
||||
public StubRemediationPlanner(
|
||||
bool patchVersionBump = false,
|
||||
bool majorVersionBump = false,
|
||||
bool buildWillFail = false,
|
||||
bool testsWillFail = false,
|
||||
bool allVerificationsPassed = false,
|
||||
bool hasBreakingChanges = false)
|
||||
{
|
||||
_patchVersionBump = patchVersionBump;
|
||||
_majorVersionBump = majorVersionBump;
|
||||
_buildWillFail = buildWillFail;
|
||||
_testsWillFail = testsWillFail;
|
||||
_allVerificationsPassed = allVerificationsPassed;
|
||||
_hasBreakingChanges = hasBreakingChanges;
|
||||
}
|
||||
|
||||
public Task<RemediationPlan> GeneratePlanAsync(
|
||||
RemediationPlanRequest request,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var planId = $"plan-{Guid.NewGuid():N}";
|
||||
|
||||
var (actionType, risk, confidence) = DetermineStepDetails(request);
|
||||
|
||||
var steps = new List<RemediationStep>
|
||||
{
|
||||
new()
|
||||
{
|
||||
Order = 1,
|
||||
ActionType = actionType,
|
||||
FilePath = GetFilePath(request),
|
||||
Description = $"Fix {request.VulnerabilityId}",
|
||||
PreviousValue = "old",
|
||||
NewValue = "new",
|
||||
Risk = risk
|
||||
}
|
||||
};
|
||||
|
||||
var authority = DetermineAuthority(request);
|
||||
var prReady = authority == RemediationAuthority.Verified;
|
||||
var notReadyReason = GetNotReadyReason();
|
||||
|
||||
if (_hasBreakingChanges)
|
||||
{
|
||||
confidence *= 0.6;
|
||||
}
|
||||
|
||||
var plan = new RemediationPlan
|
||||
{
|
||||
PlanId = planId,
|
||||
Request = request,
|
||||
Steps = steps,
|
||||
ExpectedDelta = new ExpectedSbomDelta
|
||||
{
|
||||
Added = Array.Empty<string>(),
|
||||
Removed = Array.Empty<string>(),
|
||||
Upgraded = new Dictionary<string, string>
|
||||
{
|
||||
[request.ComponentPurl] = request.ComponentPurl + "-fixed"
|
||||
},
|
||||
NetVulnerabilityChange = -1
|
||||
},
|
||||
RiskAssessment = risk,
|
||||
TestRequirements = new RemediationTestRequirements
|
||||
{
|
||||
TestSuites = new[] { "unit", "integration" },
|
||||
MinCoverage = 80,
|
||||
RequireAllPass = true,
|
||||
Timeout = TimeSpan.FromMinutes(15)
|
||||
},
|
||||
Authority = authority,
|
||||
PrReady = prReady,
|
||||
NotReadyReason = notReadyReason,
|
||||
ConfidenceScore = confidence,
|
||||
ModelId = "stub-model",
|
||||
GeneratedAt = DateTime.UtcNow.ToString("o"),
|
||||
InputHashes = new[] { $"input:{request.FindingId}", $"input:{request.ArtifactDigest}" },
|
||||
EvidenceRefs = new[] { "evidence/ref-001" }
|
||||
};
|
||||
|
||||
_plans[planId] = plan;
|
||||
return Task.FromResult(plan);
|
||||
}
|
||||
|
||||
private (string ActionType, RemediationRisk Risk, double Confidence) DetermineStepDetails(
|
||||
RemediationPlanRequest request)
|
||||
{
|
||||
var actionType = request.RemediationType switch
|
||||
{
|
||||
RemediationType.Bump => "update_package",
|
||||
RemediationType.Upgrade => "update_base_image",
|
||||
RemediationType.Config => "update_config",
|
||||
RemediationType.Backport => "apply_patch",
|
||||
_ => "update_package"
|
||||
};
|
||||
|
||||
if (_patchVersionBump)
|
||||
return (actionType, RemediationRisk.Low, 0.95);
|
||||
|
||||
if (_majorVersionBump)
|
||||
return (actionType, RemediationRisk.High, 0.65);
|
||||
|
||||
return (actionType, RemediationRisk.Medium, 0.85);
|
||||
}
|
||||
|
||||
private string GetFilePath(RemediationPlanRequest request)
|
||||
{
|
||||
if (request.ComponentPurl.StartsWith("pkg:npm"))
|
||||
return "package.json";
|
||||
if (request.ComponentPurl.StartsWith("pkg:pypi"))
|
||||
return "requirements.txt";
|
||||
if (request.ComponentPurl.StartsWith("pkg:oci"))
|
||||
return "Dockerfile";
|
||||
return "package.json";
|
||||
}
|
||||
|
||||
private RemediationAuthority DetermineAuthority(RemediationPlanRequest request)
|
||||
{
|
||||
if (!request.AutoCreatePr)
|
||||
return RemediationAuthority.Draft;
|
||||
|
||||
if (_buildWillFail || _testsWillFail)
|
||||
return RemediationAuthority.Suggestion;
|
||||
|
||||
if (_allVerificationsPassed)
|
||||
return RemediationAuthority.Verified;
|
||||
|
||||
return RemediationAuthority.Draft;
|
||||
}
|
||||
|
||||
private string? GetNotReadyReason()
|
||||
{
|
||||
if (_buildWillFail)
|
||||
return "Build failed during verification";
|
||||
if (_testsWillFail)
|
||||
return "Tests failed during verification";
|
||||
return null;
|
||||
}
|
||||
|
||||
public Task<bool> ValidatePlanAsync(string planId, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return Task.FromResult(_plans.ContainsKey(planId));
|
||||
}
|
||||
|
||||
public Task<RemediationPlan?> GetPlanAsync(string planId, CancellationToken cancellationToken = default)
|
||||
{
|
||||
_plans.TryGetValue(planId, out var plan);
|
||||
return Task.FromResult(plan);
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class StubPullRequestGenerator : IPullRequestGenerator
|
||||
{
|
||||
private readonly Dictionary<string, PullRequestResult> _prs = new();
|
||||
private int _prCounter;
|
||||
|
||||
public string ScmType => ScmTypeOverride ?? "github";
|
||||
public string? ScmTypeOverride { get; set; }
|
||||
|
||||
public Task<PullRequestResult> CreatePullRequestAsync(
|
||||
RemediationPlan plan,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var prId = $"pr-{Guid.NewGuid():N}";
|
||||
_prCounter++;
|
||||
|
||||
var branchName = $"stellaops-fix-{plan.Request.VulnerabilityId.ToLowerInvariant()}-{_prCounter}";
|
||||
|
||||
var result = new PullRequestResult
|
||||
{
|
||||
PrId = prId,
|
||||
PrNumber = _prCounter,
|
||||
Url = $"https://github.com/test/repo/pull/{_prCounter}",
|
||||
BranchName = branchName,
|
||||
Status = PullRequestStatus.Open,
|
||||
CreatedAt = DateTime.UtcNow.ToString("o"),
|
||||
UpdatedAt = DateTime.UtcNow.ToString("o")
|
||||
};
|
||||
|
||||
_prs[prId] = result;
|
||||
return Task.FromResult(result);
|
||||
}
|
||||
|
||||
public Task<PullRequestResult> GetStatusAsync(string prId, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (!_prs.TryGetValue(prId, out var result))
|
||||
throw new InvalidOperationException($"PR {prId} not found");
|
||||
|
||||
return Task.FromResult(result);
|
||||
}
|
||||
|
||||
public Task UpdateWithDeltaVerdictAsync(
|
||||
string prId,
|
||||
DeltaVerdictResult deltaVerdict,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (!_prs.TryGetValue(prId, out var result))
|
||||
throw new InvalidOperationException($"PR {prId} not found");
|
||||
|
||||
_prs[prId] = result with
|
||||
{
|
||||
DeltaVerdict = deltaVerdict,
|
||||
UpdatedAt = DateTime.UtcNow.ToString("o")
|
||||
};
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public Task ClosePullRequestAsync(string prId, string reason, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (!_prs.TryGetValue(prId, out var result))
|
||||
throw new InvalidOperationException($"PR {prId} not found");
|
||||
|
||||
_prs[prId] = result with
|
||||
{
|
||||
Status = PullRequestStatus.Closed,
|
||||
StatusMessage = reason,
|
||||
UpdatedAt = DateTime.UtcNow.ToString("o")
|
||||
};
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
@@ -10,11 +10,13 @@ using Microsoft.Extensions.Options;
|
||||
using StellaOps.AdvisoryAI.Providers;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class SbomContextHttpClientTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GetContextAsync_MapsPayloadToDocument()
|
||||
{
|
||||
const string payload = """
|
||||
@@ -98,7 +100,8 @@ public sealed class SbomContextHttpClientTests
|
||||
Assert.Contains("includeBlastRadius=true", handler.LastRequest.RequestUri!.Query);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GetContextAsync_ReturnsNullOnNotFound()
|
||||
{
|
||||
var handler = new StubHttpMessageHandler(_ => new HttpResponseMessage(HttpStatusCode.NotFound));
|
||||
@@ -110,7 +113,8 @@ public sealed class SbomContextHttpClientTests
|
||||
Assert.Null(result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GetContextAsync_ThrowsForServerError()
|
||||
{
|
||||
var handler = new StubHttpMessageHandler(_ => new HttpResponseMessage(HttpStatusCode.InternalServerError)
|
||||
|
||||
@@ -2,11 +2,13 @@ using FluentAssertions;
|
||||
using StellaOps.AdvisoryAI.Abstractions;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class SbomContextRequestTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Constructor_NormalizesWhitespaceAndLimits()
|
||||
{
|
||||
var request = new SbomContextRequest(
|
||||
@@ -25,7 +27,8 @@ public sealed class SbomContextRequestTests
|
||||
request.IncludeBlastRadius.Should().BeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Constructor_AllowsNullPurlAndDefaults()
|
||||
{
|
||||
var request = new SbomContextRequest(artifactId: "scan-123", purl: null);
|
||||
|
||||
@@ -12,11 +12,13 @@ using StellaOps.AdvisoryAI.Providers;
|
||||
using StellaOps.AdvisoryAI.Retrievers;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class SbomContextRetrieverTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task RetrieveAsync_ReturnsDeterministicOrderingAndMetadata()
|
||||
{
|
||||
var document = new SbomContextDocument(
|
||||
@@ -103,7 +105,8 @@ public sealed class SbomContextRetrieverTests
|
||||
result.Metadata["blast_radius_present"].Should().Be(bool.TrueString);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task RetrieveAsync_ReturnsEmptyWhenNoDocument()
|
||||
{
|
||||
var client = new FakeSbomContextClient(null);
|
||||
@@ -119,7 +122,8 @@ public sealed class SbomContextRetrieverTests
|
||||
result.BlastRadius.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task RetrieveAsync_HonorsEnvironmentFlagToggle()
|
||||
{
|
||||
var document = new SbomContextDocument(
|
||||
@@ -152,7 +156,8 @@ public sealed class SbomContextRetrieverTests
|
||||
client.LastQuery.IncludeBlastRadius.Should().BeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task RetrieveAsync_DeduplicatesDependencyPaths()
|
||||
{
|
||||
var duplicatePath = ImmutableArray.Create(
|
||||
|
||||
@@ -2,11 +2,13 @@ using FluentAssertions;
|
||||
using StellaOps.AdvisoryAI.Tools;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class SemanticVersionTests
|
||||
{
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData("1.2.3", 1, 2, 3, false)]
|
||||
[InlineData("1.2.3-alpha", 1, 2, 3, true)]
|
||||
[InlineData("0.0.1+build", 0, 0, 1, false)]
|
||||
@@ -21,7 +23,8 @@ public sealed class SemanticVersionTests
|
||||
(version.PreRelease.Count > 0).Should().Be(hasPreRelease);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData("01.0.0")]
|
||||
[InlineData("1..0")]
|
||||
[InlineData("1.0.0-")]
|
||||
@@ -33,7 +36,8 @@ public sealed class SemanticVersionTests
|
||||
act.Should().Throw<FormatException>();
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData("1.2.3", "1.2.3", 0)]
|
||||
[InlineData("1.2.3", "1.2.4", -1)]
|
||||
[InlineData("1.3.0", "1.2.9", 1)]
|
||||
@@ -48,7 +52,8 @@ public sealed class SemanticVersionTests
|
||||
Math.Sign(leftVersion.CompareTo(rightVersion)).Should().Be(expectedSign);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData("1.2.3", ">=1.0.0,<2.0.0", true)]
|
||||
[InlineData("0.9.0", ">=1.0.0", false)]
|
||||
[InlineData("1.2.3-beta", ">=1.2.3", false)]
|
||||
@@ -61,7 +66,8 @@ public sealed class SemanticVersionTests
|
||||
SemanticVersionRange.Satisfies(version, range).Should().Be(expected);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void DeterministicToolset_ComparesSemverAndEvr()
|
||||
{
|
||||
IDeterministicToolset toolset = new DeterministicToolset();
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
<ProjectReference Include="..\..\..\Concelier\__Libraries\StellaOps.Concelier.Core\StellaOps.Concelier.Core.csproj" />
|
||||
<ProjectReference Include="..\..\..\Concelier\__Libraries\StellaOps.Concelier.RawModels\StellaOps.Concelier.RawModels.csproj" />
|
||||
<ProjectReference Include="..\..\..\Excititor\__Libraries\StellaOps.Excititor.Core\StellaOps.Excititor.Core.csproj" />
|
||||
<ProjectReference Include="..\..\..\__Libraries\StellaOps.TestKit\StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Update="TestData/*.json">
|
||||
|
||||
@@ -11,11 +11,13 @@ using StellaOps.AdvisoryAI.Abstractions;
|
||||
using StellaOps.AdvisoryAI.Documents;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AdvisoryAI.Tests;
|
||||
|
||||
public sealed class ToolsetServiceCollectionExtensionsTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void AddAdvisoryDeterministicToolset_RegistersSingleton()
|
||||
{
|
||||
var services = new ServiceCollection();
|
||||
@@ -29,7 +31,8 @@ public sealed class ToolsetServiceCollectionExtensionsTests
|
||||
Assert.Same(toolsetA, toolsetB);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void AddAdvisoryPipeline_RegistersOrchestrator()
|
||||
{
|
||||
var services = new ServiceCollection();
|
||||
|
||||
@@ -14,11 +14,14 @@ using Microsoft.CodeAnalysis.Diagnostics;
|
||||
using Microsoft.CodeAnalysis.Text;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AirGap.Policy.Analyzers.Tests;
|
||||
|
||||
public sealed class HttpClientUsageAnalyzerTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ReportsDiagnostic_ForNewHttpClient()
|
||||
{
|
||||
const string source = """
|
||||
@@ -39,7 +42,8 @@ public sealed class HttpClientUsageAnalyzerTests
|
||||
Assert.Contains(diagnostics, d => d.Id == HttpClientUsageAnalyzer.DiagnosticId);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task DoesNotReportDiagnostic_InsidePolicyAssembly()
|
||||
{
|
||||
const string source = """
|
||||
@@ -57,7 +61,8 @@ public sealed class HttpClientUsageAnalyzerTests
|
||||
Assert.DoesNotContain(diagnostics, d => d.Id == HttpClientUsageAnalyzer.DiagnosticId);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CodeFix_RewritesToFactoryCall()
|
||||
{
|
||||
const string source = """
|
||||
|
||||
@@ -22,6 +22,8 @@ using Microsoft.CodeAnalysis.Text;
|
||||
using Xunit;
|
||||
using FluentAssertions;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AirGap.Policy.Analyzers.Tests;
|
||||
|
||||
/// <summary>
|
||||
@@ -33,7 +35,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
{
|
||||
#region AIRGAP-5100-005: Expected Diagnostics & No False Positives
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData("var client = new HttpClient();", true, "Direct construction should trigger diagnostic")]
|
||||
[InlineData("var client = new System.Net.Http.HttpClient();", true, "Fully qualified construction should trigger diagnostic")]
|
||||
[InlineData("HttpClient client = new();", true, "Target-typed new should trigger diagnostic")]
|
||||
@@ -60,7 +63,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
hasDiagnostic.Should().Be(shouldTrigger, reason);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task NoDiagnostic_ForHttpClientParameter()
|
||||
{
|
||||
const string source = """
|
||||
@@ -83,7 +87,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
"Using HttpClient as parameter should not trigger diagnostic");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task NoDiagnostic_ForHttpClientField()
|
||||
{
|
||||
const string source = """
|
||||
@@ -107,7 +112,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
"Declaring HttpClient field should not trigger diagnostic");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task NoDiagnostic_ForFactoryMethodReturn()
|
||||
{
|
||||
const string source = """
|
||||
@@ -138,7 +144,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
"Using factory method should not trigger diagnostic");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task NoDiagnostic_InTestAssembly()
|
||||
{
|
||||
const string source = """
|
||||
@@ -160,7 +167,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
"Test assemblies should be exempt from diagnostic");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task NoDiagnostic_InPolicyAssembly()
|
||||
{
|
||||
const string source = """
|
||||
@@ -179,7 +187,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
"Policy assembly itself should be exempt");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Diagnostic_HasCorrectSeverity()
|
||||
{
|
||||
const string source = """
|
||||
@@ -203,7 +212,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
"Diagnostic should be a warning, not an error");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Diagnostic_HasCorrectLocation()
|
||||
{
|
||||
const string source = """
|
||||
@@ -228,7 +238,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
lineSpan.StartLinePosition.Line.Should().Be(8, "Diagnostic should point to line 9 (0-indexed: 8)");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task MultipleHttpClientUsages_ReportMultipleDiagnostics()
|
||||
{
|
||||
const string source = """
|
||||
@@ -265,7 +276,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
|
||||
#region AIRGAP-5100-006: Golden Generated Code Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CodeFix_GeneratesExpectedFactoryCall()
|
||||
{
|
||||
const string source = """
|
||||
@@ -301,7 +313,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
"Code fix should match golden output exactly");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CodeFix_PreservesTrivia()
|
||||
{
|
||||
const string source = """
|
||||
@@ -326,7 +339,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
"Leading comment should be preserved");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CodeFix_DeterministicOutput()
|
||||
{
|
||||
const string source = """
|
||||
@@ -352,7 +366,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
result2.Should().Be(result3, "Code fix should be deterministic");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CodeFix_ContainsRequiredPlaceholders()
|
||||
{
|
||||
const string source = """
|
||||
@@ -383,7 +398,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
fixedCode.Should().Contain("REPLACE_INTENT");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CodeFix_UsesFullyQualifiedNames()
|
||||
{
|
||||
const string source = """
|
||||
@@ -408,7 +424,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
fixedCode.Should().Contain("global::System.Uri");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task FixAllProvider_IsWellKnownBatchFixer()
|
||||
{
|
||||
var provider = new HttpClientUsageCodeFixProvider();
|
||||
@@ -418,7 +435,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
"Should use batch fixer for efficient multi-fix application");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Analyzer_SupportedDiagnostics_ContainsExpectedId()
|
||||
{
|
||||
var analyzer = new HttpClientUsageAnalyzer();
|
||||
@@ -428,7 +446,8 @@ public sealed class PolicyAnalyzerRoslynTests
|
||||
supportedDiagnostics[0].Id.Should().Be("AIRGAP001");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task CodeFixProvider_FixableDiagnosticIds_MatchesAnalyzer()
|
||||
{
|
||||
var analyzer = new HttpClientUsageAnalyzer();
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\StellaOps.AirGap.Policy.Analyzers\StellaOps.AirGap.Policy.Analyzers.csproj" />
|
||||
<ProjectReference Include="../../../__Libraries/StellaOps.TestKit/StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
||||
@@ -8,11 +8,14 @@ using Microsoft.Extensions.DependencyInjection;
|
||||
using StellaOps.AirGap.Policy;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AirGap.Policy.Tests;
|
||||
|
||||
public sealed class EgressPolicyTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Evaluate_UnsealedEnvironment_AllowsRequest()
|
||||
{
|
||||
var options = new EgressPolicyOptions
|
||||
@@ -29,7 +32,8 @@ public sealed class EgressPolicyTests
|
||||
Assert.Null(decision.Reason);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void EnsureAllowed_SealedEnvironmentWithMatchingRule_Allows()
|
||||
{
|
||||
var options = new EgressPolicyOptions
|
||||
@@ -44,7 +48,8 @@ public sealed class EgressPolicyTests
|
||||
policy.EnsureAllowed(request);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void EnsureAllowed_SealedEnvironmentWithoutRule_ThrowsWithGuidance()
|
||||
{
|
||||
var options = new EgressPolicyOptions
|
||||
@@ -67,7 +72,8 @@ public sealed class EgressPolicyTests
|
||||
Assert.Equal(options.SupportContact, exception.SupportContact);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void EnsureAllowed_SealedEnvironment_AllowsLoopbackWhenConfigured()
|
||||
{
|
||||
var options = new EgressPolicyOptions
|
||||
@@ -82,7 +88,8 @@ public sealed class EgressPolicyTests
|
||||
policy.EnsureAllowed(request);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void EnsureAllowed_SealedEnvironment_AllowsPrivateNetworkWhenConfigured()
|
||||
{
|
||||
var options = new EgressPolicyOptions
|
||||
@@ -97,7 +104,8 @@ public sealed class EgressPolicyTests
|
||||
policy.EnsureAllowed(request);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void EnsureAllowed_SealedEnvironment_BlocksPrivateNetworkWhenNotConfigured()
|
||||
{
|
||||
var options = new EgressPolicyOptions
|
||||
@@ -113,7 +121,8 @@ public sealed class EgressPolicyTests
|
||||
Assert.Contains("10.10.0.5", exception.Message, StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData("https://api.example.com", true)]
|
||||
[InlineData("https://sub.api.example.com", true)]
|
||||
[InlineData("https://example.com", false)]
|
||||
@@ -132,7 +141,8 @@ public sealed class EgressPolicyTests
|
||||
Assert.Equal(expectedAllowed, decision.IsAllowed);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ServiceCollection_AddAirGapEgressPolicy_RegistersService()
|
||||
{
|
||||
var services = new ServiceCollection();
|
||||
@@ -149,7 +159,8 @@ public sealed class EgressPolicyTests
|
||||
policy.EnsureAllowed(new EgressRequest("PolicyEngine", new Uri("https://mirror.internal"), "mirror-sync"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ServiceCollection_AddAirGapEgressPolicy_BindsFromConfiguration()
|
||||
{
|
||||
var configuration = new ConfigurationBuilder()
|
||||
@@ -182,7 +193,8 @@ public sealed class EgressPolicyTests
|
||||
Assert.Contains("mirror.internal", blocked.Remediation, StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void EgressHttpClientFactory_Create_EnforcesPolicyBeforeReturningClient()
|
||||
{
|
||||
var recordingPolicy = new RecordingPolicy();
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\StellaOps.AirGap.Policy\StellaOps.AirGap.Policy.csproj" />
|
||||
<ProjectReference Include="../../../__Libraries/StellaOps.TestKit/StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
||||
@@ -14,6 +14,7 @@ using StellaOps.AirGap.Time.Models;
|
||||
using StellaOps.Infrastructure.Postgres.Options;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AirGap.Storage.Postgres.Tests;
|
||||
|
||||
/// <summary>
|
||||
@@ -55,7 +56,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
|
||||
#region AIRGAP-5100-007: Migration Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Migration_SchemaContainsRequiredTables()
|
||||
{
|
||||
// Arrange
|
||||
@@ -77,7 +79,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Migration_AirGapStateHasRequiredColumns()
|
||||
{
|
||||
// Arrange
|
||||
@@ -94,7 +97,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Migration_IsIdempotent()
|
||||
{
|
||||
// Act - Running migrations again should not fail
|
||||
@@ -107,7 +111,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
await act.Should().NotThrowAsync("Running migrations multiple times should be idempotent");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Migration_HasTenantIndex()
|
||||
{
|
||||
// Act
|
||||
@@ -122,7 +127,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
|
||||
#region AIRGAP-5100-008: Idempotency Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Idempotency_SetStateTwice_NoException()
|
||||
{
|
||||
// Arrange
|
||||
@@ -137,7 +143,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
await act.Should().NotThrowAsync("Setting state twice should be idempotent");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Idempotency_SetStateTwice_SingleRecord()
|
||||
{
|
||||
// Arrange
|
||||
@@ -154,7 +161,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
fetched.PolicyHash.Should().Be("sha256:policy-v2", "Second set should update, not duplicate");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Idempotency_ConcurrentSets_NoDataCorruption()
|
||||
{
|
||||
// Arrange
|
||||
@@ -181,7 +189,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
fetched.PolicyHash.Should().StartWith("sha256:policy-");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Idempotency_SameBundleIdTwice_NoException()
|
||||
{
|
||||
// Arrange
|
||||
@@ -203,7 +212,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
|
||||
#region AIRGAP-5100-009: Query Determinism Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task QueryDeterminism_SameInput_SameOutput()
|
||||
{
|
||||
// Arrange
|
||||
@@ -221,7 +231,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
result2.Should().BeEquivalentTo(result3);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task QueryDeterminism_ContentBudgets_ReturnInConsistentOrder()
|
||||
{
|
||||
// Arrange
|
||||
@@ -252,7 +263,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task QueryDeterminism_TimeAnchor_PreservesAllFields()
|
||||
{
|
||||
// Arrange
|
||||
@@ -277,7 +289,8 @@ public sealed class AirGapStorageIntegrationTests : IAsyncLifetime
|
||||
fetched1.TimeAnchor.Source.Should().Be("tsa.example.com");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task QueryDeterminism_MultipleTenants_IsolatedResults()
|
||||
{
|
||||
// Arrange
|
||||
|
||||
@@ -8,6 +8,7 @@ using StellaOps.AirGap.Time.Models;
|
||||
using StellaOps.Infrastructure.Postgres.Options;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AirGap.Storage.Postgres.Tests;
|
||||
|
||||
[Collection(AirGapPostgresCollection.Name)]
|
||||
@@ -42,7 +43,8 @@ public sealed class PostgresAirGapStateStoreTests : IAsyncLifetime
|
||||
await _dataSource.DisposeAsync();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task GetAsync_ReturnsDefaultStateForNewTenant()
|
||||
{
|
||||
// Act
|
||||
@@ -55,7 +57,8 @@ public sealed class PostgresAirGapStateStoreTests : IAsyncLifetime
|
||||
state.PolicyHash.Should().BeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SetAndGet_RoundTripsState()
|
||||
{
|
||||
// Arrange
|
||||
@@ -100,7 +103,8 @@ public sealed class PostgresAirGapStateStoreTests : IAsyncLifetime
|
||||
fetched.ContentBudgets["advisories"].WarningSeconds.Should().Be(7200);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SetAsync_UpdatesExistingState()
|
||||
{
|
||||
// Arrange
|
||||
@@ -136,7 +140,8 @@ public sealed class PostgresAirGapStateStoreTests : IAsyncLifetime
|
||||
fetched.StalenessBudget.WarningSeconds.Should().Be(600);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SetAsync_PersistsContentBudgets()
|
||||
{
|
||||
// Arrange
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
<ProjectReference Include="..\StellaOps.AirGap.Storage.Postgres\StellaOps.AirGap.Storage.Postgres.csproj" />
|
||||
<ProjectReference Include="..\StellaOps.AirGap.Controller\StellaOps.AirGap.Controller.csproj" />
|
||||
<ProjectReference Include="..\..\__Tests\__Libraries\StellaOps.Infrastructure.Postgres.Testing\StellaOps.Infrastructure.Postgres.Testing.csproj" />
|
||||
<ProjectReference Include="../../__Libraries/StellaOps.TestKit/StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
||||
@@ -10,6 +10,7 @@ using System.Text;
|
||||
using FluentAssertions;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AirGap.Bundle.Tests;
|
||||
|
||||
/// <summary>
|
||||
@@ -22,7 +23,8 @@ public sealed class AirGapCliToolTests
|
||||
{
|
||||
#region AIRGAP-5100-013: Exit Code Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ExitCode_SuccessfulExport_ReturnsZero()
|
||||
{
|
||||
// Arrange
|
||||
@@ -32,7 +34,8 @@ public sealed class AirGapCliToolTests
|
||||
expectedExitCode.Should().Be(0, "Successful operations should return exit code 0");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ExitCode_UserError_ReturnsOne()
|
||||
{
|
||||
// Arrange
|
||||
@@ -43,7 +46,8 @@ public sealed class AirGapCliToolTests
|
||||
expectedExitCode.Should().Be(1, "User errors should return exit code 1");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ExitCode_SystemError_ReturnsTwo()
|
||||
{
|
||||
// Arrange
|
||||
@@ -54,7 +58,8 @@ public sealed class AirGapCliToolTests
|
||||
expectedExitCode.Should().Be(2, "System errors should return exit code 2");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ExitCode_MissingRequiredArgument_ReturnsOne()
|
||||
{
|
||||
// Arrange - Missing required argument scenario
|
||||
@@ -66,7 +71,8 @@ public sealed class AirGapCliToolTests
|
||||
expectedExitCode.Should().Be(1);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ExitCode_InvalidFeedPath_ReturnsOne()
|
||||
{
|
||||
// Arrange - Invalid feed path scenario
|
||||
@@ -84,7 +90,8 @@ public sealed class AirGapCliToolTests
|
||||
expectedExitCode.Should().Be(1, "Invalid feed path should return exit code 1");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ExitCode_HelpFlag_ReturnsZero()
|
||||
{
|
||||
// Arrange
|
||||
@@ -96,7 +103,8 @@ public sealed class AirGapCliToolTests
|
||||
expectedExitCode.Should().Be(0, "--help should return exit code 0");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void ExitCode_VersionFlag_ReturnsZero()
|
||||
{
|
||||
// Arrange
|
||||
@@ -112,7 +120,8 @@ public sealed class AirGapCliToolTests
|
||||
|
||||
#region AIRGAP-5100-014: Golden Output Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void GoldenOutput_ExportCommand_IncludesManifestSummary()
|
||||
{
|
||||
// Arrange - Expected output structure for export command
|
||||
@@ -135,7 +144,8 @@ public sealed class AirGapCliToolTests
|
||||
expectedOutputLines.Should().Contain(l => l.Contains("Digest:"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void GoldenOutput_ExportCommand_IncludesBundleDigest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -145,7 +155,8 @@ public sealed class AirGapCliToolTests
|
||||
digestPattern.Should().Contain("sha256:");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void GoldenOutput_ImportCommand_IncludesImportSummary()
|
||||
{
|
||||
// Arrange - Expected output structure for import command
|
||||
@@ -165,7 +176,8 @@ public sealed class AirGapCliToolTests
|
||||
expectedOutputLines.Should().Contain(l => l.Contains("imported successfully"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void GoldenOutput_ListCommand_IncludesBundleTable()
|
||||
{
|
||||
// Arrange - Expected output structure for list command
|
||||
@@ -177,7 +189,8 @@ public sealed class AirGapCliToolTests
|
||||
expectedHeaders.Should().Contain("Version");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void GoldenOutput_ValidateCommand_IncludesValidationResult()
|
||||
{
|
||||
// Arrange - Expected output structure for validate command
|
||||
@@ -195,7 +208,8 @@ public sealed class AirGapCliToolTests
|
||||
expectedOutputLines.Should().Contain(l => l.Contains("Validation:"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void GoldenOutput_ErrorMessage_IncludesContext()
|
||||
{
|
||||
// Arrange - Error message format
|
||||
@@ -210,7 +224,8 @@ public sealed class AirGapCliToolTests
|
||||
|
||||
#region AIRGAP-5100-015: CLI Determinism Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void CliDeterminism_SameInputs_SameOutputDigest()
|
||||
{
|
||||
// Arrange - Simulate CLI determinism
|
||||
@@ -225,7 +240,8 @@ public sealed class AirGapCliToolTests
|
||||
digest1.Should().Be(digest2, "Same inputs should produce same digest");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void CliDeterminism_OutputBundleName_IsDeterministic()
|
||||
{
|
||||
// Arrange
|
||||
@@ -241,7 +257,8 @@ public sealed class AirGapCliToolTests
|
||||
filename1.Should().Be(filename2, "Same parameters should produce same filename");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void CliDeterminism_ManifestJson_IsDeterministic()
|
||||
{
|
||||
// Arrange
|
||||
@@ -256,7 +273,8 @@ public sealed class AirGapCliToolTests
|
||||
json1.Should().Be(json2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void CliDeterminism_FeedOrdering_IsDeterministic()
|
||||
{
|
||||
// Arrange - Feeds in different order
|
||||
@@ -272,7 +290,8 @@ public sealed class AirGapCliToolTests
|
||||
"Canonical ordering should be deterministic");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void CliDeterminism_DigestComputation_IsDeterministic()
|
||||
{
|
||||
// Arrange
|
||||
@@ -290,7 +309,8 @@ public sealed class AirGapCliToolTests
|
||||
digest3.Should().Be(expectedDigest);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void CliDeterminism_TimestampFormat_IsDeterministic()
|
||||
{
|
||||
// Arrange
|
||||
|
||||
@@ -14,6 +14,7 @@ using StellaOps.AirGap.Bundle.Serialization;
|
||||
using StellaOps.AirGap.Bundle.Services;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AirGap.Bundle.Tests;
|
||||
|
||||
/// <summary>
|
||||
@@ -48,7 +49,8 @@ public sealed class AirGapIntegrationTests : IDisposable
|
||||
|
||||
#region AIRGAP-5100-016: Online → Offline Bundle Transfer Integration
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Integration_OnlineExport_OfflineImport_DataIntegrity()
|
||||
{
|
||||
// Arrange - Create source data in "online" environment
|
||||
@@ -102,7 +104,8 @@ public sealed class AirGapIntegrationTests : IDisposable
|
||||
importedFeedContent.Should().Contain("CVE-2024-0001");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Integration_BundleTransfer_PreservesAllComponents()
|
||||
{
|
||||
// Arrange - Create multi-component bundle
|
||||
@@ -143,7 +146,8 @@ public sealed class AirGapIntegrationTests : IDisposable
|
||||
File.Exists(Path.Combine(offlinePath, "certs/root.pem")).Should().BeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Integration_CorruptedBundle_ImportFails()
|
||||
{
|
||||
// Arrange
|
||||
@@ -185,7 +189,8 @@ public sealed class AirGapIntegrationTests : IDisposable
|
||||
|
||||
#region AIRGAP-5100-017: Policy Export/Import/Evaluation Integration
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Integration_PolicyExport_PolicyImport_IdenticalVerdict()
|
||||
{
|
||||
// Arrange - Create a policy in online environment
|
||||
@@ -242,7 +247,8 @@ public sealed class AirGapIntegrationTests : IDisposable
|
||||
importedDigest.Should().Be(originalDigest, "Policy digest should match");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Integration_MultiplePolices_MaintainOrder()
|
||||
{
|
||||
// Arrange - Create multiple policies
|
||||
@@ -289,7 +295,8 @@ public sealed class AirGapIntegrationTests : IDisposable
|
||||
File.Exists(Path.Combine(offlinePath, "policies/policy3.rego")).Should().BeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Integration_PolicyWithCrypto_BothTransferred()
|
||||
{
|
||||
// Arrange
|
||||
|
||||
@@ -35,7 +35,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
|
||||
#region Same Inputs → Same Hash Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Determinism_SameInputs_SameComponentDigests()
|
||||
{
|
||||
// Arrange
|
||||
@@ -55,7 +56,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
manifest1.Feeds[0].Digest.Should().Be(manifest2.Feeds[0].Digest);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Determinism_SameManifestContent_SameBundleDigest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -70,7 +72,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
digest1.Should().Be(digest2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Determinism_MultipleBuilds_SameDigests()
|
||||
{
|
||||
// Arrange
|
||||
@@ -93,7 +96,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
digests.Distinct().Should().HaveCount(1, "All builds should produce the same digest");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Determinism_Sha256_StableAcrossCalls()
|
||||
{
|
||||
// Arrange
|
||||
@@ -115,7 +119,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
|
||||
#region Roundtrip Determinism Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Roundtrip_ExportImportReexport_IdenticalBundle()
|
||||
{
|
||||
// Arrange
|
||||
@@ -147,6 +152,7 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
|
||||
// Re-export using the imported file
|
||||
var reimportFeedFile = CreateSourceFile("reimport/feed.json", importedContent);
|
||||
using StellaOps.TestKit;
|
||||
var request2 = new BundleBuildRequest(
|
||||
"roundtrip-test",
|
||||
"1.0.0",
|
||||
@@ -165,7 +171,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
manifest1.Feeds[0].Digest.Should().Be(manifest2.Feeds[0].Digest);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Roundtrip_ManifestSerialize_Deserialize_Identical()
|
||||
{
|
||||
// Arrange
|
||||
@@ -179,7 +186,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
restored.Should().BeEquivalentTo(original);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Roundtrip_ManifestSerialize_Reserialize_SameJson()
|
||||
{
|
||||
// Arrange
|
||||
@@ -198,7 +206,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
|
||||
#region Content Independence Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Determinism_SameContent_DifferentSourcePath_SameDigest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -219,7 +228,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
manifest1.Feeds[0].Digest.Should().Be(manifest2.Feeds[0].Digest);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Determinism_DifferentContent_DifferentDigest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -243,7 +253,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
|
||||
#region Multiple Component Determinism
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Determinism_MultipleFeeds_EachHasCorrectDigest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -278,7 +289,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
manifest.Feeds[2].Digest.Should().Be(ComputeSha256(content3));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Determinism_OrderIndependence_SameManifestDigest()
|
||||
{
|
||||
// Note: This test verifies that the bundle digest is computed deterministically
|
||||
@@ -300,7 +312,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
|
||||
#region Binary Content Determinism
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Determinism_BinaryContent_SameDigest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -340,7 +353,8 @@ public sealed class BundleDeterminismTests : IAsyncLifetime
|
||||
manifest1.Feeds[0].Digest.Should().Be(manifest2.Feeds[0].Digest);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Determinism_LargeContent_SameDigest()
|
||||
{
|
||||
// Arrange
|
||||
|
||||
@@ -44,7 +44,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
|
||||
#region AIRGAP-5100-001: Bundle Export Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_CreatesValidBundleStructure()
|
||||
{
|
||||
// Arrange
|
||||
@@ -63,7 +64,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
manifest.Feeds.Should().HaveCount(1);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_SetsCorrectManifestFields()
|
||||
{
|
||||
// Arrange
|
||||
@@ -83,7 +85,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
manifest.CreatedAt.Should().BeCloseTo(DateTimeOffset.UtcNow, TimeSpan.FromSeconds(5));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_ComputesCorrectFileDigests()
|
||||
{
|
||||
// Arrange
|
||||
@@ -107,7 +110,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
feedDigest.Should().Be(expectedDigest);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_ComputesCorrectBundleDigest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -124,7 +128,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
manifest.BundleDigest.Should().HaveLength(64);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_TracksCorrectFileSizes()
|
||||
{
|
||||
// Arrange
|
||||
@@ -146,7 +151,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
|
||||
#region AIRGAP-5100-002: Bundle Import Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_LoadsManifestCorrectly()
|
||||
{
|
||||
// Arrange - First export a bundle
|
||||
@@ -170,7 +176,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
loaded.Version.Should().Be("1.0.0");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_VerifiesFileIntegrity()
|
||||
{
|
||||
// Arrange
|
||||
@@ -198,7 +205,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
loaded.Feeds[0].Digest.Should().Be(actualDigest);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_FailsOnCorruptedFile()
|
||||
{
|
||||
// Arrange
|
||||
@@ -230,7 +238,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
|
||||
#region AIRGAP-5100-003: Determinism Tests (Same Inputs → Same Hash)
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Determinism_SameInputs_ProduceSameBundleDigest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -271,7 +280,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
"Same content should produce same file digest");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Determinism_DifferentInputs_ProduceDifferentDigests()
|
||||
{
|
||||
// Arrange
|
||||
@@ -294,7 +304,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
"Different content should produce different digests");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Determinism_ManifestSerialization_IsStable()
|
||||
{
|
||||
// Arrange
|
||||
@@ -314,7 +325,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
|
||||
#region AIRGAP-5100-004: Roundtrip Determinism (Export → Import → Re-export)
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Roundtrip_ExportImportReexport_ProducesIdenticalFileDigests()
|
||||
{
|
||||
// Arrange - Initial export
|
||||
@@ -337,6 +349,7 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
|
||||
// Re-export using the imported bundle's files
|
||||
var reexportFeedFile = Path.Combine(bundlePath1, "feeds", "nvd.json");
|
||||
using StellaOps.TestKit;
|
||||
var reexportRequest = new BundleBuildRequest(
|
||||
imported.Name,
|
||||
imported.Version,
|
||||
@@ -360,7 +373,8 @@ public sealed class BundleExportImportTests : IDisposable
|
||||
digest1.Should().Be(digest2, "Roundtrip should produce identical file digests");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Roundtrip_ManifestSerialization_PreservesAllFields()
|
||||
{
|
||||
// Arrange
|
||||
|
||||
@@ -7,6 +7,7 @@ using StellaOps.AirGap.Bundle.Serialization;
|
||||
using StellaOps.AirGap.Bundle.Services;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AirGap.Bundle.Tests;
|
||||
|
||||
/// <summary>
|
||||
@@ -35,7 +36,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
|
||||
#region L0 Export Structure Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_EmptyBundle_CreatesValidManifest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -65,7 +67,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
manifest.TotalSizeBytes.Should().Be(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_WithFeed_CopiesFileAndComputesDigest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -111,7 +114,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
File.Exists(Path.Combine(outputPath, "feeds/nvd.json")).Should().BeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_WithPolicy_CopiesFileAndComputesDigest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -153,7 +157,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
File.Exists(Path.Combine(outputPath, "policies/default.rego")).Should().BeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_WithCryptoMaterial_CopiesFileAndComputesDigest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -195,7 +200,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
File.Exists(Path.Combine(outputPath, "certs/root.pem")).Should().BeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_MultipleComponents_CalculatesTotalSize()
|
||||
{
|
||||
// Arrange
|
||||
@@ -234,7 +240,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
|
||||
#region Digest Computation Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_DigestComputation_MatchesSha256()
|
||||
{
|
||||
// Arrange
|
||||
@@ -263,7 +270,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
manifest.Feeds[0].Digest.Should().BeEquivalentTo(expectedDigest, options => options.IgnoringCase());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_BundleDigest_ComputedFromManifest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -294,7 +302,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
|
||||
#region Directory Structure Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_CreatesNestedDirectories()
|
||||
{
|
||||
// Arrange
|
||||
@@ -338,7 +347,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
|
||||
#region Feed Format Tests
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData(FeedFormat.StellaOpsNative)]
|
||||
[InlineData(FeedFormat.TrivyDb)]
|
||||
[InlineData(FeedFormat.GrypeDb)]
|
||||
@@ -372,7 +382,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
|
||||
#region Policy Type Tests
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData(PolicyType.OpaRego)]
|
||||
[InlineData(PolicyType.LatticeRules)]
|
||||
[InlineData(PolicyType.UnknownBudgets)]
|
||||
@@ -406,7 +417,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
|
||||
#region Crypto Component Type Tests
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData(CryptoComponentType.TrustRoot)]
|
||||
[InlineData(CryptoComponentType.IntermediateCa)]
|
||||
[InlineData(CryptoComponentType.TimestampRoot)]
|
||||
@@ -441,7 +453,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
|
||||
#region Expiration Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_WithExpiration_PreservesExpiryDate()
|
||||
{
|
||||
// Arrange
|
||||
@@ -464,7 +477,8 @@ public sealed class BundleExportTests : IAsyncLifetime
|
||||
manifest.ExpiresAt.Should().BeCloseTo(expiresAt, TimeSpan.FromSeconds(1));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Export_CryptoWithExpiration_PreservesComponentExpiry()
|
||||
{
|
||||
// Arrange
|
||||
|
||||
@@ -9,6 +9,8 @@ using StellaOps.AirGap.Bundle.Services;
|
||||
using StellaOps.AirGap.Bundle.Validation;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AirGap.Bundle.Tests;
|
||||
|
||||
/// <summary>
|
||||
@@ -37,7 +39,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
|
||||
#region Manifest Parsing Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Import_ManifestDeserialization_PreservesAllFields()
|
||||
{
|
||||
// Arrange
|
||||
@@ -51,7 +54,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
imported.Should().BeEquivalentTo(manifest);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Import_ManifestDeserialization_HandlesEmptyCollections()
|
||||
{
|
||||
// Arrange
|
||||
@@ -67,7 +71,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
imported.CryptoMaterials.Should().BeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Import_ManifestDeserialization_PreservesFeedComponents()
|
||||
{
|
||||
// Arrange
|
||||
@@ -85,7 +90,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
imported.Feeds[1].Format.Should().Be(FeedFormat.OsvJson);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Import_ManifestDeserialization_PreservesPolicyComponents()
|
||||
{
|
||||
// Arrange
|
||||
@@ -101,7 +107,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
imported.Policies[1].Type.Should().Be(PolicyType.LatticeRules);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Import_ManifestDeserialization_PreservesCryptoComponents()
|
||||
{
|
||||
// Arrange
|
||||
@@ -121,7 +128,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
|
||||
#region Validation Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_Validation_FailsWhenFilesMissing()
|
||||
{
|
||||
// Arrange
|
||||
@@ -141,7 +149,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
result.Errors.Should().Contain(e => e.Message.Contains("digest mismatch") || e.Message.Contains("FILE_NOT_FOUND"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_Validation_FailsWhenDigestMismatch()
|
||||
{
|
||||
// Arrange
|
||||
@@ -158,7 +167,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
result.Errors.Should().Contain(e => e.Message.Contains("digest mismatch"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_Validation_SucceedsWhenAllDigestsMatch()
|
||||
{
|
||||
// Arrange
|
||||
@@ -175,7 +185,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
result.Errors.Should().BeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_Validation_WarnsWhenExpired()
|
||||
{
|
||||
// Arrange
|
||||
@@ -195,7 +206,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
result.Warnings.Should().Contain(w => w.Message.Contains("expired"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_Validation_WarnsWhenFeedsOld()
|
||||
{
|
||||
// Arrange
|
||||
@@ -224,7 +236,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
|
||||
#region Bundle Loader Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_Loader_RegistersAllFeeds()
|
||||
{
|
||||
// Arrange
|
||||
@@ -252,7 +265,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
feedRegistry.Received(manifest.Feeds.Length).Register(Arg.Any<FeedComponent>(), Arg.Any<string>());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_Loader_RegistersAllPolicies()
|
||||
{
|
||||
// Arrange
|
||||
@@ -279,7 +293,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
policyRegistry.Received(manifest.Policies.Length).Register(Arg.Any<PolicyComponent>(), Arg.Any<string>());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_Loader_ThrowsOnValidationFailure()
|
||||
{
|
||||
// Arrange
|
||||
@@ -306,7 +321,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
.WithMessage("*validation failed*");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_Loader_ThrowsOnMissingManifest()
|
||||
{
|
||||
// Arrange
|
||||
@@ -330,7 +346,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
|
||||
#region Digest Verification Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_DigestVerification_MatchesExpected()
|
||||
{
|
||||
// Arrange
|
||||
@@ -346,7 +363,8 @@ public sealed class BundleImportTests : IAsyncLifetime
|
||||
actualDigest.Should().BeEquivalentTo(expectedDigest, options => options.IgnoringCase());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Import_DigestVerification_FailsOnTamperedFile()
|
||||
{
|
||||
// Arrange
|
||||
|
||||
@@ -6,11 +6,13 @@ using StellaOps.AirGap.Bundle.Services;
|
||||
using StellaOps.AirGap.Bundle.Validation;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AirGap.Bundle.Tests;
|
||||
|
||||
public class BundleManifestTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Serializer_RoundTrip_PreservesFields()
|
||||
{
|
||||
var manifest = CreateManifest();
|
||||
@@ -19,7 +21,8 @@ public class BundleManifestTests
|
||||
deserialized.Should().BeEquivalentTo(manifest);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Validator_FlagsMissingFeedFile()
|
||||
{
|
||||
var manifest = CreateManifest();
|
||||
@@ -30,7 +33,8 @@ public class BundleManifestTests
|
||||
result.Errors.Should().NotBeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Builder_CopiesComponentsAndComputesDigest()
|
||||
{
|
||||
var tempRoot = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString());
|
||||
|
||||
@@ -16,5 +16,6 @@
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\StellaOps.AirGap.Bundle\StellaOps.AirGap.Bundle.csproj" />
|
||||
<ProjectReference Include="../../../../__Libraries/StellaOps.TestKit/StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
||||
@@ -14,6 +14,8 @@ using System.Text.Json;
|
||||
using FluentAssertions;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.AirGap.Controller.Tests;
|
||||
|
||||
/// <summary>
|
||||
@@ -26,7 +28,8 @@ public sealed class AirGapControllerContractTests
|
||||
{
|
||||
#region AIRGAP-5100-010: Contract Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Contract_ExportEndpoint_ExpectedRequestStructure()
|
||||
{
|
||||
// Arrange - Define expected request structure
|
||||
@@ -56,7 +59,8 @@ public sealed class AirGapControllerContractTests
|
||||
feeds.GetArrayLength().Should().BeGreaterThan(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Contract_ExportEndpoint_ExpectedResponseStructure()
|
||||
{
|
||||
// Arrange - Define expected response structure
|
||||
@@ -87,7 +91,8 @@ public sealed class AirGapControllerContractTests
|
||||
parsed.RootElement.TryGetProperty("manifest", out _).Should().BeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Contract_ImportEndpoint_ExpectedRequestStructure()
|
||||
{
|
||||
// Arrange - Import request (typically multipart form or bundle URL)
|
||||
@@ -107,7 +112,8 @@ public sealed class AirGapControllerContractTests
|
||||
parsed.RootElement.TryGetProperty("bundleDigest", out _).Should().BeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Contract_ImportEndpoint_ExpectedResponseStructure()
|
||||
{
|
||||
// Arrange
|
||||
@@ -131,7 +137,8 @@ public sealed class AirGapControllerContractTests
|
||||
parsed.RootElement.TryGetProperty("feedsImported", out _).Should().BeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Contract_ListBundlesEndpoint_ExpectedResponseStructure()
|
||||
{
|
||||
// Arrange
|
||||
@@ -164,7 +171,8 @@ public sealed class AirGapControllerContractTests
|
||||
parsed.RootElement.TryGetProperty("total", out _).Should().BeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Contract_StateEndpoint_ExpectedResponseStructure()
|
||||
{
|
||||
// Arrange - AirGap state response
|
||||
@@ -197,7 +205,8 @@ public sealed class AirGapControllerContractTests
|
||||
|
||||
#region AIRGAP-5100-011: Auth Tests
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Auth_RequiredScopes_ForExport()
|
||||
{
|
||||
// Arrange - Expected scopes for export operation
|
||||
@@ -207,7 +216,8 @@ public sealed class AirGapControllerContractTests
|
||||
requiredScopes.Should().Contain("airgap:export");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Auth_RequiredScopes_ForImport()
|
||||
{
|
||||
// Arrange - Expected scopes for import operation
|
||||
@@ -217,7 +227,8 @@ public sealed class AirGapControllerContractTests
|
||||
requiredScopes.Should().Contain("airgap:import");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Auth_RequiredScopes_ForList()
|
||||
{
|
||||
// Arrange - Expected scopes for list operation
|
||||
@@ -227,7 +238,8 @@ public sealed class AirGapControllerContractTests
|
||||
requiredScopes.Should().Contain("airgap:read");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Auth_DenyByDefault_NoTokenReturnsUnauthorized()
|
||||
{
|
||||
// Arrange - Request without token
|
||||
@@ -237,7 +249,8 @@ public sealed class AirGapControllerContractTests
|
||||
expectedStatusCode.Should().Be(HttpStatusCode.Unauthorized);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Auth_TenantIsolation_CannotAccessOtherTenantBundles()
|
||||
{
|
||||
// Arrange - Claims for tenant A
|
||||
@@ -256,7 +269,8 @@ public sealed class AirGapControllerContractTests
|
||||
// Requests for tenant-B bundles should be rejected
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Auth_TokenExpiry_ExpiredTokenReturnsForbidden()
|
||||
{
|
||||
// Arrange - Expired token scenario
|
||||
@@ -272,7 +286,8 @@ public sealed class AirGapControllerContractTests
|
||||
|
||||
#region AIRGAP-5100-012: OTel Trace Assertions
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void OTel_ExportOperation_IncludesBundleIdTag()
|
||||
{
|
||||
// Arrange
|
||||
@@ -289,7 +304,8 @@ public sealed class AirGapControllerContractTests
|
||||
expectedTags.Should().Contain("operation");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void OTel_ImportOperation_IncludesOperationTag()
|
||||
{
|
||||
// Arrange
|
||||
@@ -305,7 +321,8 @@ public sealed class AirGapControllerContractTests
|
||||
expectedTags["operation"].Should().Be("airgap.import");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void OTel_Metrics_TracksExportCount()
|
||||
{
|
||||
// Arrange
|
||||
@@ -317,7 +334,8 @@ public sealed class AirGapControllerContractTests
|
||||
metricName.Should().NotBeNullOrEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void OTel_Metrics_TracksImportCount()
|
||||
{
|
||||
// Arrange
|
||||
@@ -329,7 +347,8 @@ public sealed class AirGapControllerContractTests
|
||||
expectedDimensions.Should().Contain("status");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void OTel_ActivitySource_HasCorrectName()
|
||||
{
|
||||
// Arrange
|
||||
@@ -339,7 +358,8 @@ public sealed class AirGapControllerContractTests
|
||||
expectedSourceName.Should().StartWith("StellaOps.");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void OTel_Spans_PropagateTraceContext()
|
||||
{
|
||||
// Arrange - Create a trace context
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\\..\\StellaOps.AirGap.Importer\\StellaOps.AirGap.Importer.csproj" />
|
||||
<ProjectReference Include="../../../__Libraries/StellaOps.TestKit/StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
||||
@@ -13,7 +13,8 @@ namespace StellaOps.Aoc.Analyzers.Tests;
|
||||
|
||||
public sealed class AocForbiddenFieldAnalyzerTests
|
||||
{
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData("severity")]
|
||||
[InlineData("cvss")]
|
||||
[InlineData("cvss_vector")]
|
||||
@@ -46,7 +47,8 @@ public sealed class AocForbiddenFieldAnalyzerTests
|
||||
Assert.Contains(diagnostics, d => d.Id == AocForbiddenFieldAnalyzer.DiagnosticIdForbiddenField);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Theory]
|
||||
[InlineData("effective_date")]
|
||||
[InlineData("effective_version")]
|
||||
[InlineData("effective_score")]
|
||||
@@ -73,7 +75,8 @@ public sealed class AocForbiddenFieldAnalyzerTests
|
||||
Assert.Contains(diagnostics, d => d.Id == AocForbiddenFieldAnalyzer.DiagnosticIdDerivedField);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ReportsDiagnostic_ForForbiddenFieldInObjectInitializer()
|
||||
{
|
||||
const string source = """
|
||||
@@ -102,7 +105,8 @@ public sealed class AocForbiddenFieldAnalyzerTests
|
||||
Assert.Contains(diagnostics, d => d.Id == AocForbiddenFieldAnalyzer.DiagnosticIdForbiddenField);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task DoesNotReportDiagnostic_ForAllowedFieldAssignment()
|
||||
{
|
||||
const string source = """
|
||||
@@ -130,7 +134,8 @@ public sealed class AocForbiddenFieldAnalyzerTests
|
||||
d.Id == AocForbiddenFieldAnalyzer.DiagnosticIdDerivedField);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task DoesNotReportDiagnostic_ForNonIngestionAssembly()
|
||||
{
|
||||
const string source = """
|
||||
@@ -154,7 +159,8 @@ public sealed class AocForbiddenFieldAnalyzerTests
|
||||
Assert.DoesNotContain(diagnostics, d => d.Id == AocForbiddenFieldAnalyzer.DiagnosticIdForbiddenField);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task DoesNotReportDiagnostic_ForTestAssembly()
|
||||
{
|
||||
const string source = """
|
||||
@@ -178,12 +184,14 @@ public sealed class AocForbiddenFieldAnalyzerTests
|
||||
Assert.DoesNotContain(diagnostics, d => d.Id == AocForbiddenFieldAnalyzer.DiagnosticIdForbiddenField);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ReportsDiagnostic_ForDictionaryAddWithForbiddenKey()
|
||||
{
|
||||
const string source = """
|
||||
using System.Collections.Generic;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Concelier.Connector.Sample;
|
||||
|
||||
public sealed class Ingester
|
||||
@@ -200,7 +208,8 @@ public sealed class AocForbiddenFieldAnalyzerTests
|
||||
Assert.Contains(diagnostics, d => d.Id == AocForbiddenFieldAnalyzer.DiagnosticIdForbiddenField);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ReportsDiagnostic_CaseInsensitive()
|
||||
{
|
||||
const string source = """
|
||||
@@ -225,7 +234,8 @@ public sealed class AocForbiddenFieldAnalyzerTests
|
||||
Assert.Contains(diagnostics, d => d.Id == AocForbiddenFieldAnalyzer.DiagnosticIdForbiddenField);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ReportsDiagnostic_ForAnonymousObjectWithForbiddenField()
|
||||
{
|
||||
const string source = """
|
||||
@@ -244,7 +254,8 @@ public sealed class AocForbiddenFieldAnalyzerTests
|
||||
Assert.Contains(diagnostics, d => d.Id == AocForbiddenFieldAnalyzer.DiagnosticIdForbiddenField);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task DoesNotReportDiagnostic_ForIngestionNamespaceButNotConnector()
|
||||
{
|
||||
const string source = """
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\__Analyzers\StellaOps.Aoc.Analyzers\StellaOps.Aoc.Analyzers.csproj" />
|
||||
<ProjectReference Include="../../../__Libraries/StellaOps.TestKit/StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
||||
@@ -5,11 +5,14 @@ using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using StellaOps.Aoc.AspNetCore.Routing;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Aoc.AspNetCore.Tests;
|
||||
|
||||
public sealed class AocGuardEndpointFilterExtensionsTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void RequireAocGuard_ReturnsBuilderInstance()
|
||||
{
|
||||
var builder = WebApplication.CreateBuilder();
|
||||
@@ -23,7 +26,8 @@ public sealed class AocGuardEndpointFilterExtensionsTests
|
||||
Assert.Same(route, result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void RequireAocGuard_WithNullBuilder_Throws()
|
||||
{
|
||||
RouteHandlerBuilder? builder = null;
|
||||
@@ -34,7 +38,8 @@ public sealed class AocGuardEndpointFilterExtensionsTests
|
||||
_ => Array.Empty<object?>()));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void RequireAocGuard_WithObjectSelector_UsesOverload()
|
||||
{
|
||||
var builder = WebApplication.CreateBuilder();
|
||||
|
||||
@@ -7,11 +7,14 @@ using Microsoft.Extensions.DependencyInjection;
|
||||
using StellaOps.Aoc;
|
||||
using StellaOps.Aoc.AspNetCore.Results;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Aoc.AspNetCore.Tests;
|
||||
|
||||
public sealed class AocHttpResultsTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Problem_WritesProblemDetails_WithGuardViolations()
|
||||
{
|
||||
// Arrange
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\__Libraries\StellaOps.Aoc.AspNetCore\StellaOps.Aoc.AspNetCore.csproj" />
|
||||
<ProjectReference Include="../../../__Libraries/StellaOps.TestKit/StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
using System.Collections.Immutable;
|
||||
using StellaOps.Aoc;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Aoc.Tests;
|
||||
|
||||
public sealed class AocErrorTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void FromResult_UsesFirstViolationCode()
|
||||
{
|
||||
var violations = ImmutableArray.Create(
|
||||
@@ -20,7 +22,8 @@ public sealed class AocErrorTests
|
||||
Assert.Equal(violations, error.Violations);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void FromResult_DefaultsWhenNoViolations()
|
||||
{
|
||||
var error = AocError.FromResult(AocGuardResult.Success);
|
||||
@@ -29,7 +32,8 @@ public sealed class AocErrorTests
|
||||
Assert.Contains("ERR_AOC_000", error.Message, StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void FromException_UsesCustomMessage()
|
||||
{
|
||||
var violations = ImmutableArray.Create(
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
using System.Text.Json;
|
||||
using StellaOps.Aoc;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Aoc.Tests;
|
||||
|
||||
public sealed class AocWriteGuardTests
|
||||
{
|
||||
private static readonly AocWriteGuard Guard = new();
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Validate_ReturnsSuccess_ForMinimalValidDocument()
|
||||
{
|
||||
using var document = JsonDocument.Parse("""
|
||||
@@ -33,7 +36,8 @@ public sealed class AocWriteGuardTests
|
||||
Assert.Empty(result.Violations);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Validate_AllowsLinksAndAdvisoryKey_ByDefault()
|
||||
{
|
||||
using var document = JsonDocument.Parse("""
|
||||
@@ -63,7 +67,8 @@ public sealed class AocWriteGuardTests
|
||||
Assert.Empty(result.Violations);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Validate_FlagsMissingTenant()
|
||||
{
|
||||
using var document = JsonDocument.Parse("""
|
||||
@@ -88,7 +93,8 @@ public sealed class AocWriteGuardTests
|
||||
Assert.Contains(result.Violations, v => v.ErrorCode == "ERR_AOC_004" && v.Path == "/tenant");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Validate_FlagsForbiddenField()
|
||||
{
|
||||
using var document = JsonDocument.Parse("""
|
||||
@@ -116,7 +122,8 @@ public sealed class AocWriteGuardTests
|
||||
Assert.Contains(result.Violations, v => v.ErrorCode == "ERR_AOC_001" && v.Path == "/severity");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Validate_FlagsUnknownField()
|
||||
{
|
||||
using var document = JsonDocument.Parse("""
|
||||
@@ -143,7 +150,8 @@ public sealed class AocWriteGuardTests
|
||||
Assert.Contains(result.Violations, v => v.ErrorCode == "ERR_AOC_007" && v.Path == "/custom_field");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Validate_AllowsCustomField_WhenConfigured()
|
||||
{
|
||||
using var document = JsonDocument.Parse("""
|
||||
@@ -174,7 +182,8 @@ public sealed class AocWriteGuardTests
|
||||
Assert.True(result.IsValid);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Validate_FlagsInvalidSignatureMetadata()
|
||||
{
|
||||
using var document = JsonDocument.Parse("""
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="../../__Libraries/StellaOps.Aoc/StellaOps.Aoc.csproj" />
|
||||
<ProjectReference Include="../../../__Libraries/StellaOps.TestKit/StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
@@ -1,8 +1,10 @@
|
||||
namespace StellaOps.Aoc.Tests;
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Aoc.Tests;
|
||||
|
||||
public class UnitTest1
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Test1()
|
||||
{
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ using StellaOps.Attestation;
|
||||
using StellaOps.Attestor.Envelope;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
public class DsseHelperTests
|
||||
{
|
||||
private sealed class FakeSigner : IAuthoritySigner
|
||||
@@ -18,7 +19,8 @@ public class DsseHelperTests
|
||||
=> Task.FromResult(Convert.FromHexString("deadbeef"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task WrapAsync_ProducesDsseEnvelope()
|
||||
{
|
||||
var stmt = new InTotoStatement(
|
||||
@@ -37,7 +39,8 @@ public class DsseHelperTests
|
||||
envelope.Signatures[0].Signature.Should().Be(Convert.ToBase64String(Convert.FromHexString("deadbeef")));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void PreAuthenticationEncoding_FollowsDsseSpec()
|
||||
{
|
||||
var payloadType = "example/type";
|
||||
|
||||
@@ -31,7 +31,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
// DSSE-8200-013: Cosign-compatible envelope structure tests
|
||||
// ==========================================================================
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void EnvelopeStructure_HasRequiredFields_ForCosignVerification()
|
||||
{
|
||||
// Arrange
|
||||
@@ -45,7 +46,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.True(result.IsValid, $"Structure validation failed: {string.Join(", ", result.Errors)}");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void EnvelopePayload_IsBase64Encoded_InSerializedForm()
|
||||
{
|
||||
// Arrange
|
||||
@@ -70,7 +72,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.Equal(payload, decoded);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void EnvelopeSignature_IsBase64Encoded_InSerializedForm()
|
||||
{
|
||||
// Arrange
|
||||
@@ -99,7 +102,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.True(sigBytes.Length > 0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void EnvelopePayloadType_IsCorrectMimeType_ForInToto()
|
||||
{
|
||||
// Arrange
|
||||
@@ -112,7 +116,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.Equal("application/vnd.in-toto+json", envelope.PayloadType);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void EnvelopeSerialization_ProducesValidJson_WithoutWhitespace()
|
||||
{
|
||||
// Arrange
|
||||
@@ -136,7 +141,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
// DSSE-8200-014: Fulcio certificate chain tests
|
||||
// ==========================================================================
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void FulcioCertificate_HasCodeSigningEku()
|
||||
{
|
||||
// Arrange & Act
|
||||
@@ -161,7 +167,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.True(hasCodeSigning, "Certificate should have Code Signing EKU");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void FulcioCertificate_HasDigitalSignatureKeyUsage()
|
||||
{
|
||||
// Arrange & Act
|
||||
@@ -173,7 +180,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.True(keyUsage.KeyUsages.HasFlag(X509KeyUsageFlags.DigitalSignature));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void FulcioCertificate_IsShortLived()
|
||||
{
|
||||
// Arrange - Fulcio certs are typically valid for ~20 minutes
|
||||
@@ -186,7 +194,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.True(validity.TotalHours <= 24, $"Certificate validity ({validity.TotalHours}h) should be <= 24 hours");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void BundleWithCertificate_HasValidPemFormat()
|
||||
{
|
||||
// Arrange
|
||||
@@ -207,7 +216,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
// DSSE-8200-015: Rekor transparency log offline verification tests
|
||||
// ==========================================================================
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void RekorEntry_HasValidLogIndex()
|
||||
{
|
||||
// Arrange
|
||||
@@ -221,7 +231,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.True(rekorEntry.LogIndex >= 0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void RekorEntry_HasValidIntegratedTime()
|
||||
{
|
||||
// Arrange
|
||||
@@ -238,7 +249,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.True(integratedTime >= now.AddHours(-1), "Integrated time should not be too old");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void RekorEntry_HasValidInclusionProof()
|
||||
{
|
||||
// Arrange
|
||||
@@ -256,7 +268,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.NotEmpty(rekorEntry.InclusionProof.Hashes);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void RekorEntry_CanonicalizedBody_IsBase64Encoded()
|
||||
{
|
||||
// Arrange
|
||||
@@ -276,7 +289,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.NotNull(json);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void RekorEntry_InclusionProof_HashesAreBase64()
|
||||
{
|
||||
// Arrange
|
||||
@@ -294,7 +308,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void BundleWithRekor_ContainsValidTransparencyEntry()
|
||||
{
|
||||
// Arrange
|
||||
@@ -310,7 +325,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.True(bundle.RekorEntry.LogIndex >= 0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void RekorEntry_CheckpointFormat_IsValid()
|
||||
{
|
||||
// Arrange
|
||||
@@ -329,7 +345,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
// Integration tests
|
||||
// ==========================================================================
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void FullBundle_SignVerifyRoundtrip_Succeeds()
|
||||
{
|
||||
// Arrange
|
||||
@@ -349,7 +366,8 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
Assert.True(structureResult.IsValid);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void DeterministicSigning_SamePayload_ProducesConsistentEnvelope()
|
||||
{
|
||||
// Arrange
|
||||
@@ -366,6 +384,7 @@ public sealed class DsseCosignCompatibilityTests : IDisposable
|
||||
// Note: Signatures may differ if using randomized ECDSA
|
||||
// (which is the default for security), so we only verify structure
|
||||
Assert.Equal(envelope1.Signatures.Count, envelope2.Signatures.Count);
|
||||
using StellaOps.TestKit;
|
||||
}
|
||||
|
||||
// ==========================================================================
|
||||
|
||||
@@ -6,13 +6,16 @@ using System.Text.Json;
|
||||
using FluentAssertions;
|
||||
using Xunit;
|
||||
using EnvelopeModel = StellaOps.Attestor.Envelope;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Envelope.Tests;
|
||||
|
||||
public sealed class DsseEnvelopeSerializerTests
|
||||
{
|
||||
private static readonly byte[] SamplePayload = Encoding.UTF8.GetBytes("deterministic-dsse-payload");
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Serialize_ProducesDeterministicCompactJson_ForSignaturePermutations()
|
||||
{
|
||||
var signatures = new[]
|
||||
|
||||
@@ -7,6 +7,8 @@ using StellaOps.Attestor.Envelope;
|
||||
using StellaOps.Cryptography;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Envelope.Tests;
|
||||
|
||||
public sealed class EnvelopeSignatureServiceTests
|
||||
@@ -23,7 +25,8 @@ public sealed class EnvelopeSignatureServiceTests
|
||||
|
||||
private readonly EnvelopeSignatureService service = new();
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void SignAndVerify_Ed25519_Succeeds()
|
||||
{
|
||||
var signingKey = EnvelopeKey.CreateEd25519Signer(Ed25519Seed, Ed25519Public);
|
||||
@@ -44,7 +47,8 @@ public sealed class EnvelopeSignatureServiceTests
|
||||
signingKey.KeyId.Should().Be(expectedKeyId);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Verify_Ed25519_InvalidSignature_ReturnsError()
|
||||
{
|
||||
var signingKey = EnvelopeKey.CreateEd25519Signer(Ed25519Seed, Ed25519Public);
|
||||
@@ -62,7 +66,8 @@ public sealed class EnvelopeSignatureServiceTests
|
||||
verifyResult.Error.Code.Should().Be(EnvelopeSignatureErrorCode.SignatureInvalid);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void SignAndVerify_EcdsaEs256_Succeeds()
|
||||
{
|
||||
using var ecdsa = ECDsa.Create(ECCurve.NamedCurves.nistP256);
|
||||
@@ -80,7 +85,8 @@ public sealed class EnvelopeSignatureServiceTests
|
||||
verifyResult.Value.Should().BeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Sign_WithVerificationOnlyKey_ReturnsMissingPrivateKey()
|
||||
{
|
||||
using var ecdsa = ECDsa.Create(ECCurve.NamedCurves.nistP256);
|
||||
@@ -93,7 +99,8 @@ public sealed class EnvelopeSignatureServiceTests
|
||||
signResult.Error.Code.Should().Be(EnvelopeSignatureErrorCode.MissingPrivateKey);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Verify_WithMismatchedKeyId_ReturnsError()
|
||||
{
|
||||
var signingKey = EnvelopeKey.CreateEd25519Signer(Ed25519Seed, Ed25519Public);
|
||||
@@ -107,7 +114,8 @@ public sealed class EnvelopeSignatureServiceTests
|
||||
verifyResult.Error.Code.Should().Be(EnvelopeSignatureErrorCode.KeyIdMismatch);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Verify_WithInvalidSignatureLength_ReturnsFormatError()
|
||||
{
|
||||
var verifyKey = EnvelopeKey.CreateEd25519Verifier(Ed25519Public);
|
||||
@@ -119,7 +127,8 @@ public sealed class EnvelopeSignatureServiceTests
|
||||
verifyResult.Error.Code.Should().Be(EnvelopeSignatureErrorCode.InvalidSignatureFormat);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Verify_WithAlgorithmMismatch_ReturnsError()
|
||||
{
|
||||
using var ecdsa = ECDsa.Create(ECCurve.NamedCurves.nistP256);
|
||||
|
||||
@@ -18,5 +18,6 @@
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\\StellaOps.Attestor.Envelope.csproj" />
|
||||
<ProjectReference Include="../../../__Libraries/StellaOps.TestKit/StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
||||
@@ -7,11 +7,14 @@ using System.Text.Json;
|
||||
using StellaOps.Attestor.Envelope;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Envelope.Tests;
|
||||
|
||||
public sealed class DsseEnvelopeSerializerTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Serialize_WithDefaultOptions_ProducesCompactAndExpandedJson()
|
||||
{
|
||||
var payload = Encoding.UTF8.GetBytes("{\"foo\":\"bar\"}");
|
||||
@@ -46,7 +49,8 @@ public sealed class DsseEnvelopeSerializerTests
|
||||
Assert.Equal("bar", preview.GetProperty("json").GetProperty("foo").GetString());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Serialize_WithCompressionEnabled_EmbedsCompressedPayloadMetadata()
|
||||
{
|
||||
var payload = Encoding.UTF8.GetBytes("{\"foo\":\"bar\",\"count\":1}");
|
||||
@@ -87,7 +91,8 @@ public sealed class DsseEnvelopeSerializerTests
|
||||
Assert.Equal(compressedBytes.Length, result.EmbeddedPayloadLength);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Serialize_WithDetachedReference_WritesMetadata()
|
||||
{
|
||||
var payload = Encoding.UTF8.GetBytes("detached payload preview");
|
||||
@@ -117,7 +122,8 @@ public sealed class DsseEnvelopeSerializerTests
|
||||
Assert.Equal(reference.MediaType, detached.GetProperty("mediaType").GetString());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Serialize_CompactOnly_SkipsExpandedPayload()
|
||||
{
|
||||
var payload = Encoding.UTF8.GetBytes("payload");
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\\..\\StellaOps.Attestor.Envelope.csproj" />
|
||||
<ProjectReference Include="../../../../__Libraries/StellaOps.TestKit/StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
||||
@@ -5,11 +5,13 @@ using StellaOps.Attestor.Core.Tests.Fixtures.Rekor;
|
||||
using StellaOps.Attestor.Core.Verification;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Core.Tests;
|
||||
|
||||
public sealed class RekorOfflineReceiptVerifierTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_ValidReceipt_Succeeds()
|
||||
{
|
||||
var (directory, receiptPath) = CreateTempReceipt(RekorOfflineReceiptFixtures.ReceiptJson);
|
||||
@@ -33,7 +35,8 @@ public sealed class RekorOfflineReceiptVerifierTests
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_CheckpointPathReference_Succeeds()
|
||||
{
|
||||
var directory = Path.Combine(Path.GetTempPath(), "stellaops-attestor-rekor-offline-" + Guid.NewGuid().ToString("n"));
|
||||
@@ -62,7 +65,8 @@ public sealed class RekorOfflineReceiptVerifierTests
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_TamperedCheckpointSignature_Fails()
|
||||
{
|
||||
var tampered = MutateReceiptJson(root =>
|
||||
@@ -90,7 +94,8 @@ public sealed class RekorOfflineReceiptVerifierTests
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_RootHashMismatch_Fails()
|
||||
{
|
||||
var badJson = MutateReceiptJson(root => root["rootHash"] = new string('0', 64));
|
||||
@@ -114,7 +119,8 @@ public sealed class RekorOfflineReceiptVerifierTests
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_AllowOfflineWithoutSignature_AllowsUnsignedCheckpoint()
|
||||
{
|
||||
var checkpointBodyOnly = RekorOfflineReceiptFixtures.SignedCheckpointNote.Split("\n\n", StringSplitOptions.None)[0] + "\n";
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\StellaOps.Attestor.Core\StellaOps.Attestor.Core.csproj" />
|
||||
<ProjectReference Include="../../../__Libraries/StellaOps.TestKit/StellaOps.TestKit.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
||||
@@ -39,7 +39,8 @@ namespace StellaOps.Attestor.Tests;
|
||||
|
||||
public sealed class AttestationBundleEndpointsTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ExportEndpoint_RequiresAuthentication()
|
||||
{
|
||||
using var factory = new AttestorWebApplicationFactory();
|
||||
@@ -50,7 +51,8 @@ public sealed class AttestationBundleEndpointsTests
|
||||
Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ExportAndImportEndpoints_RoundTripBundles()
|
||||
{
|
||||
using var factory = new AttestorWebApplicationFactory();
|
||||
@@ -64,6 +66,7 @@ public sealed class AttestationBundleEndpointsTests
|
||||
using (var scope = factory.Services.CreateScope())
|
||||
{
|
||||
var repository = scope.ServiceProvider.GetRequiredService<IAttestorEntryRepository>();
|
||||
using StellaOps.TestKit;
|
||||
var archiveStore = scope.ServiceProvider.GetRequiredService<IAttestorArchiveStore>();
|
||||
|
||||
var entry = new AttestorEntry
|
||||
|
||||
@@ -8,11 +8,13 @@ using StellaOps.Attestor.Core.Storage;
|
||||
using StellaOps.Attestor.WebService.Contracts;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Tests;
|
||||
|
||||
public sealed class AttestationQueryTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task QueryAsync_FiltersAndPagination_Work()
|
||||
{
|
||||
var repository = new InMemoryAttestorEntryRepository();
|
||||
@@ -83,7 +85,8 @@ public sealed class AttestationQueryTests
|
||||
Assert.All(secondPage.Items, item => Assert.DoesNotContain(item.RekorUuid, firstPageIds));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void TryBuildQuery_ValidatesInputs()
|
||||
{
|
||||
var httpContext = new DefaultHttpContext();
|
||||
|
||||
@@ -5,11 +5,13 @@ using System.Threading.Tasks;
|
||||
using StellaOps.Attestor.Core.Storage;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Tests;
|
||||
|
||||
public sealed class AttestorEntryRepositoryTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task QueryAsync_FiltersAndPagination_Work()
|
||||
{
|
||||
var repository = new InMemoryAttestorEntryRepository();
|
||||
@@ -53,7 +55,8 @@ public sealed class AttestorEntryRepositoryTests
|
||||
Assert.All(secondPage.Items, item => Assert.DoesNotContain(item.RekorUuid, seen));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SaveAsync_EnforcesUniqueBundleSha()
|
||||
{
|
||||
var repository = new InMemoryAttestorEntryRepository();
|
||||
|
||||
@@ -21,6 +21,8 @@ using Org.BouncyCastle.Crypto.Signers;
|
||||
using Org.BouncyCastle.Security;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Tests;
|
||||
|
||||
[Collection("SmSoftGate")]
|
||||
@@ -28,7 +30,8 @@ public sealed class AttestorSigningServiceTests : IDisposable
|
||||
{
|
||||
private readonly List<string> _temporaryPaths = new();
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SignAsync_Ed25519Key_ReturnsValidSignature()
|
||||
{
|
||||
var privateKey = new byte[32];
|
||||
@@ -110,7 +113,8 @@ public sealed class AttestorSigningServiceTests : IDisposable
|
||||
Assert.Equal("signed", auditSink.Records[0].Result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SignAsync_KmsKey_ProducesVerifiableSignature()
|
||||
{
|
||||
var kmsRoot = CreateTempDirectory();
|
||||
@@ -215,7 +219,8 @@ public sealed class AttestorSigningServiceTests : IDisposable
|
||||
Assert.Equal("signed", auditSink.Records[0].Result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SignAsync_Sm2Key_ReturnsValidSignature_WhenGateEnabled()
|
||||
{
|
||||
var originalGate = Environment.GetEnvironmentVariable("SM_SOFT_ALLOWED");
|
||||
@@ -312,7 +317,8 @@ public sealed class AttestorSigningServiceTests : IDisposable
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void Sm2Registry_Fails_WhenGateDisabled()
|
||||
{
|
||||
var originalGate = Environment.GetEnvironmentVariable("SM_SOFT_ALLOWED");
|
||||
|
||||
@@ -5,11 +5,13 @@ using StellaOps.Attestor.Core.Storage;
|
||||
using StellaOps.Attestor.Infrastructure.Storage;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Tests;
|
||||
|
||||
public sealed class AttestorStorageTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SaveAsync_PersistsAndFetchesEntry()
|
||||
{
|
||||
var repository = new InMemoryAttestorEntryRepository();
|
||||
@@ -27,7 +29,8 @@ public sealed class AttestorStorageTests
|
||||
Assert.Single(byArtifact);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SaveAsync_UpsertsExistingDocument()
|
||||
{
|
||||
var repository = new InMemoryAttestorEntryRepository();
|
||||
@@ -47,7 +50,8 @@ public sealed class AttestorStorageTests
|
||||
Assert.Equal("pending", stored!.Status);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task InMemoryDedupeStore_RoundTripsAndExpires()
|
||||
{
|
||||
var store = new InMemoryAttestorDedupeStore();
|
||||
|
||||
@@ -17,11 +17,14 @@ using StellaOps.Attestor.Infrastructure.Submission;
|
||||
using StellaOps.Attestor.Tests.Support;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Tests;
|
||||
|
||||
public sealed class AttestorSubmissionServiceTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SubmitAsync_ReturnsDeterministicUuid_OnDuplicateBundle()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions
|
||||
@@ -92,7 +95,8 @@ public sealed class AttestorSubmissionServiceTests
|
||||
Assert.Equal(request.Meta.Artifact.Sha256, verificationCache.InvalidatedSubjects[0]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task Validator_ThrowsWhenModeNotAllowed()
|
||||
{
|
||||
var canonicalizer = new DefaultDsseCanonicalizer();
|
||||
@@ -104,7 +108,8 @@ public sealed class AttestorSubmissionServiceTests
|
||||
await Assert.ThrowsAsync<AttestorValidationException>(() => validator.ValidateAsync(request));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SubmitAsync_Throws_WhenMirrorDisabledButRequested()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions
|
||||
@@ -163,7 +168,8 @@ public sealed class AttestorSubmissionServiceTests
|
||||
Assert.Equal("mirror_disabled", ex.Code);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SubmitAsync_ReturnsMirrorMetadata_WhenPreferenceBoth()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions
|
||||
@@ -233,7 +239,8 @@ public sealed class AttestorSubmissionServiceTests
|
||||
Assert.Equal("included", result.Mirror.Status);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task SubmitAsync_UsesMirrorAsCanonical_WhenPreferenceMirror()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions
|
||||
|
||||
@@ -7,13 +7,15 @@ using StellaOps.Attestor.Core.Submission;
|
||||
using StellaOps.Attestor.Infrastructure.Submission;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Tests;
|
||||
|
||||
public sealed class AttestorSubmissionValidatorHardeningTests
|
||||
{
|
||||
private static readonly DefaultDsseCanonicalizer Canonicalizer = new();
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ValidateAsync_ThrowsWhenPayloadExceedsLimit()
|
||||
{
|
||||
var constraints = new AttestorSubmissionConstraints(
|
||||
@@ -28,7 +30,8 @@ public sealed class AttestorSubmissionValidatorHardeningTests
|
||||
Assert.Equal("payload_too_large", exception.Code);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ValidateAsync_ThrowsWhenCertificateChainTooLong()
|
||||
{
|
||||
var constraints = new AttestorSubmissionConstraints(
|
||||
@@ -43,7 +46,8 @@ public sealed class AttestorSubmissionValidatorHardeningTests
|
||||
Assert.Equal("certificate_chain_too_long", exception.Code);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ValidateAsync_FuzzedInputs_DoNotCrash()
|
||||
{
|
||||
var constraints = new AttestorSubmissionConstraints();
|
||||
|
||||
@@ -22,6 +22,8 @@ using StellaOps.Attestor.Verify;
|
||||
using StellaOps.Attestor.Tests.Support;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Tests;
|
||||
|
||||
public sealed class AttestorVerificationServiceTests
|
||||
@@ -29,7 +31,8 @@ public sealed class AttestorVerificationServiceTests
|
||||
private static readonly byte[] HmacSecret = Encoding.UTF8.GetBytes("attestor-hmac-secret");
|
||||
private static readonly string HmacSecretBase64 = Convert.ToBase64String(HmacSecret);
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_ReturnsOk_ForExistingUuid()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions
|
||||
@@ -122,7 +125,8 @@ public sealed class AttestorVerificationServiceTests
|
||||
Assert.Equal("missing", verifyResult.Report.Transparency.WitnessStatus);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_KmsBundle_Passes_WhenTwoSignaturesRequired()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions
|
||||
@@ -213,7 +217,8 @@ public sealed class AttestorVerificationServiceTests
|
||||
Assert.Equal(2, verifyResult.Report.Signatures.RequiredSignatures);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_FlagsTamperedBundle()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions
|
||||
@@ -426,7 +431,8 @@ public sealed class AttestorVerificationServiceTests
|
||||
return buffer;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_OfflineSkipsProofRefreshWhenMissing()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions
|
||||
@@ -490,7 +496,8 @@ public sealed class AttestorVerificationServiceTests
|
||||
Assert.Equal(0, rekorClient.ProofRequests);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_OfflineUsesImportedProof()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions
|
||||
@@ -577,7 +584,8 @@ public sealed class AttestorVerificationServiceTests
|
||||
Assert.Equal(0, rekorClient.ProofRequests);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_FailsWhenWitnessRootMismatch()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions
|
||||
|
||||
@@ -5,11 +5,13 @@ using StellaOps.Attestor.Core.Options;
|
||||
using StellaOps.Attestor.WebService.Contracts;
|
||||
using Xunit;
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Tests;
|
||||
|
||||
public sealed class BulkVerificationContractsTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void TryBuildJob_ReturnsError_WhenItemsMissing()
|
||||
{
|
||||
var options = new AttestorOptions();
|
||||
@@ -22,7 +24,8 @@ public sealed class BulkVerificationContractsTests
|
||||
Assert.NotNull(error);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public void TryBuildJob_AppliesDefaults()
|
||||
{
|
||||
var options = new AttestorOptions
|
||||
|
||||
@@ -11,11 +11,14 @@ using StellaOps.Attestor.Core.Verification;
|
||||
using StellaOps.Attestor.Infrastructure.Bulk;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Tests;
|
||||
|
||||
public sealed class BulkVerificationWorkerTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task ProcessJobAsync_CompletesAllItems()
|
||||
{
|
||||
var jobStore = new InMemoryBulkVerificationJobStore();
|
||||
|
||||
@@ -10,11 +10,14 @@ using StellaOps.Attestor.Core.Verification;
|
||||
using StellaOps.Attestor.Infrastructure.Verification;
|
||||
using Xunit;
|
||||
|
||||
|
||||
using StellaOps.TestKit;
|
||||
namespace StellaOps.Attestor.Tests;
|
||||
|
||||
public sealed class CachedAttestorVerificationServiceTests
|
||||
{
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_ReturnsCachedResult_OnRepeatedCalls()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions());
|
||||
@@ -44,7 +47,8 @@ public sealed class CachedAttestorVerificationServiceTests
|
||||
Assert.Equal(1, inner.VerifyCallCount);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_BypassesCache_WhenRefreshProofRequested()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions());
|
||||
@@ -75,7 +79,8 @@ public sealed class CachedAttestorVerificationServiceTests
|
||||
Assert.Equal(2, inner.VerifyCallCount);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", TestCategories.Unit)]
|
||||
[Fact]
|
||||
public async Task VerifyAsync_BypassesCache_WhenDescriptorIncomplete()
|
||||
{
|
||||
var options = Options.Create(new AttestorOptions());
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user