wrongmove/services/task_progress_publisher.py

54 lines
1.8 KiB
Python
Raw Normal View History

"""Publishes task progress updates to Redis pub/sub channels.
Celery workers call publish_task_progress() alongside task.update_state() so
that the FastAPI WebSocket endpoint can forward real-time updates to connected
browsers without polling.
Channel naming: ``task_progress:{task_id}``
"""
import json
import logging
import os
from typing import Any
import redis
logger = logging.getLogger(__name__)
_redis_client: redis.Redis | None = None # type: ignore[type-arg]
def _get_redis_client() -> redis.Redis: # type: ignore[type-arg]
"""Lazy-initialised Redis client derived from CELERY_BROKER_URL."""
global _redis_client
if _redis_client is None:
broker_url = os.getenv("CELERY_BROKER_URL", "redis://redis:6379/0")
_redis_client = redis.Redis.from_url(broker_url, decode_responses=True)
return _redis_client
def publish_task_progress(task_id: str, state: str, meta: dict[str, Any]) -> None:
"""Publish a task progress update to Redis pub/sub.
Args:
task_id: Celery task ID.
state: Celery state string (e.g. 'PROGRESS', 'SUCCESS').
meta: Metadata dict (progress, phase, logs, counters, etc.).
Failures are caught and logged at WARNING level so they never break the
critical task execution path. The Redis client is reset on failure so
subsequent calls can reconnect.
"""
try:
client = _get_redis_client()
payload = json.dumps({
"task_id": task_id,
"status": state,
**meta,
})
client.publish(f"task_progress:{task_id}", payload)
except Exception:
logger.warning("Failed to publish task progress for %s", task_id, exc_info=True)
# Reset client so next call creates a fresh connection
_redis_client = None