Add API rate limiting, metrics guard, and audit middleware

Per-user rate limits via Redis sliding window, IP-restricted /metrics
endpoint, audit logging of all requests, CORS tightening, and export
caps on listing/geojson endpoints.
This commit is contained in:
Viktor Barzin 2026-02-08 00:45:43 +00:00
parent 08ac72bbfc
commit 87b5bd8676
No known key found for this signature in database
GPG key ID: 0EB088298288D958
8 changed files with 756 additions and 2 deletions

View file

@ -7,6 +7,10 @@ from typing import Annotated, AsyncGenerator, Optional
from api.auth import get_current_user
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
from api.passkey_routes import passkey_router
from api.rate_limit_config import RateLimitConfig
from api.rate_limiter import RateLimitMiddleware
from api.audit_middleware import AuditLogMiddleware
from api.metrics_guard import MetricsGuardMiddleware
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Query
from fastapi.responses import StreamingResponse
@ -33,6 +37,7 @@ load_dotenv()
logger = logging.getLogger("uvicorn")
DEFAULT_BATCH_SIZE = 50
_rate_limit_config = RateLimitConfig.from_env()
def get_query_parameters(
@ -82,10 +87,18 @@ hist = meter.create_histogram(
app.add_middleware(
CORSMiddleware,
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST"],
allow_headers=["Authorization", "Content-Type"],
)
# Security middleware (added bottom-to-top; last added = outermost)
# 3. Rate limiting — enforces per-user limits
app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
# 2. Metrics guard — blocks unauthorized /metrics access
app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config)
# 1. Audit logging — logs everything including 429s and 403s
app.add_middleware(AuditLogMiddleware)
@app.get("/api/status")
async def get_status() -> dict[str, str]:
@ -100,6 +113,7 @@ async def get_listing(
limit: int = 5,
) -> dict[str, list]:
"""Get listings from the database."""
limit = min(limit, _rate_limit_config.listing_limit_cap)
repository = ListingRepository(engine)
result = await listing_service.get_listings(repository, limit=limit)
logger.info(f"Fetched {result.total_count} listings for {user.email}")
@ -113,6 +127,10 @@ async def get_listing_geojson(
limit: int | None = None,
) -> dict:
"""Get listings as GeoJSON for map display."""
if limit is not None:
limit = min(limit, _rate_limit_config.geojson_limit_cap)
else:
limit = _rate_limit_config.geojson_limit_cap
repository = ListingRepository(engine)
result = await export_service.export_to_geojson(
repository,
@ -204,6 +222,12 @@ async def stream_listing_geojson(
- batch: Array of GeoJSON features
- complete: Final message with total count
"""
batch_size = min(batch_size, _rate_limit_config.geojson_stream_batch_size_cap)
if limit is not None:
limit = min(limit, _rate_limit_config.geojson_stream_limit_cap)
else:
limit = _rate_limit_config.geojson_stream_limit_cap
cached_count = get_cached_count(query_parameters)
if cached_count is not None and cached_count > 0:
generator = _stream_from_cache(query_parameters, batch_size, limit)

63
api/audit_middleware.py Normal file
View file

@ -0,0 +1,63 @@
"""Audit logging middleware for API requests."""
from __future__ import annotations
import logging
import time
import jwt
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
audit_logger = logging.getLogger("uvicorn.audit")
def _extract_identity(request: Request) -> str:
"""Extract user email from JWT for audit logging."""
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
return "anonymous"
token = auth_header[7:]
try:
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
return payload.get("email", "unknown")
except jwt.PyJWTError:
return "invalid-token"
def _client_ip(request: Request) -> str:
"""Best-effort client IP."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
client = request.client
return client.host if client else "unknown"
class AuditLogMiddleware(BaseHTTPMiddleware):
"""Logs all /api/ requests with method, path, user, IP, status, and duration."""
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
path = request.url.path
if not path.startswith("/api/"):
return await call_next(request)
start = time.monotonic()
identity = _extract_identity(request)
ip = _client_ip(request)
query = str(request.query_params) if request.query_params else ""
response = await call_next(request)
duration_ms = (time.monotonic() - start) * 1000
audit_logger.info(
"method=%s path=%s query=%s user=%s ip=%s status=%d duration_ms=%.1f",
request.method,
path,
query,
identity,
ip,
response.status_code,
duration_ms,
)
return response

61
api/metrics_guard.py Normal file
View file

@ -0,0 +1,61 @@
"""IP allowlist middleware for the /metrics endpoint."""
from __future__ import annotations
import ipaddress
import logging
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from api.rate_limit_config import RateLimitConfig
logger = logging.getLogger("uvicorn")
def parse_allowed_networks(raw: str) -> list[ipaddress.IPv4Network | ipaddress.IPv6Network]:
"""Parse a comma-separated string of IPs/CIDRs into network objects."""
networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = []
for entry in raw.split(","):
entry = entry.strip()
if not entry:
continue
networks.append(ipaddress.ip_network(entry, strict=False))
return networks
def is_ip_allowed(
ip_str: str,
allowed_networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network],
) -> bool:
"""Check whether an IP address falls within any of the allowed networks."""
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
return False
return any(addr in network for network in allowed_networks)
class MetricsGuardMiddleware(BaseHTTPMiddleware):
"""Restricts /metrics access to an IP allowlist."""
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
super().__init__(app)
cfg = config or RateLimitConfig.from_env()
self._allowed_networks = parse_allowed_networks(cfg.metrics_allowed_ips)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
if not request.url.path.startswith("/metrics"):
return await call_next(request)
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
client_ip = forwarded.split(",")[0].strip()
else:
client_ip = request.client.host if request.client else "unknown"
if not is_ip_allowed(client_ip, self._allowed_networks):
logger.warning("Metrics access denied for IP %s", client_ip)
return JSONResponse(status_code=403, content={"detail": "Forbidden"})
return await call_next(request)

103
api/rate_limit_config.py Normal file
View file

@ -0,0 +1,103 @@
"""Rate limit and security configuration with environment variable loading."""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import Self
@dataclass(frozen=True)
class EndpointLimit:
"""Rate limit for a single endpoint pattern."""
max_requests: int
window_seconds: int
@dataclass(frozen=True)
class RateLimitConfig:
"""Configuration for API rate limiting, export caps, and metrics access.
All values are configurable via environment variables with sensible defaults.
"""
# Per-endpoint rate limits
endpoint_limits: dict[str, EndpointLimit] = field(default_factory=lambda: {
"/api/listing": EndpointLimit(30, 60),
"/api/listing_geojson": EndpointLimit(10, 60),
"/api/listing_geojson/stream": EndpointLimit(10, 60),
"/api/refresh_listings": EndpointLimit(3, 300),
"/api/task_status": EndpointLimit(60, 60),
"/api/tasks_for_user": EndpointLimit(30, 60),
"/api/cancel_task": EndpointLimit(10, 60),
"/api/clear_all_tasks": EndpointLimit(5, 60),
"/api/get_districts": EndpointLimit(20, 60),
"/api/passkey": EndpointLimit(10, 60),
})
# Bulk export caps
listing_limit_cap: int = 100
geojson_limit_cap: int = 5_000
geojson_stream_limit_cap: int = 10_000
geojson_stream_batch_size_cap: int = 200
# Redis DB for rate limit counters
rate_limit_redis_db: int = 3
# Metrics endpoint IP allowlist (comma-separated CIDRs)
metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1"
@classmethod
def from_env(cls) -> Self:
"""Load configuration from environment variables.
Environment variables:
RATE_LIMIT_LISTING: /api/listing limit (default: 30/60s)
RATE_LIMIT_GEOJSON: /api/listing_geojson limit (default: 10/60s)
RATE_LIMIT_GEOJSON_STREAM: /api/listing_geojson/stream limit (default: 10/60s)
RATE_LIMIT_REFRESH: /api/refresh_listings limit (default: 3/300s)
RATE_LIMIT_TASK_STATUS: /api/task_status limit (default: 60/60s)
RATE_LIMIT_TASKS_FOR_USER: /api/tasks_for_user limit (default: 30/60s)
RATE_LIMIT_CANCEL_TASK: /api/cancel_task limit (default: 10/60s)
RATE_LIMIT_CLEAR_TASKS: /api/clear_all_tasks limit (default: 5/60s)
RATE_LIMIT_DISTRICTS: /api/get_districts limit (default: 20/60s)
RATE_LIMIT_PASSKEY: /api/passkey/* limit (default: 10/60s)
EXPORT_LISTING_LIMIT_CAP: Max listings per request (default: 100)
EXPORT_GEOJSON_LIMIT_CAP: Max GeoJSON features per request (default: 5000)
EXPORT_GEOJSON_STREAM_LIMIT_CAP: Max streamed features (default: 10000)
EXPORT_GEOJSON_STREAM_BATCH_CAP: Max stream batch size (default: 200)
RATE_LIMIT_REDIS_DB: Redis DB number for counters (default: 3)
METRICS_ALLOWED_IPS: Comma-separated CIDRs for /metrics (default: private ranges)
"""
def _parse_limit(env_var: str, default_requests: int, default_window: int) -> EndpointLimit:
raw = os.environ.get(env_var)
if raw:
parts = raw.split("/")
if len(parts) == 2:
return EndpointLimit(int(parts[0]), int(parts[1]))
return EndpointLimit(default_requests, default_window)
return cls(
endpoint_limits={
"/api/listing": _parse_limit("RATE_LIMIT_LISTING", 30, 60),
"/api/listing_geojson": _parse_limit("RATE_LIMIT_GEOJSON", 10, 60),
"/api/listing_geojson/stream": _parse_limit("RATE_LIMIT_GEOJSON_STREAM", 10, 60),
"/api/refresh_listings": _parse_limit("RATE_LIMIT_REFRESH", 3, 300),
"/api/task_status": _parse_limit("RATE_LIMIT_TASK_STATUS", 60, 60),
"/api/tasks_for_user": _parse_limit("RATE_LIMIT_TASKS_FOR_USER", 30, 60),
"/api/cancel_task": _parse_limit("RATE_LIMIT_CANCEL_TASK", 10, 60),
"/api/clear_all_tasks": _parse_limit("RATE_LIMIT_CLEAR_TASKS", 5, 60),
"/api/get_districts": _parse_limit("RATE_LIMIT_DISTRICTS", 20, 60),
"/api/passkey": _parse_limit("RATE_LIMIT_PASSKEY", 10, 60),
},
listing_limit_cap=int(os.environ.get("EXPORT_LISTING_LIMIT_CAP", "100")),
geojson_limit_cap=int(os.environ.get("EXPORT_GEOJSON_LIMIT_CAP", "5000")),
geojson_stream_limit_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_LIMIT_CAP", "10000")),
geojson_stream_batch_size_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_BATCH_CAP", "200")),
rate_limit_redis_db=int(os.environ.get("RATE_LIMIT_REDIS_DB", "3")),
metrics_allowed_ips=os.environ.get(
"METRICS_ALLOWED_IPS",
"127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1",
),
)

132
api/rate_limiter.py Normal file
View file

@ -0,0 +1,132 @@
"""Per-user rate limiting middleware using Redis fixed-window counters."""
from __future__ import annotations
import logging
import os
from urllib.parse import urlparse, urlunparse
import jwt
import redis
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from api.rate_limit_config import EndpointLimit, RateLimitConfig
logger = logging.getLogger("uvicorn")
# Paths exempt from rate limiting
EXEMPT_PATHS = {"/api/status", "/metrics"}
def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg]
"""Get a Redis client for rate limit counters, using the configured DB."""
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
parsed = urlparse(broker_url)
url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}"))
return redis.from_url(url, decode_responses=True)
def _extract_user_email(request: Request) -> str | None:
"""Extract user email from JWT without full verification.
Mirrors the unverified-decode pattern used in api/auth.py:89 for routing.
Only used for rate-limit keying, not for authorization.
"""
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[7:]
try:
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
return payload.get("email")
except jwt.PyJWTError:
return None
def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
"""Find the rate limit for a request path.
Passkey routes (/api/passkey/*) all share the /api/passkey limit.
"""
if path.startswith("/api/passkey"):
return config.endpoint_limits.get("/api/passkey")
return config.endpoint_limits.get(path)
def _client_ip(request: Request) -> str:
"""Best-effort client IP from X-Forwarded-For or connection."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
client = request.client
return client.host if client else "unknown"
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
super().__init__(app)
self.config = config or RateLimitConfig.from_env()
try:
self._redis = _get_rate_limit_redis(self.config)
self._redis.ping()
except redis.RedisError:
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
self._redis = None
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
path = request.url.path
# Skip exempt paths
if path in EXEMPT_PATHS:
return await call_next(request)
limit = _match_endpoint(path, self.config)
if limit is None:
return await call_next(request)
# Determine identity for the counter key
identity = _extract_user_email(request) or _client_ip(request)
# If Redis is unavailable, fail open
if self._redis is None:
return await call_next(request)
redis_key = f"ratelimit:{identity}:{path}"
try:
pipe = self._redis.pipeline(transaction=True)
pipe.incr(redis_key)
pipe.ttl(redis_key)
result = pipe.execute()
current_count: int = result[0]
ttl: int = result[1]
# Set expiry on first request in window
if ttl == -1:
self._redis.expire(redis_key, limit.window_seconds)
ttl = limit.window_seconds
remaining = max(0, limit.max_requests - current_count)
if current_count > limit.max_requests:
retry_after = max(1, ttl)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": str(limit.max_requests),
"X-RateLimit-Remaining": "0",
},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
except redis.RedisError as e:
logger.warning(f"Rate limiter Redis error, failing open: {e}")
return await call_next(request)