"""Unit tests for api/rate_limiter.py.""" from unittest import mock import pytest from starlette.testclient import TestClient from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route from api.rate_limit_config import EndpointLimit, RateLimitConfig from api.rate_limiter import ( RateLimitMiddleware, _extract_user_email, _match_endpoint, EXEMPT_PATHS, ) def _make_config(**overrides: object) -> RateLimitConfig: """Create a RateLimitConfig with defaults for testing.""" defaults: dict[str, object] = { "endpoint_limits": { "/api/listing": EndpointLimit(3, 60), "/api/passkey": EndpointLimit(2, 60), }, "listing_limit_cap": 100, "geojson_limit_cap": 5000, "geojson_stream_limit_cap": 10000, "geojson_stream_batch_size_cap": 200, "rate_limit_redis_db": 3, "metrics_allowed_ips": "127.0.0.1", } defaults.update(overrides) return RateLimitConfig(**defaults) # type: ignore[arg-type] async def _ok_endpoint(request: Request) -> JSONResponse: return JSONResponse({"ok": True}) def _build_app(config: RateLimitConfig) -> Starlette: """Build a minimal Starlette app with the rate limiter.""" app = Starlette(routes=[ Route("/api/listing", _ok_endpoint), Route("/api/passkey/login/begin", _ok_endpoint, methods=["POST"]), Route("/api/status", _ok_endpoint), ]) app.add_middleware(RateLimitMiddleware, config=config) return app class TestExtractUserEmail: """Tests for _extract_user_email.""" def test_no_auth_header(self) -> None: scope = {"type": "http", "headers": []} request = Request(scope) assert _extract_user_email(request) is None def test_invalid_token(self) -> None: scope = { "type": "http", "headers": [(b"authorization", b"Bearer not-a-jwt")], } request = Request(scope) assert _extract_user_email(request) is None def test_valid_jwt(self) -> None: import jwt as pyjwt token = pyjwt.encode({"email": "test@example.com"}, "secret", algorithm="HS256") scope = { "type": "http", "headers": [(b"authorization", f"Bearer {token}".encode())], } request = Request(scope) assert _extract_user_email(request) == "test@example.com" class TestMatchEndpoint: """Tests for _match_endpoint.""" def test_exact_match(self) -> None: config = _make_config() limit = _match_endpoint("/api/listing", config) assert limit is not None assert limit.max_requests == 3 def test_passkey_prefix_match(self) -> None: config = _make_config() limit = _match_endpoint("/api/passkey/login/begin", config) assert limit is not None assert limit.max_requests == 2 def test_no_match(self) -> None: config = _make_config() assert _match_endpoint("/api/unknown", config) is None class TestRateLimitMiddleware: """Integration tests for RateLimitMiddleware.""" @mock.patch("api.rate_limiter._get_rate_limit_redis") def test_allows_requests_under_limit(self, mock_get_redis: mock.MagicMock) -> None: mock_redis = mock.MagicMock() mock_pipe = mock.MagicMock() mock_pipe.execute.return_value = [1, -1] # first request, no TTL yet mock_redis.pipeline.return_value = mock_pipe mock_redis.ping.return_value = True mock_get_redis.return_value = mock_redis config = _make_config() app = _build_app(config) client = TestClient(app) resp = client.get("/api/listing") assert resp.status_code == 200 assert "X-RateLimit-Limit" in resp.headers assert resp.headers["X-RateLimit-Limit"] == "3" assert resp.headers["X-RateLimit-Remaining"] == "2" @mock.patch("api.rate_limiter._get_rate_limit_redis") def test_returns_429_over_limit(self, mock_get_redis: mock.MagicMock) -> None: mock_redis = mock.MagicMock() mock_pipe = mock.MagicMock() # 4th request in window (limit=3), TTL=45s remaining mock_pipe.execute.return_value = [4, 45] mock_redis.pipeline.return_value = mock_pipe mock_redis.ping.return_value = True mock_get_redis.return_value = mock_redis config = _make_config() app = _build_app(config) client = TestClient(app) resp = client.get("/api/listing") assert resp.status_code == 429 assert resp.headers["Retry-After"] == "45" assert resp.headers["X-RateLimit-Remaining"] == "0" @mock.patch("api.rate_limiter._get_rate_limit_redis") def test_exempt_paths_skip_rate_limiting(self, mock_get_redis: mock.MagicMock) -> None: mock_redis = mock.MagicMock() mock_redis.ping.return_value = True mock_get_redis.return_value = mock_redis config = _make_config() app = _build_app(config) client = TestClient(app) resp = client.get("/api/status") assert resp.status_code == 200 assert "X-RateLimit-Limit" not in resp.headers # Redis pipeline should never be called for exempt paths mock_redis.pipeline.assert_not_called() @mock.patch("api.rate_limiter._get_rate_limit_redis") def test_fails_open_on_redis_error(self, mock_get_redis: mock.MagicMock) -> None: """When Redis raises an error, requests should be allowed through.""" import redis mock_redis = mock.MagicMock() mock_pipe = mock.MagicMock() mock_pipe.execute.side_effect = redis.RedisError("connection lost") mock_redis.pipeline.return_value = mock_pipe mock_redis.ping.return_value = True mock_get_redis.return_value = mock_redis config = _make_config() app = _build_app(config) client = TestClient(app) resp = client.get("/api/listing") assert resp.status_code == 200 @mock.patch("api.rate_limiter._get_rate_limit_redis") def test_fails_open_when_redis_unavailable_at_startup(self, mock_get_redis: mock.MagicMock) -> None: """When Redis is unavailable at startup, requests pass through.""" import redis mock_get_redis.side_effect = redis.RedisError("connection refused") config = _make_config() app = _build_app(config) client = TestClient(app) resp = client.get("/api/listing") assert resp.status_code == 200 @mock.patch("api.rate_limiter._get_rate_limit_redis") def test_unmatched_endpoint_skips_limiting(self, mock_get_redis: mock.MagicMock) -> None: """Endpoints not in the config are not rate limited.""" mock_redis = mock.MagicMock() mock_redis.ping.return_value = True mock_get_redis.return_value = mock_redis config = _make_config() app = Starlette(routes=[Route("/api/unknown", _ok_endpoint)]) app.add_middleware(RateLimitMiddleware, config=config) client = TestClient(app) resp = client.get("/api/unknown") assert resp.status_code == 200 assert "X-RateLimit-Limit" not in resp.headers class TestRateLimitConfig: """Tests for RateLimitConfig.from_env.""" def test_defaults(self) -> None: with mock.patch.dict("os.environ", {}, clear=True): config = RateLimitConfig.from_env() assert config.listing_limit_cap == 100 assert config.geojson_limit_cap == 5000 assert config.rate_limit_redis_db == 3 def test_custom_env_vars(self) -> None: env = { "RATE_LIMIT_LISTING": "50/120", "EXPORT_LISTING_LIMIT_CAP": "200", "RATE_LIMIT_REDIS_DB": "5", "METRICS_ALLOWED_IPS": "10.0.0.1", } with mock.patch.dict("os.environ", env, clear=True): config = RateLimitConfig.from_env() assert config.endpoint_limits["/api/listing"].max_requests == 50 assert config.endpoint_limits["/api/listing"].window_seconds == 120 assert config.listing_limit_cap == 200 assert config.rate_limit_redis_db == 5 assert config.metrics_allowed_ips == "10.0.0.1" def test_invalid_limit_format_uses_defaults(self) -> None: env = {"RATE_LIMIT_LISTING": "invalid"} with mock.patch.dict("os.environ", env, clear=True): config = RateLimitConfig.from_env() # Should fall back to default assert config.endpoint_limits["/api/listing"].max_requests == 30 assert config.endpoint_limits["/api/listing"].window_seconds == 60