wrongmove/api/ws_routes.py

154 lines
5 KiB
Python
Raw Normal View History

"""WebSocket endpoint for real-time task progress updates.
Clients connect to ``/api/ws/tasks?token=<jwt>`` and receive live progress
messages published by Celery workers via Redis pub/sub.
"""
import asyncio
import json
import logging
from typing import Any
import jwt
import redis.asyncio as aioredis
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from api.auth import _verify_authentik_token, _verify_passkey_token, User
from api.config import JWT_ISSUER
from services import task_service
logger = logging.getLogger(__name__)
ws_router = APIRouter()
# Reuse the broker URL for the async Redis client
import os
_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0")
async def _authenticate_ws(token: str) -> User | None:
"""Verify a JWT token using the same logic as api/auth.py."""
try:
unverified = jwt.decode(
token, options={"verify_signature": False, "verify_exp": False}
)
issuer = unverified.get("iss", "")
if issuer == JWT_ISSUER:
return _verify_passkey_token(token)
else:
return await _verify_authentik_token(token)
except Exception:
return None
async def _build_task_snapshot(task_id: str) -> dict[str, Any]:
"""Build a snapshot of a task's current status for the init message."""
status = task_service.get_task_status(task_id)
result: dict[str, Any] = {
"task_id": status.task_id,
"status": status.status,
"progress": status.progress,
"processed": status.processed,
"total": status.total,
"message": status.message,
}
if status.result and isinstance(status.result, dict):
result.update(status.result)
return result
@ws_router.websocket("/api/ws/tasks")
async def ws_task_progress(websocket: WebSocket) -> None:
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=4001, reason="Missing token")
return
user = await _authenticate_ws(token)
if user is None:
await websocket.close(code=4003, reason="Invalid token")
return
await websocket.accept()
# Get user's tasks and send initial snapshot
task_ids = task_service.get_user_tasks(user.email)
snapshots = []
for tid in task_ids:
try:
snapshots.append(await _build_task_snapshot(tid))
except Exception:
logger.debug("Failed to build snapshot for task %s", tid, exc_info=True)
try:
await websocket.send_json({"type": "init", "tasks": snapshots})
except Exception:
return
# Subscribe to Redis pub/sub channels for each task
redis_client = aioredis.from_url(_BROKER_URL, decode_responses=True)
pubsub = redis_client.pubsub()
subscribed_channels: set[str] = set()
for tid in task_ids:
channel = f"task_progress:{tid}"
await pubsub.subscribe(channel)
subscribed_channels.add(channel)
async def _forward_pubsub() -> None:
"""Read from Redis pub/sub and forward to the WebSocket."""
while True:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if message and message["type"] == "message":
try:
data = json.loads(message["data"])
await websocket.send_json({"type": "task_update", **data})
except Exception:
break
async def _handle_client_messages() -> None:
"""Read messages from the client (subscribe, ping)."""
while True:
try:
raw = await websocket.receive_text()
msg = json.loads(raw)
except WebSocketDisconnect:
raise
except Exception:
continue
msg_type = msg.get("type")
if msg_type == "subscribe":
new_task_id = msg.get("task_id")
if new_task_id:
channel = f"task_progress:{new_task_id}"
if channel not in subscribed_channels:
await pubsub.subscribe(channel)
subscribed_channels.add(channel)
# Send current snapshot for the new task
try:
snapshot = await _build_task_snapshot(new_task_id)
await websocket.send_json(
{"type": "task_update", **snapshot}
)
except Exception:
pass
elif msg_type == "ping":
try:
await websocket.send_json({"type": "pong"})
except Exception:
break
try:
await asyncio.gather(
_forward_pubsub(),
_handle_client_messages(),
)
except (WebSocketDisconnect, Exception):
pass
finally:
await pubsub.unsubscribe(*subscribed_channels)
await pubsub.close()
await redis_client.aclose()