wrongmove/tests/unit/test_rate_limiter.py
Viktor Barzin 492921424e
Add security regression tests for all hardening fixes
- New: test_security_headers.py — verify all headers present, HSTS conditional on HTTPS
- New: test_passkey_error_handling.py — generic vs user-facing error messages
- New: test_poi_validation.py — field length and coordinate range constraints
- Extend test_rate_limiter.py — client IP depth selection, in-memory fallback enforcement
- Extend test_models.py — sqm range validation
- Extend test_task_service.py — IDOR 404, ownership 200, traceback suppression in production
2026-02-08 19:42:53 +00:00

328 lines
12 KiB
Python

"""Unit tests for api/rate_limiter.py."""
from unittest import mock
import pytest
from starlette.testclient import TestClient
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from api.rate_limit_config import EndpointLimit, RateLimitConfig
from api.rate_limiter import (
RateLimitMiddleware,
_extract_user_email,
_match_endpoint,
EXEMPT_PATHS,
)
def _make_config(**overrides: object) -> RateLimitConfig:
"""Create a RateLimitConfig with defaults for testing."""
defaults: dict[str, object] = {
"endpoint_limits": {
"/api/listing": EndpointLimit(3, 60),
"/api/passkey": EndpointLimit(2, 60),
},
"listing_limit_cap": 100,
"geojson_limit_cap": 5000,
"geojson_stream_limit_cap": 10000,
"geojson_stream_batch_size_cap": 200,
"rate_limit_redis_db": 3,
"metrics_allowed_ips": "127.0.0.1",
"trusted_proxy_depth": 1,
}
defaults.update(overrides)
return RateLimitConfig(**defaults) # type: ignore[arg-type]
async def _ok_endpoint(request: Request) -> JSONResponse:
return JSONResponse({"ok": True})
def _build_app(config: RateLimitConfig) -> Starlette:
"""Build a minimal Starlette app with the rate limiter."""
app = Starlette(routes=[
Route("/api/listing", _ok_endpoint),
Route("/api/passkey/login/begin", _ok_endpoint, methods=["POST"]),
Route("/api/status", _ok_endpoint),
])
app.add_middleware(RateLimitMiddleware, config=config)
return app
class TestExtractUserEmail:
"""Tests for _extract_user_email."""
def test_no_auth_header(self) -> None:
scope = {"type": "http", "headers": []}
request = Request(scope)
assert _extract_user_email(request) is None
def test_invalid_token(self) -> None:
scope = {
"type": "http",
"headers": [(b"authorization", b"Bearer not-a-jwt")],
}
request = Request(scope)
assert _extract_user_email(request) is None
def test_valid_jwt(self) -> None:
import jwt as pyjwt
token = pyjwt.encode({"email": "test@example.com"}, "secret", algorithm="HS256")
scope = {
"type": "http",
"headers": [(b"authorization", f"Bearer {token}".encode())],
}
request = Request(scope)
assert _extract_user_email(request) == "test@example.com"
class TestMatchEndpoint:
"""Tests for _match_endpoint."""
def test_exact_match(self) -> None:
config = _make_config()
limit = _match_endpoint("/api/listing", config)
assert limit is not None
assert limit.max_requests == 3
def test_passkey_prefix_match(self) -> None:
config = _make_config()
limit = _match_endpoint("/api/passkey/login/begin", config)
assert limit is not None
assert limit.max_requests == 2
def test_no_match(self) -> None:
config = _make_config()
assert _match_endpoint("/api/unknown", config) is None
class TestRateLimitMiddleware:
"""Integration tests for RateLimitMiddleware."""
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_allows_requests_under_limit(self, mock_get_redis: mock.MagicMock) -> None:
mock_redis = mock.MagicMock()
mock_pipe = mock.MagicMock()
mock_pipe.execute.return_value = [1, -1] # first request, no TTL yet
mock_redis.pipeline.return_value = mock_pipe
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 200
assert "X-RateLimit-Limit" in resp.headers
assert resp.headers["X-RateLimit-Limit"] == "3"
assert resp.headers["X-RateLimit-Remaining"] == "2"
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_returns_429_over_limit(self, mock_get_redis: mock.MagicMock) -> None:
mock_redis = mock.MagicMock()
mock_pipe = mock.MagicMock()
# 4th request in window (limit=3), TTL=45s remaining
mock_pipe.execute.return_value = [4, 45]
mock_redis.pipeline.return_value = mock_pipe
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 429
assert resp.headers["Retry-After"] == "45"
assert resp.headers["X-RateLimit-Remaining"] == "0"
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_exempt_paths_skip_rate_limiting(self, mock_get_redis: mock.MagicMock) -> None:
mock_redis = mock.MagicMock()
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/status")
assert resp.status_code == 200
assert "X-RateLimit-Limit" not in resp.headers
# Redis pipeline should never be called for exempt paths
mock_redis.pipeline.assert_not_called()
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_fails_open_on_redis_error(self, mock_get_redis: mock.MagicMock) -> None:
"""When Redis raises an error, requests should be allowed through."""
import redis
mock_redis = mock.MagicMock()
mock_pipe = mock.MagicMock()
mock_pipe.execute.side_effect = redis.RedisError("connection lost")
mock_redis.pipeline.return_value = mock_pipe
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 200
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_fails_open_when_redis_unavailable_at_startup(self, mock_get_redis: mock.MagicMock) -> None:
"""When Redis is unavailable at startup, requests pass through."""
import redis
mock_get_redis.side_effect = redis.RedisError("connection refused")
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 200
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_unmatched_endpoint_skips_limiting(self, mock_get_redis: mock.MagicMock) -> None:
"""Endpoints not in the config are not rate limited."""
mock_redis = mock.MagicMock()
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = Starlette(routes=[Route("/api/unknown", _ok_endpoint)])
app.add_middleware(RateLimitMiddleware, config=config)
client = TestClient(app)
resp = client.get("/api/unknown")
assert resp.status_code == 200
assert "X-RateLimit-Limit" not in resp.headers
class TestRateLimitConfig:
"""Tests for RateLimitConfig.from_env."""
def test_defaults(self) -> None:
with mock.patch.dict("os.environ", {}, clear=True):
config = RateLimitConfig.from_env()
assert config.listing_limit_cap == 100
assert config.geojson_limit_cap == 5000
assert config.rate_limit_redis_db == 3
def test_custom_env_vars(self) -> None:
env = {
"RATE_LIMIT_LISTING": "50/120",
"EXPORT_LISTING_LIMIT_CAP": "200",
"RATE_LIMIT_REDIS_DB": "5",
"METRICS_ALLOWED_IPS": "10.0.0.1",
}
with mock.patch.dict("os.environ", env, clear=True):
config = RateLimitConfig.from_env()
assert config.endpoint_limits["/api/listing"].max_requests == 50
assert config.endpoint_limits["/api/listing"].window_seconds == 120
assert config.listing_limit_cap == 200
assert config.rate_limit_redis_db == 5
assert config.metrics_allowed_ips == "10.0.0.1"
def test_invalid_limit_format_uses_defaults(self) -> None:
env = {"RATE_LIMIT_LISTING": "invalid"}
with mock.patch.dict("os.environ", env, clear=True):
config = RateLimitConfig.from_env()
# Should fall back to default
assert config.endpoint_limits["/api/listing"].max_requests == 30
assert config.endpoint_limits["/api/listing"].window_seconds == 60
from api.rate_limiter import _client_ip
class TestClientIp:
"""Tests for _client_ip with trusted proxy depth."""
def test_uses_rightmost_ip_with_depth_1(self) -> None:
scope = {
"type": "http",
"headers": [(b"x-forwarded-for", b"spoofed.ip, proxy1, real.ip")],
}
request = Request(scope)
assert _client_ip(request, depth=1) == "real.ip"
def test_uses_second_from_right_with_depth_2(self) -> None:
scope = {
"type": "http",
"headers": [(b"x-forwarded-for", b"spoofed.ip, real.ip, proxy1")],
}
request = Request(scope)
assert _client_ip(request, depth=2) == "real.ip"
def test_falls_back_to_connection_ip(self) -> None:
scope = {
"type": "http",
"headers": [],
"client": ("192.168.1.1", 12345),
}
request = Request(scope)
assert _client_ip(request) == "192.168.1.1"
def test_no_client_returns_unknown(self) -> None:
scope = {"type": "http", "headers": []}
request = Request(scope)
assert _client_ip(request) == "unknown"
def test_single_ip_with_depth_1(self) -> None:
scope = {
"type": "http",
"headers": [(b"x-forwarded-for", b"single.ip")],
}
request = Request(scope)
assert _client_ip(request, depth=1) == "single.ip"
class TestInMemoryFallback:
"""Tests for in-memory rate limit fallback when Redis is unavailable."""
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_fallback_enforces_limits_when_redis_unavailable(self, mock_get_redis: mock.MagicMock) -> None:
"""When Redis is unavailable at startup, in-memory fallback should enforce limits."""
import redis as redis_lib
mock_get_redis.side_effect = redis_lib.RedisError("connection refused")
config = _make_config()
app = _build_app(config)
client = TestClient(app)
# Should allow first 3 requests (limit=3)
for i in range(3):
resp = client.get("/api/listing")
assert resp.status_code == 200, f"Request {i+1} should be allowed"
# 4th request should be rate limited
resp = client.get("/api/listing")
assert resp.status_code == 429
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_fallback_activates_on_redis_error_during_request(self, mock_get_redis: mock.MagicMock) -> None:
"""When Redis errors during request handling, fallback should activate."""
import redis as redis_lib
mock_redis = mock.MagicMock()
mock_pipe = mock.MagicMock()
mock_pipe.execute.side_effect = redis_lib.RedisError("connection lost")
mock_redis.pipeline.return_value = mock_pipe
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
# First request should succeed via fallback
resp = client.get("/api/listing")
assert resp.status_code == 200
assert "X-RateLimit-Limit" in resp.headers