fire-planner/fire_planner/app.py

178 lines
6.6 KiB
Python
Raw Normal View History

"""FastAPI application — wires routers + middleware + lifespan.
Two deployment shapes:
- **Dev / tests**: `FRONTEND_DIST` unset. API routers mount at root
(e.g. `/networth`, `/scenarios`). The Vite dev server (port 5173)
hits `/api/*` and proxies to the backend stripping the prefix.
- **Prod**: `FRONTEND_DIST=/path/to/frontend/dist` set. API routers
mount under `/api/*`; the SPA static bundle mounts at `/` with
HTML fallback so React Router can own the URL.
Operational endpoints (`/healthz`, `/metrics`, `/recompute`) stay at
root in both shapes.
Auth: write/compute paths take Bearer auth via the `require_bearer`
dependency when `API_BEARER_TOKEN` is set. Read paths skip auth so the
local frontend can hit them without juggling tokens production
deploys lock those down via Authentik-fronted ingress.
CORS: enabled for the frontend dev server. Comma-separated origins
in `FRONTEND_ORIGINS` (defaults to a typical Vite localhost).
2026-05-07 17:06:19 +00:00
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
import os
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
2026-05-07 17:06:19 +00:00
from typing import Any
from fastapi import Depends, FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
2026-05-07 17:06:19 +00:00
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.types import Scope
2026-05-07 17:06:19 +00:00
from fire_planner.api.auth import require_bearer
from fire_planner.api.goals import router as goals_router
from fire_planner.api.life_events import router as life_events_router
from fire_planner.api.networth import router as networth_router
from fire_planner.api.scenarios import router as scenarios_router
from fire_planner.api.simulate import router as simulate_router
from fire_planner.api.spending import router as spending_router
from fire_planner.db import create_engine_from_env, make_session_factory
2026-05-07 17:06:19 +00:00
log = logging.getLogger(__name__)
2026-05-07 17:06:19 +00:00
def _frontend_origins() -> list[str]:
raw = os.environ.get(
"FRONTEND_ORIGINS",
"http://localhost:5173,http://localhost:4173,http://127.0.0.1:5173",
)
return [s.strip() for s in raw.split(",") if s.strip()]
2026-05-07 17:06:19 +00:00
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
app.state.queue = queue
if os.environ.get("DB_CONNECTION_STRING"):
engine = create_engine_from_env()
app.state.engine = engine
app.state.session_factory = make_session_factory(engine)
else:
# Tests inject these via dependency_overrides; nothing to wire.
log.warning("DB_CONNECTION_STRING unset; skipping engine init")
worker = asyncio.create_task(_drain_queue(app))
app.state._worker = worker
try:
yield
finally:
worker.cancel()
with contextlib.suppress(asyncio.CancelledError):
await worker
eng = getattr(app.state, "engine", None)
if eng is not None:
await eng.dispose()
async def _drain_queue(app: FastAPI) -> None:
"""Background task draining the recompute queue. Each item kicks
a full Cartesian recompute. Errors logged, don't crash."""
queue: asyncio.Queue[dict[str, Any]] = app.state.queue
while True:
item = await queue.get()
try:
from fire_planner.__main__ import _recompute_all
await _recompute_all(
n_paths=int(item.get("n_paths", 10_000)),
horizon=int(item.get("horizon", 60)),
spending=float(item.get("spending", 100_000.0)),
nw_seed=float(item.get("nw_seed", 1_000_000.0)),
savings=float(item.get("savings", 0.0)),
floor=(float(item["floor"]) if item.get("floor") is not None else None),
returns_csv=item.get("returns_csv"),
seed=int(item.get("seed", 42)),
)
except Exception:
log.exception("recompute failed")
finally:
queue.task_done()
2026-05-07 17:06:19 +00:00
_FRONTEND_DIST = os.environ.get("FRONTEND_DIST")
_API_PREFIX = "/api" if _FRONTEND_DIST else ""
2026-05-07 17:06:19 +00:00
app = FastAPI(title="fire-planner", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=_frontend_origins(),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
2026-05-07 17:06:19 +00:00
Instrumentator().instrument(app).expose(app, endpoint="/metrics")
app.include_router(networth_router, prefix=_API_PREFIX)
app.include_router(scenarios_router, prefix=_API_PREFIX)
app.include_router(life_events_router, prefix=_API_PREFIX)
app.include_router(goals_router, prefix=_API_PREFIX)
app.include_router(simulate_router, prefix=_API_PREFIX)
app.include_router(spending_router, prefix=_API_PREFIX)
2026-05-07 17:06:19 +00:00
@app.post(
"/recompute",
status_code=status.HTTP_202_ACCEPTED,
dependencies=[Depends(require_bearer)],
)
async def recompute(payload: dict[str, Any] | None = None) -> dict[str, Any]:
"""Queue a full Cartesian recompute (async, persisted). Returns 202."""
2026-05-07 17:06:19 +00:00
queue: asyncio.Queue[dict[str, Any]] = app.state.queue
await queue.put(payload or {})
2026-05-07 17:06:19 +00:00
return {"status": "accepted", "depth": queue.qsize()}
@app.get("/healthz")
async def healthz() -> dict[str, Any]:
queue = getattr(app.state, "queue", None)
depth = queue.qsize() if queue is not None else 0
return {"status": "ok", "queue_depth": depth}
class _SPAStaticFiles(StaticFiles):
"""StaticFiles with a SPA fallback: any 404 returns index.html so
React Router can own client-side routing for paths like /scenarios,
/what-if, /scenarios/123, etc.
`StaticFiles(html=True)` only serves index.html for *directories*,
not arbitrary not-found paths that's not enough for a SPA.
"""
async def get_response(self, path: str, scope: Scope): # type: ignore[no-untyped-def]
try:
return await super().get_response(path, scope)
except StarletteHTTPException as exc:
if exc.status_code == 404:
index = Path(self.directory) / "index.html" # type: ignore[arg-type]
if index.is_file():
return FileResponse(index)
raise
# Mount the SPA last so it catches any path the API didn't claim.
if _FRONTEND_DIST and Path(_FRONTEND_DIST).is_dir():
app.mount("/", _SPAStaticFiles(directory=_FRONTEND_DIST, html=True), name="spa")
log.info("Mounted SPA from %s; API at %s/*", _FRONTEND_DIST, _API_PREFIX)
elif _FRONTEND_DIST:
log.warning("FRONTEND_DIST=%s does not exist; SPA not mounted", _FRONTEND_DIST)