feat: multi-user MCP SSE support + shared memories in recall/list
- Use contextvars to resolve user identity from Authorization header in SSE connections, replacing hardcoded "default" user_id - memory_recall now includes shared memories (individual + tag-based) with deduplication and shared_by attribution - memory_list now includes shared memories with same approach - All 11 MCP tool functions use _current_user contextvar
This commit is contained in:
parent
95dd937765
commit
f242c45c73
1 changed files with 120 additions and 12 deletions
|
|
@ -4,6 +4,7 @@ import json
|
|||
import logging
|
||||
import pathlib
|
||||
from contextlib import asynccontextmanager
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
|
|
@ -31,6 +32,9 @@ from claude_memory.api.vault_service import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variable for MCP SSE multi-user support
|
||||
_current_user: ContextVar[str] = ContextVar("_current_user", default="default")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
|
@ -846,8 +850,7 @@ async def memory_store(content: str, category: str = "facts", tags: str = "",
|
|||
expanded_keywords: str = "", importance: float = 0.5) -> str:
|
||||
"""Store a new memory."""
|
||||
pool = await get_pool()
|
||||
# MCP SSE uses "default" user (single-user mode via middleware auth)
|
||||
user_id = "default"
|
||||
user_id = _current_user.get()
|
||||
is_sensitive = _detect_sensitive(content)
|
||||
stored_content = content if not is_sensitive else _redact_content(content)
|
||||
|
||||
|
|
@ -873,7 +876,7 @@ async def memory_recall(context: str, expanded_query: str = "",
|
|||
limit: int = 10) -> str:
|
||||
"""Recall memories by semantic search."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
user_id = _current_user.get()
|
||||
query_text = f"{context} {expanded_query}".strip()
|
||||
if not query_text:
|
||||
return json.dumps({"error": "context is required"})
|
||||
|
|
@ -896,7 +899,8 @@ async def memory_recall(context: str, expanded_query: str = "",
|
|||
rows = await conn.fetch(
|
||||
f"""
|
||||
SELECT id, content, category, tags, importance, is_sensitive,
|
||||
ts_rank(search_vector, query) AS rank, created_at, updated_at
|
||||
ts_rank(search_vector, query) AS rank, created_at, updated_at,
|
||||
NULL::text AS shared_by
|
||||
FROM memories, plainto_tsquery('english', $2) query
|
||||
WHERE user_id = $1 AND deleted_at IS NULL
|
||||
AND (search_vector @@ query OR $2 = '')
|
||||
|
|
@ -907,8 +911,50 @@ async def memory_recall(context: str, expanded_query: str = "",
|
|||
*params,
|
||||
)
|
||||
|
||||
# Also fetch shared memories (individual + tag-based)
|
||||
shared_rows = await conn.fetch(
|
||||
f"""
|
||||
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
|
||||
m.is_sensitive, ts_rank(m.search_vector, query) AS rank,
|
||||
m.created_at, m.updated_at, m.user_id AS shared_by
|
||||
FROM memories m, plainto_tsquery('english', $2) query
|
||||
WHERE m.deleted_at IS NULL
|
||||
AND (m.search_vector @@ query OR $2 = '')
|
||||
AND m.user_id != $1
|
||||
AND (
|
||||
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
|
||||
OR EXISTS (
|
||||
SELECT 1 FROM tag_shares ts
|
||||
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
|
||||
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
|
||||
)
|
||||
)
|
||||
ORDER BY m.id
|
||||
LIMIT $3
|
||||
""",
|
||||
*params,
|
||||
)
|
||||
|
||||
seen_ids = set()
|
||||
results = []
|
||||
for row in rows:
|
||||
seen_ids.add(row["id"])
|
||||
c = row["content"]
|
||||
if row["is_sensitive"]:
|
||||
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||
entry: dict[str, Any] = {
|
||||
"id": row["id"], "content": c, "category": row["category"],
|
||||
"tags": row["tags"], "importance": row["importance"],
|
||||
"rank": float(row["rank"]),
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
"updated_at": row["updated_at"].isoformat(),
|
||||
}
|
||||
results.append(entry)
|
||||
|
||||
for row in shared_rows:
|
||||
if row["id"] in seen_ids:
|
||||
continue
|
||||
seen_ids.add(row["id"])
|
||||
c = row["content"]
|
||||
if row["is_sensitive"]:
|
||||
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||
|
|
@ -916,6 +962,7 @@ async def memory_recall(context: str, expanded_query: str = "",
|
|||
"id": row["id"], "content": c, "category": row["category"],
|
||||
"tags": row["tags"], "importance": row["importance"],
|
||||
"rank": float(row["rank"]),
|
||||
"shared_by": row["shared_by"],
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
"updated_at": row["updated_at"].isoformat(),
|
||||
})
|
||||
|
|
@ -927,7 +974,7 @@ async def memory_recall(context: str, expanded_query: str = "",
|
|||
async def memory_list(category: str | None = None, limit: int = 20) -> str:
|
||||
"""List stored memories."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
user_id = _current_user.get()
|
||||
|
||||
if category:
|
||||
query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
||||
|
|
@ -940,11 +987,47 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
|
|||
ORDER BY importance DESC LIMIT $2"""
|
||||
params = [user_id, limit]
|
||||
|
||||
if category:
|
||||
shared_query = """
|
||||
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
|
||||
m.is_sensitive, m.created_at, m.updated_at, m.user_id AS shared_by
|
||||
FROM memories m
|
||||
WHERE m.deleted_at IS NULL AND m.category = $2 AND m.user_id != $1
|
||||
AND (
|
||||
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
|
||||
OR EXISTS (
|
||||
SELECT 1 FROM tag_shares ts
|
||||
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
|
||||
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
|
||||
)
|
||||
)
|
||||
ORDER BY m.id LIMIT $3"""
|
||||
shared_params: list[Any] = [user_id, category, limit]
|
||||
else:
|
||||
shared_query = """
|
||||
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance,
|
||||
m.is_sensitive, m.created_at, m.updated_at, m.user_id AS shared_by
|
||||
FROM memories m
|
||||
WHERE m.deleted_at IS NULL AND m.user_id != $1
|
||||
AND (
|
||||
EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1)
|
||||
OR EXISTS (
|
||||
SELECT 1 FROM tag_shares ts
|
||||
WHERE ts.owner_id = m.user_id AND ts.shared_with = $1
|
||||
AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag)
|
||||
)
|
||||
)
|
||||
ORDER BY m.id LIMIT $2"""
|
||||
shared_params = [user_id, limit]
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(query, *params)
|
||||
shared_rows = await conn.fetch(shared_query, *shared_params)
|
||||
|
||||
seen_ids = set()
|
||||
results = []
|
||||
for row in rows:
|
||||
seen_ids.add(row["id"])
|
||||
c = row["content"]
|
||||
if row["is_sensitive"]:
|
||||
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||
|
|
@ -955,6 +1038,21 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
|
|||
"updated_at": row["updated_at"].isoformat(),
|
||||
})
|
||||
|
||||
for row in shared_rows:
|
||||
if row["id"] in seen_ids:
|
||||
continue
|
||||
seen_ids.add(row["id"])
|
||||
c = row["content"]
|
||||
if row["is_sensitive"]:
|
||||
c = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||
results.append({
|
||||
"id": row["id"], "content": c, "category": row["category"],
|
||||
"tags": row["tags"], "importance": row["importance"],
|
||||
"shared_by": row["shared_by"],
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
"updated_at": row["updated_at"].isoformat(),
|
||||
})
|
||||
|
||||
return json.dumps({"memories": results})
|
||||
|
||||
|
||||
|
|
@ -962,7 +1060,7 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str:
|
|||
async def memory_delete(memory_id: int) -> str:
|
||||
"""Delete a memory by ID."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
user_id = _current_user.get()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
|
|
@ -987,7 +1085,7 @@ async def memory_delete(memory_id: int) -> str:
|
|||
async def memory_count() -> str:
|
||||
"""Count total memories."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
user_id = _current_user.get()
|
||||
async with pool.acquire() as conn:
|
||||
count = await conn.fetchval("SELECT COUNT(*) FROM memories WHERE user_id = $1 AND deleted_at IS NULL", user_id)
|
||||
return json.dumps({"count": count})
|
||||
|
|
@ -1005,7 +1103,7 @@ async def memory_share(id: int, shared_with: str, permission: str = "read") -> s
|
|||
if permission not in ("read", "write"):
|
||||
return json.dumps({"error": "permission must be 'read' or 'write'"})
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
user_id = _current_user.get()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||
|
|
@ -1026,7 +1124,7 @@ async def memory_share(id: int, shared_with: str, permission: str = "read") -> s
|
|||
async def memory_unshare(id: int, shared_with: str) -> str:
|
||||
"""Revoke sharing of a memory from a user."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
user_id = _current_user.get()
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"DELETE FROM memory_shares WHERE memory_id = $1 AND owner_id = $2 AND shared_with = $3",
|
||||
|
|
@ -1041,7 +1139,7 @@ async def memory_share_tag(tag: str, shared_with: str, permission: str = "read")
|
|||
if permission not in ("read", "write"):
|
||||
return json.dumps({"error": "permission must be 'read' or 'write'"})
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
user_id = _current_user.get()
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""INSERT INTO tag_shares (owner_id, tag, shared_with, permission)
|
||||
|
|
@ -1056,7 +1154,7 @@ async def memory_share_tag(tag: str, shared_with: str, permission: str = "read")
|
|||
async def memory_unshare_tag(tag: str, shared_with: str) -> str:
|
||||
"""Revoke tag-based sharing."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
user_id = _current_user.get()
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3",
|
||||
|
|
@ -1070,7 +1168,7 @@ async def memory_update(id: int, content: str | None = None, tags: str | None =
|
|||
importance: float | None = None, expanded_keywords: str | None = None) -> str:
|
||||
"""Update an existing memory's content, tags, importance, or keywords."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
user_id = _current_user.get()
|
||||
async with pool.acquire() as conn:
|
||||
allowed, owner_id = await check_memory_permission(conn, id, user_id, "write")
|
||||
if not allowed:
|
||||
|
|
@ -1132,6 +1230,16 @@ sse_transport = SseServerTransport("/messages/")
|
|||
class HandleSSE:
|
||||
"""ASGI app for SSE connections."""
|
||||
async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
|
||||
# Extract user from Authorization header for multi-user MCP
|
||||
user_id = "default"
|
||||
for name, value in scope.get("headers", []):
|
||||
if name == b"authorization":
|
||||
token = value.decode().removeprefix("Bearer ").strip()
|
||||
resolved = _resolve_user_from_token(token)
|
||||
if resolved:
|
||||
user_id = resolved
|
||||
break
|
||||
_current_user.set(user_id)
|
||||
async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream):
|
||||
await mcp_server._mcp_server.run(
|
||||
read_stream, write_stream, mcp_server._mcp_server.create_initialization_options()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue