Under heavy concurrent memory_store (many subagents/sessions writing close
together) the local SQLite layer raced on the single SQLite writer and surfaced
sqlite3.OperationalError: database is locked, which made memory tools slow and
eventually dropped whole sessions. Two root causes:
- The MCP server (mcp_server.py) and the background SyncEngine (sync.py) each
opened a SEPARATE connection to the same SQLite file. WAL allows one writer;
when the sync writer held the lock across a resync, a concurrent store blew
past busy_timeout and raised "database is locked".
- mcp_server's connection was opened WITHOUT check_same_thread=False, so the
moment two requests were handled on different threads every local store/recall
raised ProgrammingError "SQLite objects created in a thread...".
Fix: a single process-wide serialized writer.
- New LocalStore (local_store.py) owns ONE connection (check_same_thread=False)
guarded by ONE re-entrant lock, keeps WAL, and wraps writes in
transaction()/write() with bounded exponential-backoff retry on the rare
residual lock (e.g. another OS process) instead of failing the call.
- MemoryServer builds the LocalStore and SHARES it with the SyncEngine, so the
sync writer no longer opens a second connection — the two-connections race is
eliminated. All server reads/writes go through the shared lock; stores stay
snappy (enqueue-local + async sync).
- Bound the one genuinely slow path (remote semantic memory_recall) with an
explicit RECALL_TIMEOUT and, on timeout/unreachable backend, return a clear
"working / retry" signal instead of hanging silently or crashing. When a
client supplies _meta.progressToken, emit one notifications/progress so the
call shows life.
Ships to users via the plugin (mcp/memory-mcp.json runs src/.../mcp_server.py);
no server-side/API change needed. TDD: added concurrency tests (many threads +
sync writer on one file) and recall progress/bounding tests; full gate green
(ruff + mypy strict + 185 pytest).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
656 lines
25 KiB
Python
656 lines
25 KiB
Python
"""Tests for the Claude Memory MCP server."""
|
|
|
|
import json
|
|
import os
|
|
import sqlite3
|
|
import sys
|
|
import threading
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
# Force SQLite fallback mode for all tests
|
|
os.environ.pop("MEMORY_API_KEY", None)
|
|
os.environ.pop("CLAUDE_MEMORY_API_KEY", None)
|
|
|
|
# Add src to path so we can import without installing
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
|
|
|
from claude_memory.mcp_server import MemoryServer, SERVER_NAME, SERVER_VERSION, PROTOCOL_VERSION
|
|
|
|
|
|
@pytest.fixture
|
|
def server(tmp_path):
|
|
"""Create a MemoryServer with a temporary SQLite database."""
|
|
db_path = str(tmp_path / "test_memory.db")
|
|
srv = MemoryServer(sqlite_db_path=db_path)
|
|
yield srv
|
|
if srv.sqlite_conn:
|
|
srv.sqlite_conn.close()
|
|
|
|
|
|
class TestSQLiteInit:
|
|
def test_creates_database(self, tmp_path):
|
|
db_path = str(tmp_path / "sub" / "test.db")
|
|
srv = MemoryServer(sqlite_db_path=db_path)
|
|
assert os.path.exists(db_path)
|
|
# Verify tables exist
|
|
cursor = srv.sqlite_conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories'")
|
|
assert cursor.fetchone() is not None
|
|
srv.sqlite_conn.close()
|
|
|
|
def test_creates_fts_table(self, tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
srv = MemoryServer(sqlite_db_path=db_path)
|
|
cursor = srv.sqlite_conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories_fts'")
|
|
assert cursor.fetchone() is not None
|
|
srv.sqlite_conn.close()
|
|
|
|
|
|
class TestMemoryStore:
|
|
def test_store_basic(self, server):
|
|
result = server.memory_store({
|
|
"content": "User prefers dark mode",
|
|
"expanded_keywords": "dark mode theme preference ui",
|
|
})
|
|
assert "Stored memory #1" in result
|
|
assert "facts" in result
|
|
|
|
def test_store_with_category(self, server):
|
|
result = server.memory_store({
|
|
"content": "User likes Python",
|
|
"category": "preferences",
|
|
"expanded_keywords": "python programming language preference",
|
|
})
|
|
assert "preferences" in result
|
|
|
|
def test_store_with_importance(self, server):
|
|
result = server.memory_store({
|
|
"content": "Critical info",
|
|
"importance": 0.9,
|
|
"expanded_keywords": "critical important info",
|
|
})
|
|
assert "0.9" in result
|
|
|
|
def test_store_requires_content(self, server):
|
|
with pytest.raises(ValueError, match="content is required"):
|
|
server.memory_store({"expanded_keywords": "test"})
|
|
|
|
def test_store_force_sensitive(self, server):
|
|
result = server.memory_store({
|
|
"content": "API key: sk-1234",
|
|
"force_sensitive": True,
|
|
"expanded_keywords": "api key secret credential",
|
|
})
|
|
assert "Stored memory #1" in result
|
|
# Verify is_sensitive flag is set
|
|
cursor = server.sqlite_conn.cursor()
|
|
cursor.execute("SELECT is_sensitive FROM memories WHERE id = 1")
|
|
row = cursor.fetchone()
|
|
assert row["is_sensitive"] == 1
|
|
|
|
|
|
class TestMemoryRecall:
|
|
def test_recall_finds_memory(self, server):
|
|
server.memory_store({
|
|
"content": "User works at Acme Corp",
|
|
"expanded_keywords": "acme corp company work employer",
|
|
})
|
|
result = server.memory_recall({
|
|
"context": "work",
|
|
"expanded_query": "company employer job",
|
|
})
|
|
assert "Acme Corp" in result
|
|
assert "Found 1 memories" in result
|
|
|
|
def test_recall_no_results(self, server):
|
|
result = server.memory_recall({
|
|
"context": "nonexistent topic",
|
|
"expanded_query": "nothing here at all",
|
|
})
|
|
assert "No memories found" in result
|
|
|
|
def test_recall_with_category_filter(self, server):
|
|
server.memory_store({
|
|
"content": "User prefers vim",
|
|
"category": "preferences",
|
|
"expanded_keywords": "vim editor preference text",
|
|
})
|
|
server.memory_store({
|
|
"content": "Project uses React",
|
|
"category": "projects",
|
|
"expanded_keywords": "react project frontend framework",
|
|
})
|
|
result = server.memory_recall({
|
|
"context": "preferences",
|
|
"expanded_query": "vim editor",
|
|
"category": "preferences",
|
|
})
|
|
assert "vim" in result
|
|
assert "React" not in result
|
|
|
|
def test_recall_requires_context(self, server):
|
|
with pytest.raises(ValueError, match="context is required"):
|
|
server.memory_recall({"expanded_query": "test"})
|
|
|
|
|
|
class TestMemoryList:
|
|
def test_list_empty(self, server):
|
|
result = server.memory_list({})
|
|
assert "No memories stored yet" in result
|
|
|
|
def test_list_with_memories(self, server):
|
|
server.memory_store({
|
|
"content": "Memory one",
|
|
"expanded_keywords": "one first test",
|
|
})
|
|
server.memory_store({
|
|
"content": "Memory two",
|
|
"expanded_keywords": "two second test",
|
|
})
|
|
result = server.memory_list({})
|
|
assert "Memory one" in result
|
|
assert "Memory two" in result
|
|
assert "2 shown" in result
|
|
|
|
def test_list_with_category(self, server):
|
|
server.memory_store({
|
|
"content": "A fact",
|
|
"category": "facts",
|
|
"expanded_keywords": "fact test",
|
|
})
|
|
server.memory_store({
|
|
"content": "A preference",
|
|
"category": "preferences",
|
|
"expanded_keywords": "preference test",
|
|
})
|
|
result = server.memory_list({"category": "facts"})
|
|
assert "A fact" in result
|
|
assert "A preference" not in result
|
|
|
|
def test_list_empty_category(self, server):
|
|
result = server.memory_list({"category": "projects"})
|
|
assert "No memories in category 'projects'" in result
|
|
|
|
def test_list_respects_limit(self, server):
|
|
for i in range(5):
|
|
server.memory_store({
|
|
"content": f"Memory {i}",
|
|
"expanded_keywords": f"memory number {i}",
|
|
})
|
|
result = server.memory_list({"limit": 2})
|
|
assert "2 shown" in result
|
|
|
|
|
|
class TestMemoryDelete:
|
|
def test_delete_existing(self, server):
|
|
server.memory_store({
|
|
"content": "To be deleted",
|
|
"expanded_keywords": "delete remove test",
|
|
})
|
|
result = server.memory_delete({"id": 1})
|
|
assert "Deleted memory #1" in result
|
|
assert "To be deleted" in result
|
|
|
|
def test_delete_nonexistent(self, server):
|
|
result = server.memory_delete({"id": 999})
|
|
assert "not found" in result
|
|
|
|
def test_delete_requires_id(self, server):
|
|
with pytest.raises(ValueError, match="id is required"):
|
|
server.memory_delete({})
|
|
|
|
|
|
class TestSecretGet:
|
|
def test_secret_get_sensitive(self, server):
|
|
server.memory_store({
|
|
"content": "secret password 12345",
|
|
"force_sensitive": True,
|
|
"expanded_keywords": "password secret credential",
|
|
})
|
|
result = server.secret_get({"id": 1})
|
|
assert "secret password 12345" in result
|
|
|
|
def test_secret_get_not_sensitive(self, server):
|
|
server.memory_store({
|
|
"content": "public info",
|
|
"expanded_keywords": "public info test",
|
|
})
|
|
result = server.secret_get({"id": 1})
|
|
assert "not marked as sensitive" in result
|
|
|
|
def test_secret_get_nonexistent(self, server):
|
|
result = server.secret_get({"id": 999})
|
|
assert "not found" in result
|
|
|
|
def test_secret_get_requires_id(self, server):
|
|
with pytest.raises(ValueError, match="id is required"):
|
|
server.secret_get({})
|
|
|
|
|
|
class TestMCPProtocol:
|
|
def test_handle_initialize(self, server):
|
|
result = server.handle_initialize({})
|
|
assert result["protocolVersion"] == PROTOCOL_VERSION
|
|
assert result["serverInfo"]["name"] == SERVER_NAME
|
|
assert result["serverInfo"]["version"] == SERVER_VERSION
|
|
assert "tools" in result["capabilities"]
|
|
|
|
def test_handle_tools_list(self, server):
|
|
result = server.handle_tools_list({})
|
|
tools = result["tools"]
|
|
assert len(tools) == 6
|
|
names = {t["name"] for t in tools}
|
|
assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get", "memory_count"}
|
|
|
|
def test_handle_tools_call_store(self, server):
|
|
result = server.handle_tools_call({
|
|
"name": "memory_store",
|
|
"arguments": {
|
|
"content": "test memory",
|
|
"expanded_keywords": "test memory keywords",
|
|
},
|
|
})
|
|
assert not result.get("isError", False)
|
|
assert "Stored memory" in result["content"][0]["text"]
|
|
|
|
def test_handle_tools_call_unknown(self, server):
|
|
result = server.handle_tools_call({
|
|
"name": "nonexistent_tool",
|
|
"arguments": {},
|
|
})
|
|
assert result["isError"] is True
|
|
assert "Unknown tool" in result["content"][0]["text"]
|
|
|
|
def test_handle_tools_call_error(self, server):
|
|
result = server.handle_tools_call({
|
|
"name": "memory_store",
|
|
"arguments": {}, # missing content
|
|
})
|
|
assert result["isError"] is True
|
|
assert "Error" in result["content"][0]["text"]
|
|
|
|
|
|
class TestProcessMessage:
|
|
def test_initialize(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"id": 1,
|
|
"method": "initialize",
|
|
"params": {},
|
|
})
|
|
assert response["jsonrpc"] == "2.0"
|
|
assert response["id"] == 1
|
|
assert "result" in response
|
|
assert response["result"]["serverInfo"]["name"] == SERVER_NAME
|
|
|
|
def test_tools_list(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"id": 2,
|
|
"method": "tools/list",
|
|
"params": {},
|
|
})
|
|
assert "result" in response
|
|
assert len(response["result"]["tools"]) == 6
|
|
|
|
def test_tools_call(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"id": 3,
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "memory_store",
|
|
"arguments": {
|
|
"content": "via process_message",
|
|
"expanded_keywords": "process message test",
|
|
},
|
|
},
|
|
})
|
|
assert "result" in response
|
|
assert "Stored memory" in response["result"]["content"][0]["text"]
|
|
|
|
def test_unknown_method(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"id": 4,
|
|
"method": "unknown/method",
|
|
"params": {},
|
|
})
|
|
assert "error" in response
|
|
assert response["error"]["code"] == -32601
|
|
assert "Method not found" in response["error"]["message"]
|
|
|
|
def test_notification_no_id(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"method": "notifications/initialized",
|
|
"params": {},
|
|
})
|
|
assert response is None
|
|
|
|
def test_jsonrpc_response_format(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"id": 5,
|
|
"method": "initialize",
|
|
"params": {},
|
|
})
|
|
# Verify it's valid JSON when serialized
|
|
serialized = json.dumps(response)
|
|
parsed = json.loads(serialized)
|
|
assert parsed["jsonrpc"] == "2.0"
|
|
assert parsed["id"] == 5
|
|
|
|
|
|
class TestMemoryCount:
|
|
def test_count_empty(self, server):
|
|
result = server.memory_count({})
|
|
assert "0" in result
|
|
|
|
def test_count_after_store(self, server):
|
|
server.memory_store({
|
|
"content": "test memory",
|
|
"expanded_keywords": "test memory keywords data",
|
|
})
|
|
result = server.memory_count({})
|
|
assert "1" in result
|
|
assert "facts" in result
|
|
|
|
def test_count_multiple_categories(self, server):
|
|
server.memory_store({
|
|
"content": "a fact",
|
|
"category": "facts",
|
|
"expanded_keywords": "fact test data words",
|
|
})
|
|
server.memory_store({
|
|
"content": "a preference",
|
|
"category": "preferences",
|
|
"expanded_keywords": "preference test data words",
|
|
})
|
|
result = server.memory_count({})
|
|
assert "facts: 1" in result
|
|
assert "preferences: 1" in result
|
|
|
|
def test_count_via_tools_call(self, server):
|
|
result = server.handle_tools_call({
|
|
"name": "memory_count",
|
|
"arguments": {},
|
|
})
|
|
assert not result.get("isError", False)
|
|
assert "0" in result["content"][0]["text"]
|
|
|
|
|
|
class TestSchemaMigration:
|
|
def test_schema_version_set(self, tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
srv = MemoryServer(sqlite_db_path=db_path)
|
|
cursor = srv.sqlite_conn.cursor()
|
|
version = cursor.execute("PRAGMA user_version").fetchone()[0]
|
|
assert version == 2
|
|
srv.sqlite_conn.close()
|
|
|
|
def test_migration_idempotent(self, tmp_path):
|
|
"""Running _init_sqlite twice should not error."""
|
|
from claude_memory.mcp_server import _init_sqlite
|
|
db_path = str(tmp_path / "test.db")
|
|
conn1, _ = _init_sqlite(db_path)
|
|
conn1.close()
|
|
conn2, _ = _init_sqlite(db_path)
|
|
version = conn2.execute("PRAGMA user_version").fetchone()[0]
|
|
assert version == 2
|
|
conn2.close()
|
|
|
|
def test_server_id_column_exists(self, tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
srv = MemoryServer(sqlite_db_path=db_path)
|
|
cursor = srv.sqlite_conn.cursor()
|
|
cursor.execute("PRAGMA table_info(memories)")
|
|
columns = {row["name"] for row in cursor.fetchall()}
|
|
assert "server_id" in columns
|
|
srv.sqlite_conn.close()
|
|
|
|
|
|
def _server_rows(server: MemoryServer) -> int:
|
|
assert server.sqlite_conn is not None
|
|
return int(server.sqlite_conn.execute("SELECT COUNT(*) AS c FROM memories").fetchone()["c"])
|
|
|
|
|
|
class TestConcurrentWrites:
|
|
"""Regression tests for 'database is locked' under heavy concurrent local writes.
|
|
|
|
The store path (MemoryServer) and the background sync writer (SyncEngine) must not
|
|
contend on the single SQLite writer. Before the fix they used two separate connections
|
|
to the same file; under load the second writer blew past busy_timeout and raised
|
|
``sqlite3.OperationalError: database is locked``. After the fix all local writes are
|
|
serialized through one shared, lock-guarded connection, so a lock error is impossible.
|
|
"""
|
|
|
|
def test_concurrent_stores_all_rows_land(self, tmp_path):
|
|
"""Many threads calling memory_store concurrently → every row lands, no lock error."""
|
|
db_path = str(tmp_path / "concurrent.db")
|
|
server = MemoryServer(sqlite_db_path=db_path)
|
|
try:
|
|
n_threads = 16
|
|
per_thread = 12
|
|
errors: list[BaseException] = []
|
|
barrier = threading.Barrier(n_threads)
|
|
|
|
def worker(tid: int) -> None:
|
|
barrier.wait() # release all threads at once to maximise contention
|
|
for i in range(per_thread):
|
|
try:
|
|
server.memory_store({
|
|
"content": f"thread {tid} memory {i}",
|
|
"expanded_keywords": f"thread {tid} memory {i} concurrent write",
|
|
})
|
|
except BaseException as exc: # noqa: BLE001 — capture everything for the assert
|
|
errors.append(exc)
|
|
|
|
threads = [threading.Thread(target=worker, args=(t,)) for t in range(n_threads)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
assert errors == [], f"concurrent stores raised: {errors!r}"
|
|
assert _server_rows(server) == n_threads * per_thread
|
|
finally:
|
|
if server.sqlite_conn:
|
|
server.sqlite_conn.close()
|
|
|
|
def test_concurrent_stores_while_sync_writer_active_no_lock(self, tmp_path):
|
|
"""Store path + background sync writer hammer the SAME file → no 'database is locked'.
|
|
|
|
Reproduces the production failure: ``MemoryServer`` and ``SyncEngine`` both write to
|
|
one SQLite file. We shrink ``busy_timeout`` so the structural race (two writers fighting
|
|
the single SQLite write lock) surfaces within seconds instead of 30s. The shared-writer
|
|
fix makes a lock error impossible regardless of busy_timeout.
|
|
"""
|
|
from claude_memory.sync import SyncEngine
|
|
|
|
db_path = str(tmp_path / "hybrid.db")
|
|
server = MemoryServer(sqlite_db_path=db_path)
|
|
# The production hybrid wiring: the sync engine SHARES the server's LocalStore
|
|
# (one connection, one lock) — the structural fix for the cross-connection race.
|
|
engine = SyncEngine(
|
|
db_path=db_path,
|
|
api_base_url="http://fake-api:8080",
|
|
api_key="test-key",
|
|
sync_interval=3600, # never auto-syncs; we drive the writer by hand
|
|
store=server.store,
|
|
)
|
|
assert engine._conn is server.sqlite_conn # truly one shared connection
|
|
|
|
# Shrink the busy timeout so that, were there still two writers, the race would
|
|
# surface in ms not 30s. With one shared connection a lock error is impossible.
|
|
server.sqlite_conn.execute("PRAGMA busy_timeout=50")
|
|
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
|
|
# A write-heavy pull: many rows upserted inside the sync engine's single lock block —
|
|
# exactly the kind of long-held writer that starved the store path.
|
|
def big_pull(method: str, path: str, body: Any = None) -> dict[str, Any]:
|
|
return {
|
|
"memories": [
|
|
{
|
|
"id": 10_000 + j,
|
|
"content": f"server row {j}",
|
|
"category": "facts",
|
|
"tags": "",
|
|
"expanded_keywords": "",
|
|
"importance": 0.5,
|
|
"is_sensitive": False,
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"deleted_at": None,
|
|
}
|
|
for j in range(120)
|
|
],
|
|
"server_time": now,
|
|
}
|
|
|
|
errors: list[BaseException] = []
|
|
stop = threading.Event()
|
|
|
|
def sync_writer() -> None:
|
|
with patch.object(engine, "_api_request", side_effect=big_pull):
|
|
while not stop.is_set():
|
|
try:
|
|
engine._pull_changes()
|
|
except BaseException as exc: # noqa: BLE001
|
|
errors.append(exc)
|
|
|
|
n_threads = 12
|
|
per_thread = 12
|
|
barrier = threading.Barrier(n_threads)
|
|
|
|
def store_worker(tid: int) -> None:
|
|
barrier.wait()
|
|
for i in range(per_thread):
|
|
try:
|
|
server.memory_store({
|
|
"content": f"local {tid}.{i}",
|
|
"expanded_keywords": f"local thread {tid} item {i} keywords here",
|
|
})
|
|
except BaseException as exc: # noqa: BLE001
|
|
errors.append(exc)
|
|
|
|
writer = threading.Thread(target=sync_writer, daemon=True)
|
|
writer.start()
|
|
final_rows = 0
|
|
try:
|
|
store_threads = [threading.Thread(target=store_worker, args=(t,)) for t in range(n_threads)]
|
|
for t in store_threads:
|
|
t.start()
|
|
for t in store_threads:
|
|
t.join()
|
|
stop.set()
|
|
writer.join(timeout=5)
|
|
final_rows = _server_rows(server) # read while the connection is still open
|
|
finally:
|
|
stop.set()
|
|
engine.stop() # shares the server's store → does not close the connection
|
|
if server.sqlite_conn:
|
|
server.sqlite_conn.close()
|
|
|
|
locked = [e for e in errors if isinstance(e, sqlite3.OperationalError) and "locked" in str(e)]
|
|
assert locked == [], f"hit 'database is locked' under concurrency: {locked!r}"
|
|
assert errors == [], f"concurrent writes raised: {errors!r}"
|
|
# Every local store must have landed.
|
|
assert final_rows >= n_threads * per_thread
|
|
|
|
|
|
class TestRecallProgressAndBounding:
|
|
"""The slow path — a remote semantic ``memory_recall`` — must be bounded and give a
|
|
notion of progress instead of hanging silently and dropping the session."""
|
|
|
|
def test_remote_recall_timeout_returns_clear_signal_not_raise(self, server):
|
|
"""A timed-out / unreachable remote recall returns a clear 'retry' message, never hangs/raises."""
|
|
import claude_memory.mcp_server as m
|
|
|
|
with patch.object(m, "_api_request", side_effect=RuntimeError("API connection error: timed out")):
|
|
text = server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
|
|
|
assert "recall" in text.lower()
|
|
# Mentions the bound and that the caller should retry — a clear working/not-done signal.
|
|
assert "retry" in text.lower() or "again" in text.lower()
|
|
assert str(m.RECALL_TIMEOUT) in text
|
|
|
|
def test_remote_recall_success_formats_rows(self, server):
|
|
"""A successful remote recall still formats results normally."""
|
|
import claude_memory.mcp_server as m
|
|
|
|
payload = {"memories": [
|
|
{"id": 7, "category": "facts", "importance": 0.8, "content": "hello",
|
|
"tags": "t", "created_at": "2026-01-01T00:00:00+00:00"},
|
|
]}
|
|
with patch.object(m, "_api_request", return_value=payload):
|
|
text = server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
|
|
|
assert "Found 1 memories" in text
|
|
assert "hello" in text
|
|
|
|
def test_remote_recall_uses_bounded_timeout(self, server):
|
|
"""The remote recall passes the bounded RECALL_TIMEOUT to the HTTP layer."""
|
|
import claude_memory.mcp_server as m
|
|
|
|
with patch.object(m, "_api_request", return_value={"memories": []}) as mock_api:
|
|
server._recall_remote({"context": "x", "expanded_query": "a b c d e"})
|
|
|
|
_, kwargs = mock_api.call_args
|
|
assert kwargs.get("timeout") == m.RECALL_TIMEOUT
|
|
|
|
def test_api_request_accepts_timeout_kwarg(self):
|
|
"""Module-level _api_request must accept an explicit timeout without breaking callers."""
|
|
import inspect
|
|
import claude_memory.mcp_server as m
|
|
|
|
sig = inspect.signature(m._api_request)
|
|
assert "timeout" in sig.parameters
|
|
|
|
def test_progress_notification_emitted_for_recall_with_token(self, server):
|
|
"""When the client supplies a progressToken, a 'working' progress notification is emitted."""
|
|
sent: list[dict[str, Any]] = []
|
|
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
|
|
|
with patch.object(server, "memory_recall", return_value="ok"):
|
|
server.handle_tools_call({
|
|
"name": "memory_recall",
|
|
"arguments": {"context": "x", "expanded_query": "a b c d e"},
|
|
"_meta": {"progressToken": "tok-1"},
|
|
})
|
|
|
|
progress = [s for s in sent if s["method"] == "notifications/progress"]
|
|
assert progress, "expected a notifications/progress for a tokened recall"
|
|
assert progress[0]["params"]["progressToken"] == "tok-1"
|
|
|
|
def test_no_progress_notification_without_token(self, server):
|
|
"""No token → no progress notification (don't spam clients that didn't opt in)."""
|
|
sent: list[dict[str, Any]] = []
|
|
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
|
|
|
with patch.object(server, "memory_recall", return_value="ok"):
|
|
server.handle_tools_call({
|
|
"name": "memory_recall",
|
|
"arguments": {"context": "x", "expanded_query": "a b c d e"},
|
|
})
|
|
|
|
assert [s for s in sent if s["method"] == "notifications/progress"] == []
|
|
|
|
def test_fast_tools_do_not_emit_progress(self, server):
|
|
"""Only the slow recall path signals progress; a store with a token stays quiet."""
|
|
sent: list[dict[str, Any]] = []
|
|
server._notify = lambda method, params: sent.append({"method": method, "params": params})
|
|
|
|
server.handle_tools_call({
|
|
"name": "memory_store",
|
|
"arguments": {"content": "x", "expanded_keywords": "a b c d e"},
|
|
"_meta": {"progressToken": "tok-2"},
|
|
})
|
|
|
|
assert [s for s in sent if s["method"] == "notifications/progress"] == []
|