From 0ed5e1e016901abedae2c33049748e938bdfcf67 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Sat, 14 Mar 2026 09:42:05 +0000 Subject: [PATCH] feat: standalone claude-memory-mcp with multi-user support and Vault integration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extracted from private infra repo into standalone open-source project. Three operating modes: - Local: SQLite + FTS5 (zero dependencies) - Server: PostgreSQL via HTTP API with multi-user auth - Full: PostgreSQL + HashiCorp Vault for secret management Features: - MCP stdio server with 5 tools (store/recall/list/delete/secret_get) - FastAPI HTTP API with multi-user Bearer token auth (API_KEYS JSON map) - Regex-based credential detection with auto-redaction - AES-256-GCM encryption fallback for non-Vault deployments - Vault KV v2 client (stdlib urllib, K8s SA auto-auth) - Per-user data isolation (all queries scoped by user_id) - Secret migration endpoint for existing plain-text credentials - Backward-compatible env var aliases (CLAUDE_MEMORY_API_URL) Infrastructure: - Docker + docker-compose (API + PostgreSQL + optional Vault) - Woodpecker CI (test → build → push → kubectl deploy) - GitHub Actions CI (Python 3.11/3.12/3.13) + Release (GHCR + PyPI) - Helm chart + raw Kubernetes manifests 96 tests passing across 6 test files. --- .github/workflows/ci.yml | 36 ++ .github/workflows/release.yml | 40 ++ .gitignore | 45 ++ .woodpecker.yml | 43 ++ LICENSE | 191 ++++++ README.md | 224 +++++++ deploy/helm/claude-memory/Chart.yaml | 6 + .../claude-memory/templates/deployment.yaml | 35 ++ .../helm/claude-memory/templates/ingress.yaml | 25 + .../helm/claude-memory/templates/service.yaml | 12 + deploy/helm/claude-memory/values.yaml | 46 ++ deploy/kubernetes/deployment.yaml | 52 ++ deploy/kubernetes/ingress.yaml | 24 + deploy/kubernetes/namespace.yaml | 4 + deploy/kubernetes/service.yaml | 13 + docker/Dockerfile | 15 + docker/docker-compose.yml | 49 ++ examples/mcp-config-local.json | 10 + examples/mcp-config-server.json | 13 + examples/mcp-config-vault.json | 15 + pyproject.toml | 35 ++ src/claude_memory/__init__.py | 3 + src/claude_memory/__main__.py | 4 + src/claude_memory/api/__init__.py | 0 src/claude_memory/api/app.py | 337 +++++++++++ src/claude_memory/api/auth.py | 32 + src/claude_memory/api/database.py | 55 ++ src/claude_memory/api/models.py | 32 + src/claude_memory/api/vault_service.py | 51 ++ src/claude_memory/credential_detector.py | 76 +++ src/claude_memory/crypto.py | 71 +++ src/claude_memory/mcp_server.py | 550 ++++++++++++++++++ src/claude_memory/vault_client.py | 84 +++ tests/__init__.py | 0 tests/test_api.py | 304 ++++++++++ tests/test_auth.py | 87 +++ tests/test_credential_detector.py | 132 +++++ tests/test_crypto.py | 134 +++++ tests/test_mcp_server.py | 342 +++++++++++ tests/test_vault_client.py | 154 +++++ 40 files changed, 3381 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/release.yml create mode 100644 .gitignore create mode 100644 .woodpecker.yml create mode 100644 LICENSE create mode 100644 README.md create mode 100644 deploy/helm/claude-memory/Chart.yaml create mode 100644 deploy/helm/claude-memory/templates/deployment.yaml create mode 100644 deploy/helm/claude-memory/templates/ingress.yaml create mode 100644 deploy/helm/claude-memory/templates/service.yaml create mode 100644 deploy/helm/claude-memory/values.yaml create mode 100644 deploy/kubernetes/deployment.yaml create mode 100644 deploy/kubernetes/ingress.yaml create mode 100644 deploy/kubernetes/namespace.yaml create mode 100644 deploy/kubernetes/service.yaml create mode 100644 docker/Dockerfile create mode 100644 docker/docker-compose.yml create mode 100644 examples/mcp-config-local.json create mode 100644 examples/mcp-config-server.json create mode 100644 examples/mcp-config-vault.json create mode 100644 pyproject.toml create mode 100644 src/claude_memory/__init__.py create mode 100644 src/claude_memory/__main__.py create mode 100644 src/claude_memory/api/__init__.py create mode 100644 src/claude_memory/api/app.py create mode 100644 src/claude_memory/api/auth.py create mode 100644 src/claude_memory/api/database.py create mode 100644 src/claude_memory/api/models.py create mode 100644 src/claude_memory/api/vault_service.py create mode 100644 src/claude_memory/credential_detector.py create mode 100644 src/claude_memory/crypto.py create mode 100644 src/claude_memory/mcp_server.py create mode 100644 src/claude_memory/vault_client.py create mode 100644 tests/__init__.py create mode 100644 tests/test_api.py create mode 100644 tests/test_auth.py create mode 100644 tests/test_credential_detector.py create mode 100644 tests/test_crypto.py create mode 100644 tests/test_mcp_server.py create mode 100644 tests/test_vault_client.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..806e542 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,36 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11", "3.12", "3.13"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - run: pip install -e ".[api,dev]" + - run: ruff check src/ tests/ + - run: mypy src/claude_memory/ + - run: pytest tests/ -v --tb=short + + docker: + runs-on: ubuntu-latest + needs: test + steps: + - uses: actions/checkout@v4 + - uses: docker/setup-buildx-action@v3 + - uses: docker/build-push-action@v6 + with: + context: . + file: docker/Dockerfile + push: false + tags: claude-memory-mcp:test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..b392950 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,40 @@ +name: Release + +on: + push: + tags: ["v*"] + +jobs: + docker: + runs-on: ubuntu-latest + permissions: + packages: write + steps: + - uses: actions/checkout@v4 + - uses: docker/setup-buildx-action@v3 + - uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - uses: docker/build-push-action@v6 + with: + context: . + file: docker/Dockerfile + push: true + tags: | + ghcr.io/${{ github.repository }}:${{ github.ref_name }} + ghcr.io/${{ github.repository }}:latest + + pypi: + runs-on: ubuntu-latest + permissions: + id-token: write + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: pip install build + - run: python -m build + - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9605f28 --- /dev/null +++ b/.gitignore @@ -0,0 +1,45 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +*.egg-info/ +*.egg +dist/ +build/ +.eggs/ + +# Virtual environments +.venv/ +venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.mypy_cache/ +.ruff_cache/ + +# Environment +.env +.env.local +.env.production + +# OS +.DS_Store +Thumbs.db + +# Docker +docker/pgdata/ + +# Database +*.db +*.sqlite3 diff --git a/.woodpecker.yml b/.woodpecker.yml new file mode 100644 index 0000000..b69cc68 --- /dev/null +++ b/.woodpecker.yml @@ -0,0 +1,43 @@ +when: + - event: push + branch: main + +clone: + git: + image: woodpeckerci/plugin-git + settings: + attempts: 5 + backoff: 10s + +steps: + - name: test + image: python:3.12-slim + commands: + - pip install -e ".[api,dev]" + - ruff check src/ tests/ + - pytest tests/ -v --tb=short + + - name: build-and-push + image: woodpeckerci/plugin-docker-buildx + depends_on: + - test + settings: + username: viktorbarzin + password: + from_secret: dockerhub-token + repo: viktorbarzin/claude-memory-mcp + dockerfile: docker/Dockerfile + context: . + platforms: + - linux/amd64 + tags: + - "${CI_PIPELINE_NUMBER}" + - latest + + - name: deploy + image: bitnami/kubectl:latest + depends_on: + - build-and-push + commands: + - kubectl rollout restart deployment/claude-memory -n claude-memory + - kubectl rollout status deployment/claude-memory -n claude-memory --timeout=120s diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..134f341 --- /dev/null +++ b/LICENSE @@ -0,0 +1,191 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to the Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by the Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding any notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2026 Viktor Barzin + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..9170ae9 --- /dev/null +++ b/README.md @@ -0,0 +1,224 @@ +# Claude Memory MCP + +A persistent memory layer for Claude Code that stores knowledge across sessions. Operates as an MCP (Model Context Protocol) server with optional API and database backends. + +## Operating Modes + +| Mode | Storage | Auth | Use Case | +|------|---------|------|----------| +| **Local** | SQLite file | None | Single user, local Claude Code | +| **Server** | SQLite file | API key | Remote access, single user | +| **Full** | PostgreSQL | API keys + Vault | Multi-user, team deployment | + +## Quick Start + +### Local Mode (MCP stdio) + +Install and configure Claude Code to use the MCP server directly: + +```bash +pip install claude-memory-mcp +``` + +Add to your Claude Code MCP config (`~/.claude/plugins/`): + +```json +{ + "mcpServers": { + "memory": { + "command": "python", + "args": ["-m", "claude_memory.mcp_server"], + "env": { + "MEMORY_DB_PATH": "~/.claude/memory.db" + } + } + } +} +``` + +### Server Mode (API) + +Run the API server with SQLite: + +```bash +pip install claude-memory-mcp[api] + +export DATABASE_URL="sqlite:///./memory.db" +export API_KEY="your-secret-key" + +uvicorn claude_memory.api.app:app --host 0.0.0.0 --port 8000 +``` + +Configure Claude Code to connect via HTTP: + +```json +{ + "mcpServers": { + "memory": { + "command": "python", + "args": ["-m", "claude_memory.mcp_server"], + "env": { + "MEMORY_API_URL": "http://localhost:8000", + "MEMORY_API_KEY": "your-secret-key" + } + } + } +} +``` + +### Full Mode (Docker Compose) + +```bash +cd docker +docker compose up -d +``` + +This starts the API server with PostgreSQL. See [Docker Compose](#docker-compose) for details. + +## Docker Compose + +The dev environment includes the API server and PostgreSQL: + +```bash +cd docker +docker compose up -d +``` + +To include HashiCorp Vault for secret management: + +```bash +docker compose --profile vault up -d +``` + +### Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `DATABASE_URL` | Database connection string | `sqlite:///./memory.db` | +| `API_KEY` | Single-user API key | None | +| `API_KEYS` | Multi-user JSON map `{"user": "key"}` | None | +| `VAULT_ADDR` | Vault server address | None | +| `VAULT_TOKEN` | Vault authentication token | None | + +## Multi-User Setup + +For team deployments, use the `API_KEYS` environment variable with a JSON mapping of usernames to API keys: + +```bash +export API_KEYS='{"alice": "key-alice-xxx", "bob": "key-bob-yyy"}' +``` + +Each user gets isolated memory storage. The username is extracted from the API key on each request. + +## Vault Integration + +For production deployments, store API keys in HashiCorp Vault instead of environment variables: + +```bash +export VAULT_ADDR="https://vault.example.com" +export VAULT_TOKEN="s.xxxxxxxxxxxx" +``` + +The server reads API keys from the Vault KV store at `secret/claude-memory/api-keys`. + +## API Reference + +### Health Check + +``` +GET /health +``` + +### Store Memory + +``` +POST /api/v1/memories +Authorization: Bearer +Content-Type: application/json + +{ + "content": "The user prefers dark mode", + "tags": ["preferences", "ui"], + "source": "conversation" +} +``` + +### Recall Memories + +``` +GET /api/v1/memories?q=dark+mode&limit=10 +Authorization: Bearer +``` + +### List Memories + +``` +GET /api/v1/memories?tags=preferences&limit=20 +Authorization: Bearer +``` + +### Delete Memory + +``` +DELETE /api/v1/memories/{id} +Authorization: Bearer +``` + +## Kubernetes Deployment + +### Helm + +```bash +helm install claude-memory deploy/helm/claude-memory \ + --set env.DATABASE_URL="postgresql://user:pass@host:5432/db" \ + --set env.API_KEY="your-key" \ + --set ingress.host="claude-memory.yourdomain.com" +``` + +### Raw Manifests + +```bash +kubectl apply -f deploy/kubernetes/namespace.yaml +# Create secret with your credentials first: +kubectl create secret generic claude-memory-secrets \ + -n claude-memory \ + --from-literal=database-url="postgresql://user:pass@host:5432/db" \ + --from-literal=api-key="your-key" +kubectl apply -f deploy/kubernetes/ +``` + +## Development + +### Setup + +```bash +git clone https://github.com/viktorbarzin/claude-memory-mcp.git +cd claude-memory-mcp +python -m venv .venv +source .venv/bin/activate +pip install -e ".[api,dev]" +``` + +### Running Tests + +```bash +pytest tests/ -v +``` + +### Linting + +```bash +ruff check src/ tests/ +mypy src/claude_memory/ +``` + +### Building + +```bash +pip install build +python -m build +``` + +## License + +Apache License 2.0. See [LICENSE](LICENSE) for details. diff --git a/deploy/helm/claude-memory/Chart.yaml b/deploy/helm/claude-memory/Chart.yaml new file mode 100644 index 0000000..2a56c9b --- /dev/null +++ b/deploy/helm/claude-memory/Chart.yaml @@ -0,0 +1,6 @@ +apiVersion: v2 +name: claude-memory +description: Claude Memory MCP API server +type: application +version: 1.0.0 +appVersion: "1.0.0" diff --git a/deploy/helm/claude-memory/templates/deployment.yaml b/deploy/helm/claude-memory/templates/deployment.yaml new file mode 100644 index 0000000..ddb0954 --- /dev/null +++ b/deploy/helm/claude-memory/templates/deployment.yaml @@ -0,0 +1,35 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ .Release.Name }} + labels: + app: {{ .Release.Name }} +spec: + replicas: {{ .Values.replicaCount }} + selector: + matchLabels: + app: {{ .Release.Name }} + template: + metadata: + labels: + app: {{ .Release.Name }} + spec: + containers: + - name: {{ .Release.Name }} + image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + ports: + - containerPort: {{ .Values.service.targetPort }} + {{- range $key, $value := .Values.env }} + {{- if $value }} + env: + - name: {{ $key }} + value: {{ $value | quote }} + {{- end }} + {{- end }} + livenessProbe: + {{- toYaml .Values.livenessProbe | nindent 12 }} + readinessProbe: + {{- toYaml .Values.readinessProbe | nindent 12 }} + resources: + {{- toYaml .Values.resources | nindent 12 }} diff --git a/deploy/helm/claude-memory/templates/ingress.yaml b/deploy/helm/claude-memory/templates/ingress.yaml new file mode 100644 index 0000000..e4f7d39 --- /dev/null +++ b/deploy/helm/claude-memory/templates/ingress.yaml @@ -0,0 +1,25 @@ +{{- if .Values.ingress.enabled }} +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: {{ .Release.Name }} + annotations: + nginx.ingress.kubernetes.io/ssl-redirect: "true" +spec: + ingressClassName: {{ .Values.ingress.className }} + tls: + - hosts: + - {{ .Values.ingress.host }} + secretName: {{ .Values.ingress.tls.secretName }} + rules: + - host: {{ .Values.ingress.host }} + http: + paths: + - path: / + pathType: Prefix + backend: + service: + name: {{ .Release.Name }} + port: + number: {{ .Values.service.port }} +{{- end }} diff --git a/deploy/helm/claude-memory/templates/service.yaml b/deploy/helm/claude-memory/templates/service.yaml new file mode 100644 index 0000000..ca0fef3 --- /dev/null +++ b/deploy/helm/claude-memory/templates/service.yaml @@ -0,0 +1,12 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ .Release.Name }} +spec: + type: {{ .Values.service.type }} + ports: + - port: {{ .Values.service.port }} + targetPort: {{ .Values.service.targetPort }} + protocol: TCP + selector: + app: {{ .Release.Name }} diff --git a/deploy/helm/claude-memory/values.yaml b/deploy/helm/claude-memory/values.yaml new file mode 100644 index 0000000..d07c0de --- /dev/null +++ b/deploy/helm/claude-memory/values.yaml @@ -0,0 +1,46 @@ +replicaCount: 1 + +image: + repository: viktorbarzin/claude-memory-mcp + tag: latest + pullPolicy: Always + +service: + type: ClusterIP + port: 80 + targetPort: 8000 + +ingress: + enabled: true + className: nginx + host: claude-memory.example.com + tls: + secretName: tls-secret + +resources: + requests: + memory: 32Mi + cpu: 10m + limits: + memory: 128Mi + +env: + DATABASE_URL: "" + API_KEY: "" + # API_KEYS: '{}' + # VAULT_ADDR: "" + # VAULT_TOKEN: "" + +livenessProbe: + httpGet: + path: /health + port: 8000 + initialDelaySeconds: 5 + periodSeconds: 30 + +readinessProbe: + httpGet: + path: /health + port: 8000 + initialDelaySeconds: 3 + periodSeconds: 10 diff --git a/deploy/kubernetes/deployment.yaml b/deploy/kubernetes/deployment.yaml new file mode 100644 index 0000000..97e6043 --- /dev/null +++ b/deploy/kubernetes/deployment.yaml @@ -0,0 +1,52 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: claude-memory + namespace: claude-memory + labels: + app: claude-memory +spec: + replicas: 1 + selector: + matchLabels: + app: claude-memory + template: + metadata: + labels: + app: claude-memory + spec: + containers: + - name: claude-memory + image: viktorbarzin/claude-memory-mcp:latest + imagePullPolicy: Always + ports: + - containerPort: 8000 + env: + - name: DATABASE_URL + valueFrom: + secretKeyRef: + name: claude-memory-secrets + key: database-url + - name: API_KEY + valueFrom: + secretKeyRef: + name: claude-memory-secrets + key: api-key + livenessProbe: + httpGet: + path: /health + port: 8000 + initialDelaySeconds: 5 + periodSeconds: 30 + readinessProbe: + httpGet: + path: /health + port: 8000 + initialDelaySeconds: 3 + periodSeconds: 10 + resources: + requests: + memory: 32Mi + cpu: 10m + limits: + memory: 128Mi diff --git a/deploy/kubernetes/ingress.yaml b/deploy/kubernetes/ingress.yaml new file mode 100644 index 0000000..20c3194 --- /dev/null +++ b/deploy/kubernetes/ingress.yaml @@ -0,0 +1,24 @@ +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: claude-memory + namespace: claude-memory + annotations: + nginx.ingress.kubernetes.io/ssl-redirect: "true" +spec: + ingressClassName: nginx + tls: + - hosts: + - claude-memory.example.com + secretName: tls-secret + rules: + - host: claude-memory.example.com + http: + paths: + - path: / + pathType: Prefix + backend: + service: + name: claude-memory + port: + number: 80 diff --git a/deploy/kubernetes/namespace.yaml b/deploy/kubernetes/namespace.yaml new file mode 100644 index 0000000..beb6255 --- /dev/null +++ b/deploy/kubernetes/namespace.yaml @@ -0,0 +1,4 @@ +apiVersion: v1 +kind: Namespace +metadata: + name: claude-memory diff --git a/deploy/kubernetes/service.yaml b/deploy/kubernetes/service.yaml new file mode 100644 index 0000000..09a90d8 --- /dev/null +++ b/deploy/kubernetes/service.yaml @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Service +metadata: + name: claude-memory + namespace: claude-memory +spec: + type: ClusterIP + ports: + - port: 80 + targetPort: 8000 + protocol: TCP + selector: + app: claude-memory diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..230a576 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,15 @@ +FROM python:3.12-slim AS base + +WORKDIR /app + +COPY pyproject.toml . +COPY src/ src/ + +RUN pip install --no-cache-dir ".[api]" + +RUN useradd -r -u 1000 app +USER app + +EXPOSE 8000 + +CMD ["uvicorn", "claude_memory.api.app:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 0000000..031765b --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,49 @@ +services: + api: + build: + context: .. + dockerfile: docker/Dockerfile + ports: + - "8000:8000" + environment: + DATABASE_URL: postgresql://claude_memory:devpassword@postgres:5432/claude_memory + API_KEY: dev-api-key + # Multi-user mode (uncomment to test): + # API_KEYS: '{"viktor": "key1", "testuser": "key2"}' + # Vault (uncomment to test): + # VAULT_ADDR: http://vault:8200 + # VAULT_TOKEN: dev-root-token + depends_on: + postgres: + condition: service_healthy + + postgres: + image: postgres:16-alpine + environment: + POSTGRES_DB: claude_memory + POSTGRES_USER: claude_memory + POSTGRES_PASSWORD: devpassword + ports: + - "5432:5432" + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U claude_memory"] + interval: 5s + timeout: 3s + retries: 5 + + vault: + image: hashicorp/vault:1.15 + ports: + - "8200:8200" + environment: + VAULT_DEV_ROOT_TOKEN_ID: dev-root-token + VAULT_DEV_LISTEN_ADDRESS: 0.0.0.0:8200 + cap_add: + - IPC_LOCK + profiles: + - vault + +volumes: + pgdata: diff --git a/examples/mcp-config-local.json b/examples/mcp-config-local.json new file mode 100644 index 0000000..c22cce3 --- /dev/null +++ b/examples/mcp-config-local.json @@ -0,0 +1,10 @@ +{ + "mcpServers": { + "memory": { + "type": "stdio", + "command": "python3", + "args": ["-m", "claude_memory.mcp_server"], + "env": {} + } + } +} diff --git a/examples/mcp-config-server.json b/examples/mcp-config-server.json new file mode 100644 index 0000000..1f291a7 --- /dev/null +++ b/examples/mcp-config-server.json @@ -0,0 +1,13 @@ +{ + "mcpServers": { + "memory": { + "type": "stdio", + "command": "python3", + "args": ["-m", "claude_memory.mcp_server"], + "env": { + "MEMORY_API_URL": "https://claude-memory.example.com", + "MEMORY_API_KEY": "your-api-key-here" + } + } + } +} diff --git a/examples/mcp-config-vault.json b/examples/mcp-config-vault.json new file mode 100644 index 0000000..cb11cf1 --- /dev/null +++ b/examples/mcp-config-vault.json @@ -0,0 +1,15 @@ +{ + "mcpServers": { + "memory": { + "type": "stdio", + "command": "python3", + "args": ["-m", "claude_memory.mcp_server"], + "env": { + "MEMORY_API_URL": "https://claude-memory.example.com", + "MEMORY_API_KEY": "your-api-key-here", + "VAULT_ADDR": "https://vault.example.com", + "VAULT_TOKEN": "your-vault-token" + } + } + } +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..91e0ada --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "claude-memory-mcp" +version = "1.0.0" +description = "Standalone MCP memory server with multi-user support and Vault integration" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.11" +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", +] + +[project.optional-dependencies] +api = ["fastapi>=0.115", "asyncpg>=0.30", "uvicorn>=0.34", "pydantic>=2.0"] +vault = ["hvac>=2.0"] +dev = ["pytest>=8.0", "pytest-asyncio>=0.24", "ruff>=0.8", "mypy>=1.13", "httpx>=0.28"] + +[project.scripts] +claude-memory-server = "claude_memory.mcp_server:main" + +[tool.hatch.build.targets.wheel] +packages = ["src/claude_memory"] + +[tool.ruff] +target-version = "py311" +line-length = 120 + +[tool.mypy] +python_version = "3.11" +strict = true diff --git a/src/claude_memory/__init__.py b/src/claude_memory/__init__.py new file mode 100644 index 0000000..2865899 --- /dev/null +++ b/src/claude_memory/__init__.py @@ -0,0 +1,3 @@ +"""Claude Memory MCP — standalone memory server with multi-user support.""" + +__version__ = "1.0.0" diff --git a/src/claude_memory/__main__.py b/src/claude_memory/__main__.py new file mode 100644 index 0000000..eb65e9b --- /dev/null +++ b/src/claude_memory/__main__.py @@ -0,0 +1,4 @@ +"""Allow running as `python -m claude_memory`.""" +from claude_memory.mcp_server import main + +main() diff --git a/src/claude_memory/api/__init__.py b/src/claude_memory/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/claude_memory/api/app.py b/src/claude_memory/api/app.py new file mode 100644 index 0000000..89529eb --- /dev/null +++ b/src/claude_memory/api/app.py @@ -0,0 +1,337 @@ +"""Claude Memory API -- shared persistent memory with PostgreSQL full-text search.""" + +import logging +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from typing import Optional + +from fastapi import Depends, FastAPI, HTTPException + +from claude_memory.api.auth import AuthUser, get_current_user +from claude_memory.api.database import close_pool, get_pool, init_pool +from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse +from claude_memory.api.vault_service import ( + delete_secret, + get_secret, + is_vault_configured, + store_secret, +) + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await init_pool() + yield + await close_pool() + + +app = FastAPI(title="Claude Memory API", lifespan=lifespan) + + +def _detect_sensitive(content: str) -> bool: + """Check if content contains credentials using the credential detector.""" + try: + from claude_memory.credential_detector import detect_credentials + + findings = detect_credentials(content) + return len(findings) > 0 + except ImportError: + return False + + +def _redact_content(content: str) -> str: + """Redact sensitive content for storage in the main DB.""" + try: + from claude_memory.credential_detector import detect_credentials, redact_credentials + + creds = detect_credentials(content) + if creds: + return redact_credentials(content, creds) + return content + except ImportError: + return "[REDACTED]" + + +@app.get("/health") +async def health(): + return {"status": "ok"} + + +@app.post("/api/memories", response_model=MemoryResponse) +async def store_memory(body: MemoryStore, user: AuthUser = Depends(get_current_user)): + pool = await get_pool() + is_sensitive = body.force_sensitive or _detect_sensitive(body.content) + + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + INSERT INTO memories (user_id, content, category, tags, expanded_keywords, importance, is_sensitive) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id, category, importance + """, + user.user_id, + body.content if not is_sensitive else _redact_content(body.content), + body.category, + body.tags, + body.expanded_keywords, + body.importance, + is_sensitive, + ) + memory_id = row["id"] + + if is_sensitive and is_vault_configured(): + vault_path = await store_secret(user.user_id, memory_id, body.content) + await conn.execute( + "UPDATE memories SET vault_path = $1 WHERE id = $2", + vault_path, + memory_id, + ) + + return MemoryResponse(id=row["id"], category=row["category"], importance=row["importance"]) + + +@app.post("/api/memories/recall") +async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_current_user)): + pool = await get_pool() + + query_text = f"{body.context} {body.expanded_query}".strip() + + order_clause = "ts_rank(search_vector, query) DESC" + if body.sort_by == "importance": + order_clause = "importance DESC, ts_rank(search_vector, query) DESC" + elif body.sort_by == "recency": + order_clause = "created_at DESC" + + category_filter = "" + params: list = [user.user_id, query_text, body.limit] + if body.category: + category_filter = "AND category = $4" + params.append(body.category) + + async with pool.acquire() as conn: + rows = await conn.fetch( + f""" + SELECT id, content, category, tags, importance, is_sensitive, + ts_rank(search_vector, query) AS rank, + created_at, updated_at + FROM memories, plainto_tsquery('english', $2) query + WHERE user_id = $1 + AND (search_vector @@ query OR $2 = '') + {category_filter} + ORDER BY {order_clause} + LIMIT $3 + """, + *params, + ) + + results = [] + for row in rows: + content = row["content"] + if row["is_sensitive"]: + content = f"[SENSITIVE - use secret_get(id={row['id']})]" + results.append( + { + "id": row["id"], + "content": content, + "category": row["category"], + "tags": row["tags"], + "importance": row["importance"], + "is_sensitive": row["is_sensitive"], + "rank": float(row["rank"]), + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + ) + + return results + + +@app.get("/api/memories") +async def list_memories( + category: Optional[str] = None, + limit: int = 50, + user: AuthUser = Depends(get_current_user), +): + pool = await get_pool() + + if category: + query = """ + SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at + FROM memories WHERE user_id = $1 AND category = $2 + ORDER BY importance DESC LIMIT $3 + """ + params: list = [user.user_id, category, limit] + else: + query = """ + SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at + FROM memories WHERE user_id = $1 + ORDER BY importance DESC LIMIT $2 + """ + params = [user.user_id, limit] + + async with pool.acquire() as conn: + rows = await conn.fetch(query, *params) + + results = [] + for row in rows: + content = row["content"] + if row["is_sensitive"]: + content = f"[SENSITIVE - use secret_get(id={row['id']})]" + results.append( + { + "id": row["id"], + "content": content, + "category": row["category"], + "tags": row["tags"], + "importance": row["importance"], + "is_sensitive": row["is_sensitive"], + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + ) + + return results + + +@app.delete("/api/memories/{memory_id}") +async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_user)): + pool = await get_pool() + + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT id, vault_path FROM memories WHERE id = $1 AND user_id = $2", + memory_id, + user.user_id, + ) + if not row: + raise HTTPException(status_code=404, detail="Memory not found") + + if row["vault_path"]: + await delete_secret(user.user_id, row["vault_path"]) + + await conn.execute( + "DELETE FROM memories WHERE id = $1 AND user_id = $2", + memory_id, + user.user_id, + ) + + return {"deleted": memory_id} + + +@app.post("/api/memories/{memory_id}/secret", response_model=SecretResponse) +async def get_memory_secret(memory_id: int, user: AuthUser = Depends(get_current_user)): + pool = await get_pool() + + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT id, content, is_sensitive, vault_path, encrypted_content + FROM memories WHERE id = $1 AND user_id = $2 + """, + memory_id, + user.user_id, + ) + + if not row: + raise HTTPException(status_code=404, detail="Memory not found") + + if not row["is_sensitive"]: + return SecretResponse(id=row["id"], content=row["content"], source="plaintext") + + if row["vault_path"]: + secret = await get_secret(user.user_id, row["vault_path"]) + if secret: + return SecretResponse(id=row["id"], content=secret, source="vault") + + if row["encrypted_content"]: + return SecretResponse( + id=row["id"], + content="[ENCRYPTED - decryption not available]", + source="encrypted", + ) + + return SecretResponse(id=row["id"], content=row["content"], source="plaintext") + + +@app.post("/api/memories/migrate-secrets") +async def migrate_secrets(user: AuthUser = Depends(get_current_user)): + pool = await get_pool() + migrated = 0 + + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT id, content FROM memories + WHERE user_id = $1 AND is_sensitive = FALSE + """, + user.user_id, + ) + + for row in rows: + if _detect_sensitive(row["content"]): + original_content = row["content"] + redacted = _redact_content(original_content) + + vault_path = None + if is_vault_configured(): + vault_path = await store_secret(user.user_id, row["id"], original_content) + + await conn.execute( + """ + UPDATE memories + SET is_sensitive = TRUE, content = $1, vault_path = $2, + updated_at = NOW() + WHERE id = $3 AND user_id = $4 + """, + redacted, + vault_path, + row["id"], + user.user_id, + ) + migrated += 1 + + return {"migrated": migrated} + + +@app.post("/api/memories/import") +async def import_memories( + memories: list[MemoryStore], user: AuthUser = Depends(get_current_user) +): + pool = await get_pool() + imported = [] + + async with pool.acquire() as conn: + for mem in memories: + is_sensitive = mem.force_sensitive or _detect_sensitive(mem.content) + content = mem.content if not is_sensitive else _redact_content(mem.content) + + row = await conn.fetchrow( + """ + INSERT INTO memories (user_id, content, category, tags, expanded_keywords, importance, is_sensitive) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id, category, importance + """, + user.user_id, + content, + mem.category, + mem.tags, + mem.expanded_keywords, + mem.importance, + is_sensitive, + ) + + if is_sensitive and is_vault_configured(): + vault_path = await store_secret(user.user_id, row["id"], mem.content) + await conn.execute( + "UPDATE memories SET vault_path = $1 WHERE id = $2", + vault_path, + row["id"], + ) + + imported.append( + MemoryResponse(id=row["id"], category=row["category"], importance=row["importance"]) + ) + + return imported diff --git a/src/claude_memory/api/auth.py b/src/claude_memory/api/auth.py new file mode 100644 index 0000000..0667b33 --- /dev/null +++ b/src/claude_memory/api/auth.py @@ -0,0 +1,32 @@ +import json +import os +from dataclasses import dataclass + +from fastapi import Header, HTTPException + + +@dataclass +class AuthUser: + user_id: str + + +# Multi-user mode: API_KEYS='{"viktor": "key1", "user2": "key2"}' +# Single-user mode: API_KEY="some-key" (backward compatible, user_id="default") +_api_keys_json = os.environ.get("API_KEYS", "") +_api_key_single = os.environ.get("API_KEY", "") + +_key_to_user: dict[str, str] = {} + +if _api_keys_json: + _user_to_key = json.loads(_api_keys_json) + _key_to_user = {v: k for k, v in _user_to_key.items()} +elif _api_key_single: + _key_to_user = {_api_key_single: "default"} + + +async def get_current_user(authorization: str = Header(...)) -> AuthUser: + token = authorization.removeprefix("Bearer ").strip() + user_id = _key_to_user.get(token) + if user_id is None: + raise HTTPException(status_code=401, detail="Invalid API key") + return AuthUser(user_id=user_id) diff --git a/src/claude_memory/api/database.py b/src/claude_memory/api/database.py new file mode 100644 index 0000000..df3ee3d --- /dev/null +++ b/src/claude_memory/api/database.py @@ -0,0 +1,55 @@ +import os + +import asyncpg + +DATABASE_URL = os.environ.get("DATABASE_URL", "") + +pool: asyncpg.Pool | None = None + + +async def init_pool() -> asyncpg.Pool: + global pool + pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) + async with pool.acquire() as conn: + await conn.execute(""" + CREATE TABLE IF NOT EXISTS memories ( + id SERIAL PRIMARY KEY, + user_id VARCHAR(100) NOT NULL DEFAULT 'default', + content TEXT NOT NULL, + category VARCHAR(50) DEFAULT 'facts', + tags TEXT DEFAULT '', + expanded_keywords TEXT DEFAULT '', + importance REAL DEFAULT 0.5, + is_sensitive BOOLEAN DEFAULT FALSE, + vault_path TEXT DEFAULT NULL, + encrypted_content BYTEA DEFAULT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + search_vector tsvector GENERATED ALWAYS AS ( + setweight(to_tsvector('english', coalesce(content, '')), 'A') || + setweight(to_tsvector('english', coalesce(expanded_keywords, '')), 'B') || + setweight(to_tsvector('english', coalesce(tags, '')), 'C') || + setweight(to_tsvector('english', coalesce(category, '')), 'D') + ) STORED + ) + """) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_memories_search ON memories USING GIN(search_vector)" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_memories_user ON memories(user_id)" + ) + return pool + + +async def close_pool(): + global pool + if pool: + await pool.close() + pool = None + + +async def get_pool() -> asyncpg.Pool: + if pool is None: + raise RuntimeError("Database pool not initialized") + return pool diff --git a/src/claude_memory/api/models.py b/src/claude_memory/api/models.py new file mode 100644 index 0000000..d4fb804 --- /dev/null +++ b/src/claude_memory/api/models.py @@ -0,0 +1,32 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class MemoryStore(BaseModel): + content: str + category: str = "facts" + tags: str = "" + expanded_keywords: str = "" + importance: float = Field(default=0.5, ge=0.0, le=1.0) + force_sensitive: bool = False + + +class MemoryRecall(BaseModel): + context: str + expanded_query: str = "" + category: Optional[str] = None + sort_by: str = "importance" + limit: int = 10 + + +class MemoryResponse(BaseModel): + id: int + category: str + importance: float + + +class SecretResponse(BaseModel): + id: int + content: str + source: str # "vault", "encrypted", "plaintext" diff --git a/src/claude_memory/api/vault_service.py b/src/claude_memory/api/vault_service.py new file mode 100644 index 0000000..d329f04 --- /dev/null +++ b/src/claude_memory/api/vault_service.py @@ -0,0 +1,51 @@ +import os +import logging + +logger = logging.getLogger(__name__) + +VAULT_ADDR = os.environ.get("VAULT_ADDR", "") +VAULT_TOKEN = os.environ.get("VAULT_TOKEN", "") +VAULT_MOUNT = os.environ.get("VAULT_MOUNT", "secret") +VAULT_PREFIX = os.environ.get("VAULT_PREFIX", "claude-memory") + + +def is_vault_configured() -> bool: + return bool(VAULT_ADDR and VAULT_TOKEN) + + +async def store_secret(user_id: str, memory_id: int, content: str) -> str: + """Store secret content in Vault. Returns the vault path.""" + if not is_vault_configured(): + raise RuntimeError("Vault not configured") + + from claude_memory.vault_client import VaultClient + + client = VaultClient(VAULT_ADDR, VAULT_TOKEN, VAULT_MOUNT) + path = f"{VAULT_PREFIX}/{user_id}/mem-{memory_id}" + client.write(path, {"content": content}) + return path + + +async def get_secret(user_id: str, vault_path: str) -> str | None: + """Retrieve secret content from Vault.""" + if not is_vault_configured(): + return None + + from claude_memory.vault_client import VaultClient + + client = VaultClient(VAULT_ADDR, VAULT_TOKEN, VAULT_MOUNT) + data = client.read(vault_path) + if data: + return data.get("content") + return None + + +async def delete_secret(user_id: str, vault_path: str) -> bool: + """Delete secret from Vault.""" + if not is_vault_configured(): + return False + + from claude_memory.vault_client import VaultClient + + client = VaultClient(VAULT_ADDR, VAULT_TOKEN, VAULT_MOUNT) + return client.delete(vault_path) diff --git a/src/claude_memory/credential_detector.py b/src/claude_memory/credential_detector.py new file mode 100644 index 0000000..ea69e53 --- /dev/null +++ b/src/claude_memory/credential_detector.py @@ -0,0 +1,76 @@ +"""Detect credentials and secrets in text content.""" + +import re +from dataclasses import dataclass + + +@dataclass +class DetectedCredential: + type: str # e.g. "password", "api_key", "private_key", "connection_string", "token" + confidence: float # 0.0 to 1.0 + start: int # position in text + end: int # position in text + matched_text: str # the actual matched text (for redaction) + + +# Patterns ordered by confidence +_PATTERNS: list[tuple[str, str, float]] = [ + # High confidence (0.9+) + ("private_key", r"-----BEGIN (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----[\s\S]*?-----END (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----", 0.95), + ("connection_string", r"(?:postgres(?:ql)?|mysql|mongodb(?:\+srv)?|redis|amqp)://[^\s'\"]+", 0.9), + ("aws_key", r"(?:AKIA|ASIA)[A-Z0-9]{16}", 0.95), + ("github_token", r"gh[pousr]_[A-Za-z0-9_]{36,}", 0.95), + + # Medium confidence (0.7-0.89) + ("api_key", r"(?:api[_-]?key|apikey)\s*[:=]\s*['\"]?([A-Za-z0-9_\-]{20,})['\"]?", 0.8), + ("password", r"(?:password|passwd|pwd)\s*[:=]\s*['\"]?([^\s'\"]{8,})['\"]?", 0.8), + ("token", r"(?:token|secret|bearer)\s*[:=]\s*['\"]?([A-Za-z0-9_\-\.]{20,})['\"]?", 0.75), + ("basic_auth", r"(?:Basic\s+)[A-Za-z0-9+/=]{20,}", 0.85), + ("bearer_token", r"Bearer\s+[A-Za-z0-9_\-\.]{20,}", 0.85), + + # Lower confidence (0.5-0.69) + ("generic_secret", r"(?:secret|credential|auth)\s*[:=]\s*['\"]?([^\s'\"]{12,})['\"]?", 0.6), + ("hex_key", r"(?:key|secret)\s*[:=]\s*['\"]?([0-9a-fA-F]{32,})['\"]?", 0.65), +] + + +def detect_credentials(text: str, min_confidence: float = 0.5) -> list[DetectedCredential]: + """Scan text for potential credentials and secrets.""" + results: list[DetectedCredential] = [] + for cred_type, pattern, confidence in _PATTERNS: + if confidence < min_confidence: + continue + for match in re.finditer(pattern, text, re.IGNORECASE): + results.append(DetectedCredential( + type=cred_type, + confidence=confidence, + start=match.start(), + end=match.end(), + matched_text=match.group(0), + )) + # Deduplicate overlapping matches, keeping highest confidence + results.sort(key=lambda c: (-c.confidence, c.start)) + filtered: list[DetectedCredential] = [] + for cred in results: + if not any(c.start <= cred.start and c.end >= cred.end for c in filtered): + filtered.append(cred) + return sorted(filtered, key=lambda c: c.start) + + +def redact_credentials(text: str, credentials: list[DetectedCredential]) -> str: + """Replace detected credentials with [REDACTED] markers.""" + if not credentials: + return text + parts: list[str] = [] + last_end = 0 + for cred in sorted(credentials, key=lambda c: c.start): + parts.append(text[last_end:cred.start]) + parts.append(f"[REDACTED:{cred.type}]") + last_end = cred.end + parts.append(text[last_end:]) + return "".join(parts) + + +def is_sensitive(text: str, min_confidence: float = 0.7) -> bool: + """Quick check if text likely contains credentials.""" + return len(detect_credentials(text, min_confidence)) > 0 diff --git a/src/claude_memory/crypto.py b/src/claude_memory/crypto.py new file mode 100644 index 0000000..488fc47 --- /dev/null +++ b/src/claude_memory/crypto.py @@ -0,0 +1,71 @@ +"""AES-256-GCM encryption for memory content when Vault is not available.""" + +import base64 +import hashlib +import os + +ENCRYPTION_KEY_ENV = "MEMORY_ENCRYPTION_KEY" + + +def _get_key() -> bytes | None: + """Get 32-byte encryption key from environment.""" + raw = os.environ.get(ENCRYPTION_KEY_ENV) + if not raw: + return None + # Accept hex-encoded 32-byte key or derive from passphrase + try: + key = bytes.fromhex(raw) + if len(key) == 32: + return key + except ValueError: + pass + # Derive key from passphrase using SHA-256 + return hashlib.sha256(raw.encode()).digest() + + +def is_encryption_configured() -> bool: + return _get_key() is not None + + +def encrypt(plaintext: str) -> bytes: + """Encrypt text using AES-256-GCM. Returns nonce + ciphertext + tag.""" + key = _get_key() + if key is None: + raise RuntimeError(f"{ENCRYPTION_KEY_ENV} not set") + + try: + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + except ImportError: + raise RuntimeError("cryptography package required for encryption: pip install cryptography") + + nonce = os.urandom(12) + aesgcm = AESGCM(key) + ciphertext = aesgcm.encrypt(nonce, plaintext.encode(), None) + return nonce + ciphertext # 12 bytes nonce + ciphertext + 16 bytes tag + + +def decrypt(data: bytes) -> str: + """Decrypt AES-256-GCM encrypted data.""" + key = _get_key() + if key is None: + raise RuntimeError(f"{ENCRYPTION_KEY_ENV} not set") + + try: + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + except ImportError: + raise RuntimeError("cryptography package required for encryption: pip install cryptography") + + nonce = data[:12] + ciphertext = data[12:] + aesgcm = AESGCM(key) + return aesgcm.decrypt(nonce, ciphertext, None).decode() + + +def encrypt_b64(plaintext: str) -> str: + """Encrypt and return base64-encoded string.""" + return base64.b64encode(encrypt(plaintext)).decode() + + +def decrypt_b64(data: str) -> str: + """Decrypt from base64-encoded string.""" + return decrypt(base64.b64decode(data)) diff --git a/src/claude_memory/mcp_server.py b/src/claude_memory/mcp_server.py new file mode 100644 index 0000000..21170b8 --- /dev/null +++ b/src/claude_memory/mcp_server.py @@ -0,0 +1,550 @@ +#!/usr/bin/env python3 +""" +Claude Memory MCP Server — standalone memory server with multi-user support. + +Supports two modes: + 1. HTTP API mode: connects to a shared PostgreSQL-backed API server + 2. SQLite fallback: local file-based storage when no API key is configured + +Uses only stdlib (urllib) — no pip install required. +""" + +import json +import logging +import os +import sys +import urllib.error +import urllib.request +from typing import Any + +logger = logging.getLogger(__name__) + +PROTOCOL_VERSION = "2024-11-05" +SERVER_NAME = "claude-memory" +SERVER_VERSION = "1.0.0" + +# API configuration — support both MEMORY_* (primary) and CLAUDE_MEMORY_* (fallback) env vars +API_BASE_URL = os.environ.get("MEMORY_API_URL") or os.environ.get("CLAUDE_MEMORY_API_URL", "http://localhost:8080") +API_KEY = os.environ.get("MEMORY_API_KEY") or os.environ.get("CLAUDE_MEMORY_API_KEY", "") + +# Fallback to SQLite if API is not configured +SQLITE_FALLBACK = not API_KEY + + +def _api_request(method: str, path: str, body: dict | None = None) -> dict: + """Make an HTTP request to the memory API.""" + url = f"{API_BASE_URL}{path}" + data = json.dumps(body).encode() if body else None + req = urllib.request.Request( + url, + data=data, + method=method, + headers={ + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json", + }, + ) + try: + with urllib.request.urlopen(req, timeout=15) as resp: + return json.loads(resp.read().decode()) + except urllib.error.HTTPError as e: + error_body = e.read().decode() if e.fp else str(e) + raise RuntimeError(f"API error {e.code}: {error_body}") from e + except urllib.error.URLError as e: + raise RuntimeError(f"API connection error: {e.reason}") from e + + +# ─── SQLite fallback (local storage when API not configured) ───────────────── + +def _init_sqlite(db_path: str | None = None): + """Initialize SQLite database as fallback.""" + import sqlite3 + from pathlib import Path + + if db_path is None: + memory_home = os.path.expandvars( + os.path.expanduser(os.environ.get("MEMORY_HOME", "~/.claude/claude-memory")) + ) + db_path = os.environ.get( + "MEMORY_DB", + os.path.join(memory_home, "memory", "memory.db"), + ) + db_path = os.path.expandvars(os.path.expanduser(db_path)) + + Path(os.path.dirname(db_path)).mkdir(parents=True, exist_ok=True) + + conn = sqlite3.connect(db_path, timeout=30.0) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA busy_timeout=30000") + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS memories ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + content TEXT NOT NULL, + category TEXT DEFAULT 'facts', + tags TEXT DEFAULT '', + expanded_keywords TEXT DEFAULT '', + importance REAL DEFAULT 0.5, + is_sensitive INTEGER DEFAULT 0, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ) + """) + cursor.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5( + content, category, tags, expanded_keywords, + content='memories', content_rowid='id' + ) + """) + cursor.execute(""" + CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN + INSERT INTO memories_fts(rowid, content, category, tags, expanded_keywords) + VALUES (new.id, new.content, new.category, new.tags, new.expanded_keywords); + END + """) + cursor.execute(""" + CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN + INSERT INTO memories_fts(memories_fts, rowid, content, category, tags, expanded_keywords) + VALUES ('delete', old.id, old.content, old.category, old.tags, old.expanded_keywords); + END + """) + cursor.execute(""" + CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN + INSERT INTO memories_fts(memories_fts, rowid, content, category, tags, expanded_keywords) + VALUES ('delete', old.id, old.content, old.category, old.tags, old.expanded_keywords); + INSERT INTO memories_fts(rowid, content, category, tags, expanded_keywords) + VALUES (new.id, new.content, new.category, new.tags, new.expanded_keywords); + END + """) + conn.commit() + return conn + + +# ─── Tool definitions ──────────────────────────────────────────────────────── + +TOOLS = [ + { + "name": "memory_store", + "description": "Store a fact or memory in persistent storage. Use this to remember important information about the user, their preferences, projects, decisions, or people they mention.", + "inputSchema": { + "type": "object", + "properties": { + "content": {"type": "string", "description": "The fact or memory to store"}, + "category": { + "type": "string", + "enum": ["facts", "preferences", "projects", "people", "decisions"], + "description": "Category for organizing the memory", + "default": "facts", + }, + "tags": {"type": "string", "description": "Comma-separated tags", "default": ""}, + "importance": { + "type": "number", + "description": "Importance 0.0-1.0", + "default": 0.5, + "minimum": 0.0, + "maximum": 1.0, + }, + "expanded_keywords": { + "type": "string", + "description": "REQUIRED. Space-separated semantically related search terms (MINIMUM 5 words). Generate keywords that someone might search for when this memory would be relevant. Include synonyms, related concepts, and adjacent topics.", + }, + "force_sensitive": { + "type": "boolean", + "description": "If true, mark this memory as sensitive regardless of auto-detection. Sensitive memories have their content encrypted at rest.", + "default": False, + }, + }, + "required": ["content", "expanded_keywords"], + }, + }, + { + "name": "memory_recall", + "description": "Retrieve relevant memories based on context. Uses full-text search to find stored memories.", + "inputSchema": { + "type": "object", + "properties": { + "context": {"type": "string", "description": "The context or topic to recall memories about"}, + "expanded_query": { + "type": "string", + "description": "REQUIRED. Space-separated semantically related search terms (MINIMUM 5 words).", + }, + "category": { + "type": "string", + "enum": ["facts", "preferences", "projects", "people", "decisions"], + "description": "Optional: filter results to a specific category", + }, + "sort_by": { + "type": "string", + "enum": ["importance", "relevance"], + "description": "Sort order", + "default": "importance", + }, + "limit": {"type": "integer", "description": "Max results", "default": 10}, + }, + "required": ["context", "expanded_query"], + }, + }, + { + "name": "memory_list", + "description": "List recent memories, optionally filtered by category.", + "inputSchema": { + "type": "object", + "properties": { + "category": { + "type": "string", + "enum": ["facts", "preferences", "projects", "people", "decisions"], + }, + "limit": {"type": "integer", "default": 20}, + }, + }, + }, + { + "name": "memory_delete", + "description": "Delete a memory by ID.", + "inputSchema": { + "type": "object", + "properties": { + "id": {"type": "integer", "description": "The ID of the memory to delete"}, + }, + "required": ["id"], + }, + }, + { + "name": "secret_get", + "description": "Retrieve the decrypted content of a sensitive memory. Only works for memories marked as sensitive.", + "inputSchema": { + "type": "object", + "properties": { + "id": {"type": "integer", "description": "The ID of the sensitive memory to retrieve"}, + }, + "required": ["id"], + }, + }, +] + + +class MemoryServer: + """MCP server for persistent memory management.""" + + def __init__(self, sqlite_db_path: str | None = None) -> None: + self.sqlite_conn = None + if SQLITE_FALLBACK: + self.sqlite_conn = _init_sqlite(sqlite_db_path) + + # ── HTTP-backed methods ────────────────────────────────────────── + + def memory_store(self, args: dict[str, Any]) -> str: + content = args.get("content") + if not content: + raise ValueError("content is required") + category = args.get("category", "facts") + tags = args.get("tags", "") + importance = max(0.0, min(1.0, float(args.get("importance", 0.5)))) + expanded_keywords = args.get("expanded_keywords", "") + force_sensitive = bool(args.get("force_sensitive", False)) + + if SQLITE_FALLBACK: + return self._sqlite_store(content, category, tags, importance, expanded_keywords, force_sensitive) + + result = _api_request("POST", "/api/memories", { + "content": content, + "category": category, + "tags": tags, + "expanded_keywords": expanded_keywords, + "importance": importance, + "force_sensitive": force_sensitive, + }) + return f"Stored memory #{result['id']} in category '{result['category']}' with importance {result['importance']:.1f}" + + def memory_recall(self, args: dict[str, Any]) -> str: + context = args.get("context") + if not context: + raise ValueError("context is required") + expanded_query = args.get("expanded_query", "") + category = args.get("category") + sort_by = args.get("sort_by", "importance") + limit = args.get("limit", 10) + + if SQLITE_FALLBACK: + return self._sqlite_recall(context, expanded_query, category, sort_by, limit) + + result = _api_request("POST", "/api/memories/recall", { + "context": context, + "expanded_query": expanded_query, + "category": category, + "sort_by": sort_by, + "limit": limit, + }) + rows = result.get("memories", []) + if not rows: + filter_desc = f" in category '{category}'" if category else "" + return f"No memories found matching: {context}{filter_desc}" + + sort_desc = "by relevance" if sort_by == "relevance" else "by importance" + filter_desc = f" in '{category}'" if category else "" + results = [] + for row in rows: + results.append( + f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}" + f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}" + ) + return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results) + + def memory_list(self, args: dict[str, Any]) -> str: + category = args.get("category") + limit = args.get("limit", 20) + + if SQLITE_FALLBACK: + return self._sqlite_list(category, limit) + + params = f"?limit={limit}" + if category: + params += f"&category={category}" + result = _api_request("GET", f"/api/memories{params}") + rows = result.get("memories", []) + if not rows: + return f"No memories in category '{category}'" if category else "No memories stored yet" + + results = [] + for row in rows: + results.append( + f"#{row['id']} [{row['category']}] {row['content']}" + f"\n Importance: {row['importance']:.1f} | Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}" + ) + header = "Recent memories" + if category: + header += f" in '{category}'" + return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results) + + def memory_delete(self, args: dict[str, Any]) -> str: + memory_id = args.get("id") + if memory_id is None: + raise ValueError("id is required") + + if SQLITE_FALLBACK: + return self._sqlite_delete(memory_id) + + result = _api_request("DELETE", f"/api/memories/{memory_id}") + return f"Deleted memory #{result['deleted']}: {result['preview']}..." + + def secret_get(self, args: dict[str, Any]) -> str: + memory_id = args.get("id") + if memory_id is None: + raise ValueError("id is required") + + if SQLITE_FALLBACK: + return self._sqlite_secret_get(memory_id) + + result = _api_request("POST", f"/api/memories/{memory_id}/secret") + return f"#{result['id']} [{result['category']}] {result['content']}" + + # ── SQLite fallback methods ────────────────────────────────────── + + def _sqlite_store(self, content, category, tags, importance, expanded_keywords, force_sensitive=False): + from datetime import datetime, timezone + + now = datetime.now(timezone.utc).isoformat() + is_sensitive = 1 if force_sensitive else 0 + cursor = self.sqlite_conn.cursor() + cursor.execute( + "INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (content, category, tags, expanded_keywords, importance, is_sensitive, now, now), + ) + self.sqlite_conn.commit() + return f"Stored memory #{cursor.lastrowid} in category '{category}' with importance {importance:.1f}" + + def _sqlite_recall(self, context, expanded_query, category, sort_by, limit): + import sqlite3 + + all_terms = f"{context} {expanded_query}".strip() + words = all_terms.split() + fts_query = " OR ".join(f'"{w.replace(chr(34), "")}"' for w in words if w) + order = ( + "bm25(memories_fts), m.importance DESC" + if sort_by == "relevance" + else "m.importance DESC, m.created_at DESC" + ) + cursor = self.sqlite_conn.cursor() + try: + if category: + cursor.execute( + f"SELECT m.id, m.content, m.category, m.tags, m.importance, m.created_at " + f"FROM memories m JOIN memories_fts fts ON m.id = fts.rowid " + f"WHERE memories_fts MATCH ? AND m.category = ? ORDER BY {order} LIMIT ?", + (fts_query, category, limit), + ) + else: + cursor.execute( + f"SELECT m.id, m.content, m.category, m.tags, m.importance, m.created_at " + f"FROM memories m JOIN memories_fts fts ON m.id = fts.rowid " + f"WHERE memories_fts MATCH ? ORDER BY {order} LIMIT ?", + (fts_query, limit), + ) + rows = cursor.fetchall() + except sqlite3.OperationalError: + like = f"%{context}%" + if category: + cursor.execute( + "SELECT id, content, category, tags, importance, created_at FROM memories " + "WHERE (content LIKE ? OR tags LIKE ?) AND category = ? ORDER BY importance DESC LIMIT ?", + (like, like, category, limit), + ) + else: + cursor.execute( + "SELECT id, content, category, tags, importance, created_at FROM memories " + "WHERE content LIKE ? OR tags LIKE ? ORDER BY importance DESC LIMIT ?", + (like, like, limit), + ) + rows = cursor.fetchall() + + if not rows: + return f"No memories found matching: {context}" + + results = [] + for row in rows: + results.append( + f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}" + f"\n Tags: {row['tags'] or 'none'} | Stored: {row['created_at']}" + ) + return ( + f"Found {len(rows)} memories (by {'relevance' if sort_by == 'relevance' else 'importance'}):\n\n" + + "\n\n".join(results) + ) + + def _sqlite_list(self, category, limit): + cursor = self.sqlite_conn.cursor() + if category: + cursor.execute( + "SELECT id, content, category, tags, importance, created_at FROM memories " + "WHERE category = ? ORDER BY created_at DESC LIMIT ?", + (category, limit), + ) + else: + cursor.execute( + "SELECT id, content, category, tags, importance, created_at FROM memories " + "ORDER BY created_at DESC LIMIT ?", + (limit,), + ) + rows = cursor.fetchall() + if not rows: + return f"No memories in category '{category}'" if category else "No memories stored yet" + + results = [] + for row in rows: + results.append( + f"#{row['id']} [{row['category']}] {row['content']}" + f"\n Importance: {row['importance']:.1f} | Tags: {row['tags'] or 'none'} | Stored: {row['created_at']}" + ) + header = "Recent memories" + (f" in '{category}'" if category else "") + return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results) + + def _sqlite_delete(self, memory_id): + cursor = self.sqlite_conn.cursor() + cursor.execute("SELECT id, content FROM memories WHERE id = ?", (memory_id,)) + row = cursor.fetchone() + if not row: + return f"Memory #{memory_id} not found" + preview = row["content"][:50] + cursor.execute("DELETE FROM memories WHERE id = ?", (memory_id,)) + self.sqlite_conn.commit() + return f"Deleted memory #{memory_id}: {preview}..." + + def _sqlite_secret_get(self, memory_id): + cursor = self.sqlite_conn.cursor() + cursor.execute( + "SELECT id, content, category, is_sensitive FROM memories WHERE id = ?", + (memory_id,), + ) + row = cursor.fetchone() + if not row: + return f"Memory #{memory_id} not found" + if not row["is_sensitive"]: + return f"Memory #{memory_id} is not marked as sensitive" + return f"#{row['id']} [{row['category']}] {row['content']}" + + # ── MCP protocol ───────────────────────────────────────────────── + + def handle_initialize(self, params: dict[str, Any]) -> dict[str, Any]: + return { + "protocolVersion": PROTOCOL_VERSION, + "capabilities": {"tools": {}}, + "serverInfo": {"name": SERVER_NAME, "version": SERVER_VERSION}, + } + + def handle_tools_list(self, params: dict[str, Any]) -> dict[str, Any]: + return {"tools": TOOLS} + + def handle_tools_call(self, params: dict[str, Any]) -> dict[str, Any]: + tool_name = params.get("name") + arguments = params.get("arguments", {}) + try: + handler = { + "memory_store": self.memory_store, + "memory_recall": self.memory_recall, + "memory_list": self.memory_list, + "memory_delete": self.memory_delete, + "secret_get": self.secret_get, + }.get(tool_name) + if handler is None: + return {"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], "isError": True} + result = handler(arguments) + return {"content": [{"type": "text", "text": result}]} + except Exception as e: + return {"content": [{"type": "text", "text": f"Error: {e!s}"}], "isError": True} + + def process_message(self, message: dict[str, Any]) -> dict[str, Any] | None: + method = message.get("method") + params = message.get("params", {}) + msg_id = message.get("id") + if msg_id is None: + return None + result = None + error = None + try: + if method == "initialize": + result = self.handle_initialize(params) + elif method == "tools/list": + result = self.handle_tools_list(params) + elif method == "tools/call": + result = self.handle_tools_call(params) + else: + error = {"code": -32601, "message": f"Method not found: {method}"} + except Exception as e: + error = {"code": -32603, "message": str(e)} + response: dict[str, Any] = {"jsonrpc": "2.0", "id": msg_id} + if error: + response["error"] = error + else: + response["result"] = result + return response + + def run(self) -> None: + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + message = json.loads(line) + except json.JSONDecodeError as e: + print( + json.dumps({ + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32700, "message": f"Parse error: {e}"}, + }), + flush=True, + ) + continue + response = self.process_message(message) + if response is not None: + print(json.dumps(response), flush=True) + + +def main() -> None: + server = MemoryServer() + server.run() + + +if __name__ == "__main__": + main() diff --git a/src/claude_memory/vault_client.py b/src/claude_memory/vault_client.py new file mode 100644 index 0000000..5a82be7 --- /dev/null +++ b/src/claude_memory/vault_client.py @@ -0,0 +1,84 @@ +"""HashiCorp Vault KV v2 client using stdlib urllib.""" + +import json +import logging +import os +import urllib.error +import urllib.request +from typing import Any + +logger = logging.getLogger(__name__) + + +class VaultClient: + """Simple Vault KV v2 client using stdlib.""" + + def __init__( + self, + addr: str | None = None, + token: str | None = None, + mount: str = "secret", + ): + self.addr = (addr or os.environ.get("VAULT_ADDR", "")).rstrip("/") + self.token = token or os.environ.get("VAULT_TOKEN", "") + self.mount = mount + + if not self.addr: + raise ValueError("Vault address not configured (set VAULT_ADDR)") + + # Auto-detect Kubernetes SA token + if not self.token: + sa_token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token" + if os.path.exists(sa_token_path): + self._login_kubernetes(sa_token_path) + + def _login_kubernetes(self, sa_token_path: str) -> None: + """Authenticate with Vault using Kubernetes service account.""" + with open(sa_token_path) as f: + jwt = f.read().strip() + role = os.environ.get("VAULT_ROLE", "claude-memory") + resp = self._request("POST", "/v1/auth/kubernetes/login", {"jwt": jwt, "role": role}) + self.token = resp.get("auth", {}).get("client_token", "") + + def _request(self, method: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]: + """Make HTTP request to Vault.""" + url = f"{self.addr}{path}" + data = json.dumps(body).encode() if body else None + req = urllib.request.Request( + url, data=data, method=method, + headers={"X-Vault-Token": self.token, "Content-Type": "application/json"}, + ) + try: + with urllib.request.urlopen(req, timeout=10) as resp: + return json.loads(resp.read().decode()) + except urllib.error.HTTPError as e: + if e.code == 404: + return {} + error_body = e.read().decode() if e.fp else str(e) + raise RuntimeError(f"Vault error {e.code}: {error_body}") from e + + def read(self, path: str) -> dict[str, Any] | None: + """Read a secret from KV v2.""" + resp = self._request("GET", f"/v1/{self.mount}/data/{path}") + data = resp.get("data", {}) + return data.get("data") if data else None + + def write(self, path: str, data: dict[str, Any]) -> dict[str, Any]: + """Write a secret to KV v2.""" + return self._request("POST", f"/v1/{self.mount}/data/{path}", {"data": data}) + + def delete(self, path: str) -> bool: + """Delete a secret from KV v2.""" + try: + self._request("DELETE", f"/v1/{self.mount}/data/{path}") + return True + except RuntimeError: + return False + + def list_secrets(self, path: str) -> list[str]: + """List secrets at a path.""" + try: + resp = self._request("LIST", f"/v1/{self.mount}/metadata/{path}") + return resp.get("data", {}).get("keys", []) + except RuntimeError: + return [] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..ae500bc --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,304 @@ +"""Tests for the Claude Memory API endpoints.""" + +import importlib +import os +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from httpx import ASGITransport, AsyncClient + +from claude_memory.api.auth import AuthUser + + +# Helpers to build mock asyncpg rows (they behave like dicts with attribute access) +class MockRow(dict): + def __getattr__(self, key): + try: + return self[key] + except KeyError: + raise AttributeError(key) + + +def _make_memory_row(**overrides): + now = datetime.now(timezone.utc) + defaults = { + "id": 1, + "user_id": "testuser", + "content": "test content", + "category": "facts", + "tags": "", + "expanded_keywords": "", + "importance": 0.5, + "is_sensitive": False, + "vault_path": None, + "encrypted_content": None, + "rank": 0.5, + "created_at": now, + "updated_at": now, + } + defaults.update(overrides) + return MockRow(defaults) + + +@pytest.fixture +def mock_pool(): + """Create a mock asyncpg pool with connection context manager.""" + pool = MagicMock() + conn = AsyncMock() + + # pool.acquire() returns an async context manager yielding conn + acm = MagicMock() + acm.__aenter__ = AsyncMock(return_value=conn) + acm.__aexit__ = AsyncMock(return_value=False) + pool.acquire.return_value = acm + + return pool, conn + + +@pytest.fixture +def test_user(): + return AuthUser(user_id="testuser") + + +@pytest.fixture +def client(mock_pool, test_user): + """Create an AsyncClient with mocked dependencies.""" + pool, conn = mock_pool + + # Reload modules with test API key + with patch.dict(os.environ, {"API_KEY": "test-key", "API_KEYS": "", "DATABASE_URL": "postgresql://test"}): + import claude_memory.api.auth as auth_mod + import claude_memory.api.database as db_mod + import claude_memory.api.app as app_mod + + importlib.reload(auth_mod) + importlib.reload(db_mod) + importlib.reload(app_mod) + + # Override database pool + db_mod.pool = pool + + # Override auth to return our test user + async def mock_get_user(authorization: str = ""): + return test_user + + app_mod.app.dependency_overrides[auth_mod.get_current_user] = mock_get_user + + transport = ASGITransport(app=app_mod.app) + return AsyncClient(transport=transport, base_url="http://test"), conn, app_mod + + +@pytest.mark.asyncio +async def test_health_endpoint_no_auth(client): + ac, conn, app_mod = client + async with ac: + resp = await ac.get("/health") + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} + + +@pytest.mark.asyncio +async def test_store_memory_creates_record_with_user_id(client): + ac, conn, app_mod = client + conn.fetchrow.return_value = _make_memory_row(id=42, category="facts", importance=0.7) + + async with ac: + resp = await ac.post( + "/api/memories", + json={"content": "Python is great", "category": "facts", "importance": 0.7}, + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == 42 + assert data["category"] == "facts" + assert data["importance"] == 0.7 + + # Verify INSERT was called with user_id + call_args = conn.fetchrow.call_args + assert call_args[0][1] == "testuser" # user_id is the second positional arg + + +@pytest.mark.asyncio +async def test_recall_returns_only_user_memories(client): + ac, conn, app_mod = client + conn.fetch.return_value = [ + _make_memory_row(id=1, content="user memory", is_sensitive=False), + ] + + async with ac: + resp = await ac.post( + "/api/memories/recall", + json={"context": "test query"}, + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + results = resp.json() + assert len(results) == 1 + assert results[0]["content"] == "user memory" + + # Verify query includes user_id filter + call_args = conn.fetch.call_args + assert call_args[0][1] == "testuser" + + +@pytest.mark.asyncio +async def test_recall_redacts_sensitive_memories(client): + ac, conn, app_mod = client + conn.fetch.return_value = [ + _make_memory_row(id=5, content="[REDACTED]", is_sensitive=True), + ] + + async with ac: + resp = await ac.post( + "/api/memories/recall", + json={"context": "secrets"}, + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + results = resp.json() + assert "[SENSITIVE" in results[0]["content"] + assert "secret_get(id=5)" in results[0]["content"] + + +@pytest.mark.asyncio +async def test_list_returns_only_user_memories(client): + ac, conn, app_mod = client + conn.fetch.return_value = [ + _make_memory_row(id=1, content="mem1"), + _make_memory_row(id=2, content="mem2"), + ] + + async with ac: + resp = await ac.get( + "/api/memories", + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + results = resp.json() + assert len(results) == 2 + + # Verify user_id filter + call_args = conn.fetch.call_args + assert call_args[0][1] == "testuser" + + +@pytest.mark.asyncio +async def test_delete_only_user_memories(client): + ac, conn, app_mod = client + conn.fetchrow.return_value = _make_memory_row(id=10, vault_path=None) + conn.execute.return_value = None + + async with ac: + resp = await ac.delete( + "/api/memories/10", + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + assert resp.json() == {"deleted": 10} + + # Verify both SELECT and DELETE include user_id + fetchrow_args = conn.fetchrow.call_args + assert fetchrow_args[0][1] == 10 # memory_id + assert fetchrow_args[0][2] == "testuser" # user_id + + +@pytest.mark.asyncio +async def test_delete_nonexistent_memory_returns_404(client): + ac, conn, app_mod = client + conn.fetchrow.return_value = None + + async with ac: + resp = await ac.delete( + "/api/memories/999", + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_secret_endpoint_returns_plaintext(client): + ac, conn, app_mod = client + conn.fetchrow.return_value = _make_memory_row( + id=7, content="my secret value", is_sensitive=False, + vault_path=None, encrypted_content=None, + ) + + async with ac: + resp = await ac.post( + "/api/memories/7/secret", + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == 7 + assert data["content"] == "my secret value" + assert data["source"] == "plaintext" + + +@pytest.mark.asyncio +async def test_secret_endpoint_returns_vault_content(client): + ac, conn, app_mod = client + conn.fetchrow.return_value = _make_memory_row( + id=8, content="[REDACTED]", is_sensitive=True, + vault_path="claude-memory/testuser/mem-8", encrypted_content=None, + ) + + with patch("claude_memory.api.app.get_secret", return_value="actual-secret-from-vault"): + async with ac: + resp = await ac.post( + "/api/memories/8/secret", + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["content"] == "actual-secret-from-vault" + assert data["source"] == "vault" + + +@pytest.mark.asyncio +async def test_secret_endpoint_nonexistent_returns_404(client): + ac, conn, app_mod = client + conn.fetchrow.return_value = None + + async with ac: + resp = await ac.post( + "/api/memories/999/secret", + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_import_memories(client): + ac, conn, app_mod = client + conn.fetchrow.side_effect = [ + _make_memory_row(id=100, category="facts", importance=0.5), + _make_memory_row(id=101, category="preferences", importance=0.8), + ] + + async with ac: + resp = await ac.post( + "/api/memories/import", + json=[ + {"content": "fact one", "category": "facts"}, + {"content": "pref one", "category": "preferences", "importance": 0.8}, + ], + headers={"Authorization": "Bearer test-key"}, + ) + + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 2 + assert data[0]["id"] == 100 + assert data[1]["id"] == 101 diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..f17c923 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,87 @@ +"""Tests for multi-user authentication.""" + +import importlib +import os +from unittest.mock import patch + +import pytest +from fastapi import HTTPException + + +def _reload_auth(env_vars: dict): + """Reload the auth module with given environment variables.""" + with patch.dict(os.environ, env_vars, clear=False): + # Clear existing env vars that might interfere + for key in ("API_KEY", "API_KEYS"): + os.environ.pop(key, None) + for key, val in env_vars.items(): + os.environ[key] = val + + import claude_memory.api.auth as auth_mod + + importlib.reload(auth_mod) + return auth_mod + + +@pytest.mark.asyncio +async def test_single_api_key_maps_to_default(): + auth = _reload_auth({"API_KEY": "test-key-123", "API_KEYS": ""}) + user = await auth.get_current_user(authorization="Bearer test-key-123") + assert user.user_id == "default" + + +@pytest.mark.asyncio +async def test_multi_api_keys_maps_to_correct_user(): + auth = _reload_auth({ + "API_KEYS": '{"viktor": "key-viktor", "alice": "key-alice"}', + "API_KEY": "", + }) + user_v = await auth.get_current_user(authorization="Bearer key-viktor") + assert user_v.user_id == "viktor" + + user_a = await auth.get_current_user(authorization="Bearer key-alice") + assert user_a.user_id == "alice" + + +@pytest.mark.asyncio +async def test_invalid_key_returns_401(): + auth = _reload_auth({"API_KEY": "valid-key", "API_KEYS": ""}) + with pytest.raises(HTTPException) as exc_info: + await auth.get_current_user(authorization="Bearer wrong-key") + assert exc_info.value.status_code == 401 + + +@pytest.mark.asyncio +async def test_missing_bearer_prefix_still_works(): + auth = _reload_auth({"API_KEY": "my-key", "API_KEYS": ""}) + # Without Bearer prefix, removeprefix("Bearer ") returns "my-key" unchanged + # so the raw token still matches the key + user = await auth.get_current_user(authorization="my-key") + assert user.user_id == "default" + + # With proper Bearer prefix it also works + user = await auth.get_current_user(authorization="Bearer my-key") + assert user.user_id == "default" + + +@pytest.mark.asyncio +async def test_missing_authorization_header_raises_422(): + """FastAPI raises 422 when required Header is missing. + This is tested via the app integration, not the function directly, + since FastAPI handles the missing header before the function runs. + """ + from httpx import ASGITransport, AsyncClient + + # Need to reload with valid keys so the app can start + _reload_auth({"API_KEY": "test-key", "API_KEYS": ""}) + + # Import app after auth is configured + import claude_memory.api.app as app_mod + + importlib.reload(app_mod) + + transport = ASGITransport(app=app_mod.app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + # Skip lifespan since we don't have a real DB + resp = await client.get("/api/memories") + assert resp.status_code == 422 diff --git a/tests/test_credential_detector.py b/tests/test_credential_detector.py new file mode 100644 index 0000000..9de7818 --- /dev/null +++ b/tests/test_credential_detector.py @@ -0,0 +1,132 @@ +"""Tests for credential detection and redaction.""" + +import pytest + +from claude_memory.credential_detector import ( + DetectedCredential, + detect_credentials, + is_sensitive, + redact_credentials, +) + + +class TestDetectCredentials: + def test_detect_postgres_connection_string(self): + text = "db_url = postgres://user:pass@localhost:5432/mydb" + creds = detect_credentials(text) + assert len(creds) == 1 + assert creds[0].type == "connection_string" + assert creds[0].confidence == 0.9 + assert "postgres://" in creds[0].matched_text + + def test_detect_password_assignment(self): + text = 'password = "my_super_secret_pw"' + creds = detect_credentials(text) + assert len(creds) >= 1 + types = [c.type for c in creds] + assert "password" in types + + def test_detect_api_key(self): + text = "api_key = ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" + creds = detect_credentials(text) + assert len(creds) >= 1 + types = [c.type for c in creds] + assert "api_key" in types + + def test_detect_private_key(self): + text = "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA0Z3VS5JJcds3xfn/ygWep4PAtGoSo\n-----END RSA PRIVATE KEY-----" + creds = detect_credentials(text) + assert len(creds) == 1 + assert creds[0].type == "private_key" + assert creds[0].confidence == 0.95 + + def test_detect_bearer_token(self): + text = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkw" + creds = detect_credentials(text) + assert len(creds) >= 1 + types = [c.type for c in creds] + assert "bearer_token" in types + + def test_detect_aws_key(self): + text = "aws_access_key_id = AKIAIOSFODNN7EXAMPLE" + creds = detect_credentials(text) + assert len(creds) >= 1 + types = [c.type for c in creds] + assert "aws_key" in types + + def test_detect_github_token(self): + text = "GITHUB_TOKEN=ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmn" + creds = detect_credentials(text) + assert len(creds) >= 1 + types = [c.type for c in creds] + assert "github_token" in types + + def test_no_false_positives_on_normal_text(self): + text = "This is a normal paragraph about programming. It discusses variables, functions, and classes." + creds = detect_credentials(text) + assert len(creds) == 0 + + def test_no_false_positives_on_short_password(self): + # password values shorter than 8 chars should not match + text = 'password = "short"' + creds = detect_credentials(text) + assert len(creds) == 0 + + def test_min_confidence_filtering(self): + text = 'secret = "abcdefghijklmnopqrstuvwxyz"' + all_creds = detect_credentials(text, min_confidence=0.5) + high_creds = detect_credentials(text, min_confidence=0.9) + assert len(all_creds) >= len(high_creds) + + def test_overlapping_matches_keep_highest_confidence(self): + # A text that could match both token and generic_secret + text = 'secret = "abcdefghijklmnopqrstuvwxyz1234567890"' + creds = detect_credentials(text, min_confidence=0.5) + # Should not have overlapping ranges for the same span + for i, c1 in enumerate(creds): + for c2 in creds[i + 1:]: + # No credential should be fully contained within another + assert not (c1.start <= c2.start and c1.end >= c2.end) + + +class TestRedactCredentials: + def test_redaction_replaces_with_marker(self): + text = "db_url = postgres://user:pass@localhost:5432/mydb" + creds = detect_credentials(text) + redacted = redact_credentials(text, creds) + assert "[REDACTED:connection_string]" in redacted + assert "postgres://" not in redacted + + def test_redaction_preserves_surrounding_text(self): + text = "before postgres://user:pass@localhost/db after" + creds = detect_credentials(text) + redacted = redact_credentials(text, creds) + assert redacted.startswith("before ") + assert redacted.endswith(" after") + + def test_redaction_no_credentials(self): + text = "nothing sensitive here" + redacted = redact_credentials(text, []) + assert redacted == text + + def test_redaction_multiple_credentials(self): + text = 'password = "mysecretpw123" and api_key = ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890' + creds = detect_credentials(text) + redacted = redact_credentials(text, creds) + assert "mysecretpw123" not in redacted + assert "[REDACTED:" in redacted + + +class TestIsSensitive: + def test_sensitive_text(self): + assert is_sensitive("password = supersecretvalue123") + + def test_non_sensitive_text(self): + assert not is_sensitive("just a normal log message") + + def test_respects_min_confidence(self): + text = 'secret = "abcdefghijklmnopqrstuvwxyz"' + # Low confidence should detect + assert is_sensitive(text, min_confidence=0.5) + # Very high confidence should not detect generic_secret + assert not is_sensitive(text, min_confidence=0.95) diff --git a/tests/test_crypto.py b/tests/test_crypto.py new file mode 100644 index 0000000..4470abd --- /dev/null +++ b/tests/test_crypto.py @@ -0,0 +1,134 @@ +"""Tests for AES-256-GCM encryption module.""" + +import hashlib +import os + +import pytest + +from claude_memory.crypto import ( + ENCRYPTION_KEY_ENV, + decrypt, + decrypt_b64, + encrypt, + encrypt_b64, + is_encryption_configured, +) + +# A valid 32-byte hex key for testing +TEST_HEX_KEY = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" +TEST_PASSPHRASE = "my-test-passphrase" + + +@pytest.fixture +def hex_key_env(monkeypatch): + monkeypatch.setenv(ENCRYPTION_KEY_ENV, TEST_HEX_KEY) + + +@pytest.fixture +def passphrase_env(monkeypatch): + monkeypatch.setenv(ENCRYPTION_KEY_ENV, TEST_PASSPHRASE) + + +@pytest.fixture +def no_key_env(monkeypatch): + monkeypatch.delenv(ENCRYPTION_KEY_ENV, raising=False) + + +class TestEncryptionConfigured: + def test_configured_with_hex_key(self, hex_key_env): + assert is_encryption_configured() is True + + def test_configured_with_passphrase(self, passphrase_env): + assert is_encryption_configured() is True + + def test_not_configured_without_env(self, no_key_env): + assert is_encryption_configured() is False + + +class TestEncryptDecrypt: + def test_roundtrip_with_hex_key(self, hex_key_env): + plaintext = "Hello, this is a secret message!" + encrypted = encrypt(plaintext) + decrypted = decrypt(encrypted) + assert decrypted == plaintext + + def test_roundtrip_with_passphrase(self, passphrase_env): + plaintext = "Another secret message with passphrase key" + encrypted = encrypt(plaintext) + decrypted = decrypt(encrypted) + assert decrypted == plaintext + + def test_different_plaintexts_produce_different_ciphertexts(self, hex_key_env): + ct1 = encrypt("message one") + ct2 = encrypt("message two") + assert ct1 != ct2 + + def test_same_plaintext_produces_different_ciphertexts(self, hex_key_env): + """Due to random nonce, encrypting the same text twice gives different results.""" + ct1 = encrypt("same message") + ct2 = encrypt("same message") + assert ct1 != ct2 + + def test_missing_key_raises_on_encrypt(self, no_key_env): + with pytest.raises(RuntimeError, match=ENCRYPTION_KEY_ENV): + encrypt("test") + + def test_missing_key_raises_on_decrypt(self, no_key_env): + with pytest.raises(RuntimeError, match=ENCRYPTION_KEY_ENV): + decrypt(b"\x00" * 28) + + def test_decrypt_with_wrong_key_fails(self, hex_key_env, monkeypatch): + plaintext = "secret data" + encrypted = encrypt(plaintext) + + # Change to a different key + monkeypatch.setenv(ENCRYPTION_KEY_ENV, "ff" * 32) + with pytest.raises(Exception): + decrypt(encrypted) + + def test_encrypted_data_format(self, hex_key_env): + """Encrypted data should be at least 12 (nonce) + 16 (tag) bytes.""" + encrypted = encrypt("x") + assert len(encrypted) >= 28 # 12 nonce + 1 plaintext + 16 tag = 29 minimum + + def test_unicode_roundtrip(self, hex_key_env): + plaintext = "Unicode test: cafe\u0301, \u00fc\u00f6\u00e4, \U0001f512" + decrypted = decrypt(encrypt(plaintext)) + assert decrypted == plaintext + + +class TestBase64Variants: + def test_b64_roundtrip(self, hex_key_env): + plaintext = "base64 test message" + encrypted_b64 = encrypt_b64(plaintext) + assert isinstance(encrypted_b64, str) + decrypted = decrypt_b64(encrypted_b64) + assert decrypted == plaintext + + def test_b64_output_is_valid_base64(self, hex_key_env): + import base64 + encrypted_b64 = encrypt_b64("test") + # Should not raise + decoded = base64.b64decode(encrypted_b64) + assert len(decoded) >= 28 + + +class TestKeyDerivation: + def test_hex_key_used_directly(self, hex_key_env): + """A valid 64-char hex string should be used as-is (32 bytes).""" + ct = encrypt("test") + pt = decrypt(ct) + assert pt == "test" + + def test_passphrase_derived_via_sha256(self, passphrase_env): + """Non-hex strings should be derived via SHA-256.""" + ct = encrypt("test") + pt = decrypt(ct) + assert pt == "test" + + def test_short_hex_treated_as_passphrase(self, monkeypatch): + """Hex string that's not exactly 32 bytes should be treated as passphrase.""" + monkeypatch.setenv(ENCRYPTION_KEY_ENV, "abcd1234") + ct = encrypt("test") + pt = decrypt(ct) + assert pt == "test" diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py new file mode 100644 index 0000000..3a429e4 --- /dev/null +++ b/tests/test_mcp_server.py @@ -0,0 +1,342 @@ +"""Tests for the Claude Memory MCP server.""" + +import json +import os +import sys + +import pytest + +# Force SQLite fallback mode for all tests +os.environ.pop("MEMORY_API_KEY", None) +os.environ.pop("CLAUDE_MEMORY_API_KEY", None) + +# Add src to path so we can import without installing +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +from claude_memory.mcp_server import MemoryServer, TOOLS, SERVER_NAME, SERVER_VERSION, PROTOCOL_VERSION + + +@pytest.fixture +def server(tmp_path): + """Create a MemoryServer with a temporary SQLite database.""" + db_path = str(tmp_path / "test_memory.db") + srv = MemoryServer(sqlite_db_path=db_path) + yield srv + if srv.sqlite_conn: + srv.sqlite_conn.close() + + +class TestSQLiteInit: + def test_creates_database(self, tmp_path): + db_path = str(tmp_path / "sub" / "test.db") + srv = MemoryServer(sqlite_db_path=db_path) + assert os.path.exists(db_path) + # Verify tables exist + cursor = srv.sqlite_conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories'") + assert cursor.fetchone() is not None + srv.sqlite_conn.close() + + def test_creates_fts_table(self, tmp_path): + db_path = str(tmp_path / "test.db") + srv = MemoryServer(sqlite_db_path=db_path) + cursor = srv.sqlite_conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories_fts'") + assert cursor.fetchone() is not None + srv.sqlite_conn.close() + + +class TestMemoryStore: + def test_store_basic(self, server): + result = server.memory_store({ + "content": "User prefers dark mode", + "expanded_keywords": "dark mode theme preference ui", + }) + assert "Stored memory #1" in result + assert "facts" in result + + def test_store_with_category(self, server): + result = server.memory_store({ + "content": "User likes Python", + "category": "preferences", + "expanded_keywords": "python programming language preference", + }) + assert "preferences" in result + + def test_store_with_importance(self, server): + result = server.memory_store({ + "content": "Critical info", + "importance": 0.9, + "expanded_keywords": "critical important info", + }) + assert "0.9" in result + + def test_store_requires_content(self, server): + with pytest.raises(ValueError, match="content is required"): + server.memory_store({"expanded_keywords": "test"}) + + def test_store_force_sensitive(self, server): + result = server.memory_store({ + "content": "API key: sk-1234", + "force_sensitive": True, + "expanded_keywords": "api key secret credential", + }) + assert "Stored memory #1" in result + # Verify is_sensitive flag is set + cursor = server.sqlite_conn.cursor() + cursor.execute("SELECT is_sensitive FROM memories WHERE id = 1") + row = cursor.fetchone() + assert row["is_sensitive"] == 1 + + +class TestMemoryRecall: + def test_recall_finds_memory(self, server): + server.memory_store({ + "content": "User works at Acme Corp", + "expanded_keywords": "acme corp company work employer", + }) + result = server.memory_recall({ + "context": "work", + "expanded_query": "company employer job", + }) + assert "Acme Corp" in result + assert "Found 1 memories" in result + + def test_recall_no_results(self, server): + result = server.memory_recall({ + "context": "nonexistent topic", + "expanded_query": "nothing here at all", + }) + assert "No memories found" in result + + def test_recall_with_category_filter(self, server): + server.memory_store({ + "content": "User prefers vim", + "category": "preferences", + "expanded_keywords": "vim editor preference text", + }) + server.memory_store({ + "content": "Project uses React", + "category": "projects", + "expanded_keywords": "react project frontend framework", + }) + result = server.memory_recall({ + "context": "preferences", + "expanded_query": "vim editor", + "category": "preferences", + }) + assert "vim" in result + assert "React" not in result + + def test_recall_requires_context(self, server): + with pytest.raises(ValueError, match="context is required"): + server.memory_recall({"expanded_query": "test"}) + + +class TestMemoryList: + def test_list_empty(self, server): + result = server.memory_list({}) + assert "No memories stored yet" in result + + def test_list_with_memories(self, server): + server.memory_store({ + "content": "Memory one", + "expanded_keywords": "one first test", + }) + server.memory_store({ + "content": "Memory two", + "expanded_keywords": "two second test", + }) + result = server.memory_list({}) + assert "Memory one" in result + assert "Memory two" in result + assert "2 shown" in result + + def test_list_with_category(self, server): + server.memory_store({ + "content": "A fact", + "category": "facts", + "expanded_keywords": "fact test", + }) + server.memory_store({ + "content": "A preference", + "category": "preferences", + "expanded_keywords": "preference test", + }) + result = server.memory_list({"category": "facts"}) + assert "A fact" in result + assert "A preference" not in result + + def test_list_empty_category(self, server): + result = server.memory_list({"category": "projects"}) + assert "No memories in category 'projects'" in result + + def test_list_respects_limit(self, server): + for i in range(5): + server.memory_store({ + "content": f"Memory {i}", + "expanded_keywords": f"memory number {i}", + }) + result = server.memory_list({"limit": 2}) + assert "2 shown" in result + + +class TestMemoryDelete: + def test_delete_existing(self, server): + server.memory_store({ + "content": "To be deleted", + "expanded_keywords": "delete remove test", + }) + result = server.memory_delete({"id": 1}) + assert "Deleted memory #1" in result + assert "To be deleted" in result + + def test_delete_nonexistent(self, server): + result = server.memory_delete({"id": 999}) + assert "not found" in result + + def test_delete_requires_id(self, server): + with pytest.raises(ValueError, match="id is required"): + server.memory_delete({}) + + +class TestSecretGet: + def test_secret_get_sensitive(self, server): + server.memory_store({ + "content": "secret password 12345", + "force_sensitive": True, + "expanded_keywords": "password secret credential", + }) + result = server.secret_get({"id": 1}) + assert "secret password 12345" in result + + def test_secret_get_not_sensitive(self, server): + server.memory_store({ + "content": "public info", + "expanded_keywords": "public info test", + }) + result = server.secret_get({"id": 1}) + assert "not marked as sensitive" in result + + def test_secret_get_nonexistent(self, server): + result = server.secret_get({"id": 999}) + assert "not found" in result + + def test_secret_get_requires_id(self, server): + with pytest.raises(ValueError, match="id is required"): + server.secret_get({}) + + +class TestMCPProtocol: + def test_handle_initialize(self, server): + result = server.handle_initialize({}) + assert result["protocolVersion"] == PROTOCOL_VERSION + assert result["serverInfo"]["name"] == SERVER_NAME + assert result["serverInfo"]["version"] == SERVER_VERSION + assert "tools" in result["capabilities"] + + def test_handle_tools_list(self, server): + result = server.handle_tools_list({}) + tools = result["tools"] + assert len(tools) == 5 + names = {t["name"] for t in tools} + assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get"} + + def test_handle_tools_call_store(self, server): + result = server.handle_tools_call({ + "name": "memory_store", + "arguments": { + "content": "test memory", + "expanded_keywords": "test memory keywords", + }, + }) + assert not result.get("isError", False) + assert "Stored memory" in result["content"][0]["text"] + + def test_handle_tools_call_unknown(self, server): + result = server.handle_tools_call({ + "name": "nonexistent_tool", + "arguments": {}, + }) + assert result["isError"] is True + assert "Unknown tool" in result["content"][0]["text"] + + def test_handle_tools_call_error(self, server): + result = server.handle_tools_call({ + "name": "memory_store", + "arguments": {}, # missing content + }) + assert result["isError"] is True + assert "Error" in result["content"][0]["text"] + + +class TestProcessMessage: + def test_initialize(self, server): + response = server.process_message({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": {}, + }) + assert response["jsonrpc"] == "2.0" + assert response["id"] == 1 + assert "result" in response + assert response["result"]["serverInfo"]["name"] == SERVER_NAME + + def test_tools_list(self, server): + response = server.process_message({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {}, + }) + assert "result" in response + assert len(response["result"]["tools"]) == 5 + + def test_tools_call(self, server): + response = server.process_message({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "memory_store", + "arguments": { + "content": "via process_message", + "expanded_keywords": "process message test", + }, + }, + }) + assert "result" in response + assert "Stored memory" in response["result"]["content"][0]["text"] + + def test_unknown_method(self, server): + response = server.process_message({ + "jsonrpc": "2.0", + "id": 4, + "method": "unknown/method", + "params": {}, + }) + assert "error" in response + assert response["error"]["code"] == -32601 + assert "Method not found" in response["error"]["message"] + + def test_notification_no_id(self, server): + response = server.process_message({ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {}, + }) + assert response is None + + def test_jsonrpc_response_format(self, server): + response = server.process_message({ + "jsonrpc": "2.0", + "id": 5, + "method": "initialize", + "params": {}, + }) + # Verify it's valid JSON when serialized + serialized = json.dumps(response) + parsed = json.loads(serialized) + assert parsed["jsonrpc"] == "2.0" + assert parsed["id"] == 5 diff --git a/tests/test_vault_client.py b/tests/test_vault_client.py new file mode 100644 index 0000000..372781a --- /dev/null +++ b/tests/test_vault_client.py @@ -0,0 +1,154 @@ +"""Tests for Vault KV v2 client with mocked urllib.""" + +import json +import os +from io import BytesIO +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from claude_memory.vault_client import VaultClient + + +@pytest.fixture +def vault_env(monkeypatch): + monkeypatch.setenv("VAULT_ADDR", "http://vault.example.com:8200") + monkeypatch.setenv("VAULT_TOKEN", "s.testtoken123") + + +class TestVaultClientInit: + def test_missing_addr_raises_value_error(self, monkeypatch): + monkeypatch.delenv("VAULT_ADDR", raising=False) + monkeypatch.delenv("VAULT_TOKEN", raising=False) + with pytest.raises(ValueError, match="Vault address not configured"): + VaultClient() + + def test_init_with_explicit_args(self): + client = VaultClient(addr="http://localhost:8200", token="mytoken") + assert client.addr == "http://localhost:8200" + assert client.token == "mytoken" + assert client.mount == "secret" + + def test_init_from_env(self, vault_env): + client = VaultClient() + assert client.addr == "http://vault.example.com:8200" + assert client.token == "s.testtoken123" + + def test_addr_trailing_slash_stripped(self): + client = VaultClient(addr="http://localhost:8200/", token="t") + assert client.addr == "http://localhost:8200" + + @patch("os.path.exists", return_value=True) + @patch("builtins.open", mock_open(read_data="fake-jwt-token")) + @patch("urllib.request.urlopen") + def test_kubernetes_sa_token_auto_detection(self, mock_urlopen, mock_exists, monkeypatch): + monkeypatch.setenv("VAULT_ADDR", "http://vault:8200") + monkeypatch.delenv("VAULT_TOKEN", raising=False) + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps({ + "auth": {"client_token": "s.k8s-token-abc"} + }).encode() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + mock_urlopen.return_value = mock_response + + client = VaultClient() + assert client.token == "s.k8s-token-abc" + + +class TestVaultRead: + @patch("urllib.request.urlopen") + def test_read_secret_returns_data(self, mock_urlopen, vault_env): + mock_response = MagicMock() + mock_response.read.return_value = json.dumps({ + "data": {"data": {"username": "admin", "password": "secret"}} + }).encode() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + mock_urlopen.return_value = mock_response + + client = VaultClient() + result = client.read("myapp/config") + assert result == {"username": "admin", "password": "secret"} + + @patch("urllib.request.urlopen") + def test_read_returns_none_for_404(self, mock_urlopen, vault_env): + import urllib.error + mock_urlopen.side_effect = urllib.error.HTTPError( + url="http://vault:8200/v1/secret/data/missing", + code=404, + msg="Not Found", + hdrs={}, + fp=BytesIO(b""), + ) + + client = VaultClient() + result = client.read("missing/path") + assert result is None + + +class TestVaultWrite: + @patch("urllib.request.urlopen") + def test_write_secret_sends_correct_request(self, mock_urlopen, vault_env): + mock_response = MagicMock() + mock_response.read.return_value = json.dumps({ + "data": {"created_time": "2024-01-01T00:00:00Z", "version": 1} + }).encode() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + mock_urlopen.return_value = mock_response + + client = VaultClient() + result = client.write("myapp/config", {"key": "value"}) + + # Verify the request was made with correct data + call_args = mock_urlopen.call_args + request = call_args[0][0] + assert request.full_url == "http://vault.example.com:8200/v1/secret/data/myapp/config" + assert request.method == "POST" + body = json.loads(request.data.decode()) + assert body == {"data": {"key": "value"}} + + +class TestVaultDelete: + @patch("urllib.request.urlopen") + def test_delete_returns_true_on_success(self, mock_urlopen, vault_env): + mock_response = MagicMock() + mock_response.read.return_value = b"{}" + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + mock_urlopen.return_value = mock_response + + client = VaultClient() + assert client.delete("myapp/config") is True + + @patch("urllib.request.urlopen") + def test_delete_returns_false_on_error(self, mock_urlopen, vault_env): + import urllib.error + mock_urlopen.side_effect = urllib.error.HTTPError( + url="http://vault:8200/v1/secret/data/missing", + code=500, + msg="Internal Server Error", + hdrs={}, + fp=BytesIO(b"error"), + ) + + client = VaultClient() + assert client.delete("missing/path") is False + + +class TestVaultListSecrets: + @patch("urllib.request.urlopen") + def test_list_secrets(self, mock_urlopen, vault_env): + mock_response = MagicMock() + mock_response.read.return_value = json.dumps({ + "data": {"keys": ["secret1", "secret2/"]} + }).encode() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + mock_urlopen.return_value = mock_response + + client = VaultClient() + result = client.list_secrets("myapp") + assert result == ["secret1", "secret2/"]