diff --git a/.env.sample b/.env.sample index bfffd0f..4ea839e 100644 --- a/.env.sample +++ b/.env.sample @@ -41,3 +41,25 @@ JWT_SECRET=change-me-in-production # HMAC secret for HS256 signing JWT_ALGORITHM=HS256 # JWT signing algorithm JWT_EXPIRATION_HOURS=24 # Token expiry in hours JWT_ISSUER=wrongmove # JWT issuer claim + +# API rate limiting (format: max_requests/window_seconds) +# RATE_LIMIT_LISTING=30/60 # /api/listing: 30 req per 60s +# RATE_LIMIT_GEOJSON=10/60 # /api/listing_geojson: 10 req per 60s +# RATE_LIMIT_GEOJSON_STREAM=10/60 # /api/listing_geojson/stream: 10 req per 60s +# RATE_LIMIT_REFRESH=3/300 # /api/refresh_listings: 3 req per 5min +# RATE_LIMIT_TASK_STATUS=60/60 # /api/task_status: 60 req per 60s +# RATE_LIMIT_TASKS_FOR_USER=30/60 # /api/tasks_for_user: 30 req per 60s +# RATE_LIMIT_CANCEL_TASK=10/60 # /api/cancel_task: 10 req per 60s +# RATE_LIMIT_CLEAR_TASKS=5/60 # /api/clear_all_tasks: 5 req per 60s +# RATE_LIMIT_DISTRICTS=20/60 # /api/get_districts: 20 req per 60s +# RATE_LIMIT_PASSKEY=10/60 # /api/passkey/*: 10 req per 60s +RATE_LIMIT_REDIS_DB=3 # Redis DB for rate limit counters + +# Bulk export caps +EXPORT_LISTING_LIMIT_CAP=100 # Max listings per /api/listing request +EXPORT_GEOJSON_LIMIT_CAP=5000 # Max features per /api/listing_geojson request +EXPORT_GEOJSON_STREAM_LIMIT_CAP=10000 # Max features per /api/listing_geojson/stream +EXPORT_GEOJSON_STREAM_BATCH_CAP=200 # Max batch size for streaming + +# Metrics endpoint access control (comma-separated IPs/CIDRs) +METRICS_ALLOWED_IPS=127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1 diff --git a/api/app.py b/api/app.py index 926fb7c..edc191b 100644 --- a/api/app.py +++ b/api/app.py @@ -7,6 +7,10 @@ from typing import Annotated, AsyncGenerator, Optional from api.auth import get_current_user from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS from api.passkey_routes import passkey_router +from api.rate_limit_config import RateLimitConfig +from api.rate_limiter import RateLimitMiddleware +from api.audit_middleware import AuditLogMiddleware +from api.metrics_guard import MetricsGuardMiddleware from dotenv import load_dotenv from fastapi import Depends, FastAPI, Query from fastapi.responses import StreamingResponse @@ -33,6 +37,7 @@ load_dotenv() logger = logging.getLogger("uvicorn") DEFAULT_BATCH_SIZE = 50 +_rate_limit_config = RateLimitConfig.from_env() def get_query_parameters( @@ -82,10 +87,18 @@ hist = meter.create_histogram( app.add_middleware( CORSMiddleware, allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS], - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST"], + allow_headers=["Authorization", "Content-Type"], ) +# Security middleware (added bottom-to-top; last added = outermost) +# 3. Rate limiting — enforces per-user limits +app.add_middleware(RateLimitMiddleware, config=_rate_limit_config) +# 2. Metrics guard — blocks unauthorized /metrics access +app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config) +# 1. Audit logging — logs everything including 429s and 403s +app.add_middleware(AuditLogMiddleware) + @app.get("/api/status") async def get_status() -> dict[str, str]: @@ -100,6 +113,7 @@ async def get_listing( limit: int = 5, ) -> dict[str, list]: """Get listings from the database.""" + limit = min(limit, _rate_limit_config.listing_limit_cap) repository = ListingRepository(engine) result = await listing_service.get_listings(repository, limit=limit) logger.info(f"Fetched {result.total_count} listings for {user.email}") @@ -113,6 +127,10 @@ async def get_listing_geojson( limit: int | None = None, ) -> dict: """Get listings as GeoJSON for map display.""" + if limit is not None: + limit = min(limit, _rate_limit_config.geojson_limit_cap) + else: + limit = _rate_limit_config.geojson_limit_cap repository = ListingRepository(engine) result = await export_service.export_to_geojson( repository, @@ -204,6 +222,12 @@ async def stream_listing_geojson( - batch: Array of GeoJSON features - complete: Final message with total count """ + batch_size = min(batch_size, _rate_limit_config.geojson_stream_batch_size_cap) + if limit is not None: + limit = min(limit, _rate_limit_config.geojson_stream_limit_cap) + else: + limit = _rate_limit_config.geojson_stream_limit_cap + cached_count = get_cached_count(query_parameters) if cached_count is not None and cached_count > 0: generator = _stream_from_cache(query_parameters, batch_size, limit) diff --git a/api/audit_middleware.py b/api/audit_middleware.py new file mode 100644 index 0000000..21f7ea9 --- /dev/null +++ b/api/audit_middleware.py @@ -0,0 +1,63 @@ +"""Audit logging middleware for API requests.""" +from __future__ import annotations + +import logging +import time + +import jwt +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +audit_logger = logging.getLogger("uvicorn.audit") + + +def _extract_identity(request: Request) -> str: + """Extract user email from JWT for audit logging.""" + auth_header = request.headers.get("authorization", "") + if not auth_header.startswith("Bearer "): + return "anonymous" + token = auth_header[7:] + try: + payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False}) + return payload.get("email", "unknown") + except jwt.PyJWTError: + return "invalid-token" + + +def _client_ip(request: Request) -> str: + """Best-effort client IP.""" + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + client = request.client + return client.host if client else "unknown" + + +class AuditLogMiddleware(BaseHTTPMiddleware): + """Logs all /api/ requests with method, path, user, IP, status, and duration.""" + + async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def] + path = request.url.path + if not path.startswith("/api/"): + return await call_next(request) + + start = time.monotonic() + identity = _extract_identity(request) + ip = _client_ip(request) + query = str(request.query_params) if request.query_params else "" + + response = await call_next(request) + + duration_ms = (time.monotonic() - start) * 1000 + audit_logger.info( + "method=%s path=%s query=%s user=%s ip=%s status=%d duration_ms=%.1f", + request.method, + path, + query, + identity, + ip, + response.status_code, + duration_ms, + ) + return response diff --git a/api/metrics_guard.py b/api/metrics_guard.py new file mode 100644 index 0000000..04f29f3 --- /dev/null +++ b/api/metrics_guard.py @@ -0,0 +1,61 @@ +"""IP allowlist middleware for the /metrics endpoint.""" +from __future__ import annotations + +import ipaddress +import logging + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from api.rate_limit_config import RateLimitConfig + +logger = logging.getLogger("uvicorn") + + +def parse_allowed_networks(raw: str) -> list[ipaddress.IPv4Network | ipaddress.IPv6Network]: + """Parse a comma-separated string of IPs/CIDRs into network objects.""" + networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = [] + for entry in raw.split(","): + entry = entry.strip() + if not entry: + continue + networks.append(ipaddress.ip_network(entry, strict=False)) + return networks + + +def is_ip_allowed( + ip_str: str, + allowed_networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network], +) -> bool: + """Check whether an IP address falls within any of the allowed networks.""" + try: + addr = ipaddress.ip_address(ip_str) + except ValueError: + return False + return any(addr in network for network in allowed_networks) + + +class MetricsGuardMiddleware(BaseHTTPMiddleware): + """Restricts /metrics access to an IP allowlist.""" + + def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def] + super().__init__(app) + cfg = config or RateLimitConfig.from_env() + self._allowed_networks = parse_allowed_networks(cfg.metrics_allowed_ips) + + async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def] + if not request.url.path.startswith("/metrics"): + return await call_next(request) + + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + client_ip = forwarded.split(",")[0].strip() + else: + client_ip = request.client.host if request.client else "unknown" + + if not is_ip_allowed(client_ip, self._allowed_networks): + logger.warning("Metrics access denied for IP %s", client_ip) + return JSONResponse(status_code=403, content={"detail": "Forbidden"}) + + return await call_next(request) diff --git a/api/rate_limit_config.py b/api/rate_limit_config.py new file mode 100644 index 0000000..7d6fad7 --- /dev/null +++ b/api/rate_limit_config.py @@ -0,0 +1,103 @@ +"""Rate limit and security configuration with environment variable loading.""" +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import Self + + +@dataclass(frozen=True) +class EndpointLimit: + """Rate limit for a single endpoint pattern.""" + + max_requests: int + window_seconds: int + + +@dataclass(frozen=True) +class RateLimitConfig: + """Configuration for API rate limiting, export caps, and metrics access. + + All values are configurable via environment variables with sensible defaults. + """ + + # Per-endpoint rate limits + endpoint_limits: dict[str, EndpointLimit] = field(default_factory=lambda: { + "/api/listing": EndpointLimit(30, 60), + "/api/listing_geojson": EndpointLimit(10, 60), + "/api/listing_geojson/stream": EndpointLimit(10, 60), + "/api/refresh_listings": EndpointLimit(3, 300), + "/api/task_status": EndpointLimit(60, 60), + "/api/tasks_for_user": EndpointLimit(30, 60), + "/api/cancel_task": EndpointLimit(10, 60), + "/api/clear_all_tasks": EndpointLimit(5, 60), + "/api/get_districts": EndpointLimit(20, 60), + "/api/passkey": EndpointLimit(10, 60), + }) + + # Bulk export caps + listing_limit_cap: int = 100 + geojson_limit_cap: int = 5_000 + geojson_stream_limit_cap: int = 10_000 + geojson_stream_batch_size_cap: int = 200 + + # Redis DB for rate limit counters + rate_limit_redis_db: int = 3 + + # Metrics endpoint IP allowlist (comma-separated CIDRs) + metrics_allowed_ips: str = "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1" + + @classmethod + def from_env(cls) -> Self: + """Load configuration from environment variables. + + Environment variables: + RATE_LIMIT_LISTING: /api/listing limit (default: 30/60s) + RATE_LIMIT_GEOJSON: /api/listing_geojson limit (default: 10/60s) + RATE_LIMIT_GEOJSON_STREAM: /api/listing_geojson/stream limit (default: 10/60s) + RATE_LIMIT_REFRESH: /api/refresh_listings limit (default: 3/300s) + RATE_LIMIT_TASK_STATUS: /api/task_status limit (default: 60/60s) + RATE_LIMIT_TASKS_FOR_USER: /api/tasks_for_user limit (default: 30/60s) + RATE_LIMIT_CANCEL_TASK: /api/cancel_task limit (default: 10/60s) + RATE_LIMIT_CLEAR_TASKS: /api/clear_all_tasks limit (default: 5/60s) + RATE_LIMIT_DISTRICTS: /api/get_districts limit (default: 20/60s) + RATE_LIMIT_PASSKEY: /api/passkey/* limit (default: 10/60s) + EXPORT_LISTING_LIMIT_CAP: Max listings per request (default: 100) + EXPORT_GEOJSON_LIMIT_CAP: Max GeoJSON features per request (default: 5000) + EXPORT_GEOJSON_STREAM_LIMIT_CAP: Max streamed features (default: 10000) + EXPORT_GEOJSON_STREAM_BATCH_CAP: Max stream batch size (default: 200) + RATE_LIMIT_REDIS_DB: Redis DB number for counters (default: 3) + METRICS_ALLOWED_IPS: Comma-separated CIDRs for /metrics (default: private ranges) + """ + + def _parse_limit(env_var: str, default_requests: int, default_window: int) -> EndpointLimit: + raw = os.environ.get(env_var) + if raw: + parts = raw.split("/") + if len(parts) == 2: + return EndpointLimit(int(parts[0]), int(parts[1])) + return EndpointLimit(default_requests, default_window) + + return cls( + endpoint_limits={ + "/api/listing": _parse_limit("RATE_LIMIT_LISTING", 30, 60), + "/api/listing_geojson": _parse_limit("RATE_LIMIT_GEOJSON", 10, 60), + "/api/listing_geojson/stream": _parse_limit("RATE_LIMIT_GEOJSON_STREAM", 10, 60), + "/api/refresh_listings": _parse_limit("RATE_LIMIT_REFRESH", 3, 300), + "/api/task_status": _parse_limit("RATE_LIMIT_TASK_STATUS", 60, 60), + "/api/tasks_for_user": _parse_limit("RATE_LIMIT_TASKS_FOR_USER", 30, 60), + "/api/cancel_task": _parse_limit("RATE_LIMIT_CANCEL_TASK", 10, 60), + "/api/clear_all_tasks": _parse_limit("RATE_LIMIT_CLEAR_TASKS", 5, 60), + "/api/get_districts": _parse_limit("RATE_LIMIT_DISTRICTS", 20, 60), + "/api/passkey": _parse_limit("RATE_LIMIT_PASSKEY", 10, 60), + }, + listing_limit_cap=int(os.environ.get("EXPORT_LISTING_LIMIT_CAP", "100")), + geojson_limit_cap=int(os.environ.get("EXPORT_GEOJSON_LIMIT_CAP", "5000")), + geojson_stream_limit_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_LIMIT_CAP", "10000")), + geojson_stream_batch_size_cap=int(os.environ.get("EXPORT_GEOJSON_STREAM_BATCH_CAP", "200")), + rate_limit_redis_db=int(os.environ.get("RATE_LIMIT_REDIS_DB", "3")), + metrics_allowed_ips=os.environ.get( + "METRICS_ALLOWED_IPS", + "127.0.0.1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,::1", + ), + ) diff --git a/api/rate_limiter.py b/api/rate_limiter.py new file mode 100644 index 0000000..43b2a5d --- /dev/null +++ b/api/rate_limiter.py @@ -0,0 +1,132 @@ +"""Per-user rate limiting middleware using Redis fixed-window counters.""" +from __future__ import annotations + +import logging +import os +from urllib.parse import urlparse, urlunparse + +import jwt +import redis +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from api.rate_limit_config import EndpointLimit, RateLimitConfig + +logger = logging.getLogger("uvicorn") + +# Paths exempt from rate limiting +EXEMPT_PATHS = {"/api/status", "/metrics"} + + +def _get_rate_limit_redis(config: RateLimitConfig) -> redis.Redis: # type: ignore[type-arg] + """Get a Redis client for rate limit counters, using the configured DB.""" + broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") + parsed = urlparse(broker_url) + url = urlunparse(parsed._replace(path=f"/{config.rate_limit_redis_db}")) + return redis.from_url(url, decode_responses=True) + + +def _extract_user_email(request: Request) -> str | None: + """Extract user email from JWT without full verification. + + Mirrors the unverified-decode pattern used in api/auth.py:89 for routing. + Only used for rate-limit keying, not for authorization. + """ + auth_header = request.headers.get("authorization", "") + if not auth_header.startswith("Bearer "): + return None + token = auth_header[7:] + try: + payload = jwt.decode(token, options={"verify_signature": False, "verify_exp": False}) + return payload.get("email") + except jwt.PyJWTError: + return None + + +def _match_endpoint(path: str, config: RateLimitConfig) -> EndpointLimit | None: + """Find the rate limit for a request path. + + Passkey routes (/api/passkey/*) all share the /api/passkey limit. + """ + if path.startswith("/api/passkey"): + return config.endpoint_limits.get("/api/passkey") + return config.endpoint_limits.get(path) + + +def _client_ip(request: Request) -> str: + """Best-effort client IP from X-Forwarded-For or connection.""" + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + client = request.client + return client.host if client else "unknown" + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """Starlette middleware enforcing per-user fixed-window rate limits via Redis.""" + + def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def] + super().__init__(app) + self.config = config or RateLimitConfig.from_env() + try: + self._redis = _get_rate_limit_redis(self.config) + self._redis.ping() + except redis.RedisError: + logger.warning("Rate limiter: Redis unavailable at startup, will fail open") + self._redis = None + + async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def] + path = request.url.path + + # Skip exempt paths + if path in EXEMPT_PATHS: + return await call_next(request) + + limit = _match_endpoint(path, self.config) + if limit is None: + return await call_next(request) + + # Determine identity for the counter key + identity = _extract_user_email(request) or _client_ip(request) + + # If Redis is unavailable, fail open + if self._redis is None: + return await call_next(request) + + redis_key = f"ratelimit:{identity}:{path}" + try: + pipe = self._redis.pipeline(transaction=True) + pipe.incr(redis_key) + pipe.ttl(redis_key) + result = pipe.execute() + current_count: int = result[0] + ttl: int = result[1] + + # Set expiry on first request in window + if ttl == -1: + self._redis.expire(redis_key, limit.window_seconds) + ttl = limit.window_seconds + + remaining = max(0, limit.max_requests - current_count) + + if current_count > limit.max_requests: + retry_after = max(1, ttl) + return JSONResponse( + status_code=429, + content={"detail": "Rate limit exceeded"}, + headers={ + "Retry-After": str(retry_after), + "X-RateLimit-Limit": str(limit.max_requests), + "X-RateLimit-Remaining": "0", + }, + ) + + response = await call_next(request) + response.headers["X-RateLimit-Limit"] = str(limit.max_requests) + response.headers["X-RateLimit-Remaining"] = str(remaining) + return response + + except redis.RedisError as e: + logger.warning(f"Rate limiter Redis error, failing open: {e}") + return await call_next(request) diff --git a/tests/unit/test_metrics_guard.py b/tests/unit/test_metrics_guard.py new file mode 100644 index 0000000..825851b --- /dev/null +++ b/tests/unit/test_metrics_guard.py @@ -0,0 +1,111 @@ +"""Unit tests for api/metrics_guard.py.""" +import ipaddress + +import pytest +from starlette.testclient import TestClient +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.routing import Route + +from api.metrics_guard import MetricsGuardMiddleware, is_ip_allowed, parse_allowed_networks +from api.rate_limit_config import RateLimitConfig + + +async def _ok_endpoint(request: Request) -> JSONResponse: + return JSONResponse({"ok": True}) + + +class TestParseAllowedNetworks: + """Tests for parse_allowed_networks.""" + + def test_single_ip(self) -> None: + nets = parse_allowed_networks("127.0.0.1") + assert len(nets) == 1 + assert ipaddress.ip_address("127.0.0.1") in nets[0] + + def test_cidr(self) -> None: + nets = parse_allowed_networks("10.0.0.0/8") + assert len(nets) == 1 + assert ipaddress.ip_address("10.255.255.255") in nets[0] + + def test_multiple_entries(self) -> None: + nets = parse_allowed_networks("127.0.0.1, 10.0.0.0/8, ::1") + assert len(nets) == 3 + + def test_empty_string(self) -> None: + nets = parse_allowed_networks("") + assert nets == [] + + def test_trailing_comma(self) -> None: + nets = parse_allowed_networks("127.0.0.1,") + assert len(nets) == 1 + + +class TestIsIpAllowed: + """Tests for is_ip_allowed.""" + + def test_allowed_ip(self) -> None: + nets = parse_allowed_networks("10.0.0.0/8") + assert is_ip_allowed("10.1.2.3", nets) is True + + def test_denied_ip(self) -> None: + nets = parse_allowed_networks("10.0.0.0/8") + assert is_ip_allowed("192.168.1.1", nets) is False + + def test_ipv6(self) -> None: + nets = parse_allowed_networks("::1") + assert is_ip_allowed("::1", nets) is True + assert is_ip_allowed("::2", nets) is False + + def test_invalid_ip(self) -> None: + nets = parse_allowed_networks("10.0.0.0/8") + assert is_ip_allowed("not-an-ip", nets) is False + + +class TestMetricsGuardMiddleware: + """Integration tests for MetricsGuardMiddleware.""" + + def _build_app(self, allowed_ips: str) -> Starlette: + config = RateLimitConfig(metrics_allowed_ips=allowed_ips) + app = Starlette(routes=[ + Route("/metrics", _ok_endpoint), + Route("/api/status", _ok_endpoint), + ]) + app.add_middleware(MetricsGuardMiddleware, config=config) + return app + + def test_allows_metrics_from_allowed_ip(self) -> None: + app = self._build_app("127.0.0.1,testclient") + # TestClient connects from 'testclient' by default + # We need to override; use the header approach + app2 = self._build_app("10.0.0.1") + client = TestClient(app2) + resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"}) + assert resp.status_code == 200 + + def test_blocks_metrics_from_disallowed_ip(self) -> None: + app = self._build_app("10.0.0.0/8") + client = TestClient(app) + resp = client.get("/metrics", headers={"X-Forwarded-For": "192.168.1.1"}) + assert resp.status_code == 403 + + def test_non_metrics_path_passes_through(self) -> None: + app = self._build_app("10.0.0.0/8") + client = TestClient(app) + resp = client.get("/api/status") + assert resp.status_code == 200 + + def test_default_private_ranges(self) -> None: + config = RateLimitConfig() + app = Starlette(routes=[Route("/metrics", _ok_endpoint)]) + app.add_middleware(MetricsGuardMiddleware, config=config) + client = TestClient(app) + + # Private IP should be allowed + resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"}) + assert resp.status_code == 200 + + # Public IP should be denied + resp = client.get("/metrics", headers={"X-Forwarded-For": "8.8.8.8"}) + assert resp.status_code == 403 diff --git a/tests/unit/test_rate_limiter.py b/tests/unit/test_rate_limiter.py new file mode 100644 index 0000000..68b0b57 --- /dev/null +++ b/tests/unit/test_rate_limiter.py @@ -0,0 +1,238 @@ +"""Unit tests for api/rate_limiter.py.""" +from unittest import mock + +import pytest +from starlette.testclient import TestClient +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.routing import Route + +from api.rate_limit_config import EndpointLimit, RateLimitConfig +from api.rate_limiter import ( + RateLimitMiddleware, + _extract_user_email, + _match_endpoint, + EXEMPT_PATHS, +) + + +def _make_config(**overrides: object) -> RateLimitConfig: + """Create a RateLimitConfig with defaults for testing.""" + defaults: dict[str, object] = { + "endpoint_limits": { + "/api/listing": EndpointLimit(3, 60), + "/api/passkey": EndpointLimit(2, 60), + }, + "listing_limit_cap": 100, + "geojson_limit_cap": 5000, + "geojson_stream_limit_cap": 10000, + "geojson_stream_batch_size_cap": 200, + "rate_limit_redis_db": 3, + "metrics_allowed_ips": "127.0.0.1", + } + defaults.update(overrides) + return RateLimitConfig(**defaults) # type: ignore[arg-type] + + +async def _ok_endpoint(request: Request) -> JSONResponse: + return JSONResponse({"ok": True}) + + +def _build_app(config: RateLimitConfig) -> Starlette: + """Build a minimal Starlette app with the rate limiter.""" + app = Starlette(routes=[ + Route("/api/listing", _ok_endpoint), + Route("/api/passkey/login/begin", _ok_endpoint, methods=["POST"]), + Route("/api/status", _ok_endpoint), + ]) + app.add_middleware(RateLimitMiddleware, config=config) + return app + + +class TestExtractUserEmail: + """Tests for _extract_user_email.""" + + def test_no_auth_header(self) -> None: + scope = {"type": "http", "headers": []} + request = Request(scope) + assert _extract_user_email(request) is None + + def test_invalid_token(self) -> None: + scope = { + "type": "http", + "headers": [(b"authorization", b"Bearer not-a-jwt")], + } + request = Request(scope) + assert _extract_user_email(request) is None + + def test_valid_jwt(self) -> None: + import jwt as pyjwt + token = pyjwt.encode({"email": "test@example.com"}, "secret", algorithm="HS256") + scope = { + "type": "http", + "headers": [(b"authorization", f"Bearer {token}".encode())], + } + request = Request(scope) + assert _extract_user_email(request) == "test@example.com" + + +class TestMatchEndpoint: + """Tests for _match_endpoint.""" + + def test_exact_match(self) -> None: + config = _make_config() + limit = _match_endpoint("/api/listing", config) + assert limit is not None + assert limit.max_requests == 3 + + def test_passkey_prefix_match(self) -> None: + config = _make_config() + limit = _match_endpoint("/api/passkey/login/begin", config) + assert limit is not None + assert limit.max_requests == 2 + + def test_no_match(self) -> None: + config = _make_config() + assert _match_endpoint("/api/unknown", config) is None + + +class TestRateLimitMiddleware: + """Integration tests for RateLimitMiddleware.""" + + @mock.patch("api.rate_limiter._get_rate_limit_redis") + def test_allows_requests_under_limit(self, mock_get_redis: mock.MagicMock) -> None: + mock_redis = mock.MagicMock() + mock_pipe = mock.MagicMock() + mock_pipe.execute.return_value = [1, -1] # first request, no TTL yet + mock_redis.pipeline.return_value = mock_pipe + mock_redis.ping.return_value = True + mock_get_redis.return_value = mock_redis + + config = _make_config() + app = _build_app(config) + client = TestClient(app) + + resp = client.get("/api/listing") + assert resp.status_code == 200 + assert "X-RateLimit-Limit" in resp.headers + assert resp.headers["X-RateLimit-Limit"] == "3" + assert resp.headers["X-RateLimit-Remaining"] == "2" + + @mock.patch("api.rate_limiter._get_rate_limit_redis") + def test_returns_429_over_limit(self, mock_get_redis: mock.MagicMock) -> None: + mock_redis = mock.MagicMock() + mock_pipe = mock.MagicMock() + # 4th request in window (limit=3), TTL=45s remaining + mock_pipe.execute.return_value = [4, 45] + mock_redis.pipeline.return_value = mock_pipe + mock_redis.ping.return_value = True + mock_get_redis.return_value = mock_redis + + config = _make_config() + app = _build_app(config) + client = TestClient(app) + + resp = client.get("/api/listing") + assert resp.status_code == 429 + assert resp.headers["Retry-After"] == "45" + assert resp.headers["X-RateLimit-Remaining"] == "0" + + @mock.patch("api.rate_limiter._get_rate_limit_redis") + def test_exempt_paths_skip_rate_limiting(self, mock_get_redis: mock.MagicMock) -> None: + mock_redis = mock.MagicMock() + mock_redis.ping.return_value = True + mock_get_redis.return_value = mock_redis + + config = _make_config() + app = _build_app(config) + client = TestClient(app) + + resp = client.get("/api/status") + assert resp.status_code == 200 + assert "X-RateLimit-Limit" not in resp.headers + # Redis pipeline should never be called for exempt paths + mock_redis.pipeline.assert_not_called() + + @mock.patch("api.rate_limiter._get_rate_limit_redis") + def test_fails_open_on_redis_error(self, mock_get_redis: mock.MagicMock) -> None: + """When Redis raises an error, requests should be allowed through.""" + import redis + + mock_redis = mock.MagicMock() + mock_pipe = mock.MagicMock() + mock_pipe.execute.side_effect = redis.RedisError("connection lost") + mock_redis.pipeline.return_value = mock_pipe + mock_redis.ping.return_value = True + mock_get_redis.return_value = mock_redis + + config = _make_config() + app = _build_app(config) + client = TestClient(app) + + resp = client.get("/api/listing") + assert resp.status_code == 200 + + @mock.patch("api.rate_limiter._get_rate_limit_redis") + def test_fails_open_when_redis_unavailable_at_startup(self, mock_get_redis: mock.MagicMock) -> None: + """When Redis is unavailable at startup, requests pass through.""" + import redis + + mock_get_redis.side_effect = redis.RedisError("connection refused") + + config = _make_config() + app = _build_app(config) + client = TestClient(app) + + resp = client.get("/api/listing") + assert resp.status_code == 200 + + @mock.patch("api.rate_limiter._get_rate_limit_redis") + def test_unmatched_endpoint_skips_limiting(self, mock_get_redis: mock.MagicMock) -> None: + """Endpoints not in the config are not rate limited.""" + mock_redis = mock.MagicMock() + mock_redis.ping.return_value = True + mock_get_redis.return_value = mock_redis + + config = _make_config() + app = Starlette(routes=[Route("/api/unknown", _ok_endpoint)]) + app.add_middleware(RateLimitMiddleware, config=config) + client = TestClient(app) + + resp = client.get("/api/unknown") + assert resp.status_code == 200 + assert "X-RateLimit-Limit" not in resp.headers + + +class TestRateLimitConfig: + """Tests for RateLimitConfig.from_env.""" + + def test_defaults(self) -> None: + with mock.patch.dict("os.environ", {}, clear=True): + config = RateLimitConfig.from_env() + assert config.listing_limit_cap == 100 + assert config.geojson_limit_cap == 5000 + assert config.rate_limit_redis_db == 3 + + def test_custom_env_vars(self) -> None: + env = { + "RATE_LIMIT_LISTING": "50/120", + "EXPORT_LISTING_LIMIT_CAP": "200", + "RATE_LIMIT_REDIS_DB": "5", + "METRICS_ALLOWED_IPS": "10.0.0.1", + } + with mock.patch.dict("os.environ", env, clear=True): + config = RateLimitConfig.from_env() + assert config.endpoint_limits["/api/listing"].max_requests == 50 + assert config.endpoint_limits["/api/listing"].window_seconds == 120 + assert config.listing_limit_cap == 200 + assert config.rate_limit_redis_db == 5 + assert config.metrics_allowed_ips == "10.0.0.1" + + def test_invalid_limit_format_uses_defaults(self) -> None: + env = {"RATE_LIMIT_LISTING": "invalid"} + with mock.patch.dict("os.environ", env, clear=True): + config = RateLimitConfig.from_env() + # Should fall back to default + assert config.endpoint_limits["/api/listing"].max_requests == 30 + assert config.endpoint_limits["/api/listing"].window_seconds == 60