wrongmove/tasks/poi_tasks.py
Viktor Barzin 8559c4b461
Add real-time WebSocket task progress with multi-job drawer
Replace 5s HTTP polling with WebSocket-based real-time updates for task
progress. Celery workers publish progress to Redis pub/sub channels;
a FastAPI WebSocket endpoint subscribes and forwards to the browser.
Polling is kept as a 30s fallback when WebSocket is unavailable.

The task progress drawer now supports multiple concurrent jobs with a
tab bar for switching between scrape and POI distance tasks.

Backend:
- Add services/task_progress_publisher.py (Redis pub/sub bridge)
- Add api/ws_routes.py (WebSocket endpoint with JWT auth)
- Publish progress from listing_tasks and poi_tasks
- Publish REVOKED via pub/sub on cancel/clear to fix stuck UI

Frontend:
- Add useTaskWebSocket hook with reconnection and keepalive
- Add TaskState and WS message types
- TaskIndicator: WS-driven updates with polling fallback
- TaskProgressDrawer: multi-job tabs, POI phase timeline
- Guard against WS overwriting local cancel state
2026-02-09 21:31:45 +00:00

117 lines
3.6 KiB
Python

"""Celery tasks for POI distance calculation."""
import asyncio
import logging
from typing import Any
from celery import Task
from celery_app import app
from database import engine
from models.listing import ListingType
from repositories.listing_repository import ListingRepository
from repositories.poi_repository import POIRepository
from services.poi_distance_calculator import calculate_poi_distances
from services.task_progress_publisher import publish_task_progress
logger = logging.getLogger(__name__)
celery_logger = logging.getLogger("celery.task")
if not celery_logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
))
celery_logger.addHandler(handler)
celery_logger.setLevel(logging.INFO)
@app.task(
bind=True,
autoretry_for=(Exception,),
max_retries=3,
retry_backoff=True,
retry_backoff_max=300,
)
def calculate_poi_distances_task(
self: Task,
poi_id: int,
travel_modes: list[str],
listing_type: str,
listing_ids: list[int] | None = None,
) -> dict[str, Any]:
"""Background task to calculate distances from listings to a POI.
Args:
poi_id: ID of the PointOfInterest.
travel_modes: List of travel modes (WALK, BICYCLE, TRANSIT).
listing_type: "BUY" or "RENT".
listing_ids: Optional subset of listing IDs.
"""
celery_logger.info(
f"Starting POI distance calculation: poi_id={poi_id}, "
f"modes={travel_modes}, type={listing_type}"
)
self.update_state(state="PROGRESS", meta={
"phase": "starting",
"progress": 0,
"message": "Starting distance calculation...",
})
publish_task_progress(self.request.id, "PROGRESS", {
"phase": "starting",
"progress": 0,
"message": "Starting distance calculation...",
})
listing_repo = ListingRepository(engine)
poi_repo = POIRepository(engine)
poi = poi_repo.get_poi_by_id(poi_id)
if poi is None:
celery_logger.error(f"POI {poi_id} not found")
error_result = {"error": f"POI {poi_id} not found", "distances_computed": 0}
publish_task_progress(self.request.id, "FAILURE", error_result)
return error_result
lt = ListingType(listing_type)
def on_progress(completed: int, total: int, message: str) -> None:
progress = round(completed / total, 2) if total > 0 else 0
meta = {
"phase": "computing",
"progress": progress,
"processed": completed,
"total": total,
"message": message,
}
self.update_state(state="PROGRESS", meta=meta)
publish_task_progress(self.request.id, "PROGRESS", meta)
try:
total = asyncio.run(
calculate_poi_distances(
listing_repo=listing_repo,
poi_repo=poi_repo,
poi=poi,
travel_modes=travel_modes,
listing_type=lt,
listing_ids=listing_ids,
on_progress=on_progress,
)
)
except Exception:
celery_logger.exception(
f"POI distance calculation failed: poi_id={poi_id}"
)
raise # Let Celery's autoretry handle it
celery_logger.info(f"POI distance calculation complete: {total} distances computed")
result = {
"phase": "completed",
"progress": 1,
"distances_computed": total,
"message": f"Computed {total} distances for POI '{poi.name}'",
}
publish_task_progress(self.request.id, "SUCCESS", result)
return result