"""Scenario CRUD + projection read. Mixed surface: - GET /scenarios list (filter by kind) - GET /scenarios/{id} single - POST /scenarios create user scenario - PATCH /scenarios/{id} update user scenario - DELETE /scenarios/{id} delete user scenario (cartesian protected) - GET /scenarios/{id}/projection latest MC run + per-year fan """ from __future__ import annotations import uuid from decimal import Decimal from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from fire_planner.api.dependencies import get_session from fire_planner.api.schemas import ( GoalProbability, ProjectionPoint, ScenarioCreate, ScenarioOut, ScenarioPatch, ScenarioProjection, ) from fire_planner.db import McRun, ProjectionYearly, RetirementGoal, Scenario router = APIRouter(prefix="/scenarios", tags=["scenarios"]) def _approx_prob_above(yearly_row: ProjectionYearly, target_amount: float) -> float: """Approximate the fraction of paths whose portfolio at the row's year is at or above ``target_amount``, using linear interpolation across the persisted p10/p25/p50/p75/p90 cells. Exact only when target_amount lands on a stored quantile, otherwise a piecewise-linear estimate. Good enough for Wave 1 — the live /simulate endpoint computes the exact path-by-path probability. """ from itertools import pairwise cells = [ (10, float(yearly_row.p10_portfolio_gbp)), (25, float(yearly_row.p25_portfolio_gbp)), (50, float(yearly_row.p50_portfolio_gbp)), (75, float(yearly_row.p75_portfolio_gbp)), (90, float(yearly_row.p90_portfolio_gbp)), ] cells.sort(key=lambda kv: kv[1]) if target_amount <= cells[0][1]: return 1.0 if target_amount >= cells[-1][1]: return 0.0 for (lo_pct, lo_v), (hi_pct, hi_v) in pairwise(cells): if lo_v <= target_amount <= hi_v: span = hi_v - lo_v if span == 0: return max(0.0, 1.0 - lo_pct / 100.0) t = (target_amount - lo_v) / span pct = lo_pct + t * (hi_pct - lo_pct) return max(0.0, min(1.0, 1.0 - pct / 100.0)) return 0.0 def _evaluate_goal_against_fan( goal: RetirementGoal, yearly_rows: list[ProjectionYearly], horizon_years: int, ) -> GoalProbability | None: if not goal.enabled: return None threshold = float(goal.success_threshold) def _result(prob: float) -> GoalProbability: return GoalProbability( goal_id=goal.id, name=goal.name, kind=goal.kind, probability=Decimal(str(round(prob, 4))), threshold=Decimal(str(round(threshold, 4))), passed=prob >= threshold, ) by_year = {row.year_idx: row for row in yearly_rows} if not by_year: return _result(0.0) last_year = max(by_year) if goal.kind == "target_nw_by_year": if goal.target_year is None or goal.target_amount_gbp is None: return _result(0.0) y = max(0, min(int(goal.target_year), last_year)) row = by_year.get(y) if row is None: return _result(0.0) prob = _approx_prob_above(row, float(goal.target_amount_gbp)) return _result(prob) if goal.kind == "never_run_out": end = int(goal.target_year) if goal.target_year is not None else horizon_years end = max(0, min(end, last_year)) row = by_year.get(end) if row is None: return _result(0.0) return _result(float(row.survival_rate)) if goal.kind == "target_real_income": if goal.target_amount_gbp is None: return _result(0.0) target = float(goal.target_amount_gbp) start_y = int(goal.target_year) if goal.target_year is not None else 0 window = [r for r in yearly_rows if r.year_idx >= start_y] if not window: return _result(0.0) median_wd = sorted(float(r.p50_withdrawal_gbp) for r in window) mid = median_wd[len(median_wd) // 2] return _result(1.0 if mid >= target else 0.0) return _result(0.0) @router.get("", response_model=list[ScenarioOut]) async def list_scenarios( kind: str | None = None, session: AsyncSession = Depends(get_session), ) -> list[Scenario]: """List all scenarios. Filter `kind=user` or `kind=cartesian`.""" stmt = select(Scenario).order_by(Scenario.id) if kind is not None: stmt = stmt.where(Scenario.kind == kind) rows = (await session.execute(stmt)).scalars().all() return list(rows) @router.get("/{scenario_id}", response_model=ScenarioOut) async def get_scenario( scenario_id: int, session: AsyncSession = Depends(get_session), ) -> Scenario: scen = await session.get(Scenario, scenario_id) if scen is None: raise HTTPException(status_code=404, detail="Scenario not found") return scen @router.post( "", response_model=ScenarioOut, status_code=201, ) async def create_scenario( payload: ScenarioCreate, session: AsyncSession = Depends(get_session), ) -> Scenario: """Create a user scenario. Cartesian scenarios come from the engine, not the API.""" if payload.parent_scenario_id is not None: parent = await session.get(Scenario, payload.parent_scenario_id) if parent is None: raise HTTPException(status_code=400, detail="parent_scenario_id not found") scen = Scenario( external_id=f"user-{uuid.uuid4().hex[:12]}", kind="user", name=payload.name, description=payload.description, parent_scenario_id=payload.parent_scenario_id, jurisdiction=payload.jurisdiction, strategy=payload.strategy, leave_uk_year=payload.leave_uk_year, glide_path=payload.glide_path, spending_gbp=payload.spending_gbp, horizon_years=payload.horizon_years, nw_seed_gbp=payload.nw_seed_gbp, savings_per_year_gbp=payload.savings_per_year_gbp, config_json=payload.config_json, ) session.add(scen) await session.commit() await session.refresh(scen) return scen @router.patch( "/{scenario_id}", response_model=ScenarioOut, ) async def patch_scenario( scenario_id: int, payload: ScenarioPatch, session: AsyncSession = Depends(get_session), ) -> Scenario: scen = await session.get(Scenario, scenario_id) if scen is None: raise HTTPException(status_code=404, detail="Scenario not found") updates = payload.model_dump(exclude_unset=True) if scen.kind != "user": # Cartesian scenarios are rebuilt on every recompute — most core # fields would be wiped by the next run, so we only allow updates # to free-form metadata that we want to preserve across recomputes # (notes, flex_rules, rate overrides). Hard-block edits to the # parameters that define the scenario shape. allowed_for_cartesian = {"config_json", "name", "description"} bad = set(updates) - allowed_for_cartesian if bad: raise HTTPException( status_code=400, detail=("Cannot patch cartesian scenario fields {sorted(bad)} — " "they're auto-generated. Only config_json/name/description " "may be updated."), ) for k, v in updates.items(): setattr(scen, k, v) await session.commit() await session.refresh(scen) return scen @router.delete( "/{scenario_id}", status_code=204, response_model=None, ) async def delete_scenario( scenario_id: int, session: AsyncSession = Depends(get_session), ) -> None: scen = await session.get(Scenario, scenario_id) if scen is None: raise HTTPException(status_code=404, detail="Scenario not found") if scen.kind != "user": raise HTTPException( status_code=400, detail="Cannot delete cartesian scenarios — they re-appear on recompute", ) await session.execute(delete(Scenario).where(Scenario.id == scenario_id)) await session.commit() @router.get("/{scenario_id}/projection", response_model=ScenarioProjection) async def get_scenario_projection( scenario_id: int, session: AsyncSession = Depends(get_session), ) -> ScenarioProjection: """Latest MC run for this scenario + the per-year fan series.""" scen = await session.get(Scenario, scenario_id) if scen is None: raise HTTPException(status_code=404, detail="Scenario not found") run = (await session.execute( select(McRun).where(McRun.scenario_id == scenario_id).order_by( McRun.run_at.desc()).limit(1))).scalar_one_or_none() if run is None: raise HTTPException(status_code=404, detail="No MC runs persisted for this scenario yet") yearly_rows = list((await session.execute( select(ProjectionYearly).where(ProjectionYearly.mc_run_id == run.id).order_by( ProjectionYearly.year_idx))).scalars().all()) goals_rows = list((await session.execute( select(RetirementGoal).where( RetirementGoal.scenario_id == scenario_id))).scalars().all()) goals_probability: list[GoalProbability] = [] for goal in goals_rows: evaluation = _evaluate_goal_against_fan(goal, yearly_rows, scen.horizon_years) if evaluation is not None: goals_probability.append(evaluation) return ScenarioProjection( scenario_id=scen.id, external_id=scen.external_id, mc_run_id=run.id, run_at=run.run_at, n_paths=run.n_paths, success_rate=run.success_rate, p10_ending_gbp=run.p10_ending_gbp, p50_ending_gbp=run.p50_ending_gbp, p90_ending_gbp=run.p90_ending_gbp, median_lifetime_tax_gbp=run.median_lifetime_tax_gbp, median_years_to_ruin=run.median_years_to_ruin, yearly=[ProjectionPoint.model_validate(y) for y in yearly_rows], goals_probability=goals_probability, )