Files
git.stella-ops.org/src/AdvisoryAI/StellaOps.AdvisoryAI/UnifiedSearch/WeightedRrfFusion.cs

246 lines
9.0 KiB
C#

using StellaOps.AdvisoryAI.KnowledgeSearch;
namespace StellaOps.AdvisoryAI.UnifiedSearch;
internal static class WeightedRrfFusion
{
private const int ReciprocalRankConstant = 60;
private const double EntityProximityBoost = 0.8;
private const double MaxFreshnessBoost = 0.05;
private const int FreshnessDaysCap = 365;
public static IReadOnlyList<(KnowledgeChunkRow Row, double Score, IReadOnlyDictionary<string, string> Debug)> Fuse(
IReadOnlyDictionary<string, double> domainWeights,
IReadOnlyDictionary<string, (string ChunkId, int Rank, KnowledgeChunkRow Row)> lexicalRanks,
IReadOnlyList<(KnowledgeChunkRow Row, int Rank, double Score)> vectorRanks,
string query,
UnifiedSearchFilter? filters,
IReadOnlyList<EntityMention>? detectedEntities = null,
bool enableFreshnessBoost = false,
DateTimeOffset? referenceTime = null,
IReadOnlyDictionary<string, int>? popularityMap = null,
double popularityBoostWeight = 0.0)
{
var merged = new Dictionary<string, (KnowledgeChunkRow Row, double Score, Dictionary<string, string> Debug)>(StringComparer.Ordinal);
foreach (var lexical in lexicalRanks.Values)
{
var domainWeight = GetDomainWeight(domainWeights, lexical.Row);
var score = domainWeight * ReciprocalRank(lexical.Rank);
var debug = new Dictionary<string, string>(StringComparer.Ordinal)
{
["lexicalRank"] = lexical.Rank.ToString(),
["lexicalScore"] = lexical.Row.LexicalScore.ToString("F6", System.Globalization.CultureInfo.InvariantCulture),
["domainWeight"] = domainWeight.ToString("F4", System.Globalization.CultureInfo.InvariantCulture)
};
merged[lexical.ChunkId] = (lexical.Row, score, debug);
}
foreach (var vector in vectorRanks)
{
if (!merged.TryGetValue(vector.Row.ChunkId, out var existing))
{
var domainWeight = GetDomainWeight(domainWeights, vector.Row);
existing = (vector.Row, 0d, new Dictionary<string, string>(StringComparer.Ordinal)
{
["domainWeight"] = domainWeight.ToString("F4", System.Globalization.CultureInfo.InvariantCulture)
});
}
var vecDomainWeight = GetDomainWeight(domainWeights, vector.Row);
existing.Score += vecDomainWeight * ReciprocalRank(vector.Rank);
existing.Debug["vectorRank"] = vector.Rank.ToString();
existing.Debug["vectorScore"] = vector.Score.ToString("F6", System.Globalization.CultureInfo.InvariantCulture);
merged[vector.Row.ChunkId] = existing;
}
var ranked = merged.Values
.Select(item =>
{
var entityBoost = ComputeEntityProximityBoost(item.Row, detectedEntities);
var freshnessBoost = enableFreshnessBoost
? ComputeFreshnessBoost(item.Row, referenceTime ?? DateTimeOffset.UnixEpoch)
: 0d;
var popBoost = ComputePopularityBoost(item.Row, popularityMap, popularityBoostWeight);
item.Score += entityBoost + freshnessBoost + popBoost;
item.Debug["entityBoost"] = entityBoost.ToString("F6", System.Globalization.CultureInfo.InvariantCulture);
item.Debug["freshnessBoost"] = freshnessBoost.ToString("F6", System.Globalization.CultureInfo.InvariantCulture);
item.Debug["popularityBoost"] = popBoost.ToString("F6", System.Globalization.CultureInfo.InvariantCulture);
item.Debug["chunkId"] = item.Row.ChunkId;
return item;
})
.OrderByDescending(static item => item.Score)
.ThenBy(static item => item.Row.Kind, StringComparer.Ordinal)
.ThenBy(static item => item.Row.ChunkId, StringComparer.Ordinal)
.Select(static item => (item.Row, item.Score, (IReadOnlyDictionary<string, string>)item.Debug))
.ToArray();
return ranked;
}
private static double ReciprocalRank(int rank)
{
if (rank <= 0)
{
return 0d;
}
return 1d / (ReciprocalRankConstant + rank);
}
private static double GetDomainWeight(IReadOnlyDictionary<string, double> domainWeights, KnowledgeChunkRow row)
{
var domain = GetRowDomain(row);
return domainWeights.TryGetValue(domain, out var weight) ? weight : 1.0;
}
private static string GetRowDomain(KnowledgeChunkRow row)
{
if (row.Metadata.RootElement.TryGetProperty("domain", out var domainProp) &&
domainProp.ValueKind == System.Text.Json.JsonValueKind.String)
{
return domainProp.GetString() ?? "knowledge";
}
return row.Kind switch
{
"finding" => "findings",
"vex_statement" => "vex",
"policy_rule" => "policy",
"platform_entity" => "platform",
"md_section" => "knowledge",
"api_operation" => "knowledge",
"doctor_check" => "knowledge",
_ => "knowledge"
};
}
private static double ComputeEntityProximityBoost(
KnowledgeChunkRow row,
IReadOnlyList<EntityMention>? detectedEntities)
{
if (detectedEntities is not { Count: > 0 })
{
return 0d;
}
var metadata = row.Metadata.RootElement;
if (metadata.ValueKind != System.Text.Json.JsonValueKind.Object)
{
return 0d;
}
// Check entity_key match
if (metadata.TryGetProperty("entity_key", out var entityKeyProp) &&
entityKeyProp.ValueKind == System.Text.Json.JsonValueKind.String)
{
var entityKey = entityKeyProp.GetString();
if (!string.IsNullOrWhiteSpace(entityKey))
{
foreach (var mention in detectedEntities)
{
if (entityKey.Contains(mention.Value, StringComparison.OrdinalIgnoreCase))
{
return EntityProximityBoost;
}
}
}
}
// Check cveId in metadata
if (metadata.TryGetProperty("cveId", out var cveIdProp) &&
cveIdProp.ValueKind == System.Text.Json.JsonValueKind.String)
{
var cveId = cveIdProp.GetString();
if (!string.IsNullOrWhiteSpace(cveId))
{
foreach (var mention in detectedEntities)
{
if (cveId.Equals(mention.Value, StringComparison.OrdinalIgnoreCase))
{
return EntityProximityBoost;
}
}
}
}
return 0d;
}
private static double ComputeFreshnessBoost(KnowledgeChunkRow row, DateTimeOffset referenceTime)
{
var metadata = row.Metadata.RootElement;
if (metadata.ValueKind != System.Text.Json.JsonValueKind.Object)
{
return 0d;
}
if (!metadata.TryGetProperty("freshness", out var freshnessProp) ||
freshnessProp.ValueKind != System.Text.Json.JsonValueKind.String)
{
return 0d;
}
if (!DateTimeOffset.TryParse(freshnessProp.GetString(), out var freshness))
{
return 0d;
}
var daysSinceFresh = (referenceTime - freshness).TotalDays;
if (daysSinceFresh < 0)
{
daysSinceFresh = 0;
}
if (daysSinceFresh >= FreshnessDaysCap)
{
return 0d;
}
return MaxFreshnessBoost * (1d - daysSinceFresh / FreshnessDaysCap);
}
/// <summary>
/// Computes an additive popularity boost based on click-through frequency.
/// Uses a logarithmic function to provide diminishing returns for very popular items,
/// preventing feedback loops.
/// </summary>
private static double ComputePopularityBoost(
KnowledgeChunkRow row,
IReadOnlyDictionary<string, int>? popularityMap,
double popularityBoostWeight)
{
if (popularityMap is null || popularityMap.Count == 0 || popularityBoostWeight <= 0d)
{
return 0d;
}
var metadata = row.Metadata.RootElement;
if (metadata.ValueKind != System.Text.Json.JsonValueKind.Object)
{
return 0d;
}
string? entityKey = null;
if (metadata.TryGetProperty("entity_key", out var entityKeyProp) &&
entityKeyProp.ValueKind == System.Text.Json.JsonValueKind.String)
{
entityKey = entityKeyProp.GetString();
}
if (string.IsNullOrWhiteSpace(entityKey))
{
return 0d;
}
if (!popularityMap.TryGetValue(entityKey, out var clickCount) || clickCount <= 0)
{
return 0d;
}
// Logarithmic boost: log2(1 + clickCount) * weight
return Math.Log2(1 + clickCount) * popularityBoostWeight;
}
}