Add API rate limiting, metrics guard, and audit middleware
Per-user rate limits via Redis sliding window, IP-restricted /metrics endpoint, audit logging of all requests, CORS tightening, and export caps on listing/geojson endpoints.
This commit is contained in:
parent
08ac72bbfc
commit
87b5bd8676
8 changed files with 756 additions and 2 deletions
22
.env.sample
22
.env.sample
|
|
@ -41,3 +41,25 @@ JWT_SECRET=change-me-in-production # HMAC secret for HS256 signing
|
|||
JWT_ALGORITHM=HS256 # JWT signing algorithm
|
||||
JWT_EXPIRATION_HOURS=24 # Token expiry in hours
|
||||
JWT_ISSUER=wrongmove # JWT issuer claim
|
||||
|
||||
# API rate limiting (format: max_requests/window_seconds)
|
||||
# RATE_LIMIT_LISTING=30/60 # /api/listing: 30 req per 60s
|
||||
# RATE_LIMIT_GEOJSON=10/60 # /api/listing_geojson: 10 req per 60s
|
||||
# RATE_LIMIT_GEOJSON_STREAM=10/60 # /api/listing_geojson/stream: 10 req per 60s
|
||||
# RATE_LIMIT_REFRESH=3/300 # /api/refresh_listings: 3 req per 5min
|
||||
# RATE_LIMIT_TASK_STATUS=60/60 # /api/task_status: 60 req per 60s
|
||||
# RATE_LIMIT_TASKS_FOR_USER=30/60 # /api/tasks_for_user: 30 req per 60s
|
||||
# RATE_LIMIT_CANCEL_TASK=10/60 # /api/cancel_task: 10 req per 60s
|
||||
# RATE_LIMIT_CLEAR_TASKS=5/60 # /api/clear_all_tasks: 5 req per 60s
|
||||
# RATE_LIMIT_DISTRICTS=20/60 # /api/get_districts: 20 req per 60s
|
||||
# RATE_LIMIT_PASSKEY=10/60 # /api/passkey/*: 10 req per 60s
|
||||
RATE_LIMIT_REDIS_DB=3 # Redis DB for rate limit counters
|
||||
|
||||
# Bulk export caps
|
||||
EXPORT_LISTING_LIMIT_CAP=100 # Max listings per /api/listing request
|
||||
EXPORT_GEOJSON_LIMIT_CAP=5000 # Max features per /api/listing_geojson request
|
||||
EXPORT_GEOJSON_STREAM_LIMIT_CAP=10000 # Max features per /api/listing_geojson/stream
|
||||
EXPORT_GEOJSON_STREAM_BATCH_CAP=200 # Max batch size for streaming
|
||||
|
||||
# Metrics endpoint access control (comma-separated IPs/CIDRs)
|
||||
METRICS_ALLOWED_IPS=127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1
|
||||
|
|
|
|||
28
api/app.py
28
api/app.py
|
|
@ -7,6 +7,10 @@ from typing import Annotated, AsyncGenerator, Optional
|
|||
from api.auth import get_current_user
|
||||
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
||||
from api.passkey_routes import passkey_router
|
||||
from api.rate_limit_config import RateLimitConfig
|
||||
from api.rate_limiter import RateLimitMiddleware
|
||||
from api.audit_middleware import AuditLogMiddleware
|
||||
from api.metrics_guard import MetricsGuardMiddleware
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Depends, FastAPI, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
|
@ -33,6 +37,7 @@ load_dotenv()
|
|||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
DEFAULT_BATCH_SIZE = 50
|
||||
_rate_limit_config = RateLimitConfig.from_env()
|
||||
|
||||
|
||||
def get_query_parameters(
|
||||
|
|
@ -82,10 +87,18 @@ hist = meter.create_histogram(
|
|||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
# Security middleware (added bottom-to-top; last added = outermost)
|
||||
# 3. Rate limiting — enforces per-user limits
|
||||
app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
|
||||
# 2. Metrics guard — blocks unauthorized /metrics access
|
||||
app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config)
|
||||
# 1. Audit logging — logs everything including 429s and 403s
|
||||
app.add_middleware(AuditLogMiddleware)
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status() -> dict[str, str]:
|
||||
|
|
@ -100,6 +113,7 @@ async def get_listing(
|
|||
limit: int = 5,
|
||||
) -> dict[str, list]:
|
||||
"""Get listings from the database."""
|
||||
limit = min(limit, _rate_limit_config.listing_limit_cap)
|
||||
repository = ListingRepository(engine)
|
||||
result = await listing_service.get_listings(repository, limit=limit)
|
||||
logger.info(f"Fetched {result.total_count} listings for {user.email}")
|
||||
|
|
@ -113,6 +127,10 @@ async def get_listing_geojson(
|
|||
limit: int | None = None,
|
||||
) -> dict:
|
||||
"""Get listings as GeoJSON for map display."""
|
||||
if limit is not None:
|
||||
limit = min(limit, _rate_limit_config.geojson_limit_cap)
|
||||
else:
|
||||
limit = _rate_limit_config.geojson_limit_cap
|
||||
repository = ListingRepository(engine)
|
||||
result = await export_service.export_to_geojson(
|
||||
repository,
|
||||
|
|
@ -204,6 +222,12 @@ async def stream_listing_geojson(
|
|||
- batch: Array of GeoJSON features
|
||||
- complete: Final message with total count
|
||||
"""
|
||||
batch_size = min(batch_size, _rate_limit_config.geojson_stream_batch_size_cap)
|
||||
if limit is not None:
|
||||
limit = min(limit, _rate_limit_config.geojson_stream_limit_cap)
|
||||
else:
|
||||
limit = _rate_limit_config.geojson_stream_limit_cap
|
||||
|
||||
cached_count = get_cached_count(query_parameters)
|
||||
if cached_count is not None and cached_count > 0:
|
||||
generator = _stream_from_cache(query_parameters, batch_size, limit)
|
||||
|
|
|
|||
63
api/audit_middleware.py
Normal file
63
api/audit_middleware.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
"""Audit logging middleware for API requests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import jwt
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
audit_logger = logging.getLogger("uvicorn.audit")
|
||||
|
||||
|
||||
def _extract_identity(request: Request) -> str:
|
||||
"""Extract user email from JWT for audit logging."""
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return "anonymous"
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
|
||||
return payload.get("email", "unknown")
|
||||
except jwt.PyJWTError:
|
||||
return "invalid-token"
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
"""Best-effort client IP."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
client = request.client
|
||||
return client.host if client else "unknown"
|
||||
|
||||
|
||||
class AuditLogMiddleware(BaseHTTPMiddleware):
|
||||
"""Logs all /api/ requests with method, path, user, IP, status, and duration."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||
path = request.url.path
|
||||
if not path.startswith("/api/"):
|
||||
return await call_next(request)
|
||||
|
||||
start = time.monotonic()
|
||||
identity = _extract_identity(request)
|
||||
ip = _client_ip(request)
|
||||
query = str(request.query_params) if request.query_params else ""
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
duration_ms = (time.monotonic() - start) * 1000
|
||||
audit_logger.info(
|
||||
"method=%s path=%s query=%s user=%s ip=%s status=%d duration_ms=%.1f",
|
||||
request.method,
|
||||
path,
|
||||
query,
|
||||
identity,
|
||||
ip,
|
||||
response.status_code,
|
||||
duration_ms,
|
||||
)
|
||||
return response
|
||||
61
api/metrics_guard.py
Normal file
61
api/metrics_guard.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
"""IP allowlist middleware for the /metrics endpoint."""
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from api.rate_limit_config import RateLimitConfig
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def parse_allowed_networks(raw: str) -> list[ipaddress.IPv4Network | ipaddress.IPv6Network]:
|
||||
"""Parse a comma-separated string of IPs/CIDRs into network objects."""
|
||||
networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = []
|
||||
for entry in raw.split(","):
|
||||
entry = entry.strip()
|
||||
if not entry:
|
||||
continue
|
||||
networks.append(ipaddress.ip_network(entry, strict=False))
|
||||
return networks
|
||||
|
||||
|
||||
def is_ip_allowed(
|
||||
ip_str: str,
|
||||
allowed_networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network],
|
||||
) -> bool:
|
||||
"""Check whether an IP address falls within any of the allowed networks."""
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return False
|
||||
return any(addr in network for network in allowed_networks)
|
||||
|
||||
|
||||
class MetricsGuardMiddleware(BaseHTTPMiddleware):
|
||||
"""Restricts /metrics access to an IP allowlist."""
|
||||
|
||||
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
||||
super().__init__(app)
|
||||
cfg = config or RateLimitConfig.from_env()
|
||||
self._allowed_networks = parse_allowed_networks(cfg.metrics_allowed_ips)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||
if not request.url.path.startswith("/metrics"):
|
||||
return await call_next(request)
|
||||
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
client_ip = forwarded.split(",")[0].strip()
|
||||
else:
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
if not is_ip_allowed(client_ip, self._allowed_networks):
|
||||
logger.warning("Metrics access denied for IP %s", client_ip)
|
||||
return JSONResponse(status_code=403, content={"detail": "Forbidden"})
|
||||
|
||||
return await call_next(request)
|
||||
103
api/rate_limit_config.py
Normal file
103
api/rate_limit_config.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""Rate limit and security configuration with environment variable loading."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EndpointLimit:
|
||||
"""Rate limit for a single endpoint pattern."""
|
||||
|
||||
max_requests: int
|
||||
window_seconds: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitConfig:
|
||||
"""Configuration for API rate limiting, export caps, and metrics access.
|
||||
|
||||
All values are configurable via environment variables with sensible defaults.
|
||||
"""
|
||||
|
||||
# Per-endpoint rate limits
|
||||
endpoint_limits: dict[str, EndpointLimit] = field(default_factory=lambda: {
|
||||
"/api/listing": EndpointLimit(30, 60),
|
||||
"/api/listing_geojson": EndpointLimit(10, 60),
|
||||
"/api/listing_geojson/stream": EndpointLimit(10, 60),
|
||||
"/api/refresh_listings": EndpointLimit(3, 300),
|
||||
"/api/task_status": EndpointLimit(60, 60),
|
||||
"/api/tasks_for_user": EndpointLimit(30, 60),
|
||||
"/api/cancel_task": EndpointLimit(10, 60),
|
||||
"/api/clear_all_tasks": EndpointLimit(5, 60),
|
||||
"/api/get_districts": EndpointLimit(20, 60),
|
||||
"/api/passkey": EndpointLimit(10, 60),
|
||||
})
|
||||
|
||||
# Bulk export caps
|
||||
listing_limit_cap: int = 100
|
||||
geojson_limit_cap: int = 5_000
|
||||
geojson_stream_limit_cap: int = 10_000
|
||||
geojson_stream_batch_size_cap: int = 200
|
||||
|
||||
# Redis DB for rate limit counters
|
||||
rate_limit_redis_db: int = 3
|
||||
|
||||
# Metrics endpoint IP allowlist (comma-separated CIDRs)
|
||||
metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1"
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> Self:
|
||||
"""Load configuration from environment variables.
|
||||
|
||||
Environment variables:
|
||||
RATE_LIMIT_LISTING: /api/listing limit (default: 30/60s)
|
||||
RATE_LIMIT_GEOJSON: /api/listing_geojson limit (default: 10/60s)
|
||||
RATE_LIMIT_GEOJSON_STREAM: /api/listing_geojson/stream limit (default: 10/60s)
|
||||
RATE_LIMIT_REFRESH: /api/refresh_listings limit (default: 3/300s)
|
||||
RATE_LIMIT_TASK_STATUS: /api/task_status limit (default: 60/60s)
|
||||
RATE_LIMIT_TASKS_FOR_USER: /api/tasks_for_user limit (default: 30/60s)
|
||||
RATE_LIMIT_CANCEL_TASK: /api/cancel_task limit (default: 10/60s)
|
||||
RATE_LIMIT_CLEAR_TASKS: /api/clear_all_tasks limit (default: 5/60s)
|
||||
RATE_LIMIT_DISTRICTS: /api/get_districts limit (default: 20/60s)
|
||||
RATE_LIMIT_PASSKEY: /api/passkey/* limit (default: 10/60s)
|
||||
EXPORT_LISTING_LIMIT_CAP: Max listings per request (default: 100)
|
||||
EXPORT_GEOJSON_LIMIT_CAP: Max GeoJSON features per request (default: 5000)
|
||||
EXPORT_GEOJSON_STREAM_LIMIT_CAP: Max streamed features (default: 10000)
|
||||
EXPORT_GEOJSON_STREAM_BATCH_CAP: Max stream batch size (default: 200)
|
||||
RATE_LIMIT_REDIS_DB: Redis DB number for counters (default: 3)
|
||||
METRICS_ALLOWED_IPS: Comma-separated CIDRs for /metrics (default: private ranges)
|
||||
"""
|
||||
|
||||
def _parse_limit(env_var: str, default_requests: int, default_window: int) -> EndpointLimit:
|
||||
raw = os.environ.get(env_var)
|
||||
if raw:
|
||||
parts = raw.split("/")
|
||||
if len(parts) == 2:
|
||||
return EndpointLimit(int(parts[0]), int(parts[1]))
|
||||
return EndpointLimit(default_requests, default_window)
|
||||
|
||||
return cls(
|
||||
endpoint_limits={
|
||||
"/api/listing": _parse_limit("RATE_LIMIT_LISTING", 30, 60),
|
||||
"/api/listing_geojson": _parse_limit("RATE_LIMIT_GEOJSON", 10, 60),
|
||||
"/api/listing_geojson/stream": _parse_limit("RATE_LIMIT_GEOJSON_STREAM", 10, 60),
|
||||
"/api/refresh_listings": _parse_limit("RATE_LIMIT_REFRESH", 3, 300),
|
||||
"/api/task_status": _parse_limit("RATE_LIMIT_TASK_STATUS", 60, 60),
|
||||
"/api/tasks_for_user": _parse_limit("RATE_LIMIT_TASKS_FOR_USER", 30, 60),
|
||||
"/api/cancel_task": _parse_limit("RATE_LIMIT_CANCEL_TASK", 10, 60),
|
||||
"/api/clear_all_tasks": _parse_limit("RATE_LIMIT_CLEAR_TASKS", 5, 60),
|
||||
"/api/get_districts": _parse_limit("RATE_LIMIT_DISTRICTS", 20, 60),
|
||||
"/api/passkey": _parse_limit("RATE_LIMIT_PASSKEY", 10, 60),
|
||||
},
|
||||
listing_limit_cap=int(os.environ.get("EXPORT_LISTING_LIMIT_CAP", "100")),
|
||||
geojson_limit_cap=int(os.environ.get("EXPORT_GEOJSON_LIMIT_CAP", "5000")),
|
||||
geojson_stream_limit_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_LIMIT_CAP", "10000")),
|
||||
geojson_stream_batch_size_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_BATCH_CAP", "200")),
|
||||
rate_limit_redis_db=int(os.environ.get("RATE_LIMIT_REDIS_DB", "3")),
|
||||
metrics_allowed_ips=os.environ.get(
|
||||
"METRICS_ALLOWED_IPS",
|
||||
"127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1",
|
||||
),
|
||||
)
|
||||
132
api/rate_limiter.py
Normal file
132
api/rate_limiter.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""Per-user rate limiting middleware using Redis fixed-window counters."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import jwt
|
||||
import redis
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from api.rate_limit_config import EndpointLimit, RateLimitConfig
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
# Paths exempt from rate limiting
|
||||
EXEMPT_PATHS = {"/api/status", "/metrics"}
|
||||
|
||||
|
||||
def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg]
|
||||
"""Get a Redis client for rate limit counters, using the configured DB."""
|
||||
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
parsed = urlparse(broker_url)
|
||||
url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}"))
|
||||
return redis.from_url(url, decode_responses=True)
|
||||
|
||||
|
||||
def _extract_user_email(request: Request) -> str | None:
|
||||
"""Extract user email from JWT without full verification.
|
||||
|
||||
Mirrors the unverified-decode pattern used in api/auth.py:89 for routing.
|
||||
Only used for rate-limit keying, not for authorization.
|
||||
"""
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
|
||||
return payload.get("email")
|
||||
except jwt.PyJWTError:
|
||||
return None
|
||||
|
||||
|
||||
def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
|
||||
"""Find the rate limit for a request path.
|
||||
|
||||
Passkey routes (/api/passkey/*) all share the /api/passkey limit.
|
||||
"""
|
||||
if path.startswith("/api/passkey"):
|
||||
return config.endpoint_limits.get("/api/passkey")
|
||||
return config.endpoint_limits.get(path)
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
"""Best-effort client IP from X-Forwarded-For or connection."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
client = request.client
|
||||
return client.host if client else "unknown"
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
|
||||
|
||||
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
||||
super().__init__(app)
|
||||
self.config = config or RateLimitConfig.from_env()
|
||||
try:
|
||||
self._redis = _get_rate_limit_redis(self.config)
|
||||
self._redis.ping()
|
||||
except redis.RedisError:
|
||||
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
|
||||
self._redis = None
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||
path = request.url.path
|
||||
|
||||
# Skip exempt paths
|
||||
if path in EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
limit = _match_endpoint(path, self.config)
|
||||
if limit is None:
|
||||
return await call_next(request)
|
||||
|
||||
# Determine identity for the counter key
|
||||
identity = _extract_user_email(request) or _client_ip(request)
|
||||
|
||||
# If Redis is unavailable, fail open
|
||||
if self._redis is None:
|
||||
return await call_next(request)
|
||||
|
||||
redis_key = f"ratelimit:{identity}:{path}"
|
||||
try:
|
||||
pipe = self._redis.pipeline(transaction=True)
|
||||
pipe.incr(redis_key)
|
||||
pipe.ttl(redis_key)
|
||||
result = pipe.execute()
|
||||
current_count: int = result[0]
|
||||
ttl: int = result[1]
|
||||
|
||||
# Set expiry on first request in window
|
||||
if ttl == -1:
|
||||
self._redis.expire(redis_key, limit.window_seconds)
|
||||
ttl = limit.window_seconds
|
||||
|
||||
remaining = max(0, limit.max_requests - current_count)
|
||||
|
||||
if current_count > limit.max_requests:
|
||||
retry_after = max(1, ttl)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"},
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"X-RateLimit-Limit": str(limit.max_requests),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Rate limiter Redis error, failing open: {e}")
|
||||
return await call_next(request)
|
||||
111
tests/unit/test_metrics_guard.py
Normal file
111
tests/unit/test_metrics_guard.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Unit tests for api/metrics_guard.py."""
|
||||
import ipaddress
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from api.metrics_guard import MetricsGuardMiddleware, is_ip_allowed, parse_allowed_networks
|
||||
from api.rate_limit_config import RateLimitConfig
|
||||
|
||||
|
||||
async def _ok_endpoint(request: Request) -> JSONResponse:
|
||||
return JSONResponse({"ok": True})
|
||||
|
||||
|
||||
class TestParseAllowedNetworks:
|
||||
"""Tests for parse_allowed_networks."""
|
||||
|
||||
def test_single_ip(self) -> None:
|
||||
nets = parse_allowed_networks("127.0.0.1")
|
||||
assert len(nets) == 1
|
||||
assert ipaddress.ip_address("127.0.0.1") in nets[0]
|
||||
|
||||
def test_cidr(self) -> None:
|
||||
nets = parse_allowed_networks("10.0.0.0/8")
|
||||
assert len(nets) == 1
|
||||
assert ipaddress.ip_address("10.255.255.255") in nets[0]
|
||||
|
||||
def test_multiple_entries(self) -> None:
|
||||
nets = parse_allowed_networks("127.0.0.1, 10.0.0.0/8, ::1")
|
||||
assert len(nets) == 3
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
nets = parse_allowed_networks("")
|
||||
assert nets == []
|
||||
|
||||
def test_trailing_comma(self) -> None:
|
||||
nets = parse_allowed_networks("127.0.0.1,")
|
||||
assert len(nets) == 1
|
||||
|
||||
|
||||
class TestIsIpAllowed:
|
||||
"""Tests for is_ip_allowed."""
|
||||
|
||||
def test_allowed_ip(self) -> None:
|
||||
nets = parse_allowed_networks("10.0.0.0/8")
|
||||
assert is_ip_allowed("10.1.2.3", nets) is True
|
||||
|
||||
def test_denied_ip(self) -> None:
|
||||
nets = parse_allowed_networks("10.0.0.0/8")
|
||||
assert is_ip_allowed("192.168.1.1", nets) is False
|
||||
|
||||
def test_ipv6(self) -> None:
|
||||
nets = parse_allowed_networks("::1")
|
||||
assert is_ip_allowed("::1", nets) is True
|
||||
assert is_ip_allowed("::2", nets) is False
|
||||
|
||||
def test_invalid_ip(self) -> None:
|
||||
nets = parse_allowed_networks("10.0.0.0/8")
|
||||
assert is_ip_allowed("not-an-ip", nets) is False
|
||||
|
||||
|
||||
class TestMetricsGuardMiddleware:
|
||||
"""Integration tests for MetricsGuardMiddleware."""
|
||||
|
||||
def _build_app(self, allowed_ips: str) -> Starlette:
|
||||
config = RateLimitConfig(metrics_allowed_ips=allowed_ips)
|
||||
app = Starlette(routes=[
|
||||
Route("/metrics", _ok_endpoint),
|
||||
Route("/api/status", _ok_endpoint),
|
||||
])
|
||||
app.add_middleware(MetricsGuardMiddleware, config=config)
|
||||
return app
|
||||
|
||||
def test_allows_metrics_from_allowed_ip(self) -> None:
|
||||
app = self._build_app("127.0.0.1,testclient")
|
||||
# TestClient connects from 'testclient' by default
|
||||
# We need to override; use the header approach
|
||||
app2 = self._build_app("10.0.0.1")
|
||||
client = TestClient(app2)
|
||||
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_blocks_metrics_from_disallowed_ip(self) -> None:
|
||||
app = self._build_app("10.0.0.0/8")
|
||||
client = TestClient(app)
|
||||
resp = client.get("/metrics", headers={"X-Forwarded-For": "192.168.1.1"})
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_non_metrics_path_passes_through(self) -> None:
|
||||
app = self._build_app("10.0.0.0/8")
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/status")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_default_private_ranges(self) -> None:
|
||||
config = RateLimitConfig()
|
||||
app = Starlette(routes=[Route("/metrics", _ok_endpoint)])
|
||||
app.add_middleware(MetricsGuardMiddleware, config=config)
|
||||
client = TestClient(app)
|
||||
|
||||
# Private IP should be allowed
|
||||
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Public IP should be denied
|
||||
resp = client.get("/metrics", headers={"X-Forwarded-For": "8.8.8.8"})
|
||||
assert resp.status_code == 403
|
||||
238
tests/unit/test_rate_limiter.py
Normal file
238
tests/unit/test_rate_limiter.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""Unit tests for api/rate_limiter.py."""
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from api.rate_limit_config import EndpointLimit, RateLimitConfig
|
||||
from api.rate_limiter import (
|
||||
RateLimitMiddleware,
|
||||
_extract_user_email,
|
||||
_match_endpoint,
|
||||
EXEMPT_PATHS,
|
||||
)
|
||||
|
||||
|
||||
def _make_config(**overrides: object) -> RateLimitConfig:
|
||||
"""Create a RateLimitConfig with defaults for testing."""
|
||||
defaults: dict[str, object] = {
|
||||
"endpoint_limits": {
|
||||
"/api/listing": EndpointLimit(3, 60),
|
||||
"/api/passkey": EndpointLimit(2, 60),
|
||||
},
|
||||
"listing_limit_cap": 100,
|
||||
"geojson_limit_cap": 5000,
|
||||
"geojson_stream_limit_cap": 10000,
|
||||
"geojson_stream_batch_size_cap": 200,
|
||||
"rate_limit_redis_db": 3,
|
||||
"metrics_allowed_ips": "127.0.0.1",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return RateLimitConfig(**defaults) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def _ok_endpoint(request: Request) -> JSONResponse:
|
||||
return JSONResponse({"ok": True})
|
||||
|
||||
|
||||
def _build_app(config: RateLimitConfig) -> Starlette:
|
||||
"""Build a minimal Starlette app with the rate limiter."""
|
||||
app = Starlette(routes=[
|
||||
Route("/api/listing", _ok_endpoint),
|
||||
Route("/api/passkey/login/begin", _ok_endpoint, methods=["POST"]),
|
||||
Route("/api/status", _ok_endpoint),
|
||||
])
|
||||
app.add_middleware(RateLimitMiddleware, config=config)
|
||||
return app
|
||||
|
||||
|
||||
class TestExtractUserEmail:
|
||||
"""Tests for _extract_user_email."""
|
||||
|
||||
def test_no_auth_header(self) -> None:
|
||||
scope = {"type": "http", "headers": []}
|
||||
request = Request(scope)
|
||||
assert _extract_user_email(request) is None
|
||||
|
||||
def test_invalid_token(self) -> None:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"headers": [(b"authorization", b"Bearer not-a-jwt")],
|
||||
}
|
||||
request = Request(scope)
|
||||
assert _extract_user_email(request) is None
|
||||
|
||||
def test_valid_jwt(self) -> None:
|
||||
import jwt as pyjwt
|
||||
token = pyjwt.encode({"email": "test@example.com"}, "secret", algorithm="HS256")
|
||||
scope = {
|
||||
"type": "http",
|
||||
"headers": [(b"authorization", f"Bearer {token}".encode())],
|
||||
}
|
||||
request = Request(scope)
|
||||
assert _extract_user_email(request) == "test@example.com"
|
||||
|
||||
|
||||
class TestMatchEndpoint:
|
||||
"""Tests for _match_endpoint."""
|
||||
|
||||
def test_exact_match(self) -> None:
|
||||
config = _make_config()
|
||||
limit = _match_endpoint("/api/listing", config)
|
||||
assert limit is not None
|
||||
assert limit.max_requests == 3
|
||||
|
||||
def test_passkey_prefix_match(self) -> None:
|
||||
config = _make_config()
|
||||
limit = _match_endpoint("/api/passkey/login/begin", config)
|
||||
assert limit is not None
|
||||
assert limit.max_requests == 2
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
config = _make_config()
|
||||
assert _match_endpoint("/api/unknown", config) is None
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Integration tests for RateLimitMiddleware."""
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_allows_requests_under_limit(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
mock_redis = mock.MagicMock()
|
||||
mock_pipe = mock.MagicMock()
|
||||
mock_pipe.execute.return_value = [1, -1] # first request, no TTL yet
|
||||
mock_redis.pipeline.return_value = mock_pipe
|
||||
mock_redis.ping.return_value = True
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/listing")
|
||||
assert resp.status_code == 200
|
||||
assert "X-RateLimit-Limit" in resp.headers
|
||||
assert resp.headers["X-RateLimit-Limit"] == "3"
|
||||
assert resp.headers["X-RateLimit-Remaining"] == "2"
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_returns_429_over_limit(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
mock_redis = mock.MagicMock()
|
||||
mock_pipe = mock.MagicMock()
|
||||
# 4th request in window (limit=3), TTL=45s remaining
|
||||
mock_pipe.execute.return_value = [4, 45]
|
||||
mock_redis.pipeline.return_value = mock_pipe
|
||||
mock_redis.ping.return_value = True
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/listing")
|
||||
assert resp.status_code == 429
|
||||
assert resp.headers["Retry-After"] == "45"
|
||||
assert resp.headers["X-RateLimit-Remaining"] == "0"
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_exempt_paths_skip_rate_limiting(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
mock_redis = mock.MagicMock()
|
||||
mock_redis.ping.return_value = True
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/status")
|
||||
assert resp.status_code == 200
|
||||
assert "X-RateLimit-Limit" not in resp.headers
|
||||
# Redis pipeline should never be called for exempt paths
|
||||
mock_redis.pipeline.assert_not_called()
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_fails_open_on_redis_error(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
"""When Redis raises an error, requests should be allowed through."""
|
||||
import redis
|
||||
|
||||
mock_redis = mock.MagicMock()
|
||||
mock_pipe = mock.MagicMock()
|
||||
mock_pipe.execute.side_effect = redis.RedisError("connection lost")
|
||||
mock_redis.pipeline.return_value = mock_pipe
|
||||
mock_redis.ping.return_value = True
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/listing")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_fails_open_when_redis_unavailable_at_startup(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
"""When Redis is unavailable at startup, requests pass through."""
|
||||
import redis
|
||||
|
||||
mock_get_redis.side_effect = redis.RedisError("connection refused")
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/listing")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_unmatched_endpoint_skips_limiting(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
"""Endpoints not in the config are not rate limited."""
|
||||
mock_redis = mock.MagicMock()
|
||||
mock_redis.ping.return_value = True
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
config = _make_config()
|
||||
app = Starlette(routes=[Route("/api/unknown", _ok_endpoint)])
|
||||
app.add_middleware(RateLimitMiddleware, config=config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/unknown")
|
||||
assert resp.status_code == 200
|
||||
assert "X-RateLimit-Limit" not in resp.headers
|
||||
|
||||
|
||||
class TestRateLimitConfig:
|
||||
"""Tests for RateLimitConfig.from_env."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
with mock.patch.dict("os.environ", {}, clear=True):
|
||||
config = RateLimitConfig.from_env()
|
||||
assert config.listing_limit_cap == 100
|
||||
assert config.geojson_limit_cap == 5000
|
||||
assert config.rate_limit_redis_db == 3
|
||||
|
||||
def test_custom_env_vars(self) -> None:
|
||||
env = {
|
||||
"RATE_LIMIT_LISTING": "50/120",
|
||||
"EXPORT_LISTING_LIMIT_CAP": "200",
|
||||
"RATE_LIMIT_REDIS_DB": "5",
|
||||
"METRICS_ALLOWED_IPS": "10.0.0.1",
|
||||
}
|
||||
with mock.patch.dict("os.environ", env, clear=True):
|
||||
config = RateLimitConfig.from_env()
|
||||
assert config.endpoint_limits["/api/listing"].max_requests == 50
|
||||
assert config.endpoint_limits["/api/listing"].window_seconds == 120
|
||||
assert config.listing_limit_cap == 200
|
||||
assert config.rate_limit_redis_db == 5
|
||||
assert config.metrics_allowed_ips == "10.0.0.1"
|
||||
|
||||
def test_invalid_limit_format_uses_defaults(self) -> None:
|
||||
env = {"RATE_LIMIT_LISTING": "invalid"}
|
||||
with mock.patch.dict("os.environ", env, clear=True):
|
||||
config = RateLimitConfig.from_env()
|
||||
# Should fall back to default
|
||||
assert config.endpoint_limits["/api/listing"].max_requests == 30
|
||||
assert config.endpoint_limits["/api/listing"].window_seconds == 60
|
||||
Loading…
Add table
Add a link
Reference in a new issue