Add API rate limiting, metrics guard, and audit middleware

Per-user rate limits via Redis sliding window, IP-restricted /metrics
endpoint, audit logging of all requests, CORS tightening, and export
caps on listing/geojson endpoints.
This commit is contained in:
Viktor Barzin 2026-02-08 00:45:43 +00:00
parent 08ac72bbfc
commit 87b5bd8676
No known key found for this signature in database
GPG key ID: 0EB088298288D958
8 changed files with 756 additions and 2 deletions

View file

@ -41,3 +41,25 @@ JWT_SECRET=change-me-in-production # HMAC secret for HS256 signing
JWT_ALGORITHM=HS256 # JWT signing algorithm
JWT_EXPIRATION_HOURS=24 # Token expiry in hours
JWT_ISSUER=wrongmove # JWT issuer claim
# API rate limiting (format: max_requests/window_seconds)
# RATE_LIMIT_LISTING=30/60 # /api/listing: 30 req per 60s
# RATE_LIMIT_GEOJSON=10/60 # /api/listing_geojson: 10 req per 60s
# RATE_LIMIT_GEOJSON_STREAM=10/60 # /api/listing_geojson/stream: 10 req per 60s
# RATE_LIMIT_REFRESH=3/300 # /api/refresh_listings: 3 req per 5min
# RATE_LIMIT_TASK_STATUS=60/60 # /api/task_status: 60 req per 60s
# RATE_LIMIT_TASKS_FOR_USER=30/60 # /api/tasks_for_user: 30 req per 60s
# RATE_LIMIT_CANCEL_TASK=10/60 # /api/cancel_task: 10 req per 60s
# RATE_LIMIT_CLEAR_TASKS=5/60 # /api/clear_all_tasks: 5 req per 60s
# RATE_LIMIT_DISTRICTS=20/60 # /api/get_districts: 20 req per 60s
# RATE_LIMIT_PASSKEY=10/60 # /api/passkey/*: 10 req per 60s
RATE_LIMIT_REDIS_DB=3 # Redis DB for rate limit counters
# Bulk export caps
EXPORT_LISTING_LIMIT_CAP=100 # Max listings per /api/listing request
EXPORT_GEOJSON_LIMIT_CAP=5000 # Max features per /api/listing_geojson request
EXPORT_GEOJSON_STREAM_LIMIT_CAP=10000 # Max features per /api/listing_geojson/stream
EXPORT_GEOJSON_STREAM_BATCH_CAP=200 # Max batch size for streaming
# Metrics endpoint access control (comma-separated IPs/CIDRs)
METRICS_ALLOWED_IPS=127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1

View file

