This commit is contained in:
StellaOps Bot
2025-12-26 15:19:07 +02:00
25 changed files with 3377 additions and 132 deletions

View File

@@ -0,0 +1,136 @@
namespace StellaOps.AdvisoryAI.Inference;
/// <summary>
/// Result of local LLM inference.
/// </summary>
public sealed record LocalInferenceResult
{
/// <summary>
/// Generated text content.
/// </summary>
public required string Content { get; init; }
/// <summary>
/// Number of tokens generated.
/// </summary>
public required int TokensGenerated { get; init; }
/// <summary>
/// Total inference time in milliseconds.
/// </summary>
public required long InferenceTimeMs { get; init; }
/// <summary>
/// Time to first token in milliseconds.
/// </summary>
public required long TimeToFirstTokenMs { get; init; }
/// <summary>
/// Tokens per second throughput.
/// </summary>
public double TokensPerSecond => InferenceTimeMs > 0
? TokensGenerated * 1000.0 / InferenceTimeMs
: 0;
/// <summary>
/// Model ID used for inference.
/// </summary>
public required string ModelId { get; init; }
/// <summary>
/// Whether inference was deterministic.
/// </summary>
public required bool Deterministic { get; init; }
/// <summary>
/// Seed used for generation.
/// </summary>
public required int Seed { get; init; }
}
/// <summary>
/// Model status information.
/// </summary>
public sealed record LocalModelStatus
{
/// <summary>
/// Whether model is loaded.
/// </summary>
public required bool Loaded { get; init; }
/// <summary>
/// Model path.
/// </summary>
public required string ModelPath { get; init; }
/// <summary>
/// Verified digest matches expected.
/// </summary>
public required bool DigestVerified { get; init; }
/// <summary>
/// Memory usage in bytes.
/// </summary>
public required long MemoryBytes { get; init; }
/// <summary>
/// Device being used.
/// </summary>
public required string Device { get; init; }
/// <summary>
/// Context size in tokens.
/// </summary>
public required int ContextSize { get; init; }
}
/// <summary>
/// Interface for local LLM runtime.
/// Sprint: SPRINT_20251226_019_AI_offline_inference
/// Task: OFFLINE-04
/// </summary>
public interface ILocalLlmRuntime : IDisposable
{
/// <summary>
/// Runtime type identifier.
/// </summary>
string RuntimeType { get; }
/// <summary>
/// Load a model with the given configuration.
/// </summary>
/// <param name="config">Model configuration.</param>
/// <param name="cancellationToken">Cancellation token.</param>
Task LoadModelAsync(LocalLlmConfig config, CancellationToken cancellationToken = default);
/// <summary>
/// Unload the current model.
/// </summary>
Task UnloadModelAsync(CancellationToken cancellationToken = default);
/// <summary>
/// Get current model status.
/// </summary>
Task<LocalModelStatus> GetStatusAsync(CancellationToken cancellationToken = default);
/// <summary>
/// Generate text from a prompt.
/// </summary>
/// <param name="prompt">Input prompt.</param>
/// <param name="cancellationToken">Cancellation token.</param>
Task<LocalInferenceResult> GenerateAsync(string prompt, CancellationToken cancellationToken = default);
/// <summary>
/// Generate text with streaming output.
/// </summary>
/// <param name="prompt">Input prompt.</param>
/// <param name="cancellationToken">Cancellation token.</param>
IAsyncEnumerable<string> GenerateStreamAsync(string prompt, CancellationToken cancellationToken = default);
/// <summary>
/// Verify model digest matches expected.
/// </summary>
/// <param name="expectedDigest">Expected SHA-256 digest.</param>
/// <param name="cancellationToken">Cancellation token.</param>
Task<bool> VerifyDigestAsync(string expectedDigest, CancellationToken cancellationToken = default);
}

View File

@@ -0,0 +1,182 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Security.Cryptography;
namespace StellaOps.AdvisoryAI.Inference;
/// <summary>
/// Local LLM runtime using llama.cpp bindings.
/// Sprint: SPRINT_20251226_019_AI_offline_inference
/// Task: OFFLINE-05
/// </summary>
public sealed class LlamaCppRuntime : ILocalLlmRuntime
{
private LocalLlmConfig? _config;
private bool _modelLoaded;
private string? _computedDigest;
public string RuntimeType => "llama.cpp";
public Task LoadModelAsync(LocalLlmConfig config, CancellationToken cancellationToken = default)
{
_config = config;
// Verify model file exists
if (!File.Exists(config.ModelPath))
{
throw new FileNotFoundException($"Model file not found: {config.ModelPath}");
}
// In a real implementation, this would:
// 1. Load the GGUF/GGML model file
// 2. Initialize llama.cpp context with config settings
// 3. Verify digest if required
_modelLoaded = true;
return Task.CompletedTask;
}
public Task UnloadModelAsync(CancellationToken cancellationToken = default)
{
_modelLoaded = false;
_config = null;
_computedDigest = null;
return Task.CompletedTask;
}
public Task<LocalModelStatus> GetStatusAsync(CancellationToken cancellationToken = default)
{
return Task.FromResult(new LocalModelStatus
{
Loaded = _modelLoaded,
ModelPath = _config?.ModelPath ?? string.Empty,
DigestVerified = _computedDigest == _config?.WeightsDigest,
MemoryBytes = _modelLoaded ? EstimateMemoryUsage() : 0,
Device = _config?.Device.ToString() ?? "Unknown",
ContextSize = _config?.ContextLength ?? 0
});
}
public async Task<LocalInferenceResult> GenerateAsync(string prompt, CancellationToken cancellationToken = default)
{
if (!_modelLoaded || _config is null)
{
throw new InvalidOperationException("Model not loaded");
}
var stopwatch = Stopwatch.StartNew();
var firstTokenTime = 0L;
// In a real implementation, this would call llama.cpp inference
// For now, return a placeholder response
await Task.Delay(100, cancellationToken); // Simulate first token
firstTokenTime = stopwatch.ElapsedMilliseconds;
await Task.Delay(400, cancellationToken); // Simulate generation
stopwatch.Stop();
var generatedContent = GeneratePlaceholderResponse(prompt);
var tokensGenerated = generatedContent.Split(' ').Length;
return new LocalInferenceResult
{
Content = generatedContent,
TokensGenerated = tokensGenerated,
InferenceTimeMs = stopwatch.ElapsedMilliseconds,
TimeToFirstTokenMs = firstTokenTime,
ModelId = $"local:{Path.GetFileName(_config.ModelPath)}",
Deterministic = _config.Temperature == 0,
Seed = _config.Seed
};
}
public async IAsyncEnumerable<string> GenerateStreamAsync(
string prompt,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (!_modelLoaded || _config is null)
{
throw new InvalidOperationException("Model not loaded");
}
// Simulate streaming output
var words = GeneratePlaceholderResponse(prompt).Split(' ');
foreach (var word in words)
{
if (cancellationToken.IsCancellationRequested)
{
yield break;
}
await Task.Delay(50, cancellationToken);
yield return word + " ";
}
}
public async Task<bool> VerifyDigestAsync(string expectedDigest, CancellationToken cancellationToken = default)
{
if (_config is null || !File.Exists(_config.ModelPath))
{
return false;
}
using var sha256 = SHA256.Create();
await using var stream = File.OpenRead(_config.ModelPath);
var hash = await sha256.ComputeHashAsync(stream, cancellationToken);
_computedDigest = Convert.ToHexStringLower(hash);
return string.Equals(_computedDigest, expectedDigest, StringComparison.OrdinalIgnoreCase);
}
private long EstimateMemoryUsage()
{
if (_config is null)
{
return 0;
}
// Rough estimate based on quantization
var baseSize = new FileInfo(_config.ModelPath).Length;
var contextOverhead = _config.ContextLength * 4096L; // Rough KV cache estimate
return baseSize + contextOverhead;
}
private static string GeneratePlaceholderResponse(string prompt)
{
// In a real implementation, this would be actual LLM output
if (prompt.Contains("explain", StringComparison.OrdinalIgnoreCase))
{
return "This vulnerability affects the component by allowing unauthorized access. " +
"The vulnerable code path is reachable from the application entry point. " +
"Evidence: [EVIDENCE:sbom-001] Component is present in SBOM. " +
"[EVIDENCE:reach-001] Call graph shows reachability.";
}
if (prompt.Contains("remediat", StringComparison.OrdinalIgnoreCase))
{
return "Recommended remediation: Upgrade the affected component to the patched version. " +
"- Update package.json: dependency@1.0.0 -> dependency@1.0.1 " +
"- Run npm install to update lockfile " +
"- Verify with npm audit";
}
if (prompt.Contains("policy", StringComparison.OrdinalIgnoreCase))
{
return "Parsed policy intent: Override rule for critical severity. " +
"Conditions: severity = critical, scope = production. " +
"Actions: set_verdict = block.";
}
return "Analysis complete. The finding has been evaluated based on available evidence.";
}
public void Dispose()
{
_modelLoaded = false;
_config = null;
_computedDigest = null;
}
}

View File

@@ -0,0 +1,129 @@
namespace StellaOps.AdvisoryAI.Inference;
/// <summary>
/// Configuration options for local/offline inference.
/// Sprint: SPRINT_20251226_019_AI_offline_inference
/// Task: OFFLINE-24
/// </summary>
public sealed class LocalInferenceOptions
{
/// <summary>
/// Configuration section name.
/// </summary>
public const string SectionName = "AdvisoryAI:Inference:Offline";
/// <summary>
/// Whether to enable local inference.
/// </summary>
public bool Enabled { get; set; }
/// <summary>
/// Path to the model bundle directory.
/// </summary>
public string? BundlePath { get; set; }
/// <summary>
/// Required SHA-256 digest of the model weights.
/// </summary>
public string? RequiredDigest { get; set; }
/// <summary>
/// Model to load (filename in bundle).
/// </summary>
public string? ModelName { get; set; }
/// <summary>
/// Quantization to use.
/// </summary>
public string Quantization { get; set; } = "Q4_K_M";
/// <summary>
/// Runtime to use (llama.cpp, onnx).
/// </summary>
public string Runtime { get; set; } = "llama.cpp";
/// <summary>
/// Device for inference.
/// </summary>
public string Device { get; set; } = "auto";
/// <summary>
/// Number of GPU layers to offload.
/// </summary>
public int GpuLayers { get; set; } = 0;
/// <summary>
/// Number of threads for CPU inference.
/// </summary>
public int Threads { get; set; } = 0; // 0 = auto
/// <summary>
/// Context length (max tokens).
/// </summary>
public int ContextLength { get; set; } = 4096;
/// <summary>
/// Maximum tokens to generate.
/// </summary>
public int MaxTokens { get; set; } = 2048;
/// <summary>
/// Whether to enable inference caching.
/// </summary>
public bool EnableCache { get; set; } = true;
/// <summary>
/// Cache directory path.
/// </summary>
public string? CachePath { get; set; }
/// <summary>
/// Whether to verify digest at load time.
/// </summary>
public bool VerifyDigestOnLoad { get; set; } = true;
/// <summary>
/// Whether to enforce airgap mode (disable remote fallback).
/// </summary>
public bool AirgapMode { get; set; }
/// <summary>
/// Crypto scheme for signature verification (eidas, fips, gost, sm).
/// </summary>
public string? CryptoScheme { get; set; }
}
/// <summary>
/// Factory for creating local LLM runtimes.
/// Task: OFFLINE-22
/// </summary>
public interface ILocalLlmRuntimeFactory
{
/// <summary>
/// Create a runtime based on configuration.
/// </summary>
ILocalLlmRuntime Create(LocalInferenceOptions options);
/// <summary>
/// Get supported runtime types.
/// </summary>
IReadOnlyList<string> SupportedRuntimes { get; }
}
/// <summary>
/// Default runtime factory implementation.
/// </summary>
public sealed class LocalLlmRuntimeFactory : ILocalLlmRuntimeFactory
{
public IReadOnlyList<string> SupportedRuntimes => new[] { "llama.cpp", "onnx" };
public ILocalLlmRuntime Create(LocalInferenceOptions options)
{
return options.Runtime.ToLowerInvariant() switch
{
"llama.cpp" or "llama" or "gguf" => new LlamaCppRuntime(),
"onnx" => new OnnxRuntime(),
_ => throw new NotSupportedException($"Runtime '{options.Runtime}' not supported")
};
}
}

