feat: API gateway with passkey (WebAuthn) authentication
This commit is contained in:
parent
f218865872
commit
e0d138c457
9 changed files with 907 additions and 2 deletions
1
services/api_gateway/auth/__init__.py
Normal file
1
services/api_gateway/auth/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Auth sub-package for the API Gateway."""
|
||||
98
services/api_gateway/auth/jwt.py
Normal file
98
services/api_gateway/auth/jwt.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""JWT utilities — token creation and verification."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import jwt
|
||||
|
||||
from services.api_gateway.config import ApiGatewayConfig
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: str,
|
||||
username: str,
|
||||
config: ApiGatewayConfig,
|
||||
) -> str:
|
||||
"""Create a short-lived JWT access token.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
user_id:
|
||||
UUID of the authenticated user (stored as ``sub`` claim).
|
||||
username:
|
||||
Username (stored as ``username`` claim).
|
||||
config:
|
||||
Gateway configuration with secret key and algorithm.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Encoded JWT string.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"username": username,
|
||||
"type": "access",
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=config.access_token_expire_minutes),
|
||||
}
|
||||
return jwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
user_id: str,
|
||||
config: ApiGatewayConfig,
|
||||
) -> str:
|
||||
"""Create a longer-lived JWT refresh token.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
user_id:
|
||||
UUID of the authenticated user (stored as ``sub`` claim).
|
||||
config:
|
||||
Gateway configuration with secret key and algorithm.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Encoded JWT string.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"type": "refresh",
|
||||
"iat": now,
|
||||
"exp": now + timedelta(days=config.refresh_token_expire_days),
|
||||
}
|
||||
return jwt.encode(payload, config.jwt_secret_key, algorithm=config.jwt_algorithm)
|
||||
|
||||
|
||||
def decode_token(token: str, config: ApiGatewayConfig) -> dict:
|
||||
"""Decode and verify a JWT token.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
token:
|
||||
The JWT string to decode.
|
||||
config:
|
||||
Gateway configuration with secret key and algorithm.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The decoded payload.
|
||||
|
||||
Raises
|
||||
------
|
||||
jwt.ExpiredSignatureError
|
||||
If the token has expired.
|
||||
jwt.InvalidTokenError
|
||||
If the token is malformed or signature verification fails.
|
||||
"""
|
||||
return jwt.decode(
|
||||
token,
|
||||
config.jwt_secret_key,
|
||||
algorithms=[config.jwt_algorithm],
|
||||
)
|
||||
68
services/api_gateway/auth/middleware.py
Normal file
68
services/api_gateway/auth/middleware.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
"""Auth middleware — FastAPI dependency for JWT-based authentication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
from services.api_gateway.auth.jwt import decode_token
|
||||
from services.api_gateway.config import ApiGatewayConfig
|
||||
|
||||
# Shared config instance — injected via FastAPI dependency override in tests.
|
||||
_config: ApiGatewayConfig | None = None
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
def get_config() -> ApiGatewayConfig:
|
||||
"""Return the singleton config. Overridden in tests."""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = ApiGatewayConfig()
|
||||
return _config
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
config: ApiGatewayConfig = Depends(get_config),
|
||||
) -> dict:
|
||||
"""FastAPI dependency that extracts and validates a Bearer JWT.
|
||||
|
||||
Returns the decoded token payload (contains ``sub``, ``username``, etc.)
|
||||
on success. Raises a 401 ``HTTPException`` for missing, expired, or
|
||||
invalid tokens.
|
||||
"""
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authorization header",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
try:
|
||||
payload = decode_token(token, config)
|
||||
except pyjwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except pyjwt.InvalidTokenError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Ensure it is an access token, not a refresh token
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return payload
|
||||
410
services/api_gateway/auth/routes.py
Normal file
410
services/api_gateway/auth/routes.py
Normal file
|
|
@ -0,0 +1,410 @@
|
|||
"""Auth routes — WebAuthn passkey registration/login and JWT refresh."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from webauthn import (
|
||||
generate_authentication_options,
|
||||
generate_registration_options,
|
||||
verify_authentication_response,
|
||||
verify_registration_response,
|
||||
)
|
||||
from webauthn.helpers.structs import (
|
||||
AuthenticatorSelectionCriteria,
|
||||
PublicKeyCredentialDescriptor,
|
||||
ResidentKeyRequirement,
|
||||
UserVerificationRequirement,
|
||||
)
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
from services.api_gateway.auth.jwt import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
)
|
||||
from services.api_gateway.auth.middleware import get_config
|
||||
from services.api_gateway.config import ApiGatewayConfig
|
||||
from shared.schemas.auth import LoginRequest, RegisterRequest, TokenResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — DB and Redis access
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _get_redis(request: Request):
|
||||
"""Retrieve the Redis client from application state."""
|
||||
return request.app.state.redis
|
||||
|
||||
|
||||
async def _get_db(request: Request):
|
||||
"""Retrieve an async DB session from application state."""
|
||||
session_factory = request.app.state.db_session_factory
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/register/begin")
|
||||
async def register_begin(
|
||||
body: RegisterRequest,
|
||||
request: Request,
|
||||
config: ApiGatewayConfig = Depends(get_config),
|
||||
) -> dict[str, Any]:
|
||||
"""Generate WebAuthn registration options (challenge + relying party info).
|
||||
|
||||
Stores the challenge in Redis with a 5-minute TTL so we can verify it
|
||||
in ``/register/complete``.
|
||||
"""
|
||||
redis = await _get_redis(request)
|
||||
db_session = request.app.state.db_session_factory
|
||||
|
||||
# Check if username already exists
|
||||
from sqlalchemy import select
|
||||
from shared.models.auth import User
|
||||
|
||||
async with db_session() as session:
|
||||
existing = (
|
||||
await session.execute(
|
||||
select(User).where(User.username == body.username)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Username already registered",
|
||||
)
|
||||
|
||||
user_id = uuid.uuid4()
|
||||
options = generate_registration_options(
|
||||
rp_id=config.rp_id,
|
||||
rp_name=config.rp_name,
|
||||
user_id=str(user_id).encode(),
|
||||
user_name=body.username,
|
||||
user_display_name=body.display_name or body.username,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
resident_key=ResidentKeyRequirement.PREFERRED,
|
||||
user_verification=UserVerificationRequirement.PREFERRED,
|
||||
),
|
||||
)
|
||||
|
||||
# Store challenge in Redis (5 min TTL)
|
||||
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
|
||||
redis_key = f"webauthn:register:{body.username}"
|
||||
import json as _json
|
||||
|
||||
await redis.setex(
|
||||
redis_key,
|
||||
300,
|
||||
_json.dumps({
|
||||
"challenge": challenge_b64,
|
||||
"user_id": str(user_id),
|
||||
"display_name": body.display_name or body.username,
|
||||
}),
|
||||
)
|
||||
|
||||
# Serialize options to a JSON-safe dict
|
||||
from webauthn.helpers import options_to_json
|
||||
|
||||
return {"options": _json.loads(options_to_json(options))}
|
||||
|
||||
|
||||
@router.post("/register/complete")
|
||||
async def register_complete(
|
||||
request: Request,
|
||||
config: ApiGatewayConfig = Depends(get_config),
|
||||
) -> TokenResponse:
|
||||
"""Verify WebAuthn registration response and store credential.
|
||||
|
||||
Expects the browser's ``navigator.credentials.create()`` response
|
||||
in the request body.
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
redis = await _get_redis(request)
|
||||
body = await request.json()
|
||||
username = body.get("username", "")
|
||||
|
||||
redis_key = f"webauthn:register:{username}"
|
||||
stored_raw = await redis.get(redis_key)
|
||||
if stored_raw is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Registration challenge expired or not found",
|
||||
)
|
||||
stored = _json.loads(stored_raw)
|
||||
expected_challenge = base64.urlsafe_b64decode(stored["challenge"])
|
||||
user_id_str = stored["user_id"]
|
||||
display_name = stored["display_name"]
|
||||
|
||||
try:
|
||||
verification = verify_registration_response(
|
||||
credential=body.get("credential", body),
|
||||
expected_challenge=expected_challenge,
|
||||
expected_rp_id=config.rp_id,
|
||||
expected_origin=config.rp_origin,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Registration verification failed: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Registration verification failed: {exc}",
|
||||
)
|
||||
|
||||
# Store user + credential in DB
|
||||
from shared.models.auth import User, UserCredential
|
||||
from sqlalchemy import select
|
||||
|
||||
db_session = request.app.state.db_session_factory
|
||||
async with db_session() as session:
|
||||
user = User(
|
||||
id=uuid.UUID(user_id_str),
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
credential = UserCredential(
|
||||
user_id=uuid.UUID(user_id_str),
|
||||
credential_id=base64.urlsafe_b64encode(
|
||||
verification.credential_id
|
||||
).decode(),
|
||||
public_key=verification.credential_public_key,
|
||||
sign_count=verification.sign_count,
|
||||
)
|
||||
session.add(credential)
|
||||
await session.commit()
|
||||
|
||||
# Clean up challenge
|
||||
await redis.delete(redis_key)
|
||||
|
||||
# Issue JWT
|
||||
access_token = create_access_token(user_id_str, username, config)
|
||||
refresh_token = create_refresh_token(user_id_str, config)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=config.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Login (authentication)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/login/begin")
|
||||
async def login_begin(
|
||||
body: LoginRequest,
|
||||
request: Request,
|
||||
config: ApiGatewayConfig = Depends(get_config),
|
||||
) -> dict[str, Any]:
|
||||
"""Generate WebAuthn authentication options for an existing user."""
|
||||
import json as _json
|
||||
|
||||
redis = await _get_redis(request)
|
||||
db_session = request.app.state.db_session_factory
|
||||
|
||||
from sqlalchemy import select
|
||||
from shared.models.auth import User, UserCredential
|
||||
|
||||
async with db_session() as session:
|
||||
user = (
|
||||
await session.execute(
|
||||
select(User).where(User.username == body.username)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
creds = (
|
||||
await session.execute(
|
||||
select(UserCredential).where(
|
||||
UserCredential.user_id == user.id
|
||||
)
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
allow_credentials = [
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(c.credential_id),
|
||||
)
|
||||
for c in creds
|
||||
]
|
||||
|
||||
options = generate_authentication_options(
|
||||
rp_id=config.rp_id,
|
||||
allow_credentials=allow_credentials,
|
||||
user_verification=UserVerificationRequirement.PREFERRED,
|
||||
)
|
||||
|
||||
challenge_b64 = base64.urlsafe_b64encode(options.challenge).decode()
|
||||
redis_key = f"webauthn:login:{body.username}"
|
||||
await redis.setex(
|
||||
redis_key,
|
||||
300,
|
||||
_json.dumps({
|
||||
"challenge": challenge_b64,
|
||||
"user_id": str(user.id),
|
||||
"username": user.username,
|
||||
}),
|
||||
)
|
||||
|
||||
from webauthn.helpers import options_to_json
|
||||
|
||||
return {"options": _json.loads(options_to_json(options))}
|
||||
|
||||
|
||||
@router.post("/login/complete")
|
||||
async def login_complete(
|
||||
request: Request,
|
||||
config: ApiGatewayConfig = Depends(get_config),
|
||||
) -> TokenResponse:
|
||||
"""Verify WebAuthn authentication response and issue JWT."""
|
||||
import json as _json
|
||||
|
||||
redis = await _get_redis(request)
|
||||
body = await request.json()
|
||||
username = body.get("username", "")
|
||||
|
||||
redis_key = f"webauthn:login:{username}"
|
||||
stored_raw = await redis.get(redis_key)
|
||||
if stored_raw is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Authentication challenge expired or not found",
|
||||
)
|
||||
stored = _json.loads(stored_raw)
|
||||
expected_challenge = base64.urlsafe_b64decode(stored["challenge"])
|
||||
user_id_str = stored["user_id"]
|
||||
|
||||
# Look up the credential used
|
||||
from sqlalchemy import select
|
||||
from shared.models.auth import UserCredential
|
||||
|
||||
credential_id_b64 = body.get("credential", body).get("id", "")
|
||||
db_session = request.app.state.db_session_factory
|
||||
|
||||
async with db_session() as session:
|
||||
cred = (
|
||||
await session.execute(
|
||||
select(UserCredential).where(
|
||||
UserCredential.credential_id == credential_id_b64
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if cred is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Credential not found",
|
||||
)
|
||||
|
||||
try:
|
||||
verification = verify_authentication_response(
|
||||
credential=body.get("credential", body),
|
||||
expected_challenge=expected_challenge,
|
||||
expected_rp_id=config.rp_id,
|
||||
expected_origin=config.rp_origin,
|
||||
credential_public_key=cred.public_key,
|
||||
credential_current_sign_count=cred.sign_count,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Authentication verification failed: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Authentication verification failed: {exc}",
|
||||
)
|
||||
|
||||
# Update sign count
|
||||
cred.sign_count = verification.new_sign_count
|
||||
await session.commit()
|
||||
|
||||
# Clean up challenge
|
||||
await redis.delete(redis_key)
|
||||
|
||||
# Issue JWT
|
||||
access_token = create_access_token(user_id_str, username, config)
|
||||
refresh_token = create_refresh_token(user_id_str, config)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=config.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh(request: Request, config: ApiGatewayConfig = Depends(get_config)) -> TokenResponse:
|
||||
"""Exchange a valid refresh token for a new access token."""
|
||||
body = await request.json()
|
||||
refresh_token = body.get("refresh_token", "")
|
||||
|
||||
try:
|
||||
payload = decode_token(refresh_token, config)
|
||||
except pyjwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token has expired",
|
||||
)
|
||||
except pyjwt.InvalidTokenError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
)
|
||||
|
||||
if payload.get("type") != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type",
|
||||
)
|
||||
|
||||
user_id = payload["sub"]
|
||||
|
||||
# Look up username from DB
|
||||
from sqlalchemy import select
|
||||
from shared.models.auth import User
|
||||
|
||||
db_session = request.app.state.db_session_factory
|
||||
async with db_session() as session:
|
||||
user = (
|
||||
await session.execute(
|
||||
select(User).where(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
access_token = create_access_token(user_id, user.username, config)
|
||||
new_refresh_token = create_refresh_token(user_id, config)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
expires_in=config.access_token_expire_minutes * 60,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue