"""IP allowlist middleware for the /metrics endpoint.""" from __future__ import annotations import ipaddress import logging from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response from api.rate_limit_config import RateLimitConfig logger = logging.getLogger("uvicorn") def parse_allowed_networks(raw: str) -> list[ipaddress.IPv4Network | ipaddress.IPv6Network]: """Parse a comma-separated string of IPs/CIDRs into network objects.""" networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = [] for entry in raw.split(","): entry = entry.strip() if not entry: continue networks.append(ipaddress.ip_network(entry, strict=False)) return networks def is_ip_allowed( ip_str: str, allowed_networks: list[ipaddress.IPv4Network | ipaddress.IPv6Network], ) -> bool: """Check whether an IP address falls within any of the allowed networks.""" try: addr = ipaddress.ip_address(ip_str) except ValueError: return False return any(addr in network for network in allowed_networks) class MetricsGuardMiddleware(BaseHTTPMiddleware): """Restricts /metrics access to an IP allowlist.""" def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def] super().__init__(app) cfg = config or RateLimitConfig.from_env() self._allowed_networks = parse_allowed_networks(cfg.metrics_allowed_ips) async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def] if not request.url.path.startswith("/metrics"): return await call_next(request) forwarded = request.headers.get("x-forwarded-for") if forwarded: client_ip = forwarded.split(",")[0].strip() else: client_ip = request.client.host if request.client else "unknown" if not is_ip_allowed(client_ip, self._allowed_networks): logger.warning("Metrics access denied for IP %s", client_ip) return JSONResponse(status_code=403, content={"detail": "Forbidden"}) return await call_next(request)