Add API anti-abuse hardening: disable docs in prod, origin validator, exception handler

- Disable OpenAPI docs/redoc/openapi.json when APP_ENV=production
- Strip uvicorn Server header with --no-server-header in Dockerfile and docker-compose.yml
- Add OriginValidatorMiddleware to reject state-changing requests from disallowed origins
- Add global exception handler to prevent stack trace leakage on unhandled errors
- Add tests for all new security features (OpenAPI, origin validation, exception handler, server header)
This commit is contained in:
Viktor Barzin 2026-02-08 20:06:46 +00:00
parent 162d9a886d
commit 1ace45353a
No known key found for this signature in database
GPG key ID: 0EB088298288D958
8 changed files with 252 additions and 4 deletions

View file

@ -0,0 +1,42 @@
"""Unit tests for OpenAPI docs disabled in production."""
from unittest import mock
from starlette.testclient import TestClient
class TestOpenAPIDisabledInProduction:
"""Verify docs/redoc/openapi are disabled when APP_ENV=production."""
def test_docs_disabled_in_production(self) -> None:
prod_env = {
"APP_ENV": "production",
"OIDC_CLIENT_ID": "test-client-id",
"JWT_SECRET": "test-secret-not-default",
}
with mock.patch.dict("os.environ", prod_env):
# Re-import to pick up the new env
import importlib
import api.config
importlib.reload(api.config)
import api.app
importlib.reload(api.app)
app = api.app.app
client = TestClient(app, raise_server_exceptions=False)
assert client.get("/docs").status_code == 404
assert client.get("/redoc").status_code == 404
assert client.get("/openapi.json").status_code == 404
def test_docs_enabled_in_dev(self) -> None:
with mock.patch.dict("os.environ", {"APP_ENV": "dev"}):
import importlib
import api.config
importlib.reload(api.config)
import api.app
importlib.reload(api.app)
app = api.app.app
client = TestClient(app, raise_server_exceptions=False)
# /docs should return 200 (Swagger UI) in non-production
assert client.get("/docs").status_code == 200
assert client.get("/openapi.json").status_code == 200

View file

@ -0,0 +1,81 @@
"""Unit tests for api/origin_validator.py."""
from starlette.testclient import TestClient
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from api.origin_validator import OriginValidatorMiddleware
ALLOWED_ORIGINS = ["https://example.com", "https://app.example.com/"]
async def _ok_endpoint(request: Request) -> JSONResponse:
return JSONResponse({"ok": True})
def _build_app() -> Starlette:
app = Starlette(routes=[
Route("/test", _ok_endpoint, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]),
])
app.add_middleware(OriginValidatorMiddleware, allowed_origins=ALLOWED_ORIGINS)
return app
class TestOriginValidator:
"""Tests for OriginValidatorMiddleware."""
def test_get_request_bypasses_check(self) -> None:
"""GET requests should not be checked for origin."""
client = TestClient(_build_app())
resp = client.get("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 200
def test_post_with_allowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.post("/test", headers={"Origin": "https://example.com"})
assert resp.status_code == 200
def test_post_with_allowed_origin_trailing_slash(self) -> None:
"""Trailing slash on origin should be normalized."""
client = TestClient(_build_app())
resp = client.post("/test", headers={"Origin": "https://app.example.com/"})
assert resp.status_code == 200
def test_post_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.post("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
assert resp.json() == {"detail": "Origin not allowed"}
def test_put_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.put("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
def test_delete_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.delete("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
def test_patch_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.patch("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
def test_post_without_origin_passes(self) -> None:
"""Requests without an Origin header should pass through."""
client = TestClient(_build_app())
resp = client.post("/test")
assert resp.status_code == 200
def test_no_allowed_origins_configured(self) -> None:
"""If no allowed origins are given, all origins are rejected."""
app = Starlette(routes=[
Route("/test", _ok_endpoint, methods=["POST"]),
])
app.add_middleware(OriginValidatorMiddleware, allowed_origins=[])
client = TestClient(app)
resp = client.post("/test", headers={"Origin": "https://example.com"})
assert resp.status_code == 403

View file

@ -54,3 +54,9 @@ class TestSecurityHeaders:
client = TestClient(_build_app())
resp = client.get("/test")
assert "Strict-Transport-Security" not in resp.headers
def test_server_header_not_present(self) -> None:
"""The Server header should not leak server software info."""
client = TestClient(_build_app())
resp = client.get("/test")
assert "Server" not in resp.headers or "uvicorn" not in resp.headers.get("Server", "").lower()

View file

@ -0,0 +1,65 @@
"""Unit tests for the unhandled exception handler in api/app.py."""
from unittest import mock
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from starlette.requests import Request
from starlette.testclient import TestClient
def _build_app() -> FastAPI:
"""Build a minimal FastAPI app with the unhandled exception handler."""
app = FastAPI()
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
return JSONResponse(
status_code=500,
content={"detail": "Internal server error"},
)
@app.get("/ok")
async def ok() -> dict[str, str]:
return {"status": "ok"}
@app.get("/crash")
async def crash() -> None:
raise RuntimeError("something unexpected")
@app.get("/http-error")
async def http_error() -> None:
raise HTTPException(status_code=403, detail="forbidden")
return app
class TestUnhandledExceptionHandler:
"""Tests for the global exception handler."""
def test_normal_request_unaffected(self) -> None:
client = TestClient(_build_app())
resp = client.get("/ok")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
def test_unhandled_exception_returns_generic_500(self) -> None:
client = TestClient(_build_app(), raise_server_exceptions=False)
resp = client.get("/crash")
assert resp.status_code == 500
assert resp.json() == {"detail": "Internal server error"}
def test_http_exception_not_caught(self) -> None:
"""HTTPException should still be handled by FastAPI's built-in handler."""
client = TestClient(_build_app())
resp = client.get("/http-error")
assert resp.status_code == 403
assert resp.json() == {"detail": "forbidden"}
def test_error_body_does_not_leak_details(self) -> None:
"""The response must not contain the actual exception message."""
client = TestClient(_build_app(), raise_server_exceptions=False)
resp = client.get("/crash")
body = resp.text
assert "something unexpected" not in body
assert "RuntimeError" not in body