from models.listing import ListingType from models.poi import PointOfInterest from models.poi_distance import POIDistance from sqlalchemy import Engine, delete from sqlmodel import Session, select class POIRepository: engine: Engine def __init__(self, engine: Engine) -> None: self.engine = engine def get_pois_for_user(self, user_id: int) -> list[PointOfInterest]: with Session(self.engine) as session: statement = select(PointOfInterest).where( PointOfInterest.user_id == user_id ) return list(session.exec(statement).all()) def get_poi_by_id(self, poi_id: int) -> PointOfInterest | None: with Session(self.engine) as session: return session.get(PointOfInterest, poi_id) def create_poi(self, poi: PointOfInterest) -> PointOfInterest: with Session(self.engine) as session: session.add(poi) session.commit() session.refresh(poi) return poi def update_poi(self, poi: PointOfInterest) -> PointOfInterest: with Session(self.engine) as session: session.merge(poi) session.commit() # Re-fetch to get the refreshed state updated = session.get(PointOfInterest, poi.id) assert updated is not None return updated def delete_poi(self, poi_id: int) -> bool: with Session(self.engine) as session: poi = session.get(PointOfInterest, poi_id) if poi is None: return False # Cascade: delete associated distances session.exec( delete(POIDistance).where(POIDistance.poi_id == poi_id) # type: ignore[arg-type] ) session.delete(poi) session.commit() return True def upsert_distances(self, distances: list[POIDistance]) -> None: with Session(self.engine) as session: for dist in distances: session.merge(dist) session.commit() def get_distances_for_listings( self, listing_ids: list[int], listing_type: ListingType, user_id: int, ) -> list[POIDistance]: with Session(self.engine) as session: # Join with POI to filter by user statement = ( select(POIDistance) .join(PointOfInterest, POIDistance.poi_id == PointOfInterest.id) .where( POIDistance.listing_id.in_(listing_ids), # type: ignore[union-attr] POIDistance.listing_type == listing_type, PointOfInterest.user_id == user_id, ) ) return list(session.exec(statement).all()) def get_distances_for_poi(self, poi_id: int) -> list[POIDistance]: with Session(self.engine) as session: statement = select(POIDistance).where(POIDistance.poi_id == poi_id) return list(session.exec(statement).all()) def delete_distances_for_poi(self, poi_id: int) -> int: with Session(self.engine) as session: result = session.exec( delete(POIDistance).where(POIDistance.poi_id == poi_id) # type: ignore[arg-type] ) session.commit() return result.rowcount # type: ignore[union-attr] def get_existing_distance_keys( self, poi_id: int, travel_mode: str, listing_type: ListingType ) -> set[int]: """Get listing IDs that already have computed distances for a POI+mode.""" with Session(self.engine) as session: statement = ( select(POIDistance.listing_id) .where( POIDistance.poi_id == poi_id, POIDistance.travel_mode == travel_mode, POIDistance.listing_type == listing_type, ) ) return {row for row in session.exec(statement).all()}