diff --git a/api/app.py b/api/app.py index edc191b..449e7b2 100644 --- a/api/app.py +++ b/api/app.py @@ -7,6 +7,7 @@ from typing import Annotated, AsyncGenerator, Optional from api.auth import get_current_user from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS from api.passkey_routes import passkey_router +from api.poi_routes import poi_router from api.rate_limit_config import RateLimitConfig from api.rate_limiter import RateLimitMiddleware from api.audit_middleware import AuditLogMiddleware @@ -28,6 +29,8 @@ from services.listing_cache import ( get_cached_features, cache_features_batch, ) +from repositories.poi_repository import POIRepository +from repositories.user_repository import UserRepository from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from api.metrics import metrics_app from opentelemetry.metrics import get_meter @@ -71,6 +74,7 @@ def get_query_parameters( app = FastAPI() app.include_router(passkey_router) +app.include_router(poi_router) app.mount("/metrics", metrics_app) meter = get_meter(__name__) request_counter = meter.create_counter( @@ -173,6 +177,7 @@ async def _stream_from_db( query_parameters: QueryParameters, batch_size: int, limit: int | None, + poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None, ) -> AsyncGenerator[str, None]: """Stream GeoJSON features from the database, populating the cache as we go.""" repository = ListingRepository(engine) @@ -193,6 +198,9 @@ async def _stream_from_db( query_parameters, limit=limit, page_size=batch_size ): feature = convert_row_to_geojson(row, query_parameters.listing_type.value) + # Inject POI distances if available + if poi_distances_lookup and row['id'] in poi_distances_lookup: + feature['properties']['poi_distances'] = poi_distances_lookup[row['id']] batch.append(feature) count += 1 @@ -214,6 +222,7 @@ async def stream_listing_geojson( query_parameters: Annotated[QueryParameters, Depends(get_query_parameters)], batch_size: int = DEFAULT_BATCH_SIZE, limit: int | None = None, + include_poi_distances: bool = False, ) -> StreamingResponse: """Stream listings as NDJSON for progressive map loading. @@ -228,11 +237,39 @@ async def stream_listing_geojson( else: limit = _rate_limit_config.geojson_stream_limit_cap + # Build POI distances lookup if requested + poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None + if include_poi_distances: + user_repo = UserRepository(engine) + db_user = user_repo.get_user_by_email(user.email) + if db_user and db_user.id is not None: + poi_repo = POIRepository(engine) + pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)} + if pois: + # Get all listing IDs first for the query + listing_repo = ListingRepository(engine) + all_ids = list(listing_repo.get_listing_ids(query_parameters.listing_type)) + if all_ids: + distances = poi_repo.get_distances_for_listings( + all_ids, query_parameters.listing_type, db_user.id + ) + poi_distances_lookup = {} + for d in distances: + poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown" + entry = { + "poi_id": d.poi_id, + "poi_name": poi_name, + "travel_mode": d.travel_mode, + "duration_seconds": d.duration_seconds, + "distance_meters": d.distance_meters, + } + poi_distances_lookup.setdefault(d.listing_id, []).append(entry) + cached_count = get_cached_count(query_parameters) - if cached_count is not None and cached_count > 0: + if cached_count is not None and cached_count > 0 and not include_poi_distances: generator = _stream_from_cache(query_parameters, batch_size, limit) else: - generator = _stream_from_db(query_parameters, batch_size, limit) + generator = _stream_from_db(query_parameters, batch_size, limit, poi_distances_lookup) return StreamingResponse( generator, diff --git a/api/poi_routes.py b/api/poi_routes.py new file mode 100644 index 0000000..07b57ac --- /dev/null +++ b/api/poi_routes.py @@ -0,0 +1,200 @@ +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel + +from api.auth import User, get_current_user +from database import engine +from models.listing import ListingType +from repositories.poi_repository import POIRepository +from repositories.user_repository import UserRepository +from services import poi_service, task_service + +logger = logging.getLogger("uvicorn") + +poi_router = APIRouter(prefix="/api/poi", tags=["poi"]) + + +class CreatePOIRequest(BaseModel): + name: str + address: str + latitude: float + longitude: float + + +class UpdatePOIRequest(BaseModel): + name: str | None = None + address: str | None = None + latitude: float | None = None + longitude: float | None = None + + +class POIResponse(BaseModel): + id: int + name: str + address: str + latitude: float + longitude: float + created_at: str + + +class CalculateRequest(BaseModel): + travel_modes: list[str] # WALK, BICYCLE, TRANSIT + listing_type: ListingType = ListingType.RENT + listing_ids: list[int] | None = None + + +class POIDistanceResponse(BaseModel): + poi_id: int + poi_name: str + travel_mode: str + duration_seconds: int + distance_meters: int + + +def _get_user_id(user: User) -> int: + """Resolve auth User to database user ID.""" + user_repo = UserRepository(engine) + db_user = user_repo.get_user_by_email(user.email) + if db_user is None: + # Auto-create user on first POI interaction + db_user = user_repo.create_user(user.email) + assert db_user.id is not None + return db_user.id + + +def _poi_to_response(poi: "poi_service.PointOfInterest") -> POIResponse: + return POIResponse( + id=poi.id, # type: ignore[arg-type] + name=poi.name, + address=poi.address, + latitude=poi.latitude, + longitude=poi.longitude, + created_at=poi.created_at.isoformat(), + ) + + +@poi_router.get("", response_model=list[POIResponse]) +async def list_pois( + user: Annotated[User, Depends(get_current_user)], +) -> list[POIResponse]: + """List all POIs for the current user.""" + user_id = _get_user_id(user) + repo = POIRepository(engine) + pois = poi_service.get_user_pois(repo, user_id) + return [_poi_to_response(p) for p in pois] + + +@poi_router.post("", response_model=POIResponse) +async def create_poi( + user: Annotated[User, Depends(get_current_user)], + body: CreatePOIRequest, +) -> POIResponse: + """Create a new POI.""" + user_id = _get_user_id(user) + repo = POIRepository(engine) + result = poi_service.create_poi( + repo, + user_id=user_id, + name=body.name, + address=body.address, + latitude=body.latitude, + longitude=body.longitude, + ) + return _poi_to_response(result.poi) + + +@poi_router.put("/{poi_id}", response_model=POIResponse) +async def update_poi( + user: Annotated[User, Depends(get_current_user)], + poi_id: int, + body: UpdatePOIRequest, +) -> POIResponse: + """Update an existing POI.""" + user_id = _get_user_id(user) + repo = POIRepository(engine) + result = poi_service.update_poi( + repo, + poi_id=poi_id, + user_id=user_id, + name=body.name, + address=body.address, + latitude=body.latitude, + longitude=body.longitude, + ) + if result is None: + raise HTTPException(status_code=404, detail="POI not found") + return _poi_to_response(result.poi) + + +@poi_router.delete("/{poi_id}") +async def delete_poi( + user: Annotated[User, Depends(get_current_user)], + poi_id: int, +) -> dict[str, bool]: + """Delete a POI and its associated distances.""" + user_id = _get_user_id(user) + repo = POIRepository(engine) + deleted = poi_service.delete_poi(repo, poi_id, user_id) + if not deleted: + raise HTTPException(status_code=404, detail="POI not found") + return {"success": True} + + +@poi_router.post("/{poi_id}/calculate") +async def calculate_distances( + user: Annotated[User, Depends(get_current_user)], + poi_id: int, + body: CalculateRequest, +) -> dict[str, str]: + """Trigger distance calculation for a POI.""" + user_id = _get_user_id(user) + repo = POIRepository(engine) + + # Verify POI exists and belongs to user + poi = poi_service.get_poi(repo, poi_id) + if poi is None or poi.user_id != user_id: + raise HTTPException(status_code=404, detail="POI not found") + + result = poi_service.trigger_calculation( + poi_id=poi_id, + travel_modes=body.travel_modes, + listing_type=body.listing_type, + user_email=user.email, + listing_ids=body.listing_ids, + ) + + if result.task_id: + task_service.add_task_for_user(user.email, result.task_id) + + return {"task_id": result.task_id or "", "message": result.message} + + +@poi_router.get("/distances") +async def get_distances( + user: Annotated[User, Depends(get_current_user)], + listing_id: int, + listing_type: ListingType = ListingType.RENT, +) -> list[POIDistanceResponse]: + """Get POI distances for a specific listing.""" + user_id = _get_user_id(user) + repo = POIRepository(engine) + poi_repo_pois = { + p.id: p for p in poi_service.get_user_pois(repo, user_id) + } + + distances = poi_service.get_distances_for_listing( + repo, listing_id, listing_type, user_id + ) + + return [ + POIDistanceResponse( + poi_id=d.poi_id, + poi_name=poi_repo_pois[d.poi_id].name if d.poi_id in poi_repo_pois else "Unknown", + travel_mode=d.travel_mode, + duration_seconds=d.duration_seconds, + distance_meters=d.distance_meters, + ) + for d in distances + ] diff --git a/celery_app.py b/celery_app.py index 5560df5..933c3cb 100644 --- a/celery_app.py +++ b/celery_app.py @@ -9,7 +9,7 @@ app = Celery( "celery_app", broker=os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"), backend=os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/1"), - include=["tasks.listing_tasks"], + include=["tasks.listing_tasks", "tasks.poi_tasks"], ) app.conf.update( diff --git a/tasks/poi_tasks.py b/tasks/poi_tasks.py new file mode 100644 index 0000000..b3416a4 --- /dev/null +++ b/tasks/poi_tasks.py @@ -0,0 +1,92 @@ +"""Celery tasks for POI distance calculation.""" +import asyncio +import logging +from typing import Any + +from celery import Task +from celery_app import app +from database import engine +from models.listing import ListingType +from repositories.listing_repository import ListingRepository +from repositories.poi_repository import POIRepository +from services.poi_distance_calculator import calculate_poi_distances + +logger = logging.getLogger("uvicorn.error") + +celery_logger = logging.getLogger("celery.task") +if not celery_logger.handlers: + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter( + "%(asctime)s [%(levelname)s] %(name)s: %(message)s" + )) + celery_logger.addHandler(handler) + celery_logger.setLevel(logging.INFO) + + +@app.task(bind=True) +def calculate_poi_distances_task( + self: Task, + poi_id: int, + travel_modes: list[str], + listing_type: str, + listing_ids: list[int] | None = None, +) -> dict[str, Any]: + """Background task to calculate distances from listings to a POI. + + Args: + poi_id: ID of the PointOfInterest. + travel_modes: List of travel modes (WALK, BICYCLE, TRANSIT). + listing_type: "BUY" or "RENT". + listing_ids: Optional subset of listing IDs. + """ + celery_logger.info( + f"Starting POI distance calculation: poi_id={poi_id}, " + f"modes={travel_modes}, type={listing_type}" + ) + + self.update_state(state="PROGRESS", meta={ + "phase": "starting", + "progress": 0, + "message": "Starting distance calculation...", + }) + + listing_repo = ListingRepository(engine) + poi_repo = POIRepository(engine) + + poi = poi_repo.get_poi_by_id(poi_id) + if poi is None: + celery_logger.error(f"POI {poi_id} not found") + return {"error": f"POI {poi_id} not found", "distances_computed": 0} + + lt = ListingType(listing_type) + + def on_progress(completed: int, total: int, message: str) -> None: + progress = round(completed / total, 2) if total > 0 else 0 + self.update_state(state="PROGRESS", meta={ + "phase": "computing", + "progress": progress, + "processed": completed, + "total": total, + "message": message, + }) + + total = asyncio.run( + calculate_poi_distances( + listing_repo=listing_repo, + poi_repo=poi_repo, + poi=poi, + travel_modes=travel_modes, + listing_type=lt, + listing_ids=listing_ids, + on_progress=on_progress, + ) + ) + + celery_logger.info(f"POI distance calculation complete: {total} distances computed") + + return { + "phase": "completed", + "progress": 1, + "distances_computed": total, + "message": f"Computed {total} distances for POI '{poi.name}'", + }