"""Per-user rate limiting middleware using Redis fixed-window counters.""" from __future__ import annotations import logging import os from urllib.parse import urlparse, urlunparse import jwt import redis from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response from api.rate_limit_config import EndpointLimit, RateLimitConfig logger = logging.getLogger("uvicorn") # Paths exempt from rate limiting EXEMPT_PATHS = {"/api/status", "/metrics"} def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg] """Get a Redis client for rate limit counters, using the configured DB.""" broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") parsed = urlparse(broker_url) url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}")) return redis.from_url(url, decode_responses=True) def _extract_user_email(request: Request) -> str | None: """Extract user email from JWT without full verification. Mirrors the unverified-decode pattern used in api/auth.py:89 for routing. Only used for rate-limit keying, not for authorization. """ auth_header = request.headers.get("authorization", "") if not auth_header.startswith("Bearer "): return None token = auth_header[7:] try: payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False}) return payload.get("email") except jwt.PyJWTError: return None def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None: """Find the rate limit for a request path. Passkey routes (/api/passkey/*) all share the /api/passkey limit. """ if path.startswith("/api/passkey"): return config.endpoint_limits.get("/api/passkey") return config.endpoint_limits.get(path) def _client_ip(request: Request) -> str: """Best-effort client IP from X-Forwarded-For or connection.""" forwarded = request.headers.get("x-forwarded-for") if forwarded: return forwarded.split(",")[0].strip() client = request.client return client.host if client else "unknown" class RateLimitMiddleware(BaseHTTPMiddleware): """Starlette middleware enforcing per-user fixed-window rate limits via Redis.""" def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def] super().__init__(app) self.config = config or RateLimitConfig.from_env() try: self._redis = _get_rate_limit_redis(self.config) self._redis.ping() except redis.RedisError: logger.warning("Rate limiter: Redis unavailable at startup, will fail open") self._redis = None async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def] path = request.url.path # Skip exempt paths if path in EXEMPT_PATHS: return await call_next(request) limit = _match_endpoint(path, self.config) if limit is None: return await call_next(request) # Determine identity for the counter key identity = _extract_user_email(request) or _client_ip(request) # If Redis is unavailable, fail open if self._redis is None: return await call_next(request) redis_key = f"ratelimit:{identity}:{path}" try: pipe = self._redis.pipeline(transaction=True) pipe.incr(redis_key) pipe.ttl(redis_key) result = pipe.execute() current_count: int = result[0] ttl: int = result[1] # Set expiry on first request in window if ttl == -1: self._redis.expire(redis_key, limit.window_seconds) ttl = limit.window_seconds remaining = max(0, limit.max_requests - current_count) if current_count > limit.max_requests: retry_after = max(1, ttl) return JSONResponse( status_code=429, content={"detail": "Rate limit exceeded"}, headers={ "Retry-After": str(retry_after), "X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0", }, ) response = await call_next(request) response.headers["X-RateLimit-Limit"] = str(limit.max_requests) response.headers["X-RateLimit-Remaining"] = str(remaining) return response except redis.RedisError as e: logger.warning(f"Rate limiter Redis error, failing open: {e}") return await call_next(request)