add MCP SSE transport for direct Claude Code connection
Adds SSE endpoint at /mcp/sse so Claude Code can connect over HTTPS without needing a local Python bridge script. Benefits: - No local files or sandbox permission issues - Works from any machine (OpenClaw, DevVM) - No startup delay or stderr suppression hack - Auth via Bearer token in request headers
This commit is contained in:
parent
18e27d07d2
commit
48df739c82
2 changed files with 211 additions and 3 deletions
|
|
@ -16,7 +16,7 @@ classifiers = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
api = ["fastapi>=0.115", "asyncpg>=0.30", "uvicorn>=0.34", "pydantic>=2.0", "alembic>=1.14", "sqlalchemy>=2.0", "psycopg2-binary>=2.9"]
|
api = ["fastapi>=0.115", "asyncpg>=0.30", "uvicorn>=0.34", "pydantic>=2.0", "alembic>=1.14", "sqlalchemy>=2.0", "psycopg2-binary>=2.9", "mcp>=1.0.0"]
|
||||||
vault = ["hvac>=2.0"]
|
vault = ["hvac>=2.0"]
|
||||||
dev = ["pytest>=8.0", "pytest-asyncio>=0.24", "ruff>=0.8", "mypy>=1.13", "httpx>=0.28", "cryptography>=43.0"]
|
dev = ["pytest>=8.0", "pytest-asyncio>=0.24", "ruff>=0.8", "mypy>=1.13", "httpx>=0.28", "cryptography>=43.0"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,19 @@
|
||||||
"""Claude Memory API -- shared persistent memory with PostgreSQL full-text search."""
|
"""Claude Memory API -- shared persistent memory with PostgreSQL full-text search."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, AsyncGenerator, Optional
|
from typing import Any, AsyncGenerator, Optional
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException
|
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
from mcp.server.sse import SseServerTransport
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.routing import Mount, Route
|
||||||
|
|
||||||
from claude_memory.api.auth import AuthUser, get_current_user
|
from claude_memory.api.auth import AuthUser, get_current_user, _key_to_user
|
||||||
from claude_memory.api.database import close_pool, get_pool, init_pool
|
from claude_memory.api.database import close_pool, get_pool, init_pool
|
||||||
from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse, SyncResponse
|
from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse, SyncResponse
|
||||||
from claude_memory.api.vault_service import (
|
from claude_memory.api.vault_service import (
|
||||||
|
|
@ -431,3 +437,205 @@ async def import_memories(
|
||||||
)
|
)
|
||||||
|
|
||||||
return imported
|
return imported
|
||||||
|
|
||||||
|
|
||||||
|
# --- MCP SSE Transport ---
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_user_from_token(token: str) -> str | None:
|
||||||
|
"""Resolve API key to user_id, reusing auth module's key map."""
|
||||||
|
return _key_to_user.get(token)
|
||||||
|
|
||||||
|
|
||||||
|
mcp_server = FastMCP("claude-memory")
|
||||||
|
|
||||||
|
|
||||||
|
@mcp_server.tool()
|
||||||
|
async def memory_store(content: str, category: str = "facts", tags: str = "",
|
||||||
|
expanded_keywords: str = "", importance: float = 0.5) -> str:
|
||||||
|
"""Store a new memory."""
|
||||||
|
pool = await get_pool()
|
||||||
|
# MCP SSE uses "default" user (single-user mode via middleware auth)
|
||||||
|
user_id = "default"
|
||||||
|
is_sensitive = _detect_sensitive(content)
|
||||||
|
stored_content = content if not is_sensitive else _redact_content(content)
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"""INSERT INTO memories (user_id, content, category, tags, expanded_keywords, importance, is_sensitive)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
|
RETURNING id""",
|
||||||
|
user_id, stored_content, category, tags, expanded_keywords, importance, is_sensitive,
|
||||||
|
)
|
||||||
|
memory_id = row["id"]
|
||||||
|
|
||||||
|
if is_sensitive and is_vault_configured():
|
||||||
|
vault_path = await store_secret(user_id, memory_id, content)
|
||||||
|
await conn.execute("UPDATE memories SET vault_path = $1 WHERE id = $2", vault_path, memory_id)
|
||||||
|
|
||||||
|
return json.dumps({"id": memory_id, "category": category, "importance": importance})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp_server.tool()
|
||||||
|
async def memory_recall(context: str, expanded_query: str = "",
|
||||||
|
category: str | None = None, sort_by: str = "importance",
|
||||||
|
limit: int = 10) -> str:
|
||||||
|
"""Recall memories by semantic search."""
|
||||||
|
pool = await get_pool()
|
||||||
|
user_id = "default"
|
||||||
|
query_text = f"{context} {expanded_query}".strip()
|
||||||
|
if not query_text:
|
||||||
|
return json.dumps({"error": "context is required"})
|
||||||
|
|
||||||
|
hybrid_score = "(ts_rank(search_vector, query) * 0.7 + importance * 0.3)"
|
||||||
|
if sort_by == "importance":
|
||||||
|
hybrid_score = "(ts_rank(search_vector, query) * 0.4 + importance * 0.6)"
|
||||||
|
|
||||||
|
order_clause = f"{hybrid_score} DESC"
|
||||||
|
if sort_by == "recency":
|
||||||
|
order_clause = "created_at DESC"
|
||||||
|
|
||||||
|
category_filter = ""
|
||||||
|
params: list[Any] = [user_id, query_text, limit]
|
||||||
|
if category:
|
||||||
|
category_filter = "AND category = $4"
|
||||||
|
params.append(category)
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch(
|
||||||
|
f"""
|
||||||
|
SELECT id, content, category, tags, importance, is_sensitive,
|
||||||
|
ts_rank(search_vector, query) AS rank, created_at, updated_at
|
||||||
|
FROM memories, plainto_tsquery('english', $2) query
|
||||||
|
WHERE user_id = $1 AND deleted_at IS NULL
|
||||||
|
AND (search_vector @@ query OR $2 = '')
|
||||||
|
{category_filter}
|
||||||
|
ORDER BY {order_clause}
|
||||||
|
LIMIT $3
|
||||||
|
""",
|
||||||
|
*params,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
c = row["content"]
|
||||||
|
if row["is_sensitive"]:
|
||||||
|
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||||
|
results.append({
|
||||||
|
"id": row["id"], "content": c, "category": row["category"],
|
||||||
|
"tags": row["tags"], "importance": row["importance"],
|
||||||
|
"rank": float(row["rank"]),
|
||||||
|
"created_at": row["created_at"].isoformat(),
|
||||||
|
"updated_at": row["updated_at"].isoformat(),
|
||||||
|
})
|
||||||
|
|
||||||
|
return json.dumps({"memories": results})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp_server.tool()
|
||||||
|
async def memory_list(category: str | None = None, limit: int = 20) -> str:
|
||||||
|
"""List stored memories."""
|
||||||
|
pool = await get_pool()
|
||||||
|
user_id = "default"
|
||||||
|
|
||||||
|
if category:
|
||||||
|
query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
||||||
|
FROM memories WHERE user_id = $1 AND deleted_at IS NULL AND category = $2
|
||||||
|
ORDER BY importance DESC LIMIT $3"""
|
||||||
|
params: list[Any] = [user_id, category, limit]
|
||||||
|
else:
|
||||||
|
query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
||||||
|
FROM memories WHERE user_id = $1 AND deleted_at IS NULL
|
||||||
|
ORDER BY importance DESC LIMIT $2"""
|
||||||
|
params = [user_id, limit]
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch(query, *params)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
c = row["content"]
|
||||||
|
if row["is_sensitive"]:
|
||||||
|
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||||
|
results.append({
|
||||||
|
"id": row["id"], "content": c, "category": row["category"],
|
||||||
|
"tags": row["tags"], "importance": row["importance"],
|
||||||
|
"created_at": row["created_at"].isoformat(),
|
||||||
|
"updated_at": row["updated_at"].isoformat(),
|
||||||
|
})
|
||||||
|
|
||||||
|
return json.dumps({"memories": results})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp_server.tool()
|
||||||
|
async def memory_delete(memory_id: int) -> str:
|
||||||
|
"""Delete a memory by ID."""
|
||||||
|
pool = await get_pool()
|
||||||
|
user_id = "default"
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(
|
||||||
|
"SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||||
|
memory_id, user_id,
|
||||||
|
)
|
||||||
|
if not row:
|
||||||
|
return json.dumps({"deleted": memory_id, "preview": "[already deleted]"})
|
||||||
|
|
||||||
|
if row["vault_path"]:
|
||||||
|
await delete_secret(user_id, row["vault_path"])
|
||||||
|
|
||||||
|
await conn.execute(
|
||||||
|
"UPDATE memories SET deleted_at = NOW(), updated_at = NOW() WHERE id = $1 AND user_id = $2",
|
||||||
|
memory_id, user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return json.dumps({"deleted": memory_id, "preview": row["preview"]})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp_server.tool()
|
||||||
|
async def memory_count() -> str:
|
||||||
|
"""Count total memories."""
|
||||||
|
pool = await get_pool()
|
||||||
|
user_id = "default"
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
count = await conn.fetchval("SELECT COUNT(*) FROM memories WHERE user_id = $1 AND deleted_at IS NULL", user_id)
|
||||||
|
return json.dumps({"count": count})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp_server.tool()
|
||||||
|
async def secret_get(key: str) -> str:
|
||||||
|
"""Retrieve a secret value by key. Returns empty if not found."""
|
||||||
|
return json.dumps({"error": "secret_get is not available via SSE transport"})
|
||||||
|
|
||||||
|
|
||||||
|
# Auth middleware for /mcp/* routes
|
||||||
|
class MCPAuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next): # type: ignore[override]
|
||||||
|
if request.url.path.startswith("/mcp"):
|
||||||
|
auth = request.headers.get("authorization", "")
|
||||||
|
token = auth.removeprefix("Bearer ").strip()
|
||||||
|
if not _resolve_user_from_token(token):
|
||||||
|
return Response(content="Unauthorized", status_code=401)
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
app.add_middleware(MCPAuthMiddleware)
|
||||||
|
|
||||||
|
# Mount SSE transport
|
||||||
|
sse_transport = SseServerTransport("/messages/")
|
||||||
|
|
||||||
|
|
||||||
|
class HandleSSE:
|
||||||
|
"""ASGI app for SSE connections."""
|
||||||
|
async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
|
||||||
|
async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
|
||||||
|
await mcp_server._mcp_server.run(
|
||||||
|
read_stream, write_stream, mcp_server._mcp_server.create_initialization_options()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Client connects to /mcp/sse, posts to /mcp/messages/
|
||||||
|
app.router.routes.insert(0, Mount("/mcp", routes=[
|
||||||
|
Route("/sse", endpoint=HandleSSE()),
|
||||||
|
Mount("/messages", app=sse_transport.handle_post_message),
|
||||||
|
]))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue