Add API anti-abuse hardening: disable docs in prod, origin validator, exception handler
- Disable OpenAPI docs/redoc/openapi.json when APP_ENV=production - Strip uvicorn Server header with --no-server-header in Dockerfile and docker-compose.yml - Add OriginValidatorMiddleware to reject state-changing requests from disallowed origins - Add global exception handler to prevent stack trace leakage on unhandled errors - Add tests for all new security features (OpenAPI, origin validation, exception handler, server header)
This commit is contained in:
parent
162d9a886d
commit
1ace45353a
8 changed files with 252 additions and 4 deletions
|
|
@ -42,4 +42,4 @@ ENV PATH="/app/.venv/bin:$PATH"
|
|||
COPY . .
|
||||
|
||||
EXPOSE 5001
|
||||
CMD ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001"]
|
||||
CMD ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001 --no-server-header"]
|
||||
|
|
|
|||
24
api/app.py
24
api/app.py
|
|
@ -13,9 +13,11 @@ from api.rate_limiter import RateLimitMiddleware
|
|||
from api.audit_middleware import AuditLogMiddleware
|
||||
from api.metrics_guard import MetricsGuardMiddleware
|
||||
from api.security_headers import SecurityHeadersMiddleware
|
||||
from api.origin_validator import OriginValidatorMiddleware
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Depends, FastAPI, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.requests import Request
|
||||
from api.auth import User
|
||||
from models.listing import QueryParameters, ListingType, FurnishType
|
||||
from notifications import send_notification
|
||||
|
|
@ -85,7 +87,11 @@ def get_query_parameters(
|
|||
)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app = FastAPI(
|
||||
docs_url=None if APP_ENV == "production" else "/docs",
|
||||
redoc_url=None if APP_ENV == "production" else "/redoc",
|
||||
openapi_url=None if APP_ENV == "production" else "/openapi.json",
|
||||
)
|
||||
app.include_router(passkey_router)
|
||||
app.include_router(poi_router)
|
||||
app.mount("/metrics", metrics_app)
|
||||
|
|
@ -108,6 +114,11 @@ app.add_middleware(
|
|||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
OriginValidatorMiddleware,
|
||||
allowed_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
|
||||
)
|
||||
|
||||
# Security middleware (added bottom-to-top; last added = outermost)
|
||||
# 3. Rate limiting — enforces per-user limits
|
||||
app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
|
||||
|
|
@ -119,6 +130,15 @@ app.add_middleware(AuditLogMiddleware)
|
|||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
logger.exception("Unhandled exception")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal server error"},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status() -> dict[str, str]:
|
||||
request_counter.add(1, {"method": "GET", "path": "/status"})
|
||||
|
|
|
|||
34
api/origin_validator.py
Normal file
34
api/origin_validator.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""Origin validation middleware for state-changing requests."""
|
||||
import logging
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
STATE_CHANGING_METHODS = {"POST", "PUT", "DELETE", "PATCH"}
|
||||
|
||||
|
||||
class OriginValidatorMiddleware(BaseHTTPMiddleware):
|
||||
"""Reject state-changing requests with mismatched Origin header."""
|
||||
|
||||
def __init__(self, app, allowed_origins: list[str] | None = None) -> None:
|
||||
super().__init__(app)
|
||||
self._allowed = {o.rstrip("/") for o in (allowed_origins or [])}
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
if request.method not in STATE_CHANGING_METHODS:
|
||||
return await call_next(request)
|
||||
|
||||
origin = request.headers.get("origin")
|
||||
if origin is None:
|
||||
return await call_next(request)
|
||||
|
||||
if origin.rstrip("/") not in self._allowed:
|
||||
logger.warning(f"Rejected request from origin: {origin}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Origin not allowed"},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
|
@ -60,7 +60,7 @@ services:
|
|||
condition: service_healthy
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
command: ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001 --reload --reload-dir api --reload-dir services --reload-dir repositories --reload-dir models"]
|
||||
command: ["sh", "-c", "alembic upgrade head && uvicorn api.app:app --host 0.0.0.0 --port 5001 --reload --reload-dir api --reload-dir services --reload-dir repositories --reload-dir models --no-server-header"]
|
||||
networks:
|
||||
- rec-network
|
||||
|
||||
|
|
|
|||
42
tests/unit/test_openapi_disabled.py
Normal file
42
tests/unit/test_openapi_disabled.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Unit tests for OpenAPI docs disabled in production."""
|
||||
from unittest import mock
|
||||
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
||||
class TestOpenAPIDisabledInProduction:
|
||||
"""Verify docs/redoc/openapi are disabled when APP_ENV=production."""
|
||||
|
||||
def test_docs_disabled_in_production(self) -> None:
|
||||
prod_env = {
|
||||
"APP_ENV": "production",
|
||||
"OIDC_CLIENT_ID": "test-client-id",
|
||||
"JWT_SECRET": "test-secret-not-default",
|
||||
}
|
||||
with mock.patch.dict("os.environ", prod_env):
|
||||
# Re-import to pick up the new env
|
||||
import importlib
|
||||
import api.config
|
||||
importlib.reload(api.config)
|
||||
import api.app
|
||||
importlib.reload(api.app)
|
||||
app = api.app.app
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
assert client.get("/docs").status_code == 404
|
||||
assert client.get("/redoc").status_code == 404
|
||||
assert client.get("/openapi.json").status_code == 404
|
||||
|
||||
def test_docs_enabled_in_dev(self) -> None:
|
||||
with mock.patch.dict("os.environ", {"APP_ENV": "dev"}):
|
||||
import importlib
|
||||
import api.config
|
||||
importlib.reload(api.config)
|
||||
import api.app
|
||||
importlib.reload(api.app)
|
||||
app = api.app.app
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
# /docs should return 200 (Swagger UI) in non-production
|
||||
assert client.get("/docs").status_code == 200
|
||||
assert client.get("/openapi.json").status_code == 200
|
||||
81
tests/unit/test_origin_validator.py
Normal file
81
tests/unit/test_origin_validator.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
"""Unit tests for api/origin_validator.py."""
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from api.origin_validator import OriginValidatorMiddleware
|
||||
|
||||
|
||||
ALLOWED_ORIGINS = ["https://example.com", "https://app.example.com/"]
|
||||
|
||||
|
||||
async def _ok_endpoint(request: Request) -> JSONResponse:
|
||||
return JSONResponse({"ok": True})
|
||||
|
||||
|
||||
def _build_app() -> Starlette:
|
||||
app = Starlette(routes=[
|
||||
Route("/test", _ok_endpoint, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]),
|
||||
])
|
||||
app.add_middleware(OriginValidatorMiddleware, allowed_origins=ALLOWED_ORIGINS)
|
||||
return app
|
||||
|
||||
|
||||
class TestOriginValidator:
|
||||
"""Tests for OriginValidatorMiddleware."""
|
||||
|
||||
def test_get_request_bypasses_check(self) -> None:
|
||||
"""GET requests should not be checked for origin."""
|
||||
client = TestClient(_build_app())
|
||||
resp = client.get("/test", headers={"Origin": "https://evil.com"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_post_with_allowed_origin(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.post("/test", headers={"Origin": "https://example.com"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_post_with_allowed_origin_trailing_slash(self) -> None:
|
||||
"""Trailing slash on origin should be normalized."""
|
||||
client = TestClient(_build_app())
|
||||
resp = client.post("/test", headers={"Origin": "https://app.example.com/"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_post_with_disallowed_origin(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.post("/test", headers={"Origin": "https://evil.com"})
|
||||
assert resp.status_code == 403
|
||||
assert resp.json() == {"detail": "Origin not allowed"}
|
||||
|
||||
def test_put_with_disallowed_origin(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.put("/test", headers={"Origin": "https://evil.com"})
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_delete_with_disallowed_origin(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.delete("/test", headers={"Origin": "https://evil.com"})
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_patch_with_disallowed_origin(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.patch("/test", headers={"Origin": "https://evil.com"})
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_post_without_origin_passes(self) -> None:
|
||||
"""Requests without an Origin header should pass through."""
|
||||
client = TestClient(_build_app())
|
||||
resp = client.post("/test")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_no_allowed_origins_configured(self) -> None:
|
||||
"""If no allowed origins are given, all origins are rejected."""
|
||||
app = Starlette(routes=[
|
||||
Route("/test", _ok_endpoint, methods=["POST"]),
|
||||
])
|
||||
app.add_middleware(OriginValidatorMiddleware, allowed_origins=[])
|
||||
client = TestClient(app)
|
||||
resp = client.post("/test", headers={"Origin": "https://example.com"})
|
||||
assert resp.status_code == 403
|
||||
|
|
@ -54,3 +54,9 @@ class TestSecurityHeaders:
|
|||
client = TestClient(_build_app())
|
||||
resp = client.get("/test")
|
||||
assert "Strict-Transport-Security" not in resp.headers
|
||||
|
||||
def test_server_header_not_present(self) -> None:
|
||||
"""The Server header should not leak server software info."""
|
||||
client = TestClient(_build_app())
|
||||
resp = client.get("/test")
|
||||
assert "Server" not in resp.headers or "uvicorn" not in resp.headers.get("Server", "").lower()
|
||||
|
|
|
|||
65
tests/unit/test_unhandled_exception.py
Normal file
65
tests/unit/test_unhandled_exception.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
"""Unit tests for the unhandled exception handler in api/app.py."""
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.requests import Request
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
||||
def _build_app() -> FastAPI:
|
||||
"""Build a minimal FastAPI app with the unhandled exception handler."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal server error"},
|
||||
)
|
||||
|
||||
@app.get("/ok")
|
||||
async def ok() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/crash")
|
||||
async def crash() -> None:
|
||||
raise RuntimeError("something unexpected")
|
||||
|
||||
@app.get("/http-error")
|
||||
async def http_error() -> None:
|
||||
raise HTTPException(status_code=403, detail="forbidden")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class TestUnhandledExceptionHandler:
|
||||
"""Tests for the global exception handler."""
|
||||
|
||||
def test_normal_request_unaffected(self) -> None:
|
||||
client = TestClient(_build_app())
|
||||
resp = client.get("/ok")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "ok"}
|
||||
|
||||
def test_unhandled_exception_returns_generic_500(self) -> None:
|
||||
client = TestClient(_build_app(), raise_server_exceptions=False)
|
||||
resp = client.get("/crash")
|
||||
assert resp.status_code == 500
|
||||
assert resp.json() == {"detail": "Internal server error"}
|
||||
|
||||
def test_http_exception_not_caught(self) -> None:
|
||||
"""HTTPException should still be handled by FastAPI's built-in handler."""
|
||||
client = TestClient(_build_app())
|
||||
resp = client.get("/http-error")
|
||||
assert resp.status_code == 403
|
||||
assert resp.json() == {"detail": "forbidden"}
|
||||
|
||||
def test_error_body_does_not_leak_details(self) -> None:
|
||||
"""The response must not contain the actual exception message."""
|
||||
client = TestClient(_build_app(), raise_server_exceptions=False)
|
||||
resp = client.get("/crash")
|
||||
body = resp.text
|
||||
assert "something unexpected" not in body
|
||||
assert "RuntimeError" not in body
|
||||
Loading…
Add table
Add a link
Reference in a new issue