wrongmove/tests/unit/test_metrics_guard.py
Viktor Barzin 87b5bd8676
Add API rate limiting, metrics guard, and audit middleware
Per-user rate limits via Redis sliding window, IP-restricted /metrics
endpoint, audit logging of all requests, CORS tightening, and export
caps on listing/geojson endpoints.
2026-02-08 00:45:43 +00:00

111 lines
3.8 KiB
Python

"""Unit tests for api/metrics_guard.py."""
import ipaddress
import pytest
from starlette.testclient import TestClient
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from api.metrics_guard import MetricsGuardMiddleware, is_ip_allowed, parse_allowed_networks
from api.rate_limit_config import RateLimitConfig
async def _ok_endpoint(request: Request) -> JSONResponse:
return JSONResponse({"ok": True})
class TestParseAllowedNetworks:
"""Tests for parse_allowed_networks."""
def test_single_ip(self) -> None:
nets = parse_allowed_networks("127.0.0.1")
assert len(nets) == 1
assert ipaddress.ip_address("127.0.0.1") in nets[0]
def test_cidr(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert len(nets) == 1
assert ipaddress.ip_address("10.255.255.255") in nets[0]
def test_multiple_entries(self) -> None:
nets = parse_allowed_networks("127.0.0.1, 10.0.0.0/8, ::1")
assert len(nets) == 3
def test_empty_string(self) -> None:
nets = parse_allowed_networks("")
assert nets == []
def test_trailing_comma(self) -> None:
nets = parse_allowed_networks("127.0.0.1,")
assert len(nets) == 1
class TestIsIpAllowed:
"""Tests for is_ip_allowed."""
def test_allowed_ip(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert is_ip_allowed("10.1.2.3", nets) is True
def test_denied_ip(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert is_ip_allowed("192.168.1.1", nets) is False
def test_ipv6(self) -> None:
nets = parse_allowed_networks("::1")
assert is_ip_allowed("::1", nets) is True
assert is_ip_allowed("::2", nets) is False
def test_invalid_ip(self) -> None:
nets = parse_allowed_networks("10.0.0.0/8")
assert is_ip_allowed("not-an-ip", nets) is False
class TestMetricsGuardMiddleware:
"""Integration tests for MetricsGuardMiddleware."""
def _build_app(self, allowed_ips: str) -> Starlette:
config = RateLimitConfig(metrics_allowed_ips=allowed_ips)
app = Starlette(routes=[
Route("/metrics", _ok_endpoint),
Route("/api/status", _ok_endpoint),
])
app.add_middleware(MetricsGuardMiddleware, config=config)
return app
def test_allows_metrics_from_allowed_ip(self) -> None:
app = self._build_app("127.0.0.1,testclient")
# TestClient connects from 'testclient' by default
# We need to override; use the header approach
app2 = self._build_app("10.0.0.1")
client = TestClient(app2)
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
assert resp.status_code == 200
def test_blocks_metrics_from_disallowed_ip(self) -> None:
app = self._build_app("10.0.0.0/8")
client = TestClient(app)
resp = client.get("/metrics", headers={"X-Forwarded-For": "192.168.1.1"})
assert resp.status_code == 403
def test_non_metrics_path_passes_through(self) -> None:
app = self._build_app("10.0.0.0/8")
client = TestClient(app)
resp = client.get("/api/status")
assert resp.status_code == 200
def test_default_private_ranges(self) -> None:
config = RateLimitConfig()
app = Starlette(routes=[Route("/metrics", _ok_endpoint)])
app.add_middleware(MetricsGuardMiddleware, config=config)
client = TestClient(app)
# Private IP should be allowed
resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"})
assert resp.status_code == 200
# Public IP should be denied
resp = client.get("/metrics", headers={"X-Forwarded-For": "8.8.8.8"})
assert resp.status_code == 403