"""Postgres reporter — write MC results into `mc_run`, `projection_yearly`, `mc_path` (sparse), `scenario_summary`.""" from __future__ import annotations import time from dataclasses import dataclass from decimal import Decimal from typing import Any import numpy as np from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.ext.asyncio import AsyncSession from fire_planner.db import McPath, McRun, ProjectionYearly, Scenario, ScenarioSummary from fire_planner.scenarios import ScenarioSpec from fire_planner.simulator import SimulationResult def _dialect_insert(session: AsyncSession) -> Any: bind = session.get_bind() if bind.dialect.name == "sqlite": return sqlite_insert return pg_insert @dataclass(frozen=True) class WriteSummary: scenario_id: int mc_run_id: int elapsed_seconds: float success_rate: float def _to_dec(x: float | int) -> Decimal: return Decimal(str(round(float(x), 4))) async def upsert_scenario(session: AsyncSession, spec: ScenarioSpec) -> int: insert_ = _dialect_insert(session) stmt = insert_(Scenario).values( external_id=spec.external_id, jurisdiction=spec.jurisdiction, strategy=spec.strategy, leave_uk_year=spec.leave_uk_year, glide_path=spec.glide_path, spending_gbp=spec.spending_gbp, horizon_years=spec.horizon_years, nw_seed_gbp=spec.nw_seed_gbp, savings_per_year_gbp=spec.savings_per_year_gbp, config_json=spec.config or {}, ) stmt = stmt.on_conflict_do_update( index_elements=["external_id"], set_={ "spending_gbp": stmt.excluded.spending_gbp, "horizon_years": stmt.excluded.horizon_years, "nw_seed_gbp": stmt.excluded.nw_seed_gbp, "savings_per_year_gbp": stmt.excluded.savings_per_year_gbp, "config_json": stmt.excluded.config_json, }, ) await session.execute(stmt) await session.flush() row = await session.execute(select(Scenario.id).where(Scenario.external_id == spec.external_id)) scenario_id = row.scalar_one() return int(scenario_id) async def write_run( session: AsyncSession, spec: ScenarioSpec, result: SimulationResult, *, seed: int, elapsed_seconds: float, bucket_quantiles: tuple[int, int, int] = (10, 50, 90), ) -> WriteSummary: """Upsert scenario, append a new mc_run, persist projection_yearly, save sparse mc_path rows, and refresh scenario_summary. """ started = time.perf_counter() scenario_id = await upsert_scenario(session, spec) success_rate = result.success_rate p10, p50, p90 = (result.ending_percentile(p) for p in bucket_quantiles) median_tax = result.median_lifetime_tax() years_to_ruin = result.median_years_to_ruin() seq_corr = result.sequence_risk_correlation() run_row = McRun( scenario_id=scenario_id, n_paths=result.n_paths, seed=seed, success_rate=_to_dec(success_rate), p10_ending_gbp=_to_dec(p10), p50_ending_gbp=_to_dec(p50), p90_ending_gbp=_to_dec(p90), median_lifetime_tax_gbp=_to_dec(median_tax), median_years_to_ruin=_to_dec(years_to_ruin) if years_to_ruin is not None else None, elapsed_seconds=_to_dec(elapsed_seconds), sequence_risk_correlation=_to_dec(seq_corr), ) session.add(run_row) await session.flush() mc_run_id = int(run_row.id) await _write_projection(session, mc_run_id, result) await _write_sparse_paths(session, mc_run_id, result) await _upsert_summary(session, scenario_id, mc_run_id, spec, result) await session.flush() write_elapsed = time.perf_counter() - started del write_elapsed # surface via tracing if needed return WriteSummary( scenario_id=scenario_id, mc_run_id=mc_run_id, elapsed_seconds=elapsed_seconds, success_rate=success_rate, ) async def _write_projection(session: AsyncSession, mc_run_id: int, result: SimulationResult) -> None: n_years = result.n_years portfolios = result.portfolio_real # (n_paths, n_years+1) p10 = np.percentile(portfolios, 10, axis=0) p25 = np.percentile(portfolios, 25, axis=0) p50 = np.percentile(portfolios, 50, axis=0) p75 = np.percentile(portfolios, 75, axis=0) p90 = np.percentile(portfolios, 90, axis=0) withdrawals = result.withdrawal_real taxes = result.tax_real survival = (portfolios[:, 1:] > 0).mean(axis=0) rows = [] for y in range(n_years): rows.append( ProjectionYearly( mc_run_id=mc_run_id, year_idx=y, p10_portfolio_gbp=_to_dec(p10[y + 1]), p25_portfolio_gbp=_to_dec(p25[y + 1]), p50_portfolio_gbp=_to_dec(p50[y + 1]), p75_portfolio_gbp=_to_dec(p75[y + 1]), p90_portfolio_gbp=_to_dec(p90[y + 1]), p50_withdrawal_gbp=_to_dec(np.median(withdrawals[:, y])), p50_tax_gbp=_to_dec(np.median(taxes[:, y])), survival_rate=_to_dec(float(survival[y])), )) session.add_all(rows) async def _write_sparse_paths(session: AsyncSession, mc_run_id: int, result: SimulationResult) -> None: """Persist top-decile, bottom-decile, and median path indices. Picks 3 representative path indices per bucket to keep storage low. """ ending = result.portfolio_real[:, -1] order = np.argsort(ending) n = len(order) buckets = { "bottom": order[:max(3, n // 20)][:3], "median": order[n // 2:n // 2 + 3], "top": order[-max(3, n // 20):][:3], } rows: list[McPath] = [] for bucket_name, idxs in buckets.items(): for path_idx in idxs: for y in range(result.n_years): rows.append( McPath( mc_run_id=mc_run_id, path_idx=int(path_idx), bucket=bucket_name, year_idx=y, portfolio_gbp=_to_dec(result.portfolio_real[path_idx, y + 1]), withdrawal_gbp=_to_dec(result.withdrawal_real[path_idx, y]), tax_paid_gbp=_to_dec(result.tax_real[path_idx, y]), real_portfolio_gbp=_to_dec(result.portfolio_real[path_idx, y + 1]), )) session.add_all(rows) async def _upsert_summary( session: AsyncSession, scenario_id: int, mc_run_id: int, spec: ScenarioSpec, result: SimulationResult, ) -> None: insert_ = _dialect_insert(session) stmt = insert_(ScenarioSummary).values( scenario_id=scenario_id, mc_run_id=mc_run_id, jurisdiction=spec.jurisdiction, strategy=spec.strategy, leave_uk_year=spec.leave_uk_year, glide_path=spec.glide_path, spending_gbp=spec.spending_gbp, success_rate=_to_dec(result.success_rate), p10_ending_gbp=_to_dec(result.ending_percentile(10)), p50_ending_gbp=_to_dec(result.ending_percentile(50)), p90_ending_gbp=_to_dec(result.ending_percentile(90)), median_lifetime_tax_gbp=_to_dec(result.median_lifetime_tax()), median_years_to_ruin=(_to_dec(ytr) if (ytr := result.median_years_to_ruin()) is not None else None), ) stmt = stmt.on_conflict_do_update( index_elements=["scenario_id"], set_={ "mc_run_id": stmt.excluded.mc_run_id, "success_rate": stmt.excluded.success_rate, "p10_ending_gbp": stmt.excluded.p10_ending_gbp, "p50_ending_gbp": stmt.excluded.p50_ending_gbp, "p90_ending_gbp": stmt.excluded.p90_ending_gbp, "median_lifetime_tax_gbp": stmt.excluded.median_lifetime_tax_gbp, "median_years_to_ruin": stmt.excluded.median_years_to_ruin, }, ) await session.execute(stmt)