diff --git a/api/ws_routes.py b/api/ws_routes.py index c279858..2b2d44d 100644 --- a/api/ws_routes.py +++ b/api/ws_routes.py @@ -109,20 +109,18 @@ async def ws_task_progress(websocket: WebSocket) -> None: async def _forward_pubsub() -> None: """Read from Redis pub/sub and forward to the WebSocket.""" - while True: - message = await pubsub.get_message( - ignore_subscribe_messages=True, timeout=1.0 - ) - if message and message["type"] == "message": - try: - data = json.loads(message["data"]) - except (json.JSONDecodeError, ValueError): - logger.debug("Malformed pubsub message, skipping") - continue - try: - await websocket.send_json({"type": "task_update", **data}) - except Exception: - break + async for message in pubsub.listen(): + if message["type"] != "message": + continue + try: + data = json.loads(message["data"]) + except (json.JSONDecodeError, ValueError): + logger.debug("Malformed pubsub message, skipping") + continue + try: + await websocket.send_json({"type": "task_update", **data}) + except Exception: + break async def _handle_client_messages() -> None: """Read messages from the client (subscribe, ping).""" diff --git a/tasks/listing_tasks.py b/tasks/listing_tasks.py index 8258165..fb71821 100644 --- a/tasks/listing_tasks.py +++ b/tasks/listing_tasks.py @@ -68,6 +68,174 @@ class _PipelineState: processed_listings: list[Listing] = field(default_factory=list) +class ProgressReporter: + """Event-driven, throttled progress reporter. + + Call ``notify()`` whenever pipeline state changes (page fetched, item + processed, phase change). Publishes are throttled to at most one every + ``min_interval`` seconds to avoid flooding Redis / WebSocket. Use + ``force=True`` for phase changes, completion, and errors so they are + published immediately. + + ``run_background_flush`` is a long-running coroutine that publishes every + ``background_interval`` seconds during quiet periods (keeps ETA / elapsed + fields current even when no items are being processed). + """ + + def __init__( + self, + task: Task, + state: _PipelineState, + subqueries_total: int, + start_time: float, + min_interval: float = 0.25, + background_interval: float = 2.0, + ) -> None: + self._task = task + self._state = state + self._subqueries_total = subqueries_total + self._start_time = start_time + self._min_interval = min_interval + self._background_interval = background_interval + self._last_publish: float = 0.0 + self._last_log_progress: float = 0.0 + self._dirty = False + self._deferred: asyncio.TimerHandle | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._stopped = False + + # -- internal helpers --------------------------------------------------- + + def _build_meta(self) -> tuple[str, dict[str, Any]]: + state = self._state + total = state.ids_collected + done = state.processed_count + state.failed_count + + phase = PHASE_PROCESSING if state.fetching_done else PHASE_FETCHING + progress_ratio = round(done / total, 2) if total > 0 else 0.0 + + elapsed = time.time() - self._start_time + rate = done / elapsed if elapsed > 0 else 0 + remaining = (total - done) if total > 0 else 0 + eta = remaining / rate if rate > 0 else 0 + + label = ( + f"{'Processing' if state.fetching_done else 'Fetching & processing'}: " + f"{done}/{total}" + ) + meta: dict[str, Any] = { + "phase": phase, + "progress": progress_ratio, + "processed": done, + "total": total, + "subqueries_completed": state.completed_subqueries, + "subqueries_total": self._subqueries_total, + "ids_collected": state.ids_collected, + "pages_fetched": state.total_pages_fetched, + "fetching_done": state.fetching_done, + "details_fetched": state.details_fetched, + "images_downloaded": state.images_downloaded, + "ocr_completed": state.ocr_completed, + "failed": state.failed_count, + "elapsed_seconds": round(elapsed, 1), + "rate_per_second": round(rate, 2), + "eta_seconds": round(eta, 1), + } + return label, meta + + def _publish(self) -> None: + label, meta = self._build_meta() + + # Log milestones at every 10% progress + progress_ratio = meta["progress"] + if progress_ratio >= self._last_log_progress + 0.1 or meta["processed"] == 1: + done = meta["processed"] + total = meta["total"] + elapsed = meta["elapsed_seconds"] + rate = meta["rate_per_second"] + eta = meta["eta_seconds"] + celery_logger.info( + f"Progress: {progress_ratio * 100:.0f}% " + f"({done}/{total}) " + f"| Elapsed: {elapsed:.0f}s " + f"| Rate: {rate:.1f}/s " + f"| ETA: {eta:.0f}s" + ) + self._last_log_progress = progress_ratio + + _update_task_state(self._task, label, meta) + self._last_publish = time.time() + self._dirty = False + + def _cancel_deferred(self) -> None: + if self._deferred is not None: + self._deferred.cancel() + self._deferred = None + + def _deferred_callback(self) -> None: + """Called by asyncio.call_later when the throttle window expires.""" + self._deferred = None + if self._dirty and not self._stopped: + self._publish() + + # -- public API --------------------------------------------------------- + + def notify(self, force: bool = False) -> None: + """Signal that pipeline state has changed. + + If ``force`` is True the publish happens immediately (use for phase + changes and completion). Otherwise the publish is throttled. + """ + if self._stopped: + return + + self._dirty = True + now = time.time() + + if force or (now - self._last_publish) >= self._min_interval: + self._cancel_deferred() + self._publish() + elif self._deferred is None: + # Schedule a deferred flush so trailing updates aren't lost + loop = self._loop or asyncio.get_event_loop() + delay = self._min_interval - (now - self._last_publish) + self._deferred = loop.call_later(delay, self._deferred_callback) + + def flush(self) -> None: + """Force a final publish (call at pipeline end).""" + self._cancel_deferred() + self._publish() + + async def run_background_flush(self) -> None: + """Background keepalive: publish every ``background_interval`` seconds. + + This keeps ETA / elapsed fields current during quiet periods. + Exits when the pipeline is done (fetching complete and all items + processed). + """ + self._loop = asyncio.get_running_loop() + while True: + state = self._state + total = state.ids_collected + done = state.processed_count + state.failed_count + + if state.fetching_done and done >= total and total > 0: + break + if state.fetching_done and total == 0: + break + + await asyncio.sleep(self._background_interval) + + # Publish if nothing else has published recently + if not self._stopped: + now = time.time() + if (now - self._last_publish) >= self._background_interval: + self._publish() + + self._stopped = True + self._cancel_deferred() + + class TaskLogHandler(logging.Handler): """Captures log records into a deque for inclusion in task state updates.""" @@ -100,11 +268,13 @@ async def _fetch_subquery( existing_ids: set[int], queue: asyncio.Queue[int | None], state: _PipelineState, + reporter: ProgressReporter, ) -> None: """Fetch pages for a single subquery and enqueue new listing IDs.""" estimated = sq.estimated_results or 0 if estimated == 0: state.completed_subqueries += 1 + reporter.notify() return page_size = parameters.page_size @@ -133,6 +303,7 @@ async def _fetch_subquery( config=config, ) state.total_pages_fetched += 1 + reporter.notify() properties = result.get("properties", []) for prop in properties: @@ -166,12 +337,14 @@ async def _fetch_subquery( break state.completed_subqueries += 1 + reporter.notify() async def _process_worker( queue: asyncio.Queue[int | None], processor: ListingProcessor, state: _PipelineState, + reporter: ProgressReporter, ) -> None: """Consumer worker: pull listing IDs from the queue and process them.""" while True: @@ -195,73 +368,9 @@ async def _process_worker( state.processed_listings.append(listing) else: state.failed_count += 1 + reporter.notify() -async def _monitor_progress( - task: Task, - state: _PipelineState, - subqueries_total: int, - start_time: float, -) -> None: - """Periodically report pipeline progress via task state updates.""" - last_progress = 0.0 - - while True: - total = state.ids_collected - done = state.processed_count + state.failed_count - - if state.fetching_done and done >= total and total > 0: - break - if state.fetching_done and total == 0: - break - - phase = PHASE_PROCESSING if state.fetching_done else PHASE_FETCHING - - if total > 0: - progress_ratio = round(done / total, 2) - else: - progress_ratio = 0.0 - - elapsed = time.time() - start_time - rate = done / elapsed if elapsed > 0 else 0 - remaining = (total - done) if total > 0 else 0 - eta = remaining / rate if rate > 0 else 0 - - if progress_ratio >= last_progress + 0.1 or done == 1: - celery_logger.info( - f"Progress: {progress_ratio * 100:.0f}% " - f"({done}/{total}) " - f"| Elapsed: {elapsed:.0f}s " - f"| Rate: {rate:.1f}/s " - f"| ETA: {eta:.0f}s" - ) - last_progress = progress_ratio - - _update_task_state( - task, - f"{'Processing' if state.fetching_done else 'Fetching & processing'}: " - f"{done}/{total}", - { - "phase": phase, - "progress": progress_ratio, - "processed": done, - "total": total, - "subqueries_completed": state.completed_subqueries, - "subqueries_total": subqueries_total, - "ids_collected": state.ids_collected, - "pages_fetched": state.total_pages_fetched, - "fetching_done": state.fetching_done, - "details_fetched": state.details_fetched, - "images_downloaded": state.images_downloaded, - "ocr_completed": state.ocr_completed, - "failed": state.failed_count, - "elapsed_seconds": round(elapsed, 1), - "rate_per_second": round(rate, 2), - "eta_seconds": round(eta, 1), - }, - ) - await asyncio.sleep(1) - @app.task(bind=True, pydantic=True) def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]: @@ -407,6 +516,10 @@ async def _dump_listings_full_inner( listing_processor = ListingProcessor(repository, parameters.listing_type) + reporter = ProgressReporter( + task, state, len(subqueries), start_time, + ) + # Producer: fetch all subqueries concurrently, then signal workers to stop async def producer() -> None: await asyncio.gather( @@ -414,6 +527,7 @@ async def _dump_listings_full_inner( _fetch_subquery( sq, parameters, session, config, semaphore, existing_ids, queue, state, + reporter, ) for sq in subqueries ] @@ -425,16 +539,19 @@ async def _dump_listings_full_inner( f"{state.ids_collected} new IDs" ) state.fetching_done = True + reporter.notify(force=True) for _ in range(NUM_WORKERS): await queue.put(None) await asyncio.gather( producer(), - *[_process_worker(queue, listing_processor, state) for _ in range(NUM_WORKERS)], - _monitor_progress(task, state, len(subqueries), start_time), + *[_process_worker(queue, listing_processor, state, reporter) for _ in range(NUM_WORKERS)], + reporter.run_background_flush(), ) + reporter.flush() + except CircuitBreakerOpenError as e: celery_logger.error(f"Circuit breaker prevented query: {e}") metrics = get_throttle_metrics()