feat: standalone claude-memory-mcp with multi-user support and Vault integration
Extracted from private infra repo into standalone open-source project. Three operating modes: - Local: SQLite + FTS5 (zero dependencies) - Server: PostgreSQL via HTTP API with multi-user auth - Full: PostgreSQL + HashiCorp Vault for secret management Features: - MCP stdio server with 5 tools (store/recall/list/delete/secret_get) - FastAPI HTTP API with multi-user Bearer token auth (API_KEYS JSON map) - Regex-based credential detection with auto-redaction - AES-256-GCM encryption fallback for non-Vault deployments - Vault KV v2 client (stdlib urllib, K8s SA auto-auth) - Per-user data isolation (all queries scoped by user_id) - Secret migration endpoint for existing plain-text credentials - Backward-compatible env var aliases (CLAUDE_MEMORY_API_URL) Infrastructure: - Docker + docker-compose (API + PostgreSQL + optional Vault) - Woodpecker CI (test → build → push → kubectl deploy) - GitHub Actions CI (Python 3.11/3.12/3.13) + Release (GHCR + PyPI) - Helm chart + raw Kubernetes manifests 96 tests passing across 6 test files.
This commit is contained in:
commit
0ed5e1e016
40 changed files with 3381 additions and 0 deletions
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
304
tests/test_api.py
Normal file
304
tests/test_api.py
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
"""Tests for the Claude Memory API endpoints."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from claude_memory.api.auth import AuthUser
|
||||
|
||||
|
||||
# Helpers to build mock asyncpg rows (they behave like dicts with attribute access)
|
||||
class MockRow(dict):
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(key)
|
||||
|
||||
|
||||
def _make_memory_row(**overrides):
|
||||
now = datetime.now(timezone.utc)
|
||||
defaults = {
|
||||
"id": 1,
|
||||
"user_id": "testuser",
|
||||
"content": "test content",
|
||||
"category": "facts",
|
||||
"tags": "",
|
||||
"expanded_keywords": "",
|
||||
"importance": 0.5,
|
||||
"is_sensitive": False,
|
||||
"vault_path": None,
|
||||
"encrypted_content": None,
|
||||
"rank": 0.5,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return MockRow(defaults)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pool():
|
||||
"""Create a mock asyncpg pool with connection context manager."""
|
||||
pool = MagicMock()
|
||||
conn = AsyncMock()
|
||||
|
||||
# pool.acquire() returns an async context manager yielding conn
|
||||
acm = MagicMock()
|
||||
acm.__aenter__ = AsyncMock(return_value=conn)
|
||||
acm.__aexit__ = AsyncMock(return_value=False)
|
||||
pool.acquire.return_value = acm
|
||||
|
||||
return pool, conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user():
|
||||
return AuthUser(user_id="testuser")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_pool, test_user):
|
||||
"""Create an AsyncClient with mocked dependencies."""
|
||||
pool, conn = mock_pool
|
||||
|
||||
# Reload modules with test API key
|
||||
with patch.dict(os.environ, {"API_KEY": "test-key", "API_KEYS": "", "DATABASE_URL": "postgresql://test"}):
|
||||
import claude_memory.api.auth as auth_mod
|
||||
import claude_memory.api.database as db_mod
|
||||
import claude_memory.api.app as app_mod
|
||||
|
||||
importlib.reload(auth_mod)
|
||||
importlib.reload(db_mod)
|
||||
importlib.reload(app_mod)
|
||||
|
||||
# Override database pool
|
||||
db_mod.pool = pool
|
||||
|
||||
# Override auth to return our test user
|
||||
async def mock_get_user(authorization: str = ""):
|
||||
return test_user
|
||||
|
||||
app_mod.app.dependency_overrides[auth_mod.get_current_user] = mock_get_user
|
||||
|
||||
transport = ASGITransport(app=app_mod.app)
|
||||
return AsyncClient(transport=transport, base_url="http://test"), conn, app_mod
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint_no_auth(client):
|
||||
ac, conn, app_mod = client
|
||||
async with ac:
|
||||
resp = await ac.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_memory_creates_record_with_user_id(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = _make_memory_row(id=42, category="facts", importance=0.7)
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories",
|
||||
json={"content": "Python is great", "category": "facts", "importance": 0.7},
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == 42
|
||||
assert data["category"] == "facts"
|
||||
assert data["importance"] == 0.7
|
||||
|
||||
# Verify INSERT was called with user_id
|
||||
call_args = conn.fetchrow.call_args
|
||||
assert call_args[0][1] == "testuser" # user_id is the second positional arg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_returns_only_user_memories(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetch.return_value = [
|
||||
_make_memory_row(id=1, content="user memory", is_sensitive=False),
|
||||
]
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/recall",
|
||||
json={"context": "test query"},
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
results = resp.json()
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "user memory"
|
||||
|
||||
# Verify query includes user_id filter
|
||||
call_args = conn.fetch.call_args
|
||||
assert call_args[0][1] == "testuser"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_redacts_sensitive_memories(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetch.return_value = [
|
||||
_make_memory_row(id=5, content="[REDACTED]", is_sensitive=True),
|
||||
]
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/recall",
|
||||
json={"context": "secrets"},
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
results = resp.json()
|
||||
assert "[SENSITIVE" in results[0]["content"]
|
||||
assert "secret_get(id=5)" in results[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_returns_only_user_memories(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetch.return_value = [
|
||||
_make_memory_row(id=1, content="mem1"),
|
||||
_make_memory_row(id=2, content="mem2"),
|
||||
]
|
||||
|
||||
async with ac:
|
||||
resp = await ac.get(
|
||||
"/api/memories",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
results = resp.json()
|
||||
assert len(results) == 2
|
||||
|
||||
# Verify user_id filter
|
||||
call_args = conn.fetch.call_args
|
||||
assert call_args[0][1] == "testuser"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_only_user_memories(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = _make_memory_row(id=10, vault_path=None)
|
||||
conn.execute.return_value = None
|
||||
|
||||
async with ac:
|
||||
resp = await ac.delete(
|
||||
"/api/memories/10",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"deleted": 10}
|
||||
|
||||
# Verify both SELECT and DELETE include user_id
|
||||
fetchrow_args = conn.fetchrow.call_args
|
||||
assert fetchrow_args[0][1] == 10 # memory_id
|
||||
assert fetchrow_args[0][2] == "testuser" # user_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_memory_returns_404(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = None
|
||||
|
||||
async with ac:
|
||||
resp = await ac.delete(
|
||||
"/api/memories/999",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secret_endpoint_returns_plaintext(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = _make_memory_row(
|
||||
id=7, content="my secret value", is_sensitive=False,
|
||||
vault_path=None, encrypted_content=None,
|
||||
)
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/7/secret",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == 7
|
||||
assert data["content"] == "my secret value"
|
||||
assert data["source"] == "plaintext"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secret_endpoint_returns_vault_content(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = _make_memory_row(
|
||||
id=8, content="[REDACTED]", is_sensitive=True,
|
||||
vault_path="claude-memory/testuser/mem-8", encrypted_content=None,
|
||||
)
|
||||
|
||||
with patch("claude_memory.api.app.get_secret", return_value="actual-secret-from-vault"):
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/8/secret",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["content"] == "actual-secret-from-vault"
|
||||
assert data["source"] == "vault"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secret_endpoint_nonexistent_returns_404(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = None
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/999/secret",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_memories(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.side_effect = [
|
||||
_make_memory_row(id=100, category="facts", importance=0.5),
|
||||
_make_memory_row(id=101, category="preferences", importance=0.8),
|
||||
]
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/import",
|
||||
json=[
|
||||
{"content": "fact one", "category": "facts"},
|
||||
{"content": "pref one", "category": "preferences", "importance": 0.8},
|
||||
],
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["id"] == 100
|
||||
assert data[1]["id"] == 101
|
||||
87
tests/test_auth.py
Normal file
87
tests/test_auth.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Tests for multi-user authentication."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def _reload_auth(env_vars: dict):
|
||||
"""Reload the auth module with given environment variables."""
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
# Clear existing env vars that might interfere
|
||||
for key in ("API_KEY", "API_KEYS"):
|
||||
os.environ.pop(key, None)
|
||||
for key, val in env_vars.items():
|
||||
os.environ[key] = val
|
||||
|
||||
import claude_memory.api.auth as auth_mod
|
||||
|
||||
importlib.reload(auth_mod)
|
||||
return auth_mod
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_api_key_maps_to_default():
|
||||
auth = _reload_auth({"API_KEY": "test-key-123", "API_KEYS": ""})
|
||||
user = await auth.get_current_user(authorization="Bearer test-key-123")
|
||||
assert user.user_id == "default"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_api_keys_maps_to_correct_user():
|
||||
auth = _reload_auth({
|
||||
"API_KEYS": '{"viktor": "key-viktor", "alice": "key-alice"}',
|
||||
"API_KEY": "",
|
||||
})
|
||||
user_v = await auth.get_current_user(authorization="Bearer key-viktor")
|
||||
assert user_v.user_id == "viktor"
|
||||
|
||||
user_a = await auth.get_current_user(authorization="Bearer key-alice")
|
||||
assert user_a.user_id == "alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_key_returns_401():
|
||||
auth = _reload_auth({"API_KEY": "valid-key", "API_KEYS": ""})
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth.get_current_user(authorization="Bearer wrong-key")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_bearer_prefix_still_works():
|
||||
auth = _reload_auth({"API_KEY": "my-key", "API_KEYS": ""})
|
||||
# Without Bearer prefix, removeprefix("Bearer ") returns "my-key" unchanged
|
||||
# so the raw token still matches the key
|
||||
user = await auth.get_current_user(authorization="my-key")
|
||||
assert user.user_id == "default"
|
||||
|
||||
# With proper Bearer prefix it also works
|
||||
user = await auth.get_current_user(authorization="Bearer my-key")
|
||||
assert user.user_id == "default"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_authorization_header_raises_422():
|
||||
"""FastAPI raises 422 when required Header is missing.
|
||||
This is tested via the app integration, not the function directly,
|
||||
since FastAPI handles the missing header before the function runs.
|
||||
"""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
# Need to reload with valid keys so the app can start
|
||||
_reload_auth({"API_KEY": "test-key", "API_KEYS": ""})
|
||||
|
||||
# Import app after auth is configured
|
||||
import claude_memory.api.app as app_mod
|
||||
|
||||
importlib.reload(app_mod)
|
||||
|
||||
transport = ASGITransport(app=app_mod.app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
# Skip lifespan since we don't have a real DB
|
||||
resp = await client.get("/api/memories")
|
||||
assert resp.status_code == 422
|
||||
132
tests/test_credential_detector.py
Normal file
132
tests/test_credential_detector.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""Tests for credential detection and redaction."""
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_memory.credential_detector import (
|
||||
DetectedCredential,
|
||||
detect_credentials,
|
||||
is_sensitive,
|
||||
redact_credentials,
|
||||
)
|
||||
|
||||
|
||||
class TestDetectCredentials:
|
||||
def test_detect_postgres_connection_string(self):
|
||||
text = "db_url = postgres://user:pass@localhost:5432/mydb"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) == 1
|
||||
assert creds[0].type == "connection_string"
|
||||
assert creds[0].confidence == 0.9
|
||||
assert "postgres://" in creds[0].matched_text
|
||||
|
||||
def test_detect_password_assignment(self):
|
||||
text = 'password = "my_super_secret_pw"'
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) >= 1
|
||||
types = [c.type for c in creds]
|
||||
assert "password" in types
|
||||
|
||||
def test_detect_api_key(self):
|
||||
text = "api_key = ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) >= 1
|
||||
types = [c.type for c in creds]
|
||||
assert "api_key" in types
|
||||
|
||||
def test_detect_private_key(self):
|
||||
text = "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA0Z3VS5JJcds3xfn/ygWep4PAtGoSo\n-----END RSA PRIVATE KEY-----"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) == 1
|
||||
assert creds[0].type == "private_key"
|
||||
assert creds[0].confidence == 0.95
|
||||
|
||||
def test_detect_bearer_token(self):
|
||||
text = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkw"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) >= 1
|
||||
types = [c.type for c in creds]
|
||||
assert "bearer_token" in types
|
||||
|
||||
def test_detect_aws_key(self):
|
||||
text = "aws_access_key_id = AKIAIOSFODNN7EXAMPLE"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) >= 1
|
||||
types = [c.type for c in creds]
|
||||
assert "aws_key" in types
|
||||
|
||||
def test_detect_github_token(self):
|
||||
text = "GITHUB_TOKEN=ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmn"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) >= 1
|
||||
types = [c.type for c in creds]
|
||||
assert "github_token" in types
|
||||
|
||||
def test_no_false_positives_on_normal_text(self):
|
||||
text = "This is a normal paragraph about programming. It discusses variables, functions, and classes."
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) == 0
|
||||
|
||||
def test_no_false_positives_on_short_password(self):
|
||||
# password values shorter than 8 chars should not match
|
||||
text = 'password = "short"'
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) == 0
|
||||
|
||||
def test_min_confidence_filtering(self):
|
||||
text = 'secret = "abcdefghijklmnopqrstuvwxyz"'
|
||||
all_creds = detect_credentials(text, min_confidence=0.5)
|
||||
high_creds = detect_credentials(text, min_confidence=0.9)
|
||||
assert len(all_creds) >= len(high_creds)
|
||||
|
||||
def test_overlapping_matches_keep_highest_confidence(self):
|
||||
# A text that could match both token and generic_secret
|
||||
text = 'secret = "abcdefghijklmnopqrstuvwxyz1234567890"'
|
||||
creds = detect_credentials(text, min_confidence=0.5)
|
||||
# Should not have overlapping ranges for the same span
|
||||
for i, c1 in enumerate(creds):
|
||||
for c2 in creds[i + 1:]:
|
||||
# No credential should be fully contained within another
|
||||
assert not (c1.start <= c2.start and c1.end >= c2.end)
|
||||
|
||||
|
||||
class TestRedactCredentials:
|
||||
def test_redaction_replaces_with_marker(self):
|
||||
text = "db_url = postgres://user:pass@localhost:5432/mydb"
|
||||
creds = detect_credentials(text)
|
||||
redacted = redact_credentials(text, creds)
|
||||
assert "[REDACTED:connection_string]" in redacted
|
||||
assert "postgres://" not in redacted
|
||||
|
||||
def test_redaction_preserves_surrounding_text(self):
|
||||
text = "before postgres://user:pass@localhost/db after"
|
||||
creds = detect_credentials(text)
|
||||
redacted = redact_credentials(text, creds)
|
||||
assert redacted.startswith("before ")
|
||||
assert redacted.endswith(" after")
|
||||
|
||||
def test_redaction_no_credentials(self):
|
||||
text = "nothing sensitive here"
|
||||
redacted = redact_credentials(text, [])
|
||||
assert redacted == text
|
||||
|
||||
def test_redaction_multiple_credentials(self):
|
||||
text = 'password = "mysecretpw123" and api_key = ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890'
|
||||
creds = detect_credentials(text)
|
||||
redacted = redact_credentials(text, creds)
|
||||
assert "mysecretpw123" not in redacted
|
||||
assert "[REDACTED:" in redacted
|
||||
|
||||
|
||||
class TestIsSensitive:
|
||||
def test_sensitive_text(self):
|
||||
assert is_sensitive("password = supersecretvalue123")
|
||||
|
||||
def test_non_sensitive_text(self):
|
||||
assert not is_sensitive("just a normal log message")
|
||||
|
||||
def test_respects_min_confidence(self):
|
||||
text = 'secret = "abcdefghijklmnopqrstuvwxyz"'
|
||||
# Low confidence should detect
|
||||
assert is_sensitive(text, min_confidence=0.5)
|
||||
# Very high confidence should not detect generic_secret
|
||||
assert not is_sensitive(text, min_confidence=0.95)
|
||||
134
tests/test_crypto.py
Normal file
134
tests/test_crypto.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""Tests for AES-256-GCM encryption module."""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_memory.crypto import (
|
||||
ENCRYPTION_KEY_ENV,
|
||||
decrypt,
|
||||
decrypt_b64,
|
||||
encrypt,
|
||||
encrypt_b64,
|
||||
is_encryption_configured,
|
||||
)
|
||||
|
||||
# A valid 32-byte hex key for testing
|
||||
TEST_HEX_KEY = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
TEST_PASSPHRASE = "my-test-passphrase"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hex_key_env(monkeypatch):
|
||||
monkeypatch.setenv(ENCRYPTION_KEY_ENV, TEST_HEX_KEY)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def passphrase_env(monkeypatch):
|
||||
monkeypatch.setenv(ENCRYPTION_KEY_ENV, TEST_PASSPHRASE)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_key_env(monkeypatch):
|
||||
monkeypatch.delenv(ENCRYPTION_KEY_ENV, raising=False)
|
||||
|
||||
|
||||
class TestEncryptionConfigured:
|
||||
def test_configured_with_hex_key(self, hex_key_env):
|
||||
assert is_encryption_configured() is True
|
||||
|
||||
def test_configured_with_passphrase(self, passphrase_env):
|
||||
assert is_encryption_configured() is True
|
||||
|
||||
def test_not_configured_without_env(self, no_key_env):
|
||||
assert is_encryption_configured() is False
|
||||
|
||||
|
||||
class TestEncryptDecrypt:
|
||||
def test_roundtrip_with_hex_key(self, hex_key_env):
|
||||
plaintext = "Hello, this is a secret message!"
|
||||
encrypted = encrypt(plaintext)
|
||||
decrypted = decrypt(encrypted)
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_roundtrip_with_passphrase(self, passphrase_env):
|
||||
plaintext = "Another secret message with passphrase key"
|
||||
encrypted = encrypt(plaintext)
|
||||
decrypted = decrypt(encrypted)
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_different_plaintexts_produce_different_ciphertexts(self, hex_key_env):
|
||||
ct1 = encrypt("message one")
|
||||
ct2 = encrypt("message two")
|
||||
assert ct1 != ct2
|
||||
|
||||
def test_same_plaintext_produces_different_ciphertexts(self, hex_key_env):
|
||||
"""Due to random nonce, encrypting the same text twice gives different results."""
|
||||
ct1 = encrypt("same message")
|
||||
ct2 = encrypt("same message")
|
||||
assert ct1 != ct2
|
||||
|
||||
def test_missing_key_raises_on_encrypt(self, no_key_env):
|
||||
with pytest.raises(RuntimeError, match=ENCRYPTION_KEY_ENV):
|
||||
encrypt("test")
|
||||
|
||||
def test_missing_key_raises_on_decrypt(self, no_key_env):
|
||||
with pytest.raises(RuntimeError, match=ENCRYPTION_KEY_ENV):
|
||||
decrypt(b"\x00" * 28)
|
||||
|
||||
def test_decrypt_with_wrong_key_fails(self, hex_key_env, monkeypatch):
|
||||
plaintext = "secret data"
|
||||
encrypted = encrypt(plaintext)
|
||||
|
||||
# Change to a different key
|
||||
monkeypatch.setenv(ENCRYPTION_KEY_ENV, "ff" * 32)
|
||||
with pytest.raises(Exception):
|
||||
decrypt(encrypted)
|
||||
|
||||
def test_encrypted_data_format(self, hex_key_env):
|
||||
"""Encrypted data should be at least 12 (nonce) + 16 (tag) bytes."""
|
||||
encrypted = encrypt("x")
|
||||
assert len(encrypted) >= 28 # 12 nonce + 1 plaintext + 16 tag = 29 minimum
|
||||
|
||||
def test_unicode_roundtrip(self, hex_key_env):
|
||||
plaintext = "Unicode test: cafe\u0301, \u00fc\u00f6\u00e4, \U0001f512"
|
||||
decrypted = decrypt(encrypt(plaintext))
|
||||
assert decrypted == plaintext
|
||||
|
||||
|
||||
class TestBase64Variants:
|
||||
def test_b64_roundtrip(self, hex_key_env):
|
||||
plaintext = "base64 test message"
|
||||
encrypted_b64 = encrypt_b64(plaintext)
|
||||
assert isinstance(encrypted_b64, str)
|
||||
decrypted = decrypt_b64(encrypted_b64)
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_b64_output_is_valid_base64(self, hex_key_env):
|
||||
import base64
|
||||
encrypted_b64 = encrypt_b64("test")
|
||||
# Should not raise
|
||||
decoded = base64.b64decode(encrypted_b64)
|
||||
assert len(decoded) >= 28
|
||||
|
||||
|
||||
class TestKeyDerivation:
|
||||
def test_hex_key_used_directly(self, hex_key_env):
|
||||
"""A valid 64-char hex string should be used as-is (32 bytes)."""
|
||||
ct = encrypt("test")
|
||||
pt = decrypt(ct)
|
||||
assert pt == "test"
|
||||
|
||||
def test_passphrase_derived_via_sha256(self, passphrase_env):
|
||||
"""Non-hex strings should be derived via SHA-256."""
|
||||
ct = encrypt("test")
|
||||
pt = decrypt(ct)
|
||||
assert pt == "test"
|
||||
|
||||
def test_short_hex_treated_as_passphrase(self, monkeypatch):
|
||||
"""Hex string that's not exactly 32 bytes should be treated as passphrase."""
|
||||
monkeypatch.setenv(ENCRYPTION_KEY_ENV, "abcd1234")
|
||||
ct = encrypt("test")
|
||||
pt = decrypt(ct)
|
||||
assert pt == "test"
|
||||
342
tests/test_mcp_server.py
Normal file
342
tests/test_mcp_server.py
Normal file
|
|
@ -0,0 +1,342 @@
|
|||
"""Tests for the Claude Memory MCP server."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Force SQLite fallback mode for all tests
|
||||
os.environ.pop("MEMORY_API_KEY", None)
|
||||
os.environ.pop("CLAUDE_MEMORY_API_KEY", None)
|
||||
|
||||
# Add src to path so we can import without installing
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
from claude_memory.mcp_server import MemoryServer, TOOLS, SERVER_NAME, SERVER_VERSION, PROTOCOL_VERSION
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server(tmp_path):
|
||||
"""Create a MemoryServer with a temporary SQLite database."""
|
||||
db_path = str(tmp_path / "test_memory.db")
|
||||
srv = MemoryServer(sqlite_db_path=db_path)
|
||||
yield srv
|
||||
if srv.sqlite_conn:
|
||||
srv.sqlite_conn.close()
|
||||
|
||||
|
||||
class TestSQLiteInit:
|
||||
def test_creates_database(self, tmp_path):
|
||||
db_path = str(tmp_path / "sub" / "test.db")
|
||||
srv = MemoryServer(sqlite_db_path=db_path)
|
||||
assert os.path.exists(db_path)
|
||||
# Verify tables exist
|
||||
cursor = srv.sqlite_conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories'")
|
||||
assert cursor.fetchone() is not None
|
||||
srv.sqlite_conn.close()
|
||||
|
||||
def test_creates_fts_table(self, tmp_path):
|
||||
db_path = str(tmp_path / "test.db")
|
||||
srv = MemoryServer(sqlite_db_path=db_path)
|
||||
cursor = srv.sqlite_conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories_fts'")
|
||||
assert cursor.fetchone() is not None
|
||||
srv.sqlite_conn.close()
|
||||
|
||||
|
||||
class TestMemoryStore:
|
||||
def test_store_basic(self, server):
|
||||
result = server.memory_store({
|
||||
"content": "User prefers dark mode",
|
||||
"expanded_keywords": "dark mode theme preference ui",
|
||||
})
|
||||
assert "Stored memory #1" in result
|
||||
assert "facts" in result
|
||||
|
||||
def test_store_with_category(self, server):
|
||||
result = server.memory_store({
|
||||
"content": "User likes Python",
|
||||
"category": "preferences",
|
||||
"expanded_keywords": "python programming language preference",
|
||||
})
|
||||
assert "preferences" in result
|
||||
|
||||
def test_store_with_importance(self, server):
|
||||
result = server.memory_store({
|
||||
"content": "Critical info",
|
||||
"importance": 0.9,
|
||||
"expanded_keywords": "critical important info",
|
||||
})
|
||||
assert "0.9" in result
|
||||
|
||||
def test_store_requires_content(self, server):
|
||||
with pytest.raises(ValueError, match="content is required"):
|
||||
server.memory_store({"expanded_keywords": "test"})
|
||||
|
||||
def test_store_force_sensitive(self, server):
|
||||
result = server.memory_store({
|
||||
"content": "API key: sk-1234",
|
||||
"force_sensitive": True,
|
||||
"expanded_keywords": "api key secret credential",
|
||||
})
|
||||
assert "Stored memory #1" in result
|
||||
# Verify is_sensitive flag is set
|
||||
cursor = server.sqlite_conn.cursor()
|
||||
cursor.execute("SELECT is_sensitive FROM memories WHERE id = 1")
|
||||
row = cursor.fetchone()
|
||||
assert row["is_sensitive"] == 1
|
||||
|
||||
|
||||
class TestMemoryRecall:
|
||||
def test_recall_finds_memory(self, server):
|
||||
server.memory_store({
|
||||
"content": "User works at Acme Corp",
|
||||
"expanded_keywords": "acme corp company work employer",
|
||||
})
|
||||
result = server.memory_recall({
|
||||
"context": "work",
|
||||
"expanded_query": "company employer job",
|
||||
})
|
||||
assert "Acme Corp" in result
|
||||
assert "Found 1 memories" in result
|
||||
|
||||
def test_recall_no_results(self, server):
|
||||
result = server.memory_recall({
|
||||
"context": "nonexistent topic",
|
||||
"expanded_query": "nothing here at all",
|
||||
})
|
||||
assert "No memories found" in result
|
||||
|
||||
def test_recall_with_category_filter(self, server):
|
||||
server.memory_store({
|
||||
"content": "User prefers vim",
|
||||
"category": "preferences",
|
||||
"expanded_keywords": "vim editor preference text",
|
||||
})
|
||||
server.memory_store({
|
||||
"content": "Project uses React",
|
||||
"category": "projects",
|
||||
"expanded_keywords": "react project frontend framework",
|
||||
})
|
||||
result = server.memory_recall({
|
||||
"context": "preferences",
|
||||
"expanded_query": "vim editor",
|
||||
"category": "preferences",
|
||||
})
|
||||
assert "vim" in result
|
||||
assert "React" not in result
|
||||
|
||||
def test_recall_requires_context(self, server):
|
||||
with pytest.raises(ValueError, match="context is required"):
|
||||
server.memory_recall({"expanded_query": "test"})
|
||||
|
||||
|
||||
class TestMemoryList:
|
||||
def test_list_empty(self, server):
|
||||
result = server.memory_list({})
|
||||
assert "No memories stored yet" in result
|
||||
|
||||
def test_list_with_memories(self, server):
|
||||
server.memory_store({
|
||||
"content": "Memory one",
|
||||
"expanded_keywords": "one first test",
|
||||
})
|
||||
server.memory_store({
|
||||
"content": "Memory two",
|
||||
"expanded_keywords": "two second test",
|
||||
})
|
||||
result = server.memory_list({})
|
||||
assert "Memory one" in result
|
||||
assert "Memory two" in result
|
||||
assert "2 shown" in result
|
||||
|
||||
def test_list_with_category(self, server):
|
||||
server.memory_store({
|
||||
"content": "A fact",
|
||||
"category": "facts",
|
||||
"expanded_keywords": "fact test",
|
||||
})
|
||||
server.memory_store({
|
||||
"content": "A preference",
|
||||
"category": "preferences",
|
||||
"expanded_keywords": "preference test",
|
||||
})
|
||||
result = server.memory_list({"category": "facts"})
|
||||
assert "A fact" in result
|
||||
assert "A preference" not in result
|
||||
|
||||
def test_list_empty_category(self, server):
|
||||
result = server.memory_list({"category": "projects"})
|
||||
assert "No memories in category 'projects'" in result
|
||||
|
||||
def test_list_respects_limit(self, server):
|
||||
for i in range(5):
|
||||
server.memory_store({
|
||||
"content": f"Memory {i}",
|
||||
"expanded_keywords": f"memory number {i}",
|
||||
})
|
||||
result = server.memory_list({"limit": 2})
|
||||
assert "2 shown" in result
|
||||
|
||||
|
||||
class TestMemoryDelete:
|
||||
def test_delete_existing(self, server):
|
||||
server.memory_store({
|
||||
"content": "To be deleted",
|
||||
"expanded_keywords": "delete remove test",
|
||||
})
|
||||
result = server.memory_delete({"id": 1})
|
||||
assert "Deleted memory #1" in result
|
||||
assert "To be deleted" in result
|
||||
|
||||
def test_delete_nonexistent(self, server):
|
||||
result = server.memory_delete({"id": 999})
|
||||
assert "not found" in result
|
||||
|
||||
def test_delete_requires_id(self, server):
|
||||
with pytest.raises(ValueError, match="id is required"):
|
||||
server.memory_delete({})
|
||||
|
||||
|
||||
class TestSecretGet:
|
||||
def test_secret_get_sensitive(self, server):
|
||||
server.memory_store({
|
||||
"content": "secret password 12345",
|
||||
"force_sensitive": True,
|
||||
"expanded_keywords": "password secret credential",
|
||||
})
|
||||
result = server.secret_get({"id": 1})
|
||||
assert "secret password 12345" in result
|
||||
|
||||
def test_secret_get_not_sensitive(self, server):
|
||||
server.memory_store({
|
||||
"content": "public info",
|
||||
"expanded_keywords": "public info test",
|
||||
})
|
||||
result = server.secret_get({"id": 1})
|
||||
assert "not marked as sensitive" in result
|
||||
|
||||
def test_secret_get_nonexistent(self, server):
|
||||
result = server.secret_get({"id": 999})
|
||||
assert "not found" in result
|
||||
|
||||
def test_secret_get_requires_id(self, server):
|
||||
with pytest.raises(ValueError, match="id is required"):
|
||||
server.secret_get({})
|
||||
|
||||
|
||||
class TestMCPProtocol:
|
||||
def test_handle_initialize(self, server):
|
||||
result = server.handle_initialize({})
|
||||
assert result["protocolVersion"] == PROTOCOL_VERSION
|
||||
assert result["serverInfo"]["name"] == SERVER_NAME
|
||||
assert result["serverInfo"]["version"] == SERVER_VERSION
|
||||
assert "tools" in result["capabilities"]
|
||||
|
||||
def test_handle_tools_list(self, server):
|
||||
result = server.handle_tools_list({})
|
||||
tools = result["tools"]
|
||||
assert len(tools) == 5
|
||||
names = {t["name"] for t in tools}
|
||||
assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get"}
|
||||
|
||||
def test_handle_tools_call_store(self, server):
|
||||
result = server.handle_tools_call({
|
||||
"name": "memory_store",
|
||||
"arguments": {
|
||||
"content": "test memory",
|
||||
"expanded_keywords": "test memory keywords",
|
||||
},
|
||||
})
|
||||
assert not result.get("isError", False)
|
||||
assert "Stored memory" in result["content"][0]["text"]
|
||||
|
||||
def test_handle_tools_call_unknown(self, server):
|
||||
result = server.handle_tools_call({
|
||||
"name": "nonexistent_tool",
|
||||
"arguments": {},
|
||||
})
|
||||
assert result["isError"] is True
|
||||
assert "Unknown tool" in result["content"][0]["text"]
|
||||
|
||||
def test_handle_tools_call_error(self, server):
|
||||
result = server.handle_tools_call({
|
||||
"name": "memory_store",
|
||||
"arguments": {}, # missing content
|
||||
})
|
||||
assert result["isError"] is True
|
||||
assert "Error" in result["content"][0]["text"]
|
||||
|
||||
|
||||
class TestProcessMessage:
|
||||
def test_initialize(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {},
|
||||
})
|
||||
assert response["jsonrpc"] == "2.0"
|
||||
assert response["id"] == 1
|
||||
assert "result" in response
|
||||
assert response["result"]["serverInfo"]["name"] == SERVER_NAME
|
||||
|
||||
def test_tools_list(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/list",
|
||||
"params": {},
|
||||
})
|
||||
assert "result" in response
|
||||
assert len(response["result"]["tools"]) == 5
|
||||
|
||||
def test_tools_call(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "memory_store",
|
||||
"arguments": {
|
||||
"content": "via process_message",
|
||||
"expanded_keywords": "process message test",
|
||||
},
|
||||
},
|
||||
})
|
||||
assert "result" in response
|
||||
assert "Stored memory" in response["result"]["content"][0]["text"]
|
||||
|
||||
def test_unknown_method(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "unknown/method",
|
||||
"params": {},
|
||||
})
|
||||
assert "error" in response
|
||||
assert response["error"]["code"] == -32601
|
||||
assert "Method not found" in response["error"]["message"]
|
||||
|
||||
def test_notification_no_id(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized",
|
||||
"params": {},
|
||||
})
|
||||
assert response is None
|
||||
|
||||
def test_jsonrpc_response_format(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"method": "initialize",
|
||||
"params": {},
|
||||
})
|
||||
# Verify it's valid JSON when serialized
|
||||
serialized = json.dumps(response)
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed["jsonrpc"] == "2.0"
|
||||
assert parsed["id"] == 5
|
||||
154
tests/test_vault_client.py
Normal file
154
tests/test_vault_client.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
"""Tests for Vault KV v2 client with mocked urllib."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_memory.vault_client import VaultClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vault_env(monkeypatch):
|
||||
monkeypatch.setenv("VAULT_ADDR", "http://vault.example.com:8200")
|
||||
monkeypatch.setenv("VAULT_TOKEN", "s.testtoken123")
|
||||
|
||||
|
||||
class TestVaultClientInit:
|
||||
def test_missing_addr_raises_value_error(self, monkeypatch):
|
||||
monkeypatch.delenv("VAULT_ADDR", raising=False)
|
||||
monkeypatch.delenv("VAULT_TOKEN", raising=False)
|
||||
with pytest.raises(ValueError, match="Vault address not configured"):
|
||||
VaultClient()
|
||||
|
||||
def test_init_with_explicit_args(self):
|
||||
client = VaultClient(addr="http://localhost:8200", token="mytoken")
|
||||
assert client.addr == "http://localhost:8200"
|
||||
assert client.token == "mytoken"
|
||||
assert client.mount == "secret"
|
||||
|
||||
def test_init_from_env(self, vault_env):
|
||||
client = VaultClient()
|
||||
assert client.addr == "http://vault.example.com:8200"
|
||||
assert client.token == "s.testtoken123"
|
||||
|
||||
def test_addr_trailing_slash_stripped(self):
|
||||
client = VaultClient(addr="http://localhost:8200/", token="t")
|
||||
assert client.addr == "http://localhost:8200"
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", mock_open(read_data="fake-jwt-token"))
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_kubernetes_sa_token_auto_detection(self, mock_urlopen, mock_exists, monkeypatch):
|
||||
monkeypatch.setenv("VAULT_ADDR", "http://vault:8200")
|
||||
monkeypatch.delenv("VAULT_TOKEN", raising=False)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = json.dumps({
|
||||
"auth": {"client_token": "s.k8s-token-abc"}
|
||||
}).encode()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
client = VaultClient()
|
||||
assert client.token == "s.k8s-token-abc"
|
||||
|
||||
|
||||
class TestVaultRead:
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_read_secret_returns_data(self, mock_urlopen, vault_env):
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = json.dumps({
|
||||
"data": {"data": {"username": "admin", "password": "secret"}}
|
||||
}).encode()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
client = VaultClient()
|
||||
result = client.read("myapp/config")
|
||||
assert result == {"username": "admin", "password": "secret"}
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_read_returns_none_for_404(self, mock_urlopen, vault_env):
|
||||
import urllib.error
|
||||
mock_urlopen.side_effect = urllib.error.HTTPError(
|
||||
url="http://vault:8200/v1/secret/data/missing",
|
||||
code=404,
|
||||
msg="Not Found",
|
||||
hdrs={},
|
||||
fp=BytesIO(b""),
|
||||
)
|
||||
|
||||
client = VaultClient()
|
||||
result = client.read("missing/path")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestVaultWrite:
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_write_secret_sends_correct_request(self, mock_urlopen, vault_env):
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = json.dumps({
|
||||
"data": {"created_time": "2024-01-01T00:00:00Z", "version": 1}
|
||||
}).encode()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
client = VaultClient()
|
||||
result = client.write("myapp/config", {"key": "value"})
|
||||
|
||||
# Verify the request was made with correct data
|
||||
call_args = mock_urlopen.call_args
|
||||
request = call_args[0][0]
|
||||
assert request.full_url == "http://vault.example.com:8200/v1/secret/data/myapp/config"
|
||||
assert request.method == "POST"
|
||||
body = json.loads(request.data.decode())
|
||||
assert body == {"data": {"key": "value"}}
|
||||
|
||||
|
||||
class TestVaultDelete:
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_delete_returns_true_on_success(self, mock_urlopen, vault_env):
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b"{}"
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
client = VaultClient()
|
||||
assert client.delete("myapp/config") is True
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_delete_returns_false_on_error(self, mock_urlopen, vault_env):
|
||||
import urllib.error
|
||||
mock_urlopen.side_effect = urllib.error.HTTPError(
|
||||
url="http://vault:8200/v1/secret/data/missing",
|
||||
code=500,
|
||||
msg="Internal Server Error",
|
||||
hdrs={},
|
||||
fp=BytesIO(b"error"),
|
||||
)
|
||||
|
||||
client = VaultClient()
|
||||
assert client.delete("missing/path") is False
|
||||
|
||||
|
||||
class TestVaultListSecrets:
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_list_secrets(self, mock_urlopen, vault_env):
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = json.dumps({
|
||||
"data": {"keys": ["secret1", "secret2/"]}
|
||||
}).encode()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
client = VaultClient()
|
||||
result = client.list_secrets("myapp")
|
||||
assert result == ["secret1", "secret2/"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue