feat: API gateway with passkey (WebAuthn) authentication

This commit is contained in:
Viktor Barzin 2026-02-22 15:53:48 +00:00
parent f218865872
commit e0d138c457
No known key found for this signature in database
GPG key ID: 0EB088298288D958
9 changed files with 907 additions and 2 deletions

View file

@ -15,12 +15,12 @@ dependencies = [
]
[project.optional-dependencies]
api = ["fastapi>=0.110", "uvicorn[standard]>=0.27", "websockets>=12.0", "py-webauthn>=2.0", "pyjwt[crypto]>=2.8"]
api = ["fastapi>=0.110", "uvicorn[standard]>=0.27", "websockets>=12.0", "webauthn>=2.0", "pyjwt[crypto]>=2.8"]
news = ["feedparser>=6.0", "praw>=7.7", "asyncpraw>=7.7", "httpx>=0.27"]
sentiment = ["transformers>=4.38", "torch>=2.2", "ollama>=0.1"]
trading = ["alpaca-py>=0.21"]
backtester = ["numpy>=1.26", "pandas>=2.2"]
dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "pytest-cov>=4.1", "ruff>=0.3", "mypy>=1.8"]
dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "pytest-cov>=4.1", "ruff>=0.3", "mypy>=1.8", "httpx>=0.27"]
[build-system]
requires = ["setuptools>=70.0"]

View file

@ -0,0 +1 @@
"""API Gateway service — FastAPI application serving the trading bot dashboard."""

View file

@ -0,0 +1 @@
"""Auth sub-package for the API Gateway."""

View file

@ -0,0 +1,98 @@
"""JWT utilities — token creation and verification."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
import jwt
from services.api_gateway.config import ApiGatewayConfig
def create_access_token(
user_id: str,
username: str,
config: ApiGatewayConfig,
) -> str:
"""Create a short-lived JWT access token.
Parameters
----------
user_id:
UUID of the authenticated user (stored as ``sub`` claim).
username:
Username (stored as ``username`` claim).
config:
Gateway configuration with secret key and algorithm.
Returns
-------
str
Encoded JWT string.
"""
now = datetime.now(timezone.utc)
payload = {
"sub": user_id,
"username": username,
"type": "access",
"iat": now,
"exp": now + timedelta(minutes=config.access_token_expire_minutes),
}
return jwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm)
def create_refresh_token(
user_id: str,
config: ApiGatewayConfig,
) -> str:
"""Create a longer-lived JWT refresh token.
Parameters
----------
user_id:
UUID of the authenticated user (stored as ``sub`` claim).
config:
Gateway configuration with secret key and algorithm.
Returns
-------
str
Encoded JWT string.
"""
now = datetime.now(timezone.utc)
payload = {
"sub": user_id,
"type": "refresh",
"iat": now,
"exp": now + timedelta(days=config.refresh_token_expire_days),
}
return jwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm)
def decode_token(token: str, config: ApiGatewayConfig) -> dict:
"""Decode and verify a JWT token.
Parameters
----------
token:
The JWT string to decode.
config:
Gateway configuration with secret key and algorithm.
Returns
-------
dict
The decoded payload.
Raises
------
jwt.ExpiredSignatureError
If the token has expired.
jwt.InvalidTokenError
If the token is malformed or signature verification fails.
"""
return jwt.decode(
token,
config.jwt_secret_key,
algorithms=[config.jwt_algorithm],
)

View file

@ -0,0 +1,68 @@
"""Auth middleware — FastAPI dependency for JWT-based authentication."""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import jwt as pyjwt
from services.api_gateway.auth.jwt import decode_token
from services.api_gateway.config import ApiGatewayConfig
# Shared config instance — injected via FastAPI dependency override in tests.
_config: ApiGatewayConfig | None = None
security = HTTPBearer(auto_error=False)
def get_config() -> ApiGatewayConfig:
"""Return the singleton config. Overridden in tests."""
global _config
if _config is None:
_config = ApiGatewayConfig()
return _config
async def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(security),
config: ApiGatewayConfig = Depends(get_config),
) -> dict:
"""FastAPI dependency that extracts and validates a Bearer JWT.
Returns the decoded token payload (contains ``sub``, ``username``, etc.)
on success. Raises a 401 ``HTTPException`` for missing, expired, or
invalid tokens.
"""
if credentials is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
try:
payload = decode_token(token, config)
except pyjwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
except pyjwt.InvalidTokenError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"},
)
# Ensure it is an access token, not a refresh token
if payload.get("type") != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
headers={"WWW-Authenticate": "Bearer"},
)
return payload

View file

@ -0,0 +1,410 @@
"""Auth routes — WebAuthn passkey registration/login and JWT refresh."""
from __future__ import annotations
import base64
import logging
import uuid
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Request, status
from webauthn import (
generate_authentication_options,
generate_registration_options,
verify_authentication_response,
verify_registration_response,
)
from webauthn.helpers.structs import (
AuthenticatorSelectionCriteria,
PublicKeyCredentialDescriptor,
ResidentKeyRequirement,
UserVerificationRequirement,
)
import jwt as pyjwt
from services.api_gateway.auth.jwt import (
create_access_token,
create_refresh_token,
decode_token,
)
from services.api_gateway.auth.middleware import get_config
from services.api_gateway.config import ApiGatewayConfig
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
# ---------------------------------------------------------------------------
# Helpers — DB and Redis access
# ---------------------------------------------------------------------------
async def _get_redis(request: Request):
"""Retrieve the Redis client from application state."""
return request.app.state.redis
async def _get_db(request: Request):
"""Retrieve an async DB session from application state."""
session_factory = request.app.state.db_session_factory
async with session_factory() as session:
yield session
# ---------------------------------------------------------------------------
# Registration
# ---------------------------------------------------------------------------
@router.post("/register/begin")
async def register_begin(
body: RegisterRequest,
request: Request,
config: ApiGatewayConfig = Depends(get_config),
) -> dict[str, Any]:
"""Generate WebAuthn registration options (challenge + relying party info).
Stores the challenge in Redis with a 5-minute TTL so we can verify it
in ``/register/complete``.
"""
redis = await _get_redis(request)
db_session = request.app.state.db_session_factory
# Check if username already exists
from sqlalchemy import select
from shared.models.auth import User
async with db_session() as session:
existing = (
await session.execute(
select(User).where(User.username == body.username)
)
).scalar_one_or_none()
if existing is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Username already registered",
)
user_id = uuid.uuid4()
options = generate_registration_options(
rp_id=config.rp_id,
rp_name=config.rp_name,
user_id=str(user_id).encode(),
user_name=body.username,
user_display_name=body.display_name or body.username,
authenticator_selection=AuthenticatorSelectionCriteria(
resident_key=ResidentKeyRequirement.PREFERRED,
user_verification=UserVerificationRequirement.PREFERRED,
),
)
# Store challenge in Redis (5 min TTL)
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
redis_key = f"webauthn:register:{body.username}"
import json as _json
await redis.setex(
redis_key,
300,
_json.dumps({
"challenge": challenge_b64,
"user_id": str(user_id),
"display_name": body.display_name or body.username,
}),
)
# Serialize options to a JSON-safe dict
from webauthn.helpers import options_to_json
return {"options": _json.loads(options_to_json(options))}
@router.post("/register/complete")
async def register_complete(
request: Request,
config: ApiGatewayConfig = Depends(get_config),
) -> TokenResponse:
"""Verify WebAuthn registration response and store credential.
Expects the browser's ``navigator.credentials.create()`` response
in the request body.
"""
import json as _json
redis = await _get_redis(request)
body = await request.json()
username = body.get("username", "")
redis_key = f"webauthn:register:{username}"
stored_raw = await redis.get(redis_key)
if stored_raw is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration challenge expired or not found",
)
stored = _json.loads(stored_raw)
expected_challenge = base64.urlsafe_b64decode(stored["challenge"])
user_id_str = stored["user_id"]
display_name = stored["display_name"]
try:
verification = verify_registration_response(
credential=body.get("credential", body),
expected_challenge=expected_challenge,
expected_rp_id=config.rp_id,
expected_origin=config.rp_origin,
)
except Exception as exc:
logger.warning("Registration verification failed: %s", exc)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Registration verification failed: {exc}",
)
# Store user + credential in DB
from shared.models.auth import User, UserCredential
from sqlalchemy import select
db_session = request.app.state.db_session_factory
async with db_session() as session:
user = User(
id=uuid.UUID(user_id_str),
username=username,
display_name=display_name,
)
session.add(user)
credential = UserCredential(
user_id=uuid.UUID(user_id_str),
credential_id=base64.urlsafe_b64encode(
verification.credential_id
).decode(),
public_key=verification.credential_public_key,
sign_count=verification.sign_count,
)
session.add(credential)
await session.commit()
# Clean up challenge
await redis.delete(redis_key)
# Issue JWT
access_token = create_access_token(user_id_str, username, config)
refresh_token = create_refresh_token(user_id_str, config)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=config.access_token_expire_minutes * 60,
)
# ---------------------------------------------------------------------------
# Login (authentication)
# ---------------------------------------------------------------------------
@router.post("/login/begin")
async def login_begin(
body: LoginRequest,
request: Request,
config: ApiGatewayConfig = Depends(get_config),
) -> dict[str, Any]:
"""Generate WebAuthn authentication options for an existing user."""
import json as _json
redis = await _get_redis(request)
db_session = request.app.state.db_session_factory
from sqlalchemy import select
from shared.models.auth import User, UserCredential
async with db_session() as session:
user = (
await session.execute(
select(User).where(User.username == body.username)
)
).scalar_one_or_none()
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
creds = (
await session.execute(
select(UserCredential).where(
UserCredential.user_id == user.id
)
)
).scalars().all()
allow_credentials = [
PublicKeyCredentialDescriptor(
id=base64.urlsafe_b64decode(c.credential_id),
)
for c in creds
]
options = generate_authentication_options(
rp_id=config.rp_id,
allow_credentials=allow_credentials,
user_verification=UserVerificationRequirement.PREFERRED,
)
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
redis_key = f"webauthn:login:{body.username}"
await redis.setex(
redis_key,
300,
_json.dumps({
"challenge": challenge_b64,
"user_id": str(user.id),
"username": user.username,
}),
)
from webauthn.helpers import options_to_json
return {"options": _json.loads(options_to_json(options))}
@router.post("/login/complete")
async def login_complete(
request: Request,
config: ApiGatewayConfig = Depends(get_config),
) -> TokenResponse:
"""Verify WebAuthn authentication response and issue JWT."""
import json as _json
redis = await _get_redis(request)
body = await request.json()
username = body.get("username", "")
redis_key = f"webauthn:login:{username}"
stored_raw = await redis.get(redis_key)
if stored_raw is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Authentication challenge expired or not found",
)
stored = _json.loads(stored_raw)
expected_challenge = base64.urlsafe_b64decode(stored["challenge"])
user_id_str = stored["user_id"]
# Look up the credential used
from sqlalchemy import select
from shared.models.auth import UserCredential
credential_id_b64 = body.get("credential", body).get("id", "")
db_session = request.app.state.db_session_factory
async with db_session() as session:
cred = (
await session.execute(
select(UserCredential).where(
UserCredential.credential_id == credential_id_b64
)
)
).scalar_one_or_none()
if cred is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Credential not found",
)
try:
verification = verify_authentication_response(
credential=body.get("credential", body),
expected_challenge=expected_challenge,
expected_rp_id=config.rp_id,
expected_origin=config.rp_origin,
credential_public_key=cred.public_key,
credential_current_sign_count=cred.sign_count,
)
except Exception as exc:
logger.warning("Authentication verification failed: %s", exc)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Authentication verification failed: {exc}",
)
# Update sign count
cred.sign_count = verification.new_sign_count
await session.commit()
# Clean up challenge
await redis.delete(redis_key)
# Issue JWT
access_token = create_access_token(user_id_str, username, config)
refresh_token = create_refresh_token(user_id_str, config)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=config.access_token_expire_minutes * 60,
)
# ---------------------------------------------------------------------------
# Refresh
# ---------------------------------------------------------------------------
@router.post("/refresh")
async def refresh(request: Request, config: ApiGatewayConfig = Depends(get_config)) -> TokenResponse:
"""Exchange a valid refresh token for a new access token."""
body = await request.json()
refresh_token = body.get("refresh_token", "")
try:
payload = decode_token(refresh_token, config)
except pyjwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has expired",
)
except pyjwt.InvalidTokenError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
if payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
)
user_id = payload["sub"]
# Look up username from DB
from sqlalchemy import select
from shared.models.auth import User
db_session = request.app.state.db_session_factory
async with db_session() as session:
user = (
await session.execute(
select(User).where(User.id == uuid.UUID(user_id))
)
).scalar_one_or_none()
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
access_token = create_access_token(user_id, user.username, config)
new_refresh_token = create_refresh_token(user_id, config)
return TokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
expires_in=config.access_token_expire_minutes * 60,
)

View file

@ -0,0 +1,25 @@
"""API Gateway configuration — extends shared BaseConfig with JWT, CORS, and WebAuthn settings."""
from shared.config import BaseConfig
class ApiGatewayConfig(BaseConfig):
"""Configuration for the API Gateway service.
All settings can be overridden via environment variables
prefixed with ``TRADING_``.
"""
# JWT settings
jwt_secret_key: str = "CHANGE-ME-IN-PRODUCTION"
jwt_algorithm: str = "HS256"
access_token_expire_minutes: int = 15
refresh_token_expire_days: int = 7
# CORS settings
cors_origins: list[str] = ["http://localhost:5173"]
# WebAuthn (passkey) relying party settings
rp_id: str = "localhost"
rp_name: str = "Trading Bot"
rp_origin: str = "http://localhost:5173"

View file

@ -0,0 +1,81 @@
"""FastAPI application — API Gateway for the trading bot."""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from typing import AsyncIterator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from redis.asyncio import Redis
from services.api_gateway.auth.routes import router as auth_router
from services.api_gateway.config import ApiGatewayConfig
from shared.db import create_db
logger = logging.getLogger(__name__)
def create_app(config: ApiGatewayConfig | None = None) -> FastAPI:
"""Build and configure the FastAPI application.
Parameters
----------
config:
Optional config override (useful for testing). If ``None``, a new
:class:`ApiGatewayConfig` is created from environment variables.
"""
if config is None:
config = ApiGatewayConfig()
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Start-up / shutdown hook — connect DB and Redis."""
# Database
engine, session_factory = create_db(config)
app.state.db_engine = engine
app.state.db_session_factory = session_factory
# Redis
app.state.redis = Redis.from_url(
config.redis_url, decode_responses=True
)
app.state.config = config
logger.info("API Gateway started")
yield
# Cleanup
await app.state.redis.aclose()
await engine.dispose()
logger.info("API Gateway stopped")
app = FastAPI(
title="Trading Bot API",
version="0.1.0",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=config.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Auth routes (unauthenticated)
app.include_router(auth_router)
# Health check
@app.get("/health", tags=["health"])
async def health() -> dict:
return {"status": "ok"}
return app
# Convenience: allow ``uvicorn services.api_gateway.main:app``
app = create_app()

View file

@ -0,0 +1,221 @@
"""Tests for API Gateway auth — JWT, middleware, and health endpoint."""
from __future__ import annotations
import time
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import jwt as pyjwt
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from services.api_gateway.auth.jwt import (
create_access_token,
create_refresh_token,
decode_token,
)
from services.api_gateway.auth.middleware import get_config, get_current_user
from services.api_gateway.config import ApiGatewayConfig
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def config() -> ApiGatewayConfig:
return ApiGatewayConfig(
jwt_secret_key="test-secret-key-for-unit-tests",
jwt_algorithm="HS256",
access_token_expire_minutes=15,
refresh_token_expire_days=7,
database_url="sqlite+aiosqlite:///:memory:",
redis_url="redis://localhost:6379/0",
)
@pytest.fixture()
def app(config: ApiGatewayConfig) -> FastAPI:
"""Create a minimal FastAPI app for testing the auth middleware."""
from fastapi import Depends
app = FastAPI()
app.dependency_overrides[get_config] = lambda: config
@app.get("/protected")
async def protected(user: dict = Depends(get_current_user)):
return {"user_id": user["sub"], "username": user["username"]}
@app.get("/health")
async def health():
return {"status": "ok"}
return app
@pytest.fixture()
def client(app: FastAPI) -> TestClient:
return TestClient(app)
# ---------------------------------------------------------------------------
# JWT Tests
# ---------------------------------------------------------------------------
class TestJWTCreateAndDecode:
"""test_jwt_create_and_decode — round-trip create + decode."""
def test_access_token_round_trip(self, config: ApiGatewayConfig) -> None:
token = create_access_token("user-123", "alice", config)
payload = decode_token(token, config)
assert payload["sub"] == "user-123"
assert payload["username"] == "alice"
assert payload["type"] == "access"
assert "exp" in payload
assert "iat" in payload
def test_refresh_token_round_trip(self, config: ApiGatewayConfig) -> None:
token = create_refresh_token("user-456", config)
payload = decode_token(token, config)
assert payload["sub"] == "user-456"
assert payload["type"] == "refresh"
assert "exp" in payload
def test_access_token_expiry_time(self, config: ApiGatewayConfig) -> None:
token = create_access_token("u1", "bob", config)
payload = decode_token(token, config)
exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
iat = datetime.fromtimestamp(payload["iat"], tz=timezone.utc)
delta = exp - iat
assert timedelta(minutes=14) < delta <= timedelta(minutes=16)
class TestJWTExpiredToken:
"""test_jwt_expired_token_rejected."""
def test_expired_access_token_raises(self, config: ApiGatewayConfig) -> None:
# Manually create a token that already expired
now = datetime.now(timezone.utc)
payload = {
"sub": "user-expired",
"username": "expired",
"type": "access",
"iat": now - timedelta(hours=2),
"exp": now - timedelta(hours=1),
}
token = pyjwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm)
with pytest.raises(pyjwt.ExpiredSignatureError):
decode_token(token, config)
class TestJWTInvalidToken:
"""test_jwt_invalid_token_rejected."""
def test_wrong_secret_raises(self, config: ApiGatewayConfig) -> None:
now = datetime.now(timezone.utc)
payload = {
"sub": "user-bad",
"type": "access",
"exp": now + timedelta(hours=1),
}
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
with pytest.raises(pyjwt.InvalidSignatureError):
decode_token(token, config)
def test_malformed_token_raises(self, config: ApiGatewayConfig) -> None:
with pytest.raises(pyjwt.DecodeError):
decode_token("not.a.real.token", config)
def test_completely_garbage_raises(self, config: ApiGatewayConfig) -> None:
with pytest.raises(pyjwt.DecodeError):
decode_token("garbage", config)
# ---------------------------------------------------------------------------
# Auth Middleware Tests
# ---------------------------------------------------------------------------
class TestAuthMiddlewareValidToken:
"""test_auth_middleware_valid_token."""
def test_protected_route_with_valid_token(
self, client: TestClient, config: ApiGatewayConfig
) -> None:
token = create_access_token("user-42", "charlie", config)
resp = client.get(
"/protected", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 200
data = resp.json()
assert data["user_id"] == "user-42"
assert data["username"] == "charlie"
class TestAuthMiddlewareMissingToken:
"""test_auth_middleware_missing_token."""
def test_protected_route_no_header(self, client: TestClient) -> None:
resp = client.get("/protected")
assert resp.status_code == 401
assert "Missing authorization header" in resp.json()["detail"]
def test_protected_route_expired_token(
self, client: TestClient, config: ApiGatewayConfig
) -> None:
now = datetime.now(timezone.utc)
payload = {
"sub": "user-old",
"username": "old",
"type": "access",
"iat": now - timedelta(hours=2),
"exp": now - timedelta(hours=1),
}
token = pyjwt.encode(
payload, config.jwt_secret_key, algorithm=config.jwt_algorithm
)
resp = client.get(
"/protected", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 401
assert "expired" in resp.json()["detail"].lower()
def test_protected_route_invalid_token(self, client: TestClient) -> None:
resp = client.get(
"/protected",
headers={"Authorization": "Bearer garbage-token"},
)
assert resp.status_code == 401
assert "Invalid token" in resp.json()["detail"]
def test_refresh_token_rejected_as_access(
self, client: TestClient, config: ApiGatewayConfig
) -> None:
token = create_refresh_token("user-99", config)
resp = client.get(
"/protected", headers={"Authorization": f"Bearer {token}"}
)
assert resp.status_code == 401
assert "Invalid token type" in resp.json()["detail"]
# ---------------------------------------------------------------------------
# Health Endpoint Test
# ---------------------------------------------------------------------------
class TestHealthEndpoint:
"""test_health_endpoint."""
def test_health_returns_ok(self, client: TestClient) -> None:
resp = client.get("/health")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}