hmrc-sync/hmrc_sync/db.py

71 lines
3.2 KiB
Python
Raw Permalink Normal View History

2026-05-07 17:06:11 +00:00
import os
from datetime import datetime
from decimal import Decimal
from typing import Any
from sqlalchemy import JSON, TIMESTAMP, Integer, Numeric, String, text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
SCHEMA_NAME = "hmrc_sync"
class Base(DeclarativeBase):
pass
JSON_TYPE = JSONB().with_variant(JSON(), "sqlite")
class TaxYearSnapshot(Base):
"""One row per (tax_year, employer_paye_ref, snapshot_date).
HMRC returns the `hmrc-held` view of annual PAYE/NI for a given
employment. Taking a daily snapshot lets us see HMRC's figures evolve
as late RTI filings land, and lets the dashboard always show the
latest value by snapshot_date.
"""
__tablename__ = "tax_year_snapshot"
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
tax_year: Mapped[str] = mapped_column(String, nullable=False, index=True)
employer_paye_ref: Mapped[str] = mapped_column(String, nullable=False)
snapshot_date: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False)
gross_pay: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
income_tax: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
ni_contributions: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
source: Mapped[str] = mapped_column(String, nullable=False, server_default="hmrc-held")
raw_response: Mapped[dict[str, Any]] = mapped_column(JSON_TYPE, nullable=False)
fetched_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True),
nullable=False,
server_default=text("now()"))
class FetchLog(Base):
"""Audit trail of every HMRC API call — for fraud-header compliance review."""
__tablename__ = "fetch_log"
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
endpoint: Mapped[str] = mapped_column(String, nullable=False)
status_code: Mapped[int] = mapped_column(Integer, nullable=False)
request_id: Mapped[str | None] = mapped_column(String, nullable=True)
correlation_id: Mapped[str | None] = mapped_column(String, nullable=True)
fraud_headers_sent: Mapped[dict[str, Any]] = mapped_column(JSON_TYPE, nullable=False)
response_snippet: Mapped[str | None] = mapped_column(String, nullable=True)
duration_ms: Mapped[int] = mapped_column(Integer, nullable=False)
fetched_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True),
nullable=False,
server_default=text("now()"))
def create_engine_from_env() -> AsyncEngine:
url = os.environ["DB_CONNECTION_STRING"]
return create_async_engine(url, pool_pre_ping=True)
def make_session_factory(engine: AsyncEngine) -> async_sessionmaker[Any]:
return async_sessionmaker(engine, expire_on_commit=False)