Three interconnected bugs prevented progress updates from reaching the frontend: 1. _forward_pubsub could exit silently while _handle_client_messages kept the WebSocket alive (responding to pings), so the client never detected the broken forwarding path. Replace asyncio.gather with asyncio.wait (FIRST_COMPLETED) so both coroutines are cancelled together. 2. Polling was stopped on WS connect with no fallback if forwarding broke. Now polling runs always alongside WebSocket as a safety net. 3. Redis publish failures in task_progress_publisher were logged at DEBUG and the broken client was reused forever. Log at WARNING and reset the client so the next call reconnects.
53 lines
1.8 KiB
Python
53 lines
1.8 KiB
Python
"""Publishes task progress updates to Redis pub/sub channels.
|
|
|
|
Celery workers call publish_task_progress() alongside task.update_state() so
|
|
that the FastAPI WebSocket endpoint can forward real-time updates to connected
|
|
browsers without polling.
|
|
|
|
Channel naming: ``task_progress:{task_id}``
|
|
"""
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Any
|
|
|
|
import redis
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_redis_client: redis.Redis | None = None # type: ignore[type-arg]
|
|
|
|
|
|
def _get_redis_client() -> redis.Redis: # type: ignore[type-arg]
|
|
"""Lazy-initialised Redis client derived from CELERY_BROKER_URL."""
|
|
global _redis_client
|
|
if _redis_client is None:
|
|
broker_url = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0")
|
|
_redis_client = redis.Redis.from_url(broker_url, decode_responses=True)
|
|
return _redis_client
|
|
|
|
|
|
def publish_task_progress(task_id: str, state: str, meta: dict[str, Any]) -> None:
|
|
"""Publish a task progress update to Redis pub/sub.
|
|
|
|
Args:
|
|
task_id: Celery task ID.
|
|
state: Celery state string (e.g. 'PROGRESS', 'SUCCESS').
|
|
meta: Metadata dict (progress, phase, logs, counters, etc.).
|
|
|
|
Failures are caught and logged at WARNING level so they never break the
|
|
critical task execution path. The Redis client is reset on failure so
|
|
subsequent calls can reconnect.
|
|
"""
|
|
try:
|
|
client = _get_redis_client()
|
|
payload = json.dumps({
|
|
"task_id": task_id,
|
|
"status": state,
|
|
**meta,
|
|
})
|
|
client.publish(f"task_progress:{task_id}", payload)
|
|
except Exception:
|
|
logger.warning("Failed to publish task progress for %s", task_id, exc_info=True)
|
|
# Reset client so next call creates a fresh connection
|
|
_redis_client = None
|