@ -7,6 +7,10 @@ from typing import Annotated, AsyncGenerator, Optional
from api.auth import get_current_user
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
from api.passkey_routes import passkey_router
from api.rate_limit_config import RateLimitConfig
from api.rate_limiter import RateLimitMiddleware
from api.audit_middleware import AuditLogMiddleware
from api.metrics_guard import MetricsGuardMiddleware
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Query
from fastapi.responses import StreamingResponse
@ -33,6 +37,7 @@ load_dotenv()
logger = logging.getLogger("uvicorn")
DEFAULT_BATCH_SIZE = 50
_rate_limit_config = RateLimitConfig.from_env()
def get_query_parameters(
@ -82,10 +87,18 @@ hist = meter.create_histogram(
app.add_middleware(
CORSMiddleware,
allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS],
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST"],
allow_headers=["Authorization", "Content-Type"],
)
# Security middleware (added bottom-to-top; last added = outermost)
# 3. Rate limiting — enforces per-user limits
app.add_middleware(RateLimitMiddleware, config=_rate_limit_config)
# 2. Metrics guard — blocks unauthorized /metrics access
app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config)
# 1. Audit logging — logs everything including 429s and 403s
app.add_middleware(AuditLogMiddleware)
@app.get("/api/status")
async def get_status() -> dict[str, str]:
@ -100,6 +113,7 @@ async def get_listing(
limit: int = 5,
) -> dict[str, list]:
"""Get listings from the database."""
limit = min(limit, _rate_limit_config.listing_limit_cap)
repository = ListingRepository(engine)
result = await listing_service.get_listings(repository, limit=limit)
logger.info(f"Fetched {result.total_count} listings for {user.email}")
@ -113,6 +127,10 @@ async def get_listing_geojson(
limit: int | None = None,
) -> dict:
"""Get listings as GeoJSON for map display."""
if limit is not None:
limit = min(limit, _rate_limit_config.geojson_limit_cap)
else:
limit = _rate_limit_config.geojson_limit_cap
repository = ListingRepository(engine)
result = await export_service.export_to_geojson(
repository,
@ -204,6 +222,12 @@ async def stream_listing_geojson(
- batch: Array of GeoJSON features
- complete: Final message with total count
"""
batch_size = min(batch_size, _rate_limit_config.geojson_stream_batch_size_cap)
if limit is not None:
limit = min(limit, _rate_limit_config.geojson_stream_limit_cap)
else:
limit = _rate_limit_config.geojson_stream_limit_cap
cached_count = get_cached_count(query_parameters)
if cached_count is not None and cached_count > 0:
generator = _stream_from_cache(query_parameters, batch_size, limit)

63
api/audit_middleware.py Normal file
View file

@ -0,0 +1,63 @@
"""Audit logging middleware for API requests."""
from __future__ import annotations
import logging
import time
import jwt
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
audit_logger = logging.getLogger("uvicorn.audit")
def _extract_identity(request: Request) -> str:
"""Extract user email from JWT for audit logging."""
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
return "anonymous"
token = auth_header[7:]
try:
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
return payload.get("email", "unknown")
except jwt.PyJWTError:
return "invalid-token"
def _client_ip(request: Request) -> str:
"""Best-effort client IP."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
client = request.client
return client.host if client else "unknown"
class AuditLogMiddleware(BaseHTTPMiddleware):
"""Logs all /api/ requests with method, path, user, IP, status, and duration."""
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
path = request.url.path
if not path.startswith("/api/"):
return await call_next(request)
start = time.monotonic()
identity = _extract_identity(request)
ip = _client_ip(request)
query = str(request.query_params) if request.query_params else ""
response = await call_next(request)
duration_ms = (time.monotonic() - start) * 1000
audit_logger.info(
"method=%s path=%s query=%s user=%s ip=%s status=%d duration_ms=%.1f",
request.method,
path,
query,
identity,
ip,
response.status_code,
duration_ms,
)
return response

61
api/metrics_guard.py Normal file
View file

@ -0,0 +1,61 @@
"""IP allowlist middleware for the /metrics endpoint."""
from __future__ import annotations
import ipaddress
import logging
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from api.rate_limit_config import RateLimitConfig
logger = logging.getLogger("uvicorn")
def parse_allowed_networks(raw: str) -> list[ipaddress.IPv4Network | ipaddress.IPv6Network]:
"""Parse a comma-separated string of IPs/CIDRs into network objects."""
networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = []
for entry in raw.split(","):
entry = entry.strip()
if not entry:
continue
networks.append(ipaddress.ip_network(entry, strict=False))
return networks
def is_ip_allowed(
ip_str: str,
allowed_networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network],
) -> bool:
"""Check whether an IP address falls within any of the allowed networks."""
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
return False
return any(addr in network for network in allowed_networks)
class MetricsGuardMiddleware(BaseHTTPMiddleware):
"""Restricts /metrics access to an IP allowlist."""
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
super().__init__(app)
cfg = config or RateLimitConfig.from_env()
self._allowed_networks = parse_allowed_networks(cfg.metrics_allowed_ips)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
if not request.url.path.startswith("/metrics"):
return await call_next(request)
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
client_ip = forwarded.split(",")[0].strip()
else:
client_ip = request.client.host if request.client else "unknown"
if not is_ip_allowed(client_ip, self._allowed_networks):
logger.warning("Metrics access denied for IP %s", client_ip)
return JSONResponse(status_code=403, content={"detail": "Forbidden"})
return await call_next(request)

103
api/rate_limit_config.py Normal file
View file

@ -0,0 +1,103 @@
"""Rate limit and security configuration with environment variable loading."""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import Self
@dataclass(frozen=True)
class EndpointLimit:
"""Rate limit for a single endpoint pattern."""
max_requests: int
window_seconds: int
@dataclass(frozen=True)
class RateLimitConfig:
"""Configuration for API rate limiting, export caps, and metrics access.
All values are configurable via environment variables with sensible defaults.
"""
# Per-endpoint rate limits
endpoint_limits: dict[str, EndpointLimit] = field(default_factory=lambda: {
"/api/listing": EndpointLimit(30, 60),
"/api/listing_geojson": EndpointLimit(10, 60),
"/api/listing_geojson/stream": EndpointLimit(10, 60),
"/api/refresh_listings": EndpointLimit(3, 300),
"/api/task_status": EndpointLimit(60, 60),
"/api/tasks_for_user": EndpointLimit(30, 60),
"/api/cancel_task": EndpointLimit(10, 60),
"/api/clear_all_tasks": EndpointLimit(5, 60),
"/api/get_districts": EndpointLimit(20, 60),
"/api/passkey": EndpointLimit(10, 60),
})
# Bulk export caps
listing_limit_cap: int = 100
geojson_limit_cap: int = 5_000
geojson_stream_limit_cap: int = 10_000
geojson_stream_batch_size_cap: int = 200
# Redis DB for rate limit counters
rate_limit_redis_db: int = 3
# Metrics endpoint IP allowlist (comma-separated CIDRs)
metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1"
@classmethod
def from_env(cls) -> Self:
"""Load configuration from environment variables.
Environment variables:
RATE_LIMIT_LISTING: /api/listing limit (default: 30/60s)
RATE_LIMIT_GEOJSON: /api/listing_geojson limit (default: 10/60s)
RATE_LIMIT_GEOJSON_STREAM: /api/listing_geojson/stream limit (default: 10/60s)
RATE_LIMIT_REFRESH: /api/refresh_listings limit (default: 3/300s)
RATE_LIMIT_TASK_STATUS: /api/task_status limit (default: 60/60s)
RATE_LIMIT_TASKS_FOR_USER: /api/tasks_for_user limit (default: 30/60s)
RATE_LIMIT_CANCEL_TASK: /api/cancel_task limit (default: 10/60s)
RATE_LIMIT_CLEAR_TASKS: /api/clear_all_tasks limit (default: 5/60s)
RATE_LIMIT_DISTRICTS: /api/get_districts limit (default: 20/60s)
RATE_LIMIT_PASSKEY: /api/passkey/* limit (default: 10/60s)
EXPORT_LISTING_LIMIT_CAP: Max listings per request (default: 100)
EXPORT_GEOJSON_LIMIT_CAP: Max GeoJSON features per request (default: 5000)
EXPORT_GEOJSON_STREAM_LIMIT_CAP: Max streamed features (default: 10000)
EXPORT_GEOJSON_STREAM_BATCH_CAP: Max stream batch size (default: 200)
RATE_LIMIT_REDIS_DB: Redis DB number for counters (default: 3)
METRICS_ALLOWED_IPS: Comma-separated CIDRs for /metrics (default: private ranges)
"""
def _parse_limit(env_var: str, default_requests: int, default_window: int) -> EndpointLimit:
raw = os.environ.get(env_var)
if raw:
parts = raw.split("/")
if len(parts) == 2:
return EndpointLimit(int(parts[0]), int(parts[1]))
return EndpointLimit(default_requests, default_window)
return cls(
endpoint_limits={
"/api/listing": _parse_limit("RATE_LIMIT_LISTING", 30, 60),
"/api/listing_geojson": _parse_limit("RATE_LIMIT_GEOJSON", 10, 60),
"/api/listing_geojson/stream": _parse_limit("RATE_LIMIT_GEOJSON_STREAM", 10, 60),
"/api/refresh_listings": _parse_limit("RATE_LIMIT_REFRESH", 3, 300),
"/api/task_status": _parse_limit("RATE_LIMIT_TASK_STATUS", 60, 60),
"/api/tasks_for_user": _parse_limit("RATE_LIMIT_TASKS_FOR_USER", 30, 60),
"/api/cancel_task": _parse_limit("RATE_LIMIT_CANCEL_TASK", 10, 60),
"/api/clear_all_tasks": _parse_limit("RATE_LIMIT_CLEAR_TASKS", 5, 60),
"/api/get_districts": _parse_limit("RATE_LIMIT_DISTRICTS", 20, 60),
"/api/passkey": _parse_limit("RATE_LIMIT_PASSKEY", 10, 60),
},
listing_limit_cap=int(os.environ.get("EXPORT_LISTING_LIMIT_CAP", "100")),
geojson_limit_cap=int(os.environ.get("EXPORT_GEOJSON_LIMIT_CAP", "5000")),
geojson_stream_limit_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_LIMIT_CAP", "10000")),
geojson_stream_batch_size_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_BATCH_CAP", "200")),
rate_limit_redis_db=int(os.environ.get("RATE_LIMIT_REDIS_DB", "3")),
metrics_allowed_ips=os.environ.get(
"METRICS_ALLOWED_IPS",
"127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1",
),
)

132
api/rate_limiter.py Normal file
View file

@ -0,0 +1,132 @@
"""Per-user rate limiting middleware using Redis fixed-window counters."""
from __future__ import annotations
import logging
import os
from urllib.parse import urlparse, urlunparse
import jwt
import redis
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from api.rate_limit_config import EndpointLimit, RateLimitConfig
logger = logging.getLogger("uvicorn")
# Paths exempt from rate limiting
EXEMPT_PATHS = {"/api/status", "/metrics"}
def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg]
"""Get a Redis client for rate limit counters, using the configured DB."""
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
parsed = urlparse(broker_url)
url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}"))
return redis.from_url(url, decode_responses=True)
def _extract_user_email(request: Request) -> str | None:
"""Extract user email from JWT without full verification.
Mirrors the unverified-decode pattern used in api/auth.py:89 for routing.
Only used for rate-limit keying, not for authorization.
"""
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[7:]
try:
payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
return payload.get("email")
except jwt.PyJWTError:
return None
def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None:
"""Find the rate limit for a request path.
Passkey routes (/api/passkey/*) all share the /api/passkey limit.
"""
if path.startswith("/api/passkey"):
return config.endpoint_limits.get("/api/passkey")
return config.endpoint_limits.get(path)
def _client_ip(request: Request) -> str:
"""Best-effort client IP from X-Forwarded-For or connection."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
client = request.client
return client.host if client else "unknown"
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
super().__init__(app)
self.config = config or RateLimitConfig.from_env()
try:
self._redis = _get_rate_limit_redis(self.config)
self._redis.ping()
except redis.RedisError:
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
self._redis = None
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
path = request.url.path
# Skip exempt paths
if path in EXEMPT_PATHS:
return await call_next(request)
limit = _match_endpoint(path, self.config)
if limit is None:
return await call_next(request)
# Determine identity for the counter key
identity = _extract_user_email(request) or _client_ip(request)
# If Redis is unavailable, fail open
if self._redis is None:
return await call_next(request)
redis_key = f"ratelimit:{identity}:{path}"
try:
pipe = self._redis.pipeline(transaction=True)
pipe.incr(redis_key)
pipe.ttl(redis_key)
result = pipe.execute()
current_count: int = result[0]
ttl: int = result[1]
# Set expiry on first request in window
if ttl == -1:
self._redis.expire(redis_key, limit.window_seconds)
ttl = limit.window_seconds
remaining = max(0, limit.max_requests - current_count)
if current_count > limit.max_requests:
retry_after = max(1, ttl)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": str(limit.max_requests),
"X-RateLimit-Remaining": "0",
},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
except redis.RedisError as e:
logger.warning(f"Rate limiter Redis error, failing open: {e}")
return await call_next(request)

View file

@ -0,0 +1,111 @@
"""Unit tests for api/metrics_guard.py."""
import ipaddress
import pytest
from starlette.testclient import TestClient
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from api.metrics_guard import MetricsGuardMiddleware, is_ip_allowed, parse_allowed_networks
from api.rate_limit_config import RateLimitConfig
async def _ok_endpoint(request: Request) -> JSONResponse:
return JSONResponse({"ok": True})
class TestParseAllowedNetworks:
"""Tests for parse_allowed_networks."""
def test_single_ip(self) -> None:
nets = parse_allowed_networks("127.0.0.1")
assert len(nets) == 1
assert ipaddress.ip_address("127.0.0.1") in nets[0]
def test_cidr(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert len(nets) == 1
assert ipaddress.ip_address("10.255.255.255") in nets[0]
def test_multiple_entries(self) -> None:
nets = parse_allowed_networks("127.0.0.1, 10.0.0.0/8, ::1")
assert len(nets) == 3
def test_empty_string(self) -> None:
nets = parse_allowed_networks("")
assert nets == []
def test_trailing_comma(self) -> None:
nets = parse_allowed_networks("127.0.0.1,")
assert len(nets) == 1
class TestIsIpAllowed:
"""Tests for is_ip_allowed."""
def test_allowed_ip(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert is_ip_allowed("10.1.2.3", nets) is True
def test_denied_ip(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert is_ip_allowed("192.168.1.1", nets) is False
def test_ipv6(self) -> None:
nets = parse_allowed_networks("::1")
assert is_ip_allowed("::1", nets) is True
assert is_ip_allowed("::2", nets) is False
def test_invalid_ip(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert is_ip_allowed("not-an-ip", nets) is False
class TestMetricsGuardMiddleware:
"""Integration tests for MetricsGuardMiddleware."""
def _build_app(self, allowed_ips: str) -> Starlette:
config = RateLimitConfig(metrics_allowed_ips=allowed_ips)
app = Starlette(routes=[
Route("/metrics", _ok_endpoint),
Route("/api/status", _ok_endpoint),
])
app.add_middleware(MetricsGuardMiddleware, config=config)
return app
def test_allows_metrics_from_allowed_ip(self) -> None:
app = self._build_app("127.0.0.1,testclient")
# TestClient connects from 'testclient' by default
# We need to override; use the header approach
app2 = self._build_app("10.0.0.1")
client = TestClient(app2)
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
assert resp.status_code == 200
def test_blocks_metrics_from_disallowed_ip(self) -> None:
app = self._build_app("10.0.0.0/8")
client = TestClient(app)
resp = client.get("/metrics", headers={"X-Forwarded-For": "192.168.1.1"})
assert resp.status_code == 403
def test_non_metrics_path_passes_through(self) -> None:
app = self._build_app("10.0.0.0/8")
client = TestClient(app)
resp = client.get("/api/status")
assert resp.status_code == 200
def test_default_private_ranges(self) -> None:
config = RateLimitConfig()
app = Starlette(routes=[Route("/metrics", _ok_endpoint)])
app.add_middleware(MetricsGuardMiddleware, config=config)
client = TestClient(app)
# Private IP should be allowed
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
assert resp.status_code == 200
# Public IP should be denied
resp = client.get("/metrics", headers={"X-Forwarded-For": "8.8.8.8"})
assert resp.status_code == 403

View file

@ -0,0 +1,238 @@
"""Unit tests for api/rate_limiter.py."""
from unittest import mock
import pytest
from starlette.testclient import TestClient
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from api.rate_limit_config import EndpointLimit, RateLimitConfig
from api.rate_limiter import (
RateLimitMiddleware,
_extract_user_email,
_match_endpoint,
EXEMPT_PATHS,
)
def _make_config(**overrides: object) -> RateLimitConfig:
"""Create a RateLimitConfig with defaults for testing."""
defaults: dict[str, object] = {
"endpoint_limits": {
"/api/listing": EndpointLimit(3, 60),
"/api/passkey": EndpointLimit(2, 60),
},
"listing_limit_cap": 100,
"geojson_limit_cap": 5000,
"geojson_stream_limit_cap": 10000,
"geojson_stream_batch_size_cap": 200,
"rate_limit_redis_db": 3,
"metrics_allowed_ips": "127.0.0.1",
}
defaults.update(overrides)
return RateLimitConfig(**defaults) # type: ignore[arg-type]
async def _ok_endpoint(request: Request) -> JSONResponse:
return JSONResponse({"ok": True})
def _build_app(config: RateLimitConfig) -> Starlette:
"""Build a minimal Starlette app with the rate limiter."""
app = Starlette(routes=[
Route("/api/listing", _ok_endpoint),
Route("/api/passkey/login/begin", _ok_endpoint, methods=["POST"]),
Route("/api/status", _ok_endpoint),
])
app.add_middleware(RateLimitMiddleware, config=config)
return app
class TestExtractUserEmail:
"""Tests for _extract_user_email."""
def test_no_auth_header(self) -> None:
scope = {"type": "http", "headers": []}
request = Request(scope)
assert _extract_user_email(request) is None
def test_invalid_token(self) -> None:
scope = {
"type": "http",
"headers": [(b"authorization", b"Bearer not-a-jwt")],
}
request = Request(scope)
assert _extract_user_email(request) is None
def test_valid_jwt(self) -> None:
import jwt as pyjwt
token = pyjwt.encode({"email": "test@example.com"}, "secret", algorithm="HS256")
scope = {
"type": "http",
"headers": [(b"authorization", f"Bearer {token}".encode())],
}
request = Request(scope)
assert _extract_user_email(request) == "test@example.com"
class TestMatchEndpoint:
"""Tests for _match_endpoint."""
def test_exact_match(self) -> None:
config = _make_config()
limit = _match_endpoint("/api/listing", config)
assert limit is not None
assert limit.max_requests == 3
def test_passkey_prefix_match(self) -> None:
config = _make_config()
limit = _match_endpoint("/api/passkey/login/begin", config)
assert limit is not None
assert limit.max_requests == 2
def test_no_match(self) -> None:
config = _make_config()
assert _match_endpoint("/api/unknown", config) is None
class TestRateLimitMiddleware:
"""Integration tests for RateLimitMiddleware."""
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_allows_requests_under_limit(self, mock_get_redis: mock.MagicMock) -> None:
mock_redis = mock.MagicMock()
mock_pipe = mock.MagicMock()
mock_pipe.execute.return_value = [1, -1] # first request, no TTL yet
mock_redis.pipeline.return_value = mock_pipe
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 200
assert "X-RateLimit-Limit" in resp.headers
assert resp.headers["X-RateLimit-Limit"] == "3"
assert resp.headers["X-RateLimit-Remaining"] == "2"
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_returns_429_over_limit(self, mock_get_redis: mock.MagicMock) -> None:
mock_redis = mock.MagicMock()
mock_pipe = mock.MagicMock()
# 4th request in window (limit=3), TTL=45s remaining
mock_pipe.execute.return_value = [4, 45]
mock_redis.pipeline.return_value = mock_pipe
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 429
assert resp.headers["Retry-After"] == "45"
assert resp.headers["X-RateLimit-Remaining"] == "0"
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_exempt_paths_skip_rate_limiting(self, mock_get_redis: mock.MagicMock) -> None:
mock_redis = mock.MagicMock()
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/status")
assert resp.status_code == 200
assert "X-RateLimit-Limit" not in resp.headers
# Redis pipeline should never be called for exempt paths
mock_redis.pipeline.assert_not_called()
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_fails_open_on_redis_error(self, mock_get_redis: mock.MagicMock) -> None:
"""When Redis raises an error, requests should be allowed through."""
import redis
mock_redis = mock.MagicMock()
mock_pipe = mock.MagicMock()
mock_pipe.execute.side_effect = redis.RedisError("connection lost")
mock_redis.pipeline.return_value = mock_pipe
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 200
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_fails_open_when_redis_unavailable_at_startup(self, mock_get_redis: mock.MagicMock) -> None:
"""When Redis is unavailable at startup, requests pass through."""
import redis
mock_get_redis.side_effect = redis.RedisError("connection refused")
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 200
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_unmatched_endpoint_skips_limiting(self, mock_get_redis: mock.MagicMock) -> None:
"""Endpoints not in the config are not rate limited."""
mock_redis = mock.MagicMock()
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = Starlette(routes=[Route("/api/unknown", _ok_endpoint)])
app.add_middleware(RateLimitMiddleware, config=config)
client = TestClient(app)
resp = client.get("/api/unknown")
assert resp.status_code == 200
assert "X-RateLimit-Limit" not in resp.headers
class TestRateLimitConfig:
"""Tests for RateLimitConfig.from_env."""
def test_defaults(self) -> None:
with mock.patch.dict("os.environ", {}, clear=True):
config = RateLimitConfig.from_env()
assert config.listing_limit_cap == 100
assert config.geojson_limit_cap == 5000
assert config.rate_limit_redis_db == 3
def test_custom_env_vars(self) -> None:
env = {
"RATE_LIMIT_LISTING": "50/120",
"EXPORT_LISTING_LIMIT_CAP": "200",
"RATE_LIMIT_REDIS_DB": "5",
"METRICS_ALLOWED_IPS": "10.0.0.1",
}
with mock.patch.dict("os.environ", env, clear=True):
config = RateLimitConfig.from_env()
assert config.endpoint_limits["/api/listing"].max_requests == 50
assert config.endpoint_limits["/api/listing"].window_seconds == 120
assert config.listing_limit_cap == 200
assert config.rate_limit_redis_db == 5
assert config.metrics_allowed_ips == "10.0.0.1"
def test_invalid_limit_format_uses_defaults(self) -> None:
env = {"RATE_LIMIT_LISTING": "invalid"}
with mock.patch.dict("os.environ", env, clear=True):
config = RateLimitConfig.from_env()
# Should fall back to default
assert config.endpoint_limits["/api/listing"].max_requests == 30
assert config.endpoint_limits["/api/listing"].window_seconds == 60