Three interconnected bugs prevented progress updates from reaching the frontend: 1. _forward_pubsub could exit silently while _handle_client_messages kept the WebSocket alive (responding to pings), so the client never detected the broken forwarding path. Replace asyncio.gather with asyncio.wait (FIRST_COMPLETED) so both coroutines are cancelled together. 2. Polling was stopped on WS connect with no fallback if forwarding broke. Now polling runs always alongside WebSocket as a safety net. 3. Redis publish failures in task_progress_publisher were logged at DEBUG and the broken client was reused forever. Log at WARNING and reset the client so the next call reconnects.
180 lines
6.2 KiB
Python
180 lines
6.2 KiB
Python
"""WebSocket endpoint for real-time task progress updates.
|
|
|
|
Clients connect to ``/api/ws/tasks?token=<jwt>`` and receive live progress
|
|
messages published by Celery workers via Redis pub/sub.
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
import jwt
|
|
import redis.asyncio as aioredis
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
|
|
from api.auth import _verify_authentik_token, _verify_passkey_token, User
|
|
from api.config import JWT_ISSUER
|
|
from services import task_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ws_router = APIRouter()
|
|
|
|
# Reuse the broker URL for the async Redis client
|
|
import os
|
|
_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0")
|
|
|
|
|
|
async def _authenticate_ws(token: str) -> User | None:
|
|
"""Verify a JWT token using the same logic as api/auth.py."""
|
|
try:
|
|
unverified = jwt.decode(
|
|
token, options={"verify_signature": False, "verify_exp": False}
|
|
)
|
|
issuer = unverified.get("iss", "")
|
|
if issuer == JWT_ISSUER:
|
|
return _verify_passkey_token(token)
|
|
else:
|
|
return await _verify_authentik_token(token)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
async def _build_task_snapshot(task_id: str) -> dict[str, Any]:
|
|
"""Build a snapshot of a task's current status for the init message."""
|
|
return await asyncio.to_thread(_build_task_snapshot_sync, task_id)
|
|
|
|
|
|
def _build_task_snapshot_sync(task_id: str) -> dict[str, Any]:
|
|
"""Synchronous helper — runs in a thread to avoid blocking the loop."""
|
|
status = task_service.get_task_status(task_id)
|
|
result: dict[str, Any] = {
|
|
"task_id": status.task_id,
|
|
"status": status.status,
|
|
"progress": status.progress,
|
|
"processed": status.processed,
|
|
"total": status.total,
|
|
"message": status.message,
|
|
}
|
|
if status.result and isinstance(status.result, dict):
|
|
result.update(status.result)
|
|
return result
|
|
|
|
|
|
@ws_router.websocket("/api/ws/tasks")
|
|
async def ws_task_progress(websocket: WebSocket) -> None:
|
|
token = websocket.query_params.get("token")
|
|
if not token:
|
|
await websocket.close(code=4001, reason="Missing token")
|
|
return
|
|
|
|
user = await _authenticate_ws(token)
|
|
if user is None:
|
|
await websocket.close(code=4003, reason="Invalid token")
|
|
return
|
|
|
|
await websocket.accept()
|
|
|
|
task_ids = task_service.get_user_tasks(user.email)
|
|
|
|
# Subscribe to Redis pub/sub FIRST so no updates are lost while
|
|
# building snapshots. Messages that arrive between subscribe and
|
|
# the init send are buffered by Redis and forwarded afterwards.
|
|
redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True)
|
|
pubsub = redis_client.pubsub()
|
|
|
|
subscribed_channels: set[str] = set()
|
|
for tid in task_ids:
|
|
channel = f"task_progress:{tid}"
|
|
await pubsub.subscribe(channel)
|
|
subscribed_channels.add(channel)
|
|
|
|
# Now build snapshots (safe — pub/sub is already active)
|
|
# _build_task_snapshot calls synchronous Celery APIs, so run in a
|
|
# thread to avoid blocking the event loop.
|
|
snapshots = []
|
|
for tid in task_ids:
|
|
try:
|
|
snapshots.append(await asyncio.to_thread(_build_task_snapshot_sync, tid))
|
|
except Exception:
|
|
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
|
|
|
|
try:
|
|
await websocket.send_json({"type": "init", "tasks": snapshots})
|
|
except Exception:
|
|
await pubsub.unsubscribe(*subscribed_channels)
|
|
await pubsub.close()
|
|
await redis_client.aclose()
|
|
return
|
|
|
|
async def _forward_pubsub() -> None:
|
|
"""Read from Redis pub/sub and forward to the WebSocket."""
|
|
while True:
|
|
message = await pubsub.get_message(
|
|
ignore_subscribe_messages=True, timeout=1.0
|
|
)
|
|
if message and message["type"] == "message":
|
|
try:
|
|
data = json.loads(message["data"])
|
|
except (json.JSONDecodeError, ValueError):
|
|
logger.debug("Malformed pubsub message, skipping")
|
|
continue
|
|
try:
|
|
await websocket.send_json({"type": "task_update", **data})
|
|
except Exception:
|
|
break
|
|
|
|
async def _handle_client_messages() -> None:
|
|
"""Read messages from the client (subscribe, ping)."""
|
|
while True:
|
|
try:
|
|
raw = await websocket.receive_text()
|
|
msg = json.loads(raw)
|
|
except WebSocketDisconnect:
|
|
raise
|
|
except Exception:
|
|
continue
|
|
|
|
msg_type = msg.get("type")
|
|
if msg_type == "subscribe":
|
|
new_task_id = msg.get("task_id")
|
|
if new_task_id:
|
|
channel = f"task_progress:{new_task_id}"
|
|
if channel not in subscribed_channels:
|
|
await pubsub.subscribe(channel)
|
|
subscribed_channels.add(channel)
|
|
# Send current snapshot for the new task
|
|
try:
|
|
snapshot = await _build_task_snapshot(new_task_id)
|
|
await websocket.send_json(
|
|
{"type": "task_update", **snapshot}
|
|
)
|
|
except Exception:
|
|
pass
|
|
elif msg_type == "ping":
|
|
try:
|
|
await websocket.send_json({"type": "pong"})
|
|
except Exception:
|
|
break
|
|
|
|
try:
|
|
ws_tasks = [
|
|
asyncio.create_task(_forward_pubsub()),
|
|
asyncio.create_task(_handle_client_messages()),
|
|
]
|
|
done, pending = await asyncio.wait(
|
|
ws_tasks, return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
for t in pending:
|
|
t.cancel()
|
|
# Log non-trivial errors from the completed task(s)
|
|
for t in done:
|
|
exc = t.exception()
|
|
if exc and not isinstance(exc, (WebSocketDisconnect, asyncio.CancelledError)):
|
|
logger.debug("WS task ended with error: %s", exc)
|
|
except (WebSocketDisconnect, Exception):
|
|
pass
|
|
finally:
|
|
await pubsub.unsubscribe(*subscribed_channels)
|
|
await pubsub.close()
|
|
await redis_client.aclose()
|