Switch task progress to throttled event-driven updates
Replace timer-based _monitor_progress (1s sleep loop) with a ProgressReporter class that publishes on actual state changes, throttled to at most 1 publish per 250ms. A background flush every 2s keeps ETA/elapsed current during quiet periods. Switch WebSocket forwarder from get_message() polling (1s timeout) to async pubsub.listen() for instant Redis-to-WebSocket delivery. Combined latency improvement: ~1.5s average → ~250ms.
This commit is contained in:
parent
b816f695f0
commit
902f1b0852
2 changed files with 196 additions and 81 deletions
|
|
@ -109,20 +109,18 @@ async def ws_task_progress(websocket: WebSocket) -> None:
|
|||
|
||||
async def _forward_pubsub() -> None:
|
||||
"""Read from Redis pub/sub and forward to the WebSocket."""
|
||||
while True:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if message and message["type"] == "message":
|
||||
try:
|
||||
data = json.loads(message["data"])
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("Malformed pubsub message, skipping")
|
||||
continue
|
||||
try:
|
||||
await websocket.send_json({"type": "task_update", **data})
|
||||
except Exception:
|
||||
break
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] != "message":
|
||||
continue
|
||||
try:
|
||||
data = json.loads(message["data"])
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("Malformed pubsub message, skipping")
|
||||
continue
|
||||
try:
|
||||
await websocket.send_json({"type": "task_update", **data})
|
||||
except Exception:
|
||||
break
|
||||
|
||||
async def _handle_client_messages() -> None:
|
||||
"""Read messages from the client (subscribe, ping)."""
|
||||
|
|
|
|||
|
|
@ -68,6 +68,174 @@ class _PipelineState:
|
|||
processed_listings: list[Listing] = field(default_factory=list)
|
||||
|
||||
|
||||
class ProgressReporter:
|
||||
"""Event-driven, throttled progress reporter.
|
||||
|
||||
Call ``notify()`` whenever pipeline state changes (page fetched, item
|
||||
processed, phase change). Publishes are throttled to at most one every
|
||||
``min_interval`` seconds to avoid flooding Redis / WebSocket. Use
|
||||
``force=True`` for phase changes, completion, and errors so they are
|
||||
published immediately.
|
||||
|
||||
``run_background_flush`` is a long-running coroutine that publishes every
|
||||
``background_interval`` seconds during quiet periods (keeps ETA / elapsed
|
||||
fields current even when no items are being processed).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: Task,
|
||||
state: _PipelineState,
|
||||
subqueries_total: int,
|
||||
start_time: float,
|
||||
min_interval: float = 0.25,
|
||||
background_interval: float = 2.0,
|
||||
) -> None:
|
||||
self._task = task
|
||||
self._state = state
|
||||
self._subqueries_total = subqueries_total
|
||||
self._start_time = start_time
|
||||
self._min_interval = min_interval
|
||||
self._background_interval = background_interval
|
||||
self._last_publish: float = 0.0
|
||||
self._last_log_progress: float = 0.0
|
||||
self._dirty = False
|
||||
self._deferred: asyncio.TimerHandle | None = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._stopped = False
|
||||
|
||||
# -- internal helpers ---------------------------------------------------
|
||||
|
||||
def _build_meta(self) -> tuple[str, dict[str, Any]]:
|
||||
state = self._state
|
||||
total = state.ids_collected
|
||||
done = state.processed_count + state.failed_count
|
||||
|
||||
phase = PHASE_PROCESSING if state.fetching_done else PHASE_FETCHING
|
||||
progress_ratio = round(done / total, 2) if total > 0 else 0.0
|
||||
|
||||
elapsed = time.time() - self._start_time
|
||||
rate = done / elapsed if elapsed > 0 else 0
|
||||
remaining = (total - done) if total > 0 else 0
|
||||
eta = remaining / rate if rate > 0 else 0
|
||||
|
||||
label = (
|
||||
f"{'Processing' if state.fetching_done else 'Fetching & processing'}: "
|
||||
f"{done}/{total}"
|
||||
)
|
||||
meta: dict[str, Any] = {
|
||||
"phase": phase,
|
||||
"progress": progress_ratio,
|
||||
"processed": done,
|
||||
"total": total,
|
||||
"subqueries_completed": state.completed_subqueries,
|
||||
"subqueries_total": self._subqueries_total,
|
||||
"ids_collected": state.ids_collected,
|
||||
"pages_fetched": state.total_pages_fetched,
|
||||
"fetching_done": state.fetching_done,
|
||||
"details_fetched": state.details_fetched,
|
||||
"images_downloaded": state.images_downloaded,
|
||||
"ocr_completed": state.ocr_completed,
|
||||
"failed": state.failed_count,
|
||||
"elapsed_seconds": round(elapsed, 1),
|
||||
"rate_per_second": round(rate, 2),
|
||||
"eta_seconds": round(eta, 1),
|
||||
}
|
||||
return label, meta
|
||||
|
||||
def _publish(self) -> None:
|
||||
label, meta = self._build_meta()
|
||||
|
||||
# Log milestones at every 10% progress
|
||||
progress_ratio = meta["progress"]
|
||||
if progress_ratio >= self._last_log_progress + 0.1 or meta["processed"] == 1:
|
||||
done = meta["processed"]
|
||||
total = meta["total"]
|
||||
elapsed = meta["elapsed_seconds"]
|
||||
rate = meta["rate_per_second"]
|
||||
eta = meta["eta_seconds"]
|
||||
celery_logger.info(
|
||||
f"Progress: {progress_ratio * 100:.0f}% "
|
||||
f"({done}/{total}) "
|
||||
f"| Elapsed: {elapsed:.0f}s "
|
||||
f"| Rate: {rate:.1f}/s "
|
||||
f"| ETA: {eta:.0f}s"
|
||||
)
|
||||
self._last_log_progress = progress_ratio
|
||||
|
||||
_update_task_state(self._task, label, meta)
|
||||
self._last_publish = time.time()
|
||||
self._dirty = False
|
||||
|
||||
def _cancel_deferred(self) -> None:
|
||||
if self._deferred is not None:
|
||||
self._deferred.cancel()
|
||||
self._deferred = None
|
||||
|
||||
def _deferred_callback(self) -> None:
|
||||
"""Called by asyncio.call_later when the throttle window expires."""
|
||||
self._deferred = None
|
||||
if self._dirty and not self._stopped:
|
||||
self._publish()
|
||||
|
||||
# -- public API ---------------------------------------------------------
|
||||
|
||||
def notify(self, force: bool = False) -> None:
|
||||
"""Signal that pipeline state has changed.
|
||||
|
||||
If ``force`` is True the publish happens immediately (use for phase
|
||||
changes and completion). Otherwise the publish is throttled.
|
||||
"""
|
||||
if self._stopped:
|
||||
return
|
||||
|
||||
self._dirty = True
|
||||
now = time.time()
|
||||
|
||||
if force or (now - self._last_publish) >= self._min_interval:
|
||||
self._cancel_deferred()
|
||||
self._publish()
|
||||
elif self._deferred is None:
|
||||
# Schedule a deferred flush so trailing updates aren't lost
|
||||
loop = self._loop or asyncio.get_event_loop()
|
||||
delay = self._min_interval - (now - self._last_publish)
|
||||
self._deferred = loop.call_later(delay, self._deferred_callback)
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Force a final publish (call at pipeline end)."""
|
||||
self._cancel_deferred()
|
||||
self._publish()
|
||||
|
||||
async def run_background_flush(self) -> None:
|
||||
"""Background keepalive: publish every ``background_interval`` seconds.
|
||||
|
||||
This keeps ETA / elapsed fields current during quiet periods.
|
||||
Exits when the pipeline is done (fetching complete and all items
|
||||
processed).
|
||||
"""
|
||||
self._loop = asyncio.get_running_loop()
|
||||
while True:
|
||||
state = self._state
|
||||
total = state.ids_collected
|
||||
done = state.processed_count + state.failed_count
|
||||
|
||||
if state.fetching_done and done >= total and total > 0:
|
||||
break
|
||||
if state.fetching_done and total == 0:
|
||||
break
|
||||
|
||||
await asyncio.sleep(self._background_interval)
|
||||
|
||||
# Publish if nothing else has published recently
|
||||
if not self._stopped:
|
||||
now = time.time()
|
||||
if (now - self._last_publish) >= self._background_interval:
|
||||
self._publish()
|
||||
|
||||
self._stopped = True
|
||||
self._cancel_deferred()
|
||||
|
||||
|
||||
class TaskLogHandler(logging.Handler):
|
||||
"""Captures log records into a deque for inclusion in task state updates."""
|
||||
|
||||
|
|
@ -100,11 +268,13 @@ async def _fetch_subquery(
|
|||
existing_ids: set[int],
|
||||
queue: asyncio.Queue[int | None],
|
||||
state: _PipelineState,
|
||||
reporter: ProgressReporter,
|
||||
) -> None:
|
||||
"""Fetch pages for a single subquery and enqueue new listing IDs."""
|
||||
estimated = sq.estimated_results or 0
|
||||
if estimated == 0:
|
||||
state.completed_subqueries += 1
|
||||
reporter.notify()
|
||||
return
|
||||
|
||||
page_size = parameters.page_size
|
||||
|
|
@ -133,6 +303,7 @@ async def _fetch_subquery(
|
|||
config=config,
|
||||
)
|
||||
state.total_pages_fetched += 1
|
||||
reporter.notify()
|
||||
|
||||
properties = result.get("properties", [])
|
||||
for prop in properties:
|
||||
|
|
@ -166,12 +337,14 @@ async def _fetch_subquery(
|
|||
break
|
||||
|
||||
state.completed_subqueries += 1
|
||||
reporter.notify()
|
||||
|
||||
|
||||
async def _process_worker(
|
||||
queue: asyncio.Queue[int | None],
|
||||
processor: ListingProcessor,
|
||||
state: _PipelineState,
|
||||
reporter: ProgressReporter,
|
||||
) -> None:
|
||||
"""Consumer worker: pull listing IDs from the queue and process them."""
|
||||
while True:
|
||||
|
|
@ -195,73 +368,9 @@ async def _process_worker(
|
|||
state.processed_listings.append(listing)
|
||||
else:
|
||||
state.failed_count += 1
|
||||
reporter.notify()
|
||||
|
||||
|
||||
async def _monitor_progress(
|
||||
task: Task,
|
||||
state: _PipelineState,
|
||||
subqueries_total: int,
|
||||
start_time: float,
|
||||
) -> None:
|
||||
"""Periodically report pipeline progress via task state updates."""
|
||||
last_progress = 0.0
|
||||
|
||||
while True:
|
||||
total = state.ids_collected
|
||||
done = state.processed_count + state.failed_count
|
||||
|
||||
if state.fetching_done and done >= total and total > 0:
|
||||
break
|
||||
if state.fetching_done and total == 0:
|
||||
break
|
||||
|
||||
phase = PHASE_PROCESSING if state.fetching_done else PHASE_FETCHING
|
||||
|
||||
if total > 0:
|
||||
progress_ratio = round(done / total, 2)
|
||||
else:
|
||||
progress_ratio = 0.0
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
rate = done / elapsed if elapsed > 0 else 0
|
||||
remaining = (total - done) if total > 0 else 0
|
||||
eta = remaining / rate if rate > 0 else 0
|
||||
|
||||
if progress_ratio >= last_progress + 0.1 or done == 1:
|
||||
celery_logger.info(
|
||||
f"Progress: {progress_ratio * 100:.0f}% "
|
||||
f"({done}/{total}) "
|
||||
f"| Elapsed: {elapsed:.0f}s "
|
||||
f"| Rate: {rate:.1f}/s "
|
||||
f"| ETA: {eta:.0f}s"
|
||||
)
|
||||
last_progress = progress_ratio
|
||||
|
||||
_update_task_state(
|
||||
task,
|
||||
f"{'Processing' if state.fetching_done else 'Fetching & processing'}: "
|
||||
f"{done}/{total}",
|
||||
{
|
||||
"phase": phase,
|
||||
"progress": progress_ratio,
|
||||
"processed": done,
|
||||
"total": total,
|
||||
"subqueries_completed": state.completed_subqueries,
|
||||
"subqueries_total": subqueries_total,
|
||||
"ids_collected": state.ids_collected,
|
||||
"pages_fetched": state.total_pages_fetched,
|
||||
"fetching_done": state.fetching_done,
|
||||
"details_fetched": state.details_fetched,
|
||||
"images_downloaded": state.images_downloaded,
|
||||
"ocr_completed": state.ocr_completed,
|
||||
"failed": state.failed_count,
|
||||
"elapsed_seconds": round(elapsed, 1),
|
||||
"rate_per_second": round(rate, 2),
|
||||
"eta_seconds": round(eta, 1),
|
||||
},
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@app.task(bind=True, pydantic=True)
|
||||
def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
|
||||
|
|
@ -407,6 +516,10 @@ async def _dump_listings_full_inner(
|
|||
|
||||
listing_processor = ListingProcessor(repository, parameters.listing_type)
|
||||
|
||||
reporter = ProgressReporter(
|
||||
task, state, len(subqueries), start_time,
|
||||
)
|
||||
|
||||
# Producer: fetch all subqueries concurrently, then signal workers to stop
|
||||
async def producer() -> None:
|
||||
await asyncio.gather(
|
||||
|
|
@ -414,6 +527,7 @@ async def _dump_listings_full_inner(
|
|||
_fetch_subquery(
|
||||
sq, parameters, session, config,
|
||||
semaphore, existing_ids, queue, state,
|
||||
reporter,
|
||||
)
|
||||
for sq in subqueries
|
||||
]
|
||||
|
|
@ -425,16 +539,19 @@ async def _dump_listings_full_inner(
|
|||
f"{state.ids_collected} new IDs"
|
||||
)
|
||||
state.fetching_done = True
|
||||
reporter.notify(force=True)
|
||||
|
||||
for _ in range(NUM_WORKERS):
|
||||
await queue.put(None)
|
||||
|
||||
await asyncio.gather(
|
||||
producer(),
|
||||
*[_process_worker(queue, listing_processor, state) for _ in range(NUM_WORKERS)],
|
||||
_monitor_progress(task, state, len(subqueries), start_time),
|
||||
*[_process_worker(queue, listing_processor, state, reporter) for _ in range(NUM_WORKERS)],
|
||||
reporter.run_background_flush(),
|
||||
)
|
||||
|
||||
reporter.flush()
|
||||
|
||||
except CircuitBreakerOpenError as e:
|
||||
celery_logger.error(f"Circuit breaker prevented query: {e}")
|
||||
metrics = get_throttle_metrics()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue