112 lines
3.8 KiB
Python
112 lines
3.8 KiB
Python
"""FastAPI on-demand /recompute endpoint.
|
|
|
|
Single deployment. Bearer-token auth (matches payslip-ingest pattern).
|
|
The endpoint kicks the full 120-scenario Cartesian recompute against
|
|
whatever the latest Wealthfolio snapshot is in `account_snapshot`.
|
|
|
|
For dev / smoke tests, a `/healthz` endpoint reports queue depth.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import hmac
|
|
import logging
|
|
import os
|
|
from collections.abc import AsyncIterator
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any
|
|
|
|
from fastapi import FastAPI, Header, HTTPException, status
|
|
from prometheus_fastapi_instrumentator import Instrumentator
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
REQUIRED_ENV = ["DB_CONNECTION_STRING", "RECOMPUTE_BEARER_TOKEN"]
|
|
|
|
|
|
def _verify_env() -> None:
|
|
missing = [k for k in REQUIRED_ENV if not os.environ.get(k)]
|
|
if missing:
|
|
raise RuntimeError(f"Missing required env vars: {', '.join(missing)}")
|
|
|
|
|
|
def _verify_bearer(authorization: str | None, expected: str) -> None:
|
|
if not expected:
|
|
raise HTTPException(status_code=401, detail="Service unauthenticated")
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(status_code=401, detail="Missing bearer token")
|
|
token = authorization.removeprefix("Bearer ")
|
|
if not hmac.compare_digest(token, expected):
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|
_verify_env()
|
|
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
|
app.state.queue = queue
|
|
yield
|
|
|
|
|
|
app = FastAPI(title="fire-planner", lifespan=lifespan)
|
|
Instrumentator().instrument(app).expose(app, endpoint="/metrics")
|
|
|
|
|
|
@app.post("/recompute", status_code=status.HTTP_202_ACCEPTED)
|
|
async def recompute(
|
|
payload: dict[str, Any] | None = None,
|
|
authorization: str | None = Header(default=None),
|
|
) -> dict[str, Any]:
|
|
_verify_bearer(authorization, os.environ.get("RECOMPUTE_BEARER_TOKEN", ""))
|
|
queue: asyncio.Queue[dict[str, Any]] = app.state.queue
|
|
body = payload or {}
|
|
await queue.put(body)
|
|
return {"status": "accepted", "depth": queue.qsize()}
|
|
|
|
|
|
@app.get("/healthz")
|
|
async def healthz() -> dict[str, Any]:
|
|
queue = getattr(app.state, "queue", None)
|
|
depth = queue.qsize() if queue is not None else 0
|
|
return {"status": "ok", "queue_depth": depth}
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def _drain_loop() -> None:
|
|
"""Background task to drain the recompute queue. Each item kicks
|
|
a full Cartesian recompute. Errors get logged but don't crash."""
|
|
queue: asyncio.Queue[dict[str, Any]] = app.state.queue
|
|
|
|
async def worker() -> None:
|
|
while True:
|
|
item = await queue.get()
|
|
try:
|
|
# Avoid heavy import unless we actually have work.
|
|
from fire_planner.__main__ import _recompute_all
|
|
await _recompute_all(
|
|
n_paths=int(item.get("n_paths", 10_000)),
|
|
horizon=int(item.get("horizon", 60)),
|
|
spending=float(item.get("spending", 100_000.0)),
|
|
nw_seed=float(item.get("nw_seed", 1_000_000.0)),
|
|
savings=float(item.get("savings", 0.0)),
|
|
floor=(float(item["floor"]) if item.get("floor") is not None else None),
|
|
returns_csv=item.get("returns_csv"),
|
|
seed=int(item.get("seed", 42)),
|
|
)
|
|
except Exception:
|
|
log.exception("recompute failed")
|
|
finally:
|
|
queue.task_done()
|
|
|
|
task = asyncio.create_task(worker())
|
|
app.state._worker = task
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def _stop_worker() -> None:
|
|
task = getattr(app.state, "_worker", None)
|
|
if task is not None:
|
|
task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await task
|