from datetime import timedelta from api.config import AUTHENTIK_URL, OIDC_CACHE_TTL, OIDC_CLIENT_ID, OIDC_METADATA_URL from cachetools import TTLCache from fastapi import Depends, HTTPException from fastapi.security import OAuth2AuthorizationCodeBearer from httpx import AsyncClient import jwt from pydantic import BaseModel # OAuth2 Scheme oauth2_scheme = OAuth2AuthorizationCodeBearer( authorizationUrl=f"{AUTHENTIK_URL}/application/o/authorize/", tokenUrl=f"{AUTHENTIK_URL}/application/o/token/", ) JWKS_CACHE = TTLCache(maxsize=1, ttl=OIDC_CACHE_TTL) OIDC_METADATA_CACHE = TTLCache(maxsize=1, ttl=OIDC_CACHE_TTL) class User(BaseModel): sub: str # User ID email: str name: str async def get_oidc_metadata(): if "oidc_metadata" not in OIDC_METADATA_CACHE: async with AsyncClient() as client: resp = await client.get(OIDC_METADATA_URL, follow_redirects=True) OIDC_METADATA_CACHE["oidc_metadata"] = resp.json() return OIDC_METADATA_CACHE["oidc_metadata"] async def get_cached_jwks_client() -> jwt.PyJWKClient: if "jwks_client" not in JWKS_CACHE: metadata = await get_oidc_metadata() jwks_url = metadata["jwks_uri"] JWKS_CACHE["jwks_client"] = jwt.PyJWKClient( jwks_url, cache_keys=True, # PyJWT's built-in key caching max_cached_keys=5, ) return JWKS_CACHE["jwks_client"] async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: try: # Fetch JWKS keys from Authentik metadata = await get_oidc_metadata() signing_key = (await get_cached_jwks_client()).get_signing_key_from_jwt(token) # Decode and verify JWT payload = jwt.decode( token, signing_key, algorithms=["RS256"], audience=OIDC_CLIENT_ID, issuer=metadata["issuer"], options={"verify_exp": False}, ) return User(**payload) except jwt.PyJWTError as e: raise HTTPException(status_code=401, detail=f"Invalid token: {e}")