"""POI distance calculator - orchestrates OSRM and OTP for batch distance computation.""" import asyncio import logging from datetime import datetime from typing import Callable import aiohttp from config.routing_config import RoutingConfig from models.listing import BuyListing, ListingType, RentListing from models.poi import PointOfInterest from models.poi_distance import POIDistance from rec.osrm_client import osrm_table from rec.otp_client import otp_transit_route from repositories.listing_repository import ListingRepository from repositories.poi_repository import POIRepository logger = logging.getLogger("uvicorn.error") # Map travel mode names to OSRM profiles _OSRM_PROFILES = { "WALK": "foot", "BICYCLE": "bicycle", } async def calculate_poi_distances( listing_repo: ListingRepository, poi_repo: POIRepository, poi: PointOfInterest, travel_modes: list[str], listing_type: ListingType, listing_ids: list[int] | None = None, config: RoutingConfig | None = None, on_progress: Callable[[int, int, str], None] | None = None, ) -> int: """Calculate distances from listings to a POI for given travel modes. Args: listing_repo: Repository for listing access. poi_repo: Repository for POI and distance storage. poi: The point of interest to calculate distances to. travel_modes: List of travel modes (WALK, BICYCLE, TRANSIT). listing_type: BUY or RENT. listing_ids: Optional subset of listing IDs. If None, uses all listings. config: Routing engine configuration. on_progress: Callback(completed, total, message) for progress updates. Returns: Total number of distances computed. """ if config is None: config = RoutingConfig.from_env() # Load listings with coordinates model = RentListing if listing_type == ListingType.RENT else BuyListing listings = await listing_repo.get_listings( only_ids=listing_ids, query_parameters=None, ) if not listings: logger.info("No listings found for distance calculation") return 0 total_computed = 0 total_modes = len(travel_modes) for mode_idx, mode in enumerate(travel_modes): mode_upper = mode.upper() # Skip listings that already have computed distances existing = poi_repo.get_existing_distance_keys( poi.id, mode_upper, listing_type # type: ignore[arg-type] ) pending_listings = [l for l in listings if l.id not in existing] if not pending_listings: logger.info(f"All listings already computed for {mode_upper}") if on_progress: on_progress( total_computed, len(listings) * total_modes, f"Skipped {mode_upper} (already computed)" ) continue logger.info( f"Computing {mode_upper} distances for {len(pending_listings)} listings " f"(skipped {len(existing)} already computed)" ) try: if mode_upper in _OSRM_PROFILES: computed = await _compute_osrm( pending_listings, poi, mode_upper, listing_type, config, poi_repo, on_progress, total_computed, len(listings) * total_modes, ) elif mode_upper == "TRANSIT": computed = await _compute_transit( pending_listings, poi, listing_type, config, poi_repo, on_progress, total_computed, len(listings) * total_modes, ) else: logger.warning(f"Unknown travel mode: {mode_upper}") continue except (aiohttp.ClientError, OSError) as e: logger.error(f"Routing engine unreachable for {mode_upper}: {e}") if on_progress: on_progress( total_computed, len(listings) * total_modes, f"{mode_upper}: routing engine unavailable" ) continue total_computed += computed return total_computed async def _compute_osrm( listings: list, poi: PointOfInterest, mode: str, listing_type: ListingType, config: RoutingConfig, poi_repo: POIRepository, on_progress: Callable[[int, int, str], None] | None, progress_offset: int, progress_total: int, ) -> int: """Compute distances using OSRM /table API in batches.""" profile = _OSRM_PROFILES[mode] destination = [(poi.longitude, poi.latitude)] batch_size = config.osrm_batch_size computed = 0 async with aiohttp.ClientSession() as session: for batch_start in range(0, len(listings), batch_size): batch = listings[batch_start:batch_start + batch_size] origins = [(l.longitude, l.latitude) for l in batch] results = await osrm_table( origins, destination, profile, config, session ) distances_to_save: list[POIDistance] = [] for i, listing in enumerate(batch): result = results[i][0] if results[i] else None if result is not None: distances_to_save.append(POIDistance( listing_id=listing.id, listing_type=listing_type, poi_id=poi.id, # type: ignore[arg-type] travel_mode=mode, duration_seconds=result.duration_seconds, distance_meters=result.distance_meters, computed_at=datetime.utcnow(), )) # type: ignore[call-arg] if distances_to_save: poi_repo.upsert_distances(distances_to_save) computed += len(distances_to_save) if on_progress: on_progress( progress_offset + computed, progress_total, f"{mode}: {computed}/{len(listings)}" ) return computed async def _compute_transit( listings: list, poi: PointOfInterest, listing_type: ListingType, config: RoutingConfig, poi_repo: POIRepository, on_progress: Callable[[int, int, str], None] | None, progress_offset: int, progress_total: int, ) -> int: """Compute transit distances using OTP with concurrency control.""" semaphore = asyncio.Semaphore(config.otp_max_concurrent) computed = 0 batch_results: list[POIDistance] = [] save_interval = 50 # Save every N results async with aiohttp.ClientSession() as session: async def compute_one(listing: object) -> POIDistance | None: async with semaphore: result = await otp_transit_route( listing.latitude, listing.longitude, # type: ignore[union-attr] poi.latitude, poi.longitude, config, session, ) if result is None: return None return POIDistance( listing_id=listing.id, # type: ignore[union-attr] listing_type=listing_type, poi_id=poi.id, # type: ignore[arg-type] travel_mode="TRANSIT", duration_seconds=result.duration_seconds, distance_meters=result.distance_meters, computed_at=datetime.utcnow(), ) # type: ignore[call-arg] tasks = [compute_one(listing) for listing in listings] for coro in asyncio.as_completed(tasks): result = await coro if result is not None: batch_results.append(result) computed += 1 # Periodically save results if len(batch_results) >= save_interval: poi_repo.upsert_distances(batch_results) batch_results = [] if on_progress: on_progress( progress_offset + computed, progress_total, f"TRANSIT: {computed}/{len(listings)}" ) # Save remaining results if batch_results: poi_repo.upsert_distances(batch_results) return computed