Add real-time WebSocket task progress with multi-job drawer
Replace 5s HTTP polling with WebSocket-based real-time updates for task progress. Celery workers publish progress to Redis pub/sub channels; a FastAPI WebSocket endpoint subscribes and forwards to the browser. Polling is kept as a 30s fallback when WebSocket is unavailable. The task progress drawer now supports multiple concurrent jobs with a tab bar for switching between scrape and POI distance tasks. Backend: - Add services/task_progress_publisher.py (Redis pub/sub bridge) - Add api/ws_routes.py (WebSocket endpoint with JWT auth) - Publish progress from listing_tasks and poi_tasks - Publish REVOKED via pub/sub on cancel/clear to fix stuck UI Frontend: - Add useTaskWebSocket hook with reconnection and keepalive - Add TaskState and WS message types - TaskIndicator: WS-driven updates with polling fallback - TaskProgressDrawer: multi-job tabs, POI phase timeline - Guard against WS overwriting local cancel state
This commit is contained in:
parent
73d19e29d5
commit
8559c4b461
11 changed files with 774 additions and 72 deletions
153
api/ws_routes.py
Normal file
153
api/ws_routes.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
"""WebSocket endpoint for real-time task progress updates.
|
||||
|
||||
Clients connect to ``/api/ws/tasks?token=<jwt>`` and receive live progress
|
||||
messages published by Celery workers via Redis pub/sub.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
import redis.asyncio as aioredis
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from api.auth import _verify_authentik_token, _verify_passkey_token, User
|
||||
from api.config import JWT_ISSUER
|
||||
from services import task_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ws_router = APIRouter()
|
||||
|
||||
# Reuse the broker URL for the async Redis client
|
||||
import os
|
||||
_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0")
|
||||
|
||||
|
||||
async def _authenticate_ws(token: str) -> User | None:
|
||||
"""Verify a JWT token using the same logic as api/auth.py."""
|
||||
try:
|
||||
unverified = jwt.decode(
|
||||
token, options={"verify_signature": False, "verify_exp": False}
|
||||
)
|
||||
issuer = unverified.get("iss", "")
|
||||
if issuer == JWT_ISSUER:
|
||||
return _verify_passkey_token(token)
|
||||
else:
|
||||
return await _verify_authentik_token(token)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def _build_task_snapshot(task_id: str) -> dict[str, Any]:
|
||||
"""Build a snapshot of a task's current status for the init message."""
|
||||
status = task_service.get_task_status(task_id)
|
||||
result: dict[str, Any] = {
|
||||
"task_id": status.task_id,
|
||||
"status": status.status,
|
||||
"progress": status.progress,
|
||||
"processed": status.processed,
|
||||
"total": status.total,
|
||||
"message": status.message,
|
||||
}
|
||||
if status.result and isinstance(status.result, dict):
|
||||
result.update(status.result)
|
||||
return result
|
||||
|
||||
|
||||
@ws_router.websocket("/api/ws/tasks")
|
||||
async def ws_task_progress(websocket: WebSocket) -> None:
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.close(code=4001, reason="Missing token")
|
||||
return
|
||||
|
||||
user = await _authenticate_ws(token)
|
||||
if user is None:
|
||||
await websocket.close(code=4003, reason="Invalid token")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# Get user's tasks and send initial snapshot
|
||||
task_ids = task_service.get_user_tasks(user.email)
|
||||
snapshots = []
|
||||
for tid in task_ids:
|
||||
try:
|
||||
snapshots.append(await _build_task_snapshot(tid))
|
||||
except Exception:
|
||||
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
|
||||
|
||||
try:
|
||||
await websocket.send_json({"type": "init", "tasks": snapshots})
|
||||
except Exception:
|
||||
return
|
||||
|
||||
# Subscribe to Redis pub/sub channels for each task
|
||||
redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True)
|
||||
pubsub = redis_client.pubsub()
|
||||
|
||||
subscribed_channels: set[str] = set()
|
||||
for tid in task_ids:
|
||||
channel = f"task_progress:{tid}"
|
||||
await pubsub.subscribe(channel)
|
||||
subscribed_channels.add(channel)
|
||||
|
||||
async def _forward_pubsub() -> None:
|
||||
"""Read from Redis pub/sub and forward to the WebSocket."""
|
||||
while True:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if message and message["type"] == "message":
|
||||
try:
|
||||
data = json.loads(message["data"])
|
||||
await websocket.send_json({"type": "task_update", **data})
|
||||
except Exception:
|
||||
break
|
||||
|
||||
async def _handle_client_messages() -> None:
|
||||
"""Read messages from the client (subscribe, ping)."""
|
||||
while True:
|
||||
try:
|
||||
raw = await websocket.receive_text()
|
||||
msg = json.loads(raw)
|
||||
except WebSocketDisconnect:
|
||||
raise
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
msg_type = msg.get("type")
|
||||
if msg_type == "subscribe":
|
||||
new_task_id = msg.get("task_id")
|
||||
if new_task_id:
|
||||
channel = f"task_progress:{new_task_id}"
|
||||
if channel not in subscribed_channels:
|
||||
await pubsub.subscribe(channel)
|
||||
subscribed_channels.add(channel)
|
||||
# Send current snapshot for the new task
|
||||
try:
|
||||
snapshot = await _build_task_snapshot(new_task_id)
|
||||
await websocket.send_json(
|
||||
{"type": "task_update", **snapshot}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
elif msg_type == "ping":
|
||||
try:
|
||||
await websocket.send_json({"type": "pong"})
|
||||
except Exception:
|
||||
break
|
||||
|
||||
try:
|
||||
await asyncio.gather(
|
||||
_forward_pubsub(),
|
||||
_handle_client_messages(),
|
||||
)
|
||||
except (WebSocketDisconnect, Exception):
|
||||
pass
|
||||
finally:
|
||||
await pubsub.unsubscribe(*subscribed_channels)
|
||||
await pubsub.close()
|
||||
await redis_client.aclose()
|
||||
Loading…
Add table
Add a link
Reference in a new issue