Add API rate limiting, metrics guard, and audit middleware
Per-user rate limits via Redis sliding window, IP-restricted /metrics endpoint, audit logging of all requests, CORS tightening, and export caps on listing/geojson endpoints.
This commit is contained in:
parent
08ac72bbfc
commit
87b5bd8676
8 changed files with 756 additions and 2 deletions
132
api/rate_limiter.py
Normal file
132
api/rate_limiter.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""Per-user rate limiting middleware using Redis fixed-window counters."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import jwt
|
||||
import redis
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from api.rate_limit_config import EndpointLimit, RateLimitConfig
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
# Paths exempt from rate limiting
|
||||
EXEMPT_PATHS = {"/api/status", "/metrics"}
|
||||
|
||||
|
||||
def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg]
|
||||
"""Get a Redis client for rate limit counters, using the configured DB."""
|
||||
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
parsed = urlparse(broker_url)
|
||||
url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}"))
|
||||
return redis.from_url(url, decode_responses=True)
|
||||
|
||||
|
||||
def _extract_user_email(request: Request) -> str | None:
|
||||
"""Extract user email from JWT without full verification.
|
||||
|
||||
Mirrors the unverified-decode pattern used in api/auth.py:89 for routing.
|
||||
Only used for rate-limit keying, not for authorization.
|
||||
"""
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
|
||||
return payload.get("email")
|
||||
except jwt.PyJWTError:
|
||||
return None
|
||||
|
||||
|
||||
def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
|
||||
"""Find the rate limit for a request path.
|
||||
|
||||
Passkey routes (/api/passkey/*) all share the /api/passkey limit.
|
||||
"""
|
||||
if path.startswith("/api/passkey"):
|
||||
return config.endpoint_limits.get("/api/passkey")
|
||||
return config.endpoint_limits.get(path)
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
"""Best-effort client IP from X-Forwarded-For or connection."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
client = request.client
|
||||
return client.host if client else "unknown"
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
|
||||
|
||||
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
||||
super().__init__(app)
|
||||
self.config = config or RateLimitConfig.from_env()
|
||||
try:
|
||||
self._redis = _get_rate_limit_redis(self.config)
|
||||
self._redis.ping()
|
||||
except redis.RedisError:
|
||||
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
|
||||
self._redis = None
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||
path = request.url.path
|
||||
|
||||
# Skip exempt paths
|
||||
if path in EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
limit = _match_endpoint(path, self.config)
|
||||
if limit is None:
|
||||
return await call_next(request)
|
||||
|
||||
# Determine identity for the counter key
|
||||
identity = _extract_user_email(request) or _client_ip(request)
|
||||
|
||||
# If Redis is unavailable, fail open
|
||||
if self._redis is None:
|
||||
return await call_next(request)
|
||||
|
||||
redis_key = f"ratelimit:{identity}:{path}"
|
||||
try:
|
||||
pipe = self._redis.pipeline(transaction=True)
|
||||
pipe.incr(redis_key)
|
||||
pipe.ttl(redis_key)
|
||||
result = pipe.execute()
|
||||
current_count: int = result[0]
|
||||
ttl: int = result[1]
|
||||
|
||||
# Set expiry on first request in window
|
||||
if ttl == -1:
|
||||
self._redis.expire(redis_key, limit.window_seconds)
|
||||
ttl = limit.window_seconds
|
||||
|
||||
remaining = max(0, limit.max_requests - current_count)
|
||||
|
||||
if current_count > limit.max_requests:
|
||||
retry_after = max(1, ttl)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"},
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"X-RateLimit-Limit": str(limit.max_requests),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Rate limiter Redis error, failing open: {e}")
|
||||
return await call_next(request)
|
||||
Loading…
Add table
Add a link
Reference in a new issue