"""WebSocket endpoint for real-time task progress updates. Clients connect to ``/api/ws/tasks?token=`` and receive live progress messages published by Celery workers via Redis pub/sub. """ import asyncio import json import logging from typing import Any import jwt import redis.asyncio as aioredis from fastapi import APIRouter, WebSocket, WebSocketDisconnect from api.auth import _verify_authentik_token, _verify_passkey_token, User from api.config import JWT_ISSUER from services import task_service logger = logging.getLogger(__name__) ws_router = APIRouter() # Reuse the broker URL for the async Redis client import os _BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0") async def _authenticate_ws(token: str) -> User | None: """Verify a JWT token using the same logic as api/auth.py.""" try: unverified = jwt.decode( token, options={"verify_signature": False, "verify_exp": False} ) issuer = unverified.get("iss", "") if issuer == JWT_ISSUER: return _verify_passkey_token(token) else: return await _verify_authentik_token(token) except Exception: return None async def _build_task_snapshot(task_id: str) -> dict[str, Any]: """Build a snapshot of a task's current status for the init message.""" return await asyncio.to_thread(_build_task_snapshot_sync, task_id) def _build_task_snapshot_sync(task_id: str) -> dict[str, Any]: """Synchronous helper — runs in a thread to avoid blocking the loop.""" status = task_service.get_task_status(task_id) result: dict[str, Any] = { "task_id": status.task_id, "status": status.status, "progress": status.progress, "processed": status.processed, "total": status.total, "message": status.message, } if status.result and isinstance(status.result, dict): result.update(status.result) return result @ws_router.websocket("/api/ws/tasks") async def ws_task_progress(websocket: WebSocket) -> None: token = websocket.query_params.get("token") if not token: await websocket.close(code=4001, reason="Missing token") return user = await _authenticate_ws(token) if user is None: await websocket.close(code=4003, reason="Invalid token") return await websocket.accept() task_ids = task_service.get_user_tasks(user.email) # Subscribe to Redis pub/sub FIRST so no updates are lost while # building snapshots. Messages that arrive between subscribe and # the init send are buffered by Redis and forwarded afterwards. redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True) pubsub = redis_client.pubsub() subscribed_channels: set[str] = set() for tid in task_ids: channel = f"task_progress:{tid}" await pubsub.subscribe(channel) subscribed_channels.add(channel) # Now build snapshots (safe — pub/sub is already active) # _build_task_snapshot calls synchronous Celery APIs, so run in a # thread to avoid blocking the event loop. snapshots = [] for tid in task_ids: try: snapshots.append(await asyncio.to_thread(_build_task_snapshot_sync, tid)) except Exception: logger.debug("Failed to build snapshot for task %s", tid, exc_info=True) try: await websocket.send_json({"type": "init", "tasks": snapshots}) except Exception: await pubsub.unsubscribe(*subscribed_channels) await pubsub.close() await redis_client.aclose() return async def _forward_pubsub() -> None: """Read from Redis pub/sub and forward to the WebSocket.""" while True: message = await pubsub.get_message( ignore_subscribe_messages=True, timeout=1.0 ) if message and message["type"] == "message": try: data = json.loads(message["data"]) except (json.JSONDecodeError, ValueError): logger.debug("Malformed pubsub message, skipping") continue try: await websocket.send_json({"type": "task_update", **data}) except Exception: break async def _handle_client_messages() -> None: """Read messages from the client (subscribe, ping).""" while True: try: raw = await websocket.receive_text() msg = json.loads(raw) except WebSocketDisconnect: raise except Exception: continue msg_type = msg.get("type") if msg_type == "subscribe": new_task_id = msg.get("task_id") if new_task_id: channel = f"task_progress:{new_task_id}" if channel not in subscribed_channels: await pubsub.subscribe(channel) subscribed_channels.add(channel) # Send current snapshot for the new task try: snapshot = await _build_task_snapshot(new_task_id) await websocket.send_json( {"type": "task_update", **snapshot} ) except Exception: pass elif msg_type == "ping": try: await websocket.send_json({"type": "pong"}) except Exception: break try: ws_tasks = [ asyncio.create_task(_forward_pubsub()), asyncio.create_task(_handle_client_messages()), ] done, pending = await asyncio.wait( ws_tasks, return_when=asyncio.FIRST_COMPLETED ) for t in pending: t.cancel() # Log non-trivial errors from the completed task(s) for t in done: exc = t.exception() if exc and not isinstance(exc, (WebSocketDisconnect, asyncio.CancelledError)): logger.debug("WS task ended with error: %s", exc) except (WebSocketDisconnect, Exception): pass finally: await pubsub.unsubscribe(*subscribed_channels) await pubsub.close() await redis_client.aclose()