trading/tests/services/test_api_auth.py

221 lines
7.3 KiB
Python

"""Tests for API Gateway auth — JWT, middleware, and health endpoint."""
from __future__ import annotations
import time
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import jwt as pyjwt
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from services.api_gateway.auth.jwt import (
create_access_token,
create_refresh_token,
decode_token,
)
from services.api_gateway.auth.middleware import get_config, get_current_user
from services.api_gateway.config import ApiGatewayConfig
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def config() -> ApiGatewayConfig:
return ApiGatewayConfig(
jwt_secret_key="test-secret-key-for-unit-tests",
jwt_algorithm="HS256",
access_token_expire_minutes=15,
refresh_token_expire_days=7,
database_url="sqlite+aiosqlite:///:memory:",
redis_url="redis://localhost:6379/0",
)
@pytest.fixture()
def app(config: ApiGatewayConfig) -> FastAPI:
"""Create a minimal FastAPI app for testing the auth middleware."""
from fastapi import Depends
app = FastAPI()
app.dependency_overrides[get_config] = lambda: config
@app.get("/protected")
async def protected(user: dict = Depends(get_current_user)):
return {"user_id": user["sub"], "username": user["username"]}
@app.get("/health")
async def health():
return {"status": "ok"}
return app
@pytest.fixture()
def client(app: FastAPI) -> TestClient:
return TestClient(app)
# ---------------------------------------------------------------------------
# JWT Tests
# ---------------------------------------------------------------------------
class TestJWTCreateAndDecode:
"""test_jwt_create_and_decode — round-trip create + decode."""
def test_access_token_round_trip(self, config: ApiGatewayConfig) -> None:
token = create_access_token("user-123", "alice", config)
payload = decode_token(token, config)
assert payload["sub"] == "user-123"
assert payload["username"] == "alice"
assert payload["type"] == "access"
assert "exp" in payload
assert "iat" in payload
def test_refresh_token_round_trip(self, config: ApiGatewayConfig) -> None:
token = create_refresh_token("user-456", config)
payload = decode_token(token, config)
assert payload["sub"] == "user-456"
assert payload["type"] == "refresh"
assert "exp" in payload
def test_access_token_expiry_time(self, config: ApiGatewayConfig) -> None:
token = create_access_token("u1", "bob", config)
payload = decode_token(token, config)
exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
iat = datetime.fromtimestamp(payload["iat"], tz=timezone.utc)
delta = exp - iat
assert timedelta(minutes=14) < delta <= timedelta(minutes=16)
class TestJWTExpiredToken:
"""test_jwt_expired_token_rejected."""
def test_expired_access_token_raises(self, config: ApiGatewayConfig) -> None:
# Manually create a token that already expired
now = datetime.now(timezone.utc)
payload = {
"sub": "user-expired",
"username": "expired",
"type": "access",
"iat": now - timedelta(hours=2),
"exp": now - timedelta(hours=1),
}
token = pyjwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm)
with pytest.raises(pyjwt.ExpiredSignatureError):
decode_token(token, config)
class TestJWTInvalidToken:
"""test_jwt_invalid_token_rejected."""
def test_wrong_secret_raises(self, config: ApiGatewayConfig) -> None:
now = datetime.now(timezone.utc)
payload = {
"sub": "user-bad",
"type": "access",
"exp": now + timedelta(hours=1),
}
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
with pytest.raises(pyjwt.InvalidSignatureError):
decode_token(token, config)
def test_malformed_token_raises(self, config: ApiGatewayConfig) -> None:
with pytest.raises(pyjwt.DecodeError):
decode_token("not.a.real.token", config)
def test_completely_garbage_raises(self, config: ApiGatewayConfig) -> None:
with pytest.raises(pyjwt.DecodeError):
decode_token("garbage", config)
# ---------------------------------------------------------------------------
# Auth Middleware Tests
# ---------------------------------------------------------------------------
class TestAuthMiddlewareValidToken:
"""test_auth_middleware_valid_token."""
def test_protected_route_with_valid_token(
self, client: TestClient, config: ApiGatewayConfig
) -> None:
token = create_access_token("user-42", "charlie", config)
resp = client.get(
"/protected", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
data = resp.json()
assert data["user_id"] == "user-42"
assert data["username"] == "charlie"
class TestAuthMiddlewareMissingToken:
"""test_auth_middleware_missing_token."""
def test_protected_route_no_header(self, client: TestClient) -> None:
resp = client.get("/protected")
assert resp.status_code == 401
assert "Missing authorization header" in resp.json()["detail"]
def test_protected_route_expired_token(
self, client: TestClient, config: ApiGatewayConfig
) -> None:
now = datetime.now(timezone.utc)
payload = {
"sub": "user-old",
"username": "old",
"type": "access",
"iat": now - timedelta(hours=2),
"exp": now - timedelta(hours=1),
}
token = pyjwt.encode(
payload, config.jwt_secret_key, algorithm=config.jwt_algorithm
)
resp = client.get(
"/protected", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 401
assert "expired" in resp.json()["detail"].lower()
def test_protected_route_invalid_token(self, client: TestClient) -> None:
resp = client.get(
"/protected",
headers={"Authorization": "Bearer garbage-token"},
)
assert resp.status_code == 401
assert "Invalid token" in resp.json()["detail"]
def test_refresh_token_rejected_as_access(
self, client: TestClient, config: ApiGatewayConfig
) -> None:
token = create_refresh_token("user-99", config)
resp = client.get(
"/protected", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 401
assert "Invalid token type" in resp.json()["detail"]
# ---------------------------------------------------------------------------
# Health Endpoint Test
# ---------------------------------------------------------------------------
class TestHealthEndpoint:
"""test_health_endpoint."""
def test_health_returns_ok(self, client: TestClient) -> None:
resp = client.get("/health")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}