wrongmove/api/rate_limiter.py

177 lines
6.9 KiB
Python
Raw Normal View History

"""Per-user rate limiting middleware using Redis fixed-window counters."""
from __future__ import annotations
import logging
import os
import time
from urllib.parse import urlparse, urlunparse
import jwt
import redis
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from api.rate_limit_config import EndpointLimit, RateLimitConfig
logger = logging.getLogger("uvicorn")
# Paths exempt from rate limiting
EXEMPT_PATHS = {"/api/status", "/metrics"}
def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg]
"""Get a Redis client for rate limit counters, using the configured DB."""
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
parsed = urlparse(broker_url)
url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}"))
return redis.from_url(url, decode_responses=True)
def _extract_user_email(request: Request) -> str | None:
"""Extract user email from JWT without full verification.
Mirrors the unverified-decode pattern used in api/auth.py:89 for routing.
Only used for rate-limit keying, not for authorization.
"""
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[7:]
try:
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
return payload.get("email")
except jwt.PyJWTError:
return None
def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
"""Find the rate limit for a request path.
Passkey routes (/api/passkey/*) all share the /api/passkey limit.
"""
if path.startswith("/api/passkey"):
return config.endpoint_limits.get("/api/passkey")
return config.endpoint_limits.get(path)
def _client_ip(request: Request, depth: int = 1) -> str:
"""Best-effort client IP from X-Forwarded-For or connection."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
parts = [p.strip() for p in forwarded.split(",")]
idx = max(0, len(parts) - depth)
return parts[idx]
client = request.client
return client.host if client else "unknown"
class _InMemoryCounter:
"""Simple fixed-window counter for rate limiting when Redis is unavailable."""
def __init__(self) -> None:
self._windows: dict[str, tuple[int, float]] = {}
def check(self, key: str, max_requests: int, window_seconds: int) -> tuple[bool, int]:
"""Returns (allowed, remaining). Increments counter."""
now = time.monotonic()
count, window_start = self._windows.get(key, (0, now))
if now - window_start >= window_seconds:
count, window_start = 0, now
count += 1
self._windows[key] = (count, window_start)
remaining = max(0, max_requests - count)
return count <= max_requests, remaining
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
super().__init__(app)
self.config = config or RateLimitConfig.from_env()
self._fallback = _InMemoryCounter()
try:
self._redis = _get_rate_limit_redis(self.config)
self._redis.ping()
except redis.RedisError:
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
self._redis = None
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
path = request.url.path
# Skip exempt paths
if path in EXEMPT_PATHS:
return await call_next(request)
limit = _match_endpoint(path, self.config)
if limit is None:
return await call_next(request)
# Determine identity for the counter key
identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
# If Redis is unavailable, use in-memory fallback
if self._redis is None:
fallback_key = f"ratelimit:{identity}:{path}"
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
redis_key = f"ratelimit:{identity}:{path}"
try:
pipe = self._redis.pipeline(transaction=True)
pipe.incr(redis_key)
pipe.ttl(redis_key)
result = pipe.execute()
current_count: int = result[0]
ttl: int = result[1]
# Set expiry on first request in window
if ttl == -1:
self._redis.expire(redis_key, limit.window_seconds)
ttl = limit.window_seconds
remaining = max(0, limit.max_requests - current_count)
if current_count > limit.max_requests:
retry_after = max(1, ttl)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": str(limit.max_requests),
"X-RateLimit-Remaining": "0",
},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
except redis.RedisError as e:
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
fallback_key = f"ratelimit:{identity}:{path}"
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response