research: benchmark hybrid (lexical+dense+graph) recall vs current FTS
Some checks failed
Build and Push / lint-and-test (push) Has been cancelled
Build and Push / build (push) Has been cancelled
Build and Push / deploy (push) Has been cancelled
Build and Push / notify-failure (push) Has been cancelled

Viktor asked to enhance the memory system with 'semantics' — remember concepts
(not just tokens) linked in a graph — and to prove, by benchmarking against the
current system, that it actually improves recall. A multi-phase research workflow
(18 agents) did landscape research, an adversarially-reviewed integration design,
a stratified eval set over the real 5,452-memory corpus, and a head-to-head
prototype-vs-current benchmark.

Result: hybrid (lexical FTS + dense embeddings, RRF-fused) beats FTS on every
overall metric, driven by a robust paraphrase win (recall@10 +0.350). Recommend
adopting lexical+dense; the concept graph is DEFERRED.

Post-run adversarial review correction (applied to all docs before commit): the
prototype's fusion config structurally barred the graph leg from the ranked top-k,
so the 'graph contributes nothing' ablation was a math artifact, NOT an empirical
result — the graph is UNEVALUATED, not disproven (deferred on cost+uncertainty).
Multi-hop deltas are not statistically significant. Glossary in CONTEXT.md; framing
in ADR-0001-0003; findings in ADR-0004-0006 + docs/research/.

Privacy: the corpus/queries/qrels/results are the user's real memories and stay
gitignored (data/, cache/, results/, build_eval_set.py); only harness code,
aggregate numbers, and synthetic examples are committed.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Viktor Barzin 2026-06-25 17:51:53 +00:00
parent 7439540f8f
commit 1cc8a2b378
23 changed files with 3428 additions and 0 deletions

View file

@ -0,0 +1,28 @@
"""Benchmark harness for claude-memory recall evaluation.
Public API:
from harness import Retriever, load_dataset, run_benchmark, BenchmarkResult
from harness import metrics
A retriever is any object (or callable) implementing:
retrieve(query: str, k: int) -> list[memory_id] # ranked, best first
memory_id matches the `id` field in corpus.jsonl / qrels.jsonl (int).
"""
from .types import Retriever, Query, Memory, Qrels
from .dataset import load_dataset, Dataset
from .runner import run_benchmark, BenchmarkResult, StratumResult
from . import metrics
__all__ = [
"Retriever",
"Query",
"Memory",
"Qrels",
"load_dataset",
"Dataset",
"run_benchmark",
"BenchmarkResult",
"StratumResult",
"metrics",
]

View file

