resilient memory sync: decouple push/pull, startup full resync, auth failure handling
- Decouple push and pull in _sync_once() so pull always runs even if push fails - Add startup full resync to catch drift from other agents and schema changes - Add periodic full resync every ~10 minutes for continuous drift correction - Add auth failure detection (401/403) with graceful SQLite-only degradation - Add /api/auth-check endpoint for lightweight key validation - Add retry cap (5 attempts) on pending ops to prevent infinite queue buildup - Add orphan reconciliation: push local-only records with content dedup - Add memory_count MCP tool for sync diagnostics - Add version-based SQLite schema migration (PRAGMA user_version) - Fix API key in ~/.claude.json to match server - Update README with sync resilience docs, test structure, project layout - Add 30 new tests covering all new behaviors (155 total, all passing)
This commit is contained in:
parent
a18b94d310
commit
e47efee6b6
8 changed files with 948 additions and 134 deletions
96
README.md
96
README.md
|
|
@ -44,6 +44,7 @@ Claude has direct access to these tools during conversation:
|
||||||
| `memory_list` | List recent memories, optionally filtered by category |
|
| `memory_list` | List recent memories, optionally filtered by category |
|
||||||
| `memory_delete` | Delete a memory by ID |
|
| `memory_delete` | Delete a memory by ID |
|
||||||
| `secret_get` | Retrieve the decrypted content of a sensitive memory |
|
| `secret_get` | Retrieve the decrypted content of a sensitive memory |
|
||||||
|
| `memory_count` | Get memory counts by category and sync status diagnostics |
|
||||||
|
|
||||||
### Memory Categories
|
### Memory Categories
|
||||||
Memories are organized into: `facts`, `preferences`, `projects`, `people`, `decisions`
|
Memories are organized into: `facts`, `preferences`, `projects`, `people`, `decisions`
|
||||||
|
|
@ -63,6 +64,7 @@ Claude Code Session
|
||||||
│ │ compaction │ │ memory_list │ │
|
│ │ compaction │ │ memory_list │ │
|
||||||
│ │ auto-approve │ │ memory_delete │ │
|
│ │ auto-approve │ │ memory_delete │ │
|
||||||
│ └──────────────────┘ │ secret_get │ │
|
│ └──────────────────┘ │ secret_get │ │
|
||||||
|
│ │ memory_count │ │
|
||||||
│ └─────────┬──────────┘ │
|
│ └─────────┬──────────┘ │
|
||||||
│ │ │
|
│ │ │
|
||||||
│ ┌─────────────────────────┼──────────┐ │
|
│ ┌─────────────────────────┼──────────┐ │
|
||||||
|
|
@ -89,6 +91,17 @@ Claude Code Session
|
||||||
└──────────────────────┘
|
└──────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Sync Resilience
|
||||||
|
|
||||||
|
The SyncEngine is designed to handle failures gracefully:
|
||||||
|
|
||||||
|
- **Decoupled push/pull** — push failures never block pull. Remote changes from other agents always flow in.
|
||||||
|
- **Auth failure detection** — on 401/403, the engine sets an auth-failed flag, logs a clear warning, and degrades to SQLite-only mode. A periodic health check detects when auth is restored.
|
||||||
|
- **Startup full resync** — on MCP server start, a full cache replacement runs to catch drift from other agents, deleted records, and schema changes.
|
||||||
|
- **Periodic full resync** — every ~10 minutes, a full resync replaces incremental sync to catch any drift.
|
||||||
|
- **Retry cap** — individual pending ops are retried up to 5 times, then permanently skipped to prevent queue buildup.
|
||||||
|
- **Orphan reconciliation** — local records that never synced are deduplicated against server content before push.
|
||||||
|
|
||||||
## Search Algorithm
|
## Search Algorithm
|
||||||
|
|
||||||
Memory recall uses two different full-text search backends depending on the operating mode. Both follow the same query building pattern: the `context` and `expanded_query` fields from the user's recall request are concatenated into a single search string, then processed into a backend-specific query.
|
Memory recall uses two different full-text search backends depending on the operating mode. Both follow the same query building pattern: the `context` and `expanded_query` fields from the user's recall request are concatenated into a single search string, then processed into a backend-specific query.
|
||||||
|
|
@ -258,7 +271,7 @@ export MEMORY_API_KEY="your-api-key"
|
||||||
|
|
||||||
### Option 2: Manual MCP Config
|
### Option 2: Manual MCP Config
|
||||||
|
|
||||||
Add to `~/.claude/settings.json`:
|
Add to `~/.claude.json` under `mcpServers`:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
|
|
@ -266,8 +279,9 @@ Add to `~/.claude/settings.json`:
|
||||||
"claude_memory": {
|
"claude_memory": {
|
||||||
"type": "stdio",
|
"type": "stdio",
|
||||||
"command": "python3",
|
"command": "python3",
|
||||||
"args": ["-m", "claude_memory.mcp_server"],
|
"args": ["/path/to/claude-memory-mcp/src/claude_memory/mcp_server.py"],
|
||||||
"env": {
|
"env": {
|
||||||
|
"PYTHONPATH": "/path/to/claude-memory-mcp/src",
|
||||||
"MEMORY_API_URL": "https://your-server.example.com",
|
"MEMORY_API_URL": "https://your-server.example.com",
|
||||||
"MEMORY_API_KEY": "your-api-key"
|
"MEMORY_API_KEY": "your-api-key"
|
||||||
}
|
}
|
||||||
|
|
@ -276,7 +290,7 @@ Add to `~/.claude/settings.json`:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Omit the `env` block for SQLite-only mode. Requires `pip install claude-memory-mcp`.
|
Omit `MEMORY_API_URL` and `MEMORY_API_KEY` for SQLite-only mode.
|
||||||
|
|
||||||
### Verify
|
### Verify
|
||||||
|
|
||||||
|
|
@ -394,7 +408,7 @@ The auto-learn hook runs asynchronously after each Claude response. It operates
|
||||||
|
|
||||||
**Judge fallback chain:** Claude CLI (haiku model) → local Ollama (qwen2.5:3b/llama3.2:3b/gemma2:2b/phi3:mini) → heuristic pattern matching (keyword-based extraction from user messages).
|
**Judge fallback chain:** Claude CLI (haiku model) → local Ollama (qwen2.5:3b/llama3.2:3b/gemma2:2b/phi3:mini) → heuristic pattern matching (keyword-based extraction from user messages).
|
||||||
|
|
||||||
Extracted events are deduplicated via SHA-256 content hashing. Each event is stored to both the MCP memory database and a `~/.claude/projects/<project>/memory/auto-learned.md` markdown file.
|
Extracted events are deduplicated via SHA-256 content hashing. Each event is stored to the MCP memory database.
|
||||||
|
|
||||||
### Debug
|
### Debug
|
||||||
|
|
||||||
|
|
@ -422,7 +436,8 @@ In hybrid mode, a `SyncEngine` runs in a daemon thread with its own SQLite conne
|
||||||
|
|
||||||
| Endpoint | Method | Description |
|
| Endpoint | Method | Description |
|
||||||
|----------|--------|-------------|
|
|----------|--------|-------------|
|
||||||
| `/health` | GET | Health check |
|
| `/health` | GET | Health check (unauthenticated) |
|
||||||
|
| `/api/auth-check` | GET | Validate API key without side effects |
|
||||||
| `/api/memories` | POST | Store a memory |
|
| `/api/memories` | POST | Store a memory |
|
||||||
| `/api/memories` | GET | List memories (`?category=facts&limit=20`) |
|
| `/api/memories` | GET | List memories (`?category=facts&limit=20`) |
|
||||||
| `/api/memories/recall` | POST | Search memories by context and expanded query |
|
| `/api/memories/recall` | POST | Search memories by context and expanded query |
|
||||||
|
|
@ -463,6 +478,8 @@ Aliases `CLAUDE_MEMORY_API_URL` and `CLAUDE_MEMORY_API_KEY` are also supported.
|
||||||
|
|
||||||
## Database Migrations
|
## Database Migrations
|
||||||
|
|
||||||
|
### PostgreSQL (API Server)
|
||||||
|
|
||||||
Migrations run automatically on API server startup. To run manually:
|
Migrations run automatically on API server startup. To run manually:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
@ -477,21 +494,90 @@ Three migrations:
|
||||||
|
|
||||||
All migrations are idempotent (check column/table existence before altering).
|
All migrations are idempotent (check column/table existence before altering).
|
||||||
|
|
||||||
|
### SQLite (MCP Client)
|
||||||
|
|
||||||
|
SQLite schema is versioned via `PRAGMA user_version` and migrated automatically on startup. Current version: **2**.
|
||||||
|
|
||||||
|
| Version | Migration |
|
||||||
|
|---------|-----------|
|
||||||
|
| 1 | Add `server_id` column to `memories` table |
|
||||||
|
| 2 | Add `retry_count` column to `pending_ops` table |
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.11+
|
||||||
|
- For API server tests: `httpx`, `pytest-asyncio`
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/ViktorBarzin/claude-memory-mcp.git
|
git clone https://github.com/ViktorBarzin/claude-memory-mcp.git
|
||||||
cd claude-memory-mcp
|
cd claude-memory-mcp
|
||||||
python -m venv .venv
|
python -m venv .venv
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
pip install -e ".[api,dev]"
|
pip install -e ".[api,dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# All tests
|
||||||
pytest tests/ -v
|
pytest tests/ -v
|
||||||
|
|
||||||
|
# Individual test suites
|
||||||
|
pytest tests/test_sync.py -v # SyncEngine (client-side sync resilience)
|
||||||
|
pytest tests/test_mcp_server.py -v # MCP server (SQLite, tools, protocol)
|
||||||
|
pytest tests/test_api.py -v # API server (FastAPI endpoints)
|
||||||
|
|
||||||
|
# Linting
|
||||||
ruff check src/ tests/
|
ruff check src/ tests/
|
||||||
mypy src/claude_memory/ --strict
|
mypy src/claude_memory/ --strict
|
||||||
```
|
```
|
||||||
|
|
||||||
The MCP server itself (`mcp_server.py` and `sync.py`) uses **stdlib only** — no pip install needed on the client side. The `[api]` extra adds FastAPI, asyncpg, uvicorn, etc. for the server.
|
The MCP server itself (`mcp_server.py` and `sync.py`) uses **stdlib only** — no pip install needed on the client side. The `[api]` extra adds FastAPI, asyncpg, uvicorn, etc. for the server.
|
||||||
|
|
||||||
|
### Test Structure
|
||||||
|
|
||||||
|
| File | Tests | What it covers |
|
||||||
|
|------|-------|---------------|
|
||||||
|
| `test_sync.py` | SyncEngine unit tests | Push/pull, auth failure handling, retry caps, full resync, orphan reconciliation, decoupled push/pull, diagnostics |
|
||||||
|
| `test_mcp_server.py` | MCP server unit tests | SQLite CRUD, FTS search, tool dispatch, MCP protocol, memory_count, schema migration |
|
||||||
|
| `test_api.py` | API server integration tests | All REST endpoints, auth, user isolation, soft delete, sync, secrets, import |
|
||||||
|
| `test_auth.py` | Auth module tests | Single/multi-user auth, key mapping |
|
||||||
|
| `test_credential_detector.py` | Credential detection | Pattern matching for secrets |
|
||||||
|
| `test_crypto.py` | Encryption tests | AES-256 encrypt/decrypt |
|
||||||
|
| `test_vault_client.py` | Vault integration | Secret storage/retrieval |
|
||||||
|
|
||||||
|
### Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
claude-memory-mcp/
|
||||||
|
├── src/claude_memory/
|
||||||
|
│ ├── mcp_server.py # MCP server entry point (stdio NDJSON)
|
||||||
|
│ ├── sync.py # Background sync engine (SQLite ↔ API)
|
||||||
|
│ ├── credential_detector.py # Sensitive content detection
|
||||||
|
│ └── api/
|
||||||
|
│ ├── app.py # FastAPI application
|
||||||
|
│ ├── auth.py # API key authentication
|
||||||
|
│ ├── database.py # asyncpg connection pool
|
||||||
|
│ ├── models.py # Pydantic models
|
||||||
|
│ └── vault_service.py # HashiCorp Vault integration
|
||||||
|
├── tests/ # pytest test suite
|
||||||
|
├── hooks/ # Claude Code hooks (auto-recall, auto-learn, etc.)
|
||||||
|
├── docker/ # Docker Compose for API server + PostgreSQL
|
||||||
|
├── deploy/ # Kubernetes manifests and Helm chart
|
||||||
|
└── pyproject.toml # Package config (hatchling)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Design Decisions
|
||||||
|
|
||||||
|
- **stdlib-only MCP server**: The MCP server (`mcp_server.py`, `sync.py`) uses only Python stdlib — no pip install required for the client side. This ensures it works in any Claude Code environment without dependency management.
|
||||||
|
- **NDJSON transport**: Claude Code uses NDJSON (one JSON per line) for stdio MCP, not Content-Length framing.
|
||||||
|
- **Non-blocking startup**: MCP server startup must complete in ~15s or Claude Code times out. All network calls are deferred to background threads.
|
||||||
|
- **Suppress stderr**: Any stderr output during MCP startup causes Claude Code to reject the server.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
Apache License 2.0. See [LICENSE](LICENSE) for details.
|
Apache License 2.0. See [LICENSE](LICENSE) for details.
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ haiku to detect learnings worth persisting:
|
||||||
Features:
|
Features:
|
||||||
- Multi-turn context window (last 5 exchanges by default)
|
- Multi-turn context window (last 5 exchanges by default)
|
||||||
- State tracking to avoid duplicate extraction
|
- State tracking to avoid duplicate extraction
|
||||||
- Writes to memory API/SQLite AND auto-memory markdown files
|
- Writes to memory API/SQLite only
|
||||||
- Throttled deep extraction: full window every ~5 turns, single-turn otherwise
|
- Throttled deep extraction: full window every ~5 turns, single-turn otherwise
|
||||||
|
|
||||||
Runs with async: true — does NOT block the user.
|
Runs with async: true — does NOT block the user.
|
||||||
|
|
@ -252,36 +252,6 @@ def _store_via_sqlite(content, category, tags, importance, expanded_keywords):
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def _append_to_auto_memory(content: str, event_type: str) -> None:
|
|
||||||
"""Append a learning to the auto-memory markdown file for the current project."""
|
|
||||||
# Find the project memory directory based on CWD
|
|
||||||
cwd = os.getcwd()
|
|
||||||
# Claude Code stores project memory at ~/.claude/projects/<escaped-path>/memory/
|
|
||||||
escaped = cwd.replace("/", "-")
|
|
||||||
if escaped.startswith("-"):
|
|
||||||
escaped = escaped[1:] # Remove leading dash
|
|
||||||
memory_dir = Path.home() / ".claude" / "projects" / f"-{escaped}" / "memory"
|
|
||||||
|
|
||||||
if not memory_dir.exists():
|
|
||||||
# Try without the leading dash
|
|
||||||
memory_dir = Path.home() / ".claude" / "projects" / escaped / "memory"
|
|
||||||
|
|
||||||
if not memory_dir.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
auto_learn_file = memory_dir / "auto-learned.md"
|
|
||||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
header = "# Auto-Learned Knowledge\n\nAutomatically extracted by the auto-learn hook. Review periodically and promote valuable entries to MEMORY.md.\n\n"
|
|
||||||
|
|
||||||
if not auto_learn_file.exists():
|
|
||||||
auto_learn_file.write_text(header)
|
|
||||||
|
|
||||||
# Append the new learning
|
|
||||||
with open(auto_learn_file, "a") as f:
|
|
||||||
f.write(f"- [{now}] **{event_type}**: {content}\n")
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_llm_response(response_text: str) -> list[dict]:
|
def _parse_llm_response(response_text: str) -> list[dict]:
|
||||||
"""Parse LLM response text into events list."""
|
"""Parse LLM response text into events list."""
|
||||||
response_text = response_text.strip()
|
response_text = response_text.strip()
|
||||||
|
|
@ -485,12 +455,6 @@ def _store_events(events: list[dict], extracted_hashes: list[str]) -> list[str]:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Also append to auto-memory markdown
|
|
||||||
try:
|
|
||||||
_append_to_auto_memory(content, event_type)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
new_hashes.append(h)
|
new_hashes.append(h)
|
||||||
|
|
||||||
return new_hashes
|
return new_hashes
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,12 @@ async def health() -> dict[str, str]:
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/auth-check")
|
||||||
|
async def auth_check(user: AuthUser = Depends(get_current_user)) -> dict[str, str]:
|
||||||
|
"""Validate API key without doing any real work."""
|
||||||
|
return {"status": "ok", "user_id": user.user_id}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/memories/sync", response_model=SyncResponse)
|
@app.get("/api/memories/sync", response_model=SyncResponse)
|
||||||
async def sync_memories(
|
async def sync_memories(
|
||||||
since: Optional[str] = None,
|
since: Optional[str] = None,
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,41 @@ def _get_db_path(db_path: str | None = None) -> str:
|
||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
|
SCHEMA_VERSION = 2
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_sqlite(conn: sqlite3.Connection) -> None:
|
||||||
|
"""Version-based SQLite schema migrations."""
|
||||||
|
current = conn.execute("PRAGMA user_version").fetchone()[0]
|
||||||
|
if current < 1:
|
||||||
|
# Add server_id column for hybrid mode sync
|
||||||
|
cursor = conn.execute("PRAGMA table_info(memories)")
|
||||||
|
columns = {row["name"] for row in cursor.fetchall()}
|
||||||
|
if "server_id" not in columns:
|
||||||
|
conn.execute("ALTER TABLE memories ADD COLUMN server_id INTEGER")
|
||||||
|
conn.execute(
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)"
|
||||||
|
)
|
||||||
|
if current < 2:
|
||||||
|
# Ensure pending_ops has retry_count (sync.py also handles this, but belt-and-suspenders)
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS pending_ops (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
op_type TEXT NOT NULL,
|
||||||
|
payload TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
retry_count INTEGER DEFAULT 0
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
# Add retry_count if pending_ops already exists without it
|
||||||
|
cursor = conn.execute("PRAGMA table_info(pending_ops)")
|
||||||
|
po_columns = {row["name"] for row in cursor.fetchall()}
|
||||||
|
if "retry_count" not in po_columns:
|
||||||
|
conn.execute("ALTER TABLE pending_ops ADD COLUMN retry_count INTEGER DEFAULT 0")
|
||||||
|
conn.execute(f"PRAGMA user_version = {SCHEMA_VERSION}")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
def _init_sqlite(db_path: str | None = None) -> tuple[sqlite3.Connection, str]:
|
def _init_sqlite(db_path: str | None = None) -> tuple[sqlite3.Connection, str]:
|
||||||
"""Initialize SQLite database."""
|
"""Initialize SQLite database."""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -111,14 +146,8 @@ def _init_sqlite(db_path: str | None = None) -> tuple[sqlite3.Connection, str]:
|
||||||
updated_at TEXT NOT NULL
|
updated_at TEXT NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
# Add server_id column if missing (for hybrid mode sync)
|
# Version-based schema migrations
|
||||||
cursor.execute("PRAGMA table_info(memories)")
|
_migrate_sqlite(conn)
|
||||||
columns = {row["name"] for row in cursor.fetchall()}
|
|
||||||
if "server_id" not in columns:
|
|
||||||
cursor.execute("ALTER TABLE memories ADD COLUMN server_id INTEGER")
|
|
||||||
cursor.execute(
|
|
||||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)"
|
|
||||||
)
|
|
||||||
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
|
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
|
||||||
|
|
@ -250,6 +279,14 @@ TOOLS = [
|
||||||
"required": ["id"],
|
"required": ["id"],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "memory_count",
|
||||||
|
"description": "Get memory counts by category from local cache and sync status. Useful for diagnostics.",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -418,6 +455,32 @@ class MemoryServer:
|
||||||
|
|
||||||
return self._sqlite_secret_get(memory_id)
|
return self._sqlite_secret_get(memory_id)
|
||||||
|
|
||||||
|
def memory_count(self, args: dict[str, Any]) -> str:
|
||||||
|
if self.sync_engine:
|
||||||
|
counts = self.sync_engine.get_counts()
|
||||||
|
lines = [f"Local memories: {counts['total']}"]
|
||||||
|
for cat, n in counts["by_category"].items():
|
||||||
|
lines.append(f" {cat}: {n}")
|
||||||
|
lines.append(f"Orphans (no server_id): {counts['orphans_no_server_id']}")
|
||||||
|
lines.append(f"Pending ops: {counts['pending_ops']}")
|
||||||
|
lines.append(f"Last sync: {counts['last_sync_ts'] or 'never'}")
|
||||||
|
lines.append(f"Auth failed: {counts['auth_failed']}")
|
||||||
|
lines.append(f"Last sync success: {counts['last_sync_success']}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
if self.sqlite_conn:
|
||||||
|
cursor = self.sqlite_conn.cursor()
|
||||||
|
cursor.execute("SELECT COUNT(*) as c FROM memories")
|
||||||
|
total = cursor.fetchone()["c"]
|
||||||
|
cursor.execute("SELECT category, COUNT(*) as c FROM memories GROUP BY category ORDER BY c DESC")
|
||||||
|
by_cat = cursor.fetchall()
|
||||||
|
lines = [f"Local memories (SQLite-only): {total}"]
|
||||||
|
for row in by_cat:
|
||||||
|
lines.append(f" {row['category']}: {row['c']}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
return "No storage available"
|
||||||
|
|
||||||
# ── SQLite methods ──────────────────────────────────────────────
|
# ── SQLite methods ──────────────────────────────────────────────
|
||||||
|
|
||||||
def _sqlite_store(self, content: str, category: str, tags: str, importance: float, expanded_keywords: str, force_sensitive: bool = False) -> str:
|
def _sqlite_store(self, content: str, category: str, tags: str, importance: float, expanded_keywords: str, force_sensitive: bool = False) -> str:
|
||||||
|
|
@ -573,6 +636,7 @@ class MemoryServer:
|
||||||
"memory_list": self.memory_list,
|
"memory_list": self.memory_list,
|
||||||
"memory_delete": self.memory_delete,
|
"memory_delete": self.memory_delete,
|
||||||
"secret_get": self.secret_get,
|
"secret_get": self.secret_get,
|
||||||
|
"memory_count": self.memory_count,
|
||||||
}.get(tool_name)
|
}.get(tool_name)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
return {"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], "isError": True}
|
return {"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], "isError": True}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,12 @@ from pathlib import Path
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Max retries before an individual pending op is permanently skipped
|
||||||
|
MAX_OP_RETRIES = 5
|
||||||
|
|
||||||
|
# Full resync every N sync cycles (~10 min at 60s interval)
|
||||||
|
FULL_RESYNC_EVERY = 10
|
||||||
|
|
||||||
|
|
||||||
class SyncEngine:
|
class SyncEngine:
|
||||||
"""Background sync between local SQLite cache and remote API."""
|
"""Background sync between local SQLite cache and remote API."""
|
||||||
|
|
@ -29,6 +35,7 @@ class SyncEngine:
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self._thread: threading.Thread | None = None
|
self._thread: threading.Thread | None = None
|
||||||
self._last_sync_success = False
|
self._last_sync_success = False
|
||||||
|
self._auth_failed = False
|
||||||
|
|
||||||
# Own connection for thread safety
|
# Own connection for thread safety
|
||||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -48,7 +55,8 @@ class SyncEngine:
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
op_type TEXT NOT NULL,
|
op_type TEXT NOT NULL,
|
||||||
payload TEXT NOT NULL,
|
payload TEXT NOT NULL,
|
||||||
created_at TEXT NOT NULL
|
created_at TEXT NOT NULL,
|
||||||
|
retry_count INTEGER DEFAULT 0
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS sync_meta (
|
CREATE TABLE IF NOT EXISTS sync_meta (
|
||||||
|
|
@ -64,6 +72,13 @@ class SyncEngine:
|
||||||
self._conn.execute(
|
self._conn.execute(
|
||||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)"
|
"CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add retry_count column to pending_ops if missing (migration)
|
||||||
|
cursor = self._conn.execute("PRAGMA table_info(pending_ops)")
|
||||||
|
po_columns = {row["name"] for row in cursor.fetchall()}
|
||||||
|
if "retry_count" not in po_columns:
|
||||||
|
self._conn.execute("ALTER TABLE pending_ops ADD COLUMN retry_count INTEGER DEFAULT 0")
|
||||||
|
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -89,7 +104,15 @@ class SyncEngine:
|
||||||
return self._last_sync_success
|
return self._last_sync_success
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
"""Start background sync thread (non-blocking)."""
|
"""Start background sync thread. Runs a full resync on startup."""
|
||||||
|
# Full sync on startup (blocking, before background thread)
|
||||||
|
try:
|
||||||
|
self._full_resync()
|
||||||
|
self._last_sync_success = True
|
||||||
|
self._auth_failed = False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Startup full sync failed: %s", e)
|
||||||
|
|
||||||
self._thread = threading.Thread(target=self._sync_loop, daemon=True)
|
self._thread = threading.Thread(target=self._sync_loop, daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
|
|
@ -102,21 +125,162 @@ class SyncEngine:
|
||||||
|
|
||||||
def _sync_loop(self) -> None:
|
def _sync_loop(self) -> None:
|
||||||
"""Periodic sync loop running in background thread."""
|
"""Periodic sync loop running in background thread."""
|
||||||
|
cycle = 0
|
||||||
while not self._stop_event.is_set():
|
while not self._stop_event.is_set():
|
||||||
self._stop_event.wait(self.sync_interval)
|
self._stop_event.wait(self.sync_interval)
|
||||||
if self._stop_event.is_set():
|
if self._stop_event.is_set():
|
||||||
break
|
break
|
||||||
|
cycle += 1
|
||||||
try:
|
try:
|
||||||
self._sync_once()
|
# If auth previously failed, try a lightweight check first
|
||||||
|
if self._auth_failed:
|
||||||
|
if not self._check_auth():
|
||||||
|
continue # Still failing, skip this cycle
|
||||||
|
|
||||||
|
if cycle % FULL_RESYNC_EVERY == 0:
|
||||||
|
self._full_resync()
|
||||||
|
else:
|
||||||
|
self._sync_once()
|
||||||
self._last_sync_success = True
|
self._last_sync_success = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Sync cycle failed: %s", e)
|
logger.warning("Sync cycle failed: %s", e)
|
||||||
self._last_sync_success = False
|
self._last_sync_success = False
|
||||||
|
|
||||||
|
def _check_auth(self) -> bool:
|
||||||
|
"""Lightweight auth check. Returns True if auth is OK."""
|
||||||
|
try:
|
||||||
|
self._api_request("GET", "/api/auth-check")
|
||||||
|
self._auth_failed = False
|
||||||
|
logger.info("Auth check passed — resuming sync")
|
||||||
|
return True
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
if e.code in (401, 403):
|
||||||
|
logger.warning(
|
||||||
|
"Auth still failing (HTTP %d) — API key mismatch. "
|
||||||
|
"Update MEMORY_API_KEY in ~/.claude.json", e.code
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
# Non-auth error (e.g. 500) — try the auth-check endpoint might not exist,
|
||||||
|
# fall back to /health
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: try /health (unauthenticated)
|
||||||
|
try:
|
||||||
|
url = f"{self.api_base_url}/health"
|
||||||
|
req = urllib.request.Request(url, method="GET")
|
||||||
|
with urllib.request.urlopen(req, timeout=5):
|
||||||
|
pass
|
||||||
|
# Server is reachable but auth-check failed — auth is still broken
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
# Server unreachable — not an auth problem
|
||||||
|
return False
|
||||||
|
|
||||||
def _sync_once(self) -> None:
|
def _sync_once(self) -> None:
|
||||||
"""Push pending ops, then pull remote changes."""
|
"""Push pending ops, then pull remote changes. Both run independently."""
|
||||||
self._push_pending_ops()
|
push_ok = self._push_pending_ops()
|
||||||
self._pull_changes()
|
pull_ok = self._pull_changes()
|
||||||
|
if not push_ok and not pull_ok:
|
||||||
|
raise RuntimeError("Both push and pull failed")
|
||||||
|
|
||||||
|
def _full_resync(self) -> None:
|
||||||
|
"""Full cache replacement from server — handles drift, deletes, schema changes."""
|
||||||
|
# Step 1: Push orphaned local-only records (deduplicated)
|
||||||
|
self._push_orphans()
|
||||||
|
|
||||||
|
# Step 2: Pull everything from server (no since filter = non-deleted only)
|
||||||
|
result = self._api_request("GET", "/api/memories/sync")
|
||||||
|
memories = result.get("memories", [])
|
||||||
|
server_time = result.get("server_time")
|
||||||
|
server_ids = {m["id"] for m in memories}
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# Delete local records whose server_id no longer exists on server
|
||||||
|
local_rows = self._conn.execute(
|
||||||
|
"SELECT id, server_id FROM memories WHERE server_id IS NOT NULL"
|
||||||
|
).fetchall()
|
||||||
|
for row in local_rows:
|
||||||
|
if row["server_id"] not in server_ids:
|
||||||
|
self._conn.execute("DELETE FROM memories WHERE id = ?", (row["id"],))
|
||||||
|
|
||||||
|
# Delete remaining orphans (already pushed or duplicates)
|
||||||
|
self._conn.execute("DELETE FROM memories WHERE server_id IS NULL")
|
||||||
|
|
||||||
|
# Upsert all server records
|
||||||
|
for mem in memories:
|
||||||
|
server_id = mem["id"]
|
||||||
|
existing = self._conn.execute(
|
||||||
|
"SELECT id FROM memories WHERE server_id = ?", (server_id,)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
self._conn.execute(
|
||||||
|
"""UPDATE memories SET content=?, category=?, tags=?,
|
||||||
|
expanded_keywords=?, importance=?, is_sensitive=?,
|
||||||
|
updated_at=? WHERE server_id=?""",
|
||||||
|
(
|
||||||
|
mem["content"], mem["category"], mem.get("tags", ""),
|
||||||
|
mem.get("expanded_keywords", ""), mem["importance"],
|
||||||
|
1 if mem.get("is_sensitive") else 0,
|
||||||
|
mem.get("updated_at", ""), server_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._conn.execute(
|
||||||
|
"""INSERT INTO memories (content, category, tags, expanded_keywords,
|
||||||
|
importance, is_sensitive, created_at, updated_at, server_id)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(
|
||||||
|
mem["content"], mem["category"], mem.get("tags", ""),
|
||||||
|
mem.get("expanded_keywords", ""), mem["importance"],
|
||||||
|
1 if mem.get("is_sensitive") else 0,
|
||||||
|
mem.get("created_at", ""), mem.get("updated_at", ""), server_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._conn.commit()
|
||||||
|
|
||||||
|
if server_time:
|
||||||
|
self.last_sync_ts = server_time
|
||||||
|
|
||||||
|
def _push_orphans(self) -> None:
|
||||||
|
"""Push local-only records to server, skipping content duplicates."""
|
||||||
|
with self._lock:
|
||||||
|
orphans = self._conn.execute(
|
||||||
|
"SELECT id, content, category, tags, expanded_keywords, importance "
|
||||||
|
"FROM memories WHERE server_id IS NULL"
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if not orphans:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get all server content for dedup comparison
|
||||||
|
result = self._api_request("GET", "/api/memories/sync")
|
||||||
|
server_contents = {m["content"] for m in result.get("memories", [])}
|
||||||
|
|
||||||
|
for orphan in orphans:
|
||||||
|
if orphan["content"] in server_contents:
|
||||||
|
continue # Skip duplicate
|
||||||
|
try:
|
||||||
|
resp = self._api_request("POST", "/api/memories", {
|
||||||
|
"content": orphan["content"],
|
||||||
|
"category": orphan["category"],
|
||||||
|
"tags": orphan["tags"],
|
||||||
|
"expanded_keywords": orphan["expanded_keywords"],
|
||||||
|
"importance": orphan["importance"],
|
||||||
|
})
|
||||||
|
server_id = resp.get("id")
|
||||||
|
if server_id:
|
||||||
|
with self._lock:
|
||||||
|
self._conn.execute(
|
||||||
|
"UPDATE memories SET server_id=? WHERE id=?",
|
||||||
|
(server_id, orphan["id"]),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except Exception:
|
||||||
|
pass # Will be cleaned up by the full resync delete step
|
||||||
|
|
||||||
def _api_request(self, method: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]:
|
def _api_request(self, method: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
"""Make an HTTP request to the memory API."""
|
"""Make an HTTP request to the memory API."""
|
||||||
|
|
@ -131,22 +295,47 @@ class SyncEngine:
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
try:
|
||||||
result: dict[str, Any] = json.loads(resp.read().decode())
|
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||||
return result
|
result: dict[str, Any] = json.loads(resp.read().decode())
|
||||||
|
return result
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
if e.code in (401, 403):
|
||||||
|
self._auth_failed = True
|
||||||
|
logger.warning(
|
||||||
|
"Auth failed (HTTP %d) — API key may have rotated. "
|
||||||
|
"Update MEMORY_API_KEY in ~/.claude.json", e.code
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def _push_pending_ops(self) -> None:
|
def _push_pending_ops(self) -> bool:
|
||||||
"""Push queued operations to the API server."""
|
"""Push queued operations to the API server. Returns True on success."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
cursor = self._conn.execute(
|
cursor = self._conn.execute(
|
||||||
"SELECT id, op_type, payload FROM pending_ops ORDER BY id"
|
"SELECT id, op_type, payload, retry_count FROM pending_ops ORDER BY id"
|
||||||
)
|
)
|
||||||
ops = cursor.fetchall()
|
ops = cursor.fetchall()
|
||||||
|
|
||||||
|
if not ops:
|
||||||
|
return True
|
||||||
|
|
||||||
|
all_ok = True
|
||||||
for op in ops:
|
for op in ops:
|
||||||
op_id = op["id"]
|
op_id = op["id"]
|
||||||
op_type = op["op_type"]
|
op_type = op["op_type"]
|
||||||
payload = json.loads(op["payload"])
|
payload = json.loads(op["payload"])
|
||||||
|
retry_count = op["retry_count"] or 0
|
||||||
|
|
||||||
|
# Skip ops that have exceeded retry limit
|
||||||
|
if retry_count >= MAX_OP_RETRIES:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping op %d (%s) after %d retries — removing from queue",
|
||||||
|
op_id, op_type, retry_count,
|
||||||
|
)
|
||||||
|
with self._lock:
|
||||||
|
self._conn.execute("DELETE FROM pending_ops WHERE id = ?", (op_id,))
|
||||||
|
self._conn.commit()
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if op_type == "store":
|
if op_type == "store":
|
||||||
|
|
@ -164,8 +353,8 @@ class SyncEngine:
|
||||||
if server_id:
|
if server_id:
|
||||||
try:
|
try:
|
||||||
self._api_request("DELETE", f"/api/memories/{server_id}")
|
self._api_request("DELETE", f"/api/memories/{server_id}")
|
||||||
except (RuntimeError, urllib.error.HTTPError) as e:
|
except urllib.error.HTTPError as e:
|
||||||
if "404" in str(e):
|
if e.code == 404:
|
||||||
pass # Already deleted on server
|
pass # Already deleted on server
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
@ -175,75 +364,103 @@ class SyncEngine:
|
||||||
self._conn.execute("DELETE FROM pending_ops WHERE id = ?", (op_id,))
|
self._conn.execute("DELETE FROM pending_ops WHERE id = ?", (op_id,))
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
|
||||||
except Exception as e:
|
except urllib.error.HTTPError as e:
|
||||||
logger.warning("Failed to push op %d (%s): %s", op_id, op_type, e)
|
if e.code in (401, 403):
|
||||||
raise # Propagate to mark sync as failed
|
self._auth_failed = True
|
||||||
|
logger.warning("Auth failed (HTTP %d) — aborting push", e.code)
|
||||||
def _pull_changes(self) -> None:
|
return False # Abort entire push — no point retrying with bad key
|
||||||
"""Pull changes from server since last sync."""
|
# Increment retry count for non-auth errors
|
||||||
params = ""
|
with self._lock:
|
||||||
ts = self.last_sync_ts
|
|
||||||
if ts:
|
|
||||||
params = f"?since={urllib.parse.quote(ts, safe='')}"
|
|
||||||
|
|
||||||
result = self._api_request("GET", f"/api/memories/sync{params}")
|
|
||||||
memories = result.get("memories", [])
|
|
||||||
server_time = result.get("server_time")
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
for mem in memories:
|
|
||||||
server_id = mem["id"]
|
|
||||||
deleted_at = mem.get("deleted_at")
|
|
||||||
|
|
||||||
if deleted_at:
|
|
||||||
# Remove from local cache
|
|
||||||
self._conn.execute(
|
self._conn.execute(
|
||||||
"DELETE FROM memories WHERE server_id = ?", (server_id,)
|
"UPDATE pending_ops SET retry_count = retry_count + 1 WHERE id = ?",
|
||||||
|
(op_id,),
|
||||||
)
|
)
|
||||||
else:
|
self._conn.commit()
|
||||||
# Upsert by server_id (server wins)
|
logger.warning("Failed to push op %d (%s): HTTP %d", op_id, op_type, e.code)
|
||||||
existing = self._conn.execute(
|
all_ok = False
|
||||||
"SELECT id FROM memories WHERE server_id = ?", (server_id,)
|
except Exception as e:
|
||||||
).fetchone()
|
with self._lock:
|
||||||
|
self._conn.execute(
|
||||||
|
"UPDATE pending_ops SET retry_count = retry_count + 1 WHERE id = ?",
|
||||||
|
(op_id,),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
logger.warning("Failed to push op %d (%s): %s", op_id, op_type, e)
|
||||||
|
all_ok = False
|
||||||
|
|
||||||
if existing:
|
return all_ok
|
||||||
|
|
||||||
|
def _pull_changes(self) -> bool:
|
||||||
|
"""Pull changes from server since last sync. Returns True on success."""
|
||||||
|
try:
|
||||||
|
params = ""
|
||||||
|
ts = self.last_sync_ts
|
||||||
|
if ts:
|
||||||
|
params = f"?since={urllib.parse.quote(ts, safe='')}"
|
||||||
|
|
||||||
|
result = self._api_request("GET", f"/api/memories/sync{params}")
|
||||||
|
memories = result.get("memories", [])
|
||||||
|
server_time = result.get("server_time")
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
for mem in memories:
|
||||||
|
server_id = mem["id"]
|
||||||
|
deleted_at = mem.get("deleted_at")
|
||||||
|
|
||||||
|
if deleted_at:
|
||||||
|
# Remove from local cache
|
||||||
self._conn.execute(
|
self._conn.execute(
|
||||||
"""UPDATE memories SET content = ?, category = ?, tags = ?,
|
"DELETE FROM memories WHERE server_id = ?", (server_id,)
|
||||||
expanded_keywords = ?, importance = ?, is_sensitive = ?,
|
|
||||||
updated_at = ? WHERE server_id = ?""",
|
|
||||||
(
|
|
||||||
mem["content"],
|
|
||||||
mem["category"],
|
|
||||||
mem.get("tags", ""),
|
|
||||||
mem.get("expanded_keywords", ""),
|
|
||||||
mem["importance"],
|
|
||||||
1 if mem.get("is_sensitive") else 0,
|
|
||||||
mem.get("updated_at", datetime.now(timezone.utc).isoformat()),
|
|
||||||
server_id,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._conn.execute(
|
# Upsert by server_id (server wins)
|
||||||
"""INSERT INTO memories
|
existing = self._conn.execute(
|
||||||
(content, category, tags, expanded_keywords, importance,
|
"SELECT id FROM memories WHERE server_id = ?", (server_id,)
|
||||||
is_sensitive, created_at, updated_at, server_id)
|
).fetchone()
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
||||||
(
|
|
||||||
mem["content"],
|
|
||||||
mem["category"],
|
|
||||||
mem.get("tags", ""),
|
|
||||||
mem.get("expanded_keywords", ""),
|
|
||||||
mem["importance"],
|
|
||||||
1 if mem.get("is_sensitive") else 0,
|
|
||||||
mem.get("created_at", datetime.now(timezone.utc).isoformat()),
|
|
||||||
mem.get("updated_at", datetime.now(timezone.utc).isoformat()),
|
|
||||||
server_id,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
|
||||||
|
|
||||||
if server_time:
|
if existing:
|
||||||
self.last_sync_ts = server_time
|
self._conn.execute(
|
||||||
|
"""UPDATE memories SET content = ?, category = ?, tags = ?,
|
||||||
|
expanded_keywords = ?, importance = ?, is_sensitive = ?,
|
||||||
|
updated_at = ? WHERE server_id = ?""",
|
||||||
|
(
|
||||||
|
mem["content"],
|
||||||
|
mem["category"],
|
||||||
|
mem.get("tags", ""),
|
||||||
|
mem.get("expanded_keywords", ""),
|
||||||
|
mem["importance"],
|
||||||
|
1 if mem.get("is_sensitive") else 0,
|
||||||
|
mem.get("updated_at", datetime.now(timezone.utc).isoformat()),
|
||||||
|
server_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._conn.execute(
|
||||||
|
"""INSERT INTO memories
|
||||||
|
(content, category, tags, expanded_keywords, importance,
|
||||||
|
is_sensitive, created_at, updated_at, server_id)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||||
|
(
|
||||||
|
mem["content"],
|
||||||
|
mem["category"],
|
||||||
|
mem.get("tags", ""),
|
||||||
|
mem.get("expanded_keywords", ""),
|
||||||
|
mem["importance"],
|
||||||
|
1 if mem.get("is_sensitive") else 0,
|
||||||
|
mem.get("created_at", datetime.now(timezone.utc).isoformat()),
|
||||||
|
mem.get("updated_at", datetime.now(timezone.utc).isoformat()),
|
||||||
|
server_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
|
||||||
|
if server_time:
|
||||||
|
self.last_sync_ts = server_time
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Pull changes failed: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
def enqueue_store(
|
def enqueue_store(
|
||||||
self,
|
self,
|
||||||
|
|
@ -295,6 +512,11 @@ class SyncEngine:
|
||||||
force_sensitive: bool = False,
|
force_sensitive: bool = False,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
"""Try to sync a store immediately. Returns server_id or None if failed."""
|
"""Try to sync a store immediately. Returns server_id or None if failed."""
|
||||||
|
if self._auth_failed:
|
||||||
|
self.enqueue_store(
|
||||||
|
local_id, content, category, tags, expanded_keywords, importance, force_sensitive
|
||||||
|
)
|
||||||
|
return None
|
||||||
try:
|
try:
|
||||||
result = self._api_request("POST", "/api/memories", {
|
result = self._api_request("POST", "/api/memories", {
|
||||||
"content": content,
|
"content": content,
|
||||||
|
|
@ -321,6 +543,9 @@ class SyncEngine:
|
||||||
|
|
||||||
def try_sync_delete(self, server_id: int) -> bool:
|
def try_sync_delete(self, server_id: int) -> bool:
|
||||||
"""Try to sync a delete immediately. Returns True if successful."""
|
"""Try to sync a delete immediately. Returns True if successful."""
|
||||||
|
if self._auth_failed:
|
||||||
|
self.enqueue_delete(server_id)
|
||||||
|
return False
|
||||||
try:
|
try:
|
||||||
self._api_request("DELETE", f"/api/memories/{server_id}")
|
self._api_request("DELETE", f"/api/memories/{server_id}")
|
||||||
return True
|
return True
|
||||||
|
|
@ -332,3 +557,27 @@ class SyncEngine:
|
||||||
except Exception:
|
except Exception:
|
||||||
self.enqueue_delete(server_id)
|
self.enqueue_delete(server_id)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def get_counts(self) -> dict[str, Any]:
|
||||||
|
"""Get memory counts for diagnostics."""
|
||||||
|
with self._lock:
|
||||||
|
total = self._conn.execute("SELECT COUNT(*) as c FROM memories").fetchone()["c"]
|
||||||
|
by_cat = self._conn.execute(
|
||||||
|
"SELECT category, COUNT(*) as c FROM memories GROUP BY category ORDER BY c DESC"
|
||||||
|
).fetchall()
|
||||||
|
orphans = self._conn.execute(
|
||||||
|
"SELECT COUNT(*) as c FROM memories WHERE server_id IS NULL"
|
||||||
|
).fetchone()["c"]
|
||||||
|
pending = self._conn.execute(
|
||||||
|
"SELECT COUNT(*) as c FROM pending_ops"
|
||||||
|
).fetchone()["c"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"by_category": {row["category"]: row["c"] for row in by_cat},
|
||||||
|
"orphans_no_server_id": orphans,
|
||||||
|
"pending_ops": pending,
|
||||||
|
"last_sync_ts": self.last_sync_ts,
|
||||||
|
"auth_failed": self._auth_failed,
|
||||||
|
"last_sync_success": self._last_sync_success,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -99,6 +99,20 @@ async def test_health_endpoint_no_auth(client):
|
||||||
assert resp.json() == {"status": "ok"}
|
assert resp.json() == {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_check_endpoint(client):
|
||||||
|
ac, conn, app_mod = client
|
||||||
|
async with ac:
|
||||||
|
resp = await ac.get(
|
||||||
|
"/api/auth-check",
|
||||||
|
headers={"Authorization": "Bearer test-key"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
assert data["user_id"] == "testuser"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_store_memory_creates_record_with_user_id(client):
|
async def test_store_memory_creates_record_with_user_id(client):
|
||||||
ac, conn, app_mod = client
|
ac, conn, app_mod = client
|
||||||
|
|
|
||||||
|
|
@ -238,9 +238,9 @@ class TestMCPProtocol:
|
||||||
def test_handle_tools_list(self, server):
|
def test_handle_tools_list(self, server):
|
||||||
result = server.handle_tools_list({})
|
result = server.handle_tools_list({})
|
||||||
tools = result["tools"]
|
tools = result["tools"]
|
||||||
assert len(tools) == 5
|
assert len(tools) == 6
|
||||||
names = {t["name"] for t in tools}
|
names = {t["name"] for t in tools}
|
||||||
assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get"}
|
assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get", "memory_count"}
|
||||||
|
|
||||||
def test_handle_tools_call_store(self, server):
|
def test_handle_tools_call_store(self, server):
|
||||||
result = server.handle_tools_call({
|
result = server.handle_tools_call({
|
||||||
|
|
@ -291,7 +291,7 @@ class TestProcessMessage:
|
||||||
"params": {},
|
"params": {},
|
||||||
})
|
})
|
||||||
assert "result" in response
|
assert "result" in response
|
||||||
assert len(response["result"]["tools"]) == 5
|
assert len(response["result"]["tools"]) == 6
|
||||||
|
|
||||||
def test_tools_call(self, server):
|
def test_tools_call(self, server):
|
||||||
response = server.process_message({
|
response = server.process_message({
|
||||||
|
|
@ -340,3 +340,71 @@ class TestProcessMessage:
|
||||||
parsed = json.loads(serialized)
|
parsed = json.loads(serialized)
|
||||||
assert parsed["jsonrpc"] == "2.0"
|
assert parsed["jsonrpc"] == "2.0"
|
||||||
assert parsed["id"] == 5
|
assert parsed["id"] == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryCount:
|
||||||
|
def test_count_empty(self, server):
|
||||||
|
result = server.memory_count({})
|
||||||
|
assert "0" in result
|
||||||
|
|
||||||
|
def test_count_after_store(self, server):
|
||||||
|
server.memory_store({
|
||||||
|
"content": "test memory",
|
||||||
|
"expanded_keywords": "test memory keywords data",
|
||||||
|
})
|
||||||
|
result = server.memory_count({})
|
||||||
|
assert "1" in result
|
||||||
|
assert "facts" in result
|
||||||
|
|
||||||
|
def test_count_multiple_categories(self, server):
|
||||||
|
server.memory_store({
|
||||||
|
"content": "a fact",
|
||||||
|
"category": "facts",
|
||||||
|
"expanded_keywords": "fact test data words",
|
||||||
|
})
|
||||||
|
server.memory_store({
|
||||||
|
"content": "a preference",
|
||||||
|
"category": "preferences",
|
||||||
|
"expanded_keywords": "preference test data words",
|
||||||
|
})
|
||||||
|
result = server.memory_count({})
|
||||||
|
assert "facts: 1" in result
|
||||||
|
assert "preferences: 1" in result
|
||||||
|
|
||||||
|
def test_count_via_tools_call(self, server):
|
||||||
|
result = server.handle_tools_call({
|
||||||
|
"name": "memory_count",
|
||||||
|
"arguments": {},
|
||||||
|
})
|
||||||
|
assert not result.get("isError", False)
|
||||||
|
assert "0" in result["content"][0]["text"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSchemaMigration:
|
||||||
|
def test_schema_version_set(self, tmp_path):
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
srv = MemoryServer(sqlite_db_path=db_path)
|
||||||
|
cursor = srv.sqlite_conn.cursor()
|
||||||
|
version = cursor.execute("PRAGMA user_version").fetchone()[0]
|
||||||
|
assert version == 2
|
||||||
|
srv.sqlite_conn.close()
|
||||||
|
|
||||||
|
def test_migration_idempotent(self, tmp_path):
|
||||||
|
"""Running _init_sqlite twice should not error."""
|
||||||
|
from claude_memory.mcp_server import _init_sqlite
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
conn1, _ = _init_sqlite(db_path)
|
||||||
|
conn1.close()
|
||||||
|
conn2, _ = _init_sqlite(db_path)
|
||||||
|
version = conn2.execute("PRAGMA user_version").fetchone()[0]
|
||||||
|
assert version == 2
|
||||||
|
conn2.close()
|
||||||
|
|
||||||
|
def test_server_id_column_exists(self, tmp_path):
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
srv = MemoryServer(sqlite_db_path=db_path)
|
||||||
|
cursor = srv.sqlite_conn.cursor()
|
||||||
|
cursor.execute("PRAGMA table_info(memories)")
|
||||||
|
columns = {row["name"] for row in cursor.fetchall()}
|
||||||
|
assert "server_id" in columns
|
||||||
|
srv.sqlite_conn.close()
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,9 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import urllib.error
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -154,21 +155,25 @@ class TestPushPendingOps:
|
||||||
"""A 404 on delete means already deleted on server — should still clear queue."""
|
"""A 404 on delete means already deleted on server — should still clear queue."""
|
||||||
engine.enqueue_delete(42)
|
engine.enqueue_delete(42)
|
||||||
|
|
||||||
|
import urllib.error
|
||||||
with patch.object(engine, "_api_request") as mock_api:
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
mock_api.side_effect = RuntimeError("API error 404: not found")
|
mock_api.side_effect = urllib.error.HTTPError(
|
||||||
|
url="http://fake", code=404, msg="Not Found", hdrs=None, fp=None
|
||||||
|
)
|
||||||
engine._push_pending_ops()
|
engine._push_pending_ops()
|
||||||
|
|
||||||
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||||
assert cursor.fetchone()["cnt"] == 0
|
assert cursor.fetchone()["cnt"] == 0
|
||||||
|
|
||||||
def test_push_failure_keeps_queue(self, engine):
|
def test_push_failure_keeps_queue_returns_false(self, engine):
|
||||||
|
"""Push failure should keep the op in queue and return False (not raise)."""
|
||||||
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
|
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
|
||||||
|
|
||||||
with patch.object(engine, "_api_request") as mock_api:
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
mock_api.side_effect = RuntimeError("Connection refused")
|
mock_api.side_effect = RuntimeError("Connection refused")
|
||||||
with pytest.raises(RuntimeError):
|
result = engine._push_pending_ops()
|
||||||
engine._push_pending_ops()
|
|
||||||
|
|
||||||
|
assert result is False
|
||||||
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||||
assert cursor.fetchone()["cnt"] == 1
|
assert cursor.fetchone()["cnt"] == 1
|
||||||
|
|
||||||
|
|
@ -393,3 +398,361 @@ class TestFullSyncCycle:
|
||||||
# Should be gone locally
|
# Should be gone locally
|
||||||
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 500")
|
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 500")
|
||||||
assert cursor.fetchone() is None
|
assert cursor.fetchone() is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthFailureHandling:
|
||||||
|
def test_auth_flag_set_on_401(self, engine):
|
||||||
|
"""401 from _api_request should set _auth_failed flag."""
|
||||||
|
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.side_effect = urllib.error.HTTPError(
|
||||||
|
url="http://fake", code=401, msg="Unauthorized", hdrs=None, fp=None
|
||||||
|
)
|
||||||
|
result = engine._push_pending_ops()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert engine._auth_failed is True
|
||||||
|
|
||||||
|
def test_auth_flag_set_on_403(self, engine):
|
||||||
|
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.side_effect = urllib.error.HTTPError(
|
||||||
|
url="http://fake", code=403, msg="Forbidden", hdrs=None, fp=None
|
||||||
|
)
|
||||||
|
result = engine._push_pending_ops()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert engine._auth_failed is True
|
||||||
|
|
||||||
|
def test_push_aborts_on_auth_failure(self, engine):
|
||||||
|
"""On 401, push should abort immediately — no further ops attempted."""
|
||||||
|
engine.enqueue_store(1, "test1", "facts", "", "kw", 0.5)
|
||||||
|
engine.enqueue_store(2, "test2", "facts", "", "kw", 0.5)
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.side_effect = urllib.error.HTTPError(
|
||||||
|
url="http://fake", code=401, msg="Unauthorized", hdrs=None, fp=None
|
||||||
|
)
|
||||||
|
engine._push_pending_ops()
|
||||||
|
|
||||||
|
# Both ops should still be in queue (aborted before processing second)
|
||||||
|
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||||
|
assert cursor.fetchone()["cnt"] == 2
|
||||||
|
|
||||||
|
def test_try_sync_store_queues_when_auth_failed(self, engine):
|
||||||
|
"""When auth is failed, try_sync_store should queue without attempting API call."""
|
||||||
|
engine._auth_failed = True
|
||||||
|
|
||||||
|
result = engine.try_sync_store(1, "test", "facts", "", "kw", 0.5)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||||
|
assert cursor.fetchone()["cnt"] == 1
|
||||||
|
|
||||||
|
def test_try_sync_delete_queues_when_auth_failed(self, engine):
|
||||||
|
engine._auth_failed = True
|
||||||
|
|
||||||
|
result = engine.try_sync_delete(42)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||||
|
assert cursor.fetchone()["cnt"] == 1
|
||||||
|
|
||||||
|
def test_check_auth_clears_flag_on_success(self, engine):
|
||||||
|
engine._auth_failed = True
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.return_value = {"status": "ok", "user_id": "test"}
|
||||||
|
result = engine._check_auth()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert engine._auth_failed is False
|
||||||
|
|
||||||
|
def test_check_auth_stays_failed_on_401(self, engine):
|
||||||
|
engine._auth_failed = True
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.side_effect = urllib.error.HTTPError(
|
||||||
|
url="http://fake", code=401, msg="Unauthorized", hdrs=None, fp=None
|
||||||
|
)
|
||||||
|
# Also mock urlopen for /health fallback
|
||||||
|
with patch("urllib.request.urlopen") as mock_urlopen:
|
||||||
|
mock_urlopen.return_value.__enter__ = MagicMock()
|
||||||
|
mock_urlopen.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
result = engine._check_auth()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert engine._auth_failed is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetryCount:
|
||||||
|
def test_retry_count_incremented_on_failure(self, engine):
|
||||||
|
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.side_effect = RuntimeError("Connection refused")
|
||||||
|
engine._push_pending_ops()
|
||||||
|
|
||||||
|
cursor = engine._conn.execute("SELECT retry_count FROM pending_ops WHERE id = 1")
|
||||||
|
assert cursor.fetchone()["retry_count"] == 1
|
||||||
|
|
||||||
|
def test_op_skipped_after_max_retries(self, engine):
|
||||||
|
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
|
||||||
|
# Set retry_count to max
|
||||||
|
engine._conn.execute("UPDATE pending_ops SET retry_count = 5 WHERE id = 1")
|
||||||
|
engine._conn.commit()
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
result = engine._push_pending_ops()
|
||||||
|
|
||||||
|
# Op should be deleted (skipped), API never called
|
||||||
|
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||||
|
assert cursor.fetchone()["cnt"] == 0
|
||||||
|
mock_api.assert_not_called()
|
||||||
|
|
||||||
|
def test_retry_count_persists_across_pushes(self, engine):
|
||||||
|
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.side_effect = RuntimeError("fail")
|
||||||
|
engine._push_pending_ops()
|
||||||
|
engine._push_pending_ops()
|
||||||
|
engine._push_pending_ops()
|
||||||
|
|
||||||
|
cursor = engine._conn.execute("SELECT retry_count FROM pending_ops WHERE id = 1")
|
||||||
|
assert cursor.fetchone()["retry_count"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecoupledPushPull:
|
||||||
|
def test_pull_runs_even_when_push_fails(self, engine):
|
||||||
|
"""Pull should execute even if push fails — they're decoupled."""
|
||||||
|
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def mock_api(method, path, body=None):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if "POST" == method:
|
||||||
|
raise RuntimeError("Push failed")
|
||||||
|
# GET for pull
|
||||||
|
return {
|
||||||
|
"memories": [{
|
||||||
|
"id": 99, "content": "from server", "category": "facts",
|
||||||
|
"tags": "", "expanded_keywords": "", "importance": 0.5,
|
||||||
|
"is_sensitive": False, "created_at": now, "updated_at": now,
|
||||||
|
"deleted_at": None,
|
||||||
|
}],
|
||||||
|
"server_time": now,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request", side_effect=mock_api):
|
||||||
|
engine._sync_once()
|
||||||
|
|
||||||
|
# Pull should have inserted the server memory
|
||||||
|
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 99")
|
||||||
|
assert cursor.fetchone() is not None
|
||||||
|
|
||||||
|
def test_sync_once_returns_normally_on_partial_failure(self, engine):
|
||||||
|
"""If push fails but pull succeeds, _sync_once should not raise."""
|
||||||
|
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
|
||||||
|
|
||||||
|
def mock_api(method, path, body=None):
|
||||||
|
if method == "POST":
|
||||||
|
raise RuntimeError("Push failed")
|
||||||
|
return {"memories": [], "server_time": "2026-03-16T12:00:00+00:00"}
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request", side_effect=mock_api):
|
||||||
|
# Should not raise
|
||||||
|
engine._sync_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFullResync:
|
||||||
|
def test_full_resync_inserts_server_records(self, engine):
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.return_value = {
|
||||||
|
"memories": [
|
||||||
|
{"id": 1, "content": "server mem 1", "category": "facts",
|
||||||
|
"tags": "", "expanded_keywords": "", "importance": 0.5,
|
||||||
|
"is_sensitive": False, "created_at": now, "updated_at": now},
|
||||||
|
{"id": 2, "content": "server mem 2", "category": "projects",
|
||||||
|
"tags": "", "expanded_keywords": "", "importance": 0.8,
|
||||||
|
"is_sensitive": False, "created_at": now, "updated_at": now},
|
||||||
|
],
|
||||||
|
"server_time": now,
|
||||||
|
}
|
||||||
|
engine._full_resync()
|
||||||
|
|
||||||
|
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM memories")
|
||||||
|
assert cursor.fetchone()["cnt"] == 2
|
||||||
|
|
||||||
|
def test_full_resync_removes_stale_local_records(self, engine):
|
||||||
|
"""Local records with server_ids not on server should be deleted."""
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
# Insert a local record with server_id=999 (not on server)
|
||||||
|
engine._conn.execute(
|
||||||
|
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
|
||||||
|
"is_sensitive, created_at, updated_at, server_id) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
("stale", "facts", "", "", 0.5, 0, now, now, 999),
|
||||||
|
)
|
||||||
|
engine._conn.commit()
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.return_value = {
|
||||||
|
"memories": [
|
||||||
|
{"id": 1, "content": "current", "category": "facts",
|
||||||
|
"tags": "", "expanded_keywords": "", "importance": 0.5,
|
||||||
|
"is_sensitive": False, "created_at": now, "updated_at": now},
|
||||||
|
],
|
||||||
|
"server_time": now,
|
||||||
|
}
|
||||||
|
engine._full_resync()
|
||||||
|
|
||||||
|
# Stale record should be gone
|
||||||
|
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 999")
|
||||||
|
assert cursor.fetchone() is None
|
||||||
|
# Current record should exist
|
||||||
|
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 1")
|
||||||
|
assert cursor.fetchone() is not None
|
||||||
|
|
||||||
|
def test_full_resync_deletes_orphans_after_push(self, engine):
|
||||||
|
"""Orphans (server_id IS NULL) should be cleaned up after push attempt."""
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
engine._conn.execute(
|
||||||
|
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
|
||||||
|
"is_sensitive, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?)",
|
||||||
|
("orphan", "facts", "", "", 0.5, 0, now, now),
|
||||||
|
)
|
||||||
|
engine._conn.commit()
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.return_value = {
|
||||||
|
"memories": [],
|
||||||
|
"server_time": now,
|
||||||
|
}
|
||||||
|
engine._full_resync()
|
||||||
|
|
||||||
|
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id IS NULL")
|
||||||
|
assert cursor.fetchone() is None
|
||||||
|
|
||||||
|
def test_full_resync_updates_last_sync_ts(self, engine):
|
||||||
|
server_time = "2026-03-16T15:00:00+00:00"
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.return_value = {"memories": [], "server_time": server_time}
|
||||||
|
engine._full_resync()
|
||||||
|
|
||||||
|
assert engine.last_sync_ts == server_time
|
||||||
|
|
||||||
|
def test_full_resync_updates_existing_records(self, engine):
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
engine._conn.execute(
|
||||||
|
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
|
||||||
|
"is_sensitive, created_at, updated_at, server_id) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
("old content", "facts", "", "", 0.5, 0, now, now, 10),
|
||||||
|
)
|
||||||
|
engine._conn.commit()
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request") as mock_api:
|
||||||
|
mock_api.return_value = {
|
||||||
|
"memories": [
|
||||||
|
{"id": 10, "content": "new content", "category": "projects",
|
||||||
|
"tags": "updated", "expanded_keywords": "", "importance": 0.9,
|
||||||
|
"is_sensitive": False, "created_at": now, "updated_at": now},
|
||||||
|
],
|
||||||
|
"server_time": now,
|
||||||
|
}
|
||||||
|
engine._full_resync()
|
||||||
|
|
||||||
|
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 10")
|
||||||
|
row = cursor.fetchone()
|
||||||
|
assert row["content"] == "new content"
|
||||||
|
assert row["category"] == "projects"
|
||||||
|
assert row["importance"] == 0.9
|
||||||
|
|
||||||
|
|
||||||
|
class TestPushOrphans:
|
||||||
|
def test_push_orphans_skips_duplicates(self, engine):
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
# Insert orphan with content matching server
|
||||||
|
engine._conn.execute(
|
||||||
|
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
|
||||||
|
"is_sensitive, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?)",
|
||||||
|
("duplicate content", "facts", "", "", 0.5, 0, now, now),
|
||||||
|
)
|
||||||
|
engine._conn.commit()
|
||||||
|
|
||||||
|
call_log = []
|
||||||
|
|
||||||
|
def mock_api(method, path, body=None):
|
||||||
|
call_log.append((method, path))
|
||||||
|
return {
|
||||||
|
"memories": [{"id": 1, "content": "duplicate content", "category": "facts",
|
||||||
|
"tags": "", "expanded_keywords": "", "importance": 0.5,
|
||||||
|
"is_sensitive": False, "created_at": now, "updated_at": now}],
|
||||||
|
"server_time": now,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request", side_effect=mock_api):
|
||||||
|
engine._push_orphans()
|
||||||
|
|
||||||
|
# Should have called GET for sync but NOT POST (duplicate skipped)
|
||||||
|
assert all(m != "POST" for m, _ in call_log)
|
||||||
|
|
||||||
|
def test_push_orphans_posts_unique(self, engine):
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
engine._conn.execute(
|
||||||
|
"INSERT INTO memories (id, content, category, tags, expanded_keywords, importance, "
|
||||||
|
"is_sensitive, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
(1, "unique content", "facts", "", "", 0.5, 0, now, now),
|
||||||
|
)
|
||||||
|
engine._conn.commit()
|
||||||
|
|
||||||
|
def mock_api(method, path, body=None):
|
||||||
|
if method == "GET":
|
||||||
|
return {"memories": [], "server_time": now}
|
||||||
|
if method == "POST":
|
||||||
|
return {"id": 100, "category": "facts", "importance": 0.5}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
with patch.object(engine, "_api_request", side_effect=mock_api):
|
||||||
|
engine._push_orphans()
|
||||||
|
|
||||||
|
# Orphan should now have server_id
|
||||||
|
cursor = engine._conn.execute("SELECT server_id FROM memories WHERE id = 1")
|
||||||
|
assert cursor.fetchone()["server_id"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCounts:
|
||||||
|
def test_empty_counts(self, engine):
|
||||||
|
counts = engine.get_counts()
|
||||||
|
assert counts["total"] == 0
|
||||||
|
assert counts["by_category"] == {}
|
||||||
|
assert counts["orphans_no_server_id"] == 0
|
||||||
|
assert counts["pending_ops"] == 0
|
||||||
|
assert counts["auth_failed"] is False
|
||||||
|
|
||||||
|
def test_counts_with_data(self, engine):
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
engine._conn.execute(
|
||||||
|
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
|
||||||
|
"is_sensitive, created_at, updated_at, server_id) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
("mem1", "facts", "", "", 0.5, 0, now, now, 1),
|
||||||
|
)
|
||||||
|
engine._conn.execute(
|
||||||
|
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, "
|
||||||
|
"is_sensitive, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?)",
|
||||||
|
("orphan", "projects", "", "", 0.5, 0, now, now),
|
||||||
|
)
|
||||||
|
engine.enqueue_store(99, "queued", "facts", "", "", 0.5)
|
||||||
|
engine._conn.commit()
|
||||||
|
|
||||||
|
counts = engine.get_counts()
|
||||||
|
assert counts["total"] == 2
|
||||||
|
assert counts["by_category"]["facts"] == 1
|
||||||
|
assert counts["by_category"]["projects"] == 1
|
||||||
|
assert counts["orphans_no_server_id"] == 1
|
||||||
|
assert counts["pending_ops"] == 1
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue