import base64 import json import logging import uuid from datetime import datetime, timedelta, timezone import jwt from webauthn import ( generate_registration_options, verify_registration_response, generate_authentication_options, verify_authentication_response, ) from webauthn.helpers import ( options_to_json, parse_registration_credential_json, parse_authentication_credential_json, ) from webauthn.helpers.structs import ( AuthenticatorSelectionCriteria, PublicKeyCredentialDescriptor, AuthenticatorTransport, ResidentKeyRequirement, UserVerificationRequirement, ) from webauthn.helpers.cose import COSEAlgorithmIdentifier from api.config import ( WEBAUTHN_RP_ID, WEBAUTHN_RP_NAME, WEBAUTHN_ORIGIN, JWT_SECRET, JWT_ALGORITHM, JWT_EXPIRATION_HOURS, JWT_ISSUER, ) from models.passkey_credential import PasskeyCredential from repositories.user_repository import UserRepository from redis_repository import RedisRepository logger = logging.getLogger("uvicorn") CHALLENGE_TTL = timedelta(minutes=5) CHALLENGE_KEY_PREFIX = "webauthn:challenge:" def _store_challenge(session_id: str, data: dict) -> None: # type: ignore[type-arg] redis = RedisRepository.instance() redis.set_key(f"{CHALLENGE_KEY_PREFIX}{session_id}", data, ttl=CHALLENGE_TTL) def _get_challenge(session_id: str) -> dict | None: # type: ignore[type-arg] redis = RedisRepository.instance() return redis.get_key(f"{CHALLENGE_KEY_PREFIX}{session_id}") # type: ignore[return-value] def _issue_jwt(user_id: int, email: str) -> str: now = datetime.now(timezone.utc) payload = { "sub": str(user_id), "email": email, "name": email, "iss": JWT_ISSUER, "iat": now, "exp": now + timedelta(hours=JWT_EXPIRATION_HOURS), } return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) def begin_registration( email: str, user_repo: UserRepository ) -> tuple[dict, str]: # type: ignore[type-arg] """Start WebAuthn registration ceremony. Returns (options_dict, session_id). """ user = user_repo.get_user_by_email(email) if user is None: user = user_repo.create_user(email) existing_credentials = user_repo.get_credentials_for_user(user.id) exclude_credentials = [] for cred in existing_credentials: transports = [] if cred.transports: transports = [ AuthenticatorTransport(t) for t in json.loads(cred.transports) ] exclude_credentials.append( PublicKeyCredentialDescriptor( id=base64.urlsafe_b64decode(cred.credential_id + "=="), transports=transports, ) ) options = generate_registration_options( rp_id=WEBAUTHN_RP_ID, rp_name=WEBAUTHN_RP_NAME, user_id=str(user.id).encode(), user_name=email, user_display_name=email, exclude_credentials=exclude_credentials, authenticator_selection=AuthenticatorSelectionCriteria( resident_key=ResidentKeyRequirement.REQUIRED, user_verification=UserVerificationRequirement.PREFERRED, ), supported_pub_key_algs=[ COSEAlgorithmIdentifier.ECDSA_SHA_256, COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256, ], ) session_id = str(uuid.uuid4()) _store_challenge(session_id, { "challenge": base64.urlsafe_b64encode(options.challenge).decode(), "user_id": user.id, "email": email, "type": "registration", }) options_json = json.loads(options_to_json(options)) return options_json, session_id def complete_registration( session_id: str, credential: dict, # type: ignore[type-arg] user_repo: UserRepository, ) -> str: """Complete WebAuthn registration ceremony. Returns a JWT string. """ challenge_data = _get_challenge(session_id) if not challenge_data or challenge_data.get("type") != "registration": raise ValueError("Invalid or expired registration session") expected_challenge = base64.urlsafe_b64decode( challenge_data["challenge"] + "==" ) registration_credential = parse_registration_credential_json( json.dumps(credential) ) verification = verify_registration_response( credential=registration_credential, expected_challenge=expected_challenge, expected_rp_id=WEBAUTHN_RP_ID, expected_origin=WEBAUTHN_ORIGIN, ) credential_id_b64 = base64.urlsafe_b64encode( verification.credential_id ).decode().rstrip("=") public_key_b64 = base64.urlsafe_b64encode( verification.credential_public_key ).decode().rstrip("=") transports_json = None if credential.get("response", {}).get("transports"): transports_json = json.dumps( credential["response"]["transports"] ) passkey_cred = PasskeyCredential( credential_id=credential_id_b64, public_key=public_key_b64, sign_count=verification.sign_count, transports=transports_json, user_id=challenge_data["user_id"], ) user_repo.save_credential(passkey_cred) return _issue_jwt(challenge_data["user_id"], challenge_data["email"]) def begin_authentication( user_repo: UserRepository, ) -> tuple[dict, str]: # type: ignore[type-arg] """Start WebAuthn authentication ceremony (discoverable credentials). Returns (options_dict, session_id). """ options = generate_authentication_options( rp_id=WEBAUTHN_RP_ID, user_verification=UserVerificationRequirement.PREFERRED, ) session_id = str(uuid.uuid4()) _store_challenge(session_id, { "challenge": base64.urlsafe_b64encode(options.challenge).decode(), "type": "authentication", }) options_json = json.loads(options_to_json(options)) return options_json, session_id def complete_authentication( session_id: str, credential: dict, # type: ignore[type-arg] user_repo: UserRepository, ) -> str: """Complete WebAuthn authentication ceremony. Returns a JWT string. """ challenge_data = _get_challenge(session_id) if not challenge_data or challenge_data.get("type") != "authentication": raise ValueError("Invalid or expired authentication session") expected_challenge = base64.urlsafe_b64decode( challenge_data["challenge"] + "==" ) # Look up the credential in the database raw_id = credential.get("rawId") or credential.get("id", "") stored_cred = user_repo.get_credential_by_id(raw_id) if not stored_cred: raise ValueError("Credential not found") stored_public_key = base64.urlsafe_b64decode( stored_cred.public_key + "==" ) auth_credential = parse_authentication_credential_json( json.dumps(credential) ) verification = verify_authentication_response( credential=auth_credential, expected_challenge=expected_challenge, expected_rp_id=WEBAUTHN_RP_ID, expected_origin=WEBAUTHN_ORIGIN, credential_public_key=stored_public_key, credential_current_sign_count=stored_cred.sign_count, ) user_repo.update_credential_sign_count( stored_cred.credential_id, verification.new_sign_count ) user = user_repo.get_user_by_id(stored_cred.user_id) if not user: raise ValueError("User not found") return _issue_jwt(user.id, user.email)