Fix real-time task progress by closing WS on pubsub exit and keeping polling active

Three interconnected bugs prevented progress updates from reaching the frontend:

1. _forward_pubsub could exit silently while _handle_client_messages kept
   the WebSocket alive (responding to pings), so the client never detected
   the broken forwarding path. Replace asyncio.gather with asyncio.wait
   (FIRST_COMPLETED) so both coroutines are cancelled together.

2. Polling was stopped on WS connect with no fallback if forwarding broke.
   Now polling runs always alongside WebSocket as a safety net.

3. Redis publish failures in task_progress_publisher were logged at DEBUG
   and the broken client was reused forever. Log at WARNING and reset the
   client so the next call reconnects.
This commit is contained in:
Viktor Barzin 2026-02-09 22:48:57 +00:00
parent 8d52bdf99d
commit 791b5a9d55
No known key found for this signature in database
GPG key ID: 0EB088298288D958
3 changed files with 362 additions and 19 deletions

View file

@ -42,6 +42,11 @@ async def _authenticate_ws(token: str) -> User | None:
async def _build_task_snapshot(task_id: str) -> dict[str, Any]:
"""Build a snapshot of a task's current status for the init message."""
return await asyncio.to_thread(_build_task_snapshot_sync, task_id)
def _build_task_snapshot_sync(task_id: str) -> dict[str, Any]:
"""Synchronous helper — runs in a thread to avoid blocking the loop."""
status = task_service.get_task_status(task_id)
result: dict[str, Any] = {
"task_id": status.task_id,
@ -70,21 +75,11 @@ async def ws_task_progress(websocket: WebSocket) -> None:
await websocket.accept()
# Get user's tasks and send initial snapshot
task_ids = task_service.get_user_tasks(user.email)
snapshots = []
for tid in task_ids:
try:
snapshots.append(await _build_task_snapshot(tid))
except Exception:
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
try:
await websocket.send_json({"type": "init", "tasks": snapshots})
except Exception:
return
# Subscribe to Redis pub/sub channels for each task
# Subscribe to Redis pub/sub FIRST so no updates are lost while
# building snapshots. Messages that arrive between subscribe and
# the init send are buffered by Redis and forwarded afterwards.
redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True)
pubsub = redis_client.pubsub()
@ -94,6 +89,24 @@ async def ws_task_progress(websocket: WebSocket) -> None:
await pubsub.subscribe(channel)
subscribed_channels.add(channel)
# Now build snapshots (safe — pub/sub is already active)
# _build_task_snapshot calls synchronous Celery APIs, so run in a
# thread to avoid blocking the event loop.
snapshots = []
for tid in task_ids:
try:
snapshots.append(await asyncio.to_thread(_build_task_snapshot_sync, tid))
except Exception:
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
try:
await websocket.send_json({"type": "init", "tasks": snapshots})
except Exception:
await pubsub.unsubscribe(*subscribed_channels)
await pubsub.close()
await redis_client.aclose()
return
async def _forward_pubsub() -> None:
"""Read from Redis pub/sub and forward to the WebSocket."""
while True:
@ -103,6 +116,10 @@ async def ws_task_progress(websocket: WebSocket) -> None:
if message and message["type"] == "message":
try:
data = json.loads(message["data"])
except (json.JSONDecodeError, ValueError):
logger.debug("Malformed pubsub message, skipping")
continue
try:
await websocket.send_json({"type": "task_update", **data})
except Exception:
break
@ -141,10 +158,20 @@ async def ws_task_progress(websocket: WebSocket) -> None:
break
try:
await asyncio.gather(
_forward_pubsub(),
_handle_client_messages(),
ws_tasks = [
asyncio.create_task(_forward_pubsub()),
asyncio.create_task(_handle_client_messages()),
]
done, pending = await asyncio.wait(
ws_tasks, return_when=asyncio.FIRST_COMPLETED
)
for t in pending:
t.cancel()
# Log non-trivial errors from the completed task(s)
for t in done:
exc = t.exception()
if exc and not isinstance(exc, (WebSocketDisconnect, asyncio.CancelledError)):
logger.debug("WS task ended with error: %s", exc)
except (WebSocketDisconnect, Exception):
pass
finally:

View file

@ -0,0 +1,313 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import type { AuthUser } from '@/auth/types';
import type { TaskState, TaskStatusResponse, WSMessage } from '@/types';
import { WS_TASKS_PATH, POLLING_INTERVALS } from '@/constants';
import { fetchTasksForUser, fetchTaskStatus } from '@/services';
const KEEPALIVE_MS = 30_000;
const MAX_RECONNECT_DELAY_MS = 30_000;
function wsUrl(token: string): string {
const proto = window.location.protocol === 'https:' ? 'wss' : 'ws';
return `${proto}://${window.location.host}${WS_TASKS_PATH}?token=${encodeURIComponent(token)}`;
}
function isTerminalStatus(status: string): boolean {
return status === 'SUCCESS' || status === 'FAILURE' || status === 'REVOKED';
}
/** Convert an HTTP TaskStatusResponse into the canonical TaskState shape. */
function httpResponseToTaskState(resp: TaskStatusResponse): TaskState {
const state: TaskState = {
task_id: resp.task_id,
status: resp.status,
progress: resp.progress ?? undefined,
processed: resp.processed ?? undefined,
total: resp.total ?? undefined,
message: resp.message ?? undefined,
};
// Parse the result JSON for detailed phase info
if (resp.result) {
try {
const parsed = JSON.parse(resp.result);
if (typeof parsed === 'object' && parsed !== null) {
if (parsed.phase) state.phase = parsed.phase;
if (parsed.message && !state.message) state.message = parsed.message;
if (parsed.progress !== undefined && state.progress === undefined)
state.progress = parsed.progress;
if (parsed.processed !== undefined && state.processed === undefined)
state.processed = parsed.processed;
if (parsed.total !== undefined && state.total === undefined)
state.total = parsed.total;
if (parsed.subqueries_probed !== undefined)
state.subqueries_probed = parsed.subqueries_probed;
if (parsed.subqueries_initial !== undefined)
state.subqueries_initial = parsed.subqueries_initial;
if (parsed.estimated_results !== undefined)
state.estimated_results = parsed.estimated_results;
if (parsed.subqueries_total !== undefined)
state.subqueries_total = parsed.subqueries_total;
if (parsed.subqueries_completed !== undefined)
state.subqueries_completed = parsed.subqueries_completed;
if (parsed.ids_collected !== undefined)
state.ids_collected = parsed.ids_collected;
if (parsed.pages_fetched !== undefined)
state.pages_fetched = parsed.pages_fetched;
if (parsed.fetching_done !== undefined)
state.fetching_done = parsed.fetching_done;
if (parsed.details_fetched !== undefined)
state.details_fetched = parsed.details_fetched;
if (parsed.images_downloaded !== undefined)
state.images_downloaded = parsed.images_downloaded;
if (parsed.ocr_completed !== undefined)
state.ocr_completed = parsed.ocr_completed;
if (parsed.failed !== undefined) state.failed = parsed.failed;
if (parsed.elapsed_seconds !== undefined)
state.elapsed_seconds = parsed.elapsed_seconds;
if (parsed.rate_per_second !== undefined)
state.rate_per_second = parsed.rate_per_second;
if (parsed.eta_seconds !== undefined)
state.eta_seconds = parsed.eta_seconds;
if (parsed.distances_computed !== undefined)
state.distances_computed = parsed.distances_computed;
if (Array.isArray(parsed.logs)) state.logs = parsed.logs;
}
} catch {
// Ignore parse errors
}
}
return state;
}
export interface UseTaskProgressReturn {
tasks: Record<string, TaskState>;
isConnected: boolean;
subscribe: (taskId: string) => void;
cancelTask: (taskId: string) => Promise<boolean>;
clearAllTasks: () => Promise<boolean>;
}
export function useTaskProgress(user: AuthUser | null): UseTaskProgressReturn {
const [tasks, setTasks] = useState<Record<string, TaskState>>({});
const [isConnected, setIsConnected] = useState(false);
const wsRef = useRef<WebSocket | null>(null);
const reconnectAttempt = useRef(0);
const reconnectTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
const keepaliveTimer = useRef<ReturnType<typeof setInterval> | null>(null);
const pollingTimer = useRef<ReturnType<typeof setInterval> | null>(null);
const mountedRef = useRef(true);
const pendingSubscriptions = useRef<Set<string>>(new Set());
// Refs to break the dependency cycle: callbacks read current values
// from refs so they don't need reactive deps that trigger re-creation.
const userRef = useRef(user);
userRef.current = user;
const tasksRef = useRef(tasks);
tasksRef.current = tasks;
// ---- Polling fallback ----
const stopPolling = useCallback(() => {
if (pollingTimer.current) {
clearInterval(pollingTimer.current);
pollingTimer.current = null;
}
}, []);
const startPolling = useCallback(() => {
const currentUser = userRef.current;
if (!currentUser || pollingTimer.current) return;
const fetchAndPoll = async () => {
const u = userRef.current;
if (!mountedRef.current || !u) return;
try {
const taskIds = await fetchTasksForUser(u);
if (!mountedRef.current) return;
// Also include locally-known non-terminal task IDs
const localIds = Object.entries(tasksRef.current)
.filter(([, t]) => !isTerminalStatus(t.status))
.map(([id]) => id);
const allIds = [...new Set([...taskIds, ...localIds])];
if (allIds.length === 0) return;
const results = await Promise.allSettled(
allIds.map((tid) => fetchTaskStatus(u, tid)),
);
if (!mountedRef.current) return;
setTasks((prev) => {
const next = { ...prev };
for (const r of results) {
if (r.status === 'fulfilled') {
const state = httpResponseToTaskState(r.value);
next[state.task_id] = { ...prev[state.task_id], ...state };
}
}
return next;
});
} catch {
// Ignore polling errors
}
};
fetchAndPoll();
pollingTimer.current = setInterval(fetchAndPoll, POLLING_INTERVALS.TASK_STATUS_MS);
}, []); // No reactive deps — reads from refs
// ---- WebSocket connection ----
// Stable connect function — only re-created when user identity changes.
const connect = useCallback(() => {
if (!user) return;
const ws = new WebSocket(wsUrl(user.accessToken));
wsRef.current = ws;
ws.onopen = () => {
if (!mountedRef.current) return;
setIsConnected(true);
reconnectAttempt.current = 0;
// Start keepalive pings
keepaliveTimer.current = setInterval(() => {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: 'ping' }));
}
}, KEEPALIVE_MS);
// Send pending subscriptions
for (const taskId of pendingSubscriptions.current) {
ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId }));
}
pendingSubscriptions.current.clear();
};
ws.onmessage = (event) => {
if (!mountedRef.current) return;
try {
const msg: WSMessage = JSON.parse(event.data);
if (msg.type === 'init') {
const initial: Record<string, TaskState> = {};
for (const t of msg.tasks) {
initial[t.task_id] = t;
}
setTasks(initial);
} else if (msg.type === 'task_update') {
const { type: _, ...taskData } = msg;
setTasks((prev) => ({
...prev,
[msg.task_id]: { ...prev[msg.task_id], ...taskData } as TaskState,
}));
}
} catch {
// Ignore malformed messages
}
};
ws.onclose = () => {
if (!mountedRef.current) return;
setIsConnected(false);
if (keepaliveTimer.current) {
clearInterval(keepaliveTimer.current);
keepaliveTimer.current = null;
}
// Fallback to polling while WS is down
startPolling();
// Exponential backoff reconnect
const delay = Math.min(
1000 * 2 ** reconnectAttempt.current,
MAX_RECONNECT_DELAY_MS,
);
reconnectAttempt.current += 1;
reconnectTimer.current = setTimeout(() => {
if (mountedRef.current && userRef.current) connect();
}, delay);
};
ws.onerror = () => {
// onclose will fire after this, triggering reconnect
};
}, [user, startPolling]); // startPolling is stable (no deps)
// Mount/unmount + reconnect when user changes
useEffect(() => {
mountedRef.current = true;
connect();
startPolling(); // Always run polling as safety net alongside WebSocket
return () => {
mountedRef.current = false;
if (reconnectTimer.current) {
clearTimeout(reconnectTimer.current);
reconnectTimer.current = null;
}
if (keepaliveTimer.current) {
clearInterval(keepaliveTimer.current);
keepaliveTimer.current = null;
}
stopPolling();
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
}
};
}, [connect, startPolling, stopPolling]);
// ---- Public API ----
const subscribe = useCallback((taskId: string) => {
const ws = wsRef.current;
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId }));
} else {
pendingSubscriptions.current.add(taskId);
}
}, []);
const cancelTask = useCallback(
async (taskId: string): Promise<boolean> => {
if (!user) return false;
try {
const { cancelTask: cancel } = await import('@/services');
const result = await cancel(user, taskId);
if (result.success) {
setTasks((prev) => ({
...prev,
[taskId]: { ...prev[taskId], task_id: taskId, status: 'REVOKED' },
}));
return true;
}
return false;
} catch {
return false;
}
},
[user],
);
const clearAllTasks = useCallback(async (): Promise<boolean> => {
if (!user) return false;
try {
const { clearAllTasks: clearAll } = await import('@/services');
const result = await clearAll(user);
if (result.success) {
setTasks({});
return true;
}
return false;
} catch {
return false;
}
}, [user]);
return { tasks, isConnected, subscribe, cancelTask, clearAllTasks };
}

View file

@ -35,8 +35,9 @@ def publish_task_progress(task_id: str, state: str, meta: dict[str, Any]) -> Non
state: Celery state string (e.g. 'PROGRESS', 'SUCCESS').
meta: Metadata dict (progress, phase, logs, counters, etc.).
Failures are caught and logged at DEBUG level so they never break the
critical task execution path.
Failures are caught and logged at WARNING level so they never break the
critical task execution path. The Redis client is reset on failure so
subsequent calls can reconnect.
"""
try:
client = _get_redis_client()
@ -47,4 +48,6 @@ def publish_task_progress(task_id: str, state: str, meta: dict[str, Any]) -> Non
})
client.publish(f"task_progress:{task_id}", payload)
except Exception:
logger.debug("Failed to publish task progress for %s", task_id, exc_info=True)
logger.warning("Failed to publish task progress for %s", task_id, exc_info=True)
# Reset client so next call creates a fresh connection
_redis_client = None