resilient memory sync: decouple push/pull, startup full resync, auth failure handling

- Decouple push and pull in _sync_once() so pull always runs even if push fails
- Add startup full resync to catch drift from other agents and schema changes
- Add periodic full resync every ~10 minutes for continuous drift correction
- Add auth failure detection (401/403) with graceful SQLite-only degradation
- Add /api/auth-check endpoint for lightweight key validation
- Add retry cap (5 attempts) on pending ops to prevent infinite queue buildup
- Add orphan reconciliation: push local-only records with content dedup
- Add memory_count MCP tool for sync diagnostics
- Add version-based SQLite schema migration (PRAGMA user_version)
- Fix API key in ~/.claude.json to match server
- Update README with sync resilience docs, test structure, project layout
- Add 30 new tests covering all new behaviors (155 total, all passing)
This commit is contained in:
Viktor Barzin 2026-03-16 18:35:09 +00:00
parent a18b94d310
commit e47efee6b6
No known key found for this signature in database
GPG key ID: 0EB088298288D958
8 changed files with 948 additions and 134 deletions

View file

@ -44,6 +44,7 @@ Claude has direct access to these tools during conversation:
| `memory_list` | List recent memories, optionally filtered by category | | `memory_list` | List recent memories, optionally filtered by category |
| `memory_delete` | Delete a memory by ID | | `memory_delete` | Delete a memory by ID |
| `secret_get` | Retrieve the decrypted content of a sensitive memory | | `secret_get` | Retrieve the decrypted content of a sensitive memory |
| `memory_count` | Get memory counts by category and sync status diagnostics |
### Memory Categories ### Memory Categories
Memories are organized into: `facts`, `preferences`, `projects`, `people`, `decisions` Memories are organized into: `facts`, `preferences`, `projects`, `people`, `decisions`
@ -63,6 +64,7 @@ Claude Code Session
│ │ compaction │ │ memory_list │ │ │ │ compaction │ │ memory_list │ │
│ │ auto-approve │ │ memory_delete │ │ │ │ auto-approve │ │ memory_delete │ │
│ └──────────────────┘ │ secret_get │ │ │ └──────────────────┘ │ secret_get │ │
│ │ memory_count │ │
│ └─────────┬──────────┘ │ │ └─────────┬──────────┘ │
│ │ │ │ │ │
│ ┌─────────────────────────┼──────────┐ │ │ ┌─────────────────────────┼──────────┐ │
@ -89,6 +91,17 @@ Claude Code Session
└──────────────────────┘ └──────────────────────┘
``` ```
### Sync Resilience
The SyncEngine is designed to handle failures gracefully:
- **Decoupled push/pull** — push failures never block pull. Remote changes from other agents always flow in.
- **Auth failure detection** — on 401/403, the engine sets an auth-failed flag, logs a clear warning, and degrades to SQLite-only mode. A periodic health check detects when auth is restored.
- **Startup full resync** — on MCP server start, a full cache replacement runs to catch drift from other agents, deleted records, and schema changes.
- **Periodic full resync** — every ~10 minutes, a full resync replaces incremental sync to catch any drift.
- **Retry cap** — individual pending ops are retried up to 5 times, then permanently skipped to prevent queue buildup.
- **Orphan reconciliation** — local records that never synced are deduplicated against server content before push.
## Search Algorithm ## Search Algorithm
Memory recall uses two different full-text search backends depending on the operating mode. Both follow the same query building pattern: the `context` and `expanded_query` fields from the user's recall request are concatenated into a single search string, then processed into a backend-specific query. Memory recall uses two different full-text search backends depending on the operating mode. Both follow the same query building pattern: the `context` and `expanded_query` fields from the user's recall request are concatenated into a single search string, then processed into a backend-specific query.
@ -258,7 +271,7 @@ export MEMORY_API_KEY="your-api-key"
### Option 2: Manual MCP Config ### Option 2: Manual MCP Config
Add to `~/.claude/settings.json`: Add to `~/.claude.json` under `mcpServers`:
```json ```json
{ {
@ -266,8 +279,9 @@ Add to `~/.claude/settings.json`:
"claude_memory": { "claude_memory": {
"type": "stdio", "type": "stdio",
"command": "python3", "command": "python3",
"args": ["-m", "claude_memory.mcp_server"], "args": ["/path/to/claude-memory-mcp/src/claude_memory/mcp_server.py"],
"env": { "env": {
"PYTHONPATH": "/path/to/claude-memory-mcp/src",
"MEMORY_API_URL": "https://your-server.example.com", "MEMORY_API_URL": "https://your-server.example.com",
"MEMORY_API_KEY": "your-api-key" "MEMORY_API_KEY": "your-api-key"
} }
@ -276,7 +290,7 @@ Add to `~/.claude/settings.json`:
} }
``` ```
Omit the `env` block for SQLite-only mode. Requires `pip install claude-memory-mcp`. Omit `MEMORY_API_URL` and `MEMORY_API_KEY` for SQLite-only mode.
### Verify ### Verify
@ -394,7 +408,7 @@ The auto-learn hook runs asynchronously after each Claude response. It operates
**Judge fallback chain:** Claude CLI (haiku model) → local Ollama (qwen2.5:3b/llama3.2:3b/gemma2:2b/phi3:mini) → heuristic pattern matching (keyword-based extraction from user messages). **Judge fallback chain:** Claude CLI (haiku model) → local Ollama (qwen2.5:3b/llama3.2:3b/gemma2:2b/phi3:mini) → heuristic pattern matching (keyword-based extraction from user messages).
Extracted events are deduplicated via SHA-256 content hashing. Each event is stored to both the MCP memory database and a `~/.claude/projects/<project>/memory/auto-learned.md` markdown file. Extracted events are deduplicated via SHA-256 content hashing. Each event is stored to the MCP memory database.
### Debug ### Debug
@ -422,7 +436,8 @@ In hybrid mode, a `SyncEngine` runs in a daemon thread with its own SQLite conne
| Endpoint | Method | Description | | Endpoint | Method | Description |
|----------|--------|-------------| |----------|--------|-------------|
| `/health` | GET | Health check | | `/health` | GET | Health check (unauthenticated) |
| `/api/auth-check` | GET | Validate API key without side effects |
| `/api/memories` | POST | Store a memory | | `/api/memories` | POST | Store a memory |
| `/api/memories` | GET | List memories (`?category=facts&limit=20`) | | `/api/memories` | GET | List memories (`?category=facts&limit=20`) |
| `/api/memories/recall` | POST | Search memories by context and expanded query | | `/api/memories/recall` | POST | Search memories by context and expanded query |
@ -463,6 +478,8 @@ Aliases `CLAUDE_MEMORY_API_URL` and `CLAUDE_MEMORY_API_KEY` are also supported.
## Database Migrations ## Database Migrations
### PostgreSQL (API Server)
Migrations run automatically on API server startup. To run manually: Migrations run automatically on API server startup. To run manually:
```bash ```bash
@ -477,21 +494,90 @@ Three migrations:
All migrations are idempotent (check column/table existence before altering). All migrations are idempotent (check column/table existence before altering).
### SQLite (MCP Client)
SQLite schema is versioned via `PRAGMA user_version` and migrated automatically on startup. Current version: **2**.
| Version | Migration |
|---------|-----------|
| 1 | Add `server_id` column to `memories` table |
| 2 | Add `retry_count` column to `pending_ops` table |
## Development ## Development
### Prerequisites
- Python 3.11+
- For API server tests: `httpx`, `pytest-asyncio`
### Quick Start
```bash ```bash
git clone https://github.com/ViktorBarzin/claude-memory-mcp.git git clone https://github.com/ViktorBarzin/claude-memory-mcp.git
cd claude-memory-mcp cd claude-memory-mcp
python -m venv .venv python -m venv .venv
source .venv/bin/activate source .venv/bin/activate
pip install -e ".[api,dev]" pip install -e ".[api,dev]"
```
### Running Tests
```bash
# All tests
pytest tests/ -v pytest tests/ -v
# Individual test suites
pytest tests/test_sync.py -v # SyncEngine (client-side sync resilience)
pytest tests/test_mcp_server.py -v # MCP server (SQLite, tools, protocol)
pytest tests/test_api.py -v # API server (FastAPI endpoints)
# Linting
ruff check src/ tests/ ruff check src/ tests/
mypy src/claude_memory/ --strict mypy src/claude_memory/ --strict
``` ```
The MCP server itself (`mcp_server.py` and `sync.py`) uses **stdlib only** — no pip install needed on the client side. The `[api]` extra adds FastAPI, asyncpg, uvicorn, etc. for the server. The MCP server itself (`mcp_server.py` and `sync.py`) uses **stdlib only** — no pip install needed on the client side. The `[api]` extra adds FastAPI, asyncpg, uvicorn, etc. for the server.
### Test Structure
| File | Tests | What it covers |
|------|-------|---------------|
| `test_sync.py` | SyncEngine unit tests | Push/pull, auth failure handling, retry caps, full resync, orphan reconciliation, decoupled push/pull, diagnostics |
| `test_mcp_server.py` | MCP server unit tests | SQLite CRUD, FTS search, tool dispatch, MCP protocol, memory_count, schema migration |
| `test_api.py` | API server integration tests | All REST endpoints, auth, user isolation, soft delete, sync, secrets, import |
| `test_auth.py` | Auth module tests | Single/multi-user auth, key mapping |
| `test_credential_detector.py` | Credential detection | Pattern matching for secrets |
| `test_crypto.py` | Encryption tests | AES-256 encrypt/decrypt |
| `test_vault_client.py` | Vault integration | Secret storage/retrieval |
### Project Structure
```
claude-memory-mcp/
├── src/claude_memory/
│ ├── mcp_server.py # MCP server entry point (stdio NDJSON)
│ ├── sync.py # Background sync engine (SQLite ↔ API)
│ ├── credential_detector.py # Sensitive content detection
│ └── api/
│ ├── app.py # FastAPI application
│ ├── auth.py # API key authentication
│ ├── database.py # asyncpg connection pool
│ ├── models.py # Pydantic models
│ └── vault_service.py # HashiCorp Vault integration
├── tests/ # pytest test suite
├── hooks/ # Claude Code hooks (auto-recall, auto-learn, etc.)
├── docker/ # Docker Compose for API server + PostgreSQL
├── deploy/ # Kubernetes manifests and Helm chart
└── pyproject.toml # Package config (hatchling)
```
### Key Design Decisions
- **stdlib-only MCP server**: The MCP server (`mcp_server.py`, `sync.py`) uses only Python stdlib — no pip install required for the client side. This ensures it works in any Claude Code environment without dependency management.
- **NDJSON transport**: Claude Code uses NDJSON (one JSON per line) for stdio MCP, not Content-Length framing.
- **Non-blocking startup**: MCP server startup must complete in ~15s or Claude Code times out. All network calls are deferred to background threads.
- **Suppress stderr**: Any stderr output during MCP startup causes Claude Code to reject the server.
## License ## License
Apache License 2.0. See [LICENSE](LICENSE) for details. Apache License 2.0. See [LICENSE](LICENSE) for details.

View file

@ -12,7 +12,7 @@ haiku to detect learnings worth persisting:
Features: Features:
- Multi-turn context window (last 5 exchanges by default) - Multi-turn context window (last 5 exchanges by default)
- State tracking to avoid duplicate extraction - State tracking to avoid duplicate extraction
- Writes to memory API/SQLite AND auto-memory markdown files - Writes to memory API/SQLite only
- Throttled deep extraction: full window every ~5 turns, single-turn otherwise - Throttled deep extraction: full window every ~5 turns, single-turn otherwise
Runs with async: true does NOT block the user. Runs with async: true does NOT block the user.
@ -252,36 +252,6 @@ def _store_via_sqlite(content, category, tags, importance, expanded_keywords):
conn.close() conn.close()
def _append_to_auto_memory(content: str, event_type: str) -> None:
"""Append a learning to the auto-memory markdown file for the current project."""
# Find the project memory directory based on CWD
cwd = os.getcwd()
# Claude Code stores project memory at ~/.claude/projects/<escaped-path>/memory/
escaped = cwd.replace("/", "-")
if escaped.startswith("-"):
escaped = escaped[1:] # Remove leading dash
memory_dir = Path.home() / ".claude" / "projects" / f"-{escaped}" / "memory"
if not memory_dir.exists():
# Try without the leading dash
memory_dir = Path.home() / ".claude" / "projects" / escaped / "memory"
if not memory_dir.exists():
return
auto_learn_file = memory_dir / "auto-learned.md"
now = datetime.now(timezone.utc).strftime("%Y-%m-%d")
header = "# Auto-Learned Knowledge\n\nAutomatically extracted by the auto-learn hook. Review periodically and promote valuable entries to MEMORY.md.\n\n"
if not auto_learn_file.exists():
auto_learn_file.write_text(header)
# Append the new learning
with open(auto_learn_file, "a") as f:
f.write(f"- [{now}] **{event_type}**: {content}\n")
def _parse_llm_response(response_text: str) -> list[dict]: def _parse_llm_response(response_text: str) -> list[dict]:
"""Parse LLM response text into events list.""" """Parse LLM response text into events list."""
response_text = response_text.strip() response_text = response_text.strip()
@ -485,12 +455,6 @@ def _store_events(events: list[dict], extracted_hashes: list[str]) -> list[str]:
except Exception: except Exception:
pass pass
# Also append to auto-memory markdown
try:
_append_to_auto_memory(content, event_type)
except Exception:
pass
new_hashes.append(h) new_hashes.append(h)
return new_hashes return new_hashes

View file

@ -59,6 +59,12 @@ async def health() -> dict[str, str]:
return {"status": "ok"} return {"status": "ok"}
@app.get("/api/auth-check")
async def auth_check(user: AuthUser = Depends(get_current_user)) -> dict[str, str]:
"""Validate API key without doing any real work."""
return {"status": "ok", "user_id": user.user_id}
@app.get("/api/memories/sync", response_model=SyncResponse) @app.get("/api/memories/sync", response_model=SyncResponse)
async def sync_memories( async def sync_memories(
since: Optional[str] = None, since: Optional[str] = None,

View file

@ -86,6 +86,41 @@ def _get_db_path(db_path: str | None = None) -> str:
return resolved return resolved
SCHEMA_VERSION = 2
def _migrate_sqlite(conn: sqlite3.Connection) -> None:
"""Version-based SQLite schema migrations."""
current = conn.execute("PRAGMA user_version").fetchone()[0]
if current < 1:
# Add server_id column for hybrid mode sync
cursor = conn.execute("PRAGMA table_info(memories)")
columns = {row["name"] for row in cursor.fetchall()}
if "server_id" not in columns:
conn.execute("ALTER TABLE memories ADD COLUMN server_id INTEGER")
conn.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)"
)
if current < 2:
# Ensure pending_ops has retry_count (sync.py also handles this, but belt-and-suspenders)
conn.execute("""
CREATE TABLE IF NOT EXISTS pending_ops (
id INTEGER PRIMARY KEY AUTOINCREMENT,
op_type TEXT NOT NULL,
payload TEXT NOT NULL,
created_at TEXT NOT NULL,
retry_count INTEGER DEFAULT 0
)
""")
# Add retry_count if pending_ops already exists without it
cursor = conn.execute("PRAGMA table_info(pending_ops)")
po_columns = {row["name"] for row in cursor.fetchall()}
if "retry_count" not in po_columns:
conn.execute("ALTER TABLE pending_ops ADD COLUMN retry_count INTEGER DEFAULT 0")
conn.execute(f"PRAGMA user_version = {SCHEMA_VERSION}")
conn.commit()
def _init_sqlite(db_path: str | None = None) -> tuple[sqlite3.Connection, str]: def _init_sqlite(db_path: str | None = None) -> tuple[sqlite3.Connection, str]:
"""Initialize SQLite database.""" """Initialize SQLite database."""
from pathlib import Path from pathlib import Path
@ -111,14 +146,8 @@ def _init_sqlite(db_path: str | None = None) -> tuple[sqlite3.Connection, str]:
updated_at TEXT NOT NULL updated_at TEXT NOT NULL
) )
""") """)
# Add server_id column if missing (for hybrid mode sync) # Version-based schema migrations
cursor.execute("PRAGMA table_info(memories)") _migrate_sqlite(conn)
columns = {row["name"] for row in cursor.fetchall()}
if "server_id" not in columns:
cursor.execute("ALTER TABLE memories ADD COLUMN server_id INTEGER")
cursor.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)"
)
cursor.execute(""" cursor.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5( CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
@ -250,6 +279,14 @@ TOOLS = [
"required": ["id"], "required": ["id"],
}, },
}, },
{
"name": "memory_count",
"description": "Get memory counts by category from local cache and sync status. Useful for diagnostics.",
"inputSchema": {
"type": "object",
"properties": {},
},
},
] ]
@ -418,6 +455,32 @@ class MemoryServer:
return self._sqlite_secret_get(memory_id) return self._sqlite_secret_get(memory_id)
def memory_count(self, args: dict[str, Any]) -> str:
if self.sync_engine:
counts = self.sync_engine.get_counts()
lines = [f"Local memories: {counts['total']}"]
for cat, n in counts["by_category"].items():
lines.append(f" {cat}: {n}")
lines.append(f"Orphans (no server_id): {counts['orphans_no_server_id']}")
lines.append(f"Pending ops: {counts['pending_ops']}")
lines.append(f"Last sync: {counts['last_sync_ts'] or 'never'}")
lines.append(f"Auth failed: {counts['auth_failed']}")
lines.append(f"Last sync success: {counts['last_sync_success']}")
return "\n".join(lines)
if self.sqlite_conn:
cursor = self.sqlite_conn.cursor()
cursor.execute("SELECT COUNT(*) as c FROM memories")
total = cursor.fetchone()["c"]
cursor.execute("SELECT category, COUNT(*) as c FROM memories GROUP BY category ORDER BY c DESC")
by_cat = cursor.fetchall()
lines = [f"Local memories (SQLite-only): {total}"]
for row in by_cat:
lines.append(f" {row['category']}: {row['c']}")
return "\n".join(lines)
return "No storage available"
# ── SQLite methods ────────────────────────────────────────────── # ── SQLite methods ──────────────────────────────────────────────
def _sqlite_store(self, content: str, category: str, tags: str, importance: float, expanded_keywords: str, force_sensitive: bool = False) -> str: def _sqlite_store(self, content: str, category: str, tags: str, importance: float, expanded_keywords: str, force_sensitive: bool = False) -> str:
@ -573,6 +636,7 @@ class MemoryServer:
"memory_list": self.memory_list, "memory_list": self.memory_list,
"memory_delete": self.memory_delete, "memory_delete": self.memory_delete,
"secret_get": self.secret_get, "secret_get": self.secret_get,
"memory_count": self.memory_count,
}.get(tool_name) }.get(tool_name)
if handler is None: if handler is None:
return {"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], "isError": True} return {"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], "isError": True}

View file

@ -16,6 +16,12 @@ from pathlib import Path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Max retries before an individual pending op is permanently skipped
MAX_OP_RETRIES = 5
# Full resync every N sync cycles (~10 min at 60s interval)
FULL_RESYNC_EVERY = 10
class SyncEngine: class SyncEngine:
"""Background sync between local SQLite cache and remote API.""" """Background sync between local SQLite cache and remote API."""
@ -29,6 +35,7 @@ class SyncEngine:
self._stop_event = threading.Event() self._stop_event = threading.Event()
self._thread: threading.Thread | None = None self._thread: threading.Thread | None = None
self._last_sync_success = False self._last_sync_success = False
self._auth_failed = False
# Own connection for thread safety # Own connection for thread safety
Path(db_path).parent.mkdir(parents=True, exist_ok=True) Path(db_path).parent.mkdir(parents=True, exist_ok=True)
@ -48,7 +55,8 @@ class SyncEngine:
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
op_type TEXT NOT NULL, op_type TEXT NOT NULL,
payload TEXT NOT NULL, payload TEXT NOT NULL,
created_at TEXT NOT NULL created_at TEXT NOT NULL,
retry_count INTEGER DEFAULT 0
); );
CREATE TABLE IF NOT EXISTS sync_meta ( CREATE TABLE IF NOT EXISTS sync_meta (
@ -64,6 +72,13 @@ class SyncEngine:
self._conn.execute( self._conn.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)" "CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)"
) )
# Add retry_count column to pending_ops if missing (migration)
cursor = self._conn.execute("PRAGMA table_info(pending_ops)")
po_columns = {row["name"] for row in cursor.fetchall()}
if "retry_count" not in po_columns:
self._conn.execute("ALTER TABLE pending_ops ADD COLUMN retry_count INTEGER DEFAULT 0")
self._conn.commit() self._conn.commit()
@property @property
@ -89,7 +104,15 @@ class SyncEngine:
return self._last_sync_success return self._last_sync_success
def start(self) -> None: def start(self) -> None:
"""Start background sync thread (non-blocking).""" """Start background sync thread. Runs a full resync on startup."""
# Full sync on startup (blocking, before background thread)
try:
self._full_resync()
self._last_sync_success = True
self._auth_failed = False
except Exception as e:
logger.warning("Startup full sync failed: %s", e)
self._thread = threading.Thread(target=self._sync_loop, daemon=True) self._thread = threading.Thread(target=self._sync_loop, daemon=True)
self._thread.start() self._thread.start()
@ -102,21 +125,162 @@ class SyncEngine:
def _sync_loop(self) -> None: def _sync_loop(self) -> None:
"""Periodic sync loop running in background thread.""" """Periodic sync loop running in background thread."""
cycle = 0
while not self._stop_event.is_set(): while not self._stop_event.is_set():
self._stop_event.wait(self.sync_interval) self._stop_event.wait(self.sync_interval)
if self._stop_event.is_set(): if self._stop_event.is_set():
break break
cycle += 1
try: try:
self._sync_once() # If auth previously failed, try a lightweight check first
if self._auth_failed:
if not self._check_auth():
continue # Still failing, skip this cycle
if cycle % FULL_RESYNC_EVERY == 0:
self._full_resync()
else:
self._sync_once()
self._last_sync_success = True self._last_sync_success = True
except Exception as e: except Exception as e:
logger.warning("Sync cycle failed: %s", e) logger.warning("Sync cycle failed: %s", e)
self._last_sync_success = False self._last_sync_success = False
def _check_auth(self) -> bool:
"""Lightweight auth check. Returns True if auth is OK."""
try:
self._api_request("GET", "/api/auth-check")
self._auth_failed = False
logger.info("Auth check passed — resuming sync")
return True
except urllib.error.HTTPError as e:
if e.code in (401, 403):
logger.warning(
"Auth still failing (HTTP %d) — API key mismatch. "
"Update MEMORY_API_KEY in ~/.claude.json", e.code
)
return False
# Non-auth error (e.g. 500) — try the auth-check endpoint might not exist,
# fall back to /health
pass
except Exception:
pass
# Fallback: try /health (unauthenticated)
try:
url = f"{self.api_base_url}/health"
req = urllib.request.Request(url, method="GET")
with urllib.request.urlopen(req, timeout=5):
pass
# Server is reachable but auth-check failed — auth is still broken
return False
except Exception:
# Server unreachable — not an auth problem
return False
def _sync_once(self) -> None: def _sync_once(self) -> None:
"""Push pending ops, then pull remote changes.""" """Push pending ops, then pull remote changes. Both run independently."""
self._push_pending_ops() push_ok = self._push_pending_ops()
self._pull_changes() pull_ok = self._pull_changes()
if not push_ok and not pull_ok:
raise RuntimeError("Both push and pull failed")
def _full_resync(self) -> None:
"""Full cache replacement from server — handles drift, deletes, schema changes."""
# Step 1: Push orphaned local-only records (deduplicated)
self._push_orphans()
# Step 2: Pull everything from server (no since filter = non-deleted only)
result = self._api_request("GET", "/api/memories/sync")
memories = result.get("memories", [])
server_time = result.get("server_time")
server_ids = {m["id"] for m in memories}
with self._lock:
# Delete local records whose server_id no longer exists on server
local_rows = self._conn.execute(
"SELECT id, server_id FROM memories WHERE server_id IS NOT NULL"
).fetchall()
for row in local_rows:
if row["server_id"] not in server_ids:
self._conn.execute("DELETE FROM memories WHERE id = ?", (row["id"],))
# Delete remaining orphans (already pushed or duplicates)
self._conn.execute("DELETE FROM memories WHERE server_id IS NULL")
# Upsert all server records
for mem in memories:
server_id = mem["id"]
existing = self._conn.execute(
"SELECT id FROM memories WHERE server_id = ?", (server_id,)
).fetchone()
if existing:
self._conn.execute(
"""UPDATE memories SET content=?, category=?, tags=?,
expanded_keywords=?, importance=?, is_sensitive=?,
updated_at=? WHERE server_id=?""",
(
mem["content"], mem["category"], mem.get("tags", ""),
mem.get("expanded_keywords", ""), mem["importance"],
1 if mem.get("is_sensitive") else 0,
mem.get("updated_at", ""), server_id,
),
)
else:
self._conn.execute(
"""INSERT INTO memories (content, category, tags, expanded_keywords,
importance, is_sensitive, created_at, updated_at, server_id)
VALUES (?,?,?,?,?,?,?,?,?)""",
(
mem["content"], mem["category"], mem.get("tags", ""),
mem.get("expanded_keywords", ""), mem["importance"],
1 if mem.get("is_sensitive") else 0,
mem.get("created_at", ""), mem.get("updated_at", ""), server_id,
),
)
self._conn.commit()
if server_time:
self.last_sync_ts = server_time
def _push_orphans(self) -> None:
"""Push local-only records to server, skipping content duplicates."""
with self._lock:
orphans = self._conn.execute(
"SELECT id, content, category, tags, expanded_keywords, importance "
"FROM memories WHERE server_id IS NULL"
).fetchall()
if not orphans:
return
# Get all server content for dedup comparison
result = self._api_request("GET", "/api/memories/sync")
server_contents = {m["content"] for m in result.get("memories", [])}
for orphan in orphans:
if orphan["content"] in server_contents:
continue # Skip duplicate
try:
resp = self._api_request("POST", "/api/memories", {
"content": orphan["content"],
"category": orphan["category"],
"tags": orphan["tags"],
"expanded_keywords": orphan["expanded_keywords"],
"importance": orphan["importance"],
})
server_id = resp.get("id")
if server_id:
with self._lock:
self._conn.execute(
"UPDATE memories SET server_id=? WHERE id=?",
(server_id, orphan["id"]),
)
self._conn.commit()
except Exception:
pass # Will be cleaned up by the full resync delete step
def _api_request(self, method: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]: def _api_request(self, method: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]:
"""Make an HTTP request to the memory API.""" """Make an HTTP request to the memory API."""
@ -131,22 +295,47 @@ class SyncEngine:
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
with urllib.request.urlopen(req, timeout=15) as resp: try:
result: dict[str, Any] = json.loads(resp.read().decode()) with urllib.request.urlopen(req, timeout=15) as resp:
return result result: dict[str, Any] = json.loads(resp.read().decode())
return result
except urllib.error.HTTPError as e:
if e.code in (401, 403):
self._auth_failed = True
logger.warning(
"Auth failed (HTTP %d) — API key may have rotated. "
"Update MEMORY_API_KEY in ~/.claude.json", e.code
)
raise
def _push_pending_ops(self) -> None: def _push_pending_ops(self) -> bool:
"""Push queued operations to the API server.""" """Push queued operations to the API server. Returns True on success."""
with self._lock: with self._lock:
cursor = self._conn.execute( cursor = self._conn.execute(
"SELECT id, op_type, payload FROM pending_ops ORDER BY id" "SELECT id, op_type, payload, retry_count FROM pending_ops ORDER BY id"
) )
ops = cursor.fetchall() ops = cursor.fetchall()
if not ops:
return True
all_ok = True
for op in ops: for op in ops:
op_id = op["id"] op_id = op["id"]
op_type = op["op_type"] op_type = op["op_type"]
payload = json.loads(op["payload"]) payload = json.loads(op["payload"])
retry_count = op["retry_count"] or 0
# Skip ops that have exceeded retry limit
if retry_count >= MAX_OP_RETRIES:
logger.warning(
"Skipping op %d (%s) after %d retries — removing from queue",
op_id, op_type, retry_count,
)
with self._lock:
self._conn.execute("DELETE FROM pending_ops WHERE id = ?", (op_id,))
self._conn.commit()
continue
try: try:
if op_type == "store": if op_type == "store":
@ -164,8 +353,8 @@ class SyncEngine:
if server_id: if server_id:
try: try:
self._api_request("DELETE", f"/api/memories/{server_id}") self._api_request("DELETE", f"/api/memories/{server_id}")
except (RuntimeError, urllib.error.HTTPError) as e: except urllib.error.HTTPError as e:
if "404" in str(e): if e.code == 404:
pass # Already deleted on server pass # Already deleted on server
else: else:
raise raise
@ -175,75 +364,103 @@ class SyncEngine:
self._conn.execute("DELETE FROM pending_ops WHERE id = ?", (op_id,)) self._conn.execute("DELETE FROM pending_ops WHERE id = ?", (op_id,))
self._conn.commit() self._conn.commit()
except Exception as e: except urllib.error.HTTPError as e:
logger.warning("Failed to push op %d (%s): %s", op_id, op_type, e) if e.code in (401, 403):
raise # Propagate to mark sync as failed self._auth_failed = True
logger.warning("Auth failed (HTTP %d) — aborting push", e.code)
def _pull_changes(self) -> None: return False # Abort entire push — no point retrying with bad key
"""Pull changes from server since last sync.""" # Increment retry count for non-auth errors
params = "" with self._lock:
ts = self.last_sync_ts
if ts:
params = f"?since={urllib.parse.quote(ts, safe='')}"
result = self._api_request("GET", f"/api/memories/sync{params}")
memories = result.get("memories", [])
server_time = result.get("server_time")
with self._lock:
for mem in memories:
server_id = mem["id"]
deleted_at = mem.get("deleted_at")
if deleted_at:
# Remove from local cache
self._conn.execute( self._conn.execute(
"DELETE FROM memories WHERE server_id = ?", (server_id,) "UPDATE pending_ops SET retry_count = retry_count + 1 WHERE id = ?",
(op_id,),
) )
else: self._conn.commit()
# Upsert by server_id (server wins) logger.warning("Failed to push op %d (%s): HTTP %d", op_id, op_type, e.code)
existing = self._conn.execute( all_ok = False
"SELECT id FROM memories WHERE server_id = ?", (server_id,) except Exception as e:
).fetchone() with self._lock:
self._conn.execute(
"UPDATE pending_ops SET retry_count = retry_count + 1 WHERE id = ?",
(op_id,),
)
self._conn.commit()
logger.warning("Failed to push op %d (%s): %s", op_id, op_type, e)
all_ok = False
if existing: return all_ok
def _pull_changes(self) -> bool:
"""Pull changes from server since last sync. Returns True on success."""
try:
params = ""
ts = self.last_sync_ts
if ts:
params = f"?since={urllib.parse.quote(ts, safe='')}"
result = self._api_request("GET", f"/api/memories/sync{params}")
memories = result.get("memories", [])
server_time = result.get("server_time")
with self._lock:
for mem in memories:
server_id = mem["id"]
deleted_at = mem.get("deleted_at")
if deleted_at:
# Remove from local cache
self._conn.execute( self._conn.execute(
"""UPDATE memories SET content = ?, category = ?, tags = ?, "DELETE FROM memories WHERE server_id = ?", (server_id,)
expanded_keywords = ?, importance = ?, is_sensitive = ?,
updated_at = ? WHERE server_id = ?""",
(
mem["content"],
mem["category"],
mem.get("tags", ""),
mem.get("expanded_keywords", ""),
mem["importance"],
1 if mem.get("is_sensitive") else 0,
mem.get("updated_at", datetime.now(timezone.utc).isoformat()),
server_id,
),
) )
else: else:
self._conn.execute( # Upsert by server_id (server wins)
"""INSERT INTO memories existing = self._conn.execute(
(content, category, tags, expanded_keywords, importance, "SELECT id FROM memories WHERE server_id = ?", (server_id,)
is_sensitive, created_at, updated_at, server_id) ).fetchone()
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
mem["content"],
mem["category"],
mem.get("tags", ""),
mem.get("expanded_keywords", ""),
mem["importance"],
1 if mem.get("is_sensitive") else 0,
mem.get("created_at", datetime.now(timezone.utc).isoformat()),
mem.get("updated_at", datetime.now(timezone.utc).isoformat()),
server_id,
),
)
self._conn.commit()
if server_time: if existing:
self.last_sync_ts = server_time self._conn.execute(
"""UPDATE memories SET content = ?, category = ?, tags = ?,
expanded_keywords = ?, importance = ?, is_sensitive = ?,
updated_at = ? WHERE server_id = ?""",
(
mem["content"],
mem["category"],
mem.get("tags", ""),
mem.get("expanded_keywords", ""),
mem["importance"],
1 if mem.get("is_sensitive") else 0,
mem.get("updated_at", datetime.now(timezone.utc).isoformat()),
server_id,
),
)
else:
self._conn.execute(
"""INSERT INTO memories
(content, category, tags, expanded_keywords, importance,
is_sensitive, created_at, updated_at, server_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
mem["content"],
mem["category"],
mem.get("tags", ""),
mem.get("expanded_keywords", ""),
mem["importance"],
1 if mem.get("is_sensitive") else 0,
mem.get("created_at", datetime.now(timezone.utc).isoformat()),
mem.get("updated_at", datetime.now(timezone.utc).isoformat()),
server_id,
),
)
self._conn.commit()
if server_time:
self.last_sync_ts = server_time
return True
except Exception as e:
logger.warning("Pull changes failed: %s", e)
return False
def enqueue_store( def enqueue_store(
self, self,
@ -295,6 +512,11 @@ class SyncEngine:
force_sensitive: bool = False, force_sensitive: bool = False,
) -> int | None: ) -> int | None:
"""Try to sync a store immediately. Returns server_id or None if failed.""" """Try to sync a store immediately. Returns server_id or None if failed."""
if self._auth_failed:
self.enqueue_store(
local_id, content, category, tags, expanded_keywords, importance, force_sensitive
)
return None
try: try:
result = self._api_request("POST", "/api/memories", { result = self._api_request("POST", "/api/memories", {
"content": content, "content": content,
@ -321,6 +543,9 @@ class SyncEngine:
def try_sync_delete(self, server_id: int) -> bool: def try_sync_delete(self, server_id: int) -> bool:
"""Try to sync a delete immediately. Returns True if successful.""" """Try to sync a delete immediately. Returns True if successful."""
if self._auth_failed:
self.enqueue_delete(server_id)
return False
try: try:
self._api_request("DELETE", f"/api/memories/{server_id}") self._api_request("DELETE", f"/api/memories/{server_id}")
return True return True
@ -332,3 +557,27 @@ class SyncEngine:
except Exception: except Exception:
self.enqueue_delete(server_id) self.enqueue_delete(server_id)
return False return False
def get_counts(self) -> dict[str, Any]:
"""Get memory counts for diagnostics."""
with self._lock:
total = self._conn.execute("SELECT COUNT(*) as c FROM memories").fetchone()["c"]
by_cat = self._conn.execute(
"SELECT category, COUNT(*) as c FROM memories GROUP BY category ORDER BY c DESC"
).fetchall()
orphans = self._conn.execute(
"SELECT COUNT(*) as c FROM memories WHERE server_id IS NULL"
).fetchone()["c"]
pending = self._conn.execute(
"SELECT COUNT(*) as c FROM pending_ops"
).fetchone()["c"]
return {
"total": total,
"by_category": {row["category"]: row["c"] for row in by_cat},
"orphans_no_server_id": orphans,
"pending_ops": pending,
"last_sync_ts": self.last_sync_ts,
"auth_failed": self._auth_failed,
"last_sync_success": self._last_sync_success,
}

View file

@ -99,6 +99,20 @@ async def test_health_endpoint_no_auth(client):
assert resp.json() == {"status": "ok"} assert resp.json() == {"status": "ok"}
@pytest.mark.asyncio
async def test_auth_check_endpoint(client):
ac, conn, app_mod = client
async with ac:
resp = await ac.get(
"/api/auth-check",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "ok"
assert data["user_id"] == "testuser"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_memory_creates_record_with_user_id(client): async def test_store_memory_creates_record_with_user_id(client):
ac, conn, app_mod = client ac, conn, app_mod = client

View file

@ -238,9 +238,9 @@ class TestMCPProtocol:
def test_handle_tools_list(self, server): def test_handle_tools_list(self, server):
result = server.handle_tools_list({}) result = server.handle_tools_list({})
tools = result["tools"] tools = result["tools"]
assert len(tools) == 5 assert len(tools) == 6
names = {t["name"] for t in tools} names = {t["name"] for t in tools}
assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get"} assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get", "memory_count"}
def test_handle_tools_call_store(self, server): def test_handle_tools_call_store(self, server):
result = server.handle_tools_call({ result = server.handle_tools_call({
@ -291,7 +291,7 @@ class TestProcessMessage:
"params": {}, "params": {},
}) })
assert "result" in response assert "result" in response
assert len(response["result"]["tools"]) == 5 assert len(response["result"]["tools"]) == 6
def test_tools_call(self, server): def test_tools_call(self, server):
response = server.process_message({ response = server.process_message({
@ -340,3 +340,71 @@ class TestProcessMessage:
parsed = json.loads(serialized) parsed = json.loads(serialized)
assert parsed["jsonrpc"] == "2.0" assert parsed["jsonrpc"] == "2.0"
assert parsed["id"] == 5 assert parsed["id"] == 5
class TestMemoryCount:
def test_count_empty(self, server):
result = server.memory_count({})
assert "0" in result
def test_count_after_store(self, server):
server.memory_store({
"content": "test memory",
"expanded_keywords": "test memory keywords data",
})
result = server.memory_count({})
assert "1" in result
assert "facts" in result
def test_count_multiple_categories(self, server):
server.memory_store({
"content": "a fact",
"category": "facts",
"expanded_keywords": "fact test data words",
})
server.memory_store({
"content": "a preference",
"category": "preferences",
"expanded_keywords": "preference test data words",
})
result = server.memory_count({})
assert "facts: 1" in result
assert "preferences: 1" in result
def test_count_via_tools_call(self, server):
result = server.handle_tools_call({
"name": "memory_count",
"arguments": {},
})
assert not result.get("isError", False)
assert "0" in result["content"][0]["text"]
class TestSchemaMigration:
def test_schema_version_set(self, tmp_path):
db_path = str(tmp_path / "test.db")
srv = MemoryServer(sqlite_db_path=db_path)
cursor = srv.sqlite_conn.cursor()
version = cursor.execute("PRAGMA user_version").fetchone()[0]
assert version == 2
srv.sqlite_conn.close()
def test_migration_idempotent(self, tmp_path):
"""Running _init_sqlite twice should not error."""
from claude_memory.mcp_server import _init_sqlite
db_path = str(tmp_path / "test.db")
conn1, _ = _init_sqlite(db_path)
conn1.close()
conn2, _ = _init_sqlite(db_path)
version = conn2.execute("PRAGMA user_version").fetchone()[0]
assert version == 2
conn2.close()
def test_server_id_column_exists(self, tmp_path):
db_path = str(tmp_path / "test.db")
srv = MemoryServer(sqlite_db_path=db_path)
cursor = srv.sqlite_conn.cursor()
cursor.execute("PRAGMA table_info(memories)")
columns = {row["name"] for row in cursor.fetchall()}
assert "server_id" in columns
srv.sqlite_conn.close()

View file

@ -3,8 +3,9 @@
import json import json
import os import os
import sys import sys
import urllib.error
from datetime import datetime, timezone from datetime import datetime, timezone
from unittest.mock import patch from unittest.mock import patch, MagicMock
import pytest import pytest
@ -154,21 +155,25 @@ class TestPushPendingOps:
"""A 404 on delete means already deleted on server — should still clear queue.""" """A 404 on delete means already deleted on server — should still clear queue."""
engine.enqueue_delete(42) engine.enqueue_delete(42)
import urllib.error
with patch.object(engine, "_api_request") as mock_api: with patch.object(engine, "_api_request") as mock_api:
mock_api.side_effect = RuntimeError("API error 404: not found") mock_api.side_effect = urllib.error.HTTPError(
url="http://fake", code=404, msg="Not Found", hdrs=None, fp=None
)
engine._push_pending_ops() engine._push_pending_ops()
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops") cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
assert cursor.fetchone()["cnt"] == 0 assert cursor.fetchone()["cnt"] == 0
def test_push_failure_keeps_queue(self, engine): def test_push_failure_keeps_queue_returns_false(self, engine):
"""Push failure should keep the op in queue and return False (not raise)."""
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5) engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
with patch.object(engine, "_api_request") as mock_api: with patch.object(engine, "_api_request") as mock_api:
mock_api.side_effect = RuntimeError("Connection refused") mock_api.side_effect = RuntimeError("Connection refused")
with pytest.raises(RuntimeError): result = engine._push_pending_ops()
engine._push_pending_ops()
assert result is False
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops") cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
assert cursor.fetchone()["cnt"] == 1 assert cursor.fetchone()["cnt"] == 1
@ -393,3 +398,361 @@ class TestFullSyncCycle:
# Should be gone locally # Should be gone locally
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 500") cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 500")
assert cursor.fetchone() is None assert cursor.fetchone() is None
class TestAuthFailureHandling:
def test_auth_flag_set_on_401(self, engine):
"""401 from _api_request should set _auth_failed flag."""
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
with patch.object(engine, "_api_request") as mock_api:
mock_api.side_effect = urllib.error.HTTPError(
url="http://fake", code=401, msg="Unauthorized", hdrs=None, fp=None
)
result = engine._push_pending_ops()
assert result is False
assert engine._auth_failed is True
def test_auth_flag_set_on_403(self, engine):
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
with patch.object(engine, "_api_request") as mock_api:
mock_api.side_effect = urllib.error.HTTPError(
url="http://fake", code=403, msg="Forbidden", hdrs=None, fp=None
)
result = engine._push_pending_ops()
assert result is False
assert engine._auth_failed is True
def test_push_aborts_on_auth_failure(self, engine):
"""On 401, push should abort immediately — no further ops attempted."""
engine.enqueue_store(1, "test1", "facts", "", "kw", 0.5)
engine.enqueue_store(2, "test2", "facts", "", "kw", 0.5)
with patch.object(engine, "_api_request") as mock_api:
mock_api.side_effect = urllib.error.HTTPError(
url="http://fake", code=401, msg="Unauthorized", hdrs=None, fp=None
)
engine._push_pending_ops()
# Both ops should still be in queue (aborted before processing second)
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
assert cursor.fetchone()["cnt"] == 2
def test_try_sync_store_queues_when_auth_failed(self, engine):
"""When auth is failed, try_sync_store should queue without attempting API call."""
engine._auth_failed = True
result = engine.try_sync_store(1, "test", "facts", "", "kw", 0.5)
assert result is None
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
assert cursor.fetchone()["cnt"] == 1
def test_try_sync_delete_queues_when_auth_failed(self, engine):
engine._auth_failed = True
result = engine.try_sync_delete(42)
assert result is False
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
assert cursor.fetchone()["cnt"] == 1
def test_check_auth_clears_flag_on_success(self, engine):
engine._auth_failed = True
with patch.object(engine, "_api_request") as mock_api:
mock_api.return_value = {"status": "ok", "user_id": "test"}
result = engine._check_auth()
assert result is True
assert engine._auth_failed is False
def test_check_auth_stays_failed_on_401(self, engine):
engine._auth_failed = True
with patch.object(engine, "_api_request") as mock_api:
mock_api.side_effect = urllib.error.HTTPError(
url="http://fake", code=401, msg="Unauthorized", hdrs=None, fp=None
)
# Also mock urlopen for /health fallback
with patch("urllib.request.urlopen") as mock_urlopen:
mock_urlopen.return_value.__enter__ = MagicMock()
mock_urlopen.return_value.__exit__ = MagicMock(return_value=False)
result = engine._check_auth()
assert result is False
assert engine._auth_failed is True
class TestRetryCount:
def test_retry_count_incremented_on_failure(self, engine):
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
with patch.object(engine, "_api_request") as mock_api:
mock_api.side_effect = RuntimeError("Connection refused")
engine._push_pending_ops()
cursor = engine._conn.execute("SELECT retry_count FROM pending_ops WHERE id = 1")
assert cursor.fetchone()["retry_count"] == 1
def test_op_skipped_after_max_retries(self, engine):
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
# Set retry_count to max
engine._conn.execute("UPDATE pending_ops SET retry_count = 5 WHERE id = 1")
engine._conn.commit()
with patch.object(engine, "_api_request") as mock_api:
result = engine._push_pending_ops()
# Op should be deleted (skipped), API never called
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
assert cursor.fetchone()["cnt"] == 0
mock_api.assert_not_called()
def test_retry_count_persists_across_pushes(self, engine):
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
with patch.object(engine, "_api_request") as mock_api:
mock_api.side_effect = RuntimeError("fail")
engine._push_pending_ops()
engine._push_pending_ops()
engine._push_pending_ops()
cursor = engine._conn.execute("SELECT retry_count FROM pending_ops WHERE id = 1")
assert cursor.fetchone()["retry_count"] == 3
class TestDecoupledPushPull:
def test_pull_runs_even_when_push_fails(self, engine):
"""Pull should execute even if push fails — they're decoupled."""
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
now = datetime.now(timezone.utc).isoformat()
call_count = 0
def mock_api(method, path, body=None):
nonlocal call_count
call_count += 1
if "POST" == method:
raise RuntimeError("Push failed")
# GET for pull
return {
"memories": [{
"id": 99, "content": "from server", "category": "facts",
"tags": "", "expanded_keywords": "", "importance": 0.5,
"is_sensitive": False, "created_at": now, "updated_at": now,
"deleted_at": None,
}],
"server_time": now,
}
with patch.object(engine, "_api_request", side_effect=mock_api):
engine._sync_once()
# Pull should have inserted the server memory
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 99")
assert cursor.fetchone() is not None
def test_sync_once_returns_normally_on_partial_failure(self, engine):
"""If push fails but pull succeeds, _sync_once should not raise."""
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
def mock_api(method, path, body=None):
if method == "POST":
raise RuntimeError("Push failed")
return {"memories": [], "server_time": "2026-03-16T12:00:00+00:00"}
with patch.object(engine, "_api_request", side_effect=mock_api):
# Should not raise
engine._sync_once()
class TestFullResync:
def test_full_resync_inserts_server_records(self, engine):
now = datetime.now(timezone.utc).isoformat()
with patch.object(engine, "_api_request") as mock_api:
mock_api.return_value = {
"memories": [
{"id": 1, "content": "server mem 1", "category": "facts",
"tags": "", "expanded_keywords": "", "importance": 0.5,
"is_sensitive": False, "created_at": now, "updated_at": now},
{"id": 2, "content": "server mem 2", "category": "projects",
"tags": "", "expanded_keywords": "", "importance": 0.8,
"is_sensitive": False, "created_at": now, "updated_at": now},
],
"server_time": now,
}
engine._full_resync()
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM memories")
assert cursor.fetchone()["cnt"] == 2
def test_full_resync_removes_stale_local_records(self, engine):
"""Local records with server_ids not on server should be deleted."""
now = datetime.now(timezone.utc).isoformat()
# Insert a local record with server_id=999 (not on server)
engine._conn.execute(
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
"is_sensitive, created_at, updated_at, server_id) VALUES (?,?,?,?,?,?,?,?,?)",
("stale", "facts", "", "", 0.5, 0, now, now, 999),
)
engine._conn.commit()
with patch.object(engine, "_api_request") as mock_api:
mock_api.return_value = {
"memories": [
{"id": 1, "content": "current", "category": "facts",
"tags": "", "expanded_keywords": "", "importance": 0.5,
"is_sensitive": False, "created_at": now, "updated_at": now},
],
"server_time": now,
}
engine._full_resync()
# Stale record should be gone
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 999")
assert cursor.fetchone() is None
# Current record should exist
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 1")
assert cursor.fetchone() is not None
def test_full_resync_deletes_orphans_after_push(self, engine):
"""Orphans (server_id IS NULL) should be cleaned up after push attempt."""
now = datetime.now(timezone.utc).isoformat()
engine._conn.execute(
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
"is_sensitive, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?)",
("orphan", "facts", "", "", 0.5, 0, now, now),
)
engine._conn.commit()
with patch.object(engine, "_api_request") as mock_api:
mock_api.return_value = {
"memories": [],
"server_time": now,
}
engine._full_resync()
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id IS NULL")
assert cursor.fetchone() is None
def test_full_resync_updates_last_sync_ts(self, engine):
server_time = "2026-03-16T15:00:00+00:00"
with patch.object(engine, "_api_request") as mock_api:
mock_api.return_value = {"memories": [], "server_time": server_time}
engine._full_resync()
assert engine.last_sync_ts == server_time
def test_full_resync_updates_existing_records(self, engine):
now = datetime.now(timezone.utc).isoformat()
engine._conn.execute(
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
"is_sensitive, created_at, updated_at, server_id) VALUES (?,?,?,?,?,?,?,?,?)",
("old content", "facts", "", "", 0.5, 0, now, now, 10),
)
engine._conn.commit()
with patch.object(engine, "_api_request") as mock_api:
mock_api.return_value = {
"memories": [
{"id": 10, "content": "new content", "category": "projects",
"tags": "updated", "expanded_keywords": "", "importance": 0.9,
"is_sensitive": False, "created_at": now, "updated_at": now},
],
"server_time": now,
}
engine._full_resync()
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 10")
row = cursor.fetchone()
assert row["content"] == "new content"
assert row["category"] == "projects"
assert row["importance"] == 0.9
class TestPushOrphans:
def test_push_orphans_skips_duplicates(self, engine):
now = datetime.now(timezone.utc).isoformat()
# Insert orphan with content matching server
engine._conn.execute(
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
"is_sensitive, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?)",
("duplicate content", "facts", "", "", 0.5, 0, now, now),
)
engine._conn.commit()
call_log = []
def mock_api(method, path, body=None):
call_log.append((method, path))
return {
"memories": [{"id": 1, "content": "duplicate content", "category": "facts",
"tags": "", "expanded_keywords": "", "importance": 0.5,
"is_sensitive": False, "created_at": now, "updated_at": now}],
"server_time": now,
}
with patch.object(engine, "_api_request", side_effect=mock_api):
engine._push_orphans()
# Should have called GET for sync but NOT POST (duplicate skipped)
assert all(m != "POST" for m, _ in call_log)
def test_push_orphans_posts_unique(self, engine):
now = datetime.now(timezone.utc).isoformat()
engine._conn.execute(
"INSERT INTO memories (id, content, category, tags, expanded_keywords, importance, "
"is_sensitive, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?)",
(1, "unique content", "facts", "", "", 0.5, 0, now, now),
)
engine._conn.commit()
def mock_api(method, path, body=None):
if method == "GET":
return {"memories": [], "server_time": now}
if method == "POST":
return {"id": 100, "category": "facts", "importance": 0.5}
return {}
with patch.object(engine, "_api_request", side_effect=mock_api):
engine._push_orphans()
# Orphan should now have server_id
cursor = engine._conn.execute("SELECT server_id FROM memories WHERE id = 1")
assert cursor.fetchone()["server_id"] == 100
class TestGetCounts:
def test_empty_counts(self, engine):
counts = engine.get_counts()
assert counts["total"] == 0
assert counts["by_category"] == {}
assert counts["orphans_no_server_id"] == 0
assert counts["pending_ops"] == 0
assert counts["auth_failed"] is False
def test_counts_with_data(self, engine):
now = datetime.now(timezone.utc).isoformat()
engine._conn.execute(
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
"is_sensitive, created_at, updated_at, server_id) VALUES (?,?,?,?,?,?,?,?,?)",
("mem1", "facts", "", "", 0.5, 0, now, now, 1),
)
engine._conn.execute(
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
"is_sensitive, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?)",
("orphan", "projects", "", "", 0.5, 0, now, now),
)
engine.enqueue_store(99, "queued", "facts", "", "", 0.5)
engine._conn.commit()
counts = engine.get_counts()
assert counts["total"] == 2
assert counts["by_category"]["facts"] == 1
assert counts["by_category"]["projects"] == 1
assert counts["orphans_no_server_id"] == 1
assert counts["pending_ops"] == 1