import asyncio import base64 import json from typing import Any import httpx from pydantic import ValidationError from payslip_ingest.schema import ExtractedPayslip AGENT_PATH = ".claude/agents/payslip-extractor" EXTRACTION_PROMPT = ( "You are extracting fields from a UK payslip PDF. Return ONLY a single JSON object " "matching this exact schema — no prose, no markdown fences.\n" "\n" "Schema:\n" "{\n" ' "pay_date": "YYYY-MM-DD",\n' ' "pay_period_start": "YYYY-MM-DD or null",\n' ' "pay_period_end": "YYYY-MM-DD or null",\n' ' "employer": "string or null",\n' ' "currency": "GBP",\n' ' "gross_pay": number,\n' ' "income_tax": number,\n' ' "national_insurance": number,\n' ' "pension_employee": number,\n' ' "pension_employer": number,\n' ' "student_loan": number,\n' ' "other_deductions": {"label": number, ...},\n' ' "net_pay": number\n' "}\n" "\n" "Rules:\n" "- Report numbers as the payslip shows them; do not compute sums.\n" "- Unknown numeric fields → 0, not null.\n" "- `other_deductions` covers cycle-to-work, share-save, benefits-in-kind, court orders, " "anything not in the main fields.\n" "- All money in GBP unless the payslip is denominated otherwise.\n" '- If a field\'s value is ambiguous, pick the value from the "this period" column, not YTD.') POLL_INTERVAL_SECONDS = 2 MAX_POLL_SECONDS = 120 BUSY_RETRY_DELAY_SECONDS = 5 MAX_BUSY_RETRIES = 10 DEFAULT_MAX_BUDGET_USD = 1.0 DEFAULT_TIMEOUT_SECONDS = 300 TERMINAL_STATUSES = {"completed", "failed", "timeout", "error"} class ExtractorError(RuntimeError): pass class ClaudeExtractor: """Calls claude-agent-service to extract structured fields from a payslip PDF. The agent service serializes execution (one job at a time, 409 when busy); we back off and retry so the caller-side queue doesn't have to know. """ def __init__( self, base_url: str, bearer_token: str, client: httpx.AsyncClient | None = None, ): self._base_url = base_url.rstrip("/") self._headers = {"Authorization": f"Bearer {bearer_token}"} self._client = client or httpx.AsyncClient(timeout=60.0) self._owns_client = client is None async def aclose(self) -> None: if self._owns_client: await self._client.aclose() async def __aenter__(self) -> "ClaudeExtractor": return self async def __aexit__(self, *exc: object) -> None: await self.aclose() async def extract(self, pdf_bytes: bytes, doc_metadata: dict[str, Any]) -> ExtractedPayslip: job_id = await self._submit_job(pdf_bytes, doc_metadata) output_lines = await self._poll_until_done(job_id) payload = _parse_output(output_lines) try: return ExtractedPayslip.model_validate(payload) except ValidationError as exc: raise ExtractorError(f"Extracted payload failed schema validation: {exc}") from exc async def _submit_job(self, pdf_bytes: bytes, doc_metadata: dict[str, Any]) -> str: encoded = base64.b64encode(pdf_bytes).decode("ascii") prompt = f"{EXTRACTION_PROMPT}\n\nPDF_BASE64:\n{encoded}\n" body = { "prompt": prompt, "agent": AGENT_PATH, "max_budget_usd": DEFAULT_MAX_BUDGET_USD, "timeout_seconds": DEFAULT_TIMEOUT_SECONDS, "metadata": { "paperless_doc_id": doc_metadata.get("id") }, } for _ in range(MAX_BUSY_RETRIES): resp = await self._client.post(f"{self._base_url}/execute", headers=self._headers, json=body) if resp.status_code == 409: await asyncio.sleep(BUSY_RETRY_DELAY_SECONDS) continue resp.raise_for_status() job_id = resp.json().get("job_id") if not isinstance(job_id, str): raise ExtractorError(f"Missing job_id in response: {resp.json()}") return job_id raise ExtractorError(f"Agent service remained busy after {MAX_BUSY_RETRIES} retries") async def _poll_until_done(self, job_id: str) -> list[str]: max_iterations = max(1, MAX_POLL_SECONDS // max(1, POLL_INTERVAL_SECONDS)) for _ in range(max_iterations): resp = await self._client.get(f"{self._base_url}/jobs/{job_id}", headers=self._headers) resp.raise_for_status() job = resp.json() status = job.get("status") if status in TERMINAL_STATUSES: if status != "completed": raise ExtractorError(f"Job {job_id} terminated with status={status}: {job}") output = job.get("output", []) if not isinstance(output, list): raise ExtractorError(f"Job {job_id} output is not a list: {output!r}") return [str(line) for line in output] await asyncio.sleep(POLL_INTERVAL_SECONDS) raise TimeoutError(f"Job {job_id} did not complete within {MAX_POLL_SECONDS}s") def _parse_output(output_lines: list[str]) -> dict[str, Any]: """Extract the JSON payload from claude CLI --output-format json stream. The CLI emits one JSON object per line; the final 'result' message holds the assistant's final text. We walk from the end, parse each line, and return the first embedded JSON object we can recover from the assistant response. """ non_empty = [line.strip() for line in output_lines if line.strip()] if not non_empty: raise ExtractorError("Agent produced no output") for line in reversed(non_empty): try: parsed = json.loads(line) except json.JSONDecodeError: continue text = _extract_assistant_text(parsed) if text is None: continue payload = _first_json_object(text) if payload is not None: return payload # Fallback: the last line itself might be the JSON object. try: candidate = json.loads(non_empty[-1]) except json.JSONDecodeError as exc: raise ExtractorError(f"Could not parse JSON from agent output: {exc}") from exc if isinstance(candidate, dict): return candidate raise ExtractorError(f"Last agent line is not a JSON object: {candidate!r}") def _extract_assistant_text(parsed: Any) -> str | None: if not isinstance(parsed, dict): return None result = parsed.get("result") if parsed.get("type") == "result" and isinstance(result, str): return result message = parsed.get("message") if isinstance(message, dict): content = message.get("content") if isinstance(content, list): texts = [ block.get("text", "") for block in content if isinstance(block, dict) and block.get("type") == "text" ] combined = "".join(str(t) for t in texts) if combined: return combined if isinstance(content, str): return content text = parsed.get("text") if isinstance(text, str): return text return None def _first_json_object(text: str) -> dict[str, Any] | None: start = text.find("{") while start != -1: depth = 0 for i in range(start, len(text)): ch = text[i] if ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: candidate = text[start:i + 1] try: obj = json.loads(candidate) except json.JSONDecodeError: break if isinstance(obj, dict): return obj break start = text.find("{", start + 1) return None