feat: standalone claude-memory-mcp with multi-user support and Vault integration

Extracted from private infra repo into standalone open-source project.

Three operating modes:
- Local: SQLite + FTS5 (zero dependencies)
- Server: PostgreSQL via HTTP API with multi-user auth
- Full: PostgreSQL + HashiCorp Vault for secret management

Features:
- MCP stdio server with 5 tools (store/recall/list/delete/secret_get)
- FastAPI HTTP API with multi-user Bearer token auth (API_KEYS JSON map)
- Regex-based credential detection with auto-redaction
- AES-256-GCM encryption fallback for non-Vault deployments
- Vault KV v2 client (stdlib urllib, K8s SA auto-auth)
- Per-user data isolation (all queries scoped by user_id)
- Secret migration endpoint for existing plain-text credentials
- Backward-compatible env var aliases (CLAUDE_MEMORY_API_URL)

Infrastructure:
- Docker + docker-compose (API + PostgreSQL + optional Vault)
- Woodpecker CI (test → build → push → kubectl deploy)
- GitHub Actions CI (Python 3.11/3.12/3.13) + Release (GHCR + PyPI)
- Helm chart + raw Kubernetes manifests

96 tests passing across 6 test files.
This commit is contained in:
Viktor Barzin 2026-03-14 09:42:05 +00:00
commit 0ed5e1e016
No known key found for this signature in database
GPG key ID: 0EB088298288D958
40 changed files with 3381 additions and 0 deletions

36
.github/workflows/ci.yml vendored Normal file
View file

@ -0,0 +1,36 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- run: pip install -e ".[api,dev]"
- run: ruff check src/ tests/
- run: mypy src/claude_memory/
- run: pytest tests/ -v --tb=short
docker:
runs-on: ubuntu-latest
needs: test
steps:
- uses: actions/checkout@v4
- uses: docker/setup-buildx-action@v3
- uses: docker/build-push-action@v6
with:
context: .
file: docker/Dockerfile
push: false
tags: claude-memory-mcp:test

40
.github/workflows/release.yml vendored Normal file
View file

@ -0,0 +1,40 @@
name: Release
on:
push:
tags: ["v*"]
jobs:
docker:
runs-on: ubuntu-latest
permissions:
packages: write
steps:
- uses: actions/checkout@v4
- uses: docker/setup-buildx-action@v3
- uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- uses: docker/build-push-action@v6
with:
context: .
file: docker/Dockerfile
push: true
tags: |
ghcr.io/${{ github.repository }}:${{ github.ref_name }}
ghcr.io/${{ github.repository }}:latest
pypi:
runs-on: ubuntu-latest
permissions:
id-token: write
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- run: pip install build
- run: python -m build
- uses: pypa/gh-action-pypi-publish@release/v1

45
.gitignore vendored Normal file
View file

@ -0,0 +1,45 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
*.egg-info/
*.egg
dist/
build/
.eggs/
# Virtual environments
.venv/
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# Testing
.pytest_cache/
.coverage
htmlcov/
.mypy_cache/
.ruff_cache/
# Environment
.env
.env.local
.env.production
# OS
.DS_Store
Thumbs.db
# Docker
docker/pgdata/
# Database
*.db
*.sqlite3

43
.woodpecker.yml Normal file
View file

@ -0,0 +1,43 @@
when:
- event: push
branch: main
clone:
git:
image: woodpeckerci/plugin-git
settings:
attempts: 5
backoff: 10s
steps:
- name: test
image: python:3.12-slim
commands:
- pip install -e ".[api,dev]"
- ruff check src/ tests/
- pytest tests/ -v --tb=short
- name: build-and-push
image: woodpeckerci/plugin-docker-buildx
depends_on:
- test
settings:
username: viktorbarzin
password:
from_secret: dockerhub-token
repo: viktorbarzin/claude-memory-mcp
dockerfile: docker/Dockerfile
context: .
platforms:
- linux/amd64
tags:
- "${CI_PIPELINE_NUMBER}"
- latest
- name: deploy
image: bitnami/kubectl:latest
depends_on:
- build-and-push
commands:
- kubectl rollout restart deployment/claude-memory -n claude-memory
- kubectl rollout status deployment/claude-memory -n claude-memory --timeout=120s

191
LICENSE Normal file
View file

@ -0,0 +1,191 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to the Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by the Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding any notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
Copyright 2026 Viktor Barzin
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

224
README.md Normal file
View file

@ -0,0 +1,224 @@
# Claude Memory MCP
A persistent memory layer for Claude Code that stores knowledge across sessions. Operates as an MCP (Model Context Protocol) server with optional API and database backends.
## Operating Modes
| Mode | Storage | Auth | Use Case |
|------|---------|------|----------|
| **Local** | SQLite file | None | Single user, local Claude Code |
| **Server** | SQLite file | API key | Remote access, single user |
| **Full** | PostgreSQL | API keys + Vault | Multi-user, team deployment |
## Quick Start
### Local Mode (MCP stdio)
Install and configure Claude Code to use the MCP server directly:
```bash
pip install claude-memory-mcp
```
Add to your Claude Code MCP config (`~/.claude/plugins/`):
```json
{
"mcpServers": {
"memory": {
"command": "python",
"args": ["-m", "claude_memory.mcp_server"],
"env": {
"MEMORY_DB_PATH": "~/.claude/memory.db"
}
}
}
}
```
### Server Mode (API)
Run the API server with SQLite:
```bash
pip install claude-memory-mcp[api]
export DATABASE_URL="sqlite:///./memory.db"
export API_KEY="your-secret-key"
uvicorn claude_memory.api.app:app --host 0.0.0.0 --port 8000
```
Configure Claude Code to connect via HTTP:
```json
{
"mcpServers": {
"memory": {
"command": "python",
"args": ["-m", "claude_memory.mcp_server"],
"env": {
"MEMORY_API_URL": "http://localhost:8000",
"MEMORY_API_KEY": "your-secret-key"
}
}
}
}
```
### Full Mode (Docker Compose)
```bash
cd docker
docker compose up -d
```
This starts the API server with PostgreSQL. See [Docker Compose](#docker-compose) for details.
## Docker Compose
The dev environment includes the API server and PostgreSQL:
```bash
cd docker
docker compose up -d
```
To include HashiCorp Vault for secret management:
```bash
docker compose --profile vault up -d
```
### Environment Variables
| Variable | Description | Default |
|----------|-------------|---------|
| `DATABASE_URL` | Database connection string | `sqlite:///./memory.db` |
| `API_KEY` | Single-user API key | None |
| `API_KEYS` | Multi-user JSON map `{"user": "key"}` | None |
| `VAULT_ADDR` | Vault server address | None |
| `VAULT_TOKEN` | Vault authentication token | None |
## Multi-User Setup
For team deployments, use the `API_KEYS` environment variable with a JSON mapping of usernames to API keys:
```bash
export API_KEYS='{"alice": "key-alice-xxx", "bob": "key-bob-yyy"}'
```
Each user gets isolated memory storage. The username is extracted from the API key on each request.
## Vault Integration
For production deployments, store API keys in HashiCorp Vault instead of environment variables:
```bash
export VAULT_ADDR="https://vault.example.com"
export VAULT_TOKEN="s.xxxxxxxxxxxx"
```
The server reads API keys from the Vault KV store at `secret/claude-memory/api-keys`.
## API Reference
### Health Check
```
GET /health
```
### Store Memory
```
POST /api/v1/memories
Authorization: Bearer <api-key>
Content-Type: application/json
{
"content": "The user prefers dark mode",
"tags": ["preferences", "ui"],
"source": "conversation"
}
```
### Recall Memories
```
GET /api/v1/memories?q=dark+mode&limit=10
Authorization: Bearer <api-key>
```
### List Memories
```
GET /api/v1/memories?tags=preferences&limit=20
Authorization: Bearer <api-key>
```
### Delete Memory
```
DELETE /api/v1/memories/{id}
Authorization: Bearer <api-key>
```
## Kubernetes Deployment
### Helm
```bash
helm install claude-memory deploy/helm/claude-memory \
--set env.DATABASE_URL="postgresql://user:pass@host:5432/db" \
--set env.API_KEY="your-key" \
--set ingress.host="claude-memory.yourdomain.com"
```
### Raw Manifests
```bash
kubectl apply -f deploy/kubernetes/namespace.yaml
# Create secret with your credentials first:
kubectl create secret generic claude-memory-secrets \
-n claude-memory \
--from-literal=database-url="postgresql://user:pass@host:5432/db" \
--from-literal=api-key="your-key"
kubectl apply -f deploy/kubernetes/
```
## Development
### Setup
```bash
git clone https://github.com/viktorbarzin/claude-memory-mcp.git
cd claude-memory-mcp
python -m venv .venv
source .venv/bin/activate
pip install -e ".[api,dev]"
```
### Running Tests
```bash
pytest tests/ -v
```
### Linting
```bash
ruff check src/ tests/
mypy src/claude_memory/
```
### Building
```bash
pip install build
python -m build
```
## License
Apache License 2.0. See [LICENSE](LICENSE) for details.

View file

@ -0,0 +1,6 @@
apiVersion: v2
name: claude-memory
description: Claude Memory MCP API server
type: application
version: 1.0.0
appVersion: "1.0.0"

View file

@ -0,0 +1,35 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ .Release.Name }}
labels:
app: {{ .Release.Name }}
spec:
replicas: {{ .Values.replicaCount }}
selector:
matchLabels:
app: {{ .Release.Name }}
template:
metadata:
labels:
app: {{ .Release.Name }}
spec:
containers:
- name: {{ .Release.Name }}
image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}"
imagePullPolicy: {{ .Values.image.pullPolicy }}
ports:
- containerPort: {{ .Values.service.targetPort }}
{{- range $key, $value := .Values.env }}
{{- if $value }}
env:
- name: {{ $key }}
value: {{ $value | quote }}
{{- end }}
{{- end }}
livenessProbe:
{{- toYaml .Values.livenessProbe | nindent 12 }}
readinessProbe:
{{- toYaml .Values.readinessProbe | nindent 12 }}
resources:
{{- toYaml .Values.resources | nindent 12 }}

View file

@ -0,0 +1,25 @@
{{- if .Values.ingress.enabled }}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ .Release.Name }}
annotations:
nginx.ingress.kubernetes.io/ssl-redirect: "true"
spec:
ingressClassName: {{ .Values.ingress.className }}
tls:
- hosts:
- {{ .Values.ingress.host }}
secretName: {{ .Values.ingress.tls.secretName }}
rules:
- host: {{ .Values.ingress.host }}
http:
paths:
- path: /
pathType: Prefix
backend:
service:
name: {{ .Release.Name }}
port:
number: {{ .Values.service.port }}
{{- end }}

