wrongmove/api/rate_limiter.py
Viktor Barzin f833309297
Refactor backend for cleaner error handling, DRY, and type safety
- Extract rate limiter DRY: consolidate 3 duplicated check/respond paths
  into _check_counter and _enforce_limit helpers, add proper type annotations
- Replace bare Exception raises with FloorplanDownloadError and
  RightmoveApiError; narrow catch clauses to specific exception types;
  fix Step base class to inherit from ABC
- Consolidate MAX_OCR_WORKERS into config/scraper_config.py; extract
  _find_tenure_value helper to deduplicate tenure parsing
- Extract _build_poi_distances_lookup from stream endpoint to reduce nesting
- Fix csv_exporter: optional decisions.json, NaN instead of -1 sentinels,
  guard against division by zero on missing square meters
- Fix notifications.py broken list[Surface]() constructor, database.py
  stale comments and missing type annotation, auth.py type:ignore,
  ui_exporter.py stale TODO
- Fix 3 pre-existing test failures: mock cache layer in streaming tests,
  bypass rate limiter for test isolation, fix cache invalidation test to
  account for two-pattern scan loop
2026-02-10 22:19:24 +00:00

172 lines
6.6 KiB
Python

"""Per-user rate limiting middleware using Redis fixed-window counters."""
from __future__ import annotations
import logging
import os
import time
from collections.abc import Awaitable, Callable
from urllib.parse import urlparse, urlunparse
import jwt
import redis
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.types import ASGIApp
from api.rate_limit_config import EndpointLimit, RateLimitConfig
logger = logging.getLogger("uvicorn")
# Paths exempt from rate limiting
EXEMPT_PATHS = {"/api/status", "/metrics"}
def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg]
"""Get a Redis client for rate limit counters, using the configured DB."""
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
parsed = urlparse(broker_url)
url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}"))
return redis.from_url(url, decode_responses=True)
def _extract_user_email(request: Request) -> str | None:
"""Extract user email from JWT without full verification.
Mirrors the unverified-decode pattern used in api/auth.py:89 for routing.
Only used for rate-limit keying, not for authorization.
"""
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[7:]
try:
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
return payload.get("email")
except jwt.PyJWTError:
return None
def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
"""Find the rate limit for a request path.
Passkey routes (/api/passkey/*) all share the /api/passkey limit.
"""
if path.startswith("/api/passkey"):
return config.endpoint_limits.get("/api/passkey")
return config.endpoint_limits.get(path)
def _client_ip(request: Request, depth: int = 1) -> str:
"""Best-effort client IP from X-Forwarded-For or connection."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
parts = [p.strip() for p in forwarded.split(",")]
idx = max(0, len(parts) - depth)
return parts[idx]
client = request.client
return client.host if client else "unknown"
class _InMemoryCounter:
"""Simple fixed-window counter for rate limiting when Redis is unavailable."""
def __init__(self) -> None:
self._windows: dict[str, tuple[int, float]] = {}
def check(self, key: str, max_requests: int, window_seconds: int) -> tuple[bool, int]:
"""Returns (allowed, remaining). Increments counter."""
now = time.monotonic()
count, window_start = self._windows.get(key, (0, now))
if now - window_start >= window_seconds:
count, window_start = 0, now
count += 1
self._windows[key] = (count, window_start)
remaining = max(0, max_requests - count)
return count <= max_requests, remaining
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
def __init__(self, app: ASGIApp, config: RateLimitConfig | None = None) -> None:
super().__init__(app)
self.config = config or RateLimitConfig.from_env()
self._fallback = _InMemoryCounter()
try:
self._redis: redis.Redis | None = _get_rate_limit_redis(self.config) # type: ignore[type-arg]
self._redis.ping()
except redis.RedisError:
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
self._redis = None
def _check_counter(self, key: str, limit: EndpointLimit) -> tuple[bool, int, int | None]:
"""Check rate limit counter, returning (allowed, remaining, retry_after).
Tries Redis first; falls back to in-memory counter on Redis errors.
retry_after is None for in-memory counters (no TTL available).
"""
if self._redis is None:
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
return allowed, remaining, None
try:
pipe = self._redis.pipeline(transaction=True)
pipe.incr(key)
pipe.ttl(key)
result = pipe.execute()
current_count: int = result[0]
ttl: int = result[1]
# Set expiry on first request in window
if ttl == -1:
self._redis.expire(key, limit.window_seconds)
ttl = limit.window_seconds
remaining = max(0, limit.max_requests - current_count)
allowed = current_count <= limit.max_requests
retry_after = max(1, ttl) if not allowed else None
return allowed, remaining, retry_after
except redis.RedisError as e:
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
return allowed, remaining, None
async def _enforce_limit(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
limit: EndpointLimit,
key: str,
) -> Response:
"""Check the rate limit and either reject with 429 or forward with headers."""
allowed, remaining, retry_after = self._check_counter(key, limit)
if not allowed:
headers: dict[str, str] = {
"X-RateLimit-Limit": str(limit.max_requests),
"X-RateLimit-Remaining": "0",
}
if retry_after is not None:
headers["Retry-After"] = str(retry_after)
return JSONResponse(status_code=429, content={"detail": "Rate limit exceeded"}, headers=headers)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
path = request.url.path
if path in EXEMPT_PATHS:
return await call_next(request)
limit = _match_endpoint(path, self.config)
if limit is None:
return await call_next(request)
identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
key = f"ratelimit:{identity}:{path}"
return await self._enforce_limit(request, call_next, limit, key)