Add POI API routes and Celery task
FastAPI router with CRUD endpoints for POIs, distance calculation trigger, and distance queries. Streaming GeoJSON endpoint now accepts include_poi_distances=true to inject travel times into features. Celery task wraps the distance calculator with progress reporting.
This commit is contained in:
parent
da0a56895d
commit
bd788df9aa
4 changed files with 332 additions and 3 deletions
41
api/app.py
41
api/app.py
|
|
@ -7,6 +7,7 @@ from typing import Annotated, AsyncGenerator, Optional
|
||||||
from api.auth import get_current_user
|
from api.auth import get_current_user
|
||||||
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
||||||
from api.passkey_routes import passkey_router
|
from api.passkey_routes import passkey_router
|
||||||
|
from api.poi_routes import poi_router
|
||||||
from api.rate_limit_config import RateLimitConfig
|
from api.rate_limit_config import RateLimitConfig
|
||||||
from api.rate_limiter import RateLimitMiddleware
|
from api.rate_limiter import RateLimitMiddleware
|
||||||
from api.audit_middleware import AuditLogMiddleware
|
from api.audit_middleware import AuditLogMiddleware
|
||||||
|
|
@ -28,6 +29,8 @@ from services.listing_cache import (
|
||||||
get_cached_features,
|
get_cached_features,
|
||||||
cache_features_batch,
|
cache_features_batch,
|
||||||
)
|
)
|
||||||
|
from repositories.poi_repository import POIRepository
|
||||||
|
from repositories.user_repository import UserRepository
|
||||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||||
from api.metrics import metrics_app
|
from api.metrics import metrics_app
|
||||||
from opentelemetry.metrics import get_meter
|
from opentelemetry.metrics import get_meter
|
||||||
|
|
@ -71,6 +74,7 @@ def get_query_parameters(
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(passkey_router)
|
app.include_router(passkey_router)
|
||||||
|
app.include_router(poi_router)
|
||||||
app.mount("/metrics", metrics_app)
|
app.mount("/metrics", metrics_app)
|
||||||
meter = get_meter(__name__)
|
meter = get_meter(__name__)
|
||||||
request_counter = meter.create_counter(
|
request_counter = meter.create_counter(
|
||||||
|
|
@ -173,6 +177,7 @@ async def _stream_from_db(
|
||||||
query_parameters: QueryParameters,
|
query_parameters: QueryParameters,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
limit: int | None,
|
limit: int | None,
|
||||||
|
poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""Stream GeoJSON features from the database, populating the cache as we go."""
|
"""Stream GeoJSON features from the database, populating the cache as we go."""
|
||||||
repository = ListingRepository(engine)
|
repository = ListingRepository(engine)
|
||||||
|
|
@ -193,6 +198,9 @@ async def _stream_from_db(
|
||||||
query_parameters, limit=limit, page_size=batch_size
|
query_parameters, limit=limit, page_size=batch_size
|
||||||
):
|
):
|
||||||
feature = convert_row_to_geojson(row, query_parameters.listing_type.value)
|
feature = convert_row_to_geojson(row, query_parameters.listing_type.value)
|
||||||
|
# Inject POI distances if available
|
||||||
|
if poi_distances_lookup and row['id'] in poi_distances_lookup:
|
||||||
|
feature['properties']['poi_distances'] = poi_distances_lookup[row['id']]
|
||||||
batch.append(feature)
|
batch.append(feature)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
@ -214,6 +222,7 @@ async def stream_listing_geojson(
|
||||||
query_parameters: Annotated[QueryParameters, Depends(get_query_parameters)],
|
query_parameters: Annotated[QueryParameters, Depends(get_query_parameters)],
|
||||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
|
include_poi_distances: bool = False,
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
"""Stream listings as NDJSON for progressive map loading.
|
"""Stream listings as NDJSON for progressive map loading.
|
||||||
|
|
||||||
|
|
@ -228,11 +237,39 @@ async def stream_listing_geojson(
|
||||||
else:
|
else:
|
||||||
limit = _rate_limit_config.geojson_stream_limit_cap
|
limit = _rate_limit_config.geojson_stream_limit_cap
|
||||||
|
|
||||||
|
# Build POI distances lookup if requested
|
||||||
|
poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None
|
||||||
|
if include_poi_distances:
|
||||||
|
user_repo = UserRepository(engine)
|
||||||
|
db_user = user_repo.get_user_by_email(user.email)
|
||||||
|
if db_user and db_user.id is not None:
|
||||||
|
poi_repo = POIRepository(engine)
|
||||||
|
pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)}
|
||||||
|
if pois:
|
||||||
|
# Get all listing IDs first for the query
|
||||||
|
listing_repo = ListingRepository(engine)
|
||||||
|
all_ids = list(listing_repo.get_listing_ids(query_parameters.listing_type))
|
||||||
|
if all_ids:
|
||||||
|
distances = poi_repo.get_distances_for_listings(
|
||||||
|
all_ids, query_parameters.listing_type, db_user.id
|
||||||
|
)
|
||||||
|
poi_distances_lookup = {}
|
||||||
|
for d in distances:
|
||||||
|
poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown"
|
||||||
|
entry = {
|
||||||
|
"poi_id": d.poi_id,
|
||||||
|
"poi_name": poi_name,
|
||||||
|
"travel_mode": d.travel_mode,
|
||||||
|
"duration_seconds": d.duration_seconds,
|
||||||
|
"distance_meters": d.distance_meters,
|
||||||
|
}
|
||||||
|
poi_distances_lookup.setdefault(d.listing_id, []).append(entry)
|
||||||
|
|
||||||
cached_count = get_cached_count(query_parameters)
|
cached_count = get_cached_count(query_parameters)
|
||||||
if cached_count is not None and cached_count > 0:
|
if cached_count is not None and cached_count > 0 and not include_poi_distances:
|
||||||
generator = _stream_from_cache(query_parameters, batch_size, limit)
|
generator = _stream_from_cache(query_parameters, batch_size, limit)
|
||||||
else:
|
else:
|
||||||
generator = _stream_from_db(query_parameters, batch_size, limit)
|
generator = _stream_from_db(query_parameters, batch_size, limit, poi_distances_lookup)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generator,
|
generator,
|
||||||
|
|
|
||||||
200
api/poi_routes.py
Normal file
200
api/poi_routes.py
Normal file
|
|
@ -0,0 +1,200 @@
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from api.auth import User, get_current_user
|
||||||
|
from database import engine
|
||||||
|
from models.listing import ListingType
|
||||||
|
from repositories.poi_repository import POIRepository
|
||||||
|
from repositories.user_repository import UserRepository
|
||||||
|
from services import poi_service, task_service
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
|
poi_router = APIRouter(prefix="/api/poi", tags=["poi"])
|
||||||
|
|
||||||
|
|
||||||
|
class CreatePOIRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
address: str
|
||||||
|
latitude: float
|
||||||
|
longitude: float
|
||||||
|
|
||||||
|
|
||||||
|
class UpdatePOIRequest(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
address: str | None = None
|
||||||
|
latitude: float | None = None
|
||||||
|
longitude: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class POIResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
address: str
|
||||||
|
latitude: float
|
||||||
|
longitude: float
|
||||||
|
created_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class CalculateRequest(BaseModel):
|
||||||
|
travel_modes: list[str] # WALK, BICYCLE, TRANSIT
|
||||||
|
listing_type: ListingType = ListingType.RENT
|
||||||
|
listing_ids: list[int] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class POIDistanceResponse(BaseModel):
|
||||||
|
poi_id: int
|
||||||
|
poi_name: str
|
||||||
|
travel_mode: str
|
||||||
|
duration_seconds: int
|
||||||
|
distance_meters: int
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_id(user: User) -> int:
|
||||||
|
"""Resolve auth User to database user ID."""
|
||||||
|
user_repo = UserRepository(engine)
|
||||||
|
db_user = user_repo.get_user_by_email(user.email)
|
||||||
|
if db_user is None:
|
||||||
|
# Auto-create user on first POI interaction
|
||||||
|
db_user = user_repo.create_user(user.email)
|
||||||
|
assert db_user.id is not None
|
||||||
|
return db_user.id
|
||||||
|
|
||||||
|
|
||||||
|
def _poi_to_response(poi: "poi_service.PointOfInterest") -> POIResponse:
|
||||||
|
return POIResponse(
|
||||||
|
id=poi.id, # type: ignore[arg-type]
|
||||||
|
name=poi.name,
|
||||||
|
address=poi.address,
|
||||||
|
latitude=poi.latitude,
|
||||||
|
longitude=poi.longitude,
|
||||||
|
created_at=poi.created_at.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@poi_router.get("", response_model=list[POIResponse])
|
||||||
|
async def list_pois(
|
||||||
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
|
) -> list[POIResponse]:
|
||||||
|
"""List all POIs for the current user."""
|
||||||
|
user_id = _get_user_id(user)
|
||||||
|
repo = POIRepository(engine)
|
||||||
|
pois = poi_service.get_user_pois(repo, user_id)
|
||||||
|
return [_poi_to_response(p) for p in pois]
|
||||||
|
|
||||||
|
|
||||||
|
@poi_router.post("", response_model=POIResponse)
|
||||||
|
async def create_poi(
|
||||||
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
|
body: CreatePOIRequest,
|
||||||
|
) -> POIResponse:
|
||||||
|
"""Create a new POI."""
|
||||||
|
user_id = _get_user_id(user)
|
||||||
|
repo = POIRepository(engine)
|
||||||
|
result = poi_service.create_poi(
|
||||||
|
repo,
|
||||||
|
user_id=user_id,
|
||||||
|
name=body.name,
|
||||||
|
address=body.address,
|
||||||
|
latitude=body.latitude,
|
||||||
|
longitude=body.longitude,
|
||||||
|
)
|
||||||
|
return _poi_to_response(result.poi)
|
||||||
|
|
||||||
|
|
||||||
|
@poi_router.put("/{poi_id}", response_model=POIResponse)
|
||||||
|
async def update_poi(
|
||||||
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
|
poi_id: int,
|
||||||
|
body: UpdatePOIRequest,
|
||||||
|
) -> POIResponse:
|
||||||
|
"""Update an existing POI."""
|
||||||
|
user_id = _get_user_id(user)
|
||||||
|
repo = POIRepository(engine)
|
||||||
|
result = poi_service.update_poi(
|
||||||
|
repo,
|
||||||
|
poi_id=poi_id,
|
||||||
|
user_id=user_id,
|
||||||
|
name=body.name,
|
||||||
|
address=body.address,
|
||||||
|
latitude=body.latitude,
|
||||||
|
longitude=body.longitude,
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
raise HTTPException(status_code=404, detail="POI not found")
|
||||||
|
return _poi_to_response(result.poi)
|
||||||
|
|
||||||
|
|
||||||
|
@poi_router.delete("/{poi_id}")
|
||||||
|
async def delete_poi(
|
||||||
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
|
poi_id: int,
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a POI and its associated distances."""
|
||||||
|
user_id = _get_user_id(user)
|
||||||
|
repo = POIRepository(engine)
|
||||||
|
deleted = poi_service.delete_poi(repo, poi_id, user_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="POI not found")
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
|
||||||
|
@poi_router.post("/{poi_id}/calculate")
|
||||||
|
async def calculate_distances(
|
||||||
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
|
poi_id: int,
|
||||||
|
body: CalculateRequest,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Trigger distance calculation for a POI."""
|
||||||
|
user_id = _get_user_id(user)
|
||||||
|
repo = POIRepository(engine)
|
||||||
|
|
||||||
|
# Verify POI exists and belongs to user
|
||||||
|
poi = poi_service.get_poi(repo, poi_id)
|
||||||
|
if poi is None or poi.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=404, detail="POI not found")
|
||||||
|
|
||||||
|
result = poi_service.trigger_calculation(
|
||||||
|
poi_id=poi_id,
|
||||||
|
travel_modes=body.travel_modes,
|
||||||
|
listing_type=body.listing_type,
|
||||||
|
user_email=user.email,
|
||||||
|
listing_ids=body.listing_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.task_id:
|
||||||
|
task_service.add_task_for_user(user.email, result.task_id)
|
||||||
|
|
||||||
|
return {"task_id": result.task_id or "", "message": result.message}
|
||||||
|
|
||||||
|
|
||||||
|
@poi_router.get("/distances")
|
||||||
|
async def get_distances(
|
||||||
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
|
listing_id: int,
|
||||||
|
listing_type: ListingType = ListingType.RENT,
|
||||||
|
) -> list[POIDistanceResponse]:
|
||||||
|
"""Get POI distances for a specific listing."""
|
||||||
|
user_id = _get_user_id(user)
|
||||||
|
repo = POIRepository(engine)
|
||||||
|
poi_repo_pois = {
|
||||||
|
p.id: p for p in poi_service.get_user_pois(repo, user_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
distances = poi_service.get_distances_for_listing(
|
||||||
|
repo, listing_id, listing_type, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
POIDistanceResponse(
|
||||||
|
poi_id=d.poi_id,
|
||||||
|
poi_name=poi_repo_pois[d.poi_id].name if d.poi_id in poi_repo_pois else "Unknown",
|
||||||
|
travel_mode=d.travel_mode,
|
||||||
|
duration_seconds=d.duration_seconds,
|
||||||
|
distance_meters=d.distance_meters,
|
||||||
|
)
|
||||||
|
for d in distances
|
||||||
|
]
|
||||||
|
|
@ -9,7 +9,7 @@ app = Celery(
|
||||||
"celery_app",
|
"celery_app",
|
||||||
broker=os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
|
broker=os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
|
||||||
backend=os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/1"),
|
backend=os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/1"),
|
||||||
include=["tasks.listing_tasks"],
|
include=["tasks.listing_tasks", "tasks.poi_tasks"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.conf.update(
|
app.conf.update(
|
||||||
|
|
|
||||||
92
tasks/poi_tasks.py
Normal file
92
tasks/poi_tasks.py
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
"""Celery tasks for POI distance calculation."""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from celery import Task
|
||||||
|
from celery_app import app
|
||||||
|
from database import engine
|
||||||
|
from models.listing import ListingType
|
||||||
|
from repositories.listing_repository import ListingRepository
|
||||||
|
from repositories.poi_repository import POIRepository
|
||||||
|
from services.poi_distance_calculator import calculate_poi_distances
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn.error")
|
||||||
|
|
||||||
|
celery_logger = logging.getLogger("celery.task")
|
||||||
|
if not celery_logger.handlers:
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(logging.Formatter(
|
||||||
|
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||||||
|
))
|
||||||
|
celery_logger.addHandler(handler)
|
||||||
|
celery_logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(bind=True)
|
||||||
|
def calculate_poi_distances_task(
|
||||||
|
self: Task,
|
||||||
|
poi_id: int,
|
||||||
|
travel_modes: list[str],
|
||||||
|
listing_type: str,
|
||||||
|
listing_ids: list[int] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Background task to calculate distances from listings to a POI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
poi_id: ID of the PointOfInterest.
|
||||||
|
travel_modes: List of travel modes (WALK, BICYCLE, TRANSIT).
|
||||||
|
listing_type: "BUY" or "RENT".
|
||||||
|
listing_ids: Optional subset of listing IDs.
|
||||||
|
"""
|
||||||
|
celery_logger.info(
|
||||||
|
f"Starting POI distance calculation: poi_id={poi_id}, "
|
||||||
|
f"modes={travel_modes}, type={listing_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_state(state="PROGRESS", meta={
|
||||||
|
"phase": "starting",
|
||||||
|
"progress": 0,
|
||||||
|
"message": "Starting distance calculation...",
|
||||||
|
})
|
||||||
|
|
||||||
|
listing_repo = ListingRepository(engine)
|
||||||
|
poi_repo = POIRepository(engine)
|
||||||
|
|
||||||
|
poi = poi_repo.get_poi_by_id(poi_id)
|
||||||
|
if poi is None:
|
||||||
|
celery_logger.error(f"POI {poi_id} not found")
|
||||||
|
return {"error": f"POI {poi_id} not found", "distances_computed": 0}
|
||||||
|
|
||||||
|
lt = ListingType(listing_type)
|
||||||
|
|
||||||
|
def on_progress(completed: int, total: int, message: str) -> None:
|
||||||
|
progress = round(completed / total, 2) if total > 0 else 0
|
||||||
|
self.update_state(state="PROGRESS", meta={
|
||||||
|
"phase": "computing",
|
||||||
|
"progress": progress,
|
||||||
|
"processed": completed,
|
||||||
|
"total": total,
|
||||||
|
"message": message,
|
||||||
|
})
|
||||||
|
|
||||||
|
total = asyncio.run(
|
||||||
|
calculate_poi_distances(
|
||||||
|
listing_repo=listing_repo,
|
||||||
|
poi_repo=poi_repo,
|
||||||
|
poi=poi,
|
||||||
|
travel_modes=travel_modes,
|
||||||
|
listing_type=lt,
|
||||||
|
listing_ids=listing_ids,
|
||||||
|
on_progress=on_progress,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
celery_logger.info(f"POI distance calculation complete: {total} distances computed")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"phase": "completed",
|
||||||
|
"progress": 1,
|
||||||
|
"distances_computed": total,
|
||||||
|
"message": f"Computed {total} distances for POI '{poi.name}'",
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue