Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5151bbe0d5 | ||
|
|
68088e684e |
6 changed files with 591 additions and 116 deletions
1
.github/workflows/build.yml
vendored
1
.github/workflows/build.yml
vendored
|
|
@ -70,6 +70,7 @@ jobs:
|
|||
- uses: docker/build-push-action@v7
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
push: true
|
||||
platforms: linux/amd64
|
||||
# Single-manifest images (no provenance/SBOM attestation children) so
|
||||
|
|
|
|||
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -48,3 +48,9 @@ docker/pgdata/
|
|||
# Beads / Dolt files (added by bd init)
|
||||
.dolt/
|
||||
.beads-credential-key
|
||||
|
||||
# Agent git worktrees (standing policy: never the shared checkout)
|
||||
.worktrees/
|
||||
|
||||
# uv lockfile — CI runs `uv sync` itself; not tracked
|
||||
uv.lock
|
||||
|
|
|
|||
116
src/claude_memory/local_store.py
Normal file
116
src/claude_memory/local_store.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""Single, process-wide serialized SQLite writer for the local memory cache.
|
||||
|
||||
SQLite permits only one writer at a time. The MCP server's store path and the
|
||||
background sync engine used to open *separate* connections to the *same* file;
|
||||
under heavy concurrent ``memory_store`` calls those two writers fought over the
|
||||
single SQLite write lock, blew past ``busy_timeout``, and surfaced
|
||||
``sqlite3.OperationalError: database is locked`` — which made the tool slow and
|
||||
eventually dropped the session.
|
||||
|
||||
``LocalStore`` fixes this structurally: it owns ONE connection (opened with
|
||||
``check_same_thread=False``) guarded by ONE re-entrant lock. Every component that
|
||||
needs to touch the local DB shares the same ``LocalStore`` instance, so all
|
||||
writes serialize cleanly through the in-process lock and queue instead of racing
|
||||
the SQLite writer. On the rare residual lock (e.g. another OS process touching
|
||||
the file), writes retry with bounded exponential backoff rather than failing the
|
||||
caller. WAL stays on for concurrent reads.
|
||||
|
||||
Uses only stdlib — no pip install required.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Bounded retry window for the rare residual "database is locked" — handles a
|
||||
# lock held by a *different OS process* (the in-process lock already serializes
|
||||
# this process's own writers). Total worst-case wait ≈ 0.05+0.1+0.2+0.4+0.8 ≈ 1.55s.
|
||||
_MAX_RETRIES = 5
|
||||
_BASE_BACKOFF_S = 0.05
|
||||
_BUSY_TIMEOUT_MS = 30000
|
||||
|
||||
|
||||
def _is_locked_error(exc: sqlite3.OperationalError) -> bool:
|
||||
msg = str(exc).lower()
|
||||
return "database is locked" in msg or "database is busy" in msg
|
||||
|
||||
|
||||
class LocalStore:
|
||||
"""Owns the single shared SQLite connection + lock for local memory writes."""
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
self.conn = conn
|
||||
# Re-entrant so a transaction callback may itself call ``execute``/``write``
|
||||
# without dead-locking on the same thread.
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# ── construction ────────────────────────────────────────────────
|
||||
|
||||
@classmethod
|
||||
def open(cls, db_path: str) -> "LocalStore":
|
||||
"""Open (creating parent dirs) a WAL connection safe for cross-thread use."""
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(db_path, timeout=30.0, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(f"PRAGMA busy_timeout={_BUSY_TIMEOUT_MS}")
|
||||
return cls(conn)
|
||||
|
||||
# ── serialized access ───────────────────────────────────────────
|
||||
|
||||
def transaction(self, fn: Callable[[sqlite3.Connection], T]) -> T:
|
||||
"""Run ``fn(conn)`` holding the shared lock, with bounded retry on lock errors.
|
||||
|
||||
``fn`` is responsible for issuing its own ``COMMIT`` (call ``conn.commit()``)
|
||||
when it performs writes. The whole callback runs under the process-wide lock,
|
||||
so concurrent callers queue rather than collide on the SQLite writer.
|
||||
"""
|
||||
last_exc: sqlite3.OperationalError | None = None
|
||||
for attempt in range(_MAX_RETRIES):
|
||||
with self.lock:
|
||||
try:
|
||||
return fn(self.conn)
|
||||
except sqlite3.OperationalError as exc:
|
||||
if not _is_locked_error(exc):
|
||||
raise
|
||||
last_exc = exc
|
||||
# Roll back any partial txn so the retry starts clean and the
|
||||
# connection isn't left mid-transaction holding locks.
|
||||
try:
|
||||
self.conn.rollback()
|
||||
except sqlite3.Error:
|
||||
pass
|
||||
# Back off *outside* the lock so other writers can make progress.
|
||||
backoff = _BASE_BACKOFF_S * (2 ** attempt)
|
||||
logger.warning(
|
||||
"SQLite locked (attempt %d/%d) — backing off %.3fs", attempt + 1, _MAX_RETRIES, backoff
|
||||
)
|
||||
time.sleep(backoff)
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
||||
def execute(self, sql: str, params: tuple[Any, ...] = ()) -> sqlite3.Cursor:
|
||||
"""Run a single read query under the shared lock (no implicit commit)."""
|
||||
with self.lock:
|
||||
return self.conn.execute(sql, params)
|
||||
|
||||
def write(self, sql: str, params: tuple[Any, ...] = ()) -> sqlite3.Cursor:
|
||||
"""Run a single write statement + commit, serialized and retry-guarded."""
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> sqlite3.Cursor:
|
||||
cur = conn.execute(sql, params)
|
||||
conn.commit()
|
||||
return cur
|
||||
|
||||
return self.transaction(_do)
|
||||
|
||||
def close(self) -> None:
|
||||
with self.lock:
|
||||
self.conn.close()
|
||||
|
|
@ -17,7 +17,10 @@ import sqlite3
|
|||
import sys
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from claude_memory.local_store import LocalStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,9 +38,17 @@ HYBRID_MODE = bool(API_KEY) and not SYNC_DISABLED
|
|||
HTTP_ONLY = bool(API_KEY) and SYNC_DISABLED
|
||||
SQLITE_ONLY = not API_KEY
|
||||
|
||||
# Default per-request HTTP timeout, and a wider bound for the one genuinely slow path:
|
||||
# a remote semantic ``memory_recall`` (embedding/search can be slow to warm up). Both are
|
||||
# explicit upper bounds so a call can never hang the MCP server indefinitely.
|
||||
DEFAULT_API_TIMEOUT = 15
|
||||
RECALL_TIMEOUT = int(os.environ.get("MEMORY_RECALL_TIMEOUT", "30"))
|
||||
|
||||
def _api_request(method: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Make an HTTP request to the memory API."""
|
||||
|
||||
def _api_request(
|
||||
method: str, path: str, body: dict[str, Any] | None = None, timeout: int = DEFAULT_API_TIMEOUT
|
||||
) -> dict[str, Any]:
|
||||
"""Make an HTTP request to the memory API (bounded by ``timeout`` seconds)."""
|
||||
url = f"{API_BASE_URL}{path}"
|
||||
data = json.dumps(body).encode() if body else None
|
||||
req = urllib.request.Request(
|
||||
|
|
@ -50,7 +61,7 @@ def _api_request(method: str, path: str, body: dict[str, Any] | None = None) ->
|
|||
},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
result: dict[str, Any] = json.loads(resp.read().decode())
|
||||
return result
|
||||
except urllib.error.HTTPError as e:
|
||||
|
|
@ -128,7 +139,9 @@ def _init_sqlite(db_path: str | None = None) -> tuple[sqlite3.Connection, str]:
|
|||
db_path = _get_db_path(db_path)
|
||||
Path(os.path.dirname(db_path)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(db_path, timeout=30.0)
|
||||
# check_same_thread=False: the MCP server may handle requests on different
|
||||
# threads and shares this one connection via LocalStore's lock (see local_store.py).
|
||||
conn = sqlite3.connect(db_path, timeout=30.0, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=30000")
|
||||
|
|
@ -390,19 +403,30 @@ class MemoryServer:
|
|||
|
||||
def __init__(self, sqlite_db_path: str | None = None) -> None:
|
||||
self.sqlite_conn: sqlite3.Connection | None = None
|
||||
self.store: "LocalStore | None" = None # single serialized writer (see local_store.py)
|
||||
self.sync_engine: Any = None
|
||||
# Sink for MCP notifications (e.g. progress). Defaults to writing a JSON-RPC
|
||||
# notification line to stdout; overridable in tests.
|
||||
self._notify: Callable[[str, dict[str, Any]], None] = self._emit_notification
|
||||
|
||||
if SQLITE_ONLY or HYBRID_MODE:
|
||||
self.sqlite_conn, resolved_path = _init_sqlite(sqlite_db_path)
|
||||
conn, resolved_path = _init_sqlite(sqlite_db_path)
|
||||
from claude_memory.local_store import LocalStore
|
||||
self.store = LocalStore(conn)
|
||||
self.sqlite_conn = conn
|
||||
|
||||
if HYBRID_MODE:
|
||||
from claude_memory.sync import SyncEngine
|
||||
sync_interval = int(os.environ.get("MEMORY_SYNC_INTERVAL", "60"))
|
||||
# Share the SAME LocalStore (one connection, one lock) so the background
|
||||
# sync writer never opens a second connection to the file and never races
|
||||
# the store path on the single SQLite writer.
|
||||
self.sync_engine = SyncEngine(
|
||||
db_path=resolved_path,
|
||||
api_base_url=API_BASE_URL,
|
||||
api_key=API_KEY,
|
||||
sync_interval=sync_interval,
|
||||
store=self.store,
|
||||
)
|
||||
self.sync_engine.start()
|
||||
|
||||
|
|
@ -455,31 +479,59 @@ class MemoryServer:
|
|||
limit = args.get("limit", 10)
|
||||
|
||||
if HTTP_ONLY:
|
||||
result = _api_request("POST", "/api/memories/recall", {
|
||||
"context": context,
|
||||
"expanded_query": expanded_query,
|
||||
"category": category,
|
||||
"sort_by": sort_by,
|
||||
"limit": limit,
|
||||
})
|
||||
rows = result.get("memories", [])
|
||||
if not rows:
|
||||
filter_desc = f" in category '{category}'" if category else ""
|
||||
return f"No memories found matching: {context}{filter_desc}"
|
||||
|
||||
sort_desc = "by relevance" if sort_by == "relevance" else "by importance"
|
||||
filter_desc = f" in '{category}'" if category else ""
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}"
|
||||
f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
|
||||
)
|
||||
return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results)
|
||||
return self._recall_remote(args)
|
||||
|
||||
# SQLite-only or Hybrid: always read from local cache
|
||||
return self._sqlite_recall(context, expanded_query, category, sort_by, limit)
|
||||
|
||||
def _recall_remote(self, args: dict[str, Any]) -> str:
|
||||
"""Remote semantic recall — the one genuinely slow path.
|
||||
|
||||
Bounded by ``RECALL_TIMEOUT`` so it can never hang the MCP server. On a timeout
|
||||
or unreachable backend it returns a clear "working / not done — retry" signal
|
||||
rather than raising or blocking silently.
|
||||
"""
|
||||
context = args.get("context")
|
||||
expanded_query = args.get("expanded_query", "")
|
||||
category = args.get("category")
|
||||
sort_by = args.get("sort_by", "importance")
|
||||
limit = args.get("limit", 10)
|
||||
|
||||
try:
|
||||
result = _api_request(
|
||||
"POST", "/api/memories/recall",
|
||||
{
|
||||
"context": context,
|
||||
"expanded_query": expanded_query,
|
||||
"category": category,
|
||||
"sort_by": sort_by,
|
||||
"limit": limit,
|
||||
},
|
||||
timeout=RECALL_TIMEOUT,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# _api_request wraps connection errors / timeouts as RuntimeError. Surface a
|
||||
# clear, actionable signal instead of hanging or dumping a stack trace.
|
||||
return (
|
||||
f"Memory recall did not complete within {RECALL_TIMEOUT}s — the semantic "
|
||||
f"search backend is slow or unreachable ({e}). Please try again."
|
||||
)
|
||||
|
||||
rows = result.get("memories", [])
|
||||
if not rows:
|
||||
filter_desc = f" in category '{category}'" if category else ""
|
||||
return f"No memories found matching: {context}{filter_desc}"
|
||||
|
||||
sort_desc = "by relevance" if sort_by == "relevance" else "by importance"
|
||||
filter_desc = f" in '{category}'" if category else ""
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}"
|
||||
f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
|
||||
)
|
||||
return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results)
|
||||
|
||||
def memory_list(self, args: dict[str, Any]) -> str:
|
||||
category = args.get("category")
|
||||
limit = args.get("limit", 20)
|
||||
|
|
@ -519,10 +571,11 @@ class MemoryServer:
|
|||
# SQLite-only or Hybrid: delete from local SQLite
|
||||
# In hybrid mode, also try to sync delete to server
|
||||
server_id: int | None = None
|
||||
if HYBRID_MODE and self.sync_engine and self.sqlite_conn:
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
cursor.execute("SELECT server_id FROM memories WHERE id = ?", (memory_id,))
|
||||
row = cursor.fetchone()
|
||||
if HYBRID_MODE and self.sync_engine and self.store:
|
||||
with self.store.lock:
|
||||
cursor = self.store.conn.cursor()
|
||||
cursor.execute("SELECT server_id FROM memories WHERE id = ?", (memory_id,))
|
||||
row = cursor.fetchone()
|
||||
server_id = row["server_id"] if row and row["server_id"] else None
|
||||
|
||||
result_text = self._sqlite_delete(memory_id)
|
||||
|
|
@ -563,12 +616,13 @@ class MemoryServer:
|
|||
lines.append(f"Last sync success: {counts['last_sync_success']}")
|
||||
return "\n".join(lines)
|
||||
|
||||
if self.sqlite_conn:
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) as c FROM memories")
|
||||
total = cursor.fetchone()["c"]
|
||||
cursor.execute("SELECT category, COUNT(*) as c FROM memories GROUP BY category ORDER BY c DESC")
|
||||
by_cat = cursor.fetchall()
|
||||
if self.store:
|
||||
with self.store.lock:
|
||||
cursor = self.store.conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) as c FROM memories")
|
||||
total = cursor.fetchone()["c"]
|
||||
cursor.execute("SELECT category, COUNT(*) as c FROM memories GROUP BY category ORDER BY c DESC")
|
||||
by_cat = cursor.fetchall()
|
||||
lines = [f"Local memories (SQLite-only): {total}"]
|
||||
for row in by_cat:
|
||||
lines.append(f" {row['category']}: {row['c']}")
|
||||
|
|
@ -682,19 +736,25 @@ class MemoryServer:
|
|||
def _sqlite_store(self, content: str, category: str, tags: str, importance: float, expanded_keywords: str, force_sensitive: bool = False) -> str:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
assert self.sqlite_conn is not None
|
||||
assert self.store is not None
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
is_sensitive = 1 if force_sensitive else 0
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(content, category, tags, expanded_keywords, importance, is_sensitive, now, now),
|
||||
)
|
||||
self.sqlite_conn.commit()
|
||||
return f"Stored memory #{cursor.lastrowid} in category '{category}' with importance {importance:.1f}"
|
||||
|
||||
def _insert(conn: sqlite3.Connection) -> int | None:
|
||||
cursor = conn.execute(
|
||||
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(content, category, tags, expanded_keywords, importance, is_sensitive, now, now),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
# Serialized + retry-guarded through the shared LocalStore so concurrent
|
||||
# stores (and the background sync writer) never collide on the SQLite writer.
|
||||
new_id = self.store.transaction(_insert)
|
||||
return f"Stored memory #{new_id} in category '{category}' with importance {importance:.1f}"
|
||||
|
||||
def _sqlite_recall(self, context: str, expanded_query: str, category: str | None, sort_by: str, limit: int) -> str:
|
||||
assert self.sqlite_conn is not None
|
||||
assert self.store is not None
|
||||
all_terms = f"{context} {expanded_query}".strip()
|
||||
words = [w.replace(chr(34), "") for w in all_terms.split() if w]
|
||||
and_query = " AND ".join(f'"{w}"' for w in words)
|
||||
|
|
@ -712,35 +772,38 @@ class MemoryServer:
|
|||
"SELECT m.id, m.content, m.category, m.tags, m.importance, m.created_at "
|
||||
"FROM memories m JOIN memories_fts fts ON m.id = fts.rowid "
|
||||
)
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
rows: list[Any] = []
|
||||
try:
|
||||
# Try AND first for precise matches, fall back to OR for broader results
|
||||
cat_filter = "AND m.category = ?" if category else ""
|
||||
for fts_query in (and_query, or_query):
|
||||
params = [fts_query, category, limit] if category else [fts_query, limit]
|
||||
cursor.execute(
|
||||
f"{base_select}WHERE memories_fts MATCH ? {cat_filter} ORDER BY {order} LIMIT ?",
|
||||
tuple(p for p in params if p is not None),
|
||||
)
|
||||
# Hold the shared lock for the whole read — the connection is shared across
|
||||
# threads with the sync writer, so reads must serialize too.
|
||||
with self.store.lock:
|
||||
cursor = self.store.conn.cursor()
|
||||
try:
|
||||
# Try AND first for precise matches, fall back to OR for broader results
|
||||
cat_filter = "AND m.category = ?" if category else ""
|
||||
for fts_query in (and_query, or_query):
|
||||
params = [fts_query, category, limit] if category else [fts_query, limit]
|
||||
cursor.execute(
|
||||
f"{base_select}WHERE memories_fts MATCH ? {cat_filter} ORDER BY {order} LIMIT ?",
|
||||
tuple(p for p in params if p is not None),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
if rows:
|
||||
break
|
||||
except sqlite3.OperationalError:
|
||||
like = f"%{context}%"
|
||||
if category:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"WHERE (content LIKE ? OR tags LIKE ?) AND category = ? ORDER BY importance DESC LIMIT ?",
|
||||
(like, like, category, limit),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"WHERE content LIKE ? OR tags LIKE ? ORDER BY importance DESC LIMIT ?",
|
||||
(like, like, limit),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
if rows:
|
||||
break
|
||||
except sqlite3.OperationalError:
|
||||
like = f"%{context}%"
|
||||
if category:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"WHERE (content LIKE ? OR tags LIKE ?) AND category = ? ORDER BY importance DESC LIMIT ?",
|
||||
(like, like, category, limit),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"WHERE content LIKE ? OR tags LIKE ? ORDER BY importance DESC LIMIT ?",
|
||||
(like, like, limit),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
return f"No memories found matching: {context}"
|
||||
|
|
@ -757,21 +820,22 @@ class MemoryServer:
|
|||
)
|
||||
|
||||
def _sqlite_list(self, category: str | None, limit: int) -> str:
|
||||
assert self.sqlite_conn is not None
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
if category:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"WHERE category = ? ORDER BY created_at DESC LIMIT ?",
|
||||
(category, limit),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"ORDER BY created_at DESC LIMIT ?",
|
||||
(limit,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
assert self.store is not None
|
||||
with self.store.lock:
|
||||
cursor = self.store.conn.cursor()
|
||||
if category:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"WHERE category = ? ORDER BY created_at DESC LIMIT ?",
|
||||
(category, limit),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"ORDER BY created_at DESC LIMIT ?",
|
||||
(limit,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
if not rows:
|
||||
return f"No memories in category '{category}'" if category else "No memories stored yet"
|
||||
|
||||
|
|
@ -785,25 +849,30 @@ class MemoryServer:
|
|||
return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results)
|
||||
|
||||
def _sqlite_delete(self, memory_id: int) -> str:
|
||||
assert self.sqlite_conn is not None
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
cursor.execute("SELECT id, content FROM memories WHERE id = ?", (memory_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return f"Memory #{memory_id} not found"
|
||||
preview = row["content"][:50]
|
||||
cursor.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
|
||||
self.sqlite_conn.commit()
|
||||
return f"Deleted memory #{memory_id}: {preview}..."
|
||||
assert self.store is not None
|
||||
|
||||
def _delete(conn: sqlite3.Connection) -> str:
|
||||
cursor = conn.execute("SELECT id, content FROM memories WHERE id = ?", (memory_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return f"Memory #{memory_id} not found"
|
||||
preview = row["content"][:50]
|
||||
conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
|
||||
conn.commit()
|
||||
return f"Deleted memory #{memory_id}: {preview}..."
|
||||
|
||||
# SELECT + DELETE + commit as one serialized, retry-guarded unit.
|
||||
return self.store.transaction(_delete)
|
||||
|
||||
def _sqlite_secret_get(self, memory_id: int) -> str:
|
||||
assert self.sqlite_conn is not None
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, is_sensitive FROM memories WHERE id = ?",
|
||||
(memory_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert self.store is not None
|
||||
with self.store.lock:
|
||||
cursor = self.store.conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, is_sensitive FROM memories WHERE id = ?",
|
||||
(memory_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return f"Memory #{memory_id} not found"
|
||||
if not row["is_sensitive"]:
|
||||
|
|
@ -825,9 +894,25 @@ class MemoryServer:
|
|||
tools.extend(SHARING_TOOLS)
|
||||
return {"tools": tools}
|
||||
|
||||
# Tools whose work is genuinely slow enough to warrant a progress signal.
|
||||
_SLOW_TOOLS = frozenset({"memory_recall"})
|
||||
|
||||
def handle_tools_call(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
tool_name: str = params.get("name", "")
|
||||
arguments: dict[str, Any] = params.get("arguments", {})
|
||||
|
||||
# If the client opted into progress (sent a token) and this is a slow tool, emit a
|
||||
# single "working" progress notification so the call shows life instead of looking hung.
|
||||
progress_token = (params.get("_meta") or {}).get("progressToken")
|
||||
if progress_token is not None and tool_name in self._SLOW_TOOLS:
|
||||
try:
|
||||
self._notify(
|
||||
"notifications/progress",
|
||||
{"progressToken": progress_token, "progress": 0, "message": f"Running {tool_name}…"},
|
||||
)
|
||||
except Exception:
|
||||
pass # never let progress reporting break the actual call
|
||||
|
||||
try:
|
||||
handler = {
|
||||
"memory_store": self.memory_store,
|
||||
|
|
@ -851,6 +936,10 @@ class MemoryServer:
|
|||
except Exception as e:
|
||||
return {"content": [{"type": "text", "text": f"Error: {e!s}"}], "isError": True}
|
||||
|
||||
def _emit_notification(self, method: str, params: dict[str, Any]) -> None:
|
||||
"""Default notification sink: write a JSON-RPC notification line to stdout."""
|
||||
print(json.dumps({"jsonrpc": "2.0", "method": method, "params": params}), flush=True)
|
||||
|
||||
def process_message(self, message: dict[str, Any]) -> dict[str, Any] | None:
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
|
|
|
|||
|
|
@ -5,14 +5,14 @@ Uses only stdlib — no pip install required.
|
|||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from claude_memory.local_store import LocalStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -26,7 +26,14 @@ FULL_RESYNC_EVERY = 10
|
|||
class SyncEngine:
|
||||
"""Background sync between local SQLite cache and remote API."""
|
||||
|
||||
def __init__(self, db_path: str, api_base_url: str, api_key: str, sync_interval: int = 60):
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str,
|
||||
api_base_url: str,
|
||||
api_key: str,
|
||||
sync_interval: int = 60,
|
||||
store: "LocalStore | None" = None,
|
||||
):
|
||||
self.db_path = db_path
|
||||
self.api_base_url = api_base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
|
|
@ -37,13 +44,20 @@ class SyncEngine:
|
|||
self._last_sync_success = False
|
||||
self._auth_failed = False
|
||||
|
||||
# Own connection for thread safety
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._conn = sqlite3.connect(db_path, timeout=30.0, check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("PRAGMA busy_timeout=30000")
|
||||
self._lock = threading.Lock()
|
||||
# Share the caller's LocalStore (one connection, one lock) when given, so the
|
||||
# background sync writer never opens a SECOND connection to the same file and
|
||||
# never races the store path on the single SQLite writer. When run standalone
|
||||
# (e.g. tests), open our own store. Either way, all SQLite access below goes
|
||||
# through self._store / self._conn / self._lock, which now point at one shared
|
||||
# connection guarded by one re-entrant lock.
|
||||
if store is None:
|
||||
self._store = LocalStore.open(db_path)
|
||||
self._owns_store = True
|
||||
else:
|
||||
self._store = store
|
||||
self._owns_store = False
|
||||
self._conn = self._store.conn
|
||||
self._lock = self._store.lock
|
||||
|
||||
self._init_sync_tables()
|
||||
|
||||
|
|
@ -121,7 +135,10 @@ class SyncEngine:
|
|||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=5)
|
||||
self._conn.close()
|
||||
# Only close the connection if we opened it; when sharing the server's
|
||||
# LocalStore, the server owns the connection lifecycle.
|
||||
if self._owns_store:
|
||||
self._store.close()
|
||||
|
||||
def _sync_loop(self) -> None:
|
||||
"""Periodic sync loop running in background thread."""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,12 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -408,3 +413,244 @@ class TestSchemaMigration:
|
|||
columns = {row["name"] for row in cursor.fetchall()}
|
||||
assert "server_id" in columns
|
||||
srv.sqlite_conn.close()
|
||||
|
||||
|
||||
def _server_rows(server: MemoryServer) -> int:
|
||||
assert server.sqlite_conn is not None
|
||||
return int(server.sqlite_conn.execute("SELECT COUNT(*) AS c FROM memories").fetchone()["c"])
|
||||
|
||||
|
||||
class TestConcurrentWrites:
|
||||
"""Regression tests for 'database is locked' under heavy concurrent local writes.
|
||||
|
||||
The store path (MemoryServer) and the background sync writer (SyncEngine) must not
|
||||
contend on the single SQLite writer. Before the fix they used two separate connections
|
||||
to the same file; under load the second writer blew past busy_timeout and raised
|
||||
``sqlite3.OperationalError: database is locked``. After the fix all local writes are
|
||||
serialized through one shared, lock-guarded connection, so a lock error is impossible.
|
||||
"""
|
||||
|
||||
def test_concurrent_stores_all_rows_land(self, tmp_path):
|
||||
"""Many threads calling memory_store concurrently → every row lands, no lock error."""
|
||||
db_path = str(tmp_path / "concurrent.db")
|
||||
server = MemoryServer(sqlite_db_path=db_path)
|
||||
try:
|
||||
n_threads = 16
|
||||
per_thread = 12
|
||||
errors: list[BaseException] = []
|
||||
barrier = threading.Barrier(n_threads)
|
||||
|
||||
def worker(tid: int) -> None:
|
||||
barrier.wait() # release all threads at once to maximise contention
|
||||
for i in range(per_thread):
|
||||
try:
|
||||
server.memory_store({
|
||||
"content": f"thread {tid} memory {i}",
|
||||
"expanded_keywords": f"thread {tid} memory {i} concurrent write",
|
||||
})
|
||||
except BaseException as exc: # noqa: BLE001 — capture everything for the assert
|
||||
errors.append(exc)
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(t,)) for t in range(n_threads)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert errors == [], f"concurrent stores raised: {errors!r}"
|
||||
assert _server_rows(server) == n_threads * per_thread
|
||||
finally:
|
||||
if server.sqlite_conn:
|
||||
server.sqlite_conn.close()
|
||||
|
||||
def test_concurrent_stores_while_sync_writer_active_no_lock(self, tmp_path):
|
||||
"""Store path + background sync writer hammer the SAME file → no 'database is locked'.
|
||||
|
||||
Reproduces the production failure: ``MemoryServer`` and ``SyncEngine`` both write to
|
||||
one SQLite file. We shrink ``busy_timeout`` so the structural race (two writers fighting
|
||||
the single SQLite write lock) surfaces within seconds instead of 30s. The shared-writer
|
||||
fix makes a lock error impossible regardless of busy_timeout.
|
||||
"""
|
||||
from claude_memory.sync import SyncEngine
|
||||
|
||||
db_path = str(tmp_path / "hybrid.db")
|
||||
server = MemoryServer(sqlite_db_path=db_path)
|
||||
# The production hybrid wiring: the sync engine SHARES the server's LocalStore
|
||||
# (one connection, one lock) — the structural fix for the cross-connection race.
|
||||
engine = SyncEngine(
|
||||
db_path=db_path,
|
||||
api_base_url="http://fake-api:8080",
|
||||
api_key="test-key",
|
||||
sync_interval=3600, # never auto-syncs; we drive the writer by hand
|
||||
store=server.store,
|
||||
)
|
||||
assert engine._conn is server.sqlite_conn # truly one shared connection
|
||||
|
||||
# Shrink the busy timeout so that, were there still two writers, the race would
|
||||
# surface in ms not 30s. With one shared connection a lock error is impossible.
|
||||
server.sqlite_conn.execute("PRAGMA busy_timeout=50")
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# A write-heavy pull: many rows upserted inside the sync engine's single lock block —
|
||||
# exactly the kind of long-held writer that starved the store path.
|
||||
def big_pull(method: str, path: str, body: Any = None) -> dict[str, Any]:
|
||||
return {
|
||||
"memories": [
|
||||
{
|
||||
"id": 10_000 + j,
|
||||
"content": f"server row {j}",
|
||||
"category": "facts",
|
||||
"tags": "",
|
||||
"expanded_keywords": "",
|
||||
"importance": 0.5,
|
||||
"is_sensitive": False,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"deleted_at": None,
|
||||
}
|
||||
for j in range(120)
|
||||
],
|
||||
"server_time": now,
|
||||
}
|
||||
|
||||
errors: list[BaseException] = []
|
||||
stop = threading.Event()
|
||||
|
||||
def sync_writer() -> None:
|
||||
with patch.object(engine, "_api_request", side_effect=big_pull):
|
||||
while not stop.is_set():
|
||||
try:
|
||||
engine._pull_changes()
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
|
||||
n_threads = 12
|
||||
per_thread = 12
|
||||
barrier = threading.Barrier(n_threads)
|
||||
|
||||
def store_worker(tid: int) -> None:
|
||||
barrier.wait()
|
||||
for i in range(per_thread):
|
||||
try:
|
||||
server.memory_store({
|
||||
"content": f"local {tid}.{i}",
|
||||
"expanded_keywords": f"local thread {tid} item {i} keywords here",
|
||||
})
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
|
||||
writer = threading.Thread(target=sync_writer, daemon=True)
|
||||
writer.start()
|
||||
final_rows = 0
|
||||
try:
|
||||
store_threads = [threading.Thread(target=store_worker, args=(t,)) for t in range(n_threads)]
|
||||
for t in store_threads:
|
||||
t.start()
|
||||
for t in store_threads:
|
||||
t.join()
|
||||
stop.set()
|
||||
writer.join(timeout=5)
|
||||
final_rows = _server_rows(server) # read while the connection is still open
|
||||
finally:
|
||||
stop.set()
|
||||
engine.stop() # shares the server's store → does not close the connection
|
||||
if server.sqlite_conn:
|
||||
server.sqlite_conn.close()
|
||||
|
||||
locked = [e for e in errors if isinstance(e, sqlite3.OperationalError) and "locked" in str(e)]
|
||||
assert locked == [], f"hit 'database is locked' under concurrency: {locked!r}"
|
||||
assert errors == [], f"concurrent writes raised: {errors!r}"
|
||||
# Every local store must have landed.
|
||||
assert final_rows >= n_threads * per_thread
|
||||
|
||||
|
||||
class TestRecallProgressAndBounding:
|
||||
"""The slow path — a remote semantic ``memory_recall`` — must be bounded and give a
|
||||
notion of progress instead of hanging silently and dropping the session."""
|
||||
|
||||
def test_remote_recall_timeout_returns_clear_signal_not_raise(self, server):
|
||||
"""A timed-out / unreachable remote recall returns a clear 'retry' message, never hangs/raises."""
|
||||
import claude_memory.mcp_server as m
|
||||
|
||||
with patch.object(m, "_api_request", side_effect=RuntimeError("API connection error: timed out")):
|
||||
text = server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
||||
|
||||
assert "recall" in text.lower()
|
||||
# Mentions the bound and that the caller should retry — a clear working/not-done signal.
|
||||
assert "retry" in text.lower() or "again" in text.lower()
|
||||
assert str(m.RECALL_TIMEOUT) in text
|
||||
|
||||
def test_remote_recall_success_formats_rows(self, server):
|
||||
"""A successful remote recall still formats results normally."""
|
||||
import claude_memory.mcp_server as m
|
||||
|
||||
payload = {"memories": [
|
||||
{"id": 7, "category": "facts", "importance": 0.8, "content": "hello",
|
||||
"tags": "t", "created_at": "2026-01-01T00:00:00+00:00"},
|
||||
]}
|
||||
with patch.object(m, "_api_request", return_value=payload):
|
||||
text = server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
||||
|
||||
assert "Found 1 memories" in text
|
||||
assert "hello" in text
|
||||
|
||||
def test_remote_recall_uses_bounded_timeout(self, server):
|
||||
"""The remote recall passes the bounded RECALL_TIMEOUT to the HTTP layer."""
|
||||
import claude_memory.mcp_server as m
|
||||
|
||||
with patch.object(m, "_api_request", return_value={"memories": []}) as mock_api:
|
||||
server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
||||
|
||||
_, kwargs = mock_api.call_args
|
||||
assert kwargs.get("timeout") == m.RECALL_TIMEOUT
|
||||
|
||||
def test_api_request_accepts_timeout_kwarg(self):
|
||||
"""Module-level _api_request must accept an explicit timeout without breaking callers."""
|
||||
import inspect
|
||||
import claude_memory.mcp_server as m
|
||||
|
||||
sig = inspect.signature(m._api_request)
|
||||
assert "timeout" in sig.parameters
|
||||
|
||||
def test_progress_notification_emitted_for_recall_with_token(self, server):
|
||||
"""When the client supplies a progressToken, a 'working' progress notification is emitted."""
|
||||
sent: list[dict[str, Any]] = []
|
||||
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
||||
|
||||
with patch.object(server, "memory_recall", return_value="ok"):
|
||||
server.handle_tools_call({
|
||||
"name": "memory_recall",
|
||||
"arguments": {"context": "x", "expanded_query": "a b c d e"},
|
||||
"_meta": {"progressToken": "tok-1"},
|
||||
})
|
||||
|
||||
progress = [s for s in sent if s["method"] == "notifications/progress"]
|
||||
assert progress, "expected a notifications/progress for a tokened recall"
|
||||
assert progress[0]["params"]["progressToken"] == "tok-1"
|
||||
|
||||
def test_no_progress_notification_without_token(self, server):
|
||||
"""No token → no progress notification (don't spam clients that didn't opt in)."""
|
||||
sent: list[dict[str, Any]] = []
|
||||
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
||||
|
||||
with patch.object(server, "memory_recall", return_value="ok"):
|
||||
server.handle_tools_call({
|
||||
"name": "memory_recall",
|
||||
"arguments": {"context": "x", "expanded_query": "a b c d e"},
|
||||
})
|
||||
|
||||
assert [s for s in sent if s["method"] == "notifications/progress"] == []
|
||||
|
||||
def test_fast_tools_do_not_emit_progress(self, server):
|
||||
"""Only the slow recall path signals progress; a store with a token stays quiet."""
|
||||
sent: list[dict[str, Any]] = []
|
||||
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
||||
|
||||
server.handle_tools_call({
|
||||
"name": "memory_store",
|
||||
"arguments": {"content": "x", "expanded_keywords": "a b c d e"},
|
||||
"_meta": {"progressToken": "tok-2"},
|
||||
})
|
||||
|
||||
assert [s for s in sent if s["method"] == "notifications/progress"] == []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue