Harden backend security: IDOR fix, error sanitization, rate limiter fallback, security headers
- Fix task status IDOR by adding ownership check; suppress traceback/error in production - Passkey routes: return generic error messages for internal exceptions, keep ValueError for user-facing - JWT_SECRET and OIDC_CLIENT_ID: raise RuntimeError in production when using defaults - Rate limiter: add in-memory fallback counter when Redis is unavailable - Fix X-Forwarded-For IP spoofing with trusted_proxy_depth (rightmost-N selection) - Add SecurityHeadersMiddleware (X-Content-Type-Options, X-Frame-Options, CSP, conditional HSTS) - CORS: add PUT/DELETE methods for POI routes - POI input validation: field length and coordinate range constraints - QueryParameters: add min_sqm <= max_sqm validation
This commit is contained in:
parent
e431eaf2aa
commit
0a9a83507e
8 changed files with 133 additions and 32 deletions
16
api/app.py
16
api/app.py
|
|
@ -5,15 +5,16 @@ import logging
|
|||
import logging.config
|
||||
from typing import Annotated, AsyncGenerator, Optional
|
||||
from api.auth import get_current_user
|
||||
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
||||
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS, APP_ENV
|
||||
from api.passkey_routes import passkey_router
|
||||
from api.poi_routes import poi_router
|
||||
from api.rate_limit_config import RateLimitConfig
|
||||
from api.rate_limiter import RateLimitMiddleware
|
||||
from api.audit_middleware import AuditLogMiddleware
|
||||
from api.metrics_guard import MetricsGuardMiddleware
|
||||
from api.security_headers import SecurityHeadersMiddleware
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Depends, FastAPI, Query
|
||||
from fastapi import Depends, FastAPI, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from api.auth import User
|
||||
from models.listing import QueryParameters, ListingType, FurnishType
|
||||
|
|
@ -103,7 +104,7 @@ hist = meter.create_histogram(
|
|||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
|
|
@ -114,6 +115,8 @@ app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
|
|||
app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config)
|
||||
# 1. Audit logging — logs everything including 429s and 403s
|
||||
app.add_middleware(AuditLogMiddleware)
|
||||
# 0. Security headers — adds standard security headers to all responses
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
|
|
@ -324,6 +327,9 @@ async def get_task_status(
|
|||
task_id: str,
|
||||
) -> dict[str, str | int | float | None]:
|
||||
"""Get the status of a background task."""
|
||||
user_tasks = task_service.get_user_tasks(user.email)
|
||||
if task_id not in user_tasks:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
status = task_service.get_task_status(task_id)
|
||||
return {
|
||||
"task_id": status.task_id,
|
||||
|
|
@ -333,8 +339,8 @@ async def get_task_status(
|
|||
"processed": status.processed,
|
||||
"total": status.total,
|
||||
"message": status.message,
|
||||
"error": status.error,
|
||||
"traceback": status.traceback,
|
||||
"error": status.error if APP_ENV != "production" else None,
|
||||
"traceback": status.traceback if APP_ENV != "production" else None,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,15 @@ import os
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
APP_ENV = os.getenv("APP_ENV", "development")
|
||||
|
||||
# Authentik OIDC Configuration
|
||||
AUTHENTIK_URL = os.getenv("AUTHENTIK_URL", "https://authentik.viktorbarzin.me")
|
||||
OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe")
|
||||
OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "")
|
||||
if APP_ENV == "production" and not OIDC_CLIENT_ID:
|
||||
raise RuntimeError("OIDC_CLIENT_ID must be set in production")
|
||||
if not OIDC_CLIENT_ID:
|
||||
_logger.warning("OIDC_CLIENT_ID not set; OIDC login will not work")
|
||||
OIDC_METADATA_URL = (
|
||||
f"{AUTHENTIK_URL}/application/o/wrongmove/.well-known/openid-configuration"
|
||||
)
|
||||
|
|
@ -27,6 +33,8 @@ WEBAUTHN_ORIGIN = os.getenv("WEBAUTHN_ORIGIN", "https://localhost")
|
|||
# JWT Configuration (for passkey-issued tokens)
|
||||
JWT_SECRET = os.getenv("JWT_SECRET", "change-me-in-production")
|
||||
if JWT_SECRET == "change-me-in-production":
|
||||
if APP_ENV == "production":
|
||||
raise RuntimeError("JWT_SECRET must be changed from default in production")
|
||||
_logger.warning("JWT_SECRET is using the default value. Set JWT_SECRET env var in production.")
|
||||
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
|
||||
JWT_EXPIRATION_HOURS = int(os.getenv("JWT_EXPIRATION_HOURS", "24"))
|
||||
|
|
|
|||
|
|
@ -44,9 +44,11 @@ async def register_begin(body: RegisterBeginRequest) -> RegisterBeginResponse:
|
|||
body.email, user_repo
|
||||
)
|
||||
return RegisterBeginResponse(options=options, session_id=session_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Registration begin failed: {e}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception:
|
||||
logger.exception("Registration begin failed")
|
||||
raise HTTPException(status_code=400, detail="Registration failed. Please try again.")
|
||||
|
||||
|
||||
@passkey_router.post("/register/complete", response_model=AuthTokenResponse)
|
||||
|
|
@ -60,9 +62,9 @@ async def register_complete(body: CeremonyCompleteRequest) -> AuthTokenResponse:
|
|||
return AuthTokenResponse(token=token)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Registration complete failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception:
|
||||
logger.exception("Registration complete failed")
|
||||
raise HTTPException(status_code=400, detail="Registration could not be completed.")
|
||||
|
||||
|
||||
@passkey_router.post("/login/begin", response_model=LoginBeginResponse)
|
||||
|
|
@ -72,9 +74,11 @@ async def login_begin() -> LoginBeginResponse:
|
|||
user_repo = UserRepository(engine)
|
||||
options, session_id = passkey_service.begin_authentication(user_repo)
|
||||
return LoginBeginResponse(options=options, session_id=session_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Login begin failed: {e}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception:
|
||||
logger.exception("Login begin failed")
|
||||
raise HTTPException(status_code=400, detail="Login initiation failed. Please try again.")
|
||||
|
||||
|
||||
@passkey_router.post("/login/complete", response_model=AuthTokenResponse)
|
||||
|
|
@ -88,6 +92,6 @@ async def login_complete(body: CeremonyCompleteRequest) -> AuthTokenResponse:
|
|||
return AuthTokenResponse(token=token)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Login complete failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception:
|
||||
logger.exception("Login complete failed")
|
||||
raise HTTPException(status_code=400, detail="Login could not be completed.")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.auth import User, get_current_user
|
||||
from database import engine
|
||||
|
|
@ -17,17 +17,17 @@ poi_router = APIRouter(prefix="/api/poi", tags=["poi"])
|
|||
|
||||
|
||||
class CreatePOIRequest(BaseModel):
|
||||
name: str
|
||||
address: str
|
||||
latitude: float
|
||||
longitude: float
|
||||
name: str = Field(max_length=200)
|
||||
address: str = Field(max_length=500)
|
||||
latitude: float = Field(ge=-90, le=90)
|
||||
longitude: float = Field(ge=-180, le=180)
|
||||
|
||||
|
||||
class UpdatePOIRequest(BaseModel):
|
||||
name: str | None = None
|
||||
address: str | None = None
|
||||
latitude: float | None = None
|
||||
longitude: float | None = None
|
||||
name: str | None = Field(default=None, max_length=200)
|
||||
address: str | None = Field(default=None, max_length=500)
|
||||
latitude: float | None = Field(default=None, ge=-90, le=90)
|
||||
longitude: float | None = Field(default=None, ge=-180, le=180)
|
||||
|
||||
|
||||
class POIResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -47,6 +47,9 @@ class RateLimitConfig:
|
|||
# Metrics endpoint IP allowlist (comma-separated CIDRs)
|
||||
metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1"
|
||||
|
||||
# X-Forwarded-For trusted proxy depth
|
||||
trusted_proxy_depth: int = 1
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> Self:
|
||||
"""Load configuration from environment variables.
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import jwt
|
||||
|
|
@ -54,21 +55,42 @@ def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
|
|||
return config.endpoint_limits.get(path)
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
def _client_ip(request: Request, depth: int = 1) -> str:
|
||||
"""Best-effort client IP from X-Forwarded-For or connection."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
parts = [p.strip() for p in forwarded.split(",")]
|
||||
idx = max(0, len(parts) - depth)
|
||||
return parts[idx]
|
||||
client = request.client
|
||||
return client.host if client else "unknown"
|
||||
|
||||
|
||||
class _InMemoryCounter:
|
||||
"""Simple fixed-window counter for rate limiting when Redis is unavailable."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._windows: dict[str, tuple[int, float]] = {}
|
||||
|
||||
def check(self, key: str, max_requests: int, window_seconds: int) -> tuple[bool, int]:
|
||||
"""Returns (allowed, remaining). Increments counter."""
|
||||
now = time.monotonic()
|
||||
count, window_start = self._windows.get(key, (0, now))
|
||||
if now - window_start >= window_seconds:
|
||||
count, window_start = 0, now
|
||||
count += 1
|
||||
self._windows[key] = (count, window_start)
|
||||
remaining = max(0, max_requests - count)
|
||||
return count <= max_requests, remaining
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
|
||||
|
||||
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
||||
super().__init__(app)
|
||||
self.config = config or RateLimitConfig.from_env()
|
||||
self._fallback = _InMemoryCounter()
|
||||
try:
|
||||
self._redis = _get_rate_limit_redis(self.config)
|
||||
self._redis.ping()
|
||||
|
|
@ -88,11 +110,22 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
|||
return await call_next(request)
|
||||
|
||||
# Determine identity for the counter key
|
||||
identity = _extract_user_email(request) or _client_ip(request)
|
||||
identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
|
||||
|
||||
# If Redis is unavailable, fail open
|
||||
# If Redis is unavailable, use in-memory fallback
|
||||
if self._redis is None:
|
||||
return await call_next(request)
|
||||
fallback_key = f"ratelimit:{identity}:{path}"
|
||||
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"},
|
||||
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
redis_key = f"ratelimit:{identity}:{path}"
|
||||
try:
|
||||
|
|
@ -128,5 +161,16 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
|||
return response
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Rate limiter Redis error, failing open: {e}")
|
||||
return await call_next(request)
|
||||
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
|
||||
fallback_key = f"ratelimit:{identity}:{path}"
|
||||
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"},
|
||||
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
|
|
|||
28
api/security_headers.py
Normal file
28
api/security_headers.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""Security headers middleware."""
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add standard security headers to every response."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||
response = await call_next(request)
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"connect-src 'self' https://*.mapbox.com; "
|
||||
"img-src 'self' data: https://*.mapbox.com https://media.rightmove.co.uk; "
|
||||
"frame-ancestors 'none'"
|
||||
)
|
||||
# Only add HSTS when behind TLS-terminating proxy
|
||||
if request.headers.get("x-forwarded-proto") == "https":
|
||||
response.headers["Strict-Transport-Security"] = (
|
||||
"max-age=63072000; includeSubDomains"
|
||||
)
|
||||
return response
|
||||
|
|
@ -248,4 +248,12 @@ class QueryParameters(BaseModel):
|
|||
raise ValueError(
|
||||
f"min_price_per_sqm ({self.min_price_per_sqm}) must be <= max_price_per_sqm ({self.max_price_per_sqm})"
|
||||
)
|
||||
if (
|
||||
self.min_sqm is not None
|
||||
and self.max_sqm is not None
|
||||
and self.min_sqm > self.max_sqm
|
||||
):
|
||||
raise ValueError(
|
||||
f"min_sqm ({self.min_sqm}) must be <= max_sqm ({self.max_sqm})"
|
||||
)
|
||||
return self
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue