"""Sync simulate + multi-scenario compare. Unlike the persisted Cartesian recompute (`/recompute`), these run a single scenario inline and return the result immediately. The React UI uses these for what-if exploration — no DB write. Returns a fan-chart series in the same shape as `GET /scenarios/{id}/projection`, so frontend chart code is shared. """ from __future__ import annotations import asyncio import logging import time from decimal import Decimal from pathlib import Path import numpy as np from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from fire_planner.api.dependencies import get_session from fire_planner.api.schemas import ( CompareRequest, CompareResult, ExamplesOverlay, GoalProbability, ProjectionPoint, SimulateRequest, SimulateResult, ) from fire_planner.col import compute_col_ratio, representative_city_for from fire_planner.examples.service import summary_for_country from fire_planner.flex_spending import FlexRule as EngineFlexRule from fire_planner.glide_path import static from fire_planner.goals_eval import evaluate_goals from fire_planner.income_streams import IncomeStreamInput, streams_to_arrays from fire_planner.ingest.wealthfolio_pg import create_wf_sync_engine_from_env from fire_planner.life_events import ( EventInput, events_to_cashflow_array, events_to_category_outflows, ) from fire_planner.returns.bootstrap import block_bootstrap from fire_planner.returns.shiller import load_from_csv, synthetic_returns from fire_planner.returns.wealthfolio_returns import ( compute_annual_returns_from_pg, constant_real_return_paths, ) from fire_planner.scenarios import build_regime_schedule, build_strategy from fire_planner.simulator import SimulationResult, build_fixed_paths, simulate router = APIRouter(tags=["simulate"]) log = logging.getLogger(__name__) _RETURNS_CSV = Path("/data/shiller_returns.csv") # Maps `SimulateRequest.jurisdiction` (lowercase slug used throughout the # planner — e.g. "thailand") to the country name as stored in # `fire_example.country` (e.g. "Thailand"). The keys mirror # `JURISDICTION_REPRESENTATIVE_CITY` so the overlay covers every # jurisdiction with a fixed country. `nomad` has no fixed country and is # intentionally absent. _JURISDICTION_COUNTRY: dict[str, str] = { "uk": "United Kingdom", "cyprus": "Cyprus", "bulgaria": "Bulgaria", "uae": "United Arab Emirates", "malaysia": "Malaysia", "thailand": "Thailand", } def _resolve_target_country_for_examples(req: SimulateRequest) -> str | None: return _JURISDICTION_COUNTRY.get(req.jurisdiction.lower()) async def _build_examples_overlay( session: AsyncSession, req: SimulateRequest, ) -> ExamplesOverlay | None: """Look up real-world Reddit examples for the scenario's target country. Returns None when the jurisdiction has no fixed country (e.g. nomad), when no examples are stored, or when the lookup fails for any reason — examples are informational and must never sink a successful simulation.""" try: country = _resolve_target_country_for_examples(req) if country is None: return None summary = await summary_for_country(session, country) if summary.count == 0: return None return ExamplesOverlay( country=summary.country, count=summary.count, portfolio_gbp_median=summary.portfolio_gbp.median, portfolio_gbp_p25=summary.portfolio_gbp.p25, portfolio_gbp_p75=summary.portfolio_gbp.p75, annual_exp_gbp_median=summary.annual_exp_gbp.median, sample_links=summary.sample_links, ) except Exception: log.warning("examples_overlay lookup failed", exc_info=True) return None def _resolve_col_adjustment( req: SimulateRequest, ) -> tuple[SimulateRequest, Decimal | None, Decimal | None, str | None]: """Apply cost-of-living adjustment to `req.spending_gbp` when enabled. Returns the (possibly modified) request, the multiplier applied (or None), the post-adjustment spending GBP (or None), and the resolved target city slug (or None). Skipped silently when: - col_auto_adjust is False - the jurisdiction has no representative city (e.g. nomad) - baseline_city == resolved target city (identity transform) - either city is unknown to the baseline lookup (degrade gracefully rather than 400 — a future Phase-2 scraper will close the gap) """ if not req.col_auto_adjust: return req, None, None, None target = req.col_target_city or representative_city_for(req.jurisdiction) if target is None: return req, None, None, None if target == req.col_baseline_city: return req, None, None, target try: ratio = compute_col_ratio(req.col_baseline_city, target) except KeyError: return req, None, None, target adjusted_spend = req.spending_gbp * ratio adjusted_req = req.model_copy(update={"spending_gbp": adjusted_spend}) return adjusted_req, ratio, adjusted_spend, target def _shiller_paths(seed: int, n_paths: int, n_years: int) -> np.ndarray: bundle = (load_from_csv(_RETURNS_CSV) if _RETURNS_CSV.exists() else synthetic_returns(seed=42)) rng = np.random.default_rng(seed) return block_bootstrap(bundle, n_paths=n_paths, n_years=n_years, block_size=5, rng=rng) async def _wealthfolio_paths(seed: int, n_paths: int, n_years: int) -> np.ndarray: """Block-bootstrap the user's actual blended real returns. With typically <10 distinct annual samples, block_size=1 is appropriate — there's no serial-correlation signal to preserve.""" eng = create_wf_sync_engine_from_env() try: factory = async_sessionmaker(eng, expire_on_commit=False) async with factory() as wf_sess: bundle = await compute_annual_returns_from_pg(wf_sess) finally: await eng.dispose() rng = np.random.default_rng(seed) return block_bootstrap(bundle, n_paths=n_paths, n_years=n_years, block_size=1, rng=rng) async def _build_paths(req: SimulateRequest) -> np.ndarray: if req.rates_mode == "fixed": return build_fixed_paths( n_paths=req.n_paths, n_years=req.horizon_years, inflation_pct=float(req.inflation_pct), stocks_growth_pct=float(req.stocks_growth_pct), stocks_dividend_pct=float(req.stocks_dividend_pct), bonds_growth_pct=float(req.bonds_growth_pct), bonds_dividend_pct=float(req.bonds_dividend_pct), ) if req.returns_mode == "manual": if req.manual_real_return_pct is None: raise HTTPException( status_code=400, detail="manual_real_return_pct is required when returns_mode='manual'", ) return constant_real_return_paths( n_paths=req.n_paths, n_years=req.horizon_years, real_return_pct=float(req.manual_real_return_pct), ) if req.returns_mode == "wealthfolio": try: return await _wealthfolio_paths(req.seed, req.n_paths, req.horizon_years) except ValueError as e: raise HTTPException( status_code=400, detail=f"Wealthfolio history insufficient: {e}", ) from e return _shiller_paths(req.seed, req.n_paths, req.horizon_years) def _project(req: SimulateRequest, paths: np.ndarray) -> tuple[SimulationResult, float]: annual_savings = (np.full(req.horizon_years, float(req.savings_per_year_gbp), dtype=np.float64) if req.savings_per_year_gbp > 0 else None) floor = float(req.floor_gbp) if req.floor_gbp is not None else None cashflow_adjustments = None discretionary_outflows = None extra_outflows = None if req.life_events: engine_events = [ EventInput( year_start=ev.year_start, year_end=ev.year_end, delta_gbp_per_year=float(ev.delta_gbp_per_year), one_time_amount_gbp=(float(ev.one_time_amount_gbp) if ev.one_time_amount_gbp is not None else None), category=ev.category, enabled=ev.enabled, ) for ev in req.life_events ] cashflow_adjustments = events_to_cashflow_array(engine_events, req.horizon_years) category_outflows = events_to_category_outflows(engine_events, req.horizon_years) discretionary_outflows = category_outflows.get("discretionary") # extra_outflows feeds the withdrawal-trace display: total of # essential + discretionary spending events surfaces alongside # the strategy's draw on the chart. essential = category_outflows.get("essential") if essential is not None and discretionary_outflows is not None: extra_outflows = essential + discretionary_outflows engine_flex = [ EngineFlexRule( from_ath_pct=float(r.from_ath_pct), cut_discretionary_pct=float(r.cut_discretionary_pct), ) for r in req.flex_rules ] if req.flex_rules else None income_inflows = None income_taxable = None if req.income_streams: engine_streams = [ IncomeStreamInput( kind=s.kind, start_year=s.start_year, end_year=s.end_year, amount_gbp_per_year=float(s.amount_gbp_per_year), growth_pct=float(s.growth_pct), tax_treatment=s.tax_treatment, enabled=s.enabled, ) for s in req.income_streams ] income_inflows, income_taxable = streams_to_arrays(engine_streams, req.horizon_years) strategy = build_strategy( req.strategy, floor=floor, annual_real_adjust_pct=float(req.annual_real_adjust_pct), guardrail_threshold_pct=(float(req.guardrail_threshold_pct) if req.guardrail_threshold_pct is not None else None), guardrail_cut_pct=float(req.guardrail_cut_pct), ) glide_alloc = float(req.stocks_allocation) if req.rates_mode == "fixed" else 1.0 started = time.perf_counter() result = simulate( paths=paths, initial_portfolio=float(req.nw_seed_gbp), spending_target=float(req.spending_gbp), glide=static(glide_alloc), strategy=strategy, regime=build_regime_schedule(req.jurisdiction, req.leave_uk_year), horizon_years=req.horizon_years, annual_savings=annual_savings, cashflow_adjustments=cashflow_adjustments, income_inflows=income_inflows, income_taxable=income_taxable, discretionary_outflows=discretionary_outflows, extra_outflows=extra_outflows, flex_rules=engine_flex, ) elapsed = time.perf_counter() - started return result, elapsed def _to_response( result: SimulationResult, elapsed: float, req: SimulateRequest | None = None, col_multiplier: Decimal | None = None, col_adjusted_spend: Decimal | None = None, col_target_city: str | None = None, examples_overlay: ExamplesOverlay | None = None, ) -> SimulateResult: # portfolio_real has n_years+1 columns (year 0 = seed, year k = end-of-year k). # withdrawal_real / tax_real have n_years columns (year k = withdrawn in year k+1). # Yearly point k describes "end of year k+1": portfolio after withdrawal & growth. pcts = [10, 25, 50, 75, 90] portfolio_quantiles = {p: np.percentile(result.portfolio_real, p, axis=0) for p in pcts} median_wd = np.percentile(result.withdrawal_real, 50, axis=0) median_tax = np.percentile(result.tax_real, 50, axis=0) n_years = result.n_years survival_path = (result.success_mask.astype(np.float64).mean(axis=0) if result.success_mask.ndim == 2 else np.ones(n_years)) yearly = [ ProjectionPoint( year_idx=y, p10_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[10][y + 1]), 2))), p25_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[25][y + 1]), 2))), p50_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[50][y + 1]), 2))), p75_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[75][y + 1]), 2))), p90_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[90][y + 1]), 2))), p50_withdrawal_gbp=Decimal(str(round(float(median_wd[y]), 2))), p50_tax_gbp=Decimal(str(round(float(median_tax[y]), 2))), survival_rate=Decimal(str(round(float(survival_path[y]), 4))), ) for y in range(n_years) ] median_ytr = result.median_years_to_ruin() goals_probability: list[GoalProbability] = [] if req is not None and req.goals: evaluations = evaluate_goals(result, req.goals, req.horizon_years) goals_probability = [ GoalProbability( goal_id=None, name=ev.name, kind=ev.kind, probability=Decimal(str(round(ev.probability, 4))), threshold=Decimal(str(round(ev.threshold, 4))), passed=ev.passed, ) for ev in evaluations ] return SimulateResult( success_rate=Decimal(str(round(float(result.success_rate), 4))), p10_ending_gbp=Decimal(str(round(float(result.ending_percentile(10)), 2))), p50_ending_gbp=Decimal(str(round(float(result.ending_percentile(50)), 2))), p90_ending_gbp=Decimal(str(round(float(result.ending_percentile(90)), 2))), median_lifetime_tax_gbp=Decimal(str(round(float(result.median_lifetime_tax()), 2))), median_years_to_ruin=(Decimal(str(round(float(median_ytr), 2))) if median_ytr is not None else None), elapsed_seconds=Decimal(str(round(elapsed, 3))), yearly=yearly, goals_probability=goals_probability, col_multiplier_applied=(Decimal(str(round(float(col_multiplier), 6))) if col_multiplier is not None else None), col_adjusted_spending_gbp=(Decimal(str(round(float(col_adjusted_spend), 2))) if col_adjusted_spend is not None else None), col_target_city=col_target_city, examples_overlay=examples_overlay, ) @router.post("/simulate", response_model=SimulateResult) async def simulate_one( req: SimulateRequest, session: AsyncSession = Depends(get_session), ) -> SimulateResult: """Run one scenario synchronously, no DB write. ~1-3s for 5k paths.""" adjusted_req, mult, adj_spend, target_city = _resolve_col_adjustment(req) paths = await _build_paths(adjusted_req) try: result, elapsed = await asyncio.to_thread(_project, adjusted_req, paths) except KeyError as e: raise HTTPException(status_code=400, detail=f"Unknown name: {e}") from None overlay = await _build_examples_overlay(session, adjusted_req) return _to_response( result, elapsed, adjusted_req, mult, adj_spend, target_city, overlay) @router.post("/compare", response_model=CompareResult) async def compare_scenarios( req: CompareRequest, session: AsyncSession = Depends(get_session), ) -> CompareResult: """Run 2-5 scenarios in parallel, return all results.""" async def one(s: SimulateRequest) -> tuple[SimulationResult, float, SimulateRequest, Decimal | None, Decimal | None, str | None]: adjusted_s, mult, adj_spend, target_city = _resolve_col_adjustment(s) paths = await _build_paths(adjusted_s) result, elapsed = await asyncio.to_thread(_project, adjusted_s, paths) return result, elapsed, adjusted_s, mult, adj_spend, target_city try: projected = await asyncio.gather(*(one(s) for s in req.scenarios)) except KeyError as e: raise HTTPException(status_code=400, detail=f"Unknown name: {e}") from None # Overlay lookups must run sequentially — AsyncSession is not safe for # concurrent use. The lookup is fast (single SELECT) and informational # only, so per-scenario serial cost is negligible. results = [] for result, elapsed, adjusted_s, mult, adj_spend, target_city in projected: overlay = await _build_examples_overlay(session, adjusted_s) results.append(_to_response( result, elapsed, adjusted_s, mult, adj_spend, target_city, overlay)) return CompareResult(results=results)