trading/services/api_gateway/auth/routes.py

410 lines
13 KiB
Python

"""Auth routes — WebAuthn passkey registration/login and JWT refresh."""
from __future__ import annotations
import base64
import logging
import uuid
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Request, status
from webauthn import (
generate_authentication_options,
generate_registration_options,
verify_authentication_response,
verify_registration_response,
)
from webauthn.helpers.structs import (
AuthenticatorSelectionCriteria,
PublicKeyCredentialDescriptor,
ResidentKeyRequirement,
UserVerificationRequirement,
)
import jwt as pyjwt
from services.api_gateway.auth.jwt import (
create_access_token,
create_refresh_token,
decode_token,
)
from services.api_gateway.auth.middleware import get_config
from services.api_gateway.config import ApiGatewayConfig
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
# ---------------------------------------------------------------------------
# Helpers — DB and Redis access
# ---------------------------------------------------------------------------
async def _get_redis(request: Request):
"""Retrieve the Redis client from application state."""
return request.app.state.redis
async def _get_db(request: Request):
"""Retrieve an async DB session from application state."""
session_factory = request.app.state.db_session_factory
async with session_factory() as session:
yield session
# ---------------------------------------------------------------------------
# Registration
# ---------------------------------------------------------------------------
@router.post("/register/begin")
async def register_begin(
body: RegisterRequest,
request: Request,
config: ApiGatewayConfig = Depends(get_config),
) -> dict[str, Any]:
"""Generate WebAuthn registration options (challenge + relying party info).
Stores the challenge in Redis with a 5-minute TTL so we can verify it
in ``/register/complete``.
"""
redis = await _get_redis(request)
db_session = request.app.state.db_session_factory
# Check if username already exists
from sqlalchemy import select
from shared.models.auth import User
async with db_session() as session:
existing = (
await session.execute(
select(User).where(User.username == body.username)
)
).scalar_one_or_none()
if existing is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Username already registered",
)
user_id = uuid.uuid4()
options = generate_registration_options(
rp_id=config.rp_id,
rp_name=config.rp_name,
user_id=str(user_id).encode(),
user_name=body.username,
user_display_name=body.display_name or body.username,
authenticator_selection=AuthenticatorSelectionCriteria(
resident_key=ResidentKeyRequirement.PREFERRED,
user_verification=UserVerificationRequirement.PREFERRED,
),
)
# Store challenge in Redis (5 min TTL)
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
redis_key = f"webauthn:register:{body.username}"
import json as _json
await redis.setex(
redis_key,
300,
_json.dumps({
"challenge": challenge_b64,
"user_id": str(user_id),
"display_name": body.display_name or body.username,
}),
)
# Serialize options to a JSON-safe dict
from webauthn.helpers import options_to_json
return {"options": _json.loads(options_to_json(options))}
@router.post("/register/complete")
async def register_complete(
request: Request,
config: ApiGatewayConfig = Depends(get_config),
) -> TokenResponse:
"""Verify WebAuthn registration response and store credential.
Expects the browser's ``navigator.credentials.create()`` response
in the request body.
"""
import json as _json
redis = await _get_redis(request)
body = await request.json()
username = body.get("username", "")
redis_key = f"webauthn:register:{username}"
stored_raw = await redis.get(redis_key)
if stored_raw is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration challenge expired or not found",
)
stored = _json.loads(stored_raw)
expected_challenge = base64.urlsafe_b64decode(stored["challenge"])
user_id_str = stored["user_id"]
display_name = stored["display_name"]
try:
verification = verify_registration_response(
credential=body.get("credential", body),
expected_challenge=expected_challenge,
expected_rp_id=config.rp_id,
expected_origin=config.rp_origin,
)
except Exception as exc:
logger.warning("Registration verification failed: %s", exc)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Registration verification failed: {exc}",
)
# Store user + credential in DB
from shared.models.auth import User, UserCredential
from sqlalchemy import select
db_session = request.app.state.db_session_factory
async with db_session() as session:
user = User(
id=uuid.UUID(user_id_str),
username=username,
display_name=display_name,
)
session.add(user)
credential = UserCredential(
user_id=uuid.UUID(user_id_str),
credential_id=base64.urlsafe_b64encode(
verification.credential_id
).decode(),
public_key=verification.credential_public_key,
sign_count=verification.sign_count,
)
session.add(credential)
await session.commit()
# Clean up challenge
await redis.delete(redis_key)
# Issue JWT
access_token = create_access_token(user_id_str, username, config)
refresh_token = create_refresh_token(user_id_str, config)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=config.access_token_expire_minutes * 60,
)
# ---------------------------------------------------------------------------
# Login (authentication)
# ---------------------------------------------------------------------------
@router.post("/login/begin")
async def login_begin(
body: LoginRequest,
request: Request,
config: ApiGatewayConfig = Depends(get_config),
) -> dict[str, Any]:
"""Generate WebAuthn authentication options for an existing user."""
import json as _json
redis = await _get_redis(request)
db_session = request.app.state.db_session_factory
from sqlalchemy import select
from shared.models.auth import User, UserCredential
async with db_session() as session:
user = (
await session.execute(
select(User).where(User.username == body.username)
)
).scalar_one_or_none()
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
creds = (
await session.execute(
select(UserCredential).where(
UserCredential.user_id == user.id
)
)
).scalars().all()
allow_credentials = [
PublicKeyCredentialDescriptor(
id=base64.urlsafe_b64decode(c.credential_id),
)
for c in creds
]
options = generate_authentication_options(
rp_id=config.rp_id,
allow_credentials=allow_credentials,
user_verification=UserVerificationRequirement.PREFERRED,
)
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
redis_key = f"webauthn:login:{body.username}"
await redis.setex(
redis_key,
300,
_json.dumps({
"challenge": challenge_b64,
"user_id": str(user.id),
"username": user.username,
}),
)
from webauthn.helpers import options_to_json
return {"options": _json.loads(options_to_json(options))}
@router.post("/login/complete")
async def login_complete(
request: Request,
config: ApiGatewayConfig = Depends(get_config),
) -> TokenResponse:
"""Verify WebAuthn authentication response and issue JWT."""
import json as _json
redis = await _get_redis(request)
body = await request.json()
username = body.get("username", "")
redis_key = f"webauthn:login:{username}"
stored_raw = await redis.get(redis_key)
if stored_raw is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Authentication challenge expired or not found",
)
stored = _json.loads(stored_raw)
expected_challenge = base64.urlsafe_b64decode(stored["challenge"])
user_id_str = stored["user_id"]
# Look up the credential used
from sqlalchemy import select
from shared.models.auth import UserCredential
credential_id_b64 = body.get("credential", body).get("id", "")
db_session = request.app.state.db_session_factory
async with db_session() as session:
cred = (
await session.execute(
select(UserCredential).where(
UserCredential.credential_id == credential_id_b64
)
)
).scalar_one_or_none()
if cred is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Credential not found",
)
try:
verification = verify_authentication_response(
credential=body.get("credential", body),
expected_challenge=expected_challenge,
expected_rp_id=config.rp_id,
expected_origin=config.rp_origin,
credential_public_key=cred.public_key,
credential_current_sign_count=cred.sign_count,
)
except Exception as exc:
logger.warning("Authentication verification failed: %s", exc)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Authentication verification failed: {exc}",
)
# Update sign count
cred.sign_count = verification.new_sign_count
await session.commit()
# Clean up challenge
await redis.delete(redis_key)
# Issue JWT
access_token = create_access_token(user_id_str, username, config)
refresh_token = create_refresh_token(user_id_str, config)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=config.access_token_expire_minutes * 60,
)
# ---------------------------------------------------------------------------
# Refresh
# ---------------------------------------------------------------------------
@router.post("/refresh")
async def refresh(request: Request, config: ApiGatewayConfig = Depends(get_config)) -> TokenResponse:
"""Exchange a valid refresh token for a new access token."""
body = await request.json()
refresh_token = body.get("refresh_token", "")
try:
payload = decode_token(refresh_token, config)
except pyjwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has expired",
)
except pyjwt.InvalidTokenError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
if payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
)
user_id = payload["sub"]
# Look up username from DB
from sqlalchemy import select
from shared.models.auth import User
db_session = request.app.state.db_session_factory
async with db_session() as session:
user = (
await session.execute(
select(User).where(User.id == uuid.UUID(user_id))
)
).scalar_one_or_none()
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
access_token = create_access_token(user_id, user.username, config)
new_refresh_token = create_refresh_token(user_id, config)
return TokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
expires_in=config.access_token_expire_minutes * 60,
)