add MCP SSE transport for direct Claude Code connection
Adds SSE endpoint at /mcp/sse so Claude Code can connect over HTTPS without needing a local Python bridge script. Benefits: - No local files or sandbox permission issues - Works from any machine (OpenClaw, DevVM) - No startup delay or stderr suppression hack - Auth via Bearer token in request headers
This commit is contained in:
parent
18e27d07d2
commit
48df739c82
2 changed files with 211 additions and 3 deletions
|
|
@ -16,7 +16,7 @@ classifiers = [
|
|||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
api = ["fastapi>=0.115", "asyncpg>=0.30", "uvicorn>=0.34", "pydantic>=2.0", "alembic>=1.14", "sqlalchemy>=2.0", "psycopg2-binary>=2.9"]
|
||||
api = ["fastapi>=0.115", "asyncpg>=0.30", "uvicorn>=0.34", "pydantic>=2.0", "alembic>=1.14", "sqlalchemy>=2.0", "psycopg2-binary>=2.9", "mcp>=1.0.0"]
|
||||
vault = ["hvac>=2.0"]
|
||||
dev = ["pytest>=8.0", "pytest-asyncio>=0.24", "ruff>=0.8", "mypy>=1.13", "httpx>=0.28", "cryptography>=43.0"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,19 @@
|
|||
"""Claude Memory API -- shared persistent memory with PostgreSQL full-text search."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from claude_memory.api.auth import AuthUser, get_current_user
|
||||
from claude_memory.api.auth import AuthUser, get_current_user, _key_to_user
|
||||
from claude_memory.api.database import close_pool, get_pool, init_pool
|
||||
from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse, SyncResponse
|
||||
from claude_memory.api.vault_service import (
|
||||
|
|
@ -431,3 +437,205 @@ async def import_memories(
|
|||
)
|
||||
|
||||
return imported
|
||||
|
||||
|
||||
# --- MCP SSE Transport ---
|
||||
|
||||
|
||||
def _resolve_user_from_token(token: str) -> str | None:
|
||||
"""Resolve API key to user_id, reusing auth module's key map."""
|
||||
return _key_to_user.get(token)
|
||||
|
||||
|
||||
mcp_server = FastMCP("claude-memory")
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def memory_store(content: str, category: str = "facts", tags: str = "",
|
||||
expanded_keywords: str = "", importance: float = 0.5) -> str:
|
||||
"""Store a new memory."""
|
||||
pool = await get_pool()
|
||||
# MCP SSE uses "default" user (single-user mode via middleware auth)
|
||||
user_id = "default"
|
||||
is_sensitive = _detect_sensitive(content)
|
||||
stored_content = content if not is_sensitive else _redact_content(content)
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""INSERT INTO memories (user_id, content, category, tags, expanded_keywords, importance, is_sensitive)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id""",
|
||||
user_id, stored_content, category, tags, expanded_keywords, importance, is_sensitive,
|
||||
)
|
||||
memory_id = row["id"]
|
||||
|
||||
if is_sensitive and is_vault_configured():
|
||||
vault_path = await store_secret(user_id, memory_id, content)
|
||||
await conn.execute("UPDATE memories SET vault_path = $1 WHERE id = $2", vault_path, memory_id)
|
||||
|
||||
return json.dumps({"id": memory_id, "category": category, "importance": importance})
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def memory_recall(context: str, expanded_query: str = "",
|
||||
category: str | None = None, sort_by: str = "importance",
|
||||
limit: int = 10) -> str:
|
||||
"""Recall memories by semantic search."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
query_text = f"{context} {expanded_query}".strip()
|
||||
if not query_text:
|
||||
return json.dumps({"error": "context is required"})
|
||||
|
||||
hybrid_score = "(ts_rank(search_vector, query) * 0.7 + importance * 0.3)"
|
||||
if sort_by == "importance":
|
||||
hybrid_score = "(ts_rank(search_vector, query) * 0.4 + importance * 0.6)"
|
||||
|
||||
order_clause = f"{hybrid_score} DESC"
|
||||
if sort_by == "recency":
|
||||
order_clause = "created_at DESC"
|
||||
|
||||
category_filter = ""
|
||||
params: list[Any] = [user_id, query_text, limit]
|
||||
if category:
|
||||
category_filter = "AND category = $4"
|
||||
params.append(category)
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
f"""
|
||||
SELECT id, content, category, tags, importance, is_sensitive,
|
||||
ts_rank(search_vector, query) AS rank, created_at, updated_at
|
||||
FROM memories, plainto_tsquery('english', $2) query
|
||||
WHERE user_id = $1 AND deleted_at IS NULL
|
||||
AND (search_vector @@ query OR $2 = '')
|
||||
{category_filter}
|
||||
ORDER BY {order_clause}
|
||||
LIMIT $3
|
||||
""",
|
||||
*params,
|
||||
)
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
c = row["content"]
|
||||
if row["is_sensitive"]:
|
||||
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||
results.append({
|
||||
"id": row["id"], "content": c, "category": row["category"],
|
||||
"tags": row["tags"], "importance": row["importance"],
|
||||
"rank": float(row["rank"]),
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
"updated_at": row["updated_at"].isoformat(),
|
||||
})
|
||||
|
||||
return json.dumps({"memories": results})
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def memory_list(category: str | None = None, limit: int = 20) -> str:
|
||||
"""List stored memories."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
|
||||
if category:
|
||||
query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
||||
FROM memories WHERE user_id = $1 AND deleted_at IS NULL AND category = $2
|
||||
ORDER BY importance DESC LIMIT $3"""
|
||||
params: list[Any] = [user_id, category, limit]
|
||||
else:
|
||||
query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
||||
FROM memories WHERE user_id = $1 AND deleted_at IS NULL
|
||||
ORDER BY importance DESC LIMIT $2"""
|
||||
params = [user_id, limit]
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
c = row["content"]
|
||||
if row["is_sensitive"]:
|
||||
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||
results.append({
|
||||
"id": row["id"], "content": c, "category": row["category"],
|
||||
"tags": row["tags"], "importance": row["importance"],
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
"updated_at": row["updated_at"].isoformat(),
|
||||
})
|
||||
|
||||
return json.dumps({"memories": results})
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def memory_delete(memory_id: int) -> str:
|
||||
"""Delete a memory by ID."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||
memory_id, user_id,
|
||||
)
|
||||
if not row:
|
||||
return json.dumps({"deleted": memory_id, "preview": "[already deleted]"})
|
||||
|
||||
if row["vault_path"]:
|
||||
await delete_secret(user_id, row["vault_path"])
|
||||
|
||||
await conn.execute(
|
||||
"UPDATE memories SET deleted_at = NOW(), updated_at = NOW() WHERE id = $1 AND user_id = $2",
|
||||
memory_id, user_id,
|
||||
)
|
||||
|
||||
return json.dumps({"deleted": memory_id, "preview": row["preview"]})
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def memory_count() -> str:
|
||||
"""Count total memories."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
async with pool.acquire() as conn:
|
||||
count = await conn.fetchval("SELECT COUNT(*) FROM memories WHERE user_id = $1 AND deleted_at IS NULL", user_id)
|
||||
return json.dumps({"count": count})
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def secret_get(key: str) -> str:
|
||||
"""Retrieve a secret value by key. Returns empty if not found."""
|
||||
return json.dumps({"error": "secret_get is not available via SSE transport"})
|
||||
|
||||
|
||||
# Auth middleware for /mcp/* routes
|
||||
class MCPAuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next): # type: ignore[override]
|
||||
if request.url.path.startswith("/mcp"):
|
||||
auth = request.headers.get("authorization", "")
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
if not _resolve_user_from_token(token):
|
||||
return Response(content="Unauthorized", status_code=401)
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
app.add_middleware(MCPAuthMiddleware)
|
||||
|
||||
# Mount SSE transport
|
||||
sse_transport = SseServerTransport("/messages/")
|
||||
|
||||
|
||||
class HandleSSE:
|
||||
"""ASGI app for SSE connections."""
|
||||
async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
|
||||
async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
|
||||
await mcp_server._mcp_server.run(
|
||||
read_stream, write_stream, mcp_server._mcp_server.create_initialization_options()
|
||||
)
|
||||
|
||||
|
||||
# Client connects to /mcp/sse, posts to /mcp/messages/
|
||||
app.router.routes.insert(0, Mount("/mcp", routes=[
|
||||
Route("/sse", endpoint=HandleSSE()),
|
||||
Mount("/messages", app=sse_transport.handle_post_message),
|
||||
]))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue