Fix real-time task progress by closing WS on pubsub exit and keeping polling active
Three interconnected bugs prevented progress updates from reaching the frontend: 1. _forward_pubsub could exit silently while _handle_client_messages kept the WebSocket alive (responding to pings), so the client never detected the broken forwarding path. Replace asyncio.gather with asyncio.wait (FIRST_COMPLETED) so both coroutines are cancelled together. 2. Polling was stopped on WS connect with no fallback if forwarding broke. Now polling runs always alongside WebSocket as a safety net. 3. Redis publish failures in task_progress_publisher were logged at DEBUG and the broken client was reused forever. Log at WARNING and reset the client so the next call reconnects.
This commit is contained in:
parent
8d52bdf99d
commit
791b5a9d55
3 changed files with 362 additions and 19 deletions
|
|
@ -42,6 +42,11 @@ async def _authenticate_ws(token: str) -> User | None:
|
|||
|
||||
async def _build_task_snapshot(task_id: str) -> dict[str, Any]:
|
||||
"""Build a snapshot of a task's current status for the init message."""
|
||||
return await asyncio.to_thread(_build_task_snapshot_sync, task_id)
|
||||
|
||||
|
||||
def _build_task_snapshot_sync(task_id: str) -> dict[str, Any]:
|
||||
"""Synchronous helper — runs in a thread to avoid blocking the loop."""
|
||||
status = task_service.get_task_status(task_id)
|
||||
result: dict[str, Any] = {
|
||||
"task_id": status.task_id,
|
||||
|
|
@ -70,21 +75,11 @@ async def ws_task_progress(websocket: WebSocket) -> None:
|
|||
|
||||
await websocket.accept()
|
||||
|
||||
# Get user's tasks and send initial snapshot
|
||||
task_ids = task_service.get_user_tasks(user.email)
|
||||
snapshots = []
|
||||
for tid in task_ids:
|
||||
try:
|
||||
snapshots.append(await _build_task_snapshot(tid))
|
||||
except Exception:
|
||||
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
|
||||
|
||||
try:
|
||||
await websocket.send_json({"type": "init", "tasks": snapshots})
|
||||
except Exception:
|
||||
return
|
||||
|
||||
# Subscribe to Redis pub/sub channels for each task
|
||||
# Subscribe to Redis pub/sub FIRST so no updates are lost while
|
||||
# building snapshots. Messages that arrive between subscribe and
|
||||
# the init send are buffered by Redis and forwarded afterwards.
|
||||
redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True)
|
||||
pubsub = redis_client.pubsub()
|
||||
|
||||
|
|
@ -94,6 +89,24 @@ async def ws_task_progress(websocket: WebSocket) -> None:
|
|||
await pubsub.subscribe(channel)
|
||||
subscribed_channels.add(channel)
|
||||
|
||||
# Now build snapshots (safe — pub/sub is already active)
|
||||
# _build_task_snapshot calls synchronous Celery APIs, so run in a
|
||||
# thread to avoid blocking the event loop.
|
||||
snapshots = []
|
||||
for tid in task_ids:
|
||||
try:
|
||||
snapshots.append(await asyncio.to_thread(_build_task_snapshot_sync, tid))
|
||||
except Exception:
|
||||
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
|
||||
|
||||
try:
|
||||
await websocket.send_json({"type": "init", "tasks": snapshots})
|
||||
except Exception:
|
||||
await pubsub.unsubscribe(*subscribed_channels)
|
||||
await pubsub.close()
|
||||
await redis_client.aclose()
|
||||
return
|
||||
|
||||
async def _forward_pubsub() -> None:
|
||||
"""Read from Redis pub/sub and forward to the WebSocket."""
|
||||
while True:
|
||||
|
|
@ -103,6 +116,10 @@ async def ws_task_progress(websocket: WebSocket) -> None:
|
|||
if message and message["type"] == "message":
|
||||
try:
|
||||
data = json.loads(message["data"])
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("Malformed pubsub message, skipping")
|
||||
continue
|
||||
try:
|
||||
await websocket.send_json({"type": "task_update", **data})
|
||||
except Exception:
|
||||
break
|
||||
|
|
@ -141,10 +158,20 @@ async def ws_task_progress(websocket: WebSocket) -> None:
|
|||
break
|
||||
|
||||
try:
|
||||
await asyncio.gather(
|
||||
_forward_pubsub(),
|
||||
_handle_client_messages(),
|
||||
ws_tasks = [
|
||||
asyncio.create_task(_forward_pubsub()),
|
||||
asyncio.create_task(_handle_client_messages()),
|
||||
]
|
||||
done, pending = await asyncio.wait(
|
||||
ws_tasks, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
# Log non-trivial errors from the completed task(s)
|
||||
for t in done:
|
||||
exc = t.exception()
|
||||
if exc and not isinstance(exc, (WebSocketDisconnect, asyncio.CancelledError)):
|
||||
logger.debug("WS task ended with error: %s", exc)
|
||||
except (WebSocketDisconnect, Exception):
|
||||
pass
|
||||
finally:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue