Fix real-time task progress by closing WS on pubsub exit and keeping polling active
Three interconnected bugs prevented progress updates from reaching the frontend: 1. _forward_pubsub could exit silently while _handle_client_messages kept the WebSocket alive (responding to pings), so the client never detected the broken forwarding path. Replace asyncio.gather with asyncio.wait (FIRST_COMPLETED) so both coroutines are cancelled together. 2. Polling was stopped on WS connect with no fallback if forwarding broke. Now polling runs always alongside WebSocket as a safety net. 3. Redis publish failures in task_progress_publisher were logged at DEBUG and the broken client was reused forever. Log at WARNING and reset the client so the next call reconnects.
This commit is contained in:
parent
8d52bdf99d
commit
791b5a9d55
3 changed files with 362 additions and 19 deletions
|
|
@ -42,6 +42,11 @@ async def _authenticate_ws(token: str) -> User | None:
|
|||
|
||||
async def _build_task_snapshot(task_id: str) -> dict[str, Any]:
|
||||
"""Build a snapshot of a task's current status for the init message."""
|
||||
return await asyncio.to_thread(_build_task_snapshot_sync, task_id)
|
||||
|
||||
|
||||
def _build_task_snapshot_sync(task_id: str) -> dict[str, Any]:
|
||||
"""Synchronous helper — runs in a thread to avoid blocking the loop."""
|
||||
status = task_service.get_task_status(task_id)
|
||||
result: dict[str, Any] = {
|
||||
"task_id": status.task_id,
|
||||
|
|
@ -70,21 +75,11 @@ async def ws_task_progress(websocket: WebSocket) -> None:
|
|||
|
||||
await websocket.accept()
|
||||
|
||||
# Get user's tasks and send initial snapshot
|
||||
task_ids = task_service.get_user_tasks(user.email)
|
||||
snapshots = []
|
||||
for tid in task_ids:
|
||||
try:
|
||||
snapshots.append(await _build_task_snapshot(tid))
|
||||
except Exception:
|
||||
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
|
||||
|
||||
try:
|
||||
await websocket.send_json({"type": "init", "tasks": snapshots})
|
||||
except Exception:
|
||||
return
|
||||
|
||||
# Subscribe to Redis pub/sub channels for each task
|
||||
# Subscribe to Redis pub/sub FIRST so no updates are lost while
|
||||
# building snapshots. Messages that arrive between subscribe and
|
||||
# the init send are buffered by Redis and forwarded afterwards.
|
||||
redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True)
|
||||
pubsub = redis_client.pubsub()
|
||||
|
||||
|
|
@ -94,6 +89,24 @@ async def ws_task_progress(websocket: WebSocket) -> None:
|
|||
await pubsub.subscribe(channel)
|
||||
subscribed_channels.add(channel)
|
||||
|
||||
# Now build snapshots (safe — pub/sub is already active)
|
||||
# _build_task_snapshot calls synchronous Celery APIs, so run in a
|
||||
# thread to avoid blocking the event loop.
|
||||
snapshots = []
|
||||
for tid in task_ids:
|
||||
try:
|
||||
snapshots.append(await asyncio.to_thread(_build_task_snapshot_sync, tid))
|
||||
except Exception:
|
||||
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
|
||||
|
||||
try:
|
||||
await websocket.send_json({"type": "init", "tasks": snapshots})
|
||||
except Exception:
|
||||
await pubsub.unsubscribe(*subscribed_channels)
|
||||
await pubsub.close()
|
||||
await redis_client.aclose()
|
||||
return
|
||||
|
||||
async def _forward_pubsub() -> None:
|
||||
"""Read from Redis pub/sub and forward to the WebSocket."""
|
||||
while True:
|
||||
|
|
@ -103,6 +116,10 @@ async def ws_task_progress(websocket: WebSocket) -> None:
|
|||
if message and message["type"] == "message":
|
||||
try:
|
||||
data = json.loads(message["data"])
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("Malformed pubsub message, skipping")
|
||||
continue
|
||||
try:
|
||||
await websocket.send_json({"type": "task_update", **data})
|
||||
except Exception:
|
||||
break
|
||||
|
|
@ -141,10 +158,20 @@ async def ws_task_progress(websocket: WebSocket) -> None:
|
|||
break
|
||||
|
||||
try:
|
||||
await asyncio.gather(
|
||||
_forward_pubsub(),
|
||||
_handle_client_messages(),
|
||||
ws_tasks = [
|
||||
asyncio.create_task(_forward_pubsub()),
|
||||
asyncio.create_task(_handle_client_messages()),
|
||||
]
|
||||
done, pending = await asyncio.wait(
|
||||
ws_tasks, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
# Log non-trivial errors from the completed task(s)
|
||||
for t in done:
|
||||
exc = t.exception()
|
||||
if exc and not isinstance(exc, (WebSocketDisconnect, asyncio.CancelledError)):
|
||||
logger.debug("WS task ended with error: %s", exc)
|
||||
except (WebSocketDisconnect, Exception):
|
||||
pass
|
||||
finally:
|
||||
|
|
|
|||
313
frontend/src/hooks/useTaskProgress.ts
Normal file
313
frontend/src/hooks/useTaskProgress.ts
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import type { AuthUser } from '@/auth/types';
|
||||
import type { TaskState, TaskStatusResponse, WSMessage } from '@/types';
|
||||
import { WS_TASKS_PATH, POLLING_INTERVALS } from '@/constants';
|
||||
import { fetchTasksForUser, fetchTaskStatus } from '@/services';
|
||||
|
||||
const KEEPALIVE_MS = 30_000;
|
||||
const MAX_RECONNECT_DELAY_MS = 30_000;
|
||||
|
||||
function wsUrl(token: string): string {
|
||||
const proto = window.location.protocol === 'https:' ? 'wss' : 'ws';
|
||||
return `${proto}://${window.location.host}${WS_TASKS_PATH}?token=${encodeURIComponent(token)}`;
|
||||
}
|
||||
|
||||
function isTerminalStatus(status: string): boolean {
|
||||
return status === 'SUCCESS' || status === 'FAILURE' || status === 'REVOKED';
|
||||
}
|
||||
|
||||
/** Convert an HTTP TaskStatusResponse into the canonical TaskState shape. */
|
||||
function httpResponseToTaskState(resp: TaskStatusResponse): TaskState {
|
||||
const state: TaskState = {
|
||||
task_id: resp.task_id,
|
||||
status: resp.status,
|
||||
progress: resp.progress ?? undefined,
|
||||
processed: resp.processed ?? undefined,
|
||||
total: resp.total ?? undefined,
|
||||
message: resp.message ?? undefined,
|
||||
};
|
||||
|
||||
// Parse the result JSON for detailed phase info
|
||||
if (resp.result) {
|
||||
try {
|
||||
const parsed = JSON.parse(resp.result);
|
||||
if (typeof parsed === 'object' && parsed !== null) {
|
||||
if (parsed.phase) state.phase = parsed.phase;
|
||||
if (parsed.message && !state.message) state.message = parsed.message;
|
||||
if (parsed.progress !== undefined && state.progress === undefined)
|
||||
state.progress = parsed.progress;
|
||||
if (parsed.processed !== undefined && state.processed === undefined)
|
||||
state.processed = parsed.processed;
|
||||
if (parsed.total !== undefined && state.total === undefined)
|
||||
state.total = parsed.total;
|
||||
if (parsed.subqueries_probed !== undefined)
|
||||
state.subqueries_probed = parsed.subqueries_probed;
|
||||
if (parsed.subqueries_initial !== undefined)
|
||||
state.subqueries_initial = parsed.subqueries_initial;
|
||||
if (parsed.estimated_results !== undefined)
|
||||
state.estimated_results = parsed.estimated_results;
|
||||
if (parsed.subqueries_total !== undefined)
|
||||
state.subqueries_total = parsed.subqueries_total;
|
||||
if (parsed.subqueries_completed !== undefined)
|
||||
state.subqueries_completed = parsed.subqueries_completed;
|
||||
if (parsed.ids_collected !== undefined)
|
||||
state.ids_collected = parsed.ids_collected;
|
||||
if (parsed.pages_fetched !== undefined)
|
||||
state.pages_fetched = parsed.pages_fetched;
|
||||
if (parsed.fetching_done !== undefined)
|
||||
state.fetching_done = parsed.fetching_done;
|
||||
if (parsed.details_fetched !== undefined)
|
||||
state.details_fetched = parsed.details_fetched;
|
||||
if (parsed.images_downloaded !== undefined)
|
||||
state.images_downloaded = parsed.images_downloaded;
|
||||
if (parsed.ocr_completed !== undefined)
|
||||
state.ocr_completed = parsed.ocr_completed;
|
||||
if (parsed.failed !== undefined) state.failed = parsed.failed;
|
||||
if (parsed.elapsed_seconds !== undefined)
|
||||
state.elapsed_seconds = parsed.elapsed_seconds;
|
||||
if (parsed.rate_per_second !== undefined)
|
||||
state.rate_per_second = parsed.rate_per_second;
|
||||
if (parsed.eta_seconds !== undefined)
|
||||
state.eta_seconds = parsed.eta_seconds;
|
||||
if (parsed.distances_computed !== undefined)
|
||||
state.distances_computed = parsed.distances_computed;
|
||||
if (Array.isArray(parsed.logs)) state.logs = parsed.logs;
|
||||
}
|
||||
} catch {
|
||||
// Ignore parse errors
|
||||
}
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
export interface UseTaskProgressReturn {
|
||||
tasks: Record<string, TaskState>;
|
||||
isConnected: boolean;
|
||||
subscribe: (taskId: string) => void;
|
||||
cancelTask: (taskId: string) => Promise<boolean>;
|
||||
clearAllTasks: () => Promise<boolean>;
|
||||
}
|
||||
|
||||
export function useTaskProgress(user: AuthUser | null): UseTaskProgressReturn {
|
||||
const [tasks, setTasks] = useState<Record<string, TaskState>>({});
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
|
||||
const wsRef = useRef<WebSocket | null>(null);
|
||||
const reconnectAttempt = useRef(0);
|
||||
const reconnectTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const keepaliveTimer = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
const pollingTimer = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
const mountedRef = useRef(true);
|
||||
const pendingSubscriptions = useRef<Set<string>>(new Set());
|
||||
// Refs to break the dependency cycle: callbacks read current values
|
||||
// from refs so they don't need reactive deps that trigger re-creation.
|
||||
const userRef = useRef(user);
|
||||
userRef.current = user;
|
||||
const tasksRef = useRef(tasks);
|
||||
tasksRef.current = tasks;
|
||||
|
||||
// ---- Polling fallback ----
|
||||
|
||||
const stopPolling = useCallback(() => {
|
||||
if (pollingTimer.current) {
|
||||
clearInterval(pollingTimer.current);
|
||||
pollingTimer.current = null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const startPolling = useCallback(() => {
|
||||
const currentUser = userRef.current;
|
||||
if (!currentUser || pollingTimer.current) return;
|
||||
|
||||
const fetchAndPoll = async () => {
|
||||
const u = userRef.current;
|
||||
if (!mountedRef.current || !u) return;
|
||||
try {
|
||||
const taskIds = await fetchTasksForUser(u);
|
||||
if (!mountedRef.current) return;
|
||||
|
||||
// Also include locally-known non-terminal task IDs
|
||||
const localIds = Object.entries(tasksRef.current)
|
||||
.filter(([, t]) => !isTerminalStatus(t.status))
|
||||
.map(([id]) => id);
|
||||
|
||||
const allIds = [...new Set([...taskIds, ...localIds])];
|
||||
if (allIds.length === 0) return;
|
||||
|
||||
const results = await Promise.allSettled(
|
||||
allIds.map((tid) => fetchTaskStatus(u, tid)),
|
||||
);
|
||||
|
||||
if (!mountedRef.current) return;
|
||||
|
||||
setTasks((prev) => {
|
||||
const next = { ...prev };
|
||||
for (const r of results) {
|
||||
if (r.status === 'fulfilled') {
|
||||
const state = httpResponseToTaskState(r.value);
|
||||
next[state.task_id] = { ...prev[state.task_id], ...state };
|
||||
}
|
||||
}
|
||||
return next;
|
||||
});
|
||||
} catch {
|
||||
// Ignore polling errors
|
||||
}
|
||||
};
|
||||
|
||||
fetchAndPoll();
|
||||
pollingTimer.current = setInterval(fetchAndPoll, POLLING_INTERVALS.TASK_STATUS_MS);
|
||||
}, []); // No reactive deps — reads from refs
|
||||
|
||||
// ---- WebSocket connection ----
|
||||
// Stable connect function — only re-created when user identity changes.
|
||||
const connect = useCallback(() => {
|
||||
if (!user) return;
|
||||
|
||||
const ws = new WebSocket(wsUrl(user.accessToken));
|
||||
wsRef.current = ws;
|
||||
|
||||
ws.onopen = () => {
|
||||
if (!mountedRef.current) return;
|
||||
setIsConnected(true);
|
||||
reconnectAttempt.current = 0;
|
||||
|
||||
// Start keepalive pings
|
||||
keepaliveTimer.current = setInterval(() => {
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({ type: 'ping' }));
|
||||
}
|
||||
}, KEEPALIVE_MS);
|
||||
|
||||
// Send pending subscriptions
|
||||
for (const taskId of pendingSubscriptions.current) {
|
||||
ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId }));
|
||||
}
|
||||
pendingSubscriptions.current.clear();
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
if (!mountedRef.current) return;
|
||||
try {
|
||||
const msg: WSMessage = JSON.parse(event.data);
|
||||
|
||||
if (msg.type === 'init') {
|
||||
const initial: Record<string, TaskState> = {};
|
||||
for (const t of msg.tasks) {
|
||||
initial[t.task_id] = t;
|
||||
}
|
||||
setTasks(initial);
|
||||
} else if (msg.type === 'task_update') {
|
||||
const { type: _, ...taskData } = msg;
|
||||
setTasks((prev) => ({
|
||||
...prev,
|
||||
[msg.task_id]: { ...prev[msg.task_id], ...taskData } as TaskState,
|
||||
}));
|
||||
}
|
||||
} catch {
|
||||
// Ignore malformed messages
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
if (!mountedRef.current) return;
|
||||
setIsConnected(false);
|
||||
|
||||
if (keepaliveTimer.current) {
|
||||
clearInterval(keepaliveTimer.current);
|
||||
keepaliveTimer.current = null;
|
||||
}
|
||||
|
||||
// Fallback to polling while WS is down
|
||||
startPolling();
|
||||
|
||||
// Exponential backoff reconnect
|
||||
const delay = Math.min(
|
||||
1000 * 2 ** reconnectAttempt.current,
|
||||
MAX_RECONNECT_DELAY_MS,
|
||||
);
|
||||
reconnectAttempt.current += 1;
|
||||
reconnectTimer.current = setTimeout(() => {
|
||||
if (mountedRef.current && userRef.current) connect();
|
||||
}, delay);
|
||||
};
|
||||
|
||||
ws.onerror = () => {
|
||||
// onclose will fire after this, triggering reconnect
|
||||
};
|
||||
}, [user, startPolling]); // startPolling is stable (no deps)
|
||||
|
||||
// Mount/unmount + reconnect when user changes
|
||||
useEffect(() => {
|
||||
mountedRef.current = true;
|
||||
connect();
|
||||
startPolling(); // Always run polling as safety net alongside WebSocket
|
||||
|
||||
return () => {
|
||||
mountedRef.current = false;
|
||||
if (reconnectTimer.current) {
|
||||
clearTimeout(reconnectTimer.current);
|
||||
reconnectTimer.current = null;
|
||||
}
|
||||
if (keepaliveTimer.current) {
|
||||
clearInterval(keepaliveTimer.current);
|
||||
keepaliveTimer.current = null;
|
||||
}
|
||||
stopPolling();
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close();
|
||||
wsRef.current = null;
|
||||
}
|
||||
};
|
||||
}, [connect, startPolling, stopPolling]);
|
||||
|
||||
// ---- Public API ----
|
||||
|
||||
const subscribe = useCallback((taskId: string) => {
|
||||
const ws = wsRef.current;
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId }));
|
||||
} else {
|
||||
pendingSubscriptions.current.add(taskId);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const cancelTask = useCallback(
|
||||
async (taskId: string): Promise<boolean> => {
|
||||
if (!user) return false;
|
||||
try {
|
||||
const { cancelTask: cancel } = await import('@/services');
|
||||
const result = await cancel(user, taskId);
|
||||
if (result.success) {
|
||||
setTasks((prev) => ({
|
||||
...prev,
|
||||
[taskId]: { ...prev[taskId], task_id: taskId, status: 'REVOKED' },
|
||||
}));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
[user],
|
||||
);
|
||||
|
||||
const clearAllTasks = useCallback(async (): Promise<boolean> => {
|
||||
if (!user) return false;
|
||||
try {
|
||||
const { clearAllTasks: clearAll } = await import('@/services');
|
||||
const result = await clearAll(user);
|
||||
if (result.success) {
|
||||
setTasks({});
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}, [user]);
|
||||
|
||||
return { tasks, isConnected, subscribe, cancelTask, clearAllTasks };
|
||||
}
|
||||
|
|
@ -35,8 +35,9 @@ def publish_task_progress(task_id: str, state: str, meta: dict[str, Any]) -> Non
|
|||
state: Celery state string (e.g. 'PROGRESS', 'SUCCESS').
|
||||
meta: Metadata dict (progress, phase, logs, counters, etc.).
|
||||
|
||||
Failures are caught and logged at DEBUG level so they never break the
|
||||
critical task execution path.
|
||||
Failures are caught and logged at WARNING level so they never break the
|
||||
critical task execution path. The Redis client is reset on failure so
|
||||
subsequent calls can reconnect.
|
||||
"""
|
||||
try:
|
||||
client = _get_redis_client()
|
||||
|
|
@ -47,4 +48,6 @@ def publish_task_progress(task_id: str, state: str, meta: dict[str, Any]) -> Non
|
|||
})
|
||||
client.publish(f"task_progress:{task_id}", payload)
|
||||
except Exception:
|
||||
logger.debug("Failed to publish task progress for %s", task_id, exc_info=True)
|
||||
logger.warning("Failed to publish task progress for %s", task_id, exc_info=True)
|
||||
# Reset client so next call creates a fresh connection
|
||||
_redis_client = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue