Files
git.stella-ops.org/src/Policy/StellaOps.Policy.Engine/Services/RiskProfileConfigurationService.cs
master cc69d332e3
Some checks failed
Docs CI / lint-and-preview (push) Has been cancelled
Add unit tests for RabbitMq and Udp transport servers and clients
- Implemented comprehensive unit tests for RabbitMqTransportServer, covering constructor, disposal, connection management, event handlers, and exception handling.
- Added configuration tests for RabbitMqTransportServer to validate SSL, durable queues, auto-recovery, and custom virtual host options.
- Created unit tests for UdpFrameProtocol, including frame parsing and serialization, header size validation, and round-trip data preservation.
- Developed tests for UdpTransportClient, focusing on connection handling, event subscriptions, and exception scenarios.
- Established tests for UdpTransportServer, ensuring proper start/stop behavior, connection state management, and event handling.
- Included tests for UdpTransportOptions to verify default values and modification capabilities.
- Enhanced service registration tests for Udp transport services in the dependency injection container.
2025-12-05 19:01:12 +02:00

345 lines
12 KiB
C#

using System.Collections.Concurrent;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StellaOps.Cryptography;
using StellaOps.Policy.Engine.Options;
using StellaOps.Policy.RiskProfile.Hashing;
using StellaOps.Policy.RiskProfile.Merge;
using StellaOps.Policy.RiskProfile.Models;
using StellaOps.Policy.RiskProfile.Validation;
namespace StellaOps.Policy.Engine.Services;
/// <summary>
/// Service for loading and providing risk profiles from configuration.
/// </summary>
public sealed class RiskProfileConfigurationService
{
private readonly ILogger<RiskProfileConfigurationService> _logger;
private readonly PolicyEngineRiskProfileOptions _options;
private readonly RiskProfileMergeService _mergeService;
private readonly RiskProfileHasher _hasher;
private readonly RiskProfileValidator _validator;
private readonly ConcurrentDictionary<string, RiskProfileModel> _profileCache;
private readonly ConcurrentDictionary<string, RiskProfileModel> _resolvedCache;
private readonly object _loadLock = new();
private bool _loaded;
public RiskProfileConfigurationService(
ILogger<RiskProfileConfigurationService> logger,
IOptions<PolicyEngineOptions> options,
ICryptoHash cryptoHash)
{
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_options = options?.Value.RiskProfile ?? throw new ArgumentNullException(nameof(options));
ArgumentNullException.ThrowIfNull(cryptoHash);
_mergeService = new RiskProfileMergeService();
_hasher = new RiskProfileHasher(cryptoHash);
_validator = new RiskProfileValidator();
_profileCache = new ConcurrentDictionary<string, RiskProfileModel>(StringComparer.OrdinalIgnoreCase);
_resolvedCache = new ConcurrentDictionary<string, RiskProfileModel>(StringComparer.OrdinalIgnoreCase);
}
/// <summary>
/// Gets whether risk profile integration is enabled.
/// </summary>
public bool IsEnabled => _options.Enabled;
/// <summary>
/// Gets the default profile ID.
/// </summary>
public string DefaultProfileId => _options.DefaultProfileId;
/// <summary>
/// Loads all profiles from configuration and file system.
/// </summary>
public void LoadProfiles()
{
if (_loaded)
{
return;
}
lock (_loadLock)
{
if (_loaded)
{
return;
}
LoadInlineProfiles();
LoadFileProfiles();
EnsureDefaultProfile();
_loaded = true;
_logger.LogInformation(
"Loaded {Count} risk profiles (default: {DefaultId})",
_profileCache.Count,
_options.DefaultProfileId);
}
}
/// <summary>
/// Gets a profile by ID, resolving inheritance if needed.
/// </summary>
/// <param name="profileId">The profile ID to retrieve.</param>
/// <returns>The resolved profile, or null if not found.</returns>
public RiskProfileModel? GetProfile(string? profileId)
{
var id = string.IsNullOrWhiteSpace(profileId) ? _options.DefaultProfileId : profileId;
if (_options.CacheResolvedProfiles && _resolvedCache.TryGetValue(id, out var cached))
{
return cached;
}
if (!_profileCache.TryGetValue(id, out var profile))
{
_logger.LogWarning("Risk profile '{ProfileId}' not found", id);
return null;
}
var resolved = _mergeService.ResolveInheritance(
profile,
LookupProfile,
_options.MaxInheritanceDepth);
if (_options.CacheResolvedProfiles)
{
_resolvedCache.TryAdd(id, resolved);
}
return resolved;
}
/// <summary>
/// Gets the default profile.
/// </summary>
public RiskProfileModel? GetDefaultProfile() => GetProfile(_options.DefaultProfileId);
/// <summary>
/// Gets all loaded profile IDs.
/// </summary>
public IReadOnlyCollection<string> GetProfileIds() => _profileCache.Keys.ToList().AsReadOnly();
/// <summary>
/// Computes a deterministic hash for a profile.
/// </summary>
public string ComputeHash(RiskProfileModel profile) => _hasher.ComputeHash(profile);
/// <summary>
/// Computes a content hash (ignoring identity fields) for a profile.
/// </summary>
public string ComputeContentHash(RiskProfileModel profile) => _hasher.ComputeContentHash(profile);
/// <summary>
/// Registers a profile programmatically.
/// </summary>
public void RegisterProfile(RiskProfileModel profile)
{
ArgumentNullException.ThrowIfNull(profile);
_profileCache[profile.Id] = profile;
_resolvedCache.TryRemove(profile.Id, out _);
_logger.LogDebug("Registered risk profile '{ProfileId}' v{Version}", profile.Id, profile.Version);
}
/// <summary>
/// Clears the resolved profile cache.
/// </summary>
public void ClearResolvedCache()
{
_resolvedCache.Clear();
_logger.LogDebug("Cleared resolved profile cache");
}
private RiskProfileModel? LookupProfile(string id) =>
_profileCache.TryGetValue(id, out var profile) ? profile : null;
private void LoadInlineProfiles()
{
foreach (var definition in _options.Profiles)
{
try
{
var profile = ConvertFromDefinition(definition);
_profileCache[profile.Id] = profile;
_logger.LogDebug("Loaded inline profile '{ProfileId}' v{Version}", profile.Id, profile.Version);
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to load inline profile '{ProfileId}'", definition.Id);
}
}
}
private void LoadFileProfiles()
{
if (string.IsNullOrWhiteSpace(_options.ProfileDirectory))
{
return;
}
if (!Directory.Exists(_options.ProfileDirectory))
{
_logger.LogWarning("Risk profile directory not found: {Directory}", _options.ProfileDirectory);
return;
}
var files = Directory.GetFiles(_options.ProfileDirectory, "*.json", SearchOption.AllDirectories);
foreach (var file in files)
{
try
{
var json = File.ReadAllText(file);
if (_options.ValidateOnLoad)
{
var validation = _validator.Validate(json);
if (!validation.IsValid)
{
var errorMessages = validation.Errors?.Values ?? Enumerable.Empty<string>();
_logger.LogWarning(
"Risk profile file '{File}' failed validation: {Errors}",
file,
string.Join("; ", errorMessages.Any() ? errorMessages : new[] { "Unknown error" }));
continue;
}
}
var profile = JsonSerializer.Deserialize<RiskProfileModel>(json, JsonOptions);
if (profile != null)
{
_profileCache[profile.Id] = profile;
_logger.LogDebug("Loaded profile '{ProfileId}' from {File}", profile.Id, file);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to load risk profile from '{File}'", file);
}
}
}
private void EnsureDefaultProfile()
{
if (_profileCache.ContainsKey(_options.DefaultProfileId))
{
return;
}
var defaultProfile = CreateBuiltInDefaultProfile();
_profileCache[defaultProfile.Id] = defaultProfile;
_logger.LogDebug("Created built-in default profile '{ProfileId}'", defaultProfile.Id);
}
private static RiskProfileModel CreateBuiltInDefaultProfile()
{
return new RiskProfileModel
{
Id = "default",
Version = "1.0.0",
Description = "Built-in default risk profile with standard vulnerability signals.",
Signals = new List<RiskSignal>
{
new()
{
Name = "cvss_score",
Source = "vulnerability",
Type = RiskSignalType.Numeric,
Path = "/cvss/baseScore",
Unit = "score"
},
new()
{
Name = "kev",
Source = "cisa",
Type = RiskSignalType.Boolean,
Path = "/kev/inCatalog"
},
new()
{
Name = "epss",
Source = "first",
Type = RiskSignalType.Numeric,
Path = "/epss/probability",
Unit = "probability"
},
new()
{
Name = "reachability",
Source = "analysis",
Type = RiskSignalType.Categorical,
Path = "/reachability/status"
},
new()
{
Name = "exploit_available",
Source = "exploit-db",
Type = RiskSignalType.Boolean,
Path = "/exploit/available"
}
},
Weights = new Dictionary<string, double>
{
["cvss_score"] = 0.3,
["kev"] = 0.25,
["epss"] = 0.2,
["reachability"] = 0.15,
["exploit_available"] = 0.1
},
Overrides = new RiskOverrides(),
Metadata = new Dictionary<string, object?>
{
["builtin"] = true,
["created"] = DateTimeOffset.UtcNow.ToString("o")
}
};
}
private static RiskProfileModel ConvertFromDefinition(RiskProfileDefinition definition)
{
return new RiskProfileModel
{
Id = definition.Id,
Version = definition.Version,
Description = definition.Description,
Extends = definition.Extends,
Signals = definition.Signals.Select(s => new RiskSignal
{
Name = s.Name,
Source = s.Source,
Type = ParseSignalType(s.Type),
Path = s.Path,
Transform = s.Transform,
Unit = s.Unit
}).ToList(),
Weights = new Dictionary<string, double>(definition.Weights),
Overrides = new RiskOverrides(),
Metadata = definition.Metadata != null
? new Dictionary<string, object?>(definition.Metadata)
: null
};
}
private static RiskSignalType ParseSignalType(string type)
{
return type.ToLowerInvariant() switch
{
"boolean" or "bool" => RiskSignalType.Boolean,
"numeric" or "number" => RiskSignalType.Numeric,
"categorical" or "category" => RiskSignalType.Categorical,
_ => throw new ArgumentException($"Unknown signal type: {type}")
};
}
private static readonly JsonSerializerOptions JsonOptions = new()
{
PropertyNameCaseInsensitive = true,
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
};
}