diff --git a/rec/osrm_client.py b/rec/osrm_client.py index d8e9801..a89d6a7 100644 --- a/rec/osrm_client.py +++ b/rec/osrm_client.py @@ -10,7 +10,7 @@ import aiohttp from config.routing_config import RoutingConfig -logger = logging.getLogger("uvicorn.error") +logger = logging.getLogger(__name__) @dataclass(frozen=True) diff --git a/rec/otp_client.py b/rec/otp_client.py index 1b9c165..1655541 100644 --- a/rec/otp_client.py +++ b/rec/otp_client.py @@ -12,7 +12,7 @@ import aiohttp from config.routing_config import RoutingConfig from rec.utils import nextMonday -logger = logging.getLogger("uvicorn.error") +logger = logging.getLogger(__name__) # OTP 2.x GraphQL query for transit plan _PLAN_QUERY = """ diff --git a/repositories/poi_repository.py b/repositories/poi_repository.py index 97187c2..f6e4d29 100644 --- a/repositories/poi_repository.py +++ b/repositories/poi_repository.py @@ -52,9 +52,41 @@ class POIRepository: return True def upsert_distances(self, distances: list[POIDistance]) -> None: + """Insert or update POI distances, handling duplicate unique constraints.""" + if not distances: + return with Session(self.engine) as session: + dialect = self.engine.dialect.name for dist in distances: - session.merge(dist) + values = { + "listing_id": dist.listing_id, + "listing_type": dist.listing_type, + "poi_id": dist.poi_id, + "travel_mode": dist.travel_mode, + "duration_seconds": dist.duration_seconds, + "distance_meters": dist.distance_meters, + "computed_at": dist.computed_at, + } + if dialect == "mysql": + from sqlalchemy.dialects.mysql import insert as mysql_insert + stmt = mysql_insert(POIDistance).values(**values) + stmt = stmt.on_duplicate_key_update( + duration_seconds=stmt.inserted.duration_seconds, + distance_meters=stmt.inserted.distance_meters, + computed_at=stmt.inserted.computed_at, + ) + else: + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + stmt = sqlite_insert(POIDistance).values(**values) + stmt = stmt.on_conflict_do_update( + index_elements=["listing_id", "listing_type", "poi_id", "travel_mode"], + set_={ + "duration_seconds": stmt.excluded.duration_seconds, + "distance_meters": stmt.excluded.distance_meters, + "computed_at": stmt.excluded.computed_at, + }, + ) + session.execute(stmt) session.commit() def get_distances_for_listings( diff --git a/services/poi_distance_calculator.py b/services/poi_distance_calculator.py index e38fe5a..90bc3ac 100644 --- a/services/poi_distance_calculator.py +++ b/services/poi_distance_calculator.py @@ -15,7 +15,7 @@ from rec.otp_client import otp_transit_route from repositories.listing_repository import ListingRepository from repositories.poi_repository import POIRepository -logger = logging.getLogger("uvicorn.error") +logger = logging.getLogger(__name__) # Map travel mode names to OSRM profiles _OSRM_PROFILES = { @@ -61,8 +61,25 @@ async def calculate_poi_distances( logger.info("No listings found for distance calculation") return 0 + # Filter out listings without valid coordinates + valid_listings = [ + l for l in listings + if l.latitude is not None and l.longitude is not None + and l.latitude != 0 and l.longitude != 0 + ] + if len(valid_listings) < len(listings): + logger.warning( + f"Skipped {len(listings) - len(valid_listings)} listings " + f"with missing coordinates" + ) + listings = valid_listings + if not listings: + logger.warning("No listings with valid coordinates") + return 0 + total_computed = 0 total_modes = len(travel_modes) + modes_failed: list[str] = [] for mode_idx, mode in enumerate(travel_modes): mode_upper = mode.upper() @@ -105,6 +122,7 @@ async def calculate_poi_distances( continue except (aiohttp.ClientError, OSError) as e: logger.error(f"Routing engine unreachable for {mode_upper}: {e}") + modes_failed.append(mode_upper) if on_progress: on_progress( total_computed, len(listings) * total_modes, @@ -114,6 +132,13 @@ async def calculate_poi_distances( total_computed += computed + if modes_failed and total_computed == 0: + failed_str = ", ".join(modes_failed) + raise OSError( + f"All routing engines failed ({failed_str}). " + f"No distances computed for {len(listings)} listings." + ) + return total_computed diff --git a/tasks/poi_tasks.py b/tasks/poi_tasks.py index b3416a4..b822c94 100644 --- a/tasks/poi_tasks.py +++ b/tasks/poi_tasks.py @@ -11,7 +11,7 @@ from repositories.listing_repository import ListingRepository from repositories.poi_repository import POIRepository from services.poi_distance_calculator import calculate_poi_distances -logger = logging.getLogger("uvicorn.error") +logger = logging.getLogger(__name__) celery_logger = logging.getLogger("celery.task") if not celery_logger.handlers: @@ -23,7 +23,13 @@ if not celery_logger.handlers: celery_logger.setLevel(logging.INFO) -@app.task(bind=True) +@app.task( + bind=True, + autoretry_for=(Exception,), + max_retries=3, + retry_backoff=True, + retry_backoff_max=300, +) def calculate_poi_distances_task( self: Task, poi_id: int, @@ -70,17 +76,23 @@ def calculate_poi_distances_task( "message": message, }) - total = asyncio.run( - calculate_poi_distances( - listing_repo=listing_repo, - poi_repo=poi_repo, - poi=poi, - travel_modes=travel_modes, - listing_type=lt, - listing_ids=listing_ids, - on_progress=on_progress, + try: + total = asyncio.run( + calculate_poi_distances( + listing_repo=listing_repo, + poi_repo=poi_repo, + poi=poi, + travel_modes=travel_modes, + listing_type=lt, + listing_ids=listing_ids, + on_progress=on_progress, + ) ) - ) + except Exception: + celery_logger.exception( + f"POI distance calculation failed: poi_id={poi_id}" + ) + raise # Let Celery's autoretry handle it celery_logger.info(f"POI distance calculation complete: {total} distances computed")