Switch task progress to throttled event-driven updates

Replace timer-based _monitor_progress (1s sleep loop) with a
ProgressReporter class that publishes on actual state changes,
throttled to at most 1 publish per 250ms. A background flush
every 2s keeps ETA/elapsed current during quiet periods.

Switch WebSocket forwarder from get_message() polling (1s timeout)
to async pubsub.listen() for instant Redis-to-WebSocket delivery.

Combined latency improvement: ~1.5s average → ~250ms.
This commit is contained in:
Viktor Barzin 2026-02-10 21:24:33 +00:00
parent b816f695f0
commit 902f1b0852
No known key found for this signature in database
GPG key ID: 0EB088298288D958
2 changed files with 196 additions and 81 deletions

View file

@ -109,20 +109,18 @@ async def ws_task_progress(websocket: WebSocket) -> None:
async def _forward_pubsub() -> None:
"""Read from Redis pub/sub and forward to the WebSocket."""
while True:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if message and message["type"] == "message":
try:
data = json.loads(message["data"])
except (json.JSONDecodeError, ValueError):
logger.debug("Malformed pubsub message, skipping")
continue
try:
await websocket.send_json({"type": "task_update", **data})
except Exception:
break
async for message in pubsub.listen():
if message["type"] != "message":
continue
try:
data = json.loads(message["data"])
except (json.JSONDecodeError, ValueError):
logger.debug("Malformed pubsub message, skipping")
continue
try:
await websocket.send_json({"type": "task_update", **data})
except Exception:
break
async def _handle_client_messages() -> None:
"""Read messages from the client (subscribe, ping)."""

View file

@ -68,6 +68,174 @@ class _PipelineState:
processed_listings: list[Listing] = field(default_factory=list)
class ProgressReporter:
"""Event-driven, throttled progress reporter.
Call ``notify()`` whenever pipeline state changes (page fetched, item
processed, phase change). Publishes are throttled to at most one every
``min_interval`` seconds to avoid flooding Redis / WebSocket. Use
``force=True`` for phase changes, completion, and errors so they are
published immediately.
``run_background_flush`` is a long-running coroutine that publishes every
``background_interval`` seconds during quiet periods (keeps ETA / elapsed
fields current even when no items are being processed).
"""
def __init__(
self,
task: Task,
state: _PipelineState,
subqueries_total: int,
start_time: float,
min_interval: float = 0.25,
background_interval: float = 2.0,
) -> None:
self._task = task
self._state = state
self._subqueries_total = subqueries_total
self._start_time = start_time
self._min_interval = min_interval
self._background_interval = background_interval
self._last_publish: float = 0.0
self._last_log_progress: float = 0.0
self._dirty = False
self._deferred: asyncio.TimerHandle | None = None
self._loop: asyncio.AbstractEventLoop | None = None
self._stopped = False
# -- internal helpers ---------------------------------------------------
def _build_meta(self) -> tuple[str, dict[str, Any]]:
state = self._state
total = state.ids_collected
done = state.processed_count + state.failed_count
phase = PHASE_PROCESSING if state.fetching_done else PHASE_FETCHING
progress_ratio = round(done / total, 2) if total > 0 else 0.0
elapsed = time.time() - self._start_time
rate = done / elapsed if elapsed > 0 else 0
remaining = (total - done) if total > 0 else 0
eta = remaining / rate if rate > 0 else 0
label = (
f"{'Processing' if state.fetching_done else 'Fetching & processing'}: "
f"{done}/{total}"
)
meta: dict[str, Any] = {
"phase": phase,
"progress": progress_ratio,
"processed": done,
"total": total,
"subqueries_completed": state.completed_subqueries,
"subqueries_total": self._subqueries_total,
"ids_collected": state.ids_collected,
"pages_fetched": state.total_pages_fetched,
"fetching_done": state.fetching_done,
"details_fetched": state.details_fetched,
"images_downloaded": state.images_downloaded,
"ocr_completed": state.ocr_completed,
"failed": state.failed_count,
"elapsed_seconds": round(elapsed, 1),
"rate_per_second": round(rate, 2),
"eta_seconds": round(eta, 1),
}
return label, meta
def _publish(self) -> None:
label, meta = self._build_meta()
# Log milestones at every 10% progress
progress_ratio = meta["progress"]
if progress_ratio >= self._last_log_progress + 0.1 or meta["processed"] == 1:
done = meta["processed"]
total = meta["total"]
elapsed = meta["elapsed_seconds"]
rate = meta["rate_per_second"]
eta = meta["eta_seconds"]
celery_logger.info(
f"Progress: {progress_ratio * 100:.0f}% "
f"({done}/{total}) "
f"| Elapsed: {elapsed:.0f}s "
f"| Rate: {rate:.1f}/s "
f"| ETA: {eta:.0f}s"
)
self._last_log_progress = progress_ratio
_update_task_state(self._task, label, meta)
self._last_publish = time.time()
self._dirty = False
def _cancel_deferred(self) -> None:
if self._deferred is not None:
self._deferred.cancel()
self._deferred = None
def _deferred_callback(self) -> None:
"""Called by asyncio.call_later when the throttle window expires."""
self._deferred = None
if self._dirty and not self._stopped:
self._publish()
# -- public API ---------------------------------------------------------
def notify(self, force: bool = False) -> None:
"""Signal that pipeline state has changed.
If ``force`` is True the publish happens immediately (use for phase
changes and completion). Otherwise the publish is throttled.
"""
if self._stopped:
return
self._dirty = True
now = time.time()
if force or (now - self._last_publish) >= self._min_interval:
self._cancel_deferred()
self._publish()
elif self._deferred is None:
# Schedule a deferred flush so trailing updates aren't lost
loop = self._loop or asyncio.get_event_loop()
delay = self._min_interval - (now - self._last_publish)
self._deferred = loop.call_later(delay, self._deferred_callback)
def flush(self) -> None:
"""Force a final publish (call at pipeline end)."""
self._cancel_deferred()
self._publish()
async def run_background_flush(self) -> None:
"""Background keepalive: publish every ``background_interval`` seconds.
This keeps ETA / elapsed fields current during quiet periods.
Exits when the pipeline is done (fetching complete and all items
processed).
"""
self._loop = asyncio.get_running_loop()
while True:
state = self._state
total = state.ids_collected
done = state.processed_count + state.failed_count
if state.fetching_done and done >= total and total > 0:
break
if state.fetching_done and total == 0:
break
await asyncio.sleep(self._background_interval)
# Publish if nothing else has published recently
if not self._stopped:
now = time.time()
if (now - self._last_publish) >= self._background_interval:
self._publish()
self._stopped = True
self._cancel_deferred()
class TaskLogHandler(logging.Handler):
"""Captures log records into a deque for inclusion in task state updates."""
@ -100,11 +268,13 @@ async def _fetch_subquery(
existing_ids: set[int],
queue: asyncio.Queue[int | None],
state: _PipelineState,
reporter: ProgressReporter,
) -> None:
"""Fetch pages for a single subquery and enqueue new listing IDs."""
estimated = sq.estimated_results or 0
if estimated == 0:
state.completed_subqueries += 1
reporter.notify()
return
page_size = parameters.page_size
@ -133,6 +303,7 @@ async def _fetch_subquery(
config=config,
)
state.total_pages_fetched += 1
reporter.notify()
properties = result.get("properties", [])
for prop in properties:
@ -166,12 +337,14 @@ async def _fetch_subquery(
break
state.completed_subqueries += 1
reporter.notify()
async def _process_worker(
queue: asyncio.Queue[int | None],
processor: ListingProcessor,
state: _PipelineState,
reporter: ProgressReporter,
) -> None:
"""Consumer worker: pull listing IDs from the queue and process them."""
while True:
@ -195,73 +368,9 @@ async def _process_worker(
state.processed_listings.append(listing)
else:
state.failed_count += 1
reporter.notify()
async def _monitor_progress(
task: Task,
state: _PipelineState,
subqueries_total: int,
start_time: float,
) -> None:
"""Periodically report pipeline progress via task state updates."""
last_progress = 0.0
while True:
total = state.ids_collected
done = state.processed_count + state.failed_count
if state.fetching_done and done >= total and total > 0:
break
if state.fetching_done and total == 0:
break
phase = PHASE_PROCESSING if state.fetching_done else PHASE_FETCHING
if total > 0:
progress_ratio = round(done / total, 2)
else:
progress_ratio = 0.0
elapsed = time.time() - start_time
rate = done / elapsed if elapsed > 0 else 0
remaining = (total - done) if total > 0 else 0
eta = remaining / rate if rate > 0 else 0
if progress_ratio >= last_progress + 0.1 or done == 1:
celery_logger.info(
f"Progress: {progress_ratio * 100:.0f}% "
f"({done}/{total}) "
f"| Elapsed: {elapsed:.0f}s "
f"| Rate: {rate:.1f}/s "
f"| ETA: {eta:.0f}s"
)
last_progress = progress_ratio
_update_task_state(
task,
f"{'Processing' if state.fetching_done else 'Fetching & processing'}: "
f"{done}/{total}",
{
"phase": phase,
"progress": progress_ratio,
"processed": done,
"total": total,
"subqueries_completed": state.completed_subqueries,
"subqueries_total": subqueries_total,
"ids_collected": state.ids_collected,
"pages_fetched": state.total_pages_fetched,
"fetching_done": state.fetching_done,
"details_fetched": state.details_fetched,
"images_downloaded": state.images_downloaded,
"ocr_completed": state.ocr_completed,
"failed": state.failed_count,
"elapsed_seconds": round(elapsed, 1),
"rate_per_second": round(rate, 2),
"eta_seconds": round(eta, 1),
},
)
await asyncio.sleep(1)
@app.task(bind=True, pydantic=True)
def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
@ -407,6 +516,10 @@ async def _dump_listings_full_inner(
listing_processor = ListingProcessor(repository, parameters.listing_type)
reporter = ProgressReporter(
task, state, len(subqueries), start_time,
)
# Producer: fetch all subqueries concurrently, then signal workers to stop
async def producer() -> None:
await asyncio.gather(
@ -414,6 +527,7 @@ async def _dump_listings_full_inner(
_fetch_subquery(
sq, parameters, session, config,
semaphore, existing_ids, queue, state,
reporter,
)
for sq in subqueries
]
@ -425,16 +539,19 @@ async def _dump_listings_full_inner(
f"{state.ids_collected} new IDs"
)
state.fetching_done = True
reporter.notify(force=True)
for _ in range(NUM_WORKERS):
await queue.put(None)
await asyncio.gather(
producer(),
*[_process_worker(queue, listing_processor, state) for _ in range(NUM_WORKERS)],
_monitor_progress(task, state, len(subqueries), start_time),
*[_process_worker(queue, listing_processor, state, reporter) for _ in range(NUM_WORKERS)],
reporter.run_background_flush(),
)
reporter.flush()
except CircuitBreakerOpenError as e:
celery_logger.error(f"Circuit breaker prevented query: {e}")
metrics = get_throttle_metrics()