api: expand FastAPI surface for scenarios, networth, life-events, goals, simulate
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
Adds the read+write endpoints the frontend needs to drive a
ProjectionLab-style UX on top of the existing engine.
- /networth, /networth/history — NW total + per-account from
account_snapshot (frontend chart)
- /scenarios CRUD + projection — list/get/create/patch/delete user
scenarios; cartesian read-only
- /scenarios/{id}/life-events — life event CRUD nested under scenario
- /life-events/{id} — patch + delete by id
- /scenarios/{id}/goals,
/goals/{id} — retirement goal CRUD
- /simulate, /compare — sync, no-DB-write what-if endpoints
Auth: Bearer-token dependency on writes + simulate when API_BEARER_TOKEN
is set; reads always open (lock down via Authentik-fronted ingress in
prod). Existing /recompute keeps its bearer auth.
CORS middleware reads FRONTEND_ORIGINS (comma-separated) for the dev
SPA. Lifespan now provisions the SQLAlchemy engine + session_factory
on app.state and disposes them on shutdown.
40 new tests covering happy paths and validation. 172 tests total.
mypy strict + ruff clean (B008 ignore added — Depends() in defaults
is the canonical FastAPI pattern, not a bug).
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
31193faf08
commit
ee6ed1d3c4
15 changed files with 1570 additions and 74 deletions
1
fire_planner/api/__init__.py
Normal file
1
fire_planner/api/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""HTTP API surface — read + write endpoints over the engine + DB."""
|
||||
42
fire_planner/api/auth.py
Normal file
42
fire_planner/api/auth.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Bearer-token auth shared across routers.
|
||||
|
||||
Two modes, picked at startup from env:
|
||||
- API_BEARER_TOKEN set → enforce Bearer auth on all write/compute paths
|
||||
- API_BEARER_TOKEN unset (dev) → no auth, log a one-time warning
|
||||
|
||||
Read endpoints (`/networth`, `/scenarios`, ...) skip auth entirely so
|
||||
the frontend can render without juggling tokens during dev. Lock those
|
||||
down later via Authentik-fronted ingress when we deploy.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import Header, HTTPException
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_warned_unauth = False
|
||||
|
||||
|
||||
def _read_token() -> str | None:
|
||||
return os.environ.get("API_BEARER_TOKEN") or os.environ.get("RECOMPUTE_BEARER_TOKEN")
|
||||
|
||||
|
||||
async def require_bearer(authorization: str | None = Header(default=None)) -> None:
|
||||
"""FastAPI dependency: enforce bearer auth IF API_BEARER_TOKEN is set."""
|
||||
expected = _read_token()
|
||||
if not expected:
|
||||
global _warned_unauth
|
||||
if not _warned_unauth:
|
||||
log.warning("API_BEARER_TOKEN unset — write endpoints are open. "
|
||||
"Set it before exposing this service.")
|
||||
_warned_unauth = True
|
||||
return
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing bearer token")
|
||||
token = authorization.removeprefix("Bearer ")
|
||||
if not hmac.compare_digest(token, expected):
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
18
fire_planner/api/dependencies.py
Normal file
18
fire_planner/api/dependencies.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Shared FastAPI dependencies — DB session per request."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
async def get_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||
"""Yield an AsyncSession bound to the engine on app.state.
|
||||
|
||||
The engine + session factory are wired up in `app.lifespan`. Tests
|
||||
swap them out via dependency_overrides.
|
||||
"""
|
||||
factory: async_sessionmaker[AsyncSession] = request.app.state.session_factory
|
||||
async with factory() as session:
|
||||
yield session
|
||||
68
fire_planner/api/goals.py
Normal file
68
fire_planner/api/goals.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
"""Retirement-goal CRUD nested under a scenario."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.api.auth import require_bearer
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.api.schemas import GoalCreate, GoalOut
|
||||
from fire_planner.db import RetirementGoal, Scenario
|
||||
|
||||
router = APIRouter(tags=["goals"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/scenarios/{scenario_id}/goals",
|
||||
response_model=list[GoalOut],
|
||||
)
|
||||
async def list_goals(
|
||||
scenario_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[RetirementGoal]:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
rows = (await session.execute(
|
||||
select(RetirementGoal).where(RetirementGoal.scenario_id == scenario_id).order_by(
|
||||
RetirementGoal.id))).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scenarios/{scenario_id}/goals",
|
||||
response_model=GoalOut,
|
||||
status_code=201,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def create_goal(
|
||||
scenario_id: int,
|
||||
payload: GoalCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> RetirementGoal:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
goal = RetirementGoal(scenario_id=scenario_id, **payload.model_dump())
|
||||
session.add(goal)
|
||||
await session.commit()
|
||||
await session.refresh(goal)
|
||||
return goal
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/goals/{goal_id}",
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def delete_goal(
|
||||
goal_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
goal = await session.get(RetirementGoal, goal_id)
|
||||
if goal is None:
|
||||
raise HTTPException(status_code=404, detail="Goal not found")
|
||||
await session.execute(delete(RetirementGoal).where(RetirementGoal.id == goal_id))
|
||||
await session.commit()
|
||||
93
fire_planner/api/life_events.py
Normal file
93
fire_planner/api/life_events.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""Life-event CRUD nested under a scenario."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.api.auth import require_bearer
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.api.schemas import LifeEventCreate, LifeEventOut, LifeEventPatch
|
||||
from fire_planner.db import LifeEvent, Scenario
|
||||
|
||||
router = APIRouter(tags=["life-events"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/scenarios/{scenario_id}/life-events",
|
||||
response_model=list[LifeEventOut],
|
||||
)
|
||||
async def list_events(
|
||||
scenario_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[LifeEvent]:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
rows = (await session.execute(
|
||||
select(LifeEvent).where(LifeEvent.scenario_id == scenario_id).order_by(
|
||||
LifeEvent.year_start, LifeEvent.id))).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scenarios/{scenario_id}/life-events",
|
||||
response_model=LifeEventOut,
|
||||
status_code=201,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def create_event(
|
||||
scenario_id: int,
|
||||
payload: LifeEventCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> LifeEvent:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
if payload.year_end is not None and payload.year_end < payload.year_start:
|
||||
raise HTTPException(status_code=400, detail="year_end < year_start")
|
||||
ev = LifeEvent(scenario_id=scenario_id, **payload.model_dump())
|
||||
session.add(ev)
|
||||
await session.commit()
|
||||
await session.refresh(ev)
|
||||
return ev
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/life-events/{event_id}",
|
||||
response_model=LifeEventOut,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def patch_event(
|
||||
event_id: int,
|
||||
payload: LifeEventPatch,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> LifeEvent:
|
||||
ev = await session.get(LifeEvent, event_id)
|
||||
if ev is None:
|
||||
raise HTTPException(status_code=404, detail="Event not found")
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
for k, v in updates.items():
|
||||
setattr(ev, k, v)
|
||||
if ev.year_end is not None and ev.year_end < ev.year_start:
|
||||
raise HTTPException(status_code=400, detail="year_end < year_start")
|
||||
await session.commit()
|
||||
await session.refresh(ev)
|
||||
return ev
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/life-events/{event_id}",
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def delete_event(
|
||||
event_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
ev = await session.get(LifeEvent, event_id)
|
||||
if ev is None:
|
||||
raise HTTPException(status_code=404, detail="Event not found")
|
||||
await session.execute(delete(LifeEvent).where(LifeEvent.id == event_id))
|
||||
await session.commit()
|
||||
78
fire_planner/api/networth.py
Normal file
78
fire_planner/api/networth.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
"""Net-worth read endpoints.
|
||||
|
||||
Reads from `fire_planner.account_snapshot` (populated hourly by the
|
||||
wealthfolio ingest). Two views:
|
||||
- GET /networth → latest snapshot per account, totals
|
||||
- GET /networth/history → daily totals + per-account series, for charts
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.api.schemas import (
|
||||
AccountSnapshotOut,
|
||||
NetWorthCurrent,
|
||||
NetWorthHistory,
|
||||
NetWorthHistoryPoint,
|
||||
)
|
||||
from fire_planner.db import AccountSnapshot
|
||||
|
||||
router = APIRouter(prefix="/networth", tags=["networth"])
|
||||
|
||||
|
||||
@router.get("", response_model=NetWorthCurrent)
|
||||
async def current_networth(session: AsyncSession = Depends(get_session)) -> NetWorthCurrent:
|
||||
"""Latest snapshot per account + GBP total."""
|
||||
latest_date = (await session.execute(
|
||||
select(AccountSnapshot.snapshot_date).order_by(
|
||||
AccountSnapshot.snapshot_date.desc()).limit(1))).scalar()
|
||||
if latest_date is None:
|
||||
return NetWorthCurrent(snapshot_date=date.today(), total_gbp=Decimal("0"), accounts=[])
|
||||
rows = (await session.execute(
|
||||
select(AccountSnapshot).where(
|
||||
AccountSnapshot.snapshot_date == latest_date))).scalars().all()
|
||||
accounts = [AccountSnapshotOut.model_validate(r) for r in rows]
|
||||
total = sum((a.market_value_gbp for a in accounts), Decimal("0"))
|
||||
return NetWorthCurrent(snapshot_date=latest_date, total_gbp=total, accounts=accounts)
|
||||
|
||||
|
||||
@router.get("/history", response_model=NetWorthHistory)
|
||||
async def networth_history(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
days: int = Query(default=365, ge=1, le=3650, description="Look-back window."),
|
||||
) -> NetWorthHistory:
|
||||
"""Daily NW total + per-account breakdown for a stacked area chart.
|
||||
|
||||
Picks one row per (account_id, snapshot_date) — wealthfolio ingest
|
||||
upserts daily so this is already de-duped, but we group defensively.
|
||||
"""
|
||||
rows = (await session.execute(
|
||||
select(
|
||||
AccountSnapshot.snapshot_date,
|
||||
AccountSnapshot.account_name,
|
||||
AccountSnapshot.market_value_gbp,
|
||||
).order_by(AccountSnapshot.snapshot_date))).all()
|
||||
if not rows:
|
||||
return NetWorthHistory(points=[])
|
||||
|
||||
by_date: dict[date, dict[str, Decimal]] = defaultdict(lambda: defaultdict(lambda: Decimal("0")))
|
||||
for snap_date, name, value in rows:
|
||||
by_date[snap_date][name] += Decimal(str(value))
|
||||
|
||||
cutoff_idx = max(0, len(by_date) - days)
|
||||
sorted_dates = sorted(by_date.keys())[cutoff_idx:]
|
||||
points = [
|
||||
NetWorthHistoryPoint(
|
||||
snapshot_date=d,
|
||||
total_gbp=sum(by_date[d].values(), Decimal("0")),
|
||||
by_account=dict(by_date[d]),
|
||||
) for d in sorted_dates
|
||||
]
|
||||
return NetWorthHistory(points=points)
|
||||
172
fire_planner/api/scenarios.py
Normal file
172
fire_planner/api/scenarios.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
"""Scenario CRUD + projection read.
|
||||
|
||||
Mixed surface:
|
||||
- GET /scenarios list (filter by kind)
|
||||
- GET /scenarios/{id} single
|
||||
- POST /scenarios create user scenario
|
||||
- PATCH /scenarios/{id} update user scenario
|
||||
- DELETE /scenarios/{id} delete user scenario (cartesian protected)
|
||||
- GET /scenarios/{id}/projection latest MC run + per-year fan
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.api.auth import require_bearer
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.api.schemas import (
|
||||
ProjectionPoint,
|
||||
ScenarioCreate,
|
||||
ScenarioOut,
|
||||
ScenarioPatch,
|
||||
ScenarioProjection,
|
||||
)
|
||||
from fire_planner.db import McRun, ProjectionYearly, Scenario
|
||||
|
||||
router = APIRouter(prefix="/scenarios", tags=["scenarios"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[ScenarioOut])
|
||||
async def list_scenarios(
|
||||
kind: str | None = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Scenario]:
|
||||
"""List all scenarios. Filter `kind=user` or `kind=cartesian`."""
|
||||
stmt = select(Scenario).order_by(Scenario.id)
|
||||
if kind is not None:
|
||||
stmt = stmt.where(Scenario.kind == kind)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
|
||||
@router.get("/{scenario_id}", response_model=ScenarioOut)
|
||||
async def get_scenario(
|
||||
scenario_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Scenario:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
return scen
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=ScenarioOut,
|
||||
status_code=201,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def create_scenario(
|
||||
payload: ScenarioCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Scenario:
|
||||
"""Create a user scenario. Cartesian scenarios come from the engine,
|
||||
not the API."""
|
||||
if payload.parent_scenario_id is not None:
|
||||
parent = await session.get(Scenario, payload.parent_scenario_id)
|
||||
if parent is None:
|
||||
raise HTTPException(status_code=400, detail="parent_scenario_id not found")
|
||||
|
||||
scen = Scenario(
|
||||
external_id=f"user-{uuid.uuid4().hex[:12]}",
|
||||
kind="user",
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
parent_scenario_id=payload.parent_scenario_id,
|
||||
jurisdiction=payload.jurisdiction,
|
||||
strategy=payload.strategy,
|
||||
leave_uk_year=payload.leave_uk_year,
|
||||
glide_path=payload.glide_path,
|
||||
spending_gbp=payload.spending_gbp,
|
||||
horizon_years=payload.horizon_years,
|
||||
nw_seed_gbp=payload.nw_seed_gbp,
|
||||
savings_per_year_gbp=payload.savings_per_year_gbp,
|
||||
config_json=payload.config_json,
|
||||
)
|
||||
session.add(scen)
|
||||
await session.commit()
|
||||
await session.refresh(scen)
|
||||
return scen
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{scenario_id}",
|
||||
response_model=ScenarioOut,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def patch_scenario(
|
||||
scenario_id: int,
|
||||
payload: ScenarioPatch,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Scenario:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
if scen.kind != "user":
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Cannot patch cartesian scenarios — they're auto-generated")
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
for k, v in updates.items():
|
||||
setattr(scen, k, v)
|
||||
await session.commit()
|
||||
await session.refresh(scen)
|
||||
return scen
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{scenario_id}",
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def delete_scenario(
|
||||
scenario_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
if scen.kind != "user":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete cartesian scenarios — they re-appear on recompute",
|
||||
)
|
||||
await session.execute(delete(Scenario).where(Scenario.id == scenario_id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
@router.get("/{scenario_id}/projection", response_model=ScenarioProjection)
|
||||
async def get_scenario_projection(
|
||||
scenario_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> ScenarioProjection:
|
||||
"""Latest MC run for this scenario + the per-year fan series."""
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
run = (await session.execute(
|
||||
select(McRun).where(McRun.scenario_id == scenario_id).order_by(
|
||||
McRun.run_at.desc()).limit(1))).scalar_one_or_none()
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail="No MC runs persisted for this scenario yet")
|
||||
yearly_rows = (await session.execute(
|
||||
select(ProjectionYearly).where(ProjectionYearly.mc_run_id == run.id).order_by(
|
||||
ProjectionYearly.year_idx))).scalars().all()
|
||||
return ScenarioProjection(
|
||||
scenario_id=scen.id,
|
||||
external_id=scen.external_id,
|
||||
mc_run_id=run.id,
|
||||
run_at=run.run_at,
|
||||
n_paths=run.n_paths,
|
||||
success_rate=run.success_rate,
|
||||
p10_ending_gbp=run.p10_ending_gbp,
|
||||
p50_ending_gbp=run.p50_ending_gbp,
|
||||
p90_ending_gbp=run.p90_ending_gbp,
|
||||
median_lifetime_tax_gbp=run.median_lifetime_tax_gbp,
|
||||
median_years_to_ruin=run.median_years_to_ruin,
|
||||
yearly=[ProjectionPoint.model_validate(y) for y in yearly_rows],
|
||||
)
|
||||
237
fire_planner/api/schemas.py
Normal file
237
fire_planner/api/schemas.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
"""Pydantic response/request schemas for the HTTP API.
|
||||
|
||||
Mirror the SQLAlchemy ORM but keep them de-coupled — the API surface is a
|
||||
contract for the frontend; we don't want migrations to silently change
|
||||
JSON shape.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class _Base(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ── scenarios ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ScenarioOut(_Base):
|
||||
id: int
|
||||
external_id: str
|
||||
kind: str
|
||||
name: str | None
|
||||
description: str | None
|
||||
parent_scenario_id: int | None
|
||||
jurisdiction: str
|
||||
strategy: str
|
||||
leave_uk_year: int
|
||||
glide_path: str
|
||||
spending_gbp: Decimal
|
||||
horizon_years: int
|
||||
nw_seed_gbp: Decimal
|
||||
savings_per_year_gbp: Decimal
|
||||
config_json: dict[str, Any]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ScenarioCreate(BaseModel):
|
||||
"""Body for POST /scenarios — user-defined scenario."""
|
||||
name: str = Field(min_length=1, max_length=200)
|
||||
description: str | None = None
|
||||
parent_scenario_id: int | None = None
|
||||
jurisdiction: str
|
||||
strategy: str
|
||||
leave_uk_year: int = Field(ge=0, le=60)
|
||||
glide_path: str
|
||||
spending_gbp: Decimal = Field(gt=0)
|
||||
horizon_years: int = Field(ge=5, le=100, default=60)
|
||||
nw_seed_gbp: Decimal = Field(ge=0)
|
||||
savings_per_year_gbp: Decimal = Field(ge=0, default=Decimal("0"))
|
||||
config_json: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ScenarioPatch(BaseModel):
|
||||
"""Body for PATCH /scenarios/{id} — all fields optional."""
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
jurisdiction: str | None = None
|
||||
strategy: str | None = None
|
||||
leave_uk_year: int | None = None
|
||||
glide_path: str | None = None
|
||||
spending_gbp: Decimal | None = None
|
||||
horizon_years: int | None = None
|
||||
nw_seed_gbp: Decimal | None = None
|
||||
savings_per_year_gbp: Decimal | None = None
|
||||
config_json: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# ── projections ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ProjectionPoint(_Base):
|
||||
year_idx: int
|
||||
p10_portfolio_gbp: Decimal
|
||||
p25_portfolio_gbp: Decimal
|
||||
p50_portfolio_gbp: Decimal
|
||||
p75_portfolio_gbp: Decimal
|
||||
p90_portfolio_gbp: Decimal
|
||||
p50_withdrawal_gbp: Decimal
|
||||
p50_tax_gbp: Decimal
|
||||
survival_rate: Decimal
|
||||
|
||||
|
||||
class ScenarioProjection(BaseModel):
|
||||
"""Latest MC run + per-year fan-chart series for a scenario."""
|
||||
scenario_id: int
|
||||
external_id: str
|
||||
mc_run_id: int
|
||||
run_at: datetime
|
||||
n_paths: int
|
||||
success_rate: Decimal
|
||||
p10_ending_gbp: Decimal
|
||||
p50_ending_gbp: Decimal
|
||||
p90_ending_gbp: Decimal
|
||||
median_lifetime_tax_gbp: Decimal
|
||||
median_years_to_ruin: Decimal | None
|
||||
yearly: list[ProjectionPoint]
|
||||
|
||||
|
||||
# ── net worth ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AccountSnapshotOut(_Base):
|
||||
account_id: str
|
||||
account_name: str
|
||||
account_type: str
|
||||
currency: str
|
||||
snapshot_date: date
|
||||
market_value: Decimal
|
||||
market_value_gbp: Decimal
|
||||
cost_basis_gbp: Decimal | None
|
||||
|
||||
|
||||
class NetWorthCurrent(BaseModel):
|
||||
"""Snapshot at one point in time (latest by default)."""
|
||||
snapshot_date: date
|
||||
total_gbp: Decimal
|
||||
accounts: list[AccountSnapshotOut]
|
||||
|
||||
|
||||
class NetWorthHistoryPoint(BaseModel):
|
||||
snapshot_date: date
|
||||
total_gbp: Decimal
|
||||
by_account: dict[str, Decimal]
|
||||
|
||||
|
||||
class NetWorthHistory(BaseModel):
|
||||
"""Per-day NW totals + per-account breakdown for a stacked area chart."""
|
||||
points: list[NetWorthHistoryPoint]
|
||||
|
||||
|
||||
# ── life events ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class LifeEventOut(_Base):
|
||||
id: int
|
||||
scenario_id: int
|
||||
kind: str
|
||||
name: str
|
||||
year_start: int
|
||||
year_end: int | None
|
||||
delta_gbp_per_year: Decimal
|
||||
one_time_amount_gbp: Decimal | None
|
||||
enabled: bool
|
||||
payload: dict[str, Any] | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class LifeEventCreate(BaseModel):
|
||||
kind: str
|
||||
name: str = Field(min_length=1, max_length=200)
|
||||
year_start: int = Field(ge=0, le=100)
|
||||
year_end: int | None = Field(default=None, ge=0, le=100)
|
||||
delta_gbp_per_year: Decimal = Decimal("0")
|
||||
one_time_amount_gbp: Decimal | None = None
|
||||
enabled: bool = True
|
||||
payload: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LifeEventPatch(BaseModel):
|
||||
kind: str | None = None
|
||||
name: str | None = None
|
||||
year_start: int | None = None
|
||||
year_end: int | None = None
|
||||
delta_gbp_per_year: Decimal | None = None
|
||||
one_time_amount_gbp: Decimal | None = None
|
||||
enabled: bool | None = None
|
||||
payload: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# ── goals ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GoalOut(_Base):
|
||||
id: int
|
||||
scenario_id: int
|
||||
kind: str
|
||||
name: str
|
||||
target_amount_gbp: Decimal | None
|
||||
target_year: int | None
|
||||
comparator: str
|
||||
success_threshold: Decimal
|
||||
enabled: bool
|
||||
payload: dict[str, Any] | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class GoalCreate(BaseModel):
|
||||
kind: str
|
||||
name: str = Field(min_length=1, max_length=200)
|
||||
target_amount_gbp: Decimal | None = None
|
||||
target_year: int | None = Field(default=None, ge=0, le=100)
|
||||
comparator: str = ">="
|
||||
success_threshold: Decimal = Field(default=Decimal("0.95"), ge=0, le=1)
|
||||
enabled: bool = True
|
||||
payload: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# ── simulate / compare ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class SimulateRequest(BaseModel):
|
||||
"""Sync, non-persisted simulate. Used by the React UI for what-if."""
|
||||
jurisdiction: str
|
||||
strategy: str
|
||||
leave_uk_year: int = Field(ge=0, le=60)
|
||||
glide_path: str = "rising"
|
||||
spending_gbp: Decimal = Field(gt=0)
|
||||
nw_seed_gbp: Decimal = Field(ge=0)
|
||||
savings_per_year_gbp: Decimal = Decimal("0")
|
||||
horizon_years: int = Field(ge=5, le=100, default=60)
|
||||
floor_gbp: Decimal | None = None
|
||||
n_paths: int = Field(ge=100, le=50_000, default=5_000)
|
||||
seed: int = 42
|
||||
|
||||
|
||||
class SimulateResult(BaseModel):
|
||||
success_rate: Decimal
|
||||
p10_ending_gbp: Decimal
|
||||
p50_ending_gbp: Decimal
|
||||
p90_ending_gbp: Decimal
|
||||
median_lifetime_tax_gbp: Decimal
|
||||
median_years_to_ruin: Decimal | None
|
||||
elapsed_seconds: Decimal
|
||||
yearly: list[ProjectionPoint]
|
||||
|
||||
|
||||
class CompareRequest(BaseModel):
|
||||
scenarios: list[SimulateRequest] = Field(min_length=2, max_length=5)
|
||||
|
||||
|
||||
class CompareResult(BaseModel):
|
||||
results: list[SimulateResult]
|
||||
125
fire_planner/api/simulate.py
Normal file
125
fire_planner/api/simulate.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""Sync simulate + multi-scenario compare.
|
||||
|
||||
Unlike the persisted Cartesian recompute (`/recompute`), these run a
|
||||
single scenario inline and return the result immediately. The React UI
|
||||
uses these for what-if exploration — no DB write.
|
||||
|
||||
Returns a fan-chart series in the same shape as
|
||||
`GET /scenarios/{id}/projection`, so frontend chart code is shared.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from fire_planner.api.auth import require_bearer
|
||||
from fire_planner.api.schemas import (
|
||||
CompareRequest,
|
||||
CompareResult,
|
||||
ProjectionPoint,
|
||||
SimulateRequest,
|
||||
SimulateResult,
|
||||
)
|
||||
from fire_planner.glide_path import get as get_glide
|
||||
from fire_planner.returns.bootstrap import block_bootstrap
|
||||
from fire_planner.returns.shiller import load_from_csv, synthetic_returns
|
||||
from fire_planner.scenarios import build_regime_schedule, build_strategy
|
||||
from fire_planner.simulator import SimulationResult, simulate
|
||||
|
||||
router = APIRouter(tags=["simulate"], dependencies=[Depends(require_bearer)])
|
||||
|
||||
_RETURNS_CSV = Path("/data/shiller_returns.csv")
|
||||
|
||||
|
||||
def _load_paths(seed: int, n_paths: int, n_years: int) -> np.ndarray:
|
||||
bundle = (load_from_csv(_RETURNS_CSV) if _RETURNS_CSV.exists() else synthetic_returns(seed=42))
|
||||
rng = np.random.default_rng(seed)
|
||||
return block_bootstrap(bundle, n_paths=n_paths, n_years=n_years, block_size=5, rng=rng)
|
||||
|
||||
|
||||
def _project(req: SimulateRequest) -> tuple[SimulationResult, float]:
|
||||
paths = _load_paths(req.seed, req.n_paths, req.horizon_years)
|
||||
annual_savings = (np.full(req.horizon_years, float(req.savings_per_year_gbp), dtype=np.float64)
|
||||
if req.savings_per_year_gbp > 0 else None)
|
||||
floor = float(req.floor_gbp) if req.floor_gbp is not None else None
|
||||
started = time.perf_counter()
|
||||
result = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=float(req.nw_seed_gbp),
|
||||
spending_target=float(req.spending_gbp),
|
||||
glide=get_glide(req.glide_path),
|
||||
strategy=build_strategy(req.strategy, floor=floor),
|
||||
regime=build_regime_schedule(req.jurisdiction, req.leave_uk_year),
|
||||
horizon_years=req.horizon_years,
|
||||
annual_savings=annual_savings,
|
||||
)
|
||||
elapsed = time.perf_counter() - started
|
||||
return result, elapsed
|
||||
|
||||
|
||||
def _to_response(result: SimulationResult, elapsed: float) -> SimulateResult:
|
||||
# portfolio_real has n_years+1 columns (year 0 = seed, year k = end-of-year k).
|
||||
# withdrawal_real / tax_real have n_years columns (year k = withdrawn in year k+1).
|
||||
# Yearly point k describes "end of year k+1": portfolio after withdrawal & growth.
|
||||
pcts = [10, 25, 50, 75, 90]
|
||||
portfolio_quantiles = {p: np.percentile(result.portfolio_real, p, axis=0) for p in pcts}
|
||||
median_wd = np.percentile(result.withdrawal_real, 50, axis=0)
|
||||
median_tax = np.percentile(result.tax_real, 50, axis=0)
|
||||
n_years = result.n_years
|
||||
survival_path = (result.success_mask.astype(np.float64).mean(axis=0) if
|
||||
result.success_mask.ndim == 2 else np.ones(n_years))
|
||||
|
||||
yearly = [
|
||||
ProjectionPoint(
|
||||
year_idx=y,
|
||||
p10_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[10][y + 1]), 2))),
|
||||
p25_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[25][y + 1]), 2))),
|
||||
p50_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[50][y + 1]), 2))),
|
||||
p75_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[75][y + 1]), 2))),
|
||||
p90_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[90][y + 1]), 2))),
|
||||
p50_withdrawal_gbp=Decimal(str(round(float(median_wd[y]), 2))),
|
||||
p50_tax_gbp=Decimal(str(round(float(median_tax[y]), 2))),
|
||||
survival_rate=Decimal(str(round(float(survival_path[y]), 4))),
|
||||
) for y in range(n_years)
|
||||
]
|
||||
median_ytr = result.median_years_to_ruin()
|
||||
return SimulateResult(
|
||||
success_rate=Decimal(str(round(float(result.success_rate), 4))),
|
||||
p10_ending_gbp=Decimal(str(round(float(result.ending_percentile(10)), 2))),
|
||||
p50_ending_gbp=Decimal(str(round(float(result.ending_percentile(50)), 2))),
|
||||
p90_ending_gbp=Decimal(str(round(float(result.ending_percentile(90)), 2))),
|
||||
median_lifetime_tax_gbp=Decimal(str(round(float(result.median_lifetime_tax()), 2))),
|
||||
median_years_to_ruin=(Decimal(str(round(float(median_ytr), 2)))
|
||||
if median_ytr is not None else None),
|
||||
elapsed_seconds=Decimal(str(round(elapsed, 3))),
|
||||
yearly=yearly,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/simulate", response_model=SimulateResult)
|
||||
async def simulate_one(req: SimulateRequest) -> SimulateResult:
|
||||
"""Run one scenario synchronously, no DB write. ~1-3s for 5k paths."""
|
||||
try:
|
||||
result, elapsed = await asyncio.to_thread(_project, req)
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown name: {e}") from None
|
||||
return _to_response(result, elapsed)
|
||||
|
||||
|
||||
@router.post("/compare", response_model=CompareResult)
|
||||
async def compare_scenarios(req: CompareRequest) -> CompareResult:
|
||||
"""Run 2-5 scenarios in parallel, return all results."""
|
||||
async def one(s: SimulateRequest) -> SimulateResult:
|
||||
result, elapsed = await asyncio.to_thread(_project, s)
|
||||
return _to_response(result, elapsed)
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*(one(s) for s in req.scenarios))
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown name: {e}") from None
|
||||
return CompareResult(results=results)
|
||||
Loading…
Add table
Add a link
Reference in a new issue