diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5b4b20c..463cf3a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -70,6 +70,7 @@ jobs: - uses: docker/build-push-action@v7 with: context: . + file: docker/Dockerfile push: true platforms: linux/amd64 # Single-manifest images (no provenance/SBOM attestation children) so diff --git a/.gitignore b/.gitignore index fb63210..4a896e7 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,9 @@ docker/pgdata/ # Beads / Dolt files (added by bd init) .dolt/ .beads-credential-key + +# Agent git worktrees (standing policy: never the shared checkout) +.worktrees/ + +# uv lockfile — CI runs `uv sync` itself; not tracked +uv.lock diff --git a/src/claude_memory/local_store.py b/src/claude_memory/local_store.py new file mode 100644 index 0000000..09cd3d9 --- /dev/null +++ b/src/claude_memory/local_store.py @@ -0,0 +1,116 @@ +"""Single, process-wide serialized SQLite writer for the local memory cache. + +SQLite permits only one writer at a time. The MCP server's store path and the +background sync engine used to open *separate* connections to the *same* file; +under heavy concurrent ``memory_store`` calls those two writers fought over the +single SQLite write lock, blew past ``busy_timeout``, and surfaced +``sqlite3.OperationalError: database is locked`` — which made the tool slow and +eventually dropped the session. + +``LocalStore`` fixes this structurally: it owns ONE connection (opened with +``check_same_thread=False``) guarded by ONE re-entrant lock. Every component that +needs to touch the local DB shares the same ``LocalStore`` instance, so all +writes serialize cleanly through the in-process lock and queue instead of racing +the SQLite writer. On the rare residual lock (e.g. another OS process touching +the file), writes retry with bounded exponential backoff rather than failing the +caller. WAL stays on for concurrent reads. + +Uses only stdlib — no pip install required. +""" + +import logging +import sqlite3 +import threading +import time +from pathlib import Path +from typing import Any, Callable, TypeVar + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +# Bounded retry window for the rare residual "database is locked" — handles a +# lock held by a *different OS process* (the in-process lock already serializes +# this process's own writers). Total worst-case wait ≈ 0.05+0.1+0.2+0.4+0.8 ≈ 1.55s. +_MAX_RETRIES = 5 +_BASE_BACKOFF_S = 0.05 +_BUSY_TIMEOUT_MS = 30000 + + +def _is_locked_error(exc: sqlite3.OperationalError) -> bool: + msg = str(exc).lower() + return "database is locked" in msg or "database is busy" in msg + + +class LocalStore: + """Owns the single shared SQLite connection + lock for local memory writes.""" + + def __init__(self, conn: sqlite3.Connection) -> None: + self.conn = conn + # Re-entrant so a transaction callback may itself call ``execute``/``write`` + # without dead-locking on the same thread. + self.lock = threading.RLock() + + # ── construction ──────────────────────────────────────────────── + + @classmethod + def open(cls, db_path: str) -> "LocalStore": + """Open (creating parent dirs) a WAL connection safe for cross-thread use.""" + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(db_path, timeout=30.0, check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute(f"PRAGMA busy_timeout={_BUSY_TIMEOUT_MS}") + return cls(conn) + + # ── serialized access ─────────────────────────────────────────── + + def transaction(self, fn: Callable[[sqlite3.Connection], T]) -> T: + """Run ``fn(conn)`` holding the shared lock, with bounded retry on lock errors. + + ``fn`` is responsible for issuing its own ``COMMIT`` (call ``conn.commit()``) + when it performs writes. The whole callback runs under the process-wide lock, + so concurrent callers queue rather than collide on the SQLite writer. + """ + last_exc: sqlite3.OperationalError | None = None + for attempt in range(_MAX_RETRIES): + with self.lock: + try: + return fn(self.conn) + except sqlite3.OperationalError as exc: + if not _is_locked_error(exc): + raise + last_exc = exc + # Roll back any partial txn so the retry starts clean and the + # connection isn't left mid-transaction holding locks. + try: + self.conn.rollback() + except sqlite3.Error: + pass + # Back off *outside* the lock so other writers can make progress. + backoff = _BASE_BACKOFF_S * (2 ** attempt) + logger.warning( + "SQLite locked (attempt %d/%d) — backing off %.3fs", attempt + 1, _MAX_RETRIES, backoff + ) + time.sleep(backoff) + assert last_exc is not None + raise last_exc + + def execute(self, sql: str, params: tuple[Any, ...] = ()) -> sqlite3.Cursor: + """Run a single read query under the shared lock (no implicit commit).""" + with self.lock: + return self.conn.execute(sql, params) + + def write(self, sql: str, params: tuple[Any, ...] = ()) -> sqlite3.Cursor: + """Run a single write statement + commit, serialized and retry-guarded.""" + + def _do(conn: sqlite3.Connection) -> sqlite3.Cursor: + cur = conn.execute(sql, params) + conn.commit() + return cur + + return self.transaction(_do) + + def close(self) -> None: + with self.lock: + self.conn.close() diff --git a/src/claude_memory/mcp_server.py b/src/claude_memory/mcp_server.py index a3fa210..367ed71 100644 --- a/src/claude_memory/mcp_server.py +++ b/src/claude_memory/mcp_server.py @@ -17,7 +17,10 @@ import sqlite3 import sys import urllib.error import urllib.request -from typing import Any +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from claude_memory.local_store import LocalStore logger = logging.getLogger(__name__) @@ -35,9 +38,17 @@ HYBRID_MODE = bool(API_KEY) and not SYNC_DISABLED HTTP_ONLY = bool(API_KEY) and SYNC_DISABLED SQLITE_ONLY = not API_KEY +# Default per-request HTTP timeout, and a wider bound for the one genuinely slow path: +# a remote semantic ``memory_recall`` (embedding/search can be slow to warm up). Both are +# explicit upper bounds so a call can never hang the MCP server indefinitely. +DEFAULT_API_TIMEOUT = 15 +RECALL_TIMEOUT = int(os.environ.get("MEMORY_RECALL_TIMEOUT", "30")) -def _api_request(method: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]: - """Make an HTTP request to the memory API.""" + +def _api_request( + method: str, path: str, body: dict[str, Any] | None = None, timeout: int = DEFAULT_API_TIMEOUT +) -> dict[str, Any]: + """Make an HTTP request to the memory API (bounded by ``timeout`` seconds).""" url = f"{API_BASE_URL}{path}" data = json.dumps(body).encode() if body else None req = urllib.request.Request( @@ -50,7 +61,7 @@ def _api_request(method: str, path: str, body: dict[str, Any] | None = None) -> }, ) try: - with urllib.request.urlopen(req, timeout=15) as resp: + with urllib.request.urlopen(req, timeout=timeout) as resp: result: dict[str, Any] = json.loads(resp.read().decode()) return result except urllib.error.HTTPError as e: @@ -128,7 +139,9 @@ def _init_sqlite(db_path: str | None = None) -> tuple[sqlite3.Connection, str]: db_path = _get_db_path(db_path) Path(os.path.dirname(db_path)).mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(db_path, timeout=30.0) + # check_same_thread=False: the MCP server may handle requests on different + # threads and shares this one connection via LocalStore's lock (see local_store.py). + conn = sqlite3.connect(db_path, timeout=30.0, check_same_thread=False) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA busy_timeout=30000") @@ -390,19 +403,30 @@ class MemoryServer: def __init__(self, sqlite_db_path: str | None = None) -> None: self.sqlite_conn: sqlite3.Connection | None = None + self.store: "LocalStore | None" = None # single serialized writer (see local_store.py) self.sync_engine: Any = None + # Sink for MCP notifications (e.g. progress). Defaults to writing a JSON-RPC + # notification line to stdout; overridable in tests. + self._notify: Callable[[str, dict[str, Any]], None] = self._emit_notification if SQLITE_ONLY or HYBRID_MODE: - self.sqlite_conn, resolved_path = _init_sqlite(sqlite_db_path) + conn, resolved_path = _init_sqlite(sqlite_db_path) + from claude_memory.local_store import LocalStore + self.store = LocalStore(conn) + self.sqlite_conn = conn if HYBRID_MODE: from claude_memory.sync import SyncEngine sync_interval = int(os.environ.get("MEMORY_SYNC_INTERVAL", "60")) + # Share the SAME LocalStore (one connection, one lock) so the background + # sync writer never opens a second connection to the file and never races + # the store path on the single SQLite writer. self.sync_engine = SyncEngine( db_path=resolved_path, api_base_url=API_BASE_URL, api_key=API_KEY, sync_interval=sync_interval, + store=self.store, ) self.sync_engine.start() @@ -455,31 +479,59 @@ class MemoryServer: limit = args.get("limit", 10) if HTTP_ONLY: - result = _api_request("POST", "/api/memories/recall", { - "context": context, - "expanded_query": expanded_query, - "category": category, - "sort_by": sort_by, - "limit": limit, - }) - rows = result.get("memories", []) - if not rows: - filter_desc = f" in category '{category}'" if category else "" - return f"No memories found matching: {context}{filter_desc}" - - sort_desc = "by relevance" if sort_by == "relevance" else "by importance" - filter_desc = f" in '{category}'" if category else "" - results = [] - for row in rows: - results.append( - f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}" - f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}" - ) - return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results) + return self._recall_remote(args) # SQLite-only or Hybrid: always read from local cache return self._sqlite_recall(context, expanded_query, category, sort_by, limit) + def _recall_remote(self, args: dict[str, Any]) -> str: + """Remote semantic recall — the one genuinely slow path. + + Bounded by ``RECALL_TIMEOUT`` so it can never hang the MCP server. On a timeout + or unreachable backend it returns a clear "working / not done — retry" signal + rather than raising or blocking silently. + """ + context = args.get("context") + expanded_query = args.get("expanded_query", "") + category = args.get("category") + sort_by = args.get("sort_by", "importance") + limit = args.get("limit", 10) + + try: + result = _api_request( + "POST", "/api/memories/recall", + { + "context": context, + "expanded_query": expanded_query, + "category": category, + "sort_by": sort_by, + "limit": limit, + }, + timeout=RECALL_TIMEOUT, + ) + except RuntimeError as e: + # _api_request wraps connection errors / timeouts as RuntimeError. Surface a + # clear, actionable signal instead of hanging or dumping a stack trace. + return ( + f"Memory recall did not complete within {RECALL_TIMEOUT}s — the semantic " + f"search backend is slow or unreachable ({e}). Please try again." + ) + + rows = result.get("memories", []) + if not rows: + filter_desc = f" in category '{category}'" if category else "" + return f"No memories found matching: {context}{filter_desc}" + + sort_desc = "by relevance" if sort_by == "relevance" else "by importance" + filter_desc = f" in '{category}'" if category else "" + results = [] + for row in rows: + results.append( + f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}" + f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}" + ) + return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results) + def memory_list(self, args: dict[str, Any]) -> str: category = args.get("category") limit = args.get("limit", 20) @@ -519,10 +571,11 @@ class MemoryServer: # SQLite-only or Hybrid: delete from local SQLite # In hybrid mode, also try to sync delete to server server_id: int | None = None - if HYBRID_MODE and self.sync_engine and self.sqlite_conn: - cursor = self.sqlite_conn.cursor() - cursor.execute("SELECT server_id FROM memories WHERE id = ?", (memory_id,)) - row = cursor.fetchone() + if HYBRID_MODE and self.sync_engine and self.store: + with self.store.lock: + cursor = self.store.conn.cursor() + cursor.execute("SELECT server_id FROM memories WHERE id = ?", (memory_id,)) + row = cursor.fetchone() server_id = row["server_id"] if row and row["server_id"] else None result_text = self._sqlite_delete(memory_id) @@ -563,12 +616,13 @@ class MemoryServer: lines.append(f"Last sync success: {counts['last_sync_success']}") return "\n".join(lines) - if self.sqlite_conn: - cursor = self.sqlite_conn.cursor() - cursor.execute("SELECT COUNT(*) as c FROM memories") - total = cursor.fetchone()["c"] - cursor.execute("SELECT category, COUNT(*) as c FROM memories GROUP BY category ORDER BY c DESC") - by_cat = cursor.fetchall() + if self.store: + with self.store.lock: + cursor = self.store.conn.cursor() + cursor.execute("SELECT COUNT(*) as c FROM memories") + total = cursor.fetchone()["c"] + cursor.execute("SELECT category, COUNT(*) as c FROM memories GROUP BY category ORDER BY c DESC") + by_cat = cursor.fetchall() lines = [f"Local memories (SQLite-only): {total}"] for row in by_cat: lines.append(f" {row['category']}: {row['c']}") @@ -682,19 +736,25 @@ class MemoryServer: def _sqlite_store(self, content: str, category: str, tags: str, importance: float, expanded_keywords: str, force_sensitive: bool = False) -> str: from datetime import datetime, timezone - assert self.sqlite_conn is not None + assert self.store is not None now = datetime.now(timezone.utc).isoformat() is_sensitive = 1 if force_sensitive else 0 - cursor = self.sqlite_conn.cursor() - cursor.execute( - "INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - (content, category, tags, expanded_keywords, importance, is_sensitive, now, now), - ) - self.sqlite_conn.commit() - return f"Stored memory #{cursor.lastrowid} in category '{category}' with importance {importance:.1f}" + + def _insert(conn: sqlite3.Connection) -> int | None: + cursor = conn.execute( + "INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (content, category, tags, expanded_keywords, importance, is_sensitive, now, now), + ) + conn.commit() + return cursor.lastrowid + + # Serialized + retry-guarded through the shared LocalStore so concurrent + # stores (and the background sync writer) never collide on the SQLite writer. + new_id = self.store.transaction(_insert) + return f"Stored memory #{new_id} in category '{category}' with importance {importance:.1f}" def _sqlite_recall(self, context: str, expanded_query: str, category: str | None, sort_by: str, limit: int) -> str: - assert self.sqlite_conn is not None + assert self.store is not None all_terms = f"{context} {expanded_query}".strip() words = [w.replace(chr(34), "") for w in all_terms.split() if w] and_query = " AND ".join(f'"{w}"' for w in words) @@ -712,35 +772,38 @@ class MemoryServer: "SELECT m.id, m.content, m.category, m.tags, m.importance, m.created_at " "FROM memories m JOIN memories_fts fts ON m.id = fts.rowid " ) - cursor = self.sqlite_conn.cursor() rows: list[Any] = [] - try: - # Try AND first for precise matches, fall back to OR for broader results - cat_filter = "AND m.category = ?" if category else "" - for fts_query in (and_query, or_query): - params = [fts_query, category, limit] if category else [fts_query, limit] - cursor.execute( - f"{base_select}WHERE memories_fts MATCH ? {cat_filter} ORDER BY {order} LIMIT ?", - tuple(p for p in params if p is not None), - ) + # Hold the shared lock for the whole read — the connection is shared across + # threads with the sync writer, so reads must serialize too. + with self.store.lock: + cursor = self.store.conn.cursor() + try: + # Try AND first for precise matches, fall back to OR for broader results + cat_filter = "AND m.category = ?" if category else "" + for fts_query in (and_query, or_query): + params = [fts_query, category, limit] if category else [fts_query, limit] + cursor.execute( + f"{base_select}WHERE memories_fts MATCH ? {cat_filter} ORDER BY {order} LIMIT ?", + tuple(p for p in params if p is not None), + ) + rows = cursor.fetchall() + if rows: + break + except sqlite3.OperationalError: + like = f"%{context}%" + if category: + cursor.execute( + "SELECT id, content, category, tags, importance, created_at FROM memories " + "WHERE (content LIKE ? OR tags LIKE ?) AND category = ? ORDER BY importance DESC LIMIT ?", + (like, like, category, limit), + ) + else: + cursor.execute( + "SELECT id, content, category, tags, importance, created_at FROM memories " + "WHERE content LIKE ? OR tags LIKE ? ORDER BY importance DESC LIMIT ?", + (like, like, limit), + ) rows = cursor.fetchall() - if rows: - break - except sqlite3.OperationalError: - like = f"%{context}%" - if category: - cursor.execute( - "SELECT id, content, category, tags, importance, created_at FROM memories " - "WHERE (content LIKE ? OR tags LIKE ?) AND category = ? ORDER BY importance DESC LIMIT ?", - (like, like, category, limit), - ) - else: - cursor.execute( - "SELECT id, content, category, tags, importance, created_at FROM memories " - "WHERE content LIKE ? OR tags LIKE ? ORDER BY importance DESC LIMIT ?", - (like, like, limit), - ) - rows = cursor.fetchall() if not rows: return f"No memories found matching: {context}" @@ -757,21 +820,22 @@ class MemoryServer: ) def _sqlite_list(self, category: str | None, limit: int) -> str: - assert self.sqlite_conn is not None - cursor = self.sqlite_conn.cursor() - if category: - cursor.execute( - "SELECT id, content, category, tags, importance, created_at FROM memories " - "WHERE category = ? ORDER BY created_at DESC LIMIT ?", - (category, limit), - ) - else: - cursor.execute( - "SELECT id, content, category, tags, importance, created_at FROM memories " - "ORDER BY created_at DESC LIMIT ?", - (limit,), - ) - rows = cursor.fetchall() + assert self.store is not None + with self.store.lock: + cursor = self.store.conn.cursor() + if category: + cursor.execute( + "SELECT id, content, category, tags, importance, created_at FROM memories " + "WHERE category = ? ORDER BY created_at DESC LIMIT ?", + (category, limit), + ) + else: + cursor.execute( + "SELECT id, content, category, tags, importance, created_at FROM memories " + "ORDER BY created_at DESC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() if not rows: return f"No memories in category '{category}'" if category else "No memories stored yet" @@ -785,25 +849,30 @@ class MemoryServer: return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results) def _sqlite_delete(self, memory_id: int) -> str: - assert self.sqlite_conn is not None - cursor = self.sqlite_conn.cursor() - cursor.execute("SELECT id, content FROM memories WHERE id = ?", (memory_id,)) - row = cursor.fetchone() - if not row: - return f"Memory #{memory_id} not found" - preview = row["content"][:50] - cursor.execute("DELETE FROM memories WHERE id = ?", (memory_id,)) - self.sqlite_conn.commit() - return f"Deleted memory #{memory_id}: {preview}..." + assert self.store is not None + + def _delete(conn: sqlite3.Connection) -> str: + cursor = conn.execute("SELECT id, content FROM memories WHERE id = ?", (memory_id,)) + row = cursor.fetchone() + if not row: + return f"Memory #{memory_id} not found" + preview = row["content"][:50] + conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,)) + conn.commit() + return f"Deleted memory #{memory_id}: {preview}..." + + # SELECT + DELETE + commit as one serialized, retry-guarded unit. + return self.store.transaction(_delete) def _sqlite_secret_get(self, memory_id: int) -> str: - assert self.sqlite_conn is not None - cursor = self.sqlite_conn.cursor() - cursor.execute( - "SELECT id, content, category, is_sensitive FROM memories WHERE id = ?", - (memory_id,), - ) - row = cursor.fetchone() + assert self.store is not None + with self.store.lock: + cursor = self.store.conn.cursor() + cursor.execute( + "SELECT id, content, category, is_sensitive FROM memories WHERE id = ?", + (memory_id,), + ) + row = cursor.fetchone() if not row: return f"Memory #{memory_id} not found" if not row["is_sensitive"]: @@ -825,9 +894,25 @@ class MemoryServer: tools.extend(SHARING_TOOLS) return {"tools": tools} + # Tools whose work is genuinely slow enough to warrant a progress signal. + _SLOW_TOOLS = frozenset({"memory_recall"}) + def handle_tools_call(self, params: dict[str, Any]) -> dict[str, Any]: tool_name: str = params.get("name", "") arguments: dict[str, Any] = params.get("arguments", {}) + + # If the client opted into progress (sent a token) and this is a slow tool, emit a + # single "working" progress notification so the call shows life instead of looking hung. + progress_token = (params.get("_meta") or {}).get("progressToken") + if progress_token is not None and tool_name in self._SLOW_TOOLS: + try: + self._notify( + "notifications/progress", + {"progressToken": progress_token, "progress": 0, "message": f"Running {tool_name}…"}, + ) + except Exception: + pass # never let progress reporting break the actual call + try: handler = { "memory_store": self.memory_store, @@ -851,6 +936,10 @@ class MemoryServer: except Exception as e: return {"content": [{"type": "text", "text": f"Error: {e!s}"}], "isError": True} + def _emit_notification(self, method: str, params: dict[str, Any]) -> None: + """Default notification sink: write a JSON-RPC notification line to stdout.""" + print(json.dumps({"jsonrpc": "2.0", "method": method, "params": params}), flush=True) + def process_message(self, message: dict[str, Any]) -> dict[str, Any] | None: method = message.get("method") params = message.get("params", {}) diff --git a/src/claude_memory/sync.py b/src/claude_memory/sync.py index b4d1f99..86acda4 100644 --- a/src/claude_memory/sync.py +++ b/src/claude_memory/sync.py @@ -5,14 +5,14 @@ Uses only stdlib — no pip install required. import json import logging -import sqlite3 import threading import urllib.error import urllib.parse import urllib.request from typing import Any from datetime import datetime, timezone -from pathlib import Path + +from claude_memory.local_store import LocalStore logger = logging.getLogger(__name__) @@ -26,7 +26,14 @@ FULL_RESYNC_EVERY = 10 class SyncEngine: """Background sync between local SQLite cache and remote API.""" - def __init__(self, db_path: str, api_base_url: str, api_key: str, sync_interval: int = 60): + def __init__( + self, + db_path: str, + api_base_url: str, + api_key: str, + sync_interval: int = 60, + store: "LocalStore | None" = None, + ): self.db_path = db_path self.api_base_url = api_base_url.rstrip("/") self.api_key = api_key @@ -37,13 +44,20 @@ class SyncEngine: self._last_sync_success = False self._auth_failed = False - # Own connection for thread safety - Path(db_path).parent.mkdir(parents=True, exist_ok=True) - self._conn = sqlite3.connect(db_path, timeout=30.0, check_same_thread=False) - self._conn.row_factory = sqlite3.Row - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA busy_timeout=30000") - self._lock = threading.Lock() + # Share the caller's LocalStore (one connection, one lock) when given, so the + # background sync writer never opens a SECOND connection to the same file and + # never races the store path on the single SQLite writer. When run standalone + # (e.g. tests), open our own store. Either way, all SQLite access below goes + # through self._store / self._conn / self._lock, which now point at one shared + # connection guarded by one re-entrant lock. + if store is None: + self._store = LocalStore.open(db_path) + self._owns_store = True + else: + self._store = store + self._owns_store = False + self._conn = self._store.conn + self._lock = self._store.lock self._init_sync_tables() @@ -121,7 +135,10 @@ class SyncEngine: self._stop_event.set() if self._thread and self._thread.is_alive(): self._thread.join(timeout=5) - self._conn.close() + # Only close the connection if we opened it; when sharing the server's + # LocalStore, the server owns the connection lifecycle. + if self._owns_store: + self._store.close() def _sync_loop(self) -> None: """Periodic sync loop running in background thread.""" diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 584fd69..e470f5e 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -2,7 +2,12 @@ import json import os +import sqlite3 import sys +import threading +from datetime import datetime, timezone +from typing import Any +from unittest.mock import patch import pytest @@ -408,3 +413,244 @@ class TestSchemaMigration: columns = {row["name"] for row in cursor.fetchall()} assert "server_id" in columns srv.sqlite_conn.close() + + +def _server_rows(server: MemoryServer) -> int: + assert server.sqlite_conn is not None + return int(server.sqlite_conn.execute("SELECT COUNT(*) AS c FROM memories").fetchone()["c"]) + + +class TestConcurrentWrites: + """Regression tests for 'database is locked' under heavy concurrent local writes. + + The store path (MemoryServer) and the background sync writer (SyncEngine) must not + contend on the single SQLite writer. Before the fix they used two separate connections + to the same file; under load the second writer blew past busy_timeout and raised + ``sqlite3.OperationalError: database is locked``. After the fix all local writes are + serialized through one shared, lock-guarded connection, so a lock error is impossible. + """ + + def test_concurrent_stores_all_rows_land(self, tmp_path): + """Many threads calling memory_store concurrently → every row lands, no lock error.""" + db_path = str(tmp_path / "concurrent.db") + server = MemoryServer(sqlite_db_path=db_path) + try: + n_threads = 16 + per_thread = 12 + errors: list[BaseException] = [] + barrier = threading.Barrier(n_threads) + + def worker(tid: int) -> None: + barrier.wait() # release all threads at once to maximise contention + for i in range(per_thread): + try: + server.memory_store({ + "content": f"thread {tid} memory {i}", + "expanded_keywords": f"thread {tid} memory {i} concurrent write", + }) + except BaseException as exc: # noqa: BLE001 — capture everything for the assert + errors.append(exc) + + threads = [threading.Thread(target=worker, args=(t,)) for t in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [], f"concurrent stores raised: {errors!r}" + assert _server_rows(server) == n_threads * per_thread + finally: + if server.sqlite_conn: + server.sqlite_conn.close() + + def test_concurrent_stores_while_sync_writer_active_no_lock(self, tmp_path): + """Store path + background sync writer hammer the SAME file → no 'database is locked'. + + Reproduces the production failure: ``MemoryServer`` and ``SyncEngine`` both write to + one SQLite file. We shrink ``busy_timeout`` so the structural race (two writers fighting + the single SQLite write lock) surfaces within seconds instead of 30s. The shared-writer + fix makes a lock error impossible regardless of busy_timeout. + """ + from claude_memory.sync import SyncEngine + + db_path = str(tmp_path / "hybrid.db") + server = MemoryServer(sqlite_db_path=db_path) + # The production hybrid wiring: the sync engine SHARES the server's LocalStore + # (one connection, one lock) — the structural fix for the cross-connection race. + engine = SyncEngine( + db_path=db_path, + api_base_url="http://fake-api:8080", + api_key="test-key", + sync_interval=3600, # never auto-syncs; we drive the writer by hand + store=server.store, + ) + assert engine._conn is server.sqlite_conn # truly one shared connection + + # Shrink the busy timeout so that, were there still two writers, the race would + # surface in ms not 30s. With one shared connection a lock error is impossible. + server.sqlite_conn.execute("PRAGMA busy_timeout=50") + + now = datetime.now(timezone.utc).isoformat() + + # A write-heavy pull: many rows upserted inside the sync engine's single lock block — + # exactly the kind of long-held writer that starved the store path. + def big_pull(method: str, path: str, body: Any = None) -> dict[str, Any]: + return { + "memories": [ + { + "id": 10_000 + j, + "content": f"server row {j}", + "category": "facts", + "tags": "", + "expanded_keywords": "", + "importance": 0.5, + "is_sensitive": False, + "created_at": now, + "updated_at": now, + "deleted_at": None, + } + for j in range(120) + ], + "server_time": now, + } + + errors: list[BaseException] = [] + stop = threading.Event() + + def sync_writer() -> None: + with patch.object(engine, "_api_request", side_effect=big_pull): + while not stop.is_set(): + try: + engine._pull_changes() + except BaseException as exc: # noqa: BLE001 + errors.append(exc) + + n_threads = 12 + per_thread = 12 + barrier = threading.Barrier(n_threads) + + def store_worker(tid: int) -> None: + barrier.wait() + for i in range(per_thread): + try: + server.memory_store({ + "content": f"local {tid}.{i}", + "expanded_keywords": f"local thread {tid} item {i} keywords here", + }) + except BaseException as exc: # noqa: BLE001 + errors.append(exc) + + writer = threading.Thread(target=sync_writer, daemon=True) + writer.start() + final_rows = 0 + try: + store_threads = [threading.Thread(target=store_worker, args=(t,)) for t in range(n_threads)] + for t in store_threads: + t.start() + for t in store_threads: + t.join() + stop.set() + writer.join(timeout=5) + final_rows = _server_rows(server) # read while the connection is still open + finally: + stop.set() + engine.stop() # shares the server's store → does not close the connection + if server.sqlite_conn: + server.sqlite_conn.close() + + locked = [e for e in errors if isinstance(e, sqlite3.OperationalError) and "locked" in str(e)] + assert locked == [], f"hit 'database is locked' under concurrency: {locked!r}" + assert errors == [], f"concurrent writes raised: {errors!r}" + # Every local store must have landed. + assert final_rows >= n_threads * per_thread + + +class TestRecallProgressAndBounding: + """The slow path — a remote semantic ``memory_recall`` — must be bounded and give a + notion of progress instead of hanging silently and dropping the session.""" + + def test_remote_recall_timeout_returns_clear_signal_not_raise(self, server): + """A timed-out / unreachable remote recall returns a clear 'retry' message, never hangs/raises.""" + import claude_memory.mcp_server as m + + with patch.object(m, "_api_request", side_effect=RuntimeError("API connection error: timed out")): + text = server._recall_remote({"context": "x", "expanded_query": "a b c d e"}) + + assert "recall" in text.lower() + # Mentions the bound and that the caller should retry — a clear working/not-done signal. + assert "retry" in text.lower() or "again" in text.lower() + assert str(m.RECALL_TIMEOUT) in text + + def test_remote_recall_success_formats_rows(self, server): + """A successful remote recall still formats results normally.""" + import claude_memory.mcp_server as m + + payload = {"memories": [ + {"id": 7, "category": "facts", "importance": 0.8, "content": "hello", + "tags": "t", "created_at": "2026-01-01T00:00:00+00:00"}, + ]} + with patch.object(m, "_api_request", return_value=payload): + text = server._recall_remote({"context": "x", "expanded_query": "a b c d e"}) + + assert "Found 1 memories" in text + assert "hello" in text + + def test_remote_recall_uses_bounded_timeout(self, server): + """The remote recall passes the bounded RECALL_TIMEOUT to the HTTP layer.""" + import claude_memory.mcp_server as m + + with patch.object(m, "_api_request", return_value={"memories": []}) as mock_api: + server._recall_remote({"context": "x", "expanded_query": "a b c d e"}) + + _, kwargs = mock_api.call_args + assert kwargs.get("timeout") == m.RECALL_TIMEOUT + + def test_api_request_accepts_timeout_kwarg(self): + """Module-level _api_request must accept an explicit timeout without breaking callers.""" + import inspect + import claude_memory.mcp_server as m + + sig = inspect.signature(m._api_request) + assert "timeout" in sig.parameters + + def test_progress_notification_emitted_for_recall_with_token(self, server): + """When the client supplies a progressToken, a 'working' progress notification is emitted.""" + sent: list[dict[str, Any]] = [] + server._notify = lambda method, params: sent.append({"method": method, "params": params}) + + with patch.object(server, "memory_recall", return_value="ok"): + server.handle_tools_call({ + "name": "memory_recall", + "arguments": {"context": "x", "expanded_query": "a b c d e"}, + "_meta": {"progressToken": "tok-1"}, + }) + + progress = [s for s in sent if s["method"] == "notifications/progress"] + assert progress, "expected a notifications/progress for a tokened recall" + assert progress[0]["params"]["progressToken"] == "tok-1" + + def test_no_progress_notification_without_token(self, server): + """No token → no progress notification (don't spam clients that didn't opt in).""" + sent: list[dict[str, Any]] = [] + server._notify = lambda method, params: sent.append({"method": method, "params": params}) + + with patch.object(server, "memory_recall", return_value="ok"): + server.handle_tools_call({ + "name": "memory_recall", + "arguments": {"context": "x", "expanded_query": "a b c d e"}, + }) + + assert [s for s in sent if s["method"] == "notifications/progress"] == [] + + def test_fast_tools_do_not_emit_progress(self, server): + """Only the slow recall path signals progress; a store with a token stays quiet.""" + sent: list[dict[str, Any]] = [] + server._notify = lambda method, params: sent.append({"method": method, "params": params}) + + server.handle_tools_call({ + "name": "memory_store", + "arguments": {"content": "x", "expanded_keywords": "a b c d e"}, + "_meta": {"progressToken": "tok-2"}, + }) + + assert [s for s in sent if s["method"] == "notifications/progress"] == []