"""Sync simulate + multi-scenario compare. Unlike the persisted Cartesian recompute (`/recompute`), these run a single scenario inline and return the result immediately. The React UI uses these for what-if exploration — no DB write. Returns a fan-chart series in the same shape as `GET /scenarios/{id}/projection`, so frontend chart code is shared. """ from __future__ import annotations import asyncio import time from decimal import Decimal from pathlib import Path import numpy as np from fastapi import APIRouter, HTTPException from sqlalchemy.ext.asyncio import async_sessionmaker from fire_planner.api.schemas import ( CompareRequest, CompareResult, GoalProbability, ProjectionPoint, SimulateRequest, SimulateResult, ) from fire_planner.flex_spending import FlexRule as EngineFlexRule from fire_planner.glide_path import static from fire_planner.goals_eval import evaluate_goals from fire_planner.income_streams import IncomeStreamInput, streams_to_arrays from fire_planner.ingest.wealthfolio_pg import create_wf_sync_engine_from_env from fire_planner.life_events import ( EventInput, events_to_cashflow_array, events_to_category_outflows, ) from fire_planner.returns.bootstrap import block_bootstrap from fire_planner.returns.shiller import load_from_csv, synthetic_returns from fire_planner.returns.wealthfolio_returns import ( compute_annual_returns_from_pg, constant_real_return_paths, ) from fire_planner.scenarios import build_regime_schedule, build_strategy from fire_planner.simulator import SimulationResult, build_fixed_paths, simulate router = APIRouter(tags=["simulate"]) _RETURNS_CSV = Path("/data/shiller_returns.csv") def _shiller_paths(seed: int, n_paths: int, n_years: int) -> np.ndarray: bundle = (load_from_csv(_RETURNS_CSV) if _RETURNS_CSV.exists() else synthetic_returns(seed=42)) rng = np.random.default_rng(seed) return block_bootstrap(bundle, n_paths=n_paths, n_years=n_years, block_size=5, rng=rng) async def _wealthfolio_paths(seed: int, n_paths: int, n_years: int) -> np.ndarray: """Block-bootstrap the user's actual blended real returns. With typically <10 distinct annual samples, block_size=1 is appropriate — there's no serial-correlation signal to preserve.""" eng = create_wf_sync_engine_from_env() try: factory = async_sessionmaker(eng, expire_on_commit=False) async with factory() as wf_sess: bundle = await compute_annual_returns_from_pg(wf_sess) finally: await eng.dispose() rng = np.random.default_rng(seed) return block_bootstrap(bundle, n_paths=n_paths, n_years=n_years, block_size=1, rng=rng) async def _build_paths(req: SimulateRequest) -> np.ndarray: if req.rates_mode == "fixed": return build_fixed_paths( n_paths=req.n_paths, n_years=req.horizon_years, inflation_pct=float(req.inflation_pct), stocks_growth_pct=float(req.stocks_growth_pct), stocks_dividend_pct=float(req.stocks_dividend_pct), bonds_growth_pct=float(req.bonds_growth_pct), bonds_dividend_pct=float(req.bonds_dividend_pct), ) if req.returns_mode == "manual": if req.manual_real_return_pct is None: raise HTTPException( status_code=400, detail="manual_real_return_pct is required when returns_mode='manual'", ) return constant_real_return_paths( n_paths=req.n_paths, n_years=req.horizon_years, real_return_pct=float(req.manual_real_return_pct), ) if req.returns_mode == "wealthfolio": try: return await _wealthfolio_paths(req.seed, req.n_paths, req.horizon_years) except ValueError as e: raise HTTPException( status_code=400, detail=f"Wealthfolio history insufficient: {e}", ) from e return _shiller_paths(req.seed, req.n_paths, req.horizon_years) def _project(req: SimulateRequest, paths: np.ndarray) -> tuple[SimulationResult, float]: annual_savings = (np.full(req.horizon_years, float(req.savings_per_year_gbp), dtype=np.float64) if req.savings_per_year_gbp > 0 else None) floor = float(req.floor_gbp) if req.floor_gbp is not None else None cashflow_adjustments = None discretionary_outflows = None extra_outflows = None if req.life_events: engine_events = [ EventInput( year_start=ev.year_start, year_end=ev.year_end, delta_gbp_per_year=float(ev.delta_gbp_per_year), one_time_amount_gbp=(float(ev.one_time_amount_gbp) if ev.one_time_amount_gbp is not None else None), category=ev.category, enabled=ev.enabled, ) for ev in req.life_events ] cashflow_adjustments = events_to_cashflow_array(engine_events, req.horizon_years) category_outflows = events_to_category_outflows(engine_events, req.horizon_years) discretionary_outflows = category_outflows.get("discretionary") # extra_outflows feeds the withdrawal-trace display: total of # essential + discretionary spending events surfaces alongside # the strategy's draw on the chart. essential = category_outflows.get("essential") if essential is not None and discretionary_outflows is not None: extra_outflows = essential + discretionary_outflows engine_flex = [ EngineFlexRule( from_ath_pct=float(r.from_ath_pct), cut_discretionary_pct=float(r.cut_discretionary_pct), ) for r in req.flex_rules ] if req.flex_rules else None income_inflows = None income_taxable = None if req.income_streams: engine_streams = [ IncomeStreamInput( kind=s.kind, start_year=s.start_year, end_year=s.end_year, amount_gbp_per_year=float(s.amount_gbp_per_year), growth_pct=float(s.growth_pct), tax_treatment=s.tax_treatment, enabled=s.enabled, ) for s in req.income_streams ] income_inflows, income_taxable = streams_to_arrays(engine_streams, req.horizon_years) strategy = build_strategy( req.strategy, floor=floor, annual_real_adjust_pct=float(req.annual_real_adjust_pct), guardrail_threshold_pct=(float(req.guardrail_threshold_pct) if req.guardrail_threshold_pct is not None else None), guardrail_cut_pct=float(req.guardrail_cut_pct), ) glide_alloc = float(req.stocks_allocation) if req.rates_mode == "fixed" else 1.0 started = time.perf_counter() result = simulate( paths=paths, initial_portfolio=float(req.nw_seed_gbp), spending_target=float(req.spending_gbp), glide=static(glide_alloc), strategy=strategy, regime=build_regime_schedule(req.jurisdiction, req.leave_uk_year), horizon_years=req.horizon_years, annual_savings=annual_savings, cashflow_adjustments=cashflow_adjustments, income_inflows=income_inflows, income_taxable=income_taxable, discretionary_outflows=discretionary_outflows, extra_outflows=extra_outflows, flex_rules=engine_flex, ) elapsed = time.perf_counter() - started return result, elapsed def _to_response( result: SimulationResult, elapsed: float, req: SimulateRequest | None = None, ) -> SimulateResult: # portfolio_real has n_years+1 columns (year 0 = seed, year k = end-of-year k). # withdrawal_real / tax_real have n_years columns (year k = withdrawn in year k+1). # Yearly point k describes "end of year k+1": portfolio after withdrawal & growth. pcts = [10, 25, 50, 75, 90] portfolio_quantiles = {p: np.percentile(result.portfolio_real, p, axis=0) for p in pcts} median_wd = np.percentile(result.withdrawal_real, 50, axis=0) median_tax = np.percentile(result.tax_real, 50, axis=0) n_years = result.n_years survival_path = (result.success_mask.astype(np.float64).mean(axis=0) if result.success_mask.ndim == 2 else np.ones(n_years)) yearly = [ ProjectionPoint( year_idx=y, p10_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[10][y + 1]), 2))), p25_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[25][y + 1]), 2))), p50_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[50][y + 1]), 2))), p75_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[75][y + 1]), 2))), p90_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[90][y + 1]), 2))), p50_withdrawal_gbp=Decimal(str(round(float(median_wd[y]), 2))), p50_tax_gbp=Decimal(str(round(float(median_tax[y]), 2))), survival_rate=Decimal(str(round(float(survival_path[y]), 4))), ) for y in range(n_years) ] median_ytr = result.median_years_to_ruin() goals_probability: list[GoalProbability] = [] if req is not None and req.goals: evaluations = evaluate_goals(result, req.goals, req.horizon_years) goals_probability = [ GoalProbability( goal_id=None, name=ev.name, kind=ev.kind, probability=Decimal(str(round(ev.probability, 4))), threshold=Decimal(str(round(ev.threshold, 4))), passed=ev.passed, ) for ev in evaluations ] return SimulateResult( success_rate=Decimal(str(round(float(result.success_rate), 4))), p10_ending_gbp=Decimal(str(round(float(result.ending_percentile(10)), 2))), p50_ending_gbp=Decimal(str(round(float(result.ending_percentile(50)), 2))), p90_ending_gbp=Decimal(str(round(float(result.ending_percentile(90)), 2))), median_lifetime_tax_gbp=Decimal(str(round(float(result.median_lifetime_tax()), 2))), median_years_to_ruin=(Decimal(str(round(float(median_ytr), 2))) if median_ytr is not None else None), elapsed_seconds=Decimal(str(round(elapsed, 3))), yearly=yearly, goals_probability=goals_probability, ) @router.post("/simulate", response_model=SimulateResult) async def simulate_one(req: SimulateRequest) -> SimulateResult: """Run one scenario synchronously, no DB write. ~1-3s for 5k paths.""" paths = await _build_paths(req) try: result, elapsed = await asyncio.to_thread(_project, req, paths) except KeyError as e: raise HTTPException(status_code=400, detail=f"Unknown name: {e}") from None return _to_response(result, elapsed, req) @router.post("/compare", response_model=CompareResult) async def compare_scenarios(req: CompareRequest) -> CompareResult: """Run 2-5 scenarios in parallel, return all results.""" async def one(s: SimulateRequest) -> SimulateResult: paths = await _build_paths(s) result, elapsed = await asyncio.to_thread(_project, s, paths) return _to_response(result, elapsed, s) try: results = await asyncio.gather(*(one(s) for s in req.scenarios)) except KeyError as e: raise HTTPException(status_code=400, detail=f"Unknown name: {e}") from None return CompareResult(results=results)