"""Per-user rate limiting middleware using Redis fixed-window counters.""" from __future__ import annotations import logging import os import time from urllib.parse import urlparse, urlunparse import jwt import redis from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response from api.rate_limit_config import EndpointLimit, RateLimitConfig logger = logging.getLogger("uvicorn") # Paths exempt from rate limiting EXEMPT_PATHS = {"/api/status", "/metrics"} def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg] """Get a Redis client for rate limit counters, using the configured DB.""" broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") parsed = urlparse(broker_url) url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}")) return redis.from_url(url, decode_responses=True) def _extract_user_email(request: Request) -> str | None: """Extract user email from JWT without full verification. Mirrors the unverified-decode pattern used in api/auth.py:89 for routing. Only used for rate-limit keying, not for authorization. """ auth_header = request.headers.get("authorization", "") if not auth_header.startswith("Bearer "): return None token = auth_header[7:] try: payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False}) return payload.get("email") except jwt.PyJWTError: return None def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None: """Find the rate limit for a request path. Passkey routes (/api/passkey/*) all share the /api/passkey limit. """ if path.startswith("/api/passkey"): return config.endpoint_limits.get("/api/passkey") return config.endpoint_limits.get(path) def _client_ip(request: Request, depth: int = 1) -> str: """Best-effort client IP from X-Forwarded-For or connection.""" forwarded = request.headers.get("x-forwarded-for") if forwarded: parts = [p.strip() for p in forwarded.split(",")] idx = max(0, len(parts) - depth) return parts[idx] client = request.client return client.host if client else "unknown" class _InMemoryCounter: """Simple fixed-window counter for rate limiting when Redis is unavailable.""" def __init__(self) -> None: self._windows: dict[str, tuple[int, float]] = {} def check(self, key: str, max_requests: int, window_seconds: int) -> tuple[bool, int]: """Returns (allowed, remaining). Increments counter.""" now = time.monotonic() count, window_start = self._windows.get(key, (0, now)) if now - window_start >= window_seconds: count, window_start = 0, now count += 1 self._windows[key] = (count, window_start) remaining = max(0, max_requests - count) return count <= max_requests, remaining class RateLimitMiddleware(BaseHTTPMiddleware): """Starlette middleware enforcing per-user fixed-window rate limits via Redis.""" def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def] super().__init__(app) self.config = config or RateLimitConfig.from_env() self._fallback = _InMemoryCounter() try: self._redis = _get_rate_limit_redis(self.config) self._redis.ping() except redis.RedisError: logger.warning("Rate limiter: Redis unavailable at startup, will fail open") self._redis = None async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def] path = request.url.path # Skip exempt paths if path in EXEMPT_PATHS: return await call_next(request) limit = _match_endpoint(path, self.config) if limit is None: return await call_next(request) # Determine identity for the counter key identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth) # If Redis is unavailable, use in-memory fallback if self._redis is None: fallback_key = f"ratelimit:{identity}:{path}" allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds) if not allowed: return JSONResponse( status_code=429, content={"detail": "Rate limit exceeded"}, headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"}, ) response = await call_next(request) response.headers["X-RateLimit-Limit"] = str(limit.max_requests) response.headers["X-RateLimit-Remaining"] = str(remaining) return response redis_key = f"ratelimit:{identity}:{path}" try: pipe = self._redis.pipeline(transaction=True) pipe.incr(redis_key) pipe.ttl(redis_key) result = pipe.execute() current_count: int = result[0] ttl: int = result[1] # Set expiry on first request in window if ttl == -1: self._redis.expire(redis_key, limit.window_seconds) ttl = limit.window_seconds remaining = max(0, limit.max_requests - current_count) if current_count > limit.max_requests: retry_after = max(1, ttl) return JSONResponse( status_code=429, content={"detail": "Rate limit exceeded"}, headers={ "Retry-After": str(retry_after), "X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0", }, ) response = await call_next(request) response.headers["X-RateLimit-Limit"] = str(limit.max_requests) response.headers["X-RateLimit-Remaining"] = str(remaining) return response except redis.RedisError as e: logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}") fallback_key = f"ratelimit:{identity}:{path}" allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds) if not allowed: return JSONResponse( status_code=429, content={"detail": "Rate limit exceeded"}, headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"}, ) response = await call_next(request) response.headers["X-RateLimit-Limit"] = str(limit.max_requests) response.headers["X-RateLimit-Remaining"] = str(remaining) return response