Add API rate limiting, metrics guard, and audit middleware
Per-user rate limits via Redis sliding window, IP-restricted /metrics endpoint, audit logging of all requests, CORS tightening, and export caps on listing/geojson endpoints.
This commit is contained in:
parent
08ac72bbfc
commit
87b5bd8676
8 changed files with 756 additions and 2 deletions
28
api/app.py
28
api/app.py
|
|
@ -7,6 +7,10 @@ from typing import Annotated, AsyncGenerator, Optional
|
|||
from api.auth import get_current_user
|
||||
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
||||
from api.passkey_routes import passkey_router
|
||||
from api.rate_limit_config import RateLimitConfig
|
||||
from api.rate_limiter import RateLimitMiddleware
|
||||
from api.audit_middleware import AuditLogMiddleware
|
||||
from api.metrics_guard import MetricsGuardMiddleware
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Depends, FastAPI, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
|
@ -33,6 +37,7 @@ load_dotenv()
|
|||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
DEFAULT_BATCH_SIZE = 50
|
||||
_rate_limit_config = RateLimitConfig.from_env()
|
||||
|
||||
|
||||
def get_query_parameters(
|
||||
|
|
@ -82,10 +87,18 @@ hist = meter.create_histogram(
|
|||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
# Security middleware (added bottom-to-top; last added = outermost)
|
||||
# 3. Rate limiting — enforces per-user limits
|
||||
app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
|
||||
# 2. Metrics guard — blocks unauthorized /metrics access
|
||||
app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config)
|
||||
# 1. Audit logging — logs everything including 429s and 403s
|
||||
app.add_middleware(AuditLogMiddleware)
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status() -> dict[str, str]:
|
||||
|
|
@ -100,6 +113,7 @@ async def get_listing(
|
|||
limit: int = 5,
|
||||
) -> dict[str, list]:
|
||||
"""Get listings from the database."""
|
||||
limit = min(limit, _rate_limit_config.listing_limit_cap)
|
||||
repository = ListingRepository(engine)
|
||||
result = await listing_service.get_listings(repository, limit=limit)
|
||||
logger.info(f"Fetched {result.total_count} listings for {user.email}")
|
||||
|
|
@ -113,6 +127,10 @@ async def get_listing_geojson(
|
|||
limit: int | None = None,
|
||||
) -> dict:
|
||||
"""Get listings as GeoJSON for map display."""
|
||||
if limit is not None:
|
||||
limit = min(limit, _rate_limit_config.geojson_limit_cap)
|
||||
else:
|
||||
limit = _rate_limit_config.geojson_limit_cap
|
||||
repository = ListingRepository(engine)
|
||||
result = await export_service.export_to_geojson(
|
||||
repository,
|
||||
|
|
@ -204,6 +222,12 @@ async def stream_listing_geojson(
|
|||
- batch: Array of GeoJSON features
|
||||
- complete: Final message with total count
|
||||
"""
|
||||
batch_size = min(batch_size, _rate_limit_config.geojson_stream_batch_size_cap)
|
||||
if limit is not None:
|
||||
limit = min(limit, _rate_limit_config.geojson_stream_limit_cap)
|
||||
else:
|
||||
limit = _rate_limit_config.geojson_stream_limit_cap
|
||||
|
||||
cached_count = get_cached_count(query_parameters)
|
||||
if cached_count is not None and cached_count > 0:
|
||||
generator = _stream_from_cache(query_parameters, batch_size, limit)
|
||||
|
|
|
|||
63
api/audit_middleware.py
Normal file
63
api/audit_middleware.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
"""Audit logging middleware for API requests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import jwt
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
audit_logger = logging.getLogger("uvicorn.audit")
|
||||
|
||||
|
||||
def _extract_identity(request: Request) -> str:
|
||||
"""Extract user email from JWT for audit logging."""
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return "anonymous"
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
|
||||
return payload.get("email", "unknown")
|
||||
except jwt.PyJWTError:
|
||||
return "invalid-token"
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
"""Best-effort client IP."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
client = request.client
|
||||
return client.host if client else "unknown"
|
||||
|
||||
|
||||
class AuditLogMiddleware(BaseHTTPMiddleware):
|
||||
"""Logs all /api/ requests with method, path, user, IP, status, and duration."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||
path = request.url.path
|
||||
if not path.startswith("/api/"):
|
||||
return await call_next(request)
|
||||
|
||||
start = time.monotonic()
|
||||
identity = _extract_identity(request)
|
||||
ip = _client_ip(request)
|
||||
query = str(request.query_params) if request.query_params else ""
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
duration_ms = (time.monotonic() - start) * 1000
|
||||
audit_logger.info(
|
||||
"method=%s path=%s query=%s user=%s ip=%s status=%d duration_ms=%.1f",
|
||||
request.method,
|
||||
path,
|
||||
query,
|
||||
identity,
|
||||
ip,
|
||||
response.status_code,
|
||||
duration_ms,
|
||||
)
|
||||
return response
|
||||
61
api/metrics_guard.py
Normal file
61
api/metrics_guard.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""IP allowlist middleware for the /metrics endpoint."""
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from api.rate_limit_config import RateLimitConfig
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def parse_allowed_networks(raw: str) -> list[ipaddress.IPv4Network | ipaddress.IPv6Network]:
|
||||
"""Parse a comma-separated string of IPs/CIDRs into network objects."""
|
||||
networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = []
|
||||
for entry in raw.split(","):
|
||||
entry = entry.strip()
|
||||
if not entry:
|
||||
continue
|
||||
networks.append(ipaddress.ip_network(entry, strict=False))
|
||||
return networks
|
||||
|
||||
|
||||
def is_ip_allowed(
|
||||
ip_str: str,
|
||||
allowed_networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network],
|
||||
) -> bool:
|
||||
"""Check whether an IP address falls within any of the allowed networks."""
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return False
|
||||
return any(addr in network for network in allowed_networks)
|
||||
|
||||
|
||||
class MetricsGuardMiddleware(BaseHTTPMiddleware):
|
||||
"""Restricts /metrics access to an IP allowlist."""
|
||||
|
||||
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
||||
super().__init__(app)
|
||||
cfg = config or RateLimitConfig.from_env()
|
||||
self._allowed_networks = parse_allowed_networks(cfg.metrics_allowed_ips)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||
if not request.url.path.startswith("/metrics"):
|
||||
return await call_next(request)
|
||||
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
client_ip = forwarded.split(",")[0].strip()
|
||||
else:
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
if not is_ip_allowed(client_ip, self._allowed_networks):
|
||||
logger.warning("Metrics access denied for IP %s", client_ip)
|
||||
return JSONResponse(status_code=403, content={"detail": "Forbidden"})
|
||||
|
||||
return await call_next(request)
|
||||
103
api/rate_limit_config.py
Normal file
103
api/rate_limit_config.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""Rate limit and security configuration with environment variable loading."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EndpointLimit:
|
||||
"""Rate limit for a single endpoint pattern."""
|
||||
|
||||
max_requests: int
|
||||
window_seconds: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
"""Configuration for API rate limiting, export caps, and metrics access.
|
||||
|
||||
All values are configurable via environment variables with sensible defaults.
|
||||
"""
|
||||
|
||||
# Per-endpoint rate limits
|
||||
endpoint_limits: dict[str, EndpointLimit] = field(default_factory=lambda: {
|
||||
"/api/listing": EndpointLimit(30, 60),
|
||||
"/api/listing_geojson": EndpointLimit(10, 60),
|
||||
"/api/listing_geojson/stream": EndpointLimit(10, 60),
|
||||
"/api/refresh_listings": EndpointLimit(3, 300),
|
||||
"/api/task_status": EndpointLimit(60, 60),
|
||||
"/api/tasks_for_user": EndpointLimit(30, 60),
|
||||
"/api/cancel_task": EndpointLimit(10, 60),
|
||||
"/api/clear_all_tasks": EndpointLimit(5, 60),
|
||||
"/api/get_districts": EndpointLimit(20, 60),
|
||||
"/api/passkey": EndpointLimit(10, 60),
|
||||
})
|
||||
|
||||
# Bulk export caps
|
||||
listing_limit_cap: int = 100
|
||||
geojson_limit_cap: int = 5_000
|
||||
geojson_stream_limit_cap: int = 10_000
|
||||
geojson_stream_batch_size_cap: int = 200
|
||||
|
||||
# Redis DB for rate limit counters
|
||||
rate_limit_redis_db: int = 3
|
||||
|
||||
# Metrics endpoint IP allowlist (comma-separated CIDRs)
|
||||
metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1"
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> Self:
|
||||
"""Load configuration from environment variables.
|
||||
|
||||
Environment variables:
|
||||
RATE_LIMIT_LISTING: /api/listing limit (default: 30/60s)
|
||||
RATE_LIMIT_GEOJSON: /api/listing_geojson limit (default: 10/60s)
|
||||
RATE_LIMIT_GEOJSON_STREAM: /api/listing_geojson/stream limit (default: 10/60s)
|
||||
RATE_LIMIT_REFRESH: /api/refresh_listings limit (default: 3/300s)
|
||||
RATE_LIMIT_TASK_STATUS: /api/task_status limit (default: 60/60s)
|
||||
RATE_LIMIT_TASKS_FOR_USER: /api/tasks_for_user limit (default: 30/60s)
|
||||
RATE_LIMIT_CANCEL_TASK: /api/cancel_task limit (default: 10/60s)
|
||||
RATE_LIMIT_CLEAR_TASKS: /api/clear_all_tasks limit (default: 5/60s)
|
||||
RATE_LIMIT_DISTRICTS: /api/get_districts limit (default: 20/60s)
|
||||
RATE_LIMIT_PASSKEY: /api/passkey/* limit (default: 10/60s)
|
||||
EXPORT_LISTING_LIMIT_CAP: Max listings per request (default: 100)
|
||||
EXPORT_GEOJSON_LIMIT_CAP: Max GeoJSON features per request (default: 5000)
|
||||
EXPORT_GEOJSON_STREAM_LIMIT_CAP: Max streamed features (default: 10000)
|
||||
EXPORT_GEOJSON_STREAM_BATCH_CAP: Max stream batch size (default: 200)
|
||||
RATE_LIMIT_REDIS_DB: Redis DB number for counters (default: 3)
|
||||
METRICS_ALLOWED_IPS: Comma-separated CIDRs for /metrics (default: private ranges)
|
||||
"""
|
||||
|
||||
def _parse_limit(env_var: str, default_requests: int, default_window: int) -> EndpointLimit:
|
||||
raw = os.environ.get(env_var)
|
||||
if raw:
|
||||
parts = raw.split("/")
|
||||
if len(parts) == 2:
|
||||
return EndpointLimit(int(parts[0]), int(parts[1]))
|
||||
return EndpointLimit(default_requests, default_window)
|
||||
|
||||
return cls(
|
||||
endpoint_limits={
|
||||
"/api/listing": _parse_limit("RATE_LIMIT_LISTING", 30, 60),
|
||||
"/api/listing_geojson": _parse_limit("RATE_LIMIT_GEOJSON", 10, 60),
|
||||
"/api/listing_geojson/stream": _parse_limit("RATE_LIMIT_GEOJSON_STREAM", 10, 60),
|
||||
"/api/refresh_listings": _parse_limit("RATE_LIMIT_REFRESH", 3, 300),
|
||||
"/api/task_status": _parse_limit("RATE_LIMIT_TASK_STATUS", 60, 60),
|
||||
"/api/tasks_for_user": _parse_limit("RATE_LIMIT_TASKS_FOR_USER", 30, 60),
|
||||
"/api/cancel_task": _parse_limit("RATE_LIMIT_CANCEL_TASK", 10, 60),
|
||||
"/api/clear_all_tasks": _parse_limit("RATE_LIMIT_CLEAR_TASKS", 5, 60),
|
||||
"/api/get_districts": _parse_limit("RATE_LIMIT_DISTRICTS", 20, 60),
|
||||
"/api/passkey": _parse_limit("RATE_LIMIT_PASSKEY", 10, 60),
|
||||
},
|
||||
listing_limit_cap=int(os.environ.get("EXPORT_LISTING_LIMIT_CAP", "100")),
|
||||
geojson_limit_cap=int(os.environ.get("EXPORT_GEOJSON_LIMIT_CAP", "5000")),
|
||||
geojson_stream_limit_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_LIMIT_CAP", "10000")),
|
||||
geojson_stream_batch_size_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_BATCH_CAP", "200")),
|
||||
rate_limit_redis_db=int(os.environ.get("RATE_LIMIT_REDIS_DB", "3")),
|
||||
metrics_allowed_ips=os.environ.get(
|
||||
"METRICS_ALLOWED_IPS",
|
||||
"127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1",
|
||||
),
|
||||
)
|
||||
132
api/rate_limiter.py
Normal file
132
api/rate_limiter.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""Per-user rate limiting middleware using Redis fixed-window counters."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import jwt
|
||||
import redis
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from api.rate_limit_config import EndpointLimit, RateLimitConfig
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
# Paths exempt from rate limiting
|
||||
EXEMPT_PATHS = {"/api/status", "/metrics"}
|
||||
|
||||
|
||||
def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg]
|
||||
"""Get a Redis client for rate limit counters, using the configured DB."""
|
||||
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
parsed = urlparse(broker_url)
|
||||
url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}"))
|
||||
return redis.from_url(url, decode_responses=True)
|
||||
|
||||
|
||||
def _extract_user_email(request: Request) -> str | None:
|
||||
"""Extract user email from JWT without full verification.
|
||||
|
||||
Mirrors the unverified-decode pattern used in api/auth.py:89 for routing.
|
||||
Only used for rate-limit keying, not for authorization.
|
||||
"""
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
|
||||
return payload.get("email")
|
||||
except jwt.PyJWTError:
|
||||
return None
|
||||
|
||||
|
||||
def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
|
||||
"""Find the rate limit for a request path.
|
||||
|
||||
Passkey routes (/api/passkey/*) all share the /api/passkey limit.
|
||||
"""
|
||||
if path.startswith("/api/passkey"):
|
||||
return config.endpoint_limits.get("/api/passkey")
|
||||
return config.endpoint_limits.get(path)
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
"""Best-effort client IP from X-Forwarded-For or connection."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
client = request.client
|
||||
return client.host if client else "unknown"
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
|
||||
|
||||
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
||||
super().__init__(app)
|
||||
self.config = config or RateLimitConfig.from_env()
|
||||
try:
|
||||
self._redis = _get_rate_limit_redis(self.config)
|
||||
self._redis.ping()
|
||||
except redis.RedisError:
|
||||
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
|
||||
self._redis = None
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||
path = request.url.path
|
||||
|
||||
# Skip exempt paths
|
||||
if path in EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
limit = _match_endpoint(path, self.config)
|
||||
if limit is None:
|
||||
return await call_next(request)
|
||||
|
||||
# Determine identity for the counter key
|
||||
identity = _extract_user_email(request) or _client_ip(request)
|
||||
|
||||
# If Redis is unavailable, fail open
|
||||
if self._redis is None:
|
||||
return await call_next(request)
|
||||
|
||||
redis_key = f"ratelimit:{identity}:{path}"
|
||||
try:
|
||||
pipe = self._redis.pipeline(transaction=True)
|
||||
pipe.incr(redis_key)
|
||||
pipe.ttl(redis_key)
|
||||
result = pipe.execute()
|
||||
current_count: int = result[0]
|
||||
ttl: int = result[1]
|
||||
|
||||
# Set expiry on first request in window
|
||||
if ttl == -1:
|
||||
self._redis.expire(redis_key, limit.window_seconds)
|
||||
ttl = limit.window_seconds
|
||||
|
||||
remaining = max(0, limit.max_requests - current_count)
|
||||
|
||||
if current_count > limit.max_requests:
|
||||
retry_after = max(1, ttl)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"},
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"X-RateLimit-Limit": str(limit.max_requests),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Rate limiter Redis error, failing open: {e}")
|
||||
return await call_next(request)
|
||||
Loading…
Add table
Add a link
Reference in a new issue