Harden backend security: IDOR fix, error sanitization, rate limiter fallback, security headers

- Fix task status IDOR by adding ownership check; suppress traceback/error in production
- Passkey routes: return generic error messages for internal exceptions, keep ValueError for user-facing
- JWT_SECRET and OIDC_CLIENT_ID: raise RuntimeError in production when using defaults
- Rate limiter: add in-memory fallback counter when Redis is unavailable
- Fix X-Forwarded-For IP spoofing with trusted_proxy_depth (rightmost-N selection)
- Add SecurityHeadersMiddleware (X-Content-Type-Options, X-Frame-Options, CSP, conditional HSTS)
- CORS: add PUT/DELETE methods for POI routes
- POI input validation: field length and coordinate range constraints
- QueryParameters: add min_sqm <= max_sqm validation
This commit is contained in:
Viktor Barzin 2026-02-08 19:42:30 +00:00
parent e431eaf2aa
commit 0a9a83507e
No known key found for this signature in database
GPG key ID: 0EB088298288D958
8 changed files with 133 additions and 32 deletions

View file

@ -5,15 +5,16 @@ import logging
import logging.config import logging.config
from typing import Annotated, AsyncGenerator, Optional from typing import Annotated, AsyncGenerator, Optional
from api.auth import get_current_user from api.auth import get_current_user
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS, APP_ENV
from api.passkey_routes import passkey_router from api.passkey_routes import passkey_router
from api.poi_routes import poi_router from api.poi_routes import poi_router
from api.rate_limit_config import RateLimitConfig from api.rate_limit_config import RateLimitConfig
from api.rate_limiter import RateLimitMiddleware from api.rate_limiter import RateLimitMiddleware
from api.audit_middleware import AuditLogMiddleware from api.audit_middleware import AuditLogMiddleware
from api.metrics_guard import MetricsGuardMiddleware from api.metrics_guard import MetricsGuardMiddleware
from api.security_headers import SecurityHeadersMiddleware
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Query from fastapi import Depends, FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from api.auth import User from api.auth import User
from models.listing import QueryParameters, ListingType, FurnishType from models.listing import QueryParameters, ListingType, FurnishType
@ -103,7 +104,7 @@ hist = meter.create_histogram(
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS], allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
allow_methods=["GET", "POST"], allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["Authorization", "Content-Type"], allow_headers=["Authorization", "Content-Type"],
) )
@ -114,6 +115,8 @@ app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config) app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config)
# 1. Audit logging — logs everything including 429s and 403s # 1. Audit logging — logs everything including 429s and 403s
app.add_middleware(AuditLogMiddleware) app.add_middleware(AuditLogMiddleware)
# 0. Security headers — adds standard security headers to all responses
app.add_middleware(SecurityHeadersMiddleware)
@app.get("/api/status") @app.get("/api/status")
@ -324,6 +327,9 @@ async def get_task_status(
task_id: str, task_id: str,
) -> dict[str, str | int | float | None]: ) -> dict[str, str | int | float | None]:
"""Get the status of a background task.""" """Get the status of a background task."""
user_tasks = task_service.get_user_tasks(user.email)
if task_id not in user_tasks:
raise HTTPException(status_code=404, detail="Task not found")
status = task_service.get_task_status(task_id) status = task_service.get_task_status(task_id)
return { return {
"task_id": status.task_id, "task_id": status.task_id,
@ -333,8 +339,8 @@ async def get_task_status(
"processed": status.processed, "processed": status.processed,
"total": status.total, "total": status.total,
"message": status.message, "message": status.message,
"error": status.error, "error": status.error if APP_ENV != "production" else None,
"traceback": status.traceback, "traceback": status.traceback if APP_ENV != "production" else None,
} }

View file

@ -5,9 +5,15 @@ import os
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
APP_ENV = os.getenv("APP_ENV", "development")
# Authentik OIDC Configuration # Authentik OIDC Configuration
AUTHENTIK_URL = os.getenv("AUTHENTIK_URL", "https://authentik.viktorbarzin.me") AUTHENTIK_URL = os.getenv("AUTHENTIK_URL", "https://authentik.viktorbarzin.me")
OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe") OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "")
if APP_ENV == "production" and not OIDC_CLIENT_ID:
raise RuntimeError("OIDC_CLIENT_ID must be set in production")
if not OIDC_CLIENT_ID:
_logger.warning("OIDC_CLIENT_ID not set; OIDC login will not work")
OIDC_METADATA_URL = ( OIDC_METADATA_URL = (
f"{AUTHENTIK_URL}/application/o/wrongmove/.well-known/openid-configuration" f"{AUTHENTIK_URL}/application/o/wrongmove/.well-known/openid-configuration"
) )
@ -27,6 +33,8 @@ WEBAUTHN_ORIGIN = os.getenv("WEBAUTHN_ORIGIN", "https://localhost")
# JWT Configuration (for passkey-issued tokens) # JWT Configuration (for passkey-issued tokens)
JWT_SECRET = os.getenv("JWT_SECRET", "change-me-in-production") JWT_SECRET = os.getenv("JWT_SECRET", "change-me-in-production")
if JWT_SECRET == "change-me-in-production": if JWT_SECRET == "change-me-in-production":
if APP_ENV == "production":
raise RuntimeError("JWT_SECRET must be changed from default in production")
_logger.warning("JWT_SECRET is using the default value. Set JWT_SECRET env var in production.") _logger.warning("JWT_SECRET is using the default value. Set JWT_SECRET env var in production.")
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256") JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
JWT_EXPIRATION_HOURS = int(os.getenv("JWT_EXPIRATION_HOURS", "24")) JWT_EXPIRATION_HOURS = int(os.getenv("JWT_EXPIRATION_HOURS", "24"))

View file

@ -44,9 +44,11 @@ async def register_begin(body: RegisterBeginRequest) -> RegisterBeginResponse:
body.email, user_repo body.email, user_repo
) )
return RegisterBeginResponse(options=options, session_id=session_id) return RegisterBeginResponse(options=options, session_id=session_id)
except Exception as e: except ValueError as e:
logger.error(f"Registration begin failed: {e}")
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
except Exception:
logger.exception("Registration begin failed")
raise HTTPException(status_code=400, detail="Registration failed. Please try again.")
@passkey_router.post("/register/complete", response_model=AuthTokenResponse) @passkey_router.post("/register/complete", response_model=AuthTokenResponse)
@ -60,9 +62,9 @@ async def register_complete(body: CeremonyCompleteRequest) -> AuthTokenResponse:
return AuthTokenResponse(token=token) return AuthTokenResponse(token=token)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
except Exception as e: except Exception:
logger.error(f"Registration complete failed: {e}") logger.exception("Registration complete failed")
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail="Registration could not be completed.")
@passkey_router.post("/login/begin", response_model=LoginBeginResponse) @passkey_router.post("/login/begin", response_model=LoginBeginResponse)
@ -72,9 +74,11 @@ async def login_begin() -> LoginBeginResponse:
user_repo = UserRepository(engine) user_repo = UserRepository(engine)
options, session_id = passkey_service.begin_authentication(user_repo) options, session_id = passkey_service.begin_authentication(user_repo)
return LoginBeginResponse(options=options, session_id=session_id) return LoginBeginResponse(options=options, session_id=session_id)
except Exception as e: except ValueError as e:
logger.error(f"Login begin failed: {e}")
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
except Exception:
logger.exception("Login begin failed")
raise HTTPException(status_code=400, detail="Login initiation failed. Please try again.")
@passkey_router.post("/login/complete", response_model=AuthTokenResponse) @passkey_router.post("/login/complete", response_model=AuthTokenResponse)
@ -88,6 +92,6 @@ async def login_complete(body: CeremonyCompleteRequest) -> AuthTokenResponse:
return AuthTokenResponse(token=token) return AuthTokenResponse(token=token)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
except Exception as e: except Exception:
logger.error(f"Login complete failed: {e}") logger.exception("Login complete failed")
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail="Login could not be completed.")

