188 lines
6.9 KiB
Python
188 lines
6.9 KiB
Python
|
|
"""Claude Memory API — shared persistent memory with PostgreSQL full-text search."""
|
||
|
|
|
||
|
|
import os
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
import asyncpg
|
||
|
|
from fastapi import Depends, FastAPI, Header, HTTPException
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
|
||
|
|
DATABASE_URL = os.environ["DATABASE_URL"]
|
||
|
|
API_KEY = os.environ["API_KEY"]
|
||
|
|
|
||
|
|
pool: asyncpg.Pool
|
||
|
|
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def lifespan(app: FastAPI):
|
||
|
|
global pool
|
||
|
|
pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||
|
|
async with pool.acquire() as conn:
|
||
|
|
await conn.execute("""
|
||
|
|
CREATE TABLE IF NOT EXISTS memories (
|
||
|
|
id SERIAL PRIMARY KEY,
|
||
|
|
content TEXT NOT NULL,
|
||
|
|
category VARCHAR(50) DEFAULT 'facts',
|
||
|
|
tags TEXT DEFAULT '',
|
||
|
|
expanded_keywords TEXT DEFAULT '',
|
||
|
|
importance REAL DEFAULT 0.5,
|
||
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||
|
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||
|
|
search_vector tsvector GENERATED ALWAYS AS (
|
||
|
|
setweight(to_tsvector('english', coalesce(content, '')), 'A') ||
|
||
|
|
setweight(to_tsvector('english', coalesce(expanded_keywords, '')), 'B') ||
|
||
|
|
setweight(to_tsvector('english', coalesce(tags, '')), 'C') ||
|
||
|
|
setweight(to_tsvector('english', coalesce(category, '')), 'D')
|
||
|
|
) STORED
|
||
|
|
)
|
||
|
|
""")
|
||
|
|
await conn.execute("""
|
||
|
|
CREATE INDEX IF NOT EXISTS idx_memories_search
|
||
|
|
ON memories USING GIN(search_vector)
|
||
|
|
""")
|
||
|
|
yield
|
||
|
|
await pool.close()
|
||
|
|
|
||
|
|
|
||
|
|
app = FastAPI(title="Claude Memory API", lifespan=lifespan)
|
||
|
|
|
||
|
|
|
||
|
|
async def verify_api_key(authorization: str = Header(...)):
|
||
|
|
if authorization != f"Bearer {API_KEY}":
|
||
|
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||
|
|
|
||
|
|
|
||
|
|
class MemoryStore(BaseModel):
|
||
|
|
content: str
|
||
|
|
category: str = "facts"
|
||
|
|
tags: str = ""
|
||
|
|
expanded_keywords: str = ""
|
||
|
|
importance: float = Field(default=0.5, ge=0.0, le=1.0)
|
||
|
|
|
||
|
|
|
||
|
|
class MemoryRecall(BaseModel):
|
||
|
|
context: str
|
||
|
|
expanded_query: str = ""
|
||
|
|
category: Optional[str] = None
|
||
|
|
sort_by: str = "importance"
|
||
|
|
limit: int = 10
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/health")
|
||
|
|
async def health():
|
||
|
|
async with pool.acquire() as conn:
|
||
|
|
await conn.fetchval("SELECT 1")
|
||
|
|
return {"status": "ok"}
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/api/memories", dependencies=[Depends(verify_api_key)])
|
||
|
|
async def store_memory(mem: MemoryStore):
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
async with pool.acquire() as conn:
|
||
|
|
row = await conn.fetchrow(
|
||
|
|
"""INSERT INTO memories (content, category, tags, expanded_keywords, importance, created_at, updated_at)
|
||
|
|
VALUES ($1, $2, $3, $4, $5, $6, $6)
|
||
|
|
RETURNING id""",
|
||
|
|
mem.content, mem.category, mem.tags, mem.expanded_keywords, mem.importance, now,
|
||
|
|
)
|
||
|
|
return {"id": row["id"], "category": mem.category, "importance": mem.importance}
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/api/memories/recall", dependencies=[Depends(verify_api_key)])
|
||
|
|
async def recall_memories(req: MemoryRecall):
|
||
|
|
terms = f"{req.context} {req.expanded_query}".strip()
|
||
|
|
words = [w for w in terms.split() if w]
|
||
|
|
if not words:
|
||
|
|
raise HTTPException(status_code=400, detail="context is required")
|
||
|
|
|
||
|
|
if req.sort_by == "relevance":
|
||
|
|
order = "ts_rank(m.search_vector, query) DESC, m.importance DESC"
|
||
|
|
else:
|
||
|
|
order = "m.importance DESC, m.created_at DESC"
|
||
|
|
|
||
|
|
category_filter = "AND m.category = $3" if req.category else ""
|
||
|
|
|
||
|
|
# Use websearch_to_tsquery which handles stop words and short tokens gracefully,
|
||
|
|
# with OR between terms so any match surfaces results
|
||
|
|
websearch_input = " OR ".join(words)
|
||
|
|
|
||
|
|
sql = f"""
|
||
|
|
SELECT m.id, m.content, m.category, m.tags, m.importance,
|
||
|
|
m.created_at, ts_rank(m.search_vector, query) AS rank
|
||
|
|
FROM memories m, websearch_to_tsquery('english', $1) query
|
||
|
|
WHERE m.search_vector @@ query {category_filter}
|
||
|
|
ORDER BY {order}
|
||
|
|
LIMIT $2
|
||
|
|
"""
|
||
|
|
|
||
|
|
async with pool.acquire() as conn:
|
||
|
|
try:
|
||
|
|
if req.category:
|
||
|
|
rows = await conn.fetch(sql, websearch_input, req.limit, req.category)
|
||
|
|
else:
|
||
|
|
rows = await conn.fetch(sql, websearch_input, req.limit)
|
||
|
|
except Exception:
|
||
|
|
rows = []
|
||
|
|
|
||
|
|
if not rows:
|
||
|
|
# Fallback to ILIKE search for terms the stemmer can't handle
|
||
|
|
like = f"%{req.context}%"
|
||
|
|
sql_fallback = f"""
|
||
|
|
SELECT id, content, category, tags, importance, created_at, 0.0 AS rank
|
||
|
|
FROM memories
|
||
|
|
WHERE (content ILIKE $1 OR tags ILIKE $1 OR expanded_keywords ILIKE $1)
|
||
|
|
{("AND category = $3" if req.category else "")}
|
||
|
|
ORDER BY importance DESC, created_at DESC
|
||
|
|
LIMIT $2
|
||
|
|
"""
|
||
|
|
async with pool.acquire() as conn:
|
||
|
|
if req.category:
|
||
|
|
rows = await conn.fetch(sql_fallback, like, req.limit, req.category)
|
||
|
|
else:
|
||
|
|
rows = await conn.fetch(sql_fallback, like, req.limit)
|
||
|
|
|
||
|
|
return {"memories": [dict(r) for r in rows]}
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/api/memories", dependencies=[Depends(verify_api_key)])
|
||
|
|
async def list_memories(category: Optional[str] = None, limit: int = 20):
|
||
|
|
async with pool.acquire() as conn:
|
||
|
|
if category:
|
||
|
|
rows = await conn.fetch(
|
||
|
|
"SELECT id, content, category, tags, importance, created_at FROM memories WHERE category = $1 ORDER BY created_at DESC LIMIT $2",
|
||
|
|
category, limit,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
rows = await conn.fetch(
|
||
|
|
"SELECT id, content, category, tags, importance, created_at FROM memories ORDER BY created_at DESC LIMIT $1",
|
||
|
|
limit,
|
||
|
|
)
|
||
|
|
return {"memories": [dict(r) for r in rows]}
|
||
|
|
|
||
|
|
|
||
|
|
@app.delete("/api/memories/{memory_id}", dependencies=[Depends(verify_api_key)])
|
||
|
|
async def delete_memory(memory_id: int):
|
||
|
|
async with pool.acquire() as conn:
|
||
|
|
row = await conn.fetchrow("DELETE FROM memories WHERE id = $1 RETURNING id, substr(content, 1, 50) AS preview", memory_id)
|
||
|
|
if not row:
|
||
|
|
raise HTTPException(status_code=404, detail=f"Memory #{memory_id} not found")
|
||
|
|
return {"deleted": row["id"], "preview": row["preview"]}
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/api/memories/import", dependencies=[Depends(verify_api_key)])
|
||
|
|
async def import_memories(memories: list[MemoryStore]):
|
||
|
|
"""Bulk import memories (for migrating from SQLite)."""
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
imported = 0
|
||
|
|
async with pool.acquire() as conn:
|
||
|
|
for mem in memories:
|
||
|
|
await conn.execute(
|
||
|
|
"""INSERT INTO memories (content, category, tags, expanded_keywords, importance, created_at, updated_at)
|
||
|
|
VALUES ($1, $2, $3, $4, $5, $6, $6)""",
|
||
|
|
mem.content, mem.category, mem.tags, mem.expanded_keywords, mem.importance, now,
|
||
|
|
)
|
||
|
|
imported += 1
|
||
|
|
return {"imported": imported}
|