From f242c45c73e064b138b9648e676f25d8efcf4658 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Sun, 22 Mar 2026 22:50:18 +0200 Subject: [PATCH] feat: multi-user MCP SSE support + shared memories in recall/list - Use contextvars to resolve user identity from Authorization header in SSE connections, replacing hardcoded "default" user_id - memory_recall now includes shared memories (individual + tag-based) with deduplication and shared_by attribution - memory_list now includes shared memories with same approach - All 11 MCP tool functions use _current_user contextvar --- src/claude_memory/api/app.py | 132 +++++++++++++++++++++++++++++++---- 1 file changed, 120 insertions(+), 12 deletions(-) diff --git a/src/claude_memory/api/app.py b/src/claude_memory/api/app.py index 12aaf67..526c0b5 100644 --- a/src/claude_memory/api/app.py +++ b/src/claude_memory/api/app.py @@ -4,6 +4,7 @@ import json import logging import pathlib from contextlib import asynccontextmanager +from contextvars import ContextVar from datetime import datetime, timezone from typing import Any, AsyncGenerator, Optional @@ -31,6 +32,9 @@ from claude_memory.api.vault_service import ( logger = logging.getLogger(__name__) +# Context variable for MCP SSE multi-user support +_current_user: ContextVar[str] = ContextVar("_current_user", default="default") + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: @@ -846,8 +850,7 @@ async def memory_store(content: str, category: str = "facts", tags: str = "", expanded_keywords: str = "", importance: float = 0.5) -> str: """Store a new memory.""" pool = await get_pool() - # MCP SSE uses "default" user (single-user mode via middleware auth) - user_id = "default" + user_id = _current_user.get() is_sensitive = _detect_sensitive(content) stored_content = content if not is_sensitive else _redact_content(content) @@ -873,7 +876,7 @@ async def memory_recall(context: str, expanded_query: str = "", limit: int = 10) -> str: """Recall memories by semantic search.""" pool = await get_pool() - user_id = "default" + user_id = _current_user.get() query_text = f"{context} {expanded_query}".strip() if not query_text: return json.dumps({"error": "context is required"}) @@ -896,7 +899,8 @@ async def memory_recall(context: str, expanded_query: str = "", rows = await conn.fetch( f""" SELECT id, content, category, tags, importance, is_sensitive, - ts_rank(search_vector, query) AS rank, created_at, updated_at + ts_rank(search_vector, query) AS rank, created_at, updated_at, + NULL::text AS shared_by FROM memories, plainto_tsquery('english', $2) query WHERE user_id = $1 AND deleted_at IS NULL AND (search_vector @@ query OR $2 = '') @@ -907,8 +911,50 @@ async def memory_recall(context: str, expanded_query: str = "", *params, ) + # Also fetch shared memories (individual + tag-based) + shared_rows = await conn.fetch( + f""" + SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance, + m.is_sensitive, ts_rank(m.search_vector, query) AS rank, + m.created_at, m.updated_at, m.user_id AS shared_by + FROM memories m, plainto_tsquery('english', $2) query + WHERE m.deleted_at IS NULL + AND (m.search_vector @@ query OR $2 = '') + AND m.user_id != $1 + AND ( + EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1) + OR EXISTS ( + SELECT 1 FROM tag_shares ts + WHERE ts.owner_id = m.user_id AND ts.shared_with = $1 + AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag) + ) + ) + ORDER BY m.id + LIMIT $3 + """, + *params, + ) + + seen_ids = set() results = [] for row in rows: + seen_ids.add(row["id"]) + c = row["content"] + if row["is_sensitive"]: + c = f"[SENSITIVE - use secret_get(id={row['id']})]" + entry: dict[str, Any] = { + "id": row["id"], "content": c, "category": row["category"], + "tags": row["tags"], "importance": row["importance"], + "rank": float(row["rank"]), + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + results.append(entry) + + for row in shared_rows: + if row["id"] in seen_ids: + continue + seen_ids.add(row["id"]) c = row["content"] if row["is_sensitive"]: c = f"[SENSITIVE - use secret_get(id={row['id']})]" @@ -916,6 +962,7 @@ async def memory_recall(context: str, expanded_query: str = "", "id": row["id"], "content": c, "category": row["category"], "tags": row["tags"], "importance": row["importance"], "rank": float(row["rank"]), + "shared_by": row["shared_by"], "created_at": row["created_at"].isoformat(), "updated_at": row["updated_at"].isoformat(), }) @@ -927,7 +974,7 @@ async def memory_recall(context: str, expanded_query: str = "", async def memory_list(category: str | None = None, limit: int = 20) -> str: """List stored memories.""" pool = await get_pool() - user_id = "default" + user_id = _current_user.get() if category: query = """SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at @@ -940,11 +987,47 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str: ORDER BY importance DESC LIMIT $2""" params = [user_id, limit] + if category: + shared_query = """ + SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance, + m.is_sensitive, m.created_at, m.updated_at, m.user_id AS shared_by + FROM memories m + WHERE m.deleted_at IS NULL AND m.category = $2 AND m.user_id != $1 + AND ( + EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1) + OR EXISTS ( + SELECT 1 FROM tag_shares ts + WHERE ts.owner_id = m.user_id AND ts.shared_with = $1 + AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag) + ) + ) + ORDER BY m.id LIMIT $3""" + shared_params: list[Any] = [user_id, category, limit] + else: + shared_query = """ + SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance, + m.is_sensitive, m.created_at, m.updated_at, m.user_id AS shared_by + FROM memories m + WHERE m.deleted_at IS NULL AND m.user_id != $1 + AND ( + EXISTS (SELECT 1 FROM memory_shares ms WHERE ms.memory_id = m.id AND ms.shared_with = $1) + OR EXISTS ( + SELECT 1 FROM tag_shares ts + WHERE ts.owner_id = m.user_id AND ts.shared_with = $1 + AND EXISTS (SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t WHERE trim(t) = ts.tag) + ) + ) + ORDER BY m.id LIMIT $2""" + shared_params = [user_id, limit] + async with pool.acquire() as conn: rows = await conn.fetch(query, *params) + shared_rows = await conn.fetch(shared_query, *shared_params) + seen_ids = set() results = [] for row in rows: + seen_ids.add(row["id"]) c = row["content"] if row["is_sensitive"]: c = f"[SENSITIVE - use secret_get(id={row['id']})]" @@ -955,6 +1038,21 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str: "updated_at": row["updated_at"].isoformat(), }) + for row in shared_rows: + if row["id"] in seen_ids: + continue + seen_ids.add(row["id"]) + c = row["content"] + if row["is_sensitive"]: + c = f"[SENSITIVE - use secret_get(id={row['id']})]" + results.append({ + "id": row["id"], "content": c, "category": row["category"], + "tags": row["tags"], "importance": row["importance"], + "shared_by": row["shared_by"], + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + }) + return json.dumps({"memories": results}) @@ -962,7 +1060,7 @@ async def memory_list(category: str | None = None, limit: int = 20) -> str: async def memory_delete(memory_id: int) -> str: """Delete a memory by ID.""" pool = await get_pool() - user_id = "default" + user_id = _current_user.get() async with pool.acquire() as conn: row = await conn.fetchrow( @@ -987,7 +1085,7 @@ async def memory_delete(memory_id: int) -> str: async def memory_count() -> str: """Count total memories.""" pool = await get_pool() - user_id = "default" + user_id = _current_user.get() async with pool.acquire() as conn: count = await conn.fetchval("SELECT COUNT(*) FROM memories WHERE user_id = $1 AND deleted_at IS NULL", user_id) return json.dumps({"count": count}) @@ -1005,7 +1103,7 @@ async def memory_share(id: int, shared_with: str, permission: str = "read") -> s if permission not in ("read", "write"): return json.dumps({"error": "permission must be 'read' or 'write'"}) pool = await get_pool() - user_id = "default" + user_id = _current_user.get() async with pool.acquire() as conn: row = await conn.fetchrow( "SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL", @@ -1026,7 +1124,7 @@ async def memory_share(id: int, shared_with: str, permission: str = "read") -> s async def memory_unshare(id: int, shared_with: str) -> str: """Revoke sharing of a memory from a user.""" pool = await get_pool() - user_id = "default" + user_id = _current_user.get() async with pool.acquire() as conn: await conn.execute( "DELETE FROM memory_shares WHERE memory_id = $1 AND owner_id = $2 AND shared_with = $3", @@ -1041,7 +1139,7 @@ async def memory_share_tag(tag: str, shared_with: str, permission: str = "read") if permission not in ("read", "write"): return json.dumps({"error": "permission must be 'read' or 'write'"}) pool = await get_pool() - user_id = "default" + user_id = _current_user.get() async with pool.acquire() as conn: await conn.execute( """INSERT INTO tag_shares (owner_id, tag, shared_with, permission) @@ -1056,7 +1154,7 @@ async def memory_share_tag(tag: str, shared_with: str, permission: str = "read") async def memory_unshare_tag(tag: str, shared_with: str) -> str: """Revoke tag-based sharing.""" pool = await get_pool() - user_id = "default" + user_id = _current_user.get() async with pool.acquire() as conn: await conn.execute( "DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3", @@ -1070,7 +1168,7 @@ async def memory_update(id: int, content: str | None = None, tags: str | None = importance: float | None = None, expanded_keywords: str | None = None) -> str: """Update an existing memory's content, tags, importance, or keywords.""" pool = await get_pool() - user_id = "default" + user_id = _current_user.get() async with pool.acquire() as conn: allowed, owner_id = await check_memory_permission(conn, id, user_id, "write") if not allowed: @@ -1132,6 +1230,16 @@ sse_transport = SseServerTransport("/messages/") class HandleSSE: """ASGI app for SSE connections.""" async def __call__(self, scope: Any, receive: Any, send: Any) -> None: + # Extract user from Authorization header for multi-user MCP + user_id = "default" + for name, value in scope.get("headers", []): + if name == b"authorization": + token = value.decode().removeprefix("Bearer ").strip() + resolved = _resolve_user_from_token(token) + if resolved: + user_id = resolved + break + _current_user.set(user_id) async with sse_transport.connect_sse(scope, receive, send) as (read_stream, write_stream): await mcp_server._mcp_server.run( read_stream, write_stream, mcp_server._mcp_server.create_initialization_options()