wrongmove/api/ws_routes.py

179 lines
6.1 KiB
Python
Raw Normal View History

"""WebSocket endpoint for real-time task progress updates.
Clients connect to ``/api/ws/tasks?token=<jwt>`` and receive live progress
messages published by Celery workers via Redis pub/sub.
"""
import asyncio
import json
import logging
from typing import Any
import jwt
import redis.asyncio as aioredis
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from api.auth import _verify_authentik_token, _verify_passkey_token, User
from api.config import JWT_ISSUER
from services import task_service
logger = logging.getLogger(__name__)
ws_router = APIRouter()
# Reuse the broker URL for the async Redis client
import os
_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0")
async def _authenticate_ws(token: str) -> User | None:
"""Verify a JWT token using the same logic as api/auth.py."""
try:
unverified = jwt.decode(
token, options={"verify_signature": False, "verify_exp": False}
)
issuer = unverified.get("iss", "")
if issuer == JWT_ISSUER:
return _verify_passkey_token(token)
else:
return await _verify_authentik_token(token)
except Exception:
return None
async def _build_task_snapshot(task_id: str) -> dict[str, Any]:
"""Build a snapshot of a task's current status for the init message."""
return await asyncio.to_thread(_build_task_snapshot_sync, task_id)
def _build_task_snapshot_sync(task_id: str) -> dict[str, Any]:
"""Synchronous helper — runs in a thread to avoid blocking the loop."""
status = task_service.get_task_status(task_id)
result: dict[str, Any] = {
"task_id": status.task_id,
"status": status.status,
"progress": status.progress,
"processed": status.processed,
"total": status.total,
"message": status.message,
}
if status.result and isinstance(status.result, dict):
result.update(status.result)
return result
@ws_router.websocket("/api/ws/tasks")
async def ws_task_progress(websocket: WebSocket) -> None:
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=4001, reason="Missing token")
return
user = await _authenticate_ws(token)
if user is None:
await websocket.close(code=4003, reason="Invalid token")
return
await websocket.accept()
task_ids = task_service.get_user_tasks(user.email)
# Subscribe to Redis pub/sub FIRST so no updates are lost while
# building snapshots. Messages that arrive between subscribe and
# the init send are buffered by Redis and forwarded afterwards.
redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True)
pubsub = redis_client.pubsub()
subscribed_channels: set[str] = set()
for tid in task_ids:
channel = f"task_progress:{tid}"
await pubsub.subscribe(channel)
subscribed_channels.add(channel)
# Now build snapshots (safe — pub/sub is already active)
# _build_task_snapshot calls synchronous Celery APIs, so run in a
# thread to avoid blocking the event loop.
snapshots = []
for tid in task_ids:
try:
snapshots.append(await asyncio.to_thread(_build_task_snapshot_sync, tid))
except Exception:
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
try:
await websocket.send_json({"type": "init", "tasks": snapshots})
except Exception:
await pubsub.unsubscribe(*subscribed_channels)
await pubsub.close()
await redis_client.aclose()
return
async def _forward_pubsub() -> None:
"""Read from Redis pub/sub and forward to the WebSocket."""
async for message in pubsub.listen():
if message["type"] != "message":
continue
try:
data = json.loads(message["data"])
except (json.JSONDecodeError, ValueError):
logger.debug("Malformed pubsub message, skipping")
continue
try:
await websocket.send_json({"type": "task_update", **data})
except Exception:
break
async def _handle_client_messages() -> None:
"""Read messages from the client (subscribe, ping)."""
while True:
try:
raw = await websocket.receive_text()
msg = json.loads(raw)
except WebSocketDisconnect:
raise
except Exception:
continue
msg_type = msg.get("type")
if msg_type == "subscribe":
new_task_id = msg.get("task_id")
if new_task_id:
channel = f"task_progress:{new_task_id}"
if channel not in subscribed_channels:
await pubsub.subscribe(channel)
subscribed_channels.add(channel)
# Send current snapshot for the new task
try:
snapshot = await _build_task_snapshot(new_task_id)
await websocket.send_json(
{"type": "task_update", **snapshot}
)
except Exception:
pass
elif msg_type == "ping":
try:
await websocket.send_json({"type": "pong"})
except Exception:
break
try:
ws_tasks = [
asyncio.create_task(_forward_pubsub()),
asyncio.create_task(_handle_client_messages()),
]
done, pending = await asyncio.wait(
ws_tasks, return_when=asyncio.FIRST_COMPLETED
)
for t in pending:
t.cancel()
# Log non-trivial errors from the completed task(s)
for t in done:
exc = t.exception()
if exc and not isinstance(exc, (WebSocketDisconnect, asyncio.CancelledError)):
logger.debug("WS task ended with error: %s", exc)
except (WebSocketDisconnect, Exception):
pass
finally:
await pubsub.unsubscribe(*subscribed_channels)
await pubsub.close()
await redis_client.aclose()