import { useCallback, useEffect, useRef, useState } from 'react'; import type { AuthUser } from '@/auth/types'; import type { TaskState, TaskStatusResponse, WSMessage } from '@/types'; import { WS_TASKS_PATH } from '@/constants'; import { fetchTasksForUser, fetchTaskStatus } from '@/services'; const KEEPALIVE_MS = 30_000; const MAX_RECONNECT_DELAY_MS = 30_000; const POLL_FAST_MS = 5_000; // WS down — polling is the primary update source const POLL_SLOW_MS = 30_000; // WS up — polling is just a safety net function wsUrl(token: string): string { const proto = window.location.protocol === 'https:' ? 'wss' : 'ws'; return `${proto}://${window.location.host}${WS_TASKS_PATH}?token=${encodeURIComponent(token)}`; } function isTerminalStatus(status: string): boolean { return status === 'SUCCESS' || status === 'FAILURE' || status === 'REVOKED'; } /** Convert an HTTP TaskStatusResponse into the canonical TaskState shape. */ function httpResponseToTaskState(resp: TaskStatusResponse): TaskState { const state: TaskState = { task_id: resp.task_id, status: resp.status, progress: resp.progress ?? undefined, processed: resp.processed ?? undefined, total: resp.total ?? undefined, message: resp.message ?? undefined, }; // Parse the result JSON for detailed phase info if (resp.result) { try { const parsed = JSON.parse(resp.result); if (typeof parsed === 'object' && parsed !== null) { if (parsed.phase) state.phase = parsed.phase; if (parsed.message && !state.message) state.message = parsed.message; if (parsed.progress !== undefined && state.progress === undefined) state.progress = parsed.progress; if (parsed.processed !== undefined && state.processed === undefined) state.processed = parsed.processed; if (parsed.total !== undefined && state.total === undefined) state.total = parsed.total; if (parsed.subqueries_probed !== undefined) state.subqueries_probed = parsed.subqueries_probed; if (parsed.subqueries_initial !== undefined) state.subqueries_initial = parsed.subqueries_initial; if (parsed.estimated_results !== undefined) state.estimated_results = parsed.estimated_results; if (parsed.subqueries_total !== undefined) state.subqueries_total = parsed.subqueries_total; if (parsed.subqueries_completed !== undefined) state.subqueries_completed = parsed.subqueries_completed; if (parsed.ids_collected !== undefined) state.ids_collected = parsed.ids_collected; if (parsed.pages_fetched !== undefined) state.pages_fetched = parsed.pages_fetched; if (parsed.fetching_done !== undefined) state.fetching_done = parsed.fetching_done; if (parsed.details_fetched !== undefined) state.details_fetched = parsed.details_fetched; if (parsed.images_downloaded !== undefined) state.images_downloaded = parsed.images_downloaded; if (parsed.ocr_completed !== undefined) state.ocr_completed = parsed.ocr_completed; if (parsed.failed !== undefined) state.failed = parsed.failed; if (parsed.elapsed_seconds !== undefined) state.elapsed_seconds = parsed.elapsed_seconds; if (parsed.rate_per_second !== undefined) state.rate_per_second = parsed.rate_per_second; if (parsed.eta_seconds !== undefined) state.eta_seconds = parsed.eta_seconds; if (parsed.distances_computed !== undefined) state.distances_computed = parsed.distances_computed; if (Array.isArray(parsed.logs)) state.logs = parsed.logs; } } catch { // Ignore parse errors } } return state; } export interface UseTaskProgressReturn { tasks: Record; isConnected: boolean; subscribe: (taskId: string) => void; cancelTask: (taskId: string) => Promise; clearAllTasks: () => Promise; } export function useTaskProgress(user: AuthUser | null): UseTaskProgressReturn { const [tasks, setTasks] = useState>({}); const [isConnected, setIsConnected] = useState(false); const wsRef = useRef(null); const reconnectAttempt = useRef(0); const reconnectTimer = useRef | null>(null); const keepaliveTimer = useRef | null>(null); const pollingTimer = useRef | null>(null); const mountedRef = useRef(true); const pendingSubscriptions = useRef>(new Set()); // Refs to break the dependency cycle: callbacks read current values // from refs so they don't need reactive deps that trigger re-creation. const userRef = useRef(user); userRef.current = user; const tasksRef = useRef(tasks); tasksRef.current = tasks; // ---- Polling fallback ---- const stopPolling = useCallback(() => { if (pollingTimer.current) { clearInterval(pollingTimer.current); pollingTimer.current = null; } }, []); const startPolling = useCallback((intervalMs: number = POLL_SLOW_MS) => { const currentUser = userRef.current; if (!currentUser) return; // Stop existing polling before (re)starting with the requested interval if (pollingTimer.current) { clearInterval(pollingTimer.current); pollingTimer.current = null; } const fetchAndPoll = async () => { const u = userRef.current; if (!mountedRef.current || !u) return; try { const taskIds = await fetchTasksForUser(u); if (!mountedRef.current) return; // Also include locally-known non-terminal task IDs const localIds = Object.entries(tasksRef.current) .filter(([, t]) => !isTerminalStatus(t.status)) .map(([id]) => id); const allIds = [...new Set([...taskIds, ...localIds])]; if (allIds.length === 0) return; const results = await Promise.allSettled( allIds.map((tid) => fetchTaskStatus(u, tid)), ); if (!mountedRef.current) return; setTasks((prev) => { const next = { ...prev }; for (const r of results) { if (r.status === 'fulfilled') { const state = httpResponseToTaskState(r.value); next[state.task_id] = { ...prev[state.task_id], ...state }; } } return next; }); } catch { // Ignore polling errors } }; fetchAndPoll(); pollingTimer.current = setInterval(fetchAndPoll, intervalMs); }, []); // No reactive deps — reads from refs // ---- WebSocket connection ---- // Stable connect function — only re-created when user identity changes. const connect = useCallback(() => { if (!user) return; const ws = new WebSocket(wsUrl(user.accessToken)); wsRef.current = ws; ws.onopen = () => { if (!mountedRef.current) return; setIsConnected(true); reconnectAttempt.current = 0; // WS is primary — slow down polling to safety-net rate startPolling(POLL_SLOW_MS); // Start keepalive pings keepaliveTimer.current = setInterval(() => { if (ws.readyState === WebSocket.OPEN) { ws.send(JSON.stringify({ type: 'ping' })); } }, KEEPALIVE_MS); // Send pending subscriptions for (const taskId of pendingSubscriptions.current) { ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId })); } pendingSubscriptions.current.clear(); }; ws.onmessage = (event) => { if (!mountedRef.current) return; try { const msg: WSMessage = JSON.parse(event.data); if (msg.type === 'init') { const initial: Record = {}; for (const t of msg.tasks) { initial[t.task_id] = t; } setTasks(initial); } else if (msg.type === 'task_update') { const { type: _, ...taskData } = msg; setTasks((prev) => ({ ...prev, [msg.task_id]: { ...prev[msg.task_id], ...taskData } as TaskState, })); } } catch { // Ignore malformed messages } }; ws.onclose = () => { if (!mountedRef.current) return; setIsConnected(false); if (keepaliveTimer.current) { clearInterval(keepaliveTimer.current); keepaliveTimer.current = null; } // WS is down — poll at fast rate as primary update source startPolling(POLL_FAST_MS); // Exponential backoff reconnect const delay = Math.min( 1000 * 2 ** reconnectAttempt.current, MAX_RECONNECT_DELAY_MS, ); reconnectAttempt.current += 1; reconnectTimer.current = setTimeout(() => { if (mountedRef.current && userRef.current) connect(); }, delay); }; ws.onerror = () => { // onclose will fire after this, triggering reconnect }; }, [user, startPolling]); // startPolling is stable (no deps) // Mount/unmount + reconnect when user changes useEffect(() => { mountedRef.current = true; connect(); startPolling(); // Always run polling as safety net alongside WebSocket return () => { mountedRef.current = false; if (reconnectTimer.current) { clearTimeout(reconnectTimer.current); reconnectTimer.current = null; } if (keepaliveTimer.current) { clearInterval(keepaliveTimer.current); keepaliveTimer.current = null; } stopPolling(); if (wsRef.current) { wsRef.current.close(); wsRef.current = null; } }; }, [connect, startPolling, stopPolling]); // ---- Public API ---- const subscribe = useCallback((taskId: string) => { const ws = wsRef.current; if (ws && ws.readyState === WebSocket.OPEN) { ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId })); } else { pendingSubscriptions.current.add(taskId); } }, []); const cancelTask = useCallback( async (taskId: string): Promise => { if (!user) return false; try { const { cancelTask: cancel } = await import('@/services'); const result = await cancel(user, taskId); if (result.success) { setTasks((prev) => ({ ...prev, [taskId]: { ...prev[taskId], task_id: taskId, status: 'REVOKED' }, })); return true; } return false; } catch { return false; } }, [user], ); const clearAllTasks = useCallback(async (): Promise => { if (!user) return false; try { const { clearAllTasks: clearAll } = await import('@/services'); const result = await clearAll(user); if (result.success) { setTasks({}); return true; } return false; } catch { return false; } }, [user]); return { tasks, isConnected, subscribe, cancelTask, clearAllTasks }; }