"""Tests for API Gateway auth — JWT, middleware, and health endpoint.""" from __future__ import annotations import time from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch import jwt as pyjwt import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from services.api_gateway.auth.jwt import ( create_access_token, create_refresh_token, decode_token, ) from services.api_gateway.auth.middleware import get_config, get_current_user from services.api_gateway.config import ApiGatewayConfig # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture() def config() -> ApiGatewayConfig: return ApiGatewayConfig( jwt_secret_key="test-secret-key-for-unit-tests", jwt_algorithm="HS256", access_token_expire_minutes=15, refresh_token_expire_days=7, database_url="sqlite+aiosqlite:///:memory:", redis_url="redis://localhost:6379/0", ) @pytest.fixture() def app(config: ApiGatewayConfig) -> FastAPI: """Create a minimal FastAPI app for testing the auth middleware.""" from fastapi import Depends app = FastAPI() app.dependency_overrides[get_config] = lambda: config @app.get("/protected") async def protected(user: dict = Depends(get_current_user)): return {"user_id": user["sub"], "username": user["username"]} @app.get("/health") async def health(): return {"status": "ok"} return app @pytest.fixture() def client(app: FastAPI) -> TestClient: return TestClient(app) # --------------------------------------------------------------------------- # JWT Tests # --------------------------------------------------------------------------- class TestJWTCreateAndDecode: """test_jwt_create_and_decode — round-trip create + decode.""" def test_access_token_round_trip(self, config: ApiGatewayConfig) -> None: token = create_access_token("user-123", "alice", config) payload = decode_token(token, config) assert payload["sub"] == "user-123" assert payload["username"] == "alice" assert payload["type"] == "access" assert "exp" in payload assert "iat" in payload def test_refresh_token_round_trip(self, config: ApiGatewayConfig) -> None: token = create_refresh_token("user-456", config) payload = decode_token(token, config) assert payload["sub"] == "user-456" assert payload["type"] == "refresh" assert "exp" in payload def test_access_token_expiry_time(self, config: ApiGatewayConfig) -> None: token = create_access_token("u1", "bob", config) payload = decode_token(token, config) exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) iat = datetime.fromtimestamp(payload["iat"], tz=timezone.utc) delta = exp - iat assert timedelta(minutes=14) < delta <= timedelta(minutes=16) class TestJWTExpiredToken: """test_jwt_expired_token_rejected.""" def test_expired_access_token_raises(self, config: ApiGatewayConfig) -> None: # Manually create a token that already expired now = datetime.now(timezone.utc) payload = { "sub": "user-expired", "username": "expired", "type": "access", "iat": now - timedelta(hours=2), "exp": now - timedelta(hours=1), } token = pyjwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm) with pytest.raises(pyjwt.ExpiredSignatureError): decode_token(token, config) class TestJWTInvalidToken: """test_jwt_invalid_token_rejected.""" def test_wrong_secret_raises(self, config: ApiGatewayConfig) -> None: now = datetime.now(timezone.utc) payload = { "sub": "user-bad", "type": "access", "exp": now + timedelta(hours=1), } token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256") with pytest.raises(pyjwt.InvalidSignatureError): decode_token(token, config) def test_malformed_token_raises(self, config: ApiGatewayConfig) -> None: with pytest.raises(pyjwt.DecodeError): decode_token("not.a.real.token", config) def test_completely_garbage_raises(self, config: ApiGatewayConfig) -> None: with pytest.raises(pyjwt.DecodeError): decode_token("garbage", config) # --------------------------------------------------------------------------- # Auth Middleware Tests # --------------------------------------------------------------------------- class TestAuthMiddlewareValidToken: """test_auth_middleware_valid_token.""" def test_protected_route_with_valid_token( self, client: TestClient, config: ApiGatewayConfig ) -> None: token = create_access_token("user-42", "charlie", config) resp = client.get( "/protected", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 200 data = resp.json() assert data["user_id"] == "user-42" assert data["username"] == "charlie" class TestAuthMiddlewareMissingToken: """test_auth_middleware_missing_token.""" def test_protected_route_no_header(self, client: TestClient) -> None: resp = client.get("/protected") assert resp.status_code == 401 assert "Missing authorization header" in resp.json()["detail"] def test_protected_route_expired_token( self, client: TestClient, config: ApiGatewayConfig ) -> None: now = datetime.now(timezone.utc) payload = { "sub": "user-old", "username": "old", "type": "access", "iat": now - timedelta(hours=2), "exp": now - timedelta(hours=1), } token = pyjwt.encode( payload, config.jwt_secret_key, algorithm=config.jwt_algorithm ) resp = client.get( "/protected", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 401 assert "expired" in resp.json()["detail"].lower() def test_protected_route_invalid_token(self, client: TestClient) -> None: resp = client.get( "/protected", headers={"Authorization": "Bearer garbage-token"}, ) assert resp.status_code == 401 assert "Invalid token" in resp.json()["detail"] def test_refresh_token_rejected_as_access( self, client: TestClient, config: ApiGatewayConfig ) -> None: token = create_refresh_token("user-99", config) resp = client.get( "/protected", headers={"Authorization": f"Bearer {token}"} ) assert resp.status_code == 401 assert "Invalid token type" in resp.json()["detail"] # --------------------------------------------------------------------------- # Health Endpoint Test # --------------------------------------------------------------------------- class TestHealthEndpoint: """test_health_endpoint.""" def test_health_returns_ok(self, client: TestClient) -> None: resp = client.get("/health") assert resp.status_code == 200 assert resp.json() == {"status": "ok"}