Add API anti-abuse hardening: disable docs in prod, origin validator, exception handler

- Disable OpenAPI docs/redoc/openapi.json when APP_ENV=production
- Strip uvicorn Server header with --no-server-header in Dockerfile and docker-compose.yml
- Add OriginValidatorMiddleware to reject state-changing requests from disallowed origins
- Add global exception handler to prevent stack trace leakage on unhandled errors
- Add tests for all new security features (OpenAPI, origin validation, exception handler, server header)
This commit is contained in:
Viktor Barzin 2026-02-08 20:06:46 +00:00
parent 162d9a886d
commit 1ace45353a
No known key found for this signature in database
GPG key ID: 0EB088298288D958
8 changed files with 252 additions and 4 deletions

View file

@ -42,4 +42,4 @@ ENV PATH="/app/.venv/bin:$PATH"
COPY . .
EXPOSE 5001
CMD ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001"]
CMD ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001 --no-server-header"]

View file

@ -13,9 +13,11 @@ from api.rate_limiter import RateLimitMiddleware
from api.audit_middleware import AuditLogMiddleware
from api.metrics_guard import MetricsGuardMiddleware
from api.security_headers import SecurityHeadersMiddleware
from api.origin_validator import OriginValidatorMiddleware
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.requests import Request
from api.auth import User
from models.listing import QueryParameters, ListingType, FurnishType
from notifications import send_notification
@ -85,7 +87,11 @@ def get_query_parameters(
)
app = FastAPI()
app = FastAPI(
docs_url=None if APP_ENV == "production" else "/docs",
redoc_url=None if APP_ENV == "production" else "/redoc",
openapi_url=None if APP_ENV == "production" else "/openapi.json",
)
app.include_router(passkey_router)
app.include_router(poi_router)
app.mount("/metrics", metrics_app)
@ -108,6 +114,11 @@ app.add_middleware(
allow_headers=["Authorization", "Content-Type"],
)
app.add_middleware(
OriginValidatorMiddleware,
allowed_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
)
# Security middleware (added bottom-to-top; last added = outermost)
# 3. Rate limiting — enforces per-user limits
app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
@ -119,6 +130,15 @@ app.add_middleware(AuditLogMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
logger.exception("Unhandled exception")
return JSONResponse(
status_code=500,
content={"detail": "Internal server error"},
)
@app.get("/api/status")
async def get_status() -> dict[str, str]:
request_counter.add(1, {"method": "GET", "path": "/status"})

34
api/origin_validator.py Normal file
View file

@ -0,0 +1,34 @@
"""Origin validation middleware for state-changing requests."""
import logging
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
logger = logging.getLogger("uvicorn")
STATE_CHANGING_METHODS = {"POST", "PUT", "DELETE", "PATCH"}
class OriginValidatorMiddleware(BaseHTTPMiddleware):
"""Reject state-changing requests with mismatched Origin header."""
def __init__(self, app, allowed_origins: list[str] | None = None) -> None:
super().__init__(app)
self._allowed = {o.rstrip("/") for o in (allowed_origins or [])}
async def dispatch(self, request: Request, call_next) -> Response:
if request.method not in STATE_CHANGING_METHODS:
return await call_next(request)
origin = request.headers.get("origin")
if origin is None:
return await call_next(request)
if origin.rstrip("/") not in self._allowed:
logger.warning(f"Rejected request from origin: {origin}")
return JSONResponse(
status_code=403,
content={"detail": "Origin not allowed"},
)
return await call_next(request)

View file

@ -60,7 +60,7 @@ services:
condition: service_healthy
mysql:
condition: service_healthy
command: ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001 --reload --reload-dir api --reload-dir services --reload-dir repositories --reload-dir models"]
command: ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001 --reload --reload-dir api --reload-dir services --reload-dir repositories --reload-dir models --no-server-header"]
networks:
- rec-network

View file

@ -0,0 +1,42 @@
"""Unit tests for OpenAPI docs disabled in production."""
from unittest import mock
from starlette.testclient import TestClient
class TestOpenAPIDisabledInProduction:
"""Verify docs/redoc/openapi are disabled when APP_ENV=production."""
def test_docs_disabled_in_production(self) -> None:
prod_env = {
"APP_ENV": "production",
"OIDC_CLIENT_ID": "test-client-id",
"JWT_SECRET": "test-secret-not-default",
}
with mock.patch.dict("os.environ", prod_env):
# Re-import to pick up the new env
import importlib
import api.config
importlib.reload(api.config)
import api.app
importlib.reload(api.app)
app = api.app.app
client = TestClient(app, raise_server_exceptions=False)
assert client.get("/docs").status_code == 404
assert client.get("/redoc").status_code == 404
assert client.get("/openapi.json").status_code == 404
def test_docs_enabled_in_dev(self) -> None:
with mock.patch.dict("os.environ", {"APP_ENV": "dev"}):
import importlib
import api.config
importlib.reload(api.config)
import api.app
importlib.reload(api.app)
app = api.app.app
client = TestClient(app, raise_server_exceptions=False)
# /docs should return 200 (Swagger UI) in non-production
assert client.get("/docs").status_code == 200
assert client.get("/openapi.json").status_code == 200

View file

@ -0,0 +1,81 @@
"""Unit tests for api/origin_validator.py."""
from starlette.testclient import TestClient
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from api.origin_validator import OriginValidatorMiddleware
ALLOWED_ORIGINS = ["https://example.com", "https://app.example.com/"]
async def _ok_endpoint(request: Request) -> JSONResponse:
return JSONResponse({"ok": True})
def _build_app() -> Starlette:
app = Starlette(routes=[
Route("/test", _ok_endpoint, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]),
])
app.add_middleware(OriginValidatorMiddleware, allowed_origins=ALLOWED_ORIGINS)
return app
class TestOriginValidator:
"""Tests for OriginValidatorMiddleware."""
def test_get_request_bypasses_check(self) -> None:
"""GET requests should not be checked for origin."""
client = TestClient(_build_app())
resp = client.get("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 200
def test_post_with_allowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.post("/test", headers={"Origin": "https://example.com"})
assert resp.status_code == 200
def test_post_with_allowed_origin_trailing_slash(self) -> None:
"""Trailing slash on origin should be normalized."""
client = TestClient(_build_app())
resp = client.post("/test", headers={"Origin": "https://app.example.com/"})
assert resp.status_code == 200
def test_post_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.post("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
assert resp.json() == {"detail": "Origin not allowed"}
def test_put_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.put("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
def test_delete_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.delete("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
def test_patch_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.patch("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
def test_post_without_origin_passes(self) -> None:
"""Requests without an Origin header should pass through."""
client = TestClient(_build_app())
resp = client.post("/test")
assert resp.status_code == 200
def test_no_allowed_origins_configured(self) -> None:
"""If no allowed origins are given, all origins are rejected."""
app = Starlette(routes=[
Route("/test", _ok_endpoint, methods=["POST"]),
])
app.add_middleware(OriginValidatorMiddleware, allowed_origins=[])
client = TestClient(app)
resp = client.post("/test", headers={"Origin": "https://example.com"})
assert resp.status_code == 403

View file

@ -54,3 +54,9 @@ class TestSecurityHeaders:
client = TestClient(_build_app())
resp = client.get("/test")
assert "Strict-Transport-Security" not in resp.headers
def test_server_header_not_present(self) -> None:
"""The Server header should not leak server software info."""
client = TestClient(_build_app())
resp = client.get("/test")
assert "Server" not in resp.headers or "uvicorn" not in resp.headers.get("Server", "").lower()

View file

@ -0,0 +1,65 @@
"""Unit tests for the unhandled exception handler in api/app.py."""
from unittest import mock
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from starlette.requests import Request
from starlette.testclient import TestClient
def _build_app() -> FastAPI:
"""Build a minimal FastAPI app with the unhandled exception handler."""
app = FastAPI()
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
return JSONResponse(
status_code=500,
content={"detail": "Internal server error"},
)
@app.get("/ok")
async def ok() -> dict[str, str]:
return {"status": "ok"}
@app.get("/crash")
async def crash() -> None:
raise RuntimeError("something unexpected")
@app.get("/http-error")
async def http_error() -> None:
raise HTTPException(status_code=403, detail="forbidden")
return app
class TestUnhandledExceptionHandler:
"""Tests for the global exception handler."""
def test_normal_request_unaffected(self) -> None:
client = TestClient(_build_app())
resp = client.get("/ok")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
def test_unhandled_exception_returns_generic_500(self) -> None:
client = TestClient(_build_app(), raise_server_exceptions=False)
resp = client.get("/crash")
assert resp.status_code == 500
assert resp.json() == {"detail": "Internal server error"}
def test_http_exception_not_caught(self) -> None:
"""HTTPException should still be handled by FastAPI's built-in handler."""
client = TestClient(_build_app())
resp = client.get("/http-error")
assert resp.status_code == 403
assert resp.json() == {"detail": "forbidden"}
def test_error_body_does_not_leak_details(self) -> None:
"""The response must not contain the actual exception message."""
client = TestClient(_build_app(), raise_server_exceptions=False)
resp = client.get("/crash")
body = resp.text
assert "something unexpected" not in body
assert "RuntimeError" not in body