from models.listing import ListingType from models.poi import PointOfInterest from models.poi_distance import POIDistance from sqlalchemy import Engine, delete from sqlmodel import Session, select class POIRepository: engine: Engine def __init__(self, engine: Engine) -> None: self.engine = engine def get_pois_for_user(self, user_id: int) -> list[PointOfInterest]: with Session(self.engine) as session: statement = select(PointOfInterest).where( PointOfInterest.user_id == user_id ) return list(session.exec(statement).all()) def get_poi_by_id(self, poi_id: int) -> PointOfInterest | None: with Session(self.engine) as session: return session.get(PointOfInterest, poi_id) def create_poi(self, poi: PointOfInterest) -> PointOfInterest: with Session(self.engine) as session: session.add(poi) session.commit() session.refresh(poi) return poi def update_poi(self, poi: PointOfInterest) -> PointOfInterest: with Session(self.engine) as session: session.merge(poi) session.commit() # Re-fetch to get the refreshed state updated = session.get(PointOfInterest, poi.id) assert updated is not None return updated def delete_poi(self, poi_id: int) -> bool: with Session(self.engine) as session: poi = session.get(PointOfInterest, poi_id) if poi is None: return False # Cascade: delete associated distances session.exec( delete(POIDistance).where(POIDistance.poi_id == poi_id) # type: ignore[arg-type] ) session.delete(poi) session.commit() return True def upsert_distances(self, distances: list[POIDistance]) -> None: """Insert or update POI distances, handling duplicate unique constraints.""" if not distances: return with Session(self.engine) as session: dialect = self.engine.dialect.name for dist in distances: values = { "listing_id": dist.listing_id, "listing_type": dist.listing_type, "poi_id": dist.poi_id, "travel_mode": dist.travel_mode, "duration_seconds": dist.duration_seconds, "distance_meters": dist.distance_meters, "computed_at": dist.computed_at, } if dialect == "mysql": from sqlalchemy.dialects.mysql import insert as mysql_insert stmt = mysql_insert(POIDistance).values(**values) stmt = stmt.on_duplicate_key_update( duration_seconds=stmt.inserted.duration_seconds, distance_meters=stmt.inserted.distance_meters, computed_at=stmt.inserted.computed_at, ) else: from sqlalchemy.dialects.sqlite import insert as sqlite_insert stmt = sqlite_insert(POIDistance).values(**values) stmt = stmt.on_conflict_do_update( index_elements=["listing_id", "listing_type", "poi_id", "travel_mode"], set_={ "duration_seconds": stmt.excluded.duration_seconds, "distance_meters": stmt.excluded.distance_meters, "computed_at": stmt.excluded.computed_at, }, ) session.execute(stmt) session.commit() def get_distances_for_listings( self, listing_ids: list[int], listing_type: ListingType, user_id: int, ) -> list[POIDistance]: with Session(self.engine) as session: # Join with POI to filter by user statement = ( select(POIDistance) .join(PointOfInterest, POIDistance.poi_id == PointOfInterest.id) .where( POIDistance.listing_id.in_(listing_ids), # type: ignore[union-attr] POIDistance.listing_type == listing_type, PointOfInterest.user_id == user_id, ) ) return list(session.exec(statement).all()) def get_distances_for_poi(self, poi_id: int) -> list[POIDistance]: with Session(self.engine) as session: statement = select(POIDistance).where(POIDistance.poi_id == poi_id) return list(session.exec(statement).all()) def delete_distances_for_poi(self, poi_id: int) -> int: with Session(self.engine) as session: result = session.exec( delete(POIDistance).where(POIDistance.poi_id == poi_id) # type: ignore[arg-type] ) session.commit() return result.rowcount # type: ignore[union-attr] def get_existing_distance_keys( self, poi_id: int, travel_mode: str, listing_type: ListingType ) -> set[int]: """Get listing IDs that already have computed distances for a POI+mode.""" with Session(self.engine) as session: statement = ( select(POIDistance.listing_id) .where( POIDistance.poi_id == poi_id, POIDistance.travel_mode == travel_mode, POIDistance.listing_type == listing_type, ) ) return {row for row in session.exec(statement).all()}