The distance calculator always queried the rentlisting table regardless of listing type because get_listings() defaulted to RentListing when called without query_parameters. Added a listing_type parameter to get_listings() and _get_model_for_query() so callers can select the correct table directly.
231 lines
8 KiB
Python
231 lines
8 KiB
Python
"""POI distance calculator - orchestrates OSRM and OTP for batch distance computation."""
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Callable
|
|
|
|
import aiohttp
|
|
|
|
from config.routing_config import RoutingConfig
|
|
from models.listing import ListingType
|
|
from models.poi import PointOfInterest
|
|
from models.poi_distance import POIDistance
|
|
from rec.osrm_client import osrm_table
|
|
from rec.otp_client import otp_transit_route
|
|
from repositories.listing_repository import ListingRepository
|
|
from repositories.poi_repository import POIRepository
|
|
|
|
logger = logging.getLogger("uvicorn.error")
|
|
|
|
# Map travel mode names to OSRM profiles
|
|
_OSRM_PROFILES = {
|
|
"WALK": "foot",
|
|
"BICYCLE": "bicycle",
|
|
}
|
|
|
|
|
|
async def calculate_poi_distances(
|
|
listing_repo: ListingRepository,
|
|
poi_repo: POIRepository,
|
|
poi: PointOfInterest,
|
|
travel_modes: list[str],
|
|
listing_type: ListingType,
|
|
listing_ids: list[int] | None = None,
|
|
config: RoutingConfig | None = None,
|
|
on_progress: Callable[[int, int, str], None] | None = None,
|
|
) -> int:
|
|
"""Calculate distances from listings to a POI for given travel modes.
|
|
|
|
Args:
|
|
listing_repo: Repository for listing access.
|
|
poi_repo: Repository for POI and distance storage.
|
|
poi: The point of interest to calculate distances to.
|
|
travel_modes: List of travel modes (WALK, BICYCLE, TRANSIT).
|
|
listing_type: BUY or RENT.
|
|
listing_ids: Optional subset of listing IDs. If None, uses all listings.
|
|
config: Routing engine configuration.
|
|
on_progress: Callback(completed, total, message) for progress updates.
|
|
|
|
Returns:
|
|
Total number of distances computed.
|
|
"""
|
|
if config is None:
|
|
config = RoutingConfig.from_env()
|
|
|
|
# Load listings with coordinates
|
|
listings = await listing_repo.get_listings(
|
|
only_ids=listing_ids,
|
|
listing_type=listing_type,
|
|
)
|
|
if not listings:
|
|
logger.info("No listings found for distance calculation")
|
|
return 0
|
|
|
|
total_computed = 0
|
|
total_modes = len(travel_modes)
|
|
|
|
for mode_idx, mode in enumerate(travel_modes):
|
|
mode_upper = mode.upper()
|
|
|
|
# Skip listings that already have computed distances
|
|
existing = poi_repo.get_existing_distance_keys(
|
|
poi.id, mode_upper, listing_type # type: ignore[arg-type]
|
|
)
|
|
pending_listings = [l for l in listings if l.id not in existing]
|
|
|
|
if not pending_listings:
|
|
logger.info(f"All listings already computed for {mode_upper}")
|
|
if on_progress:
|
|
on_progress(
|
|
total_computed, len(listings) * total_modes,
|
|
f"Skipped {mode_upper} (already computed)"
|
|
)
|
|
continue
|
|
|
|
logger.info(
|
|
f"Computing {mode_upper} distances for {len(pending_listings)} listings "
|
|
f"(skipped {len(existing)} already computed)"
|
|
)
|
|
|
|
try:
|
|
if mode_upper in _OSRM_PROFILES:
|
|
computed = await _compute_osrm(
|
|
pending_listings, poi, mode_upper, listing_type,
|
|
config, poi_repo, on_progress,
|
|
total_computed, len(listings) * total_modes,
|
|
)
|
|
elif mode_upper == "TRANSIT":
|
|
computed = await _compute_transit(
|
|
pending_listings, poi, listing_type,
|
|
config, poi_repo, on_progress,
|
|
total_computed, len(listings) * total_modes,
|
|
)
|
|
else:
|
|
logger.warning(f"Unknown travel mode: {mode_upper}")
|
|
continue
|
|
except (aiohttp.ClientError, OSError) as e:
|
|
logger.error(f"Routing engine unreachable for {mode_upper}: {e}")
|
|
if on_progress:
|
|
on_progress(
|
|
total_computed, len(listings) * total_modes,
|
|
f"{mode_upper}: routing engine unavailable"
|
|
)
|
|
continue
|
|
|
|
total_computed += computed
|
|
|
|
return total_computed
|
|
|
|
|
|
async def _compute_osrm(
|
|
listings: list,
|
|
poi: PointOfInterest,
|
|
mode: str,
|
|
listing_type: ListingType,
|
|
config: RoutingConfig,
|
|
poi_repo: POIRepository,
|
|
on_progress: Callable[[int, int, str], None] | None,
|
|
progress_offset: int,
|
|
progress_total: int,
|
|
) -> int:
|
|
"""Compute distances using OSRM /table API in batches."""
|
|
profile = _OSRM_PROFILES[mode]
|
|
destination = [(poi.longitude, poi.latitude)]
|
|
batch_size = config.osrm_batch_size
|
|
computed = 0
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
for batch_start in range(0, len(listings), batch_size):
|
|
batch = listings[batch_start:batch_start + batch_size]
|
|
origins = [(l.longitude, l.latitude) for l in batch]
|
|
|
|
results = await osrm_table(
|
|
origins, destination, profile, config, session
|
|
)
|
|
|
|
distances_to_save: list[POIDistance] = []
|
|
for i, listing in enumerate(batch):
|
|
result = results[i][0] if results[i] else None
|
|
if result is not None:
|
|
distances_to_save.append(POIDistance(
|
|
listing_id=listing.id,
|
|
listing_type=listing_type,
|
|
poi_id=poi.id, # type: ignore[arg-type]
|
|
travel_mode=mode,
|
|
duration_seconds=result.duration_seconds,
|
|
distance_meters=result.distance_meters,
|
|
computed_at=datetime.utcnow(),
|
|
)) # type: ignore[call-arg]
|
|
|
|
if distances_to_save:
|
|
poi_repo.upsert_distances(distances_to_save)
|
|
computed += len(distances_to_save)
|
|
|
|
if on_progress:
|
|
on_progress(
|
|
progress_offset + computed, progress_total,
|
|
f"{mode}: {computed}/{len(listings)}"
|
|
)
|
|
|
|
return computed
|
|
|
|
|
|
async def _compute_transit(
|
|
listings: list,
|
|
poi: PointOfInterest,
|
|
listing_type: ListingType,
|
|
config: RoutingConfig,
|
|
poi_repo: POIRepository,
|
|
on_progress: Callable[[int, int, str], None] | None,
|
|
progress_offset: int,
|
|
progress_total: int,
|
|
) -> int:
|
|
"""Compute transit distances using OTP with concurrency control."""
|
|
semaphore = asyncio.Semaphore(config.otp_max_concurrent)
|
|
computed = 0
|
|
batch_results: list[POIDistance] = []
|
|
save_interval = 50 # Save every N results
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async def compute_one(listing: object) -> POIDistance | None:
|
|
async with semaphore:
|
|
result = await otp_transit_route(
|
|
listing.latitude, listing.longitude, # type: ignore[union-attr]
|
|
poi.latitude, poi.longitude,
|
|
config, session,
|
|
)
|
|
if result is None:
|
|
return None
|
|
return POIDistance(
|
|
listing_id=listing.id, # type: ignore[union-attr]
|
|
listing_type=listing_type,
|
|
poi_id=poi.id, # type: ignore[arg-type]
|
|
travel_mode="TRANSIT",
|
|
duration_seconds=result.duration_seconds,
|
|
distance_meters=result.distance_meters,
|
|
computed_at=datetime.utcnow(),
|
|
) # type: ignore[call-arg]
|
|
|
|
tasks = [compute_one(listing) for listing in listings]
|
|
for coro in asyncio.as_completed(tasks):
|
|
result = await coro
|
|
if result is not None:
|
|
batch_results.append(result)
|
|
computed += 1
|
|
|
|
# Periodically save results
|
|
if len(batch_results) >= save_interval:
|
|
poi_repo.upsert_distances(batch_results)
|
|
batch_results = []
|
|
|
|
if on_progress:
|
|
on_progress(
|
|
progress_offset + computed, progress_total,
|
|
f"TRANSIT: {computed}/{len(listings)}"
|
|
)
|
|
|
|
# Save remaining results
|
|
if batch_results:
|
|
poi_repo.upsert_distances(batch_results)
|
|
|
|
return computed
|