Harden backend security: IDOR fix, error sanitization, rate limiter fallback, security headers
- Fix task status IDOR by adding ownership check; suppress traceback/error in production - Passkey routes: return generic error messages for internal exceptions, keep ValueError for user-facing - JWT_SECRET and OIDC_CLIENT_ID: raise RuntimeError in production when using defaults - Rate limiter: add in-memory fallback counter when Redis is unavailable - Fix X-Forwarded-For IP spoofing with trusted_proxy_depth (rightmost-N selection) - Add SecurityHeadersMiddleware (X-Content-Type-Options, X-Frame-Options, CSP, conditional HSTS) - CORS: add PUT/DELETE methods for POI routes - POI input validation: field length and coordinate range constraints - QueryParameters: add min_sqm <= max_sqm validation
This commit is contained in:
parent
e431eaf2aa
commit
0a9a83507e
8 changed files with 133 additions and 32 deletions
16
api/app.py
16
api/app.py
|
|
@ -5,15 +5,16 @@ import logging
|
||||||
import logging.config
|
import logging.config
|
||||||
from typing import Annotated, AsyncGenerator, Optional
|
from typing import Annotated, AsyncGenerator, Optional
|
||||||
from api.auth import get_current_user
|
from api.auth import get_current_user
|
||||||
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS, APP_ENV
|
||||||
from api.passkey_routes import passkey_router
|
from api.passkey_routes import passkey_router
|
||||||
from api.poi_routes import poi_router
|
from api.poi_routes import poi_router
|
||||||
from api.rate_limit_config import RateLimitConfig
|
from api.rate_limit_config import RateLimitConfig
|
||||||
from api.rate_limiter import RateLimitMiddleware
|
from api.rate_limiter import RateLimitMiddleware
|
||||||
from api.audit_middleware import AuditLogMiddleware
|
from api.audit_middleware import AuditLogMiddleware
|
||||||
from api.metrics_guard import MetricsGuardMiddleware
|
from api.metrics_guard import MetricsGuardMiddleware
|
||||||
|
from api.security_headers import SecurityHeadersMiddleware
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import Depends, FastAPI, Query
|
from fastapi import Depends, FastAPI, HTTPException, Query
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from api.auth import User
|
from api.auth import User
|
||||||
from models.listing import QueryParameters, ListingType, FurnishType
|
from models.listing import QueryParameters, ListingType, FurnishType
|
||||||
|
|
@ -103,7 +104,7 @@ hist = meter.create_histogram(
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
|
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
|
||||||
allow_methods=["GET", "POST"],
|
allow_methods=["GET", "POST", "PUT", "DELETE"],
|
||||||
allow_headers=["Authorization", "Content-Type"],
|
allow_headers=["Authorization", "Content-Type"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -114,6 +115,8 @@ app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
|
||||||
app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config)
|
app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config)
|
||||||
# 1. Audit logging — logs everything including 429s and 403s
|
# 1. Audit logging — logs everything including 429s and 403s
|
||||||
app.add_middleware(AuditLogMiddleware)
|
app.add_middleware(AuditLogMiddleware)
|
||||||
|
# 0. Security headers — adds standard security headers to all responses
|
||||||
|
app.add_middleware(SecurityHeadersMiddleware)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/status")
|
@app.get("/api/status")
|
||||||
|
|
@ -324,6 +327,9 @@ async def get_task_status(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
) -> dict[str, str | int | float | None]:
|
) -> dict[str, str | int | float | None]:
|
||||||
"""Get the status of a background task."""
|
"""Get the status of a background task."""
|
||||||
|
user_tasks = task_service.get_user_tasks(user.email)
|
||||||
|
if task_id not in user_tasks:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
status = task_service.get_task_status(task_id)
|
status = task_service.get_task_status(task_id)
|
||||||
return {
|
return {
|
||||||
"task_id": status.task_id,
|
"task_id": status.task_id,
|
||||||
|
|
@ -333,8 +339,8 @@ async def get_task_status(
|
||||||
"processed": status.processed,
|
"processed": status.processed,
|
||||||
"total": status.total,
|
"total": status.total,
|
||||||
"message": status.message,
|
"message": status.message,
|
||||||
"error": status.error,
|
"error": status.error if APP_ENV != "production" else None,
|
||||||
"traceback": status.traceback,
|
"traceback": status.traceback if APP_ENV != "production" else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,15 @@ import os
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
APP_ENV = os.getenv("APP_ENV", "development")
|
||||||
|
|
||||||
# Authentik OIDC Configuration
|
# Authentik OIDC Configuration
|
||||||
AUTHENTIK_URL = os.getenv("AUTHENTIK_URL", "https://authentik.viktorbarzin.me")
|
AUTHENTIK_URL = os.getenv("AUTHENTIK_URL", "https://authentik.viktorbarzin.me")
|
||||||
OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe")
|
OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "")
|
||||||
|
if APP_ENV == "production" and not OIDC_CLIENT_ID:
|
||||||
|
raise RuntimeError("OIDC_CLIENT_ID must be set in production")
|
||||||
|
if not OIDC_CLIENT_ID:
|
||||||
|
_logger.warning("OIDC_CLIENT_ID not set; OIDC login will not work")
|
||||||
OIDC_METADATA_URL = (
|
OIDC_METADATA_URL = (
|
||||||
f"{AUTHENTIK_URL}/application/o/wrongmove/.well-known/openid-configuration"
|
f"{AUTHENTIK_URL}/application/o/wrongmove/.well-known/openid-configuration"
|
||||||
)
|
)
|
||||||
|
|
@ -27,6 +33,8 @@ WEBAUTHN_ORIGIN = os.getenv("WEBAUTHN_ORIGIN", "https://localhost")
|
||||||
# JWT Configuration (for passkey-issued tokens)
|
# JWT Configuration (for passkey-issued tokens)
|
||||||
JWT_SECRET = os.getenv("JWT_SECRET", "change-me-in-production")
|
JWT_SECRET = os.getenv("JWT_SECRET", "change-me-in-production")
|
||||||
if JWT_SECRET == "change-me-in-production":
|
if JWT_SECRET == "change-me-in-production":
|
||||||
|
if APP_ENV == "production":
|
||||||
|
raise RuntimeError("JWT_SECRET must be changed from default in production")
|
||||||
_logger.warning("JWT_SECRET is using the default value. Set JWT_SECRET env var in production.")
|
_logger.warning("JWT_SECRET is using the default value. Set JWT_SECRET env var in production.")
|
||||||
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
|
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
|
||||||
JWT_EXPIRATION_HOURS = int(os.getenv("JWT_EXPIRATION_HOURS", "24"))
|
JWT_EXPIRATION_HOURS = int(os.getenv("JWT_EXPIRATION_HOURS", "24"))
|
||||||
|
|
|
||||||
|
|
@ -44,9 +44,11 @@ async def register_begin(body: RegisterBeginRequest) -> RegisterBeginResponse:
|
||||||
body.email, user_repo
|
body.email, user_repo
|
||||||
)
|
)
|
||||||
return RegisterBeginResponse(options=options, session_id=session_id)
|
return RegisterBeginResponse(options=options, session_id=session_id)
|
||||||
except Exception as e:
|
except ValueError as e:
|
||||||
logger.error(f"Registration begin failed: {e}")
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Registration begin failed")
|
||||||
|
raise HTTPException(status_code=400, detail="Registration failed. Please try again.")
|
||||||
|
|
||||||
|
|
||||||
@passkey_router.post("/register/complete", response_model=AuthTokenResponse)
|
@passkey_router.post("/register/complete", response_model=AuthTokenResponse)
|
||||||
|
|
@ -60,9 +62,9 @@ async def register_complete(body: CeremonyCompleteRequest) -> AuthTokenResponse:
|
||||||
return AuthTokenResponse(token=token)
|
return AuthTokenResponse(token=token)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"Registration complete failed: {e}")
|
logger.exception("Registration complete failed")
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail="Registration could not be completed.")
|
||||||
|
|
||||||
|
|
||||||
@passkey_router.post("/login/begin", response_model=LoginBeginResponse)
|
@passkey_router.post("/login/begin", response_model=LoginBeginResponse)
|
||||||
|
|
@ -72,9 +74,11 @@ async def login_begin() -> LoginBeginResponse:
|
||||||
user_repo = UserRepository(engine)
|
user_repo = UserRepository(engine)
|
||||||
options, session_id = passkey_service.begin_authentication(user_repo)
|
options, session_id = passkey_service.begin_authentication(user_repo)
|
||||||
return LoginBeginResponse(options=options, session_id=session_id)
|
return LoginBeginResponse(options=options, session_id=session_id)
|
||||||
except Exception as e:
|
except ValueError as e:
|
||||||
logger.error(f"Login begin failed: {e}")
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Login begin failed")
|
||||||
|
raise HTTPException(status_code=400, detail="Login initiation failed. Please try again.")
|
||||||
|
|
||||||
|
|
||||||
@passkey_router.post("/login/complete", response_model=AuthTokenResponse)
|
@passkey_router.post("/login/complete", response_model=AuthTokenResponse)
|
||||||
|
|
@ -88,6 +92,6 @@ async def login_complete(body: CeremonyCompleteRequest) -> AuthTokenResponse:
|
||||||
return AuthTokenResponse(token=token)
|
return AuthTokenResponse(token=token)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"Login complete failed: {e}")
|
logger.exception("Login complete failed")
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail="Login could not be completed.")
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from api.auth import User, get_current_user
|
from api.auth import User, get_current_user
|
||||||
from database import engine
|
from database import engine
|
||||||
|
|
@ -17,17 +17,17 @@ poi_router = APIRouter(prefix="/api/poi", tags=["poi"])
|
||||||
|
|
||||||
|
|
||||||
class CreatePOIRequest(BaseModel):
|
class CreatePOIRequest(BaseModel):
|
||||||
name: str
|
name: str = Field(max_length=200)
|
||||||
address: str
|
address: str = Field(max_length=500)
|
||||||
latitude: float
|
latitude: float = Field(ge=-90, le=90)
|
||||||
longitude: float
|
longitude: float = Field(ge=-180, le=180)
|
||||||
|
|
||||||
|
|
||||||
class UpdatePOIRequest(BaseModel):
|
class UpdatePOIRequest(BaseModel):
|
||||||
name: str | None = None
|
name: str | None = Field(default=None, max_length=200)
|
||||||
address: str | None = None
|
address: str | None = Field(default=None, max_length=500)
|
||||||
latitude: float | None = None
|
latitude: float | None = Field(default=None, ge=-90, le=90)
|
||||||
longitude: float | None = None
|
longitude: float | None = Field(default=None, ge=-180, le=180)
|
||||||
|
|
||||||
|
|
||||||
class POIResponse(BaseModel):
|
class POIResponse(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,9 @@ class RateLimitConfig:
|
||||||
# Metrics endpoint IP allowlist (comma-separated CIDRs)
|
# Metrics endpoint IP allowlist (comma-separated CIDRs)
|
||||||
metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1"
|
metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1"
|
||||||
|
|
||||||
|
# X-Forwarded-For trusted proxy depth
|
||||||
|
trusted_proxy_depth: int = 1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_env(cls) -> Self:
|
def from_env(cls) -> Self:
|
||||||
"""Load configuration from environment variables.
|
"""Load configuration from environment variables.
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
@ -54,21 +55,42 @@ def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
|
||||||
return config.endpoint_limits.get(path)
|
return config.endpoint_limits.get(path)
|
||||||
|
|
||||||
|
|
||||||
def _client_ip(request: Request) -> str:
|
def _client_ip(request: Request, depth: int = 1) -> str:
|
||||||
"""Best-effort client IP from X-Forwarded-For or connection."""
|
"""Best-effort client IP from X-Forwarded-For or connection."""
|
||||||
forwarded = request.headers.get("x-forwarded-for")
|
forwarded = request.headers.get("x-forwarded-for")
|
||||||
if forwarded:
|
if forwarded:
|
||||||
return forwarded.split(",")[0].strip()
|
parts = [p.strip() for p in forwarded.split(",")]
|
||||||
|
idx = max(0, len(parts) - depth)
|
||||||
|
return parts[idx]
|
||||||
client = request.client
|
client = request.client
|
||||||
return client.host if client else "unknown"
|
return client.host if client else "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class _InMemoryCounter:
|
||||||
|
"""Simple fixed-window counter for rate limiting when Redis is unavailable."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._windows: dict[str, tuple[int, float]] = {}
|
||||||
|
|
||||||
|
def check(self, key: str, max_requests: int, window_seconds: int) -> tuple[bool, int]:
|
||||||
|
"""Returns (allowed, remaining). Increments counter."""
|
||||||
|
now = time.monotonic()
|
||||||
|
count, window_start = self._windows.get(key, (0, now))
|
||||||
|
if now - window_start >= window_seconds:
|
||||||
|
count, window_start = 0, now
|
||||||
|
count += 1
|
||||||
|
self._windows[key] = (count, window_start)
|
||||||
|
remaining = max(0, max_requests - count)
|
||||||
|
return count <= max_requests, remaining
|
||||||
|
|
||||||
|
|
||||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
|
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
|
||||||
|
|
||||||
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
||||||
super().__init__(app)
|
super().__init__(app)
|
||||||
self.config = config or RateLimitConfig.from_env()
|
self.config = config or RateLimitConfig.from_env()
|
||||||
|
self._fallback = _InMemoryCounter()
|
||||||
try:
|
try:
|
||||||
self._redis = _get_rate_limit_redis(self.config)
|
self._redis = _get_rate_limit_redis(self.config)
|
||||||
self._redis.ping()
|
self._redis.ping()
|
||||||
|
|
@ -88,11 +110,22 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# Determine identity for the counter key
|
# Determine identity for the counter key
|
||||||
identity = _extract_user_email(request) or _client_ip(request)
|
identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
|
||||||
|
|
||||||
# If Redis is unavailable, fail open
|
# If Redis is unavailable, use in-memory fallback
|
||||||
if self._redis is None:
|
if self._redis is None:
|
||||||
return await call_next(request)
|
fallback_key = f"ratelimit:{identity}:{path}"
|
||||||
|
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
|
||||||
|
if not allowed:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={"detail": "Rate limit exceeded"},
|
||||||
|
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
|
||||||
|
)
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||||
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||||
|
return response
|
||||||
|
|
||||||
redis_key = f"ratelimit:{identity}:{path}"
|
redis_key = f"ratelimit:{identity}:{path}"
|
||||||
try:
|
try:
|
||||||
|
|
@ -128,5 +161,16 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except redis.RedisError as e:
|
except redis.RedisError as e:
|
||||||
logger.warning(f"Rate limiter Redis error, failing open: {e}")
|
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
|
||||||
return await call_next(request)
|
fallback_key = f"ratelimit:{identity}:{path}"
|
||||||
|
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
|
||||||
|
if not allowed:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={"detail": "Rate limit exceeded"},
|
||||||
|
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
|
||||||
|
)
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||||
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||||
|
return response
|
||||||
|
|
|
||||||
28
api/security_headers.py
Normal file
28
api/security_headers.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""Security headers middleware."""
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Add standard security headers to every response."""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
response.headers["X-Frame-Options"] = "DENY"
|
||||||
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||||
|
response.headers["Content-Security-Policy"] = (
|
||||||
|
"default-src 'self'; "
|
||||||
|
"script-src 'self'; "
|
||||||
|
"style-src 'self' 'unsafe-inline'; "
|
||||||
|
"connect-src 'self' https://*.mapbox.com; "
|
||||||
|
"img-src 'self' data: https://*.mapbox.com https://media.rightmove.co.uk; "
|
||||||
|
"frame-ancestors 'none'"
|
||||||
|
)
|
||||||
|
# Only add HSTS when behind TLS-terminating proxy
|
||||||
|
if request.headers.get("x-forwarded-proto") == "https":
|
||||||
|
response.headers["Strict-Transport-Security"] = (
|
||||||
|
"max-age=63072000; includeSubDomains"
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
@ -248,4 +248,12 @@ class QueryParameters(BaseModel):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"min_price_per_sqm ({self.min_price_per_sqm}) must be <= max_price_per_sqm ({self.max_price_per_sqm})"
|
f"min_price_per_sqm ({self.min_price_per_sqm}) must be <= max_price_per_sqm ({self.max_price_per_sqm})"
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
self.min_sqm is not None
|
||||||
|
and self.max_sqm is not None
|
||||||
|
and self.min_sqm > self.max_sqm
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"min_sqm ({self.min_sqm}) must be <= max_sqm ({self.max_sqm})"
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue