Flatten repo structure: move crawler/ to root, remove vqa/ and immoweb/
The crawler subdirectory was the only active project. Moving it to the repo root simplifies paths and removes the unnecessary nesting. The vqa/ and immoweb/ directories were legacy/unused and have been removed. Updated .drone.yml, .gitignore, .claude/ docs, and skills to reflect the new flat structure.
This commit is contained in:
parent
e2247be700
commit
eafbc1ac52
221 changed files with 70 additions and 146140 deletions
41
services/__init__.py
Normal file
41
services/__init__.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Services package for real estate crawler.
|
||||
|
||||
This package contains two layers of services:
|
||||
|
||||
## Low-level services (internal implementation):
|
||||
- listing_fetcher: Fetches listing data from Rightmove API
|
||||
- image_fetcher: Downloads floorplan images
|
||||
- floorplan_detector: OCR-based square meter detection from floorplans
|
||||
- route_calculator: Calculates transit routes using Google Maps API
|
||||
|
||||
## High-level services (use these in CLI and API):
|
||||
- listing_service: Unified listing operations (get, refresh, download images, etc.)
|
||||
- export_service: Export listings to CSV, GeoJSON
|
||||
- district_service: District lookup and validation
|
||||
- task_service: Background task management
|
||||
"""
|
||||
# Low-level services (internal)
|
||||
from services.listing_fetcher import dump_listings, dump_listings_full
|
||||
from services.image_fetcher import dump_images
|
||||
from services.floorplan_detector import detect_floorplan
|
||||
from services.route_calculator import calculate_route
|
||||
|
||||
# High-level services (CLI and API should use these)
|
||||
from services import listing_service
|
||||
from services import export_service
|
||||
from services import district_service
|
||||
from services import task_service
|
||||
|
||||
__all__ = [
|
||||
# Low-level
|
||||
"dump_listings",
|
||||
"dump_listings_full",
|
||||
"dump_images",
|
||||
"detect_floorplan",
|
||||
"calculate_route",
|
||||
# High-level
|
||||
"listing_service",
|
||||
"export_service",
|
||||
"district_service",
|
||||
"task_service",
|
||||
]
|
||||
37
services/district_service.py
Normal file
37
services/district_service.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""Unified district service - shared between CLI and HTTP API."""
|
||||
from rec.districts import get_districts as _get_districts
|
||||
|
||||
|
||||
def get_all_districts() -> dict[str, str]:
|
||||
"""Get all available districts with their region IDs.
|
||||
|
||||
Used by:
|
||||
- CLI: --district option choices
|
||||
- API: GET /api/get_districts
|
||||
|
||||
Returns:
|
||||
Dictionary mapping district names to region IDs
|
||||
"""
|
||||
return _get_districts()
|
||||
|
||||
|
||||
def get_district_names() -> list[str]:
|
||||
"""Get list of all district names.
|
||||
|
||||
Returns:
|
||||
List of district names
|
||||
"""
|
||||
return list(_get_districts().keys())
|
||||
|
||||
|
||||
def validate_districts(district_names: list[str]) -> list[str]:
|
||||
"""Validate that district names exist.
|
||||
|
||||
Args:
|
||||
district_names: List of district names to validate
|
||||
|
||||
Returns:
|
||||
List of invalid district names (empty if all valid)
|
||||
"""
|
||||
valid_districts = set(_get_districts().keys())
|
||||
return [d for d in district_names if d not in valid_districts]
|
||||
92
services/export_service.py
Normal file
92
services/export_service.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
"""Unified export service - shared between CLI and HTTP API.
|
||||
|
||||
This module provides export functionality for listings in various formats.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from models.listing import QueryParameters
|
||||
from repositories.listing_repository import ListingRepository
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportResult:
|
||||
"""Result of an export operation."""
|
||||
success: bool
|
||||
output_path: str | None # For file exports
|
||||
data: Any | None # For in-memory exports (GeoJSON)
|
||||
record_count: int
|
||||
message: str
|
||||
|
||||
|
||||
async def export_to_csv(
|
||||
repository: ListingRepository,
|
||||
output_path: Path,
|
||||
query_parameters: QueryParameters | None = None,
|
||||
) -> ExportResult:
|
||||
"""Export listings to CSV file.
|
||||
|
||||
Used by:
|
||||
- CLI: export-csv
|
||||
- API: (could be added as download endpoint)
|
||||
"""
|
||||
from csv_exporter import export_to_csv as _export_csv
|
||||
|
||||
await _export_csv(repository, output_path, query_parameters)
|
||||
|
||||
listings = await repository.get_listings(query_parameters=query_parameters)
|
||||
return ExportResult(
|
||||
success=True,
|
||||
output_path=str(output_path),
|
||||
data=None,
|
||||
record_count=len(listings),
|
||||
message=f"Exported {len(listings)} listings to {output_path}",
|
||||
)
|
||||
|
||||
|
||||
async def export_to_geojson(
|
||||
repository: ListingRepository,
|
||||
query_parameters: QueryParameters | None = None,
|
||||
output_path: Path | None = None,
|
||||
limit: int | None = None,
|
||||
) -> ExportResult:
|
||||
"""Export listings to GeoJSON format.
|
||||
|
||||
Args:
|
||||
repository: Database repository
|
||||
query_parameters: Filtering parameters
|
||||
output_path: If provided, write to file. Otherwise return data.
|
||||
limit: Maximum number of listings to export
|
||||
|
||||
Used by:
|
||||
- CLI: export-immoweb
|
||||
- API: GET /api/listing_geojson
|
||||
"""
|
||||
from ui_exporter import export_immoweb
|
||||
|
||||
geojson_data = await export_immoweb(
|
||||
repository,
|
||||
output_file=str(output_path) if output_path else None,
|
||||
query_parameters=query_parameters,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
feature_count = len(geojson_data.get("features", [])) if geojson_data else 0
|
||||
|
||||
if output_path:
|
||||
return ExportResult(
|
||||
success=True,
|
||||
output_path=str(output_path),
|
||||
data=None,
|
||||
record_count=feature_count,
|
||||
message=f"Exported {feature_count} listings to {output_path}",
|
||||
)
|
||||
|
||||
return ExportResult(
|
||||
success=True,
|
||||
output_path=None,
|
||||
data=geojson_data,
|
||||
record_count=feature_count,
|
||||
message=f"Generated GeoJSON with {feature_count} features",
|
||||
)
|
||||
47
services/floorplan_detector.py
Normal file
47
services/floorplan_detector.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""Floorplan detector service - OCR-based square meter detection."""
|
||||
import asyncio
|
||||
from models import Listing
|
||||
from rec import floorplan
|
||||
from repositories.listing_repository import ListingRepository
|
||||
from tqdm.asyncio import tqdm
|
||||
import multiprocessing
|
||||
|
||||
# Use a quarter of available CPUs to avoid starving other processes
|
||||
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
|
||||
|
||||
|
||||
async def detect_floorplan(repository: ListingRepository) -> None:
|
||||
"""Detect square meters from floorplan images for all listings."""
|
||||
listings = await repository.get_listings()
|
||||
semaphore = asyncio.Semaphore(MAX_OCR_WORKERS)
|
||||
|
||||
updated_listings = [
|
||||
listing
|
||||
for listing in await tqdm.gather(
|
||||
*[_calculate_sqm_ocr(listing, semaphore) for listing in listings]
|
||||
)
|
||||
if listing is not None
|
||||
]
|
||||
await repository.upsert_listings(updated_listings)
|
||||
|
||||
|
||||
async def _calculate_sqm_ocr(
|
||||
listing: Listing, semaphore: asyncio.Semaphore
|
||||
) -> Listing | None:
|
||||
"""Calculate square meters from floorplan images using OCR."""
|
||||
if listing.square_meters is not None:
|
||||
return None
|
||||
if not listing.floorplan_image_paths:
|
||||
listing.square_meters = 0
|
||||
return listing
|
||||
sqms: list[float] = []
|
||||
for floorplan_path in listing.floorplan_image_paths:
|
||||
async with semaphore:
|
||||
estimated_sqm, _ = await asyncio.to_thread(
|
||||
floorplan.calculate_ocr, floorplan_path
|
||||
)
|
||||
if estimated_sqm is not None:
|
||||
sqms.append(estimated_sqm)
|
||||
max_sqm = max(sqms, default=0) # try once, if we fail, keep as 0
|
||||
listing.square_meters = max_sqm
|
||||
return listing
|
||||
88
services/image_fetcher.py
Normal file
88
services/image_fetcher.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
"""Image fetcher service - downloads floorplan images for listings."""
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from repositories import ListingRepository
|
||||
from tenacity import retry, stop_after_attempt, wait_random
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from models import Listing
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum number of concurrent image downloads.
|
||||
# Setting this too high either crashes Rightmove or gets us blocked.
|
||||
MAX_CONCURRENT_DOWNLOADS = 5
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_DOWNLOADS)
|
||||
|
||||
|
||||
async def dump_images(
|
||||
repository: ListingRepository,
|
||||
image_base_path: Path = Path("data/rs/"),
|
||||
) -> None:
|
||||
"""Download floorplan images for all listings."""
|
||||
listings = await repository.get_listings()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
updated_listings = await tqdm.gather(
|
||||
*[
|
||||
dump_images_for_listing(listing, image_base_path, session=session)
|
||||
for listing in listings
|
||||
]
|
||||
)
|
||||
await repository.upsert_listings(
|
||||
[listing for listing in updated_listings if listing is not None]
|
||||
)
|
||||
|
||||
|
||||
@retry(wait=wait_random(min=1, max=2), stop=stop_after_attempt(3))
|
||||
async def dump_images_for_listing(
|
||||
listing: Listing,
|
||||
base_path: Path,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
) -> Listing | None:
|
||||
"""Download floorplan images for a single listing."""
|
||||
all_floorplans = listing.additional_info.get("property", {}).get("floorplans", [])
|
||||
for floorplan in all_floorplans:
|
||||
url = floorplan["url"]
|
||||
picname = Path(urlparse(url).path).name
|
||||
floorplan_path = Path(base_path, str(listing.id), "floorplans", picname)
|
||||
if floorplan_path.exists():
|
||||
continue
|
||||
try:
|
||||
owns_session = session is None
|
||||
active_session = session or aiohttp.ClientSession()
|
||||
try:
|
||||
async with semaphore:
|
||||
async with active_session.get(url) as response:
|
||||
if response.status == 404:
|
||||
logger.warning(
|
||||
"Listing %s: floorplan not found (404) at %s",
|
||||
listing.id,
|
||||
url,
|
||||
)
|
||||
return None
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Error downloading floorplan for listing {listing.id} "
|
||||
f"from {url}: HTTP {response.status}"
|
||||
)
|
||||
floorplan_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(floorplan_path, "wb") as f:
|
||||
f.write(await response.read())
|
||||
listing.floorplan_image_paths.append(str(floorplan_path))
|
||||
return listing
|
||||
finally:
|
||||
if owns_session:
|
||||
await active_session.close()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Listing %s: error downloading floorplan from %s: %s",
|
||||
listing.id,
|
||||
url,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
return None
|
||||
103
services/listing_cache.py
Normal file
103
services/listing_cache.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""Redis-based caching for listing GeoJSON query results."""
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Generator
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import redis
|
||||
|
||||
from models.listing import QueryParameters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_PREFIX = "listings:geojson:"
|
||||
CACHE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
CACHE_DB = 2
|
||||
|
||||
|
||||
def _get_redis_client() -> redis.Redis:
|
||||
"""Get Redis client using Celery broker URL but overriding to db=2."""
|
||||
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
parsed = urlparse(broker_url)
|
||||
cache_url = urlunparse(parsed._replace(path=f"/{CACHE_DB}"))
|
||||
return redis.from_url(cache_url, decode_responses=True)
|
||||
|
||||
|
||||
def make_cache_key(query_params: QueryParameters) -> str:
|
||||
"""Generate a cache key from query parameters."""
|
||||
params_json = query_params.model_dump_json()
|
||||
hash_suffix = hashlib.sha256(params_json.encode()).hexdigest()[:16]
|
||||
return f"{CACHE_PREFIX}{hash_suffix}"
|
||||
|
||||
|
||||
def get_cached_count(query_params: QueryParameters) -> int | None:
|
||||
"""Return the number of cached features for a query, or None if not cached."""
|
||||
try:
|
||||
client = _get_redis_client()
|
||||
key = make_cache_key(query_params)
|
||||
if not client.exists(key):
|
||||
return None
|
||||
return client.llen(key)
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Redis cache read error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_cached_features(
|
||||
query_params: QueryParameters, batch_size: int = 50
|
||||
) -> Generator[list[dict], None, None]:
|
||||
"""Yield batches of cached GeoJSON features."""
|
||||
try:
|
||||
client = _get_redis_client()
|
||||
key = make_cache_key(query_params)
|
||||
total = client.llen(key)
|
||||
|
||||
for start in range(0, total, batch_size):
|
||||
end = start + batch_size - 1
|
||||
items = client.lrange(key, start, end)
|
||||
batch = [json.loads(item) for item in items]
|
||||
if batch:
|
||||
yield batch
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Redis cache read error during streaming: {e}")
|
||||
|
||||
|
||||
def cache_features_batch(query_params: QueryParameters, features: list[dict]) -> None:
|
||||
"""Append a batch of features to the cache list."""
|
||||
if not features:
|
||||
return
|
||||
try:
|
||||
client = _get_redis_client()
|
||||
key = make_cache_key(query_params)
|
||||
pipeline = client.pipeline()
|
||||
for feature in features:
|
||||
pipeline.rpush(key, json.dumps(feature))
|
||||
# Set/refresh TTL
|
||||
pipeline.expire(key, CACHE_TTL_SECONDS)
|
||||
pipeline.execute()
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Redis cache write error: {e}")
|
||||
|
||||
|
||||
def invalidate_cache() -> None:
|
||||
"""Delete all listing GeoJSON cache entries."""
|
||||
try:
|
||||
client = _get_redis_client()
|
||||
cursor = 0
|
||||
deleted = 0
|
||||
while True:
|
||||
cursor, keys = client.scan(cursor, match=f"{CACHE_PREFIX}*", count=100)
|
||||
if keys:
|
||||
pipeline = client.pipeline()
|
||||
for key in keys:
|
||||
pipeline.delete(key)
|
||||
pipeline.execute()
|
||||
deleted += len(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
if deleted:
|
||||
logger.info(f"Invalidated {deleted} listing cache entries")
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Redis cache invalidation error: {e}")
|
||||
211
services/listing_fetcher.py
Normal file
211
services/listing_fetcher.py
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
"""Listing fetcher service - fetches listing data from Rightmove API."""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from config.scraper_config import ScraperConfig
|
||||
from listing_processor import ListingProcessor
|
||||
from rec.query import create_session, listing_query
|
||||
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
|
||||
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
|
||||
from models.listing import Listing, QueryParameters
|
||||
from repositories import ListingRepository
|
||||
from services.query_splitter import QuerySplitter, SubQuery
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
# Number of concurrent workers that process listing details (fetch details,
|
||||
# download images, run OCR) from the streaming queue in parallel.
|
||||
NUM_WORKERS = 20
|
||||
|
||||
|
||||
async def dump_listings_full(
|
||||
parameters: QueryParameters,
|
||||
repository: ListingRepository,
|
||||
) -> list[Listing]:
|
||||
"""Fetches all listings, images as well as detects floorplans."""
|
||||
new_listings = await dump_listings(parameters, repository)
|
||||
logger.debug(f"Upserted {len(new_listings)} new listings")
|
||||
new_listing_ids = [listing.id for listing in new_listings]
|
||||
return await repository.get_listings(only_ids=new_listing_ids)
|
||||
|
||||
|
||||
async def _fetch_subquery(
|
||||
sq: SubQuery,
|
||||
parameters: QueryParameters,
|
||||
session: object,
|
||||
config: ScraperConfig,
|
||||
semaphore: asyncio.Semaphore,
|
||||
existing_ids: set[int],
|
||||
queue: asyncio.Queue[int | None],
|
||||
) -> int:
|
||||
"""Fetch listing IDs for a single subquery and enqueue new ones.
|
||||
|
||||
Iterates through pages of results for the given subquery, adding any
|
||||
newly discovered listing IDs to the processing queue.
|
||||
|
||||
Args:
|
||||
sq: The subquery to fetch results for.
|
||||
parameters: The original query parameters (for page_size, etc.).
|
||||
session: The aiohttp session for making requests.
|
||||
config: Scraper configuration.
|
||||
semaphore: Concurrency limiter for HTTP requests.
|
||||
existing_ids: Set of already-known listing IDs (mutated in place).
|
||||
queue: Queue to push new listing IDs onto for processing.
|
||||
|
||||
Returns:
|
||||
The number of new IDs discovered and enqueued.
|
||||
"""
|
||||
estimated = sq.estimated_results or 0
|
||||
if estimated == 0:
|
||||
return 0
|
||||
|
||||
ids_found = 0
|
||||
page_size = parameters.page_size
|
||||
max_pages = min(
|
||||
config.max_pages_per_query,
|
||||
(estimated // page_size) + 1,
|
||||
)
|
||||
|
||||
for page_id in range(1, max_pages + 1):
|
||||
async with semaphore:
|
||||
await asyncio.sleep(config.request_delay_ms / 1000)
|
||||
try:
|
||||
result = await listing_query(
|
||||
page=page_id,
|
||||
channel=parameters.listing_type,
|
||||
min_bedrooms=sq.min_bedrooms,
|
||||
max_bedrooms=sq.max_bedrooms,
|
||||
radius=parameters.radius,
|
||||
min_price=sq.min_price,
|
||||
max_price=sq.max_price,
|
||||
district=sq.district,
|
||||
page_size=page_size,
|
||||
max_days_since_added=parameters.max_days_since_added,
|
||||
furnish_types=parameters.furnish_types or [],
|
||||
session=session,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Extract and enqueue new IDs inline
|
||||
properties = result.get("properties", [])
|
||||
for prop in properties:
|
||||
identifier = prop.get("identifier")
|
||||
if identifier and identifier not in existing_ids:
|
||||
existing_ids.add(identifier)
|
||||
ids_found += 1
|
||||
await queue.put(identifier)
|
||||
|
||||
if len(properties) < page_size:
|
||||
break
|
||||
|
||||
except CircuitBreakerOpenError as e:
|
||||
logger.error(f"Circuit breaker open: {e}")
|
||||
break
|
||||
except ThrottlingError as e:
|
||||
logger.warning(
|
||||
f"Throttling error on page {page_id} for "
|
||||
f"{sq.district}: {e}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
# Rightmove returns GENERIC_ERROR when requesting pages
|
||||
# past the last page of results. This is expected behavior
|
||||
# and signals we've exhausted this subquery's results.
|
||||
if "GENERIC_ERROR" in str(e):
|
||||
logger.debug(
|
||||
f"Max page for {sq.district}: {page_id - 1}"
|
||||
)
|
||||
break
|
||||
logger.warning(
|
||||
f"Error fetching page {page_id} for "
|
||||
f"{sq.district}: {e}"
|
||||
)
|
||||
break
|
||||
|
||||
return ids_found
|
||||
|
||||
|
||||
async def dump_listings(
|
||||
parameters: QueryParameters,
|
||||
repository: ListingRepository,
|
||||
) -> list[Listing]:
|
||||
"""Fetch listings from Rightmove API and process them.
|
||||
|
||||
Uses intelligent query splitting and a streaming pipeline so that
|
||||
listing processing starts as soon as IDs become available.
|
||||
"""
|
||||
config = ScraperConfig.from_env()
|
||||
splitter = QuerySplitter(config)
|
||||
|
||||
# Reset throttle metrics at start
|
||||
reset_throttle_metrics()
|
||||
|
||||
try:
|
||||
async with create_session(config) as session:
|
||||
# Phase 1: Split and probe queries
|
||||
logger.info("Splitting query and probing result counts...")
|
||||
subqueries = await splitter.split(parameters, session)
|
||||
|
||||
total_estimated = splitter.calculate_total_estimated_results(subqueries)
|
||||
logger.info(
|
||||
f"Split into {len(subqueries)} subqueries, "
|
||||
f"estimated {total_estimated} total results"
|
||||
)
|
||||
|
||||
# Load existing IDs (fast, ID-only projection)
|
||||
existing_ids = repository.get_listing_ids(parameters.listing_type)
|
||||
logger.info(f"Found {len(existing_ids)} existing listings in DB")
|
||||
|
||||
# Phase 2: Streaming fetch & process
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
semaphore = asyncio.Semaphore(config.max_concurrent_requests)
|
||||
processed_listings: list[Listing] = []
|
||||
|
||||
async def producer() -> int:
|
||||
"""Fetch all subqueries and send sentinel values to workers."""
|
||||
tasks = [
|
||||
_fetch_subquery(
|
||||
sq, parameters, session, config,
|
||||
semaphore, existing_ids, queue,
|
||||
)
|
||||
for sq in subqueries
|
||||
]
|
||||
counts = await asyncio.gather(*tasks)
|
||||
ids_collected = sum(counts)
|
||||
logger.info(f"Fetch complete: {ids_collected} new IDs found")
|
||||
for _ in range(NUM_WORKERS):
|
||||
await queue.put(None)
|
||||
return ids_collected
|
||||
|
||||
async def worker() -> None:
|
||||
while True:
|
||||
listing_id = await queue.get()
|
||||
if listing_id is None:
|
||||
break
|
||||
listing_processor = ListingProcessor(repository)
|
||||
listing = await listing_processor.process_listing(listing_id)
|
||||
if listing is not None:
|
||||
processed_listings.append(listing)
|
||||
|
||||
results = await asyncio.gather(
|
||||
producer(),
|
||||
*[worker() for _ in range(NUM_WORKERS)],
|
||||
)
|
||||
ids_collected = results[0]
|
||||
|
||||
except CircuitBreakerOpenError as e:
|
||||
logger.error(f"Circuit breaker prevented listing fetch: {e}")
|
||||
logger.info(get_throttle_metrics().summary())
|
||||
return []
|
||||
finally:
|
||||
# Log throttle metrics at end
|
||||
metrics = get_throttle_metrics()
|
||||
if metrics.total_requests > 0:
|
||||
logger.info("\n" + metrics.summary())
|
||||
|
||||
logger.info(
|
||||
f"Processed {len(processed_listings)} new listings "
|
||||
f"({ids_collected} total found)"
|
||||
)
|
||||
|
||||
return processed_listings
|
||||
168
services/listing_service.py
Normal file
168
services/listing_service.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""Unified listing service - shared between CLI and HTTP API.
|
||||
|
||||
This module provides the core business logic for listing operations.
|
||||
Both the CLI (main.py) and HTTP API (api/app.py) should use these functions.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from models.listing import Listing, QueryParameters
|
||||
from repositories.listing_repository import ListingRepository
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListingResult:
|
||||
"""Result of a listing operation."""
|
||||
listings: list[Listing]
|
||||
total_count: int
|
||||
message: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RefreshResult:
|
||||
"""Result of a refresh operation."""
|
||||
task_id: str | None # None if run synchronously
|
||||
new_listings_count: int
|
||||
message: str
|
||||
|
||||
|
||||
async def get_listings(
|
||||
repository: ListingRepository,
|
||||
query_parameters: QueryParameters | None = None,
|
||||
limit: int | None = None,
|
||||
only_ids: list[int] | None = None,
|
||||
) -> ListingResult:
|
||||
"""Get listings from the database with optional filtering.
|
||||
|
||||
Used by:
|
||||
- CLI: export-csv, export-immoweb
|
||||
- API: GET /api/listing, GET /api/listing_geojson
|
||||
"""
|
||||
listings = await repository.get_listings(
|
||||
query_parameters=query_parameters,
|
||||
limit=limit,
|
||||
only_ids=only_ids,
|
||||
)
|
||||
return ListingResult(
|
||||
listings=listings,
|
||||
total_count=len(listings),
|
||||
)
|
||||
|
||||
|
||||
async def refresh_listings(
|
||||
repository: ListingRepository,
|
||||
query_parameters: QueryParameters,
|
||||
full: bool = False,
|
||||
async_mode: bool = False,
|
||||
user_email: str | None = None,
|
||||
) -> RefreshResult:
|
||||
"""Refresh listings by fetching from external API.
|
||||
|
||||
Args:
|
||||
repository: Database repository
|
||||
query_parameters: Filtering parameters
|
||||
full: If True, also fetch images and run OCR
|
||||
async_mode: If True, run as background task and return task_id
|
||||
user_email: User email for tracking (API mode)
|
||||
|
||||
Used by:
|
||||
- CLI: dump-listings
|
||||
- API: POST /api/refresh_listings
|
||||
"""
|
||||
if async_mode:
|
||||
# Import here to avoid circular imports
|
||||
from tasks.listing_tasks import dump_listings_task
|
||||
from datetime import timedelta
|
||||
|
||||
expiry_time = datetime.now() + timedelta(minutes=10)
|
||||
task = dump_listings_task.apply_async(
|
||||
args=(query_parameters.model_dump_json(),),
|
||||
expires=expiry_time,
|
||||
)
|
||||
return RefreshResult(
|
||||
task_id=task.id,
|
||||
new_listings_count=0,
|
||||
message=f"Task {task.id} started",
|
||||
)
|
||||
|
||||
# Synchronous mode - run directly
|
||||
from services.listing_fetcher import dump_listings, dump_listings_full
|
||||
|
||||
if full:
|
||||
new_listings = await dump_listings_full(query_parameters, repository)
|
||||
else:
|
||||
new_listings = await dump_listings(query_parameters, repository)
|
||||
|
||||
return RefreshResult(
|
||||
task_id=None,
|
||||
new_listings_count=len(new_listings),
|
||||
message=f"Fetched {len(new_listings)} new listings",
|
||||
)
|
||||
|
||||
|
||||
async def download_images(
|
||||
repository: ListingRepository,
|
||||
data_dir: Path = Path("data/rs/"),
|
||||
) -> int:
|
||||
"""Download floorplan images for all listings.
|
||||
|
||||
Used by:
|
||||
- CLI: dump-images
|
||||
- API: (could be added)
|
||||
|
||||
Returns:
|
||||
Number of listings processed
|
||||
"""
|
||||
from services.image_fetcher import dump_images
|
||||
|
||||
await dump_images(repository, image_base_path=data_dir)
|
||||
listings = await repository.get_listings()
|
||||
return len(listings)
|
||||
|
||||
|
||||
async def detect_floorplans(
|
||||
repository: ListingRepository,
|
||||
) -> int:
|
||||
"""Run OCR on floorplan images to detect square meters.
|
||||
|
||||
Used by:
|
||||
- CLI: detect-floorplan
|
||||
- API: (could be added)
|
||||
|
||||
Returns:
|
||||
Number of listings processed
|
||||
"""
|
||||
from services.floorplan_detector import detect_floorplan
|
||||
|
||||
await detect_floorplan(repository)
|
||||
listings = await repository.get_listings()
|
||||
return len(listings)
|
||||
|
||||
|
||||
async def calculate_routes(
|
||||
repository: ListingRepository,
|
||||
destination_address: str,
|
||||
travel_mode: str,
|
||||
limit: int | None = None,
|
||||
) -> int:
|
||||
"""Calculate transit routes for listings.
|
||||
|
||||
Used by:
|
||||
- CLI: routing
|
||||
- API: (could be added)
|
||||
|
||||
Returns:
|
||||
Number of listings processed
|
||||
"""
|
||||
from services.route_calculator import calculate_route
|
||||
from rec.routing import TravelMode
|
||||
|
||||
await calculate_route(
|
||||
repository,
|
||||
destination_address,
|
||||
TravelMode[travel_mode],
|
||||
limit=limit,
|
||||
)
|
||||
return limit or 0
|
||||
248
services/passkey_service.py
Normal file
248
services/passkey_service.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import jwt
|
||||
from webauthn import (
|
||||
generate_registration_options,
|
||||
verify_registration_response,
|
||||
generate_authentication_options,
|
||||
verify_authentication_response,
|
||||
)
|
||||
from webauthn.helpers import (
|
||||
options_to_json,
|
||||
parse_registration_credential_json,
|
||||
parse_authentication_credential_json,
|
||||
)
|
||||
from webauthn.helpers.structs import (
|
||||
AuthenticatorSelectionCriteria,
|
||||
PublicKeyCredentialDescriptor,
|
||||
AuthenticatorTransport,
|
||||
ResidentKeyRequirement,
|
||||
UserVerificationRequirement,
|
||||
)
|
||||
from webauthn.helpers.cose import COSEAlgorithmIdentifier
|
||||
|
||||
from api.config import (
|
||||
WEBAUTHN_RP_ID,
|
||||
WEBAUTHN_RP_NAME,
|
||||
WEBAUTHN_ORIGIN,
|
||||
JWT_SECRET,
|
||||
JWT_ALGORITHM,
|
||||
JWT_EXPIRATION_HOURS,
|
||||
JWT_ISSUER,
|
||||
)
|
||||
from models.passkey_credential import PasskeyCredential
|
||||
from repositories.user_repository import UserRepository
|
||||
from redis_repository import RedisRepository
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
CHALLENGE_TTL = timedelta(minutes=5)
|
||||
CHALLENGE_KEY_PREFIX = "webauthn:challenge:"
|
||||
|
||||
|
||||
def _store_challenge(session_id: str, data: dict) -> None: # type: ignore[type-arg]
|
||||
redis = RedisRepository.instance()
|
||||
redis.set_key(f"{CHALLENGE_KEY_PREFIX}{session_id}", data, ttl=CHALLENGE_TTL)
|
||||
|
||||
|
||||
def _get_challenge(session_id: str) -> dict | None: # type: ignore[type-arg]
|
||||
redis = RedisRepository.instance()
|
||||
return redis.get_key(f"{CHALLENGE_KEY_PREFIX}{session_id}") # type: ignore[return-value]
|
||||
|
||||
|
||||
def _issue_jwt(user_id: int, email: str) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"email": email,
|
||||
"name": email,
|
||||
"iss": JWT_ISSUER,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(hours=JWT_EXPIRATION_HOURS),
|
||||
}
|
||||
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
||||
|
||||
|
||||
def begin_registration(
|
||||
email: str, user_repo: UserRepository
|
||||
) -> tuple[dict, str]: # type: ignore[type-arg]
|
||||
"""Start WebAuthn registration ceremony.
|
||||
|
||||
Returns (options_dict, session_id).
|
||||
"""
|
||||
user = user_repo.get_user_by_email(email)
|
||||
if user is None:
|
||||
user = user_repo.create_user(email)
|
||||
|
||||
existing_credentials = user_repo.get_credentials_for_user(user.id)
|
||||
exclude_credentials = []
|
||||
for cred in existing_credentials:
|
||||
transports = []
|
||||
if cred.transports:
|
||||
transports = [
|
||||
AuthenticatorTransport(t) for t in json.loads(cred.transports)
|
||||
]
|
||||
exclude_credentials.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(cred.credential_id + "=="),
|
||||
transports=transports,
|
||||
)
|
||||
)
|
||||
|
||||
options = generate_registration_options(
|
||||
rp_id=WEBAUTHN_RP_ID,
|
||||
rp_name=WEBAUTHN_RP_NAME,
|
||||
user_id=str(user.id).encode(),
|
||||
user_name=email,
|
||||
user_display_name=email,
|
||||
exclude_credentials=exclude_credentials,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
resident_key=ResidentKeyRequirement.REQUIRED,
|
||||
user_verification=UserVerificationRequirement.PREFERRED,
|
||||
),
|
||||
supported_pub_key_algs=[
|
||||
COSEAlgorithmIdentifier.ECDSA_SHA_256,
|
||||
COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256,
|
||||
],
|
||||
)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
_store_challenge(session_id, {
|
||||
"challenge": base64.urlsafe_b64encode(options.challenge).decode(),
|
||||
"user_id": user.id,
|
||||
"email": email,
|
||||
"type": "registration",
|
||||
})
|
||||
|
||||
options_json = json.loads(options_to_json(options))
|
||||
return options_json, session_id
|
||||
|
||||
|
||||
def complete_registration(
|
||||
session_id: str,
|
||||
credential: dict, # type: ignore[type-arg]
|
||||
user_repo: UserRepository,
|
||||
) -> str:
|
||||
"""Complete WebAuthn registration ceremony.
|
||||
|
||||
Returns a JWT string.
|
||||
"""
|
||||
challenge_data = _get_challenge(session_id)
|
||||
if not challenge_data or challenge_data.get("type") != "registration":
|
||||
raise ValueError("Invalid or expired registration session")
|
||||
|
||||
expected_challenge = base64.urlsafe_b64decode(
|
||||
challenge_data["challenge"] + "=="
|
||||
)
|
||||
|
||||
registration_credential = parse_registration_credential_json(
|
||||
json.dumps(credential)
|
||||
)
|
||||
|
||||
verification = verify_registration_response(
|
||||
credential=registration_credential,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_rp_id=WEBAUTHN_RP_ID,
|
||||
expected_origin=WEBAUTHN_ORIGIN,
|
||||
)
|
||||
|
||||
credential_id_b64 = base64.urlsafe_b64encode(
|
||||
verification.credential_id
|
||||
).decode().rstrip("=")
|
||||
public_key_b64 = base64.urlsafe_b64encode(
|
||||
verification.credential_public_key
|
||||
).decode().rstrip("=")
|
||||
|
||||
transports_json = None
|
||||
if credential.get("response", {}).get("transports"):
|
||||
transports_json = json.dumps(
|
||||
credential["response"]["transports"]
|
||||
)
|
||||
|
||||
passkey_cred = PasskeyCredential(
|
||||
credential_id=credential_id_b64,
|
||||
public_key=public_key_b64,
|
||||
sign_count=verification.sign_count,
|
||||
transports=transports_json,
|
||||
user_id=challenge_data["user_id"],
|
||||
)
|
||||
user_repo.save_credential(passkey_cred)
|
||||
|
||||
return _issue_jwt(challenge_data["user_id"], challenge_data["email"])
|
||||
|
||||
|
||||
def begin_authentication(
|
||||
user_repo: UserRepository,
|
||||
) -> tuple[dict, str]: # type: ignore[type-arg]
|
||||
"""Start WebAuthn authentication ceremony (discoverable credentials).
|
||||
|
||||
Returns (options_dict, session_id).
|
||||
"""
|
||||
options = generate_authentication_options(
|
||||
rp_id=WEBAUTHN_RP_ID,
|
||||
user_verification=UserVerificationRequirement.PREFERRED,
|
||||
)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
_store_challenge(session_id, {
|
||||
"challenge": base64.urlsafe_b64encode(options.challenge).decode(),
|
||||
"type": "authentication",
|
||||
})
|
||||
|
||||
options_json = json.loads(options_to_json(options))
|
||||
return options_json, session_id
|
||||
|
||||
|
||||
def complete_authentication(
|
||||
session_id: str,
|
||||
credential: dict, # type: ignore[type-arg]
|
||||
user_repo: UserRepository,
|
||||
) -> str:
|
||||
"""Complete WebAuthn authentication ceremony.
|
||||
|
||||
Returns a JWT string.
|
||||
"""
|
||||
challenge_data = _get_challenge(session_id)
|
||||
if not challenge_data or challenge_data.get("type") != "authentication":
|
||||
raise ValueError("Invalid or expired authentication session")
|
||||
|
||||
expected_challenge = base64.urlsafe_b64decode(
|
||||
challenge_data["challenge"] + "=="
|
||||
)
|
||||
|
||||
# Look up the credential in the database
|
||||
raw_id = credential.get("rawId") or credential.get("id", "")
|
||||
stored_cred = user_repo.get_credential_by_id(raw_id)
|
||||
if not stored_cred:
|
||||
raise ValueError("Credential not found")
|
||||
|
||||
stored_public_key = base64.urlsafe_b64decode(
|
||||
stored_cred.public_key + "=="
|
||||
)
|
||||
|
||||
auth_credential = parse_authentication_credential_json(
|
||||
json.dumps(credential)
|
||||
)
|
||||
|
||||
verification = verify_authentication_response(
|
||||
credential=auth_credential,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_rp_id=WEBAUTHN_RP_ID,
|
||||
expected_origin=WEBAUTHN_ORIGIN,
|
||||
credential_public_key=stored_public_key,
|
||||
credential_current_sign_count=stored_cred.sign_count,
|
||||
)
|
||||
|
||||
user_repo.update_credential_sign_count(
|
||||
stored_cred.credential_id, verification.new_sign_count
|
||||
)
|
||||
|
||||
user = user_repo.get_user_by_id(stored_cred.user_id)
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
return _issue_jwt(user.id, user.email)
|
||||
335
services/query_splitter.py
Normal file
335
services/query_splitter.py
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
"""Query splitting service for handling Rightmove's result cap.
|
||||
|
||||
This module provides intelligent query splitting to work around Rightmove's
|
||||
~1,500 listing cap per search. It adaptively splits queries by price bands
|
||||
based on actual result counts.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from config.scraper_config import ScraperConfig
|
||||
from models.listing import ListingType, QueryParameters
|
||||
from rec.districts import get_districts
|
||||
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubQuery:
|
||||
"""Represents a single query subdivision.
|
||||
|
||||
Attributes:
|
||||
district: District identifier string.
|
||||
min_bedrooms: Minimum number of bedrooms.
|
||||
max_bedrooms: Maximum number of bedrooms.
|
||||
min_price: Minimum price in currency units.
|
||||
max_price: Maximum price in currency units.
|
||||
estimated_results: Cached result count from probing (None if not probed).
|
||||
"""
|
||||
|
||||
district: str
|
||||
min_bedrooms: int
|
||||
max_bedrooms: int
|
||||
min_price: int
|
||||
max_price: int
|
||||
estimated_results: int | None = None
|
||||
|
||||
@property
|
||||
def price_range(self) -> int:
|
||||
"""Returns the width of the price band."""
|
||||
return self.max_price - self.min_price
|
||||
|
||||
|
||||
class QuerySplitter:
|
||||
"""Splits large queries into smaller subqueries to avoid result caps.
|
||||
|
||||
Uses adaptive binary search on price ranges to find optimal subdivisions
|
||||
that keep each subquery under the result threshold.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ScraperConfig | None = None) -> None:
|
||||
"""Initialize the splitter with configuration.
|
||||
|
||||
Args:
|
||||
config: Scraper configuration. Loads from environment if not provided.
|
||||
"""
|
||||
self.config = config or ScraperConfig.from_env()
|
||||
|
||||
def create_initial_subqueries(
|
||||
self,
|
||||
parameters: QueryParameters,
|
||||
districts: dict[str, str],
|
||||
) -> list[SubQuery]:
|
||||
"""Create initial subqueries by splitting on district and bedrooms.
|
||||
|
||||
This creates the initial split before probing for result counts.
|
||||
Each bedroom count gets its own subquery to enable finer-grained splitting.
|
||||
|
||||
Args:
|
||||
parameters: Original query parameters.
|
||||
districts: Dictionary of district name to location ID.
|
||||
|
||||
Returns:
|
||||
List of initial SubQuery objects.
|
||||
"""
|
||||
subqueries: list[SubQuery] = []
|
||||
|
||||
for district in districts.keys():
|
||||
for num_bedrooms in range(
|
||||
parameters.min_bedrooms, parameters.max_bedrooms + 1
|
||||
):
|
||||
subqueries.append(
|
||||
SubQuery(
|
||||
district=district,
|
||||
min_bedrooms=num_bedrooms,
|
||||
max_bedrooms=num_bedrooms,
|
||||
min_price=parameters.min_price,
|
||||
max_price=parameters.max_price,
|
||||
)
|
||||
)
|
||||
|
||||
return subqueries
|
||||
|
||||
async def probe_result_count(
|
||||
self,
|
||||
subquery: SubQuery,
|
||||
session: aiohttp.ClientSession,
|
||||
parameters: QueryParameters,
|
||||
) -> int:
|
||||
"""Probe the API to get the total result count for a subquery.
|
||||
|
||||
Makes a minimal request (page_size=1) to get totalAvailableResults.
|
||||
|
||||
Args:
|
||||
subquery: The subquery to probe.
|
||||
session: aiohttp session for making requests.
|
||||
parameters: Original query parameters for additional settings.
|
||||
|
||||
Returns:
|
||||
Total available results for this subquery.
|
||||
|
||||
Raises:
|
||||
CircuitBreakerOpenError: If the circuit breaker is open.
|
||||
"""
|
||||
from rec.query import probe_query
|
||||
|
||||
try:
|
||||
result = await probe_query(
|
||||
session=session,
|
||||
channel=parameters.listing_type,
|
||||
min_bedrooms=subquery.min_bedrooms,
|
||||
max_bedrooms=subquery.max_bedrooms,
|
||||
radius=parameters.radius,
|
||||
min_price=subquery.min_price,
|
||||
max_price=subquery.max_price,
|
||||
district=subquery.district,
|
||||
max_days_since_added=parameters.max_days_since_added,
|
||||
furnish_types=parameters.furnish_types or [],
|
||||
config=self.config,
|
||||
)
|
||||
return result.get("totalAvailableResults", 0)
|
||||
except CircuitBreakerOpenError:
|
||||
logger.error("Circuit breaker is open, stopping probe operations")
|
||||
raise
|
||||
except ThrottlingError as e:
|
||||
logger.warning(
|
||||
f"Throttling detected during probe for {subquery.district}: {e}"
|
||||
)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to probe subquery {subquery}: {e}")
|
||||
return 0
|
||||
|
||||
def split_by_price(self, subquery: SubQuery) -> list[SubQuery]:
|
||||
"""Split a subquery into two by halving the price range.
|
||||
|
||||
Args:
|
||||
subquery: The subquery to split.
|
||||
|
||||
Returns:
|
||||
List of two subqueries covering the same price range.
|
||||
"""
|
||||
mid_price = (subquery.min_price + subquery.max_price) // 2
|
||||
|
||||
return [
|
||||
replace(
|
||||
subquery,
|
||||
max_price=mid_price,
|
||||
estimated_results=None,
|
||||
),
|
||||
replace(
|
||||
subquery,
|
||||
min_price=mid_price,
|
||||
estimated_results=None,
|
||||
),
|
||||
]
|
||||
|
||||
async def adaptive_split(
|
||||
self,
|
||||
subquery: SubQuery,
|
||||
session: aiohttp.ClientSession,
|
||||
parameters: QueryParameters,
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> list[SubQuery]:
|
||||
"""Recursively split a subquery until all parts are under threshold.
|
||||
|
||||
Uses binary search on price range to find optimal splits.
|
||||
|
||||
Args:
|
||||
subquery: The subquery to split.
|
||||
session: aiohttp session for making requests.
|
||||
parameters: Original query parameters.
|
||||
semaphore: Semaphore for rate limiting.
|
||||
|
||||
Returns:
|
||||
List of subqueries that are all under the split threshold.
|
||||
"""
|
||||
# Check if we can split further
|
||||
if subquery.price_range <= self.config.min_price_band:
|
||||
logger.warning(
|
||||
f"Cannot split further, price band at minimum: {subquery}"
|
||||
)
|
||||
return [subquery]
|
||||
|
||||
# Split into two halves
|
||||
halves = self.split_by_price(subquery)
|
||||
result: list[SubQuery] = []
|
||||
|
||||
for half in halves:
|
||||
async with semaphore:
|
||||
await asyncio.sleep(self.config.request_delay_ms / 1000)
|
||||
count = await self.probe_result_count(half, session, parameters)
|
||||
|
||||
half = replace(half, estimated_results=count)
|
||||
|
||||
if count > self.config.split_threshold:
|
||||
# Need to split further
|
||||
result.extend(
|
||||
await self.adaptive_split(
|
||||
half, session, parameters, semaphore
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(half)
|
||||
|
||||
return result
|
||||
|
||||
async def split(
|
||||
self,
|
||||
parameters: QueryParameters,
|
||||
session: aiohttp.ClientSession,
|
||||
on_progress: Any = None,
|
||||
) -> list[SubQuery]:
|
||||
"""Split query parameters into optimized subqueries.
|
||||
|
||||
Performs the full splitting algorithm:
|
||||
1. Create initial splits by district and bedroom count
|
||||
2. Probe each to get result counts
|
||||
3. Adaptively split any that exceed the threshold
|
||||
|
||||
Args:
|
||||
parameters: Original query parameters to split.
|
||||
session: aiohttp session for making requests.
|
||||
on_progress: Optional callback for progress updates.
|
||||
Called as on_progress(phase, message, **kwargs) where kwargs
|
||||
contains structured data like subqueries_probed, etc.
|
||||
|
||||
Returns:
|
||||
List of SubQuery objects, each under the result threshold.
|
||||
"""
|
||||
# Get valid districts
|
||||
if parameters.district_names:
|
||||
districts = {
|
||||
district: locid
|
||||
for district, locid in get_districts().items()
|
||||
if district in parameters.district_names
|
||||
}
|
||||
else:
|
||||
districts = get_districts()
|
||||
|
||||
# Phase 1: Create initial subqueries
|
||||
initial_subqueries = self.create_initial_subqueries(parameters, districts)
|
||||
logger.info(f"Created {len(initial_subqueries)} initial subqueries")
|
||||
|
||||
if on_progress:
|
||||
on_progress(
|
||||
phase="splitting",
|
||||
message=f"Created {len(initial_subqueries)} initial subqueries",
|
||||
subqueries_initial=len(initial_subqueries),
|
||||
subqueries_probed=0,
|
||||
)
|
||||
|
||||
# Phase 2: Probe and adaptively split
|
||||
semaphore = asyncio.Semaphore(self.config.max_concurrent_requests)
|
||||
refined_subqueries: list[SubQuery] = []
|
||||
probed_count = 0
|
||||
|
||||
# Probe all initial subqueries in parallel
|
||||
async def probe_and_split(sq: SubQuery) -> list[SubQuery]:
|
||||
nonlocal probed_count
|
||||
async with semaphore:
|
||||
await asyncio.sleep(self.config.request_delay_ms / 1000)
|
||||
count = await self.probe_result_count(sq, session, parameters)
|
||||
|
||||
sq = replace(sq, estimated_results=count)
|
||||
probed_count += 1
|
||||
|
||||
if on_progress:
|
||||
on_progress(
|
||||
phase="splitting",
|
||||
message=f"Probed {probed_count}/{len(initial_subqueries)} subqueries",
|
||||
subqueries_initial=len(initial_subqueries),
|
||||
subqueries_probed=probed_count,
|
||||
)
|
||||
|
||||
if count > self.config.split_threshold:
|
||||
logger.info(
|
||||
f"Subquery {sq.district}/{sq.min_bedrooms}BR "
|
||||
f"has {count} results, splitting..."
|
||||
)
|
||||
return await self.adaptive_split(
|
||||
sq, session, parameters, semaphore
|
||||
)
|
||||
return [sq]
|
||||
|
||||
tasks = [probe_and_split(sq) for sq in initial_subqueries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
for subquery_list in results:
|
||||
refined_subqueries.extend(subquery_list)
|
||||
|
||||
logger.info(
|
||||
f"Refined to {len(refined_subqueries)} subqueries after splitting"
|
||||
)
|
||||
|
||||
total_estimated = self.calculate_total_estimated_results(refined_subqueries)
|
||||
|
||||
if on_progress:
|
||||
on_progress(
|
||||
phase="splitting_complete",
|
||||
message=f"Refined to {len(refined_subqueries)} subqueries",
|
||||
subqueries_total=len(refined_subqueries),
|
||||
estimated_results=total_estimated,
|
||||
)
|
||||
|
||||
return refined_subqueries
|
||||
|
||||
def calculate_total_estimated_results(
|
||||
self, subqueries: list[SubQuery]
|
||||
) -> int:
|
||||
"""Calculate total estimated results across all subqueries.
|
||||
|
||||
Args:
|
||||
subqueries: List of subqueries with estimated_results set.
|
||||
|
||||
Returns:
|
||||
Sum of all estimated results.
|
||||
"""
|
||||
return sum(sq.estimated_results or 0 for sq in subqueries)
|
||||
71
services/route_calculator.py
Normal file
71
services/route_calculator.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""Route calculator service - calculates transit routes using Google Maps API."""
|
||||
from models.listing import DestinationMode, Route, RouteLegStep
|
||||
from repositories.listing_repository import ListingRepository
|
||||
from tqdm.asyncio import tqdm
|
||||
from rec import routing
|
||||
from models import Listing
|
||||
|
||||
|
||||
def _parse_duration(duration_str: str) -> int:
|
||||
"""Parse a duration string like '123s' to integer seconds."""
|
||||
return int(duration_str.rstrip("s"))
|
||||
|
||||
|
||||
async def calculate_route(
|
||||
repository: ListingRepository,
|
||||
destination_address: str,
|
||||
travel_mode: routing.TravelMode,
|
||||
limit: int | None = None,
|
||||
) -> None:
|
||||
"""Calculate transit routes for listings to a destination."""
|
||||
listings = await repository.get_listings()
|
||||
|
||||
if limit is not None:
|
||||
listings = listings[:limit]
|
||||
|
||||
destination_mode = DestinationMode(destination_address, travel_mode)
|
||||
updated_listings = await tqdm.gather(
|
||||
*[update_routing_info(listing, destination_mode) for listing in listings],
|
||||
total=len(listings),
|
||||
desc="Updating routing info",
|
||||
)
|
||||
await repository.upsert_listings(
|
||||
[listing for listing in updated_listings if listing is not None]
|
||||
)
|
||||
|
||||
|
||||
async def update_routing_info(
|
||||
listing: Listing, destination_mode: DestinationMode
|
||||
) -> Listing | None:
|
||||
"""Update routing information for a single listing."""
|
||||
if listing.routing_info.get(destination_mode) is not None:
|
||||
# already calculated, do not recompute to save API calls
|
||||
return None
|
||||
|
||||
routes_data = routing.transit_route(
|
||||
listing.latitude,
|
||||
listing.longitude,
|
||||
destination_mode.destination_address,
|
||||
destination_mode.travel_mode,
|
||||
)
|
||||
|
||||
routes: list[Route] = []
|
||||
for route_data in routes_data["routes"]:
|
||||
duration_s = _parse_duration(route_data["duration"])
|
||||
route = Route(
|
||||
legs=[
|
||||
RouteLegStep(
|
||||
distance_meters=step_data["distanceMeters"],
|
||||
duration_s=_parse_duration(step_data["staticDuration"]),
|
||||
travel_mode=routing.TravelMode(step_data["travelMode"]),
|
||||
)
|
||||
for step_data in route_data["legs"][0]["steps"]
|
||||
],
|
||||
distance_meters=route_data["distanceMeters"],
|
||||
duration_s=duration_s,
|
||||
)
|
||||
routes.append(route)
|
||||
listing.routing_info_json = listing.serialize_routing_info(
|
||||
{**listing.routing_info, **{destination_mode: routes}}
|
||||
)
|
||||
return listing
|
||||
242
services/task_service.py
Normal file
242
services/task_service.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"""Unified task service - shared between CLI and HTTP API.
|
||||
|
||||
Manages background task operations using Celery.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Standard Celery states; anything else is treated as a custom state
|
||||
# whose name is used as the human-readable status message.
|
||||
_CELERY_STANDARD_STATES = frozenset(
|
||||
{"PENDING", "STARTED", "SUCCESS", "FAILURE", "REVOKED", "RETRY"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskStatus:
|
||||
"""Status of a background task."""
|
||||
task_id: str
|
||||
status: str # PENDING, STARTED, SUCCESS, FAILURE, REVOKED, SKIPPED
|
||||
result: Any | None
|
||||
progress: float | None # 0.0 to 1.0
|
||||
processed: int | None # Number of items processed
|
||||
total: int | None # Total number of items
|
||||
message: str | None # Human-readable status message (e.g., "Fetching listings")
|
||||
error: str | None # Error message if failed
|
||||
traceback: str | None # Full traceback if failed
|
||||
|
||||
|
||||
def _make_system_user(email: str) -> Any:
|
||||
"""Create a minimal User object used only for Redis key generation.
|
||||
|
||||
These are *not* real authenticated users -- they exist solely so that
|
||||
RedisRepository can derive the per-user storage key from the email.
|
||||
"""
|
||||
# Lazy import: api.auth imports from api.app which eventually imports
|
||||
# services, so importing at module level would create a circular dependency.
|
||||
from api.auth import User
|
||||
|
||||
return User(sub="", email=email, name="")
|
||||
|
||||
|
||||
def _extract_result(task_result: Any) -> tuple[Any, str | None]:
|
||||
"""Extract a serialisable result and an error string from a Celery AsyncResult.
|
||||
|
||||
Returns:
|
||||
(result, error) -- exactly one of the two will be non-None (or both None
|
||||
for tasks that haven't produced output yet).
|
||||
"""
|
||||
if task_result.failed():
|
||||
error = str(task_result.result) if task_result.result else None
|
||||
return None, error
|
||||
|
||||
try:
|
||||
result = json.loads(json.dumps(task_result.result))
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
result = str(task_result.result) if task_result.result else None
|
||||
return result, None
|
||||
|
||||
|
||||
def _extract_progress_info(task_result: Any) -> dict[str, Any]:
|
||||
"""Extract progress metadata from a Celery AsyncResult's ``info`` dict.
|
||||
|
||||
Returns a dict with keys ``progress``, ``processed``, ``total``, and
|
||||
``message`` (any of which may be None).
|
||||
"""
|
||||
progress: float | None = None
|
||||
processed: int | None = None
|
||||
total: int | None = None
|
||||
message: str | None = None
|
||||
|
||||
if task_result.info and isinstance(task_result.info, dict):
|
||||
progress = task_result.info.get("progress")
|
||||
processed = task_result.info.get("processed")
|
||||
total = task_result.info.get("total")
|
||||
# Use 'message' if available, fall back to 'reason' for SKIPPED tasks
|
||||
message = task_result.info.get("message") or task_result.info.get("reason")
|
||||
|
||||
# For custom states (like "Fetching listings"), use the state as message
|
||||
# if no message was provided in info
|
||||
if not message and task_result.status not in _CELERY_STANDARD_STATES:
|
||||
message = task_result.status
|
||||
|
||||
return {
|
||||
"progress": progress,
|
||||
"processed": processed,
|
||||
"total": total,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
|
||||
def get_task_status(task_id: str) -> TaskStatus:
|
||||
"""Get the status of a background task.
|
||||
|
||||
Used by:
|
||||
- API: GET /api/task_status
|
||||
|
||||
Args:
|
||||
task_id: The Celery task ID
|
||||
|
||||
Returns:
|
||||
TaskStatus with current state
|
||||
"""
|
||||
# Lazy import: listing_tasks imports the Celery app which in turn
|
||||
# pulls in broker configuration; importing at module level would
|
||||
# create a circular dependency chain.
|
||||
from tasks.listing_tasks import dump_listings_task
|
||||
|
||||
task_result = dump_listings_task.AsyncResult(task_id)
|
||||
|
||||
result, error = _extract_result(task_result)
|
||||
task_traceback = task_result.traceback if task_result.failed() else None
|
||||
progress_info = _extract_progress_info(task_result)
|
||||
|
||||
return TaskStatus(
|
||||
task_id=task_id,
|
||||
status=task_result.status,
|
||||
result=result,
|
||||
error=error,
|
||||
traceback=task_traceback,
|
||||
**progress_info,
|
||||
)
|
||||
|
||||
|
||||
def get_user_tasks(user_email: str) -> list[str]:
|
||||
"""Get all task IDs for a user.
|
||||
|
||||
Used by:
|
||||
- API: GET /api/tasks_for_user
|
||||
|
||||
Args:
|
||||
user_email: The user's email address
|
||||
|
||||
Returns:
|
||||
List of task IDs
|
||||
"""
|
||||
# Lazy import: RedisRepository depends on redis which may not be
|
||||
# available at import time in all contexts (CLI, tests).
|
||||
from redis_repository import RedisRepository
|
||||
|
||||
redis_repo = RedisRepository.instance()
|
||||
user = _make_system_user(user_email)
|
||||
return redis_repo.get_tasks_for_user(user)
|
||||
|
||||
|
||||
def add_task_for_user(user_email: str, task_id: str) -> None:
|
||||
"""Associate a task with a user.
|
||||
|
||||
Used by:
|
||||
- API: POST /api/refresh_listings
|
||||
|
||||
Args:
|
||||
user_email: The user's email address
|
||||
task_id: The Celery task ID
|
||||
"""
|
||||
# Lazy import: see get_user_tasks for rationale.
|
||||
from redis_repository import RedisRepository
|
||||
|
||||
redis_repo = RedisRepository.instance()
|
||||
user = _make_system_user(user_email)
|
||||
redis_repo.add_task_for_user(user, task_id)
|
||||
|
||||
|
||||
def cancel_task(task_id: str, user_email: str | None = None) -> bool:
|
||||
"""Cancel a running task and remove it from the user's task list.
|
||||
|
||||
Args:
|
||||
task_id: The Celery task ID
|
||||
user_email: Optional user email to remove task from their list
|
||||
|
||||
Returns:
|
||||
True if task was cancelled successfully
|
||||
"""
|
||||
# Lazy import: celery_app bootstraps the broker connection.
|
||||
from celery_app import app as celery_app
|
||||
|
||||
logger.info("Cancelling task %s (user=%s)", task_id, user_email)
|
||||
# Revoke the task in Celery
|
||||
celery_app.control.revoke(task_id, terminate=True)
|
||||
|
||||
# Also remove from user's task list if user_email provided
|
||||
if user_email:
|
||||
remove_task_from_user(user_email, task_id)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def remove_task_from_user(user_email: str, task_id: str) -> bool:
|
||||
"""Remove a task from a user's task list without cancelling it.
|
||||
|
||||
Use this to clean up stuck tasks that can't be cancelled via Celery.
|
||||
|
||||
Args:
|
||||
user_email: The user's email address
|
||||
task_id: The Celery task ID
|
||||
|
||||
Returns:
|
||||
True if task was removed, False if not found
|
||||
"""
|
||||
# Lazy import: see get_user_tasks for rationale.
|
||||
from redis_repository import RedisRepository
|
||||
|
||||
redis_repo = RedisRepository.instance()
|
||||
user = _make_system_user(user_email)
|
||||
return redis_repo.remove_task_for_user(user, task_id)
|
||||
|
||||
|
||||
def clear_all_tasks(user_email: str, revoke: bool = True) -> int:
|
||||
"""Clear all tasks for a user.
|
||||
|
||||
Args:
|
||||
user_email: The user's email address
|
||||
revoke: If True, also attempt to revoke tasks in Celery
|
||||
|
||||
Returns:
|
||||
Number of tasks cleared
|
||||
"""
|
||||
# Lazy imports: see get_user_tasks and cancel_task for rationale.
|
||||
from redis_repository import RedisRepository
|
||||
from celery_app import app as celery_app
|
||||
|
||||
redis_repo = RedisRepository.instance()
|
||||
user = _make_system_user(user_email)
|
||||
|
||||
logger.info("Clearing all tasks for user %s (revoke=%s)", user_email, revoke)
|
||||
|
||||
# Get tasks before clearing to revoke them
|
||||
if revoke:
|
||||
tasks = redis_repo.get_tasks_for_user(user)
|
||||
for task_id in tasks:
|
||||
try:
|
||||
celery_app.control.revoke(task_id, terminate=True)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to revoke task %s: %s", task_id, e
|
||||
)
|
||||
|
||||
return redis_repo.clear_tasks_for_user(user)
|
||||
Loading…
Add table
Add a link
Reference in a new issue