fire-planner/fire_planner/reporters/pg.py

268 lines
9.7 KiB
Python
Raw Normal View History

2026-05-07 17:06:19 +00:00
"""Postgres reporter — write MC results into `mc_run`,
`projection_yearly`, `mc_path` (sparse), `scenario_summary`."""
from __future__ import annotations
import time
from dataclasses import dataclass
from decimal import Decimal
from typing import Any
import numpy as np
from sqlalchemy import func, select
2026-05-07 17:06:19 +00:00
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.db import FireTarget, McPath, McRun, ProjectionYearly, Scenario, ScenarioSummary
from fire_planner.fire_target import SolveResult, TargetInputs
2026-05-07 17:06:19 +00:00
from fire_planner.scenarios import ScenarioSpec
from fire_planner.simulator import SimulationResult
def _dialect_insert(session: AsyncSession) -> Any:
bind = session.get_bind()
if bind.dialect.name == "sqlite":
return sqlite_insert
return pg_insert
@dataclass(frozen=True)
class WriteSummary:
scenario_id: int
mc_run_id: int
elapsed_seconds: float
success_rate: float
def _to_dec(x: float | int) -> Decimal:
return Decimal(str(round(float(x), 4)))
async def upsert_scenario(session: AsyncSession, spec: ScenarioSpec) -> int:
insert_ = _dialect_insert(session)
stmt = insert_(Scenario).values(
external_id=spec.external_id,
jurisdiction=spec.jurisdiction,
strategy=spec.strategy,
leave_uk_year=spec.leave_uk_year,
glide_path=spec.glide_path,
spending_gbp=spec.spending_gbp,
horizon_years=spec.horizon_years,
nw_seed_gbp=spec.nw_seed_gbp,
savings_per_year_gbp=spec.savings_per_year_gbp,
config_json=spec.config or {},
)
stmt = stmt.on_conflict_do_update(
index_elements=["external_id"],
set_={
"spending_gbp": stmt.excluded.spending_gbp,
"horizon_years": stmt.excluded.horizon_years,
"nw_seed_gbp": stmt.excluded.nw_seed_gbp,
"savings_per_year_gbp": stmt.excluded.savings_per_year_gbp,
"config_json": stmt.excluded.config_json,
},
)
await session.execute(stmt)
await session.flush()
row = await session.execute(select(Scenario.id).where(Scenario.external_id == spec.external_id))
scenario_id = row.scalar_one()
return int(scenario_id)
async def write_run(
session: AsyncSession,
spec: ScenarioSpec,
result: SimulationResult,
*,
seed: int,
elapsed_seconds: float,
bucket_quantiles: tuple[int, int, int] = (10, 50, 90),
) -> WriteSummary:
"""Upsert scenario, append a new mc_run, persist projection_yearly,
save sparse mc_path rows, and refresh scenario_summary.
"""
started = time.perf_counter()
scenario_id = await upsert_scenario(session, spec)
success_rate = result.success_rate
p10, p50, p90 = (result.ending_percentile(p) for p in bucket_quantiles)
median_tax = result.median_lifetime_tax()
years_to_ruin = result.median_years_to_ruin()
seq_corr = result.sequence_risk_correlation()
run_row = McRun(
scenario_id=scenario_id,
n_paths=result.n_paths,
seed=seed,
success_rate=_to_dec(success_rate),
p10_ending_gbp=_to_dec(p10),
p50_ending_gbp=_to_dec(p50),
p90_ending_gbp=_to_dec(p90),
median_lifetime_tax_gbp=_to_dec(median_tax),
median_years_to_ruin=_to_dec(years_to_ruin) if years_to_ruin is not None else None,
elapsed_seconds=_to_dec(elapsed_seconds),
sequence_risk_correlation=_to_dec(seq_corr),
)
session.add(run_row)
await session.flush()
mc_run_id = int(run_row.id)
await _write_projection(session, mc_run_id, result)
await _write_sparse_paths(session, mc_run_id, result)
await _upsert_summary(session, scenario_id, mc_run_id, spec, result)
await session.flush()
write_elapsed = time.perf_counter() - started
del write_elapsed # surface via tracing if needed
return WriteSummary(
scenario_id=scenario_id,
mc_run_id=mc_run_id,
elapsed_seconds=elapsed_seconds,
success_rate=success_rate,
)
async def _write_projection(session: AsyncSession, mc_run_id: int,
result: SimulationResult) -> None:
n_years = result.n_years
portfolios = result.portfolio_real # (n_paths, n_years+1)
p10 = np.percentile(portfolios, 10, axis=0)
p25 = np.percentile(portfolios, 25, axis=0)
p50 = np.percentile(portfolios, 50, axis=0)
p75 = np.percentile(portfolios, 75, axis=0)
p90 = np.percentile(portfolios, 90, axis=0)
withdrawals = result.withdrawal_real
taxes = result.tax_real
survival = (portfolios[:, 1:] > 0).mean(axis=0)
rows = []
for y in range(n_years):
rows.append(
ProjectionYearly(
mc_run_id=mc_run_id,
year_idx=y,
p10_portfolio_gbp=_to_dec(p10[y + 1]),
p25_portfolio_gbp=_to_dec(p25[y + 1]),
p50_portfolio_gbp=_to_dec(p50[y + 1]),
p75_portfolio_gbp=_to_dec(p75[y + 1]),
p90_portfolio_gbp=_to_dec(p90[y + 1]),
p50_withdrawal_gbp=_to_dec(np.median(withdrawals[:, y])),
p50_tax_gbp=_to_dec(np.median(taxes[:, y])),
survival_rate=_to_dec(float(survival[y])),
))
session.add_all(rows)
async def _write_sparse_paths(session: AsyncSession, mc_run_id: int,
result: SimulationResult) -> None:
"""Persist top-decile, bottom-decile, and median path indices.
Picks 3 representative path indices per bucket to keep storage low.
"""
ending = result.portfolio_real[:, -1]
order = np.argsort(ending)
n = len(order)
buckets = {
"bottom": order[:max(3, n // 20)][:3],
"median": order[n // 2:n // 2 + 3],
"top": order[-max(3, n // 20):][:3],
}
rows: list[McPath] = []
for bucket_name, idxs in buckets.items():
for path_idx in idxs:
for y in range(result.n_years):
rows.append(
McPath(
mc_run_id=mc_run_id,
path_idx=int(path_idx),
bucket=bucket_name,
year_idx=y,
portfolio_gbp=_to_dec(result.portfolio_real[path_idx, y + 1]),
withdrawal_gbp=_to_dec(result.withdrawal_real[path_idx, y]),
tax_paid_gbp=_to_dec(result.tax_real[path_idx, y]),
real_portfolio_gbp=_to_dec(result.portfolio_real[path_idx, y + 1]),
))
session.add_all(rows)
async def _upsert_summary(
session: AsyncSession,
scenario_id: int,
mc_run_id: int,
spec: ScenarioSpec,
result: SimulationResult,
) -> None:
insert_ = _dialect_insert(session)
stmt = insert_(ScenarioSummary).values(
scenario_id=scenario_id,
mc_run_id=mc_run_id,
jurisdiction=spec.jurisdiction,
strategy=spec.strategy,
leave_uk_year=spec.leave_uk_year,
glide_path=spec.glide_path,
spending_gbp=spec.spending_gbp,
success_rate=_to_dec(result.success_rate),
p10_ending_gbp=_to_dec(result.ending_percentile(10)),
p50_ending_gbp=_to_dec(result.ending_percentile(50)),
p90_ending_gbp=_to_dec(result.ending_percentile(90)),
median_lifetime_tax_gbp=_to_dec(result.median_lifetime_tax()),
median_years_to_ruin=(_to_dec(ytr) if
(ytr := result.median_years_to_ruin()) is not None else None),
)
stmt = stmt.on_conflict_do_update(
index_elements=["scenario_id"],
set_={
"mc_run_id": stmt.excluded.mc_run_id,
"success_rate": stmt.excluded.success_rate,
"p10_ending_gbp": stmt.excluded.p10_ending_gbp,
"p50_ending_gbp": stmt.excluded.p50_ending_gbp,
"p90_ending_gbp": stmt.excluded.p90_ending_gbp,
"median_lifetime_tax_gbp": stmt.excluded.median_lifetime_tax_gbp,
"median_years_to_ruin": stmt.excluded.median_years_to_ruin,
},
)
await session.execute(stmt)
async def upsert_fire_target(
session: AsyncSession,
inp: TargetInputs,
result: SolveResult,
n_paths: int,
) -> None:
"""Upsert one solved FIRE target on (case, country, with_home, bar)."""
insert_ = _dialect_insert(session)
stmt = insert_(FireTarget).values(
case=inp.case.value,
country_slug=inp.country_slug,
country_display=inp.country_display,
jurisdiction=inp.jurisdiction,
with_home=inp.with_home,
bar=_to_dec(inp.bar),
strategy="guyton_klinger",
annual_spend_gbp=_to_dec(inp.annual_spend_gbp),
target_nw_gbp=_to_dec(result.target_nw_gbp),
pension_at_unlock_gbp=_to_dec(result.pension_at_unlock_gbp),
success_at_target=_to_dec(result.success_at_target),
reached_bar=result.reached_bar,
horizon_years=inp.horizon_years,
n_paths=n_paths,
)
stmt = stmt.on_conflict_do_update(
index_elements=["case", "country_slug", "with_home", "bar"],
set_={
"country_display": stmt.excluded.country_display,
"jurisdiction": stmt.excluded.jurisdiction,
"annual_spend_gbp": stmt.excluded.annual_spend_gbp,
"target_nw_gbp": stmt.excluded.target_nw_gbp,
"pension_at_unlock_gbp": stmt.excluded.pension_at_unlock_gbp,
"success_at_target": stmt.excluded.success_at_target,
"reached_bar": stmt.excluded.reached_bar,
"horizon_years": stmt.excluded.horizon_years,
"n_paths": stmt.excluded.n_paths,
"updated_at": func.now(),
},
)
await session.execute(stmt)