Add API anti-abuse hardening: disable docs in prod, origin validator, exception handler

- Disable OpenAPI docs/redoc/openapi.json when APP_ENV=production
- Strip uvicorn Server header with --no-server-header in Dockerfile and docker-compose.yml
- Add OriginValidatorMiddleware to reject state-changing requests from disallowed origins
- Add global exception handler to prevent stack trace leakage on unhandled errors
- Add tests for all new security features (OpenAPI, origin validation, exception handler, server header)
This commit is contained in:
Viktor Barzin 2026-02-08 20:06:46 +00:00
parent 162d9a886d
commit 1ace45353a
No known key found for this signature in database
GPG key ID: 0EB088298288D958
8 changed files with 252 additions and 4 deletions

View file

@ -0,0 +1,81 @@
"""Unit tests for api/origin_validator.py."""
from starlette.testclient import TestClient
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from api.origin_validator import OriginValidatorMiddleware
ALLOWED_ORIGINS = ["https://example.com", "https://app.example.com/"]
async def _ok_endpoint(request: Request) -> JSONResponse:
return JSONResponse({"ok": True})
def _build_app() -> Starlette:
app = Starlette(routes=[
Route("/test", _ok_endpoint, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]),
])
app.add_middleware(OriginValidatorMiddleware, allowed_origins=ALLOWED_ORIGINS)
return app
class TestOriginValidator:
"""Tests for OriginValidatorMiddleware."""
def test_get_request_bypasses_check(self) -> None:
"""GET requests should not be checked for origin."""
client = TestClient(_build_app())
resp = client.get("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 200
def test_post_with_allowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.post("/test", headers={"Origin": "https://example.com"})
assert resp.status_code == 200
def test_post_with_allowed_origin_trailing_slash(self) -> None:
"""Trailing slash on origin should be normalized."""
client = TestClient(_build_app())
resp = client.post("/test", headers={"Origin": "https://app.example.com/"})
assert resp.status_code == 200
def test_post_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.post("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
assert resp.json() == {"detail": "Origin not allowed"}
def test_put_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.put("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
def test_delete_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.delete("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
def test_patch_with_disallowed_origin(self) -> None:
client = TestClient(_build_app())
resp = client.patch("/test", headers={"Origin": "https://evil.com"})
assert resp.status_code == 403
def test_post_without_origin_passes(self) -> None:
"""Requests without an Origin header should pass through."""
client = TestClient(_build_app())
resp = client.post("/test")
assert resp.status_code == 200
def test_no_allowed_origins_configured(self) -> None:
"""If no allowed origins are given, all origins are rejected."""
app = Starlette(routes=[
Route("/test", _ok_endpoint, methods=["POST"]),
])
app.add_middleware(OriginValidatorMiddleware, allowed_origins=[])
client = TestClient(app)
resp = client.post("/test", headers={"Origin": "https://example.com"})
assert resp.status_code == 403