From 1ace45353a8e8d6c552bec2914a33563c0747826 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Sun, 8 Feb 2026 20:06:46 +0000 Subject: [PATCH] Add API anti-abuse hardening: disable docs in prod, origin validator, exception handler - Disable OpenAPI docs/redoc/openapi.json when APP_ENV=production - Strip uvicorn Server header with --no-server-header in Dockerfile and docker-compose.yml - Add OriginValidatorMiddleware to reject state-changing requests from disallowed origins - Add global exception handler to prevent stack trace leakage on unhandled errors - Add tests for all new security features (OpenAPI, origin validation, exception handler, server header) --- Dockerfile | 2 +- api/app.py | 24 +++++++- api/origin_validator.py | 34 +++++++++++ docker-compose.yml | 2 +- tests/unit/test_openapi_disabled.py | 42 +++++++++++++ tests/unit/test_origin_validator.py | 81 ++++++++++++++++++++++++++ tests/unit/test_security_headers.py | 6 ++ tests/unit/test_unhandled_exception.py | 65 +++++++++++++++++++++ 8 files changed, 252 insertions(+), 4 deletions(-) create mode 100644 api/origin_validator.py create mode 100644 tests/unit/test_openapi_disabled.py create mode 100644 tests/unit/test_origin_validator.py create mode 100644 tests/unit/test_unhandled_exception.py diff --git a/Dockerfile b/Dockerfile index 465d0b5..0a7490b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,4 +42,4 @@ ENV PATH="/app/.venv/bin:$PATH" COPY . . EXPOSE 5001 -CMD ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001"] +CMD ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001 --no-server-header"] diff --git a/api/app.py b/api/app.py index 921385f..00c2cce 100644 --- a/api/app.py +++ b/api/app.py @@ -13,9 +13,11 @@ from api.rate_limiter import RateLimitMiddleware from api.audit_middleware import AuditLogMiddleware from api.metrics_guard import MetricsGuardMiddleware from api.security_headers import SecurityHeadersMiddleware +from api.origin_validator import OriginValidatorMiddleware from dotenv import load_dotenv from fastapi import Depends, FastAPI, HTTPException, Query -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse +from starlette.requests import Request from api.auth import User from models.listing import QueryParameters, ListingType, FurnishType from notifications import send_notification @@ -85,7 +87,11 @@ def get_query_parameters( ) -app = FastAPI() +app = FastAPI( + docs_url=None if APP_ENV == "production" else "/docs", + redoc_url=None if APP_ENV == "production" else "/redoc", + openapi_url=None if APP_ENV == "production" else "/openapi.json", +) app.include_router(passkey_router) app.include_router(poi_router) app.mount("/metrics", metrics_app) @@ -108,6 +114,11 @@ app.add_middleware( allow_headers=["Authorization", "Content-Type"], ) +app.add_middleware( + OriginValidatorMiddleware, + allowed_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS], +) + # Security middleware (added bottom-to-top; last added = outermost) # 3. Rate limiting — enforces per-user limits app.add_middleware(RateLimitMiddleware, config=_rate_limit_config) @@ -119,6 +130,15 @@ app.add_middleware(AuditLogMiddleware) app.add_middleware(SecurityHeadersMiddleware) +@app.exception_handler(Exception) +async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: + logger.exception("Unhandled exception") + return JSONResponse( + status_code=500, + content={"detail": "Internal server error"}, + ) + + @app.get("/api/status") async def get_status() -> dict[str, str]: request_counter.add(1, {"method": "GET", "path": "/status"}) diff --git a/api/origin_validator.py b/api/origin_validator.py new file mode 100644 index 0000000..3dcd821 --- /dev/null +++ b/api/origin_validator.py @@ -0,0 +1,34 @@ +"""Origin validation middleware for state-changing requests.""" +import logging +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +logger = logging.getLogger("uvicorn") + +STATE_CHANGING_METHODS = {"POST", "PUT", "DELETE", "PATCH"} + + +class OriginValidatorMiddleware(BaseHTTPMiddleware): + """Reject state-changing requests with mismatched Origin header.""" + + def __init__(self, app, allowed_origins: list[str] | None = None) -> None: + super().__init__(app) + self._allowed = {o.rstrip("/") for o in (allowed_origins or [])} + + async def dispatch(self, request: Request, call_next) -> Response: + if request.method not in STATE_CHANGING_METHODS: + return await call_next(request) + + origin = request.headers.get("origin") + if origin is None: + return await call_next(request) + + if origin.rstrip("/") not in self._allowed: + logger.warning(f"Rejected request from origin: {origin}") + return JSONResponse( + status_code=403, + content={"detail": "Origin not allowed"}, + ) + + return await call_next(request) diff --git a/docker-compose.yml b/docker-compose.yml index 1cc1e16..1248c66 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -60,7 +60,7 @@ services: condition: service_healthy mysql: condition: service_healthy - command: ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001 --reload --reload-dir api --reload-dir services --reload-dir repositories --reload-dir models"] + command: ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001 --reload --reload-dir api --reload-dir services --reload-dir repositories --reload-dir models --no-server-header"] networks: - rec-network diff --git a/tests/unit/test_openapi_disabled.py b/tests/unit/test_openapi_disabled.py new file mode 100644 index 0000000..23b5d40 --- /dev/null +++ b/tests/unit/test_openapi_disabled.py @@ -0,0 +1,42 @@ +"""Unit tests for OpenAPI docs disabled in production.""" +from unittest import mock + +from starlette.testclient import TestClient + + +class TestOpenAPIDisabledInProduction: + """Verify docs/redoc/openapi are disabled when APP_ENV=production.""" + + def test_docs_disabled_in_production(self) -> None: + prod_env = { + "APP_ENV": "production", + "OIDC_CLIENT_ID": "test-client-id", + "JWT_SECRET": "test-secret-not-default", + } + with mock.patch.dict("os.environ", prod_env): + # Re-import to pick up the new env + import importlib + import api.config + importlib.reload(api.config) + import api.app + importlib.reload(api.app) + app = api.app.app + + client = TestClient(app, raise_server_exceptions=False) + assert client.get("/docs").status_code == 404 + assert client.get("/redoc").status_code == 404 + assert client.get("/openapi.json").status_code == 404 + + def test_docs_enabled_in_dev(self) -> None: + with mock.patch.dict("os.environ", {"APP_ENV": "dev"}): + import importlib + import api.config + importlib.reload(api.config) + import api.app + importlib.reload(api.app) + app = api.app.app + + client = TestClient(app, raise_server_exceptions=False) + # /docs should return 200 (Swagger UI) in non-production + assert client.get("/docs").status_code == 200 + assert client.get("/openapi.json").status_code == 200 diff --git a/tests/unit/test_origin_validator.py b/tests/unit/test_origin_validator.py new file mode 100644 index 0000000..da2b9ea --- /dev/null +++ b/tests/unit/test_origin_validator.py @@ -0,0 +1,81 @@ +"""Unit tests for api/origin_validator.py.""" +from starlette.testclient import TestClient +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.routing import Route + +from api.origin_validator import OriginValidatorMiddleware + + +ALLOWED_ORIGINS = ["https://example.com", "https://app.example.com/"] + + +async def _ok_endpoint(request: Request) -> JSONResponse: + return JSONResponse({"ok": True}) + + +def _build_app() -> Starlette: + app = Starlette(routes=[ + Route("/test", _ok_endpoint, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]), + ]) + app.add_middleware(OriginValidatorMiddleware, allowed_origins=ALLOWED_ORIGINS) + return app + + +class TestOriginValidator: + """Tests for OriginValidatorMiddleware.""" + + def test_get_request_bypasses_check(self) -> None: + """GET requests should not be checked for origin.""" + client = TestClient(_build_app()) + resp = client.get("/test", headers={"Origin": "https://evil.com"}) + assert resp.status_code == 200 + + def test_post_with_allowed_origin(self) -> None: + client = TestClient(_build_app()) + resp = client.post("/test", headers={"Origin": "https://example.com"}) + assert resp.status_code == 200 + + def test_post_with_allowed_origin_trailing_slash(self) -> None: + """Trailing slash on origin should be normalized.""" + client = TestClient(_build_app()) + resp = client.post("/test", headers={"Origin": "https://app.example.com/"}) + assert resp.status_code == 200 + + def test_post_with_disallowed_origin(self) -> None: + client = TestClient(_build_app()) + resp = client.post("/test", headers={"Origin": "https://evil.com"}) + assert resp.status_code == 403 + assert resp.json() == {"detail": "Origin not allowed"} + + def test_put_with_disallowed_origin(self) -> None: + client = TestClient(_build_app()) + resp = client.put("/test", headers={"Origin": "https://evil.com"}) + assert resp.status_code == 403 + + def test_delete_with_disallowed_origin(self) -> None: + client = TestClient(_build_app()) + resp = client.delete("/test", headers={"Origin": "https://evil.com"}) + assert resp.status_code == 403 + + def test_patch_with_disallowed_origin(self) -> None: + client = TestClient(_build_app()) + resp = client.patch("/test", headers={"Origin": "https://evil.com"}) + assert resp.status_code == 403 + + def test_post_without_origin_passes(self) -> None: + """Requests without an Origin header should pass through.""" + client = TestClient(_build_app()) + resp = client.post("/test") + assert resp.status_code == 200 + + def test_no_allowed_origins_configured(self) -> None: + """If no allowed origins are given, all origins are rejected.""" + app = Starlette(routes=[ + Route("/test", _ok_endpoint, methods=["POST"]), + ]) + app.add_middleware(OriginValidatorMiddleware, allowed_origins=[]) + client = TestClient(app) + resp = client.post("/test", headers={"Origin": "https://example.com"}) + assert resp.status_code == 403 diff --git a/tests/unit/test_security_headers.py b/tests/unit/test_security_headers.py index 1e61423..33accaa 100644 --- a/tests/unit/test_security_headers.py +++ b/tests/unit/test_security_headers.py @@ -54,3 +54,9 @@ class TestSecurityHeaders: client = TestClient(_build_app()) resp = client.get("/test") assert "Strict-Transport-Security" not in resp.headers + + def test_server_header_not_present(self) -> None: + """The Server header should not leak server software info.""" + client = TestClient(_build_app()) + resp = client.get("/test") + assert "Server" not in resp.headers or "uvicorn" not in resp.headers.get("Server", "").lower() diff --git a/tests/unit/test_unhandled_exception.py b/tests/unit/test_unhandled_exception.py new file mode 100644 index 0000000..5d128c6 --- /dev/null +++ b/tests/unit/test_unhandled_exception.py @@ -0,0 +1,65 @@ +"""Unit tests for the unhandled exception handler in api/app.py.""" +from unittest import mock + +import pytest +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse +from starlette.requests import Request +from starlette.testclient import TestClient + + +def _build_app() -> FastAPI: + """Build a minimal FastAPI app with the unhandled exception handler.""" + app = FastAPI() + + @app.exception_handler(Exception) + async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=500, + content={"detail": "Internal server error"}, + ) + + @app.get("/ok") + async def ok() -> dict[str, str]: + return {"status": "ok"} + + @app.get("/crash") + async def crash() -> None: + raise RuntimeError("something unexpected") + + @app.get("/http-error") + async def http_error() -> None: + raise HTTPException(status_code=403, detail="forbidden") + + return app + + +class TestUnhandledExceptionHandler: + """Tests for the global exception handler.""" + + def test_normal_request_unaffected(self) -> None: + client = TestClient(_build_app()) + resp = client.get("/ok") + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} + + def test_unhandled_exception_returns_generic_500(self) -> None: + client = TestClient(_build_app(), raise_server_exceptions=False) + resp = client.get("/crash") + assert resp.status_code == 500 + assert resp.json() == {"detail": "Internal server error"} + + def test_http_exception_not_caught(self) -> None: + """HTTPException should still be handled by FastAPI's built-in handler.""" + client = TestClient(_build_app()) + resp = client.get("/http-error") + assert resp.status_code == 403 + assert resp.json() == {"detail": "forbidden"} + + def test_error_body_does_not_leak_details(self) -> None: + """The response must not contain the actual exception message.""" + client = TestClient(_build_app(), raise_server_exceptions=False) + resp = client.get("/crash") + body = resp.text + assert "something unexpected" not in body + assert "RuntimeError" not in body