From e0d138c4574e9974e73f42c167e31eed39af4f6a Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Sun, 22 Feb 2026 15:53:48 +0000 Subject: [PATCH] feat: API gateway with passkey (WebAuthn) authentication --- pyproject.toml | 4 +- services/api_gateway/__init__.py | 1 + services/api_gateway/auth/__init__.py | 1 + services/api_gateway/auth/jwt.py | 98 ++++++ services/api_gateway/auth/middleware.py | 68 ++++ services/api_gateway/auth/routes.py | 410 ++++++++++++++++++++++++ services/api_gateway/config.py | 25 ++ services/api_gateway/main.py | 81 +++++ tests/services/test_api_auth.py | 221 +++++++++++++ 9 files changed, 907 insertions(+), 2 deletions(-) create mode 100644 services/api_gateway/__init__.py create mode 100644 services/api_gateway/auth/__init__.py create mode 100644 services/api_gateway/auth/jwt.py create mode 100644 services/api_gateway/auth/middleware.py create mode 100644 services/api_gateway/auth/routes.py create mode 100644 services/api_gateway/config.py create mode 100644 services/api_gateway/main.py create mode 100644 tests/services/test_api_auth.py diff --git a/pyproject.toml b/pyproject.toml index 36ccccf..fbd6349 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,12 +15,12 @@ dependencies = [ ] [project.optional-dependencies] -api = ["fastapi>=0.110", "uvicorn[standard]>=0.27", "websockets>=12.0", "py-webauthn>=2.0", "pyjwt[crypto]>=2.8"] +api = ["fastapi>=0.110", "uvicorn[standard]>=0.27", "websockets>=12.0", "webauthn>=2.0", "pyjwt[crypto]>=2.8"] news = ["feedparser>=6.0", "praw>=7.7", "asyncpraw>=7.7", "httpx>=0.27"] sentiment = ["transformers>=4.38", "torch>=2.2", "ollama>=0.1"] trading = ["alpaca-py>=0.21"] backtester = ["numpy>=1.26", "pandas>=2.2"] -dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "pytest-cov>=4.1", "ruff>=0.3", "mypy>=1.8"] +dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "pytest-cov>=4.1", "ruff>=0.3", "mypy>=1.8", "httpx>=0.27"] [build-system] requires = ["setuptools>=70.0"] diff --git a/services/api_gateway/__init__.py b/services/api_gateway/__init__.py new file mode 100644 index 0000000..0b5caf9 --- /dev/null +++ b/services/api_gateway/__init__.py @@ -0,0 +1 @@ +"""API Gateway service — FastAPI application serving the trading bot dashboard.""" diff --git a/services/api_gateway/auth/__init__.py b/services/api_gateway/auth/__init__.py new file mode 100644 index 0000000..d24f224 --- /dev/null +++ b/services/api_gateway/auth/__init__.py @@ -0,0 +1 @@ +"""Auth sub-package for the API Gateway.""" diff --git a/services/api_gateway/auth/jwt.py b/services/api_gateway/auth/jwt.py new file mode 100644 index 0000000..c117a5c --- /dev/null +++ b/services/api_gateway/auth/jwt.py @@ -0,0 +1,98 @@ +"""JWT utilities — token creation and verification.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import jwt + +from services.api_gateway.config import ApiGatewayConfig + + +def create_access_token( + user_id: str, + username: str, + config: ApiGatewayConfig, +) -> str: + """Create a short-lived JWT access token. + + Parameters + ---------- + user_id: + UUID of the authenticated user (stored as ``sub`` claim). + username: + Username (stored as ``username`` claim). + config: + Gateway configuration with secret key and algorithm. + + Returns + ------- + str + Encoded JWT string. + """ + now = datetime.now(timezone.utc) + payload = { + "sub": user_id, + "username": username, + "type": "access", + "iat": now, + "exp": now + timedelta(minutes=config.access_token_expire_minutes), + } + return jwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm) + + +def create_refresh_token( + user_id: str, + config: ApiGatewayConfig, +) -> str: + """Create a longer-lived JWT refresh token. + + Parameters + ---------- + user_id: + UUID of the authenticated user (stored as ``sub`` claim). + config: + Gateway configuration with secret key and algorithm. + + Returns + ------- + str + Encoded JWT string. + """ + now = datetime.now(timezone.utc) + payload = { + "sub": user_id, + "type": "refresh", + "iat": now, + "exp": now + timedelta(days=config.refresh_token_expire_days), + } + return jwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm) + + +def decode_token(token: str, config: ApiGatewayConfig) -> dict: + """Decode and verify a JWT token. + + Parameters + ---------- + token: + The JWT string to decode. + config: + Gateway configuration with secret key and algorithm. + + Returns + ------- + dict + The decoded payload. + + Raises + ------ + jwt.ExpiredSignatureError + If the token has expired. + jwt.InvalidTokenError + If the token is malformed or signature verification fails. + """ + return jwt.decode( + token, + config.jwt_secret_key, + algorithms=[config.jwt_algorithm], + ) diff --git a/services/api_gateway/auth/middleware.py b/services/api_gateway/auth/middleware.py new file mode 100644 index 0000000..ee25863 --- /dev/null +++ b/services/api_gateway/auth/middleware.py @@ -0,0 +1,68 @@ +"""Auth middleware — FastAPI dependency for JWT-based authentication.""" + +from __future__ import annotations + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +import jwt as pyjwt + +from services.api_gateway.auth.jwt import decode_token +from services.api_gateway.config import ApiGatewayConfig + +# Shared config instance — injected via FastAPI dependency override in tests. +_config: ApiGatewayConfig | None = None + +security = HTTPBearer(auto_error=False) + + +def get_config() -> ApiGatewayConfig: + """Return the singleton config. Overridden in tests.""" + global _config + if _config is None: + _config = ApiGatewayConfig() + return _config + + +async def get_current_user( + credentials: HTTPAuthorizationCredentials | None = Depends(security), + config: ApiGatewayConfig = Depends(get_config), +) -> dict: + """FastAPI dependency that extracts and validates a Bearer JWT. + + Returns the decoded token payload (contains ``sub``, ``username``, etc.) + on success. Raises a 401 ``HTTPException`` for missing, expired, or + invalid tokens. + """ + if credentials is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authorization header", + headers={"WWW-Authenticate": "Bearer"}, + ) + + token = credentials.credentials + try: + payload = decode_token(token, config) + except pyjwt.ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token has expired", + headers={"WWW-Authenticate": "Bearer"}, + ) + except pyjwt.InvalidTokenError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Ensure it is an access token, not a refresh token + if payload.get("type") != "access": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token type", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return payload diff --git a/services/api_gateway/auth/routes.py b/services/api_gateway/auth/routes.py new file mode 100644 index 0000000..5d9ac5c --- /dev/null +++ b/services/api_gateway/auth/routes.py @@ -0,0 +1,410 @@ +"""Auth routes — WebAuthn passkey registration/login and JWT refresh.""" + +from __future__ import annotations + +import base64 +import logging +import uuid +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from webauthn import ( + generate_authentication_options, + generate_registration_options, + verify_authentication_response, + verify_registration_response, +) +from webauthn.helpers.structs import ( + AuthenticatorSelectionCriteria, + PublicKeyCredentialDescriptor, + ResidentKeyRequirement, + UserVerificationRequirement, +) + +import jwt as pyjwt + +from services.api_gateway.auth.jwt import ( + create_access_token, + create_refresh_token, + decode_token, +) +from services.api_gateway.auth.middleware import get_config +from services.api_gateway.config import ApiGatewayConfig +from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/auth", tags=["auth"]) + +# --------------------------------------------------------------------------- +# Helpers — DB and Redis access +# --------------------------------------------------------------------------- + +async def _get_redis(request: Request): + """Retrieve the Redis client from application state.""" + return request.app.state.redis + + +async def _get_db(request: Request): + """Retrieve an async DB session from application state.""" + session_factory = request.app.state.db_session_factory + async with session_factory() as session: + yield session + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + + +@router.post("/register/begin") +async def register_begin( + body: RegisterRequest, + request: Request, + config: ApiGatewayConfig = Depends(get_config), +) -> dict[str, Any]: + """Generate WebAuthn registration options (challenge + relying party info). + + Stores the challenge in Redis with a 5-minute TTL so we can verify it + in ``/register/complete``. + """ + redis = await _get_redis(request) + db_session = request.app.state.db_session_factory + + # Check if username already exists + from sqlalchemy import select + from shared.models.auth import User + + async with db_session() as session: + existing = ( + await session.execute( + select(User).where(User.username == body.username) + ) + ).scalar_one_or_none() + if existing is not None: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Username already registered", + ) + + user_id = uuid.uuid4() + options = generate_registration_options( + rp_id=config.rp_id, + rp_name=config.rp_name, + user_id=str(user_id).encode(), + user_name=body.username, + user_display_name=body.display_name or body.username, + authenticator_selection=AuthenticatorSelectionCriteria( + resident_key=ResidentKeyRequirement.PREFERRED, + user_verification=UserVerificationRequirement.PREFERRED, + ), + ) + + # Store challenge in Redis (5 min TTL) + challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode() + redis_key = f"webauthn:register:{body.username}" + import json as _json + + await redis.setex( + redis_key, + 300, + _json.dumps({ + "challenge": challenge_b64, + "user_id": str(user_id), + "display_name": body.display_name or body.username, + }), + ) + + # Serialize options to a JSON-safe dict + from webauthn.helpers import options_to_json + + return {"options": _json.loads(options_to_json(options))} + + +@router.post("/register/complete") +async def register_complete( + request: Request, + config: ApiGatewayConfig = Depends(get_config), +) -> TokenResponse: + """Verify WebAuthn registration response and store credential. + + Expects the browser's ``navigator.credentials.create()`` response + in the request body. + """ + import json as _json + + redis = await _get_redis(request) + body = await request.json() + username = body.get("username", "") + + redis_key = f"webauthn:register:{username}" + stored_raw = await redis.get(redis_key) + if stored_raw is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Registration challenge expired or not found", + ) + stored = _json.loads(stored_raw) + expected_challenge = base64.urlsafe_b64decode(stored["challenge"]) + user_id_str = stored["user_id"] + display_name = stored["display_name"] + + try: + verification = verify_registration_response( + credential=body.get("credential", body), + expected_challenge=expected_challenge, + expected_rp_id=config.rp_id, + expected_origin=config.rp_origin, + ) + except Exception as exc: + logger.warning("Registration verification failed: %s", exc) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Registration verification failed: {exc}", + ) + + # Store user + credential in DB + from shared.models.auth import User, UserCredential + from sqlalchemy import select + + db_session = request.app.state.db_session_factory + async with db_session() as session: + user = User( + id=uuid.UUID(user_id_str), + username=username, + display_name=display_name, + ) + session.add(user) + + credential = UserCredential( + user_id=uuid.UUID(user_id_str), + credential_id=base64.urlsafe_b64encode( + verification.credential_id + ).decode(), + public_key=verification.credential_public_key, + sign_count=verification.sign_count, + ) + session.add(credential) + await session.commit() + + # Clean up challenge + await redis.delete(redis_key) + + # Issue JWT + access_token = create_access_token(user_id_str, username, config) + refresh_token = create_refresh_token(user_id_str, config) + + return TokenResponse( + access_token=access_token, + refresh_token=refresh_token, + expires_in=config.access_token_expire_minutes * 60, + ) + + +# --------------------------------------------------------------------------- +# Login (authentication) +# --------------------------------------------------------------------------- + + +@router.post("/login/begin") +async def login_begin( + body: LoginRequest, + request: Request, + config: ApiGatewayConfig = Depends(get_config), +) -> dict[str, Any]: + """Generate WebAuthn authentication options for an existing user.""" + import json as _json + + redis = await _get_redis(request) + db_session = request.app.state.db_session_factory + + from sqlalchemy import select + from shared.models.auth import User, UserCredential + + async with db_session() as session: + user = ( + await session.execute( + select(User).where(User.username == body.username) + ) + ).scalar_one_or_none() + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found", + ) + + creds = ( + await session.execute( + select(UserCredential).where( + UserCredential.user_id == user.id + ) + ) + ).scalars().all() + + allow_credentials = [ + PublicKeyCredentialDescriptor( + id=base64.urlsafe_b64decode(c.credential_id), + ) + for c in creds + ] + + options = generate_authentication_options( + rp_id=config.rp_id, + allow_credentials=allow_credentials, + user_verification=UserVerificationRequirement.PREFERRED, + ) + + challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode() + redis_key = f"webauthn:login:{body.username}" + await redis.setex( + redis_key, + 300, + _json.dumps({ + "challenge": challenge_b64, + "user_id": str(user.id), + "username": user.username, + }), + ) + + from webauthn.helpers import options_to_json + + return {"options": _json.loads(options_to_json(options))} + + +@router.post("/login/complete") +async def login_complete( + request: Request, + config: ApiGatewayConfig = Depends(get_config), +) -> TokenResponse: + """Verify WebAuthn authentication response and issue JWT.""" + import json as _json + + redis = await _get_redis(request) + body = await request.json() + username = body.get("username", "") + + redis_key = f"webauthn:login:{username}" + stored_raw = await redis.get(redis_key) + if stored_raw is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Authentication challenge expired or not found", + ) + stored = _json.loads(stored_raw) + expected_challenge = base64.urlsafe_b64decode(stored["challenge"]) + user_id_str = stored["user_id"] + + # Look up the credential used + from sqlalchemy import select + from shared.models.auth import UserCredential + + credential_id_b64 = body.get("credential", body).get("id", "") + db_session = request.app.state.db_session_factory + + async with db_session() as session: + cred = ( + await session.execute( + select(UserCredential).where( + UserCredential.credential_id == credential_id_b64 + ) + ) + ).scalar_one_or_none() + + if cred is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Credential not found", + ) + + try: + verification = verify_authentication_response( + credential=body.get("credential", body), + expected_challenge=expected_challenge, + expected_rp_id=config.rp_id, + expected_origin=config.rp_origin, + credential_public_key=cred.public_key, + credential_current_sign_count=cred.sign_count, + ) + except Exception as exc: + logger.warning("Authentication verification failed: %s", exc) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Authentication verification failed: {exc}", + ) + + # Update sign count + cred.sign_count = verification.new_sign_count + await session.commit() + + # Clean up challenge + await redis.delete(redis_key) + + # Issue JWT + access_token = create_access_token(user_id_str, username, config) + refresh_token = create_refresh_token(user_id_str, config) + + return TokenResponse( + access_token=access_token, + refresh_token=refresh_token, + expires_in=config.access_token_expire_minutes * 60, + ) + + +# --------------------------------------------------------------------------- +# Refresh +# --------------------------------------------------------------------------- + + +@router.post("/refresh") +async def refresh(request: Request, config: ApiGatewayConfig = Depends(get_config)) -> TokenResponse: + """Exchange a valid refresh token for a new access token.""" + body = await request.json() + refresh_token = body.get("refresh_token", "") + + try: + payload = decode_token(refresh_token, config) + except pyjwt.ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Refresh token has expired", + ) + except pyjwt.InvalidTokenError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token", + ) + + if payload.get("type") != "refresh": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token type", + ) + + user_id = payload["sub"] + + # Look up username from DB + from sqlalchemy import select + from shared.models.auth import User + + db_session = request.app.state.db_session_factory + async with db_session() as session: + user = ( + await session.execute( + select(User).where(User.id == uuid.UUID(user_id)) + ) + ).scalar_one_or_none() + + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User not found", + ) + + access_token = create_access_token(user_id, user.username, config) + new_refresh_token = create_refresh_token(user_id, config) + + return TokenResponse( + access_token=access_token, + refresh_token=new_refresh_token, + expires_in=config.access_token_expire_minutes * 60, + ) diff --git a/services/api_gateway/config.py b/services/api_gateway/config.py new file mode 100644 index 0000000..dcb7159 --- /dev/null +++ b/services/api_gateway/config.py @@ -0,0 +1,25 @@ +"""API Gateway configuration — extends shared BaseConfig with JWT, CORS, and WebAuthn settings.""" + +from shared.config import BaseConfig + + +class ApiGatewayConfig(BaseConfig): + """Configuration for the API Gateway service. + + All settings can be overridden via environment variables + prefixed with ``TRADING_``. + """ + + # JWT settings + jwt_secret_key: str = "CHANGE-ME-IN-PRODUCTION" + jwt_algorithm: str = "HS256" + access_token_expire_minutes: int = 15 + refresh_token_expire_days: int = 7 + + # CORS settings + cors_origins: list[str] = ["http://localhost:5173"] + + # WebAuthn (passkey) relying party settings + rp_id: str = "localhost" + rp_name: str = "Trading Bot" + rp_origin: str = "http://localhost:5173" diff --git a/services/api_gateway/main.py b/services/api_gateway/main.py new file mode 100644 index 0000000..6edf979 --- /dev/null +++ b/services/api_gateway/main.py @@ -0,0 +1,81 @@ +"""FastAPI application — API Gateway for the trading bot.""" + +from __future__ import annotations + +import logging +from contextlib import asynccontextmanager +from typing import AsyncIterator + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from redis.asyncio import Redis + +from services.api_gateway.auth.routes import router as auth_router +from services.api_gateway.config import ApiGatewayConfig +from shared.db import create_db + +logger = logging.getLogger(__name__) + + +def create_app(config: ApiGatewayConfig | None = None) -> FastAPI: + """Build and configure the FastAPI application. + + Parameters + ---------- + config: + Optional config override (useful for testing). If ``None``, a new + :class:`ApiGatewayConfig` is created from environment variables. + """ + if config is None: + config = ApiGatewayConfig() + + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncIterator[None]: + """Start-up / shutdown hook — connect DB and Redis.""" + # Database + engine, session_factory = create_db(config) + app.state.db_engine = engine + app.state.db_session_factory = session_factory + + # Redis + app.state.redis = Redis.from_url( + config.redis_url, decode_responses=True + ) + app.state.config = config + + logger.info("API Gateway started") + yield + + # Cleanup + await app.state.redis.aclose() + await engine.dispose() + logger.info("API Gateway stopped") + + app = FastAPI( + title="Trading Bot API", + version="0.1.0", + lifespan=lifespan, + ) + + # CORS + app.add_middleware( + CORSMiddleware, + allow_origins=config.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Auth routes (unauthenticated) + app.include_router(auth_router) + + # Health check + @app.get("/health", tags=["health"]) + async def health() -> dict: + return {"status": "ok"} + + return app + + +# Convenience: allow ``uvicorn services.api_gateway.main:app`` +app = create_app() diff --git a/tests/services/test_api_auth.py b/tests/services/test_api_auth.py new file mode 100644 index 0000000..7a63275 --- /dev/null +++ b/tests/services/test_api_auth.py @@ -0,0 +1,221 @@ +"""Tests for API Gateway auth — JWT, middleware, and health endpoint.""" + +from __future__ import annotations + +import time +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import jwt as pyjwt +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from services.api_gateway.auth.jwt import ( + create_access_token, + create_refresh_token, + decode_token, +) +from services.api_gateway.auth.middleware import get_config, get_current_user +from services.api_gateway.config import ApiGatewayConfig + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def config() -> ApiGatewayConfig: + return ApiGatewayConfig( + jwt_secret_key="test-secret-key-for-unit-tests", + jwt_algorithm="HS256", + access_token_expire_minutes=15, + refresh_token_expire_days=7, + database_url="sqlite+aiosqlite:///:memory:", + redis_url="redis://localhost:6379/0", + ) + + +@pytest.fixture() +def app(config: ApiGatewayConfig) -> FastAPI: + """Create a minimal FastAPI app for testing the auth middleware.""" + from fastapi import Depends + + app = FastAPI() + app.dependency_overrides[get_config] = lambda: config + + @app.get("/protected") + async def protected(user: dict = Depends(get_current_user)): + return {"user_id": user["sub"], "username": user["username"]} + + @app.get("/health") + async def health(): + return {"status": "ok"} + + return app + + +@pytest.fixture() +def client(app: FastAPI) -> TestClient: + return TestClient(app) + + +# --------------------------------------------------------------------------- +# JWT Tests +# --------------------------------------------------------------------------- + + +class TestJWTCreateAndDecode: + """test_jwt_create_and_decode — round-trip create + decode.""" + + def test_access_token_round_trip(self, config: ApiGatewayConfig) -> None: + token = create_access_token("user-123", "alice", config) + payload = decode_token(token, config) + + assert payload["sub"] == "user-123" + assert payload["username"] == "alice" + assert payload["type"] == "access" + assert "exp" in payload + assert "iat" in payload + + def test_refresh_token_round_trip(self, config: ApiGatewayConfig) -> None: + token = create_refresh_token("user-456", config) + payload = decode_token(token, config) + + assert payload["sub"] == "user-456" + assert payload["type"] == "refresh" + assert "exp" in payload + + def test_access_token_expiry_time(self, config: ApiGatewayConfig) -> None: + token = create_access_token("u1", "bob", config) + payload = decode_token(token, config) + exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) + iat = datetime.fromtimestamp(payload["iat"], tz=timezone.utc) + delta = exp - iat + assert timedelta(minutes=14) < delta <= timedelta(minutes=16) + + +class TestJWTExpiredToken: + """test_jwt_expired_token_rejected.""" + + def test_expired_access_token_raises(self, config: ApiGatewayConfig) -> None: + # Manually create a token that already expired + now = datetime.now(timezone.utc) + payload = { + "sub": "user-expired", + "username": "expired", + "type": "access", + "iat": now - timedelta(hours=2), + "exp": now - timedelta(hours=1), + } + token = pyjwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm) + + with pytest.raises(pyjwt.ExpiredSignatureError): + decode_token(token, config) + + +class TestJWTInvalidToken: + """test_jwt_invalid_token_rejected.""" + + def test_wrong_secret_raises(self, config: ApiGatewayConfig) -> None: + now = datetime.now(timezone.utc) + payload = { + "sub": "user-bad", + "type": "access", + "exp": now + timedelta(hours=1), + } + token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256") + + with pytest.raises(pyjwt.InvalidSignatureError): + decode_token(token, config) + + def test_malformed_token_raises(self, config: ApiGatewayConfig) -> None: + with pytest.raises(pyjwt.DecodeError): + decode_token("not.a.real.token", config) + + def test_completely_garbage_raises(self, config: ApiGatewayConfig) -> None: + with pytest.raises(pyjwt.DecodeError): + decode_token("garbage", config) + + +# --------------------------------------------------------------------------- +# Auth Middleware Tests +# --------------------------------------------------------------------------- + + +class TestAuthMiddlewareValidToken: + """test_auth_middleware_valid_token.""" + + def test_protected_route_with_valid_token( + self, client: TestClient, config: ApiGatewayConfig + ) -> None: + token = create_access_token("user-42", "charlie", config) + resp = client.get( + "/protected", headers={"Authorization": f"Bearer {token}"} + ) + assert resp.status_code == 200 + data = resp.json() + assert data["user_id"] == "user-42" + assert data["username"] == "charlie" + + +class TestAuthMiddlewareMissingToken: + """test_auth_middleware_missing_token.""" + + def test_protected_route_no_header(self, client: TestClient) -> None: + resp = client.get("/protected") + assert resp.status_code == 401 + assert "Missing authorization header" in resp.json()["detail"] + + def test_protected_route_expired_token( + self, client: TestClient, config: ApiGatewayConfig + ) -> None: + now = datetime.now(timezone.utc) + payload = { + "sub": "user-old", + "username": "old", + "type": "access", + "iat": now - timedelta(hours=2), + "exp": now - timedelta(hours=1), + } + token = pyjwt.encode( + payload, config.jwt_secret_key, algorithm=config.jwt_algorithm + ) + resp = client.get( + "/protected", headers={"Authorization": f"Bearer {token}"} + ) + assert resp.status_code == 401 + assert "expired" in resp.json()["detail"].lower() + + def test_protected_route_invalid_token(self, client: TestClient) -> None: + resp = client.get( + "/protected", + headers={"Authorization": "Bearer garbage-token"}, + ) + assert resp.status_code == 401 + assert "Invalid token" in resp.json()["detail"] + + def test_refresh_token_rejected_as_access( + self, client: TestClient, config: ApiGatewayConfig + ) -> None: + token = create_refresh_token("user-99", config) + resp = client.get( + "/protected", headers={"Authorization": f"Bearer {token}"} + ) + assert resp.status_code == 401 + assert "Invalid token type" in resp.json()["detail"] + + +# --------------------------------------------------------------------------- +# Health Endpoint Test +# --------------------------------------------------------------------------- + + +class TestHealthEndpoint: + """test_health_endpoint.""" + + def test_health_returns_ok(self, client: TestClient) -> None: + resp = client.get("/health") + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"}