api: expand FastAPI surface for scenarios, networth, life-events, goals, simulate
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful

Adds the read+write endpoints the frontend needs to drive a
ProjectionLab-style UX on top of the existing engine.

- /networth, /networth/history    — NW total + per-account from
                                     account_snapshot (frontend chart)
- /scenarios CRUD + projection    — list/get/create/patch/delete user
                                     scenarios; cartesian read-only
- /scenarios/{id}/life-events     — life event CRUD nested under scenario
- /life-events/{id}               — patch + delete by id
- /scenarios/{id}/goals,
  /goals/{id}                     — retirement goal CRUD
- /simulate, /compare             — sync, no-DB-write what-if endpoints

Auth: Bearer-token dependency on writes + simulate when API_BEARER_TOKEN
is set; reads always open (lock down via Authentik-fronted ingress in
prod). Existing /recompute keeps its bearer auth.

CORS middleware reads FRONTEND_ORIGINS (comma-separated) for the dev
SPA. Lifespan now provisions the SQLAlchemy engine + session_factory
on app.state and disposes them on shutdown.

40 new tests covering happy paths and validation. 172 tests total.
mypy strict + ruff clean (B008 ignore added — Depends() in defaults
is the canonical FastAPI pattern, not a bug).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Viktor Barzin 2026-05-09 21:48:36 +00:00
parent 31193faf08
commit ee6ed1d3c4
15 changed files with 1570 additions and 74 deletions

View file

@ -0,0 +1 @@
"""HTTP API surface — read + write endpoints over the engine + DB."""

42
fire_planner/api/auth.py Normal file
View file

@ -0,0 +1,42 @@
"""Bearer-token auth shared across routers.
Two modes, picked at startup from env:
- API_BEARER_TOKEN set enforce Bearer auth on all write/compute paths
- API_BEARER_TOKEN unset (dev) no auth, log a one-time warning
Read endpoints (`/networth`, `/scenarios`, ...) skip auth entirely so
the frontend can render without juggling tokens during dev. Lock those
down later via Authentik-fronted ingress when we deploy.
"""
from __future__ import annotations
import hmac
import logging
import os
from fastapi import Header, HTTPException
log = logging.getLogger(__name__)
_warned_unauth = False
def _read_token() -> str | None:
return os.environ.get("API_BEARER_TOKEN") or os.environ.get("RECOMPUTE_BEARER_TOKEN")
async def require_bearer(authorization: str | None = Header(default=None)) -> None:
"""FastAPI dependency: enforce bearer auth IF API_BEARER_TOKEN is set."""
expected = _read_token()
if not expected:
global _warned_unauth
if not _warned_unauth:
log.warning("API_BEARER_TOKEN unset — write endpoints are open. "
"Set it before exposing this service.")
_warned_unauth = True
return
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing bearer token")
token = authorization.removeprefix("Bearer ")
if not hmac.compare_digest(token, expected):
raise HTTPException(status_code=401, detail="Invalid token")

View file

@ -0,0 +1,18 @@
"""Shared FastAPI dependencies — DB session per request."""
from __future__ import annotations
from collections.abc import AsyncIterator
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
async def get_session(request: Request) -> AsyncIterator[AsyncSession]:
"""Yield an AsyncSession bound to the engine on app.state.
The engine + session factory are wired up in `app.lifespan`. Tests
swap them out via dependency_overrides.
"""
factory: async_sessionmaker[AsyncSession] = request.app.state.session_factory
async with factory() as session:
yield session

68
fire_planner/api/goals.py Normal file
View file

@ -0,0 +1,68 @@
"""Retirement-goal CRUD nested under a scenario."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.api.auth import require_bearer
from fire_planner.api.dependencies import get_session
from fire_planner.api.schemas import GoalCreate, GoalOut
from fire_planner.db import RetirementGoal, Scenario
router = APIRouter(tags=["goals"])
@router.get(
"/scenarios/{scenario_id}/goals",
response_model=list[GoalOut],
)
async def list_goals(
scenario_id: int,
session: AsyncSession = Depends(get_session),
) -> list[RetirementGoal]:
scen = await session.get(Scenario, scenario_id)
if scen is None:
raise HTTPException(status_code=404, detail="Scenario not found")
rows = (await session.execute(
select(RetirementGoal).where(RetirementGoal.scenario_id == scenario_id).order_by(
RetirementGoal.id))).scalars().all()
return list(rows)
@router.post(
"/scenarios/{scenario_id}/goals",
response_model=GoalOut,
status_code=201,
dependencies=[Depends(require_bearer)],
)
async def create_goal(
scenario_id: int,
payload: GoalCreate,
session: AsyncSession = Depends(get_session),
) -> RetirementGoal:
scen = await session.get(Scenario, scenario_id)
if scen is None:
raise HTTPException(status_code=404, detail="Scenario not found")
goal = RetirementGoal(scenario_id=scenario_id, **payload.model_dump())
session.add(goal)
await session.commit()
await session.refresh(goal)
return goal
@router.delete(
"/goals/{goal_id}",
status_code=204,
response_model=None,
dependencies=[Depends(require_bearer)],
)
async def delete_goal(
goal_id: int,
session: AsyncSession = Depends(get_session),
) -> None:
goal = await session.get(RetirementGoal, goal_id)
if goal is None:
raise HTTPException(status_code=404, detail="Goal not found")
await session.execute(delete(RetirementGoal).where(RetirementGoal.id == goal_id))
await session.commit()

View file

@ -0,0 +1,93 @@
"""Life-event CRUD nested under a scenario."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.api.auth import require_bearer
from fire_planner.api.dependencies import get_session
from fire_planner.api.schemas import LifeEventCreate, LifeEventOut, LifeEventPatch
from fire_planner.db import LifeEvent, Scenario
router = APIRouter(tags=["life-events"])
@router.get(
"/scenarios/{scenario_id}/life-events",
response_model=list[LifeEventOut],
)
async def list_events(
scenario_id: int,
session: AsyncSession = Depends(get_session),
) -> list[LifeEvent]:
scen = await session.get(Scenario, scenario_id)
if scen is None:
raise HTTPException(status_code=404, detail="Scenario not found")
rows = (await session.execute(
select(LifeEvent).where(LifeEvent.scenario_id == scenario_id).order_by(
LifeEvent.year_start, LifeEvent.id))).scalars().all()
return list(rows)
@router.post(
"/scenarios/{scenario_id}/life-events",
response_model=LifeEventOut,
status_code=201,
dependencies=[Depends(require_bearer)],
)
async def create_event(
scenario_id: int,
payload: LifeEventCreate,
session: AsyncSession = Depends(get_session),
) -> LifeEvent:
scen = await session.get(Scenario, scenario_id)
if scen is None:
raise HTTPException(status_code=404, detail="Scenario not found")
if payload.year_end is not None and payload.year_end < payload.year_start:
raise HTTPException(status_code=400, detail="year_end < year_start")
ev = LifeEvent(scenario_id=scenario_id, **payload.model_dump())
session.add(ev)
await session.commit()
await session.refresh(ev)
return ev
@router.patch(
"/life-events/{event_id}",
response_model=LifeEventOut,
dependencies=[Depends(require_bearer)],
)
async def patch_event(
event_id: int,
payload: LifeEventPatch,
session: AsyncSession = Depends(get_session),
) -> LifeEvent:
ev = await session.get(LifeEvent, event_id)
if ev is None:
raise HTTPException(status_code=404, detail="Event not found")
updates = payload.model_dump(exclude_unset=True)
for k, v in updates.items():
setattr(ev, k, v)
if ev.year_end is not None and ev.year_end < ev.year_start:
raise HTTPException(status_code=400, detail="year_end < year_start")
await session.commit()
await session.refresh(ev)
return ev
@router.delete(
"/life-events/{event_id}",
status_code=204,
response_model=None,
dependencies=[Depends(require_bearer)],
)
async def delete_event(
event_id: int,
session: AsyncSession = Depends(get_session),
) -> None:
ev = await session.get(LifeEvent, event_id)
if ev is None:
raise HTTPException(status_code=404, detail="Event not found")
await session.execute(delete(LifeEvent).where(LifeEvent.id == event_id))
await session.commit()

View file

@ -0,0 +1,78 @@
"""Net-worth read endpoints.
Reads from `fire_planner.account_snapshot` (populated hourly by the
wealthfolio ingest). Two views:
- GET /networth latest snapshot per account, totals
- GET /networth/history daily totals + per-account series, for charts
"""
from __future__ import annotations
from collections import defaultdict
from datetime import date
from decimal import Decimal
from fastapi import APIRouter, Depends, Query
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.api.dependencies import get_session
from fire_planner.api.schemas import (
AccountSnapshotOut,
NetWorthCurrent,
NetWorthHistory,
NetWorthHistoryPoint,
)
from fire_planner.db import AccountSnapshot
router = APIRouter(prefix="/networth", tags=["networth"])
@router.get("", response_model=NetWorthCurrent)
async def current_networth(session: AsyncSession = Depends(get_session)) -> NetWorthCurrent:
"""Latest snapshot per account + GBP total."""
latest_date = (await session.execute(
select(AccountSnapshot.snapshot_date).order_by(
AccountSnapshot.snapshot_date.desc()).limit(1))).scalar()
if latest_date is None:
return NetWorthCurrent(snapshot_date=date.today(), total_gbp=Decimal("0"), accounts=[])
rows = (await session.execute(
select(AccountSnapshot).where(
AccountSnapshot.snapshot_date == latest_date))).scalars().all()
accounts = [AccountSnapshotOut.model_validate(r) for r in rows]
total = sum((a.market_value_gbp for a in accounts), Decimal("0"))
return NetWorthCurrent(snapshot_date=latest_date, total_gbp=total, accounts=accounts)
@router.get("/history", response_model=NetWorthHistory)
async def networth_history(
session: AsyncSession = Depends(get_session),
days: int = Query(default=365, ge=1, le=3650, description="Look-back window."),
) -> NetWorthHistory:
"""Daily NW total + per-account breakdown for a stacked area chart.
Picks one row per (account_id, snapshot_date) wealthfolio ingest
upserts daily so this is already de-duped, but we group defensively.
"""
rows = (await session.execute(
select(
AccountSnapshot.snapshot_date,
AccountSnapshot.account_name,
AccountSnapshot.market_value_gbp,
).order_by(AccountSnapshot.snapshot_date))).all()
if not rows:
return NetWorthHistory(points=[])
by_date: dict[date, dict[str, Decimal]] = defaultdict(lambda: defaultdict(lambda: Decimal("0")))
for snap_date, name, value in rows:
by_date[snap_date][name] += Decimal(str(value))
cutoff_idx = max(0, len(by_date) - days)
sorted_dates = sorted(by_date.keys())[cutoff_idx:]
points = [
NetWorthHistoryPoint(
snapshot_date=d,
total_gbp=sum(by_date[d].values(), Decimal("0")),
by_account=dict(by_date[d]),
) for d in sorted_dates
]
return NetWorthHistory(points=points)

View file

