Add real-time WebSocket task progress with multi-job drawer

Replace 5s HTTP polling with WebSocket-based real-time updates for task
progress. Celery workers publish progress to Redis pub/sub channels;
a FastAPI WebSocket endpoint subscribes and forwards to the browser.
Polling is kept as a 30s fallback when WebSocket is unavailable.

The task progress drawer now supports multiple concurrent jobs with a
tab bar for switching between scrape and POI distance tasks.

Backend:
- Add services/task_progress_publisher.py (Redis pub/sub bridge)
- Add api/ws_routes.py (WebSocket endpoint with JWT auth)
- Publish progress from listing_tasks and poi_tasks
- Publish REVOKED via pub/sub on cancel/clear to fix stuck UI

Frontend:
- Add useTaskWebSocket hook with reconnection and keepalive
- Add TaskState and WS message types
- TaskIndicator: WS-driven updates with polling fallback
- TaskProgressDrawer: multi-job tabs, POI phase timeline
- Guard against WS overwriting local cancel state
This commit is contained in:
Viktor Barzin 2026-02-09 21:31:45 +00:00
parent 73d19e29d5
commit 8559c4b461
No known key found for this signature in database
GPG key ID: 0EB088298288D958
11 changed files with 774 additions and 72 deletions

153
api/ws_routes.py Normal file
View file

@ -0,0 +1,153 @@
"""WebSocket endpoint for real-time task progress updates.
Clients connect to ``/api/ws/tasks?token=<jwt>`` and receive live progress
messages published by Celery workers via Redis pub/sub.
"""
import asyncio
import json
import logging
from typing import Any
import jwt
import redis.asyncio as aioredis
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from api.auth import _verify_authentik_token, _verify_passkey_token, User
from api.config import JWT_ISSUER
from services import task_service
logger = logging.getLogger(__name__)
ws_router = APIRouter()
# Reuse the broker URL for the async Redis client
import os
_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0")
async def _authenticate_ws(token: str) -> User | None:
"""Verify a JWT token using the same logic as api/auth.py."""
try:
unverified = jwt.decode(
token, options={"verify_signature": False, "verify_exp": False}
)
issuer = unverified.get("iss", "")
if issuer == JWT_ISSUER:
return _verify_passkey_token(token)
else:
return await _verify_authentik_token(token)
except Exception:
return None
async def _build_task_snapshot(task_id: str) -> dict[str, Any]:
"""Build a snapshot of a task's current status for the init message."""
status = task_service.get_task_status(task_id)
result: dict[str, Any] = {
"task_id": status.task_id,
"status": status.status,
"progress": status.progress,
"processed": status.processed,
"total": status.total,
"message": status.message,
}
if status.result and isinstance(status.result, dict):
result.update(status.result)
return result
@ws_router.websocket("/api/ws/tasks")
async def ws_task_progress(websocket: WebSocket) -> None:
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=4001, reason="Missing token")
return
user = await _authenticate_ws(token)
if user is None:
await websocket.close(code=4003, reason="Invalid token")
return
await websocket.accept()
# Get user's tasks and send initial snapshot
task_ids = task_service.get_user_tasks(user.email)
snapshots = []
for tid in task_ids:
try:
snapshots.append(await _build_task_snapshot(tid))
except Exception:
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
try:
await websocket.send_json({"type": "init", "tasks": snapshots})
except Exception:
return
# Subscribe to Redis pub/sub channels for each task
redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True)
pubsub = redis_client.pubsub()
subscribed_channels: set[str] = set()
for tid in task_ids:
channel = f"task_progress:{tid}"
await pubsub.subscribe(channel)
subscribed_channels.add(channel)
async def _forward_pubsub() -> None:
"""Read from Redis pub/sub and forward to the WebSocket."""
while True:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if message and message["type"] == "message":
try:
data = json.loads(message["data"])
await websocket.send_json({"type": "task_update", **data})
except Exception:
break
async def _handle_client_messages() -> None:
"""Read messages from the client (subscribe, ping)."""
while True:
try:
raw = await websocket.receive_text()
msg = json.loads(raw)
except WebSocketDisconnect:
raise
except Exception:
continue
msg_type = msg.get("type")
if msg_type == "subscribe":
new_task_id = msg.get("task_id")
if new_task_id:
channel = f"task_progress:{new_task_id}"
if channel not in subscribed_channels:
await pubsub.subscribe(channel)
subscribed_channels.add(channel)
# Send current snapshot for the new task
try:
snapshot = await _build_task_snapshot(new_task_id)
await websocket.send_json(
{"type": "task_update", **snapshot}
)
except Exception:
pass
elif msg_type == "ping":
try:
await websocket.send_json({"type": "pong"})
except Exception:
break
try:
await asyncio.gather(
_forward_pubsub(),
_handle_client_messages(),
)
except (WebSocketDisconnect, Exception):
pass
finally:
await pubsub.unsubscribe(*subscribed_channels)
await pubsub.close()
await redis_client.aclose()

View file

@ -1,4 +1,5 @@
import type { AuthUser } from '@/auth/types'; import type { AuthUser } from '@/auth/types';
import type { TaskState } from '@/types';
import { Button } from './ui/button'; import { Button } from './ui/button';
import { Separator } from './ui/separator'; import { Separator } from './ui/separator';
import { LogOut, Home, Filter } from 'lucide-react'; import { LogOut, Home, Filter } from 'lucide-react';
@ -16,6 +17,9 @@ interface HeaderProps {
showFilterToggle?: boolean; showFilterToggle?: boolean;
onTaskCancelled?: () => void; onTaskCancelled?: () => void;
onTaskCompleted?: () => void; onTaskCompleted?: () => void;
wsTasks?: Record<string, TaskState>;
wsConnected?: boolean;
wsSubscribe?: (taskId: string) => void;
} }
export function Header({ export function Header({
@ -26,6 +30,9 @@ export function Header({
showFilterToggle = false, showFilterToggle = false,
onTaskCancelled, onTaskCancelled,
onTaskCompleted, onTaskCompleted,
wsTasks,
wsConnected,
wsSubscribe,
}: HeaderProps) { }: HeaderProps) {
const handleLogout = async () => { const handleLogout = async () => {
if (user.provider === 'passkey') { if (user.provider === 'passkey') {
@ -50,7 +57,13 @@ export function Header({
<HealthIndicator /> <HealthIndicator />
{/* Task Indicator */} {/* Task Indicator */}
<TaskIndicator taskID={taskID ?? null} onTaskCancelled={onTaskCancelled} onTaskCompleted={onTaskCompleted} /> <TaskIndicator
taskID={taskID ?? null}
onTaskCancelled={onTaskCancelled}
onTaskCompleted={onTaskCompleted}
wsTasks={wsTasks}
wsConnected={wsConnected}
/>
{/* Filter Toggle (mobile) */} {/* Filter Toggle (mobile) */}
{showFilterToggle && ( {showFilterToggle && (

View file

@ -3,8 +3,8 @@ import { getStoredPasskeyUser } from '@/auth/passkeyService';
import { fromOidcUser, type AuthUser } from '@/auth/types'; import { fromOidcUser, type AuthUser } from '@/auth/types';
import { POLLING_INTERVALS } from '@/constants'; import { POLLING_INTERVALS } from '@/constants';
import { fetchTaskStatus, cancelTask, clearAllTasks } from '@/services'; import { fetchTaskStatus, cancelTask, clearAllTasks } from '@/services';
import { TaskStatus, type TaskResult } from '@/types'; import { TaskStatus, type TaskResult, type TaskState } from '@/types';
import { useEffect, useState, useRef } from 'react'; import { useEffect, useState, useRef, useMemo } from 'react';
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from './ui/tooltip'; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from './ui/tooltip';
import { Button } from './ui/button'; import { Button } from './ui/button';
import { Loader2, CheckCircle2, XCircle, X, Trash2 } from 'lucide-react'; import { Loader2, CheckCircle2, XCircle, X, Trash2 } from 'lucide-react';
@ -14,9 +14,48 @@ interface TaskIndicatorProps {
taskID: string | null; taskID: string | null;
onTaskCancelled?: () => void; onTaskCancelled?: () => void;
onTaskCompleted?: () => void; onTaskCompleted?: () => void;
wsTasks?: Record<string, TaskState>;
wsConnected?: boolean;
} }
export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: TaskIndicatorProps) { /** Convert a TaskState (from WS) into a TaskResult (for the drawer). */
function taskStateToResult(ts: TaskState): TaskResult {
return {
progress: ts.progress ?? 0,
processed: ts.processed,
total: ts.total,
phase: ts.phase,
message: ts.message,
subqueries_probed: ts.subqueries_probed,
subqueries_initial: ts.subqueries_initial,
estimated_results: ts.estimated_results,
subqueries_total: ts.subqueries_total,
subqueries_completed: ts.subqueries_completed,
ids_collected: ts.ids_collected,
pages_fetched: ts.pages_fetched,
fetching_done: ts.fetching_done,
details_fetched: ts.details_fetched,
images_downloaded: ts.images_downloaded,
ocr_completed: ts.ocr_completed,
failed: ts.failed,
elapsed_seconds: ts.elapsed_seconds,
rate_per_second: ts.rate_per_second,
eta_seconds: ts.eta_seconds,
logs: ts.logs,
};
}
function isTerminalStatus(status: string): boolean {
return status === 'SUCCESS' || status === 'FAILURE' || status === 'REVOKED';
}
export function TaskIndicator({
taskID,
onTaskCancelled,
onTaskCompleted,
wsTasks,
wsConnected,
}: TaskIndicatorProps) {
const [user, setUser] = useState<AuthUser | null>(null); const [user, setUser] = useState<AuthUser | null>(null);
const [progressPercentage, setProgressPercentage] = useState<number>(0); const [progressPercentage, setProgressPercentage] = useState<number>(0);
const [processed, setProcessed] = useState<number | null>(null); const [processed, setProcessed] = useState<number | null>(null);
@ -26,6 +65,11 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task
const [isCancelling, setIsCancelling] = useState(false); const [isCancelling, setIsCancelling] = useState(false);
const [isClearing, setIsClearing] = useState(false); const [isClearing, setIsClearing] = useState(false);
const [drawerOpen, setDrawerOpen] = useState(false); const [drawerOpen, setDrawerOpen] = useState(false);
const [selectedTaskId, setSelectedTaskId] = useState<string | null>(null);
// Prevents WS effect from overwriting local cancel/clear state before
// the parent's setTaskID(null) propagates. Reset when taskID changes.
const cancelledRef = useRef(false);
const onTaskCompletedRef = useRef(onTaskCompleted); const onTaskCompletedRef = useRef(onTaskCompleted);
useEffect(() => { useEffect(() => {
@ -43,7 +87,58 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task
} }
}, []); }, []);
// Track the currently-viewed task in the drawer; default to the externally-provided taskID
useEffect(() => { useEffect(() => {
if (taskID) {
setSelectedTaskId(taskID);
cancelledRef.current = false; // new task, reset cancelled guard
}
}, [taskID]);
// Count active (non-terminal) tasks from WS
const activeWsTaskCount = useMemo(() => {
if (!wsTasks) return 0;
return Object.values(wsTasks).filter(
(t) => !isTerminalStatus(t.status),
).length;
}, [wsTasks]);
// ----- WebSocket-driven state updates -----
// When wsConnected, derive taskStatus/taskResult/progress from wsTasks
useEffect(() => {
if (!wsConnected || !wsTasks || !taskID) return;
// Don't let WS overwrite local cancel/clear state
if (cancelledRef.current) return;
const wsTask = wsTasks[taskID];
if (!wsTask) return;
const status = wsTask.status as TaskStatus;
setTaskStatus(status);
if (wsTask.phase) {
setTaskResult(taskStateToResult(wsTask));
}
if (wsTask.progress !== undefined) {
setProgressPercentage(wsTask.progress * 100);
}
if (wsTask.processed !== undefined) {
setProcessed(wsTask.processed);
}
if (wsTask.total !== undefined) {
setTotal(wsTask.total);
}
if (status === TaskStatus.SUCCESS) {
setProgressPercentage(100);
onTaskCompletedRef.current?.();
}
}, [wsConnected, wsTasks, taskID]);
// ----- Polling fallback (only when WS is not connected) -----
useEffect(() => {
// If WS is connected, skip polling
if (wsConnected) return;
if (!user || !taskID) { if (!user || !taskID) {
setTaskStatus(null); setTaskStatus(null);
setTaskResult(null); setTaskResult(null);
@ -65,10 +160,6 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task
if (status === TaskStatus.SUCCESS) { if (status === TaskStatus.SUCCESS) {
setProgressPercentage(100); setProgressPercentage(100);
// Parse final result for the drawer to show completed state.
// Only update taskResult if the new result has phase info;
// otherwise keep the last in-progress result which has richer data
// than the bare SUCCESS return value.
if (data.result) { if (data.result) {
try { try {
const parsedResult: TaskResult = JSON.parse(data.result); const parsedResult: TaskResult = JSON.parse(data.result);
@ -80,26 +171,19 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task
} }
} }
onTaskCompletedRef.current?.(); onTaskCompletedRef.current?.();
return true; // Stop polling return true;
} }
if (status === TaskStatus.FAILURE || status === TaskStatus.REVOKED) { if (status === TaskStatus.FAILURE || status === TaskStatus.REVOKED) {
return true; // Stop polling return true;
} }
// Parse progress for in-progress tasks
if (data.result) { if (data.result) {
try { try {
const parsedResult: TaskResult = JSON.parse(data.result); const parsedResult: TaskResult = JSON.parse(data.result);
// Only update taskResult if the parsed data has a phase field.
// This prevents blanking the drawer when the backend sends a
// state update without phase info (e.g. during brief transitions).
if (parsedResult.phase) { if (parsedResult.phase) {
setTaskResult(parsedResult); setTaskResult(parsedResult);
} }
// Only update progress/processed/total when the fields are
// actually present — otherwise keep the previous values so
// the UI doesn't flash back to 0 during phase transitions.
if (parsedResult.progress !== undefined) { if (parsedResult.progress !== undefined) {
setProgressPercentage(parsedResult.progress * 100); setProgressPercentage(parsedResult.progress * 100);
} }
@ -113,14 +197,13 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task
// Ignore parsing errors // Ignore parsing errors
} }
} }
return false; // Continue polling return false;
} catch { } catch {
setTaskStatus(TaskStatus.FAILURE); setTaskStatus(TaskStatus.FAILURE);
return true; // Stop polling on error return true;
} }
}; };
// Initial poll
pollTaskStatus(); pollTaskStatus();
const interval = setInterval(async () => { const interval = setInterval(async () => {
@ -131,7 +214,7 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task
}, POLLING_INTERVALS.TASK_STATUS_MS); }, POLLING_INTERVALS.TASK_STATUS_MS);
return () => clearInterval(interval); return () => clearInterval(interval);
}, [taskID, user]); }, [taskID, user, wsConnected]);
const handleCancel = async () => { const handleCancel = async () => {
if (!user || !taskID || isCancelling) return; if (!user || !taskID || isCancelling) return;
@ -140,6 +223,7 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task
try { try {
const result = await cancelTask(user, taskID); const result = await cancelTask(user, taskID);
if (result.success) { if (result.success) {
cancelledRef.current = true;
setTaskStatus(TaskStatus.REVOKED); setTaskStatus(TaskStatus.REVOKED);
onTaskCancelled?.(); onTaskCancelled?.();
} }
@ -157,6 +241,7 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task
try { try {
const result = await clearAllTasks(user); const result = await clearAllTasks(user);
if (result.success) { if (result.success) {
cancelledRef.current = true;
setTaskStatus(null); setTaskStatus(null);
setTaskResult(null); setTaskResult(null);
onTaskCancelled?.(); onTaskCancelled?.();
@ -245,6 +330,11 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task
{taskStatus} {taskStatus}
</span> </span>
)} )}
{activeWsTaskCount > 1 && (
<span className="inline-flex items-center justify-center h-4 min-w-[16px] rounded-full bg-blue-500 text-[10px] font-medium text-white px-1">
{activeWsTaskCount}
</span>
)}
</div> </div>
</TooltipTrigger> </TooltipTrigger>
<TooltipContent side="bottom"> <TooltipContent side="bottom">
@ -292,9 +382,12 @@ export function TaskIndicator({ taskID, onTaskCancelled, onTaskCompleted }: Task
onOpenChange={setDrawerOpen} onOpenChange={setDrawerOpen}
taskResult={taskResult} taskResult={taskResult}
taskStatus={taskStatus} taskStatus={taskStatus}
taskID={taskID} taskID={selectedTaskId ?? taskID}
onCancel={handleCancel} onCancel={handleCancel}
isCancelling={isCancelling} isCancelling={isCancelling}
wsTasks={wsTasks}
selectedTaskId={selectedTaskId}
onSelectTask={setSelectedTaskId}
/> />
</TooltipProvider> </TooltipProvider>
); );

View file

@ -1,4 +1,4 @@
import { TaskStatus, type TaskPhase, type TaskResult } from '@/types'; import { TaskStatus, type TaskPhase, type TaskResult, type TaskState } from '@/types';
import { import {
Sheet, Sheet,
SheetContent, SheetContent,
@ -8,8 +8,8 @@ import {
SheetFooter, SheetFooter,
} from './ui/sheet'; } from './ui/sheet';
import { Button } from './ui/button'; import { Button } from './ui/button';
import { CheckCircle2, Circle, Loader2, XCircle } from 'lucide-react'; import { CheckCircle2, Circle, Loader2, XCircle, MapPin, Search } from 'lucide-react';
import { useEffect, useRef } from 'react'; import { useEffect, useRef, useMemo } from 'react';
interface TaskProgressDrawerProps { interface TaskProgressDrawerProps {
open: boolean; open: boolean;
@ -19,19 +19,62 @@ interface TaskProgressDrawerProps {
taskID: string | null; taskID: string | null;
onCancel: () => void; onCancel: () => void;
isCancelling: boolean; isCancelling: boolean;
wsTasks?: Record<string, TaskState>;
selectedTaskId?: string | null;
onSelectTask?: (taskId: string) => void;
} }
const PHASES: { key: TaskPhase; label: string }[] = [ const SCRAPE_PHASES: { key: TaskPhase; label: string }[] = [
{ key: 'splitting', label: 'Splitting queries' }, { key: 'splitting', label: 'Splitting queries' },
{ key: 'fetching', label: 'Fetching & processing' }, { key: 'fetching', label: 'Fetching & processing' },
{ key: 'processing', label: 'Processing remaining' }, { key: 'processing', label: 'Processing remaining' },
]; ];
const POI_PHASES: { key: string; label: string }[] = [
{ key: 'starting', label: 'Starting' },
{ key: 'computing', label: 'Computing distances' },
{ key: 'completed', label: 'Completed' },
];
function inferTaskType(task: TaskState | TaskResult): 'scrape' | 'poi' | 'task' {
if ('distances_computed' in task && task.distances_computed !== undefined) return 'poi';
if ('subqueries_completed' in task || 'ids_collected' in task || 'pages_fetched' in task) return 'scrape';
const phase = task.phase;
if (phase === 'starting' || phase === 'computing') return 'poi';
if (phase === 'splitting' || phase === 'splitting_complete' || phase === 'fetching' || phase === 'processing') return 'scrape';
return 'task';
}
function taskTypeLabel(type: 'scrape' | 'poi' | 'task'): string {
switch (type) {
case 'scrape': return 'Scrape';
case 'poi': return 'POI Distances';
default: return 'Task';
}
}
function taskTypeIcon(type: 'scrape' | 'poi' | 'task') {
switch (type) {
case 'scrape': return <Search className="h-3 w-3" />;
case 'poi': return <MapPin className="h-3 w-3" />;
default: return <Circle className="h-3 w-3" />;
}
}
function isTerminalStatus(status: string): boolean {
return status === 'SUCCESS' || status === 'FAILURE' || status === 'REVOKED';
}
function getPhaseIndex(phase: TaskPhase | undefined): number { function getPhaseIndex(phase: TaskPhase | undefined): number {
if (!phase) return -1; if (!phase) return -1;
if (phase === 'splitting_complete') return 1; // splitting done, fetching is next if (phase === 'splitting_complete') return 1;
if (phase === 'completed') return PHASES.length; if (phase === 'completed') return SCRAPE_PHASES.length;
return PHASES.findIndex((p) => p.key === phase); return SCRAPE_PHASES.findIndex((p) => p.key === phase);
}
function getPoiPhaseIndex(phase: string | undefined): number {
if (!phase) return -1;
return POI_PHASES.findIndex((p) => p.key === phase);
} }
function formatEta(seconds: number | undefined): string { function formatEta(seconds: number | undefined): string {
@ -44,13 +87,10 @@ function formatEta(seconds: number | undefined): string {
return `~${secs}s remaining`; return `~${secs}s remaining`;
} }
function StatusBadge({ status }: { status: TaskStatus | null }) { function StatusBadge({ status }: { status: TaskStatus | string | null }) {
if (!status) return null; if (!status) return null;
const isInProgress = const isInProgress = !isTerminalStatus(status);
status !== TaskStatus.SUCCESS &&
status !== TaskStatus.FAILURE &&
status !== TaskStatus.REVOKED;
if (isInProgress) { if (isInProgress) {
return ( return (
@ -60,7 +100,7 @@ function StatusBadge({ status }: { status: TaskStatus | null }) {
</span> </span>
); );
} }
if (status === TaskStatus.SUCCESS) { if (status === 'SUCCESS') {
return ( return (
<span className="inline-flex items-center gap-1 rounded-full bg-green-100 px-2 py-0.5 text-xs font-medium text-green-700"> <span className="inline-flex items-center gap-1 rounded-full bg-green-100 px-2 py-0.5 text-xs font-medium text-green-700">
<CheckCircle2 className="h-3 w-3" /> <CheckCircle2 className="h-3 w-3" />
@ -68,7 +108,7 @@ function StatusBadge({ status }: { status: TaskStatus | null }) {
</span> </span>
); );
} }
if (status === TaskStatus.REVOKED) { if (status === 'REVOKED') {
return ( return (
<span className="inline-flex items-center gap-1 rounded-full bg-yellow-100 px-2 py-0.5 text-xs font-medium text-yellow-700"> <span className="inline-flex items-center gap-1 rounded-full bg-yellow-100 px-2 py-0.5 text-xs font-medium text-yellow-700">
<XCircle className="h-3 w-3" /> <XCircle className="h-3 w-3" />
@ -87,21 +127,23 @@ function StatusBadge({ status }: { status: TaskStatus | null }) {
function PhaseTimeline({ function PhaseTimeline({
currentPhase, currentPhase,
taskStatus, taskStatus,
taskType,
}: { }: {
currentPhase: TaskPhase | undefined; currentPhase: TaskPhase | string | undefined;
taskStatus: TaskStatus | null; taskStatus: TaskStatus | string | null;
taskType: 'scrape' | 'poi' | 'task';
}) { }) {
const isTerminal = const phases = taskType === 'poi' ? POI_PHASES : SCRAPE_PHASES;
taskStatus === TaskStatus.SUCCESS || const terminal = taskStatus !== null && isTerminalStatus(taskStatus);
taskStatus === TaskStatus.FAILURE || const activeIdx = taskType === 'poi'
taskStatus === TaskStatus.REVOKED; ? (terminal ? phases.length : getPoiPhaseIndex(currentPhase))
const activeIdx = isTerminal ? PHASES.length : getPhaseIndex(currentPhase); : (terminal ? phases.length : getPhaseIndex(currentPhase as TaskPhase | undefined));
return ( return (
<div className="flex flex-col gap-1"> <div className="flex flex-col gap-1">
{PHASES.map((phase, idx) => { {phases.map((phase, idx) => {
const isCompleted = idx < activeIdx; const isCompleted = idx < activeIdx;
const isActive = idx === activeIdx && !isTerminal; const isActive = idx === activeIdx && !terminal;
const isFuture = idx > activeIdx; const isFuture = idx > activeIdx;
return ( return (
@ -231,6 +273,20 @@ function PhaseDetails({ result }: { result: TaskResult }) {
return null; return null;
} }
function POIPhaseDetails({ task }: { task: TaskState }) {
return (
<div className="rounded-md border p-3 space-y-1">
<p className="text-xs font-medium text-muted-foreground uppercase tracking-wide mb-2">
POI Distances
</p>
<CounterRow label="Processed" value={task.processed} total={task.total} />
{task.distances_computed !== undefined && (
<CounterRow label="Distances computed" value={task.distances_computed} />
)}
</div>
);
}
function LogViewer({ logs }: { logs: string[] }) { function LogViewer({ logs }: { logs: string[] }) {
const scrollRef = useRef<HTMLDivElement>(null); const scrollRef = useRef<HTMLDivElement>(null);
const isAutoScrolling = useRef(true); const isAutoScrolling = useRef(true);
@ -267,6 +323,90 @@ function LogViewer({ logs }: { logs: string[] }) {
); );
} }
/** Convert TaskState → TaskResult for existing phase detail components. */
function taskStateToResult(ts: TaskState): TaskResult {
return {
progress: ts.progress ?? 0,
processed: ts.processed,
total: ts.total,
phase: ts.phase,
message: ts.message,
subqueries_probed: ts.subqueries_probed,
subqueries_initial: ts.subqueries_initial,
estimated_results: ts.estimated_results,
subqueries_total: ts.subqueries_total,
subqueries_completed: ts.subqueries_completed,
ids_collected: ts.ids_collected,
pages_fetched: ts.pages_fetched,
fetching_done: ts.fetching_done,
details_fetched: ts.details_fetched,
images_downloaded: ts.images_downloaded,
ocr_completed: ts.ocr_completed,
failed: ts.failed,
elapsed_seconds: ts.elapsed_seconds,
rate_per_second: ts.rate_per_second,
eta_seconds: ts.eta_seconds,
logs: ts.logs,
};
}
function TaskTabBar({
tasks,
selectedTaskId,
onSelectTask,
}: {
tasks: Record<string, TaskState>;
selectedTaskId: string | null;
onSelectTask: (taskId: string) => void;
}) {
// Sort: active first, then completed, then failed
const sortedEntries = useMemo(() => {
return Object.entries(tasks).sort(([, a], [, b]) => {
const aTerminal = isTerminalStatus(a.status);
const bTerminal = isTerminalStatus(b.status);
if (aTerminal !== bTerminal) return aTerminal ? 1 : -1;
if (aTerminal && bTerminal) {
if (a.status === 'SUCCESS' && b.status !== 'SUCCESS') return -1;
if (b.status === 'SUCCESS' && a.status !== 'SUCCESS') return 1;
}
return 0;
});
}, [tasks]);
if (sortedEntries.length <= 1) return null;
return (
<div className="flex gap-1 overflow-x-auto px-4 pb-2 scrollbar-thin">
{sortedEntries.map(([tid, task]) => {
const type = inferTaskType(task);
const isSelected = tid === selectedTaskId;
const terminal = isTerminalStatus(task.status);
return (
<button
key={tid}
onClick={() => onSelectTask(tid)}
className={`flex items-center gap-1.5 px-2.5 py-1.5 rounded-md text-xs font-medium whitespace-nowrap transition-colors shrink-0 ${
isSelected
? 'bg-primary text-primary-foreground'
: 'bg-muted text-muted-foreground hover:bg-muted/80'
}`}
>
{!terminal && <Loader2 className="h-3 w-3 animate-spin" />}
{task.status === 'SUCCESS' && <CheckCircle2 className="h-3 w-3 text-green-500" />}
{(task.status === 'FAILURE' || task.status === 'REVOKED') && <XCircle className="h-3 w-3 text-red-500" />}
{taskTypeIcon(type)}
<span>{taskTypeLabel(type)}</span>
{!terminal && task.progress !== undefined && (
<span className="opacity-70">{Math.round(task.progress * 100)}%</span>
)}
</button>
);
})}
</div>
);
}
export function TaskProgressDrawer({ export function TaskProgressDrawer({
open, open,
onOpenChange, onOpenChange,
@ -275,42 +415,75 @@ export function TaskProgressDrawer({
taskID, taskID,
onCancel, onCancel,
isCancelling, isCancelling,
wsTasks,
selectedTaskId,
onSelectTask,
}: TaskProgressDrawerProps) { }: TaskProgressDrawerProps) {
const isInProgress = // Determine which task's data to show
taskStatus !== null && const hasMultipleTasks = wsTasks && Object.keys(wsTasks).length > 1;
taskStatus !== TaskStatus.SUCCESS && const effectiveTaskId = selectedTaskId ?? taskID;
taskStatus !== TaskStatus.FAILURE &&
taskStatus !== TaskStatus.REVOKED;
const progressPercent = taskResult // Derive the active task data from wsTasks if available, else fall back to props
? Math.min((taskResult.progress ?? 0) * 100, 100) const activeWsTask = effectiveTaskId && wsTasks ? wsTasks[effectiveTaskId] : undefined;
const effectiveResult = activeWsTask ? taskStateToResult(activeWsTask) : taskResult;
const effectiveStatus = activeWsTask ? (activeWsTask.status as TaskStatus) : taskStatus;
const effectiveTaskType = activeWsTask
? inferTaskType(activeWsTask)
: (taskResult ? inferTaskType(taskResult) : 'scrape');
const isInProgress =
effectiveStatus !== null &&
effectiveStatus !== undefined &&
!isTerminalStatus(effectiveStatus);
const progressPercent = effectiveResult
? Math.min((effectiveResult.progress ?? 0) * 100, 100)
: 0; : 0;
const drawerTitle = hasMultipleTasks
? 'Job Progress'
: `${taskTypeLabel(effectiveTaskType)} Job Progress`;
return ( return (
<Sheet open={open} onOpenChange={onOpenChange}> <Sheet open={open} onOpenChange={onOpenChange}>
<SheetContent side="right" className="flex flex-col w-full sm:!max-w-lg"> <SheetContent side="right" className="flex flex-col w-full sm:!max-w-lg">
<SheetHeader> <SheetHeader>
<div className="flex items-center justify-between pr-6"> <div className="flex items-center justify-between pr-6">
<SheetTitle>Crawl Job Progress</SheetTitle> <SheetTitle>{drawerTitle}</SheetTitle>
<StatusBadge status={taskStatus} /> <StatusBadge status={effectiveStatus} />
</div> </div>
{taskID && ( {effectiveTaskId && (
<SheetDescription> <SheetDescription>
Task ID: {taskID.slice(0, 8)}... Task ID: {effectiveTaskId.slice(0, 8)}...
</SheetDescription> </SheetDescription>
)} )}
</SheetHeader> </SheetHeader>
{/* Multi-job tab bar */}
{hasMultipleTasks && onSelectTask && (
<TaskTabBar
tasks={wsTasks!}
selectedTaskId={effectiveTaskId}
onSelectTask={onSelectTask}
/>
)}
{/* Fixed top section: timeline + counters + progress */} {/* Fixed top section: timeline + counters + progress */}
<div className="space-y-3 px-4 shrink-0"> <div className="space-y-3 px-4 shrink-0">
<PhaseTimeline <PhaseTimeline
currentPhase={taskResult?.phase} currentPhase={effectiveResult?.phase ?? activeWsTask?.phase}
taskStatus={taskStatus} taskStatus={effectiveStatus}
taskType={effectiveTaskType}
/> />
{taskResult && <PhaseDetails result={taskResult} />} {effectiveTaskType === 'poi' && activeWsTask && (
<POIPhaseDetails task={activeWsTask} />
)}
{effectiveTaskType !== 'poi' && effectiveResult && (
<PhaseDetails result={effectiveResult} />
)}
{taskResult && (taskResult.phase === 'processing' || taskResult.phase === 'fetching') && (taskResult.total ?? 0) > 0 && ( {effectiveResult && (effectiveResult.phase === 'processing' || effectiveResult.phase === 'fetching' || effectiveResult.phase === 'computing') && (effectiveResult.total ?? 0) > 0 && (
<div className="space-y-1"> <div className="space-y-1">
<div className="w-full h-2 bg-primary/20 rounded-full overflow-hidden"> <div className="w-full h-2 bg-primary/20 rounded-full overflow-hidden">
<div <div
@ -320,15 +493,15 @@ export function TaskProgressDrawer({
</div> </div>
<div className="flex justify-between text-xs text-muted-foreground"> <div className="flex justify-between text-xs text-muted-foreground">
<span> <span>
{taskResult.processed ?? 0} / {taskResult.total ?? '?'} {effectiveResult.processed ?? 0} / {effectiveResult.total ?? '?'}
</span> </span>
<span>{formatEta(taskResult.eta_seconds)}</span> <span>{formatEta(effectiveResult.eta_seconds)}</span>
</div> </div>
</div> </div>
)} )}
{taskResult?.message && ( {effectiveResult?.message && (
<p className="text-sm text-muted-foreground">{taskResult.message}</p> <p className="text-sm text-muted-foreground">{effectiveResult.message}</p>
)} )}
</div> </div>
@ -338,7 +511,7 @@ export function TaskProgressDrawer({
Worker Logs Worker Logs
</p> </p>
<div className="flex-1 min-h-0"> <div className="flex-1 min-h-0">
<LogViewer logs={taskResult?.logs ?? []} /> <LogViewer logs={effectiveResult?.logs ?? []} />
</div> </div>
</div> </div>

View file

@ -59,5 +59,8 @@ export const DEFAULT_FORM_VALUES = {
// Polling intervals // Polling intervals
export const POLLING_INTERVALS = { export const POLLING_INTERVALS = {
TASK_STATUS_MS: 5000, // 5 seconds TASK_STATUS_MS: 30000, // 30 seconds (fallback when WebSocket is unavailable)
} as const; } as const;
// WebSocket paths
export const WS_TASKS_PATH = '/api/ws/tasks';

View file

@ -0,0 +1,129 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import type { AuthUser } from '@/auth/types';
import type { TaskState, WSMessage } from '@/types';
import { WS_TASKS_PATH } from '@/constants';
const KEEPALIVE_MS = 30_000;
const MAX_RECONNECT_DELAY_MS = 30_000;
function wsUrl(token: string): string {
const proto = window.location.protocol === 'https:' ? 'wss' : 'ws';
return `${proto}://${window.location.host}${WS_TASKS_PATH}?token=${encodeURIComponent(token)}`;
}
export interface UseTaskWebSocketReturn {
tasks: Record<string, TaskState>;
isConnected: boolean;
subscribe: (taskId: string) => void;
}
export function useTaskWebSocket(user: AuthUser | null): UseTaskWebSocketReturn {
const [tasks, setTasks] = useState<Record<string, TaskState>>({});
const [isConnected, setIsConnected] = useState(false);
const wsRef = useRef<WebSocket | null>(null);
const reconnectAttempt = useRef(0);
const reconnectTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
const keepaliveTimer = useRef<ReturnType<typeof setInterval> | null>(null);
const mountedRef = useRef(true);
const clearTimers = useCallback(() => {
if (reconnectTimer.current) {
clearTimeout(reconnectTimer.current);
reconnectTimer.current = null;
}
if (keepaliveTimer.current) {
clearInterval(keepaliveTimer.current);
keepaliveTimer.current = null;
}
}, []);
const connect = useCallback(() => {
if (!user) return;
const ws = new WebSocket(wsUrl(user.accessToken));
wsRef.current = ws;
ws.onopen = () => {
if (!mountedRef.current) return;
setIsConnected(true);
reconnectAttempt.current = 0;
// Start keepalive pings
keepaliveTimer.current = setInterval(() => {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: 'ping' }));
}
}, KEEPALIVE_MS);
};
ws.onmessage = (event) => {
if (!mountedRef.current) return;
try {
const msg: WSMessage = JSON.parse(event.data);
if (msg.type === 'init') {
const initial: Record<string, TaskState> = {};
for (const t of msg.tasks) {
initial[t.task_id] = t;
}
setTasks(initial);
} else if (msg.type === 'task_update') {
const { type: _, ...taskData } = msg;
setTasks((prev) => ({
...prev,
[msg.task_id]: { ...prev[msg.task_id], ...taskData } as TaskState,
}));
}
// pong messages are ignored
} catch {
// Ignore malformed messages
}
};
ws.onclose = () => {
if (!mountedRef.current) return;
setIsConnected(false);
if (keepaliveTimer.current) {
clearInterval(keepaliveTimer.current);
keepaliveTimer.current = null;
}
// Exponential backoff reconnect
const delay = Math.min(
1000 * 2 ** reconnectAttempt.current,
MAX_RECONNECT_DELAY_MS,
);
reconnectAttempt.current += 1;
reconnectTimer.current = setTimeout(() => {
if (mountedRef.current) connect();
}, delay);
};
ws.onerror = () => {
// onclose will fire after this, triggering reconnect
};
}, [user, clearTimers]);
useEffect(() => {
mountedRef.current = true;
connect();
return () => {
mountedRef.current = false;
clearTimers();
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
}
};
}, [connect, clearTimers]);
const subscribe = useCallback((taskId: string) => {
const ws = wsRef.current;
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ type: 'subscribe', task_id: taskId }));
}
}, []);
return { tasks, isConnected, subscribe };
}

View file

@ -120,3 +120,55 @@ export interface POITravelFilter {
travelMode: 'WALK' | 'BICYCLE' | 'TRANSIT'; travelMode: 'WALK' | 'BICYCLE' | 'TRANSIT';
maxMinutes: number | undefined; maxMinutes: number | undefined;
} }
// WebSocket task state (combines status + result fields)
export interface TaskState {
task_id: string;
status: string;
progress?: number;
processed?: number;
total?: number;
phase?: TaskPhase;
message?: string;
// Splitting phase
subqueries_probed?: number;
subqueries_initial?: number;
estimated_results?: number;
subqueries_total?: number;
// Fetching phase
subqueries_completed?: number;
ids_collected?: number;
pages_fetched?: number;
fetching_done?: boolean;
// Processing phase
details_fetched?: number;
images_downloaded?: number;
ocr_completed?: number;
failed?: number;
elapsed_seconds?: number;
rate_per_second?: number;
eta_seconds?: number;
// POI-specific
distances_computed?: number;
// Live logs
logs?: string[];
}
// WebSocket message types
export interface WSInitMessage {
type: 'init';
tasks: TaskState[];
}
export interface WSTaskUpdateMessage {
type: 'task_update';
task_id: string;
status: string;
[key: string]: unknown;
}
export interface WSPongMessage {
type: 'pong';
}
export type WSMessage = WSInitMessage | WSTaskUpdateMessage | WSPongMessage;

View file

@ -0,0 +1,50 @@
"""Publishes task progress updates to Redis pub/sub channels.
Celery workers call publish_task_progress() alongside task.update_state() so
that the FastAPI WebSocket endpoint can forward real-time updates to connected
browsers without polling.
Channel naming: ``task_progress:{task_id}``
"""
import json
import logging
import os
from typing import Any
import redis
logger = logging.getLogger(__name__)
_redis_client: redis.Redis | None = None # type: ignore[type-arg]
def _get_redis_client() -> redis.Redis: # type: ignore[type-arg]
"""Lazy-initialised Redis client derived from CELERY_BROKER_URL."""
global _redis_client
if _redis_client is None:
broker_url = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0")
_redis_client = redis.Redis.from_url(broker_url, decode_responses=True)
return _redis_client
def publish_task_progress(task_id: str, state: str, meta: dict[str, Any]) -> None:
"""Publish a task progress update to Redis pub/sub.
Args:
task_id: Celery task ID.
state: Celery state string (e.g. 'PROGRESS', 'SUCCESS').
meta: Metadata dict (progress, phase, logs, counters, etc.).
Failures are caught and logged at DEBUG level so they never break the
critical task execution path.
"""
try:
client = _get_redis_client()
payload = json.dumps({
"task_id": task_id,
"status": state,
**meta,
})
client.publish(f"task_progress:{task_id}", payload)
except Exception:
logger.debug("Failed to publish task progress for %s", task_id, exc_info=True)

View file

@ -177,11 +177,19 @@ def cancel_task(task_id: str, user_email: str | None = None) -> bool:
""" """
# Lazy import: celery_app bootstraps the broker connection. # Lazy import: celery_app bootstraps the broker connection.
from celery_app import app as celery_app from celery_app import app as celery_app
from services.task_progress_publisher import publish_task_progress
logger.info("Cancelling task %s (user=%s)", task_id, user_email) logger.info("Cancelling task %s (user=%s)", task_id, user_email)
# Revoke the task in Celery # Revoke the task in Celery
celery_app.control.revoke(task_id, terminate=True) celery_app.control.revoke(task_id, terminate=True)
# Publish REVOKED status via pub/sub so WebSocket clients learn immediately
publish_task_progress(task_id, "REVOKED", {
"phase": "completed",
"progress": 0,
"message": "Task cancelled",
})
# Also remove from user's task list if user_email provided # Also remove from user's task list if user_email provided
if user_email: if user_email:
remove_task_from_user(user_email, task_id) remove_task_from_user(user_email, task_id)
@ -222,6 +230,7 @@ def clear_all_tasks(user_email: str, revoke: bool = True) -> int:
# Lazy imports: see get_user_tasks and cancel_task for rationale. # Lazy imports: see get_user_tasks and cancel_task for rationale.
from redis_repository import RedisRepository from redis_repository import RedisRepository
from celery_app import app as celery_app from celery_app import app as celery_app
from services.task_progress_publisher import publish_task_progress
redis_repo = RedisRepository.instance() redis_repo = RedisRepository.instance()
user = _make_system_user(user_email) user = _make_system_user(user_email)
@ -238,5 +247,11 @@ def clear_all_tasks(user_email: str, revoke: bool = True) -> int:
logger.warning( logger.warning(
"Failed to revoke task %s: %s", task_id, e "Failed to revoke task %s: %s", task_id, e
) )
# Publish REVOKED via pub/sub so WebSocket clients learn immediately
publish_task_progress(task_id, "REVOKED", {
"phase": "completed",
"progress": 0,
"message": "Task cancelled",
})
return redis_repo.clear_tasks_for_user(user) return redis_repo.clear_tasks_for_user(user)

View file

@ -19,6 +19,7 @@ from database import engine
from services.query_splitter import QuerySplitter, SubQuery from services.query_splitter import QuerySplitter, SubQuery
from utils.redis_lock import redis_lock from utils.redis_lock import redis_lock
from services.listing_cache import invalidate_cache from services.listing_cache import invalidate_cache
from services.task_progress_publisher import publish_task_progress
logger = logging.getLogger("uvicorn.error") logger = logging.getLogger("uvicorn.error")
@ -86,6 +87,8 @@ def _update_task_state(task: Task, state: str, meta: dict[str, Any]) -> None:
if _active_log_buffer is not None: if _active_log_buffer is not None:
meta["logs"] = list(_active_log_buffer) meta["logs"] = list(_active_log_buffer)
task.update_state(state=state, meta=meta) task.update_state(state=state, meta=meta)
if hasattr(task, 'request') and task.request.id:
publish_task_progress(task.request.id, state, meta)
async def _fetch_subquery( async def _fetch_subquery(
@ -266,7 +269,9 @@ def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
if not acquired: if not acquired:
msg = "Another scrape job is already running, skipping this execution" msg = "Another scrape job is already running, skipping this execution"
celery_logger.warning(msg) celery_logger.warning(msg)
self.update_state(state="SKIPPED", meta={"reason": "Another scrape job is running"}) meta = {"reason": "Another scrape job is running"}
self.update_state(state="SKIPPED", meta=meta)
publish_task_progress(self.request.id, "SKIPPED", meta)
return {"status": "skipped", "reason": "another_job_running"} return {"status": "skipped", "reason": "another_job_running"}
celery_logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}") celery_logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}")
@ -275,8 +280,11 @@ def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
celery_logger.info(f"Starting scrape with parameters: {parsed_parameters}") celery_logger.info(f"Starting scrape with parameters: {parsed_parameters}")
self.update_state(state="Starting...", meta={"phase": PHASE_SPLITTING, "progress": 0}) self.update_state(state="Starting...", meta={"phase": PHASE_SPLITTING, "progress": 0})
publish_task_progress(self.request.id, "Starting...", {"phase": PHASE_SPLITTING, "progress": 0})
asyncio.run(dump_listings_full(task=self, parameters=parsed_parameters)) asyncio.run(dump_listings_full(task=self, parameters=parsed_parameters))
return {"phase": PHASE_COMPLETED, "progress": 1} result = {"phase": PHASE_COMPLETED, "progress": 1}
publish_task_progress(self.request.id, "SUCCESS", result)
return result
async def async_dump_listings_task(parameters_json: str) -> dict[str, Any]: async def async_dump_listings_task(parameters_json: str) -> dict[str, Any]:

View file

@ -10,6 +10,7 @@ from models.listing import ListingType
from repositories.listing_repository import ListingRepository from repositories.listing_repository import ListingRepository
from repositories.poi_repository import POIRepository from repositories.poi_repository import POIRepository
from services.poi_distance_calculator import calculate_poi_distances from services.poi_distance_calculator import calculate_poi_distances
from services.task_progress_publisher import publish_task_progress
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,6 +56,11 @@ def calculate_poi_distances_task(
"progress": 0, "progress": 0,
"message": "Starting distance calculation...", "message": "Starting distance calculation...",
}) })
publish_task_progress(self.request.id, "PROGRESS", {
"phase": "starting",
"progress": 0,
"message": "Starting distance calculation...",
})
listing_repo = ListingRepository(engine) listing_repo = ListingRepository(engine)
poi_repo = POIRepository(engine) poi_repo = POIRepository(engine)
@ -62,19 +68,23 @@ def calculate_poi_distances_task(
poi = poi_repo.get_poi_by_id(poi_id) poi = poi_repo.get_poi_by_id(poi_id)
if poi is None: if poi is None:
celery_logger.error(f"POI {poi_id} not found") celery_logger.error(f"POI {poi_id} not found")
return {"error": f"POI {poi_id} not found", "distances_computed": 0} error_result = {"error": f"POI {poi_id} not found", "distances_computed": 0}
publish_task_progress(self.request.id, "FAILURE", error_result)
return error_result
lt = ListingType(listing_type) lt = ListingType(listing_type)
def on_progress(completed: int, total: int, message: str) -> None: def on_progress(completed: int, total: int, message: str) -> None:
progress = round(completed / total, 2) if total > 0 else 0 progress = round(completed / total, 2) if total > 0 else 0
self.update_state(state="PROGRESS", meta={ meta = {
"phase": "computing", "phase": "computing",
"progress": progress, "progress": progress,
"processed": completed, "processed": completed,
"total": total, "total": total,
"message": message, "message": message,
}) }
self.update_state(state="PROGRESS", meta=meta)
publish_task_progress(self.request.id, "PROGRESS", meta)
try: try:
total = asyncio.run( total = asyncio.run(
@ -96,9 +106,12 @@ def calculate_poi_distances_task(
celery_logger.info(f"POI distance calculation complete: {total} distances computed") celery_logger.info(f"POI distance calculation complete: {total} distances computed")
return { result = {
"phase": "completed", "phase": "completed",
"progress": 1, "progress": 1,
"distances_computed": total, "distances_computed": total,
"message": f"Computed {total} distances for POI '{poi.name}'", "message": f"Computed {total} distances for POI '{poi.name}'",
} }
publish_task_progress(self.request.id, "SUCCESS", result)
return result