using System.Security.Cryptography; using System.Text; using System.Text.Json; using System.Text.RegularExpressions; using StellaOps.Policy.Registry.Contracts; using StellaOps.Policy.Registry.Storage; namespace StellaOps.Policy.Registry.Services; /// /// Default implementation of quick policy simulation service. /// Evaluates policy rules against provided input and returns violations. /// public sealed partial class PolicySimulationService : IPolicySimulationService { private readonly IPolicyPackStore _packStore; private readonly TimeProvider _timeProvider; // Regex patterns for input reference extraction [GeneratedRegex(@"input\.(\w+(?:\.\w+)*)", RegexOptions.None)] private static partial Regex InputReferenceRegex(); [GeneratedRegex(@"input\[""([^""]+)""\]", RegexOptions.None)] private static partial Regex InputBracketReferenceRegex(); public PolicySimulationService(IPolicyPackStore packStore, TimeProvider? timeProvider = null) { _packStore = packStore ?? throw new ArgumentNullException(nameof(packStore)); _timeProvider = timeProvider ?? TimeProvider.System; } public async Task SimulateAsync( Guid tenantId, Guid packId, SimulationRequest request, CancellationToken cancellationToken = default) { var start = _timeProvider.GetTimestamp(); var executedAt = _timeProvider.GetUtcNow(); var simulationId = GenerateSimulationId(tenantId, packId, executedAt); var pack = await _packStore.GetByIdAsync(tenantId, packId, cancellationToken); if (pack is null) { return new PolicySimulationResponse { SimulationId = simulationId, Success = false, ExecutedAt = executedAt, DurationMilliseconds = GetElapsedMs(start), Errors = [new SimulationError { Code = "PACK_NOT_FOUND", Message = $"Policy pack {packId} not found" }] }; } return await SimulateRulesInternalAsync( simulationId, pack.Rules ?? [], request, start, executedAt, cancellationToken); } public async Task SimulateRulesAsync( Guid tenantId, IReadOnlyList rules, SimulationRequest request, CancellationToken cancellationToken = default) { var start = _timeProvider.GetTimestamp(); var executedAt = _timeProvider.GetUtcNow(); var simulationId = GenerateSimulationId(tenantId, Guid.Empty, executedAt); return await SimulateRulesInternalAsync( simulationId, rules, request, start, executedAt, cancellationToken); } public Task ValidateInputAsync( IReadOnlyDictionary input, CancellationToken cancellationToken = default) { var errors = new List(); if (input.Count == 0) { errors.Add(new InputValidationError { Path = "$", Message = "Input must contain at least one property" }); } // Check for common required fields var commonFields = new[] { "subject", "resource", "action", "context" }; var missingFields = commonFields.Where(f => !input.ContainsKey(f)).ToList(); if (missingFields.Count == commonFields.Length) { // Warn if none of the common fields are present errors.Add(new InputValidationError { Path = "$", Message = $"Input should contain at least one of: {string.Join(", ", commonFields)}" }); } return Task.FromResult(errors.Count > 0 ? InputValidationResult.Invalid(errors) : InputValidationResult.Valid()); } private async Task SimulateRulesInternalAsync( string simulationId, IReadOnlyList rules, SimulationRequest request, long startTimestamp, DateTimeOffset executedAt, CancellationToken cancellationToken) { var violations = new List(); var errors = new List(); var trace = new List(); int rulesMatched = 0; var enabledRules = rules.Where(r => r.Enabled).ToList(); foreach (var rule in enabledRules) { cancellationToken.ThrowIfCancellationRequested(); try { var (matched, violation, traceEntry) = EvaluateRule(rule, request.Input, request.Options); if (request.Options?.Trace == true && traceEntry is not null) { trace.Add(traceEntry); } if (matched) { rulesMatched++; if (violation is not null) { violations.Add(violation); } } } catch (Exception ex) { errors.Add(new SimulationError { RuleId = rule.RuleId, Code = "EVALUATION_ERROR", Message = ex.Message }); } } var elapsed = GetElapsedMs(startTimestamp); var severityCounts = violations .GroupBy(v => v.Severity.ToLowerInvariant()) .ToDictionary(g => g.Key, g => g.Count()); var summary = new SimulationSummary { TotalRulesEvaluated = enabledRules.Count, RulesMatched = rulesMatched, ViolationsFound = violations.Count, ViolationsBySeverity = severityCounts }; var result = new SimulationResult { Result = new Dictionary { ["allow"] = violations.Count == 0, ["violations_count"] = violations.Count }, Violations = violations.Count > 0 ? violations : null, Trace = request.Options?.Trace == true && trace.Count > 0 ? trace : null, Explain = request.Options?.Explain == true ? BuildExplainTrace(enabledRules, request.Input) : null }; return new PolicySimulationResponse { SimulationId = simulationId, Success = errors.Count == 0, ExecutedAt = executedAt, DurationMilliseconds = elapsed, Result = result, Summary = summary, Errors = errors.Count > 0 ? errors : null }; } private (bool matched, SimulatedViolation? violation, string? trace) EvaluateRule( PolicyRule rule, IReadOnlyDictionary input, SimulationOptions? options) { // If no Rego code, use basic rule matching based on severity and name if (string.IsNullOrWhiteSpace(rule.Rego)) { // Without Rego, we do pattern-based matching on rule name/description var matched = MatchRuleByName(rule, input); var trace = options?.Trace == true ? $"Rule {rule.RuleId}: matched={matched} (no Rego, name-based)" : null; if (matched) { var violation = new SimulatedViolation { RuleId = rule.RuleId, Severity = rule.Severity.ToString().ToLowerInvariant(), Message = rule.Description ?? $"Violation of rule {rule.Name}" }; return (true, violation, trace); } return (false, null, trace); } // Evaluate Rego-based rule var regoResult = EvaluateRegoRule(rule, input); var regoTrace = options?.Trace == true ? $"Rule {rule.RuleId}: matched={regoResult.matched}, inputs_used={string.Join(",", regoResult.inputsUsed)}" : null; if (regoResult.matched) { var violation = new SimulatedViolation { RuleId = rule.RuleId, Severity = rule.Severity.ToString().ToLowerInvariant(), Message = rule.Description ?? $"Violation of rule {rule.Name}", Context = regoResult.context }; return (true, violation, regoTrace); } return (false, null, regoTrace); } private static bool MatchRuleByName(PolicyRule rule, IReadOnlyDictionary input) { // Simple heuristic matching for rules without Rego var ruleName = rule.Name.ToLowerInvariant(); var ruleDesc = rule.Description?.ToLowerInvariant() ?? ""; // Check if any input key matches rule keywords foreach (var (key, value) in input) { var keyLower = key.ToLowerInvariant(); var valueLower = value?.ToString()?.ToLowerInvariant() ?? ""; if (ruleName.Contains(keyLower) || ruleDesc.Contains(keyLower)) { return true; } if (ruleName.Contains(valueLower) || ruleDesc.Contains(valueLower)) { return true; } } return false; } private (bool matched, HashSet inputsUsed, IReadOnlyDictionary? context) EvaluateRegoRule( PolicyRule rule, IReadOnlyDictionary input) { // Extract input references from Rego code var inputRefs = ExtractInputReferences(rule.Rego!); var inputsUsed = new HashSet(); var context = new Dictionary(); // Simple evaluation: check if referenced inputs exist and have values bool allInputsPresent = true; foreach (var inputRef in inputRefs) { var value = GetNestedValue(input, inputRef); if (value is not null) { inputsUsed.Add(inputRef); context[inputRef] = value; } else { allInputsPresent = false; } } // For this simplified simulation: // - Rule matches if all referenced inputs are present // - This simulates the rule being able to evaluate var matched = inputRefs.Count > 0 && allInputsPresent; return (matched, inputsUsed, context.Count > 0 ? context : null); } private static HashSet ExtractInputReferences(string rego) { var refs = new HashSet(StringComparer.Ordinal); // Match input.field.subfield pattern foreach (Match match in InputReferenceRegex().Matches(rego)) { refs.Add(match.Groups[1].Value); } // Match input["field"] pattern foreach (Match match in InputBracketReferenceRegex().Matches(rego)) { refs.Add(match.Groups[1].Value); } return refs; } private static object? GetNestedValue(IReadOnlyDictionary input, string path) { var parts = path.Split('.'); object? current = input; foreach (var part in parts) { if (current is IReadOnlyDictionary dict) { if (!dict.TryGetValue(part, out current)) { return null; } } else if (current is JsonElement jsonElement) { if (jsonElement.ValueKind == JsonValueKind.Object && jsonElement.TryGetProperty(part, out var prop)) { current = prop; } else { return null; } } else { return null; } } return current; } private static PolicyExplainTrace BuildExplainTrace( IReadOnlyList rules, IReadOnlyDictionary input) { var steps = new List(); steps.Add(new { type = "input_received", keys = input.Keys.ToList() }); foreach (var rule in rules) { steps.Add(new { type = "rule_evaluation", rule_id = rule.RuleId, rule_name = rule.Name, severity = rule.Severity.ToString(), has_rego = !string.IsNullOrWhiteSpace(rule.Rego) }); } steps.Add(new { type = "evaluation_complete", rules_count = rules.Count }); return new PolicyExplainTrace { Steps = steps }; } private static string GenerateSimulationId(Guid tenantId, Guid packId, DateTimeOffset timestamp) { var content = $"{tenantId}:{packId}:{timestamp.ToUnixTimeMilliseconds()}"; var hash = SHA256.HashData(Encoding.UTF8.GetBytes(content)); return $"sim_{Convert.ToHexString(hash)[..16].ToLowerInvariant()}"; } private long GetElapsedMs(long startTimestamp) { var elapsed = _timeProvider.GetElapsedTime(startTimestamp, _timeProvider.GetTimestamp()); return (long)Math.Ceiling(elapsed.TotalMilliseconds); } }