api: expand FastAPI surface for scenarios, networth, life-events, goals, simulate
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
Adds the read+write endpoints the frontend needs to drive a
ProjectionLab-style UX on top of the existing engine.
- /networth, /networth/history — NW total + per-account from
account_snapshot (frontend chart)
- /scenarios CRUD + projection — list/get/create/patch/delete user
scenarios; cartesian read-only
- /scenarios/{id}/life-events — life event CRUD nested under scenario
- /life-events/{id} — patch + delete by id
- /scenarios/{id}/goals,
/goals/{id} — retirement goal CRUD
- /simulate, /compare — sync, no-DB-write what-if endpoints
Auth: Bearer-token dependency on writes + simulate when API_BEARER_TOKEN
is set; reads always open (lock down via Authentik-fronted ingress in
prod). Existing /recompute keeps its bearer auth.
CORS middleware reads FRONTEND_ORIGINS (comma-separated) for the dev
SPA. Lifespan now provisions the SQLAlchemy engine + session_factory
on app.state and disposes them on shutdown.
40 new tests covering happy paths and validation. 172 tests total.
mypy strict + ruff clean (B008 ignore added — Depends() in defaults
is the canonical FastAPI pattern, not a bug).
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
31193faf08
commit
ee6ed1d3c4
15 changed files with 1570 additions and 74 deletions
1
fire_planner/api/__init__.py
Normal file
1
fire_planner/api/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""HTTP API surface — read + write endpoints over the engine + DB."""
|
||||
42
fire_planner/api/auth.py
Normal file
42
fire_planner/api/auth.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Bearer-token auth shared across routers.
|
||||
|
||||
Two modes, picked at startup from env:
|
||||
- API_BEARER_TOKEN set → enforce Bearer auth on all write/compute paths
|
||||
- API_BEARER_TOKEN unset (dev) → no auth, log a one-time warning
|
||||
|
||||
Read endpoints (`/networth`, `/scenarios`, ...) skip auth entirely so
|
||||
the frontend can render without juggling tokens during dev. Lock those
|
||||
down later via Authentik-fronted ingress when we deploy.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import Header, HTTPException
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_warned_unauth = False
|
||||
|
||||
|
||||
def _read_token() -> str | None:
|
||||
return os.environ.get("API_BEARER_TOKEN") or os.environ.get("RECOMPUTE_BEARER_TOKEN")
|
||||
|
||||
|
||||
async def require_bearer(authorization: str | None = Header(default=None)) -> None:
|
||||
"""FastAPI dependency: enforce bearer auth IF API_BEARER_TOKEN is set."""
|
||||
expected = _read_token()
|
||||
if not expected:
|
||||
global _warned_unauth
|
||||
if not _warned_unauth:
|
||||
log.warning("API_BEARER_TOKEN unset — write endpoints are open. "
|
||||
"Set it before exposing this service.")
|
||||
_warned_unauth = True
|
||||
return
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing bearer token")
|
||||
token = authorization.removeprefix("Bearer ")
|
||||
if not hmac.compare_digest(token, expected):
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
18
fire_planner/api/dependencies.py
Normal file
18
fire_planner/api/dependencies.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Shared FastAPI dependencies — DB session per request."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
async def get_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||
"""Yield an AsyncSession bound to the engine on app.state.
|
||||
|
||||
The engine + session factory are wired up in `app.lifespan`. Tests
|
||||
swap them out via dependency_overrides.
|
||||
"""
|
||||
factory: async_sessionmaker[AsyncSession] = request.app.state.session_factory
|
||||
async with factory() as session:
|
||||
yield session
|
||||
68
fire_planner/api/goals.py
Normal file
68
fire_planner/api/goals.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
"""Retirement-goal CRUD nested under a scenario."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.api.auth import require_bearer
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.api.schemas import GoalCreate, GoalOut
|
||||
from fire_planner.db import RetirementGoal, Scenario
|
||||
|
||||
router = APIRouter(tags=["goals"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/scenarios/{scenario_id}/goals",
|
||||
response_model=list[GoalOut],
|
||||
)
|
||||
async def list_goals(
|
||||
scenario_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[RetirementGoal]:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
rows = (await session.execute(
|
||||
select(RetirementGoal).where(RetirementGoal.scenario_id == scenario_id).order_by(
|
||||
RetirementGoal.id))).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scenarios/{scenario_id}/goals",
|
||||
response_model=GoalOut,
|
||||
status_code=201,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def create_goal(
|
||||
scenario_id: int,
|
||||
payload: GoalCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> RetirementGoal:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
goal = RetirementGoal(scenario_id=scenario_id, **payload.model_dump())
|
||||
session.add(goal)
|
||||
await session.commit()
|
||||
await session.refresh(goal)
|
||||
return goal
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/goals/{goal_id}",
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def delete_goal(
|
||||
goal_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
goal = await session.get(RetirementGoal, goal_id)
|
||||
if goal is None:
|
||||
raise HTTPException(status_code=404, detail="Goal not found")
|
||||
await session.execute(delete(RetirementGoal).where(RetirementGoal.id == goal_id))
|
||||
await session.commit()
|
||||
93
fire_planner/api/life_events.py
Normal file
93
fire_planner/api/life_events.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""Life-event CRUD nested under a scenario."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.api.auth import require_bearer
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.api.schemas import LifeEventCreate, LifeEventOut, LifeEventPatch
|
||||
from fire_planner.db import LifeEvent, Scenario
|
||||
|
||||
router = APIRouter(tags=["life-events"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/scenarios/{scenario_id}/life-events",
|
||||
response_model=list[LifeEventOut],
|
||||
)
|
||||
async def list_events(
|
||||
scenario_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[LifeEvent]:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
rows = (await session.execute(
|
||||
select(LifeEvent).where(LifeEvent.scenario_id == scenario_id).order_by(
|
||||
LifeEvent.year_start, LifeEvent.id))).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scenarios/{scenario_id}/life-events",
|
||||
response_model=LifeEventOut,
|
||||
status_code=201,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def create_event(
|
||||
scenario_id: int,
|
||||
payload: LifeEventCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> LifeEvent:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
if payload.year_end is not None and payload.year_end < payload.year_start:
|
||||
raise HTTPException(status_code=400, detail="year_end < year_start")
|
||||
ev = LifeEvent(scenario_id=scenario_id, **payload.model_dump())
|
||||
session.add(ev)
|
||||
await session.commit()
|
||||
await session.refresh(ev)
|
||||
return ev
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/life-events/{event_id}",
|
||||
response_model=LifeEventOut,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def patch_event(
|
||||
event_id: int,
|
||||
payload: LifeEventPatch,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> LifeEvent:
|
||||
ev = await session.get(LifeEvent, event_id)
|
||||
if ev is None:
|
||||
raise HTTPException(status_code=404, detail="Event not found")
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
for k, v in updates.items():
|
||||
setattr(ev, k, v)
|
||||
if ev.year_end is not None and ev.year_end < ev.year_start:
|
||||
raise HTTPException(status_code=400, detail="year_end < year_start")
|
||||
await session.commit()
|
||||
await session.refresh(ev)
|
||||
return ev
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/life-events/{event_id}",
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def delete_event(
|
||||
event_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
ev = await session.get(LifeEvent, event_id)
|
||||
if ev is None:
|
||||
raise HTTPException(status_code=404, detail="Event not found")
|
||||
await session.execute(delete(LifeEvent).where(LifeEvent.id == event_id))
|
||||
await session.commit()
|
||||
78
fire_planner/api/networth.py
Normal file
78
fire_planner/api/networth.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
"""Net-worth read endpoints.
|
||||
|
||||
Reads from `fire_planner.account_snapshot` (populated hourly by the
|
||||
wealthfolio ingest). Two views:
|
||||
- GET /networth → latest snapshot per account, totals
|
||||
- GET /networth/history → daily totals + per-account series, for charts
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.api.schemas import (
|
||||
AccountSnapshotOut,
|
||||
NetWorthCurrent,
|
||||
NetWorthHistory,
|
||||
NetWorthHistoryPoint,
|
||||
)
|
||||
from fire_planner.db import AccountSnapshot
|
||||
|
||||
router = APIRouter(prefix="/networth", tags=["networth"])
|
||||
|
||||
|
||||
@router.get("", response_model=NetWorthCurrent)
|
||||
async def current_networth(session: AsyncSession = Depends(get_session)) -> NetWorthCurrent:
|
||||
"""Latest snapshot per account + GBP total."""
|
||||
latest_date = (await session.execute(
|
||||
select(AccountSnapshot.snapshot_date).order_by(
|
||||
AccountSnapshot.snapshot_date.desc()).limit(1))).scalar()
|
||||
if latest_date is None:
|
||||
return NetWorthCurrent(snapshot_date=date.today(), total_gbp=Decimal("0"), accounts=[])
|
||||
rows = (await session.execute(
|
||||
select(AccountSnapshot).where(
|
||||
AccountSnapshot.snapshot_date == latest_date))).scalars().all()
|
||||
accounts = [AccountSnapshotOut.model_validate(r) for r in rows]
|
||||
total = sum((a.market_value_gbp for a in accounts), Decimal("0"))
|
||||
return NetWorthCurrent(snapshot_date=latest_date, total_gbp=total, accounts=accounts)
|
||||
|
||||
|
||||
@router.get("/history", response_model=NetWorthHistory)
|
||||
async def networth_history(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
days: int = Query(default=365, ge=1, le=3650, description="Look-back window."),
|
||||
) -> NetWorthHistory:
|
||||
"""Daily NW total + per-account breakdown for a stacked area chart.
|
||||
|
||||
Picks one row per (account_id, snapshot_date) — wealthfolio ingest
|
||||
upserts daily so this is already de-duped, but we group defensively.
|
||||
"""
|
||||
rows = (await session.execute(
|
||||
select(
|
||||
AccountSnapshot.snapshot_date,
|
||||
AccountSnapshot.account_name,
|
||||
AccountSnapshot.market_value_gbp,
|
||||
).order_by(AccountSnapshot.snapshot_date))).all()
|
||||
if not rows:
|
||||
return NetWorthHistory(points=[])
|
||||
|
||||
by_date: dict[date, dict[str, Decimal]] = defaultdict(lambda: defaultdict(lambda: Decimal("0")))
|
||||
for snap_date, name, value in rows:
|
||||
by_date[snap_date][name] += Decimal(str(value))
|
||||
|
||||
cutoff_idx = max(0, len(by_date) - days)
|
||||
sorted_dates = sorted(by_date.keys())[cutoff_idx:]
|
||||
points = [
|
||||
NetWorthHistoryPoint(
|
||||
snapshot_date=d,
|
||||
total_gbp=sum(by_date[d].values(), Decimal("0")),
|
||||
by_account=dict(by_date[d]),
|
||||
) for d in sorted_dates
|
||||
]
|
||||
return NetWorthHistory(points=points)
|
||||
172
fire_planner/api/scenarios.py
Normal file
172
fire_planner/api/scenarios.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
"""Scenario CRUD + projection read.
|
||||
|
||||
Mixed surface:
|
||||
- GET /scenarios list (filter by kind)
|
||||
- GET /scenarios/{id} single
|
||||
- POST /scenarios create user scenario
|
||||
- PATCH /scenarios/{id} update user scenario
|
||||
- DELETE /scenarios/{id} delete user scenario (cartesian protected)
|
||||
- GET /scenarios/{id}/projection latest MC run + per-year fan
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fire_planner.api.auth import require_bearer
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.api.schemas import (
|
||||
ProjectionPoint,
|
||||
ScenarioCreate,
|
||||
ScenarioOut,
|
||||
ScenarioPatch,
|
||||
ScenarioProjection,
|
||||
)
|
||||
from fire_planner.db import McRun, ProjectionYearly, Scenario
|
||||
|
||||
router = APIRouter(prefix="/scenarios", tags=["scenarios"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[ScenarioOut])
|
||||
async def list_scenarios(
|
||||
kind: str | None = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Scenario]:
|
||||
"""List all scenarios. Filter `kind=user` or `kind=cartesian`."""
|
||||
stmt = select(Scenario).order_by(Scenario.id)
|
||||
if kind is not None:
|
||||
stmt = stmt.where(Scenario.kind == kind)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return list(rows)
|
||||
|
||||
|
||||
@router.get("/{scenario_id}", response_model=ScenarioOut)
|
||||
async def get_scenario(
|
||||
scenario_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Scenario:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
return scen
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=ScenarioOut,
|
||||
status_code=201,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def create_scenario(
|
||||
payload: ScenarioCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Scenario:
|
||||
"""Create a user scenario. Cartesian scenarios come from the engine,
|
||||
not the API."""
|
||||
if payload.parent_scenario_id is not None:
|
||||
parent = await session.get(Scenario, payload.parent_scenario_id)
|
||||
if parent is None:
|
||||
raise HTTPException(status_code=400, detail="parent_scenario_id not found")
|
||||
|
||||
scen = Scenario(
|
||||
external_id=f"user-{uuid.uuid4().hex[:12]}",
|
||||
kind="user",
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
parent_scenario_id=payload.parent_scenario_id,
|
||||
jurisdiction=payload.jurisdiction,
|
||||
strategy=payload.strategy,
|
||||
leave_uk_year=payload.leave_uk_year,
|
||||
glide_path=payload.glide_path,
|
||||
spending_gbp=payload.spending_gbp,
|
||||
horizon_years=payload.horizon_years,
|
||||
nw_seed_gbp=payload.nw_seed_gbp,
|
||||
savings_per_year_gbp=payload.savings_per_year_gbp,
|
||||
config_json=payload.config_json,
|
||||
)
|
||||
session.add(scen)
|
||||
await session.commit()
|
||||
await session.refresh(scen)
|
||||
return scen
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{scenario_id}",
|
||||
response_model=ScenarioOut,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def patch_scenario(
|
||||
scenario_id: int,
|
||||
payload: ScenarioPatch,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Scenario:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
if scen.kind != "user":
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Cannot patch cartesian scenarios — they're auto-generated")
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
for k, v in updates.items():
|
||||
setattr(scen, k, v)
|
||||
await session.commit()
|
||||
await session.refresh(scen)
|
||||
return scen
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{scenario_id}",
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def delete_scenario(
|
||||
scenario_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
if scen.kind != "user":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete cartesian scenarios — they re-appear on recompute",
|
||||
)
|
||||
await session.execute(delete(Scenario).where(Scenario.id == scenario_id))
|
||||
await session.commit()
|
||||
|
||||
|
||||
@router.get("/{scenario_id}/projection", response_model=ScenarioProjection)
|
||||
async def get_scenario_projection(
|
||||
scenario_id: int,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> ScenarioProjection:
|
||||
"""Latest MC run for this scenario + the per-year fan series."""
|
||||
scen = await session.get(Scenario, scenario_id)
|
||||
if scen is None:
|
||||
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||
run = (await session.execute(
|
||||
select(McRun).where(McRun.scenario_id == scenario_id).order_by(
|
||||
McRun.run_at.desc()).limit(1))).scalar_one_or_none()
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail="No MC runs persisted for this scenario yet")
|
||||
yearly_rows = (await session.execute(
|
||||
select(ProjectionYearly).where(ProjectionYearly.mc_run_id == run.id).order_by(
|
||||
ProjectionYearly.year_idx))).scalars().all()
|
||||
return ScenarioProjection(
|
||||
scenario_id=scen.id,
|
||||
external_id=scen.external_id,
|
||||
mc_run_id=run.id,
|
||||
run_at=run.run_at,
|
||||
n_paths=run.n_paths,
|
||||
success_rate=run.success_rate,
|
||||
p10_ending_gbp=run.p10_ending_gbp,
|
||||
p50_ending_gbp=run.p50_ending_gbp,
|
||||
p90_ending_gbp=run.p90_ending_gbp,
|
||||
median_lifetime_tax_gbp=run.median_lifetime_tax_gbp,
|
||||
median_years_to_ruin=run.median_years_to_ruin,
|
||||
yearly=[ProjectionPoint.model_validate(y) for y in yearly_rows],
|
||||
)
|
||||
237
fire_planner/api/schemas.py
Normal file
237
fire_planner/api/schemas.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
"""Pydantic response/request schemas for the HTTP API.
|
||||
|
||||
Mirror the SQLAlchemy ORM but keep them de-coupled — the API surface is a
|
||||
contract for the frontend; we don't want migrations to silently change
|
||||
JSON shape.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class _Base(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ── scenarios ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ScenarioOut(_Base):
|
||||
id: int
|
||||
external_id: str
|
||||
kind: str
|
||||
name: str | None
|
||||
description: str | None
|
||||
parent_scenario_id: int | None
|
||||
jurisdiction: str
|
||||
strategy: str
|
||||
leave_uk_year: int
|
||||
glide_path: str
|
||||
spending_gbp: Decimal
|
||||
horizon_years: int
|
||||
nw_seed_gbp: Decimal
|
||||
savings_per_year_gbp: Decimal
|
||||
config_json: dict[str, Any]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ScenarioCreate(BaseModel):
|
||||
"""Body for POST /scenarios — user-defined scenario."""
|
||||
name: str = Field(min_length=1, max_length=200)
|
||||
description: str | None = None
|
||||
parent_scenario_id: int | None = None
|
||||
jurisdiction: str
|
||||
strategy: str
|
||||
leave_uk_year: int = Field(ge=0, le=60)
|
||||
glide_path: str
|
||||
spending_gbp: Decimal = Field(gt=0)
|
||||
horizon_years: int = Field(ge=5, le=100, default=60)
|
||||
nw_seed_gbp: Decimal = Field(ge=0)
|
||||
savings_per_year_gbp: Decimal = Field(ge=0, default=Decimal("0"))
|
||||
config_json: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ScenarioPatch(BaseModel):
|
||||
"""Body for PATCH /scenarios/{id} — all fields optional."""
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
jurisdiction: str | None = None
|
||||
strategy: str | None = None
|
||||
leave_uk_year: int | None = None
|
||||
glide_path: str | None = None
|
||||
spending_gbp: Decimal | None = None
|
||||
horizon_years: int | None = None
|
||||
nw_seed_gbp: Decimal | None = None
|
||||
savings_per_year_gbp: Decimal | None = None
|
||||
config_json: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# ── projections ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ProjectionPoint(_Base):
|
||||
year_idx: int
|
||||
p10_portfolio_gbp: Decimal
|
||||
p25_portfolio_gbp: Decimal
|
||||
p50_portfolio_gbp: Decimal
|
||||
p75_portfolio_gbp: Decimal
|
||||
p90_portfolio_gbp: Decimal
|
||||
p50_withdrawal_gbp: Decimal
|
||||
p50_tax_gbp: Decimal
|
||||
survival_rate: Decimal
|
||||
|
||||
|
||||
class ScenarioProjection(BaseModel):
|
||||
"""Latest MC run + per-year fan-chart series for a scenario."""
|
||||
scenario_id: int
|
||||
external_id: str
|
||||
mc_run_id: int
|
||||
run_at: datetime
|
||||
n_paths: int
|
||||
success_rate: Decimal
|
||||
p10_ending_gbp: Decimal
|
||||
p50_ending_gbp: Decimal
|
||||
p90_ending_gbp: Decimal
|
||||
median_lifetime_tax_gbp: Decimal
|
||||
median_years_to_ruin: Decimal | None
|
||||
yearly: list[ProjectionPoint]
|
||||
|
||||
|
||||
# ── net worth ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AccountSnapshotOut(_Base):
|
||||
account_id: str
|
||||
account_name: str
|
||||
account_type: str
|
||||
currency: str
|
||||
snapshot_date: date
|
||||
market_value: Decimal
|
||||
market_value_gbp: Decimal
|
||||
cost_basis_gbp: Decimal | None
|
||||
|
||||
|
||||
class NetWorthCurrent(BaseModel):
|
||||
"""Snapshot at one point in time (latest by default)."""
|
||||
snapshot_date: date
|
||||
total_gbp: Decimal
|
||||
accounts: list[AccountSnapshotOut]
|
||||
|
||||
|
||||
class NetWorthHistoryPoint(BaseModel):
|
||||
snapshot_date: date
|
||||
total_gbp: Decimal
|
||||
by_account: dict[str, Decimal]
|
||||
|
||||
|
||||
class NetWorthHistory(BaseModel):
|
||||
"""Per-day NW totals + per-account breakdown for a stacked area chart."""
|
||||
points: list[NetWorthHistoryPoint]
|
||||
|
||||
|
||||
# ── life events ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class LifeEventOut(_Base):
|
||||
id: int
|
||||
scenario_id: int
|
||||
kind: str
|
||||
name: str
|
||||
year_start: int
|
||||
year_end: int | None
|
||||
delta_gbp_per_year: Decimal
|
||||
one_time_amount_gbp: Decimal | None
|
||||
enabled: bool
|
||||
payload: dict[str, Any] | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class LifeEventCreate(BaseModel):
|
||||
kind: str
|
||||
name: str = Field(min_length=1, max_length=200)
|
||||
year_start: int = Field(ge=0, le=100)
|
||||
year_end: int | None = Field(default=None, ge=0, le=100)
|
||||
delta_gbp_per_year: Decimal = Decimal("0")
|
||||
one_time_amount_gbp: Decimal | None = None
|
||||
enabled: bool = True
|
||||
payload: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LifeEventPatch(BaseModel):
|
||||
kind: str | None = None
|
||||
name: str | None = None
|
||||
year_start: int | None = None
|
||||
year_end: int | None = None
|
||||
delta_gbp_per_year: Decimal | None = None
|
||||
one_time_amount_gbp: Decimal | None = None
|
||||
enabled: bool | None = None
|
||||
payload: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# ── goals ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GoalOut(_Base):
|
||||
id: int
|
||||
scenario_id: int
|
||||
kind: str
|
||||
name: str
|
||||
target_amount_gbp: Decimal | None
|
||||
target_year: int | None
|
||||
comparator: str
|
||||
success_threshold: Decimal
|
||||
enabled: bool
|
||||
payload: dict[str, Any] | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class GoalCreate(BaseModel):
|
||||
kind: str
|
||||
name: str = Field(min_length=1, max_length=200)
|
||||
target_amount_gbp: Decimal | None = None
|
||||
target_year: int | None = Field(default=None, ge=0, le=100)
|
||||
comparator: str = ">="
|
||||
success_threshold: Decimal = Field(default=Decimal("0.95"), ge=0, le=1)
|
||||
enabled: bool = True
|
||||
payload: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# ── simulate / compare ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class SimulateRequest(BaseModel):
|
||||
"""Sync, non-persisted simulate. Used by the React UI for what-if."""
|
||||
jurisdiction: str
|
||||
strategy: str
|
||||
leave_uk_year: int = Field(ge=0, le=60)
|
||||
glide_path: str = "rising"
|
||||
spending_gbp: Decimal = Field(gt=0)
|
||||
nw_seed_gbp: Decimal = Field(ge=0)
|
||||
savings_per_year_gbp: Decimal = Decimal("0")
|
||||
horizon_years: int = Field(ge=5, le=100, default=60)
|
||||
floor_gbp: Decimal | None = None
|
||||
n_paths: int = Field(ge=100, le=50_000, default=5_000)
|
||||
seed: int = 42
|
||||
|
||||
|
||||
class SimulateResult(BaseModel):
|
||||
success_rate: Decimal
|
||||
p10_ending_gbp: Decimal
|
||||
p50_ending_gbp: Decimal
|
||||
p90_ending_gbp: Decimal
|
||||
median_lifetime_tax_gbp: Decimal
|
||||
median_years_to_ruin: Decimal | None
|
||||
elapsed_seconds: Decimal
|
||||
yearly: list[ProjectionPoint]
|
||||
|
||||
|
||||
class CompareRequest(BaseModel):
|
||||
scenarios: list[SimulateRequest] = Field(min_length=2, max_length=5)
|
||||
|
||||
|
||||
class CompareResult(BaseModel):
|
||||
results: list[SimulateResult]
|
||||
125
fire_planner/api/simulate.py
Normal file
125
fire_planner/api/simulate.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""Sync simulate + multi-scenario compare.
|
||||
|
||||
Unlike the persisted Cartesian recompute (`/recompute`), these run a
|
||||
single scenario inline and return the result immediately. The React UI
|
||||
uses these for what-if exploration — no DB write.
|
||||
|
||||
Returns a fan-chart series in the same shape as
|
||||
`GET /scenarios/{id}/projection`, so frontend chart code is shared.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from fire_planner.api.auth import require_bearer
|
||||
from fire_planner.api.schemas import (
|
||||
CompareRequest,
|
||||
CompareResult,
|
||||
ProjectionPoint,
|
||||
SimulateRequest,
|
||||
SimulateResult,
|
||||
)
|
||||
from fire_planner.glide_path import get as get_glide
|
||||
from fire_planner.returns.bootstrap import block_bootstrap
|
||||
from fire_planner.returns.shiller import load_from_csv, synthetic_returns
|
||||
from fire_planner.scenarios import build_regime_schedule, build_strategy
|
||||
from fire_planner.simulator import SimulationResult, simulate
|
||||
|
||||
router = APIRouter(tags=["simulate"], dependencies=[Depends(require_bearer)])
|
||||
|
||||
_RETURNS_CSV = Path("/data/shiller_returns.csv")
|
||||
|
||||
|
||||
def _load_paths(seed: int, n_paths: int, n_years: int) -> np.ndarray:
|
||||
bundle = (load_from_csv(_RETURNS_CSV) if _RETURNS_CSV.exists() else synthetic_returns(seed=42))
|
||||
rng = np.random.default_rng(seed)
|
||||
return block_bootstrap(bundle, n_paths=n_paths, n_years=n_years, block_size=5, rng=rng)
|
||||
|
||||
|
||||
def _project(req: SimulateRequest) -> tuple[SimulationResult, float]:
|
||||
paths = _load_paths(req.seed, req.n_paths, req.horizon_years)
|
||||
annual_savings = (np.full(req.horizon_years, float(req.savings_per_year_gbp), dtype=np.float64)
|
||||
if req.savings_per_year_gbp > 0 else None)
|
||||
floor = float(req.floor_gbp) if req.floor_gbp is not None else None
|
||||
started = time.perf_counter()
|
||||
result = simulate(
|
||||
paths=paths,
|
||||
initial_portfolio=float(req.nw_seed_gbp),
|
||||
spending_target=float(req.spending_gbp),
|
||||
glide=get_glide(req.glide_path),
|
||||
strategy=build_strategy(req.strategy, floor=floor),
|
||||
regime=build_regime_schedule(req.jurisdiction, req.leave_uk_year),
|
||||
horizon_years=req.horizon_years,
|
||||
annual_savings=annual_savings,
|
||||
)
|
||||
elapsed = time.perf_counter() - started
|
||||
return result, elapsed
|
||||
|
||||
|
||||
def _to_response(result: SimulationResult, elapsed: float) -> SimulateResult:
|
||||
# portfolio_real has n_years+1 columns (year 0 = seed, year k = end-of-year k).
|
||||
# withdrawal_real / tax_real have n_years columns (year k = withdrawn in year k+1).
|
||||
# Yearly point k describes "end of year k+1": portfolio after withdrawal & growth.
|
||||
pcts = [10, 25, 50, 75, 90]
|
||||
portfolio_quantiles = {p: np.percentile(result.portfolio_real, p, axis=0) for p in pcts}
|
||||
median_wd = np.percentile(result.withdrawal_real, 50, axis=0)
|
||||
median_tax = np.percentile(result.tax_real, 50, axis=0)
|
||||
n_years = result.n_years
|
||||
survival_path = (result.success_mask.astype(np.float64).mean(axis=0) if
|
||||
result.success_mask.ndim == 2 else np.ones(n_years))
|
||||
|
||||
yearly = [
|
||||
ProjectionPoint(
|
||||
year_idx=y,
|
||||
p10_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[10][y + 1]), 2))),
|
||||
p25_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[25][y + 1]), 2))),
|
||||
p50_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[50][y + 1]), 2))),
|
||||
p75_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[75][y + 1]), 2))),
|
||||
p90_portfolio_gbp=Decimal(str(round(float(portfolio_quantiles[90][y + 1]), 2))),
|
||||
p50_withdrawal_gbp=Decimal(str(round(float(median_wd[y]), 2))),
|
||||
p50_tax_gbp=Decimal(str(round(float(median_tax[y]), 2))),
|
||||
survival_rate=Decimal(str(round(float(survival_path[y]), 4))),
|
||||
) for y in range(n_years)
|
||||
]
|
||||
median_ytr = result.median_years_to_ruin()
|
||||
return SimulateResult(
|
||||
success_rate=Decimal(str(round(float(result.success_rate), 4))),
|
||||
p10_ending_gbp=Decimal(str(round(float(result.ending_percentile(10)), 2))),
|
||||
p50_ending_gbp=Decimal(str(round(float(result.ending_percentile(50)), 2))),
|
||||
p90_ending_gbp=Decimal(str(round(float(result.ending_percentile(90)), 2))),
|
||||
median_lifetime_tax_gbp=Decimal(str(round(float(result.median_lifetime_tax()), 2))),
|
||||
median_years_to_ruin=(Decimal(str(round(float(median_ytr), 2)))
|
||||
if median_ytr is not None else None),
|
||||
elapsed_seconds=Decimal(str(round(elapsed, 3))),
|
||||
yearly=yearly,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/simulate", response_model=SimulateResult)
|
||||
async def simulate_one(req: SimulateRequest) -> SimulateResult:
|
||||
"""Run one scenario synchronously, no DB write. ~1-3s for 5k paths."""
|
||||
try:
|
||||
result, elapsed = await asyncio.to_thread(_project, req)
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown name: {e}") from None
|
||||
return _to_response(result, elapsed)
|
||||
|
||||
|
||||
@router.post("/compare", response_model=CompareResult)
|
||||
async def compare_scenarios(req: CompareRequest) -> CompareResult:
|
||||
"""Run 2-5 scenarios in parallel, return all results."""
|
||||
async def one(s: SimulateRequest) -> SimulateResult:
|
||||
result, elapsed = await asyncio.to_thread(_project, s)
|
||||
return _to_response(result, elapsed)
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*(one(s) for s in req.scenarios))
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown name: {e}") from None
|
||||
return CompareResult(results=results)
|
||||
|
|
@ -1,67 +1,131 @@
|
|||
"""FastAPI on-demand /recompute endpoint.
|
||||
"""FastAPI application — wires routers + middleware + lifespan.
|
||||
|
||||
Single deployment. Bearer-token auth (matches payslip-ingest pattern).
|
||||
The endpoint kicks the full 120-scenario Cartesian recompute against
|
||||
whatever the latest Wealthfolio snapshot is in `account_snapshot`.
|
||||
Routers:
|
||||
- /healthz, /metrics, /recompute — operational
|
||||
- /networth, /networth/history — read NW from account_snapshot
|
||||
- /scenarios/... — scenario CRUD + projection
|
||||
- /scenarios/{id}/life-events,
|
||||
/life-events/{id} — life event CRUD
|
||||
- /scenarios/{id}/goals,
|
||||
/goals/{id} — retirement goal CRUD
|
||||
- /simulate, /compare — sync simulate (no DB write)
|
||||
|
||||
For dev / smoke tests, a `/healthz` endpoint reports queue depth.
|
||||
Auth: write/compute paths take Bearer auth via the `require_bearer`
|
||||
dependency when `API_BEARER_TOKEN` is set. Read paths skip auth so the
|
||||
local frontend can hit them without juggling tokens — production
|
||||
deploys lock those down via Authentik-fronted ingress.
|
||||
|
||||
CORS: enabled for the frontend dev server. Comma-separated origins
|
||||
in `FRONTEND_ORIGINS` (defaults to a typical Vite localhost).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, status
|
||||
from fastapi import Depends, FastAPI, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
from fire_planner.api.auth import require_bearer
|
||||
from fire_planner.api.goals import router as goals_router
|
||||
from fire_planner.api.life_events import router as life_events_router
|
||||
from fire_planner.api.networth import router as networth_router
|
||||
from fire_planner.api.scenarios import router as scenarios_router
|
||||
from fire_planner.api.simulate import router as simulate_router
|
||||
from fire_planner.db import create_engine_from_env, make_session_factory
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
REQUIRED_ENV = ["DB_CONNECTION_STRING", "RECOMPUTE_BEARER_TOKEN"]
|
||||
|
||||
|
||||
def _verify_env() -> None:
|
||||
missing = [k for k in REQUIRED_ENV if not os.environ.get(k)]
|
||||
if missing:
|
||||
raise RuntimeError(f"Missing required env vars: {', '.join(missing)}")
|
||||
|
||||
|
||||
def _verify_bearer(authorization: str | None, expected: str) -> None:
|
||||
if not expected:
|
||||
raise HTTPException(status_code=401, detail="Service unauthenticated")
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing bearer token")
|
||||
token = authorization.removeprefix("Bearer ")
|
||||
if not hmac.compare_digest(token, expected):
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
def _frontend_origins() -> list[str]:
|
||||
raw = os.environ.get(
|
||||
"FRONTEND_ORIGINS",
|
||||
"http://localhost:5173,http://localhost:4173,http://127.0.0.1:5173",
|
||||
)
|
||||
return [s.strip() for s in raw.split(",") if s.strip()]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
_verify_env()
|
||||
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
||||
app.state.queue = queue
|
||||
yield
|
||||
if os.environ.get("DB_CONNECTION_STRING"):
|
||||
engine = create_engine_from_env()
|
||||
app.state.engine = engine
|
||||
app.state.session_factory = make_session_factory(engine)
|
||||
else:
|
||||
# Tests inject these via dependency_overrides; nothing to wire.
|
||||
log.warning("DB_CONNECTION_STRING unset; skipping engine init")
|
||||
|
||||
worker = asyncio.create_task(_drain_queue(app))
|
||||
app.state._worker = worker
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
worker.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await worker
|
||||
eng = getattr(app.state, "engine", None)
|
||||
if eng is not None:
|
||||
await eng.dispose()
|
||||
|
||||
|
||||
async def _drain_queue(app: FastAPI) -> None:
|
||||
"""Background task draining the recompute queue. Each item kicks
|
||||
a full Cartesian recompute. Errors logged, don't crash."""
|
||||
queue: asyncio.Queue[dict[str, Any]] = app.state.queue
|
||||
while True:
|
||||
item = await queue.get()
|
||||
try:
|
||||
from fire_planner.__main__ import _recompute_all
|
||||
await _recompute_all(
|
||||
n_paths=int(item.get("n_paths", 10_000)),
|
||||
horizon=int(item.get("horizon", 60)),
|
||||
spending=float(item.get("spending", 100_000.0)),
|
||||
nw_seed=float(item.get("nw_seed", 1_000_000.0)),
|
||||
savings=float(item.get("savings", 0.0)),
|
||||
floor=(float(item["floor"]) if item.get("floor") is not None else None),
|
||||
returns_csv=item.get("returns_csv"),
|
||||
seed=int(item.get("seed", 42)),
|
||||
)
|
||||
except Exception:
|
||||
log.exception("recompute failed")
|
||||
finally:
|
||||
queue.task_done()
|
||||
|
||||
|
||||
app = FastAPI(title="fire-planner", lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_frontend_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
Instrumentator().instrument(app).expose(app, endpoint="/metrics")
|
||||
|
||||
app.include_router(networth_router)
|
||||
app.include_router(scenarios_router)
|
||||
app.include_router(life_events_router)
|
||||
app.include_router(goals_router)
|
||||
app.include_router(simulate_router)
|
||||
|
||||
@app.post("/recompute", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def recompute(
|
||||
payload: dict[str, Any] | None = None,
|
||||
authorization: str | None = Header(default=None),
|
||||
) -> dict[str, Any]:
|
||||
_verify_bearer(authorization, os.environ.get("RECOMPUTE_BEARER_TOKEN", ""))
|
||||
|
||||
@app.post(
|
||||
"/recompute",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
dependencies=[Depends(require_bearer)],
|
||||
)
|
||||
async def recompute(payload: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Queue a full Cartesian recompute (async, persisted). Returns 202."""
|
||||
queue: asyncio.Queue[dict[str, Any]] = app.state.queue
|
||||
body = payload or {}
|
||||
await queue.put(body)
|
||||
await queue.put(payload or {})
|
||||
return {"status": "accepted", "depth": queue.qsize()}
|
||||
|
||||
|
||||
|
|
@ -70,43 +134,3 @@ async def healthz() -> dict[str, Any]:
|
|||
queue = getattr(app.state, "queue", None)
|
||||
depth = queue.qsize() if queue is not None else 0
|
||||
return {"status": "ok", "queue_depth": depth}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _drain_loop() -> None:
|
||||
"""Background task to drain the recompute queue. Each item kicks
|
||||
a full Cartesian recompute. Errors get logged but don't crash."""
|
||||
queue: asyncio.Queue[dict[str, Any]] = app.state.queue
|
||||
|
||||
async def worker() -> None:
|
||||
while True:
|
||||
item = await queue.get()
|
||||
try:
|
||||
# Avoid heavy import unless we actually have work.
|
||||
from fire_planner.__main__ import _recompute_all
|
||||
await _recompute_all(
|
||||
n_paths=int(item.get("n_paths", 10_000)),
|
||||
horizon=int(item.get("horizon", 60)),
|
||||
spending=float(item.get("spending", 100_000.0)),
|
||||
nw_seed=float(item.get("nw_seed", 1_000_000.0)),
|
||||
savings=float(item.get("savings", 0.0)),
|
||||
floor=(float(item["floor"]) if item.get("floor") is not None else None),
|
||||
returns_csv=item.get("returns_csv"),
|
||||
seed=int(item.get("seed", 42)),
|
||||
)
|
||||
except Exception:
|
||||
log.exception("recompute failed")
|
||||
finally:
|
||||
queue.task_done()
|
||||
|
||||
task = asyncio.create_task(worker())
|
||||
app.state._worker = task
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def _stop_worker() -> None:
|
||||
task = getattr(app.state, "_worker", None)
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
|
|
|||
|
|
@ -55,7 +55,9 @@ target-version = "py312"
|
|||
select = ["E", "F", "W", "I", "UP", "B", "SIM", "RUF"]
|
||||
# RUF002 / RUF003 flag ambiguous unicode characters (×, —, etc.) in
|
||||
# docstrings and comments — we use them intentionally for readability.
|
||||
ignore = ["RUF002", "RUF003"]
|
||||
# B008 trips on `Depends(...)` in argument defaults — that's the
|
||||
# idiomatic FastAPI pattern, not a bug.
|
||||
ignore = ["RUF002", "RUF003", "B008"]
|
||||
|
||||
[tool.yapf]
|
||||
based_on_style = "pep8"
|
||||
|
|
|
|||
151
tests/test_api_life_events_goals.py
Normal file
151
tests/test_api_life_events_goals.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
"""Tests for /life-events and /goals."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.app import app
|
||||
from fire_planner.db import Scenario
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(engine: AsyncEngine,
|
||||
session: AsyncSession) -> AsyncIterator[AsyncClient]:
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async def _override() -> AsyncIterator[AsyncSession]:
|
||||
async with factory() as s:
|
||||
yield s
|
||||
|
||||
app.dependency_overrides[get_session] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def _seed_scenario(session: AsyncSession) -> int:
|
||||
scen = Scenario(
|
||||
external_id="user-host",
|
||||
kind="user",
|
||||
name="Host plan",
|
||||
jurisdiction="uk",
|
||||
strategy="trinity",
|
||||
leave_uk_year=0,
|
||||
glide_path="static",
|
||||
spending_gbp=Decimal("60000"),
|
||||
nw_seed_gbp=Decimal("1000000"),
|
||||
savings_per_year_gbp=Decimal("0"),
|
||||
config_json={},
|
||||
)
|
||||
session.add(scen)
|
||||
await session.commit()
|
||||
await session.refresh(scen)
|
||||
return scen.id
|
||||
|
||||
|
||||
# ── life events ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_create_and_list_life_events(client: AsyncClient, session: AsyncSession) -> None:
|
||||
sid = await _seed_scenario(session)
|
||||
create = await client.post(
|
||||
f"/scenarios/{sid}/life-events",
|
||||
json={
|
||||
"kind": "retirement",
|
||||
"name": "Retire at 50",
|
||||
"year_start": 15,
|
||||
"year_end": 15,
|
||||
},
|
||||
)
|
||||
assert create.status_code == 201, create.text
|
||||
listed = await client.get(f"/scenarios/{sid}/life-events")
|
||||
assert listed.status_code == 200
|
||||
body = listed.json()
|
||||
assert len(body) == 1
|
||||
assert body[0]["name"] == "Retire at 50"
|
||||
|
||||
|
||||
async def test_life_event_year_validation(client: AsyncClient, session: AsyncSession) -> None:
|
||||
sid = await _seed_scenario(session)
|
||||
resp = await client.post(
|
||||
f"/scenarios/{sid}/life-events",
|
||||
json={
|
||||
"kind": "expense_range",
|
||||
"name": "Bad range",
|
||||
"year_start": 20,
|
||||
"year_end": 5,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
async def test_life_event_unknown_scenario(client: AsyncClient) -> None:
|
||||
resp = await client.get("/scenarios/9999/life-events")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_patch_life_event(client: AsyncClient, session: AsyncSession) -> None:
|
||||
sid = await _seed_scenario(session)
|
||||
create = await client.post(
|
||||
f"/scenarios/{sid}/life-events",
|
||||
json={"kind": "retirement", "name": "Retire", "year_start": 15},
|
||||
)
|
||||
eid = create.json()["id"]
|
||||
resp = await client.patch(f"/life-events/{eid}",
|
||||
json={"year_start": 20, "name": "Retire at 55"})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["year_start"] == 20
|
||||
assert body["name"] == "Retire at 55"
|
||||
|
||||
|
||||
async def test_delete_life_event(client: AsyncClient, session: AsyncSession) -> None:
|
||||
sid = await _seed_scenario(session)
|
||||
create = await client.post(
|
||||
f"/scenarios/{sid}/life-events",
|
||||
json={"kind": "retirement", "name": "X", "year_start": 5},
|
||||
)
|
||||
eid = create.json()["id"]
|
||||
resp = await client.delete(f"/life-events/{eid}")
|
||||
assert resp.status_code == 204
|
||||
listed = await client.get(f"/scenarios/{sid}/life-events")
|
||||
assert listed.json() == []
|
||||
|
||||
|
||||
# ── goals ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_create_and_list_goals(client: AsyncClient, session: AsyncSession) -> None:
|
||||
sid = await _seed_scenario(session)
|
||||
create = await client.post(
|
||||
f"/scenarios/{sid}/goals",
|
||||
json={
|
||||
"kind": "target_nw",
|
||||
"name": "≥ £2M at 50",
|
||||
"target_amount_gbp": "2000000",
|
||||
"target_year": 15,
|
||||
"comparator": ">=",
|
||||
"success_threshold": "0.90",
|
||||
},
|
||||
)
|
||||
assert create.status_code == 201, create.text
|
||||
listed = await client.get(f"/scenarios/{sid}/goals")
|
||||
assert len(listed.json()) == 1
|
||||
assert Decimal(listed.json()[0]["target_amount_gbp"]) == Decimal("2000000")
|
||||
|
||||
|
||||
async def test_delete_goal(client: AsyncClient, session: AsyncSession) -> None:
|
||||
sid = await _seed_scenario(session)
|
||||
create = await client.post(
|
||||
f"/scenarios/{sid}/goals",
|
||||
json={"kind": "never_run_out", "name": "Last to 95", "target_year": 65},
|
||||
)
|
||||
gid = create.json()["id"]
|
||||
resp = await client.delete(f"/goals/{gid}")
|
||||
assert resp.status_code == 204
|
||||
122
tests/test_api_networth.py
Normal file
122
tests/test_api_networth.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Tests for /networth and /networth/history."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.app import app
|
||||
from fire_planner.db import AccountSnapshot
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(engine: AsyncEngine,
|
||||
session: AsyncSession) -> AsyncIterator[AsyncClient]:
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async def _override() -> AsyncIterator[AsyncSession]:
|
||||
async with factory() as s:
|
||||
yield s
|
||||
|
||||
app.dependency_overrides[get_session] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def _seed_snapshots(session: AsyncSession) -> None:
|
||||
rows = [
|
||||
AccountSnapshot(
|
||||
external_id="wealthfolio:isa:2026-04-23",
|
||||
snapshot_date=date(2026, 4, 23),
|
||||
account_id="isa",
|
||||
account_name="ISA",
|
||||
account_type="ISA",
|
||||
currency="GBP",
|
||||
market_value=Decimal("280000"),
|
||||
market_value_gbp=Decimal("280000"),
|
||||
),
|
||||
AccountSnapshot(
|
||||
external_id="wealthfolio:schwab:2026-04-23",
|
||||
snapshot_date=date(2026, 4, 23),
|
||||
account_id="schwab",
|
||||
account_name="Schwab",
|
||||
account_type="BROKERAGE",
|
||||
currency="USD",
|
||||
market_value=Decimal("780000"),
|
||||
market_value_gbp=Decimal("615000"),
|
||||
),
|
||||
AccountSnapshot(
|
||||
external_id="wealthfolio:isa:2026-04-25",
|
||||
snapshot_date=date(2026, 4, 25),
|
||||
account_id="isa",
|
||||
account_name="ISA",
|
||||
account_type="ISA",
|
||||
currency="GBP",
|
||||
market_value=Decimal("300000"),
|
||||
market_value_gbp=Decimal("300000"),
|
||||
),
|
||||
AccountSnapshot(
|
||||
external_id="wealthfolio:schwab:2026-04-25",
|
||||
snapshot_date=date(2026, 4, 25),
|
||||
account_id="schwab",
|
||||
account_name="Schwab",
|
||||
account_type="BROKERAGE",
|
||||
currency="USD",
|
||||
market_value=Decimal("800000"),
|
||||
market_value_gbp=Decimal("640000"),
|
||||
),
|
||||
]
|
||||
for r in rows:
|
||||
session.add(r)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def test_get_networth_returns_latest(client: AsyncClient, session: AsyncSession) -> None:
|
||||
await _seed_snapshots(session)
|
||||
resp = await client.get("/networth")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["snapshot_date"] == "2026-04-25"
|
||||
assert Decimal(body["total_gbp"]) == Decimal("940000")
|
||||
by_id = {a["account_id"]: a for a in body["accounts"]}
|
||||
assert Decimal(by_id["isa"]["market_value_gbp"]) == Decimal("300000")
|
||||
assert Decimal(by_id["schwab"]["market_value_gbp"]) == Decimal("640000")
|
||||
|
||||
|
||||
async def test_get_networth_empty_when_no_snapshots(client: AsyncClient) -> None:
|
||||
resp = await client.get("/networth")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["accounts"] == []
|
||||
assert Decimal(body["total_gbp"]) == Decimal("0")
|
||||
|
||||
|
||||
async def test_networth_history_returns_per_date(client: AsyncClient,
|
||||
session: AsyncSession) -> None:
|
||||
await _seed_snapshots(session)
|
||||
resp = await client.get("/networth/history")
|
||||
assert resp.status_code == 200
|
||||
points = resp.json()["points"]
|
||||
assert len(points) == 2
|
||||
by_date = {p["snapshot_date"]: p for p in points}
|
||||
assert Decimal(by_date["2026-04-23"]["total_gbp"]) == Decimal("895000")
|
||||
assert Decimal(by_date["2026-04-25"]["total_gbp"]) == Decimal("940000")
|
||||
assert Decimal(by_date["2026-04-25"]["by_account"]["ISA"]) == Decimal("300000")
|
||||
|
||||
|
||||
async def test_networth_history_respects_days_filter(
|
||||
client: AsyncClient,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
await _seed_snapshots(session)
|
||||
resp = await client.get("/networth/history?days=1")
|
||||
assert resp.status_code == 200
|
||||
# days=1 ⇒ only the latest 1 distinct date
|
||||
assert len(resp.json()["points"]) == 1
|
||||
232
tests/test_api_scenarios.py
Normal file
232
tests/test_api_scenarios.py
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
"""Tests for /scenarios CRUD + projection."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.app import app
|
||||
from fire_planner.db import McRun, ProjectionYearly, Scenario
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(engine: AsyncEngine,
|
||||
session: AsyncSession) -> AsyncIterator[AsyncClient]:
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async def _override() -> AsyncIterator[AsyncSession]:
|
||||
async with factory() as s:
|
||||
yield s
|
||||
|
||||
app.dependency_overrides[get_session] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def _seed(session: AsyncSession) -> Scenario:
|
||||
scen = Scenario(
|
||||
external_id="cyprus-vpw-leave-y3-glide-rising",
|
||||
kind="cartesian",
|
||||
jurisdiction="cyprus",
|
||||
strategy="vpw",
|
||||
leave_uk_year=3,
|
||||
glide_path="rising",
|
||||
spending_gbp=Decimal("60000"),
|
||||
nw_seed_gbp=Decimal("1500000"),
|
||||
savings_per_year_gbp=Decimal("0"),
|
||||
config_json={"horizon_years": 60},
|
||||
)
|
||||
session.add(scen)
|
||||
await session.commit()
|
||||
await session.refresh(scen)
|
||||
return scen
|
||||
|
||||
|
||||
async def test_list_scenarios_empty(client: AsyncClient) -> None:
|
||||
resp = await client.get("/scenarios")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
async def test_list_and_filter_by_kind(client: AsyncClient, session: AsyncSession) -> None:
|
||||
base = await _seed(session)
|
||||
user = Scenario(
|
||||
external_id="user-abc",
|
||||
kind="user",
|
||||
name="My plan",
|
||||
parent_scenario_id=base.id,
|
||||
jurisdiction="cyprus",
|
||||
strategy="vpw",
|
||||
leave_uk_year=3,
|
||||
glide_path="rising",
|
||||
spending_gbp=Decimal("80000"),
|
||||
nw_seed_gbp=Decimal("1500000"),
|
||||
savings_per_year_gbp=Decimal("0"),
|
||||
config_json={},
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
all_resp = await client.get("/scenarios")
|
||||
assert len(all_resp.json()) == 2
|
||||
|
||||
user_resp = await client.get("/scenarios?kind=user")
|
||||
assert len(user_resp.json()) == 1
|
||||
assert user_resp.json()[0]["name"] == "My plan"
|
||||
|
||||
|
||||
async def test_get_scenario(client: AsyncClient, session: AsyncSession) -> None:
|
||||
scen = await _seed(session)
|
||||
resp = await client.get(f"/scenarios/{scen.id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["jurisdiction"] == "cyprus"
|
||||
|
||||
|
||||
async def test_get_scenario_404(client: AsyncClient) -> None:
|
||||
resp = await client.get("/scenarios/9999")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_create_user_scenario(client: AsyncClient) -> None:
|
||||
resp = await client.post(
|
||||
"/scenarios",
|
||||
json={
|
||||
"name": "Aggressive FIRE",
|
||||
"description": "Cyprus, lower spend",
|
||||
"jurisdiction": "cyprus",
|
||||
"strategy": "vpw",
|
||||
"leave_uk_year": 2,
|
||||
"glide_path": "rising",
|
||||
"spending_gbp": "50000",
|
||||
"horizon_years": 60,
|
||||
"nw_seed_gbp": "1500000",
|
||||
"savings_per_year_gbp": "0",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
body = resp.json()
|
||||
assert body["kind"] == "user"
|
||||
assert body["name"] == "Aggressive FIRE"
|
||||
assert body["external_id"].startswith("user-")
|
||||
|
||||
|
||||
async def test_create_with_invalid_parent_id(client: AsyncClient) -> None:
|
||||
resp = await client.post(
|
||||
"/scenarios",
|
||||
json={
|
||||
"name": "X",
|
||||
"parent_scenario_id": 9999,
|
||||
"jurisdiction": "uk",
|
||||
"strategy": "trinity",
|
||||
"leave_uk_year": 0,
|
||||
"glide_path": "static",
|
||||
"spending_gbp": "60000",
|
||||
"nw_seed_gbp": "1000000",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
async def test_patch_user_scenario(client: AsyncClient) -> None:
|
||||
create = await client.post("/scenarios",
|
||||
json={
|
||||
"name": "Plan A",
|
||||
"jurisdiction": "uk",
|
||||
"strategy": "trinity",
|
||||
"leave_uk_year": 0,
|
||||
"glide_path": "static",
|
||||
"spending_gbp": "60000",
|
||||
"nw_seed_gbp": "1000000",
|
||||
})
|
||||
sid = create.json()["id"]
|
||||
resp = await client.patch(f"/scenarios/{sid}", json={"name": "Plan A v2", "leave_uk_year": 2})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["name"] == "Plan A v2"
|
||||
assert body["leave_uk_year"] == 2
|
||||
|
||||
|
||||
async def test_patch_cartesian_blocked(client: AsyncClient, session: AsyncSession) -> None:
|
||||
cart = await _seed(session)
|
||||
resp = await client.patch(f"/scenarios/{cart.id}", json={"name": "Renamed"})
|
||||
assert resp.status_code == 400
|
||||
assert "cartesian" in resp.json()["detail"]
|
||||
|
||||
|
||||
async def test_delete_user_scenario(client: AsyncClient) -> None:
|
||||
create = await client.post("/scenarios",
|
||||
json={
|
||||
"name": "Throwaway",
|
||||
"jurisdiction": "uk",
|
||||
"strategy": "trinity",
|
||||
"leave_uk_year": 0,
|
||||
"glide_path": "static",
|
||||
"spending_gbp": "60000",
|
||||
"nw_seed_gbp": "1000000",
|
||||
})
|
||||
sid = create.json()["id"]
|
||||
resp = await client.delete(f"/scenarios/{sid}")
|
||||
assert resp.status_code == 204
|
||||
assert (await client.get(f"/scenarios/{sid}")).status_code == 404
|
||||
|
||||
|
||||
async def test_delete_cartesian_blocked(client: AsyncClient, session: AsyncSession) -> None:
|
||||
cart = await _seed(session)
|
||||
resp = await client.delete(f"/scenarios/{cart.id}")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
async def test_projection_404_when_no_run(client: AsyncClient, session: AsyncSession) -> None:
|
||||
scen = await _seed(session)
|
||||
resp = await client.get(f"/scenarios/{scen.id}/projection")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_projection_returns_yearly_series(client: AsyncClient,
|
||||
session: AsyncSession) -> None:
|
||||
scen = await _seed(session)
|
||||
run = McRun(
|
||||
scenario_id=scen.id,
|
||||
run_at=datetime(2026, 5, 1, tzinfo=UTC),
|
||||
n_paths=1000,
|
||||
seed=42,
|
||||
success_rate=Decimal("0.9050"),
|
||||
p10_ending_gbp=Decimal("100000"),
|
||||
p50_ending_gbp=Decimal("3000000"),
|
||||
p90_ending_gbp=Decimal("9000000"),
|
||||
median_lifetime_tax_gbp=Decimal("750000"),
|
||||
elapsed_seconds=Decimal("12.500"),
|
||||
)
|
||||
session.add(run)
|
||||
await session.commit()
|
||||
await session.refresh(run)
|
||||
for y in range(5):
|
||||
session.add(
|
||||
ProjectionYearly(
|
||||
mc_run_id=run.id,
|
||||
year_idx=y,
|
||||
p10_portfolio_gbp=Decimal("900000"),
|
||||
p25_portfolio_gbp=Decimal("950000"),
|
||||
p50_portfolio_gbp=Decimal("1000000"),
|
||||
p75_portfolio_gbp=Decimal("1100000"),
|
||||
p90_portfolio_gbp=Decimal("1200000"),
|
||||
p50_withdrawal_gbp=Decimal("60000"),
|
||||
p50_tax_gbp=Decimal("8000"),
|
||||
survival_rate=Decimal("1.0"),
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
resp = await client.get(f"/scenarios/{scen.id}/projection")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["scenario_id"] == scen.id
|
||||
assert body["n_paths"] == 1000
|
||||
assert len(body["yearly"]) == 5
|
||||
assert Decimal(body["yearly"][0]["p50_portfolio_gbp"]) == Decimal("1000000")
|
||||
131
tests/test_api_simulate.py
Normal file
131
tests/test_api_simulate.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
"""Smoke-tests for /simulate and /compare.
|
||||
|
||||
Uses very small n_paths (100) to keep tests fast — accuracy isn't the
|
||||
point, the point is the endpoint produces a valid response shape.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
from fire_planner.api.dependencies import get_session
|
||||
from fire_planner.app import app
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(engine: AsyncEngine,
|
||||
session: AsyncSession) -> AsyncIterator[AsyncClient]:
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async def _override() -> AsyncIterator[AsyncSession]:
|
||||
async with factory() as s:
|
||||
yield s
|
||||
|
||||
app.dependency_overrides[get_session] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test", timeout=30) as ac:
|
||||
yield ac
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_simulate_runs_and_returns_yearly_fan(client: AsyncClient) -> None:
|
||||
resp = await client.post(
|
||||
"/simulate",
|
||||
json={
|
||||
"jurisdiction": "uk",
|
||||
"strategy": "trinity",
|
||||
"leave_uk_year": 0,
|
||||
"glide_path": "static_60_40",
|
||||
"spending_gbp": "60000",
|
||||
"nw_seed_gbp": "1500000",
|
||||
"horizon_years": 30,
|
||||
"n_paths": 100,
|
||||
"seed": 42,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert "success_rate" in body
|
||||
assert len(body["yearly"]) == 30
|
||||
yp = body["yearly"][0]
|
||||
# Quantiles must be monotone non-decreasing
|
||||
p10, p25, p50, p75, p90 = (
|
||||
Decimal(yp[k])
|
||||
for k in ("p10_portfolio_gbp", "p25_portfolio_gbp", "p50_portfolio_gbp",
|
||||
"p75_portfolio_gbp", "p90_portfolio_gbp"))
|
||||
assert p10 <= p25 <= p50 <= p75 <= p90
|
||||
|
||||
|
||||
async def test_simulate_validates_unknown_jurisdiction(client: AsyncClient) -> None:
|
||||
resp = await client.post(
|
||||
"/simulate",
|
||||
json={
|
||||
"jurisdiction": "atlantis",
|
||||
"strategy": "trinity",
|
||||
"leave_uk_year": 0,
|
||||
"glide_path": "static_60_40",
|
||||
"spending_gbp": "60000",
|
||||
"nw_seed_gbp": "1000000",
|
||||
"horizon_years": 10,
|
||||
"n_paths": 100,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
async def test_compare_runs_two_scenarios(client: AsyncClient) -> None:
|
||||
resp = await client.post(
|
||||
"/compare",
|
||||
json={
|
||||
"scenarios": [
|
||||
{
|
||||
"jurisdiction": "uk",
|
||||
"strategy": "trinity",
|
||||
"leave_uk_year": 0,
|
||||
"glide_path": "static_60_40",
|
||||
"spending_gbp": "60000",
|
||||
"nw_seed_gbp": "1500000",
|
||||
"horizon_years": 20,
|
||||
"n_paths": 100,
|
||||
"seed": 42,
|
||||
},
|
||||
{
|
||||
"jurisdiction": "cyprus",
|
||||
"strategy": "guyton_klinger",
|
||||
"leave_uk_year": 2,
|
||||
"glide_path": "rising",
|
||||
"spending_gbp": "60000",
|
||||
"nw_seed_gbp": "1500000",
|
||||
"horizon_years": 20,
|
||||
"n_paths": 100,
|
||||
"seed": 42,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
results = resp.json()["results"]
|
||||
assert len(results) == 2
|
||||
assert all(len(r["yearly"]) == 20 for r in results)
|
||||
|
||||
|
||||
async def test_compare_rejects_single_scenario(client: AsyncClient) -> None:
|
||||
resp = await client.post(
|
||||
"/compare",
|
||||
json={
|
||||
"scenarios": [{
|
||||
"jurisdiction": "uk",
|
||||
"strategy": "trinity",
|
||||
"leave_uk_year": 0,
|
||||
"glide_path": "static_60_40",
|
||||
"spending_gbp": "60000",
|
||||
"nw_seed_gbp": "1500000",
|
||||
"n_paths": 100,
|
||||
}]
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422 # pydantic validation
|
||||
Loading…
Add table
Add a link
Reference in a new issue