fire-planner/fire_planner/reporters/pg.py
2026-05-07 17:06:19 +00:00

224 lines
8 KiB
Python

"""Postgres reporter — write MC results into `mc_run`,
`projection_yearly`, `mc_path` (sparse), `scenario_summary`."""
from __future__ import annotations
import time
from dataclasses import dataclass
from decimal import Decimal
from typing import Any
import numpy as np
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import McPath, McRun, ProjectionYearly, Scenario, ScenarioSummary
from fire_planner.scenarios import ScenarioSpec
from fire_planner.simulator import SimulationResult
def _dialect_insert(session: AsyncSession) -> Any:
bind = session.get_bind()
if bind.dialect.name == "sqlite":
return sqlite_insert
return pg_insert
@dataclass(frozen=True)
class WriteSummary:
scenario_id: int
mc_run_id: int
elapsed_seconds: float
success_rate: float
def _to_dec(x: float | int) -> Decimal:
return Decimal(str(round(float(x), 4)))
async def upsert_scenario(session: AsyncSession, spec: ScenarioSpec) -> int:
insert_ = _dialect_insert(session)
stmt = insert_(Scenario).values(
external_id=spec.external_id,
jurisdiction=spec.jurisdiction,
strategy=spec.strategy,
leave_uk_year=spec.leave_uk_year,
glide_path=spec.glide_path,
spending_gbp=spec.spending_gbp,
horizon_years=spec.horizon_years,
nw_seed_gbp=spec.nw_seed_gbp,
savings_per_year_gbp=spec.savings_per_year_gbp,
config_json=spec.config or {},
)
stmt = stmt.on_conflict_do_update(
index_elements=["external_id"],
set_={
"spending_gbp": stmt.excluded.spending_gbp,
"horizon_years": stmt.excluded.horizon_years,
"nw_seed_gbp": stmt.excluded.nw_seed_gbp,
"savings_per_year_gbp": stmt.excluded.savings_per_year_gbp,
"config_json": stmt.excluded.config_json,
},
)
await session.execute(stmt)
await session.flush()
row = await session.execute(select(Scenario.id).where(Scenario.external_id == spec.external_id))
scenario_id = row.scalar_one()
return int(scenario_id)
async def write_run(
session: AsyncSession,
spec: ScenarioSpec,
result: SimulationResult,
*,
seed: int,
elapsed_seconds: float,
bucket_quantiles: tuple[int, int, int] = (10, 50, 90),
) -> WriteSummary:
"""Upsert scenario, append a new mc_run, persist projection_yearly,
save sparse mc_path rows, and refresh scenario_summary.
"""
started = time.perf_counter()
scenario_id = await upsert_scenario(session, spec)
success_rate = result.success_rate
p10, p50, p90 = (result.ending_percentile(p) for p in bucket_quantiles)
median_tax = result.median_lifetime_tax()
years_to_ruin = result.median_years_to_ruin()
seq_corr = result.sequence_risk_correlation()
run_row = McRun(
scenario_id=scenario_id,
n_paths=result.n_paths,
seed=seed,
success_rate=_to_dec(success_rate),
p10_ending_gbp=_to_dec(p10),
p50_ending_gbp=_to_dec(p50),
p90_ending_gbp=_to_dec(p90),
median_lifetime_tax_gbp=_to_dec(median_tax),
median_years_to_ruin=_to_dec(years_to_ruin) if years_to_ruin is not None else None,
elapsed_seconds=_to_dec(elapsed_seconds),
sequence_risk_correlation=_to_dec(seq_corr),
)
session.add(run_row)
await session.flush()
mc_run_id = int(run_row.id)
await _write_projection(session, mc_run_id, result)
await _write_sparse_paths(session, mc_run_id, result)
await _upsert_summary(session, scenario_id, mc_run_id, spec, result)
await session.flush()
write_elapsed = time.perf_counter() - started
del write_elapsed # surface via tracing if needed
return WriteSummary(
scenario_id=scenario_id,
mc_run_id=mc_run_id,
elapsed_seconds=elapsed_seconds,
success_rate=success_rate,
)
async def _write_projection(session: AsyncSession, mc_run_id: int,
result: SimulationResult) -> None:
n_years = result.n_years
portfolios = result.portfolio_real # (n_paths, n_years+1)
p10 = np.percentile(portfolios, 10, axis=0)
p25 = np.percentile(portfolios, 25, axis=0)
p50 = np.percentile(portfolios, 50, axis=0)
p75 = np.percentile(portfolios, 75, axis=0)
p90 = np.percentile(portfolios, 90, axis=0)
withdrawals = result.withdrawal_real
taxes = result.tax_real
survival = (portfolios[:, 1:] > 0).mean(axis=0)
rows = []
for y in range(n_years):
rows.append(
ProjectionYearly(
mc_run_id=mc_run_id,
year_idx=y,
p10_portfolio_gbp=_to_dec(p10[y + 1]),
p25_portfolio_gbp=_to_dec(p25[y + 1]),
p50_portfolio_gbp=_to_dec(p50[y + 1]),
p75_portfolio_gbp=_to_dec(p75[y + 1]),
p90_portfolio_gbp=_to_dec(p90[y + 1]),
p50_withdrawal_gbp=_to_dec(np.median(withdrawals[:, y])),
p50_tax_gbp=_to_dec(np.median(taxes[:, y])),
survival_rate=_to_dec(float(survival[y])),
))
session.add_all(rows)
async def _write_sparse_paths(session: AsyncSession, mc_run_id: int,
result: SimulationResult) -> None:
"""Persist top-decile, bottom-decile, and median path indices.
Picks 3 representative path indices per bucket to keep storage low.
"""
ending = result.portfolio_real[:, -1]
order = np.argsort(ending)
n = len(order)
buckets = {
"bottom": order[:max(3, n // 20)][:3],
"median": order[n // 2:n // 2 + 3],
"top": order[-max(3, n // 20):][:3],
}
rows: list[McPath] = []
for bucket_name, idxs in buckets.items():
for path_idx in idxs:
for y in range(result.n_years):
rows.append(
McPath(
mc_run_id=mc_run_id,
path_idx=int(path_idx),
bucket=bucket_name,
year_idx=y,
portfolio_gbp=_to_dec(result.portfolio_real[path_idx, y + 1]),
withdrawal_gbp=_to_dec(result.withdrawal_real[path_idx, y]),
tax_paid_gbp=_to_dec(result.tax_real[path_idx, y]),
real_portfolio_gbp=_to_dec(result.portfolio_real[path_idx, y + 1]),
))
session.add_all(rows)
async def _upsert_summary(
session: AsyncSession,
scenario_id: int,
mc_run_id: int,
spec: ScenarioSpec,
result: SimulationResult,
) -> None:
insert_ = _dialect_insert(session)
stmt = insert_(ScenarioSummary).values(
scenario_id=scenario_id,
mc_run_id=mc_run_id,
jurisdiction=spec.jurisdiction,
strategy=spec.strategy,
leave_uk_year=spec.leave_uk_year,
glide_path=spec.glide_path,
spending_gbp=spec.spending_gbp,
success_rate=_to_dec(result.success_rate),
p10_ending_gbp=_to_dec(result.ending_percentile(10)),
p50_ending_gbp=_to_dec(result.ending_percentile(50)),
p90_ending_gbp=_to_dec(result.ending_percentile(90)),
median_lifetime_tax_gbp=_to_dec(result.median_lifetime_tax()),
median_years_to_ruin=(_to_dec(ytr) if
(ytr := result.median_years_to_ruin()) is not None else None),
)
stmt = stmt.on_conflict_do_update(
index_elements=["scenario_id"],
set_={
"mc_run_id": stmt.excluded.mc_run_id,
"success_rate": stmt.excluded.success_rate,
"p10_ending_gbp": stmt.excluded.p10_ending_gbp,
"p50_ending_gbp": stmt.excluded.p50_ending_gbp,
"p90_ending_gbp": stmt.excluded.p90_ending_gbp,
"median_lifetime_tax_gbp": stmt.excluded.median_lifetime_tax_gbp,
"median_years_to_ruin": stmt.excluded.median_years_to_ruin,
},
)
await session.execute(stmt)