diff --git a/api/app.py b/api/app.py index b13bd59..921385f 100644 --- a/api/app.py +++ b/api/app.py @@ -5,15 +5,16 @@ import logging import logging.config from typing import Annotated, AsyncGenerator, Optional from api.auth import get_current_user -from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS +from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS, APP_ENV from api.passkey_routes import passkey_router from api.poi_routes import poi_router from api.rate_limit_config import RateLimitConfig from api.rate_limiter import RateLimitMiddleware from api.audit_middleware import AuditLogMiddleware from api.metrics_guard import MetricsGuardMiddleware +from api.security_headers import SecurityHeadersMiddleware from dotenv import load_dotenv -from fastapi import Depends, FastAPI, Query +from fastapi import Depends, FastAPI, HTTPException, Query from fastapi.responses import StreamingResponse from api.auth import User from models.listing import QueryParameters, ListingType, FurnishType @@ -103,7 +104,7 @@ hist = meter.create_histogram( app.add_middleware( CORSMiddleware, allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS], - allow_methods=["GET", "POST"], + allow_methods=["GET", "POST", "PUT", "DELETE"], allow_headers=["Authorization", "Content-Type"], ) @@ -114,6 +115,8 @@ app.add_middleware(RateLimitMiddleware, config=_rate_limit_config) app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config) # 1. Audit logging — logs everything including 429s and 403s app.add_middleware(AuditLogMiddleware) +# 0. Security headers — adds standard security headers to all responses +app.add_middleware(SecurityHeadersMiddleware) @app.get("/api/status") @@ -324,6 +327,9 @@ async def get_task_status( task_id: str, ) -> dict[str, str | int | float | None]: """Get the status of a background task.""" + user_tasks = task_service.get_user_tasks(user.email) + if task_id not in user_tasks: + raise HTTPException(status_code=404, detail="Task not found") status = task_service.get_task_status(task_id) return { "task_id": status.task_id, @@ -333,8 +339,8 @@ async def get_task_status( "processed": status.processed, "total": status.total, "message": status.message, - "error": status.error, - "traceback": status.traceback, + "error": status.error if APP_ENV != "production" else None, + "traceback": status.traceback if APP_ENV != "production" else None, } diff --git a/api/config.py b/api/config.py index 8821552..3ab7a24 100644 --- a/api/config.py +++ b/api/config.py @@ -5,9 +5,15 @@ import os _logger = logging.getLogger(__name__) +APP_ENV = os.getenv("APP_ENV", "development") + # Authentik OIDC Configuration AUTHENTIK_URL = os.getenv("AUTHENTIK_URL", "https://authentik.viktorbarzin.me") -OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe") +OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "") +if APP_ENV == "production" and not OIDC_CLIENT_ID: + raise RuntimeError("OIDC_CLIENT_ID must be set in production") +if not OIDC_CLIENT_ID: + _logger.warning("OIDC_CLIENT_ID not set; OIDC login will not work") OIDC_METADATA_URL = ( f"{AUTHENTIK_URL}/application/o/wrongmove/.well-known/openid-configuration" ) @@ -27,6 +33,8 @@ WEBAUTHN_ORIGIN = os.getenv("WEBAUTHN_ORIGIN", "https://localhost") # JWT Configuration (for passkey-issued tokens) JWT_SECRET = os.getenv("JWT_SECRET", "change-me-in-production") if JWT_SECRET == "change-me-in-production": + if APP_ENV == "production": + raise RuntimeError("JWT_SECRET must be changed from default in production") _logger.warning("JWT_SECRET is using the default value. Set JWT_SECRET env var in production.") JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256") JWT_EXPIRATION_HOURS = int(os.getenv("JWT_EXPIRATION_HOURS", "24")) diff --git a/api/passkey_routes.py b/api/passkey_routes.py index 6d2a4dc..97f0f04 100644 --- a/api/passkey_routes.py +++ b/api/passkey_routes.py @@ -44,9 +44,11 @@ async def register_begin(body: RegisterBeginRequest) -> RegisterBeginResponse: body.email, user_repo ) return RegisterBeginResponse(options=options, session_id=session_id) - except Exception as e: - logger.error(f"Registration begin failed: {e}") + except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + except Exception: + logger.exception("Registration begin failed") + raise HTTPException(status_code=400, detail="Registration failed. Please try again.") @passkey_router.post("/register/complete", response_model=AuthTokenResponse) @@ -60,9 +62,9 @@ async def register_complete(body: CeremonyCompleteRequest) -> AuthTokenResponse: return AuthTokenResponse(token=token) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - logger.error(f"Registration complete failed: {e}") - raise HTTPException(status_code=400, detail=str(e)) + except Exception: + logger.exception("Registration complete failed") + raise HTTPException(status_code=400, detail="Registration could not be completed.") @passkey_router.post("/login/begin", response_model=LoginBeginResponse) @@ -72,9 +74,11 @@ async def login_begin() -> LoginBeginResponse: user_repo = UserRepository(engine) options, session_id = passkey_service.begin_authentication(user_repo) return LoginBeginResponse(options=options, session_id=session_id) - except Exception as e: - logger.error(f"Login begin failed: {e}") + except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + except Exception: + logger.exception("Login begin failed") + raise HTTPException(status_code=400, detail="Login initiation failed. Please try again.") @passkey_router.post("/login/complete", response_model=AuthTokenResponse) @@ -88,6 +92,6 @@ async def login_complete(body: CeremonyCompleteRequest) -> AuthTokenResponse: return AuthTokenResponse(token=token) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - logger.error(f"Login complete failed: {e}") - raise HTTPException(status_code=400, detail=str(e)) + except Exception: + logger.exception("Login complete failed") + raise HTTPException(status_code=400, detail="Login could not be completed.") diff --git a/api/poi_routes.py b/api/poi_routes.py index 07b57ac..b701980 100644 --- a/api/poi_routes.py +++ b/api/poi_routes.py @@ -2,7 +2,7 @@ import logging from typing import Annotated from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel +from pydantic import BaseModel, Field from api.auth import User, get_current_user from database import engine @@ -17,17 +17,17 @@ poi_router = APIRouter(prefix="/api/poi", tags=["poi"]) class CreatePOIRequest(BaseModel): - name: str - address: str - latitude: float - longitude: float + name: str = Field(max_length=200) + address: str = Field(max_length=500) + latitude: float = Field(ge=-90, le=90) + longitude: float = Field(ge=-180, le=180) class UpdatePOIRequest(BaseModel): - name: str | None = None - address: str | None = None - latitude: float | None = None - longitude: float | None = None + name: str | None = Field(default=None, max_length=200) + address: str | None = Field(default=None, max_length=500) + latitude: float | None = Field(default=None, ge=-90, le=90) + longitude: float | None = Field(default=None, ge=-180, le=180) class POIResponse(BaseModel): diff --git a/api/rate_limit_config.py b/api/rate_limit_config.py index 7d6fad7..0f56375 100644 --- a/api/rate_limit_config.py +++ b/api/rate_limit_config.py @@ -47,6 +47,9 @@ class RateLimitConfig: # Metrics endpoint IP allowlist (comma-separated CIDRs) metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1" + # X-Forwarded-For trusted proxy depth + trusted_proxy_depth: int = 1 + @classmethod def from_env(cls) -> Self: """Load configuration from environment variables. diff --git a/api/rate_limiter.py b/api/rate_limiter.py index 43b2a5d..7f6e9da 100644 --- a/api/rate_limiter.py +++ b/api/rate_limiter.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging import os +import time from urllib.parse import urlparse, urlunparse import jwt @@ -54,21 +55,42 @@ def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None: return config.endpoint_limits.get(path) -def _client_ip(request: Request) -> str: +def _client_ip(request: Request, depth: int = 1) -> str: """Best-effort client IP from X-Forwarded-For or connection.""" forwarded = request.headers.get("x-forwarded-for") if forwarded: - return forwarded.split(",")[0].strip() + parts = [p.strip() for p in forwarded.split(",")] + idx = max(0, len(parts) - depth) + return parts[idx] client = request.client return client.host if client else "unknown" +class _InMemoryCounter: + """Simple fixed-window counter for rate limiting when Redis is unavailable.""" + + def __init__(self) -> None: + self._windows: dict[str, tuple[int, float]] = {} + + def check(self, key: str, max_requests: int, window_seconds: int) -> tuple[bool, int]: + """Returns (allowed, remaining). Increments counter.""" + now = time.monotonic() + count, window_start = self._windows.get(key, (0, now)) + if now - window_start >= window_seconds: + count, window_start = 0, now + count += 1 + self._windows[key] = (count, window_start) + remaining = max(0, max_requests - count) + return count <= max_requests, remaining + + class RateLimitMiddleware(BaseHTTPMiddleware): """Starlette middleware enforcing per-user fixed-window rate limits via Redis.""" def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def] super().__init__(app) self.config = config or RateLimitConfig.from_env() + self._fallback = _InMemoryCounter() try: self._redis = _get_rate_limit_redis(self.config) self._redis.ping() @@ -88,11 +110,22 @@ class RateLimitMiddleware(BaseHTTPMiddleware): return await call_next(request) # Determine identity for the counter key - identity = _extract_user_email(request) or _client_ip(request) + identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth) - # If Redis is unavailable, fail open + # If Redis is unavailable, use in-memory fallback if self._redis is None: - return await call_next(request) + fallback_key = f"ratelimit:{identity}:{path}" + allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds) + if not allowed: + return JSONResponse( + status_code=429, + content={"detail": "Rate limit exceeded"}, + headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"}, + ) + response = await call_next(request) + response.headers["X-RateLimit-Limit"] = str(limit.max_requests) + response.headers["X-RateLimit-Remaining"] = str(remaining) + return response redis_key = f"ratelimit:{identity}:{path}" try: @@ -128,5 +161,16 @@ class RateLimitMiddleware(BaseHTTPMiddleware): return response except redis.RedisError as e: - logger.warning(f"Rate limiter Redis error, failing open: {e}") - return await call_next(request) + logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}") + fallback_key = f"ratelimit:{identity}:{path}" + allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds) + if not allowed: + return JSONResponse( + status_code=429, + content={"detail": "Rate limit exceeded"}, + headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"}, + ) + response = await call_next(request) + response.headers["X-RateLimit-Limit"] = str(limit.max_requests) + response.headers["X-RateLimit-Remaining"] = str(remaining) + return response diff --git a/api/security_headers.py b/api/security_headers.py new file mode 100644 index 0000000..9f091db --- /dev/null +++ b/api/security_headers.py @@ -0,0 +1,28 @@ +"""Security headers middleware.""" +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """Add standard security headers to every response.""" + + async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def] + response = await call_next(request) + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + response.headers["Content-Security-Policy"] = ( + "default-src 'self'; " + "script-src 'self'; " + "style-src 'self' 'unsafe-inline'; " + "connect-src 'self' https://*.mapbox.com; " + "img-src 'self' data: https://*.mapbox.com https://media.rightmove.co.uk; " + "frame-ancestors 'none'" + ) + # Only add HSTS when behind TLS-terminating proxy + if request.headers.get("x-forwarded-proto") == "https": + response.headers["Strict-Transport-Security"] = ( + "max-age=63072000; includeSubDomains" + ) + return response diff --git a/models/listing.py b/models/listing.py index 2af77e2..9d969ec 100644 --- a/models/listing.py +++ b/models/listing.py @@ -248,4 +248,12 @@ class QueryParameters(BaseModel): raise ValueError( f"min_price_per_sqm ({self.min_price_per_sqm}) must be <= max_price_per_sqm ({self.max_price_per_sqm})" ) + if ( + self.min_sqm is not None + and self.max_sqm is not None + and self.min_sqm > self.max_sqm + ): + raise ValueError( + f"min_sqm ({self.min_sqm}) must be <= max_sqm ({self.max_sqm})" + ) return self