246 lines
9.0 KiB
C#
246 lines
9.0 KiB
C#
using StellaOps.AdvisoryAI.KnowledgeSearch;
|
|
|
|
namespace StellaOps.AdvisoryAI.UnifiedSearch;
|
|
|
|
internal static class WeightedRrfFusion
|
|
{
|
|
private const int ReciprocalRankConstant = 60;
|
|
private const double EntityProximityBoost = 0.8;
|
|
private const double MaxFreshnessBoost = 0.05;
|
|
private const int FreshnessDaysCap = 365;
|
|
|
|
public static IReadOnlyList<(KnowledgeChunkRow Row, double Score, IReadOnlyDictionary<string, string> Debug)> Fuse(
|
|
IReadOnlyDictionary<string, double> domainWeights,
|
|
IReadOnlyDictionary<string, (string ChunkId, int Rank, KnowledgeChunkRow Row)> lexicalRanks,
|
|
IReadOnlyList<(KnowledgeChunkRow Row, int Rank, double Score)> vectorRanks,
|
|
string query,
|
|
UnifiedSearchFilter? filters,
|
|
IReadOnlyList<EntityMention>? detectedEntities = null,
|
|
bool enableFreshnessBoost = false,
|
|
DateTimeOffset? referenceTime = null,
|
|
IReadOnlyDictionary<string, int>? popularityMap = null,
|
|
double popularityBoostWeight = 0.0)
|
|
{
|
|
var merged = new Dictionary<string, (KnowledgeChunkRow Row, double Score, Dictionary<string, string> Debug)>(StringComparer.Ordinal);
|
|
|
|
foreach (var lexical in lexicalRanks.Values)
|
|
{
|
|
var domainWeight = GetDomainWeight(domainWeights, lexical.Row);
|
|
var score = domainWeight * ReciprocalRank(lexical.Rank);
|
|
var debug = new Dictionary<string, string>(StringComparer.Ordinal)
|
|
{
|
|
["lexicalRank"] = lexical.Rank.ToString(),
|
|
["lexicalScore"] = lexical.Row.LexicalScore.ToString("F6", System.Globalization.CultureInfo.InvariantCulture),
|
|
["domainWeight"] = domainWeight.ToString("F4", System.Globalization.CultureInfo.InvariantCulture)
|
|
};
|
|
|
|
merged[lexical.ChunkId] = (lexical.Row, score, debug);
|
|
}
|
|
|
|
foreach (var vector in vectorRanks)
|
|
{
|
|
if (!merged.TryGetValue(vector.Row.ChunkId, out var existing))
|
|
{
|
|
var domainWeight = GetDomainWeight(domainWeights, vector.Row);
|
|
existing = (vector.Row, 0d, new Dictionary<string, string>(StringComparer.Ordinal)
|
|
{
|
|
["domainWeight"] = domainWeight.ToString("F4", System.Globalization.CultureInfo.InvariantCulture)
|
|
});
|
|
}
|
|
|
|
var vecDomainWeight = GetDomainWeight(domainWeights, vector.Row);
|
|
existing.Score += vecDomainWeight * ReciprocalRank(vector.Rank);
|
|
existing.Debug["vectorRank"] = vector.Rank.ToString();
|
|
existing.Debug["vectorScore"] = vector.Score.ToString("F6", System.Globalization.CultureInfo.InvariantCulture);
|
|
merged[vector.Row.ChunkId] = existing;
|
|
}
|
|
|
|
var ranked = merged.Values
|
|
.Select(item =>
|
|
{
|
|
var entityBoost = ComputeEntityProximityBoost(item.Row, detectedEntities);
|
|
var freshnessBoost = enableFreshnessBoost
|
|
? ComputeFreshnessBoost(item.Row, referenceTime ?? DateTimeOffset.UnixEpoch)
|
|
: 0d;
|
|
var popBoost = ComputePopularityBoost(item.Row, popularityMap, popularityBoostWeight);
|
|
item.Score += entityBoost + freshnessBoost + popBoost;
|
|
item.Debug["entityBoost"] = entityBoost.ToString("F6", System.Globalization.CultureInfo.InvariantCulture);
|
|
item.Debug["freshnessBoost"] = freshnessBoost.ToString("F6", System.Globalization.CultureInfo.InvariantCulture);
|
|
item.Debug["popularityBoost"] = popBoost.ToString("F6", System.Globalization.CultureInfo.InvariantCulture);
|
|
item.Debug["chunkId"] = item.Row.ChunkId;
|
|
return item;
|
|
})
|
|
.OrderByDescending(static item => item.Score)
|
|
.ThenBy(static item => item.Row.Kind, StringComparer.Ordinal)
|
|
.ThenBy(static item => item.Row.ChunkId, StringComparer.Ordinal)
|
|
.Select(static item => (item.Row, item.Score, (IReadOnlyDictionary<string, string>)item.Debug))
|
|
.ToArray();
|
|
|
|
return ranked;
|
|
}
|
|
|
|
private static double ReciprocalRank(int rank)
|
|
{
|
|
if (rank <= 0)
|
|
{
|
|
return 0d;
|
|
}
|
|
|
|
return 1d / (ReciprocalRankConstant + rank);
|
|
}
|
|
|
|
private static double GetDomainWeight(IReadOnlyDictionary<string, double> domainWeights, KnowledgeChunkRow row)
|
|
{
|
|
var domain = GetRowDomain(row);
|
|
return domainWeights.TryGetValue(domain, out var weight) ? weight : 1.0;
|
|
}
|
|
|
|
private static string GetRowDomain(KnowledgeChunkRow row)
|
|
{
|
|
if (row.Metadata.RootElement.TryGetProperty("domain", out var domainProp) &&
|
|
domainProp.ValueKind == System.Text.Json.JsonValueKind.String)
|
|
{
|
|
return domainProp.GetString() ?? "knowledge";
|
|
}
|
|
|
|
return row.Kind switch
|
|
{
|
|
"finding" => "findings",
|
|
"vex_statement" => "vex",
|
|
"policy_rule" => "policy",
|
|
"platform_entity" => "platform",
|
|
"md_section" => "knowledge",
|
|
"api_operation" => "knowledge",
|
|
"doctor_check" => "knowledge",
|
|
_ => "knowledge"
|
|
};
|
|
}
|
|
|
|
private static double ComputeEntityProximityBoost(
|
|
KnowledgeChunkRow row,
|
|
IReadOnlyList<EntityMention>? detectedEntities)
|
|
{
|
|
if (detectedEntities is not { Count: > 0 })
|
|
{
|
|
return 0d;
|
|
}
|
|
|
|
var metadata = row.Metadata.RootElement;
|
|
if (metadata.ValueKind != System.Text.Json.JsonValueKind.Object)
|
|
{
|
|
return 0d;
|
|
}
|
|
|
|
// Check entity_key match
|
|
if (metadata.TryGetProperty("entity_key", out var entityKeyProp) &&
|
|
entityKeyProp.ValueKind == System.Text.Json.JsonValueKind.String)
|
|
{
|
|
var entityKey = entityKeyProp.GetString();
|
|
if (!string.IsNullOrWhiteSpace(entityKey))
|
|
{
|
|
foreach (var mention in detectedEntities)
|
|
{
|
|
if (entityKey.Contains(mention.Value, StringComparison.OrdinalIgnoreCase))
|
|
{
|
|
return EntityProximityBoost;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check cveId in metadata
|
|
if (metadata.TryGetProperty("cveId", out var cveIdProp) &&
|
|
cveIdProp.ValueKind == System.Text.Json.JsonValueKind.String)
|
|
{
|
|
var cveId = cveIdProp.GetString();
|
|
if (!string.IsNullOrWhiteSpace(cveId))
|
|
{
|
|
foreach (var mention in detectedEntities)
|
|
{
|
|
if (cveId.Equals(mention.Value, StringComparison.OrdinalIgnoreCase))
|
|
{
|
|
return EntityProximityBoost;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return 0d;
|
|
}
|
|
|
|
private static double ComputeFreshnessBoost(KnowledgeChunkRow row, DateTimeOffset referenceTime)
|
|
{
|
|
var metadata = row.Metadata.RootElement;
|
|
if (metadata.ValueKind != System.Text.Json.JsonValueKind.Object)
|
|
{
|
|
return 0d;
|
|
}
|
|
|
|
if (!metadata.TryGetProperty("freshness", out var freshnessProp) ||
|
|
freshnessProp.ValueKind != System.Text.Json.JsonValueKind.String)
|
|
{
|
|
return 0d;
|
|
}
|
|
|
|
if (!DateTimeOffset.TryParse(freshnessProp.GetString(), out var freshness))
|
|
{
|
|
return 0d;
|
|
}
|
|
|
|
var daysSinceFresh = (referenceTime - freshness).TotalDays;
|
|
if (daysSinceFresh < 0)
|
|
{
|
|
daysSinceFresh = 0;
|
|
}
|
|
|
|
if (daysSinceFresh >= FreshnessDaysCap)
|
|
{
|
|
return 0d;
|
|
}
|
|
|
|
return MaxFreshnessBoost * (1d - daysSinceFresh / FreshnessDaysCap);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Computes an additive popularity boost based on click-through frequency.
|
|
/// Uses a logarithmic function to provide diminishing returns for very popular items,
|
|
/// preventing feedback loops.
|
|
/// </summary>
|
|
private static double ComputePopularityBoost(
|
|
KnowledgeChunkRow row,
|
|
IReadOnlyDictionary<string, int>? popularityMap,
|
|
double popularityBoostWeight)
|
|
{
|
|
if (popularityMap is null || popularityMap.Count == 0 || popularityBoostWeight <= 0d)
|
|
{
|
|
return 0d;
|
|
}
|
|
|
|
var metadata = row.Metadata.RootElement;
|
|
if (metadata.ValueKind != System.Text.Json.JsonValueKind.Object)
|
|
{
|
|
return 0d;
|
|
}
|
|
|
|
string? entityKey = null;
|
|
if (metadata.TryGetProperty("entity_key", out var entityKeyProp) &&
|
|
entityKeyProp.ValueKind == System.Text.Json.JsonValueKind.String)
|
|
{
|
|
entityKey = entityKeyProp.GetString();
|
|
}
|
|
|
|
if (string.IsNullOrWhiteSpace(entityKey))
|
|
{
|
|
return 0d;
|
|
}
|
|
|
|
if (!popularityMap.TryGetValue(entityKey, out var clickCount) || clickCount <= 0)
|
|
{
|
|
return 0d;
|
|
}
|
|
|
|
// Logarithmic boost: log2(1 + clickCount) * weight
|
|
return Math.Log2(1 + clickCount) * popularityBoostWeight;
|
|
}
|
|
}
|