Replace 5s HTTP polling with WebSocket-based real-time updates for task progress. Celery workers publish progress to Redis pub/sub channels; a FastAPI WebSocket endpoint subscribes and forwards to the browser. Polling is kept as a 30s fallback when WebSocket is unavailable. The task progress drawer now supports multiple concurrent jobs with a tab bar for switching between scrape and POI distance tasks. Backend: - Add services/task_progress_publisher.py (Redis pub/sub bridge) - Add api/ws_routes.py (WebSocket endpoint with JWT auth) - Publish progress from listing_tasks and poi_tasks - Publish REVOKED via pub/sub on cancel/clear to fix stuck UI Frontend: - Add useTaskWebSocket hook with reconnection and keepalive - Add TaskState and WS message types - TaskIndicator: WS-driven updates with polling fallback - TaskProgressDrawer: multi-job tabs, POI phase timeline - Guard against WS overwriting local cancel state
153 lines
5 KiB
Python
153 lines
5 KiB
Python
"""WebSocket endpoint for real-time task progress updates.
|
|
|
|
Clients connect to ``/api/ws/tasks?token=<jwt>`` and receive live progress
|
|
messages published by Celery workers via Redis pub/sub.
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
import jwt
|
|
import redis.asyncio as aioredis
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
|
|
from api.auth import _verify_authentik_token, _verify_passkey_token, User
|
|
from api.config import JWT_ISSUER
|
|
from services import task_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ws_router = APIRouter()
|
|
|
|
# Reuse the broker URL for the async Redis client
|
|
import os
|
|
_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0")
|
|
|
|
|
|
async def _authenticate_ws(token: str) -> User | None:
|
|
"""Verify a JWT token using the same logic as api/auth.py."""
|
|
try:
|
|
unverified = jwt.decode(
|
|
token, options={"verify_signature": False, "verify_exp": False}
|
|
)
|
|
issuer = unverified.get("iss", "")
|
|
if issuer == JWT_ISSUER:
|
|
return _verify_passkey_token(token)
|
|
else:
|
|
return await _verify_authentik_token(token)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
async def _build_task_snapshot(task_id: str) -> dict[str, Any]:
|
|
"""Build a snapshot of a task's current status for the init message."""
|
|
status = task_service.get_task_status(task_id)
|
|
result: dict[str, Any] = {
|
|
"task_id": status.task_id,
|
|
"status": status.status,
|
|
"progress": status.progress,
|
|
"processed": status.processed,
|
|
"total": status.total,
|
|
"message": status.message,
|
|
}
|
|
if status.result and isinstance(status.result, dict):
|
|
result.update(status.result)
|
|
return result
|
|
|
|
|
|
@ws_router.websocket("/api/ws/tasks")
|
|
async def ws_task_progress(websocket: WebSocket) -> None:
|
|
token = websocket.query_params.get("token")
|
|
if not token:
|
|
await websocket.close(code=4001, reason="Missing token")
|
|
return
|
|
|
|
user = await _authenticate_ws(token)
|
|
if user is None:
|
|
await websocket.close(code=4003, reason="Invalid token")
|
|
return
|
|
|
|
await websocket.accept()
|
|
|
|
# Get user's tasks and send initial snapshot
|
|
task_ids = task_service.get_user_tasks(user.email)
|
|
snapshots = []
|
|
for tid in task_ids:
|
|
try:
|
|
snapshots.append(await _build_task_snapshot(tid))
|
|
except Exception:
|
|
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
|
|
|
|
try:
|
|
await websocket.send_json({"type": "init", "tasks": snapshots})
|
|
except Exception:
|
|
return
|
|
|
|
# Subscribe to Redis pub/sub channels for each task
|
|
redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True)
|
|
pubsub = redis_client.pubsub()
|
|
|
|
subscribed_channels: set[str] = set()
|
|
for tid in task_ids:
|
|
channel = f"task_progress:{tid}"
|
|
await pubsub.subscribe(channel)
|
|
subscribed_channels.add(channel)
|
|
|
|
async def _forward_pubsub() -> None:
|
|
"""Read from Redis pub/sub and forward to the WebSocket."""
|
|
while True:
|
|
message = await pubsub.get_message(
|
|
ignore_subscribe_messages=True, timeout=1.0
|
|
)
|
|
if message and message["type"] == "message":
|
|
try:
|
|
data = json.loads(message["data"])
|
|
await websocket.send_json({"type": "task_update", **data})
|
|
except Exception:
|
|
break
|
|
|
|
async def _handle_client_messages() -> None:
|
|
"""Read messages from the client (subscribe, ping)."""
|
|
while True:
|
|
try:
|
|
raw = await websocket.receive_text()
|
|
msg = json.loads(raw)
|
|
except WebSocketDisconnect:
|
|
raise
|
|
except Exception:
|
|
continue
|
|
|
|
msg_type = msg.get("type")
|
|
if msg_type == "subscribe":
|
|
new_task_id = msg.get("task_id")
|
|
if new_task_id:
|
|
channel = f"task_progress:{new_task_id}"
|
|
if channel not in subscribed_channels:
|
|
await pubsub.subscribe(channel)
|
|
subscribed_channels.add(channel)
|
|
# Send current snapshot for the new task
|
|
try:
|
|
snapshot = await _build_task_snapshot(new_task_id)
|
|
await websocket.send_json(
|
|
{"type": "task_update", **snapshot}
|
|
)
|
|
except Exception:
|
|
pass
|
|
elif msg_type == "ping":
|
|
try:
|
|
await websocket.send_json({"type": "pong"})
|
|
except Exception:
|
|
break
|
|
|
|
try:
|
|
await asyncio.gather(
|
|
_forward_pubsub(),
|
|
_handle_client_messages(),
|
|
)
|
|
except (WebSocketDisconnect, Exception):
|
|
pass
|
|
finally:
|
|
await pubsub.unsubscribe(*subscribed_channels)
|
|
await pubsub.close()
|
|
await redis_client.aclose()
|