From 791b5a9d55a84747ce883afd10e840a304cb48e1 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Mon, 9 Feb 2026 22:48:57 +0000 Subject: [PATCH] Fix real-time task progress by closing WS on pubsub exit and keeping polling active Three interconnected bugs prevented progress updates from reaching the frontend: 1. _forward_pubsub could exit silently while _handle_client_messages kept the WebSocket alive (responding to pings), so the client never detected the broken forwarding path. Replace asyncio.gather with asyncio.wait (FIRST_COMPLETED) so both coroutines are cancelled together. 2. Polling was stopped on WS connect with no fallback if forwarding broke. Now polling runs always alongside WebSocket as a safety net. 3. Redis publish failures in task_progress_publisher were logged at DEBUG and the broken client was reused forever. Log at WARNING and reset the client so the next call reconnects. --- api/ws_routes.py | 59 +++-- frontend/src/hooks/useTaskProgress.ts | 313 ++++++++++++++++++++++++++ services/task_progress_publisher.py | 9 +- 3 files changed, 362 insertions(+), 19 deletions(-) create mode 100644 frontend/src/hooks/useTaskProgress.ts diff --git a/api/ws_routes.py b/api/ws_routes.py index 1e90741..c279858 100644 --- a/api/ws_routes.py +++ b/api/ws_routes.py @@ -42,6 +42,11 @@ async def _authenticate_ws(token: str) -> User | None: async def _build_task_snapshot(task_id: str) -> dict[str, Any]: """Build a snapshot of a task's current status for the init message.""" + return await asyncio.to_thread(_build_task_snapshot_sync, task_id) + + +def _build_task_snapshot_sync(task_id: str) -> dict[str, Any]: + """Synchronous helper — runs in a thread to avoid blocking the loop.""" status = task_service.get_task_status(task_id) result: dict[str, Any] = { "task_id": status.task_id, @@ -70,21 +75,11 @@ async def ws_task_progress(websocket: WebSocket) -> None: await websocket.accept() - # Get user's tasks and send initial snapshot task_ids = task_service.get_user_tasks(user.email) - snapshots = [] - for tid in task_ids: - try: - snapshots.append(await _build_task_snapshot(tid)) - except Exception: - logger.debug("Failed to build snapshot for task %s", tid, exc_info=True) - try: - await websocket.send_json({"type": "init", "tasks": snapshots}) - except Exception: - return - - # Subscribe to Redis pub/sub channels for each task + # Subscribe to Redis pub/sub FIRST so no updates are lost while + # building snapshots. Messages that arrive between subscribe and + # the init send are buffered by Redis and forwarded afterwards. redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True) pubsub = redis_client.pubsub() @@ -94,6 +89,24 @@ async def ws_task_progress(websocket: WebSocket) -> None: await pubsub.subscribe(channel) subscribed_channels.add(channel) + # Now build snapshots (safe — pub/sub is already active) + # _build_task_snapshot calls synchronous Celery APIs, so run in a + # thread to avoid blocking the event loop. + snapshots = [] + for tid in task_ids: + try: + snapshots.append(await asyncio.to_thread(_build_task_snapshot_sync, tid)) + except Exception: + logger.debug("Failed to build snapshot for task %s", tid, exc_info=True) + + try: + await websocket.send_json({"type": "init", "tasks": snapshots}) + except Exception: + await pubsub.unsubscribe(*subscribed_channels) + await pubsub.close() + await redis_client.aclose() + return + async def _forward_pubsub() -> None: """Read from Redis pub/sub and forward to the WebSocket.""" while True: @@ -103,6 +116,10 @@ async def ws_task_progress(websocket: WebSocket) -> None: if message and message["type"] == "message": try: data = json.loads(message["data"]) + except (json.JSONDecodeError, ValueError): + logger.debug("Malformed pubsub message, skipping") + continue + try: await websocket.send_json({"type": "task_update", **data}) except Exception: break @@ -141,10 +158,20 @@ async def ws_task_progress(websocket: WebSocket) -> None: break try: - await asyncio.gather( - _forward_pubsub(), - _handle_client_messages(), + ws_tasks = [ + asyncio.create_task(_forward_pubsub()), + asyncio.create_task(_handle_client_messages()), + ] + done, pending = await asyncio.wait( + ws_tasks, return_when=asyncio.FIRST_COMPLETED ) + for t in pending: + t.cancel() + # Log non-trivial errors from the completed task(s) + for t in done: + exc = t.exception() + if exc and not isinstance(exc, (WebSocketDisconnect, asyncio.CancelledError)): + logger.debug("WS task ended with error: %s", exc) except (WebSocketDisconnect, Exception): pass finally: diff --git a/frontend/src/hooks/useTaskProgress.ts b/frontend/src/hooks/useTaskProgress.ts new file mode 100644 index 0000000..0282210 --- /dev/null +++ b/frontend/src/hooks/useTaskProgress.ts @@ -0,0 +1,313 @@ +import { useCallback, useEffect, useRef, useState } from 'react'; +import type { AuthUser } from '@/auth/types'; +import type { TaskState, TaskStatusResponse, WSMessage } from '@/types'; +import { WS_TASKS_PATH, POLLING_INTERVALS } from '@/constants'; +import { fetchTasksForUser, fetchTaskStatus } from '@/services'; + +const KEEPALIVE_MS = 30_000; +const MAX_RECONNECT_DELAY_MS = 30_000; + +function wsUrl(token: string): string { + const proto = window.location.protocol === 'https:' ? 'wss' : 'ws'; + return `${proto}://${window.location.host}${WS_TASKS_PATH}?token=${encodeURIComponent(token)}`; +} + +function isTerminalStatus(status: string): boolean { + return status === 'SUCCESS' || status === 'FAILURE' || status === 'REVOKED'; +} + +/** Convert an HTTP TaskStatusResponse into the canonical TaskState shape. */ +function httpResponseToTaskState(resp: TaskStatusResponse): TaskState { + const state: TaskState = { + task_id: resp.task_id, + status: resp.status, + progress: resp.progress ?? undefined, + processed: resp.processed ?? undefined, + total: resp.total ?? undefined, + message: resp.message ?? undefined, + }; + + // Parse the result JSON for detailed phase info + if (resp.result) { + try { + const parsed = JSON.parse(resp.result); + if (typeof parsed === 'object' && parsed !== null) { + if (parsed.phase) state.phase = parsed.phase; + if (parsed.message && !state.message) state.message = parsed.message; + if (parsed.progress !== undefined && state.progress === undefined) + state.progress = parsed.progress; + if (parsed.processed !== undefined && state.processed === undefined) + state.processed = parsed.processed; + if (parsed.total !== undefined && state.total === undefined) + state.total = parsed.total; + if (parsed.subqueries_probed !== undefined) + state.subqueries_probed = parsed.subqueries_probed; + if (parsed.subqueries_initial !== undefined) + state.subqueries_initial = parsed.subqueries_initial; + if (parsed.estimated_results !== undefined) + state.estimated_results = parsed.estimated_results; + if (parsed.subqueries_total !== undefined) + state.subqueries_total = parsed.subqueries_total; + if (parsed.subqueries_completed !== undefined) + state.subqueries_completed = parsed.subqueries_completed; + if (parsed.ids_collected !== undefined) + state.ids_collected = parsed.ids_collected; + if (parsed.pages_fetched !== undefined) + state.pages_fetched = parsed.pages_fetched; + if (parsed.fetching_done !== undefined) + state.fetching_done = parsed.fetching_done; + if (parsed.details_fetched !== undefined) + state.details_fetched = parsed.details_fetched; + if (parsed.images_downloaded !== undefined) + state.images_downloaded = parsed.images_downloaded; + if (parsed.ocr_completed !== undefined) + state.ocr_completed = parsed.ocr_completed; + if (parsed.failed !== undefined) state.failed = parsed.failed; + if (parsed.elapsed_seconds !== undefined) + state.elapsed_seconds = parsed.elapsed_seconds; + if (parsed.rate_per_second !== undefined) + state.rate_per_second = parsed.rate_per_second; + if (parsed.eta_seconds !== undefined) + state.eta_seconds = parsed.eta_seconds; + if (parsed.distances_computed !== undefined) + state.distances_computed = parsed.distances_computed; + if (Array.isArray(parsed.logs)) state.logs = parsed.logs; + } + } catch { + // Ignore parse errors + } + } + + return state; +} + +export interface UseTaskProgressReturn { + tasks: Record; + isConnected: boolean; + subscribe: (taskId: string) => void; + cancelTask: (taskId: string) => Promise; + clearAllTasks: () => Promise; +} + +export function useTaskProgress(user: AuthUser | null): UseTaskProgressReturn { + const [tasks, setTasks] = useState>({}); + const [isConnected, setIsConnected] = useState(false); + + const wsRef = useRef(null); + const reconnectAttempt = useRef(0); + const reconnectTimer = useRef | null>(null); + const keepaliveTimer = useRef | null>(null); + const pollingTimer = useRef | null>(null); + const mountedRef = useRef(true); + const pendingSubscriptions = useRef>(new Set()); + // Refs to break the dependency cycle: callbacks read current values + // from refs so they don't need reactive deps that trigger re-creation. + const userRef = useRef(user); + userRef.current = user; + const tasksRef = useRef(tasks); + tasksRef.current = tasks; + + // ---- Polling fallback ---- + + const stopPolling = useCallback(() => { + if (pollingTimer.current) { + clearInterval(pollingTimer.current); + pollingTimer.current = null; + } + }, []); + + const startPolling = useCallback(() => { + const currentUser = userRef.current; + if (!currentUser || pollingTimer.current) return; + + const fetchAndPoll = async () => { + const u = userRef.current; + if (!mountedRef.current || !u) return; + try { + const taskIds = await fetchTasksForUser(u); + if (!mountedRef.current) return; + + // Also include locally-known non-terminal task IDs + const localIds = Object.entries(tasksRef.current) + .filter(([, t]) => !isTerminalStatus(t.status)) + .map(([id]) => id); + + const allIds = [...new Set([...taskIds, ...localIds])]; + if (allIds.length === 0) return; + + const results = await Promise.allSettled( + allIds.map((tid) => fetchTaskStatus(u, tid)), + ); + + if (!mountedRef.current) return; + + setTasks((prev) => { + const next = { ...prev }; + for (const r of results) { + if (r.status === 'fulfilled') { + const state = httpResponseToTaskState(r.value); + next[state.task_id] = { ...prev[state.task_id], ...state }; + } + } + return next; + }); + } catch { + // Ignore polling errors + } + }; + + fetchAndPoll(); + pollingTimer.current = setInterval(fetchAndPoll, POLLING_INTERVALS.TASK_STATUS_MS); + }, []); // No reactive deps — reads from refs + + // ---- WebSocket connection ---- + // Stable connect function — only re-created when user identity changes. + const connect = useCallback(() => { + if (!user) return; + + const ws = new WebSocket(wsUrl(user.accessToken)); + wsRef.current = ws; + + ws.onopen = () => { + if (!mountedRef.current) return; + setIsConnected(true); + reconnectAttempt.current = 0; + + // Start keepalive pings + keepaliveTimer.current = setInterval(() => { + if (ws.readyState === WebSocket.OPEN) { + ws.send(JSON.stringify({ type: 'ping' })); + } + }, KEEPALIVE_MS); + + // Send pending subscriptions + for (const taskId of pendingSubscriptions.current) { + ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId })); + } + pendingSubscriptions.current.clear(); + }; + + ws.onmessage = (event) => { + if (!mountedRef.current) return; + try { + const msg: WSMessage = JSON.parse(event.data); + + if (msg.type === 'init') { + const initial: Record = {}; + for (const t of msg.tasks) { + initial[t.task_id] = t; + } + setTasks(initial); + } else if (msg.type === 'task_update') { + const { type: _, ...taskData } = msg; + setTasks((prev) => ({ + ...prev, + [msg.task_id]: { ...prev[msg.task_id], ...taskData } as TaskState, + })); + } + } catch { + // Ignore malformed messages + } + }; + + ws.onclose = () => { + if (!mountedRef.current) return; + setIsConnected(false); + + if (keepaliveTimer.current) { + clearInterval(keepaliveTimer.current); + keepaliveTimer.current = null; + } + + // Fallback to polling while WS is down + startPolling(); + + // Exponential backoff reconnect + const delay = Math.min( + 1000 * 2 ** reconnectAttempt.current, + MAX_RECONNECT_DELAY_MS, + ); + reconnectAttempt.current += 1; + reconnectTimer.current = setTimeout(() => { + if (mountedRef.current && userRef.current) connect(); + }, delay); + }; + + ws.onerror = () => { + // onclose will fire after this, triggering reconnect + }; + }, [user, startPolling]); // startPolling is stable (no deps) + + // Mount/unmount + reconnect when user changes + useEffect(() => { + mountedRef.current = true; + connect(); + startPolling(); // Always run polling as safety net alongside WebSocket + + return () => { + mountedRef.current = false; + if (reconnectTimer.current) { + clearTimeout(reconnectTimer.current); + reconnectTimer.current = null; + } + if (keepaliveTimer.current) { + clearInterval(keepaliveTimer.current); + keepaliveTimer.current = null; + } + stopPolling(); + if (wsRef.current) { + wsRef.current.close(); + wsRef.current = null; + } + }; + }, [connect, startPolling, stopPolling]); + + // ---- Public API ---- + + const subscribe = useCallback((taskId: string) => { + const ws = wsRef.current; + if (ws && ws.readyState === WebSocket.OPEN) { + ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId })); + } else { + pendingSubscriptions.current.add(taskId); + } + }, []); + + const cancelTask = useCallback( + async (taskId: string): Promise => { + if (!user) return false; + try { + const { cancelTask: cancel } = await import('@/services'); + const result = await cancel(user, taskId); + if (result.success) { + setTasks((prev) => ({ + ...prev, + [taskId]: { ...prev[taskId], task_id: taskId, status: 'REVOKED' }, + })); + return true; + } + return false; + } catch { + return false; + } + }, + [user], + ); + + const clearAllTasks = useCallback(async (): Promise => { + if (!user) return false; + try { + const { clearAllTasks: clearAll } = await import('@/services'); + const result = await clearAll(user); + if (result.success) { + setTasks({}); + return true; + } + return false; + } catch { + return false; + } + }, [user]); + + return { tasks, isConnected, subscribe, cancelTask, clearAllTasks }; +} diff --git a/services/task_progress_publisher.py b/services/task_progress_publisher.py index 52a25a3..d746211 100644 --- a/services/task_progress_publisher.py +++ b/services/task_progress_publisher.py @@ -35,8 +35,9 @@ def publish_task_progress(task_id: str, state: str, meta: dict[str, Any]) -> Non state: Celery state string (e.g. 'PROGRESS', 'SUCCESS'). meta: Metadata dict (progress, phase, logs, counters, etc.). - Failures are caught and logged at DEBUG level so they never break the - critical task execution path. + Failures are caught and logged at WARNING level so they never break the + critical task execution path. The Redis client is reset on failure so + subsequent calls can reconnect. """ try: client = _get_redis_client() @@ -47,4 +48,6 @@ def publish_task_progress(task_id: str, state: str, meta: dict[str, Any]) -> Non }) client.publish(f"task_progress:{task_id}", payload) except Exception: - logger.debug("Failed to publish task progress for %s", task_id, exc_info=True) + logger.warning("Failed to publish task progress for %s", task_id, exc_info=True) + # Reset client so next call creates a fresh connection + _redis_client = None