wrongmove/services/poi_distance_calculator.py
Viktor Barzin 8a5d1b3787
Fix POI distance calculation: OSRM index separator and error handling
- Fix OSRM client to use semicolons (not commas) for source/destination
  indices in /table API requests. Commas caused "Query string malformed"
  errors for any batch with more than one origin.
- Add error handling in poi_distance_calculator for unreachable routing
  engines (OSRM/OTP). Connection failures now log an error and skip the
  mode instead of crashing the entire Celery task.
2026-02-08 14:50:09 +00:00

232 lines
8.1 KiB
Python

"""POI distance calculator - orchestrates OSRM and OTP for batch distance computation."""
import asyncio
import logging
from datetime import datetime
from typing import Callable
import aiohttp
from config.routing_config import RoutingConfig
from models.listing import BuyListing, ListingType, RentListing
from models.poi import PointOfInterest
from models.poi_distance import POIDistance
from rec.osrm_client import osrm_table
from rec.otp_client import otp_transit_route
from repositories.listing_repository import ListingRepository
from repositories.poi_repository import POIRepository
logger = logging.getLogger("uvicorn.error")
# Map travel mode names to OSRM profiles
_OSRM_PROFILES = {
"WALK": "foot",
"BICYCLE": "bicycle",
}
async def calculate_poi_distances(
listing_repo: ListingRepository,
poi_repo: POIRepository,
poi: PointOfInterest,
travel_modes: list[str],
listing_type: ListingType,
listing_ids: list[int] | None = None,
config: RoutingConfig | None = None,
on_progress: Callable[[int, int, str], None] | None = None,
) -> int:
"""Calculate distances from listings to a POI for given travel modes.
Args:
listing_repo: Repository for listing access.
poi_repo: Repository for POI and distance storage.
poi: The point of interest to calculate distances to.
travel_modes: List of travel modes (WALK, BICYCLE, TRANSIT).
listing_type: BUY or RENT.
listing_ids: Optional subset of listing IDs. If None, uses all listings.
config: Routing engine configuration.
on_progress: Callback(completed, total, message) for progress updates.
Returns:
Total number of distances computed.
"""
if config is None:
config = RoutingConfig.from_env()
# Load listings with coordinates
model = RentListing if listing_type == ListingType.RENT else BuyListing
listings = await listing_repo.get_listings(
only_ids=listing_ids,
query_parameters=None,
)
if not listings:
logger.info("No listings found for distance calculation")
return 0
total_computed = 0
total_modes = len(travel_modes)
for mode_idx, mode in enumerate(travel_modes):
mode_upper = mode.upper()
# Skip listings that already have computed distances
existing = poi_repo.get_existing_distance_keys(
poi.id, mode_upper, listing_type # type: ignore[arg-type]
)
pending_listings = [l for l in listings if l.id not in existing]
if not pending_listings:
logger.info(f"All listings already computed for {mode_upper}")
if on_progress:
on_progress(
total_computed, len(listings) * total_modes,
f"Skipped {mode_upper} (already computed)"
)
continue
logger.info(
f"Computing {mode_upper} distances for {len(pending_listings)} listings "
f"(skipped {len(existing)} already computed)"
)
try:
if mode_upper in _OSRM_PROFILES:
computed = await _compute_osrm(
pending_listings, poi, mode_upper, listing_type,
config, poi_repo, on_progress,
total_computed, len(listings) * total_modes,
)
elif mode_upper == "TRANSIT":
computed = await _compute_transit(
pending_listings, poi, listing_type,
config, poi_repo, on_progress,
total_computed, len(listings) * total_modes,
)
else:
logger.warning(f"Unknown travel mode: {mode_upper}")
continue
except (aiohttp.ClientError, OSError) as e:
logger.error(f"Routing engine unreachable for {mode_upper}: {e}")
if on_progress:
on_progress(
total_computed, len(listings) * total_modes,
f"{mode_upper}: routing engine unavailable"
)
continue
total_computed += computed
return total_computed
async def _compute_osrm(
listings: list,
poi: PointOfInterest,
mode: str,
listing_type: ListingType,
config: RoutingConfig,
poi_repo: POIRepository,
on_progress: Callable[[int, int, str], None] | None,
progress_offset: int,
progress_total: int,
) -> int:
"""Compute distances using OSRM /table API in batches."""
profile = _OSRM_PROFILES[mode]
destination = [(poi.longitude, poi.latitude)]
batch_size = config.osrm_batch_size
computed = 0
async with aiohttp.ClientSession() as session:
for batch_start in range(0, len(listings), batch_size):
batch = listings[batch_start:batch_start + batch_size]
origins = [(l.longitude, l.latitude) for l in batch]
results = await osrm_table(
origins, destination, profile, config, session
)
distances_to_save: list[POIDistance] = []
for i, listing in enumerate(batch):
result = results[i][0] if results[i] else None
if result is not None:
distances_to_save.append(POIDistance(
listing_id=listing.id,
listing_type=listing_type,
poi_id=poi.id, # type: ignore[arg-type]
travel_mode=mode,
duration_seconds=result.duration_seconds,
distance_meters=result.distance_meters,
computed_at=datetime.utcnow(),
)) # type: ignore[call-arg]
if distances_to_save:
poi_repo.upsert_distances(distances_to_save)
computed += len(distances_to_save)
if on_progress:
on_progress(
progress_offset + computed, progress_total,
f"{mode}: {computed}/{len(listings)}"
)
return computed
async def _compute_transit(
listings: list,
poi: PointOfInterest,
listing_type: ListingType,
config: RoutingConfig,
poi_repo: POIRepository,
on_progress: Callable[[int, int, str], None] | None,
progress_offset: int,
progress_total: int,
) -> int:
"""Compute transit distances using OTP with concurrency control."""
semaphore = asyncio.Semaphore(config.otp_max_concurrent)
computed = 0
batch_results: list[POIDistance] = []
save_interval = 50 # Save every N results
async with aiohttp.ClientSession() as session:
async def compute_one(listing: object) -> POIDistance | None:
async with semaphore:
result = await otp_transit_route(
listing.latitude, listing.longitude, # type: ignore[union-attr]
poi.latitude, poi.longitude,
config, session,
)
if result is None:
return None
return POIDistance(
listing_id=listing.id, # type: ignore[union-attr]
listing_type=listing_type,
poi_id=poi.id, # type: ignore[arg-type]
travel_mode="TRANSIT",
duration_seconds=result.duration_seconds,
distance_meters=result.distance_meters,
computed_at=datetime.utcnow(),
) # type: ignore[call-arg]
tasks = [compute_one(listing) for listing in listings]
for coro in asyncio.as_completed(tasks):
result = await coro
if result is not None:
batch_results.append(result)
computed += 1
# Periodically save results
if len(batch_results) >= save_interval:
poi_repo.upsert_distances(batch_results)
batch_results = []
if on_progress:
on_progress(
progress_offset + computed, progress_total,
f"TRANSIT: {computed}/{len(listings)}"
)
# Save remaining results
if batch_results:
poi_repo.upsert_distances(batch_results)
return computed