feat: add local SQLite cache with background sync and HA deployment
- Add SyncEngine for background sync between local SQLite cache and remote API with pending_ops queue for offline resilience - Refactor MCP server to support three modes: SQLite-only, hybrid (local cache + sync, new default), and HTTP-only (legacy) - Add GET /api/memories/sync endpoint for incremental sync - Change DELETE to soft delete (set deleted_at) for sync support - Add deleted_at IS NULL filters to all read queries - Scale API deployment to 2 replicas with pod anti-affinity, PDB, and startup probe for high availability - Add migration 003 for deleted_at column and updated_at index - Add comprehensive tests for sync engine and API sync endpoint
This commit is contained in:
parent
fe55ac634b
commit
cd80a67dfa
7 changed files with 1133 additions and 110 deletions
40
migrations/versions/003_add_soft_delete_and_sync.py
Normal file
40
migrations/versions/003_add_soft_delete_and_sync.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Add soft delete and sync support.
|
||||
|
||||
Revision ID: 003
|
||||
Revises: 002
|
||||
Create Date: 2026-03-14
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "003"
|
||||
down_revision: Union[str, None] = "002"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _column_exists(conn, column_name: str) -> bool:
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS(SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = 'memories' AND column_name = :col)"
|
||||
),
|
||||
{"col": column_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
if not _column_exists(conn, "deleted_at"):
|
||||
op.add_column("memories", sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
op.execute("CREATE INDEX IF NOT EXISTS idx_memories_updated ON memories(updated_at)")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_memories_updated")
|
||||
op.drop_column("memories", "deleted_at")
|
||||
|
|
@ -2,13 +2,14 @@
|
|||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException
|
||||
|
||||
from claude_memory.api.auth import AuthUser, get_current_user
|
||||
from claude_memory.api.database import close_pool, get_pool, init_pool
|
||||
from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse
|
||||
from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse, SyncResponse
|
||||
from claude_memory.api.vault_service import (
|
||||
delete_secret,
|
||||
get_secret,
|
||||
|
|
@ -58,6 +59,58 @@ async def health():
|
|||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/api/memories/sync", response_model=SyncResponse)
|
||||
async def sync_memories(
|
||||
since: Optional[str] = None,
|
||||
user: AuthUser = Depends(get_current_user),
|
||||
):
|
||||
pool = await get_pool()
|
||||
server_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
if since:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, content, category, tags, expanded_keywords, importance,
|
||||
is_sensitive, created_at, updated_at, deleted_at
|
||||
FROM memories
|
||||
WHERE user_id = $1 AND updated_at > $2::timestamptz
|
||||
ORDER BY updated_at ASC
|
||||
""",
|
||||
user.user_id,
|
||||
since,
|
||||
)
|
||||
else:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, content, category, tags, expanded_keywords, importance,
|
||||
is_sensitive, created_at, updated_at, deleted_at
|
||||
FROM memories
|
||||
WHERE user_id = $1 AND deleted_at IS NULL
|
||||
ORDER BY updated_at ASC
|
||||
""",
|
||||
user.user_id,
|
||||
)
|
||||
|
||||
memories = []
|
||||
for row in rows:
|
||||
mem = {
|
||||
"id": row["id"],
|
||||
"content": row["content"],
|
||||
"category": row["category"],
|
||||
"tags": row["tags"],
|
||||
"expanded_keywords": row["expanded_keywords"],
|
||||
"importance": row["importance"],
|
||||
"is_sensitive": row["is_sensitive"],
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
"updated_at": row["updated_at"].isoformat(),
|
||||
"deleted_at": row["deleted_at"].isoformat() if row["deleted_at"] else None,
|
||||
}
|
||||
memories.append(mem)
|
||||
|
||||
return SyncResponse(memories=memories, server_time=server_time)
|
||||
|
||||
|
||||
@app.post("/api/memories", response_model=MemoryResponse)
|
||||
async def store_memory(body: MemoryStore, user: AuthUser = Depends(get_current_user)):
|
||||
pool = await get_pool()
|
||||
|
|
@ -117,6 +170,7 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
|||
created_at, updated_at
|
||||
FROM memories, plainto_tsquery('english', $2) query
|
||||
WHERE user_id = $1
|
||||
AND deleted_at IS NULL
|
||||
AND (search_vector @@ query OR $2 = '')
|
||||
{category_filter}
|
||||
ORDER BY {order_clause}
|
||||
|
|
@ -158,14 +212,14 @@ async def list_memories(
|
|||
if category:
|
||||
query = """
|
||||
SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
||||
FROM memories WHERE user_id = $1 AND category = $2
|
||||
FROM memories WHERE user_id = $1 AND deleted_at IS NULL AND category = $2
|
||||
ORDER BY importance DESC LIMIT $3
|
||||
"""
|
||||
params: list = [user.user_id, category, limit]
|
||||
else:
|
||||
query = """
|
||||
SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
||||
FROM memories WHERE user_id = $1
|
||||
FROM memories WHERE user_id = $1 AND deleted_at IS NULL
|
||||
ORDER BY importance DESC LIMIT $2
|
||||
"""
|
||||
params = [user.user_id, limit]
|
||||
|
|
@ -200,7 +254,7 @@ async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_use
|
|||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2",
|
||||
"SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||
memory_id,
|
||||
user.user_id,
|
||||
)
|
||||
|
|
@ -211,7 +265,7 @@ async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_use
|
|||
await delete_secret(user.user_id, row["vault_path"])
|
||||
|
||||
await conn.execute(
|
||||
"DELETE FROM memories WHERE id = $1 AND user_id = $2",
|
||||
"UPDATE memories SET deleted_at = NOW(), updated_at = NOW() WHERE id = $1 AND user_id = $2",
|
||||
memory_id,
|
||||
user.user_id,
|
||||
)
|
||||
|
|
@ -227,7 +281,7 @@ async def get_memory_secret(memory_id: int, user: AuthUser = Depends(get_current
|
|||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT id, content, is_sensitive, vault_path, encrypted_content
|
||||
FROM memories WHERE id = $1 AND user_id = $2
|
||||
FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL
|
||||
""",
|
||||
memory_id,
|
||||
user.user_id,
|
||||
|
|
@ -263,7 +317,7 @@ async def migrate_secrets(user: AuthUser = Depends(get_current_user)):
|
|||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, content FROM memories
|
||||
WHERE user_id = $1 AND is_sensitive = FALSE
|
||||
WHERE user_id = $1 AND is_sensitive = FALSE AND deleted_at IS NULL
|
||||
""",
|
||||
user.user_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -30,3 +30,8 @@ class SecretResponse(BaseModel):
|
|||
id: int
|
||||
content: str
|
||||
source: str # "vault", "encrypted", "plaintext"
|
||||
|
||||
|
||||
class SyncResponse(BaseModel):
|
||||
memories: list[dict[str, Any]]
|
||||
server_time: str
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@
|
|||
"""
|
||||
Claude Memory MCP Server — standalone memory server with multi-user support.
|
||||
|
||||
Supports two modes:
|
||||
1. HTTP API mode: connects to a shared PostgreSQL-backed API server
|
||||
2. SQLite fallback: local file-based storage when no API key is configured
|
||||
Supports three modes:
|
||||
1. SQLite-only: local file-based storage when no API key is configured
|
||||
2. Hybrid (default when API key set): local SQLite cache + background sync
|
||||
3. HTTP-only (legacy): direct HTTP to API, no local cache (MEMORY_SYNC_DISABLE=1)
|
||||
|
||||
Uses only stdlib (urllib) — no pip install required.
|
||||
"""
|
||||
|
|
@ -21,14 +22,17 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
PROTOCOL_VERSION = "2024-11-05"
|
||||
SERVER_NAME = "claude-memory"
|
||||
SERVER_VERSION = "1.0.0"
|
||||
SERVER_VERSION = "1.1.0"
|
||||
|
||||
# API configuration — support both MEMORY_* (primary) and CLAUDE_MEMORY_* (fallback) env vars
|
||||
API_BASE_URL = os.environ.get("MEMORY_API_URL") or os.environ.get("CLAUDE_MEMORY_API_URL", "http://localhost:8080")
|
||||
API_KEY = os.environ.get("MEMORY_API_KEY") or os.environ.get("CLAUDE_MEMORY_API_KEY", "")
|
||||
|
||||
# Fallback to SQLite if API is not configured
|
||||
SQLITE_FALLBACK = not API_KEY
|
||||
# Mode detection
|
||||
SYNC_DISABLED = os.environ.get("MEMORY_SYNC_DISABLE", "") == "1"
|
||||
HYBRID_MODE = bool(API_KEY) and not SYNC_DISABLED
|
||||
HTTP_ONLY = bool(API_KEY) and SYNC_DISABLED
|
||||
SQLITE_ONLY = not API_KEY
|
||||
|
||||
|
||||
def _api_request(method: str, path: str, body: dict | None = None) -> dict:
|
||||
|
|
@ -54,23 +58,29 @@ def _api_request(method: str, path: str, body: dict | None = None) -> dict:
|
|||
raise RuntimeError(f"API connection error: {e.reason}") from e
|
||||
|
||||
|
||||
# ─── SQLite fallback (local storage when API not configured) ─────────────────
|
||||
# ─── SQLite initialization ────────────────────────────────────────────────────
|
||||
|
||||
def _get_db_path(db_path: str | None = None) -> str:
|
||||
"""Resolve the SQLite database path."""
|
||||
if db_path is not None:
|
||||
return db_path
|
||||
|
||||
memory_home = os.path.expandvars(
|
||||
os.path.expanduser(os.environ.get("MEMORY_HOME", "~/.claude/claude-memory"))
|
||||
)
|
||||
db_path = os.environ.get(
|
||||
"MEMORY_DB",
|
||||
os.path.join(memory_home, "memory", "memory.db"),
|
||||
)
|
||||
return os.path.expandvars(os.path.expanduser(db_path))
|
||||
|
||||
|
||||
def _init_sqlite(db_path: str | None = None):
|
||||
"""Initialize SQLite database as fallback."""
|
||||
"""Initialize SQLite database."""
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
if db_path is None:
|
||||
memory_home = os.path.expandvars(
|
||||
os.path.expanduser(os.environ.get("MEMORY_HOME", "~/.claude/claude-memory"))
|
||||
)
|
||||
db_path = os.environ.get(
|
||||
"MEMORY_DB",
|
||||
os.path.join(memory_home, "memory", "memory.db"),
|
||||
)
|
||||
db_path = os.path.expandvars(os.path.expanduser(db_path))
|
||||
|
||||
db_path = _get_db_path(db_path)
|
||||
Path(os.path.dirname(db_path)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(db_path, timeout=30.0)
|
||||
|
|
@ -91,6 +101,15 @@ def _init_sqlite(db_path: str | None = None):
|
|||
updated_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
# Add server_id column if missing (for hybrid mode sync)
|
||||
cursor.execute("PRAGMA table_info(memories)")
|
||||
columns = {row["name"] for row in cursor.fetchall()}
|
||||
if "server_id" not in columns:
|
||||
cursor.execute("ALTER TABLE memories ADD COLUMN server_id INTEGER")
|
||||
cursor.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)"
|
||||
)
|
||||
|
||||
cursor.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
|
||||
content, category, tags, expanded_keywords,
|
||||
|
|
@ -118,7 +137,7 @@ def _init_sqlite(db_path: str | None = None):
|
|||
END
|
||||
""")
|
||||
conn.commit()
|
||||
return conn
|
||||
return conn, db_path
|
||||
|
||||
|
||||
# ─── Tool definitions ────────────────────────────────────────────────────────
|
||||
|
|
@ -229,10 +248,27 @@ class MemoryServer:
|
|||
|
||||
def __init__(self, sqlite_db_path: str | None = None) -> None:
|
||||
self.sqlite_conn = None
|
||||
if SQLITE_FALLBACK:
|
||||
self.sqlite_conn = _init_sqlite(sqlite_db_path)
|
||||
self.sync_engine = None
|
||||
|
||||
# ── HTTP-backed methods ──────────────────────────────────────────
|
||||
if SQLITE_ONLY or HYBRID_MODE:
|
||||
self.sqlite_conn, resolved_path = _init_sqlite(sqlite_db_path)
|
||||
|
||||
if HYBRID_MODE:
|
||||
from claude_memory.sync import SyncEngine
|
||||
sync_interval = int(os.environ.get("MEMORY_SYNC_INTERVAL", "60"))
|
||||
self.sync_engine = SyncEngine(
|
||||
db_path=resolved_path,
|
||||
api_base_url=API_BASE_URL,
|
||||
api_key=API_KEY,
|
||||
sync_interval=sync_interval,
|
||||
)
|
||||
self.sync_engine.start()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.sync_engine:
|
||||
self.sync_engine.stop()
|
||||
|
||||
# ── Tool methods ────────────────────────────────────────────────
|
||||
|
||||
def memory_store(self, args: dict[str, Any]) -> str:
|
||||
content = args.get("content")
|
||||
|
|
@ -244,18 +280,28 @@ class MemoryServer:
|
|||
expanded_keywords = args.get("expanded_keywords", "")
|
||||
force_sensitive = bool(args.get("force_sensitive", False))
|
||||
|
||||
if SQLITE_FALLBACK:
|
||||
return self._sqlite_store(content, category, tags, importance, expanded_keywords, force_sensitive)
|
||||
if HTTP_ONLY:
|
||||
result = _api_request("POST", "/api/memories", {
|
||||
"content": content,
|
||||
"category": category,
|
||||
"tags": tags,
|
||||
"expanded_keywords": expanded_keywords,
|
||||
"importance": importance,
|
||||
"force_sensitive": force_sensitive,
|
||||
})
|
||||
return f"Stored memory #{result['id']} in category '{result['category']}' with importance {result['importance']:.1f}"
|
||||
|
||||
result = _api_request("POST", "/api/memories", {
|
||||
"content": content,
|
||||
"category": category,
|
||||
"tags": tags,
|
||||
"expanded_keywords": expanded_keywords,
|
||||
"importance": importance,
|
||||
"force_sensitive": force_sensitive,
|
||||
})
|
||||
return f"Stored memory #{result['id']} in category '{result['category']}' with importance {result['importance']:.1f}"
|
||||
# SQLite-only or Hybrid: write to local SQLite first
|
||||
result_text = self._sqlite_store(content, category, tags, importance, expanded_keywords, force_sensitive)
|
||||
|
||||
if HYBRID_MODE and self.sync_engine:
|
||||
# Extract local_id from result text
|
||||
local_id = int(result_text.split("#")[1].split(" ")[0])
|
||||
self.sync_engine.try_sync_store(
|
||||
local_id, content, category, tags, expanded_keywords, importance, force_sensitive
|
||||
)
|
||||
|
||||
return result_text
|
||||
|
||||
def memory_recall(self, args: dict[str, Any]) -> str:
|
||||
context = args.get("context")
|
||||
|
|
@ -266,80 +312,102 @@ class MemoryServer:
|
|||
sort_by = args.get("sort_by", "importance")
|
||||
limit = args.get("limit", 10)
|
||||
|
||||
if SQLITE_FALLBACK:
|
||||
return self._sqlite_recall(context, expanded_query, category, sort_by, limit)
|
||||
if HTTP_ONLY:
|
||||
result = _api_request("POST", "/api/memories/recall", {
|
||||
"context": context,
|
||||
"expanded_query": expanded_query,
|
||||
"category": category,
|
||||
"sort_by": sort_by,
|
||||
"limit": limit,
|
||||
})
|
||||
rows = result.get("memories", [])
|
||||
if not rows:
|
||||
filter_desc = f" in category '{category}'" if category else ""
|
||||
return f"No memories found matching: {context}{filter_desc}"
|
||||
|
||||
result = _api_request("POST", "/api/memories/recall", {
|
||||
"context": context,
|
||||
"expanded_query": expanded_query,
|
||||
"category": category,
|
||||
"sort_by": sort_by,
|
||||
"limit": limit,
|
||||
})
|
||||
rows = result.get("memories", [])
|
||||
if not rows:
|
||||
filter_desc = f" in category '{category}'" if category else ""
|
||||
return f"No memories found matching: {context}{filter_desc}"
|
||||
sort_desc = "by relevance" if sort_by == "relevance" else "by importance"
|
||||
filter_desc = f" in '{category}'" if category else ""
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}"
|
||||
f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
|
||||
)
|
||||
return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results)
|
||||
|
||||
sort_desc = "by relevance" if sort_by == "relevance" else "by importance"
|
||||
filter_desc = f" in '{category}'" if category else ""
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}"
|
||||
f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
|
||||
)
|
||||
return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results)
|
||||
# SQLite-only or Hybrid: always read from local cache
|
||||
return self._sqlite_recall(context, expanded_query, category, sort_by, limit)
|
||||
|
||||
def memory_list(self, args: dict[str, Any]) -> str:
|
||||
category = args.get("category")
|
||||
limit = args.get("limit", 20)
|
||||
|
||||
if SQLITE_FALLBACK:
|
||||
return self._sqlite_list(category, limit)
|
||||
if HTTP_ONLY:
|
||||
params = f"?limit={limit}"
|
||||
if category:
|
||||
params += f"&category={category}"
|
||||
result = _api_request("GET", f"/api/memories{params}")
|
||||
rows = result.get("memories", [])
|
||||
if not rows:
|
||||
return f"No memories in category '{category}'" if category else "No memories stored yet"
|
||||
|
||||
params = f"?limit={limit}"
|
||||
if category:
|
||||
params += f"&category={category}"
|
||||
result = _api_request("GET", f"/api/memories{params}")
|
||||
rows = result.get("memories", [])
|
||||
if not rows:
|
||||
return f"No memories in category '{category}'" if category else "No memories stored yet"
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
f"#{row['id']} [{row['category']}] {row['content']}"
|
||||
f"\n Importance: {row['importance']:.1f} | Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
|
||||
)
|
||||
header = "Recent memories"
|
||||
if category:
|
||||
header += f" in '{category}'"
|
||||
return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results)
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
f"#{row['id']} [{row['category']}] {row['content']}"
|
||||
f"\n Importance: {row['importance']:.1f} | Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
|
||||
)
|
||||
header = "Recent memories"
|
||||
if category:
|
||||
header += f" in '{category}'"
|
||||
return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results)
|
||||
# SQLite-only or Hybrid: always read from local cache
|
||||
return self._sqlite_list(category, limit)
|
||||
|
||||
def memory_delete(self, args: dict[str, Any]) -> str:
|
||||
memory_id = args.get("id")
|
||||
if memory_id is None:
|
||||
raise ValueError("id is required")
|
||||
|
||||
if SQLITE_FALLBACK:
|
||||
return self._sqlite_delete(memory_id)
|
||||
if HTTP_ONLY:
|
||||
result = _api_request("DELETE", f"/api/memories/{memory_id}")
|
||||
return f"Deleted memory #{result['deleted']}: {result['preview']}..."
|
||||
|
||||
result = _api_request("DELETE", f"/api/memories/{memory_id}")
|
||||
return f"Deleted memory #{result['deleted']}: {result['preview']}..."
|
||||
# SQLite-only or Hybrid: delete from local SQLite
|
||||
# In hybrid mode, also try to sync delete to server
|
||||
if HYBRID_MODE and self.sync_engine:
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
cursor.execute("SELECT server_id FROM memories WHERE id = ?", (memory_id,))
|
||||
row = cursor.fetchone()
|
||||
server_id = row["server_id"] if row and row["server_id"] else None
|
||||
|
||||
result_text = self._sqlite_delete(memory_id)
|
||||
|
||||
if HYBRID_MODE and self.sync_engine and server_id:
|
||||
self.sync_engine.try_sync_delete(server_id)
|
||||
|
||||
return result_text
|
||||
|
||||
def secret_get(self, args: dict[str, Any]) -> str:
|
||||
memory_id = args.get("id")
|
||||
if memory_id is None:
|
||||
raise ValueError("id is required")
|
||||
|
||||
if SQLITE_FALLBACK:
|
||||
return self._sqlite_secret_get(memory_id)
|
||||
if HTTP_ONLY or HYBRID_MODE:
|
||||
# Secrets should be fetched from API when available
|
||||
try:
|
||||
result = _api_request("POST", f"/api/memories/{memory_id}/secret")
|
||||
return f"#{result['id']} [{result['category']}] {result['content']}"
|
||||
except Exception:
|
||||
if HYBRID_MODE:
|
||||
# Fall back to local SQLite
|
||||
return self._sqlite_secret_get(memory_id)
|
||||
raise
|
||||
|
||||
result = _api_request("POST", f"/api/memories/{memory_id}/secret")
|
||||
return f"#{result['id']} [{result['category']}] {result['content']}"
|
||||
return self._sqlite_secret_get(memory_id)
|
||||
|
||||
# ── SQLite fallback methods ──────────────────────────────────────
|
||||
# ── SQLite methods ──────────────────────────────────────────────
|
||||
|
||||
def _sqlite_store(self, content, category, tags, importance, expanded_keywords, force_sensitive=False):
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -520,25 +588,29 @@ class MemoryServer:
|
|||
return response
|
||||
|
||||
def run(self) -> None:
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
message = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
json.dumps({
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {"code": -32700, "message": f"Parse error: {e}"},
|
||||
}),
|
||||
flush=True,
|
||||
)
|
||||
continue
|
||||
response = self.process_message(message)
|
||||
if response is not None:
|
||||
print(json.dumps(response), flush=True)
|
||||
try:
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
message = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
json.dumps({
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {"code": -32700, "message": f"Parse error: {e}"},
|
||||
}),
|
||||
flush=True,
|
||||
)
|
||||
continue
|
||||
response = self.process_message(message)
|
||||
if response is not None:
|
||||
print(json.dumps(response), flush=True)
|
||||
finally:
|
||||
if self.sync_engine:
|
||||
self.sync_engine.stop()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
|
|||
334
src/claude_memory/sync.py
Normal file
334
src/claude_memory/sync.py
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
"""Background sync between local SQLite cache and remote API.
|
||||
|
||||
Uses only stdlib — no pip install required.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyncEngine:
|
||||
"""Background sync between local SQLite cache and remote API."""
|
||||
|
||||
def __init__(self, db_path: str, api_base_url: str, api_key: str, sync_interval: int = 60):
|
||||
self.db_path = db_path
|
||||
self.api_base_url = api_base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.sync_interval = sync_interval
|
||||
|
||||
self._stop_event = threading.Event()
|
||||
self._thread: threading.Thread | None = None
|
||||
self._last_sync_success = False
|
||||
|
||||
# Own connection for thread safety
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._conn = sqlite3.connect(db_path, timeout=30.0, check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("PRAGMA busy_timeout=30000")
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._init_sync_tables()
|
||||
|
||||
def _init_sync_tables(self) -> None:
|
||||
"""Create sync-specific tables if they don't exist."""
|
||||
with self._lock:
|
||||
self._conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS pending_ops (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
op_type TEXT NOT NULL,
|
||||
payload TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sync_meta (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
);
|
||||
""")
|
||||
# Add server_id column to memories if missing
|
||||
cursor = self._conn.execute("PRAGMA table_info(memories)")
|
||||
columns = {row["name"] for row in cursor.fetchall()}
|
||||
if "server_id" not in columns:
|
||||
self._conn.execute("ALTER TABLE memories ADD COLUMN server_id INTEGER")
|
||||
self._conn.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_memories_server_id ON memories(server_id)"
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
@property
|
||||
def last_sync_ts(self) -> str | None:
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT value FROM sync_meta WHERE key = 'last_sync_ts'"
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row["value"] if row else None
|
||||
|
||||
@last_sync_ts.setter
|
||||
def last_sync_ts(self, value: str) -> None:
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"INSERT OR REPLACE INTO sync_meta (key, value) VALUES ('last_sync_ts', ?)",
|
||||
(value,),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
@property
|
||||
def api_available(self) -> bool:
|
||||
return self._last_sync_success
|
||||
|
||||
def start(self) -> None:
|
||||
"""Run initial sync (blocking), then start background thread."""
|
||||
try:
|
||||
self._sync_once()
|
||||
self._last_sync_success = True
|
||||
except Exception:
|
||||
logger.warning("Initial sync failed, starting in offline mode")
|
||||
self._last_sync_success = False
|
||||
|
||||
self._thread = threading.Thread(target=self._sync_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal background thread to stop and wait."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=5)
|
||||
self._conn.close()
|
||||
|
||||
def _sync_loop(self) -> None:
|
||||
"""Periodic sync loop running in background thread."""
|
||||
while not self._stop_event.is_set():
|
||||
self._stop_event.wait(self.sync_interval)
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
try:
|
||||
self._sync_once()
|
||||
self._last_sync_success = True
|
||||
except Exception as e:
|
||||
logger.warning("Sync cycle failed: %s", e)
|
||||
self._last_sync_success = False
|
||||
|
||||
def _sync_once(self) -> None:
|
||||
"""Push pending ops, then pull remote changes."""
|
||||
self._push_pending_ops()
|
||||
self._pull_changes()
|
||||
|
||||
def _api_request(self, method: str, path: str, body: dict | None = None) -> dict:
|
||||
"""Make an HTTP request to the memory API."""
|
||||
url = f"{self.api_base_url}{path}"
|
||||
data = json.dumps(body).encode() if body else None
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
data=data,
|
||||
method=method,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
return json.loads(resp.read().decode())
|
||||
|
||||
def _push_pending_ops(self) -> None:
|
||||
"""Push queued operations to the API server."""
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id, op_type, payload FROM pending_ops ORDER BY id"
|
||||
)
|
||||
ops = cursor.fetchall()
|
||||
|
||||
for op in ops:
|
||||
op_id = op["id"]
|
||||
op_type = op["op_type"]
|
||||
payload = json.loads(op["payload"])
|
||||
|
||||
try:
|
||||
if op_type == "store":
|
||||
result = self._api_request("POST", "/api/memories", payload)
|
||||
server_id = result.get("id")
|
||||
if server_id and payload.get("local_id"):
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"UPDATE memories SET server_id = ? WHERE id = ?",
|
||||
(server_id, payload["local_id"]),
|
||||
)
|
||||
self._conn.commit()
|
||||
elif op_type == "delete":
|
||||
server_id = payload.get("server_id")
|
||||
if server_id:
|
||||
try:
|
||||
self._api_request("DELETE", f"/api/memories/{server_id}")
|
||||
except RuntimeError as e:
|
||||
if "404" in str(e):
|
||||
pass # Already deleted on server
|
||||
else:
|
||||
raise
|
||||
|
||||
# Remove from pending queue on success
|
||||
with self._lock:
|
||||
self._conn.execute("DELETE FROM pending_ops WHERE id = ?", (op_id,))
|
||||
self._conn.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to push op %d (%s): %s", op_id, op_type, e)
|
||||
raise # Propagate to mark sync as failed
|
||||
|
||||
def _pull_changes(self) -> None:
|
||||
"""Pull changes from server since last sync."""
|
||||
params = ""
|
||||
ts = self.last_sync_ts
|
||||
if ts:
|
||||
params = f"?since={ts}"
|
||||
|
||||
result = self._api_request("GET", f"/api/memories/sync{params}")
|
||||
memories = result.get("memories", [])
|
||||
server_time = result.get("server_time")
|
||||
|
||||
with self._lock:
|
||||
for mem in memories:
|
||||
server_id = mem["id"]
|
||||
deleted_at = mem.get("deleted_at")
|
||||
|
||||
if deleted_at:
|
||||
# Remove from local cache
|
||||
self._conn.execute(
|
||||
"DELETE FROM memories WHERE server_id = ?", (server_id,)
|
||||
)
|
||||
else:
|
||||
# Upsert by server_id (server wins)
|
||||
existing = self._conn.execute(
|
||||
"SELECT id FROM memories WHERE server_id = ?", (server_id,)
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
self._conn.execute(
|
||||
"""UPDATE memories SET content = ?, category = ?, tags = ?,
|
||||
expanded_keywords = ?, importance = ?, is_sensitive = ?,
|
||||
updated_at = ? WHERE server_id = ?""",
|
||||
(
|
||||
mem["content"],
|
||||
mem["category"],
|
||||
mem.get("tags", ""),
|
||||
mem.get("expanded_keywords", ""),
|
||||
mem["importance"],
|
||||
1 if mem.get("is_sensitive") else 0,
|
||||
mem.get("updated_at", datetime.now(timezone.utc).isoformat()),
|
||||
server_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self._conn.execute(
|
||||
"""INSERT INTO memories
|
||||
(content, category, tags, expanded_keywords, importance,
|
||||
is_sensitive, created_at, updated_at, server_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
mem["content"],
|
||||
mem["category"],
|
||||
mem.get("tags", ""),
|
||||
mem.get("expanded_keywords", ""),
|
||||
mem["importance"],
|
||||
1 if mem.get("is_sensitive") else 0,
|
||||
mem.get("created_at", datetime.now(timezone.utc).isoformat()),
|
||||
mem.get("updated_at", datetime.now(timezone.utc).isoformat()),
|
||||
server_id,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
if server_time:
|
||||
self.last_sync_ts = server_time
|
||||
|
||||
def enqueue_store(
|
||||
self,
|
||||
local_id: int,
|
||||
content: str,
|
||||
category: str,
|
||||
tags: str,
|
||||
expanded_keywords: str,
|
||||
importance: float,
|
||||
force_sensitive: bool = False,
|
||||
) -> None:
|
||||
"""Queue a store operation for later sync."""
|
||||
payload = {
|
||||
"local_id": local_id,
|
||||
"content": content,
|
||||
"category": category,
|
||||
"tags": tags,
|
||||
"expanded_keywords": expanded_keywords,
|
||||
"importance": importance,
|
||||
"force_sensitive": force_sensitive,
|
||||
}
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"INSERT INTO pending_ops (op_type, payload, created_at) VALUES (?, ?, ?)",
|
||||
("store", json.dumps(payload), now),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def enqueue_delete(self, server_id: int) -> None:
|
||||
"""Queue a delete operation for later sync."""
|
||||
payload = {"server_id": server_id}
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"INSERT INTO pending_ops (op_type, payload, created_at) VALUES (?, ?, ?)",
|
||||
("delete", json.dumps(payload), now),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def try_sync_store(
|
||||
self,
|
||||
local_id: int,
|
||||
content: str,
|
||||
category: str,
|
||||
tags: str,
|
||||
expanded_keywords: str,
|
||||
importance: float,
|
||||
force_sensitive: bool = False,
|
||||
) -> int | None:
|
||||
"""Try to sync a store immediately. Returns server_id or None if failed."""
|
||||
try:
|
||||
result = self._api_request("POST", "/api/memories", {
|
||||
"content": content,
|
||||
"category": category,
|
||||
"tags": tags,
|
||||
"expanded_keywords": expanded_keywords,
|
||||
"importance": importance,
|
||||
"force_sensitive": force_sensitive,
|
||||
})
|
||||
server_id = result.get("id")
|
||||
if server_id:
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"UPDATE memories SET server_id = ? WHERE id = ?",
|
||||
(server_id, local_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
return server_id
|
||||
except Exception:
|
||||
self.enqueue_store(
|
||||
local_id, content, category, tags, expanded_keywords, importance, force_sensitive
|
||||
)
|
||||
return None
|
||||
|
||||
def try_sync_delete(self, server_id: int) -> bool:
|
||||
"""Try to sync a delete immediately. Returns True if successful."""
|
||||
try:
|
||||
self._api_request("DELETE", f"/api/memories/{server_id}")
|
||||
return True
|
||||
except Exception:
|
||||
self.enqueue_delete(server_id)
|
||||
return False
|
||||
|
|
@ -36,6 +36,7 @@ def _make_memory_row(**overrides):
|
|||
"rank": 0.5,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"deleted_at": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return MockRow(defaults)
|
||||
|
|
@ -307,3 +308,125 @@ async def test_import_memories(client):
|
|||
assert len(data) == 2
|
||||
assert data[0]["id"] == 100
|
||||
assert data[1]["id"] == 101
|
||||
|
||||
|
||||
# ─── Sync endpoint tests ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_full_dump_without_since(client):
|
||||
ac, conn, app_mod = client
|
||||
now = datetime.now(timezone.utc)
|
||||
conn.fetch.return_value = [
|
||||
_make_memory_row(id=1, content="mem1", deleted_at=None),
|
||||
_make_memory_row(id=2, content="mem2", deleted_at=None),
|
||||
]
|
||||
|
||||
async with ac:
|
||||
resp = await ac.get(
|
||||
"/api/memories/sync",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["memories"]) == 2
|
||||
assert "server_time" in data
|
||||
assert data["memories"][0]["id"] == 1
|
||||
assert data["memories"][1]["id"] == 2
|
||||
|
||||
# Without since param, should query non-deleted only
|
||||
call_args = conn.fetch.call_args
|
||||
query = call_args[0][0]
|
||||
assert "deleted_at IS NULL" in query
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_incremental_with_since(client):
|
||||
ac, conn, app_mod = client
|
||||
now = datetime.now(timezone.utc)
|
||||
conn.fetch.return_value = [
|
||||
_make_memory_row(id=3, content="updated mem", deleted_at=None),
|
||||
]
|
||||
|
||||
async with ac:
|
||||
resp = await ac.get(
|
||||
"/api/memories/sync?since=2026-03-14T10:00:00+00:00",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["memories"]) == 1
|
||||
|
||||
# With since param, should include updated_at filter (includes soft-deleted)
|
||||
call_args = conn.fetch.call_args
|
||||
query = call_args[0][0]
|
||||
assert "updated_at >" in query
|
||||
assert "deleted_at IS NULL" not in query
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_includes_soft_deleted_with_since(client):
|
||||
ac, conn, app_mod = client
|
||||
now = datetime.now(timezone.utc)
|
||||
conn.fetch.return_value = [
|
||||
_make_memory_row(id=5, content="deleted mem", deleted_at=now),
|
||||
]
|
||||
|
||||
async with ac:
|
||||
resp = await ac.get(
|
||||
"/api/memories/sync?since=2026-03-14T10:00:00+00:00",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["memories"]) == 1
|
||||
assert data["memories"][0]["deleted_at"] is not None
|
||||
|
||||
|
||||
# ─── Soft delete tests ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_is_soft_delete(client):
|
||||
"""Delete should SET deleted_at, not DELETE the row."""
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = _make_memory_row(id=10, vault_path=None, preview="test content")
|
||||
conn.execute.return_value = None
|
||||
|
||||
async with ac:
|
||||
resp = await ac.delete(
|
||||
"/api/memories/10",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Verify the execute call uses UPDATE SET deleted_at, not DELETE
|
||||
execute_args = conn.execute.call_args
|
||||
query = execute_args[0][0]
|
||||
assert "UPDATE" in query
|
||||
assert "deleted_at" in query
|
||||
assert "DELETE" not in query.upper().split("SET")[0] # No DELETE before SET
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_excludes_already_deleted(client):
|
||||
"""DELETE endpoint should not find already-deleted memories."""
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = None # Not found because deleted_at IS NULL filter
|
||||
|
||||
async with ac:
|
||||
resp = await ac.delete(
|
||||
"/api/memories/10",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
# Verify query includes deleted_at IS NULL
|
||||
call_args = conn.fetchrow.call_args
|
||||
query = call_args[0][0]
|
||||
assert "deleted_at IS NULL" in query
|
||||
|
|
|
|||
395
tests/test_sync.py
Normal file
395
tests/test_sync.py
Normal file
|
|
@ -0,0 +1,395 @@
|
|||
"""Tests for the SyncEngine (local SQLite cache + remote API sync)."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Force SQLite-only mode for test imports
|
||||
os.environ.pop("MEMORY_API_KEY", None)
|
||||
os.environ.pop("CLAUDE_MEMORY_API_KEY", None)
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
from claude_memory.mcp_server import _init_sqlite
|
||||
from claude_memory.sync import SyncEngine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_path(tmp_path):
|
||||
return str(tmp_path / "test_sync.db")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqlite_conn(db_path):
|
||||
"""Create a SQLite database with the standard schema."""
|
||||
conn, _ = _init_sqlite(db_path)
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine(db_path, sqlite_conn):
|
||||
"""Create a SyncEngine with mocked API."""
|
||||
eng = SyncEngine(
|
||||
db_path=db_path,
|
||||
api_base_url="http://fake-api:8080",
|
||||
api_key="test-key",
|
||||
sync_interval=3600, # Don't auto-sync in tests
|
||||
)
|
||||
yield eng
|
||||
eng._conn.close()
|
||||
|
||||
|
||||
class TestSyncEngineInit:
|
||||
def test_creates_pending_ops_table(self, engine):
|
||||
cursor = engine._conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='pending_ops'"
|
||||
)
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
def test_creates_sync_meta_table(self, engine):
|
||||
cursor = engine._conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='sync_meta'"
|
||||
)
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
def test_adds_server_id_column(self, engine):
|
||||
cursor = engine._conn.execute("PRAGMA table_info(memories)")
|
||||
columns = {row["name"] for row in cursor.fetchall()}
|
||||
assert "server_id" in columns
|
||||
|
||||
def test_server_id_unique_index(self, engine):
|
||||
cursor = engine._conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='index' AND name='idx_memories_server_id'"
|
||||
)
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
|
||||
class TestEnqueueOps:
|
||||
def test_enqueue_store(self, engine):
|
||||
engine.enqueue_store(
|
||||
local_id=1,
|
||||
content="test memory",
|
||||
category="facts",
|
||||
tags="test",
|
||||
expanded_keywords="test memory keywords",
|
||||
importance=0.7,
|
||||
)
|
||||
cursor = engine._conn.execute("SELECT * FROM pending_ops")
|
||||
ops = cursor.fetchall()
|
||||
assert len(ops) == 1
|
||||
assert ops[0]["op_type"] == "store"
|
||||
payload = json.loads(ops[0]["payload"])
|
||||
assert payload["content"] == "test memory"
|
||||
assert payload["local_id"] == 1
|
||||
assert payload["importance"] == 0.7
|
||||
|
||||
def test_enqueue_delete(self, engine):
|
||||
engine.enqueue_delete(server_id=42)
|
||||
cursor = engine._conn.execute("SELECT * FROM pending_ops")
|
||||
ops = cursor.fetchall()
|
||||
assert len(ops) == 1
|
||||
assert ops[0]["op_type"] == "delete"
|
||||
payload = json.loads(ops[0]["payload"])
|
||||
assert payload["server_id"] == 42
|
||||
|
||||
def test_multiple_enqueues(self, engine):
|
||||
engine.enqueue_store(1, "mem1", "facts", "", "", 0.5)
|
||||
engine.enqueue_store(2, "mem2", "facts", "", "", 0.5)
|
||||
engine.enqueue_delete(10)
|
||||
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||
assert cursor.fetchone()["cnt"] == 3
|
||||
|
||||
|
||||
class TestPushPendingOps:
|
||||
def test_push_store_clears_queue(self, engine):
|
||||
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
|
||||
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {"id": 100, "category": "facts", "importance": 0.5}
|
||||
engine._push_pending_ops()
|
||||
|
||||
# Queue should be empty
|
||||
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||
assert cursor.fetchone()["cnt"] == 0
|
||||
|
||||
# server_id should be set on local memory (if it exists)
|
||||
mock_api.assert_called_once()
|
||||
|
||||
def test_push_store_updates_server_id(self, engine, sqlite_conn):
|
||||
# Insert a local memory first
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
sqlite_conn.execute(
|
||||
"INSERT INTO memories (id, content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(1, "test content", "facts", "", "kw", 0.5, 0, now, now),
|
||||
)
|
||||
sqlite_conn.commit()
|
||||
|
||||
engine.enqueue_store(1, "test content", "facts", "", "kw", 0.5)
|
||||
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {"id": 200, "category": "facts", "importance": 0.5}
|
||||
engine._push_pending_ops()
|
||||
|
||||
# Check server_id was updated
|
||||
cursor = engine._conn.execute("SELECT server_id FROM memories WHERE id = 1")
|
||||
row = cursor.fetchone()
|
||||
assert row["server_id"] == 200
|
||||
|
||||
def test_push_delete_clears_queue(self, engine):
|
||||
engine.enqueue_delete(42)
|
||||
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {"deleted": 42, "preview": "test"}
|
||||
engine._push_pending_ops()
|
||||
|
||||
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||
assert cursor.fetchone()["cnt"] == 0
|
||||
|
||||
def test_push_delete_404_still_clears(self, engine):
|
||||
"""A 404 on delete means already deleted on server — should still clear queue."""
|
||||
engine.enqueue_delete(42)
|
||||
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.side_effect = RuntimeError("API error 404: not found")
|
||||
engine._push_pending_ops()
|
||||
|
||||
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||
assert cursor.fetchone()["cnt"] == 0
|
||||
|
||||
def test_push_failure_keeps_queue(self, engine):
|
||||
engine.enqueue_store(1, "test", "facts", "", "kw", 0.5)
|
||||
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.side_effect = RuntimeError("Connection refused")
|
||||
with pytest.raises(RuntimeError):
|
||||
engine._push_pending_ops()
|
||||
|
||||
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||
assert cursor.fetchone()["cnt"] == 1
|
||||
|
||||
|
||||
class TestPullChanges:
|
||||
def test_pull_inserts_new_memories(self, engine):
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {
|
||||
"memories": [
|
||||
{
|
||||
"id": 10,
|
||||
"content": "server memory",
|
||||
"category": "facts",
|
||||
"tags": "tag1",
|
||||
"expanded_keywords": "server memory keywords",
|
||||
"importance": 0.8,
|
||||
"is_sensitive": False,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"deleted_at": None,
|
||||
}
|
||||
],
|
||||
"server_time": now,
|
||||
}
|
||||
engine._pull_changes()
|
||||
|
||||
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 10")
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["content"] == "server memory"
|
||||
assert row["importance"] == 0.8
|
||||
|
||||
def test_pull_updates_existing_memories(self, engine):
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
# Insert existing memory with server_id
|
||||
engine._conn.execute(
|
||||
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at, server_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
("old content", "facts", "", "", 0.5, 0, now, now, 10),
|
||||
)
|
||||
engine._conn.commit()
|
||||
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {
|
||||
"memories": [
|
||||
{
|
||||
"id": 10,
|
||||
"content": "updated content",
|
||||
"category": "projects",
|
||||
"tags": "",
|
||||
"expanded_keywords": "",
|
||||
"importance": 0.9,
|
||||
"is_sensitive": False,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"deleted_at": None,
|
||||
}
|
||||
],
|
||||
"server_time": now,
|
||||
}
|
||||
engine._pull_changes()
|
||||
|
||||
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 10")
|
||||
row = cursor.fetchone()
|
||||
assert row["content"] == "updated content"
|
||||
assert row["category"] == "projects"
|
||||
assert row["importance"] == 0.9
|
||||
|
||||
def test_pull_deletes_soft_deleted(self, engine):
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
engine._conn.execute(
|
||||
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at, server_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
("to be deleted", "facts", "", "", 0.5, 0, now, now, 20),
|
||||
)
|
||||
engine._conn.commit()
|
||||
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {
|
||||
"memories": [
|
||||
{
|
||||
"id": 20,
|
||||
"content": "to be deleted",
|
||||
"category": "facts",
|
||||
"tags": "",
|
||||
"expanded_keywords": "",
|
||||
"importance": 0.5,
|
||||
"is_sensitive": False,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"deleted_at": now,
|
||||
}
|
||||
],
|
||||
"server_time": now,
|
||||
}
|
||||
engine._pull_changes()
|
||||
|
||||
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 20")
|
||||
assert cursor.fetchone() is None
|
||||
|
||||
def test_pull_updates_last_sync_ts(self, engine):
|
||||
server_time = "2026-03-14T12:00:00+00:00"
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {
|
||||
"memories": [],
|
||||
"server_time": server_time,
|
||||
}
|
||||
engine._pull_changes()
|
||||
|
||||
assert engine.last_sync_ts == server_time
|
||||
|
||||
def test_pull_with_since_param(self, engine):
|
||||
engine.last_sync_ts = "2026-03-14T10:00:00+00:00"
|
||||
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {"memories": [], "server_time": "2026-03-14T12:00:00+00:00"}
|
||||
engine._pull_changes()
|
||||
|
||||
call_args = mock_api.call_args
|
||||
assert "since=2026-03-14T10:00:00+00:00" in call_args[0][1]
|
||||
|
||||
|
||||
class TestTrySyncStore:
|
||||
def test_success_returns_server_id(self, engine, sqlite_conn):
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
sqlite_conn.execute(
|
||||
"INSERT INTO memories (id, content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(1, "test", "facts", "", "kw", 0.5, 0, now, now),
|
||||
)
|
||||
sqlite_conn.commit()
|
||||
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {"id": 300, "category": "facts", "importance": 0.5}
|
||||
result = engine.try_sync_store(1, "test", "facts", "", "kw", 0.5)
|
||||
|
||||
assert result == 300
|
||||
|
||||
def test_failure_enqueues_op(self, engine):
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.side_effect = RuntimeError("Connection refused")
|
||||
result = engine.try_sync_store(1, "test", "facts", "", "kw", 0.5)
|
||||
|
||||
assert result is None
|
||||
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||
assert cursor.fetchone()["cnt"] == 1
|
||||
|
||||
|
||||
class TestTrySyncDelete:
|
||||
def test_success_returns_true(self, engine):
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {"deleted": 42, "preview": "test"}
|
||||
result = engine.try_sync_delete(42)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_failure_enqueues_op(self, engine):
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.side_effect = RuntimeError("Connection refused")
|
||||
result = engine.try_sync_delete(42)
|
||||
|
||||
assert result is False
|
||||
cursor = engine._conn.execute("SELECT COUNT(*) as cnt FROM pending_ops")
|
||||
assert cursor.fetchone()["cnt"] == 1
|
||||
|
||||
|
||||
class TestSyncMeta:
|
||||
def test_last_sync_ts_none_initially(self, engine):
|
||||
assert engine.last_sync_ts is None
|
||||
|
||||
def test_last_sync_ts_persists(self, engine):
|
||||
engine.last_sync_ts = "2026-03-14T12:00:00+00:00"
|
||||
assert engine.last_sync_ts == "2026-03-14T12:00:00+00:00"
|
||||
|
||||
def test_api_available_initially_false(self, engine):
|
||||
assert engine.api_available is False
|
||||
|
||||
|
||||
class TestFullSyncCycle:
|
||||
def test_store_sync_push_delete_pull(self, engine, sqlite_conn):
|
||||
"""Full cycle: store locally → push to API → server deletes → pull removes local."""
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# 1. Store locally
|
||||
sqlite_conn.execute(
|
||||
"INSERT INTO memories (id, content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(1, "cycle test", "facts", "", "cycle test kw", 0.5, 0, now, now),
|
||||
)
|
||||
sqlite_conn.commit()
|
||||
|
||||
# 2. Enqueue and push store
|
||||
engine.enqueue_store(1, "cycle test", "facts", "", "cycle test kw", 0.5)
|
||||
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {"id": 500, "category": "facts", "importance": 0.5}
|
||||
engine._push_pending_ops()
|
||||
|
||||
# Verify server_id set
|
||||
cursor = engine._conn.execute("SELECT server_id FROM memories WHERE id = 1")
|
||||
assert cursor.fetchone()["server_id"] == 500
|
||||
|
||||
# 3. Server soft-deletes → pull removes local
|
||||
with patch.object(engine, "_api_request") as mock_api:
|
||||
mock_api.return_value = {
|
||||
"memories": [
|
||||
{
|
||||
"id": 500,
|
||||
"content": "cycle test",
|
||||
"category": "facts",
|
||||
"tags": "",
|
||||
"expanded_keywords": "cycle test kw",
|
||||
"importance": 0.5,
|
||||
"is_sensitive": False,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"deleted_at": now,
|
||||
}
|
||||
],
|
||||
"server_time": now,
|
||||
}
|
||||
engine._pull_changes()
|
||||
|
||||
# Should be gone locally
|
||||
cursor = engine._conn.execute("SELECT * FROM memories WHERE server_id = 500")
|
||||
assert cursor.fetchone() is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue