Initial commit: event-driven UK payslip ingest service
Extracted from /home/wizard/code monorepo into its own repo so Woodpecker CI can watch it. Identical content to /home/wizard/code commit e426028. See README.md for overview, env vars, and Paperless workflow config. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
commit
57484619c1
27 changed files with 2878 additions and 0 deletions
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
__pycache__/
|
||||
*.pyc
|
||||
.venv/
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
dist/
|
||||
*.egg-info/
|
||||
.coverage
|
||||
htmlcov/
|
||||
30
.woodpecker.yml
Normal file
30
.woodpecker.yml
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
when:
|
||||
event: push
|
||||
path: "payslip-ingest/**"
|
||||
|
||||
clone:
|
||||
git:
|
||||
image: woodpeckerci/plugin-git
|
||||
settings:
|
||||
attempts: 5
|
||||
backoff: 10s
|
||||
|
||||
steps:
|
||||
- name: build-and-push
|
||||
image: woodpeckerci/plugin-docker-buildx
|
||||
settings:
|
||||
username: "viktorbarzin"
|
||||
password:
|
||||
from_secret: dockerhub-pat
|
||||
repo:
|
||||
- registry.viktorbarzin.me/payslip-ingest
|
||||
logins:
|
||||
- registry: registry.viktorbarzin.me
|
||||
username: viktorbarzin
|
||||
password:
|
||||
from_secret: registry-password
|
||||
dockerfile: payslip-ingest/Dockerfile
|
||||
context: payslip-ingest
|
||||
auto_tag: true
|
||||
platforms:
|
||||
- linux/amd64
|
||||
33
Dockerfile
Normal file
33
Dockerfile
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
FROM python:3.12-slim AS builder
|
||||
|
||||
ENV POETRY_VERSION=1.8.4 \
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
PIP_NO_CACHE_DIR=1
|
||||
|
||||
RUN pip install --no-cache-dir "poetry==${POETRY_VERSION}"
|
||||
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml poetry.lock* README.md ./
|
||||
RUN poetry install --only main --no-root
|
||||
|
||||
COPY payslip_ingest ./payslip_ingest
|
||||
COPY alembic ./alembic
|
||||
COPY alembic.ini ./alembic.ini
|
||||
RUN poetry install --only main
|
||||
|
||||
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN useradd --system --uid 10001 --home /app --shell /usr/sbin/nologin payslip
|
||||
|
||||
COPY --from=builder --chown=payslip:payslip /app /app
|
||||
|
||||
ENV PATH="/app/.venv/bin:${PATH}" \
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
EXPOSE 8080
|
||||
USER payslip
|
||||
ENTRYPOINT ["python", "-m", "payslip_ingest"]
|
||||
CMD ["serve"]
|
||||
78
README.md
Normal file
78
README.md
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# payslip-ingest
|
||||
|
||||
Event-driven UK payslip ingest: Paperless-ngx fires a webhook when a document
|
||||
is tagged `payslip`; this service fetches the PDF, calls `claude-agent-service`
|
||||
to extract structured fields, and upserts into Postgres keyed by
|
||||
`paperless_doc_id` (idempotent). A CLI `backfill` mode enumerates every
|
||||
existing payslip in Paperless for initial population.
|
||||
|
||||
## Local dev
|
||||
|
||||
```bash
|
||||
poetry install
|
||||
poetry run pytest -q
|
||||
poetry run mypy .
|
||||
poetry run ruff check .
|
||||
|
||||
# Smoke-test extraction against a real PDF (no DB writes):
|
||||
export CLAUDE_AGENT_URL=http://claude-agent-service.claude-agent.svc.cluster.local:8080
|
||||
export CLAUDE_AGENT_BEARER_TOKEN=...
|
||||
poetry run python -m payslip_ingest extract-one /tmp/sample.pdf
|
||||
```
|
||||
|
||||
## Env vars
|
||||
|
||||
| Variable | Purpose |
|
||||
|---|---|
|
||||
| `PAPERLESS_URL` | Paperless-ngx base URL (e.g. `https://paperless.viktorbarzin.me`) |
|
||||
| `PAPERLESS_API_TOKEN` | Paperless API token (User → My Profile → API Auth Token) |
|
||||
| `CLAUDE_AGENT_URL` | claude-agent-service URL (`http://claude-agent-service.claude-agent.svc.cluster.local:8080`) |
|
||||
| `CLAUDE_AGENT_BEARER_TOKEN` | Vault `secret/claude-agent-service` → `api_bearer_token` |
|
||||
| `DB_CONNECTION_STRING` | SQLAlchemy async URL: `postgresql+asyncpg://user:pass@host/db` |
|
||||
| `WEBHOOK_BEARER_TOKEN` | Shared secret Paperless sends in `Authorization: Bearer ...` |
|
||||
|
||||
## Paperless workflow configuration
|
||||
|
||||
In Paperless-ngx, create a workflow:
|
||||
|
||||
- **Name**: `payslip-ingest`
|
||||
- **Trigger**: Document Added, matching tag `payslip`
|
||||
- **Action type**: Webhook
|
||||
- **URL**: `http://payslip-ingest.payslip-ingest.svc.cluster.local:8080/webhook`
|
||||
- **Method**: `POST`
|
||||
- **Headers**:
|
||||
- `Authorization: Bearer <WEBHOOK_BEARER_TOKEN>`
|
||||
- `Content-Type: application/json`
|
||||
- **Body** (template):
|
||||
```json
|
||||
{"document_id": {{ document_id }}}
|
||||
```
|
||||
|
||||
## Deployment
|
||||
|
||||
Ship to the `payslip-ingest` namespace. The service serializes incoming
|
||||
webhooks onto an in-process queue so it never collides with
|
||||
`claude-agent-service`'s single-job lock.
|
||||
|
||||
Run the initial backfill once the deployment is live:
|
||||
|
||||
```bash
|
||||
kubectl -n payslip-ingest create job \
|
||||
--from=deployment/payslip-ingest \
|
||||
payslip-backfill-$(date +%s) \
|
||||
-- python -m payslip_ingest backfill --all
|
||||
```
|
||||
|
||||
## Architecture notes
|
||||
|
||||
- `extract-one` never touches the DB — safe for ad-hoc re-extraction on disk.
|
||||
- The `backfill` command is idempotent (skips rows whose `paperless_doc_id`
|
||||
already exists) so it can be re-run freely.
|
||||
- Totals validation is a best-effort sanity check; mismatches are stored with
|
||||
`validated=false` and `raw_extraction` retained for manual review, rather
|
||||
than rejected.
|
||||
- The agent service is **single-threaded**. The webhook handler enqueues and
|
||||
returns 202; a single background worker drains the queue one at a time and
|
||||
absorbs 409-busy responses from the agent with retry-with-backoff.
|
||||
- New agent prompt lives at `.claude/agents/payslip-extractor` in the `infra`
|
||||
repo — this is a separate deliverable (see TODOs).
|
||||
37
alembic.ini
Normal file
37
alembic.ini
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
[alembic]
|
||||
script_location = alembic
|
||||
sqlalchemy.url = placeholder
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
57
alembic/env.py
Normal file
57
alembic/env.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import asyncio
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
from payslip_ingest.db import SCHEMA_NAME, Base
|
||||
|
||||
config = context.config
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
db_url = os.environ.get("DB_CONNECTION_STRING")
|
||||
if db_url:
|
||||
config.set_main_option("sqlalchemy.url", db_url)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
version_table_schema=SCHEMA_NAME,
|
||||
include_schemas=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online() -> None:
|
||||
configuration = config.get_section(config.config_ini_section, {})
|
||||
connectable = async_engine_from_config(configuration, prefix="sqlalchemy.")
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
context.configure(
|
||||
url=config.get_main_option("sqlalchemy.url"),
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
version_table_schema=SCHEMA_NAME,
|
||||
include_schemas=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
asyncio.run(run_migrations_online())
|
||||
25
alembic/script.py.mako
Normal file
25
alembic/script.py.mako
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
72
alembic/versions/0001_initial.py
Normal file
72
alembic/versions/0001_initial.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""initial schema
|
||||
|
||||
Revision ID: 0001
|
||||
Revises:
|
||||
Create Date: 2026-04-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "0001"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
SCHEMA = "payslip_ingest"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(f"CREATE SCHEMA IF NOT EXISTS {SCHEMA}")
|
||||
|
||||
op.create_table(
|
||||
"payslip",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("paperless_doc_id", sa.Integer(), nullable=False, unique=True),
|
||||
sa.Column("pay_date", sa.Date(), nullable=False),
|
||||
sa.Column("pay_period_start", sa.Date(), nullable=True),
|
||||
sa.Column("pay_period_end", sa.Date(), nullable=True),
|
||||
sa.Column("employer", sa.Text(), nullable=True),
|
||||
sa.Column("currency", sa.CHAR(3), nullable=False, server_default="GBP"),
|
||||
sa.Column("gross_pay", sa.Numeric(12, 2), nullable=False),
|
||||
sa.Column("income_tax", sa.Numeric(12, 2), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column(
|
||||
"national_insurance", sa.Numeric(12, 2), nullable=False, server_default=sa.text("0")
|
||||
),
|
||||
sa.Column(
|
||||
"pension_employee", sa.Numeric(12, 2), nullable=False, server_default=sa.text("0")
|
||||
),
|
||||
sa.Column(
|
||||
"pension_employer", sa.Numeric(12, 2), nullable=False, server_default=sa.text("0")
|
||||
),
|
||||
sa.Column("student_loan", sa.Numeric(12, 2), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("other_deductions", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("net_pay", sa.Numeric(12, 2), nullable=False),
|
||||
sa.Column("tax_year", sa.Text(), nullable=False),
|
||||
sa.Column("raw_extraction", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("validated", sa.Boolean(), nullable=False, server_default=sa.text("true")),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
schema=SCHEMA,
|
||||
)
|
||||
op.create_index(
|
||||
"idx_payslip_pay_date", "payslip", ["pay_date"], schema=SCHEMA
|
||||
)
|
||||
op.create_index(
|
||||
"idx_payslip_tax_year", "payslip", ["tax_year"], schema=SCHEMA
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_payslip_tax_year", table_name="payslip", schema=SCHEMA)
|
||||
op.drop_index("idx_payslip_pay_date", table_name="payslip", schema=SCHEMA)
|
||||
op.drop_table("payslip", schema=SCHEMA)
|
||||
op.execute(f"DROP SCHEMA IF EXISTS {SCHEMA}")
|
||||
1
payslip_ingest/__init__.py
Normal file
1
payslip_ingest/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
__version__ = "0.1.0"
|
||||
101
payslip_ingest/__main__.py
Normal file
101
payslip_ingest/__main__.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import uvicorn
|
||||
|
||||
from payslip_ingest.db import create_engine_from_env, make_session_factory
|
||||
from payslip_ingest.extractor import ClaudeExtractor
|
||||
from payslip_ingest.paperless import PaperlessClient
|
||||
from payslip_ingest.processor import process_document
|
||||
from payslip_ingest.schema import validate_totals
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli() -> None:
|
||||
logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
|
||||
|
||||
|
||||
@cli.command()
|
||||
def serve() -> None:
|
||||
"""Run the webhook HTTP server (K8s entrypoint)."""
|
||||
uvicorn.run("payslip_ingest.app:app", host="0.0.0.0", port=8080)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--all", "process_all", is_flag=True, help="Process every payslip-tagged doc.")
|
||||
@click.option("--limit", type=int, default=None, help="Cap the number of documents processed.")
|
||||
@click.option("--tag", default="payslip", help="Paperless tag name to enumerate.")
|
||||
def backfill(process_all: bool, limit: int | None, tag: str) -> None:
|
||||
"""Enumerate every payslip-tagged Paperless doc and process sequentially."""
|
||||
if not process_all:
|
||||
raise click.UsageError("pass --all to opt in to the full enumeration")
|
||||
asyncio.run(_backfill(tag, limit))
|
||||
|
||||
|
||||
async def _backfill(tag: str, limit: int | None) -> None:
|
||||
engine = create_engine_from_env()
|
||||
session_factory = make_session_factory(engine)
|
||||
paperless = PaperlessClient(
|
||||
base_url=os.environ["PAPERLESS_URL"],
|
||||
api_token=os.environ["PAPERLESS_API_TOKEN"],
|
||||
)
|
||||
extractor = ClaudeExtractor(
|
||||
base_url=os.environ["CLAUDE_AGENT_URL"],
|
||||
bearer_token=os.environ["CLAUDE_AGENT_BEARER_TOKEN"],
|
||||
)
|
||||
processed = 0
|
||||
try:
|
||||
async for doc in paperless.list_tagged_documents(tag):
|
||||
if limit is not None and processed >= limit:
|
||||
break
|
||||
doc_id = int(doc["id"])
|
||||
result = await process_document(doc_id, session_factory, paperless, extractor)
|
||||
click.echo(f"doc_id={doc_id} status={result.status} validated={result.validated}")
|
||||
processed += 1
|
||||
finally:
|
||||
await paperless.aclose()
|
||||
await extractor.aclose()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@cli.command("extract-one")
|
||||
@click.argument("path", type=click.Path(exists=True, dir_okay=False, path_type=Path))
|
||||
def extract_one(path: Path) -> None:
|
||||
"""Smoke-test extraction on a local PDF — no DB writes."""
|
||||
asyncio.run(_extract_one(path))
|
||||
|
||||
|
||||
async def _extract_one(path: Path) -> None:
|
||||
pdf_bytes = path.read_bytes()
|
||||
extractor = ClaudeExtractor(
|
||||
base_url=os.environ["CLAUDE_AGENT_URL"],
|
||||
bearer_token=os.environ["CLAUDE_AGENT_BEARER_TOKEN"],
|
||||
)
|
||||
try:
|
||||
extracted = await extractor.extract(pdf_bytes, {"id": None, "source": str(path)})
|
||||
finally:
|
||||
await extractor.aclose()
|
||||
click.echo(extracted.model_dump_json(indent=2))
|
||||
ok = validate_totals(extracted)
|
||||
click.echo(json.dumps({"totals_validated": ok}))
|
||||
if not ok:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def migrate() -> None:
|
||||
"""Run `alembic upgrade head`."""
|
||||
result = subprocess.run(["alembic", "upgrade", "head"], check=False)
|
||||
sys.exit(result.returncode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
118
payslip_ingest/app.py
Normal file
118
payslip_ingest/app.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, status
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from payslip_ingest.db import create_engine_from_env, make_session_factory
|
||||
from payslip_ingest.extractor import ClaudeExtractor
|
||||
from payslip_ingest.paperless import PaperlessClient
|
||||
from payslip_ingest.processor import process_document
|
||||
from payslip_ingest.schema import WebhookPayload
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
REQUIRED_ENV = [
|
||||
"PAPERLESS_URL",
|
||||
"PAPERLESS_API_TOKEN",
|
||||
"CLAUDE_AGENT_URL",
|
||||
"CLAUDE_AGENT_BEARER_TOKEN",
|
||||
"DB_CONNECTION_STRING",
|
||||
"WEBHOOK_BEARER_TOKEN",
|
||||
]
|
||||
|
||||
# Type alias for the processor function — makes monkeypatching in tests explicit.
|
||||
ProcessorFn = Callable[
|
||||
[int, async_sessionmaker[Any], PaperlessClient, ClaudeExtractor],
|
||||
Awaitable[Any],
|
||||
]
|
||||
|
||||
|
||||
def _verify_env() -> None:
|
||||
missing = [k for k in REQUIRED_ENV if not os.environ.get(k)]
|
||||
if missing:
|
||||
raise RuntimeError(f"Missing required env vars: {', '.join(missing)}")
|
||||
|
||||
|
||||
def _verify_bearer(authorization: str | None, expected: str) -> None:
|
||||
if not expected:
|
||||
raise HTTPException(status_code=401, detail="Service unauthenticated")
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing bearer token")
|
||||
token = authorization.removeprefix("Bearer ")
|
||||
if not hmac.compare_digest(token, expected):
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
_verify_env()
|
||||
|
||||
engine = create_engine_from_env()
|
||||
session_factory = make_session_factory(engine)
|
||||
paperless = PaperlessClient(
|
||||
base_url=os.environ["PAPERLESS_URL"],
|
||||
api_token=os.environ["PAPERLESS_API_TOKEN"],
|
||||
)
|
||||
extractor = ClaudeExtractor(
|
||||
base_url=os.environ["CLAUDE_AGENT_URL"],
|
||||
bearer_token=os.environ["CLAUDE_AGENT_BEARER_TOKEN"],
|
||||
)
|
||||
queue: asyncio.Queue[int] = asyncio.Queue()
|
||||
|
||||
processor: ProcessorFn = app.state.__dict__.get("processor_fn", process_document)
|
||||
|
||||
async def worker() -> None:
|
||||
while True:
|
||||
doc_id = await queue.get()
|
||||
try:
|
||||
await processor(doc_id, session_factory, paperless, extractor)
|
||||
except Exception:
|
||||
log.exception("processing failed for doc_id=%s", doc_id)
|
||||
finally:
|
||||
queue.task_done()
|
||||
|
||||
worker_task = asyncio.create_task(worker())
|
||||
app.state.queue = queue
|
||||
app.state.session_factory = session_factory
|
||||
app.state.paperless = paperless
|
||||
app.state.extractor = extractor
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
worker_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await worker_task
|
||||
await paperless.aclose()
|
||||
await extractor.aclose()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
app = FastAPI(title="Payslip Ingest", lifespan=lifespan)
|
||||
Instrumentator().instrument(app).expose(app, endpoint="/metrics")
|
||||
|
||||
|
||||
@app.post("/webhook", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def webhook(
|
||||
payload: WebhookPayload,
|
||||
authorization: str | None = Header(default=None),
|
||||
) -> dict[str, Any]:
|
||||
_verify_bearer(authorization, os.environ.get("WEBHOOK_BEARER_TOKEN", ""))
|
||||
queue: asyncio.Queue[int] = app.state.queue
|
||||
await queue.put(payload.document_id)
|
||||
return {"status": "accepted", "document_id": payload.document_id}
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz() -> dict[str, Any]:
|
||||
queue: asyncio.Queue[int] | None = getattr(app.state, "queue", None)
|
||||
depth = queue.qsize() if queue is not None else 0
|
||||
return {"status": "ok", "queue_depth": depth}
|
||||
65
payslip_ingest/db.py
Normal file
65
payslip_ingest/db.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import os
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, TIMESTAMP, Boolean, Date, Integer, Numeric, String, text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
SCHEMA_NAME = "payslip_ingest"
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
# JSONB on Postgres, plain JSON (as text) on SQLite — tests use SQLite, prod uses Postgres.
|
||||
JSON_TYPE = JSONB().with_variant(JSON(), "sqlite")
|
||||
|
||||
|
||||
class Payslip(Base):
|
||||
__tablename__ = "payslip"
|
||||
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
paperless_doc_id: Mapped[int] = mapped_column(Integer, unique=True, nullable=False)
|
||||
pay_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
pay_period_start: Mapped[date | None] = mapped_column(Date, nullable=True)
|
||||
pay_period_end: Mapped[date | None] = mapped_column(Date, nullable=True)
|
||||
employer: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
currency: Mapped[str] = mapped_column(String(3), nullable=False, server_default="GBP")
|
||||
gross_pay: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
|
||||
income_tax: Mapped[Decimal] = mapped_column(Numeric(12, 2),
|
||||
nullable=False,
|
||||
server_default=text("0"))
|
||||
national_insurance: Mapped[Decimal] = mapped_column(Numeric(12, 2),
|
||||
nullable=False,
|
||||
server_default=text("0"))
|
||||
pension_employee: Mapped[Decimal] = mapped_column(Numeric(12, 2),
|
||||
nullable=False,
|
||||
server_default=text("0"))
|
||||
pension_employer: Mapped[Decimal] = mapped_column(Numeric(12, 2),
|
||||
nullable=False,
|
||||
server_default=text("0"))
|
||||
student_loan: Mapped[Decimal] = mapped_column(Numeric(12, 2),
|
||||
nullable=False,
|
||||
server_default=text("0"))
|
||||
other_deductions: Mapped[dict[str, Any] | None] = mapped_column(JSON_TYPE, nullable=True)
|
||||
net_pay: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
|
||||
tax_year: Mapped[str] = mapped_column(String, nullable=False)
|
||||
raw_extraction: Mapped[dict[str, Any]] = mapped_column(JSON_TYPE, nullable=False)
|
||||
validated: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default=text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text("now()"))
|
||||
|
||||
|
||||
def create_engine_from_env() -> AsyncEngine:
|
||||
url = os.environ["DB_CONNECTION_STRING"]
|
||||
return create_async_engine(url, pool_pre_ping=True)
|
||||
|
||||
|
||||
def make_session_factory(engine: AsyncEngine) -> async_sessionmaker[Any]:
|
||||
return async_sessionmaker(engine, expire_on_commit=False)
|
||||
215
payslip_ingest/extractor.py
Normal file
215
payslip_ingest/extractor.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import ValidationError
|
||||
|
||||
from payslip_ingest.schema import ExtractedPayslip
|
||||
|
||||
AGENT_PATH = ".claude/agents/payslip-extractor"
|
||||
|
||||
EXTRACTION_PROMPT = (
|
||||
"You are extracting fields from a UK payslip PDF. Return ONLY a single JSON object "
|
||||
"matching this exact schema — no prose, no markdown fences.\n"
|
||||
"\n"
|
||||
"Schema:\n"
|
||||
"{\n"
|
||||
' "pay_date": "YYYY-MM-DD",\n'
|
||||
' "pay_period_start": "YYYY-MM-DD or null",\n'
|
||||
' "pay_period_end": "YYYY-MM-DD or null",\n'
|
||||
' "employer": "string or null",\n'
|
||||
' "currency": "GBP",\n'
|
||||
' "gross_pay": number,\n'
|
||||
' "income_tax": number,\n'
|
||||
' "national_insurance": number,\n'
|
||||
' "pension_employee": number,\n'
|
||||
' "pension_employer": number,\n'
|
||||
' "student_loan": number,\n'
|
||||
' "other_deductions": {"label": number, ...},\n'
|
||||
' "net_pay": number\n'
|
||||
"}\n"
|
||||
"\n"
|
||||
"Rules:\n"
|
||||
"- Report numbers as the payslip shows them; do not compute sums.\n"
|
||||
"- Unknown numeric fields → 0, not null.\n"
|
||||
"- `other_deductions` covers cycle-to-work, share-save, benefits-in-kind, court orders, "
|
||||
"anything not in the main fields.\n"
|
||||
"- All money in GBP unless the payslip is denominated otherwise.\n"
|
||||
'- If a field\'s value is ambiguous, pick the value from the "this period" column, not YTD.')
|
||||
|
||||
POLL_INTERVAL_SECONDS = 2
|
||||
MAX_POLL_SECONDS = 120
|
||||
BUSY_RETRY_DELAY_SECONDS = 5
|
||||
MAX_BUSY_RETRIES = 10
|
||||
DEFAULT_MAX_BUDGET_USD = 1.0
|
||||
DEFAULT_TIMEOUT_SECONDS = 300
|
||||
|
||||
TERMINAL_STATUSES = {"completed", "failed", "timeout", "error"}
|
||||
|
||||
|
||||
class ExtractorError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class ClaudeExtractor:
|
||||
"""Calls claude-agent-service to extract structured fields from a payslip PDF.
|
||||
|
||||
The agent service serializes execution (one job at a time, 409 when busy);
|
||||
we back off and retry so the caller-side queue doesn't have to know.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
bearer_token: str,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
):
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._headers = {"Authorization": f"Bearer {bearer_token}"}
|
||||
self._client = client or httpx.AsyncClient(timeout=60.0)
|
||||
self._owns_client = client is None
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._owns_client:
|
||||
await self._client.aclose()
|
||||
|
||||
async def __aenter__(self) -> "ClaudeExtractor":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc: object) -> None:
|
||||
await self.aclose()
|
||||
|
||||
async def extract(self, pdf_bytes: bytes, doc_metadata: dict[str, Any]) -> ExtractedPayslip:
|
||||
job_id = await self._submit_job(pdf_bytes, doc_metadata)
|
||||
output_lines = await self._poll_until_done(job_id)
|
||||
payload = _parse_output(output_lines)
|
||||
try:
|
||||
return ExtractedPayslip.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
raise ExtractorError(f"Extracted payload failed schema validation: {exc}") from exc
|
||||
|
||||
async def _submit_job(self, pdf_bytes: bytes, doc_metadata: dict[str, Any]) -> str:
|
||||
encoded = base64.b64encode(pdf_bytes).decode("ascii")
|
||||
prompt = f"{EXTRACTION_PROMPT}\n\nPDF_BASE64:\n{encoded}\n"
|
||||
body = {
|
||||
"prompt": prompt,
|
||||
"agent": AGENT_PATH,
|
||||
"max_budget_usd": DEFAULT_MAX_BUDGET_USD,
|
||||
"timeout_seconds": DEFAULT_TIMEOUT_SECONDS,
|
||||
"metadata": {
|
||||
"paperless_doc_id": doc_metadata.get("id")
|
||||
},
|
||||
}
|
||||
for _ in range(MAX_BUSY_RETRIES):
|
||||
resp = await self._client.post(f"{self._base_url}/execute",
|
||||
headers=self._headers,
|
||||
json=body)
|
||||
if resp.status_code == 409:
|
||||
await asyncio.sleep(BUSY_RETRY_DELAY_SECONDS)
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
job_id = resp.json().get("job_id")
|
||||
if not isinstance(job_id, str):
|
||||
raise ExtractorError(f"Missing job_id in response: {resp.json()}")
|
||||
return job_id
|
||||
raise ExtractorError(f"Agent service remained busy after {MAX_BUSY_RETRIES} retries")
|
||||
|
||||
async def _poll_until_done(self, job_id: str) -> list[str]:
|
||||
max_iterations = max(1, MAX_POLL_SECONDS // max(1, POLL_INTERVAL_SECONDS))
|
||||
for _ in range(max_iterations):
|
||||
resp = await self._client.get(f"{self._base_url}/jobs/{job_id}", headers=self._headers)
|
||||
resp.raise_for_status()
|
||||
job = resp.json()
|
||||
status = job.get("status")
|
||||
if status in TERMINAL_STATUSES:
|
||||
if status != "completed":
|
||||
raise ExtractorError(f"Job {job_id} terminated with status={status}: {job}")
|
||||
output = job.get("output", [])
|
||||
if not isinstance(output, list):
|
||||
raise ExtractorError(f"Job {job_id} output is not a list: {output!r}")
|
||||
return [str(line) for line in output]
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
raise TimeoutError(f"Job {job_id} did not complete within {MAX_POLL_SECONDS}s")
|
||||
|
||||
|
||||
def _parse_output(output_lines: list[str]) -> dict[str, Any]:
|
||||
"""Extract the JSON payload from claude CLI --output-format json stream.
|
||||
|
||||
The CLI emits one JSON object per line; the final 'result' message holds the
|
||||
assistant's final text. We walk from the end, parse each line, and return
|
||||
the first embedded JSON object we can recover from the assistant response.
|
||||
"""
|
||||
non_empty = [line.strip() for line in output_lines if line.strip()]
|
||||
if not non_empty:
|
||||
raise ExtractorError("Agent produced no output")
|
||||
|
||||
for line in reversed(non_empty):
|
||||
try:
|
||||
parsed = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
text = _extract_assistant_text(parsed)
|
||||
if text is None:
|
||||
continue
|
||||
payload = _first_json_object(text)
|
||||
if payload is not None:
|
||||
return payload
|
||||
|
||||
# Fallback: the last line itself might be the JSON object.
|
||||
try:
|
||||
candidate = json.loads(non_empty[-1])
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ExtractorError(f"Could not parse JSON from agent output: {exc}") from exc
|
||||
if isinstance(candidate, dict):
|
||||
return candidate
|
||||
raise ExtractorError(f"Last agent line is not a JSON object: {candidate!r}")
|
||||
|
||||
|
||||
def _extract_assistant_text(parsed: Any) -> str | None:
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
result = parsed.get("result")
|
||||
if parsed.get("type") == "result" and isinstance(result, str):
|
||||
return result
|
||||
message = parsed.get("message")
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
texts = [
|
||||
block.get("text", "") for block in content
|
||||
if isinstance(block, dict) and block.get("type") == "text"
|
||||
]
|
||||
combined = "".join(str(t) for t in texts)
|
||||
if combined:
|
||||
return combined
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
text = parsed.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return None
|
||||
|
||||
|
||||
def _first_json_object(text: str) -> dict[str, Any] | None:
|
||||
start = text.find("{")
|
||||
while start != -1:
|
||||
depth = 0
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
candidate = text[start:i + 1]
|
||||
try:
|
||||
obj = json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
break
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
break
|
||||
start = text.find("{", start + 1)
|
||||
return None
|
||||
74
payslip_ingest/paperless.py
Normal file
74
payslip_ingest/paperless.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class PaperlessError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class PaperlessClient:
|
||||
"""Async client for Paperless-ngx REST API.
|
||||
|
||||
Auth uses a long-lived API token: Authorization: Token <token>.
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, api_token: str, client: httpx.AsyncClient | None = None):
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._headers = {"Authorization": f"Token {api_token}"}
|
||||
self._client = client or httpx.AsyncClient(timeout=60.0)
|
||||
self._owns_client = client is None
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._owns_client:
|
||||
await self._client.aclose()
|
||||
|
||||
async def __aenter__(self) -> "PaperlessClient":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc: object) -> None:
|
||||
await self.aclose()
|
||||
|
||||
async def get_document(self, doc_id: int) -> dict[str, Any]:
|
||||
resp = await self._client.get(f"{self._base_url}/api/documents/{doc_id}/",
|
||||
headers=self._headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if not isinstance(data, dict):
|
||||
raise PaperlessError(f"Unexpected document payload for {doc_id}: {type(data)}")
|
||||
return data
|
||||
|
||||
async def download_document(self, doc_id: int) -> bytes:
|
||||
resp = await self._client.get(f"{self._base_url}/api/documents/{doc_id}/download/",
|
||||
headers=self._headers)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def get_tag_id(self, tag_name: str) -> int:
|
||||
resp = await self._client.get(
|
||||
f"{self._base_url}/api/tags/",
|
||||
headers=self._headers,
|
||||
params={"name__iexact": tag_name},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
results = resp.json().get("results", [])
|
||||
if len(results) == 0:
|
||||
raise PaperlessError(f"No tag named {tag_name!r}")
|
||||
if len(results) > 1:
|
||||
raise PaperlessError(f"Multiple tags matched {tag_name!r}: {len(results)}")
|
||||
tag_id = results[0]["id"]
|
||||
if not isinstance(tag_id, int):
|
||||
raise PaperlessError(f"Tag id is not int: {tag_id!r}")
|
||||
return tag_id
|
||||
|
||||
async def list_tagged_documents(self, tag_name: str) -> AsyncIterator[dict[str, Any]]:
|
||||
tag_id = await self.get_tag_id(tag_name)
|
||||
next_url: str | None = f"{self._base_url}/api/documents/?tags__id={tag_id}"
|
||||
while next_url:
|
||||
resp = await self._client.get(next_url, headers=self._headers)
|
||||
resp.raise_for_status()
|
||||
page = resp.json()
|
||||
for item in page.get("results", []):
|
||||
yield item
|
||||
next_url = page.get("next")
|
||||
103
payslip_ingest/processor.py
Normal file
103
payslip_ingest/processor.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from typing import Any, Protocol
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from payslip_ingest.db import Payslip
|
||||
from payslip_ingest.extractor import ClaudeExtractor
|
||||
from payslip_ingest.paperless import PaperlessClient
|
||||
from payslip_ingest.schema import ExtractedPayslip, validate_totals
|
||||
from payslip_ingest.tax_year import derive_tax_year
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _SessionFactory(Protocol):
|
||||
|
||||
def __call__(self) -> Any:
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessResult:
|
||||
doc_id: int
|
||||
status: str
|
||||
payslip_id: int | None = None
|
||||
validated: bool | None = None
|
||||
|
||||
|
||||
async def process_document(
|
||||
doc_id: int,
|
||||
db_session_factory: async_sessionmaker[Any] | _SessionFactory,
|
||||
paperless: PaperlessClient,
|
||||
extractor: ClaudeExtractor,
|
||||
) -> ProcessResult:
|
||||
async with db_session_factory() as session:
|
||||
existing = await session.execute(
|
||||
select(Payslip.id).where(Payslip.paperless_doc_id == doc_id))
|
||||
if existing.scalar() is not None:
|
||||
log.info("skipping doc_id=%s — already ingested", doc_id)
|
||||
return ProcessResult(doc_id=doc_id, status="skipped")
|
||||
|
||||
metadata = await paperless.get_document(doc_id)
|
||||
pdf_bytes = await paperless.download_document(doc_id)
|
||||
extracted = await extractor.extract(pdf_bytes, metadata)
|
||||
|
||||
validated = validate_totals(extracted)
|
||||
if not validated:
|
||||
log.warning(
|
||||
"totals mismatch for doc_id=%s gross=%s net=%s — storing validated=False",
|
||||
doc_id,
|
||||
extracted.gross_pay,
|
||||
extracted.net_pay,
|
||||
)
|
||||
|
||||
payslip_id = await _insert_payslip(db_session_factory, doc_id, extracted, validated)
|
||||
status = "inserted" if payslip_id is not None else "skipped"
|
||||
return ProcessResult(doc_id=doc_id, status=status, payslip_id=payslip_id, validated=validated)
|
||||
|
||||
|
||||
async def _insert_payslip(
|
||||
db_session_factory: async_sessionmaker[Any] | _SessionFactory,
|
||||
doc_id: int,
|
||||
extracted: ExtractedPayslip,
|
||||
validated: bool,
|
||||
) -> int | None:
|
||||
raw = json.loads(extracted.model_dump_json())
|
||||
async with db_session_factory() as session, session.begin():
|
||||
existing = await session.execute(
|
||||
select(Payslip.id).where(Payslip.paperless_doc_id == doc_id))
|
||||
existing_id = existing.scalar()
|
||||
if existing_id is not None:
|
||||
return None
|
||||
|
||||
row = Payslip(
|
||||
paperless_doc_id=doc_id,
|
||||
pay_date=extracted.pay_date,
|
||||
pay_period_start=extracted.pay_period_start,
|
||||
pay_period_end=extracted.pay_period_end,
|
||||
employer=extracted.employer,
|
||||
currency=extracted.currency,
|
||||
gross_pay=extracted.gross_pay,
|
||||
income_tax=extracted.income_tax,
|
||||
national_insurance=extracted.national_insurance,
|
||||
pension_employee=extracted.pension_employee,
|
||||
pension_employer=extracted.pension_employer,
|
||||
student_loan=extracted.student_loan,
|
||||
other_deductions=_decimals_to_float(extracted.other_deductions),
|
||||
net_pay=extracted.net_pay,
|
||||
tax_year=derive_tax_year(extracted.pay_date),
|
||||
raw_extraction=raw,
|
||||
validated=validated,
|
||||
)
|
||||
session.add(row)
|
||||
await session.flush()
|
||||
return row.id
|
||||
|
||||
|
||||
def _decimals_to_float(mapping: dict[str, Decimal]) -> dict[str, float]:
|
||||
return {k: float(v) for k, v in mapping.items()}
|
||||
42
payslip_ingest/schema.py
Normal file
42
payslip_ingest/schema.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
TOTALS_TOLERANCE = Decimal("0.02")
|
||||
|
||||
|
||||
class ExtractedPayslip(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
pay_date: date
|
||||
pay_period_start: date | None = None
|
||||
pay_period_end: date | None = None
|
||||
employer: str | None = None
|
||||
currency: str = "GBP"
|
||||
gross_pay: Decimal
|
||||
income_tax: Decimal = Field(default=Decimal("0"))
|
||||
national_insurance: Decimal = Field(default=Decimal("0"))
|
||||
pension_employee: Decimal = Field(default=Decimal("0"))
|
||||
pension_employer: Decimal = Field(default=Decimal("0"))
|
||||
student_loan: Decimal = Field(default=Decimal("0"))
|
||||
other_deductions: dict[str, Decimal] = Field(default_factory=dict)
|
||||
net_pay: Decimal
|
||||
|
||||
|
||||
class WebhookPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
document_id: int
|
||||
|
||||
|
||||
def validate_totals(p: ExtractedPayslip) -> bool:
|
||||
"""Check that gross - deductions ≈ net within a 2p tolerance.
|
||||
|
||||
Employer pension is excluded — it never leaves the employer's books and
|
||||
doesn't affect take-home pay arithmetic.
|
||||
"""
|
||||
deductions = (p.income_tax + p.national_insurance + p.pension_employee + p.student_loan +
|
||||
sum(p.other_deductions.values(), start=Decimal("0")))
|
||||
diff = abs(p.gross_pay - deductions - p.net_pay)
|
||||
return diff < TOTALS_TOLERANCE
|
||||
16
payslip_ingest/tax_year.py
Normal file
16
payslip_ingest/tax_year.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from datetime import date
|
||||
|
||||
UK_TAX_YEAR_START_MONTH = 4
|
||||
UK_TAX_YEAR_START_DAY = 6
|
||||
|
||||
|
||||
def derive_tax_year(pay_date: date) -> str:
|
||||
"""Return the UK tax year label (e.g. "2026/27") for a given pay date.
|
||||
|
||||
UK tax years run 6 April to 5 April. A pay_date on 5 Apr belongs to the
|
||||
previous start-year; on 6 Apr it belongs to the current start-year.
|
||||
"""
|
||||
boundary = date(pay_date.year, UK_TAX_YEAR_START_MONTH, UK_TAX_YEAR_START_DAY)
|
||||
start_year = pay_date.year if pay_date >= boundary else pay_date.year - 1
|
||||
end_year_suffix = str(start_year + 1)[-2:]
|
||||
return f"{start_year}/{end_year_suffix}"
|
||||
1191
poetry.lock
generated
Normal file
1191
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
56
pyproject.toml
Normal file
56
pyproject.toml
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
[tool.poetry]
|
||||
name = "payslip-ingest"
|
||||
version = "0.1.0"
|
||||
description = "Event-driven UK payslip ingest from Paperless-ngx via claude-agent-service extraction"
|
||||
authors = ["Viktor Barzin <viktorbarzin@meta.com>"]
|
||||
readme = "README.md"
|
||||
packages = [{ include = "payslip_ingest" }]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.12,<3.13"
|
||||
fastapi = "^0.115"
|
||||
uvicorn = "^0.32"
|
||||
httpx = "^0.27"
|
||||
pydantic = "^2.9"
|
||||
sqlalchemy = { extras = ["asyncio"], version = "^2.0" }
|
||||
asyncpg = "^0.29"
|
||||
alembic = "^1.13"
|
||||
click = "^8.1"
|
||||
prometheus-fastapi-instrumentator = "^7.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^8.3"
|
||||
pytest-asyncio = "^0.23"
|
||||
mypy = "^1.11"
|
||||
ruff = "^0.6"
|
||||
yapf = "^0.43"
|
||||
respx = "^0.21"
|
||||
aiosqlite = "^0.20"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
strict = true
|
||||
files = ["payslip_ingest", "tests"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["respx.*"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py312"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "I", "UP", "B", "SIM", "RUF"]
|
||||
|
||||
[tool.yapf]
|
||||
based_on_style = "pep8"
|
||||
column_limit = 100
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
8
tests/conftest.py
Normal file
8
tests/conftest.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
import os
|
||||
|
||||
os.environ.setdefault("PAPERLESS_URL", "http://paperless.test")
|
||||
os.environ.setdefault("PAPERLESS_API_TOKEN", "test-paperless-token")
|
||||
os.environ.setdefault("CLAUDE_AGENT_URL", "http://agent.test")
|
||||
os.environ.setdefault("CLAUDE_AGENT_BEARER_TOKEN", "test-agent-token")
|
||||
os.environ.setdefault("DB_CONNECTION_STRING", "sqlite+aiosqlite:///:memory:")
|
||||
os.environ.setdefault("WEBHOOK_BEARER_TOKEN", "test-webhook-token")
|
||||
138
tests/test_extractor.py
Normal file
138
tests/test_extractor.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from payslip_ingest import extractor as extractor_module
|
||||
from payslip_ingest.extractor import ClaudeExtractor, ExtractorError
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _tighten_retries(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(extractor_module, "POLL_INTERVAL_SECONDS", 0)
|
||||
monkeypatch.setattr(extractor_module, "MAX_POLL_SECONDS", 1)
|
||||
monkeypatch.setattr(extractor_module, "BUSY_RETRY_DELAY_SECONDS", 0)
|
||||
|
||||
|
||||
def _sample_extraction() -> dict[str, object]:
|
||||
return {
|
||||
"pay_date": "2026-03-28",
|
||||
"pay_period_start": "2026-03-01",
|
||||
"pay_period_end": "2026-03-31",
|
||||
"employer": "Acme Ltd",
|
||||
"currency": "GBP",
|
||||
"gross_pay": 5000.0,
|
||||
"income_tax": 800.0,
|
||||
"national_insurance": 350.0,
|
||||
"pension_employee": 250.0,
|
||||
"pension_employer": 150.0,
|
||||
"student_loan": 100.0,
|
||||
"other_deductions": {
|
||||
"cycle_to_work": 50.0
|
||||
},
|
||||
"net_pay": 3450.0,
|
||||
}
|
||||
|
||||
|
||||
def _agent_output(payload: dict[str, object]) -> list[str]:
|
||||
"""Simulate claude CLI --output-format json stdout."""
|
||||
return [
|
||||
json.dumps({
|
||||
"type": "system",
|
||||
"subtype": "init"
|
||||
}) + "\n",
|
||||
json.dumps({
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": json.dumps(payload)
|
||||
}],
|
||||
},
|
||||
}) + "\n",
|
||||
json.dumps({
|
||||
"type": "result",
|
||||
"result": json.dumps(payload)
|
||||
}) + "\n",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client() -> ClaudeExtractor:
|
||||
return ClaudeExtractor(base_url="http://agent.test", bearer_token="tok")
|
||||
|
||||
|
||||
async def test_extract_happy_path(client: ClaudeExtractor) -> None:
|
||||
payload = _sample_extraction()
|
||||
with respx.mock(base_url="http://agent.test") as mock:
|
||||
mock.post("/execute").mock(
|
||||
return_value=httpx.Response(202, json={
|
||||
"job_id": "abc123",
|
||||
"status": "running"
|
||||
}))
|
||||
mock.get("/jobs/abc123").mock(return_value=httpx.Response(
|
||||
200, json={
|
||||
"status": "completed",
|
||||
"output": _agent_output(payload)
|
||||
}))
|
||||
extracted = await client.extract(b"PDFDATA", {"id": 42})
|
||||
assert float(extracted.gross_pay) == 5000.0
|
||||
assert extracted.employer == "Acme Ltd"
|
||||
|
||||
|
||||
async def test_extract_retries_on_409(client: ClaudeExtractor) -> None:
|
||||
payload = _sample_extraction()
|
||||
with respx.mock(base_url="http://agent.test") as mock:
|
||||
route = mock.post("/execute")
|
||||
route.side_effect = [
|
||||
httpx.Response(409, json={"detail": "busy"}),
|
||||
httpx.Response(202, json={"job_id": "abc123"}),
|
||||
]
|
||||
mock.get("/jobs/abc123").mock(return_value=httpx.Response(
|
||||
200, json={
|
||||
"status": "completed",
|
||||
"output": _agent_output(payload)
|
||||
}))
|
||||
extracted = await client.extract(b"PDFDATA", {"id": 42})
|
||||
assert extracted.net_pay.is_finite()
|
||||
assert route.call_count == 2
|
||||
|
||||
|
||||
async def test_extract_polling_timeout_raises(client: ClaudeExtractor) -> None:
|
||||
with respx.mock(base_url="http://agent.test") as mock:
|
||||
mock.post("/execute").mock(return_value=httpx.Response(202, json={"job_id": "abc123"}))
|
||||
mock.get("/jobs/abc123").mock(
|
||||
return_value=httpx.Response(200, json={
|
||||
"status": "running",
|
||||
"output": []
|
||||
}))
|
||||
with pytest.raises(TimeoutError):
|
||||
await client.extract(b"PDFDATA", {"id": 42})
|
||||
|
||||
|
||||
async def test_extract_malformed_json_raises(client: ClaudeExtractor) -> None:
|
||||
with respx.mock(base_url="http://agent.test") as mock:
|
||||
mock.post("/execute").mock(return_value=httpx.Response(202, json={"job_id": "abc123"}))
|
||||
mock.get("/jobs/abc123").mock(return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"status": "completed",
|
||||
"output": ["this is not json\n", "still not json\n"],
|
||||
},
|
||||
))
|
||||
with pytest.raises(ExtractorError):
|
||||
await client.extract(b"PDFDATA", {"id": 42})
|
||||
|
||||
|
||||
async def test_extract_failed_status_raises(client: ClaudeExtractor) -> None:
|
||||
with respx.mock(base_url="http://agent.test") as mock:
|
||||
mock.post("/execute").mock(return_value=httpx.Response(202, json={"job_id": "abc123"}))
|
||||
mock.get("/jobs/abc123").mock(return_value=httpx.Response(200,
|
||||
json={
|
||||
"status": "failed",
|
||||
"output": [],
|
||||
"exit_code": 1
|
||||
}))
|
||||
with pytest.raises(ExtractorError):
|
||||
await client.extract(b"PDFDATA", {"id": 42})
|
||||
96
tests/test_paperless.py
Normal file
96
tests/test_paperless.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from payslip_ingest.paperless import PaperlessClient, PaperlessError
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client() -> PaperlessClient:
|
||||
return PaperlessClient(base_url="http://paperless.test", api_token="tok")
|
||||
|
||||
|
||||
async def test_get_tag_id_happy_path(client: PaperlessClient) -> None:
|
||||
with respx.mock(base_url="http://paperless.test") as mock:
|
||||
mock.get("/api/tags/", params={
|
||||
"name__iexact": "payslip"
|
||||
}).mock(return_value=httpx.Response(200, json={"results": [{
|
||||
"id": 7,
|
||||
"name": "payslip"
|
||||
}]}))
|
||||
assert await client.get_tag_id("payslip") == 7
|
||||
|
||||
|
||||
async def test_get_tag_id_zero_results_raises(client: PaperlessClient) -> None:
|
||||
with respx.mock(base_url="http://paperless.test") as mock:
|
||||
mock.get("/api/tags/").mock(return_value=httpx.Response(200, json={"results": []}))
|
||||
with pytest.raises(PaperlessError):
|
||||
await client.get_tag_id("payslip")
|
||||
|
||||
|
||||
async def test_get_tag_id_many_results_raises(client: PaperlessClient) -> None:
|
||||
with respx.mock(base_url="http://paperless.test") as mock:
|
||||
mock.get("/api/tags/").mock(return_value=httpx.Response(
|
||||
200,
|
||||
json={"results": [{
|
||||
"id": 1,
|
||||
"name": "payslip"
|
||||
}, {
|
||||
"id": 2,
|
||||
"name": "Payslip"
|
||||
}]},
|
||||
))
|
||||
with pytest.raises(PaperlessError):
|
||||
await client.get_tag_id("payslip")
|
||||
|
||||
|
||||
async def test_list_tagged_documents_paginates(client: PaperlessClient) -> None:
|
||||
with respx.mock() as mock:
|
||||
mock.get("http://paperless.test/api/tags/").mock(
|
||||
return_value=httpx.Response(200, json={"results": [{
|
||||
"id": 7,
|
||||
"name": "payslip"
|
||||
}]}))
|
||||
mock.get("http://paperless.test/api/documents/?tags__id=7").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"results": [{
|
||||
"id": 1
|
||||
}, {
|
||||
"id": 2
|
||||
}],
|
||||
"next": "http://paperless.test/api/documents/?tags__id=7&page=2",
|
||||
},
|
||||
))
|
||||
mock.get("http://paperless.test/api/documents/?tags__id=7&page=2").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"results": [{
|
||||
"id": 3
|
||||
}],
|
||||
"next": None
|
||||
},
|
||||
))
|
||||
ids = [doc["id"] async for doc in client.list_tagged_documents("payslip")]
|
||||
assert ids == [1, 2, 3]
|
||||
|
||||
|
||||
async def test_get_document_returns_metadata(client: PaperlessClient) -> None:
|
||||
with respx.mock(base_url="http://paperless.test") as mock:
|
||||
mock.get("/api/documents/42/").mock(
|
||||
return_value=httpx.Response(200, json={
|
||||
"id": 42,
|
||||
"title": "Payslip Mar"
|
||||
}))
|
||||
data = await client.get_document(42)
|
||||
assert data["title"] == "Payslip Mar"
|
||||
|
||||
|
||||
async def test_download_document_returns_bytes(client: PaperlessClient) -> None:
|
||||
with respx.mock(base_url="http://paperless.test") as mock:
|
||||
mock.get("/api/documents/42/download/").mock(
|
||||
return_value=httpx.Response(200, content=b"PDFDATA"))
|
||||
data = await client.download_document(42)
|
||||
assert data == b"PDFDATA"
|
||||
127
tests/test_processor.py
Normal file
127
tests/test_processor.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from payslip_ingest.processor import process_document
|
||||
from payslip_ingest.schema import ExtractedPayslip
|
||||
|
||||
|
||||
def _sample_extraction() -> ExtractedPayslip:
|
||||
return ExtractedPayslip(
|
||||
pay_date=date(2026, 3, 28),
|
||||
pay_period_start=date(2026, 3, 1),
|
||||
pay_period_end=date(2026, 3, 31),
|
||||
employer="Acme Ltd",
|
||||
currency="GBP",
|
||||
gross_pay=Decimal("5000.00"),
|
||||
income_tax=Decimal("800.00"),
|
||||
national_insurance=Decimal("350.00"),
|
||||
pension_employee=Decimal("250.00"),
|
||||
pension_employer=Decimal("150.00"),
|
||||
student_loan=Decimal("100.00"),
|
||||
other_deductions={"cycle_to_work": Decimal("50.00")},
|
||||
net_pay=Decimal("3450.00"),
|
||||
)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Minimal AsyncSession stand-in that records flushes and execute calls."""
|
||||
|
||||
def __init__(self, existing_ids: list[int]):
|
||||
self._existing_ids = existing_ids
|
||||
self.added: list[Any] = []
|
||||
self.begin_calls = 0
|
||||
|
||||
async def __aenter__(self) -> "_FakeSession":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc: object) -> None:
|
||||
return None
|
||||
|
||||
def begin(self) -> "_FakeSession":
|
||||
self.begin_calls += 1
|
||||
return self
|
||||
|
||||
async def execute(self, stmt: Any) -> Any:
|
||||
result = MagicMock()
|
||||
# scalar() returns None when we treat the row as missing.
|
||||
result.scalar.return_value = self._existing_ids.pop(0) if self._existing_ids else None
|
||||
return result
|
||||
|
||||
def add(self, row: Any) -> None:
|
||||
row.id = 1
|
||||
self.added.append(row)
|
||||
|
||||
async def flush(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _SessionFactory:
|
||||
|
||||
def __init__(self, sessions: list[_FakeSession]):
|
||||
self._sessions = list(sessions)
|
||||
self.used: list[_FakeSession] = []
|
||||
|
||||
def __call__(self) -> _FakeSession:
|
||||
session = self._sessions.pop(0)
|
||||
self.used.append(session)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def paperless() -> AsyncMock:
|
||||
mock = AsyncMock()
|
||||
mock.get_document.return_value = {"id": 42, "title": "Payslip"}
|
||||
mock.download_document.return_value = b"PDFDATA"
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def extractor() -> AsyncMock:
|
||||
mock = AsyncMock()
|
||||
mock.extract.return_value = _sample_extraction()
|
||||
return mock
|
||||
|
||||
|
||||
async def test_process_document_inserts_new(paperless: AsyncMock, extractor: AsyncMock) -> None:
|
||||
factory = _SessionFactory([_FakeSession(existing_ids=[]), _FakeSession(existing_ids=[])])
|
||||
|
||||
result = await process_document(42, factory, paperless, extractor)
|
||||
|
||||
assert result.status == "inserted"
|
||||
assert result.validated is True
|
||||
paperless.get_document.assert_awaited_once_with(42)
|
||||
paperless.download_document.assert_awaited_once_with(42)
|
||||
extractor.extract.assert_awaited_once()
|
||||
inserted_row = factory.used[1].added[0]
|
||||
assert inserted_row.paperless_doc_id == 42
|
||||
assert inserted_row.tax_year == "2025/26"
|
||||
|
||||
|
||||
async def test_process_document_skips_existing(paperless: AsyncMock, extractor: AsyncMock) -> None:
|
||||
factory = _SessionFactory([_FakeSession(existing_ids=[99])])
|
||||
|
||||
result = await process_document(42, factory, paperless, extractor)
|
||||
|
||||
assert result.status == "skipped"
|
||||
paperless.get_document.assert_not_called()
|
||||
extractor.extract.assert_not_called()
|
||||
|
||||
|
||||
async def test_process_document_flags_validation_failure(paperless: AsyncMock,
|
||||
extractor: AsyncMock) -> None:
|
||||
bad = _sample_extraction()
|
||||
bad_dict = bad.model_dump()
|
||||
bad_dict["net_pay"] = Decimal("9999.00")
|
||||
extractor.extract.return_value = ExtractedPayslip.model_validate(bad_dict)
|
||||
|
||||
factory = _SessionFactory([_FakeSession(existing_ids=[]), _FakeSession(existing_ids=[])])
|
||||
|
||||
result = await process_document(42, factory, paperless, extractor)
|
||||
|
||||
assert result.status == "inserted"
|
||||
assert result.validated is False
|
||||
assert factory.used[1].added[0].validated is False
|
||||
52
tests/test_schema.py
Normal file
52
tests/test_schema.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from payslip_ingest.schema import ExtractedPayslip, validate_totals
|
||||
|
||||
|
||||
def _sample_payload() -> dict[str, object]:
|
||||
return {
|
||||
"pay_date": "2026-03-28",
|
||||
"pay_period_start": "2026-03-01",
|
||||
"pay_period_end": "2026-03-31",
|
||||
"employer": "Acme Ltd",
|
||||
"currency": "GBP",
|
||||
"gross_pay": "5000.00",
|
||||
"income_tax": "800.00",
|
||||
"national_insurance": "350.00",
|
||||
"pension_employee": "250.00",
|
||||
"pension_employer": "150.00",
|
||||
"student_loan": "100.00",
|
||||
"other_deductions": {
|
||||
"cycle_to_work": "50.00"
|
||||
},
|
||||
"net_pay": "3450.00",
|
||||
}
|
||||
|
||||
|
||||
def test_schema_accepts_realistic_payload() -> None:
|
||||
model = ExtractedPayslip.model_validate(_sample_payload())
|
||||
assert model.employer == "Acme Ltd"
|
||||
assert model.gross_pay == Decimal("5000.00")
|
||||
assert model.other_deductions == {"cycle_to_work": Decimal("50.00")}
|
||||
|
||||
|
||||
def test_schema_rejects_extra_fields() -> None:
|
||||
payload = _sample_payload()
|
||||
payload["bonus_field"] = "not allowed"
|
||||
with pytest.raises(ValidationError):
|
||||
ExtractedPayslip.model_validate(payload)
|
||||
|
||||
|
||||
def test_validate_totals_true_for_matched_numbers() -> None:
|
||||
model = ExtractedPayslip.model_validate(_sample_payload())
|
||||
assert validate_totals(model) is True
|
||||
|
||||
|
||||
def test_validate_totals_false_for_mismatch() -> None:
|
||||
payload = _sample_payload()
|
||||
payload["net_pay"] = "4000.00"
|
||||
model = ExtractedPayslip.model_validate(payload)
|
||||
assert validate_totals(model) is False
|
||||
22
tests/test_tax_year.py
Normal file
22
tests/test_tax_year.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from payslip_ingest.tax_year import derive_tax_year
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("pay_date", "expected"),
|
||||
[
|
||||
(date(2025, 4, 5), "2024/25"),
|
||||
(date(2025, 4, 6), "2025/26"),
|
||||
(date(2026, 4, 5), "2025/26"),
|
||||
(date(2026, 4, 6), "2026/27"),
|
||||
(date(2026, 12, 31), "2026/27"),
|
||||
(date(2027, 1, 1), "2026/27"),
|
||||
(date(2027, 4, 5), "2026/27"),
|
||||
(date(2027, 4, 6), "2027/28"),
|
||||
],
|
||||
)
|
||||
def test_derive_tax_year(pay_date: date, expected: str) -> None:
|
||||
assert derive_tax_year(pay_date) == expected
|
||||
111
tests/test_webhook.py
Normal file
111
tests/test_webhook.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Header, HTTPException, status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from payslip_ingest.app import _verify_bearer
|
||||
from payslip_ingest.schema import WebhookPayload
|
||||
|
||||
|
||||
def _build_app() -> tuple[FastAPI, list[int]]:
|
||||
"""Build a minimal FastAPI app that mirrors the real /webhook behaviour.
|
||||
|
||||
Mirroring rather than importing lets us avoid booting SQLAlchemy / httpx
|
||||
clients that the real `lifespan` constructs on startup.
|
||||
"""
|
||||
seen: list[int] = []
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
queue: asyncio.Queue[int] = asyncio.Queue()
|
||||
app.state.queue = queue
|
||||
|
||||
async def worker() -> None:
|
||||
while True:
|
||||
doc_id = await queue.get()
|
||||
seen.append(doc_id)
|
||||
queue.task_done()
|
||||
|
||||
task = asyncio.create_task(worker())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@app.post("/webhook", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def webhook(
|
||||
payload: WebhookPayload,
|
||||
authorization: str | None = Header(default=None),
|
||||
) -> dict[str, object]:
|
||||
_verify_bearer(authorization, os.environ.get("WEBHOOK_BEARER_TOKEN", ""))
|
||||
queue: asyncio.Queue[int] = app.state.queue
|
||||
await queue.put(payload.document_id)
|
||||
return {"status": "accepted", "document_id": payload.document_id}
|
||||
|
||||
return app, seen
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client() -> Iterator[TestClient]:
|
||||
app, seen = _build_app()
|
||||
app.state.seen = seen
|
||||
with TestClient(app) as tc:
|
||||
yield tc
|
||||
|
||||
|
||||
def test_webhook_rejects_missing_auth(client: TestClient) -> None:
|
||||
resp = client.post("/webhook", json={"document_id": 42})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_webhook_rejects_wrong_bearer(client: TestClient) -> None:
|
||||
resp = client.post(
|
||||
"/webhook",
|
||||
json={"document_id": 42},
|
||||
headers={"Authorization": "Bearer wrong"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_webhook_accepts_valid_request(client: TestClient) -> None:
|
||||
resp = client.post(
|
||||
"/webhook",
|
||||
json={"document_id": 42},
|
||||
headers={"Authorization": f"Bearer {os.environ['WEBHOOK_BEARER_TOKEN']}"},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
assert resp.json() == {"status": "accepted", "document_id": 42}
|
||||
|
||||
queue: asyncio.Queue[int] = client.app.state.queue # type: ignore[attr-defined]
|
||||
# Join the queue so the worker actually picks up our enqueued doc.
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(asyncio.wait_for(queue.join(), timeout=2.0))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
seen: list[int] = client.app.state.seen # type: ignore[attr-defined]
|
||||
assert 42 in seen
|
||||
|
||||
|
||||
def test_webhook_rejects_malformed_body(client: TestClient) -> None:
|
||||
resp = client.post(
|
||||
"/webhook",
|
||||
json={"document_id": "not-an-int"},
|
||||
headers={"Authorization": f"Bearer {os.environ['WEBHOOK_BEARER_TOKEN']}"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_verify_bearer_rejects_unconfigured_service() -> None:
|
||||
with pytest.raises(HTTPException):
|
||||
_verify_bearer("Bearer anything", "")
|
||||
Loading…
Add table
Add a link
Reference in a new issue