View file

@ -0,0 +1,12 @@
apiVersion: v1
kind: Service
metadata:
name: {{ .Release.Name }}
spec:
type: {{ .Values.service.type }}
ports:
- port: {{ .Values.service.port }}
targetPort: {{ .Values.service.targetPort }}
protocol: TCP
selector:
app: {{ .Release.Name }}

View file

@ -0,0 +1,46 @@
replicaCount: 1
image:
repository: viktorbarzin/claude-memory-mcp
tag: latest
pullPolicy: Always
service:
type: ClusterIP
port: 80
targetPort: 8000
ingress:
enabled: true
className: nginx
host: claude-memory.example.com
tls:
secretName: tls-secret
resources:
requests:
memory: 32Mi
cpu: 10m
limits:
memory: 128Mi
env:
DATABASE_URL: ""
API_KEY: ""
# API_KEYS: '{}'
# VAULT_ADDR: ""
# VAULT_TOKEN: ""
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 5
periodSeconds: 30
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 3
periodSeconds: 10

View file

@ -0,0 +1,52 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: claude-memory
namespace: claude-memory
labels:
app: claude-memory
spec:
replicas: 1
selector:
matchLabels:
app: claude-memory
template:
metadata:
labels:
app: claude-memory
spec:
containers:
- name: claude-memory
image: viktorbarzin/claude-memory-mcp:latest
imagePullPolicy: Always
ports:
- containerPort: 8000
env:
- name: DATABASE_URL
valueFrom:
secretKeyRef:
name: claude-memory-secrets
key: database-url
- name: API_KEY
valueFrom:
secretKeyRef:
name: claude-memory-secrets
key: api-key
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 5
periodSeconds: 30
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 3
periodSeconds: 10
resources:
requests:
memory: 32Mi
cpu: 10m
limits:
memory: 128Mi

View file

@ -0,0 +1,24 @@
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: claude-memory
namespace: claude-memory
annotations:
nginx.ingress.kubernetes.io/ssl-redirect: "true"
spec:
ingressClassName: nginx
tls:
- hosts:
- claude-memory.example.com
secretName: tls-secret
rules:
- host: claude-memory.example.com
http:
paths:
- path: /
pathType: Prefix
backend:
service:
name: claude-memory
port:
number: 80

View file

@ -0,0 +1,4 @@
apiVersion: v1
kind: Namespace
metadata:
name: claude-memory

View file

@ -0,0 +1,13 @@
apiVersion: v1
kind: Service
metadata:
name: claude-memory
namespace: claude-memory
spec:
type: ClusterIP
ports:
- port: 80
targetPort: 8000
protocol: TCP
selector:
app: claude-memory

15
docker/Dockerfile Normal file
View file

@ -0,0 +1,15 @@
FROM python:3.12-slim AS base
WORKDIR /app
COPY pyproject.toml .
COPY src/ src/
RUN pip install --no-cache-dir ".[api]"
RUN useradd -r -u 1000 app
USER app
EXPOSE 8000
CMD ["uvicorn", "claude_memory.api.app:app", "--host", "0.0.0.0", "--port", "8000"]

49
docker/docker-compose.yml Normal file
View file

@ -0,0 +1,49 @@
services:
api:
build:
context: ..
dockerfile: docker/Dockerfile
ports:
- "8000:8000"
environment:
DATABASE_URL: postgresql://claude_memory:devpassword@postgres:5432/claude_memory
API_KEY: dev-api-key
# Multi-user mode (uncomment to test):
# API_KEYS: '{"viktor": "key1", "testuser": "key2"}'
# Vault (uncomment to test):
# VAULT_ADDR: http://vault:8200
# VAULT_TOKEN: dev-root-token
depends_on:
postgres:
condition: service_healthy
postgres:
image: postgres:16-alpine
environment:
POSTGRES_DB: claude_memory
POSTGRES_USER: claude_memory
POSTGRES_PASSWORD: devpassword
ports:
- "5432:5432"
volumes:
- pgdata:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U claude_memory"]
interval: 5s
timeout: 3s
retries: 5
vault:
image: hashicorp/vault:1.15
ports:
- "8200:8200"
environment:
VAULT_DEV_ROOT_TOKEN_ID: dev-root-token
VAULT_DEV_LISTEN_ADDRESS: 0.0.0.0:8200
cap_add:
- IPC_LOCK
profiles:
- vault
volumes:
pgdata:

View file

@ -0,0 +1,10 @@
{
"mcpServers": {
"memory": {
"type": "stdio",
"command": "python3",
"args": ["-m", "claude_memory.mcp_server"],
"env": {}
}
}
}

View file

@ -0,0 +1,13 @@
{
"mcpServers": {
"memory": {
"type": "stdio",
"command": "python3",
"args": ["-m", "claude_memory.mcp_server"],
"env": {
"MEMORY_API_URL": "https://claude-memory.example.com",
"MEMORY_API_KEY": "your-api-key-here"
}
}
}
}

View file

@ -0,0 +1,15 @@
{
"mcpServers": {
"memory": {
"type": "stdio",
"command": "python3",
"args": ["-m", "claude_memory.mcp_server"],
"env": {
"MEMORY_API_URL": "https://claude-memory.example.com",
"MEMORY_API_KEY": "your-api-key-here",
"VAULT_ADDR": "https://vault.example.com",
"VAULT_TOKEN": "your-vault-token"
}
}
}
}

35
pyproject.toml Normal file
View file

@ -0,0 +1,35 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "claude-memory-mcp"
version = "1.0.0"
description = "Standalone MCP memory server with multi-user support and Vault integration"
readme = "README.md"
license = "Apache-2.0"
requires-python = ">=3.11"
classifiers = [
"Development Status :: 4 - Beta",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
]
[project.optional-dependencies]
api = ["fastapi>=0.115", "asyncpg>=0.30", "uvicorn>=0.34", "pydantic>=2.0"]
vault = ["hvac>=2.0"]
dev = ["pytest>=8.0", "pytest-asyncio>=0.24", "ruff>=0.8", "mypy>=1.13", "httpx>=0.28"]
[project.scripts]
claude-memory-server = "claude_memory.mcp_server:main"
[tool.hatch.build.targets.wheel]
packages = ["src/claude_memory"]
[tool.ruff]
target-version = "py311"
line-length = 120
[tool.mypy]
python_version = "3.11"
strict = true

View file

@ -0,0 +1,3 @@
"""Claude Memory MCP — standalone memory server with multi-user support."""
__version__ = "1.0.0"

View file

@ -0,0 +1,4 @@
"""Allow running as `python -m claude_memory`."""
from claude_memory.mcp_server import main
main()

View file

View file

@ -0,0 +1,337 @@
"""Claude Memory API -- shared persistent memory with PostgreSQL full-text search."""
import logging
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Optional
from fastapi import Depends, FastAPI, HTTPException
from claude_memory.api.auth import AuthUser, get_current_user
from claude_memory.api.database import close_pool, get_pool, init_pool
from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse
from claude_memory.api.vault_service import (
delete_secret,
get_secret,
is_vault_configured,
store_secret,
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
await init_pool()
yield
await close_pool()
app = FastAPI(title="Claude Memory API", lifespan=lifespan)
def _detect_sensitive(content: str) -> bool:
"""Check if content contains credentials using the credential detector."""
try:
from claude_memory.credential_detector import detect_credentials
findings = detect_credentials(content)
return len(findings) > 0
except ImportError:
return False
def _redact_content(content: str) -> str:
"""Redact sensitive content for storage in the main DB."""
try:
from claude_memory.credential_detector import detect_credentials, redact_credentials
creds = detect_credentials(content)
if creds:
return redact_credentials(content, creds)
return content
except ImportError:
return "[REDACTED]"
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/api/memories", response_model=MemoryResponse)
async def store_memory(body: MemoryStore, user: AuthUser = Depends(get_current_user)):
pool = await get_pool()
is_sensitive = body.force_sensitive or _detect_sensitive(body.content)
async with pool.acquire() as conn:
row = await conn.fetchrow(
"""
INSERT INTO memories (user_id, content, category, tags, expanded_keywords, importance, is_sensitive)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, category, importance
""",
user.user_id,
body.content if not is_sensitive else _redact_content(body.content),
body.category,
body.tags,
body.expanded_keywords,
body.importance,
is_sensitive,
)
memory_id = row["id"]
if is_sensitive and is_vault_configured():
vault_path = await store_secret(user.user_id, memory_id, body.content)
await conn.execute(
"UPDATE memories SET vault_path = $1 WHERE id = $2",
vault_path,
memory_id,
)
return MemoryResponse(id=row["id"], category=row["category"], importance=row["importance"])
@app.post("/api/memories/recall")
async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_current_user)):
pool = await get_pool()
query_text = f"{body.context} {body.expanded_query}".strip()
order_clause = "ts_rank(search_vector, query) DESC"
if body.sort_by == "importance":
order_clause = "importance DESC, ts_rank(search_vector, query) DESC"
elif body.sort_by == "recency":
order_clause = "created_at DESC"
category_filter = ""
params: list = [user.user_id, query_text, body.limit]
if body.category:
category_filter = "AND category = $4"
params.append(body.category)
async with pool.acquire() as conn:
rows = await conn.fetch(
f"""
SELECT id, content, category, tags, importance, is_sensitive,
ts_rank(search_vector, query) AS rank,
created_at, updated_at
FROM memories, plainto_tsquery('english', $2) query
WHERE user_id = $1
AND (search_vector @@ query OR $2 = '')
{category_filter}
ORDER BY {order_clause}
LIMIT $3
""",
*params,
)
results = []
for row in rows:
content = row["content"]
if row["is_sensitive"]:
content = f"[SENSITIVE - use secret_get(id={row['id']})]"
results.append(
{
"id": row["id"],
"content": content,
"category": row["category"],
"tags": row["tags"],
"importance": row["importance"],
"is_sensitive": row["is_sensitive"],
"rank": float(row["rank"]),
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
}
)
return results
@app.get("/api/memories")
async def list_memories(
category: Optional[str] = None,
limit: int = 50,
user: AuthUser = Depends(get_current_user),
):
pool = await get_pool()
if category:
query = """
SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
FROM memories WHERE user_id = $1 AND category = $2
ORDER BY importance DESC LIMIT $3
"""
params: list = [user.user_id, category, limit]
else:
query = """
SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
FROM memories WHERE user_id = $1
ORDER BY importance DESC LIMIT $2
"""
params = [user.user_id, limit]
async with pool.acquire() as conn:
rows = await conn.fetch(query, *params)
results = []
for row in rows:
content = row["content"]
if row["is_sensitive"]:
content = f"[SENSITIVE - use secret_get(id={row['id']})]"
results.append(
{
"id": row["id"],
"content": content,
"category": row["category"],
"tags": row["tags"],
"importance": row["importance"],
"is_sensitive": row["is_sensitive"],
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
}
)
return results
@app.delete("/api/memories/{memory_id}")
async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_user)):
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT id, vault_path FROM memories WHERE id = $1 AND user_id = $2",
memory_id,
user.user_id,
)
if not row:
raise HTTPException(status_code=404, detail="Memory not found")
if row["vault_path"]:
await delete_secret(user.user_id, row["vault_path"])
await conn.execute(
"DELETE FROM memories WHERE id = $1 AND user_id = $2",
memory_id,
user.user_id,
)
return {"deleted": memory_id}
@app.post("/api/memories/{memory_id}/secret", response_model=SecretResponse)
async def get_memory_secret(memory_id: int, user: AuthUser = Depends(get_current_user)):
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT id, content, is_sensitive, vault_path, encrypted_content
FROM memories WHERE id = $1 AND user_id = $2
""",
memory_id,
user.user_id,
)
if not row:
raise HTTPException(status_code=404, detail="Memory not found")
if not row["is_sensitive"]:
return SecretResponse(id=row["id"], content=row["content"], source="plaintext")
if row["vault_path"]:
secret = await get_secret(user.user_id, row["vault_path"])
if secret:
return SecretResponse(id=row["id"], content=secret, source="vault")
if row["encrypted_content"]:
return SecretResponse(
id=row["id"],
content="[ENCRYPTED - decryption not available]",
source="encrypted",
)
return SecretResponse(id=row["id"], content=row["content"], source="plaintext")
@app.post("/api/memories/migrate-secrets")
async def migrate_secrets(user: AuthUser = Depends(get_current_user)):
pool = await get_pool()
migrated = 0
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, content FROM memories
WHERE user_id = $1 AND is_sensitive = FALSE
""",
user.user_id,
)
for row in rows:
if _detect_sensitive(row["content"]):
original_content = row["content"]
redacted = _redact_content(original_content)
vault_path = None
if is_vault_configured():
vault_path = await store_secret(user.user_id, row["id"], original_content)
await conn.execute(
"""
UPDATE memories
SET is_sensitive = TRUE, content = $1, vault_path = $2,
updated_at = NOW()
WHERE id = $3 AND user_id = $4
""",
redacted,
vault_path,
row["id"],
user.user_id,
)
migrated += 1
return {"migrated": migrated}
@app.post("/api/memories/import")
async def import_memories(
memories: list[MemoryStore], user: AuthUser = Depends(get_current_user)
):
pool = await get_pool()
imported = []
async with pool.acquire() as conn:
for mem in memories:
is_sensitive = mem.force_sensitive or _detect_sensitive(mem.content)
content = mem.content if not is_sensitive else _redact_content(mem.content)
row = await conn.fetchrow(
"""
INSERT INTO memories (user_id, content, category, tags, expanded_keywords, importance, is_sensitive)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, category, importance
""",
user.user_id,
content,
mem.category,
mem.tags,
mem.expanded_keywords,
mem.importance,
is_sensitive,
)
if is_sensitive and is_vault_configured():
vault_path = await store_secret(user.user_id, row["id"], mem.content)
await conn.execute(
"UPDATE memories SET vault_path = $1 WHERE id = $2",
vault_path,
row["id"],
)
imported.append(
MemoryResponse(id=row["id"], category=row["category"], importance=row["importance"])
)
return imported

View file

@ -0,0 +1,32 @@
import json
import os
from dataclasses import dataclass
from fastapi import Header, HTTPException
@dataclass
class AuthUser:
user_id: str
# Multi-user mode: API_KEYS='{"viktor": "key1", "user2": "key2"}'
# Single-user mode: API_KEY="some-key" (backward compatible, user_id="default")
_api_keys_json = os.environ.get("API_KEYS", "")
_api_key_single = os.environ.get("API_KEY", "")
_key_to_user: dict[str, str] = {}
if _api_keys_json:
_user_to_key = json.loads(_api_keys_json)
_key_to_user = {v: k for k, v in _user_to_key.items()}
elif _api_key_single:
_key_to_user = {_api_key_single: "default"}
async def get_current_user(authorization: str = Header(...)) -> AuthUser:
token = authorization.removeprefix("Bearer ").strip()
user_id = _key_to_user.get(token)
if user_id is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return AuthUser(user_id=user_id)

View file

@ -0,0 +1,55 @@
import os
import asyncpg
DATABASE_URL = os.environ.get("DATABASE_URL", "")
pool: asyncpg.Pool | None = None
async def init_pool() -> asyncpg.Pool:
global pool
pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
async with pool.acquire() as conn:
await conn.execute("""
CREATE TABLE IF NOT EXISTS memories (
id SERIAL PRIMARY KEY,
user_id VARCHAR(100) NOT NULL DEFAULT 'default',
content TEXT NOT NULL,
category VARCHAR(50) DEFAULT 'facts',
tags TEXT DEFAULT '',
expanded_keywords TEXT DEFAULT '',
importance REAL DEFAULT 0.5,
is_sensitive BOOLEAN DEFAULT FALSE,
vault_path TEXT DEFAULT NULL,
encrypted_content BYTEA DEFAULT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
search_vector tsvector GENERATED ALWAYS AS (
setweight(to_tsvector('english', coalesce(content, '')), 'A') ||
setweight(to_tsvector('english', coalesce(expanded_keywords, '')), 'B') ||
setweight(to_tsvector('english', coalesce(tags, '')), 'C') ||
setweight(to_tsvector('english', coalesce(category, '')), 'D')
) STORED
)
""")
await conn.execute(
"CREATE INDEX IF NOT EXISTS idx_memories_search ON memories USING GIN(search_vector)"
)
await conn.execute(
"CREATE INDEX IF NOT EXISTS idx_memories_user ON memories(user_id)"
)
return pool
async def close_pool():
global pool
if pool:
await pool.close()
pool = None
async def get_pool() -> asyncpg.Pool:
if pool is None:
raise RuntimeError("Database pool not initialized")
return pool

View file

@ -0,0 +1,32 @@
from typing import Optional
from pydantic import BaseModel, Field
class MemoryStore(BaseModel):
content: str
category: str = "facts"
tags: str = ""
expanded_keywords: str = ""
importance: float = Field(default=0.5, ge=0.0, le=1.0)
force_sensitive: bool = False
class MemoryRecall(BaseModel):
context: str
expanded_query: str = ""
category: Optional[str] = None
sort_by: str = "importance"
limit: int = 10
class MemoryResponse(BaseModel):
id: int
category: str
importance: float
class SecretResponse(BaseModel):
id: int
content: str
source: str # "vault", "encrypted", "plaintext"

View file

@ -0,0 +1,51 @@
import os
import logging
logger = logging.getLogger(__name__)
VAULT_ADDR = os.environ.get("VAULT_ADDR", "")
VAULT_TOKEN = os.environ.get("VAULT_TOKEN", "")
VAULT_MOUNT = os.environ.get("VAULT_MOUNT", "secret")
VAULT_PREFIX = os.environ.get("VAULT_PREFIX", "claude-memory")
def is_vault_configured() -> bool:
return bool(VAULT_ADDR and VAULT_TOKEN)
async def store_secret(user_id: str, memory_id: int, content: str) -> str:
"""Store secret content in Vault. Returns the vault path."""
if not is_vault_configured():
raise RuntimeError("Vault not configured")
from claude_memory.vault_client import VaultClient
client = VaultClient(VAULT_ADDR, VAULT_TOKEN, VAULT_MOUNT)
path = f"{VAULT_PREFIX}/{user_id}/mem-{memory_id}"
client.write(path, {"content": content})
return path
async def get_secret(user_id: str, vault_path: str) -> str | None:
"""Retrieve secret content from Vault."""
if not is_vault_configured():
return None
from claude_memory.vault_client import VaultClient
client = VaultClient(VAULT_ADDR, VAULT_TOKEN, VAULT_MOUNT)
data = client.read(vault_path)
if data:
return data.get("content")
return None
async def delete_secret(user_id: str, vault_path: str) -> bool:
"""Delete secret from Vault."""
if not is_vault_configured():
return False
from claude_memory.vault_client import VaultClient
client = VaultClient(VAULT_ADDR, VAULT_TOKEN, VAULT_MOUNT)
return client.delete(vault_path)

View file

@ -0,0 +1,76 @@
"""Detect credentials and secrets in text content."""
import re
from dataclasses import dataclass
@dataclass
class DetectedCredential:
type: str # e.g. "password", "api_key", "private_key", "connection_string", "token"
confidence: float # 0.0 to 1.0
start: int # position in text
end: int # position in text
matched_text: str # the actual matched text (for redaction)
# Patterns ordered by confidence
_PATTERNS: list[tuple[str, str, float]] = [
# High confidence (0.9+)
("private_key", r"-----BEGIN (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----[\s\S]*?-----END (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----", 0.95),
("connection_string", r"(?:postgres(?:ql)?|mysql|mongodb(?:\+srv)?|redis|amqp)://[^\s'\"]+", 0.9),
("aws_key", r"(?:AKIA|ASIA)[A-Z0-9]{16}", 0.95),
("github_token", r"gh[pousr]_[A-Za-z0-9_]{36,}", 0.95),
# Medium confidence (0.7-0.89)
("api_key", r"(?:api[_-]?key|apikey)\s*[:=]\s*['\"]?([A-Za-z0-9_\-]{20,})['\"]?", 0.8),
("password", r"(?:password|passwd|pwd)\s*[:=]\s*['\"]?([^\s'\"]{8,})['\"]?", 0.8),
("token", r"(?:token|secret|bearer)\s*[:=]\s*['\"]?([A-Za-z0-9_\-\.]{20,})['\"]?", 0.75),
("basic_auth", r"(?:Basic\s+)[A-Za-z0-9+/=]{20,}", 0.85),
("bearer_token", r"Bearer\s+[A-Za-z0-9_\-\.]{20,}", 0.85),
# Lower confidence (0.5-0.69)
("generic_secret", r"(?:secret|credential|auth)\s*[:=]\s*['\"]?([^\s'\"]{12,})['\"]?", 0.6),
("hex_key", r"(?:key|secret)\s*[:=]\s*['\"]?([0-9a-fA-F]{32,})['\"]?", 0.65),
]
def detect_credentials(text: str, min_confidence: float = 0.5) -> list[DetectedCredential]:
"""Scan text for potential credentials and secrets."""
results: list[DetectedCredential] = []
for cred_type, pattern, confidence in _PATTERNS:
if confidence < min_confidence:
continue
for match in re.finditer(pattern, text, re.IGNORECASE):
results.append(DetectedCredential(
type=cred_type,
confidence=confidence,
start=match.start(),
end=match.end(),
matched_text=match.group(0),
))
# Deduplicate overlapping matches, keeping highest confidence
results.sort(key=lambda c: (-c.confidence, c.start))
filtered: list[DetectedCredential] = []
for cred in results:
if not any(c.start <= cred.start and c.end >= cred.end for c in filtered):
filtered.append(cred)
return sorted(filtered, key=lambda c: c.start)
def redact_credentials(text: str, credentials: list[DetectedCredential]) -> str:
"""Replace detected credentials with [REDACTED] markers."""
if not credentials:
return text
parts: list[str] = []
last_end = 0
for cred in sorted(credentials, key=lambda c: c.start):
parts.append(text[last_end:cred.start])
parts.append(f"[REDACTED:{cred.type}]")
last_end = cred.end
parts.append(text[last_end:])
return "".join(parts)
def is_sensitive(text: str, min_confidence: float = 0.7) -> bool:
"""Quick check if text likely contains credentials."""
return len(detect_credentials(text, min_confidence)) > 0

View file

@ -0,0 +1,71 @@
"""AES-256-GCM encryption for memory content when Vault is not available."""
import base64
import hashlib
import os
ENCRYPTION_KEY_ENV = "MEMORY_ENCRYPTION_KEY"
def _get_key() -> bytes | None:
"""Get 32-byte encryption key from environment."""
raw = os.environ.get(ENCRYPTION_KEY_ENV)
if not raw:
return None
# Accept hex-encoded 32-byte key or derive from passphrase
try:
key = bytes.fromhex(raw)
if len(key) == 32:
return key
except ValueError:
pass
# Derive key from passphrase using SHA-256
return hashlib.sha256(raw.encode()).digest()
def is_encryption_configured() -> bool:
return _get_key() is not None
def encrypt(plaintext: str) -> bytes:
"""Encrypt text using AES-256-GCM. Returns nonce + ciphertext + tag."""
key = _get_key()
if key is None:
raise RuntimeError(f"{ENCRYPTION_KEY_ENV} not set")
try:
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
except ImportError:
raise RuntimeError("cryptography package required for encryption: pip install cryptography")
nonce = os.urandom(12)
aesgcm = AESGCM(key)
ciphertext = aesgcm.encrypt(nonce, plaintext.encode(), None)
return nonce + ciphertext # 12 bytes nonce + ciphertext + 16 bytes tag
def decrypt(data: bytes) -> str:
"""Decrypt AES-256-GCM encrypted data."""
key = _get_key()
if key is None:
raise RuntimeError(f"{ENCRYPTION_KEY_ENV} not set")
try:
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
except ImportError:
raise RuntimeError("cryptography package required for encryption: pip install cryptography")
nonce = data[:12]
ciphertext = data[12:]
aesgcm = AESGCM(key)
return aesgcm.decrypt(nonce, ciphertext, None).decode()
def encrypt_b64(plaintext: str) -> str:
"""Encrypt and return base64-encoded string."""
return base64.b64encode(encrypt(plaintext)).decode()
def decrypt_b64(data: str) -> str:
"""Decrypt from base64-encoded string."""
return decrypt(base64.b64decode(data))

View file

@ -0,0 +1,550 @@
#!/usr/bin/env python3
"""
Claude Memory MCP Server standalone memory server with multi-user support.
Supports two modes:
1. HTTP API mode: connects to a shared PostgreSQL-backed API server
2. SQLite fallback: local file-based storage when no API key is configured
Uses only stdlib (urllib) no pip install required.
"""
import json
import logging
import os
import sys
import urllib.error
import urllib.request
from typing import Any
logger = logging.getLogger(__name__)
PROTOCOL_VERSION = "2024-11-05"
SERVER_NAME = "claude-memory"
SERVER_VERSION = "1.0.0"
# API configuration — support both MEMORY_* (primary) and CLAUDE_MEMORY_* (fallback) env vars
API_BASE_URL = os.environ.get("MEMORY_API_URL") or os.environ.get("CLAUDE_MEMORY_API_URL", "http://localhost:8080")
API_KEY = os.environ.get("MEMORY_API_KEY") or os.environ.get("CLAUDE_MEMORY_API_KEY", "")
# Fallback to SQLite if API is not configured
SQLITE_FALLBACK = not API_KEY
def _api_request(method: str, path: str, body: dict | None = None) -> dict:
"""Make an HTTP request to the memory API."""
url = f"{API_BASE_URL}{path}"
data = json.dumps(body).encode() if body else None
req = urllib.request.Request(
url,
data=data,
method=method,
headers={
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
},
)
try:
with urllib.request.urlopen(req, timeout=15) as resp:
return json.loads(resp.read().decode())
except urllib.error.HTTPError as e:
error_body = e.read().decode() if e.fp else str(e)
raise RuntimeError(f"API error {e.code}: {error_body}") from e
except urllib.error.URLError as e:
raise RuntimeError(f"API connection error: {e.reason}") from e
# ─── SQLite fallback (local storage when API not configured) ─────────────────
def _init_sqlite(db_path: str | None = None):
"""Initialize SQLite database as fallback."""
import sqlite3
from pathlib import Path
if db_path is None:
memory_home = os.path.expandvars(
os.path.expanduser(os.environ.get("MEMORY_HOME", "~/.claude/claude-memory"))
)
db_path = os.environ.get(
"MEMORY_DB",
os.path.join(memory_home, "memory", "memory.db"),
)
db_path = os.path.expandvars(os.path.expanduser(db_path))
Path(os.path.dirname(db_path)).mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(db_path, timeout=30.0)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=30000")
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS memories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT NOT NULL,
category TEXT DEFAULT 'facts',
tags TEXT DEFAULT '',
expanded_keywords TEXT DEFAULT '',
importance REAL DEFAULT 0.5,
is_sensitive INTEGER DEFAULT 0,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
""")
cursor.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
content, category, tags, expanded_keywords,
content='memories', content_rowid='id'
)
""")
cursor.execute("""
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
INSERT INTO memories_fts(rowid, content, category, tags, expanded_keywords)
VALUES (new.id, new.content, new.category, new.tags, new.expanded_keywords);
END
""")
cursor.execute("""
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, content, category, tags, expanded_keywords)
VALUES ('delete', old.id, old.content, old.category, old.tags, old.expanded_keywords);
END
""")
cursor.execute("""
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, content, category, tags, expanded_keywords)
VALUES ('delete', old.id, old.content, old.category, old.tags, old.expanded_keywords);
INSERT INTO memories_fts(rowid, content, category, tags, expanded_keywords)
VALUES (new.id, new.content, new.category, new.tags, new.expanded_keywords);
END
""")
conn.commit()
return conn
# ─── Tool definitions ────────────────────────────────────────────────────────
TOOLS = [
{
"name": "memory_store",
"description": "Store a fact or memory in persistent storage. Use this to remember important information about the user, their preferences, projects, decisions, or people they mention.",
"inputSchema": {
"type": "object",
"properties": {
"content": {"type": "string", "description": "The fact or memory to store"},
"category": {
"type": "string",
"enum": ["facts", "preferences", "projects", "people", "decisions"],
"description": "Category for organizing the memory",
"default": "facts",
},
"tags": {"type": "string", "description": "Comma-separated tags", "default": ""},
"importance": {
"type": "number",
"description": "Importance 0.0-1.0",
"default": 0.5,
"minimum": 0.0,
"maximum": 1.0,
},
"expanded_keywords": {
"type": "string",
"description": "REQUIRED. Space-separated semantically related search terms (MINIMUM 5 words). Generate keywords that someone might search for when this memory would be relevant. Include synonyms, related concepts, and adjacent topics.",
},
"force_sensitive": {
"type": "boolean",
"description": "If true, mark this memory as sensitive regardless of auto-detection. Sensitive memories have their content encrypted at rest.",
"default": False,
},
},
"required": ["content", "expanded_keywords"],
},
},
{
"name": "memory_recall",
"description": "Retrieve relevant memories based on context. Uses full-text search to find stored memories.",
"inputSchema": {
"type": "object",
"properties": {
"context": {"type": "string", "description": "The context or topic to recall memories about"},
"expanded_query": {
"type": "string",
"description": "REQUIRED. Space-separated semantically related search terms (MINIMUM 5 words).",
},
"category": {
"type": "string",
"enum": ["facts", "preferences", "projects", "people", "decisions"],
"description": "Optional: filter results to a specific category",
},
"sort_by": {
"type": "string",
"enum": ["importance", "relevance"],
"description": "Sort order",
"default": "importance",
},
"limit": {"type": "integer", "description": "Max results", "default": 10},
},
"required": ["context", "expanded_query"],
},
},
{
"name": "memory_list",
"description": "List recent memories, optionally filtered by category.",
"inputSchema": {
"type": "object",
"properties": {
"category": {
"type": "string",
"enum": ["facts", "preferences", "projects", "people", "decisions"],
},
"limit": {"type": "integer", "default": 20},
},
},
},
{
"name": "memory_delete",
"description": "Delete a memory by ID.",
"inputSchema": {
"type": "object",
"properties": {
"id": {"type": "integer", "description": "The ID of the memory to delete"},
},
"required": ["id"],
},
},
{
"name": "secret_get",
"description": "Retrieve the decrypted content of a sensitive memory. Only works for memories marked as sensitive.",
"inputSchema": {
"type": "object",
"properties": {
"id": {"type": "integer", "description": "The ID of the sensitive memory to retrieve"},
},
"required": ["id"],
},
},
]
class MemoryServer:
"""MCP server for persistent memory management."""
def __init__(self, sqlite_db_path: str | None = None) -> None:
self.sqlite_conn = None
if SQLITE_FALLBACK:
self.sqlite_conn = _init_sqlite(sqlite_db_path)
# ── HTTP-backed methods ──────────────────────────────────────────
def memory_store(self, args: dict[str, Any]) -> str:
content = args.get("content")
if not content:
raise ValueError("content is required")
category = args.get("category", "facts")
tags = args.get("tags", "")
importance = max(0.0, min(1.0, float(args.get("importance", 0.5))))
expanded_keywords = args.get("expanded_keywords", "")
force_sensitive = bool(args.get("force_sensitive", False))
if SQLITE_FALLBACK:
return self._sqlite_store(content, category, tags, importance, expanded_keywords, force_sensitive)
result = _api_request("POST", "/api/memories", {
"content": content,
"category": category,
"tags": tags,
"expanded_keywords": expanded_keywords,
"importance": importance,
"force_sensitive": force_sensitive,
})
return f"Stored memory #{result['id']} in category '{result['category']}' with importance {result['importance']:.1f}"
def memory_recall(self, args: dict[str, Any]) -> str:
context = args.get("context")
if not context:
raise ValueError("context is required")
expanded_query = args.get("expanded_query", "")
category = args.get("category")
sort_by = args.get("sort_by", "importance")
limit = args.get("limit", 10)
if SQLITE_FALLBACK:
return self._sqlite_recall(context, expanded_query, category, sort_by, limit)
result = _api_request("POST", "/api/memories/recall", {
"context": context,
"expanded_query": expanded_query,
"category": category,
"sort_by": sort_by,
"limit": limit,
})
rows = result.get("memories", [])
if not rows:
filter_desc = f" in category '{category}'" if category else ""
return f"No memories found matching: {context}{filter_desc}"
sort_desc = "by relevance" if sort_by == "relevance" else "by importance"
filter_desc = f" in '{category}'" if category else ""
results = []
for row in rows:
results.append(
f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}"
f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
)
return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results)
def memory_list(self, args: dict[str, Any]) -> str:
category = args.get("category")
limit = args.get("limit", 20)
if SQLITE_FALLBACK:
return self._sqlite_list(category, limit)
params = f"?limit={limit}"
if category:
params += f"&category={category}"
result = _api_request("GET", f"/api/memories{params}")
rows = result.get("memories", [])
if not rows:
return f"No memories in category '{category}'" if category else "No memories stored yet"
results = []
for row in rows:
results.append(
f"#{row['id']} [{row['category']}] {row['content']}"
f"\n Importance: {row['importance']:.1f} | Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
)
header = "Recent memories"
if category:
header += f" in '{category}'"
return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results)
def memory_delete(self, args: dict[str, Any]) -> str:
memory_id = args.get("id")
if memory_id is None:
raise ValueError("id is required")
if SQLITE_FALLBACK:
return self._sqlite_delete(memory_id)
result = _api_request("DELETE", f"/api/memories/{memory_id}")
return f"Deleted memory #{result['deleted']}: {result['preview']}..."
def secret_get(self, args: dict[str, Any]) -> str:
memory_id = args.get("id")
if memory_id is None:
raise ValueError("id is required")
if SQLITE_FALLBACK:
return self._sqlite_secret_get(memory_id)
result = _api_request("POST", f"/api/memories/{memory_id}/secret")
return f"#{result['id']} [{result['category']}] {result['content']}"
# ── SQLite fallback methods ──────────────────────────────────────
def _sqlite_store(self, content, category, tags, importance, expanded_keywords, force_sensitive=False):
from datetime import datetime, timezone
now = datetime.now(timezone.utc).isoformat()
is_sensitive = 1 if force_sensitive else 0
cursor = self.sqlite_conn.cursor()
cursor.execute(
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(content, category, tags, expanded_keywords, importance, is_sensitive, now, now),
)
self.sqlite_conn.commit()
return f"Stored memory #{cursor.lastrowid} in category '{category}' with importance {importance:.1f}"
def _sqlite_recall(self, context, expanded_query, category, sort_by, limit):
import sqlite3
all_terms = f"{context} {expanded_query}".strip()
words = all_terms.split()
fts_query = " OR ".join(f'"{w.replace(chr(34), "")}"' for w in words if w)
order = (
"bm25(memories_fts), m.importance DESC"
if sort_by == "relevance"
else "m.importance DESC, m.created_at DESC"
)
cursor = self.sqlite_conn.cursor()
try:
if category:
cursor.execute(
f"SELECT m.id, m.content, m.category, m.tags, m.importance, m.created_at "
f"FROM memories m JOIN memories_fts fts ON m.id = fts.rowid "
f"WHERE memories_fts MATCH ? AND m.category = ? ORDER BY {order} LIMIT ?",
(fts_query, category, limit),
)
else:
cursor.execute(
f"SELECT m.id, m.content, m.category, m.tags, m.importance, m.created_at "
f"FROM memories m JOIN memories_fts fts ON m.id = fts.rowid "
f"WHERE memories_fts MATCH ? ORDER BY {order} LIMIT ?",
(fts_query, limit),
)
rows = cursor.fetchall()
except sqlite3.OperationalError:
like = f"%{context}%"
if category:
cursor.execute(
"SELECT id, content, category, tags, importance, created_at FROM memories "
"WHERE (content LIKE ? OR tags LIKE ?) AND category = ? ORDER BY importance DESC LIMIT ?",
(like, like, category, limit),
)
else:
cursor.execute(
"SELECT id, content, category, tags, importance, created_at FROM memories "
"WHERE content LIKE ? OR tags LIKE ? ORDER BY importance DESC LIMIT ?",
(like, like, limit),
)
rows = cursor.fetchall()
if not rows:
return f"No memories found matching: {context}"
results = []
for row in rows:
results.append(
f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}"
f"\n Tags: {row['tags'] or 'none'} | Stored: {row['created_at']}"
)
return (
f"Found {len(rows)} memories (by {'relevance' if sort_by == 'relevance' else 'importance'}):\n\n"
+ "\n\n".join(results)
)
def _sqlite_list(self, category, limit):
cursor = self.sqlite_conn.cursor()
if category:
cursor.execute(
"SELECT id, content, category, tags, importance, created_at FROM memories "
"WHERE category = ? ORDER BY created_at DESC LIMIT ?",
(category, limit),
)
else:
cursor.execute(
"SELECT id, content, category, tags, importance, created_at FROM memories "
"ORDER BY created_at DESC LIMIT ?",
(limit,),
)
rows = cursor.fetchall()
if not rows:
return f"No memories in category '{category}'" if category else "No memories stored yet"
results = []
for row in rows:
results.append(
f"#{row['id']} [{row['category']}] {row['content']}"
f"\n Importance: {row['importance']:.1f} | Tags: {row['tags'] or 'none'} | Stored: {row['created_at']}"
)
header = "Recent memories" + (f" in '{category}'" if category else "")
return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results)
def _sqlite_delete(self, memory_id):
cursor = self.sqlite_conn.cursor()
cursor.execute("SELECT id, content FROM memories WHERE id = ?", (memory_id,))
row = cursor.fetchone()
if not row:
return f"Memory #{memory_id} not found"
preview = row["content"][:50]
cursor.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
self.sqlite_conn.commit()
return f"Deleted memory #{memory_id}: {preview}..."
def _sqlite_secret_get(self, memory_id):
cursor = self.sqlite_conn.cursor()
cursor.execute(
"SELECT id, content, category, is_sensitive FROM memories WHERE id = ?",
(memory_id,),
)
row = cursor.fetchone()
if not row:
return f"Memory #{memory_id} not found"
if not row["is_sensitive"]:
return f"Memory #{memory_id} is not marked as sensitive"
return f"#{row['id']} [{row['category']}] {row['content']}"
# ── MCP protocol ─────────────────────────────────────────────────
def handle_initialize(self, params: dict[str, Any]) -> dict[str, Any]:
return {
"protocolVersion": PROTOCOL_VERSION,
"capabilities": {"tools": {}},
"serverInfo": {"name": SERVER_NAME, "version": SERVER_VERSION},
}
def handle_tools_list(self, params: dict[str, Any]) -> dict[str, Any]:
return {"tools": TOOLS}
def handle_tools_call(self, params: dict[str, Any]) -> dict[str, Any]:
tool_name = params.get("name")
arguments = params.get("arguments", {})
try:
handler = {
"memory_store": self.memory_store,
"memory_recall": self.memory_recall,
"memory_list": self.memory_list,
"memory_delete": self.memory_delete,
"secret_get": self.secret_get,
}.get(tool_name)
if handler is None:
return {"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], "isError": True}
result = handler(arguments)
return {"content": [{"type": "text", "text": result}]}
except Exception as e:
return {"content": [{"type": "text", "text": f"Error: {e!s}"}], "isError": True}
def process_message(self, message: dict[str, Any]) -> dict[str, Any] | None:
method = message.get("method")
params = message.get("params", {})
msg_id = message.get("id")
if msg_id is None:
return None
result = None
error = None
try:
if method == "initialize":
result = self.handle_initialize(params)
elif method == "tools/list":
result = self.handle_tools_list(params)
elif method == "tools/call":
result = self.handle_tools_call(params)
else:
error = {"code": -32601, "message": f"Method not found: {method}"}
except Exception as e:
error = {"code": -32603, "message": str(e)}
response: dict[str, Any] = {"jsonrpc": "2.0", "id": msg_id}
if error:
response["error"] = error
else:
response["result"] = result
return response
def run(self) -> None:
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
message = json.loads(line)
except json.JSONDecodeError as e:
print(
json.dumps({
"jsonrpc": "2.0",
"id": None,
"error": {"code": -32700, "message": f"Parse error: {e}"},
}),
flush=True,
)
continue
response = self.process_message(message)
if response is not None:
print(json.dumps(response), flush=True)
def main() -> None:
server = MemoryServer()
server.run()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,84 @@
"""HashiCorp Vault KV v2 client using stdlib urllib."""
import json
import logging
import os
import urllib.error
import urllib.request
from typing import Any
logger = logging.getLogger(__name__)
class VaultClient:
"""Simple Vault KV v2 client using stdlib."""
def __init__(
self,
addr: str | None = None,
token: str | None = None,
mount: str = "secret",
):
self.addr = (addr or os.environ.get("VAULT_ADDR", "")).rstrip("/")
self.token = token or os.environ.get("VAULT_TOKEN", "")
self.mount = mount
if not self.addr:
raise ValueError("Vault address not configured (set VAULT_ADDR)")
# Auto-detect Kubernetes SA token
if not self.token:
sa_token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token"
if os.path.exists(sa_token_path):
self._login_kubernetes(sa_token_path)
def _login_kubernetes(self, sa_token_path: str) -> None:
"""Authenticate with Vault using Kubernetes service account."""
with open(sa_token_path) as f:
jwt = f.read().strip()
role = os.environ.get("VAULT_ROLE", "claude-memory")
resp = self._request("POST", "/v1/auth/kubernetes/login", {"jwt": jwt, "role": role})
self.token = resp.get("auth", {}).get("client_token", "")
def _request(self, method: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]:
"""Make HTTP request to Vault."""
url = f"{self.addr}{path}"
data = json.dumps(body).encode() if body else None
req = urllib.request.Request(
url, data=data, method=method,
headers={"X-Vault-Token": self.token, "Content-Type": "application/json"},
)
try:
with urllib.request.urlopen(req, timeout=10) as resp:
return json.loads(resp.read().decode())
except urllib.error.HTTPError as e:
if e.code == 404:
return {}
error_body = e.read().decode() if e.fp else str(e)
raise RuntimeError(f"Vault error {e.code}: {error_body}") from e
def read(self, path: str) -> dict[str, Any] | None:
"""Read a secret from KV v2."""
resp = self._request("GET", f"/v1/{self.mount}/data/{path}")
data = resp.get("data", {})
return data.get("data") if data else None
def write(self, path: str, data: dict[str, Any]) -> dict[str, Any]:
"""Write a secret to KV v2."""
return self._request("POST", f"/v1/{self.mount}/data/{path}", {"data": data})
def delete(self, path: str) -> bool:
"""Delete a secret from KV v2."""
try:
self._request("DELETE", f"/v1/{self.mount}/data/{path}")
return True
except RuntimeError:
return False
def list_secrets(self, path: str) -> list[str]:
"""List secrets at a path."""
try:
resp = self._request("LIST", f"/v1/{self.mount}/metadata/{path}")
return resp.get("data", {}).get("keys", [])
except RuntimeError:
return []

0
tests/__init__.py Normal file
View file

304
tests/test_api.py Normal file
View file

@ -0,0 +1,304 @@
"""Tests for the Claude Memory API endpoints."""
import importlib
import os
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from claude_memory.api.auth import AuthUser
# Helpers to build mock asyncpg rows (they behave like dicts with attribute access)
class MockRow(dict):
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(key)
def _make_memory_row(**overrides):
now = datetime.now(timezone.utc)
defaults = {
"id": 1,
"user_id": "testuser",
"content": "test content",
"category": "facts",
"tags": "",
"expanded_keywords": "",
"importance": 0.5,
"is_sensitive": False,
"vault_path": None,
"encrypted_content": None,
"rank": 0.5,
"created_at": now,
"updated_at": now,
}
defaults.update(overrides)
return MockRow(defaults)
@pytest.fixture
def mock_pool():
"""Create a mock asyncpg pool with connection context manager."""
pool = MagicMock()
conn = AsyncMock()
# pool.acquire() returns an async context manager yielding conn
acm = MagicMock()
acm.__aenter__ = AsyncMock(return_value=conn)
acm.__aexit__ = AsyncMock(return_value=False)
pool.acquire.return_value = acm
return pool, conn
@pytest.fixture
def test_user():
return AuthUser(user_id="testuser")
@pytest.fixture
def client(mock_pool, test_user):
"""Create an AsyncClient with mocked dependencies."""
pool, conn = mock_pool
# Reload modules with test API key
with patch.dict(os.environ, {"API_KEY": "test-key", "API_KEYS": "", "DATABASE_URL": "postgresql://test"}):
import claude_memory.api.auth as auth_mod
import claude_memory.api.database as db_mod
import claude_memory.api.app as app_mod
importlib.reload(auth_mod)
importlib.reload(db_mod)
importlib.reload(app_mod)
# Override database pool
db_mod.pool = pool
# Override auth to return our test user
async def mock_get_user(authorization: str = ""):
return test_user
app_mod.app.dependency_overrides[auth_mod.get_current_user] = mock_get_user
transport = ASGITransport(app=app_mod.app)
return AsyncClient(transport=transport, base_url="http://test"), conn, app_mod
@pytest.mark.asyncio
async def test_health_endpoint_no_auth(client):
ac, conn, app_mod = client
async with ac:
resp = await ac.get("/health")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
@pytest.mark.asyncio
async def test_store_memory_creates_record_with_user_id(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = _make_memory_row(id=42, category="facts", importance=0.7)
async with ac:
resp = await ac.post(
"/api/memories",
json={"content": "Python is great", "category": "facts", "importance": 0.7},
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
data = resp.json()
assert data["id"] == 42
assert data["category"] == "facts"
assert data["importance"] == 0.7
# Verify INSERT was called with user_id
call_args = conn.fetchrow.call_args
assert call_args[0][1] == "testuser" # user_id is the second positional arg
@pytest.mark.asyncio
async def test_recall_returns_only_user_memories(client):
ac, conn, app_mod = client
conn.fetch.return_value = [
_make_memory_row(id=1, content="user memory", is_sensitive=False),
]
async with ac:
resp = await ac.post(
"/api/memories/recall",
json={"context": "test query"},
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
results = resp.json()
assert len(results) == 1
assert results[0]["content"] == "user memory"
# Verify query includes user_id filter
call_args = conn.fetch.call_args
assert call_args[0][1] == "testuser"
@pytest.mark.asyncio
async def test_recall_redacts_sensitive_memories(client):
ac, conn, app_mod = client
conn.fetch.return_value = [
_make_memory_row(id=5, content="[REDACTED]", is_sensitive=True),
]
async with ac:
resp = await ac.post(
"/api/memories/recall",
json={"context": "secrets"},
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
results = resp.json()
assert "[SENSITIVE" in results[0]["content"]
assert "secret_get(id=5)" in results[0]["content"]
@pytest.mark.asyncio
async def test_list_returns_only_user_memories(client):
ac, conn, app_mod = client
conn.fetch.return_value = [
_make_memory_row(id=1, content="mem1"),
_make_memory_row(id=2, content="mem2"),
]
async with ac:
resp = await ac.get(
"/api/memories",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
results = resp.json()
assert len(results) == 2
# Verify user_id filter
call_args = conn.fetch.call_args
assert call_args[0][1] == "testuser"
@pytest.mark.asyncio
async def test_delete_only_user_memories(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = _make_memory_row(id=10, vault_path=None)
conn.execute.return_value = None
async with ac:
resp = await ac.delete(
"/api/memories/10",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
assert resp.json() == {"deleted": 10}
# Verify both SELECT and DELETE include user_id
fetchrow_args = conn.fetchrow.call_args
assert fetchrow_args[0][1] == 10 # memory_id
assert fetchrow_args[0][2] == "testuser" # user_id
@pytest.mark.asyncio
async def test_delete_nonexistent_memory_returns_404(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = None
async with ac:
resp = await ac.delete(
"/api/memories/999",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_secret_endpoint_returns_plaintext(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = _make_memory_row(
id=7, content="my secret value", is_sensitive=False,
vault_path=None, encrypted_content=None,
)
async with ac:
resp = await ac.post(
"/api/memories/7/secret",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
data = resp.json()
assert data["id"] == 7
assert data["content"] == "my secret value"
assert data["source"] == "plaintext"
@pytest.mark.asyncio
async def test_secret_endpoint_returns_vault_content(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = _make_memory_row(
id=8, content="[REDACTED]", is_sensitive=True,
vault_path="claude-memory/testuser/mem-8", encrypted_content=None,
)
with patch("claude_memory.api.app.get_secret", return_value="actual-secret-from-vault"):
async with ac:
resp = await ac.post(
"/api/memories/8/secret",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
data = resp.json()
assert data["content"] == "actual-secret-from-vault"
assert data["source"] == "vault"
@pytest.mark.asyncio
async def test_secret_endpoint_nonexistent_returns_404(client):
ac, conn, app_mod = client
conn.fetchrow.return_value = None
async with ac:
resp = await ac.post(
"/api/memories/999/secret",
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_import_memories(client):
ac, conn, app_mod = client
conn.fetchrow.side_effect = [
_make_memory_row(id=100, category="facts", importance=0.5),
_make_memory_row(id=101, category="preferences", importance=0.8),
]
async with ac:
resp = await ac.post(
"/api/memories/import",
json=[
{"content": "fact one", "category": "facts"},
{"content": "pref one", "category": "preferences", "importance": 0.8},
],
headers={"Authorization": "Bearer test-key"},
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
assert data[0]["id"] == 100
assert data[1]["id"] == 101

87
tests/test_auth.py Normal file
View file

@ -0,0 +1,87 @@
"""Tests for multi-user authentication."""
import importlib
import os
from unittest.mock import patch
import pytest
from fastapi import HTTPException
def _reload_auth(env_vars: dict):
"""Reload the auth module with given environment variables."""
with patch.dict(os.environ, env_vars, clear=False):
# Clear existing env vars that might interfere
for key in ("API_KEY", "API_KEYS"):
os.environ.pop(key, None)
for key, val in env_vars.items():
os.environ[key] = val
import claude_memory.api.auth as auth_mod
importlib.reload(auth_mod)
return auth_mod
@pytest.mark.asyncio
async def test_single_api_key_maps_to_default():
auth = _reload_auth({"API_KEY": "test-key-123", "API_KEYS": ""})
user = await auth.get_current_user(authorization="Bearer test-key-123")
assert user.user_id == "default"
@pytest.mark.asyncio
async def test_multi_api_keys_maps_to_correct_user():
auth = _reload_auth({
"API_KEYS": '{"viktor": "key-viktor", "alice": "key-alice"}',
"API_KEY": "",
})
user_v = await auth.get_current_user(authorization="Bearer key-viktor")
assert user_v.user_id == "viktor"
user_a = await auth.get_current_user(authorization="Bearer key-alice")
assert user_a.user_id == "alice"
@pytest.mark.asyncio
async def test_invalid_key_returns_401():
auth = _reload_auth({"API_KEY": "valid-key", "API_KEYS": ""})
with pytest.raises(HTTPException) as exc_info:
await auth.get_current_user(authorization="Bearer wrong-key")
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_missing_bearer_prefix_still_works():
auth = _reload_auth({"API_KEY": "my-key", "API_KEYS": ""})
# Without Bearer prefix, removeprefix("Bearer ") returns "my-key" unchanged
# so the raw token still matches the key
user = await auth.get_current_user(authorization="my-key")
assert user.user_id == "default"
# With proper Bearer prefix it also works
user = await auth.get_current_user(authorization="Bearer my-key")
assert user.user_id == "default"
@pytest.mark.asyncio
async def test_missing_authorization_header_raises_422():
"""FastAPI raises 422 when required Header is missing.
This is tested via the app integration, not the function directly,
since FastAPI handles the missing header before the function runs.
"""
from httpx import ASGITransport, AsyncClient
# Need to reload with valid keys so the app can start
_reload_auth({"API_KEY": "test-key", "API_KEYS": ""})
# Import app after auth is configured
import claude_memory.api.app as app_mod
importlib.reload(app_mod)
transport = ASGITransport(app=app_mod.app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# Skip lifespan since we don't have a real DB
resp = await client.get("/api/memories")
assert resp.status_code == 422

View file

@ -0,0 +1,132 @@
"""Tests for credential detection and redaction."""
import pytest
from claude_memory.credential_detector import (
DetectedCredential,
detect_credentials,
is_sensitive,
redact_credentials,
)
class TestDetectCredentials:
def test_detect_postgres_connection_string(self):
text = "db_url = postgres://user:pass@localhost:5432/mydb"
creds = detect_credentials(text)
assert len(creds) == 1
assert creds[0].type == "connection_string"
assert creds[0].confidence == 0.9
assert "postgres://" in creds[0].matched_text
def test_detect_password_assignment(self):
text = 'password = "my_super_secret_pw"'
creds = detect_credentials(text)
assert len(creds) >= 1
types = [c.type for c in creds]
assert "password" in types
def test_detect_api_key(self):
text = "api_key = ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
creds = detect_credentials(text)
assert len(creds) >= 1
types = [c.type for c in creds]
assert "api_key" in types
def test_detect_private_key(self):
text = "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA0Z3VS5JJcds3xfn/ygWep4PAtGoSo\n-----END RSA PRIVATE KEY-----"
creds = detect_credentials(text)
assert len(creds) == 1
assert creds[0].type == "private_key"
assert creds[0].confidence == 0.95
def test_detect_bearer_token(self):
text = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkw"
creds = detect_credentials(text)
assert len(creds) >= 1
types = [c.type for c in creds]
assert "bearer_token" in types
def test_detect_aws_key(self):
text = "aws_access_key_id = AKIAIOSFODNN7EXAMPLE"
creds = detect_credentials(text)
assert len(creds) >= 1
types = [c.type for c in creds]
assert "aws_key" in types
def test_detect_github_token(self):
text = "GITHUB_TOKEN=ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmn"
creds = detect_credentials(text)
assert len(creds) >= 1
types = [c.type for c in creds]
assert "github_token" in types
def test_no_false_positives_on_normal_text(self):
text = "This is a normal paragraph about programming. It discusses variables, functions, and classes."
creds = detect_credentials(text)
assert len(creds) == 0
def test_no_false_positives_on_short_password(self):
# password values shorter than 8 chars should not match
text = 'password = "short"'
creds = detect_credentials(text)
assert len(creds) == 0
def test_min_confidence_filtering(self):
text = 'secret = "abcdefghijklmnopqrstuvwxyz"'
all_creds = detect_credentials(text, min_confidence=0.5)
high_creds = detect_credentials(text, min_confidence=0.9)
assert len(all_creds) >= len(high_creds)
def test_overlapping_matches_keep_highest_confidence(self):
# A text that could match both token and generic_secret
text = 'secret = "abcdefghijklmnopqrstuvwxyz1234567890"'
creds = detect_credentials(text, min_confidence=0.5)
# Should not have overlapping ranges for the same span
for i, c1 in enumerate(creds):
for c2 in creds[i + 1:]:
# No credential should be fully contained within another
assert not (c1.start <= c2.start and c1.end >= c2.end)
class TestRedactCredentials:
def test_redaction_replaces_with_marker(self):
text = "db_url = postgres://user:pass@localhost:5432/mydb"
creds = detect_credentials(text)
redacted = redact_credentials(text, creds)
assert "[REDACTED:connection_string]" in redacted
assert "postgres://" not in redacted
def test_redaction_preserves_surrounding_text(self):
text = "before postgres://user:pass@localhost/db after"
creds = detect_credentials(text)
redacted = redact_credentials(text, creds)
assert redacted.startswith("before ")
assert redacted.endswith(" after")
def test_redaction_no_credentials(self):
text = "nothing sensitive here"
redacted = redact_credentials(text, [])
assert redacted == text
def test_redaction_multiple_credentials(self):
text = 'password = "mysecretpw123" and api_key = ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890'
creds = detect_credentials(text)
redacted = redact_credentials(text, creds)
assert "mysecretpw123" not in redacted
assert "[REDACTED:" in redacted
class TestIsSensitive:
def test_sensitive_text(self):
assert is_sensitive("password = supersecretvalue123")
def test_non_sensitive_text(self):
assert not is_sensitive("just a normal log message")
def test_respects_min_confidence(self):
text = 'secret = "abcdefghijklmnopqrstuvwxyz"'
# Low confidence should detect
assert is_sensitive(text, min_confidence=0.5)
# Very high confidence should not detect generic_secret
assert not is_sensitive(text, min_confidence=0.95)

134
tests/test_crypto.py Normal file
View file

@ -0,0 +1,134 @@
"""Tests for AES-256-GCM encryption module."""
import hashlib
import os
import pytest
from claude_memory.crypto import (
ENCRYPTION_KEY_ENV,
decrypt,
decrypt_b64,
encrypt,
encrypt_b64,
is_encryption_configured,
)
# A valid 32-byte hex key for testing
TEST_HEX_KEY = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
TEST_PASSPHRASE = "my-test-passphrase"
@pytest.fixture
def hex_key_env(monkeypatch):
monkeypatch.setenv(ENCRYPTION_KEY_ENV, TEST_HEX_KEY)
@pytest.fixture
def passphrase_env(monkeypatch):
monkeypatch.setenv(ENCRYPTION_KEY_ENV, TEST_PASSPHRASE)
@pytest.fixture
def no_key_env(monkeypatch):
monkeypatch.delenv(ENCRYPTION_KEY_ENV, raising=False)
class TestEncryptionConfigured:
def test_configured_with_hex_key(self, hex_key_env):
assert is_encryption_configured() is True
def test_configured_with_passphrase(self, passphrase_env):
assert is_encryption_configured() is True
def test_not_configured_without_env(self, no_key_env):
assert is_encryption_configured() is False
class TestEncryptDecrypt:
def test_roundtrip_with_hex_key(self, hex_key_env):
plaintext = "Hello, this is a secret message!"
encrypted = encrypt(plaintext)
decrypted = decrypt(encrypted)
assert decrypted == plaintext
def test_roundtrip_with_passphrase(self, passphrase_env):
plaintext = "Another secret message with passphrase key"
encrypted = encrypt(plaintext)
decrypted = decrypt(encrypted)
assert decrypted == plaintext
def test_different_plaintexts_produce_different_ciphertexts(self, hex_key_env):
ct1 = encrypt("message one")
ct2 = encrypt("message two")
assert ct1 != ct2
def test_same_plaintext_produces_different_ciphertexts(self, hex_key_env):
"""Due to random nonce, encrypting the same text twice gives different results."""
ct1 = encrypt("same message")
ct2 = encrypt("same message")
assert ct1 != ct2
def test_missing_key_raises_on_encrypt(self, no_key_env):
with pytest.raises(RuntimeError, match=ENCRYPTION_KEY_ENV):
encrypt("test")
def test_missing_key_raises_on_decrypt(self, no_key_env):
with pytest.raises(RuntimeError, match=ENCRYPTION_KEY_ENV):
decrypt(b"\x00" * 28)
def test_decrypt_with_wrong_key_fails(self, hex_key_env, monkeypatch):
plaintext = "secret data"
encrypted = encrypt(plaintext)
# Change to a different key
monkeypatch.setenv(ENCRYPTION_KEY_ENV, "ff" * 32)
with pytest.raises(Exception):
decrypt(encrypted)
def test_encrypted_data_format(self, hex_key_env):
"""Encrypted data should be at least 12 (nonce) + 16 (tag) bytes."""
encrypted = encrypt("x")
assert len(encrypted) >= 28 # 12 nonce + 1 plaintext + 16 tag = 29 minimum
def test_unicode_roundtrip(self, hex_key_env):
plaintext = "Unicode test: cafe\u0301, \u00fc\u00f6\u00e4, \U0001f512"
decrypted = decrypt(encrypt(plaintext))
assert decrypted == plaintext
class TestBase64Variants:
def test_b64_roundtrip(self, hex_key_env):
plaintext = "base64 test message"
encrypted_b64 = encrypt_b64(plaintext)
assert isinstance(encrypted_b64, str)
decrypted = decrypt_b64(encrypted_b64)
assert decrypted == plaintext
def test_b64_output_is_valid_base64(self, hex_key_env):
import base64
encrypted_b64 = encrypt_b64("test")
# Should not raise
decoded = base64.b64decode(encrypted_b64)
assert len(decoded) >= 28
class TestKeyDerivation:
def test_hex_key_used_directly(self, hex_key_env):
"""A valid 64-char hex string should be used as-is (32 bytes)."""
ct = encrypt("test")
pt = decrypt(ct)
assert pt == "test"
def test_passphrase_derived_via_sha256(self, passphrase_env):
"""Non-hex strings should be derived via SHA-256."""
ct = encrypt("test")
pt = decrypt(ct)
assert pt == "test"
def test_short_hex_treated_as_passphrase(self, monkeypatch):
"""Hex string that's not exactly 32 bytes should be treated as passphrase."""
monkeypatch.setenv(ENCRYPTION_KEY_ENV, "abcd1234")
ct = encrypt("test")
pt = decrypt(ct)
assert pt == "test"

342
tests/test_mcp_server.py Normal file
View file

@ -0,0 +1,342 @@
"""Tests for the Claude Memory MCP server."""
import json
import os
import sys
import pytest
# Force SQLite fallback mode for all tests
os.environ.pop("MEMORY_API_KEY", None)
os.environ.pop("CLAUDE_MEMORY_API_KEY", None)
# Add src to path so we can import without installing
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from claude_memory.mcp_server import MemoryServer, TOOLS, SERVER_NAME, SERVER_VERSION, PROTOCOL_VERSION
@pytest.fixture
def server(tmp_path):
"""Create a MemoryServer with a temporary SQLite database."""
db_path = str(tmp_path / "test_memory.db")
srv = MemoryServer(sqlite_db_path=db_path)
yield srv
if srv.sqlite_conn:
srv.sqlite_conn.close()
class TestSQLiteInit:
def test_creates_database(self, tmp_path):
db_path = str(tmp_path / "sub" / "test.db")
srv = MemoryServer(sqlite_db_path=db_path)
assert os.path.exists(db_path)
# Verify tables exist
cursor = srv.sqlite_conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories'")
assert cursor.fetchone() is not None
srv.sqlite_conn.close()
def test_creates_fts_table(self, tmp_path):
db_path = str(tmp_path / "test.db")
srv = MemoryServer(sqlite_db_path=db_path)
cursor = srv.sqlite_conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories_fts'")
assert cursor.fetchone() is not None
srv.sqlite_conn.close()
class TestMemoryStore:
def test_store_basic(self, server):
result = server.memory_store({
"content": "User prefers dark mode",
"expanded_keywords": "dark mode theme preference ui",
})
assert "Stored memory #1" in result
assert "facts" in result
def test_store_with_category(self, server):
result = server.memory_store({
"content": "User likes Python",
"category": "preferences",
"expanded_keywords": "python programming language preference",
})
assert "preferences" in result
def test_store_with_importance(self, server):
result = server.memory_store({
"content": "Critical info",
"importance": 0.9,
"expanded_keywords": "critical important info",
})
assert "0.9" in result
def test_store_requires_content(self, server):
with pytest.raises(ValueError, match="content is required"):
server.memory_store({"expanded_keywords": "test"})
def test_store_force_sensitive(self, server):
result = server.memory_store({
"content": "API key: sk-1234",
"force_sensitive": True,
"expanded_keywords": "api key secret credential",
})
assert "Stored memory #1" in result
# Verify is_sensitive flag is set
cursor = server.sqlite_conn.cursor()
cursor.execute("SELECT is_sensitive FROM memories WHERE id = 1")
row = cursor.fetchone()
assert row["is_sensitive"] == 1
class TestMemoryRecall:
def test_recall_finds_memory(self, server):
server.memory_store({
"content": "User works at Acme Corp",
"expanded_keywords": "acme corp company work employer",
})
result = server.memory_recall({
"context": "work",
"expanded_query": "company employer job",
})
assert "Acme Corp" in result
assert "Found 1 memories" in result
def test_recall_no_results(self, server):
result = server.memory_recall({
"context": "nonexistent topic",
"expanded_query": "nothing here at all",
})
assert "No memories found" in result
def test_recall_with_category_filter(self, server):
server.memory_store({
"content": "User prefers vim",
"category": "preferences",
"expanded_keywords": "vim editor preference text",
})
server.memory_store({
"content": "Project uses React",
"category": "projects",
"expanded_keywords": "react project frontend framework",
})
result = server.memory_recall({
"context": "preferences",
"expanded_query": "vim editor",
"category": "preferences",
})
assert "vim" in result
assert "React" not in result
def test_recall_requires_context(self, server):
with pytest.raises(ValueError, match="context is required"):
server.memory_recall({"expanded_query": "test"})
class TestMemoryList:
def test_list_empty(self, server):
result = server.memory_list({})
assert "No memories stored yet" in result
def test_list_with_memories(self, server):
server.memory_store({
"content": "Memory one",
"expanded_keywords": "one first test",
})
server.memory_store({
"content": "Memory two",
"expanded_keywords": "two second test",
})
result = server.memory_list({})
assert "Memory one" in result
assert "Memory two" in result
assert "2 shown" in result
def test_list_with_category(self, server):
server.memory_store({
"content": "A fact",
"category": "facts",
"expanded_keywords": "fact test",
})
server.memory_store({
"content": "A preference",
"category": "preferences",
"expanded_keywords": "preference test",
})
result = server.memory_list({"category": "facts"})
assert "A fact" in result
assert "A preference" not in result
def test_list_empty_category(self, server):
result = server.memory_list({"category": "projects"})
assert "No memories in category 'projects'" in result
def test_list_respects_limit(self, server):
for i in range(5):
server.memory_store({
"content": f"Memory {i}",
"expanded_keywords": f"memory number {i}",
})
result = server.memory_list({"limit": 2})
assert "2 shown" in result
class TestMemoryDelete:
def test_delete_existing(self, server):
server.memory_store({
"content": "To be deleted",
"expanded_keywords": "delete remove test",
})
result = server.memory_delete({"id": 1})
assert "Deleted memory #1" in result
assert "To be deleted" in result
def test_delete_nonexistent(self, server):
result = server.memory_delete({"id": 999})
assert "not found" in result
def test_delete_requires_id(self, server):
with pytest.raises(ValueError, match="id is required"):
server.memory_delete({})
class TestSecretGet:
def test_secret_get_sensitive(self, server):
server.memory_store({
"content": "secret password 12345",
"force_sensitive": True,
"expanded_keywords": "password secret credential",
})
result = server.secret_get({"id": 1})
assert "secret password 12345" in result
def test_secret_get_not_sensitive(self, server):
server.memory_store({
"content": "public info",
"expanded_keywords": "public info test",
})
result = server.secret_get({"id": 1})
assert "not marked as sensitive" in result
def test_secret_get_nonexistent(self, server):
result = server.secret_get({"id": 999})
assert "not found" in result
def test_secret_get_requires_id(self, server):
with pytest.raises(ValueError, match="id is required"):
server.secret_get({})
class TestMCPProtocol:
def test_handle_initialize(self, server):
result = server.handle_initialize({})
assert result["protocolVersion"] == PROTOCOL_VERSION
assert result["serverInfo"]["name"] == SERVER_NAME
assert result["serverInfo"]["version"] == SERVER_VERSION
assert "tools" in result["capabilities"]
def test_handle_tools_list(self, server):
result = server.handle_tools_list({})
tools = result["tools"]
assert len(tools) == 5
names = {t["name"] for t in tools}
assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get"}
def test_handle_tools_call_store(self, server):
result = server.handle_tools_call({
"name": "memory_store",
"arguments": {
"content": "test memory",
"expanded_keywords": "test memory keywords",
},
})
assert not result.get("isError", False)
assert "Stored memory" in result["content"][0]["text"]
def test_handle_tools_call_unknown(self, server):
result = server.handle_tools_call({
"name": "nonexistent_tool",
"arguments": {},
})
assert result["isError"] is True
assert "Unknown tool" in result["content"][0]["text"]
def test_handle_tools_call_error(self, server):
result = server.handle_tools_call({
"name": "memory_store",
"arguments": {}, # missing content
})
assert result["isError"] is True
assert "Error" in result["content"][0]["text"]
class TestProcessMessage:
def test_initialize(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {},
})
assert response["jsonrpc"] == "2.0"
assert response["id"] == 1
assert "result" in response
assert response["result"]["serverInfo"]["name"] == SERVER_NAME
def test_tools_list(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
"params": {},
})
assert "result" in response
assert len(response["result"]["tools"]) == 5
def test_tools_call(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {
"name": "memory_store",
"arguments": {
"content": "via process_message",
"expanded_keywords": "process message test",
},
},
})
assert "result" in response
assert "Stored memory" in response["result"]["content"][0]["text"]
def test_unknown_method(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 4,
"method": "unknown/method",
"params": {},
})
assert "error" in response
assert response["error"]["code"] == -32601
assert "Method not found" in response["error"]["message"]
def test_notification_no_id(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
})
assert response is None
def test_jsonrpc_response_format(self, server):
response = server.process_message({
"jsonrpc": "2.0",
"id": 5,
"method": "initialize",
"params": {},
})
# Verify it's valid JSON when serialized
serialized = json.dumps(response)
parsed = json.loads(serialized)
assert parsed["jsonrpc"] == "2.0"
assert parsed["id"] == 5

154
tests/test_vault_client.py Normal file
View file

@ -0,0 +1,154 @@
"""Tests for Vault KV v2 client with mocked urllib."""
import json
import os
from io import BytesIO
from unittest.mock import MagicMock, mock_open, patch
import pytest
from claude_memory.vault_client import VaultClient
@pytest.fixture
def vault_env(monkeypatch):
monkeypatch.setenv("VAULT_ADDR", "http://vault.example.com:8200")
monkeypatch.setenv("VAULT_TOKEN", "s.testtoken123")
class TestVaultClientInit:
def test_missing_addr_raises_value_error(self, monkeypatch):
monkeypatch.delenv("VAULT_ADDR", raising=False)
monkeypatch.delenv("VAULT_TOKEN", raising=False)
with pytest.raises(ValueError, match="Vault address not configured"):
VaultClient()
def test_init_with_explicit_args(self):
client = VaultClient(addr="http://localhost:8200", token="mytoken")
assert client.addr == "http://localhost:8200"
assert client.token == "mytoken"
assert client.mount == "secret"
def test_init_from_env(self, vault_env):
client = VaultClient()
assert client.addr == "http://vault.example.com:8200"
assert client.token == "s.testtoken123"
def test_addr_trailing_slash_stripped(self):
client = VaultClient(addr="http://localhost:8200/", token="t")
assert client.addr == "http://localhost:8200"
@patch("os.path.exists", return_value=True)
@patch("builtins.open", mock_open(read_data="fake-jwt-token"))
@patch("urllib.request.urlopen")
def test_kubernetes_sa_token_auto_detection(self, mock_urlopen, mock_exists, monkeypatch):
monkeypatch.setenv("VAULT_ADDR", "http://vault:8200")
monkeypatch.delenv("VAULT_TOKEN", raising=False)
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({
"auth": {"client_token": "s.k8s-token-abc"}
}).encode()
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_urlopen.return_value = mock_response
client = VaultClient()
assert client.token == "s.k8s-token-abc"
class TestVaultRead:
@patch("urllib.request.urlopen")
def test_read_secret_returns_data(self, mock_urlopen, vault_env):
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({
"data": {"data": {"username": "admin", "password": "secret"}}
}).encode()
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_urlopen.return_value = mock_response
client = VaultClient()
result = client.read("myapp/config")
assert result == {"username": "admin", "password": "secret"}
@patch("urllib.request.urlopen")
def test_read_returns_none_for_404(self, mock_urlopen, vault_env):
import urllib.error
mock_urlopen.side_effect = urllib.error.HTTPError(
url="http://vault:8200/v1/secret/data/missing",
code=404,
msg="Not Found",
hdrs={},
fp=BytesIO(b""),
)
client = VaultClient()
result = client.read("missing/path")
assert result is None
class TestVaultWrite:
@patch("urllib.request.urlopen")
def test_write_secret_sends_correct_request(self, mock_urlopen, vault_env):
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({
"data": {"created_time": "2024-01-01T00:00:00Z", "version": 1}
}).encode()
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_urlopen.return_value = mock_response
client = VaultClient()
result = client.write("myapp/config", {"key": "value"})
# Verify the request was made with correct data
call_args = mock_urlopen.call_args
request = call_args[0][0]
assert request.full_url == "http://vault.example.com:8200/v1/secret/data/myapp/config"
assert request.method == "POST"
body = json.loads(request.data.decode())
assert body == {"data": {"key": "value"}}
class TestVaultDelete:
@patch("urllib.request.urlopen")
def test_delete_returns_true_on_success(self, mock_urlopen, vault_env):
mock_response = MagicMock()
mock_response.read.return_value = b"{}"
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_urlopen.return_value = mock_response
client = VaultClient()
assert client.delete("myapp/config") is True
@patch("urllib.request.urlopen")
def test_delete_returns_false_on_error(self, mock_urlopen, vault_env):
import urllib.error
mock_urlopen.side_effect = urllib.error.HTTPError(
url="http://vault:8200/v1/secret/data/missing",
code=500,
msg="Internal Server Error",
hdrs={},
fp=BytesIO(b"error"),
)
client = VaultClient()
assert client.delete("missing/path") is False
class TestVaultListSecrets:
@patch("urllib.request.urlopen")
def test_list_secrets(self, mock_urlopen, vault_env):
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({
"data": {"keys": ["secret1", "secret2/"]}
}).encode()
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_urlopen.return_value = mock_response
client = VaultClient()
result = client.list_secrets("myapp")
assert result == ["secret1", "secret2/"]