Add API rate limiting, metrics guard, and audit middleware
Per-user rate limits via Redis sliding window, IP-restricted /metrics endpoint, audit logging of all requests, CORS tightening, and export caps on listing/geojson endpoints.
This commit is contained in:
parent
08ac72bbfc
commit
87b5bd8676
8 changed files with 756 additions and 2 deletions
111
tests/unit/test_metrics_guard.py
Normal file
111
tests/unit/test_metrics_guard.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Unit tests for api/metrics_guard.py."""
|
||||
import ipaddress
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from api.metrics_guard import MetricsGuardMiddleware, is_ip_allowed, parse_allowed_networks
|
||||
from api.rate_limit_config import RateLimitConfig
|
||||
|
||||
|
||||
async def _ok_endpoint(request: Request) -> JSONResponse:
|
||||
return JSONResponse({"ok": True})
|
||||
|
||||
|
||||
class TestParseAllowedNetworks:
|
||||
"""Tests for parse_allowed_networks."""
|
||||
|
||||
def test_single_ip(self) -> None:
|
||||
nets = parse_allowed_networks("127.0.0.1")
|
||||
assert len(nets) == 1
|
||||
assert ipaddress.ip_address("127.0.0.1") in nets[0]
|
||||
|
||||
def test_cidr(self) -> None:
|
||||
nets = parse_allowed_networks("10.0.0.0/8")
|
||||
assert len(nets) == 1
|
||||
assert ipaddress.ip_address("10.255.255.255") in nets[0]
|
||||
|
||||
def test_multiple_entries(self) -> None:
|
||||
nets = parse_allowed_networks("127.0.0.1, 10.0.0.0/8, ::1")
|
||||
assert len(nets) == 3
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
nets = parse_allowed_networks("")
|
||||
assert nets == []
|
||||
|
||||
def test_trailing_comma(self) -> None:
|
||||
nets = parse_allowed_networks("127.0.0.1,")
|
||||
assert len(nets) == 1
|
||||
|
||||
|
||||
class TestIsIpAllowed:
|
||||
"""Tests for is_ip_allowed."""
|
||||
|
||||
def test_allowed_ip(self) -> None:
|
||||
nets = parse_allowed_networks("10.0.0.0/8")
|
||||
assert is_ip_allowed("10.1.2.3", nets) is True
|
||||
|
||||
def test_denied_ip(self) -> None:
|
||||
nets = parse_allowed_networks("10.0.0.0/8")
|
||||
assert is_ip_allowed("192.168.1.1", nets) is False
|
||||
|
||||
def test_ipv6(self) -> None:
|
||||
nets = parse_allowed_networks("::1")
|
||||
assert is_ip_allowed("::1", nets) is True
|
||||
assert is_ip_allowed("::2", nets) is False
|
||||
|
||||
def test_invalid_ip(self) -> None:
|
||||
nets = parse_allowed_networks("10.0.0.0/8")
|
||||
assert is_ip_allowed("not-an-ip", nets) is False
|
||||
|
||||
|
||||
class TestMetricsGuardMiddleware:
|
||||
"""Integration tests for MetricsGuardMiddleware."""
|
||||
|
||||
def _build_app(self, allowed_ips: str) -> Starlette:
|
||||
config = RateLimitConfig(metrics_allowed_ips=allowed_ips)
|
||||
app = Starlette(routes=[
|
||||
Route("/metrics", _ok_endpoint),
|
||||
Route("/api/status", _ok_endpoint),
|
||||
])
|
||||
app.add_middleware(MetricsGuardMiddleware, config=config)
|
||||
return app
|
||||
|
||||
def test_allows_metrics_from_allowed_ip(self) -> None:
|
||||
app = self._build_app("127.0.0.1,testclient")
|
||||
# TestClient connects from 'testclient' by default
|
||||
# We need to override; use the header approach
|
||||
app2 = self._build_app("10.0.0.1")
|
||||
client = TestClient(app2)
|
||||
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_blocks_metrics_from_disallowed_ip(self) -> None:
|
||||
app = self._build_app("10.0.0.0/8")
|
||||
client = TestClient(app)
|
||||
resp = client.get("/metrics", headers={"X-Forwarded-For": "192.168.1.1"})
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_non_metrics_path_passes_through(self) -> None:
|
||||
app = self._build_app("10.0.0.0/8")
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/status")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_default_private_ranges(self) -> None:
|
||||
config = RateLimitConfig()
|
||||
app = Starlette(routes=[Route("/metrics", _ok_endpoint)])
|
||||
app.add_middleware(MetricsGuardMiddleware, config=config)
|
||||
client = TestClient(app)
|
||||
|
||||
# Private IP should be allowed
|
||||
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Public IP should be denied
|
||||
resp = client.get("/metrics", headers={"X-Forwarded-For": "8.8.8.8"})
|
||||
assert resp.status_code == 403
|
||||
238
tests/unit/test_rate_limiter.py
Normal file
238
tests/unit/test_rate_limiter.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""Unit tests for api/rate_limiter.py."""
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from api.rate_limit_config import EndpointLimit, RateLimitConfig
|
||||
from api.rate_limiter import (
|
||||
RateLimitMiddleware,
|
||||
_extract_user_email,
|
||||
_match_endpoint,
|
||||
EXEMPT_PATHS,
|
||||
)
|
||||
|
||||
|
||||
def _make_config(**overrides: object) -> RateLimitConfig:
|
||||
"""Create a RateLimitConfig with defaults for testing."""
|
||||
defaults: dict[str, object] = {
|
||||
"endpoint_limits": {
|
||||
"/api/listing": EndpointLimit(3, 60),
|
||||
"/api/passkey": EndpointLimit(2, 60),
|
||||
},
|
||||
"listing_limit_cap": 100,
|
||||
"geojson_limit_cap": 5000,
|
||||
"geojson_stream_limit_cap": 10000,
|
||||
"geojson_stream_batch_size_cap": 200,
|
||||
"rate_limit_redis_db": 3,
|
||||
"metrics_allowed_ips": "127.0.0.1",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return RateLimitConfig(**defaults) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def _ok_endpoint(request: Request) -> JSONResponse:
|
||||
return JSONResponse({"ok": True})
|
||||
|
||||
|
||||
def _build_app(config: RateLimitConfig) -> Starlette:
|
||||
"""Build a minimal Starlette app with the rate limiter."""
|
||||
app = Starlette(routes=[
|
||||
Route("/api/listing", _ok_endpoint),
|
||||
Route("/api/passkey/login/begin", _ok_endpoint, methods=["POST"]),
|
||||
Route("/api/status", _ok_endpoint),
|
||||
])
|
||||
app.add_middleware(RateLimitMiddleware, config=config)
|
||||
return app
|
||||
|
||||
|
||||
class TestExtractUserEmail:
|
||||
"""Tests for _extract_user_email."""
|
||||
|
||||
def test_no_auth_header(self) -> None:
|
||||
scope = {"type": "http", "headers": []}
|
||||
request = Request(scope)
|
||||
assert _extract_user_email(request) is None
|
||||
|
||||
def test_invalid_token(self) -> None:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"headers": [(b"authorization", b"Bearer not-a-jwt")],
|
||||
}
|
||||
request = Request(scope)
|
||||
assert _extract_user_email(request) is None
|
||||
|
||||
def test_valid_jwt(self) -> None:
|
||||
import jwt as pyjwt
|
||||
token = pyjwt.encode({"email": "test@example.com"}, "secret", algorithm="HS256")
|
||||
scope = {
|
||||
"type": "http",
|
||||
"headers": [(b"authorization", f"Bearer {token}".encode())],
|
||||
}
|
||||
request = Request(scope)
|
||||
assert _extract_user_email(request) == "test@example.com"
|
||||
|
||||
|
||||
class TestMatchEndpoint:
|
||||
"""Tests for _match_endpoint."""
|
||||
|
||||
def test_exact_match(self) -> None:
|
||||
config = _make_config()
|
||||
limit = _match_endpoint("/api/listing", config)
|
||||
assert limit is not None
|
||||
assert limit.max_requests == 3
|
||||
|
||||
def test_passkey_prefix_match(self) -> None:
|
||||
config = _make_config()
|
||||
limit = _match_endpoint("/api/passkey/login/begin", config)
|
||||
assert limit is not None
|
||||
assert limit.max_requests == 2
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
config = _make_config()
|
||||
assert _match_endpoint("/api/unknown", config) is None
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Integration tests for RateLimitMiddleware."""
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_allows_requests_under_limit(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
mock_redis = mock.MagicMock()
|
||||
mock_pipe = mock.MagicMock()
|
||||
mock_pipe.execute.return_value = [1, -1] # first request, no TTL yet
|
||||
mock_redis.pipeline.return_value = mock_pipe
|
||||
mock_redis.ping.return_value = True
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/listing")
|
||||
assert resp.status_code == 200
|
||||
assert "X-RateLimit-Limit" in resp.headers
|
||||
assert resp.headers["X-RateLimit-Limit"] == "3"
|
||||
assert resp.headers["X-RateLimit-Remaining"] == "2"
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_returns_429_over_limit(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
mock_redis = mock.MagicMock()
|
||||
mock_pipe = mock.MagicMock()
|
||||
# 4th request in window (limit=3), TTL=45s remaining
|
||||
mock_pipe.execute.return_value = [4, 45]
|
||||
mock_redis.pipeline.return_value = mock_pipe
|
||||
mock_redis.ping.return_value = True
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/listing")
|
||||
assert resp.status_code == 429
|
||||
assert resp.headers["Retry-After"] == "45"
|
||||
assert resp.headers["X-RateLimit-Remaining"] == "0"
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_exempt_paths_skip_rate_limiting(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
mock_redis = mock.MagicMock()
|
||||
mock_redis.ping.return_value = True
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/status")
|
||||
assert resp.status_code == 200
|
||||
assert "X-RateLimit-Limit" not in resp.headers
|
||||
# Redis pipeline should never be called for exempt paths
|
||||
mock_redis.pipeline.assert_not_called()
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_fails_open_on_redis_error(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
"""When Redis raises an error, requests should be allowed through."""
|
||||
import redis
|
||||
|
||||
mock_redis = mock.MagicMock()
|
||||
mock_pipe = mock.MagicMock()
|
||||
mock_pipe.execute.side_effect = redis.RedisError("connection lost")
|
||||
mock_redis.pipeline.return_value = mock_pipe
|
||||
mock_redis.ping.return_value = True
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/listing")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_fails_open_when_redis_unavailable_at_startup(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
"""When Redis is unavailable at startup, requests pass through."""
|
||||
import redis
|
||||
|
||||
mock_get_redis.side_effect = redis.RedisError("connection refused")
|
||||
|
||||
config = _make_config()
|
||||
app = _build_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/listing")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@mock.patch("api.rate_limiter._get_rate_limit_redis")
|
||||
def test_unmatched_endpoint_skips_limiting(self, mock_get_redis: mock.MagicMock) -> None:
|
||||
"""Endpoints not in the config are not rate limited."""
|
||||
mock_redis = mock.MagicMock()
|
||||
mock_redis.ping.return_value = True
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
config = _make_config()
|
||||
app = Starlette(routes=[Route("/api/unknown", _ok_endpoint)])
|
||||
app.add_middleware(RateLimitMiddleware, config=config)
|
||||
client = TestClient(app)
|
||||
|
||||
resp = client.get("/api/unknown")
|
||||
assert resp.status_code == 200
|
||||
assert "X-RateLimit-Limit" not in resp.headers
|
||||
|
||||
|
||||
class TestRateLimitConfig:
|
||||
"""Tests for RateLimitConfig.from_env."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
with mock.patch.dict("os.environ", {}, clear=True):
|
||||
config = RateLimitConfig.from_env()
|
||||
assert config.listing_limit_cap == 100
|
||||
assert config.geojson_limit_cap == 5000
|
||||
assert config.rate_limit_redis_db == 3
|
||||
|
||||
def test_custom_env_vars(self) -> None:
|
||||
env = {
|
||||
"RATE_LIMIT_LISTING": "50/120",
|
||||
"EXPORT_LISTING_LIMIT_CAP": "200",
|
||||
"RATE_LIMIT_REDIS_DB": "5",
|
||||
"METRICS_ALLOWED_IPS": "10.0.0.1",
|
||||
}
|
||||
with mock.patch.dict("os.environ", env, clear=True):
|
||||
config = RateLimitConfig.from_env()
|
||||
assert config.endpoint_limits["/api/listing"].max_requests == 50
|
||||
assert config.endpoint_limits["/api/listing"].window_seconds == 120
|
||||
assert config.listing_limit_cap == 200
|
||||
assert config.rate_limit_redis_db == 5
|
||||
assert config.metrics_allowed_ips == "10.0.0.1"
|
||||
|
||||
def test_invalid_limit_format_uses_defaults(self) -> None:
|
||||
env = {"RATE_LIMIT_LISTING": "invalid"}
|
||||
with mock.patch.dict("os.environ", env, clear=True):
|
||||
config = RateLimitConfig.from_env()
|
||||
# Should fall back to default
|
||||
assert config.endpoint_limits["/api/listing"].max_requests == 30
|
||||
assert config.endpoint_limits["/api/listing"].window_seconds == 60
|
||||
Loading…
Add table
Add a link
Reference in a new issue