research: benchmark hybrid (lexical+dense+graph) recall vs current FTS
Viktor asked to enhance the memory system with 'semantics' — remember concepts (not just tokens) linked in a graph — and to prove, by benchmarking against the current system, that it actually improves recall. A multi-phase research workflow (18 agents) did landscape research, an adversarially-reviewed integration design, a stratified eval set over the real 5,452-memory corpus, and a head-to-head prototype-vs-current benchmark. Result: hybrid (lexical FTS + dense embeddings, RRF-fused) beats FTS on every overall metric, driven by a robust paraphrase win (recall@10 +0.350). Recommend adopting lexical+dense; the concept graph is DEFERRED. Post-run adversarial review correction (applied to all docs before commit): the prototype's fusion config structurally barred the graph leg from the ranked top-k, so the 'graph contributes nothing' ablation was a math artifact, NOT an empirical result — the graph is UNEVALUATED, not disproven (deferred on cost+uncertainty). Multi-hop deltas are not statistically significant. Glossary in CONTEXT.md; framing in ADR-0001-0003; findings in ADR-0004-0006 + docs/research/. Privacy: the corpus/queries/qrels/results are the user's real memories and stay gitignored (data/, cache/, results/, build_eval_set.py); only harness code, aggregate numbers, and synthetic examples are committed. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
7439540f8f
commit
1cc8a2b378
23 changed files with 3428 additions and 0 deletions
25
benchmarks/.gitignore
vendored
Normal file
25
benchmarks/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
# Benchmark dataset is the user's REAL personal memories — NEVER commit.
|
||||
# Privacy hard-rule (research task brief): corpus/queries/qrels stay LOCAL.
|
||||
data/
|
||||
.venv/
|
||||
cache/
|
||||
*.npy
|
||||
*.faiss
|
||||
*.db
|
||||
|
||||
# The eval-set GENERATOR embeds real memory-derived query text + author notes
|
||||
# (paraphrases of real memories, real ids/notes). Treat it as a data artifact:
|
||||
# LOCAL-ONLY, never commit. Regenerates data/ from corpus.jsonl. The HARNESS
|
||||
# itself (harness/*.py, the other scripts) contains NO real content and is safe.
|
||||
scripts/build_eval_set.py
|
||||
|
||||
# Python noise
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.pytest_cache/
|
||||
*.egg-info/
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Results from runs may quote real content — keep local by default.
|
||||
results/
|
||||
*.results.json
|
||||
126
benchmarks/README.md
Normal file
126
benchmarks/README.md
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
# claude-memory recall benchmark
|
||||
|
||||
Stratified retrieval benchmark gating the hybrid-recall adoption decision
|
||||
(ADR-0001): does dense-vector semantic recall + a concept graph beat the current
|
||||
lexical FTS on **recall@5, recall@10, nDCG@10, MRR**? Quality decides adoption;
|
||||
latency/storage are measured but non-gating.
|
||||
|
||||
> **PRIVACY — read first.** The corpus is the operator's REAL personal memories.
|
||||
> `data/` (corpus/queries/qrels), `.venv/`, `cache/`, `results/`, and
|
||||
> `scripts/build_eval_set.py` (the generator embeds memory-derived query text)
|
||||
> are **gitignored and must never be committed**. Everything else here contains
|
||||
> only code / aggregate numbers and is safe to commit. Sensitive memories
|
||||
> (`is_sensitive=1`) are excluded from the corpus entirely.
|
||||
|
||||
## Layout
|
||||
|
||||
```
|
||||
benchmarks/
|
||||
harness/ # importable package (committable; no real content)
|
||||
types.py # Memory, Query, Qrels, Retriever protocol
|
||||
metrics.py # recall@k, nDCG@k, MRR (binary relevance)
|
||||
dataset.py # load_dataset() + referential-integrity validation
|
||||
runner.py # run_benchmark() -> overall + per-stratum + latency
|
||||
baselines.py # SqliteFtsRetriever (faithful FTS5/BM25 reference)
|
||||
example_retriever.py # worked example of the plug-in interface
|
||||
test_harness.py # unit tests (pytest)
|
||||
scripts/
|
||||
export_corpus.py # SQLite -> data/corpus.jsonl (non-sensitive only)
|
||||
build_eval_set.py # -> data/queries.jsonl + qrels.jsonl [GITIGNORED]
|
||||
dataset_stats.py # validate + print AGGREGATE stats (safe)
|
||||
run_eval.py # CLI: run a retriever, print/save metrics
|
||||
data/ # [GITIGNORED] corpus.jsonl, queries.jsonl, qrels.jsonl
|
||||
.venv/ # [GITIGNORED]
|
||||
```
|
||||
|
||||
## Dataset schema (JSONL, one object per line)
|
||||
|
||||
**`corpus.jsonl`** — every non-sensitive memory:
|
||||
```json
|
||||
{"id": 137, "content": "...", "category": "decisions", "tags": "memory,architecture",
|
||||
"expanded_keywords": "...", "importance": 0.85}
|
||||
```
|
||||
`id` (int) is the join key everywhere. `tags` is comma-separated; `expanded_keywords`
|
||||
space-separated (matches the production schema).
|
||||
|
||||
**`queries.jsonl`** — eval queries, three strata:
|
||||
```json
|
||||
{"query_id": "para_006", "text": "...", "stratum": "paraphrase", "relevant_ids": [380],
|
||||
"_note": "author rationale", "_jaccard": 0.023}
|
||||
```
|
||||
- `stratum` ∈ `exact` | `paraphrase` | `multihop`.
|
||||
- `relevant_ids` is a convenience copy; **`qrels.jsonl` is authoritative**.
|
||||
- `_note` / `_jaccard` are provenance fields (underscore-prefixed); ignore them in
|
||||
scoring.
|
||||
|
||||
**`qrels.jsonl`** — binary relevance judgments (authoritative):
|
||||
```json
|
||||
{"query_id": "multi_006", "relevant_ids": [263, 423, 637]}
|
||||
```
|
||||
|
||||
### Strata (what each one tests)
|
||||
|
||||
| stratum | construction | who should win |
|
||||
|---|---|---|
|
||||
| **exact** | query = a salient phrase lifted from ONE memory; that memory is relevant (verified as the top FTS hit at build time) | lexical already strong; floor check |
|
||||
| **paraphrase** | query restates ONE memory's meaning in DIFFERENT words (low lexical overlap, validated Jaccard ≤ ~0.18 vs content+keywords) | **dense embeddings** |
|
||||
| **multihop** | query needs 2+ DISTINCT memories sharing an entity/concept (e.g. project + decision, or a multi-part runbook); ALL are relevant | **concept graph** |
|
||||
|
||||
Where a near-duplicate memory equally satisfies a single-target query, qrels was
|
||||
augmented to include the twin (so a good retriever isn't penalised); deliberate
|
||||
discriminator queries are kept single-target on purpose.
|
||||
|
||||
## Pluggable retriever interface
|
||||
|
||||
A retriever is any object implementing **one** method:
|
||||
|
||||
```python
|
||||
def retrieve(self, query: str, k: int) -> list[int]:
|
||||
"""Return up to k memory ids (corpus `id`s), ranked best-first."""
|
||||
```
|
||||
|
||||
Optional lifecycle hooks the runner uses if present (duck-typed):
|
||||
|
||||
```python
|
||||
def build_index(self, corpus: list[Memory]) -> None: ... # timed separately
|
||||
def index_size_bytes(self) -> int: ... # reported
|
||||
name: str # label in reports
|
||||
```
|
||||
|
||||
A bare callable `retrieve(query, k) -> list[int]` also works.
|
||||
|
||||
## Run it
|
||||
|
||||
```bash
|
||||
.venv/bin/python scripts/export_corpus.py # (re)build data/corpus.jsonl
|
||||
.venv/bin/python scripts/build_eval_set.py # (re)build queries+qrels (local)
|
||||
.venv/bin/python scripts/dataset_stats.py # validate + aggregate stats
|
||||
.venv/bin/python -m pytest harness/test_harness.py -q
|
||||
|
||||
# evaluate a retriever (built-in alias or module:Class)
|
||||
.venv/bin/python scripts/run_eval.py --retriever fts5
|
||||
.venv/bin/python scripts/run_eval.py --retriever your_pkg.mod:YourRetriever --json results/yours.json
|
||||
```
|
||||
|
||||
Programmatic use:
|
||||
|
||||
```python
|
||||
from harness import load_dataset, run_benchmark
|
||||
ds = load_dataset()
|
||||
result = run_benchmark(MyRetriever(), ds) # builds index, times queries
|
||||
print(result.summary()) # overall + per-stratum table
|
||||
result.to_dict() # full machine-readable result
|
||||
```
|
||||
|
||||
`run_benchmark` requests `retrieve_k=20` per query by default (≥ the max metric
|
||||
cutoff of 10), macro-averages metrics over queries (overall + per stratum), and
|
||||
reports per-query latency p50/p95 plus index build time/size when the hooks exist.
|
||||
|
||||
## Reference baseline
|
||||
|
||||
`harness.baselines.SqliteFtsRetriever` mirrors the production local-store search
|
||||
(README "Search Algorithm"): FTS5 over content/category/tags/expanded_keywords,
|
||||
`'"w1" OR "w2" ...'` MATCH, `ORDER BY bm25(), importance`. This is the lexical
|
||||
"current system" any hybrid retriever must beat. (The Postgres `tsvector` path
|
||||
uses weighted A/B/C/D ranking and an importance-first default; FTS5/BM25 is the
|
||||
faithful, dependency-free relevance reference for the quality comparison.)
|
||||
28
benchmarks/harness/__init__.py
Normal file
28
benchmarks/harness/__init__.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""Benchmark harness for claude-memory recall evaluation.
|
||||
|
||||
Public API:
|
||||
from harness import Retriever, load_dataset, run_benchmark, BenchmarkResult
|
||||
from harness import metrics
|
||||
|
||||
A retriever is any object (or callable) implementing:
|
||||
retrieve(query: str, k: int) -> list[memory_id] # ranked, best first
|
||||
|
||||
memory_id matches the `id` field in corpus.jsonl / qrels.jsonl (int).
|
||||
"""
|
||||
from .types import Retriever, Query, Memory, Qrels
|
||||
from .dataset import load_dataset, Dataset
|
||||
from .runner import run_benchmark, BenchmarkResult, StratumResult
|
||||
from . import metrics
|
||||
|
||||
__all__ = [
|
||||
"Retriever",
|
||||
"Query",
|
||||
"Memory",
|
||||
"Qrels",
|
||||
"load_dataset",
|
||||
"Dataset",
|
||||
"run_benchmark",
|
||||
"BenchmarkResult",
|
||||
"StratumResult",
|
||||
"metrics",
|
||||
]
|
||||
93
benchmarks/harness/baselines.py
Normal file
93
benchmarks/harness/baselines.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""Reference LEXICAL baseline retrievers that mirror the production system.
|
||||
|
||||
These exist so (a) the eval-set author can VERIFY a query's labels and check
|
||||
that paraphrase queries genuinely defeat lexical matching, and (b) later agents
|
||||
have an honest "current system" to beat.
|
||||
|
||||
`SqliteFtsRetriever` builds an in-memory SQLite FTS5 index over the corpus and
|
||||
runs the SAME query shape the production local store uses:
|
||||
words -> '"w1" OR "w2" ...' MATCH, ORDER BY bm25(), importance as tiebreak.
|
||||
(README "SQLite: FTS5 with BM25".) This is the closest faithful, dependency-free
|
||||
baseline. The Postgres tsvector path is documented in the README; its ranking
|
||||
differs (weighted A/B/C/D + importance-first default) but for a quality ceiling
|
||||
comparison the FTS5/BM25 relevance ordering is the right lexical reference.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from collections.abc import Sequence
|
||||
|
||||
from .types import Memory, MemoryId
|
||||
|
||||
# FTS5 reserved-ish tokens; we quote every term anyway, but strip embedded quotes.
|
||||
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||
|
||||
|
||||
class SqliteFtsRetriever:
|
||||
"""Faithful FTS5/BM25 lexical baseline (mirrors local_store search)."""
|
||||
|
||||
name = "sqlite_fts5_bm25"
|
||||
|
||||
def __init__(self, sort_by: str = "relevance") -> None:
|
||||
# "relevance": ORDER BY bm25(), importance DESC (best for quality eval)
|
||||
# "importance": ORDER BY importance DESC, ... (production default)
|
||||
self.sort_by = sort_by
|
||||
self._con: sqlite3.Connection | None = None
|
||||
|
||||
def build_index(self, corpus: Sequence[Memory]) -> None:
|
||||
con = sqlite3.connect(":memory:")
|
||||
con.execute(
|
||||
"""
|
||||
CREATE VIRTUAL TABLE memories_fts USING fts5(
|
||||
content, category, tags, expanded_keywords,
|
||||
memory_id UNINDEXED, importance UNINDEXED
|
||||
)
|
||||
"""
|
||||
)
|
||||
con.executemany(
|
||||
"INSERT INTO memories_fts(content, category, tags, expanded_keywords, memory_id, importance)"
|
||||
" VALUES (?,?,?,?,?,?)",
|
||||
[
|
||||
(m.content, m.category, m.tags, m.expanded_keywords, m.id, m.importance)
|
||||
for m in corpus
|
||||
],
|
||||
)
|
||||
con.commit()
|
||||
self._con = con
|
||||
|
||||
def _fts_query(self, query: str) -> str:
|
||||
words = _WORD_RE.findall(query.lower())
|
||||
if not words:
|
||||
return ""
|
||||
return " OR ".join(f'"{w}"' for w in words)
|
||||
|
||||
def retrieve(self, query: str, k: int) -> list[MemoryId]:
|
||||
assert self._con is not None, "call build_index first"
|
||||
match = self._fts_query(query)
|
||||
if not match:
|
||||
return []
|
||||
if self.sort_by == "importance":
|
||||
order = "importance DESC, bm25(memories_fts)"
|
||||
else:
|
||||
order = "bm25(memories_fts), importance DESC"
|
||||
try:
|
||||
rows = self._con.execute(
|
||||
f"SELECT memory_id FROM memories_fts WHERE memories_fts MATCH ? "
|
||||
f"ORDER BY {order} LIMIT ?",
|
||||
(match, k),
|
||||
).fetchall()
|
||||
except sqlite3.OperationalError:
|
||||
# mirror production LIKE fallback on FTS syntax errors
|
||||
like = f"%{query}%"
|
||||
rows = self._con.execute(
|
||||
"SELECT memory_id FROM memories_fts WHERE content LIKE ? OR tags LIKE ? "
|
||||
"ORDER BY importance DESC LIMIT ?",
|
||||
(like, like, k),
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def close(self) -> None:
|
||||
if self._con is not None:
|
||||
self._con.close()
|
||||
self._con = None
|
||||
115
benchmarks/harness/dataset.py
Normal file
115
benchmarks/harness/dataset.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""Load corpus / queries / qrels JSONL into typed objects."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from .types import Memory, Query, Qrels, MemoryId
|
||||
|
||||
_DATA_DIR = Path(__file__).resolve().parents[1] / "data"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dataset:
|
||||
corpus: list[Memory]
|
||||
queries: list[Query]
|
||||
qrels: Qrels
|
||||
|
||||
@property
|
||||
def corpus_by_id(self) -> dict[MemoryId, Memory]:
|
||||
return {m.id: m for m in self.corpus}
|
||||
|
||||
def strata(self) -> set[str]:
|
||||
return {q.stratum for q in self.queries}
|
||||
|
||||
|
||||
def _read_jsonl(path: Path) -> list[dict]:
|
||||
out: list[dict] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
|
||||
def load_corpus(path: Path | None = None) -> list[Memory]:
|
||||
path = path or (_DATA_DIR / "corpus.jsonl")
|
||||
rows = _read_jsonl(path)
|
||||
return [
|
||||
Memory(
|
||||
id=r["id"],
|
||||
content=r["content"],
|
||||
category=r.get("category", "facts"),
|
||||
tags=r.get("tags", "") or "",
|
||||
expanded_keywords=r.get("expanded_keywords", "") or "",
|
||||
importance=r.get("importance", 0.5),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def load_queries(path: Path | None = None) -> list[Query]:
|
||||
path = path or (_DATA_DIR / "queries.jsonl")
|
||||
rows = _read_jsonl(path)
|
||||
return [
|
||||
Query(
|
||||
query_id=r["query_id"],
|
||||
text=r["text"],
|
||||
stratum=r["stratum"],
|
||||
relevant_ids=tuple(r.get("relevant_ids", [])),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def load_qrels(path: Path | None = None) -> Qrels:
|
||||
path = path or (_DATA_DIR / "qrels.jsonl")
|
||||
rows = _read_jsonl(path)
|
||||
qrels: Qrels = {}
|
||||
for r in rows:
|
||||
qid = r["query_id"]
|
||||
rel = set(r["relevant_ids"])
|
||||
qrels.setdefault(qid, set()).update(rel)
|
||||
return qrels
|
||||
|
||||
|
||||
def load_dataset(
|
||||
corpus_path: Path | None = None,
|
||||
queries_path: Path | None = None,
|
||||
qrels_path: Path | None = None,
|
||||
*,
|
||||
validate: bool = True,
|
||||
) -> Dataset:
|
||||
corpus = load_corpus(corpus_path)
|
||||
queries = load_queries(queries_path)
|
||||
qrels = load_qrels(qrels_path)
|
||||
|
||||
if validate:
|
||||
_validate(corpus, queries, qrels)
|
||||
|
||||
return Dataset(corpus=corpus, queries=queries, qrels=qrels)
|
||||
|
||||
|
||||
def _validate(corpus: list[Memory], queries: list[Query], qrels: Qrels) -> None:
|
||||
corpus_ids = {m.id for m in corpus}
|
||||
q_ids = {q.query_id for q in queries}
|
||||
|
||||
# Every query must have a qrels entry, and vice versa.
|
||||
missing_qrels = q_ids - set(qrels)
|
||||
if missing_qrels:
|
||||
raise ValueError(f"queries without qrels: {sorted(missing_qrels)[:10]}")
|
||||
orphan_qrels = set(qrels) - q_ids
|
||||
if orphan_qrels:
|
||||
raise ValueError(f"qrels without queries: {sorted(orphan_qrels)[:10]}")
|
||||
|
||||
# Every relevant id must exist in the corpus and the set must be non-empty.
|
||||
for qid, rels in qrels.items():
|
||||
if not rels:
|
||||
raise ValueError(f"empty qrels for query {qid}")
|
||||
unknown = rels - corpus_ids
|
||||
if unknown:
|
||||
raise ValueError(
|
||||
f"query {qid} references non-corpus ids {sorted(unknown)[:10]}"
|
||||
)
|
||||
59
benchmarks/harness/example_retriever.py
Normal file
59
benchmarks/harness/example_retriever.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""Worked example: how a later agent plugs a retriever into the harness.
|
||||
|
||||
A retriever needs only one method:
|
||||
|
||||
retrieve(self, query: str, k: int) -> list[int] # ranked memory ids
|
||||
|
||||
Optionally it may implement lifecycle hooks the runner will use if present:
|
||||
|
||||
build_index(self, corpus: list[Memory]) -> None # timed separately
|
||||
index_size_bytes(self) -> int # reported
|
||||
|
||||
Run this file directly for a smoke test against the local eval set:
|
||||
.venv/bin/python -m harness.example_retriever
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from .types import Memory, MemoryId
|
||||
|
||||
|
||||
class SubstringRetriever:
|
||||
"""Trivial baseline: rank by count of query-word occurrences in content.
|
||||
|
||||
Deliberately weak — exists only to demonstrate the interface. The real
|
||||
lexical baseline is harness.baselines.SqliteFtsRetriever.
|
||||
"""
|
||||
|
||||
name = "substring_demo"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._corpus: list[Memory] = []
|
||||
|
||||
def build_index(self, corpus: Sequence[Memory]) -> None:
|
||||
self._corpus = list(corpus)
|
||||
|
||||
def retrieve(self, query: str, k: int) -> list[MemoryId]:
|
||||
words = [w for w in query.lower().split() if len(w) > 2]
|
||||
scored: list[tuple[int, float]] = []
|
||||
for m in self._corpus:
|
||||
hay = (m.content + " " + m.expanded_keywords + " " + m.tags).lower()
|
||||
score = sum(hay.count(w) for w in words)
|
||||
if score:
|
||||
scored.append((m.id, score + m.importance)) # importance tiebreak
|
||||
scored.sort(key=lambda t: t[1], reverse=True)
|
||||
return [mid for mid, _ in scored[:k]]
|
||||
|
||||
|
||||
def _smoke() -> None:
|
||||
from .dataset import load_dataset
|
||||
from .runner import run_benchmark
|
||||
|
||||
ds = load_dataset()
|
||||
res = run_benchmark(SubstringRetriever(), ds)
|
||||
print(res.summary())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_smoke()
|
||||
100
benchmarks/harness/metrics.py
Normal file
100
benchmarks/harness/metrics.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
"""Retrieval metrics with BINARY relevance.
|
||||
|
||||
Conventions
|
||||
-----------
|
||||
- `ranked`: list of memory ids, best-first, as returned by a retriever.
|
||||
- `relevant`: set of relevant memory ids for the query (from qrels).
|
||||
- All functions are pure and operate on a single query; the runner aggregates
|
||||
(macro-average over queries).
|
||||
|
||||
Definitions
|
||||
-----------
|
||||
recall@k = |relevant ∩ ranked[:k]| / |relevant|
|
||||
(fraction of all relevant items retrieved within the top k)
|
||||
MRR = 1 / rank_of_first_relevant (0 if none retrieved at all)
|
||||
nDCG@k = DCG@k / IDCG@k with binary gains (gain=1 for relevant)
|
||||
DCG@k = sum over i in [1..k] of rel_i / log2(i + 1)
|
||||
IDCG@k is the DCG of the ideal ranking (all relevant first),
|
||||
capped at min(|relevant|, k) ones.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- nDCG uses the standard log2(rank+1) discount (Järvelin & Kekäläinen 2002);
|
||||
with binary gains this is the common IR convention also used by BEIR/pytrec_eval.
|
||||
- MRR is reported as the reciprocal rank of the FIRST relevant hit, which for a
|
||||
single query equals the per-query reciprocal-rank that the runner averages.
|
||||
- Duplicate ids in `ranked` are de-duplicated keeping first occurrence, so a
|
||||
retriever cannot inflate recall by repeating an id.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
MemoryId = int
|
||||
|
||||
|
||||
def _dedup_keep_order(ranked: Sequence[MemoryId]) -> list[MemoryId]:
|
||||
seen: set[MemoryId] = set()
|
||||
out: list[MemoryId] = []
|
||||
for x in ranked:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
def recall_at_k(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId], k: int) -> float:
|
||||
rel = set(relevant)
|
||||
if not rel:
|
||||
# Undefined; treat as 0 contribution. Runner should never pass empty.
|
||||
return 0.0
|
||||
top = _dedup_keep_order(ranked)[:k]
|
||||
hits = sum(1 for x in top if x in rel)
|
||||
return hits / len(rel)
|
||||
|
||||
|
||||
def reciprocal_rank(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId]) -> float:
|
||||
rel = set(relevant)
|
||||
if not rel:
|
||||
return 0.0
|
||||
for i, x in enumerate(_dedup_keep_order(ranked), start=1):
|
||||
if x in rel:
|
||||
return 1.0 / i
|
||||
return 0.0
|
||||
|
||||
|
||||
def dcg_at_k(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId], k: int) -> float:
|
||||
rel = set(relevant)
|
||||
top = _dedup_keep_order(ranked)[:k]
|
||||
dcg = 0.0
|
||||
for i, x in enumerate(top, start=1):
|
||||
if x in rel:
|
||||
dcg += 1.0 / math.log2(i + 1)
|
||||
return dcg
|
||||
|
||||
|
||||
def ndcg_at_k(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId], k: int) -> float:
|
||||
rel = set(relevant)
|
||||
if not rel:
|
||||
return 0.0
|
||||
dcg = dcg_at_k(ranked, rel, k)
|
||||
ideal_hits = min(len(rel), k)
|
||||
idcg = sum(1.0 / math.log2(i + 1) for i in range(1, ideal_hits + 1))
|
||||
if idcg == 0.0:
|
||||
return 0.0
|
||||
return dcg / idcg
|
||||
|
||||
|
||||
def per_query_metrics(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId]) -> dict[str, float]:
|
||||
"""All headline metrics for one query."""
|
||||
rel = set(relevant)
|
||||
return {
|
||||
"recall@5": recall_at_k(ranked, rel, 5),
|
||||
"recall@10": recall_at_k(ranked, rel, 10),
|
||||
"ndcg@10": ndcg_at_k(ranked, rel, 10),
|
||||
"mrr": reciprocal_rank(ranked, rel),
|
||||
}
|
||||
|
||||
|
||||
METRIC_NAMES = ("recall@5", "recall@10", "ndcg@10", "mrr")
|
||||
223
benchmarks/harness/runner.py
Normal file
223
benchmarks/harness/runner.py
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
"""Benchmark runner: drive a pluggable retriever over the eval set and report
|
||||
overall + per-stratum quality metrics, plus per-query latency and (optional)
|
||||
index build time / size.
|
||||
|
||||
Quality decides adoption (recall@k, nDCG@10, MRR). Latency and storage are
|
||||
measured and reported but DO NOT gate the decision (ADR-0001 success metric).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import statistics
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Any
|
||||
|
||||
from . import metrics
|
||||
from .dataset import Dataset
|
||||
from .types import MemoryId, Query, Retriever
|
||||
|
||||
# A retriever may be the Protocol object or a bare callable retrieve(query, k).
|
||||
RetrieverLike = Retriever | Callable[[str, int], list[MemoryId]]
|
||||
|
||||
# k used for the retrieve() call. We request enough depth to compute all
|
||||
# metrics (max cutoff is 10) with headroom so ties past k=10 don't distort.
|
||||
DEFAULT_RETRIEVE_K = 20
|
||||
|
||||
|
||||
def _percentile(values: list[float], pct: float) -> float:
|
||||
"""Linear-interpolation percentile (pct in [0,100]). Empty -> 0.0."""
|
||||
if not values:
|
||||
return 0.0
|
||||
if len(values) == 1:
|
||||
return values[0]
|
||||
s = sorted(values)
|
||||
rank = (pct / 100.0) * (len(s) - 1)
|
||||
lo = int(rank)
|
||||
hi = min(lo + 1, len(s) - 1)
|
||||
frac = rank - lo
|
||||
return s[lo] + (s[hi] - s[lo]) * frac
|
||||
|
||||
|
||||
@dataclass
|
||||
class StratumResult:
|
||||
stratum: str
|
||||
n_queries: int
|
||||
metrics: dict[str, float] # macro-averaged metric -> value
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
retriever_name: str
|
||||
n_queries: int
|
||||
retrieve_k: int
|
||||
overall: dict[str, float]
|
||||
per_stratum: dict[str, StratumResult]
|
||||
latency_ms: dict[str, float] # mean / p50 / p95 / max
|
||||
index_build_seconds: float | None = None
|
||||
index_size_bytes: int | None = None
|
||||
per_query: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = asdict(self)
|
||||
d["per_stratum"] = {k: asdict(v) for k, v in self.per_stratum.items()}
|
||||
return d
|
||||
|
||||
def summary(self) -> str:
|
||||
lines = [
|
||||
f"Retriever: {self.retriever_name}",
|
||||
f"Queries: {self.n_queries} (retrieve_k={self.retrieve_k})",
|
||||
]
|
||||
if self.index_build_seconds is not None:
|
||||
lines.append(f"Index build: {self.index_build_seconds:.3f}s")
|
||||
if self.index_size_bytes is not None:
|
||||
lines.append(f"Index size: {self.index_size_bytes / 1e6:.2f} MB")
|
||||
lat = self.latency_ms
|
||||
lines.append(
|
||||
"Latency/query: "
|
||||
f"p50={lat['p50']:.2f}ms p95={lat['p95']:.2f}ms "
|
||||
f"mean={lat['mean']:.2f}ms max={lat['max']:.2f}ms"
|
||||
)
|
||||
cols = metrics.METRIC_NAMES
|
||||
header = " ".join(f"{c:>10}" for c in cols)
|
||||
lines.append("")
|
||||
lines.append(f"{'stratum':<12}{'n':>5} {header}")
|
||||
lines.append("-" * (19 + len(header)))
|
||||
for name in ("overall", *sorted(self.per_stratum)):
|
||||
if name == "overall":
|
||||
m, n = self.overall, self.n_queries
|
||||
else:
|
||||
sr = self.per_stratum[name]
|
||||
m, n = sr.metrics, sr.n_queries
|
||||
row = " ".join(f"{m[c]:>10.4f}" for c in cols)
|
||||
lines.append(f"{name:<12}{n:>5} {row}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _get_retrieve_fn(retriever: RetrieverLike) -> Callable[[str, int], list[MemoryId]]:
|
||||
if hasattr(retriever, "retrieve"):
|
||||
return retriever.retrieve # type: ignore[attr-defined]
|
||||
if callable(retriever):
|
||||
return retriever
|
||||
raise TypeError("retriever must implement retrieve(query, k) or be callable")
|
||||
|
||||
|
||||
def _maybe_build_index(retriever: RetrieverLike, dataset: Dataset) -> tuple[float | None, int | None]:
|
||||
"""Call optional lifecycle hooks if present (duck-typed).
|
||||
|
||||
- build_index(corpus) -> None : measured wall-clock build time.
|
||||
- index_size_bytes() -> int : reported on-disk/in-memory index size.
|
||||
Returns (build_seconds_or_None, size_bytes_or_None).
|
||||
"""
|
||||
build_seconds: float | None = None
|
||||
size_bytes: int | None = None
|
||||
|
||||
build = getattr(retriever, "build_index", None)
|
||||
if callable(build):
|
||||
t0 = time.perf_counter()
|
||||
build(dataset.corpus)
|
||||
build_seconds = time.perf_counter() - t0
|
||||
|
||||
size_fn = getattr(retriever, "index_size_bytes", None)
|
||||
if callable(size_fn):
|
||||
try:
|
||||
size_bytes = int(size_fn())
|
||||
except Exception:
|
||||
size_bytes = None
|
||||
|
||||
return build_seconds, size_bytes
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
retriever: RetrieverLike,
|
||||
dataset: Dataset,
|
||||
*,
|
||||
retrieve_k: int = DEFAULT_RETRIEVE_K,
|
||||
retriever_name: str | None = None,
|
||||
warmup: bool = True,
|
||||
collect_per_query: bool = True,
|
||||
) -> BenchmarkResult:
|
||||
"""Evaluate `retriever` over `dataset`.
|
||||
|
||||
The retriever is asked for `retrieve_k` ids per query (>= max metric
|
||||
cutoff of 10). Metrics are macro-averaged over queries, overall and per
|
||||
stratum. Latency is measured around each retrieve() call only (index build
|
||||
is timed separately via the optional build_index hook).
|
||||
"""
|
||||
name = retriever_name or getattr(retriever, "name", None) or type(retriever).__name__
|
||||
retrieve = _get_retrieve_fn(retriever)
|
||||
qrels = dataset.qrels
|
||||
|
||||
build_seconds, size_bytes = _maybe_build_index(retriever, dataset)
|
||||
|
||||
# Optional warmup (first call can pay import/JIT/connection costs that would
|
||||
# skew p95). Excluded from latency stats. Uses the first query if any.
|
||||
if warmup and dataset.queries:
|
||||
try:
|
||||
retrieve(dataset.queries[0].text, retrieve_k)
|
||||
except Exception:
|
||||
pass # warmup failures surface on the real call below
|
||||
|
||||
per_query_rows: list[dict[str, Any]] = []
|
||||
latencies_ms: list[float] = []
|
||||
# accumulate per-stratum metric sums for macro-average
|
||||
strata: dict[str, dict[str, float]] = {}
|
||||
strata_counts: dict[str, int] = {}
|
||||
overall_sums = {m: 0.0 for m in metrics.METRIC_NAMES}
|
||||
|
||||
for q in dataset.queries:
|
||||
rel = qrels[q.query_id]
|
||||
t0 = time.perf_counter()
|
||||
ranked = list(retrieve(q.text, retrieve_k))
|
||||
dt_ms = (time.perf_counter() - t0) * 1000.0
|
||||
latencies_ms.append(dt_ms)
|
||||
|
||||
m = metrics.per_query_metrics(ranked, rel)
|
||||
for key, val in m.items():
|
||||
overall_sums[key] += val
|
||||
strata.setdefault(q.stratum, {mm: 0.0 for mm in metrics.METRIC_NAMES})
|
||||
strata_counts[q.stratum] = strata_counts.get(q.stratum, 0) + 1
|
||||
for key, val in m.items():
|
||||
strata[q.stratum][key] += val
|
||||
|
||||
if collect_per_query:
|
||||
per_query_rows.append(
|
||||
{
|
||||
"query_id": q.query_id,
|
||||
"stratum": q.stratum,
|
||||
"n_relevant": len(rel),
|
||||
"latency_ms": round(dt_ms, 3),
|
||||
"retrieved": ranked[:retrieve_k],
|
||||
**{k: round(v, 6) for k, v in m.items()},
|
||||
}
|
||||
)
|
||||
|
||||
n = len(dataset.queries)
|
||||
overall = {k: (overall_sums[k] / n if n else 0.0) for k in metrics.METRIC_NAMES}
|
||||
per_stratum: dict[str, StratumResult] = {}
|
||||
for s, sums in strata.items():
|
||||
c = strata_counts[s]
|
||||
per_stratum[s] = StratumResult(
|
||||
stratum=s,
|
||||
n_queries=c,
|
||||
metrics={k: (sums[k] / c if c else 0.0) for k in metrics.METRIC_NAMES},
|
||||
)
|
||||
|
||||
latency_stats = {
|
||||
"mean": statistics.fmean(latencies_ms) if latencies_ms else 0.0,
|
||||
"p50": _percentile(latencies_ms, 50),
|
||||
"p95": _percentile(latencies_ms, 95),
|
||||
"max": max(latencies_ms) if latencies_ms else 0.0,
|
||||
}
|
||||
|
||||
return BenchmarkResult(
|
||||
retriever_name=name,
|
||||
n_queries=n,
|
||||
retrieve_k=retrieve_k,
|
||||
overall=overall,
|
||||
per_stratum=per_stratum,
|
||||
latency_ms=latency_stats,
|
||||
index_build_seconds=build_seconds,
|
||||
index_size_bytes=size_bytes,
|
||||
per_query=per_query_rows,
|
||||
)
|
||||
145
benchmarks/harness/test_harness.py
Normal file
145
benchmarks/harness/test_harness.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""Unit tests for metrics + runner. No real corpus needed (synthetic data).
|
||||
|
||||
Run: .venv/bin/python -m pytest harness/test_harness.py -q
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from harness import metrics
|
||||
from harness.dataset import Dataset
|
||||
from harness.runner import run_benchmark, _percentile
|
||||
from harness.types import Memory, Query
|
||||
|
||||
|
||||
# ---------------- metrics ----------------
|
||||
|
||||
def test_recall_at_k_basic():
|
||||
ranked = [9, 8, 3, 7, 1]
|
||||
rel = {3, 1, 99} # 99 never retrieved
|
||||
assert metrics.recall_at_k(ranked, rel, 5) == 2 / 3
|
||||
assert metrics.recall_at_k(ranked, rel, 2) == 0.0 # neither in top2
|
||||
assert metrics.recall_at_k(ranked, rel, 3) == 1 / 3 # only id 3 in top3
|
||||
|
||||
|
||||
def test_recall_perfect_and_zero():
|
||||
assert metrics.recall_at_k([1, 2, 3], {1, 2, 3}, 5) == 1.0
|
||||
assert metrics.recall_at_k([4, 5, 6], {1, 2, 3}, 5) == 0.0
|
||||
|
||||
|
||||
def test_reciprocal_rank():
|
||||
assert metrics.reciprocal_rank([5, 4, 3], {3}) == 1 / 3
|
||||
assert metrics.reciprocal_rank([3, 4, 5], {3}) == 1.0
|
||||
assert metrics.reciprocal_rank([7, 8], {3}) == 0.0
|
||||
# first relevant wins
|
||||
assert metrics.reciprocal_rank([9, 3, 1], {1, 3}) == 1 / 2
|
||||
|
||||
|
||||
def test_ndcg_perfect():
|
||||
# all relevant at the top -> nDCG == 1
|
||||
assert math.isclose(metrics.ndcg_at_k([1, 2, 3, 4], {1, 2, 3}, 10), 1.0)
|
||||
|
||||
|
||||
def test_ndcg_known_value():
|
||||
# single relevant doc at rank 2: DCG = 1/log2(3); IDCG = 1/log2(2)=1
|
||||
ranked = [9, 1, 8]
|
||||
val = metrics.ndcg_at_k(ranked, {1}, 10)
|
||||
assert math.isclose(val, (1 / math.log2(3)) / 1.0)
|
||||
|
||||
|
||||
def test_ndcg_two_relevant_suboptimal_order():
|
||||
# relevant {1,2}; retrieved at ranks 1 and 3
|
||||
ranked = [1, 9, 2]
|
||||
dcg = 1 / math.log2(2) + 1 / math.log2(4) # ranks 1 and 3
|
||||
idcg = 1 / math.log2(2) + 1 / math.log2(3) # ideal ranks 1 and 2
|
||||
assert math.isclose(metrics.ndcg_at_k(ranked, {1, 2}, 10), dcg / idcg)
|
||||
|
||||
|
||||
def test_dedup_does_not_inflate():
|
||||
# repeating a relevant id must not increase recall beyond 1 hit's worth
|
||||
ranked = [3, 3, 3, 3]
|
||||
assert metrics.recall_at_k(ranked, {3, 7}, 5) == 0.5
|
||||
assert metrics.reciprocal_rank(ranked, {3}) == 1.0
|
||||
|
||||
|
||||
def test_empty_relevant_is_zero():
|
||||
assert metrics.recall_at_k([1, 2], set(), 5) == 0.0
|
||||
assert metrics.ndcg_at_k([1, 2], set(), 5) == 0.0
|
||||
|
||||
|
||||
# ---------------- percentile ----------------
|
||||
|
||||
def test_percentile():
|
||||
vals = [10, 20, 30, 40]
|
||||
assert _percentile(vals, 50) == 25.0 # interpolated median
|
||||
assert _percentile(vals, 0) == 10
|
||||
assert _percentile(vals, 100) == 40
|
||||
assert _percentile([5.0], 95) == 5.0
|
||||
assert _percentile([], 50) == 0.0
|
||||
|
||||
|
||||
# ---------------- runner ----------------
|
||||
|
||||
def _toy_dataset() -> Dataset:
|
||||
corpus = [Memory(id=i, content=f"memory {i}") for i in range(1, 11)]
|
||||
queries = [
|
||||
Query("q_exact_1", "find 1", "exact", (1,)),
|
||||
Query("q_para_1", "restate 5", "paraphrase", (5,)),
|
||||
Query("q_multi_1", "join 3 and 4", "multihop", (3, 4)),
|
||||
]
|
||||
qrels = {"q_exact_1": {1}, "q_para_1": {5}, "q_multi_1": {3, 4}}
|
||||
return Dataset(corpus=corpus, queries=queries, qrels=qrels)
|
||||
|
||||
|
||||
class _PerfectRetriever:
|
||||
"""Returns exactly the relevant ids first (oracle) — for runner plumbing."""
|
||||
|
||||
def __init__(self, qrels):
|
||||
self._qrels = qrels
|
||||
self._by_text = None
|
||||
|
||||
def build_index(self, corpus):
|
||||
self._n = len(corpus)
|
||||
|
||||
def index_size_bytes(self):
|
||||
return 1234
|
||||
|
||||
def retrieve(self, query, k):
|
||||
# map query text back via the toy queries' known answers
|
||||
mapping = {"find 1": [1], "restate 5": [5], "join 3 and 4": [3, 4]}
|
||||
ids = mapping.get(query, [])
|
||||
# pad with distractors
|
||||
pad = [x for x in range(100, 100 + k)]
|
||||
return (ids + pad)[:k]
|
||||
|
||||
|
||||
def test_runner_perfect_retriever():
|
||||
ds = _toy_dataset()
|
||||
r = _PerfectRetriever(ds.qrels)
|
||||
res = run_benchmark(r, ds, retriever_name="perfect")
|
||||
assert res.n_queries == 3
|
||||
assert math.isclose(res.overall["recall@10"], 1.0)
|
||||
assert math.isclose(res.overall["mrr"], 1.0)
|
||||
assert math.isclose(res.overall["ndcg@10"], 1.0)
|
||||
# per-stratum present
|
||||
assert set(res.per_stratum) == {"exact", "paraphrase", "multihop"}
|
||||
assert res.per_stratum["multihop"].n_queries == 1
|
||||
# lifecycle hooks captured
|
||||
assert res.index_build_seconds is not None
|
||||
assert res.index_size_bytes == 1234
|
||||
# latency recorded
|
||||
assert res.latency_ms["p95"] >= 0.0
|
||||
|
||||
|
||||
def test_runner_callable_retriever_and_misses():
|
||||
ds = _toy_dataset()
|
||||
|
||||
def retrieve(query, k): # always wrong
|
||||
return [999][:k]
|
||||
|
||||
res = run_benchmark(retrieve, ds, retriever_name="bad", warmup=False)
|
||||
assert res.overall["recall@10"] == 0.0
|
||||
assert res.overall["mrr"] == 0.0
|
||||
assert res.index_build_seconds is None # no hook on a bare callable
|
||||
assert "perfect" not in res.summary()
|
||||
assert "bad" in res.summary()
|
||||
53
benchmarks/harness/types.py
Normal file
53
benchmarks/harness/types.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Core dataclasses and the pluggable Retriever protocol."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
MemoryId = int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Memory:
|
||||
"""One corpus entry (mirrors corpus.jsonl)."""
|
||||
|
||||
id: MemoryId
|
||||
content: str
|
||||
category: str = "facts"
|
||||
tags: str = ""
|
||||
expanded_keywords: str = ""
|
||||
importance: float = 0.5
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Query:
|
||||
"""One eval query (mirrors queries.jsonl)."""
|
||||
|
||||
query_id: str
|
||||
text: str
|
||||
stratum: str # "exact" | "paraphrase" | "multihop"
|
||||
# convenience copy of relevant ids; authoritative source is Qrels
|
||||
relevant_ids: tuple[MemoryId, ...] = field(default_factory=tuple)
|
||||
|
||||
|
||||
# query_id -> set of relevant memory ids (binary relevance)
|
||||
Qrels = dict[str, set[MemoryId]]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Retriever(Protocol):
|
||||
"""Pluggable retriever contract.
|
||||
|
||||
Implementations rank corpus memories for a query and return the top-k
|
||||
memory ids, best match first. The harness will call `retrieve` once per
|
||||
query and compare against qrels.
|
||||
|
||||
Optional lifecycle hooks let a retriever build an index from the corpus
|
||||
and report index build time / on-disk size; the runner uses them if
|
||||
present (duck-typed), so a minimal retriever need only implement
|
||||
`retrieve`.
|
||||
"""
|
||||
|
||||
def retrieve(self, query: str, k: int) -> list[MemoryId]:
|
||||
"""Return up to k memory ids, ranked best-first."""
|
||||
...
|
||||
10
benchmarks/retrievers/__init__.py
Normal file
10
benchmarks/retrievers/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""Pluggable retrievers for the claude-memory recall benchmark.
|
||||
|
||||
Each retriever implements the harness `retrieve(query, k) -> list[int]` contract
|
||||
(see ``harness/types.py`` :: ``Retriever``) and, optionally, the ``build_index`` /
|
||||
``index_size_bytes`` lifecycle hooks the runner duck-types.
|
||||
|
||||
``fts.FtsRetriever`` is the LEXICAL BASELINE — the product's current local-store
|
||||
recall (SQLite FTS5/BM25). It is the "current system" any hybrid retriever must
|
||||
beat on recall@k / nDCG@10 / MRR (ADR-0001).
|
||||
"""
|
||||
224
benchmarks/retrievers/fts.py
Normal file
224
benchmarks/retrievers/fts.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
"""BASELINE retriever: the product's CURRENT lexical recall (SQLite FTS5/BM25).
|
||||
|
||||
This is the "current system" the hybrid upgrade (dense embeddings + concept
|
||||
graph, ADR-0001) must beat on recall@k / nDCG@10 / MRR. It is a *faithful*
|
||||
reimplementation of the production local-store recall path, not an idealised
|
||||
sketch — it mirrors ``src/claude_memory/mcp_server.py :: _sqlite_recall`` (and
|
||||
the FTS5 schema/triggers in the same module) line-for-line where it matters:
|
||||
|
||||
Production recall (``sort_by="relevance"``) does ALL of the following, and so
|
||||
does this retriever:
|
||||
|
||||
1. **Concatenate then split.** The MCP tool builds
|
||||
``all_terms = f"{context} {expanded_query}"`` and splits it on whitespace,
|
||||
stripping any embedded ``"`` from each token. The harness already hands us
|
||||
one ``query`` string (the concatenation happens upstream of recall), so here
|
||||
``query`` IS ``all_terms``; we split + strip identically.
|
||||
|
||||
2. **AND-first, then OR-broaden.** Production builds BOTH
|
||||
``'"w1" AND "w2" ...'`` and ``'"w1" OR "w2" ...'`` and runs the **AND** match
|
||||
first; only if it returns zero rows does it fall back to the **OR** match.
|
||||
(The README's "Search Algorithm" prose shows only the OR form; the *code* is
|
||||
AND→OR, and the code is authoritative. We replicate the code.)
|
||||
|
||||
3. **Blended BM25+importance relevance ordering.** ``sort_by="relevance"`` is
|
||||
NOT a pure ``ORDER BY bm25()``. It is the blend
|
||||
``(-bm25(memories_fts) * 0.7 + importance * 0.3) DESC`` (bm25 is negated
|
||||
because SQLite returns more-negative = better-match). We use the EXACT same
|
||||
expression. We deliberately evaluate ``relevance`` (not the production
|
||||
``importance`` default) so the benchmark measures RETRIEVAL quality rather
|
||||
than the importance-sort prior — per the research brief.
|
||||
|
||||
4. **FTS5 default tokenizer.** The production virtual table is declared with no
|
||||
explicit tokenizer, i.e. ``unicode61`` — case-folding + unicode diacritic
|
||||
stripping, NO stemming and NO stop-word removal. We declare ours the same
|
||||
way, so "running" does not match "run" (a known lexical weakness the dense
|
||||
path is expected to fix on the *paraphrase* stratum).
|
||||
|
||||
5. **LIKE fallback.** If the FTS5 MATCH raises ``sqlite3.OperationalError``
|
||||
(e.g. a token that trips the FTS5 query grammar), production degrades to a
|
||||
``content LIKE %context% OR tags LIKE %context%`` scan ordered by importance.
|
||||
We mirror that fallback (using the full query as the LIKE needle, since the
|
||||
harness query is the whole ``all_terms``).
|
||||
|
||||
DIFFERENCES FROM PRODUCTION (all immaterial to ranking, documented for honesty):
|
||||
- The benchmark corpus has no per-user / soft-delete / category filtering, so we
|
||||
drop the ``user_id``/``deleted_at``/``category`` predicates. No category is
|
||||
passed by the harness, so the category branch is never taken anyway.
|
||||
- We build a fresh in-memory FTS5 index over ``data/corpus.jsonl`` rather than
|
||||
reading the live ``memory.db``; same schema, same tokenizer, same columns
|
||||
(content/category/tags/expanded_keywords), so BM25 statistics match what the
|
||||
product would compute over the same documents.
|
||||
|
||||
The harness reference ``harness.baselines.SqliteFtsRetriever`` implements the
|
||||
*README* ordering (pure ``ORDER BY bm25(), importance``). This module is the
|
||||
faithful-to-the-CODE variant and is the one the RUN reports as ``retriever="fts"``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from collections.abc import Sequence
|
||||
|
||||
# Import the corpus dataclass from the sibling harness package. run_eval.py and
|
||||
# run_benchmark put the benchmarks/ root on sys.path; support direct execution
|
||||
# (python retrievers/fts.py) too by adding it ourselves if the import fails.
|
||||
try: # pragma: no cover - exercised by both import paths
|
||||
from harness.types import Memory, MemoryId
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
from harness.types import Memory, MemoryId
|
||||
|
||||
# Mirror production token extraction: split ``all_terms`` on whitespace and strip
|
||||
# any embedded double-quote from each token (mcp_server uses
|
||||
# ``w.replace(chr(34), "")``). We lowercase as well; FTS5 unicode61 case-folds
|
||||
# regardless, so this only normalises the quoted MATCH literals we emit.
|
||||
_DQUOTE = '"'
|
||||
|
||||
|
||||
class FtsRetriever:
|
||||
"""Faithful reimplementation of the production SQLite FTS5/BM25 recall.
|
||||
|
||||
Mirrors ``_sqlite_recall(sort_by="relevance")``: AND-first then OR-broaden
|
||||
over an FTS5(content, category, tags, expanded_keywords) index, ranked by
|
||||
the blended ``(-bm25*0.7 + importance*0.3)`` score, with a LIKE fallback.
|
||||
"""
|
||||
|
||||
#: Label surfaced in benchmark reports / the RUN schema.
|
||||
name = "fts"
|
||||
|
||||
def __init__(self, sort_by: str = "relevance") -> None:
|
||||
# We benchmark "relevance" so the metric reflects retrieval quality, not
|
||||
# the importance prior. "importance" is kept for parity / diagnostics.
|
||||
if sort_by not in ("relevance", "importance"):
|
||||
raise ValueError(f"sort_by must be 'relevance' or 'importance', got {sort_by!r}")
|
||||
self.sort_by = sort_by
|
||||
self._con: sqlite3.Connection | None = None
|
||||
|
||||
# ── lifecycle hooks (duck-typed by the runner) ───────────────────────────
|
||||
|
||||
def build_index(self, corpus: Sequence[Memory]) -> None:
|
||||
"""Build a fresh in-memory FTS5 index over the corpus.
|
||||
|
||||
Same virtual-table shape and (default ``unicode61``) tokenizer as the
|
||||
production ``memories_fts`` table. We carry ``memory_id`` and
|
||||
``importance`` as UNINDEXED columns so the relevance blend can read
|
||||
importance without a join — semantically identical to the production
|
||||
``memories m JOIN memories_fts fts ON m.id = fts.rowid`` read.
|
||||
"""
|
||||
con = sqlite3.connect(":memory:")
|
||||
con.execute(
|
||||
"""
|
||||
CREATE VIRTUAL TABLE memories_fts USING fts5(
|
||||
content, category, tags, expanded_keywords,
|
||||
memory_id UNINDEXED, importance UNINDEXED
|
||||
)
|
||||
"""
|
||||
)
|
||||
con.executemany(
|
||||
"INSERT INTO memories_fts"
|
||||
"(content, category, tags, expanded_keywords, memory_id, importance)"
|
||||
" VALUES (?,?,?,?,?,?)",
|
||||
[
|
||||
(
|
||||
m.content,
|
||||
m.category,
|
||||
m.tags,
|
||||
m.expanded_keywords,
|
||||
int(m.id),
|
||||
float(m.importance),
|
||||
)
|
||||
for m in corpus
|
||||
],
|
||||
)
|
||||
con.commit()
|
||||
self._con = con
|
||||
|
||||
def index_size_bytes(self) -> int:
|
||||
"""Approximate on-disk index size (sum of FTS5 shadow-table page bytes).
|
||||
|
||||
The index is in-memory, so this is the SQLite page accounting for the
|
||||
FTS5 shadow tables — reported for the storage column, non-gating per
|
||||
ADR-0001.
|
||||
"""
|
||||
if self._con is None:
|
||||
return 0
|
||||
try:
|
||||
page_count = self._con.execute("PRAGMA page_count").fetchone()[0]
|
||||
page_size = self._con.execute("PRAGMA page_size").fetchone()[0]
|
||||
return int(page_count) * int(page_size)
|
||||
except sqlite3.Error:
|
||||
return 0
|
||||
|
||||
# ── query construction (mirrors _sqlite_recall) ──────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _tokens(query: str) -> list[str]:
|
||||
"""Split ``all_terms`` exactly as production does: whitespace split,
|
||||
drop embedded double-quotes, drop empties."""
|
||||
return [w.replace(_DQUOTE, "").lower() for w in query.split() if w.strip()]
|
||||
|
||||
@classmethod
|
||||
def _and_or_queries(cls, query: str) -> tuple[str, str]:
|
||||
"""Build the ('"w1" AND "w2" ...', '"w1" OR "w2" ...') MATCH pair."""
|
||||
words = cls._tokens(query)
|
||||
if not words:
|
||||
return "", ""
|
||||
quoted = [f'"{w}"' for w in words]
|
||||
return " AND ".join(quoted), " OR ".join(quoted)
|
||||
|
||||
def _order_clause(self) -> str:
|
||||
# bm25() is negative (more-negative = better), so negate before blending.
|
||||
if self.sort_by == "relevance":
|
||||
return "(-bm25(memories_fts) * 0.7 + importance * 0.3) DESC"
|
||||
return "(-bm25(memories_fts) * 0.4 + importance * 0.6) DESC"
|
||||
|
||||
# ── retrieve ──────────────────────────────────────────────────────────────
|
||||
|
||||
def retrieve(self, query: str, k: int) -> list[MemoryId]:
|
||||
"""Return up to ``k`` memory ids, ranked best-first.
|
||||
|
||||
AND-match first (precise); if it yields nothing, OR-broaden. On an FTS5
|
||||
grammar error, fall back to a LIKE scan ordered by importance — exactly
|
||||
the production degradation path.
|
||||
"""
|
||||
assert self._con is not None, "call build_index first"
|
||||
and_query, or_query = self._and_or_queries(query)
|
||||
if not or_query: # no usable tokens
|
||||
return []
|
||||
|
||||
order = self._order_clause()
|
||||
base_select = "SELECT memory_id FROM memories_fts WHERE memories_fts MATCH ? "
|
||||
try:
|
||||
rows: list[tuple[int]] = []
|
||||
# AND first for precise matches, fall back to OR for broader recall.
|
||||
for fts_query in (and_query, or_query):
|
||||
rows = self._con.execute(
|
||||
f"{base_select}ORDER BY {order} LIMIT ?",
|
||||
(fts_query, k),
|
||||
).fetchall()
|
||||
if rows:
|
||||
break
|
||||
except sqlite3.OperationalError:
|
||||
# Mirror production LIKE fallback: full query as the needle,
|
||||
# ordered by importance.
|
||||
like = f"%{query}%"
|
||||
rows = self._con.execute(
|
||||
"SELECT memory_id FROM memories_fts "
|
||||
"WHERE content LIKE ? OR tags LIKE ? "
|
||||
"ORDER BY importance DESC LIMIT ?",
|
||||
(like, like, k),
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def close(self) -> None:
|
||||
if self._con is not None:
|
||||
self._con.close()
|
||||
self._con = None
|
||||
|
||||
|
||||
# Convenience for `run_eval.py --retriever retrievers.fts:FtsRetriever`
|
||||
# and a no-arg default instantiation (sort_by="relevance").
|
||||
570
benchmarks/retrievers/hybrid.py
Normal file
570
benchmarks/retrievers/hybrid.py
Normal file
|
|
@ -0,0 +1,570 @@
|
|||
"""HYBRID retriever (ADR-0001/0002/0003 prototype): lexical FTS + dense semantic
|
||||
recall + a memory-node concept graph, fused with Reciprocal Rank Fusion (RRF).
|
||||
|
||||
This is the self-contained prototype the hybrid-recall ADOPTION decision is gated
|
||||
on (ADR-0001): does dense embeddings + a concept graph beat the current lexical
|
||||
FTS5/BM25 on recall@5/recall@10/nDCG@10/MRR? Quality decides; latency/storage are
|
||||
reported but non-gating.
|
||||
|
||||
It implements the harness ``retrieve(query, k) -> list[int]`` contract and the
|
||||
optional ``build_index(corpus)`` / ``index_size_bytes()`` / ``name`` hooks.
|
||||
|
||||
Three legs, mirroring the FINAL DESIGN
|
||||
======================================
|
||||
|
||||
1. **Lexical (FTS5/BM25).** We reuse the *faithful* production reimplementation
|
||||
``retrievers.fts.FtsRetriever`` verbatim — AND-first then OR-broaden over an
|
||||
FTS5(content, category, tags, expanded_keywords) index, ranked by the blended
|
||||
``(-bm25*0.7 + importance*0.3)``. This is the exact "current system" the hybrid
|
||||
must beat, so the lexical leg of the hybrid IS that system (no drift).
|
||||
|
||||
2. **Dense (semantic).** Embeddings per FINAL DESIGN: a HOSTED API is used ONLY if
|
||||
its key is in the environment (``OPENAI_API_KEY`` / ``VOYAGE_API_KEY`` /
|
||||
``CO_API_KEY``) AND the memory is non-sensitive (ADR-0003); otherwise the local
|
||||
default ``BAAI/bge-large-en-v1.5`` (1024-d, MIT, sentence-transformers). The
|
||||
benchmark corpus is already sensitive-free (``is_sensitive=1`` excluded at
|
||||
export, README privacy note), so here the choice is purely "hosted key present
|
||||
or not". Vectors are L2-normalised; similarity is cosine = dot product. The
|
||||
corpus matrix is cached to ``cache/`` (gitignored) keyed by model id + a corpus
|
||||
fingerprint, so re-runs skip re-embedding. BGE retrieval convention: the QUERY
|
||||
gets the instruction prefix "Represent this sentence for searching relevant
|
||||
passages: "; passages are embedded raw (per the official BAAI model card).
|
||||
|
||||
3. **Graph (concept expansion).** A memory-node concept graph built with the
|
||||
design's TRACTABLE extraction — NO 5452 sequential LLM calls. Concepts are the
|
||||
union of each memory's ``tags`` and its already-LLM-generated
|
||||
``expanded_keywords`` (plus salient content noun-phrases via a lightweight
|
||||
regex/stop-word filter), normalised and de-pluralised. A concept that appears
|
||||
in 2..N memories (very common concepts above a document-frequency ceiling are
|
||||
dropped as non-discriminative) links those memories: ``memory -[shares
|
||||
concept c]- memory``. At query time we take the fused dense+lexical SEEDS, walk
|
||||
1 hop to neighbours that share *discriminative* concepts, and emit those
|
||||
neighbours as a third ranked list. This targets the **multihop** stratum
|
||||
(queries needing 2+ memories that share an entity/concept) without re-ranking
|
||||
the precise hits the other legs already nail.
|
||||
|
||||
Fusion (``retrieval_fusion``)
|
||||
=============================
|
||||
Reciprocal Rank Fusion (Cormack et al., 2009): for a document *d* with rank
|
||||
``r_leg(d)`` (1-based) in a leg's ranked list,
|
||||
|
||||
RRF(d) = Σ_leg w_leg / (k_rrf + r_leg(d))
|
||||
|
||||
with ``k_rrf = 60`` (the standard constant) and per-leg weights. RRF is
|
||||
score-scale-free (no BM25-vs-cosine calibration), which is why the design floats
|
||||
"RRF vs CC" and we pick RRF for the prototype. The dense and lexical legs carry
|
||||
full weight; the graph leg is down-weighted (it is a RECALL extender for multihop,
|
||||
and the design explicitly flags a possible negative graph prior — so it can add
|
||||
documents but should not dethrone strong dense/lexical hits). All three weights
|
||||
are class attributes so the kill-gate analysis can ablate the graph to zero.
|
||||
|
||||
Graceful degradation (task requirement)
|
||||
=======================================
|
||||
If the embedding model cannot be loaded/used (missing package, download failure,
|
||||
OOM), the dense leg is skipped, the failure is recorded in ``self.errors``, and the
|
||||
retriever degrades to **FTS + graph** (or FTS-only if the graph also failed). The
|
||||
harness still gets metrics for whatever worked.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
# ── package-relative imports that also work under direct execution ────────────
|
||||
try: # pragma: no cover - exercised by both import paths
|
||||
from harness.types import Memory, MemoryId
|
||||
from retrievers.fts import FtsRetriever
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
from harness.types import Memory, MemoryId
|
||||
from retrievers.fts import FtsRetriever
|
||||
|
||||
_BENCH_ROOT = Path(__file__).resolve().parents[1]
|
||||
_CACHE_DIR = _BENCH_ROOT / "cache"
|
||||
|
||||
# Local default embedding model (FINAL DESIGN: prototype default + sensitive-only
|
||||
# fallback). 1024-d, MIT-licensed, strong on MTEB retrieval.
|
||||
_LOCAL_MODEL = "BAAI/bge-large-en-v1.5"
|
||||
# BGE retrieval query instruction (official BAAI model card recommendation; the
|
||||
# v1.5 line relaxed it but it still helps short-query / long-passage asymmetry,
|
||||
# which is exactly the paraphrase stratum). Applied to QUERIES only.
|
||||
_BGE_QUERY_INSTRUCTION = "Represent this sentence for searching relevant passages: "
|
||||
|
||||
# RRF constant (Cormack/Clarke/Buettcher 2009). 60 is the canonical default.
|
||||
_RRF_K = 60
|
||||
|
||||
# Concept-graph tuning.
|
||||
# _CONCEPT_MIN_DF : a concept must appear in >= this many memories to form edges
|
||||
# (df==1 links nothing; we need a shared concept).
|
||||
# _CONCEPT_MAX_DF_FRAC : drop concepts appearing in more than this fraction of
|
||||
# the corpus — they are non-discriminative hubs ("memory",
|
||||
# "homelab") that would over-connect the graph (design risk:
|
||||
# "over-merge").
|
||||
# _GRAPH_SEEDS : how many fused seeds to expand from.
|
||||
# _GRAPH_NEIGHBOURS_PER_SEED : cap neighbours pulled per seed (keeps the graph
|
||||
# leg from flooding the candidate pool).
|
||||
_CONCEPT_MIN_DF = 2
|
||||
_CONCEPT_MAX_DF_FRAC = 0.02
|
||||
_GRAPH_SEEDS = 10
|
||||
_GRAPH_NEIGHBOURS_PER_SEED = 25
|
||||
|
||||
# A small English stop-word set for the lightweight noun-phrase extraction. We
|
||||
# deliberately keep this tiny + dependency-free (no spaCy/NLTK download on the hot
|
||||
# path); the heavy lifting is done by the pre-computed ``expanded_keywords``.
|
||||
_STOPWORDS = frozenset(
|
||||
"""
|
||||
a an the of to in on at by for with from into over under and or but not is are
|
||||
was were be been being do does did has have had this that these those it its as
|
||||
if then than so such no yes can will would should could may might must i you he
|
||||
she they we me him her them us my your his their our about above after again all
|
||||
any because before below between both during each few more most other some only
|
||||
own same too very up down out off here there when where which who whom what how
|
||||
""".split()
|
||||
)
|
||||
_WORD_RE = re.compile(r"[A-Za-z][A-Za-z0-9_+.-]{2,}")
|
||||
|
||||
|
||||
def _normalise_concept(token: str) -> str:
|
||||
"""Lowercase, strip surrounding punctuation, light de-plural so concept
|
||||
variants collapse to one node (e.g. 'decisions'->'decision',
|
||||
'addresses'->'address', 'policies'->'policy'). This is a heuristic collapser,
|
||||
not a linguistically perfect stemmer — its only job is to merge obvious
|
||||
plural/singular pairs so the graph links them; exactness is not load-bearing.
|
||||
Order matters: -ies, then -sses, then sibilant -es, then a bare trailing -s."""
|
||||
t = token.lower().strip(".,;:!?()[]{}\"'`")
|
||||
if len(t) > 4 and t.endswith("ies"): # policies -> policy
|
||||
return t[:-3] + "y"
|
||||
if len(t) > 4 and t.endswith("sses"): # addresses -> address, classes -> class
|
||||
return t[:-2]
|
||||
if len(t) > 4 and t.endswith(("ches", "shes", "xes", "zes", "ses")): # boxes->box
|
||||
return t[:-2]
|
||||
if len(t) > 3 and t.endswith("s") and not t.endswith(("ss", "us", "is")): # tags->tag
|
||||
return t[:-1]
|
||||
return t
|
||||
|
||||
|
||||
def _concepts_for(memory: Memory) -> set[str]:
|
||||
"""Extract the concept set for one memory: tags ∪ expanded_keywords ∪ salient
|
||||
content tokens. ``expanded_keywords`` is already an LLM-generated keyword field
|
||||
in the corpus, so this is the design's 'tractable extraction' — we reuse the
|
||||
extraction that production already pays for instead of new LLM calls."""
|
||||
concepts: set[str] = set()
|
||||
# tags: comma-separated
|
||||
for tag in memory.tags.split(","):
|
||||
c = _normalise_concept(tag)
|
||||
if len(c) >= 3 and c not in _STOPWORDS:
|
||||
concepts.add(c)
|
||||
# expanded_keywords: space-separated, already curated
|
||||
for kw in memory.expanded_keywords.split():
|
||||
c = _normalise_concept(kw)
|
||||
if len(c) >= 3 and c not in _STOPWORDS:
|
||||
concepts.add(c)
|
||||
# salient content tokens (lightweight noun-phrase proxy: alpha tokens len>=3,
|
||||
# not stop-words). This is a cheap NER/noun-phrase stand-in per the design.
|
||||
for m in _WORD_RE.finditer(memory.content):
|
||||
c = _normalise_concept(m.group(0))
|
||||
if len(c) >= 3 and c not in _STOPWORDS:
|
||||
concepts.add(c)
|
||||
return concepts
|
||||
|
||||
|
||||
def _corpus_fingerprint(corpus: Sequence[Memory]) -> str:
|
||||
"""Stable hash over (id, content) so the embedding cache invalidates if the
|
||||
corpus changes but is reused across runs of the same corpus."""
|
||||
h = hashlib.sha256()
|
||||
for m in corpus:
|
||||
h.update(str(m.id).encode())
|
||||
h.update(b"\x00")
|
||||
h.update(m.content.encode("utf-8", "replace"))
|
||||
h.update(b"\x01")
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
class HybridRetriever:
|
||||
"""Lexical FTS + dense (bge-large-en-v1.5 / hosted) + concept-graph expansion,
|
||||
fused with RRF. Degrades to FTS(+graph) if embeddings are unavailable."""
|
||||
|
||||
#: Label surfaced in benchmark reports / the RUN schema.
|
||||
name = "hybrid"
|
||||
|
||||
# Per-leg RRF weights. Dense + lexical carry full weight; graph is a
|
||||
# down-weighted recall extender (design: possible negative graph prior).
|
||||
w_dense = 1.0
|
||||
w_fts = 1.0
|
||||
w_graph = 0.35
|
||||
|
||||
def __init__(self, model_name: str | None = None) -> None:
|
||||
self.errors: list[str] = []
|
||||
self.model_name = model_name or _LOCAL_MODEL
|
||||
self.embedding_backend: str = "none" # "local:<model>" | "hosted:<provider>:<model>"
|
||||
self.embedding_dim: int | None = None
|
||||
|
||||
# FTS leg (always available; pure stdlib sqlite).
|
||||
self._fts = FtsRetriever(sort_by="relevance")
|
||||
|
||||
# Dense leg state.
|
||||
self._model = None # SentenceTransformer or None
|
||||
self._np = None # numpy module handle (set on successful dense build)
|
||||
self._emb = None # (N, d) float32 L2-normalised matrix, row i ↔ self._ids[i]
|
||||
self._ids: list[MemoryId] = [] # row order of self._emb
|
||||
|
||||
# Graph leg state.
|
||||
self._graph = None # networkx.Graph or None
|
||||
self._concept_to_mems: dict[str, list[MemoryId]] = {}
|
||||
self._mem_concepts: dict[MemoryId, set[str]] = {}
|
||||
self._n_concepts_total = 0 # before df pruning, for reporting
|
||||
self._n_concepts_kept = 0
|
||||
self._n_edges = 0
|
||||
|
||||
self._corpus_size = 0
|
||||
|
||||
# ── lifecycle: build_index (timed by the runner) ─────────────────────────
|
||||
|
||||
def build_index(self, corpus: Sequence[Memory]) -> None:
|
||||
corpus = list(corpus)
|
||||
self._corpus_size = len(corpus)
|
||||
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1) lexical leg
|
||||
self._fts.build_index(corpus)
|
||||
|
||||
# 2) dense leg (graceful)
|
||||
try:
|
||||
self._build_dense(corpus)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
self.errors.append(f"dense leg disabled: {type(exc).__name__}: {exc}")
|
||||
self._model = None
|
||||
self._emb = None
|
||||
|
||||
# 3) graph leg (graceful)
|
||||
try:
|
||||
self._build_graph(corpus)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
self.errors.append(f"graph leg disabled: {type(exc).__name__}: {exc}")
|
||||
self._graph = None
|
||||
self._concept_to_mems = {}
|
||||
|
||||
# ── dense leg ────────────────────────────────────────────────────────────
|
||||
|
||||
def _select_embedding_backend(self) -> str:
|
||||
"""Pick the embedding backend per FINAL DESIGN: hosted only if a key is in
|
||||
the env (non-sensitive corpus already guaranteed by export), else local.
|
||||
Returns a human label and sets self.model_name accordingly."""
|
||||
if os.environ.get("VOYAGE_API_KEY"):
|
||||
self.model_name = "voyage-3.5"
|
||||
return "hosted:voyage:voyage-3.5"
|
||||
if os.environ.get("OPENAI_API_KEY"):
|
||||
self.model_name = "text-embedding-3-large"
|
||||
return "hosted:openai:text-embedding-3-large"
|
||||
if os.environ.get("CO_API_KEY"):
|
||||
self.model_name = "embed-english-v3.0"
|
||||
return "hosted:cohere:embed-english-v3.0"
|
||||
self.model_name = _LOCAL_MODEL
|
||||
return f"local:{_LOCAL_MODEL}"
|
||||
|
||||
def _build_dense(self, corpus: Sequence[Memory]) -> None:
|
||||
import numpy as np # required for the dense leg
|
||||
|
||||
self._np = np
|
||||
self.embedding_backend = self._select_embedding_backend()
|
||||
self._ids = [m.id for m in corpus]
|
||||
fp = _corpus_fingerprint(corpus)
|
||||
safe_model = self.model_name.replace("/", "_")
|
||||
emb_path = _CACHE_DIR / f"emb_{safe_model}_{fp}.npy"
|
||||
ids_path = _CACHE_DIR / f"emb_{safe_model}_{fp}.ids.npy"
|
||||
|
||||
# cache hit?
|
||||
if emb_path.exists() and ids_path.exists():
|
||||
cached_ids = np.load(ids_path)
|
||||
if list(cached_ids.tolist()) == self._ids:
|
||||
self._emb = np.load(emb_path).astype(np.float32)
|
||||
self.embedding_dim = int(self._emb.shape[1])
|
||||
return # cached embeddings reused
|
||||
|
||||
# cache miss → embed
|
||||
if self.embedding_backend.startswith("hosted:"):
|
||||
vecs = self._embed_hosted([m.content for m in corpus])
|
||||
else:
|
||||
vecs = self._embed_local([m.content for m in corpus])
|
||||
vecs = vecs.astype(np.float32)
|
||||
# L2-normalise so dot product == cosine.
|
||||
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1.0
|
||||
vecs = vecs / norms
|
||||
self._emb = vecs
|
||||
self.embedding_dim = int(vecs.shape[1])
|
||||
np.save(emb_path, vecs)
|
||||
np.save(ids_path, np.array(self._ids, dtype=np.int64))
|
||||
|
||||
def _load_local_model(self):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
if self._model is None:
|
||||
# CPU is fine for ~5.5k short docs; force CPU to avoid CUDA init noise.
|
||||
self._model = SentenceTransformer(_LOCAL_MODEL, device="cpu")
|
||||
# Median memory is ~120 tokens; cap the window at 384 so the rare long
|
||||
# memory (1.6% > 512 tok) doesn't pad an entire batch to 512. bge's
|
||||
# native max is 512; 384 keeps ~p99 intact while bounding CPU cost.
|
||||
self._model.max_seq_length = min(self._model.max_seq_length, 384)
|
||||
return self._model
|
||||
|
||||
def _embed_local(self, texts: list[str]):
|
||||
import numpy as np
|
||||
|
||||
model = self._load_local_model()
|
||||
# Length-sort so each batch pads to a homogeneous length (big CPU win), then
|
||||
# restore original order. Passages embedded raw; the caller L2-normalises so
|
||||
# the local and hosted paths stay byte-for-byte consistent downstream.
|
||||
order = sorted(range(len(texts)), key=lambda i: len(texts[i]))
|
||||
sorted_texts = [texts[i] for i in order]
|
||||
out = model.encode(
|
||||
sorted_texts,
|
||||
batch_size=64,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
out = np.asarray(out)
|
||||
# invert the permutation
|
||||
restored = np.empty_like(out)
|
||||
restored[np.asarray(order)] = out
|
||||
return restored
|
||||
|
||||
def _embed_query_local(self, query: str):
|
||||
import numpy as np
|
||||
|
||||
model = self._load_local_model()
|
||||
out = model.encode(
|
||||
[_BGE_QUERY_INSTRUCTION + query],
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True, # query L2-normalised → cosine via dot
|
||||
show_progress_bar=False,
|
||||
)
|
||||
return np.asarray(out)[0]
|
||||
|
||||
def _embed_hosted(self, texts: list[str]):
|
||||
"""Batch-embed passages via the selected hosted API. Implemented for
|
||||
Voyage / OpenAI / Cohere; only reached when the matching key is set."""
|
||||
import numpy as np
|
||||
|
||||
backend = self.embedding_backend
|
||||
if backend.startswith("hosted:voyage"):
|
||||
import voyageai
|
||||
|
||||
client = voyageai.Client()
|
||||
vecs: list[list[float]] = []
|
||||
for i in range(0, len(texts), 128):
|
||||
batch = texts[i : i + 128]
|
||||
r = client.embed(batch, model="voyage-3.5", input_type="document")
|
||||
vecs.extend(r.embeddings)
|
||||
return np.asarray(vecs)
|
||||
if backend.startswith("hosted:openai"):
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI()
|
||||
vecs = []
|
||||
for i in range(0, len(texts), 256):
|
||||
batch = texts[i : i + 256]
|
||||
r = client.embeddings.create(model="text-embedding-3-large", input=batch)
|
||||
vecs.extend([d.embedding for d in r.data])
|
||||
return np.asarray(vecs)
|
||||
if backend.startswith("hosted:cohere"):
|
||||
import cohere
|
||||
|
||||
client = cohere.Client()
|
||||
vecs = []
|
||||
for i in range(0, len(texts), 96):
|
||||
batch = texts[i : i + 96]
|
||||
r = client.embed(texts=batch, model="embed-english-v3.0", input_type="search_document")
|
||||
vecs.extend(r.embeddings)
|
||||
return np.asarray(vecs)
|
||||
raise RuntimeError(f"unknown hosted backend {backend!r}")
|
||||
|
||||
def _embed_query_hosted(self, query: str):
|
||||
import numpy as np
|
||||
|
||||
backend = self.embedding_backend
|
||||
if backend.startswith("hosted:voyage"):
|
||||
import voyageai
|
||||
|
||||
client = voyageai.Client()
|
||||
r = client.embed([query], model="voyage-3.5", input_type="query")
|
||||
v = np.asarray(r.embeddings[0], dtype=np.float32)
|
||||
elif backend.startswith("hosted:openai"):
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI()
|
||||
r = client.embeddings.create(model="text-embedding-3-large", input=[query])
|
||||
v = np.asarray(r.data[0].embedding, dtype=np.float32)
|
||||
elif backend.startswith("hosted:cohere"):
|
||||
import cohere
|
||||
|
||||
client = cohere.Client()
|
||||
r = client.embed(texts=[query], model="embed-english-v3.0", input_type="search_query")
|
||||
v = np.asarray(r.embeddings[0], dtype=np.float32)
|
||||
else:
|
||||
raise RuntimeError(f"unknown hosted backend {backend!r}")
|
||||
n = np.linalg.norm(v)
|
||||
return v / n if n else v
|
||||
|
||||
def _dense_rank(self, query: str, k: int) -> list[MemoryId]:
|
||||
"""Top-k corpus ids by cosine similarity to the query embedding."""
|
||||
if self._emb is None or self._np is None:
|
||||
return []
|
||||
np = self._np
|
||||
if self.embedding_backend.startswith("hosted:"):
|
||||
qv = self._embed_query_hosted(query)
|
||||
else:
|
||||
qv = self._embed_query_local(query)
|
||||
sims = self._emb @ qv # (N,) cosine sims (both sides L2-normalised)
|
||||
kk = min(k, sims.shape[0])
|
||||
# argpartition for the top-kk, then sort those by score desc.
|
||||
idx = np.argpartition(-sims, kk - 1)[:kk]
|
||||
idx = idx[np.argsort(-sims[idx])]
|
||||
return [self._ids[i] for i in idx]
|
||||
|
||||
# ── graph leg ──────────────────────────────────────────────────────────
|
||||
|
||||
def _build_graph(self, corpus: Sequence[Memory]) -> None:
|
||||
import networkx as nx
|
||||
|
||||
n = len(corpus)
|
||||
max_df = max(_CONCEPT_MIN_DF, int(_CONCEPT_MAX_DF_FRAC * n))
|
||||
|
||||
# concept → set(memory ids)
|
||||
concept_to_mems: dict[str, set[MemoryId]] = defaultdict(set)
|
||||
mem_concepts: dict[MemoryId, set[str]] = {}
|
||||
for m in corpus:
|
||||
cs = _concepts_for(m)
|
||||
mem_concepts[m.id] = cs
|
||||
for c in cs:
|
||||
concept_to_mems[c].add(m.id)
|
||||
self._n_concepts_total = len(concept_to_mems)
|
||||
|
||||
# Keep only discriminative concepts: appear in [_CONCEPT_MIN_DF, max_df]
|
||||
# memories. Below MIN_DF links nothing; above max_df is a non-specific hub.
|
||||
kept: dict[str, list[MemoryId]] = {}
|
||||
for c, mems in concept_to_mems.items():
|
||||
df = len(mems)
|
||||
if _CONCEPT_MIN_DF <= df <= max_df:
|
||||
kept[c] = sorted(mems)
|
||||
self._n_concepts_kept = len(kept)
|
||||
self._concept_to_mems = kept
|
||||
# restrict each memory's concept set to kept concepts (for neighbour scoring)
|
||||
self._mem_concepts = {
|
||||
mid: {c for c in cs if c in kept} for mid, cs in mem_concepts.items()
|
||||
}
|
||||
|
||||
# Build a weighted memory-node graph: edge weight = # shared kept concepts.
|
||||
# We add edges via concept cliques but CAP per-concept fan-out to avoid an
|
||||
# O(df^2) blow-up on the densest kept concepts (design risk: over-merge).
|
||||
g = nx.Graph()
|
||||
g.add_nodes_from(m.id for m in corpus)
|
||||
edge_w: dict[tuple[MemoryId, MemoryId], int] = defaultdict(int)
|
||||
for c, mems in kept.items():
|
||||
# mems is small (<= max_df) by construction; full clique is fine.
|
||||
for i in range(len(mems)):
|
||||
a = mems[i]
|
||||
for j in range(i + 1, len(mems)):
|
||||
b = mems[j]
|
||||
key = (a, b) if a < b else (b, a)
|
||||
edge_w[key] += 1
|
||||
for (a, b), w in edge_w.items():
|
||||
g.add_edge(a, b, weight=w)
|
||||
self._n_edges = g.number_of_edges()
|
||||
self._graph = g
|
||||
|
||||
def _graph_rank(self, seeds: list[MemoryId], exclude: set[MemoryId], k: int) -> list[MemoryId]:
|
||||
"""From fused seeds, walk 1 hop and rank neighbour memories by accumulated
|
||||
edge weight (shared-concept strength), weighted by the seed's own rank so
|
||||
higher-confidence seeds pull harder. Returns up to k NEW ids (not in
|
||||
``exclude``)."""
|
||||
if self._graph is None or not seeds:
|
||||
return []
|
||||
g = self._graph
|
||||
scores: dict[MemoryId, float] = defaultdict(float)
|
||||
for rank, s in enumerate(seeds[:_GRAPH_SEEDS], start=1):
|
||||
if s not in g:
|
||||
continue
|
||||
seed_w = 1.0 / rank # earlier seeds contribute more
|
||||
nbrs = sorted(
|
||||
g[s].items(), key=lambda kv: kv[1].get("weight", 1), reverse=True
|
||||
)[:_GRAPH_NEIGHBOURS_PER_SEED]
|
||||
for nbr, data in nbrs:
|
||||
if nbr in exclude:
|
||||
continue
|
||||
scores[nbr] += seed_w * float(data.get("weight", 1))
|
||||
ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
|
||||
return [mid for mid, _ in ranked[:k]]
|
||||
|
||||
# ── fusion + retrieve ────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _rrf_accumulate(scores: dict[MemoryId, float], ranked: list[MemoryId], weight: float) -> None:
|
||||
for r, mid in enumerate(ranked, start=1):
|
||||
scores[mid] += weight / (_RRF_K + r)
|
||||
|
||||
def retrieve(self, query: str, k: int) -> list[MemoryId]:
|
||||
"""Fuse lexical + dense + graph-expansion via weighted RRF and return the
|
||||
top-k memory ids. Pulls each leg deeper than k so fusion has material to
|
||||
re-order, then truncates."""
|
||||
depth = max(k, 50) # per-leg retrieval depth before fusion
|
||||
|
||||
fts_ranked = self._fts.retrieve(query, depth)
|
||||
dense_ranked = self._dense_rank(query, depth) # [] if dense disabled
|
||||
|
||||
# Seeds for graph expansion: RRF of the two base legs (so the graph walks
|
||||
# from the best-agreed memories, not just one leg's view).
|
||||
seed_scores: dict[MemoryId, float] = defaultdict(float)
|
||||
self._rrf_accumulate(seed_scores, fts_ranked, self.w_fts)
|
||||
self._rrf_accumulate(seed_scores, dense_ranked, self.w_dense)
|
||||
seeds = [mid for mid, _ in sorted(seed_scores.items(), key=lambda kv: kv[1], reverse=True)]
|
||||
base_set = set(seeds)
|
||||
graph_ranked = self._graph_rank(seeds, exclude=base_set, k=depth)
|
||||
|
||||
# Final weighted RRF over all three legs.
|
||||
scores: dict[MemoryId, float] = defaultdict(float)
|
||||
self._rrf_accumulate(scores, fts_ranked, self.w_fts)
|
||||
self._rrf_accumulate(scores, dense_ranked, self.w_dense)
|
||||
self._rrf_accumulate(scores, graph_ranked, self.w_graph)
|
||||
|
||||
fused = sorted(scores.items(), key=lambda kv: (kv[1], -kv[0]), reverse=True)
|
||||
return [mid for mid, _ in fused[:k]]
|
||||
|
||||
# ── reporting hooks ───────────────────────────────────────────────────────
|
||||
|
||||
def index_size_bytes(self) -> int:
|
||||
"""Sum of the dense matrix bytes + FTS index bytes (graph is in-memory
|
||||
networkx; we approximate it via node+edge count * a small constant). Non-
|
||||
gating per ADR-0001; reported for the storage column."""
|
||||
total = 0
|
||||
if self._emb is not None:
|
||||
total += int(self._emb.nbytes)
|
||||
try:
|
||||
total += self._fts.index_size_bytes()
|
||||
except Exception:
|
||||
pass
|
||||
if self._graph is not None:
|
||||
# rough: ~64 B/node + ~96 B/edge accounting for python object overhead.
|
||||
total += self._graph.number_of_nodes() * 64 + self._graph.number_of_edges() * 96
|
||||
return total
|
||||
|
||||
def graph_stats(self) -> dict[str, int]:
|
||||
return {
|
||||
"nodes": self._graph.number_of_nodes() if self._graph is not None else 0,
|
||||
"edges": self._n_edges,
|
||||
"concepts_total": self._n_concepts_total,
|
||||
"concepts_kept": self._n_concepts_kept,
|
||||
}
|
||||
|
||||
def close(self) -> None:
|
||||
self._fts.close()
|
||||
|
||||
|
||||
# Convenience for `run_eval.py --retriever retrievers.hybrid:HybridRetriever`.
|
||||
204
benchmarks/retrievers/test_hybrid.py
Normal file
204
benchmarks/retrievers/test_hybrid.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
"""Unit tests for the HYBRID retriever's pure logic: concept normalisation, the
|
||||
concept-graph build + 1-hop expansion, weighted RRF fusion, and graceful
|
||||
degradation when the dense leg is unavailable.
|
||||
|
||||
These tests are MODEL-FREE on purpose — they never load sentence-transformers (a
|
||||
~1.3 GB / multi-minute CPU load). The dense leg is exercised by monkeypatching the
|
||||
ranking method, so the fusion + graph behaviour is verified deterministically and
|
||||
fast. The full end-to-end quality run is done via scripts/run_eval.py against the
|
||||
real (local, gitignored) corpus.
|
||||
|
||||
Run: .venv/bin/python -m pytest retrievers/test_hybrid.py -q
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from harness.types import Memory
|
||||
from retrievers.hybrid import (
|
||||
_RRF_K,
|
||||
HybridRetriever,
|
||||
_concepts_for,
|
||||
_normalise_concept,
|
||||
)
|
||||
|
||||
|
||||
# ---------------- concept normalisation ----------------
|
||||
|
||||
def test_normalise_concept_depluralisation():
|
||||
cases = {
|
||||
"Decisions": "decision",
|
||||
"policies": "policy",
|
||||
"addresses": "address",
|
||||
"boxes": "box",
|
||||
"tags": "tag",
|
||||
# invariants: don't over-strip
|
||||
"access": "access",
|
||||
"class": "class",
|
||||
"status": "status",
|
||||
"analysis": "analysis",
|
||||
"kubernetes": "kubernete", # heuristic, acceptable (collapses consistently)
|
||||
"k8s": "k8s",
|
||||
"GPU": "gpu",
|
||||
}
|
||||
for inp, exp in cases.items():
|
||||
assert _normalise_concept(inp) == exp, f"{inp!r} -> {_normalise_concept(inp)!r}"
|
||||
|
||||
|
||||
def test_normalise_concept_is_stable_under_repetition():
|
||||
# normalising an already-normalised token must be a no-op (idempotent), so the
|
||||
# graph collapses variants consistently no matter the source field.
|
||||
for tok in ["decision", "policy", "address", "tag", "gpu", "access"]:
|
||||
assert _normalise_concept(_normalise_concept(tok)) == _normalise_concept(tok)
|
||||
|
||||
|
||||
def test_concepts_for_unions_tags_keywords_content():
|
||||
m = Memory(
|
||||
id=1,
|
||||
content="The Postgres cluster uses pgvector for embeddings.",
|
||||
tags="database,postgres",
|
||||
expanded_keywords="cnpg vector search",
|
||||
)
|
||||
cs = _concepts_for(m)
|
||||
# from tags (note: 'postgres' de-plurals to 'postgre' — a consistent heuristic
|
||||
# collapse; what matters is every memory mentioning it lands on the SAME node).
|
||||
assert "database" in cs and "postgre" in cs
|
||||
# from expanded_keywords
|
||||
assert "cnpg" in cs and "vector" in cs and "search" in cs
|
||||
# from content (salient tokens, stop-words removed)
|
||||
assert "pgvector" in cs and "embedding" in cs # 'embeddings' -> 'embedding'
|
||||
assert "the" not in cs and "for" not in cs # stop-words excluded
|
||||
|
||||
|
||||
# ---------------- graph build + expansion ----------------
|
||||
|
||||
def _shared_concept_corpus() -> list[Memory]:
|
||||
# Three memories share concept "alpha" (df=3); two share "beta" (df=2); "gamma"
|
||||
# is unique (df=1, links nothing). With min_df=2 and a generous max_df, alpha
|
||||
# and beta both form edges.
|
||||
return [
|
||||
Memory(id=10, content="alpha topic one", tags="alpha", expanded_keywords="beta"),
|
||||
Memory(id=20, content="alpha topic two", tags="alpha", expanded_keywords="beta"),
|
||||
Memory(id=30, content="alpha topic three", tags="alpha", expanded_keywords="gamma"),
|
||||
Memory(id=40, content="unrelated delta", tags="delta", expanded_keywords="delta"),
|
||||
]
|
||||
|
||||
|
||||
def test_graph_build_links_shared_concepts():
|
||||
r = HybridRetriever()
|
||||
# widen max_df so small-corpus concepts aren't pruned as "hubs"
|
||||
import retrievers.hybrid as H
|
||||
|
||||
old = H._CONCEPT_MAX_DF_FRAC
|
||||
H._CONCEPT_MAX_DF_FRAC = 1.0
|
||||
try:
|
||||
r._build_graph(_shared_concept_corpus())
|
||||
finally:
|
||||
H._CONCEPT_MAX_DF_FRAC = old
|
||||
|
||||
g = r._graph
|
||||
assert g is not None
|
||||
# alpha links 10-20-30 (a triangle); beta links 10-20; "topic" links 10-20-30
|
||||
# too (shared content token). So the triangle exists and 10-20 is the heaviest
|
||||
# edge (they additionally share 'beta').
|
||||
assert g.has_edge(10, 20)
|
||||
assert g.has_edge(10, 30)
|
||||
assert g.has_edge(20, 30)
|
||||
# 10-20 share alpha + beta + topic (=3); 10-30 share alpha + topic (=2). The
|
||||
# exact counts aren't load-bearing — the INVARIANT is w(10,20) > w(10,30).
|
||||
assert g[10][20]["weight"] > g[10][30]["weight"]
|
||||
# the unrelated memory 40 (concept 'delta', df=1) links nothing.
|
||||
assert g.degree(40) == 0
|
||||
stats = r.graph_stats()
|
||||
assert stats["nodes"] == 4 and stats["edges"] >= 3
|
||||
|
||||
|
||||
def test_graph_rank_expands_from_seeds_by_weight():
|
||||
r = HybridRetriever()
|
||||
import retrievers.hybrid as H
|
||||
|
||||
old = H._CONCEPT_MAX_DF_FRAC
|
||||
H._CONCEPT_MAX_DF_FRAC = 1.0
|
||||
try:
|
||||
r._build_graph(_shared_concept_corpus())
|
||||
finally:
|
||||
H._CONCEPT_MAX_DF_FRAC = old
|
||||
|
||||
# Seed from memory 10; neighbours 20 (w=2) and 30 (w=1) should both surface,
|
||||
# with 20 ranked above 30 (heavier shared-concept edge).
|
||||
nbrs = r._graph_rank([10], exclude={10}, k=10)
|
||||
assert nbrs[:2] == [20, 30]
|
||||
# excluded seeds are never returned
|
||||
assert 10 not in nbrs
|
||||
|
||||
|
||||
def test_graph_rank_empty_without_graph_or_seeds():
|
||||
r = HybridRetriever() # no graph built
|
||||
assert r._graph_rank([1, 2], exclude=set(), k=5) == []
|
||||
r._graph = object.__new__(type("G", (), {})) # truthy but unused
|
||||
assert r._graph_rank([], exclude=set(), k=5) == [] # no seeds
|
||||
|
||||
|
||||
# ---------------- RRF fusion ----------------
|
||||
|
||||
def test_rrf_accumulate_formula():
|
||||
scores: dict[int, float] = {}
|
||||
from collections import defaultdict
|
||||
|
||||
scores = defaultdict(float)
|
||||
HybridRetriever._rrf_accumulate(scores, [7, 8, 9], weight=1.0)
|
||||
assert math.isclose(scores[7], 1.0 / (_RRF_K + 1))
|
||||
assert math.isclose(scores[8], 1.0 / (_RRF_K + 2))
|
||||
assert math.isclose(scores[9], 1.0 / (_RRF_K + 3))
|
||||
# a second weighted list adds on top
|
||||
HybridRetriever._rrf_accumulate(scores, [8], weight=0.5)
|
||||
assert math.isclose(scores[8], 1.0 / (_RRF_K + 2) + 0.5 / (_RRF_K + 1))
|
||||
|
||||
|
||||
def test_retrieve_fuses_all_three_legs_and_degrades():
|
||||
"""End-to-end fusion with the dense leg STUBBED (no model). Verifies (a) FTS +
|
||||
dense agreement floats a doc to the top, (b) the graph leg can introduce a doc
|
||||
neither base leg returned, and (c) dense-disabled degrades to FTS(+graph)."""
|
||||
corpus = [
|
||||
Memory(id=1, content="alpha shared concept", tags="alpha", expanded_keywords="alpha"),
|
||||
Memory(id=2, content="alpha shared concept too", tags="alpha", expanded_keywords="alpha"),
|
||||
Memory(id=3, content="beta unrelated", tags="beta", expanded_keywords="beta"),
|
||||
]
|
||||
import retrievers.hybrid as H
|
||||
|
||||
old = H._CONCEPT_MAX_DF_FRAC
|
||||
H._CONCEPT_MAX_DF_FRAC = 1.0
|
||||
try:
|
||||
r = HybridRetriever()
|
||||
# Stub the dense BUILD so the test never loads the ~1.3 GB model nor writes
|
||||
# to the shared cache/ dir; build_index then only does FTS + graph.
|
||||
r._build_dense = lambda _c: None # type: ignore[method-assign]
|
||||
r.build_index(corpus) # FTS + graph build only
|
||||
# Stub the dense RANKER deterministically to "agree" with FTS on doc 1.
|
||||
r._dense_rank = lambda q, k: [1] # type: ignore[method-assign]
|
||||
|
||||
# query matching doc 1 lexically; doc 2 shares concept 'alpha' with doc 1
|
||||
# (graph neighbour) even if FTS ranks it lower.
|
||||
out = r.retrieve("alpha shared concept", k=3)
|
||||
assert out, "should return something"
|
||||
assert out[0] == 1 # FTS+dense agreement puts doc 1 first
|
||||
assert 2 in out # graph expansion (shares 'alpha') pulls doc 2 in
|
||||
finally:
|
||||
H._CONCEPT_MAX_DF_FRAC = old
|
||||
|
||||
|
||||
def test_graceful_degradation_records_error(monkeypatch):
|
||||
"""If the dense build raises, the retriever records it and still serves FTS."""
|
||||
corpus = [Memory(id=i, content=f"doc number {i} content", tags="t") for i in range(1, 6)]
|
||||
r = HybridRetriever()
|
||||
|
||||
def boom(_corpus):
|
||||
raise RuntimeError("simulated embedding failure")
|
||||
|
||||
monkeypatch.setattr(r, "_build_dense", boom)
|
||||
r.build_index(corpus)
|
||||
assert any("dense leg disabled" in e for e in r.errors)
|
||||
assert r._emb is None
|
||||
# FTS still answers
|
||||
out = r.retrieve("doc number 3 content", k=5)
|
||||
assert 3 in out
|
||||
49
benchmarks/scripts/dataset_stats.py
Normal file
49
benchmarks/scripts/dataset_stats.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Validate the eval set and print AGGREGATE stats (safe to share / commit-able
|
||||
numbers only — prints NO raw memory content)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import statistics
|
||||
import sys
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
from harness import load_dataset # noqa: E402
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ds = load_dataset(validate=True) # raises on any referential-integrity issue
|
||||
|
||||
strata = Counter(q.stratum for q in ds.queries)
|
||||
rel_per_q = {s: [] for s in strata}
|
||||
for q in ds.queries:
|
||||
rel_per_q[q.stratum].append(len(ds.qrels[q.query_id]))
|
||||
|
||||
# how many DISTINCT corpus memories are exercised as relevant
|
||||
relevant_union = set()
|
||||
for rels in ds.qrels.values():
|
||||
relevant_union |= rels
|
||||
|
||||
out = {
|
||||
"corpus_count": len(ds.corpus),
|
||||
"query_count": len(ds.queries),
|
||||
"strata": dict(strata),
|
||||
"relevant_ids_per_query": {
|
||||
s: {
|
||||
"min": min(v),
|
||||
"median": statistics.median(v),
|
||||
"max": max(v),
|
||||
"mean": round(statistics.fmean(v), 2),
|
||||
}
|
||||
for s, v in rel_per_q.items()
|
||||
},
|
||||
"distinct_relevant_memories": len(relevant_union),
|
||||
"validation": "PASS (all qrels ids exist in corpus; every query has qrels)",
|
||||
}
|
||||
print(json.dumps(out, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
78
benchmarks/scripts/export_corpus.py
Normal file
78
benchmarks/scripts/export_corpus.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Export the local SQLite memory cache to a LOCAL-ONLY corpus.jsonl.
|
||||
|
||||
Privacy: emits ONLY rows where is_sensitive=0. The output file lives under
|
||||
benchmarks/data/ which is gitignored. NEVER commit corpus.jsonl.
|
||||
|
||||
Fields emitted per line: {id, content, category, tags, expanded_keywords, importance}
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_DB = Path.home() / ".claude" / "claude-memory" / "memory" / "memory.db"
|
||||
DEFAULT_OUT = Path(__file__).resolve().parents[1] / "data" / "corpus.jsonl"
|
||||
|
||||
|
||||
def export(db_path: Path, out_path: Path) -> dict:
|
||||
if not db_path.exists():
|
||||
raise SystemExit(f"DB not found: {db_path}")
|
||||
|
||||
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
||||
con.row_factory = sqlite3.Row
|
||||
cur = con.cursor()
|
||||
|
||||
total = cur.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
sensitive = cur.execute(
|
||||
"SELECT COUNT(*) FROM memories WHERE is_sensitive=1"
|
||||
).fetchone()[0]
|
||||
|
||||
rows = cur.execute(
|
||||
"""
|
||||
SELECT id, content, category, tags, expanded_keywords, importance
|
||||
FROM memories
|
||||
WHERE is_sensitive=0
|
||||
ORDER BY id
|
||||
"""
|
||||
).fetchall()
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
written = 0
|
||||
with out_path.open("w", encoding="utf-8") as f:
|
||||
for r in rows:
|
||||
rec = {
|
||||
"id": r["id"],
|
||||
"content": r["content"],
|
||||
"category": r["category"],
|
||||
"tags": r["tags"],
|
||||
"expanded_keywords": r["expanded_keywords"],
|
||||
"importance": r["importance"],
|
||||
}
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
written += 1
|
||||
con.close()
|
||||
|
||||
return {
|
||||
"total_rows": total,
|
||||
"sensitive_excluded": sensitive,
|
||||
"non_sensitive_written": written,
|
||||
"out_path": str(out_path),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--db", type=Path, default=DEFAULT_DB)
|
||||
ap.add_argument("--out", type=Path, default=DEFAULT_OUT)
|
||||
args = ap.parse_args()
|
||||
stats = export(args.db, args.out)
|
||||
json.dump(stats, sys.stdout, indent=2)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
65
benchmarks/scripts/run_eval.py
Normal file
65
benchmarks/scripts/run_eval.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Run the benchmark for a named retriever and print overall + per-stratum metrics.
|
||||
|
||||
Usage:
|
||||
.venv/bin/python scripts/run_eval.py --retriever fts5 # lexical baseline
|
||||
.venv/bin/python scripts/run_eval.py --retriever substring # demo
|
||||
.venv/bin/python scripts/run_eval.py --retriever mypkg.mymod:MyRetriever
|
||||
.venv/bin/python scripts/run_eval.py --retriever fts5 --json results/fts5.json
|
||||
|
||||
The --retriever value is either a built-in alias or a "module:Class" path. The
|
||||
class is instantiated with no args; the runner calls build_index() if present.
|
||||
|
||||
Outputs are LOCAL-ONLY when written under results/ (gitignored): a results file
|
||||
may echo retrieved ids (not content), but keep it local to be safe.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
from harness import load_dataset, run_benchmark # noqa: E402
|
||||
from harness.baselines import SqliteFtsRetriever # noqa: E402
|
||||
from harness.example_retriever import SubstringRetriever # noqa: E402
|
||||
|
||||
ALIASES = {
|
||||
"fts5": lambda: SqliteFtsRetriever(sort_by="relevance"),
|
||||
"fts5_importance": lambda: SqliteFtsRetriever(sort_by="importance"),
|
||||
"substring": SubstringRetriever,
|
||||
}
|
||||
|
||||
|
||||
def resolve(spec: str):
|
||||
if spec in ALIASES:
|
||||
return ALIASES[spec]()
|
||||
if ":" not in spec:
|
||||
raise SystemExit(f"unknown retriever alias '{spec}' (use module:Class or one of {list(ALIASES)})")
|
||||
mod_name, cls_name = spec.split(":", 1)
|
||||
mod = importlib.import_module(mod_name)
|
||||
return getattr(mod, cls_name)()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--retriever", default="fts5")
|
||||
ap.add_argument("--k", type=int, default=20, help="depth requested from retriever")
|
||||
ap.add_argument("--json", type=Path, default=None, help="write full result JSON here")
|
||||
args = ap.parse_args()
|
||||
|
||||
ds = load_dataset(validate=True)
|
||||
retr = resolve(args.retriever)
|
||||
res = run_benchmark(retr, ds, retrieve_k=args.k)
|
||||
print(res.summary())
|
||||
|
||||
if args.json:
|
||||
args.json.parent.mkdir(parents=True, exist_ok=True)
|
||||
args.json.write_text(json.dumps(res.to_dict(), indent=2))
|
||||
print(f"\nwrote {args.json}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue