wrongmove/api/rate_limiter.py

173 lines
6.6 KiB
Python
Raw Normal View History

"""Per-user rate limiting middleware using Redis fixed-window counters."""
from __future__ import annotations
import logging
import os
import time
from collections.abc import Awaitable, Callable
from urllib.parse import urlparse, urlunparse
import jwt
import redis
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.types import ASGIApp
from api.rate_limit_config import EndpointLimit, RateLimitConfig
logger = logging.getLogger("uvicorn")
# Paths exempt from rate limiting
EXEMPT_PATHS = {"/api/status", "/metrics"}
def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg]
"""Get a Redis client for rate limit counters, using the configured DB."""
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
parsed = urlparse(broker_url)
url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}"))
return redis.from_url(url, decode_responses=True)
def _extract_user_email(request: Request) -> str | None:
"""Extract user email from JWT without full verification.
Mirrors the unverified-decode pattern used in api/auth.py:89 for routing.
Only used for rate-limit keying, not for authorization.
"""
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[7:]
try:
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
return payload.get("email")
except jwt.PyJWTError:
return None
def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
"""Find the rate limit for a request path.
Passkey routes (/api/passkey/*) all share the /api/passkey limit.
"""
if path.startswith("/api/passkey"):
return config.endpoint_limits.get("/api/passkey")
return config.endpoint_limits.get(path)
def _client_ip(request: Request, depth: int = 1) -> str:
"""Best-effort client IP from X-Forwarded-For or connection."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
parts = [p.strip() for p in forwarded.split(",")]
idx = max(0, len(parts) - depth)
return parts[idx]
client = request.client
return client.host if client else "unknown"
class _InMemoryCounter:
"""Simple fixed-window counter for rate limiting when Redis is unavailable."""
def __init__(self) -> None:
self._windows: dict[str, tuple[int, float]] = {}
def check(self, key: str, max_requests: int, window_seconds: int) -> tuple[bool, int]:
"""Returns (allowed, remaining). Increments counter."""
now = time.monotonic()
count, window_start = self._windows.get(key, (0, now))
if now - window_start >= window_seconds:
count, window_start = 0, now
count += 1
self._windows[key] = (count, window_start)
remaining = max(0, max_requests - count)
return count <= max_requests, remaining
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
def __init__(self, app: ASGIApp, config: RateLimitConfig | None = None) -> None:
super().__init__(app)
self.config = config or RateLimitConfig.from_env()
self._fallback = _InMemoryCounter()
try:
self._redis: redis.Redis | None = _get_rate_limit_redis(self.config) # type: ignore[type-arg]
self._redis.ping()
except redis.RedisError:
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
self._redis = None
def _check_counter(self, key: str, limit: EndpointLimit) -> tuple[bool, int, int | None]:
"""Check rate limit counter, returning (allowed, remaining, retry_after).
Tries Redis first; falls back to in-memory counter on Redis errors.
retry_after is None for in-memory counters (no TTL available).
"""
if self._redis is None:
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
return allowed, remaining, None
try:
pipe = self._redis.pipeline(transaction=True)
pipe.incr(key)
pipe.ttl(key)
result = pipe.execute()
current_count: int = result[0]
ttl: int = result[1]
# Set expiry on first request in window
if ttl == -1:
self._redis.expire(key, limit.window_seconds)
ttl = limit.window_seconds
remaining = max(0, limit.max_requests - current_count)
allowed = current_count <= limit.max_requests
retry_after = max(1, ttl) if not allowed else None
return allowed, remaining, retry_after
except redis.RedisError as e:
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
return allowed, remaining, None
async def _enforce_limit(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
limit: EndpointLimit,
key: str,
) -> Response:
"""Check the rate limit and either reject with 429 or forward with headers."""
allowed, remaining, retry_after = self._check_counter(key, limit)
if not allowed:
headers: dict[str, str] = {
"X-RateLimit-Limit": str(limit.max_requests),
"X-RateLimit-Remaining": "0",
}
if retry_after is not None:
headers["Retry-After"] = str(retry_after)
return JSONResponse(status_code=429, content={"detail": "Rate limit exceeded"}, headers=headers)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
path = request.url.path
if path in EXEMPT_PATHS:
return await call_next(request)
limit = _match_endpoint(path, self.config)
if limit is None:
return await call_next(request)
identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
key = f"ratelimit:{identity}:{path}"
return await self._enforce_limit(request, call_next, limit, key)