claude-memory-mcp/tests/test_mcp_server.py
Viktor Barzin e47efee6b6
resilient memory sync: decouple push/pull, startup full resync, auth failure handling
- Decouple push and pull in _sync_once() so pull always runs even if push fails
- Add startup full resync to catch drift from other agents and schema changes
- Add periodic full resync every ~10 minutes for continuous drift correction
- Add auth failure detection (401/403) with graceful SQLite-only degradation
- Add /api/auth-check endpoint for lightweight key validation
- Add retry cap (5 attempts) on pending ops to prevent infinite queue buildup
- Add orphan reconciliation: push local-only records with content dedup
- Add memory_count MCP tool for sync diagnostics
- Add version-based SQLite schema migration (PRAGMA user_version)
- Fix API key in ~/.claude.json to match server
- Update README with sync resilience docs, test structure, project layout
- Add 30 new tests covering all new behaviors (155 total, all passing)
2026-03-16 18:37:59 +00:00

410 lines
14 KiB
Python

"""Tests for the Claude Memory MCP server."""
import json
import os
import sys
import pytest
# Force SQLite fallback mode for all tests
os.environ.pop("MEMORY_API_KEY", None)
os.environ.pop("CLAUDE_MEMORY_API_KEY", None)
# Add src to path so we can import without installing
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from claude_memory.mcp_server import MemoryServer, SERVER_NAME, SERVER_VERSION, PROTOCOL_VERSION
@pytest.fixture
def server(tmp_path):
"""Create a MemoryServer with a temporary SQLite database."""
db_path = str(tmp_path / "test_memory.db")
srv = MemoryServer(sqlite_db_path=db_path)
yield srv
if srv.sqlite_conn:
srv.sqlite_conn.close()
class TestSQLiteInit:
def test_creates_database(self, tmp_path):
db_path = str(tmp_path / "sub" / "test.db")
srv = MemoryServer(sqlite_db_path=db_path)
assert os.path.exists(db_path)
# Verify tables exist
cursor = srv.sqlite_conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories'")
assert cursor.fetchone() is not None
srv.sqlite_conn.close()
def test_creates_fts_table(self, tmp_path):
db_path = str(tmp_path / "test.db")
srv = MemoryServer(sqlite_db_path=db_path)
cursor = srv.sqlite_conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories_fts'")
assert cursor.fetchone() is not None
srv.sqlite_conn.close()
class TestMemoryStore:
def test_store_basic(self, server):
result = server.memory_store({
"content": "User prefers dark mode",
"expanded_keywords": "dark mode theme preference ui",
})
assert "Stored memory #1" in result
assert "facts" in result
def test_store_with_category(self, server):
result = server.memory_store({
"content": "User likes Python",
"category": "preferences",
"expanded_keywords": "python programming language preference",
})
assert "preferences" in result
def test_store_with_importance(self, server):
result = server.memory_store({
"content": "Critical info",
"importance": 0.9,
"expanded_keywords": "critical important info",
})
assert "0.9" in result
def test_store_requires_content(self, server):
with pytest.raises(ValueError, match="content is required"):
server.memory_store({"expanded_keywords": "test"})
def test_store_force_sensitive(self, server):
result = server.memory_store({
"content": "API key: sk-1234",
"force_sensitive": True,
"expanded_keywords": "api key secret credential",
})
assert "Stored memory #1" in result
# Verify is_sensitive flag is set
cursor = server.sqlite_conn.cursor()
cursor.execute("SELECT is_sensitive FROM memories WHERE id = 1")
row = cursor.fetchone()
assert row["is_sensitive"] == 1
class TestMemoryRecall:
def test_recall_finds_memory(self, server):
server.memory_store({
"content": "User works at Acme Corp",
"expanded_keywords": "acme corp company work employer",
})
result = server.memory_recall({
"context": "work",
"expanded_query": "company employer job",
})
assert "Acme Corp" in result
assert "Found 1 memories" in result
def test_recall_no_results(self, server):
result = server.memory_recall({
"context": "nonexistent topic",
"expanded_query": "nothing here at all",
})
assert "No memories found" in result
def test_recall_with_category_filter(self, server):
server.memory_store({
"content": "User prefers vim",
"category": "preferences",
"expanded_keywords": "vim editor preference text",
})
server.memory_store({
"content": "Project uses React",
"category": "projects",
"expanded_keywords": "react project frontend framework",
})
result = server.memory_recall({
"context": "preferences",
"expanded_query": "vim editor",
"category": "preferences",
})
assert "vim" in result
assert "React" not in result
def test_recall_requires_context(self, server):
with pytest.raises(ValueError, match="context is required"):
server.memory_recall({"expanded_query": "test"})
class TestMemoryList:
def test_list_empty(self, server):
result = server.memory_list({})
assert "No memories stored yet" in result
def test_list_with_memories(self, server):
server.memory_store({
"content": "Memory one",
"expanded_keywords": "one first test",
})
server.memory_store({
"content": "Memory two",
"expanded_keywords": "two second test",
})
result = server.memory_list({})
assert "Memory one" in result
assert "Memory two" in result
assert "2 shown" in result
def test_list_with_category(self, server):
server.memory_store({
"content": "A fact",
"category": "facts",
"expanded_keywords": "fact test",
})
server.memory_store({
"content": "A preference",
"category": "preferences",
"expanded_keywords": "preference test",
})
result = server.memory_list({"category": "facts"})
assert "A fact" in result
assert "A preference" not in result
def test_list_empty_category(self, server):
result = server.memory_list({"category": "projects"})
assert "No memories in category 'projects'" in result
def test_list_respects_limit(self, server):
for i in range(5):
server.memory_store({
"content": f"Memory {i}",
"expanded_keywords": f"memory number {i}",
})
result = server.memory_list({"limit": 2})
assert "2 shown" in result
class TestMemoryDelete:
def test_delete_existing(self, server):
server.memory_store({
"content": "To be deleted",
"expanded_keywords": "delete remove test",
})
result = server.memory_delete({"id": 1})
assert "Deleted memory #1" in result
assert "To be deleted" in result
def test_delete_nonexistent(self, server):
result = server.memory_delete({"id": 999})
assert "not found" in result
def test_delete_requires_id(self, server):
with pytest.raises(ValueError, match="id is required"):
server.memory_delete({})
class TestSecretGet:
def test_secret_get_sensitive(self, server):
server.memory_store({
"content": "secret password 12345",
"force_sensitive": True,
"expanded_keywords": "password secret credential",
})
result = server.secret_get({"id": 1})
assert "secret password 12345" in result
def test_secret_get_not_sensitive(self, server):
server.memory_store({
"content": "public info",
"expanded_keywords": "public info test",
})
result = server.secret_get({"id": 1})
assert "not marked as sensitive" in result
def test_secret_get_nonexistent(self, server):
result = server.secret_get({"id": 999})
assert "not found" in result
def test_secret_get_requires_id(self, server):
with pytest.raises(ValueError, match="id is required"):
server.secret_get({})
class TestMCPProtocol:
def test_handle_initialize(self, server):
result = server.handle_initialize({})
assert result["protocolVersion"] == PROTOCOL_VERSION
assert result["serverInfo"]["name"] == SERVER_NAME
assert result["serverInfo"]["version"] == SERVER_VERSION
assert "tools" in result["capabilities"]
def test_handle_tools_list(self, server):
result = server.handle_tools_list({})
tools = result["tools"]
assert len(tools) == 6
names = {t["name"] for t in tools}
assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get", "memory_count"}
def test_handle_tools_call_store(self, server):
result = server.handle_tools_call({
"name": "memory_store",
"arguments": {
"content": "test memory",
"expanded_keywords": "test memory keywords",
},
})
assert not result.get("isError", False)
assert "Stored memory" in result["content"][0]["text"]
def test_handle_tools_call_unknown(self, server):
result = server.handle_tools_call({
"name": "nonexistent_tool",
"arguments": {},
})
assert result["isError"] is True
assert "Unknown tool" in result["content"][0]["text"]
def test_handle_tools_call_error(self, server):
result = server.handle_tools_call({
"name": "memory_store",
"arguments": {}, # missing content
})
assert result["isError"] is True
assert "Error" in result["content"][0]["text"]
class TestProcessMessage:
def test_initialize(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {},
})
assert response["jsonrpc"] == "2.0"
assert response["id"] == 1
assert "result" in response
assert response["result"]["serverInfo"]["name"] == SERVER_NAME
def test_tools_list(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
"params": {},
})
assert "result" in response
assert len(response["result"]["tools"]) == 6
def test_tools_call(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {
"name": "memory_store",
"arguments": {
"content": "via process_message",
"expanded_keywords": "process message test",
},
},
})
assert "result" in response
assert "Stored memory" in response["result"]["content"][0]["text"]
def test_unknown_method(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 4,
"method": "unknown/method",
"params": {},
})
assert "error" in response
assert response["error"]["code"] == -32601
assert "Method not found" in response["error"]["message"]
def test_notification_no_id(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
})
assert response is None
def test_jsonrpc_response_format(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 5,
"method": "initialize",
"params": {},
})
# Verify it's valid JSON when serialized
serialized = json.dumps(response)
parsed = json.loads(serialized)
assert parsed["jsonrpc"] == "2.0"
assert parsed["id"] == 5
class TestMemoryCount:
def test_count_empty(self, server):
result = server.memory_count({})
assert "0" in result
def test_count_after_store(self, server):
server.memory_store({
"content": "test memory",
"expanded_keywords": "test memory keywords data",
})
result = server.memory_count({})
assert "1" in result
assert "facts" in result
def test_count_multiple_categories(self, server):
server.memory_store({
"content": "a fact",
"category": "facts",
"expanded_keywords": "fact test data words",
})
server.memory_store({
"content": "a preference",
"category": "preferences",
"expanded_keywords": "preference test data words",
})
result = server.memory_count({})
assert "facts: 1" in result
assert "preferences: 1" in result
def test_count_via_tools_call(self, server):
result = server.handle_tools_call({
"name": "memory_count",
"arguments": {},
})
assert not result.get("isError", False)
assert "0" in result["content"][0]["text"]
class TestSchemaMigration:
def test_schema_version_set(self, tmp_path):
db_path = str(tmp_path / "test.db")
srv = MemoryServer(sqlite_db_path=db_path)
cursor = srv.sqlite_conn.cursor()
version = cursor.execute("PRAGMA user_version").fetchone()[0]
assert version == 2
srv.sqlite_conn.close()
def test_migration_idempotent(self, tmp_path):
"""Running _init_sqlite twice should not error."""
from claude_memory.mcp_server import _init_sqlite
db_path = str(tmp_path / "test.db")
conn1, _ = _init_sqlite(db_path)
conn1.close()
conn2, _ = _init_sqlite(db_path)
version = conn2.execute("PRAGMA user_version").fetchone()[0]
assert version == 2
conn2.close()
def test_server_id_column_exists(self, tmp_path):
db_path = str(tmp_path / "test.db")
srv = MemoryServer(sqlite_db_path=db_path)
cursor = srv.sqlite_conn.cursor()
cursor.execute("PRAGMA table_info(memories)")
columns = {row["name"] for row in cursor.fetchall()}
assert "server_id" in columns
srv.sqlite_conn.close()