Fix real-time task progress by closing WS on pubsub exit and keeping polling active

Three interconnected bugs prevented progress updates from reaching the frontend:

1. _forward_pubsub could exit silently while _handle_client_messages kept
   the WebSocket alive (responding to pings), so the client never detected
   the broken forwarding path. Replace asyncio.gather with asyncio.wait
   (FIRST_COMPLETED) so both coroutines are cancelled together.

2. Polling was stopped on WS connect with no fallback if forwarding broke.
   Now polling runs always alongside WebSocket as a safety net.

3. Redis publish failures in task_progress_publisher were logged at DEBUG
   and the broken client was reused forever. Log at WARNING and reset the
   client so the next call reconnects.
This commit is contained in:
Viktor Barzin 2026-02-09 22:48:57 +00:00
parent 8d52bdf99d
commit 791b5a9d55
No known key found for this signature in database
GPG key ID: 0EB088298288D958
3 changed files with 362 additions and 19 deletions

View file

@ -0,0 +1,313 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import type { AuthUser } from '@/auth/types';
import type { TaskState, TaskStatusResponse, WSMessage } from '@/types';
import { WS_TASKS_PATH, POLLING_INTERVALS } from '@/constants';
import { fetchTasksForUser, fetchTaskStatus } from '@/services';
const KEEPALIVE_MS = 30_000;
const MAX_RECONNECT_DELAY_MS = 30_000;
function wsUrl(token: string): string {
const proto = window.location.protocol === 'https:' ? 'wss' : 'ws';
return `${proto}://${window.location.host}${WS_TASKS_PATH}?token=${encodeURIComponent(token)}`;
}
function isTerminalStatus(status: string): boolean {
return status === 'SUCCESS' || status === 'FAILURE' || status === 'REVOKED';
}
/** Convert an HTTP TaskStatusResponse into the canonical TaskState shape. */
function httpResponseToTaskState(resp: TaskStatusResponse): TaskState {
const state: TaskState = {
task_id: resp.task_id,
status: resp.status,
progress: resp.progress ?? undefined,
processed: resp.processed ?? undefined,
total: resp.total ?? undefined,
message: resp.message ?? undefined,
};
// Parse the result JSON for detailed phase info
if (resp.result) {
try {
const parsed = JSON.parse(resp.result);
if (typeof parsed === 'object' && parsed !== null) {
if (parsed.phase) state.phase = parsed.phase;
if (parsed.message && !state.message) state.message = parsed.message;
if (parsed.progress !== undefined && state.progress === undefined)
state.progress = parsed.progress;
if (parsed.processed !== undefined && state.processed === undefined)
state.processed = parsed.processed;
if (parsed.total !== undefined && state.total === undefined)
state.total = parsed.total;
if (parsed.subqueries_probed !== undefined)
state.subqueries_probed = parsed.subqueries_probed;
if (parsed.subqueries_initial !== undefined)
state.subqueries_initial = parsed.subqueries_initial;
if (parsed.estimated_results !== undefined)
state.estimated_results = parsed.estimated_results;
if (parsed.subqueries_total !== undefined)
state.subqueries_total = parsed.subqueries_total;
if (parsed.subqueries_completed !== undefined)
state.subqueries_completed = parsed.subqueries_completed;
if (parsed.ids_collected !== undefined)
state.ids_collected = parsed.ids_collected;
if (parsed.pages_fetched !== undefined)
state.pages_fetched = parsed.pages_fetched;
if (parsed.fetching_done !== undefined)
state.fetching_done = parsed.fetching_done;
if (parsed.details_fetched !== undefined)
state.details_fetched = parsed.details_fetched;
if (parsed.images_downloaded !== undefined)
state.images_downloaded = parsed.images_downloaded;
if (parsed.ocr_completed !== undefined)
state.ocr_completed = parsed.ocr_completed;
if (parsed.failed !== undefined) state.failed = parsed.failed;
if (parsed.elapsed_seconds !== undefined)
state.elapsed_seconds = parsed.elapsed_seconds;
if (parsed.rate_per_second !== undefined)
state.rate_per_second = parsed.rate_per_second;
if (parsed.eta_seconds !== undefined)
state.eta_seconds = parsed.eta_seconds;
if (parsed.distances_computed !== undefined)
state.distances_computed = parsed.distances_computed;
if (Array.isArray(parsed.logs)) state.logs = parsed.logs;
}
} catch {
// Ignore parse errors
}
}
return state;
}
export interface UseTaskProgressReturn {
tasks: Record<string, TaskState>;
isConnected: boolean;
subscribe: (taskId: string) => void;
cancelTask: (taskId: string) => Promise<boolean>;
clearAllTasks: () => Promise<boolean>;
}
export function useTaskProgress(user: AuthUser | null): UseTaskProgressReturn {
const [tasks, setTasks] = useState<Record<string, TaskState>>({});
const [isConnected, setIsConnected] = useState(false);
const wsRef = useRef<WebSocket | null>(null);
const reconnectAttempt = useRef(0);
const reconnectTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
const keepaliveTimer = useRef<ReturnType<typeof setInterval> | null>(null);
const pollingTimer = useRef<ReturnType<typeof setInterval> | null>(null);
const mountedRef = useRef(true);
const pendingSubscriptions = useRef<Set<string>>(new Set());
// Refs to break the dependency cycle: callbacks read current values
// from refs so they don't need reactive deps that trigger re-creation.
const userRef = useRef(user);
userRef.current = user;
const tasksRef = useRef(tasks);
tasksRef.current = tasks;
// ---- Polling fallback ----
const stopPolling = useCallback(() => {
if (pollingTimer.current) {
clearInterval(pollingTimer.current);
pollingTimer.current = null;
}
}, []);
const startPolling = useCallback(() => {
const currentUser = userRef.current;
if (!currentUser || pollingTimer.current) return;
const fetchAndPoll = async () => {
const u = userRef.current;
if (!mountedRef.current || !u) return;
try {
const taskIds = await fetchTasksForUser(u);
if (!mountedRef.current) return;
// Also include locally-known non-terminal task IDs
const localIds = Object.entries(tasksRef.current)
.filter(([, t]) => !isTerminalStatus(t.status))
.map(([id]) => id);
const allIds = [...new Set([...taskIds, ...localIds])];
if (allIds.length === 0) return;
const results = await Promise.allSettled(
allIds.map((tid) => fetchTaskStatus(u, tid)),
);
if (!mountedRef.current) return;
setTasks((prev) => {
const next = { ...prev };
for (const r of results) {
if (r.status === 'fulfilled') {
const state = httpResponseToTaskState(r.value);
next[state.task_id] = { ...prev[state.task_id], ...state };
}
}
return next;
});
} catch {
// Ignore polling errors
}
};
fetchAndPoll();
pollingTimer.current = setInterval(fetchAndPoll, POLLING_INTERVALS.TASK_STATUS_MS);
}, []); // No reactive deps — reads from refs
// ---- WebSocket connection ----
// Stable connect function — only re-created when user identity changes.
const connect = useCallback(() => {
if (!user) return;
const ws = new WebSocket(wsUrl(user.accessToken));
wsRef.current = ws;
ws.onopen = () => {
if (!mountedRef.current) return;
setIsConnected(true);
reconnectAttempt.current = 0;
// Start keepalive pings
keepaliveTimer.current = setInterval(() => {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: 'ping' }));
}
}, KEEPALIVE_MS);
// Send pending subscriptions
for (const taskId of pendingSubscriptions.current) {
ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId }));
}
pendingSubscriptions.current.clear();
};
ws.onmessage = (event) => {
if (!mountedRef.current) return;
try {
const msg: WSMessage = JSON.parse(event.data);
if (msg.type === 'init') {
const initial: Record<string, TaskState> = {};
for (const t of msg.tasks) {
initial[t.task_id] = t;
}
setTasks(initial);
} else if (msg.type === 'task_update') {
const { type: _, ...taskData } = msg;
setTasks((prev) => ({
...prev,
[msg.task_id]: { ...prev[msg.task_id], ...taskData } as TaskState,
}));
}
} catch {
// Ignore malformed messages
}
};
ws.onclose = () => {
if (!mountedRef.current) return;
setIsConnected(false);
if (keepaliveTimer.current) {
clearInterval(keepaliveTimer.current);
keepaliveTimer.current = null;
}
// Fallback to polling while WS is down
startPolling();
// Exponential backoff reconnect
const delay = Math.min(
1000 * 2 ** reconnectAttempt.current,
MAX_RECONNECT_DELAY_MS,
);
reconnectAttempt.current += 1;
reconnectTimer.current = setTimeout(() => {
if (mountedRef.current && userRef.current) connect();
}, delay);
};
ws.onerror = () => {
// onclose will fire after this, triggering reconnect
};
}, [user, startPolling]); // startPolling is stable (no deps)
// Mount/unmount + reconnect when user changes
useEffect(() => {
mountedRef.current = true;
connect();
startPolling(); // Always run polling as safety net alongside WebSocket
return () => {
mountedRef.current = false;
if (reconnectTimer.current) {
clearTimeout(reconnectTimer.current);
reconnectTimer.current = null;
}
if (keepaliveTimer.current) {
clearInterval(keepaliveTimer.current);
keepaliveTimer.current = null;
}
stopPolling();
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
}
};
}, [connect, startPolling, stopPolling]);
// ---- Public API ----
const subscribe = useCallback((taskId: string) => {
const ws = wsRef.current;
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId }));
} else {
pendingSubscriptions.current.add(taskId);
}
}, []);
const cancelTask = useCallback(
async (taskId: string): Promise<boolean> => {
if (!user) return false;
try {
const { cancelTask: cancel } = await import('@/services');
const result = await cancel(user, taskId);
if (result.success) {
setTasks((prev) => ({
...prev,
[taskId]: { ...prev[taskId], task_id: taskId, status: 'REVOKED' },
}));
return true;
}
return false;
} catch {
return false;
}
},
[user],
);
const clearAllTasks = useCallback(async (): Promise<boolean> => {
if (!user) return false;
try {
const { clearAllTasks: clearAll } = await import('@/services');
const result = await clearAll(user);
if (result.success) {
setTasks({});
return true;
}
return false;
} catch {
return false;
}
}, [user]);
return { tasks, isConnected, subscribe, cancelTask, clearAllTasks };
}