221 lines
7.3 KiB
Python
221 lines
7.3 KiB
Python
"""Tests for API Gateway auth — JWT, middleware, and health endpoint."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import jwt as pyjwt
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from services.api_gateway.auth.jwt import (
|
|
create_access_token,
|
|
create_refresh_token,
|
|
decode_token,
|
|
)
|
|
from services.api_gateway.auth.middleware import get_config, get_current_user
|
|
from services.api_gateway.config import ApiGatewayConfig
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture()
|
|
def config() -> ApiGatewayConfig:
|
|
return ApiGatewayConfig(
|
|
jwt_secret_key="test-secret-key-for-unit-tests",
|
|
jwt_algorithm="HS256",
|
|
access_token_expire_minutes=15,
|
|
refresh_token_expire_days=7,
|
|
database_url="sqlite+aiosqlite:///:memory:",
|
|
redis_url="redis://localhost:6379/0",
|
|
)
|
|
|
|
|
|
@pytest.fixture()
|
|
def app(config: ApiGatewayConfig) -> FastAPI:
|
|
"""Create a minimal FastAPI app for testing the auth middleware."""
|
|
from fastapi import Depends
|
|
|
|
app = FastAPI()
|
|
app.dependency_overrides[get_config] = lambda: config
|
|
|
|
@app.get("/protected")
|
|
async def protected(user: dict = Depends(get_current_user)):
|
|
return {"user_id": user["sub"], "username": user["username"]}
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture()
|
|
def client(app: FastAPI) -> TestClient:
|
|
return TestClient(app)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# JWT Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestJWTCreateAndDecode:
|
|
"""test_jwt_create_and_decode — round-trip create + decode."""
|
|
|
|
def test_access_token_round_trip(self, config: ApiGatewayConfig) -> None:
|
|
token = create_access_token("user-123", "alice", config)
|
|
payload = decode_token(token, config)
|
|
|
|
assert payload["sub"] == "user-123"
|
|
assert payload["username"] == "alice"
|
|
assert payload["type"] == "access"
|
|
assert "exp" in payload
|
|
assert "iat" in payload
|
|
|
|
def test_refresh_token_round_trip(self, config: ApiGatewayConfig) -> None:
|
|
token = create_refresh_token("user-456", config)
|
|
payload = decode_token(token, config)
|
|
|
|
assert payload["sub"] == "user-456"
|
|
assert payload["type"] == "refresh"
|
|
assert "exp" in payload
|
|
|
|
def test_access_token_expiry_time(self, config: ApiGatewayConfig) -> None:
|
|
token = create_access_token("u1", "bob", config)
|
|
payload = decode_token(token, config)
|
|
exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
|
iat = datetime.fromtimestamp(payload["iat"], tz=timezone.utc)
|
|
delta = exp - iat
|
|
assert timedelta(minutes=14) < delta <= timedelta(minutes=16)
|
|
|
|
|
|
class TestJWTExpiredToken:
|
|
"""test_jwt_expired_token_rejected."""
|
|
|
|
def test_expired_access_token_raises(self, config: ApiGatewayConfig) -> None:
|
|
# Manually create a token that already expired
|
|
now = datetime.now(timezone.utc)
|
|
payload = {
|
|
"sub": "user-expired",
|
|
"username": "expired",
|
|
"type": "access",
|
|
"iat": now - timedelta(hours=2),
|
|
"exp": now - timedelta(hours=1),
|
|
}
|
|
token = pyjwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm)
|
|
|
|
with pytest.raises(pyjwt.ExpiredSignatureError):
|
|
decode_token(token, config)
|
|
|
|
|
|
class TestJWTInvalidToken:
|
|
"""test_jwt_invalid_token_rejected."""
|
|
|
|
def test_wrong_secret_raises(self, config: ApiGatewayConfig) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
payload = {
|
|
"sub": "user-bad",
|
|
"type": "access",
|
|
"exp": now + timedelta(hours=1),
|
|
}
|
|
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
|
|
|
with pytest.raises(pyjwt.InvalidSignatureError):
|
|
decode_token(token, config)
|
|
|
|
def test_malformed_token_raises(self, config: ApiGatewayConfig) -> None:
|
|
with pytest.raises(pyjwt.DecodeError):
|
|
decode_token("not.a.real.token", config)
|
|
|
|
def test_completely_garbage_raises(self, config: ApiGatewayConfig) -> None:
|
|
with pytest.raises(pyjwt.DecodeError):
|
|
decode_token("garbage", config)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Auth Middleware Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAuthMiddlewareValidToken:
|
|
"""test_auth_middleware_valid_token."""
|
|
|
|
def test_protected_route_with_valid_token(
|
|
self, client: TestClient, config: ApiGatewayConfig
|
|
) -> None:
|
|
token = create_access_token("user-42", "charlie", config)
|
|
resp = client.get(
|
|
"/protected", headers={"Authorization": f"Bearer {token}"}
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["user_id"] == "user-42"
|
|
assert data["username"] == "charlie"
|
|
|
|
|
|
class TestAuthMiddlewareMissingToken:
|
|
"""test_auth_middleware_missing_token."""
|
|
|
|
def test_protected_route_no_header(self, client: TestClient) -> None:
|
|
resp = client.get("/protected")
|
|
assert resp.status_code == 401
|
|
assert "Missing authorization header" in resp.json()["detail"]
|
|
|
|
def test_protected_route_expired_token(
|
|
self, client: TestClient, config: ApiGatewayConfig
|
|
) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
payload = {
|
|
"sub": "user-old",
|
|
"username": "old",
|
|
"type": "access",
|
|
"iat": now - timedelta(hours=2),
|
|
"exp": now - timedelta(hours=1),
|
|
}
|
|
token = pyjwt.encode(
|
|
payload, config.jwt_secret_key, algorithm=config.jwt_algorithm
|
|
)
|
|
resp = client.get(
|
|
"/protected", headers={"Authorization": f"Bearer {token}"}
|
|
)
|
|
assert resp.status_code == 401
|
|
assert "expired" in resp.json()["detail"].lower()
|
|
|
|
def test_protected_route_invalid_token(self, client: TestClient) -> None:
|
|
resp = client.get(
|
|
"/protected",
|
|
headers={"Authorization": "Bearer garbage-token"},
|
|
)
|
|
assert resp.status_code == 401
|
|
assert "Invalid token" in resp.json()["detail"]
|
|
|
|
def test_refresh_token_rejected_as_access(
|
|
self, client: TestClient, config: ApiGatewayConfig
|
|
) -> None:
|
|
token = create_refresh_token("user-99", config)
|
|
resp = client.get(
|
|
"/protected", headers={"Authorization": f"Bearer {token}"}
|
|
)
|
|
assert resp.status_code == 401
|
|
assert "Invalid token type" in resp.json()["detail"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Health Endpoint Test
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHealthEndpoint:
|
|
"""test_health_endpoint."""
|
|
|
|
def test_health_returns_ok(self, client: TestClient) -> None:
|
|
resp = client.get("/health")
|
|
assert resp.status_code == 200
|
|
assert resp.json() == {"status": "ok"}
|