import os from datetime import date, datetime from decimal import Decimal from typing import Any from sqlalchemy import JSON, TIMESTAMP, Date, Integer, Numeric, String, func, text from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column SCHEMA_NAME = "fire_planner" class Base(DeclarativeBase): pass # JSONB on Postgres, plain JSON on SQLite — tests use SQLite, prod uses Postgres. JSON_TYPE = JSONB().with_variant(JSON(), "sqlite") class AccountSnapshot(Base): """Daily NW per account from Wealthfolio (filled by ingest). `external_id` is `wealthfolio:{account_id}:{date}` so re-runs on the same day are idempotent — Wealthfolio keeps one snapshot per account per day. """ __tablename__ = "account_snapshot" __table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012 id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) external_id: Mapped[str] = mapped_column(String, unique=True, nullable=False) snapshot_date: Mapped[date] = mapped_column(Date, nullable=False, index=True) account_id: Mapped[str] = mapped_column(String, nullable=False, index=True) account_name: Mapped[str] = mapped_column(String, nullable=False) account_type: Mapped[str] = mapped_column(String, nullable=False) currency: Mapped[str] = mapped_column(String(3), nullable=False, server_default="GBP") market_value: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False) market_value_gbp: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False) cost_basis_gbp: Mapped[Decimal | None] = mapped_column(Numeric(14, 2), nullable=True) raw_extraction: Mapped[dict[str, Any] | None] = mapped_column(JSON_TYPE, nullable=True) created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now()) class Scenario(Base): """A simulation scenario — Cartesian point in (jurisdiction × strategy × leave_year × glide × spending) space. The Cartesian product is rebuilt from `scenarios.py` every recompute; rows are upserted on `external_id`. """ __tablename__ = "scenario" __table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012 id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) external_id: Mapped[str] = mapped_column(String, unique=True, nullable=False) jurisdiction: Mapped[str] = mapped_column(String(32), nullable=False, index=True) strategy: Mapped[str] = mapped_column(String(32), nullable=False, index=True) leave_uk_year: Mapped[int] = mapped_column(Integer, nullable=False) glide_path: Mapped[str] = mapped_column(String(32), nullable=False) spending_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False) horizon_years: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("60")) nw_seed_gbp: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False) savings_per_year_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False, server_default=text("0")) config_json: Mapped[dict[str, Any]] = mapped_column(JSON_TYPE, nullable=False) created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now()) class McRun(Base): """One MC execution per (scenario, run_at). Stores execution metadata + summary statistics — enough to populate a Grafana cell without touching the per-path tables.""" __tablename__ = "mc_run" __table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012 id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) scenario_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True) run_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now()) n_paths: Mapped[int] = mapped_column(Integer, nullable=False) seed: Mapped[int] = mapped_column(Integer, nullable=False) success_rate: Mapped[Decimal] = mapped_column(Numeric(6, 4), nullable=False) p10_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) p50_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) p90_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) median_lifetime_tax_gbp: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False) median_years_to_ruin: Mapped[Decimal | None] = mapped_column(Numeric(6, 2), nullable=True) elapsed_seconds: Mapped[Decimal] = mapped_column(Numeric(8, 3), nullable=False) sequence_risk_correlation: Mapped[Decimal | None] = mapped_column(Numeric(6, 4), nullable=True) extra: Mapped[dict[str, Any] | None] = mapped_column(JSON_TYPE, nullable=True) class McPath(Base): """Sparse per-path storage: top decile, bottom decile, and median paths fully stored — enough for a fan chart, not 10k×60 ≈ 600k rows.""" __tablename__ = "mc_path" __table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012 id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) mc_run_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True) path_idx: Mapped[int] = mapped_column(Integer, nullable=False) bucket: Mapped[str] = mapped_column(String(16), nullable=False) year_idx: Mapped[int] = mapped_column(Integer, nullable=False) portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) withdrawal_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False) tax_paid_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False) real_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) class ProjectionYearly(Base): """Deterministic point projection per scenario — per-year point estimates that drive fan charts and the per-year Grafana table. One row per (scenario, year).""" __tablename__ = "projection_yearly" __table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012 id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) mc_run_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True) year_idx: Mapped[int] = mapped_column(Integer, nullable=False) p10_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) p25_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) p50_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) p75_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) p90_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) p50_withdrawal_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False) p50_tax_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False) survival_rate: Mapped[Decimal] = mapped_column(Numeric(6, 4), nullable=False) class ScenarioSummary(Base): """Denormalised fast-read for Grafana — one row per (scenario, latest run).""" __tablename__ = "scenario_summary" __table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012 id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) scenario_id: Mapped[int] = mapped_column(Integer, unique=True, nullable=False) mc_run_id: Mapped[int] = mapped_column(Integer, nullable=False) jurisdiction: Mapped[str] = mapped_column(String(32), nullable=False, index=True) strategy: Mapped[str] = mapped_column(String(32), nullable=False, index=True) leave_uk_year: Mapped[int] = mapped_column(Integer, nullable=False) glide_path: Mapped[str] = mapped_column(String(32), nullable=False) spending_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False) success_rate: Mapped[Decimal] = mapped_column(Numeric(6, 4), nullable=False) p10_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) p50_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) p90_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False) median_lifetime_tax_gbp: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False) median_years_to_ruin: Mapped[Decimal | None] = mapped_column(Numeric(6, 2), nullable=True) updated_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now()) def create_engine_from_env() -> AsyncEngine: url = os.environ["DB_CONNECTION_STRING"] return create_async_engine(url, pool_pre_ping=True) def make_session_factory(engine: AsyncEngine) -> async_sessionmaker[Any]: return async_sessionmaker(engine, expire_on_commit=False)