fix(mcp): serialize local SQLite writes to end "database is locked" under concurrent stores
Under heavy concurrent memory_store (many subagents/sessions writing close
together) the local SQLite layer raced on the single SQLite writer and surfaced
sqlite3.OperationalError: database is locked, which made memory tools slow and
eventually dropped whole sessions. Two root causes:
- The MCP server (mcp_server.py) and the background SyncEngine (sync.py) each
opened a SEPARATE connection to the same SQLite file. WAL allows one writer;
when the sync writer held the lock across a resync, a concurrent store blew
past busy_timeout and raised "database is locked".
- mcp_server's connection was opened WITHOUT check_same_thread=False, so the
moment two requests were handled on different threads every local store/recall
raised ProgrammingError "SQLite objects created in a thread...".
Fix: a single process-wide serialized writer.
- New LocalStore (local_store.py) owns ONE connection (check_same_thread=False)
guarded by ONE re-entrant lock, keeps WAL, and wraps writes in
transaction()/write() with bounded exponential-backoff retry on the rare
residual lock (e.g. another OS process) instead of failing the call.
- MemoryServer builds the LocalStore and SHARES it with the SyncEngine, so the
sync writer no longer opens a second connection — the two-connections race is
eliminated. All server reads/writes go through the shared lock; stores stay
snappy (enqueue-local + async sync).
- Bound the one genuinely slow path (remote semantic memory_recall) with an
explicit RECALL_TIMEOUT and, on timeout/unreachable backend, return a clear
"working / retry" signal instead of hanging silently or crashing. When a
client supplies _meta.progressToken, emit one notifications/progress so the
call shows life.
Ships to users via the plugin (mcp/memory-mcp.json runs src/.../mcp_server.py);
no server-side/API change needed. TDD: added concurrency tests (many threads +
sync writer on one file) and recall progress/bounding tests; full gate green
(ruff + mypy strict + 185 pytest).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
68088e684e
commit
5151bbe0d5
5 changed files with 590 additions and 116 deletions
|
|
@ -2,7 +2,12 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -408,3 +413,244 @@ class TestSchemaMigration:
|
|||
columns = {row["name"] for row in cursor.fetchall()}
|
||||
assert "server_id" in columns
|
||||
srv.sqlite_conn.close()
|
||||
|
||||
|
||||
def _server_rows(server: MemoryServer) -> int:
|
||||
assert server.sqlite_conn is not None
|
||||
return int(server.sqlite_conn.execute("SELECT COUNT(*) AS c FROM memories").fetchone()["c"])
|
||||
|
||||
|
||||
class TestConcurrentWrites:
|
||||
"""Regression tests for 'database is locked' under heavy concurrent local writes.
|
||||
|
||||
The store path (MemoryServer) and the background sync writer (SyncEngine) must not
|
||||
contend on the single SQLite writer. Before the fix they used two separate connections
|
||||
to the same file; under load the second writer blew past busy_timeout and raised
|
||||
``sqlite3.OperationalError: database is locked``. After the fix all local writes are
|
||||
serialized through one shared, lock-guarded connection, so a lock error is impossible.
|
||||
"""
|
||||
|
||||
def test_concurrent_stores_all_rows_land(self, tmp_path):
|
||||
"""Many threads calling memory_store concurrently → every row lands, no lock error."""
|
||||
db_path = str(tmp_path / "concurrent.db")
|
||||
server = MemoryServer(sqlite_db_path=db_path)
|
||||
try:
|
||||
n_threads = 16
|
||||
per_thread = 12
|
||||
errors: list[BaseException] = []
|
||||
barrier = threading.Barrier(n_threads)
|
||||
|
||||
def worker(tid: int) -> None:
|
||||
barrier.wait() # release all threads at once to maximise contention
|
||||
for i in range(per_thread):
|
||||
try:
|
||||
server.memory_store({
|
||||
"content": f"thread {tid} memory {i}",
|
||||
"expanded_keywords": f"thread {tid} memory {i} concurrent write",
|
||||
})
|
||||
except BaseException as exc: # noqa: BLE001 — capture everything for the assert
|
||||
errors.append(exc)
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(t,)) for t in range(n_threads)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert errors == [], f"concurrent stores raised: {errors!r}"
|
||||
assert _server_rows(server) == n_threads * per_thread
|
||||
finally:
|
||||
if server.sqlite_conn:
|
||||
server.sqlite_conn.close()
|
||||
|
||||
def test_concurrent_stores_while_sync_writer_active_no_lock(self, tmp_path):
|
||||
"""Store path + background sync writer hammer the SAME file → no 'database is locked'.
|
||||
|
||||
Reproduces the production failure: ``MemoryServer`` and ``SyncEngine`` both write to
|
||||
one SQLite file. We shrink ``busy_timeout`` so the structural race (two writers fighting
|
||||
the single SQLite write lock) surfaces within seconds instead of 30s. The shared-writer
|
||||
fix makes a lock error impossible regardless of busy_timeout.
|
||||
"""
|
||||
from claude_memory.sync import SyncEngine
|
||||
|
||||
db_path = str(tmp_path / "hybrid.db")
|
||||
server = MemoryServer(sqlite_db_path=db_path)
|
||||
# The production hybrid wiring: the sync engine SHARES the server's LocalStore
|
||||
# (one connection, one lock) — the structural fix for the cross-connection race.
|
||||
engine = SyncEngine(
|
||||
db_path=db_path,
|
||||
api_base_url="http://fake-api:8080",
|
||||
api_key="test-key",
|
||||
sync_interval=3600, # never auto-syncs; we drive the writer by hand
|
||||
store=server.store,
|
||||
)
|
||||
assert engine._conn is server.sqlite_conn # truly one shared connection
|
||||
|
||||
# Shrink the busy timeout so that, were there still two writers, the race would
|
||||
# surface in ms not 30s. With one shared connection a lock error is impossible.
|
||||
server.sqlite_conn.execute("PRAGMA busy_timeout=50")
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# A write-heavy pull: many rows upserted inside the sync engine's single lock block —
|
||||
# exactly the kind of long-held writer that starved the store path.
|
||||
def big_pull(method: str, path: str, body: Any = None) -> dict[str, Any]:
|
||||
return {
|
||||
"memories": [
|
||||
{
|
||||
"id": 10_000 + j,
|
||||
"content": f"server row {j}",
|
||||
"category": "facts",
|
||||
"tags": "",
|
||||
"expanded_keywords": "",
|
||||
"importance": 0.5,
|
||||
"is_sensitive": False,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"deleted_at": None,
|
||||
}
|
||||
for j in range(120)
|
||||
],
|
||||
"server_time": now,
|
||||
}
|
||||
|
||||
errors: list[BaseException] = []
|
||||
stop = threading.Event()
|
||||
|
||||
def sync_writer() -> None:
|
||||
with patch.object(engine, "_api_request", side_effect=big_pull):
|
||||
while not stop.is_set():
|
||||
try:
|
||||
engine._pull_changes()
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
|
||||
n_threads = 12
|
||||
per_thread = 12
|
||||
barrier = threading.Barrier(n_threads)
|
||||
|
||||
def store_worker(tid: int) -> None:
|
||||
barrier.wait()
|
||||
for i in range(per_thread):
|
||||
try:
|
||||
server.memory_store({
|
||||
"content": f"local {tid}.{i}",
|
||||
"expanded_keywords": f"local thread {tid} item {i} keywords here",
|
||||
})
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
|
||||
writer = threading.Thread(target=sync_writer, daemon=True)
|
||||
writer.start()
|
||||
final_rows = 0
|
||||
try:
|
||||
store_threads = [threading.Thread(target=store_worker, args=(t,)) for t in range(n_threads)]
|
||||
for t in store_threads:
|
||||
t.start()
|
||||
for t in store_threads:
|
||||
t.join()
|
||||
stop.set()
|
||||
writer.join(timeout=5)
|
||||
final_rows = _server_rows(server) # read while the connection is still open
|
||||
finally:
|
||||
stop.set()
|
||||
engine.stop() # shares the server's store → does not close the connection
|
||||
if server.sqlite_conn:
|
||||
server.sqlite_conn.close()
|
||||
|
||||
locked = [e for e in errors if isinstance(e, sqlite3.OperationalError) and "locked" in str(e)]
|
||||
assert locked == [], f"hit 'database is locked' under concurrency: {locked!r}"
|
||||
assert errors == [], f"concurrent writes raised: {errors!r}"
|
||||
# Every local store must have landed.
|
||||
assert final_rows >= n_threads * per_thread
|
||||
|
||||
|
||||
class TestRecallProgressAndBounding:
|
||||
"""The slow path — a remote semantic ``memory_recall`` — must be bounded and give a
|
||||
notion of progress instead of hanging silently and dropping the session."""
|
||||
|
||||
def test_remote_recall_timeout_returns_clear_signal_not_raise(self, server):
|
||||
"""A timed-out / unreachable remote recall returns a clear 'retry' message, never hangs/raises."""
|
||||
import claude_memory.mcp_server as m
|
||||
|
||||
with patch.object(m, "_api_request", side_effect=RuntimeError("API connection error: timed out")):
|
||||
text = server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
||||
|
||||
assert "recall" in text.lower()
|
||||
# Mentions the bound and that the caller should retry — a clear working/not-done signal.
|
||||
assert "retry" in text.lower() or "again" in text.lower()
|
||||
assert str(m.RECALL_TIMEOUT) in text
|
||||
|
||||
def test_remote_recall_success_formats_rows(self, server):
|
||||
"""A successful remote recall still formats results normally."""
|
||||
import claude_memory.mcp_server as m
|
||||
|
||||
payload = {"memories": [
|
||||
{"id": 7, "category": "facts", "importance": 0.8, "content": "hello",
|
||||
"tags": "t", "created_at": "2026-01-01T00:00:00+00:00"},
|
||||
]}
|
||||
with patch.object(m, "_api_request", return_value=payload):
|
||||
text = server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
||||
|
||||
assert "Found 1 memories" in text
|
||||
assert "hello" in text
|
||||
|
||||
def test_remote_recall_uses_bounded_timeout(self, server):
|
||||
"""The remote recall passes the bounded RECALL_TIMEOUT to the HTTP layer."""
|
||||
import claude_memory.mcp_server as m
|
||||
|
||||
with patch.object(m, "_api_request", return_value={"memories": []}) as mock_api:
|
||||
server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
||||
|
||||
_, kwargs = mock_api.call_args
|
||||
assert kwargs.get("timeout") == m.RECALL_TIMEOUT
|
||||
|
||||
def test_api_request_accepts_timeout_kwarg(self):
|
||||
"""Module-level _api_request must accept an explicit timeout without breaking callers."""
|
||||
import inspect
|
||||
import claude_memory.mcp_server as m
|
||||
|
||||
sig = inspect.signature(m._api_request)
|
||||
assert "timeout" in sig.parameters
|
||||
|
||||
def test_progress_notification_emitted_for_recall_with_token(self, server):
|
||||
"""When the client supplies a progressToken, a 'working' progress notification is emitted."""
|
||||
sent: list[dict[str, Any]] = []
|
||||
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
||||
|
||||
with patch.object(server, "memory_recall", return_value="ok"):
|
||||
server.handle_tools_call({
|
||||
"name": "memory_recall",
|
||||
"arguments": {"context": "x", "expanded_query": "a b c d e"},
|
||||
"_meta": {"progressToken": "tok-1"},
|
||||
})
|
||||
|
||||
progress = [s for s in sent if s["method"] == "notifications/progress"]
|
||||
assert progress, "expected a notifications/progress for a tokened recall"
|
||||
assert progress[0]["params"]["progressToken"] == "tok-1"
|
||||
|
||||
def test_no_progress_notification_without_token(self, server):
|
||||
"""No token → no progress notification (don't spam clients that didn't opt in)."""
|
||||
sent: list[dict[str, Any]] = []
|
||||
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
||||
|
||||
with patch.object(server, "memory_recall", return_value="ok"):
|
||||
server.handle_tools_call({
|
||||
"name": "memory_recall",
|
||||
"arguments": {"context": "x", "expanded_query": "a b c d e"},
|
||||
})
|
||||
|
||||
assert [s for s in sent if s["method"] == "notifications/progress"] == []
|
||||
|
||||
def test_fast_tools_do_not_emit_progress(self, server):
|
||||
"""Only the slow recall path signals progress; a store with a token stays quiet."""
|
||||
sent: list[dict[str, Any]] = []
|
||||
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
||||
|
||||
server.handle_tools_call({
|
||||
"name": "memory_store",
|
||||
"arguments": {"content": "x", "expanded_keywords": "a b c d e"},
|
||||
"_meta": {"progressToken": "tok-2"},
|
||||
})
|
||||
|
||||
assert [s for s in sent if s["method"] == "notifications/progress"] == []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue