"""Unit tests for api/origin_validator.py.""" from starlette.testclient import TestClient from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route from api.origin_validator import OriginValidatorMiddleware ALLOWED_ORIGINS = ["https://example.com", "https://app.example.com/"] async def _ok_endpoint(request: Request) -> JSONResponse: return JSONResponse({"ok": True}) def _build_app() -> Starlette: app = Starlette(routes=[ Route("/test", _ok_endpoint, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]), ]) app.add_middleware(OriginValidatorMiddleware, allowed_origins=ALLOWED_ORIGINS) return app class TestOriginValidator: """Tests for OriginValidatorMiddleware.""" def test_get_request_bypasses_check(self) -> None: """GET requests should not be checked for origin.""" client = TestClient(_build_app()) resp = client.get("/test", headers={"Origin": "https://evil.com"}) assert resp.status_code == 200 def test_post_with_allowed_origin(self) -> None: client = TestClient(_build_app()) resp = client.post("/test", headers={"Origin": "https://example.com"}) assert resp.status_code == 200 def test_post_with_allowed_origin_trailing_slash(self) -> None: """Trailing slash on origin should be normalized.""" client = TestClient(_build_app()) resp = client.post("/test", headers={"Origin": "https://app.example.com/"}) assert resp.status_code == 200 def test_post_with_disallowed_origin(self) -> None: client = TestClient(_build_app()) resp = client.post("/test", headers={"Origin": "https://evil.com"}) assert resp.status_code == 403 assert resp.json() == {"detail": "Origin not allowed"} def test_put_with_disallowed_origin(self) -> None: client = TestClient(_build_app()) resp = client.put("/test", headers={"Origin": "https://evil.com"}) assert resp.status_code == 403 def test_delete_with_disallowed_origin(self) -> None: client = TestClient(_build_app()) resp = client.delete("/test", headers={"Origin": "https://evil.com"}) assert resp.status_code == 403 def test_patch_with_disallowed_origin(self) -> None: client = TestClient(_build_app()) resp = client.patch("/test", headers={"Origin": "https://evil.com"}) assert resp.status_code == 403 def test_post_without_origin_passes(self) -> None: """Requests without an Origin header should pass through.""" client = TestClient(_build_app()) resp = client.post("/test") assert resp.status_code == 200 def test_no_allowed_origins_configured(self) -> None: """If no allowed origins are given, all origins are rejected.""" app = Starlette(routes=[ Route("/test", _ok_endpoint, methods=["POST"]), ]) app.add_middleware(OriginValidatorMiddleware, allowed_origins=[]) client = TestClient(app) resp = client.post("/test", headers={"Origin": "https://example.com"}) assert resp.status_code == 403