wrongmove/repositories/poi_repository.py
Viktor Barzin 5b566bab4c
Fix POI distance calculation reliability for remote/Celery execution
- Fix silent log loss: replace hardcoded "uvicorn.error" logger with __name__
  in osrm_client, otp_client, poi_distance_calculator, and poi_tasks (uvicorn
  logger has no handlers in Celery worker, so all errors were silently dropped)
- Add Celery retry: autoretry_for=(Exception,), max_retries=3, retry_backoff
- Add top-level exception handling in task with full traceback logging
- Fix upsert_distances: replace session.merge() (PK-based) with proper
  dialect-aware INSERT ON DUPLICATE KEY UPDATE / ON CONFLICT DO UPDATE
- Filter out listings with null/zero coordinates before routing
- Raise OSError when all routing engines fail with 0 results computed,
  distinguishing "nothing to compute" from "all engines unreachable"
2026-02-08 20:11:12 +00:00

137 lines
5.6 KiB
Python

from models.listing import ListingType
from models.poi import PointOfInterest
from models.poi_distance import POIDistance
from sqlalchemy import Engine, delete
from sqlmodel import Session, select
class POIRepository:
engine: Engine
def __init__(self, engine: Engine) -> None:
self.engine = engine
def get_pois_for_user(self, user_id: int) -> list[PointOfInterest]:
with Session(self.engine) as session:
statement = select(PointOfInterest).where(
PointOfInterest.user_id == user_id
)
return list(session.exec(statement).all())
def get_poi_by_id(self, poi_id: int) -> PointOfInterest | None:
with Session(self.engine) as session:
return session.get(PointOfInterest, poi_id)
def create_poi(self, poi: PointOfInterest) -> PointOfInterest:
with Session(self.engine) as session:
session.add(poi)
session.commit()
session.refresh(poi)
return poi
def update_poi(self, poi: PointOfInterest) -> PointOfInterest:
with Session(self.engine) as session:
session.merge(poi)
session.commit()
# Re-fetch to get the refreshed state
updated = session.get(PointOfInterest, poi.id)
assert updated is not None
return updated
def delete_poi(self, poi_id: int) -> bool:
with Session(self.engine) as session:
poi = session.get(PointOfInterest, poi_id)
if poi is None:
return False
# Cascade: delete associated distances
session.exec(
delete(POIDistance).where(POIDistance.poi_id == poi_id) # type: ignore[arg-type]
)
session.delete(poi)
session.commit()
return True
def upsert_distances(self, distances: list[POIDistance]) -> None:
"""Insert or update POI distances, handling duplicate unique constraints."""
if not distances:
return
with Session(self.engine) as session:
dialect = self.engine.dialect.name
for dist in distances:
values = {
"listing_id": dist.listing_id,
"listing_type": dist.listing_type,
"poi_id": dist.poi_id,
"travel_mode": dist.travel_mode,
"duration_seconds": dist.duration_seconds,
"distance_meters": dist.distance_meters,
"computed_at": dist.computed_at,
}
if dialect == "mysql":
from sqlalchemy.dialects.mysql import insert as mysql_insert
stmt = mysql_insert(POIDistance).values(**values)
stmt = stmt.on_duplicate_key_update(
duration_seconds=stmt.inserted.duration_seconds,
distance_meters=stmt.inserted.distance_meters,
computed_at=stmt.inserted.computed_at,
)
else:
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(POIDistance).values(**values)
stmt = stmt.on_conflict_do_update(
index_elements=["listing_id", "listing_type", "poi_id", "travel_mode"],
set_={
"duration_seconds": stmt.excluded.duration_seconds,
"distance_meters": stmt.excluded.distance_meters,
"computed_at": stmt.excluded.computed_at,
},
)
session.execute(stmt)
session.commit()
def get_distances_for_listings(
self,
listing_ids: list[int],
listing_type: ListingType,
user_id: int,
) -> list[POIDistance]:
with Session(self.engine) as session:
# Join with POI to filter by user
statement = (
select(POIDistance)
.join(PointOfInterest, POIDistance.poi_id == PointOfInterest.id)
.where(
POIDistance.listing_id.in_(listing_ids), # type: ignore[union-attr]
POIDistance.listing_type == listing_type,
PointOfInterest.user_id == user_id,
)
)
return list(session.exec(statement).all())
def get_distances_for_poi(self, poi_id: int) -> list[POIDistance]:
with Session(self.engine) as session:
statement = select(POIDistance).where(POIDistance.poi_id == poi_id)
return list(session.exec(statement).all())
def delete_distances_for_poi(self, poi_id: int) -> int:
with Session(self.engine) as session:
result = session.exec(
delete(POIDistance).where(POIDistance.poi_id == poi_id) # type: ignore[arg-type]
)
session.commit()
return result.rowcount # type: ignore[union-attr]
def get_existing_distance_keys(
self, poi_id: int, travel_mode: str, listing_type: ListingType
) -> set[int]:
"""Get listing IDs that already have computed distances for a POI+mode."""
with Session(self.engine) as session:
statement = (
select(POIDistance.listing_id)
.where(
POIDistance.poi_id == poi_id,
POIDistance.travel_mode == travel_mode,
POIDistance.listing_type == listing_type,
)
)
return {row for row in session.exec(statement).all()}