payslip-ingest/payslip_ingest/processor.py

104 lines
3.5 KiB
Python
Raw Normal View History

import json
import logging
from dataclasses import dataclass
from decimal import Decimal
from typing import Any, Protocol
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker
from payslip_ingest.db import Payslip
from payslip_ingest.extractor import ClaudeExtractor
from payslip_ingest.paperless import PaperlessClient
from payslip_ingest.schema import ExtractedPayslip, validate_totals
from payslip_ingest.tax_year import derive_tax_year
log = logging.getLogger(__name__)
class _SessionFactory(Protocol):
def __call__(self) -> Any:
...
@dataclass
class ProcessResult:
doc_id: int
status: str
payslip_id: int | None = None
validated: bool | None = None
async def process_document(
doc_id: int,
db_session_factory: async_sessionmaker[Any] | _SessionFactory,
paperless: PaperlessClient,
extractor: ClaudeExtractor,
) -> ProcessResult:
async with db_session_factory() as session:
existing = await session.execute(
select(Payslip.id).where(Payslip.paperless_doc_id == doc_id))
if existing.scalar() is not None:
log.info("skipping doc_id=%s — already ingested", doc_id)
return ProcessResult(doc_id=doc_id, status="skipped")
metadata = await paperless.get_document(doc_id)
pdf_bytes = await paperless.download_document(doc_id)
extracted = await extractor.extract(pdf_bytes, metadata)
validated = validate_totals(extracted)
if not validated:
log.warning(
"totals mismatch for doc_id=%s gross=%s net=%s — storing validated=False",
doc_id,
extracted.gross_pay,
extracted.net_pay,
)
payslip_id = await _insert_payslip(db_session_factory, doc_id, extracted, validated)
status = "inserted" if payslip_id is not None else "skipped"
return ProcessResult(doc_id=doc_id, status=status, payslip_id=payslip_id, validated=validated)
async def _insert_payslip(
db_session_factory: async_sessionmaker[Any] | _SessionFactory,
doc_id: int,
extracted: ExtractedPayslip,
validated: bool,
) -> int | None:
raw = json.loads(extracted.model_dump_json())
async with db_session_factory() as session, session.begin():
existing = await session.execute(
select(Payslip.id).where(Payslip.paperless_doc_id == doc_id))
existing_id = existing.scalar()
if existing_id is not None:
return None
row = Payslip(
paperless_doc_id=doc_id,
pay_date=extracted.pay_date,
pay_period_start=extracted.pay_period_start,
pay_period_end=extracted.pay_period_end,
employer=extracted.employer,
currency=extracted.currency,
gross_pay=extracted.gross_pay,
income_tax=extracted.income_tax,
national_insurance=extracted.national_insurance,
pension_employee=extracted.pension_employee,
pension_employer=extracted.pension_employer,
student_loan=extracted.student_loan,
other_deductions=_decimals_to_float(extracted.other_deductions),
net_pay=extracted.net_pay,
tax_year=derive_tax_year(extracted.pay_date),
raw_extraction=raw,
validated=validated,
)
session.add(row)
await session.flush()
return row.id
def _decimals_to_float(mapping: dict[str, Decimal]) -> dict[str, float]:
return {k: float(v) for k, v in mapping.items()}