fire-planner/fire_planner/app.py
2026-05-07 17:06:19 +00:00

112 lines
3.8 KiB
Python

"""FastAPI on-demand /recompute endpoint.
Single deployment. Bearer-token auth (matches payslip-ingest pattern).
The endpoint kicks the full 120-scenario Cartesian recompute against
whatever the latest Wealthfolio snapshot is in `account_snapshot`.
For dev / smoke tests, a `/healthz` endpoint reports queue depth.
"""
from __future__ import annotations
import asyncio
import contextlib
import hmac
import logging
import os
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any
from fastapi import FastAPI, Header, HTTPException, status
from prometheus_fastapi_instrumentator import Instrumentator
log = logging.getLogger(__name__)
REQUIRED_ENV = ["DB_CONNECTION_STRING", "RECOMPUTE_BEARER_TOKEN"]
def _verify_env() -> None:
missing = [k for k in REQUIRED_ENV if not os.environ.get(k)]
if missing:
raise RuntimeError(f"Missing required env vars: {', '.join(missing)}")
def _verify_bearer(authorization: str | None, expected: str) -> None:
if not expected:
raise HTTPException(status_code=401, detail="Service unauthenticated")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing bearer token")
token = authorization.removeprefix("Bearer ")
if not hmac.compare_digest(token, expected):
raise HTTPException(status_code=401, detail="Invalid token")
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
_verify_env()
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
app.state.queue = queue
yield
app = FastAPI(title="fire-planner", lifespan=lifespan)
Instrumentator().instrument(app).expose(app, endpoint="/metrics")
@app.post("/recompute", status_code=status.HTTP_202_ACCEPTED)
async def recompute(
payload: dict[str, Any] | None = None,
authorization: str | None = Header(default=None),
) -> dict[str, Any]:
_verify_bearer(authorization, os.environ.get("RECOMPUTE_BEARER_TOKEN", ""))
queue: asyncio.Queue[dict[str, Any]] = app.state.queue
body = payload or {}
await queue.put(body)
return {"status": "accepted", "depth": queue.qsize()}
@app.get("/healthz")
async def healthz() -> dict[str, Any]:
queue = getattr(app.state, "queue", None)
depth = queue.qsize() if queue is not None else 0
return {"status": "ok", "queue_depth": depth}
@app.on_event("startup")
async def _drain_loop() -> None:
"""Background task to drain the recompute queue. Each item kicks
a full Cartesian recompute. Errors get logged but don't crash."""
queue: asyncio.Queue[dict[str, Any]] = app.state.queue
async def worker() -> None:
while True:
item = await queue.get()
try:
# Avoid heavy import unless we actually have work.
from fire_planner.__main__ import _recompute_all
await _recompute_all(
n_paths=int(item.get("n_paths", 10_000)),
horizon=int(item.get("horizon", 60)),
spending=float(item.get("spending", 100_000.0)),
nw_seed=float(item.get("nw_seed", 1_000_000.0)),
savings=float(item.get("savings", 0.0)),
floor=(float(item["floor"]) if item.get("floor") is not None else None),
returns_csv=item.get("returns_csv"),
seed=int(item.get("seed", 42)),
)
except Exception:
log.exception("recompute failed")
finally:
queue.task_done()
task = asyncio.create_task(worker())
app.state._worker = task
@app.on_event("shutdown")
async def _stop_worker() -> None:
task = getattr(app.state, "_worker", None)
if task is not None:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task