"""Celery tasks for POI distance calculation.""" import asyncio import logging from typing import Any from celery import Task from celery_app import app from database import engine from models.listing import ListingType from repositories.listing_repository import ListingRepository from repositories.poi_repository import POIRepository from services.poi_distance_calculator import calculate_poi_distances logger = logging.getLogger(__name__) celery_logger = logging.getLogger("celery.task") if not celery_logger.handlers: handler = logging.StreamHandler() handler.setFormatter(logging.Formatter( "%(asctime)s [%(levelname)s] %(name)s: %(message)s" )) celery_logger.addHandler(handler) celery_logger.setLevel(logging.INFO) @app.task( bind=True, autoretry_for=(Exception,), max_retries=3, retry_backoff=True, retry_backoff_max=300, ) def calculate_poi_distances_task( self: Task, poi_id: int, travel_modes: list[str], listing_type: str, listing_ids: list[int] | None = None, ) -> dict[str, Any]: """Background task to calculate distances from listings to a POI. Args: poi_id: ID of the PointOfInterest. travel_modes: List of travel modes (WALK, BICYCLE, TRANSIT). listing_type: "BUY" or "RENT". listing_ids: Optional subset of listing IDs. """ celery_logger.info( f"Starting POI distance calculation: poi_id={poi_id}, " f"modes={travel_modes}, type={listing_type}" ) self.update_state(state="PROGRESS", meta={ "phase": "starting", "progress": 0, "message": "Starting distance calculation...", }) listing_repo = ListingRepository(engine) poi_repo = POIRepository(engine) poi = poi_repo.get_poi_by_id(poi_id) if poi is None: celery_logger.error(f"POI {poi_id} not found") return {"error": f"POI {poi_id} not found", "distances_computed": 0} lt = ListingType(listing_type) def on_progress(completed: int, total: int, message: str) -> None: progress = round(completed / total, 2) if total > 0 else 0 self.update_state(state="PROGRESS", meta={ "phase": "computing", "progress": progress, "processed": completed, "total": total, "message": message, }) try: total = asyncio.run( calculate_poi_distances( listing_repo=listing_repo, poi_repo=poi_repo, poi=poi, travel_modes=travel_modes, listing_type=lt, listing_ids=listing_ids, on_progress=on_progress, ) ) except Exception: celery_logger.exception( f"POI distance calculation failed: poi_id={poi_id}" ) raise # Let Celery's autoretry handle it celery_logger.info(f"POI distance calculation complete: {total} distances computed") return { "phase": "completed", "progress": 1, "distances_computed": total, "message": f"Computed {total} distances for POI '{poi.name}'", }