diff --git a/repositories/poi_repository.py b/repositories/poi_repository.py new file mode 100644 index 0000000..97187c2 --- /dev/null +++ b/repositories/poi_repository.py @@ -0,0 +1,105 @@ +from models.listing import ListingType +from models.poi import PointOfInterest +from models.poi_distance import POIDistance +from sqlalchemy import Engine, delete +from sqlmodel import Session, select + + +class POIRepository: + engine: Engine + + def __init__(self, engine: Engine) -> None: + self.engine = engine + + def get_pois_for_user(self, user_id: int) -> list[PointOfInterest]: + with Session(self.engine) as session: + statement = select(PointOfInterest).where( + PointOfInterest.user_id == user_id + ) + return list(session.exec(statement).all()) + + def get_poi_by_id(self, poi_id: int) -> PointOfInterest | None: + with Session(self.engine) as session: + return session.get(PointOfInterest, poi_id) + + def create_poi(self, poi: PointOfInterest) -> PointOfInterest: + with Session(self.engine) as session: + session.add(poi) + session.commit() + session.refresh(poi) + return poi + + def update_poi(self, poi: PointOfInterest) -> PointOfInterest: + with Session(self.engine) as session: + session.merge(poi) + session.commit() + # Re-fetch to get the refreshed state + updated = session.get(PointOfInterest, poi.id) + assert updated is not None + return updated + + def delete_poi(self, poi_id: int) -> bool: + with Session(self.engine) as session: + poi = session.get(PointOfInterest, poi_id) + if poi is None: + return False + # Cascade: delete associated distances + session.exec( + delete(POIDistance).where(POIDistance.poi_id == poi_id) # type: ignore[arg-type] + ) + session.delete(poi) + session.commit() + return True + + def upsert_distances(self, distances: list[POIDistance]) -> None: + with Session(self.engine) as session: + for dist in distances: + session.merge(dist) + session.commit() + + def get_distances_for_listings( + self, + listing_ids: list[int], + listing_type: ListingType, + user_id: int, + ) -> list[POIDistance]: + with Session(self.engine) as session: + # Join with POI to filter by user + statement = ( + select(POIDistance) + .join(PointOfInterest, POIDistance.poi_id == PointOfInterest.id) + .where( + POIDistance.listing_id.in_(listing_ids), # type: ignore[union-attr] + POIDistance.listing_type == listing_type, + PointOfInterest.user_id == user_id, + ) + ) + return list(session.exec(statement).all()) + + def get_distances_for_poi(self, poi_id: int) -> list[POIDistance]: + with Session(self.engine) as session: + statement = select(POIDistance).where(POIDistance.poi_id == poi_id) + return list(session.exec(statement).all()) + + def delete_distances_for_poi(self, poi_id: int) -> int: + with Session(self.engine) as session: + result = session.exec( + delete(POIDistance).where(POIDistance.poi_id == poi_id) # type: ignore[arg-type] + ) + session.commit() + return result.rowcount # type: ignore[union-attr] + + def get_existing_distance_keys( + self, poi_id: int, travel_mode: str, listing_type: ListingType + ) -> set[int]: + """Get listing IDs that already have computed distances for a POI+mode.""" + with Session(self.engine) as session: + statement = ( + select(POIDistance.listing_id) + .where( + POIDistance.poi_id == poi_id, + POIDistance.travel_mode == travel_mode, + POIDistance.listing_type == listing_type, + ) + ) + return {row for row in session.exec(statement).all()} diff --git a/services/poi_service.py b/services/poi_service.py new file mode 100644 index 0000000..39c133b --- /dev/null +++ b/services/poi_service.py @@ -0,0 +1,126 @@ +"""Unified POI service - shared between CLI and HTTP API. + +This module provides the core business logic for POI operations. +Both the CLI (main.py) and HTTP API (api/poi_routes.py) should use these functions. +""" +from dataclasses import dataclass + +from models.listing import ListingType +from models.poi import PointOfInterest +from models.poi_distance import POIDistance +from repositories.poi_repository import POIRepository + + +@dataclass +class POIResult: + """Result of a POI operation.""" + poi: PointOfInterest + message: str | None = None + + +@dataclass +class CalculateResult: + """Result of a distance calculation.""" + task_id: str | None # None if run synchronously + distances_computed: int + message: str + + +def get_user_pois(repository: POIRepository, user_id: int) -> list[PointOfInterest]: + """Get all POIs for a user.""" + return repository.get_pois_for_user(user_id) + + +def get_poi(repository: POIRepository, poi_id: int) -> PointOfInterest | None: + """Get a single POI by ID.""" + return repository.get_poi_by_id(poi_id) + + +def create_poi( + repository: POIRepository, + user_id: int, + name: str, + address: str, + latitude: float, + longitude: float, +) -> POIResult: + """Create a new POI for a user.""" + poi = PointOfInterest( + user_id=user_id, + name=name, + address=address, + latitude=latitude, + longitude=longitude, + ) # type: ignore[call-arg] + created = repository.create_poi(poi) + return POIResult(poi=created, message=f"Created POI '{name}'") + + +def update_poi( + repository: POIRepository, + poi_id: int, + user_id: int, + name: str | None = None, + address: str | None = None, + latitude: float | None = None, + longitude: float | None = None, +) -> POIResult | None: + """Update an existing POI. Returns None if not found or not owned by user.""" + poi = repository.get_poi_by_id(poi_id) + if poi is None or poi.user_id != user_id: + return None + + if name is not None: + poi.name = name + if address is not None: + poi.address = address + if latitude is not None: + poi.latitude = latitude + if longitude is not None: + poi.longitude = longitude + + updated = repository.update_poi(poi) + return POIResult(poi=updated, message=f"Updated POI '{updated.name}'") + + +def delete_poi(repository: POIRepository, poi_id: int, user_id: int) -> bool: + """Delete a POI. Returns False if not found or not owned by user.""" + poi = repository.get_poi_by_id(poi_id) + if poi is None or poi.user_id != user_id: + return False + return repository.delete_poi(poi_id) + + +def get_distances_for_listing( + repository: POIRepository, + listing_id: int, + listing_type: ListingType, + user_id: int, +) -> list[POIDistance]: + """Get all POI distances for a specific listing.""" + return repository.get_distances_for_listings( + [listing_id], listing_type, user_id + ) + + +def trigger_calculation( + poi_id: int, + travel_modes: list[str], + listing_type: ListingType, + user_email: str, + listing_ids: list[int] | None = None, +) -> CalculateResult: + """Trigger a background distance calculation task.""" + from tasks.poi_tasks import calculate_poi_distances_task + + task = calculate_poi_distances_task.delay( + poi_id=poi_id, + travel_modes=travel_modes, + listing_type=listing_type.value, + listing_ids=listing_ids, + ) + return CalculateResult( + task_id=task.id, + distances_computed=0, + message=f"Task {task.id} started for POI {poi_id}", + )