wrongmove/tasks/listing_tasks.py
Viktor Barzin 8559c4b461
Add real-time WebSocket task progress with multi-job drawer
Replace 5s HTTP polling with WebSocket-based real-time updates for task
progress. Celery workers publish progress to Redis pub/sub channels;
a FastAPI WebSocket endpoint subscribes and forwards to the browser.
Polling is kept as a 30s fallback when WebSocket is unavailable.

The task progress drawer now supports multiple concurrent jobs with a
tab bar for switching between scrape and POI distance tasks.

Backend:
- Add services/task_progress_publisher.py (Redis pub/sub bridge)
- Add api/ws_routes.py (WebSocket endpoint with JWT auth)
- Publish progress from listing_tasks and poi_tasks
- Publish REVOKED via pub/sub on cancel/clear to fix stuck UI

Frontend:
- Add useTaskWebSocket hook with reconnection and keepalive
- Add TaskState and WS message types
- TaskIndicator: WS-driven updates with polling fallback
- TaskProgressDrawer: multi-job tabs, POI phase timeline
- Guard against WS overwriting local cancel state
2026-02-09 21:31:45 +00:00

493 lines
18 KiB
Python

import asyncio
import logging
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Any
from celery import Task
from celery.schedules import crontab
from celery_app import app
from config.schedule_config import SchedulesConfig
from config.scraper_config import ScraperConfig
from listing_processor import ListingProcessor
from models.listing import Listing, QueryParameters
from rec.query import create_session, listing_query
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
from repositories.listing_repository import ListingRepository
from database import engine
from services.query_splitter import QuerySplitter, SubQuery
from utils.redis_lock import redis_lock
from services.listing_cache import invalidate_cache
from services.task_progress_publisher import publish_task_progress
logger = logging.getLogger("uvicorn.error")
# Also configure a celery-specific logger that always outputs to stdout
celery_logger = logging.getLogger("celery.task")
if not celery_logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
))
celery_logger.addHandler(handler)
celery_logger.setLevel(logging.INFO)
SCRAPE_LOCK_NAME = "scrape_listings"
LOG_BUFFER_MAX_LINES = 200
# Number of concurrent consumer workers that process listings from the queue.
NUM_WORKERS = 20
# Phase constants for task state reporting
PHASE_SPLITTING = "splitting"
PHASE_FETCHING = "fetching"
PHASE_PROCESSING = "processing"
PHASE_COMPLETED = "completed"
# Module-level log buffer — active only during task execution.
# This is safe as module-level mutable state because Celery workers use a
# prefork pool: each worker process handles one task at a time, so there is
# no concurrent access within a single process. The TaskLogHandler appends
# here; _update_task_state reads from here.
_active_log_buffer: deque[str] | None = None
@dataclass
class _PipelineState:
"""Shared mutable state for the streaming fetch-and-process pipeline."""
ids_collected: int = 0
completed_subqueries: int = 0
total_pages_fetched: int = 0
fetching_done: bool = False
processed_count: int = 0
failed_count: int = 0
details_fetched: int = 0
images_downloaded: int = 0
ocr_completed: int = 0
processed_listings: list[Listing] = field(default_factory=list)
class TaskLogHandler(logging.Handler):
"""Captures log records into a deque for inclusion in task state updates."""
def __init__(self, buffer: deque[str]) -> None:
super().__init__()
self.buffer = buffer
def emit(self, record: logging.LogRecord) -> None:
try:
self.buffer.append(self.format(record))
except Exception:
pass
def _update_task_state(task: Task, state: str, meta: dict[str, Any]) -> None:
"""Call task.update_state with logs injected from the active log buffer."""
if _active_log_buffer is not None:
meta["logs"] = list(_active_log_buffer)
task.update_state(state=state, meta=meta)
if hasattr(task, 'request') and task.request.id:
publish_task_progress(task.request.id, state, meta)
async def _fetch_subquery(
sq: SubQuery,
parameters: QueryParameters,
session: object,
config: ScraperConfig,
semaphore: asyncio.Semaphore,
existing_ids: set[int],
queue: asyncio.Queue[int | None],
state: _PipelineState,
) -> None:
"""Fetch pages for a single subquery and enqueue new listing IDs."""
estimated = sq.estimated_results or 0
if estimated == 0:
state.completed_subqueries += 1
return
page_size = parameters.page_size
max_pages = min(
config.max_pages_per_query,
(estimated // page_size) + 1,
)
for page_id in range(1, max_pages + 1):
async with semaphore:
await asyncio.sleep(config.request_delay_ms / 1000)
try:
result = await listing_query(
page=page_id,
channel=parameters.listing_type,
min_bedrooms=sq.min_bedrooms,
max_bedrooms=sq.max_bedrooms,
radius=parameters.radius,
min_price=sq.min_price,
max_price=sq.max_price,
district=sq.district,
page_size=page_size,
max_days_since_added=parameters.max_days_since_added,
furnish_types=parameters.furnish_types or [],
session=session,
config=config,
)
state.total_pages_fetched += 1
properties = result.get("properties", [])
for prop in properties:
identifier = prop.get("identifier")
if identifier and identifier not in existing_ids:
existing_ids.add(identifier)
state.ids_collected += 1
await queue.put(identifier)
if len(properties) < page_size:
break
except CircuitBreakerOpenError as e:
celery_logger.error(f"Circuit breaker open: {e}")
break
except ThrottlingError as e:
celery_logger.warning(
f"Throttling on {sq.district} page {page_id}: {e}"
)
break
except Exception as e:
if "GENERIC_ERROR" in str(e):
celery_logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
celery_logger.warning(
f"Error fetching page {page_id} for "
f"{sq.district}: {e}"
)
break
state.completed_subqueries += 1
async def _process_worker(
queue: asyncio.Queue[int | None],
processor: ListingProcessor,
state: _PipelineState,
) -> None:
"""Consumer worker: pull listing IDs from the queue and process them."""
while True:
listing_id = await queue.get()
if listing_id is None:
break
def step_callback(step_name: str) -> None:
if step_name == "details":
state.details_fetched += 1
elif step_name == "images":
state.images_downloaded += 1
elif step_name == "ocr":
state.ocr_completed += 1
listing = await processor.process_listing(
listing_id, on_step_complete=step_callback
)
if listing is not None:
state.processed_count += 1
state.processed_listings.append(listing)
else:
state.failed_count += 1
async def _monitor_progress(
task: Task,
state: _PipelineState,
subqueries_total: int,
start_time: float,
) -> None:
"""Periodically report pipeline progress via task state updates."""
last_progress = 0.0
while True:
total = state.ids_collected
done = state.processed_count + state.failed_count
if state.fetching_done and done >= total and total > 0:
break
if state.fetching_done and total == 0:
break
phase = PHASE_PROCESSING if state.fetching_done else PHASE_FETCHING
if total > 0:
progress_ratio = round(done / total, 2)
else:
progress_ratio = 0.0
elapsed = time.time() - start_time
rate = done / elapsed if elapsed > 0 else 0
remaining = (total - done) if total > 0 else 0
eta = remaining / rate if rate > 0 else 0
if progress_ratio >= last_progress + 0.1 or done == 1:
celery_logger.info(
f"Progress: {progress_ratio * 100:.0f}% "
f"({done}/{total}) "
f"| Elapsed: {elapsed:.0f}s "
f"| Rate: {rate:.1f}/s "
f"| ETA: {eta:.0f}s"
)
last_progress = progress_ratio
_update_task_state(
task,
f"{'Processing' if state.fetching_done else 'Fetching & processing'}: "
f"{done}/{total}",
{
"phase": phase,
"progress": progress_ratio,
"processed": done,
"total": total,
"subqueries_completed": state.completed_subqueries,
"subqueries_total": subqueries_total,
"ids_collected": state.ids_collected,
"pages_fetched": state.total_pages_fetched,
"fetching_done": state.fetching_done,
"details_fetched": state.details_fetched,
"images_downloaded": state.images_downloaded,
"ocr_completed": state.ocr_completed,
"failed": state.failed_count,
"elapsed_seconds": round(elapsed, 1),
"rate_per_second": round(rate, 2),
"eta_seconds": round(eta, 1),
},
)
await asyncio.sleep(1)
@app.task(bind=True, pydantic=True)
def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
with redis_lock(SCRAPE_LOCK_NAME) as acquired:
if not acquired:
msg = "Another scrape job is already running, skipping this execution"
celery_logger.warning(msg)
meta = {"reason": "Another scrape job is running"}
self.update_state(state="SKIPPED", meta=meta)
publish_task_progress(self.request.id, "SKIPPED", meta)
return {"status": "skipped", "reason": "another_job_running"}
celery_logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}")
parsed_parameters = QueryParameters.model_validate_json(parameters_json)
celery_logger.info(f"Starting scrape with parameters: {parsed_parameters}")
self.update_state(state="Starting...", meta={"phase": PHASE_SPLITTING, "progress": 0})
publish_task_progress(self.request.id, "Starting...", {"phase": PHASE_SPLITTING, "progress": 0})
asyncio.run(dump_listings_full(task=self, parameters=parsed_parameters))
result = {"phase": PHASE_COMPLETED, "progress": 1}
publish_task_progress(self.request.id, "SUCCESS", result)
return result
async def async_dump_listings_task(parameters_json: str) -> dict[str, Any]:
with redis_lock(SCRAPE_LOCK_NAME) as acquired:
if not acquired:
celery_logger.warning("Another scrape job is already running, skipping this execution")
return {"status": "skipped", "reason": "another_job_running"}
celery_logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}")
parsed_parameters = QueryParameters.model_validate_json(parameters_json)
await dump_listings_full(task=Task(), parameters=parsed_parameters)
return {"progress": 0}
async def dump_listings_full(
*, task: Task, parameters: QueryParameters
) -> list[Listing]:
"""Fetches all listings, images as well as detects floorplans"""
global _active_log_buffer
# Set up log capture into a module-level buffer so _update_task_state
# can inject logs into every state update.
log_buffer: deque[str] = deque(maxlen=LOG_BUFFER_MAX_LINES)
log_handler = TaskLogHandler(log_buffer)
log_handler.setFormatter(
logging.Formatter("%(asctime)s %(message)s", datefmt="%H:%M:%S")
)
# Attach handler to both loggers used in the codebase, and ensure
# they accept INFO-level messages (Celery's worker setup may leave
# the celery.task logger at WARNING).
_prev_celery_level = celery_logger.level
_prev_logger_level = logger.level
celery_logger.addHandler(log_handler)
logger.addHandler(log_handler)
if celery_logger.level == logging.NOTSET or celery_logger.level > logging.INFO:
celery_logger.setLevel(logging.INFO)
if logger.level == logging.NOTSET or logger.level > logging.INFO:
logger.setLevel(logging.INFO)
_active_log_buffer = log_buffer
try:
return await _dump_listings_full_inner(task=task, parameters=parameters)
finally:
_active_log_buffer = None
celery_logger.removeHandler(log_handler)
logger.removeHandler(log_handler)
celery_logger.setLevel(_prev_celery_level)
logger.setLevel(_prev_logger_level)
async def _dump_listings_full_inner(
*, task: Task, parameters: QueryParameters
) -> list[Listing]:
"""Inner implementation with log capture active.
Uses a streaming pipeline: an asyncio.Queue bridges the fetcher (producer)
and processor workers (consumers) so that listing processing starts as
soon as IDs become available from each subquery.
"""
start_time = time.time()
state = _PipelineState()
celery_logger.info("=" * 60)
celery_logger.info(f"PHASE 1: Splitting queries")
celery_logger.info("=" * 60)
repository = ListingRepository(engine)
config = ScraperConfig.from_env()
splitter = QuerySplitter(config)
reset_throttle_metrics()
def on_progress(phase: str, message: str, **kwargs: Any) -> None:
meta: dict[str, Any] = {"phase": phase, "message": message}
meta.update(kwargs)
_update_task_state(task, message, meta)
celery_logger.info(f"[{phase}] {message}")
_update_task_state(task, "Analyzing query and splitting by price bands...", {
"phase": PHASE_SPLITTING, "progress": 0,
})
celery_logger.info("Starting query splitting and probing...")
try:
async with create_session(config) as session:
subqueries = await splitter.split(parameters, session, on_progress)
total_estimated = splitter.calculate_total_estimated_results(subqueries)
celery_logger.info(
f"Query split complete: {len(subqueries)} subqueries, "
f"~{total_estimated} estimated total results"
)
celery_logger.info("Loading existing listing IDs from database...")
existing_ids = repository.get_listing_ids(parameters.listing_type)
celery_logger.info(f"Found {len(existing_ids)} existing listings in DB")
celery_logger.info("=" * 60)
celery_logger.info(f"PHASE 2: Streaming fetch & process")
celery_logger.info("=" * 60)
queue: asyncio.Queue[int | None] = asyncio.Queue()
semaphore = asyncio.Semaphore(config.max_concurrent_requests)
_update_task_state(
task,
f"Fetching listings from {len(subqueries)} subqueries...",
{
"phase": PHASE_FETCHING,
"subqueries_completed": 0,
"subqueries_total": len(subqueries),
"ids_collected": 0,
"pages_fetched": 0,
"estimated_results": total_estimated,
"fetching_done": False,
},
)
listing_processor = ListingProcessor(repository, parameters.listing_type)
# Producer: fetch all subqueries concurrently, then signal workers to stop
async def producer() -> None:
await asyncio.gather(
*[
_fetch_subquery(
sq, parameters, session, config,
semaphore, existing_ids, queue, state,
)
for sq in subqueries
]
)
celery_logger.info(
f"Fetch complete: {state.total_pages_fetched} pages from "
f"{state.completed_subqueries} subqueries, "
f"{state.ids_collected} new IDs"
)
state.fetching_done = True
for _ in range(NUM_WORKERS):
await queue.put(None)
await asyncio.gather(
producer(),
*[_process_worker(queue, listing_processor, state) for _ in range(NUM_WORKERS)],
_monitor_progress(task, state, len(subqueries), start_time),
)
except CircuitBreakerOpenError as e:
celery_logger.error(f"Circuit breaker prevented query: {e}")
metrics = get_throttle_metrics()
if metrics.total_requests > 0:
celery_logger.info(metrics.summary())
return []
finally:
metrics = get_throttle_metrics()
if metrics.total_requests > 0:
celery_logger.info(
f"API Stats: {metrics.total_requests} requests, "
f"avg {metrics.average_response_time:.2f}s, "
f"{metrics.total_throttling_events} throttled"
)
elapsed = time.time() - start_time
celery_logger.info("=" * 60)
celery_logger.info(
f"COMPLETED: Processed {len(state.processed_listings)} listings in {elapsed:.1f}s"
)
celery_logger.info("=" * 60)
invalidate_cache()
_update_task_state(task, "Completed", {
"phase": PHASE_COMPLETED, "progress": 1,
"processed": len(state.processed_listings), "total": state.ids_collected,
"message": f"Processed {len(state.processed_listings)} listings in {elapsed:.0f}s",
})
return state.processed_listings
@app.on_after_finalize.connect
def setup_periodic_tasks(sender, **kwargs):
"""Register periodic tasks from environment configuration."""
try:
config = SchedulesConfig.from_env()
except ValueError as e:
celery_logger.error(f"Failed to load schedule configuration: {e}")
return
for schedule in config.get_enabled_schedules():
celery_logger.info(
f"Registering periodic task: {schedule.name} at {schedule.hour}:{schedule.minute}"
)
sender.add_periodic_task(
crontab(
minute=schedule.minute,
hour=schedule.hour,
day_of_week=schedule.day_of_week,
),
dump_listings_task.s(schedule.to_query_parameters().model_dump_json()),
name=schedule.name,
)