Add POI repository and service layer
POIRepository handles all database operations for POIs and distances including upsert, cascading delete, and skip-on-recompute via get_existing_distance_keys(). POI service provides unified high-level functions shared by both CLI and API.
This commit is contained in:
parent
5783d8fae9
commit
8a31e5449c
2 changed files with 231 additions and 0 deletions
105
repositories/poi_repository.py
Normal file
105
repositories/poi_repository.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
from models.listing import ListingType
|
||||
from models.poi import PointOfInterest
|
||||
from models.poi_distance import POIDistance
|
||||
from sqlalchemy import Engine, delete
|
||||
from sqlmodel import Session, select
|
||||
|
||||
|
||||
class POIRepository:
|
||||
engine: Engine
|
||||
|
||||
def __init__(self, engine: Engine) -> None:
|
||||
self.engine = engine
|
||||
|
||||
def get_pois_for_user(self, user_id: int) -> list[PointOfInterest]:
|
||||
with Session(self.engine) as session:
|
||||
statement = select(PointOfInterest).where(
|
||||
PointOfInterest.user_id == user_id
|
||||
)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
def get_poi_by_id(self, poi_id: int) -> PointOfInterest | None:
|
||||
with Session(self.engine) as session:
|
||||
return session.get(PointOfInterest, poi_id)
|
||||
|
||||
def create_poi(self, poi: PointOfInterest) -> PointOfInterest:
|
||||
with Session(self.engine) as session:
|
||||
session.add(poi)
|
||||
session.commit()
|
||||
session.refresh(poi)
|
||||
return poi
|
||||
|
||||
def update_poi(self, poi: PointOfInterest) -> PointOfInterest:
|
||||
with Session(self.engine) as session:
|
||||
session.merge(poi)
|
||||
session.commit()
|
||||
# Re-fetch to get the refreshed state
|
||||
updated = session.get(PointOfInterest, poi.id)
|
||||
assert updated is not None
|
||||
return updated
|
||||
|
||||
def delete_poi(self, poi_id: int) -> bool:
|
||||
with Session(self.engine) as session:
|
||||
poi = session.get(PointOfInterest, poi_id)
|
||||
if poi is None:
|
||||
return False
|
||||
# Cascade: delete associated distances
|
||||
session.exec(
|
||||
delete(POIDistance).where(POIDistance.poi_id == poi_id) # type: ignore[arg-type]
|
||||
)
|
||||
session.delete(poi)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
def upsert_distances(self, distances: list[POIDistance]) -> None:
|
||||
with Session(self.engine) as session:
|
||||
for dist in distances:
|
||||
session.merge(dist)
|
||||
session.commit()
|
||||
|
||||
def get_distances_for_listings(
|
||||
self,
|
||||
listing_ids: list[int],
|
||||
listing_type: ListingType,
|
||||
user_id: int,
|
||||
) -> list[POIDistance]:
|
||||
with Session(self.engine) as session:
|
||||
# Join with POI to filter by user
|
||||
statement = (
|
||||
select(POIDistance)
|
||||
.join(PointOfInterest, POIDistance.poi_id == PointOfInterest.id)
|
||||
.where(
|
||||
POIDistance.listing_id.in_(listing_ids), # type: ignore[union-attr]
|
||||
POIDistance.listing_type == listing_type,
|
||||
PointOfInterest.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
def get_distances_for_poi(self, poi_id: int) -> list[POIDistance]:
|
||||
with Session(self.engine) as session:
|
||||
statement = select(POIDistance).where(POIDistance.poi_id == poi_id)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
def delete_distances_for_poi(self, poi_id: int) -> int:
|
||||
with Session(self.engine) as session:
|
||||
result = session.exec(
|
||||
delete(POIDistance).where(POIDistance.poi_id == poi_id) # type: ignore[arg-type]
|
||||
)
|
||||
session.commit()
|
||||
return result.rowcount # type: ignore[union-attr]
|
||||
|
||||
def get_existing_distance_keys(
|
||||
self, poi_id: int, travel_mode: str, listing_type: ListingType
|
||||
) -> set[int]:
|
||||
"""Get listing IDs that already have computed distances for a POI+mode."""
|
||||
with Session(self.engine) as session:
|
||||
statement = (
|
||||
select(POIDistance.listing_id)
|
||||
.where(
|
||||
POIDistance.poi_id == poi_id,
|
||||
POIDistance.travel_mode == travel_mode,
|
||||
POIDistance.listing_type == listing_type,
|
||||
)
|
||||
)
|
||||
return {row for row in session.exec(statement).all()}
|
||||
126
services/poi_service.py
Normal file
126
services/poi_service.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
"""Unified POI service - shared between CLI and HTTP API.
|
||||
|
||||
This module provides the core business logic for POI operations.
|
||||
Both the CLI (main.py) and HTTP API (api/poi_routes.py) should use these functions.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from models.listing import ListingType
|
||||
from models.poi import PointOfInterest
|
||||
from models.poi_distance import POIDistance
|
||||
from repositories.poi_repository import POIRepository
|
||||
|
||||
|
||||
@dataclass
|
||||
class POIResult:
|
||||
"""Result of a POI operation."""
|
||||
poi: PointOfInterest
|
||||
message: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CalculateResult:
|
||||
"""Result of a distance calculation."""
|
||||
task_id: str | None # None if run synchronously
|
||||
distances_computed: int
|
||||
message: str
|
||||
|
||||
|
||||
def get_user_pois(repository: POIRepository, user_id: int) -> list[PointOfInterest]:
|
||||
"""Get all POIs for a user."""
|
||||
return repository.get_pois_for_user(user_id)
|
||||
|
||||
|
||||
def get_poi(repository: POIRepository, poi_id: int) -> PointOfInterest | None:
|
||||
"""Get a single POI by ID."""
|
||||
return repository.get_poi_by_id(poi_id)
|
||||
|
||||
|
||||
def create_poi(
|
||||
repository: POIRepository,
|
||||
user_id: int,
|
||||
name: str,
|
||||
address: str,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
) -> POIResult:
|
||||
"""Create a new POI for a user."""
|
||||
poi = PointOfInterest(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
address=address,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
) # type: ignore[call-arg]
|
||||
created = repository.create_poi(poi)
|
||||
return POIResult(poi=created, message=f"Created POI '{name}'")
|
||||
|
||||
|
||||
def update_poi(
|
||||
repository: POIRepository,
|
||||
poi_id: int,
|
||||
user_id: int,
|
||||
name: str | None = None,
|
||||
address: str | None = None,
|
||||
latitude: float | None = None,
|
||||
longitude: float | None = None,
|
||||
) -> POIResult | None:
|
||||
"""Update an existing POI. Returns None if not found or not owned by user."""
|
||||
poi = repository.get_poi_by_id(poi_id)
|
||||
if poi is None or poi.user_id != user_id:
|
||||
return None
|
||||
|
||||
if name is not None:
|
||||
poi.name = name
|
||||
if address is not None:
|
||||
poi.address = address
|
||||
if latitude is not None:
|
||||
poi.latitude = latitude
|
||||
if longitude is not None:
|
||||
poi.longitude = longitude
|
||||
|
||||
updated = repository.update_poi(poi)
|
||||
return POIResult(poi=updated, message=f"Updated POI '{updated.name}'")
|
||||
|
||||
|
||||
def delete_poi(repository: POIRepository, poi_id: int, user_id: int) -> bool:
|
||||
"""Delete a POI. Returns False if not found or not owned by user."""
|
||||
poi = repository.get_poi_by_id(poi_id)
|
||||
if poi is None or poi.user_id != user_id:
|
||||
return False
|
||||
return repository.delete_poi(poi_id)
|
||||
|
||||
|
||||
def get_distances_for_listing(
|
||||
repository: POIRepository,
|
||||
listing_id: int,
|
||||
listing_type: ListingType,
|
||||
user_id: int,
|
||||
) -> list[POIDistance]:
|
||||
"""Get all POI distances for a specific listing."""
|
||||
return repository.get_distances_for_listings(
|
||||
[listing_id], listing_type, user_id
|
||||
)
|
||||
|
||||
|
||||
def trigger_calculation(
|
||||
poi_id: int,
|
||||
travel_modes: list[str],
|
||||
listing_type: ListingType,
|
||||
user_email: str,
|
||||
listing_ids: list[int] | None = None,
|
||||
) -> CalculateResult:
|
||||
"""Trigger a background distance calculation task."""
|
||||
from tasks.poi_tasks import calculate_poi_distances_task
|
||||
|
||||
task = calculate_poi_distances_task.delay(
|
||||
poi_id=poi_id,
|
||||
travel_modes=travel_modes,
|
||||
listing_type=listing_type.value,
|
||||
listing_ids=listing_ids,
|
||||
)
|
||||
return CalculateResult(
|
||||
task_id=task.id,
|
||||
distances_computed=0,
|
||||
message=f"Task {task.id} started for POI {poi_id}",
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue