"""Unit tests for api/metrics_guard.py.""" import ipaddress import pytest from starlette.testclient import TestClient from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route from api.metrics_guard import MetricsGuardMiddleware, is_ip_allowed, parse_allowed_networks from api.rate_limit_config import RateLimitConfig async def _ok_endpoint(request: Request) -> JSONResponse: return JSONResponse({"ok": True}) class TestParseAllowedNetworks: """Tests for parse_allowed_networks.""" def test_single_ip(self) -> None: nets = parse_allowed_networks("127.0.0.1") assert len(nets) == 1 assert ipaddress.ip_address("127.0.0.1") in nets[0] def test_cidr(self) -> None: nets = parse_allowed_networks("10.0.0.0/8") assert len(nets) == 1 assert ipaddress.ip_address("10.255.255.255") in nets[0] def test_multiple_entries(self) -> None: nets = parse_allowed_networks("127.0.0.1, 10.0.0.0/8, ::1") assert len(nets) == 3 def test_empty_string(self) -> None: nets = parse_allowed_networks("") assert nets == [] def test_trailing_comma(self) -> None: nets = parse_allowed_networks("127.0.0.1,") assert len(nets) == 1 class TestIsIpAllowed: """Tests for is_ip_allowed.""" def test_allowed_ip(self) -> None: nets = parse_allowed_networks("10.0.0.0/8") assert is_ip_allowed("10.1.2.3", nets) is True def test_denied_ip(self) -> None: nets = parse_allowed_networks("10.0.0.0/8") assert is_ip_allowed("192.168.1.1", nets) is False def test_ipv6(self) -> None: nets = parse_allowed_networks("::1") assert is_ip_allowed("::1", nets) is True assert is_ip_allowed("::2", nets) is False def test_invalid_ip(self) -> None: nets = parse_allowed_networks("10.0.0.0/8") assert is_ip_allowed("not-an-ip", nets) is False class TestMetricsGuardMiddleware: """Integration tests for MetricsGuardMiddleware.""" def _build_app(self, allowed_ips: str) -> Starlette: config = RateLimitConfig(metrics_allowed_ips=allowed_ips) app = Starlette(routes=[ Route("/metrics", _ok_endpoint), Route("/api/status", _ok_endpoint), ]) app.add_middleware(MetricsGuardMiddleware, config=config) return app def test_allows_metrics_from_allowed_ip(self) -> None: app = self._build_app("127.0.0.1,testclient") # TestClient connects from 'testclient' by default # We need to override; use the header approach app2 = self._build_app("10.0.0.1") client = TestClient(app2) resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"}) assert resp.status_code == 200 def test_blocks_metrics_from_disallowed_ip(self) -> None: app = self._build_app("10.0.0.0/8") client = TestClient(app) resp = client.get("/metrics", headers={"X-Forwarded-For": "192.168.1.1"}) assert resp.status_code == 403 def test_non_metrics_path_passes_through(self) -> None: app = self._build_app("10.0.0.0/8") client = TestClient(app) resp = client.get("/api/status") assert resp.status_code == 200 def test_default_private_ranges(self) -> None: config = RateLimitConfig() app = Starlette(routes=[Route("/metrics", _ok_endpoint)]) app.add_middleware(MetricsGuardMiddleware, config=config) client = TestClient(app) # Private IP should be allowed resp = client.get("/metrics", headers={"X-Forwarded-For": "10.0.0.1"}) assert resp.status_code == 200 # Public IP should be denied resp = client.get("/metrics", headers={"X-Forwarded-For": "8.8.8.8"}) assert resp.status_code == 403