Add API rate limiting, metrics guard, and audit middleware
Per-user rate limits via Redis sliding window, IP-restricted /metrics endpoint, audit logging of all requests, CORS tightening, and export caps on listing/geojson endpoints.
This commit is contained in:
parent
08ac72bbfc
commit
87b5bd8676
8 changed files with 756 additions and 2 deletions
22
.env.sample
22
.env.sample
|
|
@ -41,3 +41,25 @@ JWT_SECRET=change-me-in-production # HMAC secret for HS256 signing
|
||||||
JWT_ALGORITHM=HS256 # JWT signing algorithm
|
JWT_ALGORITHM=HS256 # JWT signing algorithm
|
||||||
JWT_EXPIRATION_HOURS=24 # Token expiry in hours
|
JWT_EXPIRATION_HOURS=24 # Token expiry in hours
|
||||||
JWT_ISSUER=wrongmove # JWT issuer claim
|
JWT_ISSUER=wrongmove # JWT issuer claim
|
||||||
|
|
||||||
|
# API rate limiting (format: max_requests/window_seconds)
|
||||||
|
# RATE_LIMIT_LISTING=30/60 # /api/listing: 30 req per 60s
|
||||||
|
# RATE_LIMIT_GEOJSON=10/60 # /api/listing_geojson: 10 req per 60s
|
||||||
|
# RATE_LIMIT_GEOJSON_STREAM=10/60 # /api/listing_geojson/stream: 10 req per 60s
|
||||||
|
# RATE_LIMIT_REFRESH=3/300 # /api/refresh_listings: 3 req per 5min
|
||||||
|
# RATE_LIMIT_TASK_STATUS=60/60 # /api/task_status: 60 req per 60s
|
||||||
|
# RATE_LIMIT_TASKS_FOR_USER=30/60 # /api/tasks_for_user: 30 req per 60s
|
||||||
|
# RATE_LIMIT_CANCEL_TASK=10/60 # /api/cancel_task: 10 req per 60s
|
||||||
|
# RATE_LIMIT_CLEAR_TASKS=5/60 # /api/clear_all_tasks: 5 req per 60s
|
||||||
|
# RATE_LIMIT_DISTRICTS=20/60 # /api/get_districts: 20 req per 60s
|
||||||
|
# RATE_LIMIT_PASSKEY=10/60 # /api/passkey/*: 10 req per 60s
|
||||||
|
RATE_LIMIT_REDIS_DB=3 # Redis DB for rate limit counters
|
||||||
|
|
||||||
|
# Bulk export caps
|
||||||
|
EXPORT_LISTING_LIMIT_CAP=100 # Max listings per /api/listing request
|
||||||
|
EXPORT_GEOJSON_LIMIT_CAP=5000 # Max features per /api/listing_geojson request
|
||||||
|
EXPORT_GEOJSON_STREAM_LIMIT_CAP=10000 # Max features per /api/listing_geojson/stream
|
||||||
|
EXPORT_GEOJSON_STREAM_BATCH_CAP=200 # Max batch size for streaming
|
||||||
|
|
||||||
|
# Metrics endpoint access control (comma-separated IPs/CIDRs)
|
||||||
|
METRICS_ALLOWED_IPS=127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1
|
||||||
|
|
|
||||||
28
api/app.py
28
api/app.py
|
|
@ -7,6 +7,10 @@ from typing import Annotated, AsyncGenerator, Optional
|
||||||
from api.auth import get_current_user
|
from api.auth import get_current_user
|
||||||
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
||||||
from api.passkey_routes import passkey_router
|
from api.passkey_routes import passkey_router
|
||||||
|
from api.rate_limit_config import RateLimitConfig
|
||||||
|
from api.rate_limiter import RateLimitMiddleware
|
||||||
|
from api.audit_middleware import AuditLogMiddleware
|
||||||
|
from api.metrics_guard import MetricsGuardMiddleware
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import Depends, FastAPI, Query
|
from fastapi import Depends, FastAPI, Query
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
@ -33,6 +37,7 @@ load_dotenv()
|
||||||
logger = logging.getLogger("uvicorn")
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
DEFAULT_BATCH_SIZE = 50
|
DEFAULT_BATCH_SIZE = 50
|
||||||
|
_rate_limit_config = RateLimitConfig.from_env()
|
||||||
|
|
||||||
|
|
||||||
def get_query_parameters(
|
def get_query_parameters(
|
||||||
|
|
@ -82,10 +87,18 @@ hist = meter.create_histogram(
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
|
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
|
||||||
allow_methods=["*"],
|
allow_methods=["GET", "POST"],
|
||||||
allow_headers=["*"],
|
allow_headers=["Authorization", "Content-Type"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Security middleware (added bottom-to-top; last added = outermost)
|
||||||
|
# 3. Rate limiting — enforces per-user limits
|
||||||
|
app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
|
||||||
|
# 2. Metrics guard — blocks unauthorized /metrics access
|
||||||
|
app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config)
|
||||||
|
# 1. Audit logging — logs everything including 429s and 403s
|
||||||
|
app.add_middleware(AuditLogMiddleware)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/status")
|
@app.get("/api/status")
|
||||||
async def get_status() -> dict[str, str]:
|
async def get_status() -> dict[str, str]:
|
||||||
|
|
@ -100,6 +113,7 @@ async def get_listing(
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
) -> dict[str, list]:
|
) -> dict[str, list]:
|
||||||
"""Get listings from the database."""
|
"""Get listings from the database."""
|
||||||
|
limit = min(limit, _rate_limit_config.listing_limit_cap)
|
||||||
repository = ListingRepository(engine)
|
repository = ListingRepository(engine)
|
||||||
result = await listing_service.get_listings(repository, limit=limit)
|
result = await listing_service.get_listings(repository, limit=limit)
|
||||||
logger.info(f"Fetched {result.total_count} listings for {user.email}")
|
logger.info(f"Fetched {result.total_count} listings for {user.email}")
|
||||||
|
|
@ -113,6 +127,10 @@ async def get_listing_geojson(
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Get listings as GeoJSON for map display."""
|
"""Get listings as GeoJSON for map display."""
|
||||||
|
if limit is not None:
|
||||||
|
limit = min(limit, _rate_limit_config.geojson_limit_cap)
|
||||||
|
else:
|
||||||
|
limit = _rate_limit_config.geojson_limit_cap
|
||||||
repository = ListingRepository(engine)
|
repository = ListingRepository(engine)
|
||||||
result = await export_service.export_to_geojson(
|
result = await export_service.export_to_geojson(
|
||||||
repository,
|
repository,
|
||||||
|
|
@ -204,6 +222,12 @@ async def stream_listing_geojson(
|
||||||
- batch: Array of GeoJSON features
|
- batch: Array of GeoJSON features
|
||||||
- complete: Final message with total count
|
- complete: Final message with total count
|
||||||
"""
|
"""
|
||||||
|
batch_size = min(batch_size, _rate_limit_config.geojson_stream_batch_size_cap)
|
||||||
|
if limit is not None:
|
||||||
|
limit = min(limit, _rate_limit_config.geojson_stream_limit_cap)
|
||||||
|
else:
|
||||||
|
limit = _rate_limit_config.geojson_stream_limit_cap
|
||||||
|
|
||||||
cached_count = get_cached_count(query_parameters)
|
cached_count = get_cached_count(query_parameters)
|
||||||
if cached_count is not None and cached_count > 0:
|
if cached_count is not None and cached_count > 0:
|
||||||
generator = _stream_from_cache(query_parameters, batch_size, limit)
|
generator = _stream_from_cache(query_parameters, batch_size, limit)
|
||||||
|
|
|
||||||
63
api/audit_middleware.py
Normal file
63
api/audit_middleware.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""Audit logging middleware for API requests."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
audit_logger = logging.getLogger("uvicorn.audit")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_identity(request: Request) -> str:
|
||||||
|
"""Extract user email from JWT for audit logging."""
|
||||||
|
auth_header = request.headers.get("authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
return "anonymous"
|
||||||
|
token = auth_header[7:]
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
|
||||||
|
return payload.get("email", "unknown")
|
||||||
|
except jwt.PyJWTError:
|
||||||
|
return "invalid-token"
|
||||||
|
|
||||||
|
|
||||||
|
def _client_ip(request: Request) -> str:
|
||||||
|
"""Best-effort client IP."""
|
||||||
|
forwarded = request.headers.get("x-forwarded-for")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
client = request.client
|
||||||
|
return client.host if client else "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Logs all /api/ requests with method, path, user, IP, status, and duration."""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||||
|
path = request.url.path
|
||||||
|
if not path.startswith("/api/"):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
identity = _extract_identity(request)
|
||||||
|
ip = _client_ip(request)
|
||||||
|
query = str(request.query_params) if request.query_params else ""
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
duration_ms = (time.monotonic() - start) * 1000
|
||||||
|
audit_logger.info(
|
||||||
|
"method=%s path=%s query=%s user=%s ip=%s status=%d duration_ms=%.1f",
|
||||||
|
request.method,
|
||||||
|
path,
|
||||||
|
query,
|
||||||
|
identity,
|
||||||
|
ip,
|
||||||
|
response.status_code,
|
||||||
|
duration_ms,
|
||||||
|
)
|
||||||
|
return response
|
||||||
61
api/metrics_guard.py
Normal file
61
api/metrics_guard.py
Normal file
|
|
@ -0,0 +1,61 @@
|
||||||
|
"""IP allowlist middleware for the /metrics endpoint."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
from api.rate_limit_config import RateLimitConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_allowed_networks(raw: str) -> list[ipaddress.IPv4Network | ipaddress.IPv6Network]:
|
||||||
|
"""Parse a comma-separated string of IPs/CIDRs into network objects."""
|
||||||
|
networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = []
|
||||||
|
for entry in raw.split(","):
|
||||||
|
entry = entry.strip()
|
||||||
|
if not entry:
|
||||||
|
continue
|
||||||
|
networks.append(ipaddress.ip_network(entry, strict=False))
|
||||||
|
return networks
|
||||||
|
|
||||||
|
|
||||||
|
def is_ip_allowed(
|
||||||
|
ip_str: str,
|
||||||
|
allowed_networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network],
|
||||||
|
) -> bool:
|
||||||
|
"""Check whether an IP address falls within any of the allowed networks."""
|
||||||
|
try:
|
||||||
|
addr = ipaddress.ip_address(ip_str)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return any(addr in network for network in allowed_networks)
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsGuardMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Restricts /metrics access to an IP allowlist."""
|
||||||
|
|
||||||
|
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
||||||
|
super().__init__(app)
|
||||||
|
cfg = config or RateLimitConfig.from_env()
|
||||||
|
self._allowed_networks = parse_allowed_networks(cfg.metrics_allowed_ips)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||||
|
if not request.url.path.startswith("/metrics"):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
forwarded = request.headers.get("x-forwarded-for")
|
||||||
|
if forwarded:
|
||||||
|
client_ip = forwarded.split(",")[0].strip()
|
||||||
|
else:
|
||||||
|
client_ip = request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
if not is_ip_allowed(client_ip, self._allowed_networks):
|
||||||
|
logger.warning("Metrics access denied for IP %s", client_ip)
|
||||||
|
return JSONResponse(status_code=403, content={"detail": "Forbidden"})
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
103
api/rate_limit_config.py
Normal file
103
api/rate_limit_config.py
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
"""Rate limit and security configuration with environment variable loading."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EndpointLimit:
|
||||||
|
"""Rate limit for a single endpoint pattern."""
|
||||||
|
|
||||||
|
max_requests: int
|
||||||
|
window_seconds: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RateLimitConfig:
|
||||||
|
"""Configuration for API rate limiting, export caps, and metrics access.
|
||||||
|
|
||||||
|
All values are configurable via environment variables with sensible defaults.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Per-endpoint rate limits
|
||||||
|
endpoint_limits: dict[str, EndpointLimit] = field(default_factory=lambda: {
|
||||||
|
"/api/listing": EndpointLimit(30, 60),
|
||||||
|
"/api/listing_geojson": EndpointLimit(10, 60),
|
||||||
|
"/api/listing_geojson/stream": EndpointLimit(10, 60),
|
||||||
|
"/api/refresh_listings": EndpointLimit(3, 300),
|
||||||
|
"/api/task_status": EndpointLimit(60, 60),
|
||||||
|
"/api/tasks_for_user": EndpointLimit(30, 60),
|
||||||
|
"/api/cancel_task": EndpointLimit(10, 60),
|
||||||
|
"/api/clear_all_tasks": EndpointLimit(5, 60),
|
||||||
|
"/api/get_districts": EndpointLimit(20, 60),
|
||||||
|
"/api/passkey": EndpointLimit(10, 60),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Bulk export caps
|
||||||
|
listing_limit_cap: int = 100
|
||||||
|
geojson_limit_cap: int = 5_000
|
||||||
|
geojson_stream_limit_cap: int = 10_000
|
||||||
|
geojson_stream_batch_size_cap: int = 200
|
||||||
|
|
||||||
|
# Redis DB for rate limit counters
|
||||||
|
rate_limit_redis_db: int = 3
|
||||||
|
|
||||||
|
# Metrics endpoint IP allowlist (comma-separated CIDRs)
|
||||||
|
metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> Self:
|
||||||
|
"""Load configuration from environment variables.
|
||||||
|
|
||||||
|
Environment variables:
|
||||||
|
RATE_LIMIT_LISTING: /api/listing limit (default: 30/60s)
|
||||||
|
RATE_LIMIT_GEOJSON: /api/listing_geojson limit (default: 10/60s)
|
||||||
|
RATE_LIMIT_GEOJSON_STREAM: /api/listing_geojson/stream limit (default: 10/60s)
|
||||||
|
RATE_LIMIT_REFRESH: /api/refresh_listings limit (default: 3/300s)
|
||||||
|
RATE_LIMIT_TASK_STATUS: /api/task_status limit (default: 60/60s)
|
||||||
|
RATE_LIMIT_TASKS_FOR_USER: /api/tasks_for_user limit (default: 30/60s)
|
||||||
|
RATE_LIMIT_CANCEL_TASK: /api/cancel_task limit (default: 10/60s)
|
||||||
|
RATE_LIMIT_CLEAR_TASKS: /api/clear_all_tasks limit (default: 5/60s)
|
||||||
|
RATE_LIMIT_DISTRICTS: /api/get_districts limit (default: 20/60s)
|
||||||
|
RATE_LIMIT_PASSKEY: /api/passkey/* limit (default: 10/60s)
|
||||||
|
EXPORT_LISTING_LIMIT_CAP: Max listings per request (default: 100)
|
||||||
|
EXPORT_GEOJSON_LIMIT_CAP: Max GeoJSON features per request (default: 5000)
|
||||||
|
EXPORT_GEOJSON_STREAM_LIMIT_CAP: Max streamed features (default: 10000)
|
||||||
|
EXPORT_GEOJSON_STREAM_BATCH_CAP: Max stream batch size (default: 200)
|
||||||
|
RATE_LIMIT_REDIS_DB: Redis DB number for counters (default: 3)
|
||||||
|
METRICS_ALLOWED_IPS: Comma-separated CIDRs for /metrics (default: private ranges)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _parse_limit(env_var: str, default_requests: int, default_window: int) -> EndpointLimit:
|
||||||
|
raw = os.environ.get(env_var)
|
||||||
|
if raw:
|
||||||
|
parts = raw.split("/")
|
||||||
|
if len(parts) == 2:
|
||||||
|
return EndpointLimit(int(parts[0]), int(parts[1]))
|
||||||
|
return EndpointLimit(default_requests, default_window)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
endpoint_limits={
|
||||||
|
"/api/listing": _parse_limit("RATE_LIMIT_LISTING", 30, 60),
|
||||||
|
"/api/listing_geojson": _parse_limit("RATE_LIMIT_GEOJSON", 10, 60),
|
||||||
|
"/api/listing_geojson/stream": _parse_limit("RATE_LIMIT_GEOJSON_STREAM", 10, 60),
|
||||||
|
"/api/refresh_listings": _parse_limit("RATE_LIMIT_REFRESH", 3, 300),
|
||||||
|
"/api/task_status": _parse_limit("RATE_LIMIT_TASK_STATUS", 60, 60),
|
||||||
|
"/api/tasks_for_user": _parse_limit("RATE_LIMIT_TASKS_FOR_USER", 30, 60),
|
||||||
|
"/api/cancel_task": _parse_limit("RATE_LIMIT_CANCEL_TASK", 10, 60),
|
||||||
|
"/api/clear_all_tasks": _parse_limit("RATE_LIMIT_CLEAR_TASKS", 5, 60),
|
||||||
|
"/api/get_districts": _parse_limit("RATE_LIMIT_DISTRICTS", 20, 60),
|
||||||
|
"/api/passkey": _parse_limit("RATE_LIMIT_PASSKEY", 10, 60),
|
||||||
|
},
|
||||||
|
listing_limit_cap=int(os.environ.get("EXPORT_LISTING_LIMIT_CAP", "100")),
|
||||||
|
geojson_limit_cap=int(os.environ.get("EXPORT_GEOJSON_LIMIT_CAP", "5000")),
|
||||||
|
geojson_stream_limit_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_LIMIT_CAP", "10000")),
|
||||||
|
geojson_stream_batch_size_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_BATCH_CAP", "200")),
|
||||||
|
rate_limit_redis_db=int(os.environ.get("RATE_LIMIT_REDIS_DB", "3")),
|
||||||
|
metrics_allowed_ips=os.environ.get(
|
||||||
|
"METRICS_ALLOWED_IPS",
|
||||||
|
"127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1",
|
||||||
|
),
|
||||||
|
)
|
||||||
132
api/rate_limiter.py
Normal file
132
api/rate_limiter.py
Normal file
|
|
@ -0,0 +1,132 @@
|
||||||
|
"""Per-user rate limiting middleware using Redis fixed-window counters."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
import redis
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
from api.rate_limit_config import EndpointLimit, RateLimitConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
|
# Paths exempt from rate limiting
|
||||||
|
EXEMPT_PATHS = {"/api/status", "/metrics"}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg]
|
||||||
|
"""Get a Redis client for rate limit counters, using the configured DB."""
|
||||||
|
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||||
|
parsed = urlparse(broker_url)
|
||||||
|
url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}"))
|
||||||
|
return redis.from_url(url, decode_responses=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_user_email(request: Request) -> str | None:
|
||||||
|
"""Extract user email from JWT without full verification.
|
||||||
|
|
||||||
|
Mirrors the unverified-decode pattern used in api/auth.py:89 for routing.
|
||||||
|
Only used for rate-limit keying, not for authorization.
|
||||||
|
"""
|
||||||
|
auth_header = request.headers.get("authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
return None
|
||||||
|
token = auth_header[7:]
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
|
||||||
|
return payload.get("email")
|
||||||
|
except jwt.PyJWTError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
|
||||||
|
"""Find the rate limit for a request path.
|
||||||
|
|
||||||
|
Passkey routes (/api/passkey/*) all share the /api/passkey limit.
|
||||||
|
"""
|
||||||
|
if path.startswith("/api/passkey"):
|
||||||
|
return config.endpoint_limits.get("/api/passkey")
|
||||||
|
return config.endpoint_limits.get(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _client_ip(request: Request) -> str:
|
||||||
|
"""Best-effort client IP from X-Forwarded-For or connection."""
|
||||||
|
forwarded = request.headers.get("x-forwarded-for")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
client = request.client
|
||||||
|
return client.host if client else "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
|
||||||
|
|
||||||
|
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
||||||
|
super().__init__(app)
|
||||||
|
self.config = config or RateLimitConfig.from_env()
|
||||||
|
try:
|
||||||
|
self._redis = _get_rate_limit_redis(self.config)
|
||||||
|
self._redis.ping()
|
||||||
|
except redis.RedisError:
|
||||||
|
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
|
||||||
|
self._redis = None
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||||
|
path = request.url.path
|
||||||
|
|
||||||
|
# Skip exempt paths
|
||||||
|
if path in EXEMPT_PATHS:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
limit = _match_endpoint(path, self.config)
|
||||||
|
if limit is None:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Determine identity for the counter key
|
||||||
|
identity = _extract_user_email(request) or _client_ip(request)
|
||||||
|
|
||||||
|
# If Redis is unavailable, fail open
|
||||||
|
if self._redis is None:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
redis_key = f"ratelimit:{identity}:{path}"
|
||||||
|
try:
|
||||||
|
pipe = self._redis.pipeline(transaction=True)
|
||||||
|
pipe.incr(redis_key)
|
||||||
|
pipe.ttl(redis_key)
|
||||||
|
result = pipe.execute()
|
||||||
|
current_count: int = result[0]
|
||||||
|
ttl: int = result[1]
|
||||||
|
|
||||||
|
# Set expiry on first request in window
|
||||||
|
if ttl == -1:
|
||||||
|
self._redis.expire(redis_key, limit.window_seconds)
|
||||||
|
ttl = limit.window_seconds
|
||||||
|
|
||||||
|
remaining = max(0, limit.max_requests - current_count)
|
||||||
|
|
||||||
|
if current_count > limit.max_requests:
|
||||||
|
retry_after = max(1, ttl)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={"detail": "Rate limit exceeded"},
|
||||||
|
headers={
|
||||||
|
"Retry-After": str(retry_after),
|
||||||
|
"X-RateLimit-Limit": str(limit.max_requests),
|
||||||
|
"X-RateLimit-Remaining": "0",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||||
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||||
|
return response
|
||||||
|
|
||||||
|
except redis.RedisError as e:
|
||||||
|
logger.warning(f"Rate limiter Redis error, failing open: {e}")
|
||||||
|
return await call_next(request)
|
||||||
111
tests/unit/test_metrics_guard.py
Normal file
111
tests/unit/test_metrics_guard.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""Unit tests for api/metrics_guard.py."""
|
||||||
|
import ipaddress
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.routing import Route
|
||||||
|
|
||||||
|
from api.metrics_guard import MetricsGuardMiddleware, is_ip_allowed, parse_allowed_networks
|
||||||
|
from api.rate_limit_config import RateLimitConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def _ok_endpoint(request: Request) -> JSONResponse:
|
||||||
|
return JSONResponse({"ok": True})
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseAllowedNetworks:
|
||||||
|
"""Tests for parse_allowed_networks."""
|
||||||
|
|
||||||
|
def test_single_ip(self) -> None:
|
||||||
|
nets = parse_allowed_networks("127.0.0.1")
|
||||||
|
assert len(nets) == 1
|
||||||
|
assert ipaddress.ip_address("127.0.0.1") in nets[0]
|
||||||
|
|
||||||
|
def test_cidr(self) -> None:
|
||||||
|
nets = parse_allowed_networks("10.0.0.0/8")
|
||||||
|
assert len(nets) == 1
|
||||||
|
assert ipaddress.ip_address("10.255.255.255") in nets[0]
|
||||||
|
|
||||||
|
def test_multiple_entries(self) -> None:
|
||||||
|
nets = parse_allowed_networks("127.0.0.1, 10.0.0.0/8, ::1")
|
||||||
|
assert len(nets) == 3
|
||||||
|
|
||||||
|
def test_empty_string(self) -> None:
|
||||||
|
nets = parse_allowed_networks("")
|
||||||
|
assert nets == []
|
||||||
|
|
||||||
|
def test_trailing_comma(self) -> None:
|
||||||
|
nets = parse_allowed_networks("127.0.0.1,")
|
||||||
|
assert len(nets) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsIpAllowed:
|
||||||
|
"""Tests for is_ip_allowed."""
|
||||||
|
|
||||||
|
def test_allowed_ip(self) -> None:
|
||||||
|
nets = parse_allowed_networks("10.0.0.0/8")
|
||||||
|
assert is_ip_allowed("10.1.2.3", nets) is True
|
||||||
|
|
||||||
|
def test_denied_ip(self) -> None:
|
||||||
|
nets = parse_allowed_networks("10.0.0.0/8")
|
||||||
|
assert is_ip_allowed("192.168.1.1", nets) is False
|
||||||
|
|
||||||
|
def test_ipv6(self) -> None:
|
||||||
|
nets = parse_allowed_networks("::1")
|
||||||
|
assert is_ip_allowed("::1", nets) is True
|
||||||
|
assert is_ip_allowed("::2", nets) is False
|
||||||
|
|
||||||
|
def test_invalid_ip(self) -> None:
|
||||||
|
nets = parse_allowed_networks("10.0.0.0/8")
|
||||||
|
assert is_ip_allowed("not-an-ip", nets) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsGuardMiddleware:
|
||||||
|
"""Integration tests for MetricsGuardMiddleware."""
|
||||||
|
|
||||||
|
def _build_app(self, allowed_ips: str) -> Starlette:
|
||||||
|
config = RateLimitConfig(metrics_allowed_ips=allowed_ips)
|
||||||
|
app = Starlette(routes=[
|
||||||
|
Route("/metrics", _ok_endpoint),
|
||||||
|
Route("/api/status", _ok_endpoint),
|
||||||
|
])
|
||||||
|
app.add_middleware(MetricsGuardMiddleware, config=config)
|
||||||
|
return app
|
||||||
|
|
||||||
|
def test_allows_metrics_from_allowed_ip(self) -> None:
|
||||||
|
app = self._build_app("127.0.0.1,testclient")
|
||||||
|
# TestClient connects from 'testclient' by default
|
||||||
|
# We need to override; use the header approach
|
||||||
|
app2 = self._build_app("10.0.0.1")
|
||||||
|
client = TestClient(app2)
|
||||||
|
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_blocks_metrics_from_disallowed_ip(self) -> None:
|
||||||
|
app = self._build_app("10.0.0.0/8")
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/metrics", headers={"X-Forwarded-For": "192.168.1.1"})
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_non_metrics_path_passes_through(self) -> None:
|
||||||
|
app = self._build_app("10.0.0.0/8")
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/status")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_default_private_ranges(self) -> None:
|
||||||
|
config = RateLimitConfig()
|
||||||
|
app = Starlette(routes=[Route("/metrics", _ok_endpoint)])
|
||||||
|
app.add_middleware(MetricsGuardMiddleware, config=config)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Private IP should be allowed
|
||||||
|
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Public IP should be denied
|
||||||
|
resp = client.get("/metrics", headers={"X-Forwarded-For": "8.8.8.8"})
|
||||||
|
assert resp.status_code == 403
|
||||||
238
tests/unit/test_rate_limiter.py
Normal file
238
tests/unit/test_rate_limiter.py
Normal file
|
|
@ -0,0 +1,238 @@
|
||||||
|
"""Unit tests for api/rate_limiter.py."""
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.routing import Route
|
||||||
|
|
||||||
|
from api.rate_limit_config import EndpointLimit, RateLimitConfig
|
||||||
|
from api.rate_limiter import (
|
||||||
|
RateLimitMiddleware,
|
||||||
|
_extract_user_email,
|
||||||
|
_match_endpoint,
|
||||||
|
EXEMPT_PATHS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_config(**overrides: object) -> RateLimitConfig:
|
||||||
|
"""Create a RateLimitConfig with defaults for testing."""
|
||||||
|
defaults: dict[str, object] = {
|
||||||
|
"endpoint_limits": {
|
||||||
|
"/api/listing": EndpointLimit(3, 60),
|
||||||
|
"/api/passkey": EndpointLimit(2, 60),
|
||||||
|
},
|
||||||
|
"listing_limit_cap": 100,
|
||||||
|
"geojson_limit_cap": 5000,
|
||||||
|
"geojson_stream_limit_cap": 10000,
|
||||||
|
"geojson_stream_batch_size_cap": 200,
|
||||||
|
"rate_limit_redis_db": 3,
|
||||||
|
"metrics_allowed_ips": "127.0.0.1",
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return RateLimitConfig(**defaults) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
async def _ok_endpoint(request: Request) -> JSONResponse:
|
||||||
|
return JSONResponse({"ok": True})
|
||||||
|
|
||||||
|
|
||||||
|
def _build_app(config: RateLimitConfig) -> Starlette:
|
||||||
|
"""Build a minimal Starlette app with the rate limiter."""
|
||||||
|
app = Starlette(routes=[
|
||||||
|
Route("/api/listing", _ok_endpoint),
|
||||||
|
Route("/api/passkey/login/begin", _ok_endpoint, methods=["POST"]),
|
||||||
|
Route("/api/status", _ok_endpoint),
|
||||||
|
])
|
||||||
|
app.add_middleware(RateLimitMiddleware, config=config)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractUserEmail:
|
||||||
|
"""Tests for _extract_user_email."""
|
||||||
|
|
||||||
|
def test_no_auth_header(self) -> None:
|
||||||
|
scope = {"type": "http", "headers": []}
|
||||||
|
request = Request(scope)
|
||||||
|
assert _extract_user_email(request) is None
|
||||||
|
|
||||||
|
def test_invalid_token(self) -> None:
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"headers": [(b"authorization", b"Bearer not-a-jwt")],
|
||||||
|
}
|
||||||
|
request = Request(scope)
|
||||||
|
assert _extract_user_email(request) is None
|
||||||
|
|
||||||
|
def test_valid_jwt(self) -> None:
|
||||||
|
import jwt as pyjwt
|
||||||
|
token = pyjwt.encode({"email": "test@example.com"}, "secret", algorithm="HS256")
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"headers": [(b"authorization", f"Bearer {token}".encode())],
|
||||||
|
}
|
||||||
|
request = Request(scope)
|
||||||
|
assert _extract_user_email(request) == "test@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMatchEndpoint:
|
||||||
|
"""Tests for _match_endpoint."""
|
||||||
|
|
||||||
|
def test_exact_match(self) -> None:
|
||||||
|
config = _make_config()
|
||||||
|
limit = _match_endpoint("/api/listing", config)
|
||||||
|
assert limit is not None
|
||||||
|
assert limit.max_requests == 3
|
||||||
|
|
||||||
|
def test_passkey_prefix_match(self) -> None:
|
||||||
|
config = _make_config()
|
||||||
|
limit = _match_endpoint("/api/passkey/login/begin", config)
|
||||||
|
assert limit is not None
|
||||||
|
assert limit.max_requests == 2
|
||||||
|
|
||||||
|
def test_no_match(self) -> None:
|
||||||
|
config = _make_config()
|
||||||
|
assert _match_endpoint("/api/unknown", config) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitMiddleware:
|
||||||
|
"""Integration tests for RateLimitMiddleware."""
|
||||||
|
|
||||||
|
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||||
|
def test_allows_requests_under_limit(self, mock_get_redis: mock.MagicMock) -> None:
|
||||||
|
mock_redis = mock.MagicMock()
|
||||||
|
mock_pipe = mock.MagicMock()
|
||||||
|
mock_pipe.execute.return_value = [1, -1] # first request, no TTL yet
|
||||||
|
mock_redis.pipeline.return_value = mock_pipe
|
||||||
|
mock_redis.ping.return_value = True
|
||||||
|
mock_get_redis.return_value = mock_redis
|
||||||
|
|
||||||
|
config = _make_config()
|
||||||
|
app = _build_app(config)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
resp = client.get("/api/listing")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "X-RateLimit-Limit" in resp.headers
|
||||||
|
assert resp.headers["X-RateLimit-Limit"] == "3"
|
||||||
|
assert resp.headers["X-RateLimit-Remaining"] == "2"
|
||||||
|
|
||||||
|
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||||
|
def test_returns_429_over_limit(self, mock_get_redis: mock.MagicMock) -> None:
|
||||||
|
mock_redis = mock.MagicMock()
|
||||||
|
mock_pipe = mock.MagicMock()
|
||||||
|
# 4th request in window (limit=3), TTL=45s remaining
|
||||||
|
mock_pipe.execute.return_value = [4, 45]
|
||||||
|
mock_redis.pipeline.return_value = mock_pipe
|
||||||
|
mock_redis.ping.return_value = True
|
||||||
|
mock_get_redis.return_value = mock_redis
|
||||||
|
|
||||||
|
config = _make_config()
|
||||||
|
app = _build_app(config)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
resp = client.get("/api/listing")
|
||||||
|
assert resp.status_code == 429
|
||||||
|
assert resp.headers["Retry-After"] == "45"
|
||||||
|
assert resp.headers["X-RateLimit-Remaining"] == "0"
|
||||||
|
|
||||||
|
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||||
|
def test_exempt_paths_skip_rate_limiting(self, mock_get_redis: mock.MagicMock) -> None:
|
||||||
|
mock_redis = mock.MagicMock()
|
||||||
|
mock_redis.ping.return_value = True
|
||||||
|
mock_get_redis.return_value = mock_redis
|
||||||
|
|
||||||
|
config = _make_config()
|
||||||
|
app = _build_app(config)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
resp = client.get("/api/status")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "X-RateLimit-Limit" not in resp.headers
|
||||||
|
# Redis pipeline should never be called for exempt paths
|
||||||
|
mock_redis.pipeline.assert_not_called()
|
||||||
|
|
||||||
|
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||||
|
def test_fails_open_on_redis_error(self, mock_get_redis: mock.MagicMock) -> None:
|
||||||
|
"""When Redis raises an error, requests should be allowed through."""
|
||||||
|
import redis
|
||||||
|
|
||||||
|
mock_redis = mock.MagicMock()
|
||||||
|
mock_pipe = mock.MagicMock()
|
||||||
|
mock_pipe.execute.side_effect = redis.RedisError("connection lost")
|
||||||
|
mock_redis.pipeline.return_value = mock_pipe
|
||||||
|
mock_redis.ping.return_value = True
|
||||||
|
mock_get_redis.return_value = mock_redis
|
||||||
|
|
||||||
|
config = _make_config()
|
||||||
|
app = _build_app(config)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
resp = client.get("/api/listing")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||||
|
def test_fails_open_when_redis_unavailable_at_startup(self, mock_get_redis: mock.MagicMock) -> None:
|
||||||
|
"""When Redis is unavailable at startup, requests pass through."""
|
||||||
|
import redis
|
||||||
|
|
||||||
|
mock_get_redis.side_effect = redis.RedisError("connection refused")
|
||||||
|
|
||||||
|
config = _make_config()
|
||||||
|
app = _build_app(config)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
resp = client.get("/api/listing")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||||
|
def test_unmatched_endpoint_skips_limiting(self, mock_get_redis: mock.MagicMock) -> None:
|
||||||
|
"""Endpoints not in the config are not rate limited."""
|
||||||
|
mock_redis = mock.MagicMock()
|
||||||
|
mock_redis.ping.return_value = True
|
||||||
|
mock_get_redis.return_value = mock_redis
|
||||||
|
|
||||||
|
config = _make_config()
|
||||||
|
app = Starlette(routes=[Route("/api/unknown", _ok_endpoint)])
|
||||||
|
app.add_middleware(RateLimitMiddleware, config=config)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
resp = client.get("/api/unknown")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "X-RateLimit-Limit" not in resp.headers
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitConfig:
|
||||||
|
"""Tests for RateLimitConfig.from_env."""
|
||||||
|
|
||||||
|
def test_defaults(self) -> None:
|
||||||
|
with mock.patch.dict("os.environ", {}, clear=True):
|
||||||
|
config = RateLimitConfig.from_env()
|
||||||
|
assert config.listing_limit_cap == 100
|
||||||
|
assert config.geojson_limit_cap == 5000
|
||||||
|
assert config.rate_limit_redis_db == 3
|
||||||
|
|
||||||
|
def test_custom_env_vars(self) -> None:
|
||||||
|
env = {
|
||||||
|
"RATE_LIMIT_LISTING": "50/120",
|
||||||
|
"EXPORT_LISTING_LIMIT_CAP": "200",
|
||||||
|
"RATE_LIMIT_REDIS_DB": "5",
|
||||||
|
"METRICS_ALLOWED_IPS": "10.0.0.1",
|
||||||
|
}
|
||||||
|
with mock.patch.dict("os.environ", env, clear=True):
|
||||||
|
config = RateLimitConfig.from_env()
|
||||||
|
assert config.endpoint_limits["/api/listing"].max_requests == 50
|
||||||
|
assert config.endpoint_limits["/api/listing"].window_seconds == 120
|
||||||
|
assert config.listing_limit_cap == 200
|
||||||
|
assert config.rate_limit_redis_db == 5
|
||||||
|
assert config.metrics_allowed_ips == "10.0.0.1"
|
||||||
|
|
||||||
|
def test_invalid_limit_format_uses_defaults(self) -> None:
|
||||||
|
env = {"RATE_LIMIT_LISTING": "invalid"}
|
||||||
|
with mock.patch.dict("os.environ", env, clear=True):
|
||||||
|
config = RateLimitConfig.from_env()
|
||||||
|
# Should fall back to default
|
||||||
|
assert config.endpoint_limits["/api/listing"].max_requests == 30
|
||||||
|
assert config.endpoint_limits["/api/listing"].window_seconds == 60
|
||||||
Loading…
Add table
Add a link
Reference in a new issue