View File

@@ -0,0 +1,161 @@
namespace StellaOps.AdvisoryAI.Inference;
/// <summary>
/// Quantization levels for local LLM models.
/// </summary>
public enum ModelQuantization
{
/// <summary>
/// Full precision (FP32).
/// </summary>
FP32,
/// <summary>
/// Half precision (FP16).
/// </summary>
FP16,
/// <summary>
/// Brain floating point (BF16).
/// </summary>
BF16,
/// <summary>
/// 8-bit integer quantization.
/// </summary>
INT8,
/// <summary>
/// 4-bit GGML K-quant (medium).
/// </summary>
Q4_K_M,
/// <summary>
/// 4-bit GGML K-quant (small).
/// </summary>
Q4_K_S,
/// <summary>
/// 5-bit GGML K-quant (medium).
/// </summary>
Q5_K_M,
/// <summary>
/// 8-bit GGML quantization.
/// </summary>
Q8_0
}
/// <summary>
/// Device type for local inference.
/// </summary>
public enum InferenceDevice
{
/// <summary>
/// CPU inference.
/// </summary>
CPU,
/// <summary>
/// CUDA GPU inference.
/// </summary>
CUDA,
/// <summary>
/// AMD ROCm GPU inference.
/// </summary>
ROCm,
/// <summary>
/// Apple Metal GPU inference.
/// </summary>
Metal,
/// <summary>
/// Intel NPU inference.
/// </summary>
NPU,
/// <summary>
/// Vulkan compute.
/// </summary>
Vulkan,
/// <summary>
/// Auto-detect best available.
/// </summary>
Auto
}
/// <summary>
/// Configuration for local LLM runtime.
/// Sprint: SPRINT_20251226_019_AI_offline_inference
/// Task: OFFLINE-03
/// </summary>
public sealed record LocalLlmConfig
{
/// <summary>
/// Path to the model weights file.
/// </summary>
public required string ModelPath { get; init; }
/// <summary>
/// Expected SHA-256 digest of the weights file.
/// </summary>
public required string WeightsDigest { get; init; }
/// <summary>
/// Model quantization level.
/// </summary>
public ModelQuantization Quantization { get; init; } = ModelQuantization.Q4_K_M;
/// <summary>
/// Context length (max tokens).
/// </summary>
public int ContextLength { get; init; } = 4096;
/// <summary>
/// Device for inference.
/// </summary>
public InferenceDevice Device { get; init; } = InferenceDevice.Auto;
/// <summary>
/// Number of GPU layers to offload (0 = all CPU).
/// </summary>
public int GpuLayers { get; init; } = 0;
/// <summary>
/// Number of threads for CPU inference.
/// </summary>
public int Threads { get; init; } = Environment.ProcessorCount / 2;
/// <summary>
/// Batch size for parallel decoding.
/// </summary>
public int BatchSize { get; init; } = 512;
/// <summary>
/// Temperature for sampling (0 = deterministic).
/// </summary>
public double Temperature { get; init; } = 0;
/// <summary>
/// Random seed for deterministic output.
/// </summary>
public int Seed { get; init; } = 42;
/// <summary>
/// Enable flash attention if available.
/// </summary>
public bool FlashAttention { get; init; } = true;
/// <summary>
/// Maximum tokens to generate.
/// </summary>
public int MaxTokens { get; init; } = 2048;
/// <summary>
/// Enable streaming output.
/// </summary>
public bool Streaming { get; init; } = false;
}

View File

@@ -0,0 +1,280 @@
using System.Text.Json;
using System.Text.Json.Serialization;
namespace StellaOps.AdvisoryAI.Inference;
/// <summary>
/// Model bundle manifest.
/// Sprint: SPRINT_20251226_019_AI_offline_inference
/// Task: OFFLINE-11, OFFLINE-12
/// </summary>
public sealed record ModelBundleManifest
{
/// <summary>
/// Bundle format version.
/// </summary>
[JsonPropertyName("version")]
public string Version { get; init; } = "1.0.0";
/// <summary>
/// Model name.
/// </summary>
[JsonPropertyName("name")]
public required string Name { get; init; }
/// <summary>
/// Model description.
/// </summary>
[JsonPropertyName("description")]
public string? Description { get; init; }
/// <summary>
/// Model license.
/// </summary>
[JsonPropertyName("license")]
public required string License { get; init; }
/// <summary>
/// Model size category.
/// </summary>
[JsonPropertyName("size_category")]
public required string SizeCategory { get; init; }
/// <summary>
/// Supported quantizations.
/// </summary>
[JsonPropertyName("quantizations")]
public required IReadOnlyList<string> Quantizations { get; init; }
/// <summary>
/// Files in the bundle.
/// </summary>
[JsonPropertyName("files")]
public required IReadOnlyList<BundleFile> Files { get; init; }
/// <summary>
/// Bundle creation timestamp.
/// </summary>
[JsonPropertyName("created_at")]
public required string CreatedAt { get; init; }
/// <summary>
/// Signature ID (if signed).
/// </summary>
[JsonPropertyName("signature_id")]
public string? SignatureId { get; init; }
/// <summary>
/// Crypto scheme used for signing.
/// </summary>
[JsonPropertyName("crypto_scheme")]
public string? CryptoScheme { get; init; }
}
/// <summary>
/// A file in the model bundle.
/// </summary>
public sealed record BundleFile
{
/// <summary>
/// Relative path in bundle.
/// </summary>
[JsonPropertyName("path")]
public required string Path { get; init; }
/// <summary>
/// SHA-256 digest.
/// </summary>
[JsonPropertyName("digest")]
public required string Digest { get; init; }
/// <summary>
/// File size in bytes.
/// </summary>
[JsonPropertyName("size")]
public required long Size { get; init; }
/// <summary>
/// File type.
/// </summary>
[JsonPropertyName("type")]
public required string Type { get; init; }
}
/// <summary>
/// Service for managing model bundles.
/// Task: OFFLINE-11 to OFFLINE-14
/// </summary>
public interface IModelBundleManager
{
/// <summary>
/// List available bundles.
/// </summary>
Task<IReadOnlyList<ModelBundleManifest>> ListBundlesAsync(CancellationToken cancellationToken = default);
/// <summary>
/// Get bundle manifest by name.
/// </summary>
Task<ModelBundleManifest?> GetManifestAsync(string bundleName, CancellationToken cancellationToken = default);
/// <summary>
/// Download a bundle.
/// </summary>
Task<string> DownloadBundleAsync(string bundleName, string targetPath, IProgress<double>? progress = null, CancellationToken cancellationToken = default);
/// <summary>
/// Verify bundle integrity.
/// </summary>
Task<BundleVerificationResult> VerifyBundleAsync(string bundlePath, CancellationToken cancellationToken = default);
/// <summary>
/// Extract bundle to target directory.
/// </summary>
Task<string> ExtractBundleAsync(string bundlePath, string targetDir, CancellationToken cancellationToken = default);
}
/// <summary>
/// Result of bundle verification.
/// </summary>
public sealed record BundleVerificationResult
{
/// <summary>
/// Whether verification passed.
/// </summary>
public required bool Valid { get; init; }
/// <summary>
/// Files that failed verification.
/// </summary>
public required IReadOnlyList<string> FailedFiles { get; init; }
/// <summary>
/// Signature verification result.
/// </summary>
public required bool SignatureValid { get; init; }
/// <summary>
/// Error message if invalid.
/// </summary>
public string? ErrorMessage { get; init; }
}
/// <summary>
/// Default implementation of model bundle manager.
/// </summary>
public sealed class FileSystemModelBundleManager : IModelBundleManager
{
private readonly string _bundleStorePath;
public FileSystemModelBundleManager(string bundleStorePath)
{
_bundleStorePath = bundleStorePath;
Directory.CreateDirectory(_bundleStorePath);
}
public Task<IReadOnlyList<ModelBundleManifest>> ListBundlesAsync(CancellationToken cancellationToken = default)
{
var bundles = new List<ModelBundleManifest>();
foreach (var dir in Directory.GetDirectories(_bundleStorePath))
{
var manifestPath = Path.Combine(dir, "manifest.json");
if (File.Exists(manifestPath))
{
var json = File.ReadAllText(manifestPath);
var manifest = JsonSerializer.Deserialize<ModelBundleManifest>(json);
if (manifest != null)
{
bundles.Add(manifest);
}
}
}
return Task.FromResult<IReadOnlyList<ModelBundleManifest>>(bundles);
}
public Task<ModelBundleManifest?> GetManifestAsync(string bundleName, CancellationToken cancellationToken = default)
{
var manifestPath = Path.Combine(_bundleStorePath, bundleName, "manifest.json");
if (!File.Exists(manifestPath))
{
return Task.FromResult<ModelBundleManifest?>(null);
}
var json = File.ReadAllText(manifestPath);
var manifest = JsonSerializer.Deserialize<ModelBundleManifest>(json);
return Task.FromResult(manifest);
}
public Task<string> DownloadBundleAsync(string bundleName, string targetPath, IProgress<double>? progress = null, CancellationToken cancellationToken = default)
{
// In a real implementation, this would download from a registry
throw new NotImplementedException("Bundle download not implemented - use offline transfer");
}
public async Task<BundleVerificationResult> VerifyBundleAsync(string bundlePath, CancellationToken cancellationToken = default)
{
var manifestPath = Path.Combine(bundlePath, "manifest.json");
if (!File.Exists(manifestPath))
{
return new BundleVerificationResult
{
Valid = false,
FailedFiles = Array.Empty<string>(),
SignatureValid = false,
ErrorMessage = "manifest.json not found"
};
}
var json = await File.ReadAllTextAsync(manifestPath, cancellationToken);
var manifest = JsonSerializer.Deserialize<ModelBundleManifest>(json);
if (manifest is null)
{
return new BundleVerificationResult
{
Valid = false,
FailedFiles = Array.Empty<string>(),
SignatureValid = false,
ErrorMessage = "Failed to parse manifest"
};
}
var failedFiles = new List<string>();
using var sha256 = System.Security.Cryptography.SHA256.Create();
foreach (var file in manifest.Files)
{
var filePath = Path.Combine(bundlePath, file.Path);
if (!File.Exists(filePath))
{
failedFiles.Add($"{file.Path}: missing");
continue;
}
await using var stream = File.OpenRead(filePath);
var hash = await sha256.ComputeHashAsync(stream, cancellationToken);
var digest = Convert.ToHexStringLower(hash);
if (!string.Equals(digest, file.Digest, StringComparison.OrdinalIgnoreCase))
{
failedFiles.Add($"{file.Path}: digest mismatch");
}
}
return new BundleVerificationResult
{
Valid = failedFiles.Count == 0,
FailedFiles = failedFiles,
SignatureValid = manifest.SignatureId != null, // Would verify signature in production
ErrorMessage = failedFiles.Count > 0 ? $"{failedFiles.Count} files failed verification" : null
};
}
public Task<string> ExtractBundleAsync(string bundlePath, string targetDir, CancellationToken cancellationToken = default)
{
// Bundles are expected to already be extracted
// This would handle .tar.gz extraction in production
Directory.CreateDirectory(targetDir);
return Task.FromResult(targetDir);
}
}

View File

@@ -0,0 +1,138 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Security.Cryptography;
namespace StellaOps.AdvisoryAI.Inference;
/// <summary>
/// Local LLM runtime using ONNX Runtime.
/// Sprint: SPRINT_20251226_019_AI_offline_inference
/// Task: OFFLINE-06
/// </summary>
public sealed class OnnxRuntime : ILocalLlmRuntime
{
private LocalLlmConfig? _config;
private bool _modelLoaded;
private string? _computedDigest;
public string RuntimeType => "onnx";
public Task LoadModelAsync(LocalLlmConfig config, CancellationToken cancellationToken = default)
{
_config = config;
if (!File.Exists(config.ModelPath))
{
throw new FileNotFoundException($"Model file not found: {config.ModelPath}");
}
// In a real implementation, this would:
// 1. Load the ONNX model file
// 2. Initialize ONNX Runtime session with execution providers
// 3. Configure GPU/CPU execution based on device setting
_modelLoaded = true;
return Task.CompletedTask;
}
public Task UnloadModelAsync(CancellationToken cancellationToken = default)
{
_modelLoaded = false;
_config = null;
_computedDigest = null;
return Task.CompletedTask;
}
public Task<LocalModelStatus> GetStatusAsync(CancellationToken cancellationToken = default)
{
return Task.FromResult(new LocalModelStatus
{
Loaded = _modelLoaded,
ModelPath = _config?.ModelPath ?? string.Empty,
DigestVerified = _computedDigest == _config?.WeightsDigest,
MemoryBytes = _modelLoaded ? EstimateMemoryUsage() : 0,
Device = _config?.Device.ToString() ?? "Unknown",
ContextSize = _config?.ContextLength ?? 0
});
}
public async Task<LocalInferenceResult> GenerateAsync(string prompt, CancellationToken cancellationToken = default)
{
if (!_modelLoaded || _config is null)
{
throw new InvalidOperationException("Model not loaded");
}
var stopwatch = Stopwatch.StartNew();
// Simulate ONNX inference
await Task.Delay(150, cancellationToken);
var firstTokenTime = stopwatch.ElapsedMilliseconds;
await Task.Delay(350, cancellationToken);
stopwatch.Stop();
var generatedContent = "[ONNX] Analysis based on provided evidence.";
var tokensGenerated = generatedContent.Split(' ').Length;
return new LocalInferenceResult
{
Content = generatedContent,
TokensGenerated = tokensGenerated,
InferenceTimeMs = stopwatch.ElapsedMilliseconds,
TimeToFirstTokenMs = firstTokenTime,
ModelId = $"onnx:{Path.GetFileName(_config.ModelPath)}",
Deterministic = true,
Seed = _config.Seed
};
}
public async IAsyncEnumerable<string> GenerateStreamAsync(
string prompt,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (!_modelLoaded || _config is null)
{
throw new InvalidOperationException("Model not loaded");
}
var response = "[ONNX] Analysis based on provided evidence.".Split(' ');
foreach (var word in response)
{
await Task.Delay(40, cancellationToken);
yield return word + " ";
}
}
public async Task<bool> VerifyDigestAsync(string expectedDigest, CancellationToken cancellationToken = default)
{
if (_config is null || !File.Exists(_config.ModelPath))
{
return false;
}
using var sha256 = SHA256.Create();
await using var stream = File.OpenRead(_config.ModelPath);
var hash = await sha256.ComputeHashAsync(stream, cancellationToken);
_computedDigest = Convert.ToHexStringLower(hash);
return string.Equals(_computedDigest, expectedDigest, StringComparison.OrdinalIgnoreCase);
}
private long EstimateMemoryUsage()
{
if (_config is null)
{
return 0;
}
return new FileInfo(_config.ModelPath).Length * 2; // ONNX typically needs 2x model size
}
public void Dispose()
{
_modelLoaded = false;
_config = null;
_computedDigest = null;
}
}