Per-user rate limits via Redis sliding window, IP-restricted /metrics endpoint, audit logging of all requests, CORS tightening, and export caps on listing/geojson endpoints.
111 lines
3.8 KiB
Python
111 lines
3.8 KiB
Python
"""Unit tests for api/metrics_guard.py."""
|
|
import ipaddress
|
|
|
|
import pytest
|
|
from starlette.testclient import TestClient
|
|
from starlette.applications import Starlette
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse
|
|
from starlette.routing import Route
|
|
|
|
from api.metrics_guard import MetricsGuardMiddleware, is_ip_allowed, parse_allowed_networks
|
|
from api.rate_limit_config import RateLimitConfig
|
|
|
|
|
|
async def _ok_endpoint(request: Request) -> JSONResponse:
|
|
return JSONResponse({"ok": True})
|
|
|
|
|
|
class TestParseAllowedNetworks:
|
|
"""Tests for parse_allowed_networks."""
|
|
|
|
def test_single_ip(self) -> None:
|
|
nets = parse_allowed_networks("127.0.0.1")
|
|
assert len(nets) == 1
|
|
assert ipaddress.ip_address("127.0.0.1") in nets[0]
|
|
|
|
def test_cidr(self) -> None:
|
|
nets = parse_allowed_networks("10.0.0.0/8")
|
|
assert len(nets) == 1
|
|
assert ipaddress.ip_address("10.255.255.255") in nets[0]
|
|
|
|
def test_multiple_entries(self) -> None:
|
|
nets = parse_allowed_networks("127.0.0.1, 10.0.0.0/8, ::1")
|
|
assert len(nets) == 3
|
|
|
|
def test_empty_string(self) -> None:
|
|
nets = parse_allowed_networks("")
|
|
assert nets == []
|
|
|
|
def test_trailing_comma(self) -> None:
|
|
nets = parse_allowed_networks("127.0.0.1,")
|
|
assert len(nets) == 1
|
|
|
|
|
|
class TestIsIpAllowed:
|
|
"""Tests for is_ip_allowed."""
|
|
|
|
def test_allowed_ip(self) -> None:
|
|
nets = parse_allowed_networks("10.0.0.0/8")
|
|
assert is_ip_allowed("10.1.2.3", nets) is True
|
|
|
|
def test_denied_ip(self) -> None:
|
|
nets = parse_allowed_networks("10.0.0.0/8")
|
|
assert is_ip_allowed("192.168.1.1", nets) is False
|
|
|
|
def test_ipv6(self) -> None:
|
|
nets = parse_allowed_networks("::1")
|
|
assert is_ip_allowed("::1", nets) is True
|
|
assert is_ip_allowed("::2", nets) is False
|
|
|
|
def test_invalid_ip(self) -> None:
|
|
nets = parse_allowed_networks("10.0.0.0/8")
|
|
assert is_ip_allowed("not-an-ip", nets) is False
|
|
|
|
|
|
class TestMetricsGuardMiddleware:
|
|
"""Integration tests for MetricsGuardMiddleware."""
|
|
|
|
def _build_app(self, allowed_ips: str) -> Starlette:
|
|
config = RateLimitConfig(metrics_allowed_ips=allowed_ips)
|
|
app = Starlette(routes=[
|
|
Route("/metrics", _ok_endpoint),
|
|
Route("/api/status", _ok_endpoint),
|
|
])
|
|
app.add_middleware(MetricsGuardMiddleware, config=config)
|
|
return app
|
|
|
|
def test_allows_metrics_from_allowed_ip(self) -> None:
|
|
app = self._build_app("127.0.0.1,testclient")
|
|
# TestClient connects from 'testclient' by default
|
|
# We need to override; use the header approach
|
|
app2 = self._build_app("10.0.0.1")
|
|
client = TestClient(app2)
|
|
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
|
|
assert resp.status_code == 200
|
|
|
|
def test_blocks_metrics_from_disallowed_ip(self) -> None:
|
|
app = self._build_app("10.0.0.0/8")
|
|
client = TestClient(app)
|
|
resp = client.get("/metrics", headers={"X-Forwarded-For": "192.168.1.1"})
|
|
assert resp.status_code == 403
|
|
|
|
def test_non_metrics_path_passes_through(self) -> None:
|
|
app = self._build_app("10.0.0.0/8")
|
|
client = TestClient(app)
|
|
resp = client.get("/api/status")
|
|
assert resp.status_code == 200
|
|
|
|
def test_default_private_ranges(self) -> None:
|
|
config = RateLimitConfig()
|
|
app = Starlette(routes=[Route("/metrics", _ok_endpoint)])
|
|
app.add_middleware(MetricsGuardMiddleware, config=config)
|
|
client = TestClient(app)
|
|
|
|
# Private IP should be allowed
|
|
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
|
|
assert resp.status_code == 200
|
|
|
|
# Public IP should be denied
|
|
resp = client.get("/metrics", headers={"X-Forwarded-For": "8.8.8.8"})
|
|
assert resp.status_code == 403
|