Add API rate limiting, metrics guard, and audit middleware

Per-user rate limits via Redis sliding window, IP-restricted /metrics
endpoint, audit logging of all requests, CORS tightening, and export
caps on listing/geojson endpoints.
This commit is contained in:
Viktor Barzin 2026-02-08 00:45:43 +00:00
parent 08ac72bbfc
commit 87b5bd8676
No known key found for this signature in database
GPG key ID: 0EB088298288D958
8 changed files with 756 additions and 2 deletions

View file

@ -0,0 +1,111 @@
"""Unit tests for api/metrics_guard.py."""
import ipaddress
import pytest
from starlette.testclient import TestClient
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from api.metrics_guard import MetricsGuardMiddleware, is_ip_allowed, parse_allowed_networks
from api.rate_limit_config import RateLimitConfig
async def _ok_endpoint(request: Request) -> JSONResponse:
return JSONResponse({"ok": True})
class TestParseAllowedNetworks:
"""Tests for parse_allowed_networks."""
def test_single_ip(self) -> None:
nets = parse_allowed_networks("127.0.0.1")
assert len(nets) == 1
assert ipaddress.ip_address("127.0.0.1") in nets[0]
def test_cidr(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert len(nets) == 1
assert ipaddress.ip_address("10.255.255.255") in nets[0]
def test_multiple_entries(self) -> None:
nets = parse_allowed_networks("127.0.0.1, 10.0.0.0/8, ::1")
assert len(nets) == 3
def test_empty_string(self) -> None:
nets = parse_allowed_networks("")
assert nets == []
def test_trailing_comma(self) -> None:
nets = parse_allowed_networks("127.0.0.1,")
assert len(nets) == 1
class TestIsIpAllowed:
"""Tests for is_ip_allowed."""
def test_allowed_ip(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert is_ip_allowed("10.1.2.3", nets) is True
def test_denied_ip(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert is_ip_allowed("192.168.1.1", nets) is False
def test_ipv6(self) -> None:
nets = parse_allowed_networks("::1")
assert is_ip_allowed("::1", nets) is True
assert is_ip_allowed("::2", nets) is False
def test_invalid_ip(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert is_ip_allowed("not-an-ip", nets) is False
class TestMetricsGuardMiddleware:
"""Integration tests for MetricsGuardMiddleware."""
def _build_app(self, allowed_ips: str) -> Starlette:
config = RateLimitConfig(metrics_allowed_ips=allowed_ips)
app = Starlette(routes=[
Route("/metrics", _ok_endpoint),
Route("/api/status", _ok_endpoint),
])
app.add_middleware(MetricsGuardMiddleware, config=config)
return app
def test_allows_metrics_from_allowed_ip(self) -> None:
app = self._build_app("127.0.0.1,testclient")
# TestClient connects from 'testclient' by default
# We need to override; use the header approach
app2 = self._build_app("10.0.0.1")
client = TestClient(app2)
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
assert resp.status_code == 200
def test_blocks_metrics_from_disallowed_ip(self) -> None:
app = self._build_app("10.0.0.0/8")
client = TestClient(app)
resp = client.get("/metrics", headers={"X-Forwarded-For": "192.168.1.1"})
assert resp.status_code == 403
def test_non_metrics_path_passes_through(self) -> None:
app = self._build_app("10.0.0.0/8")
client = TestClient(app)
resp = client.get("/api/status")
assert resp.status_code == 200
def test_default_private_ranges(self) -> None:
config = RateLimitConfig()
app = Starlette(routes=[Route("/metrics", _ok_endpoint)])
app.add_middleware(MetricsGuardMiddleware, config=config)
client = TestClient(app)
# Private IP should be allowed
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
assert resp.status_code == 200
# Public IP should be denied
resp = client.get("/metrics", headers={"X-Forwarded-For": "8.8.8.8"})
assert resp.status_code == 403

View file

@ -0,0 +1,238 @@
"""Unit tests for api/rate_limiter.py."""
from unittest import mock
import pytest
from starlette.testclient import TestClient
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from api.rate_limit_config import EndpointLimit, RateLimitConfig
from api.rate_limiter import (
RateLimitMiddleware,
_extract_user_email,
_match_endpoint,
EXEMPT_PATHS,
)
def _make_config(**overrides: object) -> RateLimitConfig:
"""Create a RateLimitConfig with defaults for testing."""
defaults: dict[str, object] = {
"endpoint_limits": {
"/api/listing": EndpointLimit(3, 60),
"/api/passkey": EndpointLimit(2, 60),
},
"listing_limit_cap": 100,
"geojson_limit_cap": 5000,
"geojson_stream_limit_cap": 10000,
"geojson_stream_batch_size_cap": 200,
"rate_limit_redis_db": 3,
"metrics_allowed_ips": "127.0.0.1",
}
defaults.update(overrides)
return RateLimitConfig(**defaults) # type: ignore[arg-type]
async def _ok_endpoint(request: Request) -> JSONResponse:
return JSONResponse({"ok": True})
def _build_app(config: RateLimitConfig) -> Starlette:
"""Build a minimal Starlette app with the rate limiter."""
app = Starlette(routes=[
Route("/api/listing", _ok_endpoint),
Route("/api/passkey/login/begin", _ok_endpoint, methods=["POST"]),
Route("/api/status", _ok_endpoint),
])
app.add_middleware(RateLimitMiddleware, config=config)
return app
class TestExtractUserEmail:
"""Tests for _extract_user_email."""
def test_no_auth_header(self) -> None:
scope = {"type": "http", "headers": []}
request = Request(scope)
assert _extract_user_email(request) is None
def test_invalid_token(self) -> None:
scope = {
"type": "http",
"headers": [(b"authorization", b"Bearer not-a-jwt")],
}
request = Request(scope)
assert _extract_user_email(request) is None
def test_valid_jwt(self) -> None:
import jwt as pyjwt
token = pyjwt.encode({"email": "test@example.com"}, "secret", algorithm="HS256")
scope = {
"type": "http",
"headers": [(b"authorization", f"Bearer {token}".encode())],
}
request = Request(scope)
assert _extract_user_email(request) == "test@example.com"
class TestMatchEndpoint:
"""Tests for _match_endpoint."""
def test_exact_match(self) -> None:
config = _make_config()
limit = _match_endpoint("/api/listing", config)
assert limit is not None
assert limit.max_requests == 3
def test_passkey_prefix_match(self) -> None:
config = _make_config()
limit = _match_endpoint("/api/passkey/login/begin", config)
assert limit is not None
assert limit.max_requests == 2
def test_no_match(self) -> None:
config = _make_config()
assert _match_endpoint("/api/unknown", config) is None
class TestRateLimitMiddleware:
"""Integration tests for RateLimitMiddleware."""
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_allows_requests_under_limit(self, mock_get_redis: mock.MagicMock) -> None:
mock_redis = mock.MagicMock()
mock_pipe = mock.MagicMock()
mock_pipe.execute.return_value = [1, -1] # first request, no TTL yet
mock_redis.pipeline.return_value = mock_pipe
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 200
assert "X-RateLimit-Limit" in resp.headers
assert resp.headers["X-RateLimit-Limit"] == "3"
assert resp.headers["X-RateLimit-Remaining"] == "2"
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_returns_429_over_limit(self, mock_get_redis: mock.MagicMock) -> None:
mock_redis = mock.MagicMock()
mock_pipe = mock.MagicMock()
# 4th request in window (limit=3), TTL=45s remaining
mock_pipe.execute.return_value = [4, 45]
mock_redis.pipeline.return_value = mock_pipe
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 429
assert resp.headers["Retry-After"] == "45"
assert resp.headers["X-RateLimit-Remaining"] == "0"
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_exempt_paths_skip_rate_limiting(self, mock_get_redis: mock.MagicMock) -> None:
mock_redis = mock.MagicMock()
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/status")
assert resp.status_code == 200
assert "X-RateLimit-Limit" not in resp.headers
# Redis pipeline should never be called for exempt paths
mock_redis.pipeline.assert_not_called()
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_fails_open_on_redis_error(self, mock_get_redis: mock.MagicMock) -> None:
"""When Redis raises an error, requests should be allowed through."""
import redis
mock_redis = mock.MagicMock()
mock_pipe = mock.MagicMock()
mock_pipe.execute.side_effect = redis.RedisError("connection lost")
mock_redis.pipeline.return_value = mock_pipe
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 200
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_fails_open_when_redis_unavailable_at_startup(self, mock_get_redis: mock.MagicMock) -> None:
"""When Redis is unavailable at startup, requests pass through."""
import redis
mock_get_redis.side_effect = redis.RedisError("connection refused")
config = _make_config()
app = _build_app(config)
client = TestClient(app)
resp = client.get("/api/listing")
assert resp.status_code == 200
@mock.patch("api.rate_limiter._get_rate_limit_redis")
def test_unmatched_endpoint_skips_limiting(self, mock_get_redis: mock.MagicMock) -> None:
"""Endpoints not in the config are not rate limited."""
mock_redis = mock.MagicMock()
mock_redis.ping.return_value = True
mock_get_redis.return_value = mock_redis
config = _make_config()
app = Starlette(routes=[Route("/api/unknown", _ok_endpoint)])
app.add_middleware(RateLimitMiddleware, config=config)
client = TestClient(app)
resp = client.get("/api/unknown")
assert resp.status_code == 200
assert "X-RateLimit-Limit" not in resp.headers
class TestRateLimitConfig:
"""Tests for RateLimitConfig.from_env."""
def test_defaults(self) -> None:
with mock.patch.dict("os.environ", {}, clear=True):
config = RateLimitConfig.from_env()
assert config.listing_limit_cap == 100
assert config.geojson_limit_cap == 5000
assert config.rate_limit_redis_db == 3
def test_custom_env_vars(self) -> None:
env = {
"RATE_LIMIT_LISTING": "50/120",
"EXPORT_LISTING_LIMIT_CAP": "200",
"RATE_LIMIT_REDIS_DB": "5",
"METRICS_ALLOWED_IPS": "10.0.0.1",
}
with mock.patch.dict("os.environ", env, clear=True):
config = RateLimitConfig.from_env()
assert config.endpoint_limits["/api/listing"].max_requests == 50
assert config.endpoint_limits["/api/listing"].window_seconds == 120
assert config.listing_limit_cap == 200
assert config.rate_limit_redis_db == 5
assert config.metrics_allowed_ips == "10.0.0.1"
def test_invalid_limit_format_uses_defaults(self) -> None:
env = {"RATE_LIMIT_LISTING": "invalid"}
with mock.patch.dict("os.environ", env, clear=True):
config = RateLimitConfig.from_env()
# Should fall back to default
assert config.endpoint_limits["/api/listing"].max_requests == 30
assert config.endpoint_limits["/api/listing"].window_seconds == 60