add multi-user memory sharing with r/w permissions
- New migration 004: memory_shares and tag_shares tables with indexes
- Share individual memories or entire tags with other users (read/write)
- Tag shares are live rules: future memories with shared tags auto-visible
- Recall query merges own + shared memories via UNION, returns shared_by field
- Owner-only delete enforcement (403 for non-owners, even with write access)
- PUT /api/memories/{id} update endpoint with permission checks
- 5 new MCP SSE tools: memory_share, memory_unshare, memory_share_tag,
memory_unshare_tag, memory_update
- Permission helper checks ownership, individual shares, and tag shares
This commit is contained in:
parent
1a275e976c
commit
f45e8ce2b3
4 changed files with 556 additions and 13 deletions
62
migrations/versions/004_add_sharing.py
Normal file
62
migrations/versions/004_add_sharing.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""Add memory sharing tables.
|
||||
|
||||
Revision ID: 004
|
||||
Revises: 003
|
||||
Create Date: 2026-03-22
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "004"
|
||||
down_revision: Union[str, None] = "003"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _table_exists(conn, table_name: str) -> bool:
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS(SELECT 1 FROM information_schema.tables "
|
||||
"WHERE table_name = :tbl)"
|
||||
),
|
||||
{"tbl": table_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
if not _table_exists(conn, "memory_shares"):
|
||||
op.create_table(
|
||||
"memory_shares",
|
||||
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
|
||||
sa.Column("memory_id", sa.Integer, sa.ForeignKey("memories.id"), nullable=False),
|
||||
sa.Column("owner_id", sa.String(100), nullable=False),
|
||||
sa.Column("shared_with", sa.String(100), nullable=False),
|
||||
sa.Column("permission", sa.String(10), nullable=False, server_default="read"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
|
||||
sa.UniqueConstraint("memory_id", "shared_with", name="uq_memory_shares_memory_user"),
|
||||
)
|
||||
op.create_index("idx_shares_shared_with", "memory_shares", ["shared_with"])
|
||||
op.create_index("idx_shares_memory_id", "memory_shares", ["memory_id"])
|
||||
|
||||
if not _table_exists(conn, "tag_shares"):
|
||||
op.create_table(
|
||||
"tag_shares",
|
||||
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
|
||||
sa.Column("owner_id", sa.String(100), nullable=False),
|
||||
sa.Column("tag", sa.String(100), nullable=False),
|
||||
sa.Column("shared_with", sa.String(100), nullable=False),
|
||||
sa.Column("permission", sa.String(10), nullable=False, server_default="read"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
|
||||
sa.UniqueConstraint("owner_id", "tag", "shared_with", name="uq_tag_shares_owner_tag_user"),
|
||||
)
|
||||
op.create_index("idx_tag_shares_shared_with", "tag_shares", ["shared_with"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("tag_shares")
|
||||
op.drop_table("memory_shares")
|
||||
|
|
@ -15,7 +15,11 @@ from starlette.routing import Mount, Route
|
|||
|
||||
from claude_memory.api.auth import AuthUser, get_current_user, _key_to_user
|
||||
from claude_memory.api.database import close_pool, get_pool, init_pool
|
||||
from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse, SyncResponse
|
||||
from claude_memory.api.models import (
|
||||
MemoryRecall, MemoryResponse, MemoryStore, MemoryUpdate,
|
||||
SecretResponse, ShareMemory, ShareTag, SyncResponse, UnshareTag,
|
||||
)
|
||||
from claude_memory.api.permissions import check_memory_permission
|
||||
from claude_memory.api.vault_service import (
|
||||
delete_secret,
|
||||
get_secret,
|
||||
|
|
@ -179,13 +183,13 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
|||
params.append(body.category)
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Try AND-match first (plainto_tsquery ANDs by default), fall back to
|
||||
# OR-match via individual word disjunction for broader results
|
||||
# Own memories (AND-match)
|
||||
rows = await conn.fetch(
|
||||
f"""
|
||||
SELECT id, content, category, tags, importance, is_sensitive,
|
||||
ts_rank(search_vector, query) AS rank,
|
||||
created_at, updated_at
|
||||
created_at, updated_at,
|
||||
NULL::text AS shared_by, NULL::text AS share_permission
|
||||
FROM memories, plainto_tsquery('english', $2) query
|
||||
WHERE user_id = $1
|
||||
AND deleted_at IS NULL
|
||||
|
|
@ -197,8 +201,65 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
|||
*params,
|
||||
)
|
||||
|
||||
# If AND-match returned too few results, broaden to OR-match
|
||||
if len(rows) < body.limit and query_text:
|
||||
# Individually shared memories
|
||||
shared_rows = await conn.fetch(
|
||||
f"""
|
||||
SELECT m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
|
||||
ts_rank(m.search_vector, query) AS rank,
|
||||
m.created_at, m.updated_at,
|
||||
m.user_id AS shared_by, ms.permission AS share_permission
|
||||
FROM memories m
|
||||
JOIN memory_shares ms ON ms.memory_id = m.id,
|
||||
plainto_tsquery('english', $2) query
|
||||
WHERE ms.shared_with = $1
|
||||
AND m.deleted_at IS NULL
|
||||
AND (m.search_vector @@ query OR $2 = '')
|
||||
{category_filter}
|
||||
ORDER BY {order_clause}
|
||||
LIMIT $3
|
||||
""",
|
||||
*params,
|
||||
)
|
||||
|
||||
# Tag-shared memories
|
||||
tag_shared_rows = await conn.fetch(
|
||||
f"""
|
||||
SELECT DISTINCT ON (m.id)
|
||||
m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
|
||||
ts_rank(m.search_vector, query) AS rank,
|
||||
m.created_at, m.updated_at,
|
||||
m.user_id AS shared_by, ts.permission AS share_permission
|
||||
FROM memories m
|
||||
JOIN tag_shares ts ON ts.owner_id = m.user_id,
|
||||
plainto_tsquery('english', $2) query
|
||||
WHERE ts.shared_with = $1
|
||||
AND m.deleted_at IS NULL
|
||||
AND (m.search_vector @@ query OR $2 = '')
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t
|
||||
WHERE trim(t) = ts.tag
|
||||
)
|
||||
{category_filter}
|
||||
ORDER BY m.id
|
||||
LIMIT $3
|
||||
""",
|
||||
*params,
|
||||
)
|
||||
|
||||
# Merge and deduplicate
|
||||
seen_ids: set[int] = set()
|
||||
all_rows = []
|
||||
for row in list(rows) + list(shared_rows) + list(tag_shared_rows):
|
||||
if row["id"] not in seen_ids:
|
||||
seen_ids.add(row["id"])
|
||||
all_rows.append(row)
|
||||
|
||||
# Sort merged results by importance desc and trim
|
||||
all_rows.sort(key=lambda r: r["importance"], reverse=True)
|
||||
all_rows = all_rows[:body.limit]
|
||||
|
||||
# If AND-match returned too few results, broaden to OR-match (own memories only)
|
||||
if len(all_rows) < body.limit and query_text:
|
||||
words = query_text.split()
|
||||
if len(words) > 1:
|
||||
or_tsquery = " | ".join(w for w in words if w)
|
||||
|
|
@ -207,12 +268,12 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
|||
if body.category:
|
||||
or_cat_filter = "AND category = $4"
|
||||
or_params.append(body.category)
|
||||
seen_ids = {r["id"] for r in rows}
|
||||
or_rows = await conn.fetch(
|
||||
f"""
|
||||
SELECT id, content, category, tags, importance, is_sensitive,
|
||||
ts_rank(search_vector, query) AS rank,
|
||||
created_at, updated_at
|
||||
created_at, updated_at,
|
||||
NULL::text AS shared_by, NULL::text AS share_permission
|
||||
FROM memories, to_tsquery('english', $2) query
|
||||
WHERE user_id = $1
|
||||
AND deleted_at IS NULL
|
||||
|
|
@ -223,11 +284,11 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
|||
""",
|
||||
*or_params,
|
||||
)
|
||||
rows = list(rows) + [r for r in or_rows if r["id"] not in seen_ids]
|
||||
rows = rows[:body.limit]
|
||||
all_rows = all_rows + [r for r in or_rows if r["id"] not in seen_ids]
|
||||
all_rows = all_rows[:body.limit]
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
for row in all_rows:
|
||||
content = row["content"]
|
||||
if row["is_sensitive"]:
|
||||
content = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||
|
|
@ -242,6 +303,8 @@ async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_curre
|
|||
"rank": float(row["rank"]),
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
"updated_at": row["updated_at"].isoformat(),
|
||||
"shared_by": row["shared_by"],
|
||||
"share_permission": row["share_permission"],
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -300,19 +363,27 @@ async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_use
|
|||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Only the owner can delete — even write-shared users cannot
|
||||
row = await conn.fetchrow(
|
||||
"SELECT id, vault_path, substr(content, 1, 50) AS preview FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||
memory_id,
|
||||
user.user_id,
|
||||
)
|
||||
if not row:
|
||||
# Check if memory exists but is owned by someone else
|
||||
exists = await conn.fetchrow(
|
||||
"SELECT id FROM memories WHERE id = $1 AND deleted_at IS NULL", memory_id
|
||||
)
|
||||
if exists:
|
||||
raise HTTPException(status_code=403, detail="Only the owner can delete a memory")
|
||||
# Idempotent: return success even if already deleted
|
||||
# Prevents old clients without 404-handling from infinite retry loops
|
||||
return {"deleted": memory_id, "preview": "[already deleted]"}
|
||||
|
||||
if row["vault_path"]:
|
||||
await delete_secret(user.user_id, row["vault_path"])
|
||||
|
||||
# Also clean up any shares for this memory
|
||||
await conn.execute("DELETE FROM memory_shares WHERE memory_id = $1", memory_id)
|
||||
await conn.execute(
|
||||
"UPDATE memories SET deleted_at = NOW(), updated_at = NOW() WHERE id = $1 AND user_id = $2",
|
||||
memory_id,
|
||||
|
|
@ -439,6 +510,221 @@ async def import_memories(
|
|||
return imported
|
||||
|
||||
|
||||
# --- Sharing Endpoints ---
|
||||
|
||||
|
||||
@app.post("/api/memories/{memory_id}/share")
|
||||
async def share_memory(memory_id: int, body: ShareMemory, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||
memory_id, user.user_id,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Memory not found or not owned by you")
|
||||
if body.shared_with == user.user_id:
|
||||
raise HTTPException(status_code=400, detail="Cannot share with yourself")
|
||||
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO memory_shares (memory_id, owner_id, shared_with, permission)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (memory_id, shared_with)
|
||||
DO UPDATE SET permission = EXCLUDED.permission
|
||||
""",
|
||||
memory_id, user.user_id, body.shared_with, body.permission,
|
||||
)
|
||||
return {"shared": memory_id, "with": body.shared_with, "permission": body.permission}
|
||||
|
||||
|
||||
@app.delete("/api/memories/{memory_id}/share/{target_user}")
|
||||
async def unshare_memory(memory_id: int, target_user: str, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||
memory_id, user.user_id,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Memory not found or not owned by you")
|
||||
|
||||
await conn.execute(
|
||||
"DELETE FROM memory_shares WHERE memory_id = $1 AND shared_with = $2",
|
||||
memory_id, target_user,
|
||||
)
|
||||
return {"unshared": memory_id, "from": target_user}
|
||||
|
||||
|
||||
@app.post("/api/memories/share-tag")
|
||||
async def share_tag(body: ShareTag, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||
pool = await get_pool()
|
||||
if body.shared_with == user.user_id:
|
||||
raise HTTPException(status_code=400, detail="Cannot share with yourself")
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO tag_shares (owner_id, tag, shared_with, permission)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (owner_id, tag, shared_with)
|
||||
DO UPDATE SET permission = EXCLUDED.permission
|
||||
""",
|
||||
user.user_id, body.tag, body.shared_with, body.permission,
|
||||
)
|
||||
return {"shared_tag": body.tag, "with": body.shared_with, "permission": body.permission}
|
||||
|
||||
|
||||
@app.delete("/api/memories/share-tag")
|
||||
async def unshare_tag(body: UnshareTag, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3",
|
||||
user.user_id, body.tag, body.shared_with,
|
||||
)
|
||||
return {"unshared_tag": body.tag, "from": body.shared_with}
|
||||
|
||||
|
||||
@app.get("/api/memories/shared-with-me")
|
||||
async def shared_with_me(user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Individual shares
|
||||
individual = await conn.fetch(
|
||||
"""
|
||||
SELECT m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
|
||||
m.created_at, m.updated_at, m.user_id AS shared_by, ms.permission
|
||||
FROM memories m
|
||||
JOIN memory_shares ms ON ms.memory_id = m.id
|
||||
WHERE ms.shared_with = $1 AND m.deleted_at IS NULL
|
||||
ORDER BY m.importance DESC
|
||||
""",
|
||||
user.user_id,
|
||||
)
|
||||
|
||||
# Tag shares
|
||||
tag_shared = await conn.fetch(
|
||||
"""
|
||||
SELECT DISTINCT ON (m.id) m.id, m.content, m.category, m.tags, m.importance, m.is_sensitive,
|
||||
m.created_at, m.updated_at, m.user_id AS shared_by, ts.permission
|
||||
FROM memories m
|
||||
JOIN tag_shares ts ON ts.owner_id = m.user_id
|
||||
WHERE ts.shared_with = $1 AND m.deleted_at IS NULL
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t
|
||||
WHERE trim(t) = ts.tag
|
||||
)
|
||||
ORDER BY m.id, m.importance DESC
|
||||
""",
|
||||
user.user_id,
|
||||
)
|
||||
|
||||
seen_ids = set()
|
||||
results = []
|
||||
for row in list(individual) + list(tag_shared):
|
||||
if row["id"] in seen_ids:
|
||||
continue
|
||||
seen_ids.add(row["id"])
|
||||
content = row["content"]
|
||||
if row["is_sensitive"]:
|
||||
content = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||
results.append({
|
||||
"id": row["id"], "content": content, "category": row["category"],
|
||||
"tags": row["tags"], "importance": row["importance"],
|
||||
"shared_by": row["shared_by"], "permission": row["permission"],
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
"updated_at": row["updated_at"].isoformat(),
|
||||
})
|
||||
|
||||
return {"memories": results}
|
||||
|
||||
|
||||
@app.get("/api/memories/my-shares")
|
||||
async def my_shares(user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
memory_shares = await conn.fetch(
|
||||
"""
|
||||
SELECT ms.memory_id, ms.shared_with, ms.permission, ms.created_at,
|
||||
substr(m.content, 1, 80) AS preview
|
||||
FROM memory_shares ms
|
||||
JOIN memories m ON m.id = ms.memory_id
|
||||
WHERE ms.owner_id = $1 AND m.deleted_at IS NULL
|
||||
ORDER BY ms.created_at DESC
|
||||
""",
|
||||
user.user_id,
|
||||
)
|
||||
tag_shares = await conn.fetch(
|
||||
"SELECT tag, shared_with, permission, created_at FROM tag_shares WHERE owner_id = $1 ORDER BY created_at DESC",
|
||||
user.user_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"memory_shares": [
|
||||
{
|
||||
"memory_id": r["memory_id"], "shared_with": r["shared_with"],
|
||||
"permission": r["permission"], "preview": r["preview"],
|
||||
"created_at": r["created_at"].isoformat(),
|
||||
}
|
||||
for r in memory_shares
|
||||
],
|
||||
"tag_shares": [
|
||||
{
|
||||
"tag": r["tag"], "shared_with": r["shared_with"],
|
||||
"permission": r["permission"],
|
||||
"created_at": r["created_at"].isoformat(),
|
||||
}
|
||||
for r in tag_shares
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.put("/api/memories/{memory_id}")
|
||||
async def update_memory(memory_id: int, body: MemoryUpdate, user: AuthUser = Depends(get_current_user)) -> dict[str, Any]:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
allowed, owner_id = await check_memory_permission(conn, memory_id, user.user_id, "write")
|
||||
if not allowed:
|
||||
if owner_id is None:
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
raise HTTPException(status_code=403, detail="Write permission required")
|
||||
|
||||
updates = []
|
||||
params: list[Any] = []
|
||||
idx = 1
|
||||
|
||||
if body.content is not None:
|
||||
updates.append(f"content = ${idx}")
|
||||
params.append(body.content)
|
||||
idx += 1
|
||||
if body.tags is not None:
|
||||
updates.append(f"tags = ${idx}")
|
||||
params.append(body.tags)
|
||||
idx += 1
|
||||
if body.importance is not None:
|
||||
updates.append(f"importance = ${idx}")
|
||||
params.append(body.importance)
|
||||
idx += 1
|
||||
if body.expanded_keywords is not None:
|
||||
updates.append(f"expanded_keywords = ${idx}")
|
||||
params.append(body.expanded_keywords)
|
||||
idx += 1
|
||||
|
||||
if not updates:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
|
||||
updates.append("updated_at = NOW()")
|
||||
params.append(memory_id)
|
||||
|
||||
await conn.execute(
|
||||
f"UPDATE memories SET {', '.join(updates)} WHERE id = ${idx}",
|
||||
*params,
|
||||
)
|
||||
|
||||
return {"updated": memory_id}
|
||||
|
||||
|
||||
# --- MCP SSE Transport ---
|
||||
|
||||
|
||||
|
|
@ -608,6 +894,118 @@ async def secret_get(key: str) -> str:
|
|||
return json.dumps({"error": "secret_get is not available via SSE transport"})
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def memory_share(id: int, shared_with: str, permission: str = "read") -> str:
|
||||
"""Share a memory with another user. Permission: 'read' or 'write'."""
|
||||
if permission not in ("read", "write"):
|
||||
return json.dumps({"error": "permission must be 'read' or 'write'"})
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT id FROM memories WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL",
|
||||
id, user_id,
|
||||
)
|
||||
if not row:
|
||||
return json.dumps({"error": "Memory not found or not owned by you"})
|
||||
await conn.execute(
|
||||
"""INSERT INTO memory_shares (memory_id, owner_id, shared_with, permission)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (memory_id, shared_with) DO UPDATE SET permission = EXCLUDED.permission""",
|
||||
id, user_id, shared_with, permission,
|
||||
)
|
||||
return json.dumps({"shared": id, "with": shared_with, "permission": permission})
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def memory_unshare(id: int, shared_with: str) -> str:
|
||||
"""Revoke sharing of a memory from a user."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"DELETE FROM memory_shares WHERE memory_id = $1 AND owner_id = $2 AND shared_with = $3",
|
||||
id, user_id, shared_with,
|
||||
)
|
||||
return json.dumps({"unshared": id, "from": shared_with})
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def memory_share_tag(tag: str, shared_with: str, permission: str = "read") -> str:
|
||||
"""Share all memories with a given tag with another user. Future memories with this tag are automatically shared."""
|
||||
if permission not in ("read", "write"):
|
||||
return json.dumps({"error": "permission must be 'read' or 'write'"})
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""INSERT INTO tag_shares (owner_id, tag, shared_with, permission)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (owner_id, tag, shared_with) DO UPDATE SET permission = EXCLUDED.permission""",
|
||||
user_id, tag, shared_with, permission,
|
||||
)
|
||||
return json.dumps({"shared_tag": tag, "with": shared_with, "permission": permission})
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def memory_unshare_tag(tag: str, shared_with: str) -> str:
|
||||
"""Revoke tag-based sharing."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"DELETE FROM tag_shares WHERE owner_id = $1 AND tag = $2 AND shared_with = $3",
|
||||
user_id, tag, shared_with,
|
||||
)
|
||||
return json.dumps({"unshared_tag": tag, "from": shared_with})
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def memory_update(id: int, content: str | None = None, tags: str | None = None,
|
||||
importance: float | None = None, expanded_keywords: str | None = None) -> str:
|
||||
"""Update an existing memory's content, tags, importance, or keywords."""
|
||||
pool = await get_pool()
|
||||
user_id = "default"
|
||||
async with pool.acquire() as conn:
|
||||
allowed, owner_id = await check_memory_permission(conn, id, user_id, "write")
|
||||
if not allowed:
|
||||
if owner_id is None:
|
||||
return json.dumps({"error": "Memory not found"})
|
||||
return json.dumps({"error": "Write permission required"})
|
||||
|
||||
updates = []
|
||||
params: list[Any] = []
|
||||
idx = 1
|
||||
if content is not None:
|
||||
updates.append(f"content = ${idx}")
|
||||
params.append(content)
|
||||
idx += 1
|
||||
if tags is not None:
|
||||
updates.append(f"tags = ${idx}")
|
||||
params.append(tags)
|
||||
idx += 1
|
||||
if importance is not None:
|
||||
updates.append(f"importance = ${idx}")
|
||||
params.append(importance)
|
||||
idx += 1
|
||||
if expanded_keywords is not None:
|
||||
updates.append(f"expanded_keywords = ${idx}")
|
||||
params.append(expanded_keywords)
|
||||
idx += 1
|
||||
|
||||
if not updates:
|
||||
return json.dumps({"error": "No fields to update"})
|
||||
|
||||
updates.append("updated_at = NOW()")
|
||||
params.append(id)
|
||||
await conn.execute(
|
||||
f"UPDATE memories SET {', '.join(updates)} WHERE id = ${idx}",
|
||||
*params,
|
||||
)
|
||||
|
||||
return json.dumps({"updated": id})
|
||||
|
||||
|
||||
# Auth middleware for /mcp/* routes
|
||||
class MCPAuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next: Any) -> Response:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -38,3 +38,26 @@ class SecretResponse(BaseModel):
|
|||
class SyncResponse(BaseModel):
|
||||
memories: list[dict[str, Any]]
|
||||
server_time: str
|
||||
|
||||
|
||||
class ShareMemory(BaseModel):
|
||||
shared_with: str = Field(..., min_length=1, max_length=100)
|
||||
permission: Literal["read", "write"] = "read"
|
||||
|
||||
|
||||
class ShareTag(BaseModel):
|
||||
tag: str = Field(..., min_length=1, max_length=100)
|
||||
shared_with: str = Field(..., min_length=1, max_length=100)
|
||||
permission: Literal["read", "write"] = "read"
|
||||
|
||||
|
||||
class UnshareTag(BaseModel):
|
||||
tag: str = Field(..., min_length=1, max_length=100)
|
||||
shared_with: str = Field(..., min_length=1, max_length=100)
|
||||
|
||||
|
||||
class MemoryUpdate(BaseModel):
|
||||
content: Optional[str] = Field(None, max_length=MAX_MEMORY_CHARS)
|
||||
tags: Optional[str] = None
|
||||
importance: Optional[float] = Field(None, ge=0.0, le=1.0)
|
||||
expanded_keywords: Optional[str] = None
|
||||
|
|
|
|||
60
src/claude_memory/api/permissions.py
Normal file
60
src/claude_memory/api/permissions.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""Permission checking for shared memories."""
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
async def check_memory_permission(
|
||||
conn: asyncpg.Connection, memory_id: int, user_id: str, required: str
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Check if user_id has the required permission on memory_id.
|
||||
|
||||
Returns (allowed, owner_id).
|
||||
- Owner always has full access.
|
||||
- Shared users checked via memory_shares and tag_shares.
|
||||
- required: "read" or "write". "read" is satisfied by either permission.
|
||||
"""
|
||||
row = await conn.fetchrow(
|
||||
"SELECT user_id FROM memories WHERE id = $1 AND deleted_at IS NULL",
|
||||
memory_id,
|
||||
)
|
||||
if not row:
|
||||
return False, None
|
||||
|
||||
owner_id = row["user_id"]
|
||||
|
||||
# Owner always has access
|
||||
if owner_id == user_id:
|
||||
return True, owner_id
|
||||
|
||||
# Check individual memory share
|
||||
share = await conn.fetchrow(
|
||||
"SELECT permission FROM memory_shares WHERE memory_id = $1 AND shared_with = $2",
|
||||
memory_id, user_id,
|
||||
)
|
||||
if share:
|
||||
if required == "read" or share["permission"] == "write":
|
||||
return True, owner_id
|
||||
return False, owner_id
|
||||
|
||||
# Check tag-based shares
|
||||
tag_share = await conn.fetchrow(
|
||||
"""
|
||||
SELECT ts.permission
|
||||
FROM tag_shares ts
|
||||
JOIN memories m ON m.user_id = ts.owner_id
|
||||
WHERE m.id = $1 AND ts.shared_with = $2
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM unnest(string_to_array(m.tags, ',')) t
|
||||
WHERE trim(t) = ts.tag
|
||||
)
|
||||
ORDER BY CASE WHEN ts.permission = 'write' THEN 0 ELSE 1 END
|
||||
LIMIT 1
|
||||
""",
|
||||
memory_id, user_id,
|
||||
)
|
||||
if tag_share:
|
||||
if required == "read" or tag_share["permission"] == "write":
|
||||
return True, owner_id
|
||||
return False, owner_id
|
||||
|
||||
return False, owner_id
|
||||
Loading…
Add table
Add a link
Reference in a new issue