2026-02-09 21:31:45 +00:00
|
|
|
"""WebSocket endpoint for real-time task progress updates.
|
|
|
|
|
|
|
|
|
|
Clients connect to ``/api/ws/tasks?token=<jwt>`` and receive live progress
|
|
|
|
|
messages published by Celery workers via Redis pub/sub.
|
|
|
|
|
"""
|
|
|
|
|
import asyncio
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
import jwt
|
|
|
|
|
import redis.asyncio as aioredis
|
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
|
|
|
|
|
|
|
|
from api.auth import _verify_authentik_token, _verify_passkey_token, User
|
|
|
|
|
from api.config import JWT_ISSUER
|
|
|
|
|
from services import task_service
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
ws_router = APIRouter()
|
|
|
|
|
|
|
|
|
|
# Reuse the broker URL for the async Redis client
|
|
|
|
|
import os
|
|
|
|
|
_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _authenticate_ws(token: str) -> User | None:
|
|
|
|
|
"""Verify a JWT token using the same logic as api/auth.py."""
|
|
|
|
|
try:
|
|
|
|
|
unverified = jwt.decode(
|
|
|
|
|
token, options={"verify_signature": False, "verify_exp": False}
|
|
|
|
|
)
|
|
|
|
|
issuer = unverified.get("iss", "")
|
|
|
|
|
if issuer == JWT_ISSUER:
|
|
|
|
|
return _verify_passkey_token(token)
|
|
|
|
|
else:
|
|
|
|
|
return await _verify_authentik_token(token)
|
|
|
|
|
except Exception:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _build_task_snapshot(task_id: str) -> dict[str, Any]:
|
|
|
|
|
"""Build a snapshot of a task's current status for the init message."""
|
2026-02-09 22:48:57 +00:00
|
|
|
return await asyncio.to_thread(_build_task_snapshot_sync, task_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_task_snapshot_sync(task_id: str) -> dict[str, Any]:
|
|
|
|
|
"""Synchronous helper — runs in a thread to avoid blocking the loop."""
|
2026-02-09 21:31:45 +00:00
|
|
|
status = task_service.get_task_status(task_id)
|
|
|
|
|
result: dict[str, Any] = {
|
|
|
|
|
"task_id": status.task_id,
|
|
|
|
|
"status": status.status,
|
|
|
|
|
"progress": status.progress,
|
|
|
|
|
"processed": status.processed,
|
|
|
|
|
"total": status.total,
|
|
|
|
|
"message": status.message,
|
|
|
|
|
}
|
|
|
|
|
if status.result and isinstance(status.result, dict):
|
|
|
|
|
result.update(status.result)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ws_router.websocket("/api/ws/tasks")
|
|
|
|
|
async def ws_task_progress(websocket: WebSocket) -> None:
|
|
|
|
|
token = websocket.query_params.get("token")
|
|
|
|
|
if not token:
|
|
|
|
|
await websocket.close(code=4001, reason="Missing token")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
user = await _authenticate_ws(token)
|
|
|
|
|
if user is None:
|
|
|
|
|
await websocket.close(code=4003, reason="Invalid token")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
|
|
|
|
|
task_ids = task_service.get_user_tasks(user.email)
|
2026-02-09 22:48:57 +00:00
|
|
|
|
|
|
|
|
# Subscribe to Redis pub/sub FIRST so no updates are lost while
|
|
|
|
|
# building snapshots. Messages that arrive between subscribe and
|
|
|
|
|
# the init send are buffered by Redis and forwarded afterwards.
|
|
|
|
|
redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True)
|
|
|
|
|
pubsub = redis_client.pubsub()
|
|
|
|
|
|
|
|
|
|
subscribed_channels: set[str] = set()
|
|
|
|
|
for tid in task_ids:
|
|
|
|
|
channel = f"task_progress:{tid}"
|
|
|
|
|
await pubsub.subscribe(channel)
|
|
|
|
|
subscribed_channels.add(channel)
|
|
|
|
|
|
|
|
|
|
# Now build snapshots (safe — pub/sub is already active)
|
|
|
|
|
# _build_task_snapshot calls synchronous Celery APIs, so run in a
|
|
|
|
|
# thread to avoid blocking the event loop.
|
2026-02-09 21:31:45 +00:00
|
|
|
snapshots = []
|
|
|
|
|
for tid in task_ids:
|
|
|
|
|
try:
|
2026-02-09 22:48:57 +00:00
|
|
|
snapshots.append(await asyncio.to_thread(_build_task_snapshot_sync, tid))
|
2026-02-09 21:31:45 +00:00
|
|
|
except Exception:
|
|
|
|
|
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
await websocket.send_json({"type": "init", "tasks": snapshots})
|
|
|
|
|
except Exception:
|
2026-02-09 22:48:57 +00:00
|
|
|
await pubsub.unsubscribe(*subscribed_channels)
|
|
|
|
|
await pubsub.close()
|
|
|
|
|
await redis_client.aclose()
|
2026-02-09 21:31:45 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
|
|
async def _forward_pubsub() -> None:
|
|
|
|
|
"""Read from Redis pub/sub and forward to the WebSocket."""
|
|
|
|
|
while True:
|
|
|
|
|
message = await pubsub.get_message(
|
|
|
|
|
ignore_subscribe_messages=True, timeout=1.0
|
|
|
|
|
)
|
|
|
|
|
if message and message["type"] == "message":
|
|
|
|
|
try:
|
|
|
|
|
data = json.loads(message["data"])
|
2026-02-09 22:48:57 +00:00
|
|
|
except (json.JSONDecodeError, ValueError):
|
|
|
|
|
logger.debug("Malformed pubsub message, skipping")
|
|
|
|
|
continue
|
|
|
|
|
try:
|
2026-02-09 21:31:45 +00:00
|
|
|
await websocket.send_json({"type": "task_update", **data})
|
|
|
|
|
except Exception:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
async def _handle_client_messages() -> None:
|
|
|
|
|
"""Read messages from the client (subscribe, ping)."""
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
raw = await websocket.receive_text()
|
|
|
|
|
msg = json.loads(raw)
|
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
|
raise
|
|
|
|
|
except Exception:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
msg_type = msg.get("type")
|
|
|
|
|
if msg_type == "subscribe":
|
|
|
|
|
new_task_id = msg.get("task_id")
|
|
|
|
|
if new_task_id:
|
|
|
|
|
channel = f"task_progress:{new_task_id}"
|
|
|
|
|
if channel not in subscribed_channels:
|
|
|
|
|
await pubsub.subscribe(channel)
|
|
|
|
|
subscribed_channels.add(channel)
|
|
|
|
|
# Send current snapshot for the new task
|
|
|
|
|
try:
|
|
|
|
|
snapshot = await _build_task_snapshot(new_task_id)
|
|
|
|
|
await websocket.send_json(
|
|
|
|
|
{"type": "task_update", **snapshot}
|
|
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
elif msg_type == "ping":
|
|
|
|
|
try:
|
|
|
|
|
await websocket.send_json({"type": "pong"})
|
|
|
|
|
except Exception:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
try:
|
2026-02-09 22:48:57 +00:00
|
|
|
ws_tasks = [
|
|
|
|
|
asyncio.create_task(_forward_pubsub()),
|
|
|
|
|
asyncio.create_task(_handle_client_messages()),
|
|
|
|
|
]
|
|
|
|
|
done, pending = await asyncio.wait(
|
|
|
|
|
ws_tasks, return_when=asyncio.FIRST_COMPLETED
|
2026-02-09 21:31:45 +00:00
|
|
|
)
|
2026-02-09 22:48:57 +00:00
|
|
|
for t in pending:
|
|
|
|
|
t.cancel()
|
|
|
|
|
# Log non-trivial errors from the completed task(s)
|
|
|
|
|
for t in done:
|
|
|
|
|
exc = t.exception()
|
|
|
|
|
if exc and not isinstance(exc, (WebSocketDisconnect, asyncio.CancelledError)):
|
|
|
|
|
logger.debug("WS task ended with error: %s", exc)
|
2026-02-09 21:31:45 +00:00
|
|
|
except (WebSocketDisconnect, Exception):
|
|
|
|
|
pass
|
|
|
|
|
finally:
|
|
|
|
|
await pubsub.unsubscribe(*subscribed_channels)
|
|
|
|
|
await pubsub.close()
|
|
|
|
|
await redis_client.aclose()
|