from datetime import date from decimal import Decimal from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest from payslip_ingest.processor import process_document from payslip_ingest.schema import ExtractedPayslip def _sample_extraction() -> ExtractedPayslip: return ExtractedPayslip( pay_date=date(2026, 3, 28), pay_period_start=date(2026, 3, 1), pay_period_end=date(2026, 3, 31), employer="Acme Ltd", currency="GBP", gross_pay=Decimal("5000.00"), income_tax=Decimal("800.00"), national_insurance=Decimal("350.00"), pension_employee=Decimal("250.00"), pension_employer=Decimal("150.00"), student_loan=Decimal("100.00"), other_deductions={"cycle_to_work": Decimal("50.00")}, net_pay=Decimal("3450.00"), ) class _FakeSession: """Minimal AsyncSession stand-in that records flushes and execute calls.""" def __init__(self, existing_ids: list[int]): self._existing_ids = existing_ids self.added: list[Any] = [] self.begin_calls = 0 async def __aenter__(self) -> "_FakeSession": return self async def __aexit__(self, *exc: object) -> None: return None def begin(self) -> "_FakeSession": self.begin_calls += 1 return self async def execute(self, stmt: Any) -> Any: result = MagicMock() # scalar() returns None when we treat the row as missing. result.scalar.return_value = self._existing_ids.pop(0) if self._existing_ids else None return result def add(self, row: Any) -> None: row.id = 1 self.added.append(row) async def flush(self) -> None: return None class _SessionFactory: def __init__(self, sessions: list[_FakeSession]): self._sessions = list(sessions) self.used: list[_FakeSession] = [] def __call__(self) -> _FakeSession: session = self._sessions.pop(0) self.used.append(session) return session @pytest.fixture() def paperless() -> AsyncMock: mock = AsyncMock() mock.get_document.return_value = {"id": 42, "title": "Payslip"} mock.download_document.return_value = b"PDFDATA" return mock @pytest.fixture() def extractor() -> AsyncMock: mock = AsyncMock() mock.extract.return_value = _sample_extraction() return mock async def test_process_document_inserts_new(paperless: AsyncMock, extractor: AsyncMock) -> None: factory = _SessionFactory([_FakeSession(existing_ids=[]), _FakeSession(existing_ids=[])]) result = await process_document(42, factory, paperless, extractor) assert result.status == "inserted" assert result.validated is True paperless.get_document.assert_awaited_once_with(42) paperless.download_document.assert_awaited_once_with(42) extractor.extract.assert_awaited_once() inserted_row = factory.used[1].added[0] assert inserted_row.paperless_doc_id == 42 assert inserted_row.tax_year == "2025/26" async def test_process_document_skips_existing(paperless: AsyncMock, extractor: AsyncMock) -> None: factory = _SessionFactory([_FakeSession(existing_ids=[99])]) result = await process_document(42, factory, paperless, extractor) assert result.status == "skipped" paperless.get_document.assert_not_called() extractor.extract.assert_not_called() async def test_process_document_flags_validation_failure(paperless: AsyncMock, extractor: AsyncMock) -> None: bad = _sample_extraction() bad_dict = bad.model_dump() bad_dict["net_pay"] = Decimal("9999.00") extractor.extract.return_value = ExtractedPayslip.model_validate(bad_dict) factory = _SessionFactory([_FakeSession(existing_ids=[]), _FakeSession(existing_ids=[])]) result = await process_document(42, factory, paperless, extractor) assert result.status == "inserted" assert result.validated is False assert factory.used[1].added[0].validated is False