Harden backend security: IDOR fix, error sanitization, rate limiter fallback, security headers

- Fix task status IDOR by adding ownership check; suppress traceback/error in production
- Passkey routes: return generic error messages for internal exceptions, keep ValueError for user-facing
- JWT_SECRET and OIDC_CLIENT_ID: raise RuntimeError in production when using defaults
- Rate limiter: add in-memory fallback counter when Redis is unavailable
- Fix X-Forwarded-For IP spoofing with trusted_proxy_depth (rightmost-N selection)
- Add SecurityHeadersMiddleware (X-Content-Type-Options, X-Frame-Options, CSP, conditional HSTS)
- CORS: add PUT/DELETE methods for POI routes
- POI input validation: field length and coordinate range constraints
- QueryParameters: add min_sqm <= max_sqm validation
This commit is contained in:
Viktor Barzin 2026-02-08 19:42:30 +00:00
parent e431eaf2aa
commit 0a9a83507e
No known key found for this signature in database
GPG key ID: 0EB088298288D958
8 changed files with 133 additions and 32 deletions

View file

@ -5,15 +5,16 @@ import logging
import logging.config
from typing import Annotated, AsyncGenerator, Optional
from api.auth import get_current_user
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS, APP_ENV
from api.passkey_routes import passkey_router
from api.poi_routes import poi_router
from api.rate_limit_config import RateLimitConfig
from api.rate_limiter import RateLimitMiddleware
from api.audit_middleware import AuditLogMiddleware
from api.metrics_guard import MetricsGuardMiddleware
from api.security_headers import SecurityHeadersMiddleware
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Query
from fastapi import Depends, FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse
from api.auth import User
from models.listing import QueryParameters, ListingType, FurnishType
@ -103,7 +104,7 @@ hist = meter.create_histogram(
app.add_middleware(
CORSMiddleware,
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
allow_methods=["GET", "POST"],
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["Authorization", "Content-Type"],
)
@ -114,6 +115,8 @@ app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config)
# 1. Audit logging — logs everything including 429s and 403s
app.add_middleware(AuditLogMiddleware)
# 0. Security headers — adds standard security headers to all responses
app.add_middleware(SecurityHeadersMiddleware)
@app.get("/api/status")
@ -324,6 +327,9 @@ async def get_task_status(
task_id: str,
) -> dict[str, str | int | float | None]:
"""Get the status of a background task."""
user_tasks = task_service.get_user_tasks(user.email)
if task_id not in user_tasks:
raise HTTPException(status_code=404, detail="Task not found")
status = task_service.get_task_status(task_id)
return {
"task_id": status.task_id,
@ -333,8 +339,8 @@ async def get_task_status(
"processed": status.processed,
"total": status.total,
"message": status.message,
"error": status.error,
"traceback": status.traceback,
"error": status.error if APP_ENV != "production" else None,
"traceback": status.traceback if APP_ENV != "production" else None,
}

View file

@ -5,9 +5,15 @@ import os
_logger = logging.getLogger(__name__)
APP_ENV = os.getenv("APP_ENV", "development")
# Authentik OIDC Configuration
AUTHENTIK_URL = os.getenv("AUTHENTIK_URL", "https://authentik.viktorbarzin.me")
OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe")
OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "")
if APP_ENV == "production" and not OIDC_CLIENT_ID:
raise RuntimeError("OIDC_CLIENT_ID must be set in production")
if not OIDC_CLIENT_ID:
_logger.warning("OIDC_CLIENT_ID not set; OIDC login will not work")
OIDC_METADATA_URL = (
f"{AUTHENTIK_URL}/application/o/wrongmove/.well-known/openid-configuration"
)
@ -27,6 +33,8 @@ WEBAUTHN_ORIGIN = os.getenv("WEBAUTHN_ORIGIN", "https://localhost")
# JWT Configuration (for passkey-issued tokens)
JWT_SECRET = os.getenv("JWT_SECRET", "change-me-in-production")
if JWT_SECRET == "change-me-in-production":
if APP_ENV == "production":
raise RuntimeError("JWT_SECRET must be changed from default in production")
_logger.warning("JWT_SECRET is using the default value. Set JWT_SECRET env var in production.")
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
JWT_EXPIRATION_HOURS = int(os.getenv("JWT_EXPIRATION_HOURS", "24"))

View file

@ -44,9 +44,11 @@ async def register_begin(body: RegisterBeginRequest) -> RegisterBeginResponse:
body.email, user_repo
)
return RegisterBeginResponse(options=options, session_id=session_id)
except Exception as e:
logger.error(f"Registration begin failed: {e}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception:
logger.exception("Registration begin failed")
raise HTTPException(status_code=400, detail="Registration failed. Please try again.")
@passkey_router.post("/register/complete", response_model=AuthTokenResponse)
@ -60,9 +62,9 @@ async def register_complete(body: CeremonyCompleteRequest) -> AuthTokenResponse:
return AuthTokenResponse(token=token)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Registration complete failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception:
logger.exception("Registration complete failed")
raise HTTPException(status_code=400, detail="Registration could not be completed.")
@passkey_router.post("/login/begin", response_model=LoginBeginResponse)
@ -72,9 +74,11 @@ async def login_begin() -> LoginBeginResponse:
user_repo = UserRepository(engine)
options, session_id = passkey_service.begin_authentication(user_repo)
return LoginBeginResponse(options=options, session_id=session_id)
except Exception as e:
logger.error(f"Login begin failed: {e}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception:
logger.exception("Login begin failed")
raise HTTPException(status_code=400, detail="Login initiation failed. Please try again.")
@passkey_router.post("/login/complete", response_model=AuthTokenResponse)
@ -88,6 +92,6 @@ async def login_complete(body: CeremonyCompleteRequest) -> AuthTokenResponse:
return AuthTokenResponse(token=token)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Login complete failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception:
logger.exception("Login complete failed")
raise HTTPException(status_code=400, detail="Login could not be completed.")

View file

@ -2,7 +2,7 @@ import logging
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from pydantic import BaseModel, Field
from api.auth import User, get_current_user
from database import engine
@ -17,17 +17,17 @@ poi_router = APIRouter(prefix="/api/poi", tags=["poi"])
class CreatePOIRequest(BaseModel):
name: str
address: str
latitude: float
longitude: float
name: str = Field(max_length=200)
address: str = Field(max_length=500)
latitude: float = Field(ge=-90, le=90)
longitude: float = Field(ge=-180, le=180)
class UpdatePOIRequest(BaseModel):
name: str | None = None
address: str | None = None
latitude: float | None = None
longitude: float | None = None
name: str | None = Field(default=None, max_length=200)
address: str | None = Field(default=None, max_length=500)
latitude: float | None = Field(default=None, ge=-90, le=90)
longitude: float | None = Field(default=None, ge=-180, le=180)
class POIResponse(BaseModel):

View file

@ -47,6 +47,9 @@ class RateLimitConfig:
# Metrics endpoint IP allowlist (comma-separated CIDRs)
metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1"
# X-Forwarded-For trusted proxy depth
trusted_proxy_depth: int = 1
@classmethod
def from_env(cls) -> Self:
"""Load configuration from environment variables.

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import logging
import os
import time
from urllib.parse import urlparse, urlunparse
import jwt
@ -54,21 +55,42 @@ def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
return config.endpoint_limits.get(path)
def _client_ip(request: Request) -> str:
def _client_ip(request: Request, depth: int = 1) -> str:
"""Best-effort client IP from X-Forwarded-For or connection."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
parts = [p.strip() for p in forwarded.split(",")]
idx = max(0, len(parts) - depth)
return parts[idx]
client = request.client
return client.host if client else "unknown"
class _InMemoryCounter:
"""Simple fixed-window counter for rate limiting when Redis is unavailable."""
def __init__(self) -> None:
self._windows: dict[str, tuple[int, float]] = {}
def check(self, key: str, max_requests: int, window_seconds: int) -> tuple[bool, int]:
"""Returns (allowed, remaining). Increments counter."""
now = time.monotonic()
count, window_start = self._windows.get(key, (0, now))
if now - window_start >= window_seconds:
count, window_start = 0, now
count += 1
self._windows[key] = (count, window_start)
remaining = max(0, max_requests - count)
return count <= max_requests, remaining
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
super().__init__(app)
self.config = config or RateLimitConfig.from_env()
self._fallback = _InMemoryCounter()
try:
self._redis = _get_rate_limit_redis(self.config)
self._redis.ping()
@ -88,11 +110,22 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
return await call_next(request)
# Determine identity for the counter key
identity = _extract_user_email(request) or _client_ip(request)
identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
# If Redis is unavailable, fail open
# If Redis is unavailable, use in-memory fallback
if self._redis is None:
return await call_next(request)
fallback_key = f"ratelimit:{identity}:{path}"
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
redis_key = f"ratelimit:{identity}:{path}"
try:
@ -128,5 +161,16 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
return response
except redis.RedisError as e:
logger.warning(f"Rate limiter Redis error, failing open: {e}")
return await call_next(request)
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
fallback_key = f"ratelimit:{identity}:{path}"
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response

28
api/security_headers.py Normal file
View file

@ -0,0 +1,28 @@
"""Security headers middleware."""
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add standard security headers to every response."""
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self'; "
"style-src 'self' 'unsafe-inline'; "
"connect-src 'self' https://*.mapbox.com; "
"img-src 'self' data: https://*.mapbox.com https://media.rightmove.co.uk; "
"frame-ancestors 'none'"
)
# Only add HSTS when behind TLS-terminating proxy
if request.headers.get("x-forwarded-proto") == "https":
response.headers["Strict-Transport-Security"] = (
"max-age=63072000; includeSubDomains"
)
return response

View file

@ -248,4 +248,12 @@ class QueryParameters(BaseModel):
raise ValueError(
f"min_price_per_sqm ({self.min_price_per_sqm}) must be <= max_price_per_sqm ({self.max_price_per_sqm})"
)
if (
self.min_sqm is not None
and self.max_sqm is not None
and self.min_sqm > self.max_sqm
):
raise ValueError(
f"min_sqm ({self.min_sqm}) must be <= max_sqm ({self.max_sqm})"
)
return self