From cd80a67dfa94d47318f52e0142ec9efff7ebb5a1 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Sat, 14 Mar 2026 12:42:39 +0000 Subject: [PATCH] feat: add local SQLite cache with background sync and HA deployment - Add SyncEngine for background sync between local SQLite cache and remote API with pending_ops queue for offline resilience - Refactor MCP server to support three modes: SQLite-only, hybrid (local cache + sync, new default), and HTTP-only (legacy) - Add GET /api/memories/sync endpoint for incremental sync - Change DELETE to soft delete (set deleted_at) for sync support - Add deleted_at IS NULL filters to all read queries - Scale API deployment to 2 replicas with pod anti-affinity, PDB, and startup probe for high availability - Add migration 003 for deleted_at column and updated_at index - Add comprehensive tests for sync engine and API sync endpoint --- .../versions/003_add_soft_delete_and_sync.py | 40 ++ src/claude_memory/api/app.py | 68 ++- src/claude_memory/api/models.py | 7 +- src/claude_memory/mcp_server.py | 276 +++++++----- src/claude_memory/sync.py | 334 +++++++++++++++ tests/test_api.py | 123 ++++++ tests/test_sync.py | 395 ++++++++++++++++++ 7 files changed, 1133 insertions(+), 110 deletions(-) create mode 100644 migrations/versions/003_add_soft_delete_and_sync.py create mode 100644 src/claude_memory/sync.py create mode 100644 tests/test_sync.py diff --git a/migrations/versions/003_add_soft_delete_and_sync.py b/migrations/versions/003_add_soft_delete_and_sync.py new file mode 100644 index 0000000..39852b3 --- /dev/null +++ b/migrations/versions/003_add_soft_delete_and_sync.py @@ -0,0 +1,40 @@ +"""Add soft delete and sync support. + +Revision ID: 003 +Revises: 002 +Create Date: 2026-03-14 +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +revision: str = "003" +down_revision: Union[str, None] = "002" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _column_exists(conn, column_name: str) -> bool: + result = conn.execute( + sa.text( + "SELECT EXISTS(SELECT 1 FROM information_schema.columns " + "WHERE table_name = 'memories' AND column_name = :col)" + ), + {"col": column_name}, + ) + return result.scalar() + + +def upgrade() -> None: + conn = op.get_bind() + + if not _column_exists(conn, "deleted_at"): + op.add_column("memories", sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)) + + op.execute("CREATE INDEX IF NOT EXISTS idx_memories_updated ON memories(updated_at)") + + +def downgrade() -> None: + op.drop_index("idx_memories_updated") + op.drop_column("memories", "deleted_at") diff --git a/src/claude_memory/api/app.py b/src/claude_memory/api/app.py index f0c91d6..a263fd8 100644 --- a/src/claude_memory/api/app.py +++ b/src/claude_memory/api/app.py @@ -2,13 +2,14 @@ import logging from contextlib import asynccontextmanager +from datetime import datetime, timezone from typing import Optional from fastapi import Depends, FastAPI, HTTPException from claude_memory.api.auth import AuthUser, get_current_user from claude_memory.api.database import close_pool, get_pool, init_pool -from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse +from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse, SyncResponse from claude_memory.api.vault_service import ( delete_secret, get_secret, @@ -58,6 +59,58 @@ async def health(): return {"status": "ok"} +@app.get("/api/memories/sync", response_model=SyncResponse) +async def sync_memories( + since: Optional[str] = None, + user: AuthUser = Depends(get_current_user), +): + pool = await get_pool() + server_time = datetime.now(timezone.utc).isoformat() + + async with pool.acquire() as conn: + if since: + rows = await conn.fetch( + """ + SELECT id, content, category, tags, expanded_keywords, importance, + is_sensitive, created_at, updated_at, deleted_at + FROM memories + WHERE user_id = $1 AND updated_at > $2::timestamptz + ORDER BY updated_at ASC + """, + user.user_id, + since, + ) + else: + rows = await conn.fetch( + """ + SELECT id, content, category, tags, expanded_keywords, importance, + is_sensitive, created_at, updated_at, deleted_at + FROM memories + WHERE user_id = $1 AND deleted_at IS NULL + ORDER BY updated_at ASC + """, + user.user_id, + ) + + memories = [] + for row in rows: + mem = { + "id": row["id"], + "content": row["content"], + "category": row["category"], + "tags": row["tags"], + "expanded_keywords": row["expanded_keywords"], + "importance": row["importance"], + "is_sensitive": row["is_sensitive"], + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + "deleted_at": row["deleted_at"].isoformat() if row["deleted_at"] else None, + } + memories.append(mem) + + return SyncResponse(memories=memories, server_time=server_time) + + @app.post("/api/memories", response_model=MemoryResponse) async def store_memory(body: MemoryStore, user: AuthUser = Depends(get_current_user)): pool = await get_pool() @@ -117,6 +170,7 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre created_at, updated_at FROM memories, plainto_tsquery('english', $2) query WHERE user_id = $1 + AND deleted_at IS NULL AND (search_vector @@ query OR $2 = '') {category_filter} ORDER BY {order_clause} @@ -158,14 +212,14 @@ async def list_memories( if category: query = """ SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at - FROM memories WHERE user_id = $1 AND category = $2 + FROM memories WHERE user_id = $1 AND deleted_at IS NULL AND category = $2 ORDER BY importance DESC LIMIT $3 """ params: list = [user.user_id, category, limit] else: query = """ SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at - FROM memories WHERE user_id = $1 + FROM memories WHERE user_id = $1 AND deleted_at IS NULL ORDER BY importance DESC LIMIT $2 """ params = [user.user_id, limit] @@ -200,7 +254,7 @@ async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_use async with pool.acquire() as conn: row = await conn.fetchrow( - "SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2", + "SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL", memory_id, user.user_id, ) @@ -211,7 +265,7 @@ async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_use await delete_secret(user.user_id, row["vault_path"]) await conn.execute( - "DELETE FROM memories WHERE id = $1 AND user_id = $2", + "UPDATE memories SET deleted_at = NOW(), updated_at = NOW() WHERE id = $1 AND user_id = $2", memory_id, user.user_id, ) @@ -227,7 +281,7 @@ async def get_memory_secret(memory_id: int, user: AuthUser = Depends(get_current row = await conn.fetchrow( """ SELECT id, content, is_sensitive, vault_path, encrypted_content - FROM memories WHERE id = $1 AND user_id = $2 + FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL """, memory_id, user.user_id, @@ -263,7 +317,7 @@ async def migrate_secrets(user: AuthUser = Depends(get_current_user)): rows = await conn.fetch( """ SELECT id, content FROM memories - WHERE user_id = $1 AND is_sensitive = FALSE + WHERE user_id = $1 AND is_sensitive = FALSE AND deleted_at IS NULL """, user.user_id, ) diff --git a/src/claude_memory/api/models.py b/src/claude_memory/api/models.py index d4fb804..cabdf21 100644 --- a/src/claude_memory/api/models.py +++ b/src/claude_memory/api/models.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel, Field @@ -30,3 +30,8 @@ class SecretResponse(BaseModel): id: int content: str source: str # "vault", "encrypted", "plaintext" + + +class SyncResponse(BaseModel): + memories: list[dict[str, Any]] + server_time: str diff --git a/src/claude_memory/mcp_server.py b/src/claude_memory/mcp_server.py index 21170b8..f0f2fc1 100644 --- a/src/claude_memory/mcp_server.py +++ b/src/claude_memory/mcp_server.py @@ -2,9 +2,10 @@ """ Claude Memory MCP Server — standalone memory server with multi-user support. -Supports two modes: - 1. HTTP API mode: connects to a shared PostgreSQL-backed API server - 2. SQLite fallback: local file-based storage when no API key is configured +Supports three modes: + 1. SQLite-only: local file-based storage when no API key is configured + 2. Hybrid (default when API key set): local SQLite cache + background sync + 3. HTTP-only (legacy): direct HTTP to API, no local cache (MEMORY_SYNC_DISABLE=1) Uses only stdlib (urllib) — no pip install required. """ @@ -21,14 +22,17 @@ logger = logging.getLogger(__name__) PROTOCOL_VERSION = "2024-11-05" SERVER_NAME = "claude-memory" -SERVER_VERSION = "1.0.0" +SERVER_VERSION = "1.1.0" # API configuration — support both MEMORY_* (primary) and CLAUDE_MEMORY_* (fallback) env vars API_BASE_URL = os.environ.get("MEMORY_API_URL") or os.environ.get("CLAUDE_MEMORY_API_URL", "http://localhost:8080") API_KEY = os.environ.get("MEMORY_API_KEY") or os.environ.get("CLAUDE_MEMORY_API_KEY", "") -# Fallback to SQLite if API is not configured -SQLITE_FALLBACK = not API_KEY +# Mode detection +SYNC_DISABLED = os.environ.get("MEMORY_SYNC_DISABLE", "") == "1" +HYBRID_MODE = bool(API_KEY) and not SYNC_DISABLED +HTTP_ONLY = bool(API_KEY) and SYNC_DISABLED +SQLITE_ONLY = not API_KEY def _api_request(method: str, path: str, body: dict | None = None) -> dict: @@ -54,23 +58,29 @@ def _api_request(method: str, path: str, body: dict | None = None) -> dict: raise RuntimeError(f"API connection error: {e.reason}") from e -# ─── SQLite fallback (local storage when API not configured) ───────────────── +# ─── SQLite initialization ──────────────────────────────────────────────────── + +def _get_db_path(db_path: str | None = None) -> str: + """Resolve the SQLite database path.""" + if db_path is not None: + return db_path + + memory_home = os.path.expandvars( + os.path.expanduser(os.environ.get("MEMORY_HOME", "~/.claude/claude-memory")) + ) + db_path = os.environ.get( + "MEMORY_DB", + os.path.join(memory_home, "memory", "memory.db"), + ) + return os.path.expandvars(os.path.expanduser(db_path)) + def _init_sqlite(db_path: str | None = None): - """Initialize SQLite database as fallback.""" + """Initialize SQLite database.""" import sqlite3 from pathlib import Path - if db_path is None: - memory_home = os.path.expandvars( - os.path.expanduser(os.environ.get("MEMORY_HOME", "~/.claude/claude-memory")) - ) - db_path = os.environ.get( - "MEMORY_DB", - os.path.join(memory_home, "memory", "memory.db"), - ) - db_path = os.path.expandvars(os.path.expanduser(db_path)) - + db_path = _get_db_path(db_path) Path(os.path.dirname(db_path)).mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(db_path, timeout=30.0) @@ -91,6 +101,15 @@ def _init_sqlite(db_path: str | None = None): updated_at TEXT NOT NULL ) """) + # Add server_id column if missing (for hybrid mode sync) + cursor.execute("PRAGMA table_info(memories)") + columns = {row["name"] for row in cursor.fetchall()} + if "server_id" not in columns: + cursor.execute("ALTER TABLE memories ADD COLUMN server_id INTEGER") + cursor.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)" + ) + cursor.execute(""" CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5( content, category, tags, expanded_keywords, @@ -118,7 +137,7 @@ def _init_sqlite(db_path: str | None = None): END """) conn.commit() - return conn + return conn, db_path # ─── Tool definitions ──────────────────────────────────────────────────────── @@ -229,10 +248,27 @@ class MemoryServer: def __init__(self, sqlite_db_path: str | None = None) -> None: self.sqlite_conn = None - if SQLITE_FALLBACK: - self.sqlite_conn = _init_sqlite(sqlite_db_path) + self.sync_engine = None - # ── HTTP-backed methods ────────────────────────────────────────── + if SQLITE_ONLY or HYBRID_MODE: + self.sqlite_conn, resolved_path = _init_sqlite(sqlite_db_path) + + if HYBRID_MODE: + from claude_memory.sync import SyncEngine + sync_interval = int(os.environ.get("MEMORY_SYNC_INTERVAL", "60")) + self.sync_engine = SyncEngine( + db_path=resolved_path, + api_base_url=API_BASE_URL, + api_key=API_KEY, + sync_interval=sync_interval, + ) + self.sync_engine.start() + + def __del__(self) -> None: + if self.sync_engine: + self.sync_engine.stop() + + # ── Tool methods ──────────────────────────────────────────────── def memory_store(self, args: dict[str, Any]) -> str: content = args.get("content") @@ -244,18 +280,28 @@ class MemoryServer: expanded_keywords = args.get("expanded_keywords", "") force_sensitive = bool(args.get("force_sensitive", False)) - if SQLITE_FALLBACK: - return self._sqlite_store(content, category, tags, importance, expanded_keywords, force_sensitive) + if HTTP_ONLY: + result = _api_request("POST", "/api/memories", { + "content": content, + "category": category, + "tags": tags, + "expanded_keywords": expanded_keywords, + "importance": importance, + "force_sensitive": force_sensitive, + }) + return f"Stored memory #{result['id']} in category '{result['category']}' with importance {result['importance']:.1f}" - result = _api_request("POST", "/api/memories", { - "content": content, - "category": category, - "tags": tags, - "expanded_keywords": expanded_keywords, - "importance": importance, - "force_sensitive": force_sensitive, - }) - return f"Stored memory #{result['id']} in category '{result['category']}' with importance {result['importance']:.1f}" + # SQLite-only or Hybrid: write to local SQLite first + result_text = self._sqlite_store(content, category, tags, importance, expanded_keywords, force_sensitive) + + if HYBRID_MODE and self.sync_engine: + # Extract local_id from result text + local_id = int(result_text.split("#")[1].split(" ")[0]) + self.sync_engine.try_sync_store( + local_id, content, category, tags, expanded_keywords, importance, force_sensitive + ) + + return result_text def memory_recall(self, args: dict[str, Any]) -> str: context = args.get("context") @@ -266,80 +312,102 @@ class MemoryServer: sort_by = args.get("sort_by", "importance") limit = args.get("limit", 10) - if SQLITE_FALLBACK: - return self._sqlite_recall(context, expanded_query, category, sort_by, limit) + if HTTP_ONLY: + result = _api_request("POST", "/api/memories/recall", { + "context": context, + "expanded_query": expanded_query, + "category": category, + "sort_by": sort_by, + "limit": limit, + }) + rows = result.get("memories", []) + if not rows: + filter_desc = f" in category '{category}'" if category else "" + return f"No memories found matching: {context}{filter_desc}" - result = _api_request("POST", "/api/memories/recall", { - "context": context, - "expanded_query": expanded_query, - "category": category, - "sort_by": sort_by, - "limit": limit, - }) - rows = result.get("memories", []) - if not rows: - filter_desc = f" in category '{category}'" if category else "" - return f"No memories found matching: {context}{filter_desc}" + sort_desc = "by relevance" if sort_by == "relevance" else "by importance" + filter_desc = f" in '{category}'" if category else "" + results = [] + for row in rows: + results.append( + f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}" + f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}" + ) + return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results) - sort_desc = "by relevance" if sort_by == "relevance" else "by importance" - filter_desc = f" in '{category}'" if category else "" - results = [] - for row in rows: - results.append( - f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}" - f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}" - ) - return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results) + # SQLite-only or Hybrid: always read from local cache + return self._sqlite_recall(context, expanded_query, category, sort_by, limit) def memory_list(self, args: dict[str, Any]) -> str: category = args.get("category") limit = args.get("limit", 20) - if SQLITE_FALLBACK: - return self._sqlite_list(category, limit) + if HTTP_ONLY: + params = f"?limit={limit}" + if category: + params += f"&category={category}" + result = _api_request("GET", f"/api/memories{params}") + rows = result.get("memories", []) + if not rows: + return f"No memories in category '{category}'" if category else "No memories stored yet" - params = f"?limit={limit}" - if category: - params += f"&category={category}" - result = _api_request("GET", f"/api/memories{params}") - rows = result.get("memories", []) - if not rows: - return f"No memories in category '{category}'" if category else "No memories stored yet" + results = [] + for row in rows: + results.append( + f"#{row['id']} [{row['category']}] {row['content']}" + f"\n Importance: {row['importance']:.1f} | Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}" + ) + header = "Recent memories" + if category: + header += f" in '{category}'" + return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results) - results = [] - for row in rows: - results.append( - f"#{row['id']} [{row['category']}] {row['content']}" - f"\n Importance: {row['importance']:.1f} | Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}" - ) - header = "Recent memories" - if category: - header += f" in '{category}'" - return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results) + # SQLite-only or Hybrid: always read from local cache + return self._sqlite_list(category, limit) def memory_delete(self, args: dict[str, Any]) -> str: memory_id = args.get("id") if memory_id is None: raise ValueError("id is required") - if SQLITE_FALLBACK: - return self._sqlite_delete(memory_id) + if HTTP_ONLY: + result = _api_request("DELETE", f"/api/memories/{memory_id}") + return f"Deleted memory #{result['deleted']}: {result['preview']}..." - result = _api_request("DELETE", f"/api/memories/{memory_id}") - return f"Deleted memory #{result['deleted']}: {result['preview']}..." + # SQLite-only or Hybrid: delete from local SQLite + # In hybrid mode, also try to sync delete to server + if HYBRID_MODE and self.sync_engine: + cursor = self.sqlite_conn.cursor() + cursor.execute("SELECT server_id FROM memories WHERE id = ?", (memory_id,)) + row = cursor.fetchone() + server_id = row["server_id"] if row and row["server_id"] else None + + result_text = self._sqlite_delete(memory_id) + + if HYBRID_MODE and self.sync_engine and server_id: + self.sync_engine.try_sync_delete(server_id) + + return result_text def secret_get(self, args: dict[str, Any]) -> str: memory_id = args.get("id") if memory_id is None: raise ValueError("id is required") - if SQLITE_FALLBACK: - return self._sqlite_secret_get(memory_id) + if HTTP_ONLY or HYBRID_MODE: + # Secrets should be fetched from API when available + try: + result = _api_request("POST", f"/api/memories/{memory_id}/secret") + return f"#{result['id']} [{result['category']}] {result['content']}" + except Exception: + if HYBRID_MODE: + # Fall back to local SQLite + return self._sqlite_secret_get(memory_id) + raise - result = _api_request("POST", f"/api/memories/{memory_id}/secret") - return f"#{result['id']} [{result['category']}] {result['content']}" + return self._sqlite_secret_get(memory_id) - # ── SQLite fallback methods ────────────────────────────────────── + # ── SQLite methods ────────────────────────────────────────────── def _sqlite_store(self, content, category, tags, importance, expanded_keywords, force_sensitive=False): from datetime import datetime, timezone @@ -520,25 +588,29 @@ class MemoryServer: return response def run(self) -> None: - for line in sys.stdin: - line = line.strip() - if not line: - continue - try: - message = json.loads(line) - except json.JSONDecodeError as e: - print( - json.dumps({ - "jsonrpc": "2.0", - "id": None, - "error": {"code": -32700, "message": f"Parse error: {e}"}, - }), - flush=True, - ) - continue - response = self.process_message(message) - if response is not None: - print(json.dumps(response), flush=True) + try: + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + message = json.loads(line) + except json.JSONDecodeError as e: + print( + json.dumps({ + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32700, "message": f"Parse error: {e}"}, + }), + flush=True, + ) + continue + response = self.process_message(message) + if response is not None: + print(json.dumps(response), flush=True) + finally: + if self.sync_engine: + self.sync_engine.stop() def main() -> None: diff --git a/src/claude_memory/sync.py b/src/claude_memory/sync.py new file mode 100644 index 0000000..812feb6 --- /dev/null +++ b/src/claude_memory/sync.py @@ -0,0 +1,334 @@ +"""Background sync between local SQLite cache and remote API. + +Uses only stdlib — no pip install required. +""" + +import json +import logging +import sqlite3 +import threading +import time +import urllib.error +import urllib.request +from datetime import datetime, timezone +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class SyncEngine: + """Background sync between local SQLite cache and remote API.""" + + def __init__(self, db_path: str, api_base_url: str, api_key: str, sync_interval: int = 60): + self.db_path = db_path + self.api_base_url = api_base_url.rstrip("/") + self.api_key = api_key + self.sync_interval = sync_interval + + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + self._last_sync_success = False + + # Own connection for thread safety + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + self._conn = sqlite3.connect(db_path, timeout=30.0, check_same_thread=False) + self._conn.row_factory = sqlite3.Row + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA busy_timeout=30000") + self._lock = threading.Lock() + + self._init_sync_tables() + + def _init_sync_tables(self) -> None: + """Create sync-specific tables if they don't exist.""" + with self._lock: + self._conn.executescript(""" + CREATE TABLE IF NOT EXISTS pending_ops ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + op_type TEXT NOT NULL, + payload TEXT NOT NULL, + created_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS sync_meta ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + """) + # Add server_id column to memories if missing + cursor = self._conn.execute("PRAGMA table_info(memories)") + columns = {row["name"] for row in cursor.fetchall()} + if "server_id" not in columns: + self._conn.execute("ALTER TABLE memories ADD COLUMN server_id INTEGER") + self._conn.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)" + ) + self._conn.commit() + + @property + def last_sync_ts(self) -> str | None: + with self._lock: + cursor = self._conn.execute( + "SELECT value FROM sync_meta WHERE key = 'last_sync_ts'" + ) + row = cursor.fetchone() + return row["value"] if row else None + + @last_sync_ts.setter + def last_sync_ts(self, value: str) -> None: + with self._lock: + self._conn.execute( + "INSERT OR REPLACE INTO sync_meta (key, value) VALUES ('last_sync_ts', ?)", + (value,), + ) + self._conn.commit() + + @property + def api_available(self) -> bool: + return self._last_sync_success + + def start(self) -> None: + """Run initial sync (blocking), then start background thread.""" + try: + self._sync_once() + self._last_sync_success = True + except Exception: + logger.warning("Initial sync failed, starting in offline mode") + self._last_sync_success = False + + self._thread = threading.Thread(target=self._sync_loop, daemon=True) + self._thread.start() + + def stop(self) -> None: + """Signal background thread to stop and wait.""" + self._stop_event.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=5) + self._conn.close() + + def _sync_loop(self) -> None: + """Periodic sync loop running in background thread.""" + while not self._stop_event.is_set(): + self._stop_event.wait(self.sync_interval) + if self._stop_event.is_set(): + break + try: + self._sync_once() + self._last_sync_success = True + except Exception as e: + logger.warning("Sync cycle failed: %s", e) + self._last_sync_success = False + + def _sync_once(self) -> None: + """Push pending ops, then pull remote changes.""" + self._push_pending_ops() + self._pull_changes() + + def _api_request(self, method: str, path: str, body: dict | None = None) -> dict: + """Make an HTTP request to the memory API.""" + url = f"{self.api_base_url}{path}" + data = json.dumps(body).encode() if body else None + req = urllib.request.Request( + url, + data=data, + method=method, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + with urllib.request.urlopen(req, timeout=15) as resp: + return json.loads(resp.read().decode()) + + def _push_pending_ops(self) -> None: + """Push queued operations to the API server.""" + with self._lock: + cursor = self._conn.execute( + "SELECT id, op_type, payload FROM pending_ops ORDER BY id" + ) + ops = cursor.fetchall() + + for op in ops: + op_id = op["id"] + op_type = op["op_type"] + payload = json.loads(op["payload"]) + + try: + if op_type == "store": + result = self._api_request("POST", "/api/memories", payload) + server_id = result.get("id") + if server_id and payload.get("local_id"): + with self._lock: + self._conn.execute( + "UPDATE memories SET server_id = ? WHERE id = ?", + (server_id, payload["local_id"]), + ) + self._conn.commit() + elif op_type == "delete": + server_id = payload.get("server_id") + if server_id: + try: + self._api_request("DELETE", f"/api/memories/{server_id}") + except RuntimeError as e: + if "404" in str(e): + pass # Already deleted on server + else: + raise + + # Remove from pending queue on success + with self._lock: + self._conn.execute("DELETE FROM pending_ops WHERE id = ?", (op_id,)) + self._conn.commit() + + except Exception as e: + logger.warning("Failed to push op %d (%s): %s", op_id, op_type, e) + raise # Propagate to mark sync as failed + + def _pull_changes(self) -> None: + """Pull changes from server since last sync.""" + params = "" + ts = self.last_sync_ts + if ts: + params = f"?since={ts}" + + result = self._api_request("GET", f"/api/memories/sync{params}") + memories = result.get("memories", []) + server_time = result.get("server_time") + + with self._lock: + for mem in memories: + server_id = mem["id"] + deleted_at = mem.get("deleted_at") + + if deleted_at: + # Remove from local cache + self._conn.execute( + "DELETE FROM memories WHERE server_id = ?", (server_id,) + ) + else: + # Upsert by server_id (server wins) + existing = self._conn.execute( + "SELECT id FROM memories WHERE server_id = ?", (server_id,) + ).fetchone() + + if existing: + self._conn.execute( + """UPDATE memories SET content = ?, category = ?, tags = ?, + expanded_keywords = ?, importance = ?, is_sensitive = ?, + updated_at = ? WHERE server_id = ?""", + ( + mem["content"], + mem["category"], + mem.get("tags", ""), + mem.get("expanded_keywords", ""), + mem["importance"], + 1 if mem.get("is_sensitive") else 0, + mem.get("updated_at", datetime.now(timezone.utc).isoformat()), + server_id, + ), + ) + else: + self._conn.execute( + """INSERT INTO memories + (content, category, tags, expanded_keywords, importance, + is_sensitive, created_at, updated_at, server_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + mem["content"], + mem["category"], + mem.get("tags", ""), + mem.get("expanded_keywords", ""), + mem["importance"], + 1 if mem.get("is_sensitive") else 0, + mem.get("created_at", datetime.now(timezone.utc).isoformat()), + mem.get("updated_at", datetime.now(timezone.utc).isoformat()), + server_id, + ), + ) + self._conn.commit() + + if server_time: + self.last_sync_ts = server_time + + def enqueue_store( + self, + local_id: int, + content: str, + category: str, + tags: str, + expanded_keywords: str, + importance: float, + force_sensitive: bool = False, + ) -> None: + """Queue a store operation for later sync.""" + payload = { + "local_id": local_id, + "content": content, + "category": category, + "tags": tags, + "expanded_keywords": expanded_keywords, + "importance": importance, + "force_sensitive": force_sensitive, + } + now = datetime.now(timezone.utc).isoformat() + with self._lock: + self._conn.execute( + "INSERT INTO pending_ops (op_type, payload, created_at) VALUES (?, ?, ?)", + ("store", json.dumps(payload), now), + ) + self._conn.commit() + + def enqueue_delete(self, server_id: int) -> None: + """Queue a delete operation for later sync.""" + payload = {"server_id": server_id} + now = datetime.now(timezone.utc).isoformat() + with self._lock: + self._conn.execute( + "INSERT INTO pending_ops (op_type, payload, created_at) VALUES (?, ?, ?)", + ("delete", json.dumps(payload), now), + ) + self._conn.commit() + + def try_sync_store( + self, + local_id: int, + content: str, + category: str, + tags: str, + expanded_keywords: str, + importance: float, + force_sensitive: bool = False, + ) -> int | None: + """Try to sync a store immediately. Returns server_id or None if failed.""" + try: + result = self._api_request("POST", "/api/memories", { + "content": content, + "category": category, + "tags": tags, + "expanded_keywords": expanded_keywords, + "importance": importance, + "force_sensitive": force_sensitive, + }) + server_id = result.get("id") + if server_id: + with self._lock: + self._conn.execute( + "UPDATE memories SET server_id = ? WHERE id = ?", + (server_id, local_id), + ) + self._conn.commit() + return server_id + except Exception: + self.enqueue_store( + local_id, content, category, tags, expanded_keywords, importance, force_sensitive + ) + return None + + def try_sync_delete(self, server_id: int) -> bool: + """Try to sync a delete immediately. Returns True if successful.""" + try: + self._api_request("DELETE", f"/api/memories/{server_id}") + return True + except Exception: + self.enqueue_delete(server_id) + return False diff --git a/tests/test_api.py b/tests/test_api.py index e5059d0..7f0f178 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -36,6 +36,7 @@ def _make_memory_row(**overrides): "rank": 0.5, "created_at": now, "updated_at": now, + "deleted_at": None, } defaults.update(overrides) return MockRow(defaults) @@ -307,3 +308,125 @@ async def test_import_memories(client): assert len(data) == 2 assert data[0]["id"] == 100 assert data[1]["id"] == 101 + + +# ─── Sync endpoint tests ───────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_sync_full_dump_without_since(client): + ac, conn, app_mod = client + now = datetime.now(timezone.utc) + conn.fetch.return_value = [ + _make_memory_row(id=1, content="mem1", deleted_at=None), + _make_memory_row(id=2, content="mem2", deleted_at=None), + ] + + async with ac: + resp = await ac.get( + "/api/memories/sync", + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + data = resp.json() + assert len(data["memories"]) == 2 + assert "server_time" in data + assert data["memories"][0]["id"] == 1 + assert data["memories"][1]["id"] == 2 + + # Without since param, should query non-deleted only + call_args = conn.fetch.call_args + query = call_args[0][0] + assert "deleted_at IS NULL" in query + + +@pytest.mark.asyncio +async def test_sync_incremental_with_since(client): + ac, conn, app_mod = client + now = datetime.now(timezone.utc) + conn.fetch.return_value = [ + _make_memory_row(id=3, content="updated mem", deleted_at=None), + ] + + async with ac: + resp = await ac.get( + "/api/memories/sync?since=2026-03-14T10:00:00+00:00", + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + data = resp.json() + assert len(data["memories"]) == 1 + + # With since param, should include updated_at filter (includes soft-deleted) + call_args = conn.fetch.call_args + query = call_args[0][0] + assert "updated_at >" in query + assert "deleted_at IS NULL" not in query + + +@pytest.mark.asyncio +async def test_sync_includes_soft_deleted_with_since(client): + ac, conn, app_mod = client + now = datetime.now(timezone.utc) + conn.fetch.return_value = [ + _make_memory_row(id=5, content="deleted mem", deleted_at=now), + ] + + async with ac: + resp = await ac.get( + "/api/memories/sync?since=2026-03-14T10:00:00+00:00", + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + data = resp.json() + assert len(data["memories"]) == 1 + assert data["memories"][0]["deleted_at"] is not None + + +# ─── Soft delete tests ─────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_delete_is_soft_delete(client): + """Delete should SET deleted_at, not DELETE the row.""" + ac, conn, app_mod = client + conn.fetchrow.return_value = _make_memory_row(id=10, vault_path=None, preview="test content") + conn.execute.return_value = None + + async with ac: + resp = await ac.delete( + "/api/memories/10", + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + + # Verify the execute call uses UPDATE SET deleted_at, not DELETE + execute_args = conn.execute.call_args + query = execute_args[0][0] + assert "UPDATE" in query + assert "deleted_at" in query + assert "DELETE" not in query.upper().split("SET")[0] # No DELETE before SET + + +@pytest.mark.asyncio +async def test_delete_excludes_already_deleted(client): + """DELETE endpoint should not find already-deleted memories.""" + ac, conn, app_mod = client + conn.fetchrow.return_value = None # Not found because deleted_at IS NULL filter + + async with ac: + resp = await ac.delete( + "/api/memories/10", + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 404 + + # Verify query includes deleted_at IS NULL + call_args = conn.fetchrow.call_args + query = call_args[0][0] + assert "deleted_at IS NULL" in query diff --git a/tests/test_sync.py b/tests/test_sync.py new file mode 100644 index 0000000..8d04ee6 --- /dev/null +++ b/tests/test_sync.py @@ -0,0 +1,395 @@ +"""Tests for the SyncEngine (local SQLite cache + remote API sync).""" + +import json +import os +import sqlite3 +import sys +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +# Force SQLite-only mode for test imports +os.environ.pop("MEMORY_API_KEY", None) +os.environ.pop("CLAUDE_MEMORY_API_KEY", None) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +from claude_memory.mcp_server import _init_sqlite +from claude_memory.sync import SyncEngine + + +@pytest.fixture +def db_path(tmp_path): + return str(tmp_path / "test_sync.db") + + +@pytest.fixture +def sqlite_conn(db_path): + """Create a SQLite database with the standard schema.""" + conn, _ = _init_sqlite(db_path) + yield conn + conn.close() + + +@pytest.fixture +def engine(db_path, sqlite_conn): + """Create a SyncEngine with mocked API.""" + eng = SyncEngine( + db_path=db_path, + api_base_url="http://fake-api:8080", + api_key="test-key", + sync_interval=3600, # Don't auto-sync in tests + ) + yield eng + eng._conn.close() + + +class TestSyncEngineInit: + def test_creates_pending_ops_table(self, engine): + cursor = engine._conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='pending_ops'" + ) + assert cursor.fetchone() is not None + + def test_creates_sync_meta_table(self, engine): + cursor = engine._conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='sync_meta'" + ) + assert cursor.fetchone() is not None + + def test_adds_server_id_column(self, engine): + cursor = engine._conn.execute("PRAGMA table_info(memories)") + columns = {row["name"] for row in cursor.fetchall()} + assert "server_id" in columns + + def test_server_id_unique_index(self, engine): + cursor = engine._conn.execute( + "SELECT name FROM sqlite_master WHERE type='index' AND name='idx_memories_server_id'" + ) + assert cursor.fetchone() is not None + + +class TestEnqueueOps: + def test_enqueue_store(self, engine): + engine.enqueue_store( + local_id=1, + content="test memory", + category="facts", + tags="test", + expanded_keywords="test memory keywords", + importance=0.7, + ) + cursor = engine._conn.execute("SELECT * FROM pending_ops") + ops = cursor.fetchall() + assert len(ops) == 1 + assert ops[0]["op_type"] == "store" + payload = json.loads(ops[0]["payload"]) + assert payload["content"] == "test memory" + assert payload["local_id"] == 1 + assert payload["importance"] == 0.7 + + def test_enqueue_delete(self, engine): + engine.enqueue_delete(server_id=42) + cursor = engine._conn.execute("SELECT * FROM pending_ops") + ops = cursor.fetchall() + assert len(ops) == 1 + assert ops[0]["op_type"] == "delete" + payload = json.loads(ops[0]["payload"]) + assert payload["server_id"] == 42 + + def test_multiple_enqueues(self, engine): + engine.enqueue_store(1, "mem1", "facts", "", "", 0.5) + engine.enqueue_store(2, "mem2", "facts", "", "", 0.5) + engine.enqueue_delete(10) + cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops") + assert cursor.fetchone()["cnt"] == 3 + + +class TestPushPendingOps: + def test_push_store_clears_queue(self, engine): + engine.enqueue_store(1, "test", "facts", "", "kw", 0.5) + + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = {"id": 100, "category": "facts", "importance": 0.5} + engine._push_pending_ops() + + # Queue should be empty + cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops") + assert cursor.fetchone()["cnt"] == 0 + + # server_id should be set on local memory (if it exists) + mock_api.assert_called_once() + + def test_push_store_updates_server_id(self, engine, sqlite_conn): + # Insert a local memory first + now = datetime.now(timezone.utc).isoformat() + sqlite_conn.execute( + "INSERT INTO memories (id, content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + (1, "test content", "facts", "", "kw", 0.5, 0, now, now), + ) + sqlite_conn.commit() + + engine.enqueue_store(1, "test content", "facts", "", "kw", 0.5) + + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = {"id": 200, "category": "facts", "importance": 0.5} + engine._push_pending_ops() + + # Check server_id was updated + cursor = engine._conn.execute("SELECT server_id FROM memories WHERE id = 1") + row = cursor.fetchone() + assert row["server_id"] == 200 + + def test_push_delete_clears_queue(self, engine): + engine.enqueue_delete(42) + + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = {"deleted": 42, "preview": "test"} + engine._push_pending_ops() + + cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops") + assert cursor.fetchone()["cnt"] == 0 + + def test_push_delete_404_still_clears(self, engine): + """A 404 on delete means already deleted on server — should still clear queue.""" + engine.enqueue_delete(42) + + with patch.object(engine, "_api_request") as mock_api: + mock_api.side_effect = RuntimeError("API error 404: not found") + engine._push_pending_ops() + + cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops") + assert cursor.fetchone()["cnt"] == 0 + + def test_push_failure_keeps_queue(self, engine): + engine.enqueue_store(1, "test", "facts", "", "kw", 0.5) + + with patch.object(engine, "_api_request") as mock_api: + mock_api.side_effect = RuntimeError("Connection refused") + with pytest.raises(RuntimeError): + engine._push_pending_ops() + + cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops") + assert cursor.fetchone()["cnt"] == 1 + + +class TestPullChanges: + def test_pull_inserts_new_memories(self, engine): + now = datetime.now(timezone.utc).isoformat() + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = { + "memories": [ + { + "id": 10, + "content": "server memory", + "category": "facts", + "tags": "tag1", + "expanded_keywords": "server memory keywords", + "importance": 0.8, + "is_sensitive": False, + "created_at": now, + "updated_at": now, + "deleted_at": None, + } + ], + "server_time": now, + } + engine._pull_changes() + + cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 10") + row = cursor.fetchone() + assert row is not None + assert row["content"] == "server memory" + assert row["importance"] == 0.8 + + def test_pull_updates_existing_memories(self, engine): + now = datetime.now(timezone.utc).isoformat() + # Insert existing memory with server_id + engine._conn.execute( + "INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at, server_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ("old content", "facts", "", "", 0.5, 0, now, now, 10), + ) + engine._conn.commit() + + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = { + "memories": [ + { + "id": 10, + "content": "updated content", + "category": "projects", + "tags": "", + "expanded_keywords": "", + "importance": 0.9, + "is_sensitive": False, + "created_at": now, + "updated_at": now, + "deleted_at": None, + } + ], + "server_time": now, + } + engine._pull_changes() + + cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 10") + row = cursor.fetchone() + assert row["content"] == "updated content" + assert row["category"] == "projects" + assert row["importance"] == 0.9 + + def test_pull_deletes_soft_deleted(self, engine): + now = datetime.now(timezone.utc).isoformat() + engine._conn.execute( + "INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at, server_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ("to be deleted", "facts", "", "", 0.5, 0, now, now, 20), + ) + engine._conn.commit() + + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = { + "memories": [ + { + "id": 20, + "content": "to be deleted", + "category": "facts", + "tags": "", + "expanded_keywords": "", + "importance": 0.5, + "is_sensitive": False, + "created_at": now, + "updated_at": now, + "deleted_at": now, + } + ], + "server_time": now, + } + engine._pull_changes() + + cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 20") + assert cursor.fetchone() is None + + def test_pull_updates_last_sync_ts(self, engine): + server_time = "2026-03-14T12:00:00+00:00" + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = { + "memories": [], + "server_time": server_time, + } + engine._pull_changes() + + assert engine.last_sync_ts == server_time + + def test_pull_with_since_param(self, engine): + engine.last_sync_ts = "2026-03-14T10:00:00+00:00" + + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = {"memories": [], "server_time": "2026-03-14T12:00:00+00:00"} + engine._pull_changes() + + call_args = mock_api.call_args + assert "since=2026-03-14T10:00:00+00:00" in call_args[0][1] + + +class TestTrySyncStore: + def test_success_returns_server_id(self, engine, sqlite_conn): + now = datetime.now(timezone.utc).isoformat() + sqlite_conn.execute( + "INSERT INTO memories (id, content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + (1, "test", "facts", "", "kw", 0.5, 0, now, now), + ) + sqlite_conn.commit() + + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = {"id": 300, "category": "facts", "importance": 0.5} + result = engine.try_sync_store(1, "test", "facts", "", "kw", 0.5) + + assert result == 300 + + def test_failure_enqueues_op(self, engine): + with patch.object(engine, "_api_request") as mock_api: + mock_api.side_effect = RuntimeError("Connection refused") + result = engine.try_sync_store(1, "test", "facts", "", "kw", 0.5) + + assert result is None + cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops") + assert cursor.fetchone()["cnt"] == 1 + + +class TestTrySyncDelete: + def test_success_returns_true(self, engine): + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = {"deleted": 42, "preview": "test"} + result = engine.try_sync_delete(42) + + assert result is True + + def test_failure_enqueues_op(self, engine): + with patch.object(engine, "_api_request") as mock_api: + mock_api.side_effect = RuntimeError("Connection refused") + result = engine.try_sync_delete(42) + + assert result is False + cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops") + assert cursor.fetchone()["cnt"] == 1 + + +class TestSyncMeta: + def test_last_sync_ts_none_initially(self, engine): + assert engine.last_sync_ts is None + + def test_last_sync_ts_persists(self, engine): + engine.last_sync_ts = "2026-03-14T12:00:00+00:00" + assert engine.last_sync_ts == "2026-03-14T12:00:00+00:00" + + def test_api_available_initially_false(self, engine): + assert engine.api_available is False + + +class TestFullSyncCycle: + def test_store_sync_push_delete_pull(self, engine, sqlite_conn): + """Full cycle: store locally → push to API → server deletes → pull removes local.""" + now = datetime.now(timezone.utc).isoformat() + + # 1. Store locally + sqlite_conn.execute( + "INSERT INTO memories (id, content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + (1, "cycle test", "facts", "", "cycle test kw", 0.5, 0, now, now), + ) + sqlite_conn.commit() + + # 2. Enqueue and push store + engine.enqueue_store(1, "cycle test", "facts", "", "cycle test kw", 0.5) + + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = {"id": 500, "category": "facts", "importance": 0.5} + engine._push_pending_ops() + + # Verify server_id set + cursor = engine._conn.execute("SELECT server_id FROM memories WHERE id = 1") + assert cursor.fetchone()["server_id"] == 500 + + # 3. Server soft-deletes → pull removes local + with patch.object(engine, "_api_request") as mock_api: + mock_api.return_value = { + "memories": [ + { + "id": 500, + "content": "cycle test", + "category": "facts", + "tags": "", + "expanded_keywords": "cycle test kw", + "importance": 0.5, + "is_sensitive": False, + "created_at": now, + "updated_at": now, + "deleted_at": now, + } + ], + "server_time": now, + } + engine._pull_changes() + + # Should be gone locally + cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 500") + assert cursor.fetchone() is None