wrongmove/services/poi_distance_calculator.py
Viktor Barzin 5b566bab4c
Fix POI distance calculation reliability for remote/Celery execution
- Fix silent log loss: replace hardcoded "uvicorn.error" logger with __name__
  in osrm_client, otp_client, poi_distance_calculator, and poi_tasks (uvicorn
  logger has no handlers in Celery worker, so all errors were silently dropped)
- Add Celery retry: autoretry_for=(Exception,), max_retries=3, retry_backoff
- Add top-level exception handling in task with full traceback logging
- Fix upsert_distances: replace session.merge() (PK-based) with proper
  dialect-aware INSERT ON DUPLICATE KEY UPDATE / ON CONFLICT DO UPDATE
- Filter out listings with null/zero coordinates before routing
- Raise OSError when all routing engines fail with 0 results computed,
  distinguishing "nothing to compute" from "all engines unreachable"
2026-02-08 20:11:12 +00:00

256 lines
8.9 KiB
Python

"""POI distance calculator - orchestrates OSRM and OTP for batch distance computation."""
import asyncio
import logging
from datetime import datetime
from typing import Callable
import aiohttp
from config.routing_config import RoutingConfig
from models.listing import ListingType
from models.poi import PointOfInterest
from models.poi_distance import POIDistance
from rec.osrm_client import osrm_table
from rec.otp_client import otp_transit_route
from repositories.listing_repository import ListingRepository
from repositories.poi_repository import POIRepository
logger = logging.getLogger(__name__)
# Map travel mode names to OSRM profiles
_OSRM_PROFILES = {
"WALK": "foot",
"BICYCLE": "bicycle",
}
async def calculate_poi_distances(
listing_repo: ListingRepository,
poi_repo: POIRepository,
poi: PointOfInterest,
travel_modes: list[str],
listing_type: ListingType,
listing_ids: list[int] | None = None,
config: RoutingConfig | None = None,
on_progress: Callable[[int, int, str], None] | None = None,
) -> int:
"""Calculate distances from listings to a POI for given travel modes.
Args:
listing_repo: Repository for listing access.
poi_repo: Repository for POI and distance storage.
poi: The point of interest to calculate distances to.
travel_modes: List of travel modes (WALK, BICYCLE, TRANSIT).
listing_type: BUY or RENT.
listing_ids: Optional subset of listing IDs. If None, uses all listings.
config: Routing engine configuration.
on_progress: Callback(completed, total, message) for progress updates.
Returns:
Total number of distances computed.
"""
if config is None:
config = RoutingConfig.from_env()
# Load listings with coordinates
listings = await listing_repo.get_listings(
only_ids=listing_ids,
listing_type=listing_type,
)
if not listings:
logger.info("No listings found for distance calculation")
return 0
# Filter out listings without valid coordinates
valid_listings = [
l for l in listings
if l.latitude is not None and l.longitude is not None
and l.latitude != 0 and l.longitude != 0
]
if len(valid_listings) < len(listings):
logger.warning(
f"Skipped {len(listings) - len(valid_listings)} listings "
f"with missing coordinates"
)
listings = valid_listings
if not listings:
logger.warning("No listings with valid coordinates")
return 0
total_computed = 0
total_modes = len(travel_modes)
modes_failed: list[str] = []
for mode_idx, mode in enumerate(travel_modes):
mode_upper = mode.upper()
# Skip listings that already have computed distances
existing = poi_repo.get_existing_distance_keys(
poi.id, mode_upper, listing_type # type: ignore[arg-type]
)
pending_listings = [l for l in listings if l.id not in existing]
if not pending_listings:
logger.info(f"All listings already computed for {mode_upper}")
if on_progress:
on_progress(
total_computed, len(listings) * total_modes,
f"Skipped {mode_upper} (already computed)"
)
continue
logger.info(
f"Computing {mode_upper} distances for {len(pending_listings)} listings "
f"(skipped {len(existing)} already computed)"
)
try:
if mode_upper in _OSRM_PROFILES:
computed = await _compute_osrm(
pending_listings, poi, mode_upper, listing_type,
config, poi_repo, on_progress,
total_computed, len(listings) * total_modes,
)
elif mode_upper == "TRANSIT":
computed = await _compute_transit(
pending_listings, poi, listing_type,
config, poi_repo, on_progress,
total_computed, len(listings) * total_modes,
)
else:
logger.warning(f"Unknown travel mode: {mode_upper}")
continue
except (aiohttp.ClientError, OSError) as e:
logger.error(f"Routing engine unreachable for {mode_upper}: {e}")
modes_failed.append(mode_upper)
if on_progress:
on_progress(
total_computed, len(listings) * total_modes,
f"{mode_upper}: routing engine unavailable"
)
continue
total_computed += computed
if modes_failed and total_computed == 0:
failed_str = ", ".join(modes_failed)
raise OSError(
f"All routing engines failed ({failed_str}). "
f"No distances computed for {len(listings)} listings."
)
return total_computed
async def _compute_osrm(
listings: list,
poi: PointOfInterest,
mode: str,
listing_type: ListingType,
config: RoutingConfig,
poi_repo: POIRepository,
on_progress: Callable[[int, int, str], None] | None,
progress_offset: int,
progress_total: int,
) -> int:
"""Compute distances using OSRM /table API in batches."""
profile = _OSRM_PROFILES[mode]
destination = [(poi.longitude, poi.latitude)]
batch_size = config.osrm_batch_size
computed = 0
async with aiohttp.ClientSession() as session:
for batch_start in range(0, len(listings), batch_size):
batch = listings[batch_start:batch_start + batch_size]
origins = [(l.longitude, l.latitude) for l in batch]
results = await osrm_table(
origins, destination, profile, config, session
)
distances_to_save: list[POIDistance] = []
for i, listing in enumerate(batch):
result = results[i][0] if results[i] else None
if result is not None:
distances_to_save.append(POIDistance(
listing_id=listing.id,
listing_type=listing_type,
poi_id=poi.id, # type: ignore[arg-type]
travel_mode=mode,
duration_seconds=result.duration_seconds,
distance_meters=result.distance_meters,
computed_at=datetime.utcnow(),
)) # type: ignore[call-arg]
if distances_to_save:
poi_repo.upsert_distances(distances_to_save)
computed += len(distances_to_save)
if on_progress:
on_progress(
progress_offset + computed, progress_total,
f"{mode}: {computed}/{len(listings)}"
)
return computed
async def _compute_transit(
listings: list,
poi: PointOfInterest,
listing_type: ListingType,
config: RoutingConfig,
poi_repo: POIRepository,
on_progress: Callable[[int, int, str], None] | None,
progress_offset: int,
progress_total: int,
) -> int:
"""Compute transit distances using OTP with concurrency control."""
semaphore = asyncio.Semaphore(config.otp_max_concurrent)
computed = 0
batch_results: list[POIDistance] = []
save_interval = 50 # Save every N results
async with aiohttp.ClientSession() as session:
async def compute_one(listing: object) -> POIDistance | None:
async with semaphore:
result = await otp_transit_route(
listing.latitude, listing.longitude, # type: ignore[union-attr]
poi.latitude, poi.longitude,
config, session,
)
if result is None:
return None
return POIDistance(
listing_id=listing.id, # type: ignore[union-attr]
listing_type=listing_type,
poi_id=poi.id, # type: ignore[arg-type]
travel_mode="TRANSIT",
duration_seconds=result.duration_seconds,
distance_meters=result.distance_meters,
computed_at=datetime.utcnow(),
) # type: ignore[call-arg]
tasks = [compute_one(listing) for listing in listings]
for coro in asyncio.as_completed(tasks):
result = await coro
if result is not None:
batch_results.append(result)
computed += 1
# Periodically save results
if len(batch_results) >= save_interval:
poi_repo.upsert_distances(batch_results)
batch_results = []
if on_progress:
on_progress(
progress_offset + computed, progress_total,
f"TRANSIT: {computed}/{len(listings)}"
)
# Save remaining results
if batch_results:
poi_repo.upsert_distances(batch_results)
return computed