224 lines
8 KiB
Python
224 lines
8 KiB
Python
"""Postgres reporter — write MC results into `mc_run`,
|
|
`projection_yearly`, `mc_path` (sparse), `scenario_summary`."""
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from dataclasses import dataclass
|
|
from decimal import Decimal
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
from sqlalchemy import select
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from fire_planner.db import McPath, McRun, ProjectionYearly, Scenario, ScenarioSummary
|
|
from fire_planner.scenarios import ScenarioSpec
|
|
from fire_planner.simulator import SimulationResult
|
|
|
|
|
|
def _dialect_insert(session: AsyncSession) -> Any:
|
|
bind = session.get_bind()
|
|
if bind.dialect.name == "sqlite":
|
|
return sqlite_insert
|
|
return pg_insert
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class WriteSummary:
|
|
scenario_id: int
|
|
mc_run_id: int
|
|
elapsed_seconds: float
|
|
success_rate: float
|
|
|
|
|
|
def _to_dec(x: float | int) -> Decimal:
|
|
return Decimal(str(round(float(x), 4)))
|
|
|
|
|
|
async def upsert_scenario(session: AsyncSession, spec: ScenarioSpec) -> int:
|
|
insert_ = _dialect_insert(session)
|
|
stmt = insert_(Scenario).values(
|
|
external_id=spec.external_id,
|
|
jurisdiction=spec.jurisdiction,
|
|
strategy=spec.strategy,
|
|
leave_uk_year=spec.leave_uk_year,
|
|
glide_path=spec.glide_path,
|
|
spending_gbp=spec.spending_gbp,
|
|
horizon_years=spec.horizon_years,
|
|
nw_seed_gbp=spec.nw_seed_gbp,
|
|
savings_per_year_gbp=spec.savings_per_year_gbp,
|
|
config_json=spec.config or {},
|
|
)
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=["external_id"],
|
|
set_={
|
|
"spending_gbp": stmt.excluded.spending_gbp,
|
|
"horizon_years": stmt.excluded.horizon_years,
|
|
"nw_seed_gbp": stmt.excluded.nw_seed_gbp,
|
|
"savings_per_year_gbp": stmt.excluded.savings_per_year_gbp,
|
|
"config_json": stmt.excluded.config_json,
|
|
},
|
|
)
|
|
await session.execute(stmt)
|
|
await session.flush()
|
|
|
|
row = await session.execute(select(Scenario.id).where(Scenario.external_id == spec.external_id))
|
|
scenario_id = row.scalar_one()
|
|
return int(scenario_id)
|
|
|
|
|
|
async def write_run(
|
|
session: AsyncSession,
|
|
spec: ScenarioSpec,
|
|
result: SimulationResult,
|
|
*,
|
|
seed: int,
|
|
elapsed_seconds: float,
|
|
bucket_quantiles: tuple[int, int, int] = (10, 50, 90),
|
|
) -> WriteSummary:
|
|
"""Upsert scenario, append a new mc_run, persist projection_yearly,
|
|
save sparse mc_path rows, and refresh scenario_summary.
|
|
"""
|
|
started = time.perf_counter()
|
|
scenario_id = await upsert_scenario(session, spec)
|
|
|
|
success_rate = result.success_rate
|
|
p10, p50, p90 = (result.ending_percentile(p) for p in bucket_quantiles)
|
|
median_tax = result.median_lifetime_tax()
|
|
years_to_ruin = result.median_years_to_ruin()
|
|
seq_corr = result.sequence_risk_correlation()
|
|
|
|
run_row = McRun(
|
|
scenario_id=scenario_id,
|
|
n_paths=result.n_paths,
|
|
seed=seed,
|
|
success_rate=_to_dec(success_rate),
|
|
p10_ending_gbp=_to_dec(p10),
|
|
p50_ending_gbp=_to_dec(p50),
|
|
p90_ending_gbp=_to_dec(p90),
|
|
median_lifetime_tax_gbp=_to_dec(median_tax),
|
|
median_years_to_ruin=_to_dec(years_to_ruin) if years_to_ruin is not None else None,
|
|
elapsed_seconds=_to_dec(elapsed_seconds),
|
|
sequence_risk_correlation=_to_dec(seq_corr),
|
|
)
|
|
session.add(run_row)
|
|
await session.flush()
|
|
mc_run_id = int(run_row.id)
|
|
|
|
await _write_projection(session, mc_run_id, result)
|
|
await _write_sparse_paths(session, mc_run_id, result)
|
|
await _upsert_summary(session, scenario_id, mc_run_id, spec, result)
|
|
await session.flush()
|
|
|
|
write_elapsed = time.perf_counter() - started
|
|
del write_elapsed # surface via tracing if needed
|
|
return WriteSummary(
|
|
scenario_id=scenario_id,
|
|
mc_run_id=mc_run_id,
|
|
elapsed_seconds=elapsed_seconds,
|
|
success_rate=success_rate,
|
|
)
|
|
|
|
|
|
async def _write_projection(session: AsyncSession, mc_run_id: int,
|
|
result: SimulationResult) -> None:
|
|
n_years = result.n_years
|
|
portfolios = result.portfolio_real # (n_paths, n_years+1)
|
|
p10 = np.percentile(portfolios, 10, axis=0)
|
|
p25 = np.percentile(portfolios, 25, axis=0)
|
|
p50 = np.percentile(portfolios, 50, axis=0)
|
|
p75 = np.percentile(portfolios, 75, axis=0)
|
|
p90 = np.percentile(portfolios, 90, axis=0)
|
|
|
|
withdrawals = result.withdrawal_real
|
|
taxes = result.tax_real
|
|
survival = (portfolios[:, 1:] > 0).mean(axis=0)
|
|
|
|
rows = []
|
|
for y in range(n_years):
|
|
rows.append(
|
|
ProjectionYearly(
|
|
mc_run_id=mc_run_id,
|
|
year_idx=y,
|
|
p10_portfolio_gbp=_to_dec(p10[y + 1]),
|
|
p25_portfolio_gbp=_to_dec(p25[y + 1]),
|
|
p50_portfolio_gbp=_to_dec(p50[y + 1]),
|
|
p75_portfolio_gbp=_to_dec(p75[y + 1]),
|
|
p90_portfolio_gbp=_to_dec(p90[y + 1]),
|
|
p50_withdrawal_gbp=_to_dec(np.median(withdrawals[:, y])),
|
|
p50_tax_gbp=_to_dec(np.median(taxes[:, y])),
|
|
survival_rate=_to_dec(float(survival[y])),
|
|
))
|
|
session.add_all(rows)
|
|
|
|
|
|
async def _write_sparse_paths(session: AsyncSession, mc_run_id: int,
|
|
result: SimulationResult) -> None:
|
|
"""Persist top-decile, bottom-decile, and median path indices.
|
|
Picks 3 representative path indices per bucket to keep storage low.
|
|
"""
|
|
ending = result.portfolio_real[:, -1]
|
|
order = np.argsort(ending)
|
|
n = len(order)
|
|
buckets = {
|
|
"bottom": order[:max(3, n // 20)][:3],
|
|
"median": order[n // 2:n // 2 + 3],
|
|
"top": order[-max(3, n // 20):][:3],
|
|
}
|
|
rows: list[McPath] = []
|
|
for bucket_name, idxs in buckets.items():
|
|
for path_idx in idxs:
|
|
for y in range(result.n_years):
|
|
rows.append(
|
|
McPath(
|
|
mc_run_id=mc_run_id,
|
|
path_idx=int(path_idx),
|
|
bucket=bucket_name,
|
|
year_idx=y,
|
|
portfolio_gbp=_to_dec(result.portfolio_real[path_idx, y + 1]),
|
|
withdrawal_gbp=_to_dec(result.withdrawal_real[path_idx, y]),
|
|
tax_paid_gbp=_to_dec(result.tax_real[path_idx, y]),
|
|
real_portfolio_gbp=_to_dec(result.portfolio_real[path_idx, y + 1]),
|
|
))
|
|
session.add_all(rows)
|
|
|
|
|
|
async def _upsert_summary(
|
|
session: AsyncSession,
|
|
scenario_id: int,
|
|
mc_run_id: int,
|
|
spec: ScenarioSpec,
|
|
result: SimulationResult,
|
|
) -> None:
|
|
insert_ = _dialect_insert(session)
|
|
stmt = insert_(ScenarioSummary).values(
|
|
scenario_id=scenario_id,
|
|
mc_run_id=mc_run_id,
|
|
jurisdiction=spec.jurisdiction,
|
|
strategy=spec.strategy,
|
|
leave_uk_year=spec.leave_uk_year,
|
|
glide_path=spec.glide_path,
|
|
spending_gbp=spec.spending_gbp,
|
|
success_rate=_to_dec(result.success_rate),
|
|
p10_ending_gbp=_to_dec(result.ending_percentile(10)),
|
|
p50_ending_gbp=_to_dec(result.ending_percentile(50)),
|
|
p90_ending_gbp=_to_dec(result.ending_percentile(90)),
|
|
median_lifetime_tax_gbp=_to_dec(result.median_lifetime_tax()),
|
|
median_years_to_ruin=(_to_dec(ytr) if
|
|
(ytr := result.median_years_to_ruin()) is not None else None),
|
|
)
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=["scenario_id"],
|
|
set_={
|
|
"mc_run_id": stmt.excluded.mc_run_id,
|
|
"success_rate": stmt.excluded.success_rate,
|
|
"p10_ending_gbp": stmt.excluded.p10_ending_gbp,
|
|
"p50_ending_gbp": stmt.excluded.p50_ending_gbp,
|
|
"p90_ending_gbp": stmt.excluded.p90_ending_gbp,
|
|
"median_lifetime_tax_gbp": stmt.excluded.median_lifetime_tax_gbp,
|
|
"median_years_to_ruin": stmt.excluded.median_years_to_ruin,
|
|
},
|
|
)
|
|
await session.execute(stmt)
|