feat: API gateway with passkey (WebAuthn) authentication
This commit is contained in:
parent
f218865872
commit
e0d138c457
9 changed files with 907 additions and 2 deletions
|
|
@ -15,12 +15,12 @@ dependencies = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
api = ["fastapi>=0.110", "uvicorn[standard]>=0.27", "websockets>=12.0", "py-webauthn>=2.0", "pyjwt[crypto]>=2.8"]
|
api = ["fastapi>=0.110", "uvicorn[standard]>=0.27", "websockets>=12.0", "webauthn>=2.0", "pyjwt[crypto]>=2.8"]
|
||||||
news = ["feedparser>=6.0", "praw>=7.7", "asyncpraw>=7.7", "httpx>=0.27"]
|
news = ["feedparser>=6.0", "praw>=7.7", "asyncpraw>=7.7", "httpx>=0.27"]
|
||||||
sentiment = ["transformers>=4.38", "torch>=2.2", "ollama>=0.1"]
|
sentiment = ["transformers>=4.38", "torch>=2.2", "ollama>=0.1"]
|
||||||
trading = ["alpaca-py>=0.21"]
|
trading = ["alpaca-py>=0.21"]
|
||||||
backtester = ["numpy>=1.26", "pandas>=2.2"]
|
backtester = ["numpy>=1.26", "pandas>=2.2"]
|
||||||
dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "pytest-cov>=4.1", "ruff>=0.3", "mypy>=1.8"]
|
dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "pytest-cov>=4.1", "ruff>=0.3", "mypy>=1.8", "httpx>=0.27"]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=70.0"]
|
requires = ["setuptools>=70.0"]
|
||||||
|
|
|
||||||
1
services/api_gateway/__init__.py
Normal file
1
services/api_gateway/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""API Gateway service — FastAPI application serving the trading bot dashboard."""
|
||||||
1
services/api_gateway/auth/__init__.py
Normal file
1
services/api_gateway/auth/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Auth sub-package for the API Gateway."""
|
||||||
98
services/api_gateway/auth/jwt.py
Normal file
98
services/api_gateway/auth/jwt.py
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
"""JWT utilities — token creation and verification."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from services.api_gateway.config import ApiGatewayConfig
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(
|
||||||
|
user_id: str,
|
||||||
|
username: str,
|
||||||
|
config: ApiGatewayConfig,
|
||||||
|
) -> str:
|
||||||
|
"""Create a short-lived JWT access token.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
user_id:
|
||||||
|
UUID of the authenticated user (stored as ``sub`` claim).
|
||||||
|
username:
|
||||||
|
Username (stored as ``username`` claim).
|
||||||
|
config:
|
||||||
|
Gateway configuration with secret key and algorithm.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Encoded JWT string.
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"username": username,
|
||||||
|
"type": "access",
|
||||||
|
"iat": now,
|
||||||
|
"exp": now + timedelta(minutes=config.access_token_expire_minutes),
|
||||||
|
}
|
||||||
|
return jwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(
|
||||||
|
user_id: str,
|
||||||
|
config: ApiGatewayConfig,
|
||||||
|
) -> str:
|
||||||
|
"""Create a longer-lived JWT refresh token.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
user_id:
|
||||||
|
UUID of the authenticated user (stored as ``sub`` claim).
|
||||||
|
config:
|
||||||
|
Gateway configuration with secret key and algorithm.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Encoded JWT string.
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"type": "refresh",
|
||||||
|
"iat": now,
|
||||||
|
"exp": now + timedelta(days=config.refresh_token_expire_days),
|
||||||
|
}
|
||||||
|
return jwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_token(token: str, config: ApiGatewayConfig) -> dict:
|
||||||
|
"""Decode and verify a JWT token.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
token:
|
||||||
|
The JWT string to decode.
|
||||||
|
config:
|
||||||
|
Gateway configuration with secret key and algorithm.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
The decoded payload.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
jwt.ExpiredSignatureError
|
||||||
|
If the token has expired.
|
||||||
|
jwt.InvalidTokenError
|
||||||
|
If the token is malformed or signature verification fails.
|
||||||
|
"""
|
||||||
|
return jwt.decode(
|
||||||
|
token,
|
||||||
|
config.jwt_secret_key,
|
||||||
|
algorithms=[config.jwt_algorithm],
|
||||||
|
)
|
||||||
68
services/api_gateway/auth/middleware.py
Normal file
68
services/api_gateway/auth/middleware.py
Normal file
|
|
@ -0,0 +1,68 @@
|
||||||
|
"""Auth middleware — FastAPI dependency for JWT-based authentication."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
|
||||||
|
import jwt as pyjwt
|
||||||
|
|
||||||
|
from services.api_gateway.auth.jwt import decode_token
|
||||||
|
from services.api_gateway.config import ApiGatewayConfig
|
||||||
|
|
||||||
|
# Shared config instance — injected via FastAPI dependency override in tests.
|
||||||
|
_config: ApiGatewayConfig | None = None
|
||||||
|
|
||||||
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_config() -> ApiGatewayConfig:
|
||||||
|
"""Return the singleton config. Overridden in tests."""
|
||||||
|
global _config
|
||||||
|
if _config is None:
|
||||||
|
_config = ApiGatewayConfig()
|
||||||
|
return _config
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||||
|
config: ApiGatewayConfig = Depends(get_config),
|
||||||
|
) -> dict:
|
||||||
|
"""FastAPI dependency that extracts and validates a Bearer JWT.
|
||||||
|
|
||||||
|
Returns the decoded token payload (contains ``sub``, ``username``, etc.)
|
||||||
|
on success. Raises a 401 ``HTTPException`` for missing, expired, or
|
||||||
|
invalid tokens.
|
||||||
|
"""
|
||||||
|
if credentials is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing authorization header",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
token = credentials.credentials
|
||||||
|
try:
|
||||||
|
payload = decode_token(token, config)
|
||||||
|
except pyjwt.ExpiredSignatureError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Token has expired",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
except pyjwt.InvalidTokenError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure it is an access token, not a refresh token
|
||||||
|
if payload.get("type") != "access":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token type",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return payload
|
||||||
410
services/api_gateway/auth/routes.py
Normal file
410
services/api_gateway/auth/routes.py
Normal file
|
|
@ -0,0 +1,410 @@
|
||||||
|
"""Auth routes — WebAuthn passkey registration/login and JWT refresh."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from webauthn import (
|
||||||
|
generate_authentication_options,
|
||||||
|
generate_registration_options,
|
||||||
|
verify_authentication_response,
|
||||||
|
verify_registration_response,
|
||||||
|
)
|
||||||
|
from webauthn.helpers.structs import (
|
||||||
|
AuthenticatorSelectionCriteria,
|
||||||
|
PublicKeyCredentialDescriptor,
|
||||||
|
ResidentKeyRequirement,
|
||||||
|
UserVerificationRequirement,
|
||||||
|
)
|
||||||
|
|
||||||
|
import jwt as pyjwt
|
||||||
|
|
||||||
|
from services.api_gateway.auth.jwt import (
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
decode_token,
|
||||||
|
)
|
||||||
|
from services.api_gateway.auth.middleware import get_config
|
||||||
|
from services.api_gateway.config import ApiGatewayConfig
|
||||||
|
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers — DB and Redis access
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _get_redis(request: Request):
|
||||||
|
"""Retrieve the Redis client from application state."""
|
||||||
|
return request.app.state.redis
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_db(request: Request):
|
||||||
|
"""Retrieve an async DB session from application state."""
|
||||||
|
session_factory = request.app.state.db_session_factory
|
||||||
|
async with session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Registration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register/begin")
|
||||||
|
async def register_begin(
|
||||||
|
body: RegisterRequest,
|
||||||
|
request: Request,
|
||||||
|
config: ApiGatewayConfig = Depends(get_config),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Generate WebAuthn registration options (challenge + relying party info).
|
||||||
|
|
||||||
|
Stores the challenge in Redis with a 5-minute TTL so we can verify it
|
||||||
|
in ``/register/complete``.
|
||||||
|
"""
|
||||||
|
redis = await _get_redis(request)
|
||||||
|
db_session = request.app.state.db_session_factory
|
||||||
|
|
||||||
|
# Check if username already exists
|
||||||
|
from sqlalchemy import select
|
||||||
|
from shared.models.auth import User
|
||||||
|
|
||||||
|
async with db_session() as session:
|
||||||
|
existing = (
|
||||||
|
await session.execute(
|
||||||
|
select(User).where(User.username == body.username)
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if existing is not None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Username already registered",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
options = generate_registration_options(
|
||||||
|
rp_id=config.rp_id,
|
||||||
|
rp_name=config.rp_name,
|
||||||
|
user_id=str(user_id).encode(),
|
||||||
|
user_name=body.username,
|
||||||
|
user_display_name=body.display_name or body.username,
|
||||||
|
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||||
|
resident_key=ResidentKeyRequirement.PREFERRED,
|
||||||
|
user_verification=UserVerificationRequirement.PREFERRED,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store challenge in Redis (5 min TTL)
|
||||||
|
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
|
||||||
|
redis_key = f"webauthn:register:{body.username}"
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
await redis.setex(
|
||||||
|
redis_key,
|
||||||
|
300,
|
||||||
|
_json.dumps({
|
||||||
|
"challenge": challenge_b64,
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"display_name": body.display_name or body.username,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Serialize options to a JSON-safe dict
|
||||||
|
from webauthn.helpers import options_to_json
|
||||||
|
|
||||||
|
return {"options": _json.loads(options_to_json(options))}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register/complete")
|
||||||
|
async def register_complete(
|
||||||
|
request: Request,
|
||||||
|
config: ApiGatewayConfig = Depends(get_config),
|
||||||
|
) -> TokenResponse:
|
||||||
|
"""Verify WebAuthn registration response and store credential.
|
||||||
|
|
||||||
|
Expects the browser's ``navigator.credentials.create()`` response
|
||||||
|
in the request body.
|
||||||
|
"""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
redis = await _get_redis(request)
|
||||||
|
body = await request.json()
|
||||||
|
username = body.get("username", "")
|
||||||
|
|
||||||
|
redis_key = f"webauthn:register:{username}"
|
||||||
|
stored_raw = await redis.get(redis_key)
|
||||||
|
if stored_raw is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Registration challenge expired or not found",
|
||||||
|
)
|
||||||
|
stored = _json.loads(stored_raw)
|
||||||
|
expected_challenge = base64.urlsafe_b64decode(stored["challenge"])
|
||||||
|
user_id_str = stored["user_id"]
|
||||||
|
display_name = stored["display_name"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
verification = verify_registration_response(
|
||||||
|
credential=body.get("credential", body),
|
||||||
|
expected_challenge=expected_challenge,
|
||||||
|
expected_rp_id=config.rp_id,
|
||||||
|
expected_origin=config.rp_origin,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Registration verification failed: %s", exc)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Registration verification failed: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store user + credential in DB
|
||||||
|
from shared.models.auth import User, UserCredential
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
db_session = request.app.state.db_session_factory
|
||||||
|
async with db_session() as session:
|
||||||
|
user = User(
|
||||||
|
id=uuid.UUID(user_id_str),
|
||||||
|
username=username,
|
||||||
|
display_name=display_name,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
|
||||||
|
credential = UserCredential(
|
||||||
|
user_id=uuid.UUID(user_id_str),
|
||||||
|
credential_id=base64.urlsafe_b64encode(
|
||||||
|
verification.credential_id
|
||||||
|
).decode(),
|
||||||
|
public_key=verification.credential_public_key,
|
||||||
|
sign_count=verification.sign_count,
|
||||||
|
)
|
||||||
|
session.add(credential)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Clean up challenge
|
||||||
|
await redis.delete(redis_key)
|
||||||
|
|
||||||
|
# Issue JWT
|
||||||
|
access_token = create_access_token(user_id_str, username, config)
|
||||||
|
refresh_token = create_refresh_token(user_id_str, config)
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
expires_in=config.access_token_expire_minutes * 60,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Login (authentication)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login/begin")
|
||||||
|
async def login_begin(
|
||||||
|
body: LoginRequest,
|
||||||
|
request: Request,
|
||||||
|
config: ApiGatewayConfig = Depends(get_config),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Generate WebAuthn authentication options for an existing user."""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
redis = await _get_redis(request)
|
||||||
|
db_session = request.app.state.db_session_factory
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from shared.models.auth import User, UserCredential
|
||||||
|
|
||||||
|
async with db_session() as session:
|
||||||
|
user = (
|
||||||
|
await session.execute(
|
||||||
|
select(User).where(User.username == body.username)
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
creds = (
|
||||||
|
await session.execute(
|
||||||
|
select(UserCredential).where(
|
||||||
|
UserCredential.user_id == user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalars().all()
|
||||||
|
|
||||||
|
allow_credentials = [
|
||||||
|
PublicKeyCredentialDescriptor(
|
||||||
|
id=base64.urlsafe_b64decode(c.credential_id),
|
||||||
|
)
|
||||||
|
for c in creds
|
||||||
|
]
|
||||||
|
|
||||||
|
options = generate_authentication_options(
|
||||||
|
rp_id=config.rp_id,
|
||||||
|
allow_credentials=allow_credentials,
|
||||||
|
user_verification=UserVerificationRequirement.PREFERRED,
|
||||||
|
)
|
||||||
|
|
||||||
|
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
|
||||||
|
redis_key = f"webauthn:login:{body.username}"
|
||||||
|
await redis.setex(
|
||||||
|
redis_key,
|
||||||
|
300,
|
||||||
|
_json.dumps({
|
||||||
|
"challenge": challenge_b64,
|
||||||
|
"user_id": str(user.id),
|
||||||
|
"username": user.username,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
from webauthn.helpers import options_to_json
|
||||||
|
|
||||||
|
return {"options": _json.loads(options_to_json(options))}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login/complete")
|
||||||
|
async def login_complete(
|
||||||
|
request: Request,
|
||||||
|
config: ApiGatewayConfig = Depends(get_config),
|
||||||
|
) -> TokenResponse:
|
||||||
|
"""Verify WebAuthn authentication response and issue JWT."""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
redis = await _get_redis(request)
|
||||||
|
body = await request.json()
|
||||||
|
username = body.get("username", "")
|
||||||
|
|
||||||
|
redis_key = f"webauthn:login:{username}"
|
||||||
|
stored_raw = await redis.get(redis_key)
|
||||||
|
if stored_raw is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Authentication challenge expired or not found",
|
||||||
|
)
|
||||||
|
stored = _json.loads(stored_raw)
|
||||||
|
expected_challenge = base64.urlsafe_b64decode(stored["challenge"])
|
||||||
|
user_id_str = stored["user_id"]
|
||||||
|
|
||||||
|
# Look up the credential used
|
||||||
|
from sqlalchemy import select
|
||||||
|
from shared.models.auth import UserCredential
|
||||||
|
|
||||||
|
credential_id_b64 = body.get("credential", body).get("id", "")
|
||||||
|
db_session = request.app.state.db_session_factory
|
||||||
|
|
||||||
|
async with db_session() as session:
|
||||||
|
cred = (
|
||||||
|
await session.execute(
|
||||||
|
select(UserCredential).where(
|
||||||
|
UserCredential.credential_id == credential_id_b64
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if cred is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Credential not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
verification = verify_authentication_response(
|
||||||
|
credential=body.get("credential", body),
|
||||||
|
expected_challenge=expected_challenge,
|
||||||
|
expected_rp_id=config.rp_id,
|
||||||
|
expected_origin=config.rp_origin,
|
||||||
|
credential_public_key=cred.public_key,
|
||||||
|
credential_current_sign_count=cred.sign_count,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Authentication verification failed: %s", exc)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Authentication verification failed: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update sign count
|
||||||
|
cred.sign_count = verification.new_sign_count
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Clean up challenge
|
||||||
|
await redis.delete(redis_key)
|
||||||
|
|
||||||
|
# Issue JWT
|
||||||
|
access_token = create_access_token(user_id_str, username, config)
|
||||||
|
refresh_token = create_refresh_token(user_id_str, config)
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
expires_in=config.access_token_expire_minutes * 60,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Refresh
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh")
|
||||||
|
async def refresh(request: Request, config: ApiGatewayConfig = Depends(get_config)) -> TokenResponse:
|
||||||
|
"""Exchange a valid refresh token for a new access token."""
|
||||||
|
body = await request.json()
|
||||||
|
refresh_token = body.get("refresh_token", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = decode_token(refresh_token, config)
|
||||||
|
except pyjwt.ExpiredSignatureError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Refresh token has expired",
|
||||||
|
)
|
||||||
|
except pyjwt.InvalidTokenError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token",
|
||||||
|
)
|
||||||
|
|
||||||
|
if payload.get("type") != "refresh":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token type",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = payload["sub"]
|
||||||
|
|
||||||
|
# Look up username from DB
|
||||||
|
from sqlalchemy import select
|
||||||
|
from shared.models.auth import User
|
||||||
|
|
||||||
|
db_session = request.app.state.db_session_factory
|
||||||
|
async with db_session() as session:
|
||||||
|
user = (
|
||||||
|
await session.execute(
|
||||||
|
select(User).where(User.id == uuid.UUID(user_id))
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token = create_access_token(user_id, user.username, config)
|
||||||
|
new_refresh_token = create_refresh_token(user_id, config)
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=new_refresh_token,
|
||||||
|
expires_in=config.access_token_expire_minutes * 60,
|
||||||
|
)
|
||||||
25
services/api_gateway/config.py
Normal file
25
services/api_gateway/config.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
"""API Gateway configuration — extends shared BaseConfig with JWT, CORS, and WebAuthn settings."""
|
||||||
|
|
||||||
|
from shared.config import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ApiGatewayConfig(BaseConfig):
|
||||||
|
"""Configuration for the API Gateway service.
|
||||||
|
|
||||||
|
All settings can be overridden via environment variables
|
||||||
|
prefixed with ``TRADING_``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# JWT settings
|
||||||
|
jwt_secret_key: str = "CHANGE-ME-IN-PRODUCTION"
|
||||||
|
jwt_algorithm: str = "HS256"
|
||||||
|
access_token_expire_minutes: int = 15
|
||||||
|
refresh_token_expire_days: int = 7
|
||||||
|
|
||||||
|
# CORS settings
|
||||||
|
cors_origins: list[str] = ["http://localhost:5173"]
|
||||||
|
|
||||||
|
# WebAuthn (passkey) relying party settings
|
||||||
|
rp_id: str = "localhost"
|
||||||
|
rp_name: str = "Trading Bot"
|
||||||
|
rp_origin: str = "http://localhost:5173"
|
||||||
81
services/api_gateway/main.py
Normal file
81
services/api_gateway/main.py
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
"""FastAPI application — API Gateway for the trading bot."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from services.api_gateway.auth.routes import router as auth_router
|
||||||
|
from services.api_gateway.config import ApiGatewayConfig
|
||||||
|
from shared.db import create_db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(config: ApiGatewayConfig | None = None) -> FastAPI:
|
||||||
|
"""Build and configure the FastAPI application.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config:
|
||||||
|
Optional config override (useful for testing). If ``None``, a new
|
||||||
|
:class:`ApiGatewayConfig` is created from environment variables.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = ApiGatewayConfig()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||||
|
"""Start-up / shutdown hook — connect DB and Redis."""
|
||||||
|
# Database
|
||||||
|
engine, session_factory = create_db(config)
|
||||||
|
app.state.db_engine = engine
|
||||||
|
app.state.db_session_factory = session_factory
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
app.state.redis = Redis.from_url(
|
||||||
|
config.redis_url, decode_responses=True
|
||||||
|
)
|
||||||
|
app.state.config = config
|
||||||
|
|
||||||
|
logger.info("API Gateway started")
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await app.state.redis.aclose()
|
||||||
|
await engine.dispose()
|
||||||
|
logger.info("API Gateway stopped")
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Trading Bot API",
|
||||||
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# CORS
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=config.cors_origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auth routes (unauthenticated)
|
||||||
|
app.include_router(auth_router)
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
@app.get("/health", tags=["health"])
|
||||||
|
async def health() -> dict:
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience: allow ``uvicorn services.api_gateway.main:app``
|
||||||
|
app = create_app()
|
||||||
221
tests/services/test_api_auth.py
Normal file
221
tests/services/test_api_auth.py
Normal file
|
|
@ -0,0 +1,221 @@
|
||||||
|
"""Tests for API Gateway auth — JWT, middleware, and health endpoint."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import jwt as pyjwt
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from services.api_gateway.auth.jwt import (
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
decode_token,
|
||||||
|
)
|
||||||
|
from services.api_gateway.auth.middleware import get_config, get_current_user
|
||||||
|
from services.api_gateway.config import ApiGatewayConfig
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def config() -> ApiGatewayConfig:
|
||||||
|
return ApiGatewayConfig(
|
||||||
|
jwt_secret_key="test-secret-key-for-unit-tests",
|
||||||
|
jwt_algorithm="HS256",
|
||||||
|
access_token_expire_minutes=15,
|
||||||
|
refresh_token_expire_days=7,
|
||||||
|
database_url="sqlite+aiosqlite:///:memory:",
|
||||||
|
redis_url="redis://localhost:6379/0",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def app(config: ApiGatewayConfig) -> FastAPI:
|
||||||
|
"""Create a minimal FastAPI app for testing the auth middleware."""
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.dependency_overrides[get_config] = lambda: config
|
||||||
|
|
||||||
|
@app.get("/protected")
|
||||||
|
async def protected(user: dict = Depends(get_current_user)):
|
||||||
|
return {"user_id": user["sub"], "username": user["username"]}
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def client(app: FastAPI) -> TestClient:
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# JWT Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestJWTCreateAndDecode:
|
||||||
|
"""test_jwt_create_and_decode — round-trip create + decode."""
|
||||||
|
|
||||||
|
def test_access_token_round_trip(self, config: ApiGatewayConfig) -> None:
|
||||||
|
token = create_access_token("user-123", "alice", config)
|
||||||
|
payload = decode_token(token, config)
|
||||||
|
|
||||||
|
assert payload["sub"] == "user-123"
|
||||||
|
assert payload["username"] == "alice"
|
||||||
|
assert payload["type"] == "access"
|
||||||
|
assert "exp" in payload
|
||||||
|
assert "iat" in payload
|
||||||
|
|
||||||
|
def test_refresh_token_round_trip(self, config: ApiGatewayConfig) -> None:
|
||||||
|
token = create_refresh_token("user-456", config)
|
||||||
|
payload = decode_token(token, config)
|
||||||
|
|
||||||
|
assert payload["sub"] == "user-456"
|
||||||
|
assert payload["type"] == "refresh"
|
||||||
|
assert "exp" in payload
|
||||||
|
|
||||||
|
def test_access_token_expiry_time(self, config: ApiGatewayConfig) -> None:
|
||||||
|
token = create_access_token("u1", "bob", config)
|
||||||
|
payload = decode_token(token, config)
|
||||||
|
exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||||
|
iat = datetime.fromtimestamp(payload["iat"], tz=timezone.utc)
|
||||||
|
delta = exp - iat
|
||||||
|
assert timedelta(minutes=14) < delta <= timedelta(minutes=16)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJWTExpiredToken:
|
||||||
|
"""test_jwt_expired_token_rejected."""
|
||||||
|
|
||||||
|
def test_expired_access_token_raises(self, config: ApiGatewayConfig) -> None:
|
||||||
|
# Manually create a token that already expired
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
payload = {
|
||||||
|
"sub": "user-expired",
|
||||||
|
"username": "expired",
|
||||||
|
"type": "access",
|
||||||
|
"iat": now - timedelta(hours=2),
|
||||||
|
"exp": now - timedelta(hours=1),
|
||||||
|
}
|
||||||
|
token = pyjwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm)
|
||||||
|
|
||||||
|
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||||
|
decode_token(token, config)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJWTInvalidToken:
|
||||||
|
"""test_jwt_invalid_token_rejected."""
|
||||||
|
|
||||||
|
def test_wrong_secret_raises(self, config: ApiGatewayConfig) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
payload = {
|
||||||
|
"sub": "user-bad",
|
||||||
|
"type": "access",
|
||||||
|
"exp": now + timedelta(hours=1),
|
||||||
|
}
|
||||||
|
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||||
|
|
||||||
|
with pytest.raises(pyjwt.InvalidSignatureError):
|
||||||
|
decode_token(token, config)
|
||||||
|
|
||||||
|
def test_malformed_token_raises(self, config: ApiGatewayConfig) -> None:
|
||||||
|
with pytest.raises(pyjwt.DecodeError):
|
||||||
|
decode_token("not.a.real.token", config)
|
||||||
|
|
||||||
|
def test_completely_garbage_raises(self, config: ApiGatewayConfig) -> None:
|
||||||
|
with pytest.raises(pyjwt.DecodeError):
|
||||||
|
decode_token("garbage", config)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Auth Middleware Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthMiddlewareValidToken:
|
||||||
|
"""test_auth_middleware_valid_token."""
|
||||||
|
|
||||||
|
def test_protected_route_with_valid_token(
|
||||||
|
self, client: TestClient, config: ApiGatewayConfig
|
||||||
|
) -> None:
|
||||||
|
token = create_access_token("user-42", "charlie", config)
|
||||||
|
resp = client.get(
|
||||||
|
"/protected", headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["user_id"] == "user-42"
|
||||||
|
assert data["username"] == "charlie"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthMiddlewareMissingToken:
|
||||||
|
"""test_auth_middleware_missing_token."""
|
||||||
|
|
||||||
|
def test_protected_route_no_header(self, client: TestClient) -> None:
|
||||||
|
resp = client.get("/protected")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
assert "Missing authorization header" in resp.json()["detail"]
|
||||||
|
|
||||||
|
def test_protected_route_expired_token(
|
||||||
|
self, client: TestClient, config: ApiGatewayConfig
|
||||||
|
) -> None:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
payload = {
|
||||||
|
"sub": "user-old",
|
||||||
|
"username": "old",
|
||||||
|
"type": "access",
|
||||||
|
"iat": now - timedelta(hours=2),
|
||||||
|
"exp": now - timedelta(hours=1),
|
||||||
|
}
|
||||||
|
token = pyjwt.encode(
|
||||||
|
payload, config.jwt_secret_key, algorithm=config.jwt_algorithm
|
||||||
|
)
|
||||||
|
resp = client.get(
|
||||||
|
"/protected", headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
assert "expired" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
def test_protected_route_invalid_token(self, client: TestClient) -> None:
|
||||||
|
resp = client.get(
|
||||||
|
"/protected",
|
||||||
|
headers={"Authorization": "Bearer garbage-token"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
assert "Invalid token" in resp.json()["detail"]
|
||||||
|
|
||||||
|
def test_refresh_token_rejected_as_access(
|
||||||
|
self, client: TestClient, config: ApiGatewayConfig
|
||||||
|
) -> None:
|
||||||
|
token = create_refresh_token("user-99", config)
|
||||||
|
resp = client.get(
|
||||||
|
"/protected", headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
assert "Invalid token type" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Health Endpoint Test
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthEndpoint:
|
||||||
|
"""test_health_endpoint."""
|
||||||
|
|
||||||
|
def test_health_returns_ok(self, client: TestClient) -> None:
|
||||||
|
resp = client.get("/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"status": "ok"}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue