From 8559c4b461dc89a47b3d77e81bf13f44ec09953e Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Mon, 9 Feb 2026 21:31:45 +0000 Subject: [PATCH] Add real-time WebSocket task progress with multi-job drawer Replace 5s HTTP polling with WebSocket-based real-time updates for task progress. Celery workers publish progress to Redis pub/sub channels; a FastAPI WebSocket endpoint subscribes and forwards to the browser. Polling is kept as a 30s fallback when WebSocket is unavailable. The task progress drawer now supports multiple concurrent jobs with a tab bar for switching between scrape and POI distance tasks. Backend: - Add services/task_progress_publisher.py (Redis pub/sub bridge) - Add api/ws_routes.py (WebSocket endpoint with JWT auth) - Publish progress from listing_tasks and poi_tasks - Publish REVOKED via pub/sub on cancel/clear to fix stuck UI Frontend: - Add useTaskWebSocket hook with reconnection and keepalive - Add TaskState and WS message types - TaskIndicator: WS-driven updates with polling fallback - TaskProgressDrawer: multi-job tabs, POI phase timeline - Guard against WS overwriting local cancel state --- api/ws_routes.py | 153 +++++++++++ frontend/src/components/Header.tsx | 15 +- frontend/src/components/TaskIndicator.tsx | 135 +++++++-- .../src/components/TaskProgressDrawer.tsx | 259 +++++++++++++++--- frontend/src/constants/index.ts | 5 +- frontend/src/hooks/useTaskWebSocket.ts | 129 +++++++++ frontend/src/types/index.ts | 52 ++++ services/task_progress_publisher.py | 50 ++++ services/task_service.py | 15 + tasks/listing_tasks.py | 12 +- tasks/poi_tasks.py | 21 +- 11 files changed, 774 insertions(+), 72 deletions(-) create mode 100644 api/ws_routes.py create mode 100644 frontend/src/hooks/useTaskWebSocket.ts create mode 100644 services/task_progress_publisher.py diff --git a/api/ws_routes.py b/api/ws_routes.py new file mode 100644 index 0000000..1e90741 --- /dev/null +++ b/api/ws_routes.py @@ -0,0 +1,153 @@ +"""WebSocket endpoint for real-time task progress updates. + +Clients connect to ``/api/ws/tasks?token=`` and receive live progress +messages published by Celery workers via Redis pub/sub. +""" +import asyncio +import json +import logging +from typing import Any + +import jwt +import redis.asyncio as aioredis +from fastapi import APIRouter, WebSocket, WebSocketDisconnect + +from api.auth import _verify_authentik_token, _verify_passkey_token, User +from api.config import JWT_ISSUER +from services import task_service + +logger = logging.getLogger(__name__) + +ws_router = APIRouter() + +# Reuse the broker URL for the async Redis client +import os +_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0") + + +async def _authenticate_ws(token: str) -> User | None: + """Verify a JWT token using the same logic as api/auth.py.""" + try: + unverified = jwt.decode( + token, options={"verify_signature": False, "verify_exp": False} + ) + issuer = unverified.get("iss", "") + if issuer == JWT_ISSUER: + return _verify_passkey_token(token) + else: + return await _verify_authentik_token(token) + except Exception: + return None + + +async def _build_task_snapshot(task_id: str) -> dict[str, Any]: + """Build a snapshot of a task's current status for the init message.""" + status = task_service.get_task_status(task_id) + result: dict[str, Any] = { + "task_id": status.task_id, + "status": status.status, + "progress": status.progress, + "processed": status.processed, + "total": status.total, + "message": status.message, + } + if status.result and isinstance(status.result, dict): + result.update(status.result) + return result + + +@ws_router.websocket("/api/ws/tasks") +async def ws_task_progress(websocket: WebSocket) -> None: + token = websocket.query_params.get("token") + if not token: + await websocket.close(code=4001, reason="Missing token") + return + + user = await _authenticate_ws(token) + if user is None: + await websocket.close(code=4003, reason="Invalid token") + return + + await websocket.accept() + + # Get user's tasks and send initial snapshot + task_ids = task_service.get_user_tasks(user.email) + snapshots = [] + for tid in task_ids: + try: + snapshots.append(await _build_task_snapshot(tid)) + except Exception: + logger.debug("Failed to build snapshot for task %s", tid, exc_info=True) + + try: + await websocket.send_json({"type": "init", "tasks": snapshots}) + except Exception: + return + + # Subscribe to Redis pub/sub channels for each task + redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True) + pubsub = redis_client.pubsub() + + subscribed_channels: set[str] = set() + for tid in task_ids: + channel = f"task_progress:{tid}" + await pubsub.subscribe(channel) + subscribed_channels.add(channel) + + async def _forward_pubsub() -> None: + """Read from Redis pub/sub and forward to the WebSocket.""" + while True: + message = await pubsub.get_message( + ignore_subscribe_messages=True, timeout=1.0 + ) + if message and message["type"] == "message": + try: + data = json.loads(message["data"]) + await websocket.send_json({"type": "task_update", **data}) + except Exception: + break + + async def _handle_client_messages() -> None: + """Read messages from the client (subscribe, ping).""" + while True: + try: + raw = await websocket.receive_text() + msg = json.loads(raw) + except WebSocketDisconnect: + raise + except Exception: + continue + + msg_type = msg.get("type") + if msg_type == "subscribe": + new_task_id = msg.get("task_id") + if new_task_id: + channel = f"task_progress:{new_task_id}" + if channel not in subscribed_channels: + await pubsub.subscribe(channel) + subscribed_channels.add(channel) + # Send current snapshot for the new task + try: + snapshot = await _build_task_snapshot(new_task_id) + await websocket.send_json( + {"type": "task_update", **snapshot} + ) + except Exception: + pass + elif msg_type == "ping": + try: + await websocket.send_json({"type": "pong"}) + except Exception: + break + + try: + await asyncio.gather( + _forward_pubsub(), + _handle_client_messages(), + ) + except (WebSocketDisconnect, Exception): + pass + finally: + await pubsub.unsubscribe(*subscribed_channels) + await pubsub.close() + await redis_client.aclose() diff --git a/frontend/src/components/Header.tsx b/frontend/src/components/Header.tsx index c49f377..f2e5c42 100644 --- a/frontend/src/components/Header.tsx +++ b/frontend/src/components/Header.tsx @@ -1,4 +1,5 @@ import type { AuthUser } from '@/auth/types'; +import type { TaskState } from '@/types'; import { Button } from './ui/button'; import { Separator } from './ui/separator'; import { LogOut, Home, Filter } from 'lucide-react'; @@ -16,6 +17,9 @@ interface HeaderProps { showFilterToggle?: boolean; onTaskCancelled?: () => void; onTaskCompleted?: () => void; + wsTasks?: Record; + wsConnected?: boolean; + wsSubscribe?: (taskId: string) => void; } export function Header({ @@ -26,6 +30,9 @@ export function Header({ showFilterToggle = false, onTaskCancelled, onTaskCompleted, + wsTasks, + wsConnected, + wsSubscribe, }: HeaderProps) { const handleLogout = async () => { if (user.provider === 'passkey') { @@ -50,7 +57,13 @@ export function Header({ {/* Task Indicator */} - + {/* Filter Toggle (mobile) */} {showFilterToggle && ( diff --git a/frontend/src/components/TaskIndicator.tsx b/frontend/src/components/TaskIndicator.tsx index a297696..140e45c 100644 --- a/frontend/src/components/TaskIndicator.tsx +++ b/frontend/src/components/TaskIndicator.tsx @@ -3,8 +3,8 @@ import { getStoredPasskeyUser } from '@/auth/passkeyService'; import { fromOidcUser, type AuthUser } from '@/auth/types'; import { POLLING_INTERVALS } from '@/constants'; import { fetchTaskStatus, cancelTask, clearAllTasks } from '@/services'; -import { TaskStatus, type TaskResult } from '@/types'; -import { useEffect, useState, useRef } from 'react'; +import { TaskStatus, type TaskResult, type TaskState } from '@/types'; +import { useEffect, useState, useRef, useMemo } from 'react'; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from './ui/tooltip'; import { Button } from './ui/button'; import { Loader2, CheckCircle2, XCircle, X, Trash2 } from 'lucide-react'; @@ -14,9 +14,48 @@ interface TaskIndicatorProps { taskID: string | null; onTaskCancelled?: () => void; onTaskCompleted?: () => void; + wsTasks?: Record; + wsConnected?: boolean; } -export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: TaskIndicatorProps) { +/** Convert a TaskState (from WS) into a TaskResult (for the drawer). */ +function taskStateToResult(ts: TaskState): TaskResult { + return { + progress: ts.progress ?? 0, + processed: ts.processed, + total: ts.total, + phase: ts.phase, + message: ts.message, + subqueries_probed: ts.subqueries_probed, + subqueries_initial: ts.subqueries_initial, + estimated_results: ts.estimated_results, + subqueries_total: ts.subqueries_total, + subqueries_completed: ts.subqueries_completed, + ids_collected: ts.ids_collected, + pages_fetched: ts.pages_fetched, + fetching_done: ts.fetching_done, + details_fetched: ts.details_fetched, + images_downloaded: ts.images_downloaded, + ocr_completed: ts.ocr_completed, + failed: ts.failed, + elapsed_seconds: ts.elapsed_seconds, + rate_per_second: ts.rate_per_second, + eta_seconds: ts.eta_seconds, + logs: ts.logs, + }; +} + +function isTerminalStatus(status: string): boolean { + return status === 'SUCCESS' || status === 'FAILURE' || status === 'REVOKED'; +} + +export function TaskIndicator({ + taskID, + onTaskCancelled, + onTaskCompleted, + wsTasks, + wsConnected, +}: TaskIndicatorProps) { const [user, setUser] = useState(null); const [progressPercentage, setProgressPercentage] = useState(0); const [processed, setProcessed] = useState(null); @@ -26,6 +65,11 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task const [isCancelling, setIsCancelling] = useState(false); const [isClearing, setIsClearing] = useState(false); const [drawerOpen, setDrawerOpen] = useState(false); + const [selectedTaskId, setSelectedTaskId] = useState(null); + + // Prevents WS effect from overwriting local cancel/clear state before + // the parent's setTaskID(null) propagates. Reset when taskID changes. + const cancelledRef = useRef(false); const onTaskCompletedRef = useRef(onTaskCompleted); useEffect(() => { @@ -43,7 +87,58 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task } }, []); + // Track the currently-viewed task in the drawer; default to the externally-provided taskID useEffect(() => { + if (taskID) { + setSelectedTaskId(taskID); + cancelledRef.current = false; // new task, reset cancelled guard + } + }, [taskID]); + + // Count active (non-terminal) tasks from WS + const activeWsTaskCount = useMemo(() => { + if (!wsTasks) return 0; + return Object.values(wsTasks).filter( + (t) => !isTerminalStatus(t.status), + ).length; + }, [wsTasks]); + + // ----- WebSocket-driven state updates ----- + // When wsConnected, derive taskStatus/taskResult/progress from wsTasks + useEffect(() => { + if (!wsConnected || !wsTasks || !taskID) return; + // Don't let WS overwrite local cancel/clear state + if (cancelledRef.current) return; + const wsTask = wsTasks[taskID]; + if (!wsTask) return; + + const status = wsTask.status as TaskStatus; + setTaskStatus(status); + + if (wsTask.phase) { + setTaskResult(taskStateToResult(wsTask)); + } + if (wsTask.progress !== undefined) { + setProgressPercentage(wsTask.progress * 100); + } + if (wsTask.processed !== undefined) { + setProcessed(wsTask.processed); + } + if (wsTask.total !== undefined) { + setTotal(wsTask.total); + } + + if (status === TaskStatus.SUCCESS) { + setProgressPercentage(100); + onTaskCompletedRef.current?.(); + } + }, [wsConnected, wsTasks, taskID]); + + // ----- Polling fallback (only when WS is not connected) ----- + useEffect(() => { + // If WS is connected, skip polling + if (wsConnected) return; + if (!user || !taskID) { setTaskStatus(null); setTaskResult(null); @@ -65,10 +160,6 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task if (status === TaskStatus.SUCCESS) { setProgressPercentage(100); - // Parse final result for the drawer to show completed state. - // Only update taskResult if the new result has phase info; - // otherwise keep the last in-progress result which has richer data - // than the bare SUCCESS return value. if (data.result) { try { const parsedResult: TaskResult = JSON.parse(data.result); @@ -80,26 +171,19 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task } } onTaskCompletedRef.current?.(); - return true; // Stop polling + return true; } if (status === TaskStatus.FAILURE || status === TaskStatus.REVOKED) { - return true; // Stop polling + return true; } - // Parse progress for in-progress tasks if (data.result) { try { const parsedResult: TaskResult = JSON.parse(data.result); - // Only update taskResult if the parsed data has a phase field. - // This prevents blanking the drawer when the backend sends a - // state update without phase info (e.g. during brief transitions). if (parsedResult.phase) { setTaskResult(parsedResult); } - // Only update progress/processed/total when the fields are - // actually present — otherwise keep the previous values so - // the UI doesn't flash back to 0 during phase transitions. if (parsedResult.progress !== undefined) { setProgressPercentage(parsedResult.progress * 100); } @@ -113,14 +197,13 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task // Ignore parsing errors } } - return false; // Continue polling + return false; } catch { setTaskStatus(TaskStatus.FAILURE); - return true; // Stop polling on error + return true; } }; - // Initial poll pollTaskStatus(); const interval = setInterval(async () => { @@ -131,7 +214,7 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task }, POLLING_INTERVALS.TASK_STATUS_MS); return () => clearInterval(interval); - }, [taskID, user]); + }, [taskID, user, wsConnected]); const handleCancel = async () => { if (!user || !taskID || isCancelling) return; @@ -140,6 +223,7 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task try { const result = await cancelTask(user, taskID); if (result.success) { + cancelledRef.current = true; setTaskStatus(TaskStatus.REVOKED); onTaskCancelled?.(); } @@ -157,6 +241,7 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task try { const result = await clearAllTasks(user); if (result.success) { + cancelledRef.current = true; setTaskStatus(null); setTaskResult(null); onTaskCancelled?.(); @@ -245,6 +330,11 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task {taskStatus} )} + {activeWsTaskCount > 1 && ( + + {activeWsTaskCount} + + )} @@ -292,9 +382,12 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task onOpenChange={setDrawerOpen} taskResult={taskResult} taskStatus={taskStatus} - taskID={taskID} + taskID={selectedTaskId ?? taskID} onCancel={handleCancel} isCancelling={isCancelling} + wsTasks={wsTasks} + selectedTaskId={selectedTaskId} + onSelectTask={setSelectedTaskId} /> ); diff --git a/frontend/src/components/TaskProgressDrawer.tsx b/frontend/src/components/TaskProgressDrawer.tsx index c1ac3c3..a7df9ef 100644 --- a/frontend/src/components/TaskProgressDrawer.tsx +++ b/frontend/src/components/TaskProgressDrawer.tsx @@ -1,4 +1,4 @@ -import { TaskStatus, type TaskPhase, type TaskResult } from '@/types'; +import { TaskStatus, type TaskPhase, type TaskResult, type TaskState } from '@/types'; import { Sheet, SheetContent, @@ -8,8 +8,8 @@ import { SheetFooter, } from './ui/sheet'; import { Button } from './ui/button'; -import { CheckCircle2, Circle, Loader2, XCircle } from 'lucide-react'; -import { useEffect, useRef } from 'react'; +import { CheckCircle2, Circle, Loader2, XCircle, MapPin, Search } from 'lucide-react'; +import { useEffect, useRef, useMemo } from 'react'; interface TaskProgressDrawerProps { open: boolean; @@ -19,19 +19,62 @@ interface TaskProgressDrawerProps { taskID: string | null; onCancel: () => void; isCancelling: boolean; + wsTasks?: Record; + selectedTaskId?: string | null; + onSelectTask?: (taskId: string) => void; } -const PHASES: { key: TaskPhase; label: string }[] = [ +const SCRAPE_PHASES: { key: TaskPhase; label: string }[] = [ { key: 'splitting', label: 'Splitting queries' }, { key: 'fetching', label: 'Fetching & processing' }, { key: 'processing', label: 'Processing remaining' }, ]; +const POI_PHASES: { key: string; label: string }[] = [ + { key: 'starting', label: 'Starting' }, + { key: 'computing', label: 'Computing distances' }, + { key: 'completed', label: 'Completed' }, +]; + +function inferTaskType(task: TaskState | TaskResult): 'scrape' | 'poi' | 'task' { + if ('distances_computed' in task && task.distances_computed !== undefined) return 'poi'; + if ('subqueries_completed' in task || 'ids_collected' in task || 'pages_fetched' in task) return 'scrape'; + const phase = task.phase; + if (phase === 'starting' || phase === 'computing') return 'poi'; + if (phase === 'splitting' || phase === 'splitting_complete' || phase === 'fetching' || phase === 'processing') return 'scrape'; + return 'task'; +} + +function taskTypeLabel(type: 'scrape' | 'poi' | 'task'): string { + switch (type) { + case 'scrape': return 'Scrape'; + case 'poi': return 'POI Distances'; + default: return 'Task'; + } +} + +function taskTypeIcon(type: 'scrape' | 'poi' | 'task') { + switch (type) { + case 'scrape': return ; + case 'poi': return ; + default: return ; + } +} + +function isTerminalStatus(status: string): boolean { + return status === 'SUCCESS' || status === 'FAILURE' || status === 'REVOKED'; +} + function getPhaseIndex(phase: TaskPhase | undefined): number { if (!phase) return -1; - if (phase === 'splitting_complete') return 1; // splitting done, fetching is next - if (phase === 'completed') return PHASES.length; - return PHASES.findIndex((p) => p.key === phase); + if (phase === 'splitting_complete') return 1; + if (phase === 'completed') return SCRAPE_PHASES.length; + return SCRAPE_PHASES.findIndex((p) => p.key === phase); +} + +function getPoiPhaseIndex(phase: string | undefined): number { + if (!phase) return -1; + return POI_PHASES.findIndex((p) => p.key === phase); } function formatEta(seconds: number | undefined): string { @@ -44,13 +87,10 @@ function formatEta(seconds: number | undefined): string { return `~${secs}s remaining`; } -function StatusBadge({ status }: { status: TaskStatus | null }) { +function StatusBadge({ status }: { status: TaskStatus | string | null }) { if (!status) return null; - const isInProgress = - status !== TaskStatus.SUCCESS && - status !== TaskStatus.FAILURE && - status !== TaskStatus.REVOKED; + const isInProgress = !isTerminalStatus(status); if (isInProgress) { return ( @@ -60,7 +100,7 @@ function StatusBadge({ status }: { status: TaskStatus | null }) { ); } - if (status === TaskStatus.SUCCESS) { + if (status === 'SUCCESS') { return ( @@ -68,7 +108,7 @@ function StatusBadge({ status }: { status: TaskStatus | null }) { ); } - if (status === TaskStatus.REVOKED) { + if (status === 'REVOKED') { return ( @@ -87,21 +127,23 @@ function StatusBadge({ status }: { status: TaskStatus | null }) { function PhaseTimeline({ currentPhase, taskStatus, + taskType, }: { - currentPhase: TaskPhase | undefined; - taskStatus: TaskStatus | null; + currentPhase: TaskPhase | string | undefined; + taskStatus: TaskStatus | string | null; + taskType: 'scrape' | 'poi' | 'task'; }) { - const isTerminal = - taskStatus === TaskStatus.SUCCESS || - taskStatus === TaskStatus.FAILURE || - taskStatus === TaskStatus.REVOKED; - const activeIdx = isTerminal ? PHASES.length : getPhaseIndex(currentPhase); + const phases = taskType === 'poi' ? POI_PHASES : SCRAPE_PHASES; + const terminal = taskStatus !== null && isTerminalStatus(taskStatus); + const activeIdx = taskType === 'poi' + ? (terminal ? phases.length : getPoiPhaseIndex(currentPhase)) + : (terminal ? phases.length : getPhaseIndex(currentPhase as TaskPhase | undefined)); return (
- {PHASES.map((phase, idx) => { + {phases.map((phase, idx) => { const isCompleted = idx < activeIdx; - const isActive = idx === activeIdx && !isTerminal; + const isActive = idx === activeIdx && !terminal; const isFuture = idx > activeIdx; return ( @@ -231,6 +273,20 @@ function PhaseDetails({ result }: { result: TaskResult }) { return null; } +function POIPhaseDetails({ task }: { task: TaskState }) { + return ( +
+

+ POI Distances +

+ + {task.distances_computed !== undefined && ( + + )} +
+ ); +} + function LogViewer({ logs }: { logs: string[] }) { const scrollRef = useRef(null); const isAutoScrolling = useRef(true); @@ -267,6 +323,90 @@ function LogViewer({ logs }: { logs: string[] }) { ); } +/** Convert TaskState → TaskResult for existing phase detail components. */ +function taskStateToResult(ts: TaskState): TaskResult { + return { + progress: ts.progress ?? 0, + processed: ts.processed, + total: ts.total, + phase: ts.phase, + message: ts.message, + subqueries_probed: ts.subqueries_probed, + subqueries_initial: ts.subqueries_initial, + estimated_results: ts.estimated_results, + subqueries_total: ts.subqueries_total, + subqueries_completed: ts.subqueries_completed, + ids_collected: ts.ids_collected, + pages_fetched: ts.pages_fetched, + fetching_done: ts.fetching_done, + details_fetched: ts.details_fetched, + images_downloaded: ts.images_downloaded, + ocr_completed: ts.ocr_completed, + failed: ts.failed, + elapsed_seconds: ts.elapsed_seconds, + rate_per_second: ts.rate_per_second, + eta_seconds: ts.eta_seconds, + logs: ts.logs, + }; +} + +function TaskTabBar({ + tasks, + selectedTaskId, + onSelectTask, +}: { + tasks: Record; + selectedTaskId: string | null; + onSelectTask: (taskId: string) => void; +}) { + // Sort: active first, then completed, then failed + const sortedEntries = useMemo(() => { + return Object.entries(tasks).sort(([, a], [, b]) => { + const aTerminal = isTerminalStatus(a.status); + const bTerminal = isTerminalStatus(b.status); + if (aTerminal !== bTerminal) return aTerminal ? 1 : -1; + if (aTerminal && bTerminal) { + if (a.status === 'SUCCESS' && b.status !== 'SUCCESS') return -1; + if (b.status === 'SUCCESS' && a.status !== 'SUCCESS') return 1; + } + return 0; + }); + }, [tasks]); + + if (sortedEntries.length <= 1) return null; + + return ( +
+ {sortedEntries.map(([tid, task]) => { + const type = inferTaskType(task); + const isSelected = tid === selectedTaskId; + const terminal = isTerminalStatus(task.status); + + return ( + + ); + })} +
+ ); +} + export function TaskProgressDrawer({ open, onOpenChange, @@ -275,42 +415,75 @@ export function TaskProgressDrawer({ taskID, onCancel, isCancelling, + wsTasks, + selectedTaskId, + onSelectTask, }: TaskProgressDrawerProps) { - const isInProgress = - taskStatus !== null && - taskStatus !== TaskStatus.SUCCESS && - taskStatus !== TaskStatus.FAILURE && - taskStatus !== TaskStatus.REVOKED; + // Determine which task's data to show + const hasMultipleTasks = wsTasks && Object.keys(wsTasks).length > 1; + const effectiveTaskId = selectedTaskId ?? taskID; - const progressPercent = taskResult - ? Math.min((taskResult.progress ?? 0) * 100, 100) + // Derive the active task data from wsTasks if available, else fall back to props + const activeWsTask = effectiveTaskId && wsTasks ? wsTasks[effectiveTaskId] : undefined; + const effectiveResult = activeWsTask ? taskStateToResult(activeWsTask) : taskResult; + const effectiveStatus = activeWsTask ? (activeWsTask.status as TaskStatus) : taskStatus; + const effectiveTaskType = activeWsTask + ? inferTaskType(activeWsTask) + : (taskResult ? inferTaskType(taskResult) : 'scrape'); + + const isInProgress = + effectiveStatus !== null && + effectiveStatus !== undefined && + !isTerminalStatus(effectiveStatus); + + const progressPercent = effectiveResult + ? Math.min((effectiveResult.progress ?? 0) * 100, 100) : 0; + const drawerTitle = hasMultipleTasks + ? 'Job Progress' + : `${taskTypeLabel(effectiveTaskType)} Job Progress`; + return (
- Crawl Job Progress - + {drawerTitle} +
- {taskID && ( + {effectiveTaskId && ( - Task ID: {taskID.slice(0, 8)}... + Task ID: {effectiveTaskId.slice(0, 8)}... )}
+ {/* Multi-job tab bar */} + {hasMultipleTasks && onSelectTask && ( + + )} + {/* Fixed top section: timeline + counters + progress */}
- {taskResult && } + {effectiveTaskType === 'poi' && activeWsTask && ( + + )} + {effectiveTaskType !== 'poi' && effectiveResult && ( + + )} - {taskResult && (taskResult.phase === 'processing' || taskResult.phase === 'fetching') && (taskResult.total ?? 0) > 0 && ( + {effectiveResult && (effectiveResult.phase === 'processing' || effectiveResult.phase === 'fetching' || effectiveResult.phase === 'computing') && (effectiveResult.total ?? 0) > 0 && (
- {taskResult.processed ?? 0} / {taskResult.total ?? '?'} + {effectiveResult.processed ?? 0} / {effectiveResult.total ?? '?'} - {formatEta(taskResult.eta_seconds)} + {formatEta(effectiveResult.eta_seconds)}
)} - {taskResult?.message && ( -

{taskResult.message}

+ {effectiveResult?.message && ( +

{effectiveResult.message}

)}
@@ -338,7 +511,7 @@ export function TaskProgressDrawer({ Worker Logs

- +
diff --git a/frontend/src/constants/index.ts b/frontend/src/constants/index.ts index 52ea8c5..dc503aa 100644 --- a/frontend/src/constants/index.ts +++ b/frontend/src/constants/index.ts @@ -59,5 +59,8 @@ export const DEFAULT_FORM_VALUES = { // Polling intervals export const POLLING_INTERVALS = { - TASK_STATUS_MS: 5000, // 5 seconds + TASK_STATUS_MS: 30000, // 30 seconds (fallback when WebSocket is unavailable) } as const; + +// WebSocket paths +export const WS_TASKS_PATH = '/api/ws/tasks'; diff --git a/frontend/src/hooks/useTaskWebSocket.ts b/frontend/src/hooks/useTaskWebSocket.ts new file mode 100644 index 0000000..f7c491b --- /dev/null +++ b/frontend/src/hooks/useTaskWebSocket.ts @@ -0,0 +1,129 @@ +import { useCallback, useEffect, useRef, useState } from 'react'; +import type { AuthUser } from '@/auth/types'; +import type { TaskState, WSMessage } from '@/types'; +import { WS_TASKS_PATH } from '@/constants'; + +const KEEPALIVE_MS = 30_000; +const MAX_RECONNECT_DELAY_MS = 30_000; + +function wsUrl(token: string): string { + const proto = window.location.protocol === 'https:' ? 'wss' : 'ws'; + return `${proto}://${window.location.host}${WS_TASKS_PATH}?token=${encodeURIComponent(token)}`; +} + +export interface UseTaskWebSocketReturn { + tasks: Record; + isConnected: boolean; + subscribe: (taskId: string) => void; +} + +export function useTaskWebSocket(user: AuthUser | null): UseTaskWebSocketReturn { + const [tasks, setTasks] = useState>({}); + const [isConnected, setIsConnected] = useState(false); + const wsRef = useRef(null); + const reconnectAttempt = useRef(0); + const reconnectTimer = useRef | null>(null); + const keepaliveTimer = useRef | null>(null); + const mountedRef = useRef(true); + + const clearTimers = useCallback(() => { + if (reconnectTimer.current) { + clearTimeout(reconnectTimer.current); + reconnectTimer.current = null; + } + if (keepaliveTimer.current) { + clearInterval(keepaliveTimer.current); + keepaliveTimer.current = null; + } + }, []); + + const connect = useCallback(() => { + if (!user) return; + + const ws = new WebSocket(wsUrl(user.accessToken)); + wsRef.current = ws; + + ws.onopen = () => { + if (!mountedRef.current) return; + setIsConnected(true); + reconnectAttempt.current = 0; + + // Start keepalive pings + keepaliveTimer.current = setInterval(() => { + if (ws.readyState === WebSocket.OPEN) { + ws.send(JSON.stringify({ type: 'ping' })); + } + }, KEEPALIVE_MS); + }; + + ws.onmessage = (event) => { + if (!mountedRef.current) return; + try { + const msg: WSMessage = JSON.parse(event.data); + + if (msg.type === 'init') { + const initial: Record = {}; + for (const t of msg.tasks) { + initial[t.task_id] = t; + } + setTasks(initial); + } else if (msg.type === 'task_update') { + const { type: _, ...taskData } = msg; + setTasks((prev) => ({ + ...prev, + [msg.task_id]: { ...prev[msg.task_id], ...taskData } as TaskState, + })); + } + // pong messages are ignored + } catch { + // Ignore malformed messages + } + }; + + ws.onclose = () => { + if (!mountedRef.current) return; + setIsConnected(false); + if (keepaliveTimer.current) { + clearInterval(keepaliveTimer.current); + keepaliveTimer.current = null; + } + + // Exponential backoff reconnect + const delay = Math.min( + 1000 * 2 ** reconnectAttempt.current, + MAX_RECONNECT_DELAY_MS, + ); + reconnectAttempt.current += 1; + reconnectTimer.current = setTimeout(() => { + if (mountedRef.current) connect(); + }, delay); + }; + + ws.onerror = () => { + // onclose will fire after this, triggering reconnect + }; + }, [user, clearTimers]); + + useEffect(() => { + mountedRef.current = true; + connect(); + + return () => { + mountedRef.current = false; + clearTimers(); + if (wsRef.current) { + wsRef.current.close(); + wsRef.current = null; + } + }; + }, [connect, clearTimers]); + + const subscribe = useCallback((taskId: string) => { + const ws = wsRef.current; + if (ws && ws.readyState === WebSocket.OPEN) { + ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId })); + } + }, []); + + return { tasks, isConnected, subscribe }; +} diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index d5e5e7a..025e80e 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -120,3 +120,55 @@ export interface POITravelFilter { travelMode: 'WALK' | 'BICYCLE' | 'TRANSIT'; maxMinutes: number | undefined; } + +// WebSocket task state (combines status + result fields) +export interface TaskState { + task_id: string; + status: string; + progress?: number; + processed?: number; + total?: number; + phase?: TaskPhase; + message?: string; + // Splitting phase + subqueries_probed?: number; + subqueries_initial?: number; + estimated_results?: number; + subqueries_total?: number; + // Fetching phase + subqueries_completed?: number; + ids_collected?: number; + pages_fetched?: number; + fetching_done?: boolean; + // Processing phase + details_fetched?: number; + images_downloaded?: number; + ocr_completed?: number; + failed?: number; + elapsed_seconds?: number; + rate_per_second?: number; + eta_seconds?: number; + // POI-specific + distances_computed?: number; + // Live logs + logs?: string[]; +} + +// WebSocket message types +export interface WSInitMessage { + type: 'init'; + tasks: TaskState[]; +} + +export interface WSTaskUpdateMessage { + type: 'task_update'; + task_id: string; + status: string; + [key: string]: unknown; +} + +export interface WSPongMessage { + type: 'pong'; +} + +export type WSMessage = WSInitMessage | WSTaskUpdateMessage | WSPongMessage; diff --git a/services/task_progress_publisher.py b/services/task_progress_publisher.py new file mode 100644 index 0000000..52a25a3 --- /dev/null +++ b/services/task_progress_publisher.py @@ -0,0 +1,50 @@ +"""Publishes task progress updates to Redis pub/sub channels. + +Celery workers call publish_task_progress() alongside task.update_state() so +that the FastAPI WebSocket endpoint can forward real-time updates to connected +browsers without polling. + +Channel naming: ``task_progress:{task_id}`` +""" +import json +import logging +import os +from typing import Any + +import redis + +logger = logging.getLogger(__name__) + +_redis_client: redis.Redis | None = None # type: ignore[type-arg] + + +def _get_redis_client() -> redis.Redis: # type: ignore[type-arg] + """Lazy-initialised Redis client derived from CELERY_BROKER_URL.""" + global _redis_client + if _redis_client is None: + broker_url = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0") + _redis_client = redis.Redis.from_url(broker_url, decode_responses=True) + return _redis_client + + +def publish_task_progress(task_id: str, state: str, meta: dict[str, Any]) -> None: + """Publish a task progress update to Redis pub/sub. + + Args: + task_id: Celery task ID. + state: Celery state string (e.g. 'PROGRESS', 'SUCCESS'). + meta: Metadata dict (progress, phase, logs, counters, etc.). + + Failures are caught and logged at DEBUG level so they never break the + critical task execution path. + """ + try: + client = _get_redis_client() + payload = json.dumps({ + "task_id": task_id, + "status": state, + **meta, + }) + client.publish(f"task_progress:{task_id}", payload) + except Exception: + logger.debug("Failed to publish task progress for %s", task_id, exc_info=True) diff --git a/services/task_service.py b/services/task_service.py index bb8df71..e4b5369 100644 --- a/services/task_service.py +++ b/services/task_service.py @@ -177,11 +177,19 @@ def cancel_task(task_id: str, user_email: str | None = None) -> bool: """ # Lazy import: celery_app bootstraps the broker connection. from celery_app import app as celery_app + from services.task_progress_publisher import publish_task_progress logger.info("Cancelling task %s (user=%s)", task_id, user_email) # Revoke the task in Celery celery_app.control.revoke(task_id, terminate=True) + # Publish REVOKED status via pub/sub so WebSocket clients learn immediately + publish_task_progress(task_id, "REVOKED", { + "phase": "completed", + "progress": 0, + "message": "Task cancelled", + }) + # Also remove from user's task list if user_email provided if user_email: remove_task_from_user(user_email, task_id) @@ -222,6 +230,7 @@ def clear_all_tasks(user_email: str, revoke: bool = True) -> int: # Lazy imports: see get_user_tasks and cancel_task for rationale. from redis_repository import RedisRepository from celery_app import app as celery_app + from services.task_progress_publisher import publish_task_progress redis_repo = RedisRepository.instance() user = _make_system_user(user_email) @@ -238,5 +247,11 @@ def clear_all_tasks(user_email: str, revoke: bool = True) -> int: logger.warning( "Failed to revoke task %s: %s", task_id, e ) + # Publish REVOKED via pub/sub so WebSocket clients learn immediately + publish_task_progress(task_id, "REVOKED", { + "phase": "completed", + "progress": 0, + "message": "Task cancelled", + }) return redis_repo.clear_tasks_for_user(user) diff --git a/tasks/listing_tasks.py b/tasks/listing_tasks.py index 33c0628..8258165 100644 --- a/tasks/listing_tasks.py +++ b/tasks/listing_tasks.py @@ -19,6 +19,7 @@ from database import engine from services.query_splitter import QuerySplitter, SubQuery from utils.redis_lock import redis_lock from services.listing_cache import invalidate_cache +from services.task_progress_publisher import publish_task_progress logger = logging.getLogger("uvicorn.error") @@ -86,6 +87,8 @@ def _update_task_state(task: Task, state: str, meta: dict[str, Any]) -> None: if _active_log_buffer is not None: meta["logs"] = list(_active_log_buffer) task.update_state(state=state, meta=meta) + if hasattr(task, 'request') and task.request.id: + publish_task_progress(task.request.id, state, meta) async def _fetch_subquery( @@ -266,7 +269,9 @@ def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]: if not acquired: msg = "Another scrape job is already running, skipping this execution" celery_logger.warning(msg) - self.update_state(state="SKIPPED", meta={"reason": "Another scrape job is running"}) + meta = {"reason": "Another scrape job is running"} + self.update_state(state="SKIPPED", meta=meta) + publish_task_progress(self.request.id, "SKIPPED", meta) return {"status": "skipped", "reason": "another_job_running"} celery_logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}") @@ -275,8 +280,11 @@ def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]: celery_logger.info(f"Starting scrape with parameters: {parsed_parameters}") self.update_state(state="Starting...", meta={"phase": PHASE_SPLITTING, "progress": 0}) + publish_task_progress(self.request.id, "Starting...", {"phase": PHASE_SPLITTING, "progress": 0}) asyncio.run(dump_listings_full(task=self, parameters=parsed_parameters)) - return {"phase": PHASE_COMPLETED, "progress": 1} + result = {"phase": PHASE_COMPLETED, "progress": 1} + publish_task_progress(self.request.id, "SUCCESS", result) + return result async def async_dump_listings_task(parameters_json: str) -> dict[str, Any]: diff --git a/tasks/poi_tasks.py b/tasks/poi_tasks.py index b822c94..635f31b 100644 --- a/tasks/poi_tasks.py +++ b/tasks/poi_tasks.py @@ -10,6 +10,7 @@ from models.listing import ListingType from repositories.listing_repository import ListingRepository from repositories.poi_repository import POIRepository from services.poi_distance_calculator import calculate_poi_distances +from services.task_progress_publisher import publish_task_progress logger = logging.getLogger(__name__) @@ -55,6 +56,11 @@ def calculate_poi_distances_task( "progress": 0, "message": "Starting distance calculation...", }) + publish_task_progress(self.request.id, "PROGRESS", { + "phase": "starting", + "progress": 0, + "message": "Starting distance calculation...", + }) listing_repo = ListingRepository(engine) poi_repo = POIRepository(engine) @@ -62,19 +68,23 @@ def calculate_poi_distances_task( poi = poi_repo.get_poi_by_id(poi_id) if poi is None: celery_logger.error(f"POI {poi_id} not found") - return {"error": f"POI {poi_id} not found", "distances_computed": 0} + error_result = {"error": f"POI {poi_id} not found", "distances_computed": 0} + publish_task_progress(self.request.id, "FAILURE", error_result) + return error_result lt = ListingType(listing_type) def on_progress(completed: int, total: int, message: str) -> None: progress = round(completed / total, 2) if total > 0 else 0 - self.update_state(state="PROGRESS", meta={ + meta = { "phase": "computing", "progress": progress, "processed": completed, "total": total, "message": message, - }) + } + self.update_state(state="PROGRESS", meta=meta) + publish_task_progress(self.request.id, "PROGRESS", meta) try: total = asyncio.run( @@ -96,9 +106,12 @@ def calculate_poi_distances_task( celery_logger.info(f"POI distance calculation complete: {total} distances computed") - return { + result = { "phase": "completed", "progress": 1, "distances_computed": total, "message": f"Computed {total} distances for POI '{poi.name}'", } + publish_task_progress(self.request.id, "SUCCESS", result) + + return result