Refactor task progress to unified useTaskProgress hook

Replace WebSocket-only useTaskWebSocket with useTaskProgress that
provides a unified task state interface. TaskIndicator no longer
manages its own polling or auth — it receives task state from the
parent via props. Rename wsTasks prop to tasks throughout.
This commit is contained in:
Viktor Barzin 2026-02-09 23:02:24 +00:00
parent 3616e678ac
commit 2d86213db5
No known key found for this signature in database
GPG key ID: 0EB088298288D958
6 changed files with 130 additions and 363 deletions

View file

@ -17,13 +17,16 @@ import { Sheet, SheetContent, SheetTrigger } from './components/ui/sheet';
import { Button } from './components/ui/button';
import { Filter } from 'lucide-react';
import type { GeoJSONFeatureCollection, PropertyProperties, PropertyFeature, POI, POITravelFilter } from '@/types';
import { refreshListings, fetchTasksForUser, streamListingGeoJSON, fetchUserPOIs, type StreamingProgress } from '@/services';
import { refreshListings, streamListingGeoJSON, fetchUserPOIs, type StreamingProgress } from '@/services';
import { poiMetricPropertyName, injectPoiMetricProperty } from '@/utils/poiUtils';
import { useTaskWebSocket } from '@/hooks/useTaskWebSocket';
import { useTaskProgress } from '@/hooks/useTaskProgress';
function isTerminalStatus(status: string): boolean {
return status === 'SUCCESS' || status === 'FAILURE' || status === 'REVOKED';
}
function App() {
const [listingData, setListingData] = useState<GeoJSONFeatureCollection | null>(null);
const [taskID, setTaskID] = useState<string | null>(null);
const [user, setUser] = useState<AuthUser | null>(null);
const [queryParameters, setQueryParameters] = useState<ParameterValues | null>(null);
const [submitError, setSubmitError] = useState<string | null>(null);
@ -44,8 +47,26 @@ function App() {
const [poiTravelFilters, setPoiTravelFilters] = useState<Record<number, POITravelFilter>>({});
const [currentMetric, setCurrentMetric] = useState<Metric>(DEFAULT_FILTER_VALUES.metric);
// WebSocket-based real-time task progress
const { tasks: wsTasks, isConnected: wsConnected, subscribe: wsSubscribe } = useTaskWebSocket(user);
// Explicit task ID set by fetch-data action (to track as "active")
const [explicitTaskId, setExplicitTaskId] = useState<string | null>(null);
// Unified task progress: WS primary, polling fallback
const { tasks, isConnected, subscribe, cancelTask, clearAllTasks } = useTaskProgress(user);
// Derive activeTaskId: explicit ID if set, else most recent non-terminal task
const activeTaskId = useMemo(() => {
if (explicitTaskId && tasks[explicitTaskId]) return explicitTaskId;
// Fall back to any non-terminal task
const nonTerminal = Object.entries(tasks).filter(
([, t]) => !isTerminalStatus(t.status),
);
if (nonTerminal.length > 0) return nonTerminal[0][0];
// Fall back to explicit even if terminal (to show final status)
if (explicitTaskId && tasks[explicitTaskId]) return explicitTaskId;
// Show most recent task if any
const allIds = Object.keys(tasks);
return allIds.length > 0 ? allIds[allIds.length - 1] : null;
}, [explicitTaskId, tasks]);
// Ref to track accumulated features during streaming
const accumulatedFeaturesRef = useRef<PropertyFeature[]>([]);
@ -77,17 +98,6 @@ function App() {
setUser(passkeyUser);
};
useEffect(() => {
if (!user) {
return;
}
fetchTasksForUser(user).then((tasks) => {
if (tasks && tasks.length > 0) {
setTaskID(tasks[0]);
}
});
}, [user, taskID]);
// Load user's POIs
useEffect(() => {
if (!user) return;
@ -235,6 +245,10 @@ function App() {
}
}, [queryParameters, loadListings]);
const handleTaskCancelled = useCallback(() => {
setExplicitTaskId(null);
}, []);
if (!user) {
return <LoginModal isOpen={user === null} onPasskeyLogin={handlePasskeyLogin} />;
}
@ -248,8 +262,8 @@ function App() {
setIsLoading(true);
try {
const data = await refreshListings(user!, parameters);
setTaskID(data.task_id);
if (data.task_id) wsSubscribe(data.task_id);
setExplicitTaskId(data.task_id);
if (data.task_id) subscribe(data.task_id);
} catch (error) {
if (error instanceof Error) {
setSubmitError(error.message);
@ -347,13 +361,9 @@ function App() {
);
};
const handleTaskCancelled = () => {
setTaskID(null);
};
const handlePOITaskCreated = (taskId: string) => {
setTaskID(taskId);
if (taskId) wsSubscribe(taskId);
setExplicitTaskId(taskId);
if (taskId) subscribe(taskId);
// Refresh POI list in case new ones were created
if (user) {
fetchUserPOIs(user).then(setUserPOIs).catch(() => {});
@ -379,12 +389,18 @@ function App() {
{/* Header */}
<Header
user={user}
taskID={taskID}
onTaskCancelled={handleTaskCancelled}
tasks={tasks}
activeTaskId={activeTaskId}
isConnected={isConnected}
onCancelTask={cancelTask}
onClearAllTasks={async () => {
const result = await clearAllTasks();
if (result) {
handleTaskCancelled();
}
return result;
}}
onTaskCompleted={handleTaskCompleted}
wsTasks={wsTasks}
wsConnected={wsConnected}
wsSubscribe={wsSubscribe}
/>
{/* Main content area */}

View file

@ -11,28 +11,29 @@ import { TaskIndicator } from './TaskIndicator';
interface HeaderProps {
user: AuthUser;
activeFilterCount?: number;
taskID?: string | null;
isLoading?: boolean;
onToggleFilters?: () => void;
showFilterToggle?: boolean;
onTaskCancelled?: () => void;
// Task progress (unified)
tasks: Record<string, TaskState>;
activeTaskId: string | null;
isConnected: boolean;
onCancelTask: (taskId: string) => Promise<boolean>;
onClearAllTasks: () => Promise<boolean>;
onTaskCompleted?: () => void;
wsTasks?: Record<string, TaskState>;
wsConnected?: boolean;
wsSubscribe?: (taskId: string) => void;
}
export function Header({
user,
activeFilterCount = 0,
taskID,
onToggleFilters,
showFilterToggle = false,
onTaskCancelled,
tasks,
activeTaskId,
isConnected,
onCancelTask,
onClearAllTasks,
onTaskCompleted,
wsTasks,
wsConnected,
wsSubscribe,
}: HeaderProps) {
const handleLogout = async () => {
if (user.provider === 'passkey') {
@ -58,11 +59,12 @@ export function Header({
{/* Task Indicator */}
<TaskIndicator
taskID={taskID ?? null}
onTaskCancelled={onTaskCancelled}
tasks={tasks}
activeTaskId={activeTaskId}
isConnected={isConnected}
onCancelTask={onCancelTask}
onClearAllTasks={onClearAllTasks}
onTaskCompleted={onTaskCompleted}
wsTasks={wsTasks}
wsConnected={wsConnected}
/>
{/* Filter Toggle (mobile) */}

View file

@ -1,8 +1,3 @@
import { getUser } from '@/auth/authService';
import { getStoredPasskeyUser } from '@/auth/passkeyService';
import { fromOidcUser, type AuthUser } from '@/auth/types';
import { POLLING_INTERVALS } from '@/constants';
import { fetchTaskStatus, cancelTask, clearAllTasks } from '@/services';
import { TaskStatus, type TaskResult, type TaskState } from '@/types';
import { useEffect, useState, useRef, useMemo } from 'react';
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from './ui/tooltip';
@ -11,14 +6,15 @@ import { Loader2, CheckCircle2, XCircle, X, Trash2 } from 'lucide-react';
import { TaskProgressDrawer } from './TaskProgressDrawer';
interface TaskIndicatorProps {
taskID: string | null;
onTaskCancelled?: () => void;
tasks: Record<string, TaskState>;
activeTaskId: string | null;
isConnected: boolean;
onCancelTask: (taskId: string) => Promise<boolean>;
onClearAllTasks: () => Promise<boolean>;
onTaskCompleted?: () => void;
wsTasks?: Record<string, TaskState>;
wsConnected?: boolean;
}
/** Convert a TaskState (from WS) into a TaskResult (for the drawer). */
/** Convert a TaskState into a TaskResult (for the drawer). */
function taskStateToResult(ts: TaskState): TaskResult {
return {
progress: ts.progress ?? 0,
@ -50,215 +46,91 @@ function isTerminalStatus(status: string): boolean {
}
export function TaskIndicator({
taskID,
onTaskCancelled,
tasks,
activeTaskId,
isConnected: _isConnected,
onCancelTask,
onClearAllTasks,
onTaskCompleted,
wsTasks,
wsConnected,
}: TaskIndicatorProps) {
const [user, setUser] = useState<AuthUser | null>(null);
const [progressPercentage, setProgressPercentage] = useState<number>(0);
const [processed, setProcessed] = useState<number | null>(null);
const [total, setTotal] = useState<number | null>(null);
const [taskStatus, setTaskStatus] = useState<TaskStatus | null>(null);
const [taskResult, setTaskResult] = useState<TaskResult | null>(null);
const [isCancelling, setIsCancelling] = useState(false);
const [isClearing, setIsClearing] = useState(false);
const [drawerOpen, setDrawerOpen] = useState(false);
const [selectedTaskId, setSelectedTaskId] = useState<string | null>(null);
// Prevents WS effect from overwriting local cancel/clear state before
// the parent's setTaskID(null) propagates. Reset when taskID changes.
const cancelledRef = useRef(false);
const onTaskCompletedRef = useRef(onTaskCompleted);
useEffect(() => {
onTaskCompletedRef.current = onTaskCompleted;
}, [onTaskCompleted]);
// Track the currently-viewed task in the drawer; default to the externally-provided activeTaskId
useEffect(() => {
const passkeyUser = getStoredPasskeyUser();
if (passkeyUser) {
setUser(passkeyUser);
} else {
getUser().then((oidcUser) => {
if (oidcUser) setUser(fromOidcUser(oidcUser));
});
if (activeTaskId) {
setSelectedTaskId(activeTaskId);
}
}, []);
}, [activeTaskId]);
// Track the currently-viewed task in the drawer; default to the externally-provided taskID
// Fire onTaskCompleted when the active task transitions to SUCCESS
const prevStatusRef = useRef<string | null>(null);
useEffect(() => {
if (taskID) {
setSelectedTaskId(taskID);
cancelledRef.current = false; // new task, reset cancelled guard
}
}, [taskID]);
// Count active (non-terminal) tasks from WS
const activeWsTaskCount = useMemo(() => {
if (!wsTasks) return 0;
return Object.values(wsTasks).filter(
(t) => !isTerminalStatus(t.status),
).length;
}, [wsTasks]);
// ----- WebSocket-driven state updates -----
// When wsConnected, derive taskStatus/taskResult/progress from wsTasks
useEffect(() => {
if (!wsConnected || !wsTasks || !taskID) return;
// Don't let WS overwrite local cancel/clear state
if (cancelledRef.current) return;
const wsTask = wsTasks[taskID];
if (!wsTask) return;
const status = wsTask.status as TaskStatus;
setTaskStatus(status);
if (wsTask.phase) {
setTaskResult(taskStateToResult(wsTask));
}
if (wsTask.progress !== undefined) {
setProgressPercentage(wsTask.progress * 100);
}
if (wsTask.processed !== undefined) {
setProcessed(wsTask.processed);
}
if (wsTask.total !== undefined) {
setTotal(wsTask.total);
}
if (status === TaskStatus.SUCCESS) {
setProgressPercentage(100);
onTaskCompletedRef.current?.();
}
}, [wsConnected, wsTasks, taskID]);
// ----- Polling (always active as baseline; WS provides faster updates on top) -----
useEffect(() => {
if (!user || !taskID) {
setTaskStatus(null);
setTaskResult(null);
if (!activeTaskId) {
prevStatusRef.current = null;
return;
}
const task = tasks[activeTaskId];
const currentStatus = task?.status ?? null;
if (
currentStatus === 'SUCCESS' &&
prevStatusRef.current !== null &&
prevStatusRef.current !== 'SUCCESS'
) {
onTaskCompletedRef.current?.();
}
prevStatusRef.current = currentStatus;
}, [activeTaskId, tasks]);
// Reset state for new task
setTaskStatus(TaskStatus.PENDING);
setProgressPercentage(0);
setProcessed(null);
setTotal(null);
setTaskResult(null);
// Derive display data from the active task
const activeTask = activeTaskId ? tasks[activeTaskId] : undefined;
const taskStatus = activeTask ? (activeTask.status as TaskStatus) : null;
const taskResult = activeTask?.phase ? taskStateToResult(activeTask) : null;
const progressPercentage = (activeTask?.progress ?? 0) * 100;
const processed = activeTask?.processed ?? null;
const total = activeTask?.total ?? null;
const pollTaskStatus = async () => {
// Skip this poll cycle if cancelled locally
if (cancelledRef.current) return true;
try {
const data = await fetchTaskStatus(user, taskID);
const status = data.status as TaskStatus;
setTaskStatus(status);
if (status === TaskStatus.SUCCESS) {
setProgressPercentage(100);
if (data.result) {
try {
const parsedResult: TaskResult = JSON.parse(data.result);
if (parsedResult.phase) {
setTaskResult(parsedResult);
}
} catch {
// Ignore parsing errors
}
}
onTaskCompletedRef.current?.();
return true;
}
if (status === TaskStatus.FAILURE || status === TaskStatus.REVOKED) {
return true;
}
if (data.result) {
try {
const parsedResult: TaskResult = JSON.parse(data.result);
if (parsedResult.phase) {
setTaskResult(parsedResult);
}
if (parsedResult.progress !== undefined) {
setProgressPercentage(parsedResult.progress * 100);
}
if (parsedResult.processed !== undefined) {
setProcessed(parsedResult.processed);
}
if (parsedResult.total !== undefined) {
setTotal(parsedResult.total);
}
} catch {
// Ignore parsing errors
}
}
return false;
} catch {
setTaskStatus(TaskStatus.FAILURE);
return true;
}
};
pollTaskStatus();
const interval = setInterval(async () => {
const shouldStop = await pollTaskStatus();
if (shouldStop) {
clearInterval(interval);
}
}, POLLING_INTERVALS.TASK_STATUS_MS);
return () => clearInterval(interval);
}, [taskID, user]);
// Count active (non-terminal) tasks
const activeTaskCount = useMemo(() => {
return Object.values(tasks).filter(
(t) => !isTerminalStatus(t.status),
).length;
}, [tasks]);
const handleCancel = async () => {
if (!user || !taskID || isCancelling) return;
if (!activeTaskId || isCancelling) return;
setIsCancelling(true);
try {
const result = await cancelTask(user, taskID);
if (result.success) {
cancelledRef.current = true;
setTaskStatus(TaskStatus.REVOKED);
onTaskCancelled?.();
}
} catch {
// Ignore cancel errors
await onCancelTask(activeTaskId);
} finally {
setIsCancelling(false);
}
};
const handleClearAll = async () => {
if (!user || isClearing) return;
if (isClearing) return;
setIsClearing(true);
try {
const result = await clearAllTasks(user);
if (result.success) {
cancelledRef.current = true;
setTaskStatus(null);
setTaskResult(null);
onTaskCancelled?.();
}
} catch {
// Ignore clear errors
await onClearAllTasks();
} finally {
setIsClearing(false);
}
};
if (!taskID || !taskStatus) {
if (!activeTaskId || !taskStatus) {
return null;
}
const isInProgress = taskStatus !== TaskStatus.SUCCESS &&
taskStatus !== TaskStatus.FAILURE &&
taskStatus !== TaskStatus.REVOKED;
const isInProgress = !isTerminalStatus(taskStatus);
const getStatusIcon = () => {
if (isInProgress) {
@ -329,16 +201,16 @@ export function TaskIndicator({
{taskStatus}
</span>
)}
{activeWsTaskCount > 1 && (
{activeTaskCount > 1 && (
<span className="inline-flex items-center justify-center h-4 min-w-[16px] rounded-full bg-blue-500 text-[10px] font-medium text-white px-1">
{activeWsTaskCount}
{activeTaskCount}
</span>
)}
</div>
</TooltipTrigger>
<TooltipContent side="bottom">
<p>{getTooltipContent()}</p>
<p className="text-xs text-muted-foreground mt-1">ID: {taskID.slice(0, 8)}...</p>
<p className="text-xs text-muted-foreground mt-1">ID: {activeTaskId.slice(0, 8)}...</p>
</TooltipContent>
</Tooltip>
{isInProgress && (
@ -381,10 +253,10 @@ export function TaskIndicator({
onOpenChange={setDrawerOpen}
taskResult={taskResult}
taskStatus={taskStatus}
taskID={selectedTaskId ?? taskID}
taskID={selectedTaskId ?? activeTaskId}
onCancel={handleCancel}
isCancelling={isCancelling}
wsTasks={wsTasks}
tasks={tasks}
selectedTaskId={selectedTaskId}
onSelectTask={setSelectedTaskId}
/>

View file

@ -19,7 +19,7 @@ interface TaskProgressDrawerProps {
taskID: string | null;
onCancel: () => void;
isCancelling: boolean;
wsTasks?: Record<string, TaskState>;
tasks?: Record<string, TaskState>;
selectedTaskId?: string | null;
onSelectTask?: (taskId: string) => void;
}
@ -415,16 +415,16 @@ export function TaskProgressDrawer({
taskID,
onCancel,
isCancelling,
wsTasks,
tasks,
selectedTaskId,
onSelectTask,
}: TaskProgressDrawerProps) {
// Determine which task's data to show
const hasMultipleTasks = wsTasks && Object.keys(wsTasks).length > 1;
const hasMultipleTasks = tasks && Object.keys(tasks).length > 1;
const effectiveTaskId = selectedTaskId ?? taskID;
// Derive the active task data from wsTasks if available, else fall back to props
const activeWsTask = effectiveTaskId && wsTasks ? wsTasks[effectiveTaskId] : undefined;
// Derive the active task data from tasks if available, else fall back to props
const activeWsTask = effectiveTaskId && tasks ? tasks[effectiveTaskId] : undefined;
const effectiveResult = activeWsTask ? taskStateToResult(activeWsTask) : taskResult;
const effectiveStatus = activeWsTask ? (activeWsTask.status as TaskStatus) : taskStatus;
const effectiveTaskType = activeWsTask
@ -462,7 +462,7 @@ export function TaskProgressDrawer({
{/* Multi-job tab bar */}
{hasMultipleTasks && onSelectTask && (
<TaskTabBar
tasks={wsTasks!}
tasks={tasks!}
selectedTaskId={effectiveTaskId}
onSelectTask={onSelectTask}
/>

View file

@ -1,129 +0,0 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import type { AuthUser } from '@/auth/types';
import type { TaskState, WSMessage } from '@/types';
import { WS_TASKS_PATH } from '@/constants';
const KEEPALIVE_MS = 30_000;
const MAX_RECONNECT_DELAY_MS = 30_000;
function wsUrl(token: string): string {
const proto = window.location.protocol === 'https:' ? 'wss' : 'ws';
return `${proto}://${window.location.host}${WS_TASKS_PATH}?token=${encodeURIComponent(token)}`;
}
export interface UseTaskWebSocketReturn {
tasks: Record<string, TaskState>;
isConnected: boolean;
subscribe: (taskId: string) => void;
}
export function useTaskWebSocket(user: AuthUser | null): UseTaskWebSocketReturn {
const [tasks, setTasks] = useState<Record<string, TaskState>>({});
const [isConnected, setIsConnected] = useState(false);
const wsRef = useRef<WebSocket | null>(null);
const reconnectAttempt = useRef(0);
const reconnectTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
const keepaliveTimer = useRef<ReturnType<typeof setInterval> | null>(null);
const mountedRef = useRef(true);
const clearTimers = useCallback(() => {
if (reconnectTimer.current) {
clearTimeout(reconnectTimer.current);
reconnectTimer.current = null;
}
if (keepaliveTimer.current) {
clearInterval(keepaliveTimer.current);
keepaliveTimer.current = null;
}
}, []);
const connect = useCallback(() => {
if (!user) return;
const ws = new WebSocket(wsUrl(user.accessToken));
wsRef.current = ws;
ws.onopen = () => {
if (!mountedRef.current) return;
setIsConnected(true);
reconnectAttempt.current = 0;
// Start keepalive pings
keepaliveTimer.current = setInterval(() => {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: 'ping' }));
}
}, KEEPALIVE_MS);
};
ws.onmessage = (event) => {
if (!mountedRef.current) return;
try {
const msg: WSMessage = JSON.parse(event.data);
if (msg.type === 'init') {
const initial: Record<string, TaskState> = {};
for (const t of msg.tasks) {
initial[t.task_id] = t;
}
setTasks(initial);
} else if (msg.type === 'task_update') {
const { type: _, ...taskData } = msg;
setTasks((prev) => ({
...prev,
[msg.task_id]: { ...prev[msg.task_id], ...taskData } as TaskState,
}));
}
// pong messages are ignored
} catch {
// Ignore malformed messages
}
};
ws.onclose = () => {
if (!mountedRef.current) return;
setIsConnected(false);
if (keepaliveTimer.current) {
clearInterval(keepaliveTimer.current);
keepaliveTimer.current = null;
}
// Exponential backoff reconnect
const delay = Math.min(
1000 * 2 ** reconnectAttempt.current,
MAX_RECONNECT_DELAY_MS,
);
reconnectAttempt.current += 1;
reconnectTimer.current = setTimeout(() => {
if (mountedRef.current) connect();
}, delay);
};
ws.onerror = () => {
// onclose will fire after this, triggering reconnect
};
}, [user, clearTimers]);
useEffect(() => {
mountedRef.current = true;
connect();
return () => {
mountedRef.current = false;
clearTimers();
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
}
};
}, [connect, clearTimers]);
const subscribe = useCallback((taskId: string) => {
const ws = wsRef.current;
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId }));
}
}, []);
return { tasks, isConnected, subscribe };
}

View file

@ -48,9 +48,15 @@ export enum TaskStatus {
}
export interface TaskStatusResponse {
status: TaskStatus;
result: string; // JSON string containing TaskResult
message?: string;
task_id: string;
status: string;
result: string | null; // JSON string containing TaskResult, or null
progress: number | null;
processed: number | null;
total: number | null;
message: string | null;
error: string | null;
traceback: string | null;
}
export type TaskPhase = 'splitting' | 'splitting_complete' | 'fetching' | 'processing' | 'completed';