fix(mcp): serialize local SQLite writes to end "database is locked" under concurrent stores
Under heavy concurrent memory_store (many subagents/sessions writing close
together) the local SQLite layer raced on the single SQLite writer and surfaced
sqlite3.OperationalError: database is locked, which made memory tools slow and
eventually dropped whole sessions. Two root causes:
- The MCP server (mcp_server.py) and the background SyncEngine (sync.py) each
opened a SEPARATE connection to the same SQLite file. WAL allows one writer;
when the sync writer held the lock across a resync, a concurrent store blew
past busy_timeout and raised "database is locked".
- mcp_server's connection was opened WITHOUT check_same_thread=False, so the
moment two requests were handled on different threads every local store/recall
raised ProgrammingError "SQLite objects created in a thread...".
Fix: a single process-wide serialized writer.
- New LocalStore (local_store.py) owns ONE connection (check_same_thread=False)
guarded by ONE re-entrant lock, keeps WAL, and wraps writes in
transaction()/write() with bounded exponential-backoff retry on the rare
residual lock (e.g. another OS process) instead of failing the call.
- MemoryServer builds the LocalStore and SHARES it with the SyncEngine, so the
sync writer no longer opens a second connection — the two-connections race is
eliminated. All server reads/writes go through the shared lock; stores stay
snappy (enqueue-local + async sync).
- Bound the one genuinely slow path (remote semantic memory_recall) with an
explicit RECALL_TIMEOUT and, on timeout/unreachable backend, return a clear
"working / retry" signal instead of hanging silently or crashing. When a
client supplies _meta.progressToken, emit one notifications/progress so the
call shows life.
Ships to users via the plugin (mcp/memory-mcp.json runs src/.../mcp_server.py);
no server-side/API change needed. TDD: added concurrency tests (many threads +
sync writer on one file) and recall progress/bounding tests; full gate green
(ruff + mypy strict + 185 pytest).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
68088e684e
commit
5151bbe0d5
5 changed files with 590 additions and 116 deletions
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -48,3 +48,9 @@ docker/pgdata/
|
||||||
# Beads / Dolt files (added by bd init)
|
# Beads / Dolt files (added by bd init)
|
||||||
.dolt/
|
.dolt/
|
||||||
.beads-credential-key
|
.beads-credential-key
|
||||||
|
|
||||||
|
# Agent git worktrees (standing policy: never the shared checkout)
|
||||||
|
.worktrees/
|
||||||
|
|
||||||
|
# uv lockfile — CI runs `uv sync` itself; not tracked
|
||||||
|
uv.lock
|
||||||
|
|
|
||||||
116
src/claude_memory/local_store.py
Normal file
116
src/claude_memory/local_store.py
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
"""Single, process-wide serialized SQLite writer for the local memory cache.
|
||||||
|
|
||||||
|
SQLite permits only one writer at a time. The MCP server's store path and the
|
||||||
|
background sync engine used to open *separate* connections to the *same* file;
|
||||||
|
under heavy concurrent ``memory_store`` calls those two writers fought over the
|
||||||
|
single SQLite write lock, blew past ``busy_timeout``, and surfaced
|
||||||
|
``sqlite3.OperationalError: database is locked`` — which made the tool slow and
|
||||||
|
eventually dropped the session.
|
||||||
|
|
||||||
|
``LocalStore`` fixes this structurally: it owns ONE connection (opened with
|
||||||
|
``check_same_thread=False``) guarded by ONE re-entrant lock. Every component that
|
||||||
|
needs to touch the local DB shares the same ``LocalStore`` instance, so all
|
||||||
|
writes serialize cleanly through the in-process lock and queue instead of racing
|
||||||
|
the SQLite writer. On the rare residual lock (e.g. another OS process touching
|
||||||
|
the file), writes retry with bounded exponential backoff rather than failing the
|
||||||
|
caller. WAL stays on for concurrent reads.
|
||||||
|
|
||||||
|
Uses only stdlib — no pip install required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, TypeVar
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
# Bounded retry window for the rare residual "database is locked" — handles a
|
||||||
|
# lock held by a *different OS process* (the in-process lock already serializes
|
||||||
|
# this process's own writers). Total worst-case wait ≈ 0.05+0.1+0.2+0.4+0.8 ≈ 1.55s.
|
||||||
|
_MAX_RETRIES = 5
|
||||||
|
_BASE_BACKOFF_S = 0.05
|
||||||
|
_BUSY_TIMEOUT_MS = 30000
|
||||||
|
|
||||||
|
|
||||||
|
def _is_locked_error(exc: sqlite3.OperationalError) -> bool:
|
||||||
|
msg = str(exc).lower()
|
||||||
|
return "database is locked" in msg or "database is busy" in msg
|
||||||
|
|
||||||
|
|
||||||
|
class LocalStore:
|
||||||
|
"""Owns the single shared SQLite connection + lock for local memory writes."""
|
||||||
|
|
||||||
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||||
|
self.conn = conn
|
||||||
|
# Re-entrant so a transaction callback may itself call ``execute``/``write``
|
||||||
|
# without dead-locking on the same thread.
|
||||||
|
self.lock = threading.RLock()
|
||||||
|
|
||||||
|
# ── construction ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def open(cls, db_path: str) -> "LocalStore":
|
||||||
|
"""Open (creating parent dirs) a WAL connection safe for cross-thread use."""
|
||||||
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
conn = sqlite3.connect(db_path, timeout=30.0, check_same_thread=False)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute(f"PRAGMA busy_timeout={_BUSY_TIMEOUT_MS}")
|
||||||
|
return cls(conn)
|
||||||
|
|
||||||
|
# ── serialized access ───────────────────────────────────────────
|
||||||
|
|
||||||
|
def transaction(self, fn: Callable[[sqlite3.Connection], T]) -> T:
|
||||||
|
"""Run ``fn(conn)`` holding the shared lock, with bounded retry on lock errors.
|
||||||
|
|
||||||
|
``fn`` is responsible for issuing its own ``COMMIT`` (call ``conn.commit()``)
|
||||||
|
when it performs writes. The whole callback runs under the process-wide lock,
|
||||||
|
so concurrent callers queue rather than collide on the SQLite writer.
|
||||||
|
"""
|
||||||
|
last_exc: sqlite3.OperationalError | None = None
|
||||||
|
for attempt in range(_MAX_RETRIES):
|
||||||
|
with self.lock:
|
||||||
|
try:
|
||||||
|
return fn(self.conn)
|
||||||
|
except sqlite3.OperationalError as exc:
|
||||||
|
if not _is_locked_error(exc):
|
||||||
|
raise
|
||||||
|
last_exc = exc
|
||||||
|
# Roll back any partial txn so the retry starts clean and the
|
||||||
|
# connection isn't left mid-transaction holding locks.
|
||||||
|
try:
|
||||||
|
self.conn.rollback()
|
||||||
|
except sqlite3.Error:
|
||||||
|
pass
|
||||||
|
# Back off *outside* the lock so other writers can make progress.
|
||||||
|
backoff = _BASE_BACKOFF_S * (2 ** attempt)
|
||||||
|
logger.warning(
|
||||||
|
"SQLite locked (attempt %d/%d) — backing off %.3fs", attempt + 1, _MAX_RETRIES, backoff
|
||||||
|
)
|
||||||
|
time.sleep(backoff)
|
||||||
|
assert last_exc is not None
|
||||||
|
raise last_exc
|
||||||
|
|
||||||
|
def execute(self, sql: str, params: tuple[Any, ...] = ()) -> sqlite3.Cursor:
|
||||||
|
"""Run a single read query under the shared lock (no implicit commit)."""
|
||||||
|
with self.lock:
|
||||||
|
return self.conn.execute(sql, params)
|
||||||
|
|
||||||
|
def write(self, sql: str, params: tuple[Any, ...] = ()) -> sqlite3.Cursor:
|
||||||
|
"""Run a single write statement + commit, serialized and retry-guarded."""
|
||||||
|
|
||||||
|
def _do(conn: sqlite3.Connection) -> sqlite3.Cursor:
|
||||||
|
cur = conn.execute(sql, params)
|
||||||
|
conn.commit()
|
||||||
|
return cur
|
||||||
|
|
||||||
|
return self.transaction(_do)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
with self.lock:
|
||||||
|
self.conn.close()
|
||||||
|
|
@ -17,7 +17,10 @@ import sqlite3
|
||||||
import sys
|
import sys
|
||||||
import urllib.error
|
import urllib.error
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from claude_memory.local_store import LocalStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -35,9 +38,17 @@ HYBRID_MODE = bool(API_KEY) and not SYNC_DISABLED
|
||||||
HTTP_ONLY = bool(API_KEY) and SYNC_DISABLED
|
HTTP_ONLY = bool(API_KEY) and SYNC_DISABLED
|
||||||
SQLITE_ONLY = not API_KEY
|
SQLITE_ONLY = not API_KEY
|
||||||
|
|
||||||
|
# Default per-request HTTP timeout, and a wider bound for the one genuinely slow path:
|
||||||
|
# a remote semantic ``memory_recall`` (embedding/search can be slow to warm up). Both are
|
||||||
|
# explicit upper bounds so a call can never hang the MCP server indefinitely.
|
||||||
|
DEFAULT_API_TIMEOUT = 15
|
||||||
|
RECALL_TIMEOUT = int(os.environ.get("MEMORY_RECALL_TIMEOUT", "30"))
|
||||||
|
|
||||||
def _api_request(method: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
||||||
"""Make an HTTP request to the memory API."""
|
def _api_request(
|
||||||
|
method: str, path: str, body: dict[str, Any] | None = None, timeout: int = DEFAULT_API_TIMEOUT
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Make an HTTP request to the memory API (bounded by ``timeout`` seconds)."""
|
||||||
url = f"{API_BASE_URL}{path}"
|
url = f"{API_BASE_URL}{path}"
|
||||||
data = json.dumps(body).encode() if body else None
|
data = json.dumps(body).encode() if body else None
|
||||||
req = urllib.request.Request(
|
req = urllib.request.Request(
|
||||||
|
|
@ -50,7 +61,7 @@ def _api_request(method: str, path: str, body: dict[str, Any] | None = None) ->
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||||
result: dict[str, Any] = json.loads(resp.read().decode())
|
result: dict[str, Any] = json.loads(resp.read().decode())
|
||||||
return result
|
return result
|
||||||
except urllib.error.HTTPError as e:
|
except urllib.error.HTTPError as e:
|
||||||
|
|
@ -128,7 +139,9 @@ def _init_sqlite(db_path: str | None = None) -> tuple[sqlite3.Connection, str]:
|
||||||
db_path = _get_db_path(db_path)
|
db_path = _get_db_path(db_path)
|
||||||
Path(os.path.dirname(db_path)).mkdir(parents=True, exist_ok=True)
|
Path(os.path.dirname(db_path)).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
conn = sqlite3.connect(db_path, timeout=30.0)
|
# check_same_thread=False: the MCP server may handle requests on different
|
||||||
|
# threads and shares this one connection via LocalStore's lock (see local_store.py).
|
||||||
|
conn = sqlite3.connect(db_path, timeout=30.0, check_same_thread=False)
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
conn.execute("PRAGMA busy_timeout=30000")
|
conn.execute("PRAGMA busy_timeout=30000")
|
||||||
|
|
@ -390,19 +403,30 @@ class MemoryServer:
|
||||||
|
|
||||||
def __init__(self, sqlite_db_path: str | None = None) -> None:
|
def __init__(self, sqlite_db_path: str | None = None) -> None:
|
||||||
self.sqlite_conn: sqlite3.Connection | None = None
|
self.sqlite_conn: sqlite3.Connection | None = None
|
||||||
|
self.store: "LocalStore | None" = None # single serialized writer (see local_store.py)
|
||||||
self.sync_engine: Any = None
|
self.sync_engine: Any = None
|
||||||
|
# Sink for MCP notifications (e.g. progress). Defaults to writing a JSON-RPC
|
||||||
|
# notification line to stdout; overridable in tests.
|
||||||
|
self._notify: Callable[[str, dict[str, Any]], None] = self._emit_notification
|
||||||
|
|
||||||
if SQLITE_ONLY or HYBRID_MODE:
|
if SQLITE_ONLY or HYBRID_MODE:
|
||||||
self.sqlite_conn, resolved_path = _init_sqlite(sqlite_db_path)
|
conn, resolved_path = _init_sqlite(sqlite_db_path)
|
||||||
|
from claude_memory.local_store import LocalStore
|
||||||
|
self.store = LocalStore(conn)
|
||||||
|
self.sqlite_conn = conn
|
||||||
|
|
||||||
if HYBRID_MODE:
|
if HYBRID_MODE:
|
||||||
from claude_memory.sync import SyncEngine
|
from claude_memory.sync import SyncEngine
|
||||||
sync_interval = int(os.environ.get("MEMORY_SYNC_INTERVAL", "60"))
|
sync_interval = int(os.environ.get("MEMORY_SYNC_INTERVAL", "60"))
|
||||||
|
# Share the SAME LocalStore (one connection, one lock) so the background
|
||||||
|
# sync writer never opens a second connection to the file and never races
|
||||||
|
# the store path on the single SQLite writer.
|
||||||
self.sync_engine = SyncEngine(
|
self.sync_engine = SyncEngine(
|
||||||
db_path=resolved_path,
|
db_path=resolved_path,
|
||||||
api_base_url=API_BASE_URL,
|
api_base_url=API_BASE_URL,
|
||||||
api_key=API_KEY,
|
api_key=API_KEY,
|
||||||
sync_interval=sync_interval,
|
sync_interval=sync_interval,
|
||||||
|
store=self.store,
|
||||||
)
|
)
|
||||||
self.sync_engine.start()
|
self.sync_engine.start()
|
||||||
|
|
||||||
|
|
@ -455,31 +479,59 @@ class MemoryServer:
|
||||||
limit = args.get("limit", 10)
|
limit = args.get("limit", 10)
|
||||||
|
|
||||||
if HTTP_ONLY:
|
if HTTP_ONLY:
|
||||||
result = _api_request("POST", "/api/memories/recall", {
|
return self._recall_remote(args)
|
||||||
"context": context,
|
|
||||||
"expanded_query": expanded_query,
|
|
||||||
"category": category,
|
|
||||||
"sort_by": sort_by,
|
|
||||||
"limit": limit,
|
|
||||||
})
|
|
||||||
rows = result.get("memories", [])
|
|
||||||
if not rows:
|
|
||||||
filter_desc = f" in category '{category}'" if category else ""
|
|
||||||
return f"No memories found matching: {context}{filter_desc}"
|
|
||||||
|
|
||||||
sort_desc = "by relevance" if sort_by == "relevance" else "by importance"
|
|
||||||
filter_desc = f" in '{category}'" if category else ""
|
|
||||||
results = []
|
|
||||||
for row in rows:
|
|
||||||
results.append(
|
|
||||||
f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}"
|
|
||||||
f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
|
|
||||||
)
|
|
||||||
return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results)
|
|
||||||
|
|
||||||
# SQLite-only or Hybrid: always read from local cache
|
# SQLite-only or Hybrid: always read from local cache
|
||||||
return self._sqlite_recall(context, expanded_query, category, sort_by, limit)
|
return self._sqlite_recall(context, expanded_query, category, sort_by, limit)
|
||||||
|
|
||||||
|
def _recall_remote(self, args: dict[str, Any]) -> str:
|
||||||
|
"""Remote semantic recall — the one genuinely slow path.
|
||||||
|
|
||||||
|
Bounded by ``RECALL_TIMEOUT`` so it can never hang the MCP server. On a timeout
|
||||||
|
or unreachable backend it returns a clear "working / not done — retry" signal
|
||||||
|
rather than raising or blocking silently.
|
||||||
|
"""
|
||||||
|
context = args.get("context")
|
||||||
|
expanded_query = args.get("expanded_query", "")
|
||||||
|
category = args.get("category")
|
||||||
|
sort_by = args.get("sort_by", "importance")
|
||||||
|
limit = args.get("limit", 10)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = _api_request(
|
||||||
|
"POST", "/api/memories/recall",
|
||||||
|
{
|
||||||
|
"context": context,
|
||||||
|
"expanded_query": expanded_query,
|
||||||
|
"category": category,
|
||||||
|
"sort_by": sort_by,
|
||||||
|
"limit": limit,
|
||||||
|
},
|
||||||
|
timeout=RECALL_TIMEOUT,
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
# _api_request wraps connection errors / timeouts as RuntimeError. Surface a
|
||||||
|
# clear, actionable signal instead of hanging or dumping a stack trace.
|
||||||
|
return (
|
||||||
|
f"Memory recall did not complete within {RECALL_TIMEOUT}s — the semantic "
|
||||||
|
f"search backend is slow or unreachable ({e}). Please try again."
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = result.get("memories", [])
|
||||||
|
if not rows:
|
||||||
|
filter_desc = f" in category '{category}'" if category else ""
|
||||||
|
return f"No memories found matching: {context}{filter_desc}"
|
||||||
|
|
||||||
|
sort_desc = "by relevance" if sort_by == "relevance" else "by importance"
|
||||||
|
filter_desc = f" in '{category}'" if category else ""
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
results.append(
|
||||||
|
f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}"
|
||||||
|
f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
|
||||||
|
)
|
||||||
|
return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results)
|
||||||
|
|
||||||
def memory_list(self, args: dict[str, Any]) -> str:
|
def memory_list(self, args: dict[str, Any]) -> str:
|
||||||
category = args.get("category")
|
category = args.get("category")
|
||||||
limit = args.get("limit", 20)
|
limit = args.get("limit", 20)
|
||||||
|
|
@ -519,10 +571,11 @@ class MemoryServer:
|
||||||
# SQLite-only or Hybrid: delete from local SQLite
|
# SQLite-only or Hybrid: delete from local SQLite
|
||||||
# In hybrid mode, also try to sync delete to server
|
# In hybrid mode, also try to sync delete to server
|
||||||
server_id: int | None = None
|
server_id: int | None = None
|
||||||
if HYBRID_MODE and self.sync_engine and self.sqlite_conn:
|
if HYBRID_MODE and self.sync_engine and self.store:
|
||||||
cursor = self.sqlite_conn.cursor()
|
with self.store.lock:
|
||||||
cursor.execute("SELECT server_id FROM memories WHERE id = ?", (memory_id,))
|
cursor = self.store.conn.cursor()
|
||||||
row = cursor.fetchone()
|
cursor.execute("SELECT server_id FROM memories WHERE id = ?", (memory_id,))
|
||||||
|
row = cursor.fetchone()
|
||||||
server_id = row["server_id"] if row and row["server_id"] else None
|
server_id = row["server_id"] if row and row["server_id"] else None
|
||||||
|
|
||||||
result_text = self._sqlite_delete(memory_id)
|
result_text = self._sqlite_delete(memory_id)
|
||||||
|
|
@ -563,12 +616,13 @@ class MemoryServer:
|
||||||
lines.append(f"Last sync success: {counts['last_sync_success']}")
|
lines.append(f"Last sync success: {counts['last_sync_success']}")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
if self.sqlite_conn:
|
if self.store:
|
||||||
cursor = self.sqlite_conn.cursor()
|
with self.store.lock:
|
||||||
cursor.execute("SELECT COUNT(*) as c FROM memories")
|
cursor = self.store.conn.cursor()
|
||||||
total = cursor.fetchone()["c"]
|
cursor.execute("SELECT COUNT(*) as c FROM memories")
|
||||||
cursor.execute("SELECT category, COUNT(*) as c FROM memories GROUP BY category ORDER BY c DESC")
|
total = cursor.fetchone()["c"]
|
||||||
by_cat = cursor.fetchall()
|
cursor.execute("SELECT category, COUNT(*) as c FROM memories GROUP BY category ORDER BY c DESC")
|
||||||
|
by_cat = cursor.fetchall()
|
||||||
lines = [f"Local memories (SQLite-only): {total}"]
|
lines = [f"Local memories (SQLite-only): {total}"]
|
||||||
for row in by_cat:
|
for row in by_cat:
|
||||||
lines.append(f" {row['category']}: {row['c']}")
|
lines.append(f" {row['category']}: {row['c']}")
|
||||||
|
|
@ -682,19 +736,25 @@ class MemoryServer:
|
||||||
def _sqlite_store(self, content: str, category: str, tags: str, importance: float, expanded_keywords: str, force_sensitive: bool = False) -> str:
|
def _sqlite_store(self, content: str, category: str, tags: str, importance: float, expanded_keywords: str, force_sensitive: bool = False) -> str:
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
assert self.sqlite_conn is not None
|
assert self.store is not None
|
||||||
now = datetime.now(timezone.utc).isoformat()
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
is_sensitive = 1 if force_sensitive else 0
|
is_sensitive = 1 if force_sensitive else 0
|
||||||
cursor = self.sqlite_conn.cursor()
|
|
||||||
cursor.execute(
|
def _insert(conn: sqlite3.Connection) -> int | None:
|
||||||
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
cursor = conn.execute(
|
||||||
(content, category, tags, expanded_keywords, importance, is_sensitive, now, now),
|
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
)
|
(content, category, tags, expanded_keywords, importance, is_sensitive, now, now),
|
||||||
self.sqlite_conn.commit()
|
)
|
||||||
return f"Stored memory #{cursor.lastrowid} in category '{category}' with importance {importance:.1f}"
|
conn.commit()
|
||||||
|
return cursor.lastrowid
|
||||||
|
|
||||||
|
# Serialized + retry-guarded through the shared LocalStore so concurrent
|
||||||
|
# stores (and the background sync writer) never collide on the SQLite writer.
|
||||||
|
new_id = self.store.transaction(_insert)
|
||||||
|
return f"Stored memory #{new_id} in category '{category}' with importance {importance:.1f}"
|
||||||
|
|
||||||
def _sqlite_recall(self, context: str, expanded_query: str, category: str | None, sort_by: str, limit: int) -> str:
|
def _sqlite_recall(self, context: str, expanded_query: str, category: str | None, sort_by: str, limit: int) -> str:
|
||||||
assert self.sqlite_conn is not None
|
assert self.store is not None
|
||||||
all_terms = f"{context} {expanded_query}".strip()
|
all_terms = f"{context} {expanded_query}".strip()
|
||||||
words = [w.replace(chr(34), "") for w in all_terms.split() if w]
|
words = [w.replace(chr(34), "") for w in all_terms.split() if w]
|
||||||
and_query = " AND ".join(f'"{w}"' for w in words)
|
and_query = " AND ".join(f'"{w}"' for w in words)
|
||||||
|
|
@ -712,35 +772,38 @@ class MemoryServer:
|
||||||
"SELECT m.id, m.content, m.category, m.tags, m.importance, m.created_at "
|
"SELECT m.id, m.content, m.category, m.tags, m.importance, m.created_at "
|
||||||
"FROM memories m JOIN memories_fts fts ON m.id = fts.rowid "
|
"FROM memories m JOIN memories_fts fts ON m.id = fts.rowid "
|
||||||
)
|
)
|
||||||
cursor = self.sqlite_conn.cursor()
|
|
||||||
rows: list[Any] = []
|
rows: list[Any] = []
|
||||||
try:
|
# Hold the shared lock for the whole read — the connection is shared across
|
||||||
# Try AND first for precise matches, fall back to OR for broader results
|
# threads with the sync writer, so reads must serialize too.
|
||||||
cat_filter = "AND m.category = ?" if category else ""
|
with self.store.lock:
|
||||||
for fts_query in (and_query, or_query):
|
cursor = self.store.conn.cursor()
|
||||||
params = [fts_query, category, limit] if category else [fts_query, limit]
|
try:
|
||||||
cursor.execute(
|
# Try AND first for precise matches, fall back to OR for broader results
|
||||||
f"{base_select}WHERE memories_fts MATCH ? {cat_filter} ORDER BY {order} LIMIT ?",
|
cat_filter = "AND m.category = ?" if category else ""
|
||||||
tuple(p for p in params if p is not None),
|
for fts_query in (and_query, or_query):
|
||||||
)
|
params = [fts_query, category, limit] if category else [fts_query, limit]
|
||||||
|
cursor.execute(
|
||||||
|
f"{base_select}WHERE memories_fts MATCH ? {cat_filter} ORDER BY {order} LIMIT ?",
|
||||||
|
tuple(p for p in params if p is not None),
|
||||||
|
)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
if rows:
|
||||||
|
break
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
like = f"%{context}%"
|
||||||
|
if category:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||||
|
"WHERE (content LIKE ? OR tags LIKE ?) AND category = ? ORDER BY importance DESC LIMIT ?",
|
||||||
|
(like, like, category, limit),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||||
|
"WHERE content LIKE ? OR tags LIKE ? ORDER BY importance DESC LIMIT ?",
|
||||||
|
(like, like, limit),
|
||||||
|
)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
if rows:
|
|
||||||
break
|
|
||||||
except sqlite3.OperationalError:
|
|
||||||
like = f"%{context}%"
|
|
||||||
if category:
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
|
||||||
"WHERE (content LIKE ? OR tags LIKE ?) AND category = ? ORDER BY importance DESC LIMIT ?",
|
|
||||||
(like, like, category, limit),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
|
||||||
"WHERE content LIKE ? OR tags LIKE ? ORDER BY importance DESC LIMIT ?",
|
|
||||||
(like, like, limit),
|
|
||||||
)
|
|
||||||
rows = cursor.fetchall()
|
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
return f"No memories found matching: {context}"
|
return f"No memories found matching: {context}"
|
||||||
|
|
@ -757,21 +820,22 @@ class MemoryServer:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _sqlite_list(self, category: str | None, limit: int) -> str:
|
def _sqlite_list(self, category: str | None, limit: int) -> str:
|
||||||
assert self.sqlite_conn is not None
|
assert self.store is not None
|
||||||
cursor = self.sqlite_conn.cursor()
|
with self.store.lock:
|
||||||
if category:
|
cursor = self.store.conn.cursor()
|
||||||
cursor.execute(
|
if category:
|
||||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
cursor.execute(
|
||||||
"WHERE category = ? ORDER BY created_at DESC LIMIT ?",
|
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||||
(category, limit),
|
"WHERE category = ? ORDER BY created_at DESC LIMIT ?",
|
||||||
)
|
(category, limit),
|
||||||
else:
|
)
|
||||||
cursor.execute(
|
else:
|
||||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
cursor.execute(
|
||||||
"ORDER BY created_at DESC LIMIT ?",
|
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||||
(limit,),
|
"ORDER BY created_at DESC LIMIT ?",
|
||||||
)
|
(limit,),
|
||||||
rows = cursor.fetchall()
|
)
|
||||||
|
rows = cursor.fetchall()
|
||||||
if not rows:
|
if not rows:
|
||||||
return f"No memories in category '{category}'" if category else "No memories stored yet"
|
return f"No memories in category '{category}'" if category else "No memories stored yet"
|
||||||
|
|
||||||
|
|
@ -785,25 +849,30 @@ class MemoryServer:
|
||||||
return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results)
|
return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results)
|
||||||
|
|
||||||
def _sqlite_delete(self, memory_id: int) -> str:
|
def _sqlite_delete(self, memory_id: int) -> str:
|
||||||
assert self.sqlite_conn is not None
|
assert self.store is not None
|
||||||
cursor = self.sqlite_conn.cursor()
|
|
||||||
cursor.execute("SELECT id, content FROM memories WHERE id = ?", (memory_id,))
|
def _delete(conn: sqlite3.Connection) -> str:
|
||||||
row = cursor.fetchone()
|
cursor = conn.execute("SELECT id, content FROM memories WHERE id = ?", (memory_id,))
|
||||||
if not row:
|
row = cursor.fetchone()
|
||||||
return f"Memory #{memory_id} not found"
|
if not row:
|
||||||
preview = row["content"][:50]
|
return f"Memory #{memory_id} not found"
|
||||||
cursor.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
|
preview = row["content"][:50]
|
||||||
self.sqlite_conn.commit()
|
conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
|
||||||
return f"Deleted memory #{memory_id}: {preview}..."
|
conn.commit()
|
||||||
|
return f"Deleted memory #{memory_id}: {preview}..."
|
||||||
|
|
||||||
|
# SELECT + DELETE + commit as one serialized, retry-guarded unit.
|
||||||
|
return self.store.transaction(_delete)
|
||||||
|
|
||||||
def _sqlite_secret_get(self, memory_id: int) -> str:
|
def _sqlite_secret_get(self, memory_id: int) -> str:
|
||||||
assert self.sqlite_conn is not None
|
assert self.store is not None
|
||||||
cursor = self.sqlite_conn.cursor()
|
with self.store.lock:
|
||||||
cursor.execute(
|
cursor = self.store.conn.cursor()
|
||||||
"SELECT id, content, category, is_sensitive FROM memories WHERE id = ?",
|
cursor.execute(
|
||||||
(memory_id,),
|
"SELECT id, content, category, is_sensitive FROM memories WHERE id = ?",
|
||||||
)
|
(memory_id,),
|
||||||
row = cursor.fetchone()
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
return f"Memory #{memory_id} not found"
|
return f"Memory #{memory_id} not found"
|
||||||
if not row["is_sensitive"]:
|
if not row["is_sensitive"]:
|
||||||
|
|
@ -825,9 +894,25 @@ class MemoryServer:
|
||||||
tools.extend(SHARING_TOOLS)
|
tools.extend(SHARING_TOOLS)
|
||||||
return {"tools": tools}
|
return {"tools": tools}
|
||||||
|
|
||||||
|
# Tools whose work is genuinely slow enough to warrant a progress signal.
|
||||||
|
_SLOW_TOOLS = frozenset({"memory_recall"})
|
||||||
|
|
||||||
def handle_tools_call(self, params: dict[str, Any]) -> dict[str, Any]:
|
def handle_tools_call(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||||
tool_name: str = params.get("name", "")
|
tool_name: str = params.get("name", "")
|
||||||
arguments: dict[str, Any] = params.get("arguments", {})
|
arguments: dict[str, Any] = params.get("arguments", {})
|
||||||
|
|
||||||
|
# If the client opted into progress (sent a token) and this is a slow tool, emit a
|
||||||
|
# single "working" progress notification so the call shows life instead of looking hung.
|
||||||
|
progress_token = (params.get("_meta") or {}).get("progressToken")
|
||||||
|
if progress_token is not None and tool_name in self._SLOW_TOOLS:
|
||||||
|
try:
|
||||||
|
self._notify(
|
||||||
|
"notifications/progress",
|
||||||
|
{"progressToken": progress_token, "progress": 0, "message": f"Running {tool_name}…"},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # never let progress reporting break the actual call
|
||||||
|
|
||||||
try:
|
try:
|
||||||
handler = {
|
handler = {
|
||||||
"memory_store": self.memory_store,
|
"memory_store": self.memory_store,
|
||||||
|
|
@ -851,6 +936,10 @@ class MemoryServer:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"content": [{"type": "text", "text": f"Error: {e!s}"}], "isError": True}
|
return {"content": [{"type": "text", "text": f"Error: {e!s}"}], "isError": True}
|
||||||
|
|
||||||
|
def _emit_notification(self, method: str, params: dict[str, Any]) -> None:
|
||||||
|
"""Default notification sink: write a JSON-RPC notification line to stdout."""
|
||||||
|
print(json.dumps({"jsonrpc": "2.0", "method": method, "params": params}), flush=True)
|
||||||
|
|
||||||
def process_message(self, message: dict[str, Any]) -> dict[str, Any] | None:
|
def process_message(self, message: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
method = message.get("method")
|
method = message.get("method")
|
||||||
params = message.get("params", {})
|
params = message.get("params", {})
|
||||||
|
|
|
||||||
|
|
@ -5,14 +5,14 @@ Uses only stdlib — no pip install required.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import sqlite3
|
|
||||||
import threading
|
import threading
|
||||||
import urllib.error
|
import urllib.error
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
|
||||||
|
from claude_memory.local_store import LocalStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -26,7 +26,14 @@ FULL_RESYNC_EVERY = 10
|
||||||
class SyncEngine:
|
class SyncEngine:
|
||||||
"""Background sync between local SQLite cache and remote API."""
|
"""Background sync between local SQLite cache and remote API."""
|
||||||
|
|
||||||
def __init__(self, db_path: str, api_base_url: str, api_key: str, sync_interval: int = 60):
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str,
|
||||||
|
api_base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
sync_interval: int = 60,
|
||||||
|
store: "LocalStore | None" = None,
|
||||||
|
):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self.api_base_url = api_base_url.rstrip("/")
|
self.api_base_url = api_base_url.rstrip("/")
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
@ -37,13 +44,20 @@ class SyncEngine:
|
||||||
self._last_sync_success = False
|
self._last_sync_success = False
|
||||||
self._auth_failed = False
|
self._auth_failed = False
|
||||||
|
|
||||||
# Own connection for thread safety
|
# Share the caller's LocalStore (one connection, one lock) when given, so the
|
||||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
# background sync writer never opens a SECOND connection to the same file and
|
||||||
self._conn = sqlite3.connect(db_path, timeout=30.0, check_same_thread=False)
|
# never races the store path on the single SQLite writer. When run standalone
|
||||||
self._conn.row_factory = sqlite3.Row
|
# (e.g. tests), open our own store. Either way, all SQLite access below goes
|
||||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
# through self._store / self._conn / self._lock, which now point at one shared
|
||||||
self._conn.execute("PRAGMA busy_timeout=30000")
|
# connection guarded by one re-entrant lock.
|
||||||
self._lock = threading.Lock()
|
if store is None:
|
||||||
|
self._store = LocalStore.open(db_path)
|
||||||
|
self._owns_store = True
|
||||||
|
else:
|
||||||
|
self._store = store
|
||||||
|
self._owns_store = False
|
||||||
|
self._conn = self._store.conn
|
||||||
|
self._lock = self._store.lock
|
||||||
|
|
||||||
self._init_sync_tables()
|
self._init_sync_tables()
|
||||||
|
|
||||||
|
|
@ -121,7 +135,10 @@ class SyncEngine:
|
||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
if self._thread and self._thread.is_alive():
|
if self._thread and self._thread.is_alive():
|
||||||
self._thread.join(timeout=5)
|
self._thread.join(timeout=5)
|
||||||
self._conn.close()
|
# Only close the connection if we opened it; when sharing the server's
|
||||||
|
# LocalStore, the server owns the connection lifecycle.
|
||||||
|
if self._owns_store:
|
||||||
|
self._store.close()
|
||||||
|
|
||||||
def _sync_loop(self) -> None:
|
def _sync_loop(self) -> None:
|
||||||
"""Periodic sync loop running in background thread."""
|
"""Periodic sync loop running in background thread."""
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,12 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import sqlite3
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -408,3 +413,244 @@ class TestSchemaMigration:
|
||||||
columns = {row["name"] for row in cursor.fetchall()}
|
columns = {row["name"] for row in cursor.fetchall()}
|
||||||
assert "server_id" in columns
|
assert "server_id" in columns
|
||||||
srv.sqlite_conn.close()
|
srv.sqlite_conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _server_rows(server: MemoryServer) -> int:
|
||||||
|
assert server.sqlite_conn is not None
|
||||||
|
return int(server.sqlite_conn.execute("SELECT COUNT(*) AS c FROM memories").fetchone()["c"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrentWrites:
|
||||||
|
"""Regression tests for 'database is locked' under heavy concurrent local writes.
|
||||||
|
|
||||||
|
The store path (MemoryServer) and the background sync writer (SyncEngine) must not
|
||||||
|
contend on the single SQLite writer. Before the fix they used two separate connections
|
||||||
|
to the same file; under load the second writer blew past busy_timeout and raised
|
||||||
|
``sqlite3.OperationalError: database is locked``. After the fix all local writes are
|
||||||
|
serialized through one shared, lock-guarded connection, so a lock error is impossible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_concurrent_stores_all_rows_land(self, tmp_path):
|
||||||
|
"""Many threads calling memory_store concurrently → every row lands, no lock error."""
|
||||||
|
db_path = str(tmp_path / "concurrent.db")
|
||||||
|
server = MemoryServer(sqlite_db_path=db_path)
|
||||||
|
try:
|
||||||
|
n_threads = 16
|
||||||
|
per_thread = 12
|
||||||
|
errors: list[BaseException] = []
|
||||||
|
barrier = threading.Barrier(n_threads)
|
||||||
|
|
||||||
|
def worker(tid: int) -> None:
|
||||||
|
barrier.wait() # release all threads at once to maximise contention
|
||||||
|
for i in range(per_thread):
|
||||||
|
try:
|
||||||
|
server.memory_store({
|
||||||
|
"content": f"thread {tid} memory {i}",
|
||||||
|
"expanded_keywords": f"thread {tid} memory {i} concurrent write",
|
||||||
|
})
|
||||||
|
except BaseException as exc: # noqa: BLE001 — capture everything for the assert
|
||||||
|
errors.append(exc)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=worker, args=(t,)) for t in range(n_threads)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert errors == [], f"concurrent stores raised: {errors!r}"
|
||||||
|
assert _server_rows(server) == n_threads * per_thread
|
||||||
|
finally:
|
||||||
|
if server.sqlite_conn:
|
||||||
|
server.sqlite_conn.close()
|
||||||
|
|
||||||
|
def test_concurrent_stores_while_sync_writer_active_no_lock(self, tmp_path):
|
||||||
|
"""Store path + background sync writer hammer the SAME file → no 'database is locked'.
|
||||||
|
|
||||||
|
Reproduces the production failure: ``MemoryServer`` and ``SyncEngine`` both write to
|
||||||
|
one SQLite file. We shrink ``busy_timeout`` so the structural race (two writers fighting
|
||||||
|
the single SQLite write lock) surfaces within seconds instead of 30s. The shared-writer
|
||||||
|
fix makes a lock error impossible regardless of busy_timeout.
|
||||||
|
"""
|
||||||
|
from claude_memory.sync import SyncEngine
|
||||||
|
|
||||||
|
db_path = str(tmp_path / "hybrid.db")
|
||||||
|
server = MemoryServer(sqlite_db_path=db_path)
|
||||||
|
# The production hybrid wiring: the sync engine SHARES the server's LocalStore
|
||||||
|
# (one connection, one lock) — the structural fix for the cross-connection race.
|
||||||
|
engine = SyncEngine(
|
||||||
|
db_path=db_path,
|
||||||
|
api_base_url="http://fake-api:8080",
|
||||||
|
api_key="test-key",
|
||||||
|
sync_interval=3600, # never auto-syncs; we drive the writer by hand
|
||||||
|
store=server.store,
|
||||||
|
)
|
||||||
|
assert engine._conn is server.sqlite_conn # truly one shared connection
|
||||||
|
|
||||||
|
# Shrink the busy timeout so that, were there still two writers, the race would
|
||||||
|
# surface in ms not 30s. With one shared connection a lock error is impossible.
|
||||||
|
server.sqlite_conn.execute("PRAGMA busy_timeout=50")
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
# A write-heavy pull: many rows upserted inside the sync engine's single lock block —
|
||||||
|
# exactly the kind of long-held writer that starved the store path.
|
||||||
|
def big_pull(method: str, path: str, body: Any = None) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"memories": [
|
||||||
|
{
|
||||||
|
"id": 10_000 + j,
|
||||||
|
"content": f"server row {j}",
|
||||||
|
"category": "facts",
|
||||||
|
"tags": "",
|
||||||
|
"expanded_keywords": "",
|
||||||
|
"importance": 0.5,
|
||||||
|
"is_sensitive": False,
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
"deleted_at": None,
|
||||||
|
}
|
||||||
|
for j in range(120)
|
||||||
|
],
|
||||||
|
"server_time": now,
|
||||||
|
}
|
||||||
|
|
||||||
|
errors: list[BaseException] = []
|
||||||
|
stop = threading.Event()
|
||||||
|
|
||||||
|
def sync_writer() -> None:
|
||||||
|
with patch.object(engine, "_api_request", side_effect=big_pull):
|
||||||
|
while not stop.is_set():
|
||||||
|
try:
|
||||||
|
engine._pull_changes()
|
||||||
|
except BaseException as exc: # noqa: BLE001
|
||||||
|
errors.append(exc)
|
||||||
|
|
||||||
|
n_threads = 12
|
||||||
|
per_thread = 12
|
||||||
|
barrier = threading.Barrier(n_threads)
|
||||||
|
|
||||||
|
def store_worker(tid: int) -> None:
|
||||||
|
barrier.wait()
|
||||||
|
for i in range(per_thread):
|
||||||
|
try:
|
||||||
|
server.memory_store({
|
||||||
|
"content": f"local {tid}.{i}",
|
||||||
|
"expanded_keywords": f"local thread {tid} item {i} keywords here",
|
||||||
|
})
|
||||||
|
except BaseException as exc: # noqa: BLE001
|
||||||
|
errors.append(exc)
|
||||||
|
|
||||||
|
writer = threading.Thread(target=sync_writer, daemon=True)
|
||||||
|
writer.start()
|
||||||
|
final_rows = 0
|
||||||
|
try:
|
||||||
|
store_threads = [threading.Thread(target=store_worker, args=(t,)) for t in range(n_threads)]
|
||||||
|
for t in store_threads:
|
||||||
|
t.start()
|
||||||
|
for t in store_threads:
|
||||||
|
t.join()
|
||||||
|
stop.set()
|
||||||
|
writer.join(timeout=5)
|
||||||
|
final_rows = _server_rows(server) # read while the connection is still open
|
||||||
|
finally:
|
||||||
|
stop.set()
|
||||||
|
engine.stop() # shares the server's store → does not close the connection
|
||||||
|
if server.sqlite_conn:
|
||||||
|
server.sqlite_conn.close()
|
||||||
|
|
||||||
|
locked = [e for e in errors if isinstance(e, sqlite3.OperationalError) and "locked" in str(e)]
|
||||||
|
assert locked == [], f"hit 'database is locked' under concurrency: {locked!r}"
|
||||||
|
assert errors == [], f"concurrent writes raised: {errors!r}"
|
||||||
|
# Every local store must have landed.
|
||||||
|
assert final_rows >= n_threads * per_thread
|
||||||
|
|
||||||
|
|
||||||
|
class TestRecallProgressAndBounding:
|
||||||
|
"""The slow path — a remote semantic ``memory_recall`` — must be bounded and give a
|
||||||
|
notion of progress instead of hanging silently and dropping the session."""
|
||||||
|
|
||||||
|
def test_remote_recall_timeout_returns_clear_signal_not_raise(self, server):
|
||||||
|
"""A timed-out / unreachable remote recall returns a clear 'retry' message, never hangs/raises."""
|
||||||
|
import claude_memory.mcp_server as m
|
||||||
|
|
||||||
|
with patch.object(m, "_api_request", side_effect=RuntimeError("API connection error: timed out")):
|
||||||
|
text = server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
||||||
|
|
||||||
|
assert "recall" in text.lower()
|
||||||
|
# Mentions the bound and that the caller should retry — a clear working/not-done signal.
|
||||||
|
assert "retry" in text.lower() or "again" in text.lower()
|
||||||
|
assert str(m.RECALL_TIMEOUT) in text
|
||||||
|
|
||||||
|
def test_remote_recall_success_formats_rows(self, server):
|
||||||
|
"""A successful remote recall still formats results normally."""
|
||||||
|
import claude_memory.mcp_server as m
|
||||||
|
|
||||||
|
payload = {"memories": [
|
||||||
|
{"id": 7, "category": "facts", "importance": 0.8, "content": "hello",
|
||||||
|
"tags": "t", "created_at": "2026-01-01T00:00:00+00:00"},
|
||||||
|
]}
|
||||||
|
with patch.object(m, "_api_request", return_value=payload):
|
||||||
|
text = server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
||||||
|
|
||||||
|
assert "Found 1 memories" in text
|
||||||
|
assert "hello" in text
|
||||||
|
|
||||||
|
def test_remote_recall_uses_bounded_timeout(self, server):
|
||||||
|
"""The remote recall passes the bounded RECALL_TIMEOUT to the HTTP layer."""
|
||||||
|
import claude_memory.mcp_server as m
|
||||||
|
|
||||||
|
with patch.object(m, "_api_request", return_value={"memories": []}) as mock_api:
|
||||||
|
server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
||||||
|
|
||||||
|
_, kwargs = mock_api.call_args
|
||||||
|
assert kwargs.get("timeout") == m.RECALL_TIMEOUT
|
||||||
|
|
||||||
|
def test_api_request_accepts_timeout_kwarg(self):
|
||||||
|
"""Module-level _api_request must accept an explicit timeout without breaking callers."""
|
||||||
|
import inspect
|
||||||
|
import claude_memory.mcp_server as m
|
||||||
|
|
||||||
|
sig = inspect.signature(m._api_request)
|
||||||
|
assert "timeout" in sig.parameters
|
||||||
|
|
||||||
|
def test_progress_notification_emitted_for_recall_with_token(self, server):
|
||||||
|
"""When the client supplies a progressToken, a 'working' progress notification is emitted."""
|
||||||
|
sent: list[dict[str, Any]] = []
|
||||||
|
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
||||||
|
|
||||||
|
with patch.object(server, "memory_recall", return_value="ok"):
|
||||||
|
server.handle_tools_call({
|
||||||
|
"name": "memory_recall",
|
||||||
|
"arguments": {"context": "x", "expanded_query": "a b c d e"},
|
||||||
|
"_meta": {"progressToken": "tok-1"},
|
||||||
|
})
|
||||||
|
|
||||||
|
progress = [s for s in sent if s["method"] == "notifications/progress"]
|
||||||
|
assert progress, "expected a notifications/progress for a tokened recall"
|
||||||
|
assert progress[0]["params"]["progressToken"] == "tok-1"
|
||||||
|
|
||||||
|
def test_no_progress_notification_without_token(self, server):
|
||||||
|
"""No token → no progress notification (don't spam clients that didn't opt in)."""
|
||||||
|
sent: list[dict[str, Any]] = []
|
||||||
|
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
||||||
|
|
||||||
|
with patch.object(server, "memory_recall", return_value="ok"):
|
||||||
|
server.handle_tools_call({
|
||||||
|
"name": "memory_recall",
|
||||||
|
"arguments": {"context": "x", "expanded_query": "a b c d e"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert [s for s in sent if s["method"] == "notifications/progress"] == []
|
||||||
|
|
||||||
|
def test_fast_tools_do_not_emit_progress(self, server):
|
||||||
|
"""Only the slow recall path signals progress; a store with a token stays quiet."""
|
||||||
|
sent: list[dict[str, Any]] = []
|
||||||
|
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
||||||
|
|
||||||
|
server.handle_tools_call({
|
||||||
|
"name": "memory_store",
|
||||||
|
"arguments": {"content": "x", "expanded_keywords": "a b c d e"},
|
||||||
|
"_meta": {"progressToken": "tok-2"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert [s for s in sent if s["method"] == "notifications/progress"] == []
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue