152 lines
6.2 KiB
Python
152 lines
6.2 KiB
Python
|
|
"""Unit tests for api/auth.py."""
|
||
|
|
from datetime import datetime, timedelta, timezone
|
||
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||
|
|
|
||
|
|
import jwt as pyjwt
|
||
|
|
import pytest
|
||
|
|
from fastapi import HTTPException
|
||
|
|
from fastapi.security import HTTPAuthorizationCredentials
|
||
|
|
|
||
|
|
from api.auth import (
|
||
|
|
User,
|
||
|
|
_verify_passkey_token,
|
||
|
|
_verify_authentik_token,
|
||
|
|
get_current_user,
|
||
|
|
)
|
||
|
|
from api.config import JWT_SECRET, JWT_ALGORITHM, JWT_ISSUER
|
||
|
|
|
||
|
|
|
||
|
|
def _make_passkey_token(
|
||
|
|
sub: str = "user-123",
|
||
|
|
email: str = "test@example.com",
|
||
|
|
name: str = "Test User",
|
||
|
|
issuer: str = JWT_ISSUER,
|
||
|
|
secret: str = JWT_SECRET,
|
||
|
|
algorithm: str = JWT_ALGORITHM,
|
||
|
|
expires_delta: timedelta | None = timedelta(hours=1),
|
||
|
|
) -> str:
|
||
|
|
"""Helper to mint a passkey-style HS256 JWT."""
|
||
|
|
payload: dict = {"sub": sub, "email": email, "name": name, "iss": issuer}
|
||
|
|
if expires_delta is not None:
|
||
|
|
payload["exp"] = datetime.now(timezone.utc) + expires_delta
|
||
|
|
return pyjwt.encode(payload, secret, algorithm=algorithm)
|
||
|
|
|
||
|
|
|
||
|
|
class TestVerifyPasskeyToken:
|
||
|
|
"""Tests for _verify_passkey_token()."""
|
||
|
|
|
||
|
|
def test_valid_token_returns_user(self) -> None:
|
||
|
|
token = _make_passkey_token()
|
||
|
|
user = _verify_passkey_token(token)
|
||
|
|
assert isinstance(user, User)
|
||
|
|
assert user.sub == "user-123"
|
||
|
|
assert user.email == "test@example.com"
|
||
|
|
assert user.name == "Test User"
|
||
|
|
|
||
|
|
def test_valid_token_without_name_uses_email(self) -> None:
|
||
|
|
payload = {
|
||
|
|
"sub": "user-456",
|
||
|
|
"email": "noname@example.com",
|
||
|
|
"iss": JWT_ISSUER,
|
||
|
|
"exp": datetime.now(timezone.utc) + timedelta(hours=1),
|
||
|
|
}
|
||
|
|
token = pyjwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
||
|
|
user = _verify_passkey_token(token)
|
||
|
|
assert user.name == "noname@example.com"
|
||
|
|
|
||
|
|
def test_rejects_expired_token(self) -> None:
|
||
|
|
token = _make_passkey_token(expires_delta=timedelta(hours=-1))
|
||
|
|
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||
|
|
_verify_passkey_token(token)
|
||
|
|
|
||
|
|
def test_rejects_wrong_secret(self) -> None:
|
||
|
|
token = _make_passkey_token(secret="wrong-secret")
|
||
|
|
with pytest.raises(pyjwt.InvalidSignatureError):
|
||
|
|
_verify_passkey_token(token)
|
||
|
|
|
||
|
|
def test_rejects_wrong_issuer(self) -> None:
|
||
|
|
token = _make_passkey_token(issuer="some-other-issuer")
|
||
|
|
with pytest.raises(pyjwt.InvalidIssuerError):
|
||
|
|
_verify_passkey_token(token)
|
||
|
|
|
||
|
|
|
||
|
|
class TestVerifyAuthentikToken:
|
||
|
|
"""Tests for _verify_authentik_token() — specifically that expiration is verified."""
|
||
|
|
|
||
|
|
async def test_verifies_expiration_after_fix(self) -> None:
|
||
|
|
"""After removing verify_exp: False, expired Authentik tokens should be rejected."""
|
||
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||
|
|
from cryptography.hazmat.primitives import serialization
|
||
|
|
|
||
|
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||
|
|
public_key = private_key.public_key()
|
||
|
|
public_pem = public_key.public_bytes(
|
||
|
|
encoding=serialization.Encoding.PEM,
|
||
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||
|
|
)
|
||
|
|
|
||
|
|
issuer = "https://authentik.viktorbarzin.me/application/o/wrongmove/"
|
||
|
|
payload = {
|
||
|
|
"sub": "authentik-user",
|
||
|
|
"email": "auth@example.com",
|
||
|
|
"name": "Auth User",
|
||
|
|
"iss": issuer,
|
||
|
|
"aud": "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe",
|
||
|
|
"exp": datetime.now(timezone.utc) - timedelta(hours=1), # expired
|
||
|
|
}
|
||
|
|
token = pyjwt.encode(payload, private_key, algorithm="RS256")
|
||
|
|
|
||
|
|
# Build a real PyJWK-compatible signing key mock so jwt.decode
|
||
|
|
# takes the PyJWK code path (uses key.key directly, skips prepare_key)
|
||
|
|
mock_signing_key = MagicMock(spec=pyjwt.PyJWK)
|
||
|
|
mock_signing_key.key = public_key
|
||
|
|
mock_signing_key.algorithm_name = "RS256"
|
||
|
|
mock_signing_key.Algorithm = pyjwt.get_algorithm_by_name("RS256")
|
||
|
|
|
||
|
|
mock_jwks_client = MagicMock()
|
||
|
|
mock_jwks_client.get_signing_key_from_jwt.return_value = mock_signing_key
|
||
|
|
|
||
|
|
mock_metadata = {
|
||
|
|
"issuer": issuer,
|
||
|
|
"jwks_uri": f"{issuer}jwks/",
|
||
|
|
}
|
||
|
|
|
||
|
|
with patch("api.auth.get_oidc_metadata", new_callable=AsyncMock, return_value=mock_metadata), \
|
||
|
|
patch("api.auth.get_cached_jwks_client", new_callable=AsyncMock, return_value=mock_jwks_client):
|
||
|
|
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||
|
|
await _verify_authentik_token(token)
|
||
|
|
|
||
|
|
|
||
|
|
class TestGetCurrentUser:
|
||
|
|
"""Tests for get_current_user()."""
|
||
|
|
|
||
|
|
async def test_routes_to_passkey_verifier_for_matching_issuer(self) -> None:
|
||
|
|
token = _make_passkey_token()
|
||
|
|
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||
|
|
user = await get_current_user(credentials)
|
||
|
|
assert user.sub == "user-123"
|
||
|
|
assert user.email == "test@example.com"
|
||
|
|
|
||
|
|
async def test_routes_to_authentik_for_other_issuer(self) -> None:
|
||
|
|
"""When issuer != JWT_ISSUER, should route to Authentik verifier."""
|
||
|
|
token = _make_passkey_token(issuer="https://authentik.viktorbarzin.me/application/o/wrongmove/")
|
||
|
|
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||
|
|
|
||
|
|
mock_user = User(sub="authentik-user", email="auth@example.com", name="Auth User")
|
||
|
|
with patch("api.auth._verify_authentik_token", new_callable=AsyncMock, return_value=mock_user):
|
||
|
|
user = await get_current_user(credentials)
|
||
|
|
assert user.email == "auth@example.com"
|
||
|
|
|
||
|
|
async def test_raises_http_exception_for_invalid_token(self) -> None:
|
||
|
|
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="not.a.valid.token")
|
||
|
|
with pytest.raises(HTTPException) as exc_info:
|
||
|
|
await get_current_user(credentials)
|
||
|
|
assert exc_info.value.status_code == 401
|
||
|
|
assert "Invalid token" in exc_info.value.detail
|
||
|
|
|
||
|
|
async def test_raises_http_exception_for_garbage_token(self) -> None:
|
||
|
|
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="totalgarbage")
|
||
|
|
with pytest.raises(HTTPException) as exc_info:
|
||
|
|
await get_current_user(credentials)
|
||
|
|
assert exc_info.value.status_code == 401
|