feat: standalone claude-memory-mcp with multi-user support and Vault integration
Extracted from private infra repo into standalone open-source project. Three operating modes: - Local: SQLite + FTS5 (zero dependencies) - Server: PostgreSQL via HTTP API with multi-user auth - Full: PostgreSQL + HashiCorp Vault for secret management Features: - MCP stdio server with 5 tools (store/recall/list/delete/secret_get) - FastAPI HTTP API with multi-user Bearer token auth (API_KEYS JSON map) - Regex-based credential detection with auto-redaction - AES-256-GCM encryption fallback for non-Vault deployments - Vault KV v2 client (stdlib urllib, K8s SA auto-auth) - Per-user data isolation (all queries scoped by user_id) - Secret migration endpoint for existing plain-text credentials - Backward-compatible env var aliases (CLAUDE_MEMORY_API_URL) Infrastructure: - Docker + docker-compose (API + PostgreSQL + optional Vault) - Woodpecker CI (test → build → push → kubectl deploy) - GitHub Actions CI (Python 3.11/3.12/3.13) + Release (GHCR + PyPI) - Helm chart + raw Kubernetes manifests 96 tests passing across 6 test files.
This commit is contained in:
commit
0ed5e1e016
40 changed files with 3381 additions and 0 deletions
36
.github/workflows/ci.yml
vendored
Normal file
36
.github/workflows/ci.yml
vendored
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- run: pip install -e ".[api,dev]"
|
||||
- run: ruff check src/ tests/
|
||||
- run: mypy src/claude_memory/
|
||||
- run: pytest tests/ -v --tb=short
|
||||
|
||||
docker:
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: docker/setup-buildx-action@v3
|
||||
- uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
push: false
|
||||
tags: claude-memory-mcp:test
|
||||
40
.github/workflows/release.yml
vendored
Normal file
40
.github/workflows/release.yml
vendored
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags: ["v*"]
|
||||
|
||||
jobs:
|
||||
docker:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
packages: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: docker/setup-buildx-action@v3
|
||||
- uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
push: true
|
||||
tags: |
|
||||
ghcr.io/${{ github.repository }}:${{ github.ref_name }}
|
||||
ghcr.io/${{ github.repository }}:latest
|
||||
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
id-token: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: pip install build
|
||||
- run: python -m build
|
||||
- uses: pypa/gh-action-pypi-publish@release/v1
|
||||
45
.gitignore
vendored
Normal file
45
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
*.egg-info/
|
||||
*.egg
|
||||
dist/
|
||||
build/
|
||||
.eggs/
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Testing
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
.env.production
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Docker
|
||||
docker/pgdata/
|
||||
|
||||
# Database
|
||||
*.db
|
||||
*.sqlite3
|
||||
43
.woodpecker.yml
Normal file
43
.woodpecker.yml
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
when:
|
||||
- event: push
|
||||
branch: main
|
||||
|
||||
clone:
|
||||
git:
|
||||
image: woodpeckerci/plugin-git
|
||||
settings:
|
||||
attempts: 5
|
||||
backoff: 10s
|
||||
|
||||
steps:
|
||||
- name: test
|
||||
image: python:3.12-slim
|
||||
commands:
|
||||
- pip install -e ".[api,dev]"
|
||||
- ruff check src/ tests/
|
||||
- pytest tests/ -v --tb=short
|
||||
|
||||
- name: build-and-push
|
||||
image: woodpeckerci/plugin-docker-buildx
|
||||
depends_on:
|
||||
- test
|
||||
settings:
|
||||
username: viktorbarzin
|
||||
password:
|
||||
from_secret: dockerhub-token
|
||||
repo: viktorbarzin/claude-memory-mcp
|
||||
dockerfile: docker/Dockerfile
|
||||
context: .
|
||||
platforms:
|
||||
- linux/amd64
|
||||
tags:
|
||||
- "${CI_PIPELINE_NUMBER}"
|
||||
- latest
|
||||
|
||||
- name: deploy
|
||||
image: bitnami/kubectl:latest
|
||||
depends_on:
|
||||
- build-and-push
|
||||
commands:
|
||||
- kubectl rollout restart deployment/claude-memory -n claude-memory
|
||||
- kubectl rollout status deployment/claude-memory -n claude-memory --timeout=120s
|
||||
191
LICENSE
Normal file
191
LICENSE
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to the Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by the Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding any notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Copyright 2026 Viktor Barzin
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
224
README.md
Normal file
224
README.md
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
# Claude Memory MCP
|
||||
|
||||
A persistent memory layer for Claude Code that stores knowledge across sessions. Operates as an MCP (Model Context Protocol) server with optional API and database backends.
|
||||
|
||||
## Operating Modes
|
||||
|
||||
| Mode | Storage | Auth | Use Case |
|
||||
|------|---------|------|----------|
|
||||
| **Local** | SQLite file | None | Single user, local Claude Code |
|
||||
| **Server** | SQLite file | API key | Remote access, single user |
|
||||
| **Full** | PostgreSQL | API keys + Vault | Multi-user, team deployment |
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Local Mode (MCP stdio)
|
||||
|
||||
Install and configure Claude Code to use the MCP server directly:
|
||||
|
||||
```bash
|
||||
pip install claude-memory-mcp
|
||||
```
|
||||
|
||||
Add to your Claude Code MCP config (`~/.claude/plugins/`):
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"memory": {
|
||||
"command": "python",
|
||||
"args": ["-m", "claude_memory.mcp_server"],
|
||||
"env": {
|
||||
"MEMORY_DB_PATH": "~/.claude/memory.db"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Server Mode (API)
|
||||
|
||||
Run the API server with SQLite:
|
||||
|
||||
```bash
|
||||
pip install claude-memory-mcp[api]
|
||||
|
||||
export DATABASE_URL="sqlite:///./memory.db"
|
||||
export API_KEY="your-secret-key"
|
||||
|
||||
uvicorn claude_memory.api.app:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
Configure Claude Code to connect via HTTP:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"memory": {
|
||||
"command": "python",
|
||||
"args": ["-m", "claude_memory.mcp_server"],
|
||||
"env": {
|
||||
"MEMORY_API_URL": "http://localhost:8000",
|
||||
"MEMORY_API_KEY": "your-secret-key"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Full Mode (Docker Compose)
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
This starts the API server with PostgreSQL. See [Docker Compose](#docker-compose) for details.
|
||||
|
||||
## Docker Compose
|
||||
|
||||
The dev environment includes the API server and PostgreSQL:
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
To include HashiCorp Vault for secret management:
|
||||
|
||||
```bash
|
||||
docker compose --profile vault up -d
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `DATABASE_URL` | Database connection string | `sqlite:///./memory.db` |
|
||||
| `API_KEY` | Single-user API key | None |
|
||||
| `API_KEYS` | Multi-user JSON map `{"user": "key"}` | None |
|
||||
| `VAULT_ADDR` | Vault server address | None |
|
||||
| `VAULT_TOKEN` | Vault authentication token | None |
|
||||
|
||||
## Multi-User Setup
|
||||
|
||||
For team deployments, use the `API_KEYS` environment variable with a JSON mapping of usernames to API keys:
|
||||
|
||||
```bash
|
||||
export API_KEYS='{"alice": "key-alice-xxx", "bob": "key-bob-yyy"}'
|
||||
```
|
||||
|
||||
Each user gets isolated memory storage. The username is extracted from the API key on each request.
|
||||
|
||||
## Vault Integration
|
||||
|
||||
For production deployments, store API keys in HashiCorp Vault instead of environment variables:
|
||||
|
||||
```bash
|
||||
export VAULT_ADDR="https://vault.example.com"
|
||||
export VAULT_TOKEN="s.xxxxxxxxxxxx"
|
||||
```
|
||||
|
||||
The server reads API keys from the Vault KV store at `secret/claude-memory/api-keys`.
|
||||
|
||||
## API Reference
|
||||
|
||||
### Health Check
|
||||
|
||||
```
|
||||
GET /health
|
||||
```
|
||||
|
||||
### Store Memory
|
||||
|
||||
```
|
||||
POST /api/v1/memories
|
||||
Authorization: Bearer <api-key>
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"content": "The user prefers dark mode",
|
||||
"tags": ["preferences", "ui"],
|
||||
"source": "conversation"
|
||||
}
|
||||
```
|
||||
|
||||
### Recall Memories
|
||||
|
||||
```
|
||||
GET /api/v1/memories?q=dark+mode&limit=10
|
||||
Authorization: Bearer <api-key>
|
||||
```
|
||||
|
||||
### List Memories
|
||||
|
||||
```
|
||||
GET /api/v1/memories?tags=preferences&limit=20
|
||||
Authorization: Bearer <api-key>
|
||||
```
|
||||
|
||||
### Delete Memory
|
||||
|
||||
```
|
||||
DELETE /api/v1/memories/{id}
|
||||
Authorization: Bearer <api-key>
|
||||
```
|
||||
|
||||
## Kubernetes Deployment
|
||||
|
||||
### Helm
|
||||
|
||||
```bash
|
||||
helm install claude-memory deploy/helm/claude-memory \
|
||||
--set env.DATABASE_URL="postgresql://user:pass@host:5432/db" \
|
||||
--set env.API_KEY="your-key" \
|
||||
--set ingress.host="claude-memory.yourdomain.com"
|
||||
```
|
||||
|
||||
### Raw Manifests
|
||||
|
||||
```bash
|
||||
kubectl apply -f deploy/kubernetes/namespace.yaml
|
||||
# Create secret with your credentials first:
|
||||
kubectl create secret generic claude-memory-secrets \
|
||||
-n claude-memory \
|
||||
--from-literal=database-url="postgresql://user:pass@host:5432/db" \
|
||||
--from-literal=api-key="your-key"
|
||||
kubectl apply -f deploy/kubernetes/
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
git clone https://github.com/viktorbarzin/claude-memory-mcp.git
|
||||
cd claude-memory-mcp
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -e ".[api,dev]"
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
pytest tests/ -v
|
||||
```
|
||||
|
||||
### Linting
|
||||
|
||||
```bash
|
||||
ruff check src/ tests/
|
||||
mypy src/claude_memory/
|
||||
```
|
||||
|
||||
### Building
|
||||
|
||||
```bash
|
||||
pip install build
|
||||
python -m build
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Apache License 2.0. See [LICENSE](LICENSE) for details.
|
||||
6
deploy/helm/claude-memory/Chart.yaml
Normal file
6
deploy/helm/claude-memory/Chart.yaml
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
apiVersion: v2
|
||||
name: claude-memory
|
||||
description: Claude Memory MCP API server
|
||||
type: application
|
||||
version: 1.0.0
|
||||
appVersion: "1.0.0"
|
||||
35
deploy/helm/claude-memory/templates/deployment.yaml
Normal file
35
deploy/helm/claude-memory/templates/deployment.yaml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ .Release.Name }}
|
||||
labels:
|
||||
app: {{ .Release.Name }}
|
||||
spec:
|
||||
replicas: {{ .Values.replicaCount }}
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Release.Name }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: {{ .Release.Name }}
|
||||
spec:
|
||||
containers:
|
||||
- name: {{ .Release.Name }}
|
||||
image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}"
|
||||
imagePullPolicy: {{ .Values.image.pullPolicy }}
|
||||
ports:
|
||||
- containerPort: {{ .Values.service.targetPort }}
|
||||
{{- range $key, $value := .Values.env }}
|
||||
{{- if $value }}
|
||||
env:
|
||||
- name: {{ $key }}
|
||||
value: {{ $value | quote }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
livenessProbe:
|
||||
{{- toYaml .Values.livenessProbe | nindent 12 }}
|
||||
readinessProbe:
|
||||
{{- toYaml .Values.readinessProbe | nindent 12 }}
|
||||
resources:
|
||||
{{- toYaml .Values.resources | nindent 12 }}
|
||||
25
deploy/helm/claude-memory/templates/ingress.yaml
Normal file
25
deploy/helm/claude-memory/templates/ingress.yaml
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
{{- if .Values.ingress.enabled }}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ .Release.Name }}
|
||||
annotations:
|
||||
nginx.ingress.kubernetes.io/ssl-redirect: "true"
|
||||
spec:
|
||||
ingressClassName: {{ .Values.ingress.className }}
|
||||
tls:
|
||||
- hosts:
|
||||
- {{ .Values.ingress.host }}
|
||||
secretName: {{ .Values.ingress.tls.secretName }}
|
||||
rules:
|
||||
- host: {{ .Values.ingress.host }}
|
||||
http:
|
||||
paths:
|
||||
- path: /
|
||||
pathType: Prefix
|
||||
backend:
|
||||
service:
|
||||
name: {{ .Release.Name }}
|
||||
port:
|
||||
number: {{ .Values.service.port }}
|
||||
{{- end }}
|
||||
12
deploy/helm/claude-memory/templates/service.yaml
Normal file
12
deploy/helm/claude-memory/templates/service.yaml
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ .Release.Name }}
|
||||
spec:
|
||||
type: {{ .Values.service.type }}
|
||||
ports:
|
||||
- port: {{ .Values.service.port }}
|
||||
targetPort: {{ .Values.service.targetPort }}
|
||||
protocol: TCP
|
||||
selector:
|
||||
app: {{ .Release.Name }}
|
||||
46
deploy/helm/claude-memory/values.yaml
Normal file
46
deploy/helm/claude-memory/values.yaml
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
replicaCount: 1
|
||||
|
||||
image:
|
||||
repository: viktorbarzin/claude-memory-mcp
|
||||
tag: latest
|
||||
pullPolicy: Always
|
||||
|
||||
service:
|
||||
type: ClusterIP
|
||||
port: 80
|
||||
targetPort: 8000
|
||||
|
||||
ingress:
|
||||
enabled: true
|
||||
className: nginx
|
||||
host: claude-memory.example.com
|
||||
tls:
|
||||
secretName: tls-secret
|
||||
|
||||
resources:
|
||||
requests:
|
||||
memory: 32Mi
|
||||
cpu: 10m
|
||||
limits:
|
||||
memory: 128Mi
|
||||
|
||||
env:
|
||||
DATABASE_URL: ""
|
||||
API_KEY: ""
|
||||
# API_KEYS: '{}'
|
||||
# VAULT_ADDR: ""
|
||||
# VAULT_TOKEN: ""
|
||||
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: 8000
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 30
|
||||
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: 8000
|
||||
initialDelaySeconds: 3
|
||||
periodSeconds: 10
|
||||
52
deploy/kubernetes/deployment.yaml
Normal file
52
deploy/kubernetes/deployment.yaml
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: claude-memory
|
||||
namespace: claude-memory
|
||||
labels:
|
||||
app: claude-memory
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: claude-memory
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: claude-memory
|
||||
spec:
|
||||
containers:
|
||||
- name: claude-memory
|
||||
image: viktorbarzin/claude-memory-mcp:latest
|
||||
imagePullPolicy: Always
|
||||
ports:
|
||||
- containerPort: 8000
|
||||
env:
|
||||
- name: DATABASE_URL
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: claude-memory-secrets
|
||||
key: database-url
|
||||
- name: API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: claude-memory-secrets
|
||||
key: api-key
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: 8000
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 30
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: 8000
|
||||
initialDelaySeconds: 3
|
||||
periodSeconds: 10
|
||||
resources:
|
||||
requests:
|
||||
memory: 32Mi
|
||||
cpu: 10m
|
||||
limits:
|
||||
memory: 128Mi
|
||||
24
deploy/kubernetes/ingress.yaml
Normal file
24
deploy/kubernetes/ingress.yaml
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: claude-memory
|
||||
namespace: claude-memory
|
||||
annotations:
|
||||
nginx.ingress.kubernetes.io/ssl-redirect: "true"
|
||||
spec:
|
||||
ingressClassName: nginx
|
||||
tls:
|
||||
- hosts:
|
||||
- claude-memory.example.com
|
||||
secretName: tls-secret
|
||||
rules:
|
||||
- host: claude-memory.example.com
|
||||
http:
|
||||
paths:
|
||||
- path: /
|
||||
pathType: Prefix
|
||||
backend:
|
||||
service:
|
||||
name: claude-memory
|
||||
port:
|
||||
number: 80
|
||||
4
deploy/kubernetes/namespace.yaml
Normal file
4
deploy/kubernetes/namespace.yaml
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
apiVersion: v1
|
||||
kind: Namespace
|
||||
metadata:
|
||||
name: claude-memory
|
||||
13
deploy/kubernetes/service.yaml
Normal file
13
deploy/kubernetes/service.yaml
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: claude-memory
|
||||
namespace: claude-memory
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 80
|
||||
targetPort: 8000
|
||||
protocol: TCP
|
||||
selector:
|
||||
app: claude-memory
|
||||
15
docker/Dockerfile
Normal file
15
docker/Dockerfile
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
FROM python:3.12-slim AS base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml .
|
||||
COPY src/ src/
|
||||
|
||||
RUN pip install --no-cache-dir ".[api]"
|
||||
|
||||
RUN useradd -r -u 1000 app
|
||||
USER app
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "claude_memory.api.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
49
docker/docker-compose.yml
Normal file
49
docker/docker-compose.yml
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
services:
|
||||
api:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
DATABASE_URL: postgresql://claude_memory:devpassword@postgres:5432/claude_memory
|
||||
API_KEY: dev-api-key
|
||||
# Multi-user mode (uncomment to test):
|
||||
# API_KEYS: '{"viktor": "key1", "testuser": "key2"}'
|
||||
# Vault (uncomment to test):
|
||||
# VAULT_ADDR: http://vault:8200
|
||||
# VAULT_TOKEN: dev-root-token
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
environment:
|
||||
POSTGRES_DB: claude_memory
|
||||
POSTGRES_USER: claude_memory
|
||||
POSTGRES_PASSWORD: devpassword
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- pgdata:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U claude_memory"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
vault:
|
||||
image: hashicorp/vault:1.15
|
||||
ports:
|
||||
- "8200:8200"
|
||||
environment:
|
||||
VAULT_DEV_ROOT_TOKEN_ID: dev-root-token
|
||||
VAULT_DEV_LISTEN_ADDRESS: 0.0.0.0:8200
|
||||
cap_add:
|
||||
- IPC_LOCK
|
||||
profiles:
|
||||
- vault
|
||||
|
||||
volumes:
|
||||
pgdata:
|
||||
10
examples/mcp-config-local.json
Normal file
10
examples/mcp-config-local.json
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"mcpServers": {
|
||||
"memory": {
|
||||
"type": "stdio",
|
||||
"command": "python3",
|
||||
"args": ["-m", "claude_memory.mcp_server"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
13
examples/mcp-config-server.json
Normal file
13
examples/mcp-config-server.json
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"mcpServers": {
|
||||
"memory": {
|
||||
"type": "stdio",
|
||||
"command": "python3",
|
||||
"args": ["-m", "claude_memory.mcp_server"],
|
||||
"env": {
|
||||
"MEMORY_API_URL": "https://claude-memory.example.com",
|
||||
"MEMORY_API_KEY": "your-api-key-here"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
15
examples/mcp-config-vault.json
Normal file
15
examples/mcp-config-vault.json
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"mcpServers": {
|
||||
"memory": {
|
||||
"type": "stdio",
|
||||
"command": "python3",
|
||||
"args": ["-m", "claude_memory.mcp_server"],
|
||||
"env": {
|
||||
"MEMORY_API_URL": "https://claude-memory.example.com",
|
||||
"MEMORY_API_KEY": "your-api-key-here",
|
||||
"VAULT_ADDR": "https://vault.example.com",
|
||||
"VAULT_TOKEN": "your-vault-token"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
35
pyproject.toml
Normal file
35
pyproject.toml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "claude-memory-mcp"
|
||||
version = "1.0.0"
|
||||
description = "Standalone MCP memory server with multi-user support and Vault integration"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
requires-python = ">=3.11"
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
api = ["fastapi>=0.115", "asyncpg>=0.30", "uvicorn>=0.34", "pydantic>=2.0"]
|
||||
vault = ["hvac>=2.0"]
|
||||
dev = ["pytest>=8.0", "pytest-asyncio>=0.24", "ruff>=0.8", "mypy>=1.13", "httpx>=0.28"]
|
||||
|
||||
[project.scripts]
|
||||
claude-memory-server = "claude_memory.mcp_server:main"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/claude_memory"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
line-length = 120
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
strict = true
|
||||
3
src/claude_memory/__init__.py
Normal file
3
src/claude_memory/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""Claude Memory MCP — standalone memory server with multi-user support."""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
4
src/claude_memory/__main__.py
Normal file
4
src/claude_memory/__main__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
"""Allow running as `python -m claude_memory`."""
|
||||
from claude_memory.mcp_server import main
|
||||
|
||||
main()
|
||||
0
src/claude_memory/api/__init__.py
Normal file
0
src/claude_memory/api/__init__.py
Normal file
337
src/claude_memory/api/app.py
Normal file
337
src/claude_memory/api/app.py
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
"""Claude Memory API -- shared persistent memory with PostgreSQL full-text search."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException
|
||||
|
||||
from claude_memory.api.auth import AuthUser, get_current_user
|
||||
from claude_memory.api.database import close_pool, get_pool, init_pool
|
||||
from claude_memory.api.models import MemoryRecall, MemoryResponse, MemoryStore, SecretResponse
|
||||
from claude_memory.api.vault_service import (
|
||||
delete_secret,
|
||||
get_secret,
|
||||
is_vault_configured,
|
||||
store_secret,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await init_pool()
|
||||
yield
|
||||
await close_pool()
|
||||
|
||||
|
||||
app = FastAPI(title="Claude Memory API", lifespan=lifespan)
|
||||
|
||||
|
||||
def _detect_sensitive(content: str) -> bool:
|
||||
"""Check if content contains credentials using the credential detector."""
|
||||
try:
|
||||
from claude_memory.credential_detector import detect_credentials
|
||||
|
||||
findings = detect_credentials(content)
|
||||
return len(findings) > 0
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _redact_content(content: str) -> str:
|
||||
"""Redact sensitive content for storage in the main DB."""
|
||||
try:
|
||||
from claude_memory.credential_detector import detect_credentials, redact_credentials
|
||||
|
||||
creds = detect_credentials(content)
|
||||
if creds:
|
||||
return redact_credentials(content, creds)
|
||||
return content
|
||||
except ImportError:
|
||||
return "[REDACTED]"
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/api/memories", response_model=MemoryResponse)
|
||||
async def store_memory(body: MemoryStore, user: AuthUser = Depends(get_current_user)):
|
||||
pool = await get_pool()
|
||||
is_sensitive = body.force_sensitive or _detect_sensitive(body.content)
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO memories (user_id, content, category, tags, expanded_keywords, importance, is_sensitive)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id, category, importance
|
||||
""",
|
||||
user.user_id,
|
||||
body.content if not is_sensitive else _redact_content(body.content),
|
||||
body.category,
|
||||
body.tags,
|
||||
body.expanded_keywords,
|
||||
body.importance,
|
||||
is_sensitive,
|
||||
)
|
||||
memory_id = row["id"]
|
||||
|
||||
if is_sensitive and is_vault_configured():
|
||||
vault_path = await store_secret(user.user_id, memory_id, body.content)
|
||||
await conn.execute(
|
||||
"UPDATE memories SET vault_path = $1 WHERE id = $2",
|
||||
vault_path,
|
||||
memory_id,
|
||||
)
|
||||
|
||||
return MemoryResponse(id=row["id"], category=row["category"], importance=row["importance"])
|
||||
|
||||
|
||||
@app.post("/api/memories/recall")
|
||||
async def recall_memories(body: MemoryRecall, user: AuthUser = Depends(get_current_user)):
|
||||
pool = await get_pool()
|
||||
|
||||
query_text = f"{body.context} {body.expanded_query}".strip()
|
||||
|
||||
order_clause = "ts_rank(search_vector, query) DESC"
|
||||
if body.sort_by == "importance":
|
||||
order_clause = "importance DESC, ts_rank(search_vector, query) DESC"
|
||||
elif body.sort_by == "recency":
|
||||
order_clause = "created_at DESC"
|
||||
|
||||
category_filter = ""
|
||||
params: list = [user.user_id, query_text, body.limit]
|
||||
if body.category:
|
||||
category_filter = "AND category = $4"
|
||||
params.append(body.category)
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
f"""
|
||||
SELECT id, content, category, tags, importance, is_sensitive,
|
||||
ts_rank(search_vector, query) AS rank,
|
||||
created_at, updated_at
|
||||
FROM memories, plainto_tsquery('english', $2) query
|
||||
WHERE user_id = $1
|
||||
AND (search_vector @@ query OR $2 = '')
|
||||
{category_filter}
|
||||
ORDER BY {order_clause}
|
||||
LIMIT $3
|
||||
""",
|
||||
*params,
|
||||
)
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
content = row["content"]
|
||||
if row["is_sensitive"]:
|
||||
content = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||
results.append(
|
||||
{
|
||||
"id": row["id"],
|
||||
"content": content,
|
||||
"category": row["category"],
|
||||
"tags": row["tags"],
|
||||
"importance": row["importance"],
|
||||
"is_sensitive": row["is_sensitive"],
|
||||
"rank": float(row["rank"]),
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
"updated_at": row["updated_at"].isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@app.get("/api/memories")
|
||||
async def list_memories(
|
||||
category: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
user: AuthUser = Depends(get_current_user),
|
||||
):
|
||||
pool = await get_pool()
|
||||
|
||||
if category:
|
||||
query = """
|
||||
SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
||||
FROM memories WHERE user_id = $1 AND category = $2
|
||||
ORDER BY importance DESC LIMIT $3
|
||||
"""
|
||||
params: list = [user.user_id, category, limit]
|
||||
else:
|
||||
query = """
|
||||
SELECT id, content, category, tags, importance, is_sensitive, created_at, updated_at
|
||||
FROM memories WHERE user_id = $1
|
||||
ORDER BY importance DESC LIMIT $2
|
||||
"""
|
||||
params = [user.user_id, limit]
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
content = row["content"]
|
||||
if row["is_sensitive"]:
|
||||
content = f"[SENSITIVE - use secret_get(id={row['id']})]"
|
||||
results.append(
|
||||
{
|
||||
"id": row["id"],
|
||||
"content": content,
|
||||
"category": row["category"],
|
||||
"tags": row["tags"],
|
||||
"importance": row["importance"],
|
||||
"is_sensitive": row["is_sensitive"],
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
"updated_at": row["updated_at"].isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@app.delete("/api/memories/{memory_id}")
|
||||
async def delete_memory(memory_id: int, user: AuthUser = Depends(get_current_user)):
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT id, vault_path FROM memories WHERE id = $1 AND user_id = $2",
|
||||
memory_id,
|
||||
user.user_id,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
|
||||
if row["vault_path"]:
|
||||
await delete_secret(user.user_id, row["vault_path"])
|
||||
|
||||
await conn.execute(
|
||||
"DELETE FROM memories WHERE id = $1 AND user_id = $2",
|
||||
memory_id,
|
||||
user.user_id,
|
||||
)
|
||||
|
||||
return {"deleted": memory_id}
|
||||
|
||||
|
||||
@app.post("/api/memories/{memory_id}/secret", response_model=SecretResponse)
|
||||
async def get_memory_secret(memory_id: int, user: AuthUser = Depends(get_current_user)):
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT id, content, is_sensitive, vault_path, encrypted_content
|
||||
FROM memories WHERE id = $1 AND user_id = $2
|
||||
""",
|
||||
memory_id,
|
||||
user.user_id,
|
||||
)
|
||||
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
|
||||
if not row["is_sensitive"]:
|
||||
return SecretResponse(id=row["id"], content=row["content"], source="plaintext")
|
||||
|
||||
if row["vault_path"]:
|
||||
secret = await get_secret(user.user_id, row["vault_path"])
|
||||
if secret:
|
||||
return SecretResponse(id=row["id"], content=secret, source="vault")
|
||||
|
||||
if row["encrypted_content"]:
|
||||
return SecretResponse(
|
||||
id=row["id"],
|
||||
content="[ENCRYPTED - decryption not available]",
|
||||
source="encrypted",
|
||||
)
|
||||
|
||||
return SecretResponse(id=row["id"], content=row["content"], source="plaintext")
|
||||
|
||||
|
||||
@app.post("/api/memories/migrate-secrets")
|
||||
async def migrate_secrets(user: AuthUser = Depends(get_current_user)):
|
||||
pool = await get_pool()
|
||||
migrated = 0
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, content FROM memories
|
||||
WHERE user_id = $1 AND is_sensitive = FALSE
|
||||
""",
|
||||
user.user_id,
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
if _detect_sensitive(row["content"]):
|
||||
original_content = row["content"]
|
||||
redacted = _redact_content(original_content)
|
||||
|
||||
vault_path = None
|
||||
if is_vault_configured():
|
||||
vault_path = await store_secret(user.user_id, row["id"], original_content)
|
||||
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE memories
|
||||
SET is_sensitive = TRUE, content = $1, vault_path = $2,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3 AND user_id = $4
|
||||
""",
|
||||
redacted,
|
||||
vault_path,
|
||||
row["id"],
|
||||
user.user_id,
|
||||
)
|
||||
migrated += 1
|
||||
|
||||
return {"migrated": migrated}
|
||||
|
||||
|
||||
@app.post("/api/memories/import")
|
||||
async def import_memories(
|
||||
memories: list[MemoryStore], user: AuthUser = Depends(get_current_user)
|
||||
):
|
||||
pool = await get_pool()
|
||||
imported = []
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
for mem in memories:
|
||||
is_sensitive = mem.force_sensitive or _detect_sensitive(mem.content)
|
||||
content = mem.content if not is_sensitive else _redact_content(mem.content)
|
||||
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO memories (user_id, content, category, tags, expanded_keywords, importance, is_sensitive)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id, category, importance
|
||||
""",
|
||||
user.user_id,
|
||||
content,
|
||||
mem.category,
|
||||
mem.tags,
|
||||
mem.expanded_keywords,
|
||||
mem.importance,
|
||||
is_sensitive,
|
||||
)
|
||||
|
||||
if is_sensitive and is_vault_configured():
|
||||
vault_path = await store_secret(user.user_id, row["id"], mem.content)
|
||||
await conn.execute(
|
||||
"UPDATE memories SET vault_path = $1 WHERE id = $2",
|
||||
vault_path,
|
||||
row["id"],
|
||||
)
|
||||
|
||||
imported.append(
|
||||
MemoryResponse(id=row["id"], category=row["category"], importance=row["importance"])
|
||||
)
|
||||
|
||||
return imported
|
||||
32
src/claude_memory/api/auth.py
Normal file
32
src/claude_memory/api/auth.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import Header, HTTPException
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthUser:
|
||||
user_id: str
|
||||
|
||||
|
||||
# Multi-user mode: API_KEYS='{"viktor": "key1", "user2": "key2"}'
|
||||
# Single-user mode: API_KEY="some-key" (backward compatible, user_id="default")
|
||||
_api_keys_json = os.environ.get("API_KEYS", "")
|
||||
_api_key_single = os.environ.get("API_KEY", "")
|
||||
|
||||
_key_to_user: dict[str, str] = {}
|
||||
|
||||
if _api_keys_json:
|
||||
_user_to_key = json.loads(_api_keys_json)
|
||||
_key_to_user = {v: k for k, v in _user_to_key.items()}
|
||||
elif _api_key_single:
|
||||
_key_to_user = {_api_key_single: "default"}
|
||||
|
||||
|
||||
async def get_current_user(authorization: str = Header(...)) -> AuthUser:
|
||||
token = authorization.removeprefix("Bearer ").strip()
|
||||
user_id = _key_to_user.get(token)
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
return AuthUser(user_id=user_id)
|
||||
55
src/claude_memory/api/database.py
Normal file
55
src/claude_memory/api/database.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import os
|
||||
|
||||
import asyncpg
|
||||
|
||||
DATABASE_URL = os.environ.get("DATABASE_URL", "")
|
||||
|
||||
pool: asyncpg.Pool | None = None
|
||||
|
||||
|
||||
async def init_pool() -> asyncpg.Pool:
|
||||
global pool
|
||||
pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id VARCHAR(100) NOT NULL DEFAULT 'default',
|
||||
content TEXT NOT NULL,
|
||||
category VARCHAR(50) DEFAULT 'facts',
|
||||
tags TEXT DEFAULT '',
|
||||
expanded_keywords TEXT DEFAULT '',
|
||||
importance REAL DEFAULT 0.5,
|
||||
is_sensitive BOOLEAN DEFAULT FALSE,
|
||||
vault_path TEXT DEFAULT NULL,
|
||||
encrypted_content BYTEA DEFAULT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
search_vector tsvector GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(content, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(expanded_keywords, '')), 'B') ||
|
||||
setweight(to_tsvector('english', coalesce(tags, '')), 'C') ||
|
||||
setweight(to_tsvector('english', coalesce(category, '')), 'D')
|
||||
) STORED
|
||||
)
|
||||
""")
|
||||
await conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_memories_search ON memories USING GIN(search_vector)"
|
||||
)
|
||||
await conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_memories_user ON memories(user_id)"
|
||||
)
|
||||
return pool
|
||||
|
||||
|
||||
async def close_pool():
|
||||
global pool
|
||||
if pool:
|
||||
await pool.close()
|
||||
pool = None
|
||||
|
||||
|
||||
async def get_pool() -> asyncpg.Pool:
|
||||
if pool is None:
|
||||
raise RuntimeError("Database pool not initialized")
|
||||
return pool
|
||||
32
src/claude_memory/api/models.py
Normal file
32
src/claude_memory/api/models.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MemoryStore(BaseModel):
|
||||
content: str
|
||||
category: str = "facts"
|
||||
tags: str = ""
|
||||
expanded_keywords: str = ""
|
||||
importance: float = Field(default=0.5, ge=0.0, le=1.0)
|
||||
force_sensitive: bool = False
|
||||
|
||||
|
||||
class MemoryRecall(BaseModel):
|
||||
context: str
|
||||
expanded_query: str = ""
|
||||
category: Optional[str] = None
|
||||
sort_by: str = "importance"
|
||||
limit: int = 10
|
||||
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
id: int
|
||||
category: str
|
||||
importance: float
|
||||
|
||||
|
||||
class SecretResponse(BaseModel):
|
||||
id: int
|
||||
content: str
|
||||
source: str # "vault", "encrypted", "plaintext"
|
||||
51
src/claude_memory/api/vault_service.py
Normal file
51
src/claude_memory/api/vault_service.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
import os
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VAULT_ADDR = os.environ.get("VAULT_ADDR", "")
|
||||
VAULT_TOKEN = os.environ.get("VAULT_TOKEN", "")
|
||||
VAULT_MOUNT = os.environ.get("VAULT_MOUNT", "secret")
|
||||
VAULT_PREFIX = os.environ.get("VAULT_PREFIX", "claude-memory")
|
||||
|
||||
|
||||
def is_vault_configured() -> bool:
|
||||
return bool(VAULT_ADDR and VAULT_TOKEN)
|
||||
|
||||
|
||||
async def store_secret(user_id: str, memory_id: int, content: str) -> str:
|
||||
"""Store secret content in Vault. Returns the vault path."""
|
||||
if not is_vault_configured():
|
||||
raise RuntimeError("Vault not configured")
|
||||
|
||||
from claude_memory.vault_client import VaultClient
|
||||
|
||||
client = VaultClient(VAULT_ADDR, VAULT_TOKEN, VAULT_MOUNT)
|
||||
path = f"{VAULT_PREFIX}/{user_id}/mem-{memory_id}"
|
||||
client.write(path, {"content": content})
|
||||
return path
|
||||
|
||||
|
||||
async def get_secret(user_id: str, vault_path: str) -> str | None:
|
||||
"""Retrieve secret content from Vault."""
|
||||
if not is_vault_configured():
|
||||
return None
|
||||
|
||||
from claude_memory.vault_client import VaultClient
|
||||
|
||||
client = VaultClient(VAULT_ADDR, VAULT_TOKEN, VAULT_MOUNT)
|
||||
data = client.read(vault_path)
|
||||
if data:
|
||||
return data.get("content")
|
||||
return None
|
||||
|
||||
|
||||
async def delete_secret(user_id: str, vault_path: str) -> bool:
|
||||
"""Delete secret from Vault."""
|
||||
if not is_vault_configured():
|
||||
return False
|
||||
|
||||
from claude_memory.vault_client import VaultClient
|
||||
|
||||
client = VaultClient(VAULT_ADDR, VAULT_TOKEN, VAULT_MOUNT)
|
||||
return client.delete(vault_path)
|
||||
76
src/claude_memory/credential_detector.py
Normal file
76
src/claude_memory/credential_detector.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""Detect credentials and secrets in text content."""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectedCredential:
|
||||
type: str # e.g. "password", "api_key", "private_key", "connection_string", "token"
|
||||
confidence: float # 0.0 to 1.0
|
||||
start: int # position in text
|
||||
end: int # position in text
|
||||
matched_text: str # the actual matched text (for redaction)
|
||||
|
||||
|
||||
# Patterns ordered by confidence
|
||||
_PATTERNS: list[tuple[str, str, float]] = [
|
||||
# High confidence (0.9+)
|
||||
("private_key", r"-----BEGIN (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----[\s\S]*?-----END (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----", 0.95),
|
||||
("connection_string", r"(?:postgres(?:ql)?|mysql|mongodb(?:\+srv)?|redis|amqp)://[^\s'\"]+", 0.9),
|
||||
("aws_key", r"(?:AKIA|ASIA)[A-Z0-9]{16}", 0.95),
|
||||
("github_token", r"gh[pousr]_[A-Za-z0-9_]{36,}", 0.95),
|
||||
|
||||
# Medium confidence (0.7-0.89)
|
||||
("api_key", r"(?:api[_-]?key|apikey)\s*[:=]\s*['\"]?([A-Za-z0-9_\-]{20,})['\"]?", 0.8),
|
||||
("password", r"(?:password|passwd|pwd)\s*[:=]\s*['\"]?([^\s'\"]{8,})['\"]?", 0.8),
|
||||
("token", r"(?:token|secret|bearer)\s*[:=]\s*['\"]?([A-Za-z0-9_\-\.]{20,})['\"]?", 0.75),
|
||||
("basic_auth", r"(?:Basic\s+)[A-Za-z0-9+/=]{20,}", 0.85),
|
||||
("bearer_token", r"Bearer\s+[A-Za-z0-9_\-\.]{20,}", 0.85),
|
||||
|
||||
# Lower confidence (0.5-0.69)
|
||||
("generic_secret", r"(?:secret|credential|auth)\s*[:=]\s*['\"]?([^\s'\"]{12,})['\"]?", 0.6),
|
||||
("hex_key", r"(?:key|secret)\s*[:=]\s*['\"]?([0-9a-fA-F]{32,})['\"]?", 0.65),
|
||||
]
|
||||
|
||||
|
||||
def detect_credentials(text: str, min_confidence: float = 0.5) -> list[DetectedCredential]:
|
||||
"""Scan text for potential credentials and secrets."""
|
||||
results: list[DetectedCredential] = []
|
||||
for cred_type, pattern, confidence in _PATTERNS:
|
||||
if confidence < min_confidence:
|
||||
continue
|
||||
for match in re.finditer(pattern, text, re.IGNORECASE):
|
||||
results.append(DetectedCredential(
|
||||
type=cred_type,
|
||||
confidence=confidence,
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
matched_text=match.group(0),
|
||||
))
|
||||
# Deduplicate overlapping matches, keeping highest confidence
|
||||
results.sort(key=lambda c: (-c.confidence, c.start))
|
||||
filtered: list[DetectedCredential] = []
|
||||
for cred in results:
|
||||
if not any(c.start <= cred.start and c.end >= cred.end for c in filtered):
|
||||
filtered.append(cred)
|
||||
return sorted(filtered, key=lambda c: c.start)
|
||||
|
||||
|
||||
def redact_credentials(text: str, credentials: list[DetectedCredential]) -> str:
|
||||
"""Replace detected credentials with [REDACTED] markers."""
|
||||
if not credentials:
|
||||
return text
|
||||
parts: list[str] = []
|
||||
last_end = 0
|
||||
for cred in sorted(credentials, key=lambda c: c.start):
|
||||
parts.append(text[last_end:cred.start])
|
||||
parts.append(f"[REDACTED:{cred.type}]")
|
||||
last_end = cred.end
|
||||
parts.append(text[last_end:])
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def is_sensitive(text: str, min_confidence: float = 0.7) -> bool:
|
||||
"""Quick check if text likely contains credentials."""
|
||||
return len(detect_credentials(text, min_confidence)) > 0
|
||||
71
src/claude_memory/crypto.py
Normal file
71
src/claude_memory/crypto.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""AES-256-GCM encryption for memory content when Vault is not available."""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
ENCRYPTION_KEY_ENV = "MEMORY_ENCRYPTION_KEY"
|
||||
|
||||
|
||||
def _get_key() -> bytes | None:
|
||||
"""Get 32-byte encryption key from environment."""
|
||||
raw = os.environ.get(ENCRYPTION_KEY_ENV)
|
||||
if not raw:
|
||||
return None
|
||||
# Accept hex-encoded 32-byte key or derive from passphrase
|
||||
try:
|
||||
key = bytes.fromhex(raw)
|
||||
if len(key) == 32:
|
||||
return key
|
||||
except ValueError:
|
||||
pass
|
||||
# Derive key from passphrase using SHA-256
|
||||
return hashlib.sha256(raw.encode()).digest()
|
||||
|
||||
|
||||
def is_encryption_configured() -> bool:
|
||||
return _get_key() is not None
|
||||
|
||||
|
||||
def encrypt(plaintext: str) -> bytes:
|
||||
"""Encrypt text using AES-256-GCM. Returns nonce + ciphertext + tag."""
|
||||
key = _get_key()
|
||||
if key is None:
|
||||
raise RuntimeError(f"{ENCRYPTION_KEY_ENV} not set")
|
||||
|
||||
try:
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
except ImportError:
|
||||
raise RuntimeError("cryptography package required for encryption: pip install cryptography")
|
||||
|
||||
nonce = os.urandom(12)
|
||||
aesgcm = AESGCM(key)
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext.encode(), None)
|
||||
return nonce + ciphertext # 12 bytes nonce + ciphertext + 16 bytes tag
|
||||
|
||||
|
||||
def decrypt(data: bytes) -> str:
|
||||
"""Decrypt AES-256-GCM encrypted data."""
|
||||
key = _get_key()
|
||||
if key is None:
|
||||
raise RuntimeError(f"{ENCRYPTION_KEY_ENV} not set")
|
||||
|
||||
try:
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
except ImportError:
|
||||
raise RuntimeError("cryptography package required for encryption: pip install cryptography")
|
||||
|
||||
nonce = data[:12]
|
||||
ciphertext = data[12:]
|
||||
aesgcm = AESGCM(key)
|
||||
return aesgcm.decrypt(nonce, ciphertext, None).decode()
|
||||
|
||||
|
||||
def encrypt_b64(plaintext: str) -> str:
|
||||
"""Encrypt and return base64-encoded string."""
|
||||
return base64.b64encode(encrypt(plaintext)).decode()
|
||||
|
||||
|
||||
def decrypt_b64(data: str) -> str:
|
||||
"""Decrypt from base64-encoded string."""
|
||||
return decrypt(base64.b64decode(data))
|
||||
550
src/claude_memory/mcp_server.py
Normal file
550
src/claude_memory/mcp_server.py
Normal file
|
|
@ -0,0 +1,550 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Claude Memory MCP Server — standalone memory server with multi-user support.
|
||||
|
||||
Supports two modes:
|
||||
1. HTTP API mode: connects to a shared PostgreSQL-backed API server
|
||||
2. SQLite fallback: local file-based storage when no API key is configured
|
||||
|
||||
Uses only stdlib (urllib) — no pip install required.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROTOCOL_VERSION = "2024-11-05"
|
||||
SERVER_NAME = "claude-memory"
|
||||
SERVER_VERSION = "1.0.0"
|
||||
|
||||
# API configuration — support both MEMORY_* (primary) and CLAUDE_MEMORY_* (fallback) env vars
|
||||
API_BASE_URL = os.environ.get("MEMORY_API_URL") or os.environ.get("CLAUDE_MEMORY_API_URL", "http://localhost:8080")
|
||||
API_KEY = os.environ.get("MEMORY_API_KEY") or os.environ.get("CLAUDE_MEMORY_API_KEY", "")
|
||||
|
||||
# Fallback to SQLite if API is not configured
|
||||
SQLITE_FALLBACK = not API_KEY
|
||||
|
||||
|
||||
def _api_request(method: str, path: str, body: dict | None = None) -> dict:
|
||||
"""Make an HTTP request to the memory API."""
|
||||
url = f"{API_BASE_URL}{path}"
|
||||
data = json.dumps(body).encode() if body else None
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
data=data,
|
||||
method=method,
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
return json.loads(resp.read().decode())
|
||||
except urllib.error.HTTPError as e:
|
||||
error_body = e.read().decode() if e.fp else str(e)
|
||||
raise RuntimeError(f"API error {e.code}: {error_body}") from e
|
||||
except urllib.error.URLError as e:
|
||||
raise RuntimeError(f"API connection error: {e.reason}") from e
|
||||
|
||||
|
||||
# ─── SQLite fallback (local storage when API not configured) ─────────────────
|
||||
|
||||
def _init_sqlite(db_path: str | None = None):
|
||||
"""Initialize SQLite database as fallback."""
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
if db_path is None:
|
||||
memory_home = os.path.expandvars(
|
||||
os.path.expanduser(os.environ.get("MEMORY_HOME", "~/.claude/claude-memory"))
|
||||
)
|
||||
db_path = os.environ.get(
|
||||
"MEMORY_DB",
|
||||
os.path.join(memory_home, "memory", "memory.db"),
|
||||
)
|
||||
db_path = os.path.expandvars(os.path.expanduser(db_path))
|
||||
|
||||
Path(os.path.dirname(db_path)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(db_path, timeout=30.0)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=30000")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
content TEXT NOT NULL,
|
||||
category TEXT DEFAULT 'facts',
|
||||
tags TEXT DEFAULT '',
|
||||
expanded_keywords TEXT DEFAULT '',
|
||||
importance REAL DEFAULT 0.5,
|
||||
is_sensitive INTEGER DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
|
||||
content, category, tags, expanded_keywords,
|
||||
content='memories', content_rowid='id'
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
|
||||
INSERT INTO memories_fts(rowid, content, category, tags, expanded_keywords)
|
||||
VALUES (new.id, new.content, new.category, new.tags, new.expanded_keywords);
|
||||
END
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
|
||||
INSERT INTO memories_fts(memories_fts, rowid, content, category, tags, expanded_keywords)
|
||||
VALUES ('delete', old.id, old.content, old.category, old.tags, old.expanded_keywords);
|
||||
END
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
|
||||
INSERT INTO memories_fts(memories_fts, rowid, content, category, tags, expanded_keywords)
|
||||
VALUES ('delete', old.id, old.content, old.category, old.tags, old.expanded_keywords);
|
||||
INSERT INTO memories_fts(rowid, content, category, tags, expanded_keywords)
|
||||
VALUES (new.id, new.content, new.category, new.tags, new.expanded_keywords);
|
||||
END
|
||||
""")
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
# ─── Tool definitions ────────────────────────────────────────────────────────
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"name": "memory_store",
|
||||
"description": "Store a fact or memory in persistent storage. Use this to remember important information about the user, their preferences, projects, decisions, or people they mention.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "The fact or memory to store"},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["facts", "preferences", "projects", "people", "decisions"],
|
||||
"description": "Category for organizing the memory",
|
||||
"default": "facts",
|
||||
},
|
||||
"tags": {"type": "string", "description": "Comma-separated tags", "default": ""},
|
||||
"importance": {
|
||||
"type": "number",
|
||||
"description": "Importance 0.0-1.0",
|
||||
"default": 0.5,
|
||||
"minimum": 0.0,
|
||||
"maximum": 1.0,
|
||||
},
|
||||
"expanded_keywords": {
|
||||
"type": "string",
|
||||
"description": "REQUIRED. Space-separated semantically related search terms (MINIMUM 5 words). Generate keywords that someone might search for when this memory would be relevant. Include synonyms, related concepts, and adjacent topics.",
|
||||
},
|
||||
"force_sensitive": {
|
||||
"type": "boolean",
|
||||
"description": "If true, mark this memory as sensitive regardless of auto-detection. Sensitive memories have their content encrypted at rest.",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["content", "expanded_keywords"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "memory_recall",
|
||||
"description": "Retrieve relevant memories based on context. Uses full-text search to find stored memories.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"context": {"type": "string", "description": "The context or topic to recall memories about"},
|
||||
"expanded_query": {
|
||||
"type": "string",
|
||||
"description": "REQUIRED. Space-separated semantically related search terms (MINIMUM 5 words).",
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["facts", "preferences", "projects", "people", "decisions"],
|
||||
"description": "Optional: filter results to a specific category",
|
||||
},
|
||||
"sort_by": {
|
||||
"type": "string",
|
||||
"enum": ["importance", "relevance"],
|
||||
"description": "Sort order",
|
||||
"default": "importance",
|
||||
},
|
||||
"limit": {"type": "integer", "description": "Max results", "default": 10},
|
||||
},
|
||||
"required": ["context", "expanded_query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "memory_list",
|
||||
"description": "List recent memories, optionally filtered by category.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["facts", "preferences", "projects", "people", "decisions"],
|
||||
},
|
||||
"limit": {"type": "integer", "default": 20},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "memory_delete",
|
||||
"description": "Delete a memory by ID.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer", "description": "The ID of the memory to delete"},
|
||||
},
|
||||
"required": ["id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "secret_get",
|
||||
"description": "Retrieve the decrypted content of a sensitive memory. Only works for memories marked as sensitive.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer", "description": "The ID of the sensitive memory to retrieve"},
|
||||
},
|
||||
"required": ["id"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class MemoryServer:
|
||||
"""MCP server for persistent memory management."""
|
||||
|
||||
def __init__(self, sqlite_db_path: str | None = None) -> None:
|
||||
self.sqlite_conn = None
|
||||
if SQLITE_FALLBACK:
|
||||
self.sqlite_conn = _init_sqlite(sqlite_db_path)
|
||||
|
||||
# ── HTTP-backed methods ──────────────────────────────────────────
|
||||
|
||||
def memory_store(self, args: dict[str, Any]) -> str:
|
||||
content = args.get("content")
|
||||
if not content:
|
||||
raise ValueError("content is required")
|
||||
category = args.get("category", "facts")
|
||||
tags = args.get("tags", "")
|
||||
importance = max(0.0, min(1.0, float(args.get("importance", 0.5))))
|
||||
expanded_keywords = args.get("expanded_keywords", "")
|
||||
force_sensitive = bool(args.get("force_sensitive", False))
|
||||
|
||||
if SQLITE_FALLBACK:
|
||||
return self._sqlite_store(content, category, tags, importance, expanded_keywords, force_sensitive)
|
||||
|
||||
result = _api_request("POST", "/api/memories", {
|
||||
"content": content,
|
||||
"category": category,
|
||||
"tags": tags,
|
||||
"expanded_keywords": expanded_keywords,
|
||||
"importance": importance,
|
||||
"force_sensitive": force_sensitive,
|
||||
})
|
||||
return f"Stored memory #{result['id']} in category '{result['category']}' with importance {result['importance']:.1f}"
|
||||
|
||||
def memory_recall(self, args: dict[str, Any]) -> str:
|
||||
context = args.get("context")
|
||||
if not context:
|
||||
raise ValueError("context is required")
|
||||
expanded_query = args.get("expanded_query", "")
|
||||
category = args.get("category")
|
||||
sort_by = args.get("sort_by", "importance")
|
||||
limit = args.get("limit", 10)
|
||||
|
||||
if SQLITE_FALLBACK:
|
||||
return self._sqlite_recall(context, expanded_query, category, sort_by, limit)
|
||||
|
||||
result = _api_request("POST", "/api/memories/recall", {
|
||||
"context": context,
|
||||
"expanded_query": expanded_query,
|
||||
"category": category,
|
||||
"sort_by": sort_by,
|
||||
"limit": limit,
|
||||
})
|
||||
rows = result.get("memories", [])
|
||||
if not rows:
|
||||
filter_desc = f" in category '{category}'" if category else ""
|
||||
return f"No memories found matching: {context}{filter_desc}"
|
||||
|
||||
sort_desc = "by relevance" if sort_by == "relevance" else "by importance"
|
||||
filter_desc = f" in '{category}'" if category else ""
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}"
|
||||
f"\n Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
|
||||
)
|
||||
return f"Found {len(rows)} memories{filter_desc} ({sort_desc}):\n\n" + "\n\n".join(results)
|
||||
|
||||
def memory_list(self, args: dict[str, Any]) -> str:
|
||||
category = args.get("category")
|
||||
limit = args.get("limit", 20)
|
||||
|
||||
if SQLITE_FALLBACK:
|
||||
return self._sqlite_list(category, limit)
|
||||
|
||||
params = f"?limit={limit}"
|
||||
if category:
|
||||
params += f"&category={category}"
|
||||
result = _api_request("GET", f"/api/memories{params}")
|
||||
rows = result.get("memories", [])
|
||||
if not rows:
|
||||
return f"No memories in category '{category}'" if category else "No memories stored yet"
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
f"#{row['id']} [{row['category']}] {row['content']}"
|
||||
f"\n Importance: {row['importance']:.1f} | Tags: {row.get('tags') or 'none'} | Stored: {row['created_at']}"
|
||||
)
|
||||
header = "Recent memories"
|
||||
if category:
|
||||
header += f" in '{category}'"
|
||||
return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results)
|
||||
|
||||
def memory_delete(self, args: dict[str, Any]) -> str:
|
||||
memory_id = args.get("id")
|
||||
if memory_id is None:
|
||||
raise ValueError("id is required")
|
||||
|
||||
if SQLITE_FALLBACK:
|
||||
return self._sqlite_delete(memory_id)
|
||||
|
||||
result = _api_request("DELETE", f"/api/memories/{memory_id}")
|
||||
return f"Deleted memory #{result['deleted']}: {result['preview']}..."
|
||||
|
||||
def secret_get(self, args: dict[str, Any]) -> str:
|
||||
memory_id = args.get("id")
|
||||
if memory_id is None:
|
||||
raise ValueError("id is required")
|
||||
|
||||
if SQLITE_FALLBACK:
|
||||
return self._sqlite_secret_get(memory_id)
|
||||
|
||||
result = _api_request("POST", f"/api/memories/{memory_id}/secret")
|
||||
return f"#{result['id']} [{result['category']}] {result['content']}"
|
||||
|
||||
# ── SQLite fallback methods ──────────────────────────────────────
|
||||
|
||||
def _sqlite_store(self, content, category, tags, importance, expanded_keywords, force_sensitive=False):
|
||||
from datetime import datetime, timezone
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
is_sensitive = 1 if force_sensitive else 0
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO memories (content, category, tags, expanded_keywords, importance, is_sensitive, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(content, category, tags, expanded_keywords, importance, is_sensitive, now, now),
|
||||
)
|
||||
self.sqlite_conn.commit()
|
||||
return f"Stored memory #{cursor.lastrowid} in category '{category}' with importance {importance:.1f}"
|
||||
|
||||
def _sqlite_recall(self, context, expanded_query, category, sort_by, limit):
|
||||
import sqlite3
|
||||
|
||||
all_terms = f"{context} {expanded_query}".strip()
|
||||
words = all_terms.split()
|
||||
fts_query = " OR ".join(f'"{w.replace(chr(34), "")}"' for w in words if w)
|
||||
order = (
|
||||
"bm25(memories_fts), m.importance DESC"
|
||||
if sort_by == "relevance"
|
||||
else "m.importance DESC, m.created_at DESC"
|
||||
)
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
try:
|
||||
if category:
|
||||
cursor.execute(
|
||||
f"SELECT m.id, m.content, m.category, m.tags, m.importance, m.created_at "
|
||||
f"FROM memories m JOIN memories_fts fts ON m.id = fts.rowid "
|
||||
f"WHERE memories_fts MATCH ? AND m.category = ? ORDER BY {order} LIMIT ?",
|
||||
(fts_query, category, limit),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
f"SELECT m.id, m.content, m.category, m.tags, m.importance, m.created_at "
|
||||
f"FROM memories m JOIN memories_fts fts ON m.id = fts.rowid "
|
||||
f"WHERE memories_fts MATCH ? ORDER BY {order} LIMIT ?",
|
||||
(fts_query, limit),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
except sqlite3.OperationalError:
|
||||
like = f"%{context}%"
|
||||
if category:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"WHERE (content LIKE ? OR tags LIKE ?) AND category = ? ORDER BY importance DESC LIMIT ?",
|
||||
(like, like, category, limit),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"WHERE content LIKE ? OR tags LIKE ? ORDER BY importance DESC LIMIT ?",
|
||||
(like, like, limit),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
return f"No memories found matching: {context}"
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
f"#{row['id']} [{row['category']}] (importance: {row['importance']:.1f}) {row['content']}"
|
||||
f"\n Tags: {row['tags'] or 'none'} | Stored: {row['created_at']}"
|
||||
)
|
||||
return (
|
||||
f"Found {len(rows)} memories (by {'relevance' if sort_by == 'relevance' else 'importance'}):\n\n"
|
||||
+ "\n\n".join(results)
|
||||
)
|
||||
|
||||
def _sqlite_list(self, category, limit):
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
if category:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"WHERE category = ? ORDER BY created_at DESC LIMIT ?",
|
||||
(category, limit),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, tags, importance, created_at FROM memories "
|
||||
"ORDER BY created_at DESC LIMIT ?",
|
||||
(limit,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
if not rows:
|
||||
return f"No memories in category '{category}'" if category else "No memories stored yet"
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
f"#{row['id']} [{row['category']}] {row['content']}"
|
||||
f"\n Importance: {row['importance']:.1f} | Tags: {row['tags'] or 'none'} | Stored: {row['created_at']}"
|
||||
)
|
||||
header = "Recent memories" + (f" in '{category}'" if category else "")
|
||||
return header + f" ({len(rows)} shown):\n\n" + "\n\n".join(results)
|
||||
|
||||
def _sqlite_delete(self, memory_id):
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
cursor.execute("SELECT id, content FROM memories WHERE id = ?", (memory_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return f"Memory #{memory_id} not found"
|
||||
preview = row["content"][:50]
|
||||
cursor.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
|
||||
self.sqlite_conn.commit()
|
||||
return f"Deleted memory #{memory_id}: {preview}..."
|
||||
|
||||
def _sqlite_secret_get(self, memory_id):
|
||||
cursor = self.sqlite_conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT id, content, category, is_sensitive FROM memories WHERE id = ?",
|
||||
(memory_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return f"Memory #{memory_id} not found"
|
||||
if not row["is_sensitive"]:
|
||||
return f"Memory #{memory_id} is not marked as sensitive"
|
||||
return f"#{row['id']} [{row['category']}] {row['content']}"
|
||||
|
||||
# ── MCP protocol ─────────────────────────────────────────────────
|
||||
|
||||
def handle_initialize(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"protocolVersion": PROTOCOL_VERSION,
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": SERVER_NAME, "version": SERVER_VERSION},
|
||||
}
|
||||
|
||||
def handle_tools_list(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"tools": TOOLS}
|
||||
|
||||
def handle_tools_call(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
try:
|
||||
handler = {
|
||||
"memory_store": self.memory_store,
|
||||
"memory_recall": self.memory_recall,
|
||||
"memory_list": self.memory_list,
|
||||
"memory_delete": self.memory_delete,
|
||||
"secret_get": self.secret_get,
|
||||
}.get(tool_name)
|
||||
if handler is None:
|
||||
return {"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], "isError": True}
|
||||
result = handler(arguments)
|
||||
return {"content": [{"type": "text", "text": result}]}
|
||||
except Exception as e:
|
||||
return {"content": [{"type": "text", "text": f"Error: {e!s}"}], "isError": True}
|
||||
|
||||
def process_message(self, message: dict[str, Any]) -> dict[str, Any] | None:
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
msg_id = message.get("id")
|
||||
if msg_id is None:
|
||||
return None
|
||||
result = None
|
||||
error = None
|
||||
try:
|
||||
if method == "initialize":
|
||||
result = self.handle_initialize(params)
|
||||
elif method == "tools/list":
|
||||
result = self.handle_tools_list(params)
|
||||
elif method == "tools/call":
|
||||
result = self.handle_tools_call(params)
|
||||
else:
|
||||
error = {"code": -32601, "message": f"Method not found: {method}"}
|
||||
except Exception as e:
|
||||
error = {"code": -32603, "message": str(e)}
|
||||
response: dict[str, Any] = {"jsonrpc": "2.0", "id": msg_id}
|
||||
if error:
|
||||
response["error"] = error
|
||||
else:
|
||||
response["result"] = result
|
||||
return response
|
||||
|
||||
def run(self) -> None:
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
message = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
json.dumps({
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {"code": -32700, "message": f"Parse error: {e}"},
|
||||
}),
|
||||
flush=True,
|
||||
)
|
||||
continue
|
||||
response = self.process_message(message)
|
||||
if response is not None:
|
||||
print(json.dumps(response), flush=True)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
server = MemoryServer()
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
84
src/claude_memory/vault_client.py
Normal file
84
src/claude_memory/vault_client.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""HashiCorp Vault KV v2 client using stdlib urllib."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VaultClient:
|
||||
"""Simple Vault KV v2 client using stdlib."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
addr: str | None = None,
|
||||
token: str | None = None,
|
||||
mount: str = "secret",
|
||||
):
|
||||
self.addr = (addr or os.environ.get("VAULT_ADDR", "")).rstrip("/")
|
||||
self.token = token or os.environ.get("VAULT_TOKEN", "")
|
||||
self.mount = mount
|
||||
|
||||
if not self.addr:
|
||||
raise ValueError("Vault address not configured (set VAULT_ADDR)")
|
||||
|
||||
# Auto-detect Kubernetes SA token
|
||||
if not self.token:
|
||||
sa_token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token"
|
||||
if os.path.exists(sa_token_path):
|
||||
self._login_kubernetes(sa_token_path)
|
||||
|
||||
def _login_kubernetes(self, sa_token_path: str) -> None:
|
||||
"""Authenticate with Vault using Kubernetes service account."""
|
||||
with open(sa_token_path) as f:
|
||||
jwt = f.read().strip()
|
||||
role = os.environ.get("VAULT_ROLE", "claude-memory")
|
||||
resp = self._request("POST", "/v1/auth/kubernetes/login", {"jwt": jwt, "role": role})
|
||||
self.token = resp.get("auth", {}).get("client_token", "")
|
||||
|
||||
def _request(self, method: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Make HTTP request to Vault."""
|
||||
url = f"{self.addr}{path}"
|
||||
data = json.dumps(body).encode() if body else None
|
||||
req = urllib.request.Request(
|
||||
url, data=data, method=method,
|
||||
headers={"X-Vault-Token": self.token, "Content-Type": "application/json"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
return json.loads(resp.read().decode())
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
return {}
|
||||
error_body = e.read().decode() if e.fp else str(e)
|
||||
raise RuntimeError(f"Vault error {e.code}: {error_body}") from e
|
||||
|
||||
def read(self, path: str) -> dict[str, Any] | None:
|
||||
"""Read a secret from KV v2."""
|
||||
resp = self._request("GET", f"/v1/{self.mount}/data/{path}")
|
||||
data = resp.get("data", {})
|
||||
return data.get("data") if data else None
|
||||
|
||||
def write(self, path: str, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Write a secret to KV v2."""
|
||||
return self._request("POST", f"/v1/{self.mount}/data/{path}", {"data": data})
|
||||
|
||||
def delete(self, path: str) -> bool:
|
||||
"""Delete a secret from KV v2."""
|
||||
try:
|
||||
self._request("DELETE", f"/v1/{self.mount}/data/{path}")
|
||||
return True
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
def list_secrets(self, path: str) -> list[str]:
|
||||
"""List secrets at a path."""
|
||||
try:
|
||||
resp = self._request("LIST", f"/v1/{self.mount}/metadata/{path}")
|
||||
return resp.get("data", {}).get("keys", [])
|
||||
except RuntimeError:
|
||||
return []
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
304
tests/test_api.py
Normal file
304
tests/test_api.py
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
"""Tests for the Claude Memory API endpoints."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from claude_memory.api.auth import AuthUser
|
||||
|
||||
|
||||
# Helpers to build mock asyncpg rows (they behave like dicts with attribute access)
|
||||
class MockRow(dict):
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(key)
|
||||
|
||||
|
||||
def _make_memory_row(**overrides):
|
||||
now = datetime.now(timezone.utc)
|
||||
defaults = {
|
||||
"id": 1,
|
||||
"user_id": "testuser",
|
||||
"content": "test content",
|
||||
"category": "facts",
|
||||
"tags": "",
|
||||
"expanded_keywords": "",
|
||||
"importance": 0.5,
|
||||
"is_sensitive": False,
|
||||
"vault_path": None,
|
||||
"encrypted_content": None,
|
||||
"rank": 0.5,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return MockRow(defaults)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pool():
|
||||
"""Create a mock asyncpg pool with connection context manager."""
|
||||
pool = MagicMock()
|
||||
conn = AsyncMock()
|
||||
|
||||
# pool.acquire() returns an async context manager yielding conn
|
||||
acm = MagicMock()
|
||||
acm.__aenter__ = AsyncMock(return_value=conn)
|
||||
acm.__aexit__ = AsyncMock(return_value=False)
|
||||
pool.acquire.return_value = acm
|
||||
|
||||
return pool, conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user():
|
||||
return AuthUser(user_id="testuser")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_pool, test_user):
|
||||
"""Create an AsyncClient with mocked dependencies."""
|
||||
pool, conn = mock_pool
|
||||
|
||||
# Reload modules with test API key
|
||||
with patch.dict(os.environ, {"API_KEY": "test-key", "API_KEYS": "", "DATABASE_URL": "postgresql://test"}):
|
||||
import claude_memory.api.auth as auth_mod
|
||||
import claude_memory.api.database as db_mod
|
||||
import claude_memory.api.app as app_mod
|
||||
|
||||
importlib.reload(auth_mod)
|
||||
importlib.reload(db_mod)
|
||||
importlib.reload(app_mod)
|
||||
|
||||
# Override database pool
|
||||
db_mod.pool = pool
|
||||
|
||||
# Override auth to return our test user
|
||||
async def mock_get_user(authorization: str = ""):
|
||||
return test_user
|
||||
|
||||
app_mod.app.dependency_overrides[auth_mod.get_current_user] = mock_get_user
|
||||
|
||||
transport = ASGITransport(app=app_mod.app)
|
||||
return AsyncClient(transport=transport, base_url="http://test"), conn, app_mod
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint_no_auth(client):
|
||||
ac, conn, app_mod = client
|
||||
async with ac:
|
||||
resp = await ac.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_memory_creates_record_with_user_id(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = _make_memory_row(id=42, category="facts", importance=0.7)
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories",
|
||||
json={"content": "Python is great", "category": "facts", "importance": 0.7},
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == 42
|
||||
assert data["category"] == "facts"
|
||||
assert data["importance"] == 0.7
|
||||
|
||||
# Verify INSERT was called with user_id
|
||||
call_args = conn.fetchrow.call_args
|
||||
assert call_args[0][1] == "testuser" # user_id is the second positional arg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_returns_only_user_memories(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetch.return_value = [
|
||||
_make_memory_row(id=1, content="user memory", is_sensitive=False),
|
||||
]
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/recall",
|
||||
json={"context": "test query"},
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
results = resp.json()
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "user memory"
|
||||
|
||||
# Verify query includes user_id filter
|
||||
call_args = conn.fetch.call_args
|
||||
assert call_args[0][1] == "testuser"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_redacts_sensitive_memories(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetch.return_value = [
|
||||
_make_memory_row(id=5, content="[REDACTED]", is_sensitive=True),
|
||||
]
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/recall",
|
||||
json={"context": "secrets"},
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
results = resp.json()
|
||||
assert "[SENSITIVE" in results[0]["content"]
|
||||
assert "secret_get(id=5)" in results[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_returns_only_user_memories(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetch.return_value = [
|
||||
_make_memory_row(id=1, content="mem1"),
|
||||
_make_memory_row(id=2, content="mem2"),
|
||||
]
|
||||
|
||||
async with ac:
|
||||
resp = await ac.get(
|
||||
"/api/memories",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
results = resp.json()
|
||||
assert len(results) == 2
|
||||
|
||||
# Verify user_id filter
|
||||
call_args = conn.fetch.call_args
|
||||
assert call_args[0][1] == "testuser"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_only_user_memories(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = _make_memory_row(id=10, vault_path=None)
|
||||
conn.execute.return_value = None
|
||||
|
||||
async with ac:
|
||||
resp = await ac.delete(
|
||||
"/api/memories/10",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"deleted": 10}
|
||||
|
||||
# Verify both SELECT and DELETE include user_id
|
||||
fetchrow_args = conn.fetchrow.call_args
|
||||
assert fetchrow_args[0][1] == 10 # memory_id
|
||||
assert fetchrow_args[0][2] == "testuser" # user_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_memory_returns_404(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = None
|
||||
|
||||
async with ac:
|
||||
resp = await ac.delete(
|
||||
"/api/memories/999",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secret_endpoint_returns_plaintext(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = _make_memory_row(
|
||||
id=7, content="my secret value", is_sensitive=False,
|
||||
vault_path=None, encrypted_content=None,
|
||||
)
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/7/secret",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == 7
|
||||
assert data["content"] == "my secret value"
|
||||
assert data["source"] == "plaintext"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secret_endpoint_returns_vault_content(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = _make_memory_row(
|
||||
id=8, content="[REDACTED]", is_sensitive=True,
|
||||
vault_path="claude-memory/testuser/mem-8", encrypted_content=None,
|
||||
)
|
||||
|
||||
with patch("claude_memory.api.app.get_secret", return_value="actual-secret-from-vault"):
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/8/secret",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["content"] == "actual-secret-from-vault"
|
||||
assert data["source"] == "vault"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secret_endpoint_nonexistent_returns_404(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.return_value = None
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/999/secret",
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_memories(client):
|
||||
ac, conn, app_mod = client
|
||||
conn.fetchrow.side_effect = [
|
||||
_make_memory_row(id=100, category="facts", importance=0.5),
|
||||
_make_memory_row(id=101, category="preferences", importance=0.8),
|
||||
]
|
||||
|
||||
async with ac:
|
||||
resp = await ac.post(
|
||||
"/api/memories/import",
|
||||
json=[
|
||||
{"content": "fact one", "category": "facts"},
|
||||
{"content": "pref one", "category": "preferences", "importance": 0.8},
|
||||
],
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["id"] == 100
|
||||
assert data[1]["id"] == 101
|
||||
87
tests/test_auth.py
Normal file
87
tests/test_auth.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Tests for multi-user authentication."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def _reload_auth(env_vars: dict):
|
||||
"""Reload the auth module with given environment variables."""
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
# Clear existing env vars that might interfere
|
||||
for key in ("API_KEY", "API_KEYS"):
|
||||
os.environ.pop(key, None)
|
||||
for key, val in env_vars.items():
|
||||
os.environ[key] = val
|
||||
|
||||
import claude_memory.api.auth as auth_mod
|
||||
|
||||
importlib.reload(auth_mod)
|
||||
return auth_mod
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_api_key_maps_to_default():
|
||||
auth = _reload_auth({"API_KEY": "test-key-123", "API_KEYS": ""})
|
||||
user = await auth.get_current_user(authorization="Bearer test-key-123")
|
||||
assert user.user_id == "default"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_api_keys_maps_to_correct_user():
|
||||
auth = _reload_auth({
|
||||
"API_KEYS": '{"viktor": "key-viktor", "alice": "key-alice"}',
|
||||
"API_KEY": "",
|
||||
})
|
||||
user_v = await auth.get_current_user(authorization="Bearer key-viktor")
|
||||
assert user_v.user_id == "viktor"
|
||||
|
||||
user_a = await auth.get_current_user(authorization="Bearer key-alice")
|
||||
assert user_a.user_id == "alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_key_returns_401():
|
||||
auth = _reload_auth({"API_KEY": "valid-key", "API_KEYS": ""})
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth.get_current_user(authorization="Bearer wrong-key")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_bearer_prefix_still_works():
|
||||
auth = _reload_auth({"API_KEY": "my-key", "API_KEYS": ""})
|
||||
# Without Bearer prefix, removeprefix("Bearer ") returns "my-key" unchanged
|
||||
# so the raw token still matches the key
|
||||
user = await auth.get_current_user(authorization="my-key")
|
||||
assert user.user_id == "default"
|
||||
|
||||
# With proper Bearer prefix it also works
|
||||
user = await auth.get_current_user(authorization="Bearer my-key")
|
||||
assert user.user_id == "default"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_authorization_header_raises_422():
|
||||
"""FastAPI raises 422 when required Header is missing.
|
||||
This is tested via the app integration, not the function directly,
|
||||
since FastAPI handles the missing header before the function runs.
|
||||
"""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
# Need to reload with valid keys so the app can start
|
||||
_reload_auth({"API_KEY": "test-key", "API_KEYS": ""})
|
||||
|
||||
# Import app after auth is configured
|
||||
import claude_memory.api.app as app_mod
|
||||
|
||||
importlib.reload(app_mod)
|
||||
|
||||
transport = ASGITransport(app=app_mod.app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
# Skip lifespan since we don't have a real DB
|
||||
resp = await client.get("/api/memories")
|
||||
assert resp.status_code == 422
|
||||
132
tests/test_credential_detector.py
Normal file
132
tests/test_credential_detector.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""Tests for credential detection and redaction."""
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_memory.credential_detector import (
|
||||
DetectedCredential,
|
||||
detect_credentials,
|
||||
is_sensitive,
|
||||
redact_credentials,
|
||||
)
|
||||
|
||||
|
||||
class TestDetectCredentials:
|
||||
def test_detect_postgres_connection_string(self):
|
||||
text = "db_url = postgres://user:pass@localhost:5432/mydb"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) == 1
|
||||
assert creds[0].type == "connection_string"
|
||||
assert creds[0].confidence == 0.9
|
||||
assert "postgres://" in creds[0].matched_text
|
||||
|
||||
def test_detect_password_assignment(self):
|
||||
text = 'password = "my_super_secret_pw"'
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) >= 1
|
||||
types = [c.type for c in creds]
|
||||
assert "password" in types
|
||||
|
||||
def test_detect_api_key(self):
|
||||
text = "api_key = ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) >= 1
|
||||
types = [c.type for c in creds]
|
||||
assert "api_key" in types
|
||||
|
||||
def test_detect_private_key(self):
|
||||
text = "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA0Z3VS5JJcds3xfn/ygWep4PAtGoSo\n-----END RSA PRIVATE KEY-----"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) == 1
|
||||
assert creds[0].type == "private_key"
|
||||
assert creds[0].confidence == 0.95
|
||||
|
||||
def test_detect_bearer_token(self):
|
||||
text = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkw"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) >= 1
|
||||
types = [c.type for c in creds]
|
||||
assert "bearer_token" in types
|
||||
|
||||
def test_detect_aws_key(self):
|
||||
text = "aws_access_key_id = AKIAIOSFODNN7EXAMPLE"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) >= 1
|
||||
types = [c.type for c in creds]
|
||||
assert "aws_key" in types
|
||||
|
||||
def test_detect_github_token(self):
|
||||
text = "GITHUB_TOKEN=ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmn"
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) >= 1
|
||||
types = [c.type for c in creds]
|
||||
assert "github_token" in types
|
||||
|
||||
def test_no_false_positives_on_normal_text(self):
|
||||
text = "This is a normal paragraph about programming. It discusses variables, functions, and classes."
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) == 0
|
||||
|
||||
def test_no_false_positives_on_short_password(self):
|
||||
# password values shorter than 8 chars should not match
|
||||
text = 'password = "short"'
|
||||
creds = detect_credentials(text)
|
||||
assert len(creds) == 0
|
||||
|
||||
def test_min_confidence_filtering(self):
|
||||
text = 'secret = "abcdefghijklmnopqrstuvwxyz"'
|
||||
all_creds = detect_credentials(text, min_confidence=0.5)
|
||||
high_creds = detect_credentials(text, min_confidence=0.9)
|
||||
assert len(all_creds) >= len(high_creds)
|
||||
|
||||
def test_overlapping_matches_keep_highest_confidence(self):
|
||||
# A text that could match both token and generic_secret
|
||||
text = 'secret = "abcdefghijklmnopqrstuvwxyz1234567890"'
|
||||
creds = detect_credentials(text, min_confidence=0.5)
|
||||
# Should not have overlapping ranges for the same span
|
||||
for i, c1 in enumerate(creds):
|
||||
for c2 in creds[i + 1:]:
|
||||
# No credential should be fully contained within another
|
||||
assert not (c1.start <= c2.start and c1.end >= c2.end)
|
||||
|
||||
|
||||
class TestRedactCredentials:
|
||||
def test_redaction_replaces_with_marker(self):
|
||||
text = "db_url = postgres://user:pass@localhost:5432/mydb"
|
||||
creds = detect_credentials(text)
|
||||
redacted = redact_credentials(text, creds)
|
||||
assert "[REDACTED:connection_string]" in redacted
|
||||
assert "postgres://" not in redacted
|
||||
|
||||
def test_redaction_preserves_surrounding_text(self):
|
||||
text = "before postgres://user:pass@localhost/db after"
|
||||
creds = detect_credentials(text)
|
||||
redacted = redact_credentials(text, creds)
|
||||
assert redacted.startswith("before ")
|
||||
assert redacted.endswith(" after")
|
||||
|
||||
def test_redaction_no_credentials(self):
|
||||
text = "nothing sensitive here"
|
||||
redacted = redact_credentials(text, [])
|
||||
assert redacted == text
|
||||
|
||||
def test_redaction_multiple_credentials(self):
|
||||
text = 'password = "mysecretpw123" and api_key = ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890'
|
||||
creds = detect_credentials(text)
|
||||
redacted = redact_credentials(text, creds)
|
||||
assert "mysecretpw123" not in redacted
|
||||
assert "[REDACTED:" in redacted
|
||||
|
||||
|
||||
class TestIsSensitive:
|
||||
def test_sensitive_text(self):
|
||||
assert is_sensitive("password = supersecretvalue123")
|
||||
|
||||
def test_non_sensitive_text(self):
|
||||
assert not is_sensitive("just a normal log message")
|
||||
|
||||
def test_respects_min_confidence(self):
|
||||
text = 'secret = "abcdefghijklmnopqrstuvwxyz"'
|
||||
# Low confidence should detect
|
||||
assert is_sensitive(text, min_confidence=0.5)
|
||||
# Very high confidence should not detect generic_secret
|
||||
assert not is_sensitive(text, min_confidence=0.95)
|
||||
134
tests/test_crypto.py
Normal file
134
tests/test_crypto.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""Tests for AES-256-GCM encryption module."""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_memory.crypto import (
|
||||
ENCRYPTION_KEY_ENV,
|
||||
decrypt,
|
||||
decrypt_b64,
|
||||
encrypt,
|
||||
encrypt_b64,
|
||||
is_encryption_configured,
|
||||
)
|
||||
|
||||
# A valid 32-byte hex key for testing
|
||||
TEST_HEX_KEY = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
TEST_PASSPHRASE = "my-test-passphrase"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hex_key_env(monkeypatch):
|
||||
monkeypatch.setenv(ENCRYPTION_KEY_ENV, TEST_HEX_KEY)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def passphrase_env(monkeypatch):
|
||||
monkeypatch.setenv(ENCRYPTION_KEY_ENV, TEST_PASSPHRASE)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_key_env(monkeypatch):
|
||||
monkeypatch.delenv(ENCRYPTION_KEY_ENV, raising=False)
|
||||
|
||||
|
||||
class TestEncryptionConfigured:
|
||||
def test_configured_with_hex_key(self, hex_key_env):
|
||||
assert is_encryption_configured() is True
|
||||
|
||||
def test_configured_with_passphrase(self, passphrase_env):
|
||||
assert is_encryption_configured() is True
|
||||
|
||||
def test_not_configured_without_env(self, no_key_env):
|
||||
assert is_encryption_configured() is False
|
||||
|
||||
|
||||
class TestEncryptDecrypt:
|
||||
def test_roundtrip_with_hex_key(self, hex_key_env):
|
||||
plaintext = "Hello, this is a secret message!"
|
||||
encrypted = encrypt(plaintext)
|
||||
decrypted = decrypt(encrypted)
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_roundtrip_with_passphrase(self, passphrase_env):
|
||||
plaintext = "Another secret message with passphrase key"
|
||||
encrypted = encrypt(plaintext)
|
||||
decrypted = decrypt(encrypted)
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_different_plaintexts_produce_different_ciphertexts(self, hex_key_env):
|
||||
ct1 = encrypt("message one")
|
||||
ct2 = encrypt("message two")
|
||||
assert ct1 != ct2
|
||||
|
||||
def test_same_plaintext_produces_different_ciphertexts(self, hex_key_env):
|
||||
"""Due to random nonce, encrypting the same text twice gives different results."""
|
||||
ct1 = encrypt("same message")
|
||||
ct2 = encrypt("same message")
|
||||
assert ct1 != ct2
|
||||
|
||||
def test_missing_key_raises_on_encrypt(self, no_key_env):
|
||||
with pytest.raises(RuntimeError, match=ENCRYPTION_KEY_ENV):
|
||||
encrypt("test")
|
||||
|
||||
def test_missing_key_raises_on_decrypt(self, no_key_env):
|
||||
with pytest.raises(RuntimeError, match=ENCRYPTION_KEY_ENV):
|
||||
decrypt(b"\x00" * 28)
|
||||
|
||||
def test_decrypt_with_wrong_key_fails(self, hex_key_env, monkeypatch):
|
||||
plaintext = "secret data"
|
||||
encrypted = encrypt(plaintext)
|
||||
|
||||
# Change to a different key
|
||||
monkeypatch.setenv(ENCRYPTION_KEY_ENV, "ff" * 32)
|
||||
with pytest.raises(Exception):
|
||||
decrypt(encrypted)
|
||||
|
||||
def test_encrypted_data_format(self, hex_key_env):
|
||||
"""Encrypted data should be at least 12 (nonce) + 16 (tag) bytes."""
|
||||
encrypted = encrypt("x")
|
||||
assert len(encrypted) >= 28 # 12 nonce + 1 plaintext + 16 tag = 29 minimum
|
||||
|
||||
def test_unicode_roundtrip(self, hex_key_env):
|
||||
plaintext = "Unicode test: cafe\u0301, \u00fc\u00f6\u00e4, \U0001f512"
|
||||
decrypted = decrypt(encrypt(plaintext))
|
||||
assert decrypted == plaintext
|
||||
|
||||
|
||||
class TestBase64Variants:
|
||||
def test_b64_roundtrip(self, hex_key_env):
|
||||
plaintext = "base64 test message"
|
||||
encrypted_b64 = encrypt_b64(plaintext)
|
||||
assert isinstance(encrypted_b64, str)
|
||||
decrypted = decrypt_b64(encrypted_b64)
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_b64_output_is_valid_base64(self, hex_key_env):
|
||||
import base64
|
||||
encrypted_b64 = encrypt_b64("test")
|
||||
# Should not raise
|
||||
decoded = base64.b64decode(encrypted_b64)
|
||||
assert len(decoded) >= 28
|
||||
|
||||
|
||||
class TestKeyDerivation:
|
||||
def test_hex_key_used_directly(self, hex_key_env):
|
||||
"""A valid 64-char hex string should be used as-is (32 bytes)."""
|
||||
ct = encrypt("test")
|
||||
pt = decrypt(ct)
|
||||
assert pt == "test"
|
||||
|
||||
def test_passphrase_derived_via_sha256(self, passphrase_env):
|
||||
"""Non-hex strings should be derived via SHA-256."""
|
||||
ct = encrypt("test")
|
||||
pt = decrypt(ct)
|
||||
assert pt == "test"
|
||||
|
||||
def test_short_hex_treated_as_passphrase(self, monkeypatch):
|
||||
"""Hex string that's not exactly 32 bytes should be treated as passphrase."""
|
||||
monkeypatch.setenv(ENCRYPTION_KEY_ENV, "abcd1234")
|
||||
ct = encrypt("test")
|
||||
pt = decrypt(ct)
|
||||
assert pt == "test"
|
||||
342
tests/test_mcp_server.py
Normal file
342
tests/test_mcp_server.py
Normal file
|
|
@ -0,0 +1,342 @@
|
|||
"""Tests for the Claude Memory MCP server."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Force SQLite fallback mode for all tests
|
||||
os.environ.pop("MEMORY_API_KEY", None)
|
||||
os.environ.pop("CLAUDE_MEMORY_API_KEY", None)
|
||||
|
||||
# Add src to path so we can import without installing
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
from claude_memory.mcp_server import MemoryServer, TOOLS, SERVER_NAME, SERVER_VERSION, PROTOCOL_VERSION
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server(tmp_path):
|
||||
"""Create a MemoryServer with a temporary SQLite database."""
|
||||
db_path = str(tmp_path / "test_memory.db")
|
||||
srv = MemoryServer(sqlite_db_path=db_path)
|
||||
yield srv
|
||||
if srv.sqlite_conn:
|
||||
srv.sqlite_conn.close()
|
||||
|
||||
|
||||
class TestSQLiteInit:
|
||||
def test_creates_database(self, tmp_path):
|
||||
db_path = str(tmp_path / "sub" / "test.db")
|
||||
srv = MemoryServer(sqlite_db_path=db_path)
|
||||
assert os.path.exists(db_path)
|
||||
# Verify tables exist
|
||||
cursor = srv.sqlite_conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories'")
|
||||
assert cursor.fetchone() is not None
|
||||
srv.sqlite_conn.close()
|
||||
|
||||
def test_creates_fts_table(self, tmp_path):
|
||||
db_path = str(tmp_path / "test.db")
|
||||
srv = MemoryServer(sqlite_db_path=db_path)
|
||||
cursor = srv.sqlite_conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='memories_fts'")
|
||||
assert cursor.fetchone() is not None
|
||||
srv.sqlite_conn.close()
|
||||
|
||||
|
||||
class TestMemoryStore:
|
||||
def test_store_basic(self, server):
|
||||
result = server.memory_store({
|
||||
"content": "User prefers dark mode",
|
||||
"expanded_keywords": "dark mode theme preference ui",
|
||||
})
|
||||
assert "Stored memory #1" in result
|
||||
assert "facts" in result
|
||||
|
||||
def test_store_with_category(self, server):
|
||||
result = server.memory_store({
|
||||
"content": "User likes Python",
|
||||
"category": "preferences",
|
||||
"expanded_keywords": "python programming language preference",
|
||||
})
|
||||
assert "preferences" in result
|
||||
|
||||
def test_store_with_importance(self, server):
|
||||
result = server.memory_store({
|
||||
"content": "Critical info",
|
||||
"importance": 0.9,
|
||||
"expanded_keywords": "critical important info",
|
||||
})
|
||||
assert "0.9" in result
|
||||
|
||||
def test_store_requires_content(self, server):
|
||||
with pytest.raises(ValueError, match="content is required"):
|
||||
server.memory_store({"expanded_keywords": "test"})
|
||||
|
||||
def test_store_force_sensitive(self, server):
|
||||
result = server.memory_store({
|
||||
"content": "API key: sk-1234",
|
||||
"force_sensitive": True,
|
||||
"expanded_keywords": "api key secret credential",
|
||||
})
|
||||
assert "Stored memory #1" in result
|
||||
# Verify is_sensitive flag is set
|
||||
cursor = server.sqlite_conn.cursor()
|
||||
cursor.execute("SELECT is_sensitive FROM memories WHERE id = 1")
|
||||
row = cursor.fetchone()
|
||||
assert row["is_sensitive"] == 1
|
||||
|
||||
|
||||
class TestMemoryRecall:
|
||||
def test_recall_finds_memory(self, server):
|
||||
server.memory_store({
|
||||
"content": "User works at Acme Corp",
|
||||
"expanded_keywords": "acme corp company work employer",
|
||||
})
|
||||
result = server.memory_recall({
|
||||
"context": "work",
|
||||
"expanded_query": "company employer job",
|
||||
})
|
||||
assert "Acme Corp" in result
|
||||
assert "Found 1 memories" in result
|
||||
|
||||
def test_recall_no_results(self, server):
|
||||
result = server.memory_recall({
|
||||
"context": "nonexistent topic",
|
||||
"expanded_query": "nothing here at all",
|
||||
})
|
||||
assert "No memories found" in result
|
||||
|
||||
def test_recall_with_category_filter(self, server):
|
||||
server.memory_store({
|
||||
"content": "User prefers vim",
|
||||
"category": "preferences",
|
||||
"expanded_keywords": "vim editor preference text",
|
||||
})
|
||||
server.memory_store({
|
||||
"content": "Project uses React",
|
||||
"category": "projects",
|
||||
"expanded_keywords": "react project frontend framework",
|
||||
})
|
||||
result = server.memory_recall({
|
||||
"context": "preferences",
|
||||
"expanded_query": "vim editor",
|
||||
"category": "preferences",
|
||||
})
|
||||
assert "vim" in result
|
||||
assert "React" not in result
|
||||
|
||||
def test_recall_requires_context(self, server):
|
||||
with pytest.raises(ValueError, match="context is required"):
|
||||
server.memory_recall({"expanded_query": "test"})
|
||||
|
||||
|
||||
class TestMemoryList:
|
||||
def test_list_empty(self, server):
|
||||
result = server.memory_list({})
|
||||
assert "No memories stored yet" in result
|
||||
|
||||
def test_list_with_memories(self, server):
|
||||
server.memory_store({
|
||||
"content": "Memory one",
|
||||
"expanded_keywords": "one first test",
|
||||
})
|
||||
server.memory_store({
|
||||
"content": "Memory two",
|
||||
"expanded_keywords": "two second test",
|
||||
})
|
||||
result = server.memory_list({})
|
||||
assert "Memory one" in result
|
||||
assert "Memory two" in result
|
||||
assert "2 shown" in result
|
||||
|
||||
def test_list_with_category(self, server):
|
||||
server.memory_store({
|
||||
"content": "A fact",
|
||||
"category": "facts",
|
||||
"expanded_keywords": "fact test",
|
||||
})
|
||||
server.memory_store({
|
||||
"content": "A preference",
|
||||
"category": "preferences",
|
||||
"expanded_keywords": "preference test",
|
||||
})
|
||||
result = server.memory_list({"category": "facts"})
|
||||
assert "A fact" in result
|
||||
assert "A preference" not in result
|
||||
|
||||
def test_list_empty_category(self, server):
|
||||
result = server.memory_list({"category": "projects"})
|
||||
assert "No memories in category 'projects'" in result
|
||||
|
||||
def test_list_respects_limit(self, server):
|
||||
for i in range(5):
|
||||
server.memory_store({
|
||||
"content": f"Memory {i}",
|
||||
"expanded_keywords": f"memory number {i}",
|
||||
})
|
||||
result = server.memory_list({"limit": 2})
|
||||
assert "2 shown" in result
|
||||
|
||||
|
||||
class TestMemoryDelete:
|
||||
def test_delete_existing(self, server):
|
||||
server.memory_store({
|
||||
"content": "To be deleted",
|
||||
"expanded_keywords": "delete remove test",
|
||||
})
|
||||
result = server.memory_delete({"id": 1})
|
||||
assert "Deleted memory #1" in result
|
||||
assert "To be deleted" in result
|
||||
|
||||
def test_delete_nonexistent(self, server):
|
||||
result = server.memory_delete({"id": 999})
|
||||
assert "not found" in result
|
||||
|
||||
def test_delete_requires_id(self, server):
|
||||
with pytest.raises(ValueError, match="id is required"):
|
||||
server.memory_delete({})
|
||||
|
||||
|
||||
class TestSecretGet:
|
||||
def test_secret_get_sensitive(self, server):
|
||||
server.memory_store({
|
||||
"content": "secret password 12345",
|
||||
"force_sensitive": True,
|
||||
"expanded_keywords": "password secret credential",
|
||||
})
|
||||
result = server.secret_get({"id": 1})
|
||||
assert "secret password 12345" in result
|
||||
|
||||
def test_secret_get_not_sensitive(self, server):
|
||||
server.memory_store({
|
||||
"content": "public info",
|
||||
"expanded_keywords": "public info test",
|
||||
})
|
||||
result = server.secret_get({"id": 1})
|
||||
assert "not marked as sensitive" in result
|
||||
|
||||
def test_secret_get_nonexistent(self, server):
|
||||
result = server.secret_get({"id": 999})
|
||||
assert "not found" in result
|
||||
|
||||
def test_secret_get_requires_id(self, server):
|
||||
with pytest.raises(ValueError, match="id is required"):
|
||||
server.secret_get({})
|
||||
|
||||
|
||||
class TestMCPProtocol:
|
||||
def test_handle_initialize(self, server):
|
||||
result = server.handle_initialize({})
|
||||
assert result["protocolVersion"] == PROTOCOL_VERSION
|
||||
assert result["serverInfo"]["name"] == SERVER_NAME
|
||||
assert result["serverInfo"]["version"] == SERVER_VERSION
|
||||
assert "tools" in result["capabilities"]
|
||||
|
||||
def test_handle_tools_list(self, server):
|
||||
result = server.handle_tools_list({})
|
||||
tools = result["tools"]
|
||||
assert len(tools) == 5
|
||||
names = {t["name"] for t in tools}
|
||||
assert names == {"memory_store", "memory_recall", "memory_list", "memory_delete", "secret_get"}
|
||||
|
||||
def test_handle_tools_call_store(self, server):
|
||||
result = server.handle_tools_call({
|
||||
"name": "memory_store",
|
||||
"arguments": {
|
||||
"content": "test memory",
|
||||
"expanded_keywords": "test memory keywords",
|
||||
},
|
||||
})
|
||||
assert not result.get("isError", False)
|
||||
assert "Stored memory" in result["content"][0]["text"]
|
||||
|
||||
def test_handle_tools_call_unknown(self, server):
|
||||
result = server.handle_tools_call({
|
||||
"name": "nonexistent_tool",
|
||||
"arguments": {},
|
||||
})
|
||||
assert result["isError"] is True
|
||||
assert "Unknown tool" in result["content"][0]["text"]
|
||||
|
||||
def test_handle_tools_call_error(self, server):
|
||||
result = server.handle_tools_call({
|
||||
"name": "memory_store",
|
||||
"arguments": {}, # missing content
|
||||
})
|
||||
assert result["isError"] is True
|
||||
assert "Error" in result["content"][0]["text"]
|
||||
|
||||
|
||||
class TestProcessMessage:
|
||||
def test_initialize(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {},
|
||||
})
|
||||
assert response["jsonrpc"] == "2.0"
|
||||
assert response["id"] == 1
|
||||
assert "result" in response
|
||||
assert response["result"]["serverInfo"]["name"] == SERVER_NAME
|
||||
|
||||
def test_tools_list(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/list",
|
||||
"params": {},
|
||||
})
|
||||
assert "result" in response
|
||||
assert len(response["result"]["tools"]) == 5
|
||||
|
||||
def test_tools_call(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "memory_store",
|
||||
"arguments": {
|
||||
"content": "via process_message",
|
||||
"expanded_keywords": "process message test",
|
||||
},
|
||||
},
|
||||
})
|
||||
assert "result" in response
|
||||
assert "Stored memory" in response["result"]["content"][0]["text"]
|
||||
|
||||
def test_unknown_method(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "unknown/method",
|
||||
"params": {},
|
||||
})
|
||||
assert "error" in response
|
||||
assert response["error"]["code"] == -32601
|
||||
assert "Method not found" in response["error"]["message"]
|
||||
|
||||
def test_notification_no_id(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized",
|
||||
"params": {},
|
||||
})
|
||||
assert response is None
|
||||
|
||||
def test_jsonrpc_response_format(self, server):
|
||||
response = server.process_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"method": "initialize",
|
||||
"params": {},
|
||||
})
|
||||
# Verify it's valid JSON when serialized
|
||||
serialized = json.dumps(response)
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed["jsonrpc"] == "2.0"
|
||||
assert parsed["id"] == 5
|
||||
154
tests/test_vault_client.py
Normal file
154
tests/test_vault_client.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
"""Tests for Vault KV v2 client with mocked urllib."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_memory.vault_client import VaultClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vault_env(monkeypatch):
|
||||
monkeypatch.setenv("VAULT_ADDR", "http://vault.example.com:8200")
|
||||
monkeypatch.setenv("VAULT_TOKEN", "s.testtoken123")
|
||||
|
||||
|
||||
class TestVaultClientInit:
|
||||
def test_missing_addr_raises_value_error(self, monkeypatch):
|
||||
monkeypatch.delenv("VAULT_ADDR", raising=False)
|
||||
monkeypatch.delenv("VAULT_TOKEN", raising=False)
|
||||
with pytest.raises(ValueError, match="Vault address not configured"):
|
||||
VaultClient()
|
||||
|
||||
def test_init_with_explicit_args(self):
|
||||
client = VaultClient(addr="http://localhost:8200", token="mytoken")
|
||||
assert client.addr == "http://localhost:8200"
|
||||
assert client.token == "mytoken"
|
||||
assert client.mount == "secret"
|
||||
|
||||
def test_init_from_env(self, vault_env):
|
||||
client = VaultClient()
|
||||
assert client.addr == "http://vault.example.com:8200"
|
||||
assert client.token == "s.testtoken123"
|
||||
|
||||
def test_addr_trailing_slash_stripped(self):
|
||||
client = VaultClient(addr="http://localhost:8200/", token="t")
|
||||
assert client.addr == "http://localhost:8200"
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", mock_open(read_data="fake-jwt-token"))
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_kubernetes_sa_token_auto_detection(self, mock_urlopen, mock_exists, monkeypatch):
|
||||
monkeypatch.setenv("VAULT_ADDR", "http://vault:8200")
|
||||
monkeypatch.delenv("VAULT_TOKEN", raising=False)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = json.dumps({
|
||||
"auth": {"client_token": "s.k8s-token-abc"}
|
||||
}).encode()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
client = VaultClient()
|
||||
assert client.token == "s.k8s-token-abc"
|
||||
|
||||
|
||||
class TestVaultRead:
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_read_secret_returns_data(self, mock_urlopen, vault_env):
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = json.dumps({
|
||||
"data": {"data": {"username": "admin", "password": "secret"}}
|
||||
}).encode()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
client = VaultClient()
|
||||
result = client.read("myapp/config")
|
||||
assert result == {"username": "admin", "password": "secret"}
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_read_returns_none_for_404(self, mock_urlopen, vault_env):
|
||||
import urllib.error
|
||||
mock_urlopen.side_effect = urllib.error.HTTPError(
|
||||
url="http://vault:8200/v1/secret/data/missing",
|
||||
code=404,
|
||||
msg="Not Found",
|
||||
hdrs={},
|
||||
fp=BytesIO(b""),
|
||||
)
|
||||
|
||||
client = VaultClient()
|
||||
result = client.read("missing/path")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestVaultWrite:
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_write_secret_sends_correct_request(self, mock_urlopen, vault_env):
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = json.dumps({
|
||||
"data": {"created_time": "2024-01-01T00:00:00Z", "version": 1}
|
||||
}).encode()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
client = VaultClient()
|
||||
result = client.write("myapp/config", {"key": "value"})
|
||||
|
||||
# Verify the request was made with correct data
|
||||
call_args = mock_urlopen.call_args
|
||||
request = call_args[0][0]
|
||||
assert request.full_url == "http://vault.example.com:8200/v1/secret/data/myapp/config"
|
||||
assert request.method == "POST"
|
||||
body = json.loads(request.data.decode())
|
||||
assert body == {"data": {"key": "value"}}
|
||||
|
||||
|
||||
class TestVaultDelete:
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_delete_returns_true_on_success(self, mock_urlopen, vault_env):
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b"{}"
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
client = VaultClient()
|
||||
assert client.delete("myapp/config") is True
|
||||
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_delete_returns_false_on_error(self, mock_urlopen, vault_env):
|
||||
import urllib.error
|
||||
mock_urlopen.side_effect = urllib.error.HTTPError(
|
||||
url="http://vault:8200/v1/secret/data/missing",
|
||||
code=500,
|
||||
msg="Internal Server Error",
|
||||
hdrs={},
|
||||
fp=BytesIO(b"error"),
|
||||
)
|
||||
|
||||
client = VaultClient()
|
||||
assert client.delete("missing/path") is False
|
||||
|
||||
|
||||
class TestVaultListSecrets:
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_list_secrets(self, mock_urlopen, vault_env):
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = json.dumps({
|
||||
"data": {"keys": ["secret1", "secret2/"]}
|
||||
}).encode()
|
||||
mock_response.__enter__ = lambda s: s
|
||||
mock_response.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
client = VaultClient()
|
||||
result = client.list_secrets("myapp")
|
||||
assert result == ["secret1", "secret2/"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue