82 lines
3.1 KiB
Python
82 lines
3.1 KiB
Python
|
|
"""Unit tests for api/origin_validator.py."""
|
||
|
|
from starlette.testclient import TestClient
|
||
|
|
from starlette.applications import Starlette
|
||
|
|
from starlette.requests import Request
|
||
|
|
from starlette.responses import JSONResponse
|
||
|
|
from starlette.routing import Route
|
||
|
|
|
||
|
|
from api.origin_validator import OriginValidatorMiddleware
|
||
|
|
|
||
|
|
|
||
|
|
ALLOWED_ORIGINS = ["https://example.com", "https://app.example.com/"]
|
||
|
|
|
||
|
|
|
||
|
|
async def _ok_endpoint(request: Request) -> JSONResponse:
|
||
|
|
return JSONResponse({"ok": True})
|
||
|
|
|
||
|
|
|
||
|
|
def _build_app() -> Starlette:
|
||
|
|
app = Starlette(routes=[
|
||
|
|
Route("/test", _ok_endpoint, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]),
|
||
|
|
])
|
||
|
|
app.add_middleware(OriginValidatorMiddleware, allowed_origins=ALLOWED_ORIGINS)
|
||
|
|
return app
|
||
|
|
|
||
|
|
|
||
|
|
class TestOriginValidator:
|
||
|
|
"""Tests for OriginValidatorMiddleware."""
|
||
|
|
|
||
|
|
def test_get_request_bypasses_check(self) -> None:
|
||
|
|
"""GET requests should not be checked for origin."""
|
||
|
|
client = TestClient(_build_app())
|
||
|
|
resp = client.get("/test", headers={"Origin": "https://evil.com"})
|
||
|
|
assert resp.status_code == 200
|
||
|
|
|
||
|
|
def test_post_with_allowed_origin(self) -> None:
|
||
|
|
client = TestClient(_build_app())
|
||
|
|
resp = client.post("/test", headers={"Origin": "https://example.com"})
|
||
|
|
assert resp.status_code == 200
|
||
|
|
|
||
|
|
def test_post_with_allowed_origin_trailing_slash(self) -> None:
|
||
|
|
"""Trailing slash on origin should be normalized."""
|
||
|
|
client = TestClient(_build_app())
|
||
|
|
resp = client.post("/test", headers={"Origin": "https://app.example.com/"})
|
||
|
|
assert resp.status_code == 200
|
||
|
|
|
||
|
|
def test_post_with_disallowed_origin(self) -> None:
|
||
|
|
client = TestClient(_build_app())
|
||
|
|
resp = client.post("/test", headers={"Origin": "https://evil.com"})
|
||
|
|
assert resp.status_code == 403
|
||
|
|
assert resp.json() == {"detail": "Origin not allowed"}
|
||
|
|
|
||
|
|
def test_put_with_disallowed_origin(self) -> None:
|
||
|
|
client = TestClient(_build_app())
|
||
|
|
resp = client.put("/test", headers={"Origin": "https://evil.com"})
|
||
|
|
assert resp.status_code == 403
|
||
|
|
|
||
|
|
def test_delete_with_disallowed_origin(self) -> None:
|
||
|
|
client = TestClient(_build_app())
|
||
|
|
resp = client.delete("/test", headers={"Origin": "https://evil.com"})
|
||
|
|
assert resp.status_code == 403
|
||
|
|
|
||
|
|
def test_patch_with_disallowed_origin(self) -> None:
|
||
|
|
client = TestClient(_build_app())
|
||
|
|
resp = client.patch("/test", headers={"Origin": "https://evil.com"})
|
||
|
|
assert resp.status_code == 403
|
||
|
|
|
||
|
|
def test_post_without_origin_passes(self) -> None:
|
||
|
|
"""Requests without an Origin header should pass through."""
|
||
|
|
client = TestClient(_build_app())
|
||
|
|
resp = client.post("/test")
|
||
|
|
assert resp.status_code == 200
|
||
|
|
|
||
|
|
def test_no_allowed_origins_configured(self) -> None:
|
||
|
|
"""If no allowed origins are given, all origins are rejected."""
|
||
|
|
app = Starlette(routes=[
|
||
|
|
Route("/test", _ok_endpoint, methods=["POST"]),
|
||
|
|
])
|
||
|
|
app.add_middleware(OriginValidatorMiddleware, allowed_origins=[])
|
||
|
|
client = TestClient(app)
|
||
|
|
resp = client.post("/test", headers={"Origin": "https://example.com"})
|
||
|
|
assert resp.status_code == 403
|