research: benchmark hybrid (lexical+dense+graph) recall vs current FTS
Viktor asked to enhance the memory system with 'semantics' — remember concepts (not just tokens) linked in a graph — and to prove, by benchmarking against the current system, that it actually improves recall. A multi-phase research workflow (18 agents) did landscape research, an adversarially-reviewed integration design, a stratified eval set over the real 5,452-memory corpus, and a head-to-head prototype-vs-current benchmark. Result: hybrid (lexical FTS + dense embeddings, RRF-fused) beats FTS on every overall metric, driven by a robust paraphrase win (recall@10 +0.350). Recommend adopting lexical+dense; the concept graph is DEFERRED. Post-run adversarial review correction (applied to all docs before commit): the prototype's fusion config structurally barred the graph leg from the ranked top-k, so the 'graph contributes nothing' ablation was a math artifact, NOT an empirical result — the graph is UNEVALUATED, not disproven (deferred on cost+uncertainty). Multi-hop deltas are not statistically significant. Glossary in CONTEXT.md; framing in ADR-0001-0003; findings in ADR-0004-0006 + docs/research/. Privacy: the corpus/queries/qrels/results are the user's real memories and stay gitignored (data/, cache/, results/, build_eval_set.py); only harness code, aggregate numbers, and synthetic examples are committed. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
7439540f8f
commit
1cc8a2b378
23 changed files with 3428 additions and 0 deletions
28
benchmarks/harness/__init__.py
Normal file
28
benchmarks/harness/__init__.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""Benchmark harness for claude-memory recall evaluation.
|
||||
|
||||
Public API:
|
||||
from harness import Retriever, load_dataset, run_benchmark, BenchmarkResult
|
||||
from harness import metrics
|
||||
|
||||
A retriever is any object (or callable) implementing:
|
||||
retrieve(query: str, k: int) -> list[memory_id] # ranked, best first
|
||||
|
||||
memory_id matches the `id` field in corpus.jsonl / qrels.jsonl (int).
|
||||
"""
|
||||
from .types import Retriever, Query, Memory, Qrels
|
||||
from .dataset import load_dataset, Dataset
|
||||
from .runner import run_benchmark, BenchmarkResult, StratumResult
|
||||
from . import metrics
|
||||
|
||||
__all__ = [
|
||||
"Retriever",
|
||||
"Query",
|
||||
"Memory",
|
||||
"Qrels",
|
||||
"load_dataset",
|
||||
"Dataset",
|
||||
"run_benchmark",
|
||||
"BenchmarkResult",
|
||||
"StratumResult",
|
||||
"metrics",
|
||||
]
|
||||
93
benchmarks/harness/baselines.py
Normal file
93
benchmarks/harness/baselines.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""Reference LEXICAL baseline retrievers that mirror the production system.
|
||||
|
||||
These exist so (a) the eval-set author can VERIFY a query's labels and check
|
||||
that paraphrase queries genuinely defeat lexical matching, and (b) later agents
|
||||
have an honest "current system" to beat.
|
||||
|
||||
`SqliteFtsRetriever` builds an in-memory SQLite FTS5 index over the corpus and
|
||||
runs the SAME query shape the production local store uses:
|
||||
words -> '"w1" OR "w2" ...' MATCH, ORDER BY bm25(), importance as tiebreak.
|
||||
(README "SQLite: FTS5 with BM25".) This is the closest faithful, dependency-free
|
||||
baseline. The Postgres tsvector path is documented in the README; its ranking
|
||||
differs (weighted A/B/C/D + importance-first default) but for a quality ceiling
|
||||
comparison the FTS5/BM25 relevance ordering is the right lexical reference.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from collections.abc import Sequence
|
||||
|
||||
from .types import Memory, MemoryId
|
||||
|
||||
# FTS5 reserved-ish tokens; we quote every term anyway, but strip embedded quotes.
|
||||
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||
|
||||
|
||||
class SqliteFtsRetriever:
|
||||
"""Faithful FTS5/BM25 lexical baseline (mirrors local_store search)."""
|
||||
|
||||
name = "sqlite_fts5_bm25"
|
||||
|
||||
def __init__(self, sort_by: str = "relevance") -> None:
|
||||
# "relevance": ORDER BY bm25(), importance DESC (best for quality eval)
|
||||
# "importance": ORDER BY importance DESC, ... (production default)
|
||||
self.sort_by = sort_by
|
||||
self._con: sqlite3.Connection | None = None
|
||||
|
||||
def build_index(self, corpus: Sequence[Memory]) -> None:
|
||||
con = sqlite3.connect(":memory:")
|
||||
con.execute(
|
||||
"""
|
||||
CREATE VIRTUAL TABLE memories_fts USING fts5(
|
||||
content, category, tags, expanded_keywords,
|
||||
memory_id UNINDEXED, importance UNINDEXED
|
||||
)
|
||||
"""
|
||||
)
|
||||
con.executemany(
|
||||
"INSERT INTO memories_fts(content, category, tags, expanded_keywords, memory_id, importance)"
|
||||
" VALUES (?,?,?,?,?,?)",
|
||||
[
|
||||
(m.content, m.category, m.tags, m.expanded_keywords, m.id, m.importance)
|
||||
for m in corpus
|
||||
],
|
||||
)
|
||||
con.commit()
|
||||
self._con = con
|
||||
|
||||
def _fts_query(self, query: str) -> str:
|
||||
words = _WORD_RE.findall(query.lower())
|
||||
if not words:
|
||||
return ""
|
||||
return " OR ".join(f'"{w}"' for w in words)
|
||||
|
||||
def retrieve(self, query: str, k: int) -> list[MemoryId]:
|
||||
assert self._con is not None, "call build_index first"
|
||||
match = self._fts_query(query)
|
||||
if not match:
|
||||
return []
|
||||
if self.sort_by == "importance":
|
||||
order = "importance DESC, bm25(memories_fts)"
|
||||
else:
|
||||
order = "bm25(memories_fts), importance DESC"
|
||||
try:
|
||||
rows = self._con.execute(
|
||||
f"SELECT memory_id FROM memories_fts WHERE memories_fts MATCH ? "
|
||||
f"ORDER BY {order} LIMIT ?",
|
||||
(match, k),
|
||||
).fetchall()
|
||||
except sqlite3.OperationalError:
|
||||
# mirror production LIKE fallback on FTS syntax errors
|
||||
like = f"%{query}%"
|
||||
rows = self._con.execute(
|
||||
"SELECT memory_id FROM memories_fts WHERE content LIKE ? OR tags LIKE ? "
|
||||
"ORDER BY importance DESC LIMIT ?",
|
||||
(like, like, k),
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def close(self) -> None:
|
||||
if self._con is not None:
|
||||
self._con.close()
|
||||
self._con = None
|
||||
115
benchmarks/harness/dataset.py
Normal file
115
benchmarks/harness/dataset.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""Load corpus / queries / qrels JSONL into typed objects."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from .types import Memory, Query, Qrels, MemoryId
|
||||
|
||||
_DATA_DIR = Path(__file__).resolve().parents[1] / "data"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dataset:
|
||||
corpus: list[Memory]
|
||||
queries: list[Query]
|
||||
qrels: Qrels
|
||||
|
||||
@property
|
||||
def corpus_by_id(self) -> dict[MemoryId, Memory]:
|
||||
return {m.id: m for m in self.corpus}
|
||||
|
||||
def strata(self) -> set[str]:
|
||||
return {q.stratum for q in self.queries}
|
||||
|
||||
|
||||
def _read_jsonl(path: Path) -> list[dict]:
|
||||
out: list[dict] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
|
||||
def load_corpus(path: Path | None = None) -> list[Memory]:
|
||||
path = path or (_DATA_DIR / "corpus.jsonl")
|
||||
rows = _read_jsonl(path)
|
||||
return [
|
||||
Memory(
|
||||
id=r["id"],
|
||||
content=r["content"],
|
||||
category=r.get("category", "facts"),
|
||||
tags=r.get("tags", "") or "",
|
||||
expanded_keywords=r.get("expanded_keywords", "") or "",
|
||||
importance=r.get("importance", 0.5),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def load_queries(path: Path | None = None) -> list[Query]:
|
||||
path = path or (_DATA_DIR / "queries.jsonl")
|
||||
rows = _read_jsonl(path)
|
||||
return [
|
||||
Query(
|
||||
query_id=r["query_id"],
|
||||
text=r["text"],
|
||||
stratum=r["stratum"],
|
||||
relevant_ids=tuple(r.get("relevant_ids", [])),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def load_qrels(path: Path | None = None) -> Qrels:
|
||||
path = path or (_DATA_DIR / "qrels.jsonl")
|
||||
rows = _read_jsonl(path)
|
||||
qrels: Qrels = {}
|
||||
for r in rows:
|
||||
qid = r["query_id"]
|
||||
rel = set(r["relevant_ids"])
|
||||
qrels.setdefault(qid, set()).update(rel)
|
||||
return qrels
|
||||
|
||||
|
||||
def load_dataset(
|
||||
corpus_path: Path | None = None,
|
||||
queries_path: Path | None = None,
|
||||
qrels_path: Path | None = None,
|
||||
*,
|
||||
validate: bool = True,
|
||||
) -> Dataset:
|
||||
corpus = load_corpus(corpus_path)
|
||||
queries = load_queries(queries_path)
|
||||
qrels = load_qrels(qrels_path)
|
||||
|
||||
if validate:
|
||||
_validate(corpus, queries, qrels)
|
||||
|
||||
return Dataset(corpus=corpus, queries=queries, qrels=qrels)
|
||||
|
||||
|
||||
def _validate(corpus: list[Memory], queries: list[Query], qrels: Qrels) -> None:
|
||||
corpus_ids = {m.id for m in corpus}
|
||||
q_ids = {q.query_id for q in queries}
|
||||
|
||||
# Every query must have a qrels entry, and vice versa.
|
||||
missing_qrels = q_ids - set(qrels)
|
||||
if missing_qrels:
|
||||
raise ValueError(f"queries without qrels: {sorted(missing_qrels)[:10]}")
|
||||
orphan_qrels = set(qrels) - q_ids
|
||||
if orphan_qrels:
|
||||
raise ValueError(f"qrels without queries: {sorted(orphan_qrels)[:10]}")
|
||||
|
||||
# Every relevant id must exist in the corpus and the set must be non-empty.
|
||||
for qid, rels in qrels.items():
|
||||
if not rels:
|
||||
raise ValueError(f"empty qrels for query {qid}")
|
||||
unknown = rels - corpus_ids
|
||||
if unknown:
|
||||
raise ValueError(
|
||||
f"query {qid} references non-corpus ids {sorted(unknown)[:10]}"
|
||||
)
|
||||
59
benchmarks/harness/example_retriever.py
Normal file
59
benchmarks/harness/example_retriever.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""Worked example: how a later agent plugs a retriever into the harness.
|
||||
|
||||
A retriever needs only one method:
|
||||
|
||||
retrieve(self, query: str, k: int) -> list[int] # ranked memory ids
|
||||
|
||||
Optionally it may implement lifecycle hooks the runner will use if present:
|
||||
|
||||
build_index(self, corpus: list[Memory]) -> None # timed separately
|
||||
index_size_bytes(self) -> int # reported
|
||||
|
||||
Run this file directly for a smoke test against the local eval set:
|
||||
.venv/bin/python -m harness.example_retriever
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from .types import Memory, MemoryId
|
||||
|
||||
|
||||
class SubstringRetriever:
|
||||
"""Trivial baseline: rank by count of query-word occurrences in content.
|
||||
|
||||
Deliberately weak — exists only to demonstrate the interface. The real
|
||||
lexical baseline is harness.baselines.SqliteFtsRetriever.
|
||||
"""
|
||||
|
||||
name = "substring_demo"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._corpus: list[Memory] = []
|
||||
|
||||
def build_index(self, corpus: Sequence[Memory]) -> None:
|
||||
self._corpus = list(corpus)
|
||||
|
||||
def retrieve(self, query: str, k: int) -> list[MemoryId]:
|
||||
words = [w for w in query.lower().split() if len(w) > 2]
|
||||
scored: list[tuple[int, float]] = []
|
||||
for m in self._corpus:
|
||||
hay = (m.content + " " + m.expanded_keywords + " " + m.tags).lower()
|
||||
score = sum(hay.count(w) for w in words)
|
||||
if score:
|
||||
scored.append((m.id, score + m.importance)) # importance tiebreak
|
||||
scored.sort(key=lambda t: t[1], reverse=True)
|
||||
return [mid for mid, _ in scored[:k]]
|
||||
|
||||
|
||||
def _smoke() -> None:
|
||||
from .dataset import load_dataset
|
||||
from .runner import run_benchmark
|
||||
|
||||
ds = load_dataset()
|
||||
res = run_benchmark(SubstringRetriever(), ds)
|
||||
print(res.summary())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_smoke()
|
||||
100
benchmarks/harness/metrics.py
Normal file
100
benchmarks/harness/metrics.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
"""Retrieval metrics with BINARY relevance.
|
||||
|
||||
Conventions
|
||||
-----------
|
||||
- `ranked`: list of memory ids, best-first, as returned by a retriever.
|
||||
- `relevant`: set of relevant memory ids for the query (from qrels).
|
||||
- All functions are pure and operate on a single query; the runner aggregates
|
||||
(macro-average over queries).
|
||||
|
||||
Definitions
|
||||
-----------
|
||||
recall@k = |relevant ∩ ranked[:k]| / |relevant|
|
||||
(fraction of all relevant items retrieved within the top k)
|
||||
MRR = 1 / rank_of_first_relevant (0 if none retrieved at all)
|
||||
nDCG@k = DCG@k / IDCG@k with binary gains (gain=1 for relevant)
|
||||
DCG@k = sum over i in [1..k] of rel_i / log2(i + 1)
|
||||
IDCG@k is the DCG of the ideal ranking (all relevant first),
|
||||
capped at min(|relevant|, k) ones.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- nDCG uses the standard log2(rank+1) discount (Järvelin & Kekäläinen 2002);
|
||||
with binary gains this is the common IR convention also used by BEIR/pytrec_eval.
|
||||
- MRR is reported as the reciprocal rank of the FIRST relevant hit, which for a
|
||||
single query equals the per-query reciprocal-rank that the runner averages.
|
||||
- Duplicate ids in `ranked` are de-duplicated keeping first occurrence, so a
|
||||
retriever cannot inflate recall by repeating an id.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
MemoryId = int
|
||||
|
||||
|
||||
def _dedup_keep_order(ranked: Sequence[MemoryId]) -> list[MemoryId]:
|
||||
seen: set[MemoryId] = set()
|
||||
out: list[MemoryId] = []
|
||||
for x in ranked:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
def recall_at_k(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId], k: int) -> float:
|
||||
rel = set(relevant)
|
||||
if not rel:
|
||||
# Undefined; treat as 0 contribution. Runner should never pass empty.
|
||||
return 0.0
|
||||
top = _dedup_keep_order(ranked)[:k]
|
||||
hits = sum(1 for x in top if x in rel)
|
||||
return hits / len(rel)
|
||||
|
||||
|
||||
def reciprocal_rank(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId]) -> float:
|
||||
rel = set(relevant)
|
||||
if not rel:
|
||||
return 0.0
|
||||
for i, x in enumerate(_dedup_keep_order(ranked), start=1):
|
||||
if x in rel:
|
||||
return 1.0 / i
|
||||
return 0.0
|
||||
|
||||
|
||||
def dcg_at_k(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId], k: int) -> float:
|
||||
rel = set(relevant)
|
||||
top = _dedup_keep_order(ranked)[:k]
|
||||
dcg = 0.0
|
||||
for i, x in enumerate(top, start=1):
|
||||
if x in rel:
|
||||
dcg += 1.0 / math.log2(i + 1)
|
||||
return dcg
|
||||
|
||||
|
||||
def ndcg_at_k(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId], k: int) -> float:
|
||||
rel = set(relevant)
|
||||
if not rel:
|
||||
return 0.0
|
||||
dcg = dcg_at_k(ranked, rel, k)
|
||||
ideal_hits = min(len(rel), k)
|
||||
idcg = sum(1.0 / math.log2(i + 1) for i in range(1, ideal_hits + 1))
|
||||
if idcg == 0.0:
|
||||
return 0.0
|
||||
return dcg / idcg
|
||||
|
||||
|
||||
def per_query_metrics(ranked: Sequence[MemoryId], relevant: Iterable[MemoryId]) -> dict[str, float]:
|
||||
"""All headline metrics for one query."""
|
||||
rel = set(relevant)
|
||||
return {
|
||||
"recall@5": recall_at_k(ranked, rel, 5),
|
||||
"recall@10": recall_at_k(ranked, rel, 10),
|
||||
"ndcg@10": ndcg_at_k(ranked, rel, 10),
|
||||
"mrr": reciprocal_rank(ranked, rel),
|
||||
}
|
||||
|
||||
|
||||
METRIC_NAMES = ("recall@5", "recall@10", "ndcg@10", "mrr")
|
||||
223
benchmarks/harness/runner.py
Normal file
223
benchmarks/harness/runner.py
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
"""Benchmark runner: drive a pluggable retriever over the eval set and report
|
||||
overall + per-stratum quality metrics, plus per-query latency and (optional)
|
||||
index build time / size.
|
||||
|
||||
Quality decides adoption (recall@k, nDCG@10, MRR). Latency and storage are
|
||||
measured and reported but DO NOT gate the decision (ADR-0001 success metric).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import statistics
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Any
|
||||
|
||||
from . import metrics
|
||||
from .dataset import Dataset
|
||||
from .types import MemoryId, Query, Retriever
|
||||
|
||||
# A retriever may be the Protocol object or a bare callable retrieve(query, k).
|
||||
RetrieverLike = Retriever | Callable[[str, int], list[MemoryId]]
|
||||
|
||||
# k used for the retrieve() call. We request enough depth to compute all
|
||||
# metrics (max cutoff is 10) with headroom so ties past k=10 don't distort.
|
||||
DEFAULT_RETRIEVE_K = 20
|
||||
|
||||
|
||||
def _percentile(values: list[float], pct: float) -> float:
|
||||
"""Linear-interpolation percentile (pct in [0,100]). Empty -> 0.0."""
|
||||
if not values:
|
||||
return 0.0
|
||||
if len(values) == 1:
|
||||
return values[0]
|
||||
s = sorted(values)
|
||||
rank = (pct / 100.0) * (len(s) - 1)
|
||||
lo = int(rank)
|
||||
hi = min(lo + 1, len(s) - 1)
|
||||
frac = rank - lo
|
||||
return s[lo] + (s[hi] - s[lo]) * frac
|
||||
|
||||
|
||||
@dataclass
|
||||
class StratumResult:
|
||||
stratum: str
|
||||
n_queries: int
|
||||
metrics: dict[str, float] # macro-averaged metric -> value
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
retriever_name: str
|
||||
n_queries: int
|
||||
retrieve_k: int
|
||||
overall: dict[str, float]
|
||||
per_stratum: dict[str, StratumResult]
|
||||
latency_ms: dict[str, float] # mean / p50 / p95 / max
|
||||
index_build_seconds: float | None = None
|
||||
index_size_bytes: int | None = None
|
||||
per_query: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = asdict(self)
|
||||
d["per_stratum"] = {k: asdict(v) for k, v in self.per_stratum.items()}
|
||||
return d
|
||||
|
||||
def summary(self) -> str:
|
||||
lines = [
|
||||
f"Retriever: {self.retriever_name}",
|
||||
f"Queries: {self.n_queries} (retrieve_k={self.retrieve_k})",
|
||||
]
|
||||
if self.index_build_seconds is not None:
|
||||
lines.append(f"Index build: {self.index_build_seconds:.3f}s")
|
||||
if self.index_size_bytes is not None:
|
||||
lines.append(f"Index size: {self.index_size_bytes / 1e6:.2f} MB")
|
||||
lat = self.latency_ms
|
||||
lines.append(
|
||||
"Latency/query: "
|
||||
f"p50={lat['p50']:.2f}ms p95={lat['p95']:.2f}ms "
|
||||
f"mean={lat['mean']:.2f}ms max={lat['max']:.2f}ms"
|
||||
)
|
||||
cols = metrics.METRIC_NAMES
|
||||
header = " ".join(f"{c:>10}" for c in cols)
|
||||
lines.append("")
|
||||
lines.append(f"{'stratum':<12}{'n':>5} {header}")
|
||||
lines.append("-" * (19 + len(header)))
|
||||
for name in ("overall", *sorted(self.per_stratum)):
|
||||
if name == "overall":
|
||||
m, n = self.overall, self.n_queries
|
||||
else:
|
||||
sr = self.per_stratum[name]
|
||||
m, n = sr.metrics, sr.n_queries
|
||||
row = " ".join(f"{m[c]:>10.4f}" for c in cols)
|
||||
lines.append(f"{name:<12}{n:>5} {row}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _get_retrieve_fn(retriever: RetrieverLike) -> Callable[[str, int], list[MemoryId]]:
|
||||
if hasattr(retriever, "retrieve"):
|
||||
return retriever.retrieve # type: ignore[attr-defined]
|
||||
if callable(retriever):
|
||||
return retriever
|
||||
raise TypeError("retriever must implement retrieve(query, k) or be callable")
|
||||
|
||||
|
||||
def _maybe_build_index(retriever: RetrieverLike, dataset: Dataset) -> tuple[float | None, int | None]:
|
||||
"""Call optional lifecycle hooks if present (duck-typed).
|
||||
|
||||
- build_index(corpus) -> None : measured wall-clock build time.
|
||||
- index_size_bytes() -> int : reported on-disk/in-memory index size.
|
||||
Returns (build_seconds_or_None, size_bytes_or_None).
|
||||
"""
|
||||
build_seconds: float | None = None
|
||||
size_bytes: int | None = None
|
||||
|
||||
build = getattr(retriever, "build_index", None)
|
||||
if callable(build):
|
||||
t0 = time.perf_counter()
|
||||
build(dataset.corpus)
|
||||
build_seconds = time.perf_counter() - t0
|
||||
|
||||
size_fn = getattr(retriever, "index_size_bytes", None)
|
||||
if callable(size_fn):
|
||||
try:
|
||||
size_bytes = int(size_fn())
|
||||
except Exception:
|
||||
size_bytes = None
|
||||
|
||||
return build_seconds, size_bytes
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
retriever: RetrieverLike,
|
||||
dataset: Dataset,
|
||||
*,
|
||||
retrieve_k: int = DEFAULT_RETRIEVE_K,
|
||||
retriever_name: str | None = None,
|
||||
warmup: bool = True,
|
||||
collect_per_query: bool = True,
|
||||
) -> BenchmarkResult:
|
||||
"""Evaluate `retriever` over `dataset`.
|
||||
|
||||
The retriever is asked for `retrieve_k` ids per query (>= max metric
|
||||
cutoff of 10). Metrics are macro-averaged over queries, overall and per
|
||||
stratum. Latency is measured around each retrieve() call only (index build
|
||||
is timed separately via the optional build_index hook).
|
||||
"""
|
||||
name = retriever_name or getattr(retriever, "name", None) or type(retriever).__name__
|
||||
retrieve = _get_retrieve_fn(retriever)
|
||||
qrels = dataset.qrels
|
||||
|
||||
build_seconds, size_bytes = _maybe_build_index(retriever, dataset)
|
||||
|
||||
# Optional warmup (first call can pay import/JIT/connection costs that would
|
||||
# skew p95). Excluded from latency stats. Uses the first query if any.
|
||||
if warmup and dataset.queries:
|
||||
try:
|
||||
retrieve(dataset.queries[0].text, retrieve_k)
|
||||
except Exception:
|
||||
pass # warmup failures surface on the real call below
|
||||
|
||||
per_query_rows: list[dict[str, Any]] = []
|
||||
latencies_ms: list[float] = []
|
||||
# accumulate per-stratum metric sums for macro-average
|
||||
strata: dict[str, dict[str, float]] = {}
|
||||
strata_counts: dict[str, int] = {}
|
||||
overall_sums = {m: 0.0 for m in metrics.METRIC_NAMES}
|
||||
|
||||
for q in dataset.queries:
|
||||
rel = qrels[q.query_id]
|
||||
t0 = time.perf_counter()
|
||||
ranked = list(retrieve(q.text, retrieve_k))
|
||||
dt_ms = (time.perf_counter() - t0) * 1000.0
|
||||
latencies_ms.append(dt_ms)
|
||||
|
||||
m = metrics.per_query_metrics(ranked, rel)
|
||||
for key, val in m.items():
|
||||
overall_sums[key] += val
|
||||
strata.setdefault(q.stratum, {mm: 0.0 for mm in metrics.METRIC_NAMES})
|
||||
strata_counts[q.stratum] = strata_counts.get(q.stratum, 0) + 1
|
||||
for key, val in m.items():
|
||||
strata[q.stratum][key] += val
|
||||
|
||||
if collect_per_query:
|
||||
per_query_rows.append(
|
||||
{
|
||||
"query_id": q.query_id,
|
||||
"stratum": q.stratum,
|
||||
"n_relevant": len(rel),
|
||||
"latency_ms": round(dt_ms, 3),
|
||||
"retrieved": ranked[:retrieve_k],
|
||||
**{k: round(v, 6) for k, v in m.items()},
|
||||
}
|
||||
)
|
||||
|
||||
n = len(dataset.queries)
|
||||
overall = {k: (overall_sums[k] / n if n else 0.0) for k in metrics.METRIC_NAMES}
|
||||
per_stratum: dict[str, StratumResult] = {}
|
||||
for s, sums in strata.items():
|
||||
c = strata_counts[s]
|
||||
per_stratum[s] = StratumResult(
|
||||
stratum=s,
|
||||
n_queries=c,
|
||||
metrics={k: (sums[k] / c if c else 0.0) for k in metrics.METRIC_NAMES},
|
||||
)
|
||||
|
||||
latency_stats = {
|
||||
"mean": statistics.fmean(latencies_ms) if latencies_ms else 0.0,
|
||||
"p50": _percentile(latencies_ms, 50),
|
||||
"p95": _percentile(latencies_ms, 95),
|
||||
"max": max(latencies_ms) if latencies_ms else 0.0,
|
||||
}
|
||||
|
||||
return BenchmarkResult(
|
||||
retriever_name=name,
|
||||
n_queries=n,
|
||||
retrieve_k=retrieve_k,
|
||||
overall=overall,
|
||||
per_stratum=per_stratum,
|
||||
latency_ms=latency_stats,
|
||||
index_build_seconds=build_seconds,
|
||||
index_size_bytes=size_bytes,
|
||||
per_query=per_query_rows,
|
||||
)
|
||||
145
benchmarks/harness/test_harness.py
Normal file
145
benchmarks/harness/test_harness.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""Unit tests for metrics + runner. No real corpus needed (synthetic data).
|
||||
|
||||
Run: .venv/bin/python -m pytest harness/test_harness.py -q
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from harness import metrics
|
||||
from harness.dataset import Dataset
|
||||
from harness.runner import run_benchmark, _percentile
|
||||
from harness.types import Memory, Query
|
||||
|
||||
|
||||
# ---------------- metrics ----------------
|
||||
|
||||
def test_recall_at_k_basic():
|
||||
ranked = [9, 8, 3, 7, 1]
|
||||
rel = {3, 1, 99} # 99 never retrieved
|
||||
assert metrics.recall_at_k(ranked, rel, 5) == 2 / 3
|
||||
assert metrics.recall_at_k(ranked, rel, 2) == 0.0 # neither in top2
|
||||
assert metrics.recall_at_k(ranked, rel, 3) == 1 / 3 # only id 3 in top3
|
||||
|
||||
|
||||
def test_recall_perfect_and_zero():
|
||||
assert metrics.recall_at_k([1, 2, 3], {1, 2, 3}, 5) == 1.0
|
||||
assert metrics.recall_at_k([4, 5, 6], {1, 2, 3}, 5) == 0.0
|
||||
|
||||
|
||||
def test_reciprocal_rank():
|
||||
assert metrics.reciprocal_rank([5, 4, 3], {3}) == 1 / 3
|
||||
assert metrics.reciprocal_rank([3, 4, 5], {3}) == 1.0
|
||||
assert metrics.reciprocal_rank([7, 8], {3}) == 0.0
|
||||
# first relevant wins
|
||||
assert metrics.reciprocal_rank([9, 3, 1], {1, 3}) == 1 / 2
|
||||
|
||||
|
||||
def test_ndcg_perfect():
|
||||
# all relevant at the top -> nDCG == 1
|
||||
assert math.isclose(metrics.ndcg_at_k([1, 2, 3, 4], {1, 2, 3}, 10), 1.0)
|
||||
|
||||
|
||||
def test_ndcg_known_value():
|
||||
# single relevant doc at rank 2: DCG = 1/log2(3); IDCG = 1/log2(2)=1
|
||||
ranked = [9, 1, 8]
|
||||
val = metrics.ndcg_at_k(ranked, {1}, 10)
|
||||
assert math.isclose(val, (1 / math.log2(3)) / 1.0)
|
||||
|
||||
|
||||
def test_ndcg_two_relevant_suboptimal_order():
|
||||
# relevant {1,2}; retrieved at ranks 1 and 3
|
||||
ranked = [1, 9, 2]
|
||||
dcg = 1 / math.log2(2) + 1 / math.log2(4) # ranks 1 and 3
|
||||
idcg = 1 / math.log2(2) + 1 / math.log2(3) # ideal ranks 1 and 2
|
||||
assert math.isclose(metrics.ndcg_at_k(ranked, {1, 2}, 10), dcg / idcg)
|
||||
|
||||
|
||||
def test_dedup_does_not_inflate():
|
||||
# repeating a relevant id must not increase recall beyond 1 hit's worth
|
||||
ranked = [3, 3, 3, 3]
|
||||
assert metrics.recall_at_k(ranked, {3, 7}, 5) == 0.5
|
||||
assert metrics.reciprocal_rank(ranked, {3}) == 1.0
|
||||
|
||||
|
||||
def test_empty_relevant_is_zero():
|
||||
assert metrics.recall_at_k([1, 2], set(), 5) == 0.0
|
||||
assert metrics.ndcg_at_k([1, 2], set(), 5) == 0.0
|
||||
|
||||
|
||||
# ---------------- percentile ----------------
|
||||
|
||||
def test_percentile():
|
||||
vals = [10, 20, 30, 40]
|
||||
assert _percentile(vals, 50) == 25.0 # interpolated median
|
||||
assert _percentile(vals, 0) == 10
|
||||
assert _percentile(vals, 100) == 40
|
||||
assert _percentile([5.0], 95) == 5.0
|
||||
assert _percentile([], 50) == 0.0
|
||||
|
||||
|
||||
# ---------------- runner ----------------
|
||||
|
||||
def _toy_dataset() -> Dataset:
|
||||
corpus = [Memory(id=i, content=f"memory {i}") for i in range(1, 11)]
|
||||
queries = [
|
||||
Query("q_exact_1", "find 1", "exact", (1,)),
|
||||
Query("q_para_1", "restate 5", "paraphrase", (5,)),
|
||||
Query("q_multi_1", "join 3 and 4", "multihop", (3, 4)),
|
||||
]
|
||||
qrels = {"q_exact_1": {1}, "q_para_1": {5}, "q_multi_1": {3, 4}}
|
||||
return Dataset(corpus=corpus, queries=queries, qrels=qrels)
|
||||
|
||||
|
||||
class _PerfectRetriever:
|
||||
"""Returns exactly the relevant ids first (oracle) — for runner plumbing."""
|
||||
|
||||
def __init__(self, qrels):
|
||||
self._qrels = qrels
|
||||
self._by_text = None
|
||||
|
||||
def build_index(self, corpus):
|
||||
self._n = len(corpus)
|
||||
|
||||
def index_size_bytes(self):
|
||||
return 1234
|
||||
|
||||
def retrieve(self, query, k):
|
||||
# map query text back via the toy queries' known answers
|
||||
mapping = {"find 1": [1], "restate 5": [5], "join 3 and 4": [3, 4]}
|
||||
ids = mapping.get(query, [])
|
||||
# pad with distractors
|
||||
pad = [x for x in range(100, 100 + k)]
|
||||
return (ids + pad)[:k]
|
||||
|
||||
|
||||
def test_runner_perfect_retriever():
|
||||
ds = _toy_dataset()
|
||||
r = _PerfectRetriever(ds.qrels)
|
||||
res = run_benchmark(r, ds, retriever_name="perfect")
|
||||
assert res.n_queries == 3
|
||||
assert math.isclose(res.overall["recall@10"], 1.0)
|
||||
assert math.isclose(res.overall["mrr"], 1.0)
|
||||
assert math.isclose(res.overall["ndcg@10"], 1.0)
|
||||
# per-stratum present
|
||||
assert set(res.per_stratum) == {"exact", "paraphrase", "multihop"}
|
||||
assert res.per_stratum["multihop"].n_queries == 1
|
||||
# lifecycle hooks captured
|
||||
assert res.index_build_seconds is not None
|
||||
assert res.index_size_bytes == 1234
|
||||
# latency recorded
|
||||
assert res.latency_ms["p95"] >= 0.0
|
||||
|
||||
|
||||
def test_runner_callable_retriever_and_misses():
|
||||
ds = _toy_dataset()
|
||||
|
||||
def retrieve(query, k): # always wrong
|
||||
return [999][:k]
|
||||
|
||||
res = run_benchmark(retrieve, ds, retriever_name="bad", warmup=False)
|
||||
assert res.overall["recall@10"] == 0.0
|
||||
assert res.overall["mrr"] == 0.0
|
||||
assert res.index_build_seconds is None # no hook on a bare callable
|
||||
assert "perfect" not in res.summary()
|
||||
assert "bad" in res.summary()
|
||||
53
benchmarks/harness/types.py
Normal file
53
benchmarks/harness/types.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Core dataclasses and the pluggable Retriever protocol."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
MemoryId = int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Memory:
|
||||
"""One corpus entry (mirrors corpus.jsonl)."""
|
||||
|
||||
id: MemoryId
|
||||
content: str
|
||||
category: str = "facts"
|
||||
tags: str = ""
|
||||
expanded_keywords: str = ""
|
||||
importance: float = 0.5
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Query:
|
||||
"""One eval query (mirrors queries.jsonl)."""
|
||||
|
||||
query_id: str
|
||||
text: str
|
||||
stratum: str # "exact" | "paraphrase" | "multihop"
|
||||
# convenience copy of relevant ids; authoritative source is Qrels
|
||||
relevant_ids: tuple[MemoryId, ...] = field(default_factory=tuple)
|
||||
|
||||
|
||||
# query_id -> set of relevant memory ids (binary relevance)
|
||||
Qrels = dict[str, set[MemoryId]]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Retriever(Protocol):
|
||||
"""Pluggable retriever contract.
|
||||
|
||||
Implementations rank corpus memories for a query and return the top-k
|
||||
memory ids, best match first. The harness will call `retrieve` once per
|
||||
query and compare against qrels.
|
||||
|
||||
Optional lifecycle hooks let a retriever build an index from the corpus
|
||||
and report index build time / on-disk size; the runner uses them if
|
||||
present (duck-typed), so a minimal retriever need only implement
|
||||
`retrieve`.
|
||||
"""
|
||||
|
||||
def retrieve(self, query: str, k: int) -> list[MemoryId]:
|
||||
"""Return up to k memory ids, ranked best-first."""
|
||||
...
|
||||
Loading…
Add table
Add a link
Reference in a new issue