@ -0,0 +1,172 @@
"""Scenario CRUD + projection read.
Mixed surface:
- GET /scenarios list (filter by kind)
- GET /scenarios/{id} single
- POST /scenarios create user scenario
- PATCH /scenarios/{id} update user scenario
- DELETE /scenarios/{id} delete user scenario (cartesian protected)
- GET /scenarios/{id}/projection latest MC run + per-year fan
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from fire_planner.api.auth import require_bearer
from fire_planner.api.dependencies import get_session
from fire_planner.api.schemas import (
ProjectionPoint,
ScenarioCreate,
ScenarioOut,
ScenarioPatch,
ScenarioProjection,
)
from fire_planner.db import McRun, ProjectionYearly, Scenario
router = APIRouter(prefix="/scenarios", tags=["scenarios"])
@router.get("", response_model=list[ScenarioOut])
async def list_scenarios(
kind: str | None = None,
session: AsyncSession = Depends(get_session),
) -> list[Scenario]:
"""List all scenarios. Filter `kind=user` or `kind=cartesian`."""
stmt = select(Scenario).order_by(Scenario.id)
if kind is not None:
stmt = stmt.where(Scenario.kind == kind)
rows = (await session.execute(stmt)).scalars().all()
return list(rows)
@router.get("/{scenario_id}", response_model=ScenarioOut)
async def get_scenario(
scenario_id: int,
session: AsyncSession = Depends(get_session),
) -> Scenario:
scen = await session.get(Scenario, scenario_id)
if scen is None:
raise HTTPException(status_code=404, detail="Scenario not found")
return scen
@router.post(
"",
response_model=ScenarioOut,
status_code=201,
dependencies=[Depends(require_bearer)],
)
async def create_scenario(
payload: ScenarioCreate,
session: AsyncSession = Depends(get_session),
) -> Scenario:
"""Create a user scenario. Cartesian scenarios come from the engine,
not the API."""
if payload.parent_scenario_id is not None:
parent = await session.get(Scenario, payload.parent_scenario_id)
if parent is None:
raise HTTPException(status_code=400, detail="parent_scenario_id not found")
scen = Scenario(
external_id=f"user-{uuid.uuid4().hex[:12]}",
kind="user",
name=payload.name,
description=payload.description,
parent_scenario_id=payload.parent_scenario_id,
jurisdiction=payload.jurisdiction,
strategy=payload.strategy,
leave_uk_year=payload.leave_uk_year,
glide_path=payload.glide_path,
spending_gbp=payload.spending_gbp,
horizon_years=payload.horizon_years,
nw_seed_gbp=payload.nw_seed_gbp,
savings_per_year_gbp=payload.savings_per_year_gbp,
config_json=payload.config_json,
)
session.add(scen)
await session.commit()
await session.refresh(scen)
return scen
@router.patch(
"/{scenario_id}",
response_model=ScenarioOut,
dependencies=[Depends(require_bearer)],
)
async def patch_scenario(
scenario_id: int,
payload: ScenarioPatch,
session: AsyncSession = Depends(get_session),
) -> Scenario:
scen = await session.get(Scenario, scenario_id)
if scen is None:
raise HTTPException(status_code=404, detail="Scenario not found")
if scen.kind != "user":
raise HTTPException(status_code=400,
detail="Cannot patch cartesian scenarios — they're auto-generated")
updates = payload.model_dump(exclude_unset=True)
for k, v in updates.items():
setattr(scen, k, v)
await session.commit()
await session.refresh(scen)
return scen
@router.delete(
"/{scenario_id}",
status_code=204,
response_model=None,
dependencies=[Depends(require_bearer)],
)
async def delete_scenario(
scenario_id: int,
session: AsyncSession = Depends(get_session),
) -> None:
scen = await session.get(Scenario, scenario_id)
if scen is None:
raise HTTPException(status_code=404, detail="Scenario not found")
if scen.kind != "user":
raise HTTPException(
status_code=400,
detail="Cannot delete cartesian scenarios — they re-appear on recompute",
)
await session.execute(delete(Scenario).where(Scenario.id == scenario_id))
await session.commit()
@router.get("/{scenario_id}/projection", response_model=ScenarioProjection)
async def get_scenario_projection(
scenario_id: int,
session: AsyncSession = Depends(get_session),
) -> ScenarioProjection:
"""Latest MC run for this scenario + the per-year fan series."""
scen = await session.get(Scenario, scenario_id)
if scen is None:
raise HTTPException(status_code=404, detail="Scenario not found")
run = (await session.execute(
select(McRun).where(McRun.scenario_id == scenario_id).order_by(
McRun.run_at.desc()).limit(1))).scalar_one_or_none()
if run is None:
raise HTTPException(status_code=404, detail="No MC runs persisted for this scenario yet")
yearly_rows = (await session.execute(
select(ProjectionYearly).where(ProjectionYearly.mc_run_id == run.id).order_by(
ProjectionYearly.year_idx))).scalars().all()
return ScenarioProjection(
scenario_id=scen.id,
external_id=scen.external_id,
mc_run_id=run.id,
run_at=run.run_at,
n_paths=run.n_paths,
success_rate=run.success_rate,
p10_ending_gbp=run.p10_ending_gbp,
p50_ending_gbp=run.p50_ending_gbp,
p90_ending_gbp=run.p90_ending_gbp,
median_lifetime_tax_gbp=run.median_lifetime_tax_gbp,
median_years_to_ruin=run.median_years_to_ruin,
yearly=[ProjectionPoint.model_validate(y) for y in yearly_rows],
)

237
fire_planner/api/schemas.py Normal file
View file

@ -0,0 +1,237 @@
"""Pydantic response/request schemas for the HTTP API.
Mirror the SQLAlchemy ORM but keep them de-coupled the API surface is a
contract for the frontend; we don't want migrations to silently change
JSON shape.
"""
from __future__ import annotations
from datetime import date, datetime
from decimal import Decimal
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class _Base(BaseModel):
model_config = ConfigDict(from_attributes=True)
# ── scenarios ────────────────────────────────────────────────────────
class ScenarioOut(_Base):
id: int
external_id: str
kind: str
name: str | None
description: str | None
parent_scenario_id: int | None
jurisdiction: str
strategy: str
leave_uk_year: int
glide_path: str
spending_gbp: Decimal
horizon_years: int
nw_seed_gbp: Decimal
savings_per_year_gbp: Decimal
config_json: dict[str, Any]
created_at: datetime
class ScenarioCreate(BaseModel):
"""Body for POST /scenarios — user-defined scenario."""
name: str = Field(min_length=1, max_length=200)
description: str | None = None
parent_scenario_id: int | None = None
jurisdiction: str
strategy: str
leave_uk_year: int = Field(ge=0, le=60)
glide_path: str
spending_gbp: Decimal = Field(gt=0)
horizon_years: int = Field(ge=5, le=100, default=60)
nw_seed_gbp: Decimal = Field(ge=0)
savings_per_year_gbp: Decimal = Field(ge=0, default=Decimal("0"))
config_json: dict[str, Any] = Field(default_factory=dict)
class ScenarioPatch(BaseModel):
"""Body for PATCH /scenarios/{id} — all fields optional."""
name: str | None = None
description: str | None = None
jurisdiction: str | None = None
strategy: str | None = None
leave_uk_year: int | None = None
glide_path: str | None = None
spending_gbp: Decimal | None = None
horizon_years: int | None = None
nw_seed_gbp: Decimal | None = None
savings_per_year_gbp: Decimal | None = None
config_json: dict[str, Any] | None = None
# ── projections ──────────────────────────────────────────────────────
class ProjectionPoint(_Base):
year_idx: int
p10_portfolio_gbp: Decimal
p25_portfolio_gbp: Decimal
p50_portfolio_gbp: Decimal
p75_portfolio_gbp: Decimal
p90_portfolio_gbp: Decimal
p50_withdrawal_gbp: Decimal
p50_tax_gbp: Decimal
survival_rate: Decimal
class ScenarioProjection(BaseModel):
"""Latest MC run + per-year fan-chart series for a scenario."""
scenario_id: int
external_id: str
mc_run_id: int
run_at: datetime
n_paths: int
success_rate: Decimal
p10_ending_gbp: Decimal
p50_ending_gbp: Decimal
p90_ending_gbp: Decimal
median_lifetime_tax_gbp: Decimal
median_years_to_ruin: Decimal | None
yearly: list[ProjectionPoint]
# ── net worth ────────────────────────────────────────────────────────
class AccountSnapshotOut(_Base):
account_id: str
account_name: str
account_type: str
currency: str
snapshot_date: date
market_value: Decimal
market_value_gbp: Decimal
cost_basis_gbp: Decimal | None
class NetWorthCurrent(BaseModel):
"""Snapshot at one point in time (latest by default)."""
snapshot_date: date
total_gbp: Decimal
accounts: list[AccountSnapshotOut]
class NetWorthHistoryPoint(BaseModel):
snapshot_date: date
total_gbp: Decimal
by_account: dict[str, Decimal]
class NetWorthHistory(BaseModel):
"""Per-day NW totals + per-account breakdown for a stacked area chart."""
points: list[NetWorthHistoryPoint]
# ── life events ──────────────────────────────────────────────────────
class LifeEventOut(_Base):
id: int
scenario_id: int
kind: str
name: str
year_start: int
year_end: int | None
delta_gbp_per_year: Decimal
one_time_amount_gbp: Decimal | None
enabled: bool
payload: dict[str, Any] | None
created_at: datetime
class LifeEventCreate(BaseModel):
kind: str
name: str = Field(min_length=1, max_length=200)
year_start: int = Field(ge=0, le=100)
year_end: int | None = Field(default=None, ge=0, le=100)
delta_gbp_per_year: Decimal = Decimal("0")
one_time_amount_gbp: Decimal | None = None
enabled: bool = True
payload: dict[str, Any] | None = None
class LifeEventPatch(BaseModel):
kind: str | None = None
name: str | None = None
year_start: int | None = None
year_end: int | None = None
delta_gbp_per_year: Decimal | None = None
one_time_amount_gbp: Decimal | None = None
enabled: bool | None = None
payload: dict[str, Any] | None = None
# ── goals ────────────────────────────────────────────────────────────
class GoalOut(_Base):
id: int
scenario_id: int
kind: str
name: str
target_amount_gbp: Decimal | None
target_year: int | None
comparator: str
success_threshold: Decimal
enabled: bool
payload: dict[str, Any] | None
created_at: datetime
class GoalCreate(BaseModel):
kind: str
name: str = Field(min_length=1, max_length=200)
target_amount_gbp: Decimal | None = None
target_year: int | None = Field(default=None, ge=0, le=100)
comparator: str = ">="
success_threshold: Decimal = Field(default=Decimal("0.95"), ge=0, le=1)
enabled: bool = True
payload: dict[str, Any] | None = None
# ── simulate / compare ───────────────────────────────────────────────
class SimulateRequest(BaseModel):
"""Sync, non-persisted simulate. Used by the React UI for what-if."""
jurisdiction: str
strategy: str
leave_uk_year: int = Field(ge=0, le=60)
glide_path: str = "rising"
spending_gbp: Decimal = Field(gt=0)
nw_seed_gbp: Decimal = Field(ge=0)
savings_per_year_gbp: Decimal = Decimal("0")
horizon_years: int = Field(ge=5, le=100, default=60)
floor_gbp: Decimal | None = None
n_paths: int = Field(ge=100, le=50_000, default=5_000)
seed: int = 42
class SimulateResult(BaseModel):
success_rate: Decimal
p10_ending_gbp: Decimal
p50_ending_gbp: Decimal
p90_ending_gbp: Decimal
median_lifetime_tax_gbp: Decimal
median_years_to_ruin: Decimal | None
elapsed_seconds: Decimal
yearly: list[ProjectionPoint]
class CompareRequest(BaseModel):
scenarios: list[SimulateRequest] = Field(min_length=2, max_length=5)
class CompareResult(BaseModel):
results: list[SimulateResult]

View file

@ -0,0 +1,125 @@
"""Sync simulate + multi-scenario compare.
Unlike the persisted Cartesian recompute (`/recompute`), these run a
single scenario inline and return the result immediately. The React UI
uses these for what-if exploration no DB write.
Returns a fan-chart series in the same shape as
`GET /scenarios/{id}/projection`, so frontend chart code is shared.
"""
from __future__ import annotations
import asyncio
import time
from decimal import Decimal
from pathlib import Path
import numpy as np
from fastapi import APIRouter, Depends, HTTPException
from fire_planner.api.auth import require_bearer
from fire_planner.api.schemas import (
CompareRequest,
CompareResult,
ProjectionPoint,
SimulateRequest,
SimulateResult,
)
from fire_planner.glide_path import get as get_glide
from fire_planner.returns.bootstrap import block_bootstrap
from fire_planner.returns.shiller import load_from_csv, synthetic_returns
from fire_planner.scenarios import build_regime_schedule, build_strategy
from fire_planner.simulator import SimulationResult, simulate
router = APIRouter(tags=["simulate"], dependencies=[Depends(require_bearer)])
_RETURNS_CSV = Path("/data/shiller_returns.csv")
def _load_paths(seed: int, n_paths: int, n_years: int) -> np.ndarray:
bundle = (load_from_csv(_RETURNS_CSV) if _RETURNS_CSV.exists() else synthetic_returns(seed=42))
rng = np.random.default_rng(seed)
return block_bootstrap(bundle, n_paths=n_paths, n_years=n_years, block_size=5, rng=rng)
def _project(req: SimulateRequest) -> tuple[SimulationResult, float]:
paths = _load_paths(req.seed, req.n_paths, req.horizon_years)
annual_savings = (np.full(req.horizon_years, float(req.savings_per_year_gbp), dtype=np.float64)
if req.savings_per_year_gbp > 0 else None)
floor = float(req.floor_gbp) if req.floor_gbp is not None else None
started = time.perf_counter()
result = simulate(
paths=paths,
initial_portfolio=float(req.nw_seed_gbp),
spending_target=float(req.spending_gbp),
glide=get_glide(req.glide_path),
strategy=build_strategy(req.strategy, floor=floor),
regime=build_regime_schedule(req.jurisdiction, req.leave_uk_year),
horizon_years=req.horizon_years,
annual_savings=annual_savings,
)
elapsed = time.perf_counter() - started
return result, elapsed
def _to_response(result: SimulationResult, elapsed: float) -> SimulateResult:
# portfolio_real has n_years+1 columns (year 0 = seed, year k = end-of-year k).
# withdrawal_real / tax_real have n_years columns (year k = withdrawn in year k+1).
# Yearly point k describes "end of year k+1": portfolio after withdrawal & growth.
pcts = [10, 25, 50, 75, 90]
portfolio_quantiles = {p: np.percentile(result.portfolio_real, p, axis=0) for p in pcts}
median_wd = np.percentile(result.withdrawal_real, 50, axis=0)
median_tax = np.percentile(result.tax_real, 50, axis=0)
n_years = result.n_years
survival_path = (result.success_mask.astype(np.float64).mean(axis=0) if
result.success_mask.ndim == 2 else np.ones(n_years))
yearly = [
ProjectionPoint(
year_idx=y,
p10_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[10][y + 1]), 2))),
p25_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[25][y + 1]), 2))),
p50_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[50][y + 1]), 2))),
p75_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[75][y + 1]), 2))),
p90_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[90][y + 1]), 2))),
p50_withdrawal_gbp=Decimal(str(round(float(median_wd[y]), 2))),
p50_tax_gbp=Decimal(str(round(float(median_tax[y]), 2))),
survival_rate=Decimal(str(round(float(survival_path[y]), 4))),
) for y in range(n_years)
]
median_ytr = result.median_years_to_ruin()
return SimulateResult(
success_rate=Decimal(str(round(float(result.success_rate), 4))),
p10_ending_gbp=Decimal(str(round(float(result.ending_percentile(10)), 2))),
p50_ending_gbp=Decimal(str(round(float(result.ending_percentile(50)), 2))),
p90_ending_gbp=Decimal(str(round(float(result.ending_percentile(90)), 2))),
median_lifetime_tax_gbp=Decimal(str(round(float(result.median_lifetime_tax()), 2))),
median_years_to_ruin=(Decimal(str(round(float(median_ytr), 2)))
if median_ytr is not None else None),
elapsed_seconds=Decimal(str(round(elapsed, 3))),
yearly=yearly,
)
@router.post("/simulate", response_model=SimulateResult)
async def simulate_one(req: SimulateRequest) -> SimulateResult:
"""Run one scenario synchronously, no DB write. ~1-3s for 5k paths."""
try:
result, elapsed = await asyncio.to_thread(_project, req)
except KeyError as e:
raise HTTPException(status_code=400, detail=f"Unknown name: {e}") from None
return _to_response(result, elapsed)
@router.post("/compare", response_model=CompareResult)
async def compare_scenarios(req: CompareRequest) -> CompareResult:
"""Run 2-5 scenarios in parallel, return all results."""
async def one(s: SimulateRequest) -> SimulateResult:
result, elapsed = await asyncio.to_thread(_project, s)
return _to_response(result, elapsed)
try:
results = await asyncio.gather(*(one(s) for s in req.scenarios))
except KeyError as e:
raise HTTPException(status_code=400, detail=f"Unknown name: {e}") from None
return CompareResult(results=results)