@ -0,0 +1,93 @@
"""Reference LEXICAL baseline retrievers that mirror the production system.
These exist so (a) the eval-set author can VERIFY a query's labels and check
that paraphrase queries genuinely defeat lexical matching, and (b) later agents
have an honest "current system" to beat.
`SqliteFtsRetriever` builds an in-memory SQLite FTS5 index over the corpus and
runs the SAME query shape the production local store uses:
words -> '"w1" OR "w2" ...' MATCH, ORDER BY bm25(), importance as tiebreak.
(README "SQLite: FTS5 with BM25".) This is the closest faithful, dependency-free
baseline. The Postgres tsvector path is documented in the README; its ranking
differs (weighted A/B/C/D + importance-first default) but for a quality ceiling
comparison the FTS5/BM25 relevance ordering is the right lexical reference.
"""
from __future__ import annotations
import re
import sqlite3
from collections.abc import Sequence
from .types import Memory, MemoryId
# FTS5 reserved-ish tokens; we quote every term anyway, but strip embedded quotes.
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
class SqliteFtsRetriever:
"""Faithful FTS5/BM25 lexical baseline (mirrors local_store search)."""
name = "sqlite_fts5_bm25"
def __init__(self, sort_by: str = "relevance") -> None:
# "relevance": ORDER BY bm25(), importance DESC (best for quality eval)
# "importance": ORDER BY importance DESC, ... (production default)
self.sort_by = sort_by
self._con: sqlite3.Connection | None = None
def build_index(self, corpus: Sequence[Memory]) -> None:
con = sqlite3.connect(":memory:")
con.execute(
"""
CREATE VIRTUAL TABLE memories_fts USING fts5(
content, category, tags, expanded_keywords,
memory_id UNINDEXED, importance UNINDEXED
)
"""
)
con.executemany(
"INSERT INTO memories_fts(content, category, tags, expanded_keywords, memory_id, importance)"
" VALUES (?,?,?,?,?,?)",
[
(m.content, m.category, m.tags, m.expanded_keywords, m.id, m.importance)
for m in corpus
],
)
con.commit()
self._con = con
def _fts_query(self, query: str) -> str:
words = _WORD_RE.findall(query.lower())
if not words:
return ""
return " OR ".join(f'"{w}"' for w in words)
def retrieve(self, query: str, k: int) -> list[MemoryId]:
assert self._con is not None, "call build_index first"
match = self._fts_query(query)
if not match:
return []
if self.sort_by == "importance":
order = "importance DESC, bm25(memories_fts)"
else:
order = "bm25(memories_fts), importance DESC"
try:
rows = self._con.execute(
f"SELECT memory_id FROM memories_fts WHERE memories_fts MATCH ? "
f"ORDER BY {order} LIMIT ?",
(match, k),
).fetchall()
except sqlite3.OperationalError:
# mirror production LIKE fallback on FTS syntax errors
like = f"%{query}%"
rows = self._con.execute(
"SELECT memory_id FROM memories_fts WHERE content LIKE ? OR tags LIKE ? "
"ORDER BY importance DESC LIMIT ?",
(like, like, k),
).fetchall()
return [r[0] for r in rows]
def close(self) -> None:
if self._con is not None:
self._con.close()
self._con = None

View file

