import json import logging from dataclasses import dataclass from decimal import Decimal from typing import Any, Protocol from sqlalchemy import select from sqlalchemy.ext.asyncio import async_sessionmaker from payslip_ingest.db import Payslip from payslip_ingest.extractor import ClaudeExtractor from payslip_ingest.paperless import PaperlessClient from payslip_ingest.schema import ExtractedPayslip, validate_totals from payslip_ingest.tax_year import derive_tax_year log = logging.getLogger(__name__) class _SessionFactory(Protocol): def __call__(self) -> Any: ... @dataclass class ProcessResult: doc_id: int status: str payslip_id: int | None = None validated: bool | None = None async def process_document( doc_id: int, db_session_factory: async_sessionmaker[Any] | _SessionFactory, paperless: PaperlessClient, extractor: ClaudeExtractor, ) -> ProcessResult: async with db_session_factory() as session: existing = await session.execute( select(Payslip.id).where(Payslip.paperless_doc_id == doc_id)) if existing.scalar() is not None: log.info("skipping doc_id=%s — already ingested", doc_id) return ProcessResult(doc_id=doc_id, status="skipped") metadata = await paperless.get_document(doc_id) pdf_bytes = await paperless.download_document(doc_id) extracted = await extractor.extract(pdf_bytes, metadata) validated = validate_totals(extracted) if not validated: log.warning( "totals mismatch for doc_id=%s gross=%s net=%s — storing validated=False", doc_id, extracted.gross_pay, extracted.net_pay, ) payslip_id = await _insert_payslip(db_session_factory, doc_id, extracted, validated) status = "inserted" if payslip_id is not None else "skipped" return ProcessResult(doc_id=doc_id, status=status, payslip_id=payslip_id, validated=validated) async def _insert_payslip( db_session_factory: async_sessionmaker[Any] | _SessionFactory, doc_id: int, extracted: ExtractedPayslip, validated: bool, ) -> int | None: raw = json.loads(extracted.model_dump_json()) async with db_session_factory() as session, session.begin(): existing = await session.execute( select(Payslip.id).where(Payslip.paperless_doc_id == doc_id)) existing_id = existing.scalar() if existing_id is not None: return None row = Payslip( paperless_doc_id=doc_id, pay_date=extracted.pay_date, pay_period_start=extracted.pay_period_start, pay_period_end=extracted.pay_period_end, employer=extracted.employer, currency=extracted.currency, gross_pay=extracted.gross_pay, income_tax=extracted.income_tax, national_insurance=extracted.national_insurance, pension_employee=extracted.pension_employee, pension_employer=extracted.pension_employer, student_loan=extracted.student_loan, other_deductions=_decimals_to_float(extracted.other_deductions), net_pay=extracted.net_pay, tax_year=derive_tax_year(extracted.pay_date), raw_extraction=raw, validated=validated, ) session.add(row) await session.flush() return row.id def _decimals_to_float(mapping: dict[str, Decimal]) -> dict[str, float]: return {k: float(v) for k, v in mapping.items()}