- Decouple push and pull in _sync_once() so pull always runs even if push fails - Add startup full resync to catch drift from other agents and schema changes - Add periodic full resync every ~10 minutes for continuous drift correction - Add auth failure detection (401/403) with graceful SQLite-only degradation - Add /api/auth-check endpoint for lightweight key validation - Add retry cap (5 attempts) on pending ops to prevent infinite queue buildup - Add orphan reconciliation: push local-only records with content dedup - Add memory_count MCP tool for sync diagnostics - Add version-based SQLite schema migration (PRAGMA user_version) - Fix API key in ~/.claude.json to match server - Update README with sync resilience docs, test structure, project layout - Add 30 new tests covering all new behaviors (155 total, all passing)
410 lines
14 KiB
Python
410 lines
14 KiB
Python
"""Tests for the Claude Memory MCP server."""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
|
|
import pytest
|
|
|
|
# Force SQLite fallback mode for all tests
|
|
os.environ.pop("MEMORY_API_KEY", None)
|
|
os.environ.pop("CLAUDE_MEMORY_API_KEY", None)
|
|
|
|
# Add src to path so we can import without installing
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
|
|
|
from claude_memory.mcp_server import MemoryServer, SERVER_NAME, SERVER_VERSION, PROTOCOL_VERSION
|
|
|
|
|
|
@pytest.fixture
|
|
def server(tmp_path):
|
|
"""Create a MemoryServer with a temporary SQLite database."""
|
|
db_path = str(tmp_path / "test_memory.db")
|
|
srv = MemoryServer(sqlite_db_path=db_path)
|
|
yield srv
|
|
if srv.sqlite_conn:
|
|
srv.sqlite_conn.close()
|
|
|
|
|
|
class TestSQLiteInit:
|
|
def test_creates_database(self, tmp_path):
|
|
db_path = str(tmp_path / "sub" / "test.db")
|
|
srv = MemoryServer(sqlite_db_path=db_path)
|
|
assert os.path.exists(db_path)
|
|
# Verify tables exist
|
|
cursor = srv.sqlite_conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories'")
|
|
assert cursor.fetchone() is not None
|
|
srv.sqlite_conn.close()
|
|
|
|
def test_creates_fts_table(self, tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
srv = MemoryServer(sqlite_db_path=db_path)
|
|
cursor = srv.sqlite_conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories_fts'")
|
|
assert cursor.fetchone() is not None
|
|
srv.sqlite_conn.close()
|
|
|
|
|
|
class TestMemoryStore:
|
|
def test_store_basic(self, server):
|
|
result = server.memory_store({
|
|
"content": "User prefers dark mode",
|
|
"expanded_keywords": "dark mode theme preference ui",
|
|
})
|
|
assert "Stored memory #1" in result
|
|
assert "facts" in result
|
|
|
|
def test_store_with_category(self, server):
|
|
result = server.memory_store({
|
|
"content": "User likes Python",
|
|
"category": "preferences",
|
|
"expanded_keywords": "python programming language preference",
|
|
})
|
|
assert "preferences" in result
|
|
|
|
def test_store_with_importance(self, server):
|
|
result = server.memory_store({
|
|
"content": "Critical info",
|
|
"importance": 0.9,
|
|
"expanded_keywords": "critical important info",
|
|
})
|
|
assert "0.9" in result
|
|
|
|
def test_store_requires_content(self, server):
|
|
with pytest.raises(ValueError, match="content is required"):
|
|
server.memory_store({"expanded_keywords": "test"})
|
|
|
|
def test_store_force_sensitive(self, server):
|
|
result = server.memory_store({
|
|
"content": "API key: sk-1234",
|
|
"force_sensitive": True,
|
|
"expanded_keywords": "api key secret credential",
|
|
})
|
|
assert "Stored memory #1" in result
|
|
# Verify is_sensitive flag is set
|
|
cursor = server.sqlite_conn.cursor()
|
|
cursor.execute("SELECT is_sensitive FROM memories WHERE id = 1")
|
|
row = cursor.fetchone()
|
|
assert row["is_sensitive"] == 1
|
|
|
|
|
|
class TestMemoryRecall:
|
|
def test_recall_finds_memory(self, server):
|
|
server.memory_store({
|
|
"content": "User works at Acme Corp",
|
|
"expanded_keywords": "acme corp company work employer",
|
|
})
|
|
result = server.memory_recall({
|
|
"context": "work",
|
|
"expanded_query": "company employer job",
|
|
})
|
|
assert "Acme Corp" in result
|
|
assert "Found 1 memories" in result
|
|
|
|
def test_recall_no_results(self, server):
|
|
result = server.memory_recall({
|
|
"context": "nonexistent topic",
|
|
"expanded_query": "nothing here at all",
|
|
})
|
|
assert "No memories found" in result
|
|
|
|
def test_recall_with_category_filter(self, server):
|
|
server.memory_store({
|
|
"content": "User prefers vim",
|
|
"category": "preferences",
|
|
"expanded_keywords": "vim editor preference text",
|
|
})
|
|
server.memory_store({
|
|
"content": "Project uses React",
|
|
"category": "projects",
|
|
"expanded_keywords": "react project frontend framework",
|
|
})
|
|
result = server.memory_recall({
|
|
"context": "preferences",
|
|
"expanded_query": "vim editor",
|
|
"category": "preferences",
|
|
})
|
|
assert "vim" in result
|
|
assert "React" not in result
|
|
|
|
def test_recall_requires_context(self, server):
|
|
with pytest.raises(ValueError, match="context is required"):
|
|
server.memory_recall({"expanded_query": "test"})
|
|
|
|
|
|
class TestMemoryList:
|
|
def test_list_empty(self, server):
|
|
result = server.memory_list({})
|
|
assert "No memories stored yet" in result
|
|
|
|
def test_list_with_memories(self, server):
|
|
server.memory_store({
|
|
"content": "Memory one",
|
|
"expanded_keywords": "one first test",
|
|
})
|
|
server.memory_store({
|
|
"content": "Memory two",
|
|
"expanded_keywords": "two second test",
|
|
})
|
|
result = server.memory_list({})
|
|
assert "Memory one" in result
|
|
assert "Memory two" in result
|
|
assert "2 shown" in result
|
|
|
|
def test_list_with_category(self, server):
|
|
server.memory_store({
|
|
"content": "A fact",
|
|
"category": "facts",
|
|
"expanded_keywords": "fact test",
|
|
})
|
|
server.memory_store({
|
|
"content": "A preference",
|
|
"category": "preferences",
|
|
"expanded_keywords": "preference test",
|
|
})
|
|
result = server.memory_list({"category": "facts"})
|
|
assert "A fact" in result
|
|
assert "A preference" not in result
|
|
|
|
def test_list_empty_category(self, server):
|
|
result = server.memory_list({"category": "projects"})
|
|
assert "No memories in category 'projects'" in result
|
|
|
|
def test_list_respects_limit(self, server):
|
|
for i in range(5):
|
|
server.memory_store({
|
|
"content": f"Memory {i}",
|
|
"expanded_keywords": f"memory number {i}",
|
|
})
|
|
result = server.memory_list({"limit": 2})
|
|
assert "2 shown" in result
|
|
|
|
|
|
class TestMemoryDelete:
|
|
def test_delete_existing(self, server):
|
|
server.memory_store({
|
|
"content": "To be deleted",
|
|
"expanded_keywords": "delete remove test",
|
|
})
|
|
result = server.memory_delete({"id": 1})
|
|
assert "Deleted memory #1" in result
|
|
assert "To be deleted" in result
|
|
|
|
def test_delete_nonexistent(self, server):
|
|
result = server.memory_delete({"id": 999})
|
|
assert "not found" in result
|
|
|
|
def test_delete_requires_id(self, server):
|
|
with pytest.raises(ValueError, match="id is required"):
|
|
server.memory_delete({})
|
|
|
|
|
|
class TestSecretGet:
|
|
def test_secret_get_sensitive(self, server):
|
|
server.memory_store({
|
|
"content": "secret password 12345",
|
|
"force_sensitive": True,
|
|
"expanded_keywords": "password secret credential",
|
|
})
|
|
result = server.secret_get({"id": 1})
|
|
assert "secret password 12345" in result
|
|
|
|
def test_secret_get_not_sensitive(self, server):
|
|
server.memory_store({
|
|
"content": "public info",
|
|
"expanded_keywords": "public info test",
|
|
})
|
|
result = server.secret_get({"id": 1})
|
|
assert "not marked as sensitive" in result
|
|
|
|
def test_secret_get_nonexistent(self, server):
|
|
result = server.secret_get({"id": 999})
|
|
assert "not found" in result
|
|
|
|
def test_secret_get_requires_id(self, server):
|
|
with pytest.raises(ValueError, match="id is required"):
|
|
server.secret_get({})
|
|
|
|
|
|
class TestMCPProtocol:
|
|
def test_handle_initialize(self, server):
|
|
result = server.handle_initialize({})
|
|
assert result["protocolVersion"] == PROTOCOL_VERSION
|
|
assert result["serverInfo"]["name"] == SERVER_NAME
|
|
assert result["serverInfo"]["version"] == SERVER_VERSION
|
|
assert "tools" in result["capabilities"]
|
|
|
|
def test_handle_tools_list(self, server):
|
|
result = server.handle_tools_list({})
|
|
tools = result["tools"]
|
|
assert len(tools) == 6
|
|
names = {t["name"] for t in tools}
|
|
assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get", "memory_count"}
|
|
|
|
def test_handle_tools_call_store(self, server):
|
|
result = server.handle_tools_call({
|
|
"name": "memory_store",
|
|
"arguments": {
|
|
"content": "test memory",
|
|
"expanded_keywords": "test memory keywords",
|
|
},
|
|
})
|
|
assert not result.get("isError", False)
|
|
assert "Stored memory" in result["content"][0]["text"]
|
|
|
|
def test_handle_tools_call_unknown(self, server):
|
|
result = server.handle_tools_call({
|
|
"name": "nonexistent_tool",
|
|
"arguments": {},
|
|
})
|
|
assert result["isError"] is True
|
|
assert "Unknown tool" in result["content"][0]["text"]
|
|
|
|
def test_handle_tools_call_error(self, server):
|
|
result = server.handle_tools_call({
|
|
"name": "memory_store",
|
|
"arguments": {}, # missing content
|
|
})
|
|
assert result["isError"] is True
|
|
assert "Error" in result["content"][0]["text"]
|
|
|
|
|
|
class TestProcessMessage:
|
|
def test_initialize(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"id": 1,
|
|
"method": "initialize",
|
|
"params": {},
|
|
})
|
|
assert response["jsonrpc"] == "2.0"
|
|
assert response["id"] == 1
|
|
assert "result" in response
|
|
assert response["result"]["serverInfo"]["name"] == SERVER_NAME
|
|
|
|
def test_tools_list(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"id": 2,
|
|
"method": "tools/list",
|
|
"params": {},
|
|
})
|
|
assert "result" in response
|
|
assert len(response["result"]["tools"]) == 6
|
|
|
|
def test_tools_call(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"id": 3,
|
|
"method": "tools/call",
|
|
"params": {
|
|
"name": "memory_store",
|
|
"arguments": {
|
|
"content": "via process_message",
|
|
"expanded_keywords": "process message test",
|
|
},
|
|
},
|
|
})
|
|
assert "result" in response
|
|
assert "Stored memory" in response["result"]["content"][0]["text"]
|
|
|
|
def test_unknown_method(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"id": 4,
|
|
"method": "unknown/method",
|
|
"params": {},
|
|
})
|
|
assert "error" in response
|
|
assert response["error"]["code"] == -32601
|
|
assert "Method not found" in response["error"]["message"]
|
|
|
|
def test_notification_no_id(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"method": "notifications/initialized",
|
|
"params": {},
|
|
})
|
|
assert response is None
|
|
|
|
def test_jsonrpc_response_format(self, server):
|
|
response = server.process_message({
|
|
"jsonrpc": "2.0",
|
|
"id": 5,
|
|
"method": "initialize",
|
|
"params": {},
|
|
})
|
|
# Verify it's valid JSON when serialized
|
|
serialized = json.dumps(response)
|
|
parsed = json.loads(serialized)
|
|
assert parsed["jsonrpc"] == "2.0"
|
|
assert parsed["id"] == 5
|
|
|
|
|
|
class TestMemoryCount:
|
|
def test_count_empty(self, server):
|
|
result = server.memory_count({})
|
|
assert "0" in result
|
|
|
|
def test_count_after_store(self, server):
|
|
server.memory_store({
|
|
"content": "test memory",
|
|
"expanded_keywords": "test memory keywords data",
|
|
})
|
|
result = server.memory_count({})
|
|
assert "1" in result
|
|
assert "facts" in result
|
|
|
|
def test_count_multiple_categories(self, server):
|
|
server.memory_store({
|
|
"content": "a fact",
|
|
"category": "facts",
|
|
"expanded_keywords": "fact test data words",
|
|
})
|
|
server.memory_store({
|
|
"content": "a preference",
|
|
"category": "preferences",
|
|
"expanded_keywords": "preference test data words",
|
|
})
|
|
result = server.memory_count({})
|
|
assert "facts: 1" in result
|
|
assert "preferences: 1" in result
|
|
|
|
def test_count_via_tools_call(self, server):
|
|
result = server.handle_tools_call({
|
|
"name": "memory_count",
|
|
"arguments": {},
|
|
})
|
|
assert not result.get("isError", False)
|
|
assert "0" in result["content"][0]["text"]
|
|
|
|
|
|
class TestSchemaMigration:
|
|
def test_schema_version_set(self, tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
srv = MemoryServer(sqlite_db_path=db_path)
|
|
cursor = srv.sqlite_conn.cursor()
|
|
version = cursor.execute("PRAGMA user_version").fetchone()[0]
|
|
assert version == 2
|
|
srv.sqlite_conn.close()
|
|
|
|
def test_migration_idempotent(self, tmp_path):
|
|
"""Running _init_sqlite twice should not error."""
|
|
from claude_memory.mcp_server import _init_sqlite
|
|
db_path = str(tmp_path / "test.db")
|
|
conn1, _ = _init_sqlite(db_path)
|
|
conn1.close()
|
|
conn2, _ = _init_sqlite(db_path)
|
|
version = conn2.execute("PRAGMA user_version").fetchone()[0]
|
|
assert version == 2
|
|
conn2.close()
|
|
|
|
def test_server_id_column_exists(self, tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
srv = MemoryServer(sqlite_db_path=db_path)
|
|
cursor = srv.sqlite_conn.cursor()
|
|
cursor.execute("PRAGMA table_info(memories)")
|
|
columns = {row["name"] for row in cursor.fetchall()}
|
|
assert "server_id" in columns
|
|
srv.sqlite_conn.close()
|