fire-planner/fire_planner/db.py
2026-05-07 17:06:19 +00:00

165 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from datetime import date, datetime
from decimal import Decimal
from typing import Any
from sqlalchemy import JSON, TIMESTAMP, Date, Integer, Numeric, String, func, text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
SCHEMA_NAME = "fire_planner"
class Base(DeclarativeBase):
pass
# JSONB on Postgres, plain JSON on SQLite — tests use SQLite, prod uses Postgres.
JSON_TYPE = JSONB().with_variant(JSON(), "sqlite")
class AccountSnapshot(Base):
"""Daily NW per account from Wealthfolio (filled by ingest).
`external_id` is `wealthfolio:{account_id}:{date}` so re-runs on the same
day are idempotent — Wealthfolio keeps one snapshot per account per day.
"""
__tablename__ = "account_snapshot"
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
external_id: Mapped[str] = mapped_column(String, unique=True, nullable=False)
snapshot_date: Mapped[date] = mapped_column(Date, nullable=False, index=True)
account_id: Mapped[str] = mapped_column(String, nullable=False, index=True)
account_name: Mapped[str] = mapped_column(String, nullable=False)
account_type: Mapped[str] = mapped_column(String, nullable=False)
currency: Mapped[str] = mapped_column(String(3), nullable=False, server_default="GBP")
market_value: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False)
market_value_gbp: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False)
cost_basis_gbp: Mapped[Decimal | None] = mapped_column(Numeric(14, 2), nullable=True)
raw_extraction: Mapped[dict[str, Any] | None] = mapped_column(JSON_TYPE, nullable=True)
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True),
nullable=False,
server_default=func.now())
class Scenario(Base):
"""A simulation scenario — Cartesian point in (jurisdiction × strategy ×
leave_year × glide × spending) space. The Cartesian product is rebuilt
from `scenarios.py` every recompute; rows are upserted on `external_id`.
"""
__tablename__ = "scenario"
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
external_id: Mapped[str] = mapped_column(String, unique=True, nullable=False)
jurisdiction: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
strategy: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
leave_uk_year: Mapped[int] = mapped_column(Integer, nullable=False)
glide_path: Mapped[str] = mapped_column(String(32), nullable=False)
spending_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
horizon_years: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("60"))
nw_seed_gbp: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False)
savings_per_year_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2),
nullable=False,
server_default=text("0"))
config_json: Mapped[dict[str, Any]] = mapped_column(JSON_TYPE, nullable=False)
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True),
nullable=False,
server_default=func.now())
class McRun(Base):
"""One MC execution per (scenario, run_at). Stores execution metadata +
summary statistics — enough to populate a Grafana cell without touching
the per-path tables."""
__tablename__ = "mc_run"
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
scenario_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
run_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True),
nullable=False,
server_default=func.now())
n_paths: Mapped[int] = mapped_column(Integer, nullable=False)
seed: Mapped[int] = mapped_column(Integer, nullable=False)
success_rate: Mapped[Decimal] = mapped_column(Numeric(6, 4), nullable=False)
p10_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
p50_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
p90_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
median_lifetime_tax_gbp: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False)
median_years_to_ruin: Mapped[Decimal | None] = mapped_column(Numeric(6, 2), nullable=True)
elapsed_seconds: Mapped[Decimal] = mapped_column(Numeric(8, 3), nullable=False)
sequence_risk_correlation: Mapped[Decimal | None] = mapped_column(Numeric(6, 4), nullable=True)
extra: Mapped[dict[str, Any] | None] = mapped_column(JSON_TYPE, nullable=True)
class McPath(Base):
"""Sparse per-path storage: top decile, bottom decile, and median paths
fully stored — enough for a fan chart, not 10k×60 ≈ 600k rows."""
__tablename__ = "mc_path"
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
mc_run_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
path_idx: Mapped[int] = mapped_column(Integer, nullable=False)
bucket: Mapped[str] = mapped_column(String(16), nullable=False)
year_idx: Mapped[int] = mapped_column(Integer, nullable=False)
portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
withdrawal_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
tax_paid_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
real_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
class ProjectionYearly(Base):
"""Deterministic point projection per scenario — per-year point estimates
that drive fan charts and the per-year Grafana table. One row per
(scenario, year)."""
__tablename__ = "projection_yearly"
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
mc_run_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
year_idx: Mapped[int] = mapped_column(Integer, nullable=False)
p10_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
p25_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
p50_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
p75_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
p90_portfolio_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
p50_withdrawal_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
p50_tax_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
survival_rate: Mapped[Decimal] = mapped_column(Numeric(6, 4), nullable=False)
class ScenarioSummary(Base):
"""Denormalised fast-read for Grafana — one row per (scenario, latest run)."""
__tablename__ = "scenario_summary"
__table_args__ = {"schema": SCHEMA_NAME} # noqa: RUF012
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
scenario_id: Mapped[int] = mapped_column(Integer, unique=True, nullable=False)
mc_run_id: Mapped[int] = mapped_column(Integer, nullable=False)
jurisdiction: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
strategy: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
leave_uk_year: Mapped[int] = mapped_column(Integer, nullable=False)
glide_path: Mapped[str] = mapped_column(String(32), nullable=False)
spending_gbp: Mapped[Decimal] = mapped_column(Numeric(12, 2), nullable=False)
success_rate: Mapped[Decimal] = mapped_column(Numeric(6, 4), nullable=False)
p10_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
p50_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
p90_ending_gbp: Mapped[Decimal] = mapped_column(Numeric(16, 2), nullable=False)
median_lifetime_tax_gbp: Mapped[Decimal] = mapped_column(Numeric(14, 2), nullable=False)
median_years_to_ruin: Mapped[Decimal | None] = mapped_column(Numeric(6, 2), nullable=True)
updated_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True),
nullable=False,
server_default=func.now())
def create_engine_from_env() -> AsyncEngine:
url = os.environ["DB_CONNECTION_STRING"]
return create_async_engine(url, pool_pre_ping=True)
def make_session_factory(engine: AsyncEngine) -> async_sessionmaker[Any]:
return async_sessionmaker(engine, expire_on_commit=False)