import os from datetime import datetime from decimal import Decimal from typing import Any from sqlalchemy import JSON, TIMESTAMP, Integer, Numeric, String, text from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column SCHEMA_NAME = "hmrc_sync" class Base(DeclarativeBase): pass JSON_TYPE = JSONB().with_variant(JSON(), "sqlite") class TaxYearSnapshot(Base): """One row per (tax_year, employer_paye_ref, snapshot_date). HMRC returns the `hmrc-held` view of annual PAYE/NI for a given employment. Taking a daily snapshot lets us see HMRC's figures evolve as late RTI filings land, and lets the dashboard always show the latest value by snapshot_date. """ __tablename__ = "tax_year_snapshot" __table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012 id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) tax_year: Mapped[str] = mapped_column(String, nullable=False, index=True) employer_paye_ref: Mapped[str] = mapped_column(String, nullable=False) snapshot_date: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False) gross_pay: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False) income_tax: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False) ni_contributions: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False) source: Mapped[str] = mapped_column(String, nullable=False, server_default="hmrc-held") raw_response: Mapped[dict[str, Any]] = mapped_column(JSON_TYPE, nullable=False) fetched_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, server_default=text("now()")) class FetchLog(Base): """Audit trail of every HMRC API call — for fraud-header compliance review.""" __tablename__ = "fetch_log" __table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012 id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) endpoint: Mapped[str] = mapped_column(String, nullable=False) status_code: Mapped[int] = mapped_column(Integer, nullable=False) request_id: Mapped[str | None] = mapped_column(String, nullable=True) correlation_id: Mapped[str | None] = mapped_column(String, nullable=True) fraud_headers_sent: Mapped[dict[str, Any]] = mapped_column(JSON_TYPE, nullable=False) response_snippet: Mapped[str | None] = mapped_column(String, nullable=True) duration_ms: Mapped[int] = mapped_column(Integer, nullable=False) fetched_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, server_default=text("now()")) def create_engine_from_env() -> AsyncEngine: url = os.environ["DB_CONNECTION_STRING"] return create_async_engine(url, pool_pre_ping=True) def make_session_factory(engine: AsyncEngine) -> async_sessionmaker[Any]: return async_sessionmaker(engine, expire_on_commit=False)