View file

@ -2,7 +2,7 @@ import logging
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel, Field
from api.auth import User, get_current_user from api.auth import User, get_current_user
from database import engine from database import engine
@ -17,17 +17,17 @@ poi_router = APIRouter(prefix="/api/poi", tags=["poi"])
class CreatePOIRequest(BaseModel): class CreatePOIRequest(BaseModel):
name: str name: str = Field(max_length=200)
address: str address: str = Field(max_length=500)
latitude: float latitude: float = Field(ge=-90, le=90)
longitude: float longitude: float = Field(ge=-180, le=180)
class UpdatePOIRequest(BaseModel): class UpdatePOIRequest(BaseModel):
name: str | None = None name: str | None = Field(default=None, max_length=200)
address: str | None = None address: str | None = Field(default=None, max_length=500)
latitude: float | None = None latitude: float | None = Field(default=None, ge=-90, le=90)
longitude: float | None = None longitude: float | None = Field(default=None, ge=-180, le=180)
class POIResponse(BaseModel): class POIResponse(BaseModel):

View file

@ -47,6 +47,9 @@ class RateLimitConfig:
# Metrics endpoint IP allowlist (comma-separated CIDRs) # Metrics endpoint IP allowlist (comma-separated CIDRs)
metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1" metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1"
# X-Forwarded-For trusted proxy depth
trusted_proxy_depth: int = 1
@classmethod @classmethod
def from_env(cls) -> Self: def from_env(cls) -> Self:
"""Load configuration from environment variables. """Load configuration from environment variables.

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
import os import os
import time
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
import jwt import jwt
@ -54,21 +55,42 @@ def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
return config.endpoint_limits.get(path) return config.endpoint_limits.get(path)
def _client_ip(request: Request) -> str: def _client_ip(request: Request, depth: int = 1) -> str:
"""Best-effort client IP from X-Forwarded-For or connection.""" """Best-effort client IP from X-Forwarded-For or connection."""
forwarded = request.headers.get("x-forwarded-for") forwarded = request.headers.get("x-forwarded-for")
if forwarded: if forwarded:
return forwarded.split(",")[0].strip() parts = [p.strip() for p in forwarded.split(",")]
idx = max(0, len(parts) - depth)
return parts[idx]
client = request.client client = request.client
return client.host if client else "unknown" return client.host if client else "unknown"
class _InMemoryCounter:
"""Simple fixed-window counter for rate limiting when Redis is unavailable."""
def __init__(self) -> None:
self._windows: dict[str, tuple[int, float]] = {}
def check(self, key: str, max_requests: int, window_seconds: int) -> tuple[bool, int]:
"""Returns (allowed, remaining). Increments counter."""
now = time.monotonic()
count, window_start = self._windows.get(key, (0, now))
if now - window_start >= window_seconds:
count, window_start = 0, now
count += 1
self._windows[key] = (count, window_start)
remaining = max(0, max_requests - count)
return count <= max_requests, remaining
class RateLimitMiddleware(BaseHTTPMiddleware): class RateLimitMiddleware(BaseHTTPMiddleware):
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis.""" """Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def] def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
super().__init__(app) super().__init__(app)
self.config = config or RateLimitConfig.from_env() self.config = config or RateLimitConfig.from_env()
self._fallback = _InMemoryCounter()
try: try:
self._redis = _get_rate_limit_redis(self.config) self._redis = _get_rate_limit_redis(self.config)
self._redis.ping() self._redis.ping()
@ -88,11 +110,22 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
return await call_next(request) return await call_next(request)
# Determine identity for the counter key # Determine identity for the counter key
identity = _extract_user_email(request) or _client_ip(request) identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
# If Redis is unavailable, fail open # If Redis is unavailable, use in-memory fallback
if self._redis is None: if self._redis is None:
return await call_next(request) fallback_key = f"ratelimit:{identity}:{path}"
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
redis_key = f"ratelimit:{identity}:{path}" redis_key = f"ratelimit:{identity}:{path}"
try: try:
@ -128,5 +161,16 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
return response return response
except redis.RedisError as e: except redis.RedisError as e:
logger.warning(f"Rate limiter Redis error, failing open: {e}") logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
return await call_next(request) fallback_key = f"ratelimit:{identity}:{path}"
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response

28
api/security_headers.py Normal file
View file

@ -0,0 +1,28 @@
"""Security headers middleware."""
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add standard security headers to every response."""
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self'; "
"style-src 'self' 'unsafe-inline'; "
"connect-src 'self' https://*.mapbox.com; "
"img-src 'self' data: https://*.mapbox.com https://media.rightmove.co.uk; "
"frame-ancestors 'none'"
)
# Only add HSTS when behind TLS-terminating proxy
if request.headers.get("x-forwarded-proto") == "https":
response.headers["Strict-Transport-Security"] = (
"max-age=63072000; includeSubDomains"
)
return response

View file

@ -248,4 +248,12 @@ class QueryParameters(BaseModel):
raise ValueError( raise ValueError(
f"min_price_per_sqm ({self.min_price_per_sqm}) must be <= max_price_per_sqm ({self.max_price_per_sqm})" f"min_price_per_sqm ({self.min_price_per_sqm}) must be <= max_price_per_sqm ({self.max_price_per_sqm})"
) )
if (
self.min_sqm is not None
and self.max_sqm is not None
and self.min_sqm > self.max_sqm
):
raise ValueError(
f"min_sqm ({self.min_sqm}) must be <= max_sqm ({self.max_sqm})"
)
return self return self