"""FastAPI application for the Real Estate Crawler API.""" from datetime import datetime, timedelta import json import logging import logging.config from typing import Annotated, AsyncGenerator, Optional from api.auth import get_current_user from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS, APP_ENV from api.passkey_routes import passkey_router from api.poi_routes import poi_router from api.ws_routes import ws_router from api.rate_limit_config import RateLimitConfig from api.rate_limiter import RateLimitMiddleware from api.audit_middleware import AuditLogMiddleware from api.metrics_guard import MetricsGuardMiddleware from api.security_headers import SecurityHeadersMiddleware from api.origin_validator import OriginValidatorMiddleware from dotenv import load_dotenv from fastapi import Depends, FastAPI, HTTPException, Query from fastapi.responses import JSONResponse, StreamingResponse from starlette.requests import Request from api.auth import User from models.listing import QueryParameters, ListingType, FurnishType from notifications import send_notification from repositories.listing_repository import ListingRepository from database import engine from fastapi.middleware.cors import CORSMiddleware from ui_exporter import convert_to_geojson_feature, convert_row_to_geojson from services import listing_service, export_service, district_service, task_service from services.listing_cache import ( get_cached_count, get_cached_features, begin_cache_population, cache_features_batch_staged, finalize_cache_population, delete_staging_key, ) from repositories.poi_repository import POIRepository from repositories.user_repository import UserRepository from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from api.metrics import metrics_app from opentelemetry.metrics import get_meter load_dotenv() logger = logging.getLogger("uvicorn") DEFAULT_BATCH_SIZE = 50 _rate_limit_config = RateLimitConfig.from_env() def get_query_parameters( listing_type: ListingType, min_bedrooms: int = 1, max_bedrooms: int = 999, min_price: int = 0, max_price: int = 10_000_000, min_sqm: Optional[int] = None, max_sqm: Optional[int] = None, min_price_per_sqm: Optional[float] = None, max_price_per_sqm: Optional[float] = None, last_seen_days: Optional[int] = None, let_date_available_from: Optional[datetime] = None, furnish_types: Optional[str] = None, # comma-separated list district_names: Optional[str] = None, # comma-separated list ) -> QueryParameters: """Parse query parameters into QueryParameters model.""" parsed_furnish_types = None if furnish_types: parsed_furnish_types = [FurnishType(f.strip()) for f in furnish_types.split(",")] parsed_district_names: set[str] = set() if district_names: parsed_district_names = {d.strip() for d in district_names.split(",") if d.strip()} return QueryParameters( listing_type=listing_type, min_bedrooms=min_bedrooms, max_bedrooms=max_bedrooms, min_price=min_price, max_price=max_price, min_sqm=min_sqm, max_sqm=max_sqm, min_price_per_sqm=min_price_per_sqm, max_price_per_sqm=max_price_per_sqm, last_seen_days=last_seen_days, let_date_available_from=let_date_available_from, furnish_types=parsed_furnish_types, district_names=parsed_district_names, ) app = FastAPI( docs_url=None if APP_ENV == "production" else "/docs", redoc_url=None if APP_ENV == "production" else "/redoc", openapi_url=None if APP_ENV == "production" else "/openapi.json", ) app.include_router(passkey_router) app.include_router(poi_router) app.include_router(ws_router) app.mount("/metrics", metrics_app) meter = get_meter(__name__) request_counter = meter.create_counter( name="custom_request_count", description="Number of times /hello was called", ) hist = meter.create_histogram( name="custom_request_duration", description="Duration of /hello requests in seconds", ) # Allow CORS (for React frontend) app.add_middleware( CORSMiddleware, allow_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS], allow_methods=["GET", "POST", "PUT", "DELETE"], allow_headers=["Authorization", "Content-Type"], ) app.add_middleware( OriginValidatorMiddleware, allowed_origins=[*DEV_TIER_ORIGINS, *PROD_TIER_ORIGINS], ) # Security middleware (added bottom-to-top; last added = outermost) # 3. Rate limiting — enforces per-user limits app.add_middleware(RateLimitMiddleware, config=_rate_limit_config) # 2. Metrics guard — blocks unauthorized /metrics access app.add_middleware(MetricsGuardMiddleware, config=_rate_limit_config) # 1. Audit logging — logs everything including 429s and 403s app.add_middleware(AuditLogMiddleware) # 0. Security headers — adds standard security headers to all responses app.add_middleware(SecurityHeadersMiddleware) @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: logger.exception("Unhandled exception") return JSONResponse( status_code=500, content={"detail": "Internal server error"}, ) @app.get("/api/status") async def get_status() -> dict[str, str]: request_counter.add(1, {"method": "GET", "path": "/status"}) hist.record(1.5, {"method": "GET", "path": "/status"}) return {"status": "OK"} @app.get("/api/listing") async def get_listing( user: Annotated[User, Depends(get_current_user)], limit: int = 5, ) -> dict[str, list]: """Get listings from the database.""" limit = min(limit, _rate_limit_config.listing_limit_cap) repository = ListingRepository(engine) result = await listing_service.get_listings(repository, limit=limit) logger.info(f"Fetched {result.total_count} listings for {user.email}") return {"listings": result.listings} @app.get("/api/listing_geojson") async def get_listing_geojson( user: Annotated[User, Depends(get_current_user)], query_parameters: Annotated[QueryParameters, Depends(get_query_parameters)], limit: int | None = None, ) -> dict: """Get listings as GeoJSON for map display.""" if limit is not None: limit = min(limit, _rate_limit_config.geojson_limit_cap) else: limit = _rate_limit_config.geojson_limit_cap repository = ListingRepository(engine) result = await export_service.export_to_geojson( repository, query_parameters=query_parameters, limit=limit, ) return result.data def _build_poi_distances_lookup( user_email: str, listing_type: ListingType, ) -> dict[int, list[dict[str, str | int]]] | None: """Build POI distance lookup for a user, or None if no POIs configured.""" user_repo = UserRepository(engine) db_user = user_repo.get_user_by_email(user_email) if not db_user or db_user.id is None: return None poi_repo = POIRepository(engine) pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)} if not pois: return None listing_repo = ListingRepository(engine) all_ids = list(listing_repo.get_listing_ids(listing_type)) if not all_ids: return None distances = poi_repo.get_distances_for_listings(all_ids, listing_type, db_user.id) lookup: dict[int, list[dict[str, str | int]]] = {} for d in distances: poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown" lookup.setdefault(d.listing_id, []).append({ "poi_id": d.poi_id, "poi_name": poi_name, "travel_mode": d.travel_mode, "duration_seconds": d.duration_seconds, "distance_meters": d.distance_meters, }) return lookup async def _stream_from_cache( query_parameters: QueryParameters, batch_size: int, limit: int | None, ) -> AsyncGenerator[str, None]: """Stream GeoJSON features from the Redis cache (cache-hit path).""" cached_count = get_cached_count(query_parameters) effective_total = min(limit, cached_count) if limit and cached_count else cached_count yield json.dumps({ "type": "metadata", "batch_size": batch_size, "total_expected": effective_total, "cached": True, }) + "\n" count = 0 for feature_batch in get_cached_features(query_parameters, batch_size=batch_size): if limit and count + len(feature_batch) > limit: feature_batch = feature_batch[:limit - count] count += len(feature_batch) yield json.dumps({"type": "batch", "features": feature_batch}) + "\n" if limit and count >= limit: break yield json.dumps({"type": "complete", "total": count}) + "\n" async def _stream_from_db( query_parameters: QueryParameters, batch_size: int, limit: int | None, poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None, skip_cache: bool = False, ) -> AsyncGenerator[str, None]: """Stream GeoJSON features from the database, populating the cache as we go.""" repository = ListingRepository(engine) total = repository.count_listings(query_parameters) effective_total = min(limit, total) if limit else total yield json.dumps({ "type": "metadata", "batch_size": batch_size, "total_expected": effective_total, "cached": False, }) + "\n" staging_key: str | None = None if not skip_cache: staging_key = begin_cache_population(query_parameters) try: count = 0 batch: list[dict] = [] for row in repository.stream_listings_optimized( query_parameters, limit=limit, page_size=batch_size ): feature = convert_row_to_geojson(row, query_parameters.listing_type.value) # Inject POI distances if available if poi_distances_lookup and row['id'] in poi_distances_lookup: feature['properties']['poi_distances'] = poi_distances_lookup[row['id']] batch.append(feature) count += 1 if len(batch) >= batch_size: if staging_key: cache_features_batch_staged(staging_key, batch) yield json.dumps({"type": "batch", "features": batch}) + "\n" batch = [] if batch: if staging_key: cache_features_batch_staged(staging_key, batch) yield json.dumps({"type": "batch", "features": batch}) + "\n" # Atomically promote staged data to live cache if staging_key: finalize_cache_population(staging_key, query_parameters) staging_key = None # Mark as finalized yield json.dumps({"type": "complete", "total": count}) + "\n" finally: # Clean up orphaned staging key on failure if staging_key: delete_staging_key(staging_key) @app.get("/api/listing_geojson/stream") async def stream_listing_geojson( user: Annotated[User, Depends(get_current_user)], query_parameters: Annotated[QueryParameters, Depends(get_query_parameters)], batch_size: int = DEFAULT_BATCH_SIZE, limit: int | None = None, include_poi_distances: bool = False, ) -> StreamingResponse: """Stream listings as NDJSON for progressive map loading. Returns newline-delimited JSON with three message types: - metadata: Initial message with batch_size and total_expected count - batch: Array of GeoJSON features - complete: Final message with total count """ batch_size = min(batch_size, _rate_limit_config.geojson_stream_batch_size_cap) if limit is not None: limit = min(limit, _rate_limit_config.geojson_stream_limit_cap) else: limit = _rate_limit_config.geojson_stream_limit_cap # Build POI distances lookup if requested poi_distances_lookup = _build_poi_distances_lookup(user.email, query_parameters.listing_type) if include_poi_distances else None cached_count = get_cached_count(query_parameters) if cached_count is not None and cached_count > 0 and not include_poi_distances: generator = _stream_from_cache(query_parameters, batch_size, limit) else: generator = _stream_from_db( query_parameters, batch_size, limit, poi_distances_lookup, skip_cache=include_poi_distances, ) return StreamingResponse( generator, media_type="application/x-ndjson", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", # Disable nginx buffering } ) @app.post("/api/refresh_listings") async def refresh_listings( user: Annotated[User, Depends(get_current_user)], query_parameters: Annotated[QueryParameters, Depends(get_query_parameters)], ) -> dict[str, str]: """Trigger a background task to refresh listings.""" await send_notification( f"{user.email} refreshing listings with query parameters {query_parameters.model_dump_json()}" ) repository = ListingRepository(engine) result = await listing_service.refresh_listings( repository, query_parameters, async_mode=True, user_email=user.email, ) # Track task for user if result.task_id: task_service.add_task_for_user(user.email, result.task_id) return {"task_id": result.task_id or "", "message": result.message} @app.get("/api/task_status") async def get_task_status( user: Annotated[User, Depends(get_current_user)], task_id: str, ) -> dict[str, str | int | float | None]: """Get the status of a background task.""" user_tasks = task_service.get_user_tasks(user.email) if task_id not in user_tasks: raise HTTPException(status_code=404, detail="Task not found") status = task_service.get_task_status(task_id) return { "task_id": status.task_id, "status": status.status, "result": json.dumps(status.result) if status.result else None, "progress": status.progress, "processed": status.processed, "total": status.total, "message": status.message, "error": status.error if APP_ENV != "production" else None, "traceback": status.traceback if APP_ENV != "production" else None, } @app.get("/api/tasks_for_user") async def get_tasks_for_user( user: Annotated[User, Depends(get_current_user)], ) -> list[str]: """Get all task IDs for the current user.""" return task_service.get_user_tasks(user.email) @app.post("/api/cancel_task") async def cancel_task( user: Annotated[User, Depends(get_current_user)], task_id: str = Query(..., description="The task ID to cancel"), ) -> dict[str, str | bool]: """Cancel a running task and remove it from the user's task list.""" # Verify user owns this task user_tasks = task_service.get_user_tasks(user.email) if task_id not in user_tasks: raise HTTPException(status_code=404, detail="Task not found or not owned by user") try: task_service.cancel_task(task_id, user_email=user.email) logger.info(f"Task {task_id} cancelled by {user.email}") return {"success": True, "message": "Task cancelled"} except Exception as e: logger.error(f"Failed to cancel task {task_id}: {e}") return {"success": False, "message": str(e)} @app.post("/api/clear_all_tasks") async def clear_all_tasks( user: Annotated[User, Depends(get_current_user)], ) -> dict[str, str | int | bool]: """Clear all tasks for the current user.""" try: count = task_service.clear_all_tasks(user.email) logger.info(f"Cleared {count} tasks for {user.email}") return {"success": True, "count": count, "message": f"Cleared {count} tasks"} except Exception as e: logger.error(f"Failed to clear tasks for {user.email}: {e}") return {"success": False, "count": 0, "message": str(e)} @app.get("/api/get_districts") async def get_districts( user: Annotated[User, Depends(get_current_user)], ) -> dict[str, str]: """Get all available districts.""" return district_service.get_all_districts() FastAPIInstrumentor.instrument_app(app)