from api.config import ( AUTHENTIK_URL, OIDC_CACHE_TTL, OIDC_CLIENT_ID, OIDC_METADATA_URL, JWT_SECRET, JWT_ALGORITHM, JWT_ISSUER, ) from cachetools import TTLCache from fastapi import Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from httpx import AsyncClient import jwt from pydantic import BaseModel # HTTPBearer scheme (provider-agnostic, works for both OIDC and passkey JWTs) http_bearer = HTTPBearer() JWKS_CACHE = TTLCache(maxsize=1, ttl=OIDC_CACHE_TTL) OIDC_METADATA_CACHE = TTLCache(maxsize=1, ttl=OIDC_CACHE_TTL) class User(BaseModel): sub: str # User ID email: str name: str async def get_oidc_metadata() -> dict: # type: ignore[type-arg] if "oidc_metadata" not in OIDC_METADATA_CACHE: async with AsyncClient() as client: resp = await client.get(OIDC_METADATA_URL, follow_redirects=True) OIDC_METADATA_CACHE["oidc_metadata"] = resp.json() return OIDC_METADATA_CACHE["oidc_metadata"] async def get_cached_jwks_client() -> jwt.PyJWKClient: if "jwks_client" not in JWKS_CACHE: metadata = await get_oidc_metadata() jwks_url = metadata["jwks_uri"] JWKS_CACHE["jwks_client"] = jwt.PyJWKClient( jwks_url, cache_keys=True, # PyJWT's built-in key caching max_cached_keys=5, ) return JWKS_CACHE["jwks_client"] async def _verify_authentik_token(token: str) -> User: """Verify a token issued by Authentik (RS256 via JWKS).""" metadata = await get_oidc_metadata() signing_key = (await get_cached_jwks_client()).get_signing_key_from_jwt(token) payload = jwt.decode( token, signing_key, algorithms=["RS256"], audience=OIDC_CLIENT_ID, issuer=metadata["issuer"], ) return User(**payload) def _verify_passkey_token(token: str) -> User: """Verify a token issued by the passkey service (HS256).""" payload = jwt.decode( token, JWT_SECRET, algorithms=[JWT_ALGORITHM], issuer=JWT_ISSUER, ) return User( sub=payload["sub"], email=payload["email"], name=payload.get("name", payload["email"]), ) async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(http_bearer), ) -> User: token = credentials.credentials try: # Decode WITHOUT verification just to read the "iss" claim for routing. # This is safe: we only use the issuer to decide which verified decode # path to take next; the actual security check happens in the branch below. unverified = jwt.decode( token, options={"verify_signature": False, "verify_exp": False} ) issuer = unverified.get("iss", "") if issuer == JWT_ISSUER: return _verify_passkey_token(token) else: return await _verify_authentik_token(token) except jwt.PyJWTError as e: raise HTTPException(status_code=401, detail=f"Invalid token: {e}")