feat: standalone claude-memory-mcp with multi-user support and Vault integration

Extracted from private infra repo into standalone open-source project.

Three operating modes:
- Local: SQLite + FTS5 (zero dependencies)
- Server: PostgreSQL via HTTP API with multi-user auth
- Full: PostgreSQL + HashiCorp Vault for secret management

Features:
- MCP stdio server with 5 tools (store/recall/list/delete/secret_get)
- FastAPI HTTP API with multi-user Bearer token auth (API_KEYS JSON map)
- Regex-based credential detection with auto-redaction
- AES-256-GCM encryption fallback for non-Vault deployments
- Vault KV v2 client (stdlib urllib, K8s SA auto-auth)
- Per-user data isolation (all queries scoped by user_id)
- Secret migration endpoint for existing plain-text credentials
- Backward-compatible env var aliases (CLAUDE_MEMORY_API_URL)

Infrastructure:
- Docker + docker-compose (API + PostgreSQL + optional Vault)
- Woodpecker CI (test → build → push → kubectl deploy)
- GitHub Actions CI (Python 3.11/3.12/3.13) + Release (GHCR + PyPI)
- Helm chart + raw Kubernetes manifests

96 tests passing across 6 test files.
This commit is contained in:
Viktor Barzin 2026-03-14 09:42:05 +00:00
commit 0ed5e1e016
No known key found for this signature in database
GPG key ID: 0EB088298288D958
40 changed files with 3381 additions and 0 deletions

0
tests/__init__.py Normal file
View file

304
tests/test_api.py Normal file
View file

@ -0,0 +1,304 @@
"""Tests for the Claude Memory API endpoints."""
import importlib
import os
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from claude_memory.api.auth import AuthUser
# Helpers to build mock asyncpg rows (they behave like dicts with attribute access)
class MockRow(dict):
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(key)
def _make_memory_row(**overrides):
now = datetime.now(timezone.utc)
defaults = {
"id": 1,
"user_id": "testuser",
"content": "test content",
"category": "facts",
"tags": "",
"expanded_keywords": "",
"importance": 0.5,
"is_sensitive": False,
"vault_path": None,
"encrypted_content": None,
"rank": 0.5,
"created_at": now,
"updated_at": now,
}
defaults.update(overrides)
return MockRow(defaults)
@pytest.fixture
def mock_pool():
"""Create a mock asyncpg pool with connection context manager."""
pool = MagicMock()
conn = AsyncMock()
# pool.acquire() returns an async context manager yielding conn
acm = MagicMock()
acm.__aenter__ = AsyncMock(return_value=conn)
acm.__aexit__ = AsyncMock(return_value=False)
pool.acquire.return_value = acm
return pool, conn
@pytest.fixture
def test_user():
return AuthUser(user_id="testuser")
@pytest.fixture
def client(mock_pool, test_user):
"""Create an AsyncClient with mocked dependencies."""
pool, conn = mock_pool
# Reload modules with test API key
with patch.dict(os.environ, {"API_KEY": "test-key", "API_KEYS": "", "DATABASE_URL": "postgresql://test"}):
import claude_memory.api.auth as auth_mod
import claude_memory.api.database as db_mod
import claude_memory.api.app as app_mod
importlib.reload(auth_mod)
importlib.reload(db_mod)
importlib.reload(app_mod)
# Override database pool
db_mod.pool = pool
# Override auth to return our test user
async def mock_get_user(authorization: str = ""):
return test_user
app_mod.app.dependency_overrides[auth_mod.get_current_user] = mock_get_user
transport = ASGITransport(app=app_mod.app)
return AsyncClient(transport=transport, base_url="http://test"), conn, app_mod
@pytest.mark.asyncio
async def test_health_endpoint_no_auth(client):
ac, conn, app_mod = client
async with ac:
resp = await ac.get("/health")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
@pytest.mark.asyncio
async def test_store_memory_creates_record_with_user_id(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = _make_memory_row(id=42, category="facts", importance=0.7)
async with ac:
resp = await ac.post(
"/api/memories",
json={"content": "Python is great", "category": "facts", "importance": 0.7},
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
data = resp.json()
assert data["id"] == 42
assert data["category"] == "facts"
assert data["importance"] == 0.7
# Verify INSERT was called with user_id
call_args = conn.fetchrow.call_args
assert call_args[0][1] == "testuser" # user_id is the second positional arg
@pytest.mark.asyncio
async def test_recall_returns_only_user_memories(client):
ac, conn, app_mod = client
conn.fetch.return_value = [
_make_memory_row(id=1, content="user memory", is_sensitive=False),
]
async with ac:
resp = await ac.post(
"/api/memories/recall",
json={"context": "test query"},
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
results = resp.json()
assert len(results) == 1
assert results[0]["content"] == "user memory"
# Verify query includes user_id filter
call_args = conn.fetch.call_args
assert call_args[0][1] == "testuser"
@pytest.mark.asyncio
async def test_recall_redacts_sensitive_memories(client):
ac, conn, app_mod = client
conn.fetch.return_value = [
_make_memory_row(id=5, content="[REDACTED]", is_sensitive=True),
]
async with ac:
resp = await ac.post(
"/api/memories/recall",
json={"context": "secrets"},
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
results = resp.json()
assert "[SENSITIVE" in results[0]["content"]
assert "secret_get(id=5)" in results[0]["content"]
@pytest.mark.asyncio
async def test_list_returns_only_user_memories(client):
ac, conn, app_mod = client
conn.fetch.return_value = [
_make_memory_row(id=1, content="mem1"),
_make_memory_row(id=2, content="mem2"),
]
async with ac:
resp = await ac.get(
"/api/memories",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
results = resp.json()
assert len(results) == 2
# Verify user_id filter
call_args = conn.fetch.call_args
assert call_args[0][1] == "testuser"
@pytest.mark.asyncio
async def test_delete_only_user_memories(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = _make_memory_row(id=10, vault_path=None)
conn.execute.return_value = None
async with ac:
resp = await ac.delete(
"/api/memories/10",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
assert resp.json() == {"deleted": 10}
# Verify both SELECT and DELETE include user_id
fetchrow_args = conn.fetchrow.call_args
assert fetchrow_args[0][1] == 10 # memory_id
assert fetchrow_args[0][2] == "testuser" # user_id
@pytest.mark.asyncio
async def test_delete_nonexistent_memory_returns_404(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = None
async with ac:
resp = await ac.delete(
"/api/memories/999",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_secret_endpoint_returns_plaintext(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = _make_memory_row(
id=7, content="my secret value", is_sensitive=False,
vault_path=None, encrypted_content=None,
)
async with ac:
resp = await ac.post(
"/api/memories/7/secret",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
data = resp.json()
assert data["id"] == 7
assert data["content"] == "my secret value"
assert data["source"] == "plaintext"
@pytest.mark.asyncio
async def test_secret_endpoint_returns_vault_content(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = _make_memory_row(
id=8, content="[REDACTED]", is_sensitive=True,
vault_path="claude-memory/testuser/mem-8", encrypted_content=None,
)
with patch("claude_memory.api.app.get_secret", return_value="actual-secret-from-vault"):
async with ac:
resp = await ac.post(
"/api/memories/8/secret",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
data = resp.json()
assert data["content"] == "actual-secret-from-vault"
assert data["source"] == "vault"
@pytest.mark.asyncio
async def test_secret_endpoint_nonexistent_returns_404(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = None
async with ac:
resp = await ac.post(
"/api/memories/999/secret",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_import_memories(client):
ac, conn, app_mod = client
conn.fetchrow.side_effect = [
_make_memory_row(id=100, category="facts", importance=0.5),
_make_memory_row(id=101, category="preferences", importance=0.8),
]
async with ac:
resp = await ac.post(
"/api/memories/import",
json=[
{"content": "fact one", "category": "facts"},
{"content": "pref one", "category": "preferences", "importance": 0.8},
],
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
assert data[0]["id"] == 100
assert data[1]["id"] == 101

87
tests/test_auth.py Normal file
View file

@ -0,0 +1,87 @@
"""Tests for multi-user authentication."""
import importlib
import os
from unittest.mock import patch
import pytest
from fastapi import HTTPException
def _reload_auth(env_vars: dict):
"""Reload the auth module with given environment variables."""
with patch.dict(os.environ, env_vars, clear=False):
# Clear existing env vars that might interfere
for key in ("API_KEY", "API_KEYS"):
os.environ.pop(key, None)
for key, val in env_vars.items():
os.environ[key] = val
import claude_memory.api.auth as auth_mod
importlib.reload(auth_mod)
return auth_mod
@pytest.mark.asyncio
async def test_single_api_key_maps_to_default():
auth = _reload_auth({"API_KEY": "test-key-123", "API_KEYS": ""})
user = await auth.get_current_user(authorization="Bearer test-key-123")
assert user.user_id == "default"
@pytest.mark.asyncio
async def test_multi_api_keys_maps_to_correct_user():
auth = _reload_auth({
"API_KEYS": '{"viktor": "key-viktor", "alice": "key-alice"}',
"API_KEY": "",
})
user_v = await auth.get_current_user(authorization="Bearer key-viktor")
assert user_v.user_id == "viktor"
user_a = await auth.get_current_user(authorization="Bearer key-alice")
assert user_a.user_id == "alice"
@pytest.mark.asyncio
async def test_invalid_key_returns_401():
auth = _reload_auth({"API_KEY": "valid-key", "API_KEYS": ""})
with pytest.raises(HTTPException) as exc_info:
await auth.get_current_user(authorization="Bearer wrong-key")
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_missing_bearer_prefix_still_works():
auth = _reload_auth({"API_KEY": "my-key", "API_KEYS": ""})
# Without Bearer prefix, removeprefix("Bearer ") returns "my-key" unchanged
# so the raw token still matches the key
user = await auth.get_current_user(authorization="my-key")
assert user.user_id == "default"
# With proper Bearer prefix it also works
user = await auth.get_current_user(authorization="Bearer my-key")
assert user.user_id == "default"
@pytest.mark.asyncio
async def test_missing_authorization_header_raises_422():
"""FastAPI raises 422 when required Header is missing.
This is tested via the app integration, not the function directly,
since FastAPI handles the missing header before the function runs.
"""
from httpx import ASGITransport, AsyncClient
# Need to reload with valid keys so the app can start
_reload_auth({"API_KEY": "test-key", "API_KEYS": ""})
# Import app after auth is configured
import claude_memory.api.app as app_mod
importlib.reload(app_mod)
transport = ASGITransport(app=app_mod.app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# Skip lifespan since we don't have a real DB
resp = await client.get("/api/memories")
assert resp.status_code == 422

View file

@ -0,0 +1,132 @@
"""Tests for credential detection and redaction."""
import pytest
from claude_memory.credential_detector import (
DetectedCredential,
detect_credentials,
is_sensitive,
redact_credentials,
)
class TestDetectCredentials:
def test_detect_postgres_connection_string(self):
text = "db_url = postgres://user:pass@localhost:5432/mydb"
creds = detect_credentials(text)
assert len(creds) == 1
assert creds[0].type == "connection_string"
assert creds[0].confidence == 0.9
assert "postgres://" in creds[0].matched_text
def test_detect_password_assignment(self):
text = 'password = "my_super_secret_pw"'
creds = detect_credentials(text)
assert len(creds) >= 1
types = [c.type for c in creds]
assert "password" in types
def test_detect_api_key(self):
text = "api_key = ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
creds = detect_credentials(text)
assert len(creds) >= 1
types = [c.type for c in creds]
assert "api_key" in types
def test_detect_private_key(self):
text = "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA0Z3VS5JJcds3xfn/ygWep4PAtGoSo\n-----END RSA PRIVATE KEY-----"
creds = detect_credentials(text)
assert len(creds) == 1
assert creds[0].type == "private_key"
assert creds[0].confidence == 0.95
def test_detect_bearer_token(self):
text = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkw"
creds = detect_credentials(text)
assert len(creds) >= 1
types = [c.type for c in creds]
assert "bearer_token" in types
def test_detect_aws_key(self):
text = "aws_access_key_id = AKIAIOSFODNN7EXAMPLE"
creds = detect_credentials(text)
assert len(creds) >= 1
types = [c.type for c in creds]
assert "aws_key" in types
def test_detect_github_token(self):
text = "GITHUB_TOKEN=ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmn"
creds = detect_credentials(text)
assert len(creds) >= 1
types = [c.type for c in creds]
assert "github_token" in types
def test_no_false_positives_on_normal_text(self):
text = "This is a normal paragraph about programming. It discusses variables, functions, and classes."
creds = detect_credentials(text)
assert len(creds) == 0
def test_no_false_positives_on_short_password(self):
# password values shorter than 8 chars should not match
text = 'password = "short"'
creds = detect_credentials(text)
assert len(creds) == 0
def test_min_confidence_filtering(self):
text = 'secret = "abcdefghijklmnopqrstuvwxyz"'
all_creds = detect_credentials(text, min_confidence=0.5)
high_creds = detect_credentials(text, min_confidence=0.9)
assert len(all_creds) >= len(high_creds)
def test_overlapping_matches_keep_highest_confidence(self):
# A text that could match both token and generic_secret
text = 'secret = "abcdefghijklmnopqrstuvwxyz1234567890"'
creds = detect_credentials(text, min_confidence=0.5)
# Should not have overlapping ranges for the same span
for i, c1 in enumerate(creds):
for c2 in creds[i + 1:]:
# No credential should be fully contained within another
assert not (c1.start <= c2.start and c1.end >= c2.end)
class TestRedactCredentials:
def test_redaction_replaces_with_marker(self):
text = "db_url = postgres://user:pass@localhost:5432/mydb"
creds = detect_credentials(text)
redacted = redact_credentials(text, creds)
assert "[REDACTED:connection_string]" in redacted
assert "postgres://" not in redacted
def test_redaction_preserves_surrounding_text(self):
text = "before postgres://user:pass@localhost/db after"
creds = detect_credentials(text)
redacted = redact_credentials(text, creds)
assert redacted.startswith("before ")
assert redacted.endswith(" after")
def test_redaction_no_credentials(self):
text = "nothing sensitive here"
redacted = redact_credentials(text, [])
assert redacted == text
def test_redaction_multiple_credentials(self):
text = 'password = "mysecretpw123" and api_key = ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890'
creds = detect_credentials(text)
redacted = redact_credentials(text, creds)
assert "mysecretpw123" not in redacted
assert "[REDACTED:" in redacted
class TestIsSensitive:
def test_sensitive_text(self):
assert is_sensitive("password = supersecretvalue123")
def test_non_sensitive_text(self):
assert not is_sensitive("just a normal log message")
def test_respects_min_confidence(self):
text = 'secret = "abcdefghijklmnopqrstuvwxyz"'
# Low confidence should detect
assert is_sensitive(text, min_confidence=0.5)
# Very high confidence should not detect generic_secret
assert not is_sensitive(text, min_confidence=0.95)

134
tests/test_crypto.py Normal file
View file

@ -0,0 +1,134 @@
"""Tests for AES-256-GCM encryption module."""
import hashlib
import os
import pytest
from claude_memory.crypto import (
ENCRYPTION_KEY_ENV,
decrypt,
decrypt_b64,
encrypt,
encrypt_b64,
is_encryption_configured,
)
# A valid 32-byte hex key for testing
TEST_HEX_KEY = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
TEST_PASSPHRASE = "my-test-passphrase"
@pytest.fixture
def hex_key_env(monkeypatch):
monkeypatch.setenv(ENCRYPTION_KEY_ENV, TEST_HEX_KEY)
@pytest.fixture
def passphrase_env(monkeypatch):
monkeypatch.setenv(ENCRYPTION_KEY_ENV, TEST_PASSPHRASE)
@pytest.fixture
def no_key_env(monkeypatch):
monkeypatch.delenv(ENCRYPTION_KEY_ENV, raising=False)
class TestEncryptionConfigured:
def test_configured_with_hex_key(self, hex_key_env):
assert is_encryption_configured() is True
def test_configured_with_passphrase(self, passphrase_env):
assert is_encryption_configured() is True
def test_not_configured_without_env(self, no_key_env):
assert is_encryption_configured() is False
class TestEncryptDecrypt:
def test_roundtrip_with_hex_key(self, hex_key_env):
plaintext = "Hello, this is a secret message!"
encrypted = encrypt(plaintext)
decrypted = decrypt(encrypted)
assert decrypted == plaintext
def test_roundtrip_with_passphrase(self, passphrase_env):
plaintext = "Another secret message with passphrase key"
encrypted = encrypt(plaintext)
decrypted = decrypt(encrypted)
assert decrypted == plaintext
def test_different_plaintexts_produce_different_ciphertexts(self, hex_key_env):
ct1 = encrypt("message one")
ct2 = encrypt("message two")
assert ct1 != ct2
def test_same_plaintext_produces_different_ciphertexts(self, hex_key_env):
"""Due to random nonce, encrypting the same text twice gives different results."""
ct1 = encrypt("same message")
ct2 = encrypt("same message")
assert ct1 != ct2
def test_missing_key_raises_on_encrypt(self, no_key_env):
with pytest.raises(RuntimeError, match=ENCRYPTION_KEY_ENV):
encrypt("test")
def test_missing_key_raises_on_decrypt(self, no_key_env):
with pytest.raises(RuntimeError, match=ENCRYPTION_KEY_ENV):
decrypt(b"\x00" * 28)
def test_decrypt_with_wrong_key_fails(self, hex_key_env, monkeypatch):
plaintext = "secret data"
encrypted = encrypt(plaintext)
# Change to a different key
monkeypatch.setenv(ENCRYPTION_KEY_ENV, "ff" * 32)
with pytest.raises(Exception):
decrypt(encrypted)
def test_encrypted_data_format(self, hex_key_env):
"""Encrypted data should be at least 12 (nonce) + 16 (tag) bytes."""
encrypted = encrypt("x")
assert len(encrypted) >= 28 # 12 nonce + 1 plaintext + 16 tag = 29 minimum
def test_unicode_roundtrip(self, hex_key_env):
plaintext = "Unicode test: cafe\u0301, \u00fc\u00f6\u00e4, \U0001f512"
decrypted = decrypt(encrypt(plaintext))
assert decrypted == plaintext
class TestBase64Variants:
def test_b64_roundtrip(self, hex_key_env):
plaintext = "base64 test message"
encrypted_b64 = encrypt_b64(plaintext)
assert isinstance(encrypted_b64, str)
decrypted = decrypt_b64(encrypted_b64)
assert decrypted == plaintext
def test_b64_output_is_valid_base64(self, hex_key_env):
import base64
encrypted_b64 = encrypt_b64("test")
# Should not raise
decoded = base64.b64decode(encrypted_b64)
assert len(decoded) >= 28
class TestKeyDerivation:
def test_hex_key_used_directly(self, hex_key_env):
"""A valid 64-char hex string should be used as-is (32 bytes)."""
ct = encrypt("test")
pt = decrypt(ct)
assert pt == "test"
def test_passphrase_derived_via_sha256(self, passphrase_env):
"""Non-hex strings should be derived via SHA-256."""
ct = encrypt("test")
pt = decrypt(ct)
assert pt == "test"
def test_short_hex_treated_as_passphrase(self, monkeypatch):
"""Hex string that's not exactly 32 bytes should be treated as passphrase."""
monkeypatch.setenv(ENCRYPTION_KEY_ENV, "abcd1234")
ct = encrypt("test")
pt = decrypt(ct)
assert pt == "test"

342
tests/test_mcp_server.py Normal file
View file

@ -0,0 +1,342 @@
"""Tests for the Claude Memory MCP server."""
import json
import os
import sys
import pytest
# Force SQLite fallback mode for all tests
os.environ.pop("MEMORY_API_KEY", None)
os.environ.pop("CLAUDE_MEMORY_API_KEY", None)
# Add src to path so we can import without installing
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from claude_memory.mcp_server import MemoryServer, TOOLS, SERVER_NAME, SERVER_VERSION, PROTOCOL_VERSION
@pytest.fixture
def server(tmp_path):
"""Create a MemoryServer with a temporary SQLite database."""
db_path = str(tmp_path / "test_memory.db")
srv = MemoryServer(sqlite_db_path=db_path)
yield srv
if srv.sqlite_conn:
srv.sqlite_conn.close()
class TestSQLiteInit:
def test_creates_database(self, tmp_path):
db_path = str(tmp_path / "sub" / "test.db")
srv = MemoryServer(sqlite_db_path=db_path)
assert os.path.exists(db_path)
# Verify tables exist
cursor = srv.sqlite_conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories'")
assert cursor.fetchone() is not None
srv.sqlite_conn.close()
def test_creates_fts_table(self, tmp_path):
db_path = str(tmp_path / "test.db")
srv = MemoryServer(sqlite_db_path=db_path)
cursor = srv.sqlite_conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories_fts'")
assert cursor.fetchone() is not None
srv.sqlite_conn.close()
class TestMemoryStore:
def test_store_basic(self, server):
result = server.memory_store({
"content": "User prefers dark mode",
"expanded_keywords": "dark mode theme preference ui",
})
assert "Stored memory #1" in result
assert "facts" in result
def test_store_with_category(self, server):
result = server.memory_store({
"content": "User likes Python",
"category": "preferences",
"expanded_keywords": "python programming language preference",
})
assert "preferences" in result
def test_store_with_importance(self, server):
result = server.memory_store({
"content": "Critical info",
"importance": 0.9,
"expanded_keywords": "critical important info",
})
assert "0.9" in result
def test_store_requires_content(self, server):
with pytest.raises(ValueError, match="content is required"):
server.memory_store({"expanded_keywords": "test"})
def test_store_force_sensitive(self, server):
result = server.memory_store({
"content": "API key: sk-1234",
"force_sensitive": True,
"expanded_keywords": "api key secret credential",
})
assert "Stored memory #1" in result
# Verify is_sensitive flag is set
cursor = server.sqlite_conn.cursor()
cursor.execute("SELECT is_sensitive FROM memories WHERE id = 1")
row = cursor.fetchone()
assert row["is_sensitive"] == 1
class TestMemoryRecall:
def test_recall_finds_memory(self, server):
server.memory_store({
"content": "User works at Acme Corp",
"expanded_keywords": "acme corp company work employer",
})
result = server.memory_recall({
"context": "work",
"expanded_query": "company employer job",
})
assert "Acme Corp" in result
assert "Found 1 memories" in result
def test_recall_no_results(self, server):
result = server.memory_recall({
"context": "nonexistent topic",
"expanded_query": "nothing here at all",
})
assert "No memories found" in result
def test_recall_with_category_filter(self, server):
server.memory_store({
"content": "User prefers vim",
"category": "preferences",
"expanded_keywords": "vim editor preference text",
})
server.memory_store({
"content": "Project uses React",
"category": "projects",
"expanded_keywords": "react project frontend framework",
})
result = server.memory_recall({
"context": "preferences",
"expanded_query": "vim editor",
"category": "preferences",
})
assert "vim" in result
assert "React" not in result
def test_recall_requires_context(self, server):
with pytest.raises(ValueError, match="context is required"):
server.memory_recall({"expanded_query": "test"})
class TestMemoryList:
def test_list_empty(self, server):
result = server.memory_list({})
assert "No memories stored yet" in result
def test_list_with_memories(self, server):
server.memory_store({
"content": "Memory one",
"expanded_keywords": "one first test",
})
server.memory_store({
"content": "Memory two",
"expanded_keywords": "two second test",
})
result = server.memory_list({})
assert "Memory one" in result
assert "Memory two" in result
assert "2 shown" in result
def test_list_with_category(self, server):
server.memory_store({
"content": "A fact",
"category": "facts",
"expanded_keywords": "fact test",
})
server.memory_store({
"content": "A preference",
"category": "preferences",
"expanded_keywords": "preference test",
})
result = server.memory_list({"category": "facts"})
assert "A fact" in result
assert "A preference" not in result
def test_list_empty_category(self, server):
result = server.memory_list({"category": "projects"})
assert "No memories in category 'projects'" in result
def test_list_respects_limit(self, server):
for i in range(5):
server.memory_store({
"content": f"Memory {i}",
"expanded_keywords": f"memory number {i}",
})
result = server.memory_list({"limit": 2})
assert "2 shown" in result
class TestMemoryDelete:
def test_delete_existing(self, server):
server.memory_store({
"content": "To be deleted",
"expanded_keywords": "delete remove test",
})
result = server.memory_delete({"id": 1})
assert "Deleted memory #1" in result
assert "To be deleted" in result
def test_delete_nonexistent(self, server):
result = server.memory_delete({"id": 999})
assert "not found" in result
def test_delete_requires_id(self, server):
with pytest.raises(ValueError, match="id is required"):
server.memory_delete({})
class TestSecretGet:
def test_secret_get_sensitive(self, server):
server.memory_store({
"content": "secret password 12345",
"force_sensitive": True,
"expanded_keywords": "password secret credential",
})
result = server.secret_get({"id": 1})
assert "secret password 12345" in result
def test_secret_get_not_sensitive(self, server):
server.memory_store({
"content": "public info",
"expanded_keywords": "public info test",
})
result = server.secret_get({"id": 1})
assert "not marked as sensitive" in result
def test_secret_get_nonexistent(self, server):
result = server.secret_get({"id": 999})
assert "not found" in result
def test_secret_get_requires_id(self, server):
with pytest.raises(ValueError, match="id is required"):
server.secret_get({})
class TestMCPProtocol:
def test_handle_initialize(self, server):
result = server.handle_initialize({})
assert result["protocolVersion"] == PROTOCOL_VERSION
assert result["serverInfo"]["name"] == SERVER_NAME
assert result["serverInfo"]["version"] == SERVER_VERSION
assert "tools" in result["capabilities"]
def test_handle_tools_list(self, server):
result = server.handle_tools_list({})
tools = result["tools"]
assert len(tools) == 5
names = {t["name"] for t in tools}
assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get"}
def test_handle_tools_call_store(self, server):
result = server.handle_tools_call({
"name": "memory_store",
"arguments": {
"content": "test memory",
"expanded_keywords": "test memory keywords",
},
})
assert not result.get("isError", False)
assert "Stored memory" in result["content"][0]["text"]
def test_handle_tools_call_unknown(self, server):
result = server.handle_tools_call({
"name": "nonexistent_tool",
"arguments": {},
})
assert result["isError"] is True
assert "Unknown tool" in result["content"][0]["text"]
def test_handle_tools_call_error(self, server):
result = server.handle_tools_call({
"name": "memory_store",
"arguments": {}, # missing content
})
assert result["isError"] is True
assert "Error" in result["content"][0]["text"]
class TestProcessMessage:
def test_initialize(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {},
})
assert response["jsonrpc"] == "2.0"
assert response["id"] == 1
assert "result" in response
assert response["result"]["serverInfo"]["name"] == SERVER_NAME
def test_tools_list(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
"params": {},
})
assert "result" in response
assert len(response["result"]["tools"]) == 5
def test_tools_call(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {
"name": "memory_store",
"arguments": {
"content": "via process_message",
"expanded_keywords": "process message test",
},
},
})
assert "result" in response
assert "Stored memory" in response["result"]["content"][0]["text"]
def test_unknown_method(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 4,
"method": "unknown/method",
"params": {},
})
assert "error" in response
assert response["error"]["code"] == -32601
assert "Method not found" in response["error"]["message"]
def test_notification_no_id(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
})
assert response is None
def test_jsonrpc_response_format(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 5,
"method": "initialize",
"params": {},
})
# Verify it's valid JSON when serialized
serialized = json.dumps(response)
parsed = json.loads(serialized)
assert parsed["jsonrpc"] == "2.0"
assert parsed["id"] == 5

154
tests/test_vault_client.py Normal file
View file

@ -0,0 +1,154 @@
"""Tests for Vault KV v2 client with mocked urllib."""
import json
import os
from io import BytesIO
from unittest.mock import MagicMock, mock_open, patch
import pytest
from claude_memory.vault_client import VaultClient
@pytest.fixture
def vault_env(monkeypatch):
monkeypatch.setenv("VAULT_ADDR", "http://vault.example.com:8200")
monkeypatch.setenv("VAULT_TOKEN", "s.testtoken123")
class TestVaultClientInit:
def test_missing_addr_raises_value_error(self, monkeypatch):
monkeypatch.delenv("VAULT_ADDR", raising=False)
monkeypatch.delenv("VAULT_TOKEN", raising=False)
with pytest.raises(ValueError, match="Vault address not configured"):
VaultClient()
def test_init_with_explicit_args(self):
client = VaultClient(addr="http://localhost:8200", token="mytoken")
assert client.addr == "http://localhost:8200"
assert client.token == "mytoken"
assert client.mount == "secret"
def test_init_from_env(self, vault_env):
client = VaultClient()
assert client.addr == "http://vault.example.com:8200"
assert client.token == "s.testtoken123"
def test_addr_trailing_slash_stripped(self):
client = VaultClient(addr="http://localhost:8200/", token="t")
assert client.addr == "http://localhost:8200"
@patch("os.path.exists", return_value=True)
@patch("builtins.open", mock_open(read_data="fake-jwt-token"))
@patch("urllib.request.urlopen")
def test_kubernetes_sa_token_auto_detection(self, mock_urlopen, mock_exists, monkeypatch):
monkeypatch.setenv("VAULT_ADDR", "http://vault:8200")
monkeypatch.delenv("VAULT_TOKEN", raising=False)
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({
"auth": {"client_token": "s.k8s-token-abc"}
}).encode()
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_urlopen.return_value = mock_response
client = VaultClient()
assert client.token == "s.k8s-token-abc"
class TestVaultRead:
@patch("urllib.request.urlopen")
def test_read_secret_returns_data(self, mock_urlopen, vault_env):
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({
"data": {"data": {"username": "admin", "password": "secret"}}
}).encode()
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_urlopen.return_value = mock_response
client = VaultClient()
result = client.read("myapp/config")
assert result == {"username": "admin", "password": "secret"}
@patch("urllib.request.urlopen")
def test_read_returns_none_for_404(self, mock_urlopen, vault_env):
import urllib.error
mock_urlopen.side_effect = urllib.error.HTTPError(
url="http://vault:8200/v1/secret/data/missing",
code=404,
msg="Not Found",
hdrs={},
fp=BytesIO(b""),
)
client = VaultClient()
result = client.read("missing/path")
assert result is None
class TestVaultWrite:
@patch("urllib.request.urlopen")
def test_write_secret_sends_correct_request(self, mock_urlopen, vault_env):
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({
"data": {"created_time": "2024-01-01T00:00:00Z", "version": 1}
}).encode()
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_urlopen.return_value = mock_response
client = VaultClient()
result = client.write("myapp/config", {"key": "value"})
# Verify the request was made with correct data
call_args = mock_urlopen.call_args
request = call_args[0][0]
assert request.full_url == "http://vault.example.com:8200/v1/secret/data/myapp/config"
assert request.method == "POST"
body = json.loads(request.data.decode())
assert body == {"data": {"key": "value"}}
class TestVaultDelete:
@patch("urllib.request.urlopen")
def test_delete_returns_true_on_success(self, mock_urlopen, vault_env):
mock_response = MagicMock()
mock_response.read.return_value = b"{}"
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_urlopen.return_value = mock_response
client = VaultClient()
assert client.delete("myapp/config") is True
@patch("urllib.request.urlopen")
def test_delete_returns_false_on_error(self, mock_urlopen, vault_env):
import urllib.error
mock_urlopen.side_effect = urllib.error.HTTPError(
url="http://vault:8200/v1/secret/data/missing",
code=500,
msg="Internal Server Error",
hdrs={},
fp=BytesIO(b"error"),
)
client = VaultClient()
assert client.delete("missing/path") is False
class TestVaultListSecrets:
@patch("urllib.request.urlopen")
def test_list_secrets(self, mock_urlopen, vault_env):
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({
"data": {"keys": ["secret1", "secret2/"]}
}).encode()
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_urlopen.return_value = mock_response
client = VaultClient()
result = client.list_secrets("myapp")
assert result == ["secret1", "secret2/"]