@ -0,0 +1,115 @@
"""Load corpus / queries / qrels JSONL into typed objects."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from .types import Memory, Query, Qrels, MemoryId
_DATA_DIR = Path(__file__).resolve().parents[1] / "data"
@dataclass
class Dataset:
corpus: list[Memory]
queries: list[Query]
qrels: Qrels
@property
def corpus_by_id(self) -> dict[MemoryId, Memory]:
return {m.id: m for m in self.corpus}
def strata(self) -> set[str]:
return {q.stratum for q in self.queries}
def _read_jsonl(path: Path) -> list[dict]:
out: list[dict] = []
with path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
out.append(json.loads(line))
return out
def load_corpus(path: Path | None = None) -> list[Memory]:
path = path or (_DATA_DIR / "corpus.jsonl")
rows = _read_jsonl(path)
return [
Memory(
id=r["id"],
content=r["content"],
category=r.get("category", "facts"),
tags=r.get("tags", "") or "",
expanded_keywords=r.get("expanded_keywords", "") or "",
importance=r.get("importance", 0.5),
)
for r in rows
]
def load_queries(path: Path | None = None) -> list[Query]:
path = path or (_DATA_DIR / "queries.jsonl")
rows = _read_jsonl(path)
return [
Query(
query_id=r["query_id"],
text=r["text"],
stratum=r["stratum"],
relevant_ids=tuple(r.get("relevant_ids", [])),
)
for r in rows
]
def load_qrels(path: Path | None = None) -> Qrels:
path = path or (_DATA_DIR / "qrels.jsonl")
rows = _read_jsonl(path)
qrels: Qrels = {}
for r in rows:
qid = r["query_id"]
rel = set(r["relevant_ids"])
qrels.setdefault(qid, set()).update(rel)
return qrels
def load_dataset(
corpus_path: Path | None = None,
queries_path: Path | None = None,
qrels_path: Path | None = None,
*,
validate: bool = True,
) -> Dataset:
corpus = load_corpus(corpus_path)
queries = load_queries(queries_path)
qrels = load_qrels(qrels_path)
if validate:
_validate(corpus, queries, qrels)
return Dataset(corpus=corpus, queries=queries, qrels=qrels)
def _validate(corpus: list[Memory], queries: list[Query], qrels: Qrels) -> None:
corpus_ids = {m.id for m in corpus}
q_ids = {q.query_id for q in queries}
# Every query must have a qrels entry, and vice versa.
missing_qrels = q_ids - set(qrels)
if missing_qrels:
raise ValueError(f"queries without qrels: {sorted(missing_qrels)[:10]}")
orphan_qrels = set(qrels) - q_ids
if orphan_qrels:
raise ValueError(f"qrels without queries: {sorted(orphan_qrels)[:10]}")
# Every relevant id must exist in the corpus and the set must be non-empty.
for qid, rels in qrels.items():
if not rels:
raise ValueError(f"empty qrels for query {qid}")
unknown = rels - corpus_ids
if unknown:
raise ValueError(
f"query {qid} references non-corpus ids {sorted(unknown)[:10]}"
)

View file

@ -0,0 +1,59 @@
"""Worked example: how a later agent plugs a retriever into the harness.
A retriever needs only one method:
retrieve(self, query: str, k: int) -> list[int] # ranked memory ids
Optionally it may implement lifecycle hooks the runner will use if present:
build_index(self, corpus: list[Memory]) -> None # timed separately
index_size_bytes(self) -> int # reported
Run this file directly for a smoke test against the local eval set:
.venv/bin/python -m harness.example_retriever
"""
from __future__ import annotations
from collections.abc import Sequence
from .types import Memory, MemoryId
class SubstringRetriever:
"""Trivial baseline: rank by count of query-word occurrences in content.
Deliberately weak exists only to demonstrate the interface. The real
lexical baseline is harness.baselines.SqliteFtsRetriever.
"""
name = "substring_demo"
def __init__(self) -> None:
self._corpus: list[Memory] = []
def build_index(self, corpus: Sequence[Memory]) -> None:
self._corpus = list(corpus)
def retrieve(self, query: str, k: int) -> list[MemoryId]:
words = [w for w in query.lower().split() if len(w) > 2]
scored: list[tuple[int, float]] = []
for m in self._corpus:
hay = (m.content + " " + m.expanded_keywords + " " + m.tags).lower()
score = sum(hay.count(w) for w in words)
if score:
scored.append((m.id, score + m.importance)) # importance tiebreak
scored.sort(key=lambda t: t[1], reverse=True)
return [mid for mid, _ in scored[:k]]
def _smoke() -> None:
from .dataset import load_dataset
from .runner import run_benchmark
ds = load_dataset()
res = run_benchmark(SubstringRetriever(), ds)
print(res.summary())
if __name__ == "__main__":
_smoke()

View file

@ -0,0 +1,100 @@
"""Retrieval metrics with BINARY relevance.
Conventions
-----------
- `ranked`: list of memory ids, best-first, as returned by a retriever.
- `relevant`: set of relevant memory ids for the query (from qrels).
- All functions are pure and operate on a single query; the runner aggregates
(macro-average over queries).
Definitions
-----------
recall@k = |relevant ranked[:k]| / |relevant|
(fraction of all relevant items retrieved within the top k)
MRR = 1 / rank_of_first_relevant (0 if none retrieved at all)
nDCG@k = DCG@k / IDCG@k with binary gains (gain=1 for relevant)
DCG@k = sum over i in [1..k] of rel_i / log2(i + 1)
IDCG@k is the DCG of the ideal ranking (all relevant first),
capped at min(|relevant|, k) ones.
Notes
-----
- nDCG uses the standard log2(rank+1) discount (Järvelin & Kekäläinen 2002);
with binary gains this is the common IR convention also used by BEIR/pytrec_eval.
- MRR is reported as the reciprocal rank of the FIRST relevant hit, which for a
single query equals the per-query reciprocal-rank that the runner averages.
- Duplicate ids in `ranked` are de-duplicated keeping first occurrence, so a
retriever cannot inflate recall by repeating an id.
"""
from __future__ import annotations
import math
from collections.abc import Iterable, Sequence
MemoryId = int
def _dedup_keep_order(ranked: Sequence[MemoryId]) -> list[MemoryId]:
seen: set[MemoryId] = set()
out: list[MemoryId] = []
for x in ranked:
if x not in seen:
seen.add(x)
out.append(x)
return out
def recall_at_k(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId], k: int) -> float:
rel = set(relevant)
if not rel:
# Undefined; treat as 0 contribution. Runner should never pass empty.
return 0.0
top = _dedup_keep_order(ranked)[:k]
hits = sum(1 for x in top if x in rel)
return hits / len(rel)
def reciprocal_rank(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId]) -> float:
rel = set(relevant)
if not rel:
return 0.0
for i, x in enumerate(_dedup_keep_order(ranked), start=1):
if x in rel:
return 1.0 / i
return 0.0
def dcg_at_k(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId], k: int) -> float:
rel = set(relevant)
top = _dedup_keep_order(ranked)[:k]
dcg = 0.0
for i, x in enumerate(top, start=1):
if x in rel:
dcg += 1.0 / math.log2(i + 1)
return dcg
def ndcg_at_k(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId], k: int) -> float:
rel = set(relevant)
if not rel:
return 0.0
dcg = dcg_at_k(ranked, rel, k)
ideal_hits = min(len(rel), k)
idcg = sum(1.0 / math.log2(i + 1) for i in range(1, ideal_hits + 1))
if idcg == 0.0:
return 0.0
return dcg / idcg
def per_query_metrics(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId]) -> dict[str, float]:
"""All headline metrics for one query."""
rel = set(relevant)
return {
"recall@5": recall_at_k(ranked, rel, 5),
"recall@10": recall_at_k(ranked, rel, 10),
"ndcg@10": ndcg_at_k(ranked, rel, 10),
"mrr": reciprocal_rank(ranked, rel),
}
METRIC_NAMES = ("recall@5", "recall@10", "ndcg@10", "mrr")

View file

@ -0,0 +1,223 @@
"""Benchmark runner: drive a pluggable retriever over the eval set and report
overall + per-stratum quality metrics, plus per-query latency and (optional)
index build time / size.
Quality decides adoption (recall@k, nDCG@10, MRR). Latency and storage are
measured and reported but DO NOT gate the decision (ADR-0001 success metric).
"""
from __future__ import annotations
import statistics
import time
from collections.abc import Callable
from dataclasses import dataclass, field, asdict
from typing import Any
from . import metrics
from .dataset import Dataset
from .types import MemoryId, Query, Retriever
# A retriever may be the Protocol object or a bare callable retrieve(query, k).
RetrieverLike = Retriever | Callable[[str, int], list[MemoryId]]
# k used for the retrieve() call. We request enough depth to compute all
# metrics (max cutoff is 10) with headroom so ties past k=10 don't distort.
DEFAULT_RETRIEVE_K = 20
def _percentile(values: list[float], pct: float) -> float:
"""Linear-interpolation percentile (pct in [0,100]). Empty -> 0.0."""
if not values:
return 0.0
if len(values) == 1:
return values[0]
s = sorted(values)
rank = (pct / 100.0) * (len(s) - 1)
lo = int(rank)
hi = min(lo + 1, len(s) - 1)
frac = rank - lo
return s[lo] + (s[hi] - s[lo]) * frac
@dataclass
class StratumResult:
stratum: str
n_queries: int
metrics: dict[str, float] # macro-averaged metric -> value
@dataclass
class BenchmarkResult:
retriever_name: str
n_queries: int
retrieve_k: int
overall: dict[str, float]
per_stratum: dict[str, StratumResult]
latency_ms: dict[str, float] # mean / p50 / p95 / max
index_build_seconds: float | None = None
index_size_bytes: int | None = None
per_query: list[dict[str, Any]] = field(default_factory=list)
def to_dict(self) -> dict:
d = asdict(self)
d["per_stratum"] = {k: asdict(v) for k, v in self.per_stratum.items()}
return d
def summary(self) -> str:
lines = [
f"Retriever: {self.retriever_name}",
f"Queries: {self.n_queries} (retrieve_k={self.retrieve_k})",
]
if self.index_build_seconds is not None:
lines.append(f"Index build: {self.index_build_seconds:.3f}s")
if self.index_size_bytes is not None:
lines.append(f"Index size: {self.index_size_bytes / 1e6:.2f} MB")
lat = self.latency_ms
lines.append(
"Latency/query: "
f"p50={lat['p50']:.2f}ms p95={lat['p95']:.2f}ms "
f"mean={lat['mean']:.2f}ms max={lat['max']:.2f}ms"
)
cols = metrics.METRIC_NAMES
header = " ".join(f"{c:>10}" for c in cols)
lines.append("")
lines.append(f"{'stratum':<12}{'n':>5} {header}")
lines.append("-" * (19 + len(header)))
for name in ("overall", *sorted(self.per_stratum)):
if name == "overall":
m, n = self.overall, self.n_queries
else:
sr = self.per_stratum[name]
m, n = sr.metrics, sr.n_queries
row = " ".join(f"{m[c]:>10.4f}" for c in cols)
lines.append(f"{name:<12}{n:>5} {row}")
return "\n".join(lines)
def _get_retrieve_fn(retriever: RetrieverLike) -> Callable[[str, int], list[MemoryId]]:
if hasattr(retriever, "retrieve"):
return retriever.retrieve # type: ignore[attr-defined]
if callable(retriever):
return retriever
raise TypeError("retriever must implement retrieve(query, k) or be callable")
def _maybe_build_index(retriever: RetrieverLike, dataset: Dataset) -> tuple[float | None, int | None]:
"""Call optional lifecycle hooks if present (duck-typed).
- build_index(corpus) -> None : measured wall-clock build time.
- index_size_bytes() -> int : reported on-disk/in-memory index size.
Returns (build_seconds_or_None, size_bytes_or_None).
"""
build_seconds: float | None = None
size_bytes: int | None = None
build = getattr(retriever, "build_index", None)
if callable(build):
t0 = time.perf_counter()
build(dataset.corpus)
build_seconds = time.perf_counter() - t0
size_fn = getattr(retriever, "index_size_bytes", None)
if callable(size_fn):
try:
size_bytes = int(size_fn())
except Exception:
size_bytes = None
return build_seconds, size_bytes
def run_benchmark(
retriever: RetrieverLike,
dataset: Dataset,
*,
retrieve_k: int = DEFAULT_RETRIEVE_K,
retriever_name: str | None = None,
warmup: bool = True,
collect_per_query: bool = True,
) -> BenchmarkResult:
"""Evaluate `retriever` over `dataset`.
The retriever is asked for `retrieve_k` ids per query (>= max metric
cutoff of 10). Metrics are macro-averaged over queries, overall and per
stratum. Latency is measured around each retrieve() call only (index build
is timed separately via the optional build_index hook).
"""
name = retriever_name or getattr(retriever, "name", None) or type(retriever).__name__
retrieve = _get_retrieve_fn(retriever)
qrels = dataset.qrels
build_seconds, size_bytes = _maybe_build_index(retriever, dataset)
# Optional warmup (first call can pay import/JIT/connection costs that would
# skew p95). Excluded from latency stats. Uses the first query if any.
if warmup and dataset.queries:
try:
retrieve(dataset.queries[0].text, retrieve_k)
except Exception:
pass # warmup failures surface on the real call below
per_query_rows: list[dict[str, Any]] = []
latencies_ms: list[float] = []
# accumulate per-stratum metric sums for macro-average
strata: dict[str, dict[str, float]] = {}
strata_counts: dict[str, int] = {}
overall_sums = {m: 0.0 for m in metrics.METRIC_NAMES}
for q in dataset.queries:
rel = qrels[q.query_id]
t0 = time.perf_counter()
ranked = list(retrieve(q.text, retrieve_k))
dt_ms = (time.perf_counter() - t0) * 1000.0
latencies_ms.append(dt_ms)
m = metrics.per_query_metrics(ranked, rel)
for key, val in m.items():
overall_sums[key] += val
strata.setdefault(q.stratum, {mm: 0.0 for mm in metrics.METRIC_NAMES})
strata_counts[q.stratum] = strata_counts.get(q.stratum, 0) + 1
for key, val in m.items():
strata[q.stratum][key] += val
if collect_per_query:
per_query_rows.append(
{
"query_id": q.query_id,
"stratum": q.stratum,
"n_relevant": len(rel),
"latency_ms": round(dt_ms, 3),
"retrieved": ranked[:retrieve_k],
**{k: round(v, 6) for k, v in m.items()},
}
)
n = len(dataset.queries)
overall = {k: (overall_sums[k] / n if n else 0.0) for k in metrics.METRIC_NAMES}
per_stratum: dict[str, StratumResult] = {}
for s, sums in strata.items():
c = strata_counts[s]
per_stratum[s] = StratumResult(
stratum=s,
n_queries=c,
metrics={k: (sums[k] / c if c else 0.0) for k in metrics.METRIC_NAMES},
)
latency_stats = {
"mean": statistics.fmean(latencies_ms) if latencies_ms else 0.0,
"p50": _percentile(latencies_ms, 50),
"p95": _percentile(latencies_ms, 95),
"max": max(latencies_ms) if latencies_ms else 0.0,
}
return BenchmarkResult(
retriever_name=name,
n_queries=n,
retrieve_k=retrieve_k,
overall=overall,
per_stratum=per_stratum,
latency_ms=latency_stats,
index_build_seconds=build_seconds,
index_size_bytes=size_bytes,
per_query=per_query_rows,
)

View file

@ -0,0 +1,145 @@
"""Unit tests for metrics + runner. No real corpus needed (synthetic data).
Run: .venv/bin/python -m pytest harness/test_harness.py -q
"""
from __future__ import annotations
import math
from harness import metrics
from harness.dataset import Dataset
from harness.runner import run_benchmark, _percentile
from harness.types import Memory, Query
# ---------------- metrics ----------------
def test_recall_at_k_basic():
ranked = [9, 8, 3, 7, 1]
rel = {3, 1, 99} # 99 never retrieved
assert metrics.recall_at_k(ranked, rel, 5) == 2 / 3
assert metrics.recall_at_k(ranked, rel, 2) == 0.0 # neither in top2
assert metrics.recall_at_k(ranked, rel, 3) == 1 / 3 # only id 3 in top3
def test_recall_perfect_and_zero():
assert metrics.recall_at_k([1, 2, 3], {1, 2, 3}, 5) == 1.0
assert metrics.recall_at_k([4, 5, 6], {1, 2, 3}, 5) == 0.0
def test_reciprocal_rank():
assert metrics.reciprocal_rank([5, 4, 3], {3}) == 1 / 3
assert metrics.reciprocal_rank([3, 4, 5], {3}) == 1.0
assert metrics.reciprocal_rank([7, 8], {3}) == 0.0
# first relevant wins
assert metrics.reciprocal_rank([9, 3, 1], {1, 3}) == 1 / 2
def test_ndcg_perfect():
# all relevant at the top -> nDCG == 1
assert math.isclose(metrics.ndcg_at_k([1, 2, 3, 4], {1, 2, 3}, 10), 1.0)
def test_ndcg_known_value():
# single relevant doc at rank 2: DCG = 1/log2(3); IDCG = 1/log2(2)=1
ranked = [9, 1, 8]
val = metrics.ndcg_at_k(ranked, {1}, 10)
assert math.isclose(val, (1 / math.log2(3)) / 1.0)
def test_ndcg_two_relevant_suboptimal_order():
# relevant {1,2}; retrieved at ranks 1 and 3
ranked = [1, 9, 2]
dcg = 1 / math.log2(2) + 1 / math.log2(4) # ranks 1 and 3
idcg = 1 / math.log2(2) + 1 / math.log2(3) # ideal ranks 1 and 2
assert math.isclose(metrics.ndcg_at_k(ranked, {1, 2}, 10), dcg / idcg)
def test_dedup_does_not_inflate():
# repeating a relevant id must not increase recall beyond 1 hit's worth
ranked = [3, 3, 3, 3]
assert metrics.recall_at_k(ranked, {3, 7}, 5) == 0.5
assert metrics.reciprocal_rank(ranked, {3}) == 1.0
def test_empty_relevant_is_zero():
assert metrics.recall_at_k([1, 2], set(), 5) == 0.0
assert metrics.ndcg_at_k([1, 2], set(), 5) == 0.0
# ---------------- percentile ----------------
def test_percentile():
vals = [10, 20, 30, 40]
assert _percentile(vals, 50) == 25.0 # interpolated median
assert _percentile(vals, 0) == 10
assert _percentile(vals, 100) == 40
assert _percentile([5.0], 95) == 5.0
assert _percentile([], 50) == 0.0
# ---------------- runner ----------------
def _toy_dataset() -> Dataset:
corpus = [Memory(id=i, content=f"memory {i}") for i in range(1, 11)]
queries = [
Query("q_exact_1", "find 1", "exact", (1,)),
Query("q_para_1", "restate 5", "paraphrase", (5,)),
Query("q_multi_1", "join 3 and 4", "multihop", (3, 4)),
]
qrels = {"q_exact_1": {1}, "q_para_1": {5}, "q_multi_1": {3, 4}}
return Dataset(corpus=corpus, queries=queries, qrels=qrels)
class _PerfectRetriever:
"""Returns exactly the relevant ids first (oracle) — for runner plumbing."""
def __init__(self, qrels):
self._qrels = qrels
self._by_text = None
def build_index(self, corpus):
self._n = len(corpus)
def index_size_bytes(self):
return 1234
def retrieve(self, query, k):
# map query text back via the toy queries' known answers
mapping = {"find 1": [1], "restate 5": [5], "join 3 and 4": [3, 4]}
ids = mapping.get(query, [])
# pad with distractors
pad = [x for x in range(100, 100 + k)]
return (ids + pad)[:k]
def test_runner_perfect_retriever():
ds = _toy_dataset()
r = _PerfectRetriever(ds.qrels)
res = run_benchmark(r, ds, retriever_name="perfect")
assert res.n_queries == 3
assert math.isclose(res.overall["recall@10"], 1.0)
assert math.isclose(res.overall["mrr"], 1.0)
assert math.isclose(res.overall["ndcg@10"], 1.0)
# per-stratum present
assert set(res.per_stratum) == {"exact", "paraphrase", "multihop"}
assert res.per_stratum["multihop"].n_queries == 1
# lifecycle hooks captured
assert res.index_build_seconds is not None
assert res.index_size_bytes == 1234
# latency recorded
assert res.latency_ms["p95"] >= 0.0
def test_runner_callable_retriever_and_misses():
ds = _toy_dataset()
def retrieve(query, k): # always wrong
return [999][:k]
res = run_benchmark(retrieve, ds, retriever_name="bad", warmup=False)
assert res.overall["recall@10"] == 0.0
assert res.overall["mrr"] == 0.0
assert res.index_build_seconds is None # no hook on a bare callable
assert "perfect" not in res.summary()
assert "bad" in res.summary()

View file

@ -0,0 +1,53 @@
"""Core dataclasses and the pluggable Retriever protocol."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Protocol, runtime_checkable
MemoryId = int
@dataclass(frozen=True)
class Memory:
"""One corpus entry (mirrors corpus.jsonl)."""
id: MemoryId
content: str
category: str = "facts"
tags: str = ""
expanded_keywords: str = ""
importance: float = 0.5
@dataclass(frozen=True)
class Query:
"""One eval query (mirrors queries.jsonl)."""
query_id: str
text: str
stratum: str # "exact" | "paraphrase" | "multihop"
# convenience copy of relevant ids; authoritative source is Qrels
relevant_ids: tuple[MemoryId, ...] = field(default_factory=tuple)
# query_id -> set of relevant memory ids (binary relevance)
Qrels = dict[str, set[MemoryId]]
@runtime_checkable
class Retriever(Protocol):
"""Pluggable retriever contract.
Implementations rank corpus memories for a query and return the top-k
memory ids, best match first. The harness will call `retrieve` once per
query and compare against qrels.
Optional lifecycle hooks let a retriever build an index from the corpus
and report index build time / on-disk size; the runner uses them if
present (duck-typed), so a minimal retriever need only implement
`retrieve`.
"""
def retrieve(self, query: str, k: int) -> list[MemoryId]:
"""Return up to k memory ids, ranked best-first."""
...