From 2d86213db5aea9f7c5c6e1c57d6d5d110c8c3118 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Mon, 9 Feb 2026 23:02:24 +0000 Subject: [PATCH] Refactor task progress to unified useTaskProgress hook MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace WebSocket-only useTaskWebSocket with useTaskProgress that provides a unified task state interface. TaskIndicator no longer manages its own polling or auth — it receives task state from the parent via props. Rename wsTasks prop to tasks throughout. --- frontend/src/App.tsx | 74 +++--- frontend/src/components/Header.tsx | 30 +-- frontend/src/components/TaskIndicator.tsx | 236 ++++-------------- .../src/components/TaskProgressDrawer.tsx | 12 +- frontend/src/hooks/useTaskWebSocket.ts | 129 ---------- frontend/src/types/index.ts | 12 +- 6 files changed, 130 insertions(+), 363 deletions(-) delete mode 100644 frontend/src/hooks/useTaskWebSocket.ts diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 8bd43de..1f3b4a8 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -17,13 +17,16 @@ import { Sheet, SheetContent, SheetTrigger } from './components/ui/sheet'; import { Button } from './components/ui/button'; import { Filter } from 'lucide-react'; import type { GeoJSONFeatureCollection, PropertyProperties, PropertyFeature, POI, POITravelFilter } from '@/types'; -import { refreshListings, fetchTasksForUser, streamListingGeoJSON, fetchUserPOIs, type StreamingProgress } from '@/services'; +import { refreshListings, streamListingGeoJSON, fetchUserPOIs, type StreamingProgress } from '@/services'; import { poiMetricPropertyName, injectPoiMetricProperty } from '@/utils/poiUtils'; -import { useTaskWebSocket } from '@/hooks/useTaskWebSocket'; +import { useTaskProgress } from '@/hooks/useTaskProgress'; + +function isTerminalStatus(status: string): boolean { + return status === 'SUCCESS' || status === 'FAILURE' || status === 'REVOKED'; +} function App() { const [listingData, setListingData] = useState(null); - const [taskID, setTaskID] = useState(null); const [user, setUser] = useState(null); const [queryParameters, setQueryParameters] = useState(null); const [submitError, setSubmitError] = useState(null); @@ -44,8 +47,26 @@ function App() { const [poiTravelFilters, setPoiTravelFilters] = useState>({}); const [currentMetric, setCurrentMetric] = useState(DEFAULT_FILTER_VALUES.metric); - // WebSocket-based real-time task progress - const { tasks: wsTasks, isConnected: wsConnected, subscribe: wsSubscribe } = useTaskWebSocket(user); + // Explicit task ID set by fetch-data action (to track as "active") + const [explicitTaskId, setExplicitTaskId] = useState(null); + + // Unified task progress: WS primary, polling fallback + const { tasks, isConnected, subscribe, cancelTask, clearAllTasks } = useTaskProgress(user); + + // Derive activeTaskId: explicit ID if set, else most recent non-terminal task + const activeTaskId = useMemo(() => { + if (explicitTaskId && tasks[explicitTaskId]) return explicitTaskId; + // Fall back to any non-terminal task + const nonTerminal = Object.entries(tasks).filter( + ([, t]) => !isTerminalStatus(t.status), + ); + if (nonTerminal.length > 0) return nonTerminal[0][0]; + // Fall back to explicit even if terminal (to show final status) + if (explicitTaskId && tasks[explicitTaskId]) return explicitTaskId; + // Show most recent task if any + const allIds = Object.keys(tasks); + return allIds.length > 0 ? allIds[allIds.length - 1] : null; + }, [explicitTaskId, tasks]); // Ref to track accumulated features during streaming const accumulatedFeaturesRef = useRef([]); @@ -77,17 +98,6 @@ function App() { setUser(passkeyUser); }; - useEffect(() => { - if (!user) { - return; - } - fetchTasksForUser(user).then((tasks) => { - if (tasks && tasks.length > 0) { - setTaskID(tasks[0]); - } - }); - }, [user, taskID]); - // Load user's POIs useEffect(() => { if (!user) return; @@ -235,6 +245,10 @@ function App() { } }, [queryParameters, loadListings]); + const handleTaskCancelled = useCallback(() => { + setExplicitTaskId(null); + }, []); + if (!user) { return ; } @@ -248,8 +262,8 @@ function App() { setIsLoading(true); try { const data = await refreshListings(user!, parameters); - setTaskID(data.task_id); - if (data.task_id) wsSubscribe(data.task_id); + setExplicitTaskId(data.task_id); + if (data.task_id) subscribe(data.task_id); } catch (error) { if (error instanceof Error) { setSubmitError(error.message); @@ -347,13 +361,9 @@ function App() { ); }; - const handleTaskCancelled = () => { - setTaskID(null); - }; - const handlePOITaskCreated = (taskId: string) => { - setTaskID(taskId); - if (taskId) wsSubscribe(taskId); + setExplicitTaskId(taskId); + if (taskId) subscribe(taskId); // Refresh POI list in case new ones were created if (user) { fetchUserPOIs(user).then(setUserPOIs).catch(() => {}); @@ -379,12 +389,18 @@ function App() { {/* Header */}
{ + const result = await clearAllTasks(); + if (result) { + handleTaskCancelled(); + } + return result; + }} onTaskCompleted={handleTaskCompleted} - wsTasks={wsTasks} - wsConnected={wsConnected} - wsSubscribe={wsSubscribe} /> {/* Main content area */} diff --git a/frontend/src/components/Header.tsx b/frontend/src/components/Header.tsx index f2e5c42..d392eb0 100644 --- a/frontend/src/components/Header.tsx +++ b/frontend/src/components/Header.tsx @@ -11,28 +11,29 @@ import { TaskIndicator } from './TaskIndicator'; interface HeaderProps { user: AuthUser; activeFilterCount?: number; - taskID?: string | null; isLoading?: boolean; onToggleFilters?: () => void; showFilterToggle?: boolean; - onTaskCancelled?: () => void; + // Task progress (unified) + tasks: Record; + activeTaskId: string | null; + isConnected: boolean; + onCancelTask: (taskId: string) => Promise; + onClearAllTasks: () => Promise; onTaskCompleted?: () => void; - wsTasks?: Record; - wsConnected?: boolean; - wsSubscribe?: (taskId: string) => void; } export function Header({ user, activeFilterCount = 0, - taskID, onToggleFilters, showFilterToggle = false, - onTaskCancelled, + tasks, + activeTaskId, + isConnected, + onCancelTask, + onClearAllTasks, onTaskCompleted, - wsTasks, - wsConnected, - wsSubscribe, }: HeaderProps) { const handleLogout = async () => { if (user.provider === 'passkey') { @@ -58,11 +59,12 @@ export function Header({ {/* Task Indicator */} {/* Filter Toggle (mobile) */} diff --git a/frontend/src/components/TaskIndicator.tsx b/frontend/src/components/TaskIndicator.tsx index 36afba1..2672489 100644 --- a/frontend/src/components/TaskIndicator.tsx +++ b/frontend/src/components/TaskIndicator.tsx @@ -1,8 +1,3 @@ -import { getUser } from '@/auth/authService'; -import { getStoredPasskeyUser } from '@/auth/passkeyService'; -import { fromOidcUser, type AuthUser } from '@/auth/types'; -import { POLLING_INTERVALS } from '@/constants'; -import { fetchTaskStatus, cancelTask, clearAllTasks } from '@/services'; import { TaskStatus, type TaskResult, type TaskState } from '@/types'; import { useEffect, useState, useRef, useMemo } from 'react'; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from './ui/tooltip'; @@ -11,14 +6,15 @@ import { Loader2, CheckCircle2, XCircle, X, Trash2 } from 'lucide-react'; import { TaskProgressDrawer } from './TaskProgressDrawer'; interface TaskIndicatorProps { - taskID: string | null; - onTaskCancelled?: () => void; + tasks: Record; + activeTaskId: string | null; + isConnected: boolean; + onCancelTask: (taskId: string) => Promise; + onClearAllTasks: () => Promise; onTaskCompleted?: () => void; - wsTasks?: Record; - wsConnected?: boolean; } -/** Convert a TaskState (from WS) into a TaskResult (for the drawer). */ +/** Convert a TaskState into a TaskResult (for the drawer). */ function taskStateToResult(ts: TaskState): TaskResult { return { progress: ts.progress ?? 0, @@ -50,215 +46,91 @@ function isTerminalStatus(status: string): boolean { } export function TaskIndicator({ - taskID, - onTaskCancelled, + tasks, + activeTaskId, + isConnected: _isConnected, + onCancelTask, + onClearAllTasks, onTaskCompleted, - wsTasks, - wsConnected, }: TaskIndicatorProps) { - const [user, setUser] = useState(null); - const [progressPercentage, setProgressPercentage] = useState(0); - const [processed, setProcessed] = useState(null); - const [total, setTotal] = useState(null); - const [taskStatus, setTaskStatus] = useState(null); - const [taskResult, setTaskResult] = useState(null); const [isCancelling, setIsCancelling] = useState(false); const [isClearing, setIsClearing] = useState(false); const [drawerOpen, setDrawerOpen] = useState(false); const [selectedTaskId, setSelectedTaskId] = useState(null); - // Prevents WS effect from overwriting local cancel/clear state before - // the parent's setTaskID(null) propagates. Reset when taskID changes. - const cancelledRef = useRef(false); - const onTaskCompletedRef = useRef(onTaskCompleted); useEffect(() => { onTaskCompletedRef.current = onTaskCompleted; }, [onTaskCompleted]); + // Track the currently-viewed task in the drawer; default to the externally-provided activeTaskId useEffect(() => { - const passkeyUser = getStoredPasskeyUser(); - if (passkeyUser) { - setUser(passkeyUser); - } else { - getUser().then((oidcUser) => { - if (oidcUser) setUser(fromOidcUser(oidcUser)); - }); + if (activeTaskId) { + setSelectedTaskId(activeTaskId); } - }, []); + }, [activeTaskId]); - // Track the currently-viewed task in the drawer; default to the externally-provided taskID + // Fire onTaskCompleted when the active task transitions to SUCCESS + const prevStatusRef = useRef(null); useEffect(() => { - if (taskID) { - setSelectedTaskId(taskID); - cancelledRef.current = false; // new task, reset cancelled guard - } - }, [taskID]); - - // Count active (non-terminal) tasks from WS - const activeWsTaskCount = useMemo(() => { - if (!wsTasks) return 0; - return Object.values(wsTasks).filter( - (t) => !isTerminalStatus(t.status), - ).length; - }, [wsTasks]); - - // ----- WebSocket-driven state updates ----- - // When wsConnected, derive taskStatus/taskResult/progress from wsTasks - useEffect(() => { - if (!wsConnected || !wsTasks || !taskID) return; - // Don't let WS overwrite local cancel/clear state - if (cancelledRef.current) return; - const wsTask = wsTasks[taskID]; - if (!wsTask) return; - - const status = wsTask.status as TaskStatus; - setTaskStatus(status); - - if (wsTask.phase) { - setTaskResult(taskStateToResult(wsTask)); - } - if (wsTask.progress !== undefined) { - setProgressPercentage(wsTask.progress * 100); - } - if (wsTask.processed !== undefined) { - setProcessed(wsTask.processed); - } - if (wsTask.total !== undefined) { - setTotal(wsTask.total); - } - - if (status === TaskStatus.SUCCESS) { - setProgressPercentage(100); - onTaskCompletedRef.current?.(); - } - }, [wsConnected, wsTasks, taskID]); - - // ----- Polling (always active as baseline; WS provides faster updates on top) ----- - useEffect(() => { - if (!user || !taskID) { - setTaskStatus(null); - setTaskResult(null); + if (!activeTaskId) { + prevStatusRef.current = null; return; } + const task = tasks[activeTaskId]; + const currentStatus = task?.status ?? null; + if ( + currentStatus === 'SUCCESS' && + prevStatusRef.current !== null && + prevStatusRef.current !== 'SUCCESS' + ) { + onTaskCompletedRef.current?.(); + } + prevStatusRef.current = currentStatus; + }, [activeTaskId, tasks]); - // Reset state for new task - setTaskStatus(TaskStatus.PENDING); - setProgressPercentage(0); - setProcessed(null); - setTotal(null); - setTaskResult(null); + // Derive display data from the active task + const activeTask = activeTaskId ? tasks[activeTaskId] : undefined; + const taskStatus = activeTask ? (activeTask.status as TaskStatus) : null; + const taskResult = activeTask?.phase ? taskStateToResult(activeTask) : null; + const progressPercentage = (activeTask?.progress ?? 0) * 100; + const processed = activeTask?.processed ?? null; + const total = activeTask?.total ?? null; - const pollTaskStatus = async () => { - // Skip this poll cycle if cancelled locally - if (cancelledRef.current) return true; - try { - const data = await fetchTaskStatus(user, taskID); - const status = data.status as TaskStatus; - setTaskStatus(status); - - if (status === TaskStatus.SUCCESS) { - setProgressPercentage(100); - if (data.result) { - try { - const parsedResult: TaskResult = JSON.parse(data.result); - if (parsedResult.phase) { - setTaskResult(parsedResult); - } - } catch { - // Ignore parsing errors - } - } - onTaskCompletedRef.current?.(); - return true; - } - - if (status === TaskStatus.FAILURE || status === TaskStatus.REVOKED) { - return true; - } - - if (data.result) { - try { - const parsedResult: TaskResult = JSON.parse(data.result); - if (parsedResult.phase) { - setTaskResult(parsedResult); - } - if (parsedResult.progress !== undefined) { - setProgressPercentage(parsedResult.progress * 100); - } - if (parsedResult.processed !== undefined) { - setProcessed(parsedResult.processed); - } - if (parsedResult.total !== undefined) { - setTotal(parsedResult.total); - } - } catch { - // Ignore parsing errors - } - } - return false; - } catch { - setTaskStatus(TaskStatus.FAILURE); - return true; - } - }; - - pollTaskStatus(); - - const interval = setInterval(async () => { - const shouldStop = await pollTaskStatus(); - if (shouldStop) { - clearInterval(interval); - } - }, POLLING_INTERVALS.TASK_STATUS_MS); - - return () => clearInterval(interval); - }, [taskID, user]); + // Count active (non-terminal) tasks + const activeTaskCount = useMemo(() => { + return Object.values(tasks).filter( + (t) => !isTerminalStatus(t.status), + ).length; + }, [tasks]); const handleCancel = async () => { - if (!user || !taskID || isCancelling) return; + if (!activeTaskId || isCancelling) return; setIsCancelling(true); try { - const result = await cancelTask(user, taskID); - if (result.success) { - cancelledRef.current = true; - setTaskStatus(TaskStatus.REVOKED); - onTaskCancelled?.(); - } - } catch { - // Ignore cancel errors + await onCancelTask(activeTaskId); } finally { setIsCancelling(false); } }; const handleClearAll = async () => { - if (!user || isClearing) return; + if (isClearing) return; setIsClearing(true); try { - const result = await clearAllTasks(user); - if (result.success) { - cancelledRef.current = true; - setTaskStatus(null); - setTaskResult(null); - onTaskCancelled?.(); - } - } catch { - // Ignore clear errors + await onClearAllTasks(); } finally { setIsClearing(false); } }; - if (!taskID || !taskStatus) { + if (!activeTaskId || !taskStatus) { return null; } - const isInProgress = taskStatus !== TaskStatus.SUCCESS && - taskStatus !== TaskStatus.FAILURE && - taskStatus !== TaskStatus.REVOKED; + const isInProgress = !isTerminalStatus(taskStatus); const getStatusIcon = () => { if (isInProgress) { @@ -329,16 +201,16 @@ export function TaskIndicator({ {taskStatus} )} - {activeWsTaskCount > 1 && ( + {activeTaskCount > 1 && ( - {activeWsTaskCount} + {activeTaskCount} )}

{getTooltipContent()}

-

ID: {taskID.slice(0, 8)}...

+

ID: {activeTaskId.slice(0, 8)}...

{isInProgress && ( @@ -381,10 +253,10 @@ export function TaskIndicator({ onOpenChange={setDrawerOpen} taskResult={taskResult} taskStatus={taskStatus} - taskID={selectedTaskId ?? taskID} + taskID={selectedTaskId ?? activeTaskId} onCancel={handleCancel} isCancelling={isCancelling} - wsTasks={wsTasks} + tasks={tasks} selectedTaskId={selectedTaskId} onSelectTask={setSelectedTaskId} /> diff --git a/frontend/src/components/TaskProgressDrawer.tsx b/frontend/src/components/TaskProgressDrawer.tsx index a7df9ef..97b8f4d 100644 --- a/frontend/src/components/TaskProgressDrawer.tsx +++ b/frontend/src/components/TaskProgressDrawer.tsx @@ -19,7 +19,7 @@ interface TaskProgressDrawerProps { taskID: string | null; onCancel: () => void; isCancelling: boolean; - wsTasks?: Record; + tasks?: Record; selectedTaskId?: string | null; onSelectTask?: (taskId: string) => void; } @@ -415,16 +415,16 @@ export function TaskProgressDrawer({ taskID, onCancel, isCancelling, - wsTasks, + tasks, selectedTaskId, onSelectTask, }: TaskProgressDrawerProps) { // Determine which task's data to show - const hasMultipleTasks = wsTasks && Object.keys(wsTasks).length > 1; + const hasMultipleTasks = tasks && Object.keys(tasks).length > 1; const effectiveTaskId = selectedTaskId ?? taskID; - // Derive the active task data from wsTasks if available, else fall back to props - const activeWsTask = effectiveTaskId && wsTasks ? wsTasks[effectiveTaskId] : undefined; + // Derive the active task data from tasks if available, else fall back to props + const activeWsTask = effectiveTaskId && tasks ? tasks[effectiveTaskId] : undefined; const effectiveResult = activeWsTask ? taskStateToResult(activeWsTask) : taskResult; const effectiveStatus = activeWsTask ? (activeWsTask.status as TaskStatus) : taskStatus; const effectiveTaskType = activeWsTask @@ -462,7 +462,7 @@ export function TaskProgressDrawer({ {/* Multi-job tab bar */} {hasMultipleTasks && onSelectTask && ( diff --git a/frontend/src/hooks/useTaskWebSocket.ts b/frontend/src/hooks/useTaskWebSocket.ts deleted file mode 100644 index f7c491b..0000000 --- a/frontend/src/hooks/useTaskWebSocket.ts +++ /dev/null @@ -1,129 +0,0 @@ -import { useCallback, useEffect, useRef, useState } from 'react'; -import type { AuthUser } from '@/auth/types'; -import type { TaskState, WSMessage } from '@/types'; -import { WS_TASKS_PATH } from '@/constants'; - -const KEEPALIVE_MS = 30_000; -const MAX_RECONNECT_DELAY_MS = 30_000; - -function wsUrl(token: string): string { - const proto = window.location.protocol === 'https:' ? 'wss' : 'ws'; - return `${proto}://${window.location.host}${WS_TASKS_PATH}?token=${encodeURIComponent(token)}`; -} - -export interface UseTaskWebSocketReturn { - tasks: Record; - isConnected: boolean; - subscribe: (taskId: string) => void; -} - -export function useTaskWebSocket(user: AuthUser | null): UseTaskWebSocketReturn { - const [tasks, setTasks] = useState>({}); - const [isConnected, setIsConnected] = useState(false); - const wsRef = useRef(null); - const reconnectAttempt = useRef(0); - const reconnectTimer = useRef | null>(null); - const keepaliveTimer = useRef | null>(null); - const mountedRef = useRef(true); - - const clearTimers = useCallback(() => { - if (reconnectTimer.current) { - clearTimeout(reconnectTimer.current); - reconnectTimer.current = null; - } - if (keepaliveTimer.current) { - clearInterval(keepaliveTimer.current); - keepaliveTimer.current = null; - } - }, []); - - const connect = useCallback(() => { - if (!user) return; - - const ws = new WebSocket(wsUrl(user.accessToken)); - wsRef.current = ws; - - ws.onopen = () => { - if (!mountedRef.current) return; - setIsConnected(true); - reconnectAttempt.current = 0; - - // Start keepalive pings - keepaliveTimer.current = setInterval(() => { - if (ws.readyState === WebSocket.OPEN) { - ws.send(JSON.stringify({ type: 'ping' })); - } - }, KEEPALIVE_MS); - }; - - ws.onmessage = (event) => { - if (!mountedRef.current) return; - try { - const msg: WSMessage = JSON.parse(event.data); - - if (msg.type === 'init') { - const initial: Record = {}; - for (const t of msg.tasks) { - initial[t.task_id] = t; - } - setTasks(initial); - } else if (msg.type === 'task_update') { - const { type: _, ...taskData } = msg; - setTasks((prev) => ({ - ...prev, - [msg.task_id]: { ...prev[msg.task_id], ...taskData } as TaskState, - })); - } - // pong messages are ignored - } catch { - // Ignore malformed messages - } - }; - - ws.onclose = () => { - if (!mountedRef.current) return; - setIsConnected(false); - if (keepaliveTimer.current) { - clearInterval(keepaliveTimer.current); - keepaliveTimer.current = null; - } - - // Exponential backoff reconnect - const delay = Math.min( - 1000 * 2 ** reconnectAttempt.current, - MAX_RECONNECT_DELAY_MS, - ); - reconnectAttempt.current += 1; - reconnectTimer.current = setTimeout(() => { - if (mountedRef.current) connect(); - }, delay); - }; - - ws.onerror = () => { - // onclose will fire after this, triggering reconnect - }; - }, [user, clearTimers]); - - useEffect(() => { - mountedRef.current = true; - connect(); - - return () => { - mountedRef.current = false; - clearTimers(); - if (wsRef.current) { - wsRef.current.close(); - wsRef.current = null; - } - }; - }, [connect, clearTimers]); - - const subscribe = useCallback((taskId: string) => { - const ws = wsRef.current; - if (ws && ws.readyState === WebSocket.OPEN) { - ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId })); - } - }, []); - - return { tasks, isConnected, subscribe }; -} diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 025e80e..e8b17f1 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -48,9 +48,15 @@ export enum TaskStatus { } export interface TaskStatusResponse { - status: TaskStatus; - result: string; // JSON string containing TaskResult - message?: string; + task_id: string; + status: string; + result: string | null; // JSON string containing TaskResult, or null + progress: number | null; + processed: number | null; + total: number | null; + message: string | null; + error: string | null; + traceback: string | null; } export type TaskPhase = 'splitting' | 'splitting_complete' | 'fetching' | 'processing' | 'completed';