Refactor codebase following Clean Code principles and add 229 tests
- Extract helpers to reduce function sizes (listing_tasks, app.py, query.py, listing_fetcher) - Replace nonlocal mutations with _PipelineState dataclass in listing_tasks - Fix bugs: isinstance→equality check in repository, verify_exp for OIDC tokens - Consolidate duplicate filter methods in listing_repository - Move hardcoded config to env vars with backward-compatible defaults - Simplify CLI decorator to auto-build QueryParameters - Add deprecation docstring to data_access.py - Test count: 158 → 387 (all passing)
This commit is contained in:
parent
7e05b3c971
commit
150342bb9e
48 changed files with 5029 additions and 990 deletions
|
|
@ -3,7 +3,7 @@ from datetime import datetime, timedelta
|
|||
import json
|
||||
import logging
|
||||
import logging.config
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, AsyncGenerator, Optional
|
||||
from api.auth import get_current_user
|
||||
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
||||
from api.passkey_routes import passkey_router
|
||||
|
|
@ -32,6 +32,8 @@ from opentelemetry.metrics import get_meter
|
|||
load_dotenv()
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
DEFAULT_BATCH_SIZE = 50
|
||||
|
||||
|
||||
def get_query_parameters(
|
||||
listing_type: ListingType,
|
||||
|
|
@ -120,11 +122,79 @@ async def get_listing_geojson(
|
|||
return result.data
|
||||
|
||||
|
||||
|
||||
async def _stream_from_cache(
|
||||
query_parameters: QueryParameters,
|
||||
batch_size: int,
|
||||
limit: int | None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream GeoJSON features from the Redis cache (cache-hit path)."""
|
||||
cached_count = get_cached_count(query_parameters)
|
||||
effective_total = min(limit, cached_count) if limit and cached_count else cached_count
|
||||
|
||||
yield json.dumps({
|
||||
"type": "metadata",
|
||||
"batch_size": batch_size,
|
||||
"total_expected": effective_total,
|
||||
"cached": True,
|
||||
}) + "\n"
|
||||
|
||||
count = 0
|
||||
for feature_batch in get_cached_features(query_parameters, batch_size=batch_size):
|
||||
if limit and count + len(feature_batch) > limit:
|
||||
feature_batch = feature_batch[:limit - count]
|
||||
count += len(feature_batch)
|
||||
yield json.dumps({"type": "batch", "features": feature_batch}) + "\n"
|
||||
if limit and count >= limit:
|
||||
break
|
||||
|
||||
yield json.dumps({"type": "complete", "total": count}) + "\n"
|
||||
|
||||
|
||||
async def _stream_from_db(
|
||||
query_parameters: QueryParameters,
|
||||
batch_size: int,
|
||||
limit: int | None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream GeoJSON features from the database, populating the cache as we go."""
|
||||
repository = ListingRepository(engine)
|
||||
|
||||
total = repository.count_listings(query_parameters)
|
||||
effective_total = min(limit, total) if limit else total
|
||||
|
||||
yield json.dumps({
|
||||
"type": "metadata",
|
||||
"batch_size": batch_size,
|
||||
"total_expected": effective_total,
|
||||
"cached": False,
|
||||
}) + "\n"
|
||||
|
||||
count = 0
|
||||
batch: list[dict] = []
|
||||
for row in repository.stream_listings_optimized(
|
||||
query_parameters, limit=limit, page_size=batch_size
|
||||
):
|
||||
feature = convert_row_to_geojson(row, query_parameters.listing_type.value)
|
||||
batch.append(feature)
|
||||
count += 1
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
cache_features_batch(query_parameters, batch)
|
||||
yield json.dumps({"type": "batch", "features": batch}) + "\n"
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
cache_features_batch(query_parameters, batch)
|
||||
yield json.dumps({"type": "batch", "features": batch}) + "\n"
|
||||
|
||||
yield json.dumps({"type": "complete", "total": count}) + "\n"
|
||||
|
||||
|
||||
@app.get("/api/listing_geojson/stream")
|
||||
async def stream_listing_geojson(
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
query_parameters: Annotated[QueryParameters, Depends(get_query_parameters)],
|
||||
batch_size: int = 50,
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
limit: int | None = None,
|
||||
) -> StreamingResponse:
|
||||
"""Stream listings as NDJSON for progressive map loading.
|
||||
|
|
@ -134,71 +204,14 @@ async def stream_listing_geojson(
|
|||
- batch: Array of GeoJSON features
|
||||
- complete: Final message with total count
|
||||
"""
|
||||
async def generate():
|
||||
# Check cache first
|
||||
cached_count = get_cached_count(query_parameters)
|
||||
|
||||
if cached_count is not None and cached_count > 0:
|
||||
# Cache HIT
|
||||
effective_total = min(limit, cached_count) if limit else cached_count
|
||||
|
||||
yield json.dumps({
|
||||
"type": "metadata",
|
||||
"batch_size": batch_size,
|
||||
"total_expected": effective_total,
|
||||
"cached": True,
|
||||
}) + "\n"
|
||||
|
||||
count = 0
|
||||
for feature_batch in get_cached_features(query_parameters, batch_size=batch_size):
|
||||
if limit and count + len(feature_batch) > limit:
|
||||
feature_batch = feature_batch[:limit - count]
|
||||
count += len(feature_batch)
|
||||
yield json.dumps({"type": "batch", "features": feature_batch}) + "\n"
|
||||
if limit and count >= limit:
|
||||
break
|
||||
|
||||
yield json.dumps({"type": "complete", "total": count}) + "\n"
|
||||
else:
|
||||
# Cache MISS - query DB and populate cache
|
||||
repository = ListingRepository(engine)
|
||||
|
||||
# Phase 1: Fast count for progress estimation
|
||||
total = repository.count_listings(query_parameters)
|
||||
effective_total = min(limit, total) if limit else total
|
||||
|
||||
yield json.dumps({
|
||||
"type": "metadata",
|
||||
"batch_size": batch_size,
|
||||
"total_expected": effective_total,
|
||||
"cached": False,
|
||||
}) + "\n"
|
||||
|
||||
# Phase 2: Stream with column projection and keyset pagination
|
||||
count = 0
|
||||
batch = []
|
||||
for row in repository.stream_listings_optimized(
|
||||
query_parameters, limit=limit, page_size=batch_size
|
||||
):
|
||||
feature = convert_row_to_geojson(row, query_parameters.listing_type.value)
|
||||
batch.append(feature)
|
||||
count += 1
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
cache_features_batch(query_parameters, batch)
|
||||
yield json.dumps({"type": "batch", "features": batch}) + "\n"
|
||||
batch = []
|
||||
|
||||
# Send remaining
|
||||
if batch:
|
||||
cache_features_batch(query_parameters, batch)
|
||||
yield json.dumps({"type": "batch", "features": batch}) + "\n"
|
||||
|
||||
# Final message
|
||||
yield json.dumps({"type": "complete", "total": count}) + "\n"
|
||||
cached_count = get_cached_count(query_parameters)
|
||||
if cached_count is not None and cached_count > 0:
|
||||
generator = _stream_from_cache(query_parameters, batch_size, limit)
|
||||
else:
|
||||
generator = _stream_from_db(query_parameters, batch_size, limit)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
generator,
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ async def _verify_authentik_token(token: str) -> User:
|
|||
algorithms=["RS256"],
|
||||
audience=OIDC_CLIENT_ID,
|
||||
issuer=metadata["issuer"],
|
||||
options={"verify_exp": False},
|
||||
)
|
||||
return User(**payload)
|
||||
|
||||
|
|
@ -84,7 +83,9 @@ async def get_current_user(
|
|||
) -> User:
|
||||
token = credentials.credentials
|
||||
try:
|
||||
# Peek at unverified issuer to route verification
|
||||
# Decode WITHOUT verification just to read the "iss" claim for routing.
|
||||
# This is safe: we only use the issuer to decide which verified decode
|
||||
# path to take next; the actual security check happens in the branch below.
|
||||
unverified = jwt.decode(
|
||||
token, options={"verify_signature": False, "verify_exp": False}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from datetime import timedelta
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# Authentik OIDC Configuration
|
||||
AUTHENTIK_URL = "https://authentik.viktorbarzin.me"
|
||||
OIDC_CLIENT_ID = "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe"
|
||||
AUTHENTIK_URL = os.getenv("AUTHENTIK_URL", "https://authentik.viktorbarzin.me")
|
||||
OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe")
|
||||
OIDC_METADATA_URL = (
|
||||
f"{AUTHENTIK_URL}/application/o/wrongmove/.well-known/openid-configuration"
|
||||
)
|
||||
|
|
@ -23,6 +26,8 @@ WEBAUTHN_ORIGIN = os.getenv("WEBAUTHN_ORIGIN", "https://localhost")
|
|||
|
||||
# JWT Configuration (for passkey-issued tokens)
|
||||
JWT_SECRET = os.getenv("JWT_SECRET", "change-me-in-production")
|
||||
if JWT_SECRET == "change-me-in-production":
|
||||
_logger.warning("JWT_SECRET is using the default value. Set JWT_SECRET env var in production.")
|
||||
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
|
||||
JWT_EXPIRATION_HOURS = int(os.getenv("JWT_EXPIRATION_HOURS", "24"))
|
||||
JWT_ISSUER = os.getenv("JWT_ISSUER", "wrongmove")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,11 @@
|
|||
"""Legacy filesystem-based data access.
|
||||
|
||||
.. deprecated::
|
||||
This module is only used by the ``populate_db`` CLI command for migrating
|
||||
old filesystem data into the database. Do not import from this module in
|
||||
new code. Use ``models.listing.RentListing`` or ``models.listing.BuyListing``
|
||||
and ``repositories.listing_repository.ListingRepository`` instead.
|
||||
"""
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -381,8 +389,6 @@ class Listing:
|
|||
for item in data
|
||||
]
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def listing_site(self) -> ListingSite:
|
||||
return ListingSite.RIGHTMOVE # this class supports only right move
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
{"root":["./src/App.tsx","./src/AppSidebar.tsx","./src/main.tsx","./src/vite-env.d.ts","./src/auth/authService.ts","./src/auth/config.ts","./src/auth/errors.ts","./src/components/ActiveQuery.tsx","./src/components/AlertError.tsx","./src/components/AuthCallback.tsx","./src/components/FilterPanel.tsx","./src/components/Header.tsx","./src/components/HealthIndicator.tsx","./src/components/ListView.tsx","./src/components/LoginModal.tsx","./src/components/Map.tsx","./src/components/Parameters.tsx","./src/components/PropertyCard.tsx","./src/components/Spinner.tsx","./src/components/StatsBar.tsx","./src/components/StreamingProgressBar.tsx","./src/components/TaskIndicator.tsx","./src/components/ui/DatePicker.tsx","./src/components/ui/accordion.tsx","./src/components/ui/alert-dialog.tsx","./src/components/ui/breadcrumb.tsx","./src/components/ui/button.tsx","./src/components/ui/calendar.tsx","./src/components/ui/checkbox.tsx","./src/components/ui/dialog.tsx","./src/components/ui/form.tsx","./src/components/ui/hover-card.tsx","./src/components/ui/input.tsx","./src/components/ui/label.tsx","./src/components/ui/popover.tsx","./src/components/ui/progress.tsx","./src/components/ui/scroll-area.tsx","./src/components/ui/select.tsx","./src/components/ui/separator.tsx","./src/components/ui/sheet.tsx","./src/components/ui/sidebar.tsx","./src/components/ui/skeleton.tsx","./src/components/ui/slider.tsx","./src/components/ui/tooltip.tsx","./src/constants/colorSchemes.ts","./src/constants/index.ts","./src/hooks/use-mobile.ts","./src/lib/utils.ts","./src/services/apiClient.ts","./src/services/healthService.ts","./src/services/index.ts","./src/services/listingService.ts","./src/services/streamingService.ts","./src/services/taskService.ts","./src/types/index.ts","./src/utils/mapUtils.ts"],"version":"5.8.3"}
|
||||
{"root":["./src/app.tsx","./src/appsidebar.tsx","./src/main.tsx","./src/vite-env.d.ts","./src/auth/authservice.ts","./src/auth/config.ts","./src/auth/errors.ts","./src/auth/passkeyservice.ts","./src/auth/types.ts","./src/components/activequery.tsx","./src/components/alerterror.tsx","./src/components/authcallback.tsx","./src/components/filterpanel.tsx","./src/components/header.tsx","./src/components/healthindicator.tsx","./src/components/listview.tsx","./src/components/loginmodal.tsx","./src/components/map.tsx","./src/components/parameters.tsx","./src/components/propertycard.tsx","./src/components/spinner.tsx","./src/components/statsbar.tsx","./src/components/streamingprogressbar.tsx","./src/components/taskindicator.tsx","./src/components/taskprogressdrawer.tsx","./src/components/ui/datepicker.tsx","./src/components/ui/accordion.tsx","./src/components/ui/alert-dialog.tsx","./src/components/ui/breadcrumb.tsx","./src/components/ui/button.tsx","./src/components/ui/calendar.tsx","./src/components/ui/checkbox.tsx","./src/components/ui/dialog.tsx","./src/components/ui/form.tsx","./src/components/ui/hover-card.tsx","./src/components/ui/input.tsx","./src/components/ui/label.tsx","./src/components/ui/popover.tsx","./src/components/ui/progress.tsx","./src/components/ui/scroll-area.tsx","./src/components/ui/select.tsx","./src/components/ui/separator.tsx","./src/components/ui/sheet.tsx","./src/components/ui/sidebar.tsx","./src/components/ui/skeleton.tsx","./src/components/ui/slider.tsx","./src/components/ui/tabs.tsx","./src/components/ui/tooltip.tsx","./src/constants/colorschemes.ts","./src/constants/index.ts","./src/hooks/use-mobile.ts","./src/lib/utils.ts","./src/services/apiclient.ts","./src/services/healthservice.ts","./src/services/index.ts","./src/services/listingservice.ts","./src/services/streamingservice.ts","./src/services/taskservice.ts","./src/types/index.ts","./src/utils/maputils.ts"],"version":"5.8.3"}
|
||||
|
|
@ -6,6 +6,7 @@ from datetime import datetime
|
|||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
from models.listing import FurnishType, Listing, ListingSite, RentListing
|
||||
from rec import floorplan
|
||||
|
|
@ -14,8 +15,33 @@ from repositories.listing_repository import ListingRepository
|
|||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
# Also use celery task logger for visibility in worker output
|
||||
celery_logger = logging.getLogger("celery.task")
|
||||
# Limit OCR threads to 25% of available cores to avoid starving other work.
|
||||
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
|
||||
|
||||
|
||||
def _parse_furnish_type(raw: str | None) -> FurnishType:
|
||||
"""Normalise the raw furnish-type string from the API into a FurnishType enum."""
|
||||
if raw is None:
|
||||
return FurnishType.UNKNOWN
|
||||
if "landlord" in raw.lower():
|
||||
return FurnishType.ASK_LANDLORD
|
||||
lowered = raw.lower()
|
||||
try:
|
||||
return FurnishType(lowered)
|
||||
except ValueError:
|
||||
return FurnishType.UNKNOWN
|
||||
|
||||
|
||||
def _parse_available_from(raw: str | None) -> datetime | None:
|
||||
"""Parse the available-from date string into a datetime, or None."""
|
||||
if raw is None:
|
||||
return None
|
||||
if raw.lower() == "now":
|
||||
return datetime.now()
|
||||
try:
|
||||
return datetime.strptime(raw, "%d/%m/%Y")
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class ListingProcessor:
|
||||
|
|
@ -62,7 +88,6 @@ class ListingProcessor:
|
|||
on_step_complete(short_name)
|
||||
except Exception as e:
|
||||
logger.error(f"[{listing_id}] {step_class_name} failed: {e}")
|
||||
celery_logger.error(f"[{listing_id}] {step_class_name} failed: {e}")
|
||||
return None
|
||||
return listing
|
||||
|
||||
|
|
@ -92,7 +117,7 @@ class FetchListingDetailsStep(Step):
|
|||
|
||||
async def process(self, listing_id: int) -> Listing:
|
||||
logger.debug(f"[{listing_id}] Fetching property details from API")
|
||||
celery_logger.info(f"[{listing_id}] Fetching details...")
|
||||
logger.info(f"[{listing_id}] Fetching details...")
|
||||
|
||||
existing_listings = await self.listing_repository.get_listings(
|
||||
only_ids=[listing_id]
|
||||
|
|
@ -105,30 +130,15 @@ class FetchListingDetailsStep(Step):
|
|||
|
||||
listing_details = await detail_query(listing_id)
|
||||
|
||||
furnish_type_str = listing_details["property"].get("letFurnishType", "unknown")
|
||||
if furnish_type_str is None:
|
||||
furnish_type_str = "unknown"
|
||||
elif "landlord" in furnish_type_str.lower():
|
||||
furnish_type_str = "ask landlord"
|
||||
else:
|
||||
furnish_type_str = furnish_type_str.lower()
|
||||
furnish_type = FurnishType(furnish_type_str)
|
||||
furnish_type = _parse_furnish_type(
|
||||
listing_details["property"].get("letFurnishType", "unknown")
|
||||
)
|
||||
|
||||
available_from: datetime | None = None
|
||||
available_from_str: str | None = listing_details["property"]["letDateAvailable"]
|
||||
if available_from_str is None:
|
||||
available_from = None
|
||||
elif available_from_str.lower() == "now":
|
||||
available_from = datetime.now()
|
||||
else:
|
||||
try:
|
||||
available_from = datetime.strptime(available_from_str, "%d/%m/%Y")
|
||||
except ValueError:
|
||||
# If the date format is not as expected, return None
|
||||
available_from = None
|
||||
available_from = _parse_available_from(
|
||||
listing_details["property"]["letDateAvailable"]
|
||||
)
|
||||
|
||||
photos = listing_details["property"]["photos"]
|
||||
# listing = Listing(
|
||||
listing = RentListing( # TODO: should pick based on price?
|
||||
id=listing_id,
|
||||
price=listing_details["property"]["price"],
|
||||
|
|
@ -150,7 +160,7 @@ class FetchListingDetailsStep(Step):
|
|||
)
|
||||
await self.listing_repository.upsert_listings([listing])
|
||||
|
||||
celery_logger.info(
|
||||
logger.info(
|
||||
f"[{listing_id}] Details fetched: £{listing.price}, "
|
||||
f"{listing.number_of_bedrooms}BR, {listing.agency}"
|
||||
)
|
||||
|
|
@ -190,13 +200,13 @@ class FetchImagesStep(Step):
|
|||
|
||||
downloaded = 0
|
||||
client_timeout = aiohttp.ClientTimeout(total=30)
|
||||
for floorplan_obj in all_floorplans:
|
||||
url = floorplan_obj["url"]
|
||||
picname = url.split("/")[-1]
|
||||
floorplan_path = Path(base_path, str(listing.id), "floorplans", picname)
|
||||
if floorplan_path.exists():
|
||||
continue
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for floorplan_obj in all_floorplans:
|
||||
url = floorplan_obj["url"]
|
||||
picname = Path(urlparse(url).path).name
|
||||
floorplan_path = Path(base_path, str(listing.id), "floorplans", picname)
|
||||
if floorplan_path.exists():
|
||||
continue
|
||||
async with session.get(url, timeout=client_timeout) as response:
|
||||
if response.status == 404:
|
||||
return listing
|
||||
|
|
@ -210,7 +220,7 @@ class FetchImagesStep(Step):
|
|||
|
||||
await self.listing_repository.upsert_listings([listing])
|
||||
|
||||
celery_logger.info(f"[{listing_id}] Downloaded {downloaded} floorplan images")
|
||||
logger.info(f"[{listing_id}] Downloaded {downloaded} floorplan images")
|
||||
logger.debug(f"[{listing_id}] Image fetch complete")
|
||||
return listing
|
||||
|
||||
|
|
@ -220,7 +230,7 @@ class DetectFloorplanStep(Step):
|
|||
|
||||
def __init__(self, listing_repository: ListingRepository):
|
||||
super().__init__(listing_repository)
|
||||
self.ocr_semaphore = asyncio.Semaphore(multiprocessing.cpu_count() // 4)
|
||||
self.ocr_semaphore = asyncio.Semaphore(MAX_OCR_WORKERS)
|
||||
|
||||
async def needs_processing(self, listing_id: int) -> bool:
|
||||
listings = await self.listing_repository.get_listings(only_ids=[listing_id])
|
||||
|
|
@ -256,7 +266,7 @@ class DetectFloorplanStep(Step):
|
|||
await self.listing_repository.upsert_listings([listing])
|
||||
|
||||
if max_sqm > 0:
|
||||
celery_logger.info(f"[{listing_id}] OCR detected {max_sqm} sqm")
|
||||
logger.info(f"[{listing_id}] OCR detected {max_sqm} sqm")
|
||||
else:
|
||||
logger.debug(f"[{listing_id}] OCR: no square meters detected")
|
||||
|
||||
|
|
|
|||
203
crawler/main.py
203
crawler/main.py
|
|
@ -22,13 +22,50 @@ P = ParamSpec("P")
|
|||
R = TypeVar("R")
|
||||
|
||||
|
||||
def build_query_parameters(
|
||||
type: str,
|
||||
district: list[str] | tuple[str, ...] | None,
|
||||
min_bedrooms: int,
|
||||
max_bedrooms: int,
|
||||
min_price: int,
|
||||
max_price: int,
|
||||
furnish_types: list[str] | tuple[str, ...],
|
||||
available_from: datetime | None,
|
||||
last_seen_days: int,
|
||||
min_sqm: int | None = None,
|
||||
radius: int = 0,
|
||||
page_size: int = 500,
|
||||
max_days_since_added: int = 14,
|
||||
) -> QueryParameters:
|
||||
"""Build QueryParameters from CLI options."""
|
||||
return QueryParameters(
|
||||
listing_type=ListingType[type],
|
||||
district_names=set(district) if district else set(),
|
||||
min_bedrooms=min_bedrooms,
|
||||
max_bedrooms=max_bedrooms,
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
furnish_types=[FurnishType[ft] for ft in furnish_types] if furnish_types else None,
|
||||
let_date_available_from=available_from,
|
||||
last_seen_days=last_seen_days,
|
||||
min_sqm=min_sqm,
|
||||
radius=radius,
|
||||
page_size=page_size,
|
||||
max_days_since_added=max_days_since_added,
|
||||
)
|
||||
|
||||
|
||||
def listing_filter_options(func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Decorator to add common options for filtering listings."""
|
||||
"""Decorator that adds common listing filter options and builds QueryParameters.
|
||||
|
||||
The wrapped function receives a `query_parameters: QueryParameters` kwarg
|
||||
instead of individual filter values.
|
||||
"""
|
||||
|
||||
@click.option(
|
||||
"--type",
|
||||
"-t",
|
||||
help="Type of listing to scrape",
|
||||
help="Type of listing to scrape (BUY or RENT)",
|
||||
type=click.Choice(
|
||||
ListingType.__members__.keys(),
|
||||
case_sensitive=False,
|
||||
|
|
@ -50,26 +87,26 @@ def listing_filter_options(func: Callable[P, R]) -> Callable[P, R]:
|
|||
@click.option(
|
||||
"--min-price",
|
||||
default=0,
|
||||
help="Minimum price",
|
||||
help="Minimum price in GBP",
|
||||
type=click.IntRange(min=0),
|
||||
)
|
||||
@click.option(
|
||||
"--max-price",
|
||||
default=999_999,
|
||||
help="Maximum price",
|
||||
help="Maximum price in GBP",
|
||||
type=click.IntRange(min=0),
|
||||
)
|
||||
@click.option(
|
||||
"--district",
|
||||
default=None,
|
||||
help="Districts to scrape",
|
||||
help="District to filter by (can be repeated for multiple districts)",
|
||||
type=click.Choice(district_service.get_district_names(), case_sensitive=False),
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
"--furnish-types",
|
||||
"-f",
|
||||
help="Furnish types for rented listings",
|
||||
help="Furnish type filter for rented listings (can be repeated)",
|
||||
type=click.Choice(
|
||||
[furnish_type.name for furnish_type in FurnishType.__members__.values()],
|
||||
case_sensitive=False,
|
||||
|
|
@ -78,13 +115,13 @@ def listing_filter_options(func: Callable[P, R]) -> Callable[P, R]:
|
|||
)
|
||||
@click.option(
|
||||
"--available-from",
|
||||
help="Let date available from",
|
||||
help="Only include listings available from this date (format: YYYY-MM-DD)",
|
||||
default=None,
|
||||
type=click.DateTime(),
|
||||
)
|
||||
@click.option(
|
||||
"--last-seen-days",
|
||||
help="Last seen (days). If set, only listings that were seen in the last N days will be included.",
|
||||
help="Only include listings seen in the last N days",
|
||||
default=14,
|
||||
type=int,
|
||||
)
|
||||
|
|
@ -95,45 +132,37 @@ def listing_filter_options(func: Callable[P, R]) -> Callable[P, R]:
|
|||
type=int,
|
||||
)
|
||||
@wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return func(*args, **kwargs)
|
||||
def wrapper(
|
||||
*args: P.args,
|
||||
type: str,
|
||||
district: tuple[str, ...],
|
||||
min_bedrooms: int,
|
||||
max_bedrooms: int,
|
||||
min_price: int,
|
||||
max_price: int,
|
||||
furnish_types: tuple[str, ...],
|
||||
available_from: datetime | None,
|
||||
last_seen_days: int,
|
||||
min_sqm: int | None,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
query_parameters = build_query_parameters(
|
||||
type=type,
|
||||
district=district,
|
||||
min_bedrooms=min_bedrooms,
|
||||
max_bedrooms=max_bedrooms,
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
furnish_types=furnish_types,
|
||||
available_from=available_from,
|
||||
last_seen_days=last_seen_days,
|
||||
min_sqm=min_sqm,
|
||||
)
|
||||
return func(*args, query_parameters=query_parameters, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def build_query_parameters(
|
||||
type: str,
|
||||
district: list[str],
|
||||
min_bedrooms: int,
|
||||
max_bedrooms: int,
|
||||
min_price: int,
|
||||
max_price: int,
|
||||
furnish_types: list[str],
|
||||
available_from: datetime | None,
|
||||
last_seen_days: int,
|
||||
min_sqm: int | None = None,
|
||||
radius: int = 0,
|
||||
page_size: int = 500,
|
||||
max_days_since_added: int = 14,
|
||||
) -> QueryParameters:
|
||||
"""Build QueryParameters from CLI options."""
|
||||
return QueryParameters(
|
||||
listing_type=ListingType[type],
|
||||
district_names=set(district) if district else None,
|
||||
min_bedrooms=min_bedrooms,
|
||||
max_bedrooms=max_bedrooms,
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
furnish_types=[FurnishType[ft] for ft in furnish_types] if furnish_types else None,
|
||||
let_date_available_from=available_from,
|
||||
last_seen_days=last_seen_days,
|
||||
min_sqm=min_sqm,
|
||||
radius=radius,
|
||||
page_size=page_size,
|
||||
max_days_since_added=max_days_since_added,
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option(
|
||||
"--data-dir",
|
||||
|
|
@ -155,46 +184,28 @@ def cli(ctx: click.Context, data_dir: str) -> None:
|
|||
|
||||
@cli.command()
|
||||
@listing_filter_options
|
||||
@click.option("--full", is_flag=True, help="Include images and floorplan detection")
|
||||
@click.option(
|
||||
"--include-processing",
|
||||
"-p",
|
||||
is_flag=True,
|
||||
help="Also download images and run floorplan OCR detection",
|
||||
)
|
||||
@click.pass_context
|
||||
def dump_listings(
|
||||
ctx: click.Context,
|
||||
full: bool,
|
||||
district: list[str],
|
||||
min_bedrooms: int,
|
||||
max_bedrooms: int,
|
||||
min_price: int,
|
||||
max_price: int,
|
||||
type: str,
|
||||
furnish_types: list[str],
|
||||
available_from: datetime | None,
|
||||
last_seen_days: int,
|
||||
min_sqm: int | None = None,
|
||||
query_parameters: QueryParameters,
|
||||
include_processing: bool,
|
||||
) -> None:
|
||||
"""Fetch listings from Rightmove API."""
|
||||
data_dir: pathlib.Path = ctx.obj["data_dir"]
|
||||
repository: ListingRepository = ctx.obj["repository"]
|
||||
|
||||
query_parameters = build_query_parameters(
|
||||
type=type,
|
||||
district=district,
|
||||
min_bedrooms=min_bedrooms,
|
||||
max_bedrooms=max_bedrooms,
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
furnish_types=furnish_types,
|
||||
available_from=available_from,
|
||||
last_seen_days=last_seen_days,
|
||||
min_sqm=min_sqm,
|
||||
)
|
||||
|
||||
click.echo(f"Fetching listings with parameters: {query_parameters}")
|
||||
|
||||
result = asyncio.run(
|
||||
listing_service.refresh_listings(
|
||||
repository,
|
||||
query_parameters,
|
||||
full=full,
|
||||
full=include_processing,
|
||||
async_mode=False,
|
||||
)
|
||||
)
|
||||
|
|
@ -240,14 +251,14 @@ def detect_floorplan(ctx: click.Context) -> None:
|
|||
@click.option(
|
||||
"--travel-mode",
|
||||
"-m",
|
||||
help="Travel mode for routing",
|
||||
help="Travel mode for routing (e.g. transit, driving, walking, bicycling)",
|
||||
type=click.Choice(TravelMode.__members__.keys(), case_sensitive=False),
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--limit",
|
||||
"-l",
|
||||
help="Limit the number of listings to process",
|
||||
help="Maximum number of listings to calculate routes for",
|
||||
type=click.IntRange(min=1),
|
||||
default=1,
|
||||
)
|
||||
|
|
@ -293,33 +304,11 @@ def routing(
|
|||
def export_csv(
|
||||
ctx: click.Context,
|
||||
output_file: str,
|
||||
district: list[str],
|
||||
min_bedrooms: int,
|
||||
max_bedrooms: int,
|
||||
min_price: int,
|
||||
max_price: int,
|
||||
type: str,
|
||||
furnish_types: list[str],
|
||||
available_from: datetime | None,
|
||||
last_seen_days: int,
|
||||
min_sqm: int | None = None,
|
||||
query_parameters: QueryParameters,
|
||||
) -> None:
|
||||
"""Export listings to CSV file."""
|
||||
repository: ListingRepository = ctx.obj["repository"]
|
||||
|
||||
query_parameters = build_query_parameters(
|
||||
type=type,
|
||||
district=district,
|
||||
min_bedrooms=min_bedrooms,
|
||||
max_bedrooms=max_bedrooms,
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
furnish_types=furnish_types,
|
||||
available_from=available_from,
|
||||
last_seen_days=last_seen_days,
|
||||
min_sqm=min_sqm,
|
||||
)
|
||||
|
||||
click.echo(f"Exporting to {output_file}")
|
||||
|
||||
result = asyncio.run(
|
||||
|
|
@ -346,33 +335,11 @@ def export_csv(
|
|||
def export_immoweb(
|
||||
ctx: click.Context,
|
||||
output_file: str,
|
||||
district: list[str],
|
||||
min_bedrooms: int,
|
||||
max_bedrooms: int,
|
||||
min_price: int,
|
||||
max_price: int,
|
||||
type: str,
|
||||
furnish_types: list[str],
|
||||
available_from: datetime | None,
|
||||
last_seen_days: int,
|
||||
min_sqm: int | None = None,
|
||||
query_parameters: QueryParameters,
|
||||
) -> None:
|
||||
"""Export listings to GeoJSON file for map visualization."""
|
||||
repository: ListingRepository = ctx.obj["repository"]
|
||||
|
||||
query_parameters = build_query_parameters(
|
||||
type=type,
|
||||
district=district,
|
||||
min_bedrooms=min_bedrooms,
|
||||
max_bedrooms=max_bedrooms,
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
furnish_types=furnish_types,
|
||||
available_from=available_from,
|
||||
last_seen_days=last_seen_days,
|
||||
min_sqm=min_sqm,
|
||||
)
|
||||
|
||||
click.echo(f"Exporting to {output_file}")
|
||||
|
||||
result = asyncio.run(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from datetime import datetime, timedelta
|
|||
import enum
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
from pydantic import BaseModel, Field as PydanticField, model_validator
|
||||
from rec import routing
|
||||
from sqlmodel import JSON, TEXT, SQLModel, Field
|
||||
|
||||
|
|
@ -52,6 +52,21 @@ class ListingSite(enum.StrEnum):
|
|||
# ... add more
|
||||
|
||||
|
||||
def _parse_price_history(price_history_json: str) -> list[PriceHistoryItem]:
|
||||
"""Parse a JSON string into a list of PriceHistoryItem objects."""
|
||||
if not price_history_json:
|
||||
return []
|
||||
parsed: list = json.loads(str(price_history_json))
|
||||
return [
|
||||
PriceHistoryItem(
|
||||
first_seen=datetime.fromisoformat(item["first_seen"]),
|
||||
last_seen=datetime.fromisoformat(item["last_seen"]),
|
||||
price=item["price"],
|
||||
)
|
||||
for item in parsed
|
||||
]
|
||||
|
||||
|
||||
class Listing(SQLModel, table=False):
|
||||
id: int = Field(primary_key=True)
|
||||
price: float = Field(nullable=False, index=True)
|
||||
|
|
@ -61,7 +76,6 @@ class Listing(SQLModel, table=False):
|
|||
council_tax_band: str | None = Field(default=None, nullable=True)
|
||||
longitude: float = Field(nullable=False)
|
||||
latitude: float = Field(nullable=False)
|
||||
# price_history: List[Dict[str, Any]] = Field(default_factory=list, sa_type=JSON)
|
||||
price_history_json: str = Field(sa_type=TEXT)
|
||||
listing_site: ListingSite = Field(nullable=False)
|
||||
last_seen: datetime = Field(
|
||||
|
|
@ -103,20 +117,7 @@ class Listing(SQLModel, table=False):
|
|||
"""
|
||||
Returns a list of PriceHistoryItem objects from the price_history_json.
|
||||
"""
|
||||
if not self.price_history_json:
|
||||
return []
|
||||
parsed: list = json.loads(str(self.price_history_json))
|
||||
for item in parsed:
|
||||
item["first_seen"] = datetime.fromisoformat(item["first_seen"])
|
||||
item["last_seen"] = datetime.fromisoformat(item["last_seen"])
|
||||
return [
|
||||
PriceHistoryItem(
|
||||
first_seen=item["first_seen"],
|
||||
last_seen=item["last_seen"],
|
||||
price=item["price"],
|
||||
)
|
||||
for item in parsed
|
||||
]
|
||||
return _parse_price_history(self.price_history_json)
|
||||
|
||||
@staticmethod
|
||||
def serialize_price_history(price_history: List[PriceHistoryItem]) -> str:
|
||||
|
|
@ -142,36 +143,8 @@ class Listing(SQLModel, table=False):
|
|||
"""
|
||||
if not self.routing_info_json:
|
||||
return {}
|
||||
|
||||
# TODO: move to a separate serializer class
|
||||
json_data = json.loads(self.routing_info_json)
|
||||
destimation_routes = {}
|
||||
for destination_mode_str, routes_json in json_data.items():
|
||||
destination_mode = DestinationMode(
|
||||
destination_address=json.loads(destination_mode_str)[
|
||||
"destination_address"
|
||||
],
|
||||
travel_mode=routing.TravelMode(
|
||||
json.loads(destination_mode_str)["travel_mode"]
|
||||
),
|
||||
)
|
||||
parsed_route = json.loads(routes_json[0])
|
||||
routes = [
|
||||
Route(
|
||||
legs=[
|
||||
RouteLegStep(
|
||||
distance_meters=step["distance_meters"],
|
||||
duration_s=step["duration_s"],
|
||||
travel_mode=routing.TravelMode(step["travel_mode"]),
|
||||
)
|
||||
for step in parsed_route["legs"]
|
||||
],
|
||||
distance_meters=parsed_route["distance_meters"],
|
||||
duration_s=int(parsed_route["duration_s"]),
|
||||
)
|
||||
]
|
||||
destimation_routes[destination_mode] = routes
|
||||
return destimation_routes
|
||||
from rec.route_serializer import RouteSerializer
|
||||
return RouteSerializer.deserialize(self.routing_info_json)
|
||||
|
||||
def serialize_routing_info(
|
||||
self, routing_info: dict[DestinationMode, list[Route]]
|
||||
|
|
@ -179,17 +152,8 @@ class Listing(SQLModel, table=False):
|
|||
"""
|
||||
Serializes the routing_info to a JSON string.
|
||||
"""
|
||||
# TODO: move to a separate serializer class
|
||||
# for destination_mode, routes in routing_info.items():
|
||||
serialized = json.dumps(
|
||||
{
|
||||
json.dumps(dataclasses.asdict(destination_mode)): [
|
||||
json.dumps(dataclasses.asdict(route)) for route in routes
|
||||
]
|
||||
for destination_mode, routes in routing_info.items()
|
||||
}
|
||||
)
|
||||
return serialized
|
||||
from rec.route_serializer import RouteSerializer
|
||||
return RouteSerializer.serialize(routing_info)
|
||||
|
||||
|
||||
class FurnishType(enum.StrEnum):
|
||||
|
|
@ -224,9 +188,9 @@ class DestinationMode:
|
|||
# This allows serializers to pick up a dict representation
|
||||
return asdict(self)
|
||||
|
||||
def __iter__(self):
|
||||
# Makes it behave like a dict when expected
|
||||
return iter(asdict(self).items())
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a dictionary representation of this DestinationMode."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
class ListingType(enum.StrEnum):
|
||||
|
|
@ -254,3 +218,23 @@ class QueryParameters(BaseModel):
|
|||
let_date_available_from: datetime | None = None
|
||||
last_seen_days: int | None = None
|
||||
min_sqm: int | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_ranges(self) -> QueryParameters:
|
||||
if self.min_price > self.max_price:
|
||||
raise ValueError(
|
||||
f"min_price ({self.min_price}) must be <= max_price ({self.max_price})"
|
||||
)
|
||||
if self.min_bedrooms < 0:
|
||||
raise ValueError(
|
||||
f"min_bedrooms ({self.min_bedrooms}) must be non-negative"
|
||||
)
|
||||
if self.max_bedrooms < 0:
|
||||
raise ValueError(
|
||||
f"max_bedrooms ({self.max_bedrooms}) must be non-negative"
|
||||
)
|
||||
if self.min_bedrooms > self.max_bedrooms:
|
||||
raise ValueError(
|
||||
f"min_bedrooms ({self.min_bedrooms}) must be <= max_bedrooms ({self.max_bedrooms})"
|
||||
)
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -36,3 +36,8 @@ def get_districts() -> dict[str, str]:
|
|||
"Wandsworth": "REGION^93977",
|
||||
"Westminster": "REGION^93980",
|
||||
}
|
||||
|
||||
|
||||
def get_district_by_name(name: str) -> str | None:
|
||||
"""Return the region ID for a district name, or None if not found."""
|
||||
return get_districts().get(name)
|
||||
|
|
|
|||
|
|
@ -72,3 +72,14 @@ class CircuitBreakerOpenError(RightmoveAPIError):
|
|||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RoutingApiError(Exception):
|
||||
"""Error from the Google Routes API."""
|
||||
|
||||
def __init__(self, status_code: int, response_body: dict):
|
||||
self.status_code = status_code
|
||||
self.response_body = response_body
|
||||
super().__init__(
|
||||
f"Routes API returned status {status_code}: {response_body}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -5,6 +6,11 @@ from PIL import Image
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIN_SQM = 30
|
||||
MAX_SQM = 160
|
||||
|
||||
|
||||
def inference(image_path: str | Path) -> tuple[str, Any]:
|
||||
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
||||
|
|
@ -22,26 +28,21 @@ def inference(image_path: str | Path) -> tuple[str, Any]:
|
|||
|
||||
|
||||
def extract_total_sqm(input_str: str) -> float | None:
|
||||
# Note: can be used on the output of inference() to extract sqm from model predictions.
|
||||
sqmregex = r"(\d+\.?\d*) ?(sq ?m|sq. ?m)"
|
||||
matches = re.findall(sqmregex, input_str.lower())
|
||||
sqms = [float(m[0]) for m in matches]
|
||||
filtered = [sqm for sqm in sqms if 30 < sqm < 160]
|
||||
filtered = [sqm for sqm in sqms if MIN_SQM < sqm < MAX_SQM]
|
||||
if len(filtered) == 0:
|
||||
return None
|
||||
return max(filtered)
|
||||
|
||||
|
||||
def calculate_model(image_path: str | Path) -> tuple[float | None, str, Any]:
|
||||
output, predictions_tensor = inference(image_path)
|
||||
estimated_sqm = extract_total_sqm(output)
|
||||
return estimated_sqm, output, predictions_tensor
|
||||
|
||||
|
||||
def improve_img_for_ocr(img: Image.Image) -> Image.Image:
|
||||
img2 = np.array(img.convert("L"))
|
||||
cv2.resize(img2, None, fx=1.2, fy=1.2, interpolation=cv2.INTER_CUBIC)
|
||||
grayscale_image = np.array(img.convert("L"))
|
||||
grayscale_image = cv2.resize(grayscale_image, None, fx=1.2, fy=1.2, interpolation=cv2.INTER_CUBIC)
|
||||
thresh = cv2.adaptiveThreshold(
|
||||
img2, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
|
||||
grayscale_image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
|
||||
)
|
||||
return Image.fromarray(thresh)
|
||||
|
||||
|
|
@ -49,15 +50,18 @@ def improve_img_for_ocr(img: Image.Image) -> Image.Image:
|
|||
def calculate_ocr(image_path: str | Path) -> tuple[float | None, str]:
|
||||
import pytesseract
|
||||
|
||||
img = Image.open(image_path)
|
||||
path = Path(image_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Image not found: {image_path}")
|
||||
|
||||
img = Image.open(path)
|
||||
text = pytesseract.image_to_string(img)
|
||||
estimated_sqm = extract_total_sqm(text)
|
||||
if estimated_sqm is None:
|
||||
improved_img = improve_img_for_ocr(img)
|
||||
text2 = pytesseract.image_to_string(improved_img)
|
||||
estimated_sqm2 = extract_total_sqm(text2)
|
||||
with open("recalculating.log", "a") as f:
|
||||
f.write(f"before: {estimated_sqm} after: {estimated_sqm2} - {image_path}\n")
|
||||
logger.debug(f"before: {estimated_sqm} after: {estimated_sqm2} - {image_path}")
|
||||
return estimated_sqm2, text2
|
||||
|
||||
return estimated_sqm, text
|
||||
|
|
|
|||
|
|
@ -28,6 +28,11 @@ logger = logging.getLogger("uvicorn.error")
|
|||
# Global circuit breaker instance
|
||||
_circuit_breaker: CircuitBreaker | None = None
|
||||
|
||||
# API constants
|
||||
ANDROID_APP_VERSION = "3.70.0"
|
||||
ANDROID_APP_VERSION_LISTING = "4.28.0"
|
||||
RIGHTMOVE_API_BASE = "https://api.rightmove.co.uk/api"
|
||||
PROPERTY_LISTING_ENDPOINT = f"{RIGHTMOVE_API_BASE}/property-listing"
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
"Host": "api.rightmove.co.uk",
|
||||
|
|
@ -35,6 +40,11 @@ DEFAULT_HEADERS = {
|
|||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
LISTING_HEADERS = {
|
||||
**DEFAULT_HEADERS,
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
}
|
||||
|
||||
|
||||
class PropertyType(enum.StrEnum):
|
||||
BUNGALOW = "bungalow"
|
||||
|
|
@ -129,6 +139,177 @@ def check_circuit_breaker(config: ScraperConfig | None = None) -> None:
|
|||
cb.call()
|
||||
|
||||
|
||||
def _build_base_params(
|
||||
*,
|
||||
channel: ListingType,
|
||||
page: int,
|
||||
page_size: int,
|
||||
radius: float,
|
||||
min_price: int,
|
||||
max_price: int,
|
||||
min_bedrooms: int,
|
||||
max_bedrooms: int,
|
||||
district: str,
|
||||
) -> dict[str, str]:
|
||||
return {
|
||||
"locationIdentifier": districts.get_districts()[district],
|
||||
"channel": str(channel).upper(),
|
||||
"page": str(page),
|
||||
"numberOfPropertiesPerPage": str(page_size),
|
||||
"radius": str(radius),
|
||||
"sortBy": "distance",
|
||||
"includeUnavailableProperties": "false",
|
||||
"minPrice": str(min_price),
|
||||
"maxPrice": str(max_price),
|
||||
"minBedrooms": str(min_bedrooms),
|
||||
"maxBedrooms": str(max_bedrooms),
|
||||
"apiApplication": "ANDROID",
|
||||
"appVersion": ANDROID_APP_VERSION_LISTING,
|
||||
}
|
||||
|
||||
|
||||
def _build_listing_params(
|
||||
*,
|
||||
page: int,
|
||||
channel: ListingType,
|
||||
min_bedrooms: int,
|
||||
max_bedrooms: int,
|
||||
radius: float,
|
||||
min_price: int,
|
||||
max_price: int,
|
||||
district: str,
|
||||
mustNewHome: bool,
|
||||
max_days_since_added: int,
|
||||
property_type: list[PropertyType],
|
||||
page_size: int,
|
||||
furnish_types: list[FurnishType],
|
||||
) -> dict[str, str]:
|
||||
params = _build_base_params(
|
||||
channel=channel,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
radius=radius,
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
min_bedrooms=min_bedrooms,
|
||||
max_bedrooms=max_bedrooms,
|
||||
district=district,
|
||||
)
|
||||
if channel is ListingType.BUY:
|
||||
params["dontShow"] = "sharedOwnership,retirement"
|
||||
if len(property_type) > 0:
|
||||
params["propertyTypes"] = ",".join(property_type)
|
||||
if max_days_since_added is not None and max_days_since_added not in [
|
||||
1,
|
||||
3,
|
||||
7,
|
||||
14,
|
||||
]:
|
||||
raise Exception(
|
||||
f"Invalid max days - {max_days_since_added} Can only be got",
|
||||
[1, 3, 7, 14],
|
||||
)
|
||||
params["maxDaysSinceAdded"] = str(max_days_since_added)
|
||||
|
||||
if mustNewHome:
|
||||
params["mustHave"] = "newHome"
|
||||
if channel is ListingType.RENT:
|
||||
if furnish_types:
|
||||
params["furnishTypes"] = ",".join(furnish_types)
|
||||
return params
|
||||
|
||||
|
||||
def _build_probe_params(
|
||||
*,
|
||||
channel: ListingType,
|
||||
min_bedrooms: int,
|
||||
max_bedrooms: int,
|
||||
radius: float,
|
||||
min_price: int,
|
||||
max_price: int,
|
||||
district: str,
|
||||
max_days_since_added: int,
|
||||
furnish_types: list[FurnishType],
|
||||
) -> dict[str, str]:
|
||||
params = _build_base_params(
|
||||
channel=channel,
|
||||
page=1,
|
||||
page_size=1, # Minimal page size for probing
|
||||
radius=radius,
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
min_bedrooms=min_bedrooms,
|
||||
max_bedrooms=max_bedrooms,
|
||||
district=district,
|
||||
)
|
||||
if channel is ListingType.BUY:
|
||||
params["dontShow"] = "sharedOwnership,retirement"
|
||||
if max_days_since_added is not None and max_days_since_added in [
|
||||
1,
|
||||
3,
|
||||
7,
|
||||
14,
|
||||
]:
|
||||
params["maxDaysSinceAdded"] = str(max_days_since_added)
|
||||
|
||||
if channel is ListingType.RENT:
|
||||
if furnish_types:
|
||||
params["furnishTypes"] = ",".join(furnish_types)
|
||||
return params
|
||||
|
||||
|
||||
async def _execute_api_request(
|
||||
*,
|
||||
url: str,
|
||||
params: dict[str, str],
|
||||
headers: dict[str, str],
|
||||
session: aiohttp.ClientSession | None,
|
||||
config: ScraperConfig,
|
||||
expect_data: bool = True,
|
||||
error_context: str = "",
|
||||
) -> dict[str, Any]:
|
||||
check_circuit_breaker(config)
|
||||
cb = get_circuit_breaker(config)
|
||||
|
||||
async def do_request(s: aiohttp.ClientSession) -> dict[str, Any]:
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with s.get(url, params=params, headers=headers) as response:
|
||||
response_time = time.time() - start_time
|
||||
body = await response.json() if response.status == 200 else None
|
||||
|
||||
validate_response(
|
||||
response,
|
||||
response_time,
|
||||
body,
|
||||
config.slow_response_threshold,
|
||||
expect_data=expect_data,
|
||||
)
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"{error_context}Failed due to: {await response.text()}"
|
||||
)
|
||||
|
||||
if cb is not None:
|
||||
cb.record_success()
|
||||
return body # type: ignore
|
||||
except ThrottlingError:
|
||||
if cb is not None:
|
||||
cb.record_failure()
|
||||
raise
|
||||
except Exception as e:
|
||||
if cb is not None:
|
||||
cb.record_failure()
|
||||
raise e
|
||||
|
||||
if session:
|
||||
return await do_request(session)
|
||||
else:
|
||||
async with aiohttp.ClientSession(trust_env=True) as new_session:
|
||||
return await do_request(new_session)
|
||||
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(ThrottlingError),
|
||||
wait=wait_exponential(multiplier=2, min=2, max=120),
|
||||
|
|
@ -156,54 +337,21 @@ async def detail_query(
|
|||
if config is None:
|
||||
config = ScraperConfig.from_env()
|
||||
|
||||
check_circuit_breaker(config)
|
||||
cb = get_circuit_breaker(config)
|
||||
|
||||
params = {
|
||||
"apiApplication": "ANDROID",
|
||||
"appVersion": "3.70.0",
|
||||
"appVersion": ANDROID_APP_VERSION,
|
||||
}
|
||||
url = f"https://api.rightmove.co.uk/api/property/{detail_id}"
|
||||
url = f"{RIGHTMOVE_API_BASE}/property/{detail_id}"
|
||||
|
||||
async def do_request(s: aiohttp.ClientSession) -> dict[str, Any]:
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with s.get(url, params=params, headers=DEFAULT_HEADERS) as response:
|
||||
response_time = time.time() - start_time
|
||||
body = await response.json() if response.status == 200 else None
|
||||
|
||||
# Validate response for throttling
|
||||
validate_response(
|
||||
response,
|
||||
response_time,
|
||||
body,
|
||||
config.slow_response_threshold,
|
||||
expect_data=True,
|
||||
)
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"""id: {detail_id}. Status Code: {response.status}."""
|
||||
f"""Failed due to: {await response.text()}"""
|
||||
)
|
||||
|
||||
if cb is not None:
|
||||
cb.record_success()
|
||||
return body # type: ignore
|
||||
except ThrottlingError:
|
||||
if cb is not None:
|
||||
cb.record_failure()
|
||||
raise
|
||||
except Exception as e:
|
||||
if cb is not None:
|
||||
cb.record_failure()
|
||||
raise e
|
||||
|
||||
if session:
|
||||
return await do_request(session)
|
||||
else:
|
||||
async with aiohttp.ClientSession(trust_env=True) as new_session:
|
||||
return await do_request(new_session)
|
||||
return await _execute_api_request(
|
||||
url=url,
|
||||
params=params,
|
||||
headers=DEFAULT_HEADERS,
|
||||
session=session,
|
||||
config=config,
|
||||
expect_data=True,
|
||||
error_context=f"id: {detail_id}. Status Code: ",
|
||||
)
|
||||
|
||||
|
||||
@retry(
|
||||
|
|
@ -223,9 +371,9 @@ async def listing_query(
|
|||
district: str, # = "STATION^5168", # kings cross station
|
||||
mustNewHome: bool = False,
|
||||
max_days_since_added: int = 30,
|
||||
property_type: list[PropertyType] = [],
|
||||
property_type: list[PropertyType] | None = None,
|
||||
page_size: int = 25,
|
||||
furnish_types: list[FurnishType] = [],
|
||||
furnish_types: list[FurnishType] | None = None,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
config: ScraperConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
|
|
@ -257,94 +405,35 @@ async def listing_query(
|
|||
"""
|
||||
if config is None:
|
||||
config = ScraperConfig.from_env()
|
||||
if property_type is None:
|
||||
property_type = []
|
||||
if furnish_types is None:
|
||||
furnish_types = []
|
||||
|
||||
check_circuit_breaker(config)
|
||||
cb = get_circuit_breaker(config)
|
||||
params = _build_listing_params(
|
||||
page=page,
|
||||
channel=channel,
|
||||
min_bedrooms=min_bedrooms,
|
||||
max_bedrooms=max_bedrooms,
|
||||
radius=radius,
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
district=district,
|
||||
mustNewHome=mustNewHome,
|
||||
max_days_since_added=max_days_since_added,
|
||||
property_type=property_type,
|
||||
page_size=page_size,
|
||||
furnish_types=furnish_types,
|
||||
)
|
||||
|
||||
params: dict[str, str] = {
|
||||
"locationIdentifier": districts.get_districts()[district],
|
||||
"channel": str(channel).upper(),
|
||||
"page": str(page),
|
||||
"numberOfPropertiesPerPage": str(page_size),
|
||||
"radius": str(radius),
|
||||
"sortBy": "distance",
|
||||
"includeUnavailableProperties": "false",
|
||||
"minPrice": str(min_price),
|
||||
"maxPrice": str(max_price),
|
||||
"minBedrooms": str(min_bedrooms),
|
||||
"maxBedrooms": str(max_bedrooms),
|
||||
"apiApplication": "ANDROID",
|
||||
"appVersion": "4.28.0",
|
||||
}
|
||||
if channel is ListingType.BUY:
|
||||
params["dontShow"] = "sharedOwnership,retirement"
|
||||
if len(property_type) > 0:
|
||||
params["propertyTypes"] = ",".join(property_type)
|
||||
if max_days_since_added is not None and max_days_since_added not in [
|
||||
1,
|
||||
3,
|
||||
7,
|
||||
14,
|
||||
]:
|
||||
raise Exception(
|
||||
f"Invalid max days - {max_days_since_added} Can only be got",
|
||||
[1, 3, 7, 14],
|
||||
)
|
||||
params["maxDaysSinceAdded"] = str(max_days_since_added)
|
||||
|
||||
if mustNewHome:
|
||||
params["mustHave"] = "newHome"
|
||||
if channel is ListingType.RENT:
|
||||
if furnish_types:
|
||||
params["furnishTypes"] = ",".join(furnish_types)
|
||||
|
||||
request_headers = {
|
||||
"Host": "api.rightmove.co.uk",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
"User-Agent": "okhttp/4.12.0",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
async def do_request(s: aiohttp.ClientSession) -> dict[str, Any]:
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with s.get(
|
||||
"https://api.rightmove.co.uk/api/property-listing",
|
||||
params=params,
|
||||
headers=request_headers,
|
||||
) as response:
|
||||
response_time = time.time() - start_time
|
||||
body = await response.json() if response.status == 200 else None
|
||||
|
||||
# Validate response for throttling
|
||||
validate_response(
|
||||
response,
|
||||
response_time,
|
||||
body,
|
||||
config.slow_response_threshold,
|
||||
expect_data=(page == 1), # Only expect data on first page
|
||||
)
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed due to: {await response.text()}")
|
||||
|
||||
if cb is not None:
|
||||
cb.record_success()
|
||||
return body # type: ignore
|
||||
except ThrottlingError:
|
||||
if cb is not None:
|
||||
cb.record_failure()
|
||||
raise
|
||||
except Exception as e:
|
||||
if cb is not None:
|
||||
cb.record_failure()
|
||||
raise e
|
||||
|
||||
if session:
|
||||
return await do_request(session)
|
||||
else:
|
||||
async with aiohttp.ClientSession(trust_env=True) as new_session:
|
||||
return await do_request(new_session)
|
||||
return await _execute_api_request(
|
||||
url=PROPERTY_LISTING_ENDPOINT,
|
||||
params=params,
|
||||
headers=LISTING_HEADERS,
|
||||
session=session,
|
||||
config=config,
|
||||
expect_data=(page == 1),
|
||||
)
|
||||
|
||||
|
||||
@retry(
|
||||
|
|
@ -363,7 +452,7 @@ async def probe_query(
|
|||
max_price: int,
|
||||
district: str,
|
||||
max_days_since_added: int = 30,
|
||||
furnish_types: list[FurnishType] = [],
|
||||
furnish_types: list[FurnishType] | None = None,
|
||||
config: ScraperConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Probe the API to get result count without fetching full results.
|
||||
|
|
@ -392,77 +481,27 @@ async def probe_query(
|
|||
"""
|
||||
if config is None:
|
||||
config = ScraperConfig.from_env()
|
||||
if furnish_types is None:
|
||||
furnish_types = []
|
||||
|
||||
check_circuit_breaker(config)
|
||||
cb = get_circuit_breaker(config)
|
||||
params = _build_probe_params(
|
||||
channel=channel,
|
||||
min_bedrooms=min_bedrooms,
|
||||
max_bedrooms=max_bedrooms,
|
||||
radius=radius,
|
||||
min_price=min_price,
|
||||
max_price=max_price,
|
||||
district=district,
|
||||
max_days_since_added=max_days_since_added,
|
||||
furnish_types=furnish_types,
|
||||
)
|
||||
|
||||
params: dict[str, str] = {
|
||||
"locationIdentifier": districts.get_districts()[district],
|
||||
"channel": str(channel).upper(),
|
||||
"page": "1",
|
||||
"numberOfPropertiesPerPage": "1", # Minimal page size for probing
|
||||
"radius": str(radius),
|
||||
"sortBy": "distance",
|
||||
"includeUnavailableProperties": "false",
|
||||
"minPrice": str(min_price),
|
||||
"maxPrice": str(max_price),
|
||||
"minBedrooms": str(min_bedrooms),
|
||||
"maxBedrooms": str(max_bedrooms),
|
||||
"apiApplication": "ANDROID",
|
||||
"appVersion": "4.28.0",
|
||||
}
|
||||
|
||||
if channel is ListingType.BUY:
|
||||
params["dontShow"] = "sharedOwnership,retirement"
|
||||
if max_days_since_added is not None and max_days_since_added in [
|
||||
1,
|
||||
3,
|
||||
7,
|
||||
14,
|
||||
]:
|
||||
params["maxDaysSinceAdded"] = str(max_days_since_added)
|
||||
|
||||
if channel is ListingType.RENT:
|
||||
if furnish_types:
|
||||
params["furnishTypes"] = ",".join(furnish_types)
|
||||
|
||||
request_headers = {
|
||||
"Host": "api.rightmove.co.uk",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
"User-Agent": "okhttp/4.12.0",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with session.get(
|
||||
"https://api.rightmove.co.uk/api/property-listing",
|
||||
params=params,
|
||||
headers=request_headers,
|
||||
) as response:
|
||||
response_time = time.time() - start_time
|
||||
body = await response.json() if response.status == 200 else None
|
||||
|
||||
# Validate response for throttling
|
||||
validate_response(
|
||||
response,
|
||||
response_time,
|
||||
body,
|
||||
config.slow_response_threshold,
|
||||
expect_data=False, # Probe doesn't need data, just count
|
||||
)
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(f"Probe failed: {await response.text()}")
|
||||
|
||||
if cb is not None:
|
||||
cb.record_success()
|
||||
return body # type: ignore
|
||||
except ThrottlingError:
|
||||
if cb is not None:
|
||||
cb.record_failure()
|
||||
raise
|
||||
except Exception as e:
|
||||
if cb is not None:
|
||||
cb.record_failure()
|
||||
raise e
|
||||
return await _execute_api_request(
|
||||
url=PROPERTY_LISTING_ENDPOINT,
|
||||
params=params,
|
||||
headers=LISTING_HEADERS,
|
||||
session=session,
|
||||
config=config,
|
||||
expect_data=False,
|
||||
error_context="Probe failed: ",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import dataclasses
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
|
|
@ -7,20 +8,25 @@ from rec import routing
|
|||
|
||||
class RouteSerializer:
|
||||
@staticmethod
|
||||
def serialize(route): ...
|
||||
def serialize(routing_info: dict[DestinationMode, list[Route]]) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
json.dumps(dataclasses.asdict(destination_mode)): [
|
||||
json.dumps(dataclasses.asdict(route)) for route in routes
|
||||
]
|
||||
for destination_mode, routes in routing_info.items()
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(route_data_json: str) -> dict[DestinationMode, List[Route]]:
|
||||
json_data = json.loads(route_data_json)
|
||||
destimation_routes = {}
|
||||
destination_routes = {}
|
||||
for destination_mode_str, routes_json in json_data.items():
|
||||
parsed_destination = json.loads(destination_mode_str)
|
||||
destination_mode = DestinationMode(
|
||||
destination_address=json.loads(destination_mode_str)[
|
||||
"destination_address"
|
||||
],
|
||||
travel_mode=routing.TravelMode(
|
||||
json.loads(destination_mode_str)["travel_mode"]
|
||||
),
|
||||
destination_address=parsed_destination["destination_address"],
|
||||
travel_mode=routing.TravelMode(parsed_destination["travel_mode"]),
|
||||
)
|
||||
parsed_route = json.loads(routes_json[0])
|
||||
routes = [
|
||||
|
|
@ -37,5 +43,5 @@ class RouteSerializer:
|
|||
duration_s=int(parsed_route["duration_s"]),
|
||||
)
|
||||
]
|
||||
destimation_routes[destination_mode] = routes
|
||||
return destimation_routes
|
||||
destination_routes[destination_mode] = routes
|
||||
return destination_routes
|
||||
|
|
|
|||
|
|
@ -3,9 +3,18 @@ import os
|
|||
from typing import Any
|
||||
import requests
|
||||
from rec.utils import nextMonday
|
||||
from rec.exceptions import RoutingApiError
|
||||
|
||||
url = "https://routes.googleapis.com/directions/v2:computeRoutes"
|
||||
ROUTES_API_URL = "https://routes.googleapis.com/directions/v2:computeRoutes"
|
||||
API_KEY_ENVIRONMENT_VARIABLE = "ROUTING_API_KEY"
|
||||
ROUTES_FIELD_MASK = (
|
||||
"routes.distanceMeters,"
|
||||
"routes.duration,"
|
||||
"routes.staticDuration,"
|
||||
"routes.legs.steps.distanceMeters,"
|
||||
"routes.legs.steps.staticDuration,"
|
||||
"routes.legs.steps.travelMode"
|
||||
)
|
||||
|
||||
|
||||
class TravelMode(enum.StrEnum):
|
||||
|
|
@ -20,7 +29,7 @@ def transit_route(
|
|||
origin_lon: float,
|
||||
dest_address: str,
|
||||
travel_mode: TravelMode,
|
||||
compute_alternative_routes=True,
|
||||
compute_alternative_routes: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
monday9am = nextMonday()
|
||||
|
||||
|
|
@ -30,38 +39,25 @@ def transit_route(
|
|||
header = {
|
||||
"X-Goog-Api-Key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
"X-Goog-FieldMask": "routes.distanceMeters,routes.duration,routes.staticDuration,routes.legs.steps.distanceMeters,routes.legs.steps.staticDuration,routes.legs.steps.travelMode", # "routes.*",
|
||||
"X-Goog-FieldMask": ROUTES_FIELD_MASK,
|
||||
}
|
||||
|
||||
body = {
|
||||
"origin": {
|
||||
# "address": origin_address
|
||||
"location": {"latLng": {"latitude": origin_lat, "longitude": origin_lon}}
|
||||
},
|
||||
"destination": {
|
||||
"address": dest_address
|
||||
# "location": {
|
||||
# "latLng": {
|
||||
# "latitude": dest_lat,
|
||||
# "longitude": dest_lon
|
||||
# }
|
||||
# }
|
||||
},
|
||||
"travelMode": travel_mode.value,
|
||||
# "2023-10-15T15:01:23.045123456Z"
|
||||
"departureTime": monday9am.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
|
||||
"computeAlternativeRoutes": compute_alternative_routes,
|
||||
# "routeModifiers": {
|
||||
# "avoidTolls": false,
|
||||
# "avoidHighways": false,
|
||||
# "avoidFerries": false
|
||||
# },
|
||||
"languageCode": "en-US",
|
||||
"units": "METRIC",
|
||||
}
|
||||
|
||||
r = requests.post(url, json=body, headers=header)
|
||||
r = requests.post(ROUTES_API_URL, json=body, headers=header)
|
||||
if r.status_code == 200:
|
||||
return r.json()
|
||||
|
||||
raise Exception(r.json())
|
||||
raise RoutingApiError(r.status_code, r.json())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
from typing import Generator
|
||||
from typing import Any, Generator
|
||||
from data_access import Listing
|
||||
from models.listing import (
|
||||
BuyListing,
|
||||
|
|
@ -12,7 +12,6 @@ from models.listing import (
|
|||
)
|
||||
from sqlalchemy import Engine, func, select as sa_select
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
|
@ -27,8 +26,10 @@ STREAMING_COLUMNS = [
|
|||
|
||||
class ListingRepository:
|
||||
engine: Engine
|
||||
# anything more than 10k is considered buy type
|
||||
buy_listing_price_threshold: int = 20_000
|
||||
|
||||
# Monthly rent prices in the UK are always below 20,000 GBP.
|
||||
# Any listing priced at or above this threshold is treated as a purchase (buy) listing.
|
||||
BUY_LISTING_PRICE_THRESHOLD: int = 20_000
|
||||
|
||||
def __init__(self, engine: Engine):
|
||||
self.engine = engine
|
||||
|
|
@ -44,24 +45,16 @@ class ListingRepository:
|
|||
"""
|
||||
only_ids = only_ids or []
|
||||
|
||||
model = RentListing # if no query params, default to renting listings
|
||||
if query_parameters:
|
||||
model = (
|
||||
RentListing
|
||||
if query_parameters.listing_type == ListingType.RENT
|
||||
else BuyListing
|
||||
# else RentListing
|
||||
)
|
||||
model = self._get_model_for_query(query_parameters)
|
||||
|
||||
query = select(model)
|
||||
if only_ids:
|
||||
query = query.where(model.id.in_(only_ids)) # type: ignore
|
||||
query = self._add_where_from_query_parameters(query, model, query_parameters)
|
||||
query = self._apply_query_filters(query, model, query_parameters)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
with Session(self.engine) as session:
|
||||
# query = select(modelListing)
|
||||
rows = list(session.exec(query).all())
|
||||
logging.debug(f"Found {len(rows)} listings")
|
||||
return rows
|
||||
|
|
@ -81,16 +74,10 @@ class ListingRepository:
|
|||
limit: Maximum number of listings to yield
|
||||
chunk_size: Number of rows to fetch at a time from the database
|
||||
"""
|
||||
model = RentListing # if no query params, default to renting listings
|
||||
if query_parameters:
|
||||
model = (
|
||||
RentListing
|
||||
if query_parameters.listing_type == ListingType.RENT
|
||||
else BuyListing
|
||||
)
|
||||
model = self._get_model_for_query(query_parameters)
|
||||
|
||||
query = select(model)
|
||||
query = self._add_where_from_query_parameters(query, model, query_parameters)
|
||||
query = self._apply_query_filters(query, model, query_parameters)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
|
|
@ -111,7 +98,7 @@ class ListingRepository:
|
|||
model = self._get_model_for_query(query_parameters)
|
||||
|
||||
query = sa_select(func.count(model.id))
|
||||
query = self._add_where_from_query_parameters_raw(query, model, query_parameters)
|
||||
query = self._apply_query_filters(query, model, query_parameters)
|
||||
|
||||
with Session(self.engine) as session:
|
||||
return session.execute(query).scalar() or 0
|
||||
|
|
@ -147,7 +134,7 @@ class ListingRepository:
|
|||
break
|
||||
|
||||
query = sa_select(*columns)
|
||||
query = self._add_where_from_query_parameters_raw(
|
||||
query = self._apply_query_filters(
|
||||
query, model, query_parameters
|
||||
)
|
||||
|
||||
|
|
@ -174,13 +161,25 @@ class ListingRepository:
|
|||
if len(results) < page_size:
|
||||
break
|
||||
|
||||
def _add_where_from_query_parameters_raw(
|
||||
def _apply_query_filters(
|
||||
self,
|
||||
query,
|
||||
query: Any,
|
||||
model: type[RentListing] | type[BuyListing],
|
||||
query_parameters: QueryParameters | None = None,
|
||||
):
|
||||
"""Add WHERE clauses from query parameters (for raw SQLAlchemy selects)."""
|
||||
) -> Any:
|
||||
"""Apply WHERE clauses from query parameters to a query.
|
||||
|
||||
Works with both SQLModel select() and raw SQLAlchemy sa_select() queries,
|
||||
since both support the .where() interface.
|
||||
|
||||
Args:
|
||||
query: A SQLModel or SQLAlchemy select query
|
||||
model: The listing model class (RentListing or BuyListing)
|
||||
query_parameters: Optional filtering parameters
|
||||
|
||||
Returns:
|
||||
The query with WHERE clauses applied
|
||||
"""
|
||||
if query_parameters is None:
|
||||
return query
|
||||
query = query.where(
|
||||
|
|
@ -207,38 +206,6 @@ class ListingRepository:
|
|||
query = query.where(model.last_seen >= last_seen_threshold)
|
||||
return query
|
||||
|
||||
def _add_where_from_query_parameters(
|
||||
self,
|
||||
query: SelectOfScalar[Listing],
|
||||
model: type[Listing],
|
||||
query_parameters: QueryParameters | None = None,
|
||||
) -> SelectOfScalar[Listing]:
|
||||
if query_parameters is None:
|
||||
return query
|
||||
query = query.where(
|
||||
model.number_of_bedrooms.between(
|
||||
query_parameters.min_bedrooms, query_parameters.max_bedrooms
|
||||
),
|
||||
model.price.between(query_parameters.min_price, query_parameters.max_price),
|
||||
)
|
||||
if query_parameters.min_sqm is not None:
|
||||
query = query.where(model.square_meters >= query_parameters.min_sqm)
|
||||
if query_parameters.furnish_types and model == RentListing:
|
||||
query = query.where(model.furnish_type.in_(query_parameters.furnish_types))
|
||||
if (
|
||||
isinstance(model, RentListing)
|
||||
and query_parameters.let_date_available_from is not None
|
||||
):
|
||||
query = query.where(
|
||||
model.available_from >= query_parameters.let_date_available_from
|
||||
)
|
||||
if query_parameters.last_seen_days is not None:
|
||||
last_seen_threshold = datetime.now() - timedelta(
|
||||
days=query_parameters.last_seen_days
|
||||
)
|
||||
query = query.where(model.last_seen >= last_seen_threshold)
|
||||
return query
|
||||
|
||||
async def upsert_listings(
|
||||
self,
|
||||
listings: list[modelListing],
|
||||
|
|
@ -258,50 +225,74 @@ class ListingRepository:
|
|||
self,
|
||||
listings: list[Listing],
|
||||
) -> list[modelListing]:
|
||||
"""
|
||||
Upsert listings into the database.
|
||||
"""Upsert legacy Listing objects into the database.
|
||||
|
||||
.. deprecated::
|
||||
This method converts legacy data_access.Listing objects to SQLModel
|
||||
entities. Use upsert_listings() with RentListing/BuyListing directly.
|
||||
|
||||
Legacy Listing objects from filesystem-based data may contain malformed
|
||||
or incomplete data, so conversion errors are logged and skipped rather
|
||||
than aborting the entire batch.
|
||||
"""
|
||||
models = []
|
||||
failed_to_upsert = []
|
||||
with Session(self.engine) as session:
|
||||
for listing in tqdm(listings, desc="Upserting listings"):
|
||||
# Convert Listing to modelListing
|
||||
# Convert legacy Listing to the appropriate SQLModel entity
|
||||
try:
|
||||
model_listing = await self._get_concrete_listing(listing)
|
||||
except Exception as e: # WHY SO MANY ERORRS??
|
||||
# If for whatever reason we cannot add listing, ignore and retry
|
||||
print(f"Error converting listing {listing.identifier}: {e}")
|
||||
except Exception as e:
|
||||
# Legacy Listing -> model conversion may fail for malformed data
|
||||
# (e.g. missing required fields, invalid types). Log and skip.
|
||||
logger.error(f"Error converting listing {listing.identifier}: {e}")
|
||||
failed_to_upsert.append(listing)
|
||||
continue
|
||||
session.merge(model_listing)
|
||||
models.append(model_listing)
|
||||
session.commit()
|
||||
print(f"Failed to upsert {len(failed_to_upsert)} listings.")
|
||||
if failed_to_upsert:
|
||||
logger.warning(f"Failed to upsert {len(failed_to_upsert)} listings.")
|
||||
return models
|
||||
|
||||
@staticmethod
|
||||
def _parse_furnish_type(listing: Listing) -> FurnishType:
|
||||
"""Extract and normalize the furnish type from a legacy Listing's detail object.
|
||||
|
||||
Handles missing/null detailobject, missing property key, missing or null
|
||||
letFurnishType value, and normalizes "landlord" variants to ASK_LANDLORD.
|
||||
|
||||
Args:
|
||||
listing: A legacy data_access.Listing object
|
||||
|
||||
Returns:
|
||||
The parsed FurnishType enum value
|
||||
"""
|
||||
if (
|
||||
listing.detailobject is None
|
||||
or listing.detailobject.get("property") is None
|
||||
or listing.detailobject["property"].get("letFurnishType") is None
|
||||
):
|
||||
return FurnishType.UNKNOWN
|
||||
|
||||
furnish_type_str = listing.detailobject["property"]["letFurnishType"]
|
||||
if furnish_type_str is None:
|
||||
return FurnishType.UNKNOWN
|
||||
elif "landlord" in furnish_type_str.lower():
|
||||
furnish_type_str = "ask landlord"
|
||||
else:
|
||||
furnish_type_str = furnish_type_str.lower()
|
||||
|
||||
return FurnishType(furnish_type_str)
|
||||
|
||||
async def _get_concrete_listing(
|
||||
self,
|
||||
listing: Listing,
|
||||
) -> modelListing:
|
||||
now = datetime.now()
|
||||
furnish_type = self._parse_furnish_type(listing)
|
||||
|
||||
if (
|
||||
listing.detailobject is None
|
||||
or listing.detailobject.get("property") is None
|
||||
or listing.detailobject["property"].get("letFurnishType") is None
|
||||
):
|
||||
furnish_type_str = "unknown"
|
||||
else:
|
||||
furnish_type_str = listing.detailobject["property"]["letFurnishType"]
|
||||
if furnish_type_str is None:
|
||||
furnish_type_str = "unknown"
|
||||
elif "landlord" in furnish_type_str.lower():
|
||||
furnish_type_str = "ask landlord"
|
||||
else:
|
||||
furnish_type_str = furnish_type_str.lower()
|
||||
furnish_type = FurnishType(furnish_type_str)
|
||||
|
||||
if listing.price < self.buy_listing_price_threshold:
|
||||
if listing.price < self.BUY_LISTING_PRICE_THRESHOLD:
|
||||
model_listing = RentListing(
|
||||
id=listing.identifier,
|
||||
price=listing.price,
|
||||
|
|
|
|||
|
|
@ -24,15 +24,14 @@ def get_district_names() -> list[str]:
|
|||
return list(_get_districts().keys())
|
||||
|
||||
|
||||
def validate_districts(district_names: list[str]) -> tuple[bool, list[str]]:
|
||||
def validate_districts(district_names: list[str]) -> list[str]:
|
||||
"""Validate that district names exist.
|
||||
|
||||
Args:
|
||||
district_names: List of district names to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (all_valid, invalid_names)
|
||||
List of invalid district names (empty if all valid)
|
||||
"""
|
||||
valid_districts = set(_get_districts().keys())
|
||||
invalid = [d for d in district_names if d not in valid_districts]
|
||||
return len(invalid) == 0, invalid
|
||||
return [d for d in district_names if d not in valid_districts]
|
||||
|
|
|
|||
|
|
@ -6,12 +6,14 @@ from repositories.listing_repository import ListingRepository
|
|||
from tqdm.asyncio import tqdm
|
||||
import multiprocessing
|
||||
|
||||
# Use a quarter of available CPUs to avoid starving other processes
|
||||
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
|
||||
|
||||
|
||||
async def detect_floorplan(repository: ListingRepository) -> None:
|
||||
"""Detect square meters from floorplan images for all listings."""
|
||||
listings = await repository.get_listings()
|
||||
cpu_count = multiprocessing.cpu_count() // 4
|
||||
semaphore = asyncio.Semaphore(cpu_count)
|
||||
semaphore = asyncio.Semaphore(MAX_OCR_WORKERS)
|
||||
|
||||
updated_listings = [
|
||||
listing
|
||||
|
|
@ -29,6 +31,9 @@ async def _calculate_sqm_ocr(
|
|||
"""Calculate square meters from floorplan images using OCR."""
|
||||
if listing.square_meters is not None:
|
||||
return None
|
||||
if not listing.floorplan_image_paths:
|
||||
listing.square_meters = 0
|
||||
return listing
|
||||
sqms: list[float] = []
|
||||
for floorplan_path in listing.floorplan_image_paths:
|
||||
async with semaphore:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
"""Image fetcher service - downloads floorplan images for listings."""
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from repositories import ListingRepository
|
||||
from tenacity import retry, stop_after_attempt, wait_random
|
||||
|
|
@ -8,8 +11,12 @@ from tqdm.asyncio import tqdm
|
|||
|
||||
from models import Listing
|
||||
|
||||
# Setting this too high either crashes rightmove or gets us blocked
|
||||
semaphore = asyncio.Semaphore(5)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum number of concurrent image downloads.
|
||||
# Setting this too high either crashes Rightmove or gets us blocked.
|
||||
MAX_CONCURRENT_DOWNLOADS = 5
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_DOWNLOADS)
|
||||
|
||||
|
||||
async def dump_images(
|
||||
|
|
@ -18,38 +25,64 @@ async def dump_images(
|
|||
) -> None:
|
||||
"""Download floorplan images for all listings."""
|
||||
listings = await repository.get_listings()
|
||||
updated_listings = await tqdm.gather(
|
||||
*[dump_images_for_listing(listing, image_base_path) for listing in listings]
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
updated_listings = await tqdm.gather(
|
||||
*[
|
||||
dump_images_for_listing(listing, image_base_path, session=session)
|
||||
for listing in listings
|
||||
]
|
||||
)
|
||||
await repository.upsert_listings(
|
||||
[listing for listing in updated_listings if listing is not None]
|
||||
)
|
||||
|
||||
|
||||
@retry(wait=wait_random(min=1, max=2), stop=stop_after_attempt(3))
|
||||
async def dump_images_for_listing(listing: Listing, base_path: Path) -> Listing | None:
|
||||
async def dump_images_for_listing(
|
||||
listing: Listing,
|
||||
base_path: Path,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
) -> Listing | None:
|
||||
"""Download floorplan images for a single listing."""
|
||||
all_floorplans = listing.additional_info.get("property", {}).get("floorplans", [])
|
||||
for floorplan in all_floorplans:
|
||||
url = floorplan["url"]
|
||||
picname = url.split("/")[-1]
|
||||
picname = Path(urlparse(url).path).name
|
||||
floorplan_path = Path(base_path, str(listing.id), "floorplans", picname)
|
||||
if floorplan_path.exists():
|
||||
continue
|
||||
try:
|
||||
async with semaphore:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
owns_session = session is None
|
||||
active_session = session or aiohttp.ClientSession()
|
||||
try:
|
||||
async with semaphore:
|
||||
async with active_session.get(url) as response:
|
||||
if response.status == 404:
|
||||
logger.warning(
|
||||
"Listing %s: floorplan not found (404) at %s",
|
||||
listing.id,
|
||||
url,
|
||||
)
|
||||
return None
|
||||
if response.status != 200:
|
||||
raise Exception(f"Error for {url}: {response.status}")
|
||||
raise Exception(
|
||||
f"Error downloading floorplan for listing {listing.id} "
|
||||
f"from {url}: HTTP {response.status}"
|
||||
)
|
||||
floorplan_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(floorplan_path, "wb") as f:
|
||||
f.write(await response.read())
|
||||
listing.floorplan_image_paths.append(str(floorplan_path))
|
||||
return listing
|
||||
finally:
|
||||
if owns_session:
|
||||
await active_session.close()
|
||||
except Exception as e:
|
||||
tqdm.write(f"Error for {url}: {e}")
|
||||
raise e # raise so that we retry it
|
||||
logger.error(
|
||||
"Listing %s: error downloading floorplan from %s: %s",
|
||||
listing.id,
|
||||
url,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ import json
|
|||
import logging
|
||||
import os
|
||||
from typing import Generator
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import redis
|
||||
|
||||
from models.listing import QueryParameters
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_PREFIX = "listings:geojson:"
|
||||
CACHE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
|
|
@ -19,9 +20,9 @@ CACHE_DB = 2
|
|||
def _get_redis_client() -> redis.Redis:
|
||||
"""Get Redis client using Celery broker URL but overriding to db=2."""
|
||||
broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
# Replace the db number in the URL
|
||||
base_url = broker_url.rsplit("/", 1)[0]
|
||||
return redis.from_url(f"{base_url}/{CACHE_DB}", decode_responses=True)
|
||||
parsed = urlparse(broker_url)
|
||||
cache_url = urlunparse(parsed._replace(path=f"/{CACHE_DB}"))
|
||||
return redis.from_url(cache_url, decode_responses=True)
|
||||
|
||||
|
||||
def make_cache_key(query_params: QueryParameters) -> str:
|
||||
|
|
@ -89,7 +90,10 @@ def invalidate_cache() -> None:
|
|||
while True:
|
||||
cursor, keys = client.scan(cursor, match=f"{CACHE_PREFIX}*", count=100)
|
||||
if keys:
|
||||
client.delete(*keys)
|
||||
pipeline = client.pipeline()
|
||||
for key in keys:
|
||||
pipeline.delete(key)
|
||||
pipeline.execute()
|
||||
deleted += len(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from services.query_splitter import QuerySplitter, SubQuery
|
|||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
# Number of concurrent workers that process listing details (fetch details,
|
||||
# download images, run OCR) from the streaming queue in parallel.
|
||||
NUM_WORKERS = 20
|
||||
|
||||
|
||||
|
|
@ -23,10 +25,104 @@ async def dump_listings_full(
|
|||
"""Fetches all listings, images as well as detects floorplans."""
|
||||
new_listings = await dump_listings(parameters, repository)
|
||||
logger.debug(f"Upserted {len(new_listings)} new listings")
|
||||
# refresh listings
|
||||
listings = await repository.get_listings(parameters) # this can be better
|
||||
new_listings = [x for x in listings if x.id in new_listings]
|
||||
return new_listings
|
||||
new_listing_ids = [listing.id for listing in new_listings]
|
||||
return await repository.get_listings(only_ids=new_listing_ids)
|
||||
|
||||
|
||||
async def _fetch_subquery(
|
||||
sq: SubQuery,
|
||||
parameters: QueryParameters,
|
||||
session: object,
|
||||
config: ScraperConfig,
|
||||
semaphore: asyncio.Semaphore,
|
||||
existing_ids: set[int],
|
||||
queue: asyncio.Queue[int | None],
|
||||
) -> int:
|
||||
"""Fetch listing IDs for a single subquery and enqueue new ones.
|
||||
|
||||
Iterates through pages of results for the given subquery, adding any
|
||||
newly discovered listing IDs to the processing queue.
|
||||
|
||||
Args:
|
||||
sq: The subquery to fetch results for.
|
||||
parameters: The original query parameters (for page_size, etc.).
|
||||
session: The aiohttp session for making requests.
|
||||
config: Scraper configuration.
|
||||
semaphore: Concurrency limiter for HTTP requests.
|
||||
existing_ids: Set of already-known listing IDs (mutated in place).
|
||||
queue: Queue to push new listing IDs onto for processing.
|
||||
|
||||
Returns:
|
||||
The number of new IDs discovered and enqueued.
|
||||
"""
|
||||
estimated = sq.estimated_results or 0
|
||||
if estimated == 0:
|
||||
return 0
|
||||
|
||||
ids_found = 0
|
||||
page_size = parameters.page_size
|
||||
max_pages = min(
|
||||
config.max_pages_per_query,
|
||||
(estimated // page_size) + 1,
|
||||
)
|
||||
|
||||
for page_id in range(1, max_pages + 1):
|
||||
async with semaphore:
|
||||
await asyncio.sleep(config.request_delay_ms / 1000)
|
||||
try:
|
||||
result = await listing_query(
|
||||
page=page_id,
|
||||
channel=parameters.listing_type,
|
||||
min_bedrooms=sq.min_bedrooms,
|
||||
max_bedrooms=sq.max_bedrooms,
|
||||
radius=parameters.radius,
|
||||
min_price=sq.min_price,
|
||||
max_price=sq.max_price,
|
||||
district=sq.district,
|
||||
page_size=page_size,
|
||||
max_days_since_added=parameters.max_days_since_added,
|
||||
furnish_types=parameters.furnish_types or [],
|
||||
session=session,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Extract and enqueue new IDs inline
|
||||
properties = result.get("properties", [])
|
||||
for prop in properties:
|
||||
identifier = prop.get("identifier")
|
||||
if identifier and identifier not in existing_ids:
|
||||
existing_ids.add(identifier)
|
||||
ids_found += 1
|
||||
await queue.put(identifier)
|
||||
|
||||
if len(properties) < page_size:
|
||||
break
|
||||
|
||||
except CircuitBreakerOpenError as e:
|
||||
logger.error(f"Circuit breaker open: {e}")
|
||||
break
|
||||
except ThrottlingError as e:
|
||||
logger.warning(
|
||||
f"Throttling error on page {page_id} for "
|
||||
f"{sq.district}: {e}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
# Rightmove returns GENERIC_ERROR when requesting pages
|
||||
# past the last page of results. This is expected behavior
|
||||
# and signals we've exhausted this subquery's results.
|
||||
if "GENERIC_ERROR" in str(e):
|
||||
logger.debug(
|
||||
f"Max page for {sq.district}: {page_id - 1}"
|
||||
)
|
||||
break
|
||||
logger.warning(
|
||||
f"Error fetching page {page_id} for "
|
||||
f"{sq.district}: {e}"
|
||||
)
|
||||
break
|
||||
|
||||
return ids_found
|
||||
|
||||
|
||||
async def dump_listings(
|
||||
|
|
@ -63,82 +159,23 @@ async def dump_listings(
|
|||
# Phase 2: Streaming fetch & process
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
semaphore = asyncio.Semaphore(config.max_concurrent_requests)
|
||||
ids_collected = 0
|
||||
processed_listings: list[Listing] = []
|
||||
|
||||
async def fetch_subquery(sq: SubQuery) -> None:
|
||||
nonlocal ids_collected
|
||||
|
||||
estimated = sq.estimated_results or 0
|
||||
if estimated == 0:
|
||||
return
|
||||
|
||||
page_size = parameters.page_size
|
||||
max_pages = min(
|
||||
config.max_pages_per_query,
|
||||
(estimated // page_size) + 1,
|
||||
)
|
||||
|
||||
for page_id in range(1, max_pages + 1):
|
||||
async with semaphore:
|
||||
await asyncio.sleep(config.request_delay_ms / 1000)
|
||||
try:
|
||||
result = await listing_query(
|
||||
page=page_id,
|
||||
channel=parameters.listing_type,
|
||||
min_bedrooms=sq.min_bedrooms,
|
||||
max_bedrooms=sq.max_bedrooms,
|
||||
radius=parameters.radius,
|
||||
min_price=sq.min_price,
|
||||
max_price=sq.max_price,
|
||||
district=sq.district,
|
||||
page_size=page_size,
|
||||
max_days_since_added=parameters.max_days_since_added,
|
||||
furnish_types=parameters.furnish_types or [],
|
||||
session=session,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Extract and enqueue new IDs inline
|
||||
properties = result.get("properties", [])
|
||||
for prop in properties:
|
||||
identifier = prop.get("identifier")
|
||||
if identifier and identifier not in existing_ids:
|
||||
existing_ids.add(identifier)
|
||||
ids_collected += 1
|
||||
await queue.put(identifier)
|
||||
|
||||
if len(properties) < page_size:
|
||||
break
|
||||
|
||||
except CircuitBreakerOpenError as e:
|
||||
logger.error(f"Circuit breaker open: {e}")
|
||||
break
|
||||
except ThrottlingError as e:
|
||||
logger.warning(
|
||||
f"Throttling error on page {page_id} for "
|
||||
f"{sq.district}: {e}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
if "GENERIC_ERROR" in str(e):
|
||||
logger.debug(
|
||||
f"Max page for {sq.district}: {page_id - 1}"
|
||||
)
|
||||
break
|
||||
logger.warning(
|
||||
f"Error fetching page {page_id} for "
|
||||
f"{sq.district}: {e}"
|
||||
)
|
||||
break
|
||||
|
||||
async def producer() -> None:
|
||||
await asyncio.gather(
|
||||
*[fetch_subquery(sq) for sq in subqueries]
|
||||
)
|
||||
async def producer() -> int:
|
||||
"""Fetch all subqueries and send sentinel values to workers."""
|
||||
tasks = [
|
||||
_fetch_subquery(
|
||||
sq, parameters, session, config,
|
||||
semaphore, existing_ids, queue,
|
||||
)
|
||||
for sq in subqueries
|
||||
]
|
||||
counts = await asyncio.gather(*tasks)
|
||||
ids_collected = sum(counts)
|
||||
logger.info(f"Fetch complete: {ids_collected} new IDs found")
|
||||
for _ in range(NUM_WORKERS):
|
||||
await queue.put(None)
|
||||
return ids_collected
|
||||
|
||||
async def worker() -> None:
|
||||
while True:
|
||||
|
|
@ -150,10 +187,11 @@ async def dump_listings(
|
|||
if listing is not None:
|
||||
processed_listings.append(listing)
|
||||
|
||||
await asyncio.gather(
|
||||
results = await asyncio.gather(
|
||||
producer(),
|
||||
*[worker() for _ in range(NUM_WORKERS)],
|
||||
)
|
||||
ids_collected = results[0]
|
||||
|
||||
except CircuitBreakerOpenError as e:
|
||||
logger.error(f"Circuit breaker prevented listing fetch: {e}")
|
||||
|
|
|
|||
|
|
@ -6,6 +6,11 @@ from rec import routing
|
|||
from models import Listing
|
||||
|
||||
|
||||
def _parse_duration(duration_str: str) -> int:
|
||||
"""Parse a duration string like '123s' to integer seconds."""
|
||||
return int(duration_str.rstrip("s"))
|
||||
|
||||
|
||||
async def calculate_route(
|
||||
repository: ListingRepository,
|
||||
destination_address: str,
|
||||
|
|
@ -18,9 +23,9 @@ async def calculate_route(
|
|||
if limit is not None:
|
||||
listings = listings[:limit]
|
||||
|
||||
destimation_mode = DestinationMode(destination_address, travel_mode)
|
||||
destination_mode = DestinationMode(destination_address, travel_mode)
|
||||
updated_listings = await tqdm.gather(
|
||||
*[update_routing_info(listing, destimation_mode) for listing in listings],
|
||||
*[update_routing_info(listing, destination_mode) for listing in listings],
|
||||
total=len(listings),
|
||||
desc="Updating routing info",
|
||||
)
|
||||
|
|
@ -46,12 +51,12 @@ async def update_routing_info(
|
|||
|
||||
routes: list[Route] = []
|
||||
for route_data in routes_data["routes"]:
|
||||
duration_s = int(route_data["duration"].split("s")[0])
|
||||
duration_s = _parse_duration(route_data["duration"])
|
||||
route = Route(
|
||||
legs=[
|
||||
RouteLegStep(
|
||||
distance_meters=step_data["distanceMeters"],
|
||||
duration_s=int(step_data["staticDuration"].split("s")[0]),
|
||||
duration_s=_parse_duration(step_data["staticDuration"]),
|
||||
travel_mode=routing.TravelMode(step_data["travelMode"]),
|
||||
)
|
||||
for step_data in route_data["legs"][0]["steps"]
|
||||
|
|
@ -63,4 +68,4 @@ async def update_routing_info(
|
|||
listing.routing_info_json = listing.serialize_routing_info(
|
||||
{**listing.routing_info, **{destination_mode: routes}}
|
||||
)
|
||||
return listing
|
||||
return listing
|
||||
|
|
|
|||
|
|
@ -5,6 +5,16 @@ Manages background task operations using Celery.
|
|||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Standard Celery states; anything else is treated as a custom state
|
||||
# whose name is used as the human-readable status message.
|
||||
_CELERY_STANDARD_STATES = frozenset(
|
||||
{"PENDING", "STARTED", "SUCCESS", "FAILURE", "REVOKED", "RETRY"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -21,6 +31,68 @@ class TaskStatus:
|
|||
traceback: str | None # Full traceback if failed
|
||||
|
||||
|
||||
def _make_system_user(email: str) -> Any:
|
||||
"""Create a minimal User object used only for Redis key generation.
|
||||
|
||||
These are *not* real authenticated users -- they exist solely so that
|
||||
RedisRepository can derive the per-user storage key from the email.
|
||||
"""
|
||||
# Lazy import: api.auth imports from api.app which eventually imports
|
||||
# services, so importing at module level would create a circular dependency.
|
||||
from api.auth import User
|
||||
|
||||
return User(sub="", email=email, name="")
|
||||
|
||||
|
||||
def _extract_result(task_result: Any) -> tuple[Any, str | None]:
|
||||
"""Extract a serialisable result and an error string from a Celery AsyncResult.
|
||||
|
||||
Returns:
|
||||
(result, error) -- exactly one of the two will be non-None (or both None
|
||||
for tasks that haven't produced output yet).
|
||||
"""
|
||||
if task_result.failed():
|
||||
error = str(task_result.result) if task_result.result else None
|
||||
return None, error
|
||||
|
||||
try:
|
||||
result = json.loads(json.dumps(task_result.result))
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
result = str(task_result.result) if task_result.result else None
|
||||
return result, None
|
||||
|
||||
|
||||
def _extract_progress_info(task_result: Any) -> dict[str, Any]:
|
||||
"""Extract progress metadata from a Celery AsyncResult's ``info`` dict.
|
||||
|
||||
Returns a dict with keys ``progress``, ``processed``, ``total``, and
|
||||
``message`` (any of which may be None).
|
||||
"""
|
||||
progress: float | None = None
|
||||
processed: int | None = None
|
||||
total: int | None = None
|
||||
message: str | None = None
|
||||
|
||||
if task_result.info and isinstance(task_result.info, dict):
|
||||
progress = task_result.info.get("progress")
|
||||
processed = task_result.info.get("processed")
|
||||
total = task_result.info.get("total")
|
||||
# Use 'message' if available, fall back to 'reason' for SKIPPED tasks
|
||||
message = task_result.info.get("message") or task_result.info.get("reason")
|
||||
|
||||
# For custom states (like "Fetching listings"), use the state as message
|
||||
# if no message was provided in info
|
||||
if not message and task_result.status not in _CELERY_STANDARD_STATES:
|
||||
message = task_result.status
|
||||
|
||||
return {
|
||||
"progress": progress,
|
||||
"processed": processed,
|
||||
"total": total,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
|
||||
def get_task_status(task_id: str) -> TaskStatus:
|
||||
"""Get the status of a background task.
|
||||
|
||||
|
|
@ -33,55 +105,24 @@ def get_task_status(task_id: str) -> TaskStatus:
|
|||
Returns:
|
||||
TaskStatus with current state
|
||||
"""
|
||||
# Lazy import: listing_tasks imports the Celery app which in turn
|
||||
# pulls in broker configuration; importing at module level would
|
||||
# create a circular dependency chain.
|
||||
from tasks.listing_tasks import dump_listings_task
|
||||
|
||||
task_result = dump_listings_task.AsyncResult(task_id)
|
||||
|
||||
# Try to serialize result
|
||||
result = None
|
||||
error = None
|
||||
if task_result.failed():
|
||||
# Extract error message from failed task
|
||||
error = str(task_result.result) if task_result.result else None
|
||||
else:
|
||||
try:
|
||||
result = json.loads(json.dumps(task_result.result))
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
result = str(task_result.result) if task_result.result else None
|
||||
|
||||
# Extract traceback if available
|
||||
result, error = _extract_result(task_result)
|
||||
task_traceback = task_result.traceback if task_result.failed() else None
|
||||
|
||||
# Extract progress, processed, total, and message from task meta
|
||||
progress = None
|
||||
processed = None
|
||||
total = None
|
||||
message = None
|
||||
|
||||
if task_result.info and isinstance(task_result.info, dict):
|
||||
progress = task_result.info.get("progress")
|
||||
processed = task_result.info.get("processed")
|
||||
total = task_result.info.get("total")
|
||||
# Use 'message' if available, fall back to 'reason' for SKIPPED tasks
|
||||
message = task_result.info.get("message") or task_result.info.get("reason")
|
||||
|
||||
# For custom states (like "Fetching listings"), use the state as message
|
||||
# if no message was provided in info
|
||||
if not message and task_result.status not in (
|
||||
"PENDING", "STARTED", "SUCCESS", "FAILURE", "REVOKED", "RETRY"
|
||||
):
|
||||
message = task_result.status
|
||||
progress_info = _extract_progress_info(task_result)
|
||||
|
||||
return TaskStatus(
|
||||
task_id=task_id,
|
||||
status=task_result.status,
|
||||
result=result,
|
||||
progress=progress,
|
||||
processed=processed,
|
||||
total=total,
|
||||
message=message,
|
||||
error=error,
|
||||
traceback=task_traceback,
|
||||
**progress_info,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -97,12 +138,12 @@ def get_user_tasks(user_email: str) -> list[str]:
|
|||
Returns:
|
||||
List of task IDs
|
||||
"""
|
||||
# Lazy import: RedisRepository depends on redis which may not be
|
||||
# available at import time in all contexts (CLI, tests).
|
||||
from redis_repository import RedisRepository
|
||||
from api.auth import User
|
||||
|
||||
redis_repo = RedisRepository.instance()
|
||||
# Create a minimal User object for the lookup
|
||||
user = User(sub="", email=user_email, name="")
|
||||
user = _make_system_user(user_email)
|
||||
return redis_repo.get_tasks_for_user(user)
|
||||
|
||||
|
||||
|
|
@ -116,11 +157,11 @@ def add_task_for_user(user_email: str, task_id: str) -> None:
|
|||
user_email: The user's email address
|
||||
task_id: The Celery task ID
|
||||
"""
|
||||
# Lazy import: see get_user_tasks for rationale.
|
||||
from redis_repository import RedisRepository
|
||||
from api.auth import User
|
||||
|
||||
redis_repo = RedisRepository.instance()
|
||||
user = User(sub="", email=user_email, name="")
|
||||
user = _make_system_user(user_email)
|
||||
redis_repo.add_task_for_user(user, task_id)
|
||||
|
||||
|
||||
|
|
@ -134,8 +175,10 @@ def cancel_task(task_id: str, user_email: str | None = None) -> bool:
|
|||
Returns:
|
||||
True if task was cancelled successfully
|
||||
"""
|
||||
# Lazy import: celery_app bootstraps the broker connection.
|
||||
from celery_app import app as celery_app
|
||||
|
||||
logger.info("Cancelling task %s (user=%s)", task_id, user_email)
|
||||
# Revoke the task in Celery
|
||||
celery_app.control.revoke(task_id, terminate=True)
|
||||
|
||||
|
|
@ -158,11 +201,11 @@ def remove_task_from_user(user_email: str, task_id: str) -> bool:
|
|||
Returns:
|
||||
True if task was removed, False if not found
|
||||
"""
|
||||
# Lazy import: see get_user_tasks for rationale.
|
||||
from redis_repository import RedisRepository
|
||||
from api.auth import User
|
||||
|
||||
redis_repo = RedisRepository.instance()
|
||||
user = User(sub="", email=user_email, name="")
|
||||
user = _make_system_user(user_email)
|
||||
return redis_repo.remove_task_for_user(user, task_id)
|
||||
|
||||
|
||||
|
|
@ -176,12 +219,14 @@ def clear_all_tasks(user_email: str, revoke: bool = True) -> int:
|
|||
Returns:
|
||||
Number of tasks cleared
|
||||
"""
|
||||
# Lazy imports: see get_user_tasks and cancel_task for rationale.
|
||||
from redis_repository import RedisRepository
|
||||
from celery_app import app as celery_app
|
||||
from api.auth import User
|
||||
|
||||
redis_repo = RedisRepository.instance()
|
||||
user = User(sub="", email=user_email, name="")
|
||||
user = _make_system_user(user_email)
|
||||
|
||||
logger.info("Clearing all tasks for user %s (revoke=%s)", user_email, revoke)
|
||||
|
||||
# Get tasks before clearing to revoke them
|
||||
if revoke:
|
||||
|
|
@ -189,7 +234,9 @@ def clear_all_tasks(user_email: str, revoke: bool = True) -> int:
|
|||
for task_id in tasks:
|
||||
try:
|
||||
celery_app.control.revoke(task_id, terminate=True)
|
||||
except Exception:
|
||||
pass # Best effort, continue clearing
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to revoke task %s: %s", task_id, e
|
||||
)
|
||||
|
||||
return redis_repo.clear_tasks_for_user(user)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from celery import Task
|
||||
from celery.schedules import crontab
|
||||
|
|
@ -34,11 +35,38 @@ if not celery_logger.handlers:
|
|||
SCRAPE_LOCK_NAME = "scrape_listings"
|
||||
LOG_BUFFER_MAX_LINES = 200
|
||||
|
||||
# Number of concurrent consumer workers that process listings from the queue.
|
||||
NUM_WORKERS = 20
|
||||
|
||||
# Phase constants for task state reporting
|
||||
PHASE_SPLITTING = "splitting"
|
||||
PHASE_FETCHING = "fetching"
|
||||
PHASE_PROCESSING = "processing"
|
||||
PHASE_COMPLETED = "completed"
|
||||
|
||||
# Module-level log buffer — active only during task execution.
|
||||
# The TaskLogHandler appends here; _update_task_state reads from here.
|
||||
# This is safe as module-level mutable state because Celery workers use a
|
||||
# prefork pool: each worker process handles one task at a time, so there is
|
||||
# no concurrent access within a single process. The TaskLogHandler appends
|
||||
# here; _update_task_state reads from here.
|
||||
_active_log_buffer: deque[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PipelineState:
|
||||
"""Shared mutable state for the streaming fetch-and-process pipeline."""
|
||||
ids_collected: int = 0
|
||||
completed_subqueries: int = 0
|
||||
total_pages_fetched: int = 0
|
||||
fetching_done: bool = False
|
||||
processed_count: int = 0
|
||||
failed_count: int = 0
|
||||
details_fetched: int = 0
|
||||
images_downloaded: int = 0
|
||||
ocr_completed: int = 0
|
||||
processed_listings: list[Listing] = field(default_factory=list)
|
||||
|
||||
|
||||
class TaskLogHandler(logging.Handler):
|
||||
"""Captures log records into a deque for inclusion in task state updates."""
|
||||
|
||||
|
|
@ -60,34 +88,204 @@ def _update_task_state(task: Task, state: str, meta: dict[str, Any]) -> None:
|
|||
task.update_state(state=state, meta=meta)
|
||||
|
||||
|
||||
async def _fetch_subquery(
|
||||
sq: SubQuery,
|
||||
parameters: QueryParameters,
|
||||
session: object,
|
||||
config: ScraperConfig,
|
||||
semaphore: asyncio.Semaphore,
|
||||
existing_ids: set[int],
|
||||
queue: asyncio.Queue[int | None],
|
||||
state: _PipelineState,
|
||||
) -> None:
|
||||
"""Fetch pages for a single subquery and enqueue new listing IDs."""
|
||||
estimated = sq.estimated_results or 0
|
||||
if estimated == 0:
|
||||
state.completed_subqueries += 1
|
||||
return
|
||||
|
||||
page_size = parameters.page_size
|
||||
max_pages = min(
|
||||
config.max_pages_per_query,
|
||||
(estimated // page_size) + 1,
|
||||
)
|
||||
|
||||
for page_id in range(1, max_pages + 1):
|
||||
async with semaphore:
|
||||
await asyncio.sleep(config.request_delay_ms / 1000)
|
||||
try:
|
||||
result = await listing_query(
|
||||
page=page_id,
|
||||
channel=parameters.listing_type,
|
||||
min_bedrooms=sq.min_bedrooms,
|
||||
max_bedrooms=sq.max_bedrooms,
|
||||
radius=parameters.radius,
|
||||
min_price=sq.min_price,
|
||||
max_price=sq.max_price,
|
||||
district=sq.district,
|
||||
page_size=page_size,
|
||||
max_days_since_added=parameters.max_days_since_added,
|
||||
furnish_types=parameters.furnish_types or [],
|
||||
session=session,
|
||||
config=config,
|
||||
)
|
||||
state.total_pages_fetched += 1
|
||||
|
||||
properties = result.get("properties", [])
|
||||
for prop in properties:
|
||||
identifier = prop.get("identifier")
|
||||
if identifier and identifier not in existing_ids:
|
||||
existing_ids.add(identifier)
|
||||
state.ids_collected += 1
|
||||
await queue.put(identifier)
|
||||
|
||||
if len(properties) < page_size:
|
||||
break
|
||||
|
||||
except CircuitBreakerOpenError as e:
|
||||
celery_logger.error(f"Circuit breaker open: {e}")
|
||||
break
|
||||
except ThrottlingError as e:
|
||||
celery_logger.warning(
|
||||
f"Throttling on {sq.district} page {page_id}: {e}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
if "GENERIC_ERROR" in str(e):
|
||||
celery_logger.debug(
|
||||
f"Max page for {sq.district}: {page_id - 1}"
|
||||
)
|
||||
break
|
||||
celery_logger.warning(
|
||||
f"Error fetching page {page_id} for "
|
||||
f"{sq.district}: {e}"
|
||||
)
|
||||
break
|
||||
|
||||
state.completed_subqueries += 1
|
||||
|
||||
|
||||
async def _process_worker(
|
||||
queue: asyncio.Queue[int | None],
|
||||
processor: ListingProcessor,
|
||||
state: _PipelineState,
|
||||
) -> None:
|
||||
"""Consumer worker: pull listing IDs from the queue and process them."""
|
||||
while True:
|
||||
listing_id = await queue.get()
|
||||
if listing_id is None:
|
||||
break
|
||||
|
||||
def step_callback(step_name: str) -> None:
|
||||
if step_name == "details":
|
||||
state.details_fetched += 1
|
||||
elif step_name == "images":
|
||||
state.images_downloaded += 1
|
||||
elif step_name == "ocr":
|
||||
state.ocr_completed += 1
|
||||
|
||||
listing = await processor.process_listing(
|
||||
listing_id, on_step_complete=step_callback
|
||||
)
|
||||
if listing is not None:
|
||||
state.processed_count += 1
|
||||
state.processed_listings.append(listing)
|
||||
else:
|
||||
state.failed_count += 1
|
||||
|
||||
|
||||
async def _monitor_progress(
|
||||
task: Task,
|
||||
state: _PipelineState,
|
||||
subqueries_total: int,
|
||||
start_time: float,
|
||||
) -> None:
|
||||
"""Periodically report pipeline progress via task state updates."""
|
||||
last_progress = 0.0
|
||||
|
||||
while True:
|
||||
total = state.ids_collected
|
||||
done = state.processed_count + state.failed_count
|
||||
|
||||
if state.fetching_done and done >= total and total > 0:
|
||||
break
|
||||
if state.fetching_done and total == 0:
|
||||
break
|
||||
|
||||
phase = PHASE_PROCESSING if state.fetching_done else PHASE_FETCHING
|
||||
|
||||
if total > 0:
|
||||
progress_ratio = round(done / total, 2)
|
||||
else:
|
||||
progress_ratio = 0.0
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
rate = done / elapsed if elapsed > 0 else 0
|
||||
remaining = (total - done) if total > 0 else 0
|
||||
eta = remaining / rate if rate > 0 else 0
|
||||
|
||||
if progress_ratio >= last_progress + 0.1 or done == 1:
|
||||
celery_logger.info(
|
||||
f"Progress: {progress_ratio * 100:.0f}% "
|
||||
f"({done}/{total}) "
|
||||
f"| Elapsed: {elapsed:.0f}s "
|
||||
f"| Rate: {rate:.1f}/s "
|
||||
f"| ETA: {eta:.0f}s"
|
||||
)
|
||||
last_progress = progress_ratio
|
||||
|
||||
_update_task_state(
|
||||
task,
|
||||
f"{'Processing' if state.fetching_done else 'Fetching & processing'}: "
|
||||
f"{done}/{total}",
|
||||
{
|
||||
"phase": phase,
|
||||
"progress": progress_ratio,
|
||||
"processed": done,
|
||||
"total": total,
|
||||
"subqueries_completed": state.completed_subqueries,
|
||||
"subqueries_total": subqueries_total,
|
||||
"ids_collected": state.ids_collected,
|
||||
"pages_fetched": state.total_pages_fetched,
|
||||
"fetching_done": state.fetching_done,
|
||||
"details_fetched": state.details_fetched,
|
||||
"images_downloaded": state.images_downloaded,
|
||||
"ocr_completed": state.ocr_completed,
|
||||
"failed": state.failed_count,
|
||||
"elapsed_seconds": round(elapsed, 1),
|
||||
"rate_per_second": round(rate, 2),
|
||||
"eta_seconds": round(eta, 1),
|
||||
},
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@app.task(bind=True, pydantic=True)
|
||||
def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
|
||||
with redis_lock(SCRAPE_LOCK_NAME) as acquired:
|
||||
if not acquired:
|
||||
msg = "Another scrape job is already running, skipping this execution"
|
||||
logger.warning(msg)
|
||||
celery_logger.warning(msg)
|
||||
self.update_state(state="SKIPPED", meta={"reason": "Another scrape job is running"})
|
||||
return {"status": "skipped", "reason": "another_job_running"}
|
||||
|
||||
celery_logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}")
|
||||
logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}")
|
||||
|
||||
parsed_parameters = QueryParameters.model_validate_json(parameters_json)
|
||||
celery_logger.info(f"Starting scrape with parameters: {parsed_parameters}")
|
||||
|
||||
self.update_state(state="Starting...", meta={"phase": "splitting", "progress": 0})
|
||||
self.update_state(state="Starting...", meta={"phase": PHASE_SPLITTING, "progress": 0})
|
||||
asyncio.run(dump_listings_full(task=self, parameters=parsed_parameters))
|
||||
return {"phase": "completed", "progress": 1}
|
||||
return {"phase": PHASE_COMPLETED, "progress": 1}
|
||||
|
||||
|
||||
async def async_dump_listings_task(parameters_json: str) -> dict[str, Any]:
|
||||
with redis_lock(SCRAPE_LOCK_NAME) as acquired:
|
||||
if not acquired:
|
||||
logger.warning("Another scrape job is already running, skipping this execution")
|
||||
celery_logger.warning("Another scrape job is already running, skipping this execution")
|
||||
return {"status": "skipped", "reason": "another_job_running"}
|
||||
|
||||
logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}")
|
||||
celery_logger.info(f"Acquired lock: {SCRAPE_LOCK_NAME}")
|
||||
parsed_parameters = QueryParameters.model_validate_json(parameters_json)
|
||||
await dump_listings_full(task=Task(), parameters=parsed_parameters)
|
||||
return {"progress": 0}
|
||||
|
|
@ -141,17 +339,16 @@ async def _dump_listings_full_inner(
|
|||
soon as IDs become available from each subquery.
|
||||
"""
|
||||
start_time = time.time()
|
||||
NUM_WORKERS = 20
|
||||
state = _PipelineState()
|
||||
|
||||
celery_logger.info("=" * 60)
|
||||
celery_logger.info("PHASE 1: Splitting queries")
|
||||
celery_logger.info(f"PHASE 1: Splitting queries")
|
||||
celery_logger.info("=" * 60)
|
||||
|
||||
repository = ListingRepository(engine)
|
||||
config = ScraperConfig.from_env()
|
||||
splitter = QuerySplitter(config)
|
||||
|
||||
# Reset throttle metrics
|
||||
reset_throttle_metrics()
|
||||
|
||||
def on_progress(phase: str, message: str, **kwargs: Any) -> None:
|
||||
|
|
@ -161,7 +358,7 @@ async def _dump_listings_full_inner(
|
|||
celery_logger.info(f"[{phase}] {message}")
|
||||
|
||||
_update_task_state(task, "Analyzing query and splitting by price bands...", {
|
||||
"phase": "splitting", "progress": 0,
|
||||
"phase": PHASE_SPLITTING, "progress": 0,
|
||||
})
|
||||
celery_logger.info("Starting query splitting and probing...")
|
||||
|
||||
|
|
@ -175,34 +372,22 @@ async def _dump_listings_full_inner(
|
|||
f"~{total_estimated} estimated total results"
|
||||
)
|
||||
|
||||
# Load existing IDs (fast, ID-only projection)
|
||||
celery_logger.info("Loading existing listing IDs from database...")
|
||||
existing_ids = repository.get_listing_ids(parameters.listing_type)
|
||||
celery_logger.info(f"Found {len(existing_ids)} existing listings in DB")
|
||||
|
||||
celery_logger.info("=" * 60)
|
||||
celery_logger.info("PHASE 2: Streaming fetch & process")
|
||||
celery_logger.info(f"PHASE 2: Streaming fetch & process")
|
||||
celery_logger.info("=" * 60)
|
||||
|
||||
# Shared state for the streaming pipeline
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
ids_collected = 0
|
||||
completed_subqueries = 0
|
||||
total_pages_fetched = 0
|
||||
fetching_done = False
|
||||
processed_count = 0
|
||||
failed_count = 0
|
||||
details_fetched = 0
|
||||
images_downloaded = 0
|
||||
ocr_completed = 0
|
||||
processed_listings: list[Listing] = []
|
||||
semaphore = asyncio.Semaphore(config.max_concurrent_requests)
|
||||
|
||||
_update_task_state(
|
||||
task,
|
||||
f"Fetching listings from {len(subqueries)} subqueries...",
|
||||
{
|
||||
"phase": "fetching",
|
||||
"phase": PHASE_FETCHING,
|
||||
"subqueries_completed": 0,
|
||||
"subqueries_total": len(subqueries),
|
||||
"ids_collected": 0,
|
||||
|
|
@ -214,190 +399,32 @@ async def _dump_listings_full_inner(
|
|||
|
||||
listing_processor = ListingProcessor(repository)
|
||||
|
||||
# --- Producer: fetch subquery pages and enqueue new IDs ---
|
||||
# Producer: fetch all subqueries concurrently, then signal workers to stop
|
||||
async def producer() -> None:
|
||||
nonlocal ids_collected, completed_subqueries, total_pages_fetched
|
||||
nonlocal fetching_done
|
||||
|
||||
async def fetch_subquery(sq: SubQuery) -> None:
|
||||
nonlocal ids_collected, completed_subqueries, total_pages_fetched
|
||||
|
||||
estimated = sq.estimated_results or 0
|
||||
if estimated == 0:
|
||||
completed_subqueries += 1
|
||||
return
|
||||
|
||||
page_size = parameters.page_size
|
||||
max_pages = min(
|
||||
config.max_pages_per_query,
|
||||
(estimated // page_size) + 1,
|
||||
)
|
||||
|
||||
for page_id in range(1, max_pages + 1):
|
||||
async with semaphore:
|
||||
await asyncio.sleep(config.request_delay_ms / 1000)
|
||||
try:
|
||||
result = await listing_query(
|
||||
page=page_id,
|
||||
channel=parameters.listing_type,
|
||||
min_bedrooms=sq.min_bedrooms,
|
||||
max_bedrooms=sq.max_bedrooms,
|
||||
radius=parameters.radius,
|
||||
min_price=sq.min_price,
|
||||
max_price=sq.max_price,
|
||||
district=sq.district,
|
||||
page_size=page_size,
|
||||
max_days_since_added=parameters.max_days_since_added,
|
||||
furnish_types=parameters.furnish_types or [],
|
||||
session=session,
|
||||
config=config,
|
||||
)
|
||||
total_pages_fetched += 1
|
||||
|
||||
# Extract and enqueue new IDs inline
|
||||
properties = result.get("properties", [])
|
||||
for prop in properties:
|
||||
identifier = prop.get("identifier")
|
||||
if identifier and identifier not in existing_ids:
|
||||
existing_ids.add(identifier)
|
||||
ids_collected += 1
|
||||
await queue.put(identifier)
|
||||
|
||||
if len(properties) < page_size:
|
||||
break
|
||||
|
||||
except CircuitBreakerOpenError as e:
|
||||
celery_logger.error(f"Circuit breaker open: {e}")
|
||||
break
|
||||
except ThrottlingError as e:
|
||||
celery_logger.warning(
|
||||
f"Throttling on {sq.district} page {page_id}: {e}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
if "GENERIC_ERROR" in str(e):
|
||||
logger.debug(
|
||||
f"Max page for {sq.district}: {page_id - 1}"
|
||||
)
|
||||
break
|
||||
logger.warning(
|
||||
f"Error fetching page {page_id} for "
|
||||
f"{sq.district}: {e}"
|
||||
)
|
||||
break
|
||||
|
||||
completed_subqueries += 1
|
||||
|
||||
# Fetch all subqueries concurrently
|
||||
await asyncio.gather(
|
||||
*[fetch_subquery(sq) for sq in subqueries]
|
||||
*[
|
||||
_fetch_subquery(
|
||||
sq, parameters, session, config,
|
||||
semaphore, existing_ids, queue, state,
|
||||
)
|
||||
for sq in subqueries
|
||||
]
|
||||
)
|
||||
|
||||
celery_logger.info(
|
||||
f"Fetch complete: {total_pages_fetched} pages from "
|
||||
f"{completed_subqueries} subqueries, "
|
||||
f"{ids_collected} new IDs"
|
||||
f"Fetch complete: {state.total_pages_fetched} pages from "
|
||||
f"{state.completed_subqueries} subqueries, "
|
||||
f"{state.ids_collected} new IDs"
|
||||
)
|
||||
fetching_done = True
|
||||
state.fetching_done = True
|
||||
|
||||
# Send sentinel values to stop workers
|
||||
for _ in range(NUM_WORKERS):
|
||||
await queue.put(None)
|
||||
|
||||
# --- Consumer workers: process listings from queue ---
|
||||
async def worker() -> None:
|
||||
nonlocal processed_count, failed_count
|
||||
nonlocal details_fetched, images_downloaded, ocr_completed
|
||||
|
||||
while True:
|
||||
listing_id = await queue.get()
|
||||
if listing_id is None:
|
||||
break
|
||||
|
||||
def step_callback(step_name: str) -> None:
|
||||
nonlocal details_fetched, images_downloaded, ocr_completed
|
||||
if step_name == "details":
|
||||
details_fetched += 1
|
||||
elif step_name == "images":
|
||||
images_downloaded += 1
|
||||
elif step_name == "ocr":
|
||||
ocr_completed += 1
|
||||
|
||||
listing = await listing_processor.process_listing(
|
||||
listing_id, on_step_complete=step_callback
|
||||
)
|
||||
if listing is not None:
|
||||
processed_count += 1
|
||||
processed_listings.append(listing)
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
# --- Monitor: reports combined progress ---
|
||||
async def monitor() -> None:
|
||||
last_progress = 0.0
|
||||
|
||||
while True:
|
||||
total = ids_collected
|
||||
done = processed_count + failed_count
|
||||
|
||||
if fetching_done and done >= total and total > 0:
|
||||
break
|
||||
if fetching_done and total == 0:
|
||||
break
|
||||
|
||||
# Determine phase label
|
||||
phase = "processing" if fetching_done else "fetching"
|
||||
|
||||
if total > 0:
|
||||
progress_ratio = round(done / total, 2)
|
||||
else:
|
||||
progress_ratio = 0.0
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
rate = done / elapsed if elapsed > 0 else 0
|
||||
remaining = (total - done) if total > 0 else 0
|
||||
eta = remaining / rate if rate > 0 else 0
|
||||
|
||||
if progress_ratio >= last_progress + 0.1 or done == 1:
|
||||
celery_logger.info(
|
||||
f"Progress: {progress_ratio * 100:.0f}% "
|
||||
f"({done}/{total}) "
|
||||
f"| Elapsed: {elapsed:.0f}s "
|
||||
f"| Rate: {rate:.1f}/s "
|
||||
f"| ETA: {eta:.0f}s"
|
||||
)
|
||||
last_progress = progress_ratio
|
||||
|
||||
_update_task_state(
|
||||
task,
|
||||
f"{'Processing' if fetching_done else 'Fetching & processing'}: "
|
||||
f"{done}/{total}",
|
||||
{
|
||||
"phase": phase,
|
||||
"progress": progress_ratio,
|
||||
"processed": done,
|
||||
"total": total,
|
||||
"subqueries_completed": completed_subqueries,
|
||||
"subqueries_total": len(subqueries),
|
||||
"ids_collected": ids_collected,
|
||||
"pages_fetched": total_pages_fetched,
|
||||
"fetching_done": fetching_done,
|
||||
"details_fetched": details_fetched,
|
||||
"images_downloaded": images_downloaded,
|
||||
"ocr_completed": ocr_completed,
|
||||
"failed": failed_count,
|
||||
"elapsed_seconds": round(elapsed, 1),
|
||||
"rate_per_second": round(rate, 2),
|
||||
"eta_seconds": round(eta, 1),
|
||||
},
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Run producer, workers, and monitor concurrently
|
||||
await asyncio.gather(
|
||||
producer(),
|
||||
*[worker() for _ in range(NUM_WORKERS)],
|
||||
monitor(),
|
||||
*[_process_worker(queue, listing_processor, state) for _ in range(NUM_WORKERS)],
|
||||
_monitor_progress(task, state, len(subqueries), start_time),
|
||||
)
|
||||
|
||||
except CircuitBreakerOpenError as e:
|
||||
|
|
@ -418,19 +445,19 @@ async def _dump_listings_full_inner(
|
|||
elapsed = time.time() - start_time
|
||||
celery_logger.info("=" * 60)
|
||||
celery_logger.info(
|
||||
f"COMPLETED: Processed {len(processed_listings)} listings in {elapsed:.1f}s"
|
||||
f"COMPLETED: Processed {len(state.processed_listings)} listings in {elapsed:.1f}s"
|
||||
)
|
||||
celery_logger.info("=" * 60)
|
||||
|
||||
invalidate_cache()
|
||||
|
||||
_update_task_state(task, "Completed", {
|
||||
"phase": "completed", "progress": 1,
|
||||
"processed": len(processed_listings), "total": ids_collected,
|
||||
"message": f"Processed {len(processed_listings)} listings in {elapsed:.0f}s",
|
||||
"phase": PHASE_COMPLETED, "progress": 1,
|
||||
"processed": len(state.processed_listings), "total": state.ids_collected,
|
||||
"message": f"Processed {len(state.processed_listings)} listings in {elapsed:.0f}s",
|
||||
})
|
||||
|
||||
return processed_listings
|
||||
return state.processed_listings
|
||||
|
||||
|
||||
@app.on_after_finalize.connect
|
||||
|
|
@ -439,11 +466,11 @@ def setup_periodic_tasks(sender, **kwargs):
|
|||
try:
|
||||
config = SchedulesConfig.from_env()
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to load schedule configuration: {e}")
|
||||
celery_logger.error(f"Failed to load schedule configuration: {e}")
|
||||
return
|
||||
|
||||
for schedule in config.get_enabled_schedules():
|
||||
logger.info(
|
||||
celery_logger.info(
|
||||
f"Registering periodic task: {schedule.name} at {schedule.hour}:{schedule.minute}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Integration tests for API endpoints."""
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
|
@ -75,10 +76,12 @@ class TestListingGeoJsonEndpoint:
|
|||
self, async_client: AsyncClient
|
||||
) -> None:
|
||||
"""Test that listing_geojson accepts filter parameters."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = {"type": "FeatureCollection", "features": []}
|
||||
with patch(
|
||||
"api.app.export_immoweb",
|
||||
"api.app.export_service.export_to_geojson",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"type": "FeatureCollection", "features": []},
|
||||
return_value=mock_result,
|
||||
):
|
||||
response = await async_client.get(
|
||||
"/api/listing_geojson",
|
||||
|
|
@ -178,3 +181,135 @@ class TestTaskStatusEndpoint:
|
|||
)
|
||||
# Should return 401 or 403 without valid auth
|
||||
assert response.status_code in (401, 403)
|
||||
|
||||
|
||||
class TestStreamListingGeoJsonEndpoint:
|
||||
"""Tests for the /api/listing_geojson/stream endpoint."""
|
||||
|
||||
async def test_stream_returns_ndjson_with_metadata(
|
||||
self, async_client: AsyncClient
|
||||
) -> None:
|
||||
"""Test that the stream endpoint returns valid NDJSON starting with a metadata message."""
|
||||
fake_features = [
|
||||
{"type": "Feature", "properties": {"id": 1}, "geometry": {"type": "Point", "coordinates": [0, 0]}},
|
||||
{"type": "Feature", "properties": {"id": 2}, "geometry": {"type": "Point", "coordinates": [1, 1]}},
|
||||
]
|
||||
|
||||
with patch("api.app.get_cached_count", return_value=2), \
|
||||
patch("api.app.get_cached_features", return_value=iter([fake_features])):
|
||||
response = await async_client.get(
|
||||
"/api/listing_geojson/stream",
|
||||
params={"listing_type": "RENT", "batch_size": 50},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/x-ndjson"
|
||||
|
||||
lines = [line for line in response.text.strip().split("\n") if line]
|
||||
assert len(lines) >= 2 # at least metadata + complete
|
||||
|
||||
metadata = json.loads(lines[0])
|
||||
assert metadata["type"] == "metadata"
|
||||
assert "batch_size" in metadata
|
||||
assert "total_expected" in metadata
|
||||
|
||||
complete = json.loads(lines[-1])
|
||||
assert complete["type"] == "complete"
|
||||
assert "total" in complete
|
||||
|
||||
async def test_stream_cache_hit_path(
|
||||
self, async_client: AsyncClient
|
||||
) -> None:
|
||||
"""Test that cache-hit path returns cached: True in metadata."""
|
||||
fake_features = [
|
||||
{"type": "Feature", "properties": {"id": 1}, "geometry": {"type": "Point", "coordinates": [0, 0]}},
|
||||
]
|
||||
|
||||
with patch("api.app.get_cached_count", return_value=1), \
|
||||
patch("api.app.get_cached_features", return_value=iter([fake_features])):
|
||||
response = await async_client.get(
|
||||
"/api/listing_geojson/stream",
|
||||
params={"listing_type": "RENT"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
lines = [line for line in response.text.strip().split("\n") if line]
|
||||
metadata = json.loads(lines[0])
|
||||
assert metadata["cached"] is True
|
||||
assert metadata["total_expected"] == 1
|
||||
|
||||
batch_msg = json.loads(lines[1])
|
||||
assert batch_msg["type"] == "batch"
|
||||
assert len(batch_msg["features"]) == 1
|
||||
|
||||
async def test_stream_cache_miss_path(
|
||||
self, async_client: AsyncClient
|
||||
) -> None:
|
||||
"""Test that cache-miss path queries DB and returns cached: False."""
|
||||
from datetime import datetime
|
||||
|
||||
fake_rows = [
|
||||
{
|
||||
"id": 100,
|
||||
"price": 2000.0,
|
||||
"number_of_bedrooms": 2,
|
||||
"square_meters": 50.0,
|
||||
"longitude": -0.1,
|
||||
"latitude": 51.5,
|
||||
"photo_thumbnail": None,
|
||||
"last_seen": datetime(2024, 1, 1),
|
||||
"agency": "Test Agency",
|
||||
"price_history_json": "[]",
|
||||
"available_from": None,
|
||||
},
|
||||
]
|
||||
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.count_listings.return_value = 1
|
||||
mock_repo.stream_listings_optimized.return_value = iter(fake_rows)
|
||||
|
||||
with patch("api.app.get_cached_count", return_value=None), \
|
||||
patch("api.app.ListingRepository", return_value=mock_repo), \
|
||||
patch("api.app.cache_features_batch"):
|
||||
response = await async_client.get(
|
||||
"/api/listing_geojson/stream",
|
||||
params={"listing_type": "RENT"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
lines = [line for line in response.text.strip().split("\n") if line]
|
||||
|
||||
metadata = json.loads(lines[0])
|
||||
assert metadata["cached"] is False
|
||||
assert metadata["total_expected"] == 1
|
||||
|
||||
batch_msg = json.loads(lines[1])
|
||||
assert batch_msg["type"] == "batch"
|
||||
assert len(batch_msg["features"]) == 1
|
||||
assert batch_msg["features"][0]["type"] == "Feature"
|
||||
assert batch_msg["features"][0]["properties"]["total_price"] == 2000.0
|
||||
|
||||
complete = json.loads(lines[-1])
|
||||
assert complete["type"] == "complete"
|
||||
assert complete["total"] == 1
|
||||
|
||||
async def test_stream_with_limit(
|
||||
self, async_client: AsyncClient
|
||||
) -> None:
|
||||
"""Test that the limit parameter caps the number of streamed features."""
|
||||
fake_features = [
|
||||
{"type": "Feature", "properties": {"id": i}, "geometry": {"type": "Point", "coordinates": [0, 0]}}
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
with patch("api.app.get_cached_count", return_value=5), \
|
||||
patch("api.app.get_cached_features", return_value=iter([fake_features])):
|
||||
response = await async_client.get(
|
||||
"/api/listing_geojson/stream",
|
||||
params={"listing_type": "RENT", "limit": 3},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
lines = [line for line in response.text.strip().split("\n") if line]
|
||||
metadata = json.loads(lines[0])
|
||||
assert metadata["total_expected"] == 3
|
||||
|
||||
complete = json.loads(lines[-1])
|
||||
assert complete["type"] == "complete"
|
||||
assert complete["total"] == 3
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class TestThrottlingRetryBehavior:
|
|||
"""Test that 429 responses trigger retry with backoff."""
|
||||
call_count = 0
|
||||
|
||||
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
|
|
@ -117,7 +117,7 @@ class TestThrottlingRetryBehavior:
|
|||
"""Test that 503 responses trigger retry."""
|
||||
call_count = 0
|
||||
|
||||
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 2:
|
||||
|
|
@ -157,7 +157,7 @@ class TestCircuitBreakerIntegration:
|
|||
"""Test that circuit breaker opens after consecutive failures."""
|
||||
call_count = 0
|
||||
|
||||
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return MockResponse(status=429)
|
||||
|
|
@ -223,14 +223,14 @@ class TestMetricsTracking:
|
|||
@pytest.mark.asyncio
|
||||
async def test_metrics_tracked_on_rate_limit(self, config: ScraperConfig) -> None:
|
||||
"""Test that rate limit errors are tracked in metrics."""
|
||||
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
return MockResponse(status=429)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = mock_get
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
|
||||
with pytest.raises(RateLimitError):
|
||||
with pytest.raises((RateLimitError, CircuitBreakerOpenError)):
|
||||
with patch("tenacity.wait_exponential.__call__", return_value=0):
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
|
|
@ -250,7 +250,7 @@ class TestMetricsTracking:
|
|||
@pytest.mark.asyncio
|
||||
async def test_metrics_tracked_on_success(self, config: ScraperConfig) -> None:
|
||||
"""Test that successful requests are tracked in metrics."""
|
||||
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
return MockResponse(
|
||||
status=200,
|
||||
json_data={"totalAvailableResults": 10, "properties": []},
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ class TestListingGeoJsonEndpoint:
|
|||
|
||||
# Override auth dependency
|
||||
async def mock_auth():
|
||||
return User(email="test@example.com", name="Test User")
|
||||
return User(sub="test-id", email="test@example.com", name="Test User")
|
||||
|
||||
app.dependency_overrides[get_current_user] = mock_auth
|
||||
yield TestClient(app)
|
||||
|
|
|
|||
151
crawler/tests/unit/test_auth.py
Normal file
151
crawler/tests/unit/test_auth.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
"""Unit tests for api/auth.py."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
from api.auth import (
|
||||
User,
|
||||
_verify_passkey_token,
|
||||
_verify_authentik_token,
|
||||
get_current_user,
|
||||
)
|
||||
from api.config import JWT_SECRET, JWT_ALGORITHM, JWT_ISSUER
|
||||
|
||||
|
||||
def _make_passkey_token(
|
||||
sub: str = "user-123",
|
||||
email: str = "test@example.com",
|
||||
name: str = "Test User",
|
||||
issuer: str = JWT_ISSUER,
|
||||
secret: str = JWT_SECRET,
|
||||
algorithm: str = JWT_ALGORITHM,
|
||||
expires_delta: timedelta | None = timedelta(hours=1),
|
||||
) -> str:
|
||||
"""Helper to mint a passkey-style HS256 JWT."""
|
||||
payload: dict = {"sub": sub, "email": email, "name": name, "iss": issuer}
|
||||
if expires_delta is not None:
|
||||
payload["exp"] = datetime.now(timezone.utc) + expires_delta
|
||||
return pyjwt.encode(payload, secret, algorithm=algorithm)
|
||||
|
||||
|
||||
class TestVerifyPasskeyToken:
|
||||
"""Tests for _verify_passkey_token()."""
|
||||
|
||||
def test_valid_token_returns_user(self) -> None:
|
||||
token = _make_passkey_token()
|
||||
user = _verify_passkey_token(token)
|
||||
assert isinstance(user, User)
|
||||
assert user.sub == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
assert user.name == "Test User"
|
||||
|
||||
def test_valid_token_without_name_uses_email(self) -> None:
|
||||
payload = {
|
||||
"sub": "user-456",
|
||||
"email": "noname@example.com",
|
||||
"iss": JWT_ISSUER,
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
}
|
||||
token = pyjwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
||||
user = _verify_passkey_token(token)
|
||||
assert user.name == "noname@example.com"
|
||||
|
||||
def test_rejects_expired_token(self) -> None:
|
||||
token = _make_passkey_token(expires_delta=timedelta(hours=-1))
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
_verify_passkey_token(token)
|
||||
|
||||
def test_rejects_wrong_secret(self) -> None:
|
||||
token = _make_passkey_token(secret="wrong-secret")
|
||||
with pytest.raises(pyjwt.InvalidSignatureError):
|
||||
_verify_passkey_token(token)
|
||||
|
||||
def test_rejects_wrong_issuer(self) -> None:
|
||||
token = _make_passkey_token(issuer="some-other-issuer")
|
||||
with pytest.raises(pyjwt.InvalidIssuerError):
|
||||
_verify_passkey_token(token)
|
||||
|
||||
|
||||
class TestVerifyAuthentikToken:
|
||||
"""Tests for _verify_authentik_token() — specifically that expiration is verified."""
|
||||
|
||||
async def test_verifies_expiration_after_fix(self) -> None:
|
||||
"""After removing verify_exp: False, expired Authentik tokens should be rejected."""
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
public_key = private_key.public_key()
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
issuer = "https://authentik.viktorbarzin.me/application/o/wrongmove/"
|
||||
payload = {
|
||||
"sub": "authentik-user",
|
||||
"email": "auth@example.com",
|
||||
"name": "Auth User",
|
||||
"iss": issuer,
|
||||
"aud": "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe",
|
||||
"exp": datetime.now(timezone.utc) - timedelta(hours=1), # expired
|
||||
}
|
||||
token = pyjwt.encode(payload, private_key, algorithm="RS256")
|
||||
|
||||
# Build a real PyJWK-compatible signing key mock so jwt.decode
|
||||
# takes the PyJWK code path (uses key.key directly, skips prepare_key)
|
||||
mock_signing_key = MagicMock(spec=pyjwt.PyJWK)
|
||||
mock_signing_key.key = public_key
|
||||
mock_signing_key.algorithm_name = "RS256"
|
||||
mock_signing_key.Algorithm = pyjwt.get_algorithm_by_name("RS256")
|
||||
|
||||
mock_jwks_client = MagicMock()
|
||||
mock_jwks_client.get_signing_key_from_jwt.return_value = mock_signing_key
|
||||
|
||||
mock_metadata = {
|
||||
"issuer": issuer,
|
||||
"jwks_uri": f"{issuer}jwks/",
|
||||
}
|
||||
|
||||
with patch("api.auth.get_oidc_metadata", new_callable=AsyncMock, return_value=mock_metadata), \
|
||||
patch("api.auth.get_cached_jwks_client", new_callable=AsyncMock, return_value=mock_jwks_client):
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
await _verify_authentik_token(token)
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for get_current_user()."""
|
||||
|
||||
async def test_routes_to_passkey_verifier_for_matching_issuer(self) -> None:
|
||||
token = _make_passkey_token()
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
user = await get_current_user(credentials)
|
||||
assert user.sub == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
async def test_routes_to_authentik_for_other_issuer(self) -> None:
|
||||
"""When issuer != JWT_ISSUER, should route to Authentik verifier."""
|
||||
token = _make_passkey_token(issuer="https://authentik.viktorbarzin.me/application/o/wrongmove/")
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
mock_user = User(sub="authentik-user", email="auth@example.com", name="Auth User")
|
||||
with patch("api.auth._verify_authentik_token", new_callable=AsyncMock, return_value=mock_user):
|
||||
user = await get_current_user(credentials)
|
||||
assert user.email == "auth@example.com"
|
||||
|
||||
async def test_raises_http_exception_for_invalid_token(self) -> None:
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="not.a.valid.token")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
async def test_raises_http_exception_for_garbage_token(self) -> None:
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="totalgarbage")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
388
crawler/tests/unit/test_cli.py
Normal file
388
crawler/tests/unit/test_cli.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
"""Characterization and unit tests for the CLI (main.py)."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import click
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from models.listing import FurnishType, ListingType, QueryParameters
|
||||
from main import build_query_parameters, cli, listing_filter_options
|
||||
|
||||
|
||||
class TestBuildQueryParameters:
|
||||
"""Tests for build_query_parameters()."""
|
||||
|
||||
def test_typical_rent_inputs(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London", "Camden"],
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=4,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
furnish_types=["FURNISHED"],
|
||||
available_from=datetime(2025, 6, 1),
|
||||
last_seen_days=7,
|
||||
min_sqm=50,
|
||||
)
|
||||
assert qp.listing_type == ListingType.RENT
|
||||
assert qp.district_names == {"London", "Camden"}
|
||||
assert qp.min_bedrooms == 2
|
||||
assert qp.max_bedrooms == 4
|
||||
assert qp.min_price == 1000
|
||||
assert qp.max_price == 3000
|
||||
assert qp.furnish_types == [FurnishType.FURNISHED]
|
||||
assert qp.let_date_available_from == datetime(2025, 6, 1)
|
||||
assert qp.last_seen_days == 7
|
||||
assert qp.min_sqm == 50
|
||||
|
||||
def test_typical_buy_inputs(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="BUY",
|
||||
district=["Barnet"],
|
||||
min_bedrooms=3,
|
||||
max_bedrooms=5,
|
||||
min_price=200000,
|
||||
max_price=500000,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.listing_type == ListingType.BUY
|
||||
assert qp.district_names == {"Barnet"}
|
||||
assert qp.furnish_types is None
|
||||
assert qp.let_date_available_from is None
|
||||
assert qp.min_sqm is None
|
||||
|
||||
def test_empty_districts_yields_empty_set(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=[],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.district_names == set()
|
||||
|
||||
def test_none_districts_yields_empty_set(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=None,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.district_names == set()
|
||||
|
||||
def test_furnish_types_conversion(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London"],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=["FURNISHED", "UNFURNISHED"],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.furnish_types == [FurnishType.FURNISHED, FurnishType.UNFURNISHED]
|
||||
|
||||
def test_empty_furnish_types_yields_none(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London"],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.furnish_types is None
|
||||
|
||||
def test_default_optional_parameters(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London"],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.radius == 0
|
||||
assert qp.page_size == 500
|
||||
assert qp.max_days_since_added == 14
|
||||
|
||||
|
||||
class TestListingFilterOptionsDecorator:
|
||||
"""Tests for the listing_filter_options decorator."""
|
||||
|
||||
def test_applies_all_expected_options(self) -> None:
|
||||
@click.command()
|
||||
@listing_filter_options
|
||||
def dummy_cmd(**kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
expected_option_names = {
|
||||
"type",
|
||||
"min_bedrooms",
|
||||
"max_bedrooms",
|
||||
"min_price",
|
||||
"max_price",
|
||||
"district",
|
||||
"furnish_types",
|
||||
"available_from",
|
||||
"last_seen_days",
|
||||
"min_sqm",
|
||||
}
|
||||
param_names = {p.name for p in dummy_cmd.params}
|
||||
assert expected_option_names.issubset(param_names), (
|
||||
f"Missing options: {expected_option_names - param_names}"
|
||||
)
|
||||
|
||||
def test_type_option_is_required(self) -> None:
|
||||
@click.command()
|
||||
@listing_filter_options
|
||||
def dummy_cmd(**kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
type_param = next(p for p in dummy_cmd.params if p.name == "type")
|
||||
assert type_param.required is True
|
||||
|
||||
def test_produces_query_parameters_kwarg(self) -> None:
|
||||
"""After refactoring, the decorator should produce a query_parameters kwarg."""
|
||||
captured: dict = {}
|
||||
|
||||
@click.command()
|
||||
@listing_filter_options
|
||||
def dummy_cmd(query_parameters: QueryParameters) -> None:
|
||||
captured["qp"] = query_parameters
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(dummy_cmd, ["--type", "RENT"])
|
||||
assert result.exit_code == 0, f"Command failed: {result.output}"
|
||||
assert isinstance(captured["qp"], QueryParameters)
|
||||
assert captured["qp"].listing_type == ListingType.RENT
|
||||
|
||||
|
||||
class TestDumpListingsCommand:
|
||||
"""Tests for the dump-listings CLI command."""
|
||||
|
||||
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_calls_refresh_listings_with_correct_params(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_refresh: AsyncMock,
|
||||
) -> None:
|
||||
from services.listing_service import RefreshResult
|
||||
|
||||
mock_refresh.return_value = RefreshResult(
|
||||
task_id=None,
|
||||
new_listings_count=5,
|
||||
message="Fetched 5 new listings",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"dump-listings",
|
||||
"--type", "RENT",
|
||||
"--min-bedrooms", "2",
|
||||
"--max-bedrooms", "4",
|
||||
"--min-price", "1000",
|
||||
"--max-price", "3000",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_refresh.assert_called_once()
|
||||
call_args = mock_refresh.call_args
|
||||
qp: QueryParameters = call_args.args[1]
|
||||
assert qp.listing_type == ListingType.RENT
|
||||
assert qp.min_bedrooms == 2
|
||||
assert qp.max_bedrooms == 4
|
||||
assert qp.min_price == 1000
|
||||
assert qp.max_price == 3000
|
||||
assert call_args.kwargs.get("full") is not True
|
||||
|
||||
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_include_processing_flag_passes_full_true(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_refresh: AsyncMock,
|
||||
) -> None:
|
||||
from services.listing_service import RefreshResult
|
||||
|
||||
mock_refresh.return_value = RefreshResult(
|
||||
task_id=None,
|
||||
new_listings_count=0,
|
||||
message="Fetched 0 new listings",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"dump-listings",
|
||||
"--type", "RENT",
|
||||
"--include-processing",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_refresh.assert_called_once()
|
||||
call_kwargs = mock_refresh.call_args.kwargs
|
||||
assert call_kwargs.get("full") is True
|
||||
|
||||
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_include_processing_short_flag(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_refresh: AsyncMock,
|
||||
) -> None:
|
||||
from services.listing_service import RefreshResult
|
||||
|
||||
mock_refresh.return_value = RefreshResult(
|
||||
task_id=None,
|
||||
new_listings_count=0,
|
||||
message="Fetched 0 new listings",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"dump-listings",
|
||||
"--type", "RENT",
|
||||
"-p",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_refresh.assert_called_once()
|
||||
call_kwargs = mock_refresh.call_args.kwargs
|
||||
assert call_kwargs.get("full") is True
|
||||
|
||||
|
||||
class TestExportCsvCommand:
|
||||
"""Tests for the export-csv CLI command."""
|
||||
|
||||
@patch("main.export_service.export_to_csv", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_calls_export_to_csv(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_export: AsyncMock,
|
||||
) -> None:
|
||||
from services.export_service import ExportResult
|
||||
|
||||
mock_export.return_value = ExportResult(
|
||||
success=True,
|
||||
output_path="/tmp/test.csv",
|
||||
data=None,
|
||||
record_count=10,
|
||||
message="Exported 10 listings to /tmp/test.csv",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"export-csv",
|
||||
"--output-file", "/tmp/test.csv",
|
||||
"--type", "RENT",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_export.assert_called_once()
|
||||
call_args = mock_export.call_args
|
||||
qp = call_args[0][2]
|
||||
assert qp.listing_type == ListingType.RENT
|
||||
|
||||
|
||||
class TestExportImmowebCommand:
|
||||
"""Tests for the export-immoweb CLI command."""
|
||||
|
||||
@patch("main.export_service.export_to_geojson", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_calls_export_to_geojson(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_export: AsyncMock,
|
||||
) -> None:
|
||||
from services.export_service import ExportResult
|
||||
|
||||
mock_export.return_value = ExportResult(
|
||||
success=True,
|
||||
output_path="/tmp/test.geojson",
|
||||
data=None,
|
||||
record_count=5,
|
||||
message="Exported 5 listings to /tmp/test.geojson",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"export-immoweb",
|
||||
"--output-file", "/tmp/test.geojson",
|
||||
"--type", "RENT",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_export.assert_called_once()
|
||||
|
||||
|
||||
class TestListDistrictsCommand:
|
||||
"""Tests for the list-districts CLI command."""
|
||||
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_outputs_district_names(self, mock_engine: MagicMock) -> None:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["list-districts"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "London" in result.output
|
||||
assert "Camden" in result.output
|
||||
assert "Available districts" in result.output
|
||||
|
||||
|
||||
class TestRoutingCommand:
|
||||
"""Tests for the routing CLI command."""
|
||||
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_requires_api_key_env_var(self, mock_engine: MagicMock) -> None:
|
||||
runner = CliRunner(env={"ROUTING_API_KEY": None})
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"routing",
|
||||
"--destination-address", "London Bridge",
|
||||
"--travel-mode", "transit",
|
||||
"--limit", "1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "ROUTING_API_KEY" in result.output
|
||||
62
crawler/tests/unit/test_districts.py
Normal file
62
crawler/tests/unit/test_districts.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""Unit tests for rec/districts.py and services/district_service.py."""
|
||||
from rec.districts import get_districts, get_district_by_name
|
||||
from services.district_service import get_all_districts, get_district_names, validate_districts
|
||||
|
||||
|
||||
class TestGetDistricts:
|
||||
def test_returns_non_empty_dict(self) -> None:
|
||||
districts = get_districts()
|
||||
assert isinstance(districts, dict)
|
||||
assert len(districts) > 0
|
||||
|
||||
def test_values_start_with_region_prefix(self) -> None:
|
||||
for name, region_id in get_districts().items():
|
||||
assert region_id.startswith("REGION^"), (
|
||||
f"District '{name}' has value '{region_id}' that doesn't start with REGION^"
|
||||
)
|
||||
|
||||
def test_contains_expected_london_boroughs(self) -> None:
|
||||
districts = get_districts()
|
||||
for borough in ("Camden", "Westminster", "Hackney"):
|
||||
assert borough in districts, f"Expected borough '{borough}' not found"
|
||||
|
||||
|
||||
class TestGetDistrictByName:
|
||||
def test_valid_name_returns_region_id(self) -> None:
|
||||
result = get_district_by_name("Camden")
|
||||
assert result == "REGION^93941"
|
||||
|
||||
def test_invalid_name_returns_none(self) -> None:
|
||||
result = get_district_by_name("Nonexistent District")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetDistrictNames:
|
||||
def test_returns_list_matching_dict_keys(self) -> None:
|
||||
names = get_district_names()
|
||||
assert isinstance(names, list)
|
||||
assert names == list(get_districts().keys())
|
||||
|
||||
|
||||
class TestGetAllDistricts:
|
||||
def test_returns_same_as_get_districts(self) -> None:
|
||||
assert get_all_districts() == get_districts()
|
||||
|
||||
|
||||
class TestValidateDistricts:
|
||||
def test_all_valid_returns_empty_list(self) -> None:
|
||||
result = validate_districts(["Camden", "Westminster", "Hackney"])
|
||||
assert result == []
|
||||
|
||||
def test_some_invalid_returns_invalid_ones(self) -> None:
|
||||
result = validate_districts(["Camden", "Faketown", "Westminster", "Nowhere"])
|
||||
assert result == ["Faketown", "Nowhere"]
|
||||
|
||||
def test_all_invalid_returns_all(self) -> None:
|
||||
invalid = ["Faketown", "Nowhere", "Neverland"]
|
||||
result = validate_districts(invalid)
|
||||
assert result == invalid
|
||||
|
||||
def test_empty_list_returns_empty_list(self) -> None:
|
||||
result = validate_districts([])
|
||||
assert result == []
|
||||
104
crawler/tests/unit/test_floorplan.py
Normal file
104
crawler/tests/unit/test_floorplan.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Unit tests for rec/floorplan.py."""
|
||||
from unittest.mock import patch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import pytest
|
||||
|
||||
from rec.floorplan import extract_total_sqm, improve_img_for_ocr, calculate_ocr
|
||||
|
||||
|
||||
class TestExtractTotalSqm:
|
||||
|
||||
def test_normal_value(self) -> None:
|
||||
assert extract_total_sqm("Total area: 75.5 sq m") == 75.5
|
||||
|
||||
def test_multiple_values_returns_max_in_range(self) -> None:
|
||||
assert extract_total_sqm("Room 1: 20 sqm, Total: 65 sq m") == 65.0
|
||||
|
||||
def test_no_match_returns_none(self) -> None:
|
||||
assert extract_total_sqm("No area info") is None
|
||||
|
||||
def test_below_minimum_returns_none(self) -> None:
|
||||
assert extract_total_sqm("Area: 15 sq m") is None
|
||||
|
||||
def test_above_maximum_returns_none(self) -> None:
|
||||
assert extract_total_sqm("Area: 200 sq m") is None
|
||||
|
||||
def test_edge_just_above_min(self) -> None:
|
||||
assert extract_total_sqm("Area: 30.1 sq m") == 30.1
|
||||
|
||||
def test_edge_just_below_max(self) -> None:
|
||||
assert extract_total_sqm("Area: 159.9 sq m") == 159.9
|
||||
|
||||
def test_exactly_at_min_boundary_returns_none(self) -> None:
|
||||
# MIN_SQM < sqm, so 30 is not strictly greater than 30
|
||||
assert extract_total_sqm("Area: 30 sq m") is None
|
||||
|
||||
def test_exactly_at_max_boundary_returns_none(self) -> None:
|
||||
# sqm < MAX_SQM, so 160 is not strictly less than 160
|
||||
assert extract_total_sqm("Area: 160 sq m") is None
|
||||
|
||||
def test_format_sq_dot_m(self) -> None:
|
||||
assert extract_total_sqm("Area: 80 sq. m") == 80.0
|
||||
|
||||
def test_format_sqm_no_space(self) -> None:
|
||||
assert extract_total_sqm("Area: 80sqm") == 80.0
|
||||
|
||||
def test_format_sq_m_with_space(self) -> None:
|
||||
assert extract_total_sqm("Area: 80 sq m") == 80.0
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
assert extract_total_sqm("") is None
|
||||
|
||||
def test_multiple_valid_values_returns_max(self) -> None:
|
||||
assert extract_total_sqm("Living: 40 sq m, Total: 100 sq m") == 100.0
|
||||
|
||||
|
||||
class TestImproveImgForOcr:
|
||||
|
||||
def test_produces_valid_pil_image(self) -> None:
|
||||
# Create a small test image (50x50 white image)
|
||||
img = Image.fromarray(np.ones((50, 50, 3), dtype=np.uint8) * 200)
|
||||
result = improve_img_for_ocr(img)
|
||||
assert isinstance(result, Image.Image)
|
||||
# Result should be a grayscale (thresholded) image
|
||||
assert result.mode == "L"
|
||||
|
||||
def test_output_dimensions_scaled(self) -> None:
|
||||
img = Image.fromarray(np.ones((100, 100, 3), dtype=np.uint8) * 128)
|
||||
result = improve_img_for_ocr(img)
|
||||
# After 1.2x resize, 100 -> 120
|
||||
assert result.size[0] == 120
|
||||
assert result.size[1] == 120
|
||||
|
||||
|
||||
class TestCalculateOcr:
|
||||
|
||||
def test_invalid_path_raises_file_not_found(self) -> None:
|
||||
with pytest.raises(FileNotFoundError):
|
||||
calculate_ocr("/nonexistent/path/to/image.png")
|
||||
|
||||
def test_returns_sqm_from_first_pass(self, tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
# Create a real image file so the path check passes
|
||||
image_file = tmp_path / "test.png"
|
||||
Image.fromarray(np.ones((10, 10, 3), dtype=np.uint8)).save(str(image_file))
|
||||
|
||||
with patch("pytesseract.image_to_string", return_value="Total: 85 sq m"):
|
||||
result_sqm, result_text = calculate_ocr(str(image_file))
|
||||
|
||||
assert result_sqm == 85.0
|
||||
assert result_text == "Total: 85 sq m"
|
||||
|
||||
def test_falls_back_to_improved_image(self, tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
image_file = tmp_path / "test.png"
|
||||
Image.fromarray(np.ones((10, 10, 3), dtype=np.uint8)).save(str(image_file))
|
||||
|
||||
# First call returns no sqm data, second (on improved image) returns valid data
|
||||
with patch("pytesseract.image_to_string", side_effect=[
|
||||
"No area info here",
|
||||
"Total: 72 sq m",
|
||||
]):
|
||||
result_sqm, result_text = calculate_ocr(str(image_file))
|
||||
|
||||
assert result_sqm == 72.0
|
||||
assert result_text == "Total: 72 sq m"
|
||||
110
crawler/tests/unit/test_floorplan_detector.py
Normal file
110
crawler/tests/unit/test_floorplan_detector.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""Unit tests for services/floorplan_detector.py."""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from models.listing import RentListing, ListingSite, FurnishType
|
||||
from services.floorplan_detector import _calculate_sqm_ocr, detect_floorplan
|
||||
|
||||
|
||||
def _make_listing(**kwargs) -> RentListing: # type: ignore[no-untyped-def]
|
||||
defaults = dict(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=None,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return RentListing(**defaults)
|
||||
|
||||
|
||||
class TestCalculateSqmOcr:
|
||||
|
||||
async def test_skips_listing_with_existing_square_meters(self) -> None:
|
||||
listing = _make_listing(square_meters=50.0)
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is None
|
||||
|
||||
async def test_empty_floorplan_paths_returns_listing_with_zero(self) -> None:
|
||||
listing = _make_listing(floorplan_image_paths=[])
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 0
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_with_mocked_ocr_returning_value(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.return_value = (85.0, "Total: 85 sq m")
|
||||
listing = _make_listing(floorplan_image_paths=["/fake/path.png"])
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 85.0
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_with_mocked_ocr_returning_none(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.return_value = (None, "no data")
|
||||
listing = _make_listing(floorplan_image_paths=["/fake/path.png"])
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 0
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_picks_max_from_multiple_floorplans(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.side_effect = [
|
||||
(50.0, "50 sq m"),
|
||||
(90.0, "90 sq m"),
|
||||
]
|
||||
listing = _make_listing(floorplan_image_paths=["/fake/a.png", "/fake/b.png"])
|
||||
semaphore = asyncio.Semaphore(2)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 90.0
|
||||
|
||||
|
||||
class TestDetectFloorplan:
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_detect_floorplan_with_mocked_repository(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.return_value = (75.0, "75 sq m")
|
||||
|
||||
listing = _make_listing(
|
||||
floorplan_image_paths=["/fake/path.png"],
|
||||
)
|
||||
repository = MagicMock()
|
||||
repository.get_listings = AsyncMock(return_value=[listing])
|
||||
repository.upsert_listings = AsyncMock(return_value=[])
|
||||
|
||||
await detect_floorplan(repository)
|
||||
|
||||
repository.upsert_listings.assert_called_once()
|
||||
upserted = repository.upsert_listings.call_args[0][0]
|
||||
assert len(upserted) == 1
|
||||
assert upserted[0].square_meters == 75.0
|
||||
|
||||
async def test_detect_floorplan_skips_already_processed(self) -> None:
|
||||
listing = _make_listing(square_meters=50.0)
|
||||
repository = MagicMock()
|
||||
repository.get_listings = AsyncMock(return_value=[listing])
|
||||
repository.upsert_listings = AsyncMock(return_value=[])
|
||||
|
||||
await detect_floorplan(repository)
|
||||
|
||||
repository.upsert_listings.assert_called_once()
|
||||
upserted = repository.upsert_listings.call_args[0][0]
|
||||
assert len(upserted) == 0
|
||||
215
crawler/tests/unit/test_image_fetcher.py
Normal file
215
crawler/tests/unit/test_image_fetcher.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
"""Unit tests for the image fetcher service."""
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from tenacity import stop_after_attempt
|
||||
|
||||
from models.listing import RentListing, ListingSite, FurnishType
|
||||
from services.image_fetcher import dump_images_for_listing, MAX_CONCURRENT_DOWNLOADS
|
||||
|
||||
|
||||
def _make_listing(**kwargs) -> RentListing: # type: ignore[no-untyped-def]
|
||||
"""Create a RentListing with sensible defaults for testing."""
|
||||
defaults = dict(
|
||||
id=12345,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=None,
|
||||
agency="Test Agency",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={
|
||||
"property": {
|
||||
"visible": True,
|
||||
"floorplans": [
|
||||
{"url": "https://media.rightmove.co.uk/imgs/floorplan_1.jpg"}
|
||||
],
|
||||
}
|
||||
},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return RentListing(**defaults)
|
||||
|
||||
|
||||
class TestDumpImagesForListing:
|
||||
"""Tests for dump_images_for_listing function."""
|
||||
|
||||
async def test_downloads_floorplan_image(self, tmp_path: Path) -> None:
|
||||
"""Test successful floorplan image download."""
|
||||
listing = _make_listing()
|
||||
image_bytes = b"\x89PNG fake image data"
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read = AsyncMock(return_value=image_bytes)
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == 12345
|
||||
assert len(result.floorplan_image_paths) == 1
|
||||
# Verify the image was written
|
||||
written_path = Path(result.floorplan_image_paths[0])
|
||||
assert written_path.exists()
|
||||
assert written_path.read_bytes() == image_bytes
|
||||
|
||||
async def test_skips_existing_images(self, tmp_path: Path) -> None:
|
||||
"""Test that existing images are not re-downloaded."""
|
||||
listing = _make_listing()
|
||||
# Pre-create the floorplan file
|
||||
floorplan_dir = tmp_path / str(listing.id) / "floorplans"
|
||||
floorplan_dir.mkdir(parents=True)
|
||||
existing_file = floorplan_dir / "floorplan_1.jpg"
|
||||
existing_file.write_bytes(b"existing image")
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
# Should return None because the only floorplan was skipped (continue)
|
||||
assert result is None
|
||||
# Session.get should NOT have been called
|
||||
mock_session.get.assert_not_called()
|
||||
|
||||
async def test_returns_none_on_404(self, tmp_path: Path) -> None:
|
||||
"""Test that 404 responses return None (image not found)."""
|
||||
listing = _make_listing()
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 404
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_raises_on_non_200_status(self, tmp_path: Path) -> None:
|
||||
"""Test that non-200/404 status raises exception."""
|
||||
listing = _make_listing()
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 500
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
with pytest.raises(Exception, match="HTTP 500"):
|
||||
# Disable tenacity retry for testing: stop after 1 attempt and reraise
|
||||
await dump_images_for_listing.retry_with(
|
||||
stop=stop_after_attempt(1),
|
||||
reraise=True,
|
||||
)(listing, tmp_path, session=mock_session)
|
||||
|
||||
async def test_returns_none_when_no_floorplans(self, tmp_path: Path) -> None:
|
||||
"""Test listing with no floorplans returns None."""
|
||||
listing = _make_listing(
|
||||
additional_info={"property": {"visible": True, "floorplans": []}}
|
||||
)
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_url_filename_extraction(self, tmp_path: Path) -> None:
|
||||
"""Test that filenames are correctly extracted from URLs."""
|
||||
listing = _make_listing(
|
||||
additional_info={
|
||||
"property": {
|
||||
"visible": True,
|
||||
"floorplans": [
|
||||
{
|
||||
"url": "https://media.rightmove.co.uk/dir/sub/my_floorplan.png"
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
image_bytes = b"fake png"
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read = AsyncMock(return_value=image_bytes)
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
written_path = Path(result.floorplan_image_paths[0])
|
||||
assert written_path.name == "my_floorplan.png"
|
||||
|
||||
async def test_creates_session_when_none_provided(self, tmp_path: Path) -> None:
|
||||
"""Test that a session is created and closed when none is provided."""
|
||||
listing = _make_listing()
|
||||
image_bytes = b"fake image"
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read = AsyncMock(return_value=image_bytes)
|
||||
|
||||
mock_session_instance = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session_instance.get = MagicMock(return_value=mock_cm)
|
||||
mock_session_instance.close = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"services.image_fetcher.aiohttp.ClientSession",
|
||||
return_value=mock_session_instance,
|
||||
):
|
||||
result = await dump_images_for_listing(listing, tmp_path, session=None)
|
||||
|
||||
assert result is not None
|
||||
mock_session_instance.close.assert_awaited_once()
|
||||
|
||||
|
||||
class TestImageFetcherConfig:
|
||||
"""Tests for image fetcher configuration."""
|
||||
|
||||
def test_max_concurrent_downloads_constant(self) -> None:
|
||||
"""Test that MAX_CONCURRENT_DOWNLOADS is defined and reasonable."""
|
||||
assert MAX_CONCURRENT_DOWNLOADS > 0
|
||||
assert MAX_CONCURRENT_DOWNLOADS <= 20
|
||||
225
crawler/tests/unit/test_listing_cache.py
Normal file
225
crawler/tests/unit/test_listing_cache.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
"""Unit tests for services/listing_cache.py."""
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
from models.listing import ListingType, QueryParameters
|
||||
from services.listing_cache import (
|
||||
CACHE_PREFIX,
|
||||
_get_redis_client,
|
||||
cache_features_batch,
|
||||
get_cached_count,
|
||||
get_cached_features,
|
||||
invalidate_cache,
|
||||
make_cache_key,
|
||||
)
|
||||
|
||||
|
||||
def _make_query(**overrides) -> QueryParameters:
|
||||
"""Create a QueryParameters with defaults for testing."""
|
||||
defaults = {"listing_type": ListingType.RENT, "min_price": 1000, "max_price": 3000}
|
||||
defaults.update(overrides)
|
||||
return QueryParameters(**defaults)
|
||||
|
||||
|
||||
class TestMakeCacheKey:
|
||||
"""Tests for make_cache_key()."""
|
||||
|
||||
def test_deterministic_for_same_params(self):
|
||||
"""Same parameters produce the same cache key."""
|
||||
qp = _make_query()
|
||||
assert make_cache_key(qp) == make_cache_key(qp)
|
||||
|
||||
def test_different_for_different_params(self):
|
||||
"""Different parameters produce different cache keys."""
|
||||
qp1 = _make_query(min_price=1000)
|
||||
qp2 = _make_query(min_price=2000)
|
||||
assert make_cache_key(qp1) != make_cache_key(qp2)
|
||||
|
||||
def test_key_starts_with_prefix(self):
|
||||
"""Cache key starts with CACHE_PREFIX."""
|
||||
qp = _make_query()
|
||||
assert make_cache_key(qp).startswith(CACHE_PREFIX)
|
||||
|
||||
|
||||
class TestGetRedisClient:
|
||||
"""Tests for _get_redis_client() URL parsing."""
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_default_broker_url(self, mock_redis):
|
||||
"""Uses default localhost URL when env var is not set."""
|
||||
with mock.patch.dict("os.environ", {}, clear=True):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://localhost:6379/2", decode_responses=True
|
||||
)
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_custom_broker_url(self, mock_redis):
|
||||
"""Replaces db number from custom broker URL."""
|
||||
with mock.patch.dict(
|
||||
"os.environ", {"CELERY_BROKER_URL": "redis://myhost:1234/5"}
|
||||
):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://myhost:1234/2", decode_responses=True
|
||||
)
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_broker_url_with_password(self, mock_redis):
|
||||
"""Preserves auth info in broker URL."""
|
||||
with mock.patch.dict(
|
||||
"os.environ",
|
||||
{"CELERY_BROKER_URL": "redis://:secret@myhost:6379/0"},
|
||||
):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://:secret@myhost:6379/2", decode_responses=True
|
||||
)
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_broker_url_with_query_params(self, mock_redis):
|
||||
"""Preserves query parameters in broker URL."""
|
||||
with mock.patch.dict(
|
||||
"os.environ",
|
||||
{"CELERY_BROKER_URL": "redis://myhost:6379/0?timeout=5"},
|
||||
):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://myhost:6379/2?timeout=5", decode_responses=True
|
||||
)
|
||||
|
||||
|
||||
class TestGetCachedCount:
|
||||
"""Tests for get_cached_count()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_returns_none_on_redis_error(self, mock_get_client):
|
||||
"""Returns None when Redis raises an error."""
|
||||
mock_get_client.side_effect = redis.RedisError("connection refused")
|
||||
|
||||
result = get_cached_count(_make_query())
|
||||
assert result is None
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_returns_none_when_key_not_exists(self, mock_get_client):
|
||||
"""Returns None when the cache key does not exist."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.exists.return_value = False
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = get_cached_count(_make_query())
|
||||
assert result is None
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_returns_count_when_key_exists(self, mock_get_client):
|
||||
"""Returns list length when key exists."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.exists.return_value = True
|
||||
mock_client.llen.return_value = 42
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = get_cached_count(_make_query())
|
||||
assert result == 42
|
||||
|
||||
|
||||
class TestGetCachedFeatures:
|
||||
"""Tests for get_cached_features()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_yields_empty_on_redis_error(self, mock_get_client):
|
||||
"""Yields nothing when Redis raises an error."""
|
||||
mock_get_client.side_effect = redis.RedisError("connection refused")
|
||||
|
||||
batches = list(get_cached_features(_make_query()))
|
||||
assert batches == []
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_yields_batches(self, mock_get_client):
|
||||
"""Yields features in batches."""
|
||||
features = [{"type": "Feature", "id": i} for i in range(3)]
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.llen.return_value = 3
|
||||
mock_client.lrange.return_value = [json.dumps(f) for f in features]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
batches = list(get_cached_features(_make_query(), batch_size=50))
|
||||
assert len(batches) == 1
|
||||
assert batches[0] == features
|
||||
|
||||
|
||||
class TestCacheFeaturesBatch:
|
||||
"""Tests for cache_features_batch()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_empty_features_returns_early(self, mock_get_client):
|
||||
"""Does not call Redis when features list is empty."""
|
||||
cache_features_batch(_make_query(), [])
|
||||
mock_get_client.assert_not_called()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_writes_features_via_pipeline(self, mock_get_client):
|
||||
"""Writes features and sets TTL through pipeline."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_pipeline = mock.MagicMock()
|
||||
mock_client.pipeline.return_value = mock_pipeline
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
features = [{"type": "Feature", "id": 1}]
|
||||
cache_features_batch(_make_query(), features)
|
||||
|
||||
mock_pipeline.rpush.assert_called_once()
|
||||
mock_pipeline.expire.assert_called_once()
|
||||
mock_pipeline.execute.assert_called_once()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_handles_redis_error(self, mock_get_client):
|
||||
"""Handles Redis error gracefully during write."""
|
||||
mock_get_client.side_effect = redis.RedisError("write error")
|
||||
|
||||
# Should not raise
|
||||
cache_features_batch(_make_query(), [{"id": 1}])
|
||||
|
||||
|
||||
class TestInvalidateCache:
|
||||
"""Tests for invalidate_cache()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_handles_redis_error(self, mock_get_client):
|
||||
"""Handles Redis error gracefully during invalidation."""
|
||||
mock_get_client.side_effect = redis.RedisError("connection refused")
|
||||
|
||||
# Should not raise
|
||||
invalidate_cache()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_deletes_matching_keys_via_pipeline(self, mock_get_client):
|
||||
"""Deletes keys matching the cache prefix using pipeline."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_pipeline = mock.MagicMock()
|
||||
mock_client.pipeline.return_value = mock_pipeline
|
||||
# Simulate one scan iteration that returns keys, then done
|
||||
mock_client.scan.return_value = (0, ["listings:geojson:abc", "listings:geojson:def"])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
invalidate_cache()
|
||||
|
||||
assert mock_pipeline.delete.call_count == 2
|
||||
mock_pipeline.execute.assert_called_once()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_no_keys_to_delete(self, mock_get_client):
|
||||
"""Does nothing when no cache keys exist."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.scan.return_value = (0, [])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
invalidate_cache()
|
||||
|
||||
mock_client.pipeline.assert_not_called()
|
||||
372
crawler/tests/unit/test_listing_fetcher.py
Normal file
372
crawler/tests/unit/test_listing_fetcher.py
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
"""Unit tests for the listing fetcher service."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.listing import ListingType, QueryParameters
|
||||
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
|
||||
from services.listing_fetcher import (
|
||||
NUM_WORKERS,
|
||||
_fetch_subquery,
|
||||
dump_listings,
|
||||
dump_listings_full,
|
||||
)
|
||||
from services.query_splitter import SubQuery
|
||||
|
||||
|
||||
def _make_subquery(**kwargs) -> SubQuery:
|
||||
"""Create a SubQuery with sensible defaults for testing."""
|
||||
defaults = dict(
|
||||
district="REGION^123",
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=3,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
estimated_results=50,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return SubQuery(**defaults)
|
||||
|
||||
|
||||
class TestDumpListingsFull:
|
||||
"""Tests for dump_listings_full."""
|
||||
|
||||
async def test_returns_empty_list_when_no_new_listings(self) -> None:
|
||||
"""Test that empty results from dump_listings returns empty list."""
|
||||
with patch(
|
||||
"services.listing_fetcher.dump_listings",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listings = AsyncMock(return_value=[])
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
result = await dump_listings_full(params, mock_repo)
|
||||
assert result == []
|
||||
|
||||
async def test_returns_only_new_listings_from_db(self) -> None:
|
||||
"""Test that dump_listings_full fetches new listings by ID from the repository."""
|
||||
mock_listing_1 = MagicMock()
|
||||
mock_listing_1.id = 100
|
||||
mock_listing_2 = MagicMock()
|
||||
mock_listing_2.id = 200
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.dump_listings",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_listing_1, mock_listing_2],
|
||||
):
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listings = AsyncMock(
|
||||
return_value=[mock_listing_1, mock_listing_2]
|
||||
)
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
result = await dump_listings_full(params, mock_repo)
|
||||
|
||||
# Verify get_listings was called with the correct IDs
|
||||
mock_repo.get_listings.assert_awaited_once_with(
|
||||
only_ids=[100, 200]
|
||||
)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestFetchSubquery:
|
||||
"""Tests for _fetch_subquery helper."""
|
||||
|
||||
async def test_skips_subquery_with_zero_estimated_results(self) -> None:
|
||||
"""Test that subqueries with 0 estimated results are skipped."""
|
||||
sq = _make_subquery(estimated_results=0)
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=MagicMock(),
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_skips_subquery_with_none_estimated_results(self) -> None:
|
||||
"""Test that subqueries with None estimated results are skipped."""
|
||||
sq = _make_subquery(estimated_results=None)
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=MagicMock(),
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_enqueues_new_ids_only(self) -> None:
|
||||
"""Test that only new (not existing) IDs are enqueued."""
|
||||
sq = _make_subquery(estimated_results=10)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
existing_ids: set[int] = {101, 103}
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
mock_config.max_concurrent_requests = 5
|
||||
|
||||
api_result = {
|
||||
"properties": [
|
||||
{"identifier": 101}, # existing
|
||||
{"identifier": 102}, # new
|
||||
{"identifier": 103}, # existing
|
||||
{"identifier": 104}, # new
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
return_value=api_result,
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=existing_ids,
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 2
|
||||
# Verify that queued IDs are the new ones
|
||||
queued = []
|
||||
while not queue.empty():
|
||||
queued.append(queue.get_nowait())
|
||||
assert 102 in queued
|
||||
assert 104 in queued
|
||||
assert 101 not in queued
|
||||
assert 103 not in queued
|
||||
|
||||
async def test_stops_on_circuit_breaker_error(self) -> None:
|
||||
"""Test that CircuitBreakerOpenError breaks the page loop."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=CircuitBreakerOpenError("open"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_on_throttling_error(self) -> None:
|
||||
"""Test that ThrottlingError breaks the page loop."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ThrottlingError("throttled"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_on_generic_error(self) -> None:
|
||||
"""Test that GENERIC_ERROR (past last page) stops pagination."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("GENERIC_ERROR: no more results"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_on_unexpected_error(self) -> None:
|
||||
"""Test that unexpected errors also stop pagination."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("some network error"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_when_fewer_results_than_page_size(self) -> None:
|
||||
"""Test that pagination stops when a page has fewer results than page_size."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
# Return fewer properties than page_size
|
||||
api_result = {
|
||||
"properties": [
|
||||
{"identifier": 1},
|
||||
{"identifier": 2},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
return_value=api_result,
|
||||
) as mock_query:
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
# Should have called listing_query exactly once (then stopped)
|
||||
assert mock_query.await_count == 1
|
||||
assert ids_found == 2
|
||||
|
||||
|
||||
class TestDumpListings:
|
||||
"""Tests for dump_listings."""
|
||||
|
||||
async def test_circuit_breaker_returns_empty_list(self) -> None:
|
||||
"""Test that CircuitBreakerOpenError returns empty list."""
|
||||
mock_repo = AsyncMock()
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
|
||||
with patch("services.listing_fetcher.create_session") as mock_cs:
|
||||
mock_cs.side_effect = CircuitBreakerOpenError("open")
|
||||
result = await dump_listings(params, mock_repo)
|
||||
assert result == []
|
||||
|
||||
async def test_returns_processed_listings(self) -> None:
|
||||
"""Test that dump_listings returns processed listings from the pipeline."""
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listing_ids = MagicMock(return_value=set())
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.id = 42
|
||||
|
||||
mock_session_cm = AsyncMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.listing_fetcher.create_session",
|
||||
return_value=mock_session_cm,
|
||||
),
|
||||
patch(
|
||||
"services.listing_fetcher.QuerySplitter"
|
||||
) as mock_splitter_cls,
|
||||
patch(
|
||||
"services.listing_fetcher._fetch_subquery",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
),
|
||||
):
|
||||
mock_splitter = mock_splitter_cls.return_value
|
||||
mock_splitter.split = AsyncMock(return_value=[])
|
||||
mock_splitter.calculate_total_estimated_results = MagicMock(
|
||||
return_value=0
|
||||
)
|
||||
|
||||
result = await dump_listings(params, mock_repo)
|
||||
# With no subqueries, no listings are processed
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestNumWorkers:
|
||||
"""Tests for NUM_WORKERS constant."""
|
||||
|
||||
def test_num_workers_is_positive(self) -> None:
|
||||
"""Test that NUM_WORKERS is a positive integer."""
|
||||
assert NUM_WORKERS > 0
|
||||
|
||||
def test_num_workers_value(self) -> None:
|
||||
"""Test that NUM_WORKERS has the expected value."""
|
||||
assert NUM_WORKERS == 20
|
||||
87
crawler/tests/unit/test_listing_processor.py
Normal file
87
crawler/tests/unit/test_listing_processor.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Unit tests for the listing processor."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from models.listing import FurnishType
|
||||
from listing_processor import (
|
||||
_parse_furnish_type,
|
||||
_parse_available_from,
|
||||
ListingProcessor,
|
||||
FetchListingDetailsStep,
|
||||
MAX_OCR_WORKERS,
|
||||
)
|
||||
|
||||
|
||||
class TestParseFurnishType:
|
||||
"""Tests for _parse_furnish_type helper."""
|
||||
|
||||
def test_none_returns_unknown(self):
|
||||
assert _parse_furnish_type(None) == FurnishType.UNKNOWN
|
||||
|
||||
def test_ask_landlord_variant(self):
|
||||
assert _parse_furnish_type("Ask landlord") == FurnishType.ASK_LANDLORD
|
||||
|
||||
def test_furnished_lowercased(self):
|
||||
assert _parse_furnish_type("Furnished") == FurnishType.FURNISHED
|
||||
|
||||
def test_unfurnished(self):
|
||||
assert _parse_furnish_type("Unfurnished") == FurnishType.UNFURNISHED
|
||||
|
||||
def test_part_furnished(self):
|
||||
assert _parse_furnish_type("Part Furnished") == FurnishType.PART_FURNISHED
|
||||
|
||||
def test_unknown_string_returns_unknown(self):
|
||||
assert _parse_furnish_type("unknown") == FurnishType.UNKNOWN
|
||||
|
||||
def test_garbage_string_returns_unknown(self):
|
||||
assert _parse_furnish_type("xyzzy") == FurnishType.UNKNOWN
|
||||
|
||||
|
||||
class TestParseAvailableFrom:
|
||||
"""Tests for _parse_available_from helper."""
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert _parse_available_from(None) is None
|
||||
|
||||
def test_now_returns_datetime(self):
|
||||
result = _parse_available_from("Now")
|
||||
assert isinstance(result, datetime)
|
||||
|
||||
def test_valid_date_string(self):
|
||||
result = _parse_available_from("15/03/2024")
|
||||
assert result is not None
|
||||
assert result.day == 15
|
||||
assert result.month == 3
|
||||
|
||||
def test_invalid_date_returns_none(self):
|
||||
assert _parse_available_from("invalid") is None
|
||||
|
||||
|
||||
class TestListingProcessor:
|
||||
"""Tests for ListingProcessor."""
|
||||
|
||||
async def test_process_listing_marks_seen(self):
|
||||
"""Test that process_listing calls mark_seen."""
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listings = AsyncMock(return_value=[MagicMock()])
|
||||
processor = ListingProcessor(mock_repo)
|
||||
# Mock all steps to not need processing
|
||||
for step in processor.process_steps:
|
||||
step.needs_processing = AsyncMock(return_value=False)
|
||||
await processor.process_listing(123)
|
||||
mock_repo.mark_seen.assert_awaited_once_with(123)
|
||||
|
||||
async def test_process_listing_returns_none_on_step_failure(self):
|
||||
"""Test that a step failure returns None."""
|
||||
mock_repo = AsyncMock()
|
||||
processor = ListingProcessor(mock_repo)
|
||||
for step in processor.process_steps:
|
||||
step.needs_processing = AsyncMock(return_value=True)
|
||||
step.process = AsyncMock(side_effect=Exception("fail"))
|
||||
result = await processor.process_listing(123)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestOcrWorkersConfig:
|
||||
def test_max_ocr_workers_positive(self):
|
||||
assert MAX_OCR_WORKERS >= 1
|
||||
295
crawler/tests/unit/test_listing_tasks.py
Normal file
295
crawler/tests/unit/test_listing_tasks.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
"""Unit tests for tasks/listing_tasks.py."""
|
||||
import json
|
||||
import os
|
||||
from collections import deque
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
import tasks.listing_tasks as module
|
||||
from tasks.listing_tasks import (
|
||||
_update_task_state,
|
||||
_PipelineState,
|
||||
TaskLogHandler,
|
||||
SCRAPE_LOCK_NAME,
|
||||
LOG_BUFFER_MAX_LINES,
|
||||
NUM_WORKERS,
|
||||
PHASE_SPLITTING,
|
||||
PHASE_FETCHING,
|
||||
PHASE_PROCESSING,
|
||||
PHASE_COMPLETED,
|
||||
)
|
||||
|
||||
|
||||
class TestUpdateTaskState:
|
||||
"""Tests for _update_task_state."""
|
||||
|
||||
def test_injects_logs_from_active_buffer(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = deque(["log line 1", "log line 2"])
|
||||
_update_task_state(task, "test_state", {"key": "value"})
|
||||
task.update_state.assert_called_once()
|
||||
call_meta = task.update_state.call_args[1]["meta"]
|
||||
assert call_meta["logs"] == ["log line 1", "log line 2"]
|
||||
assert call_meta["key"] == "value"
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
def test_works_when_buffer_is_none(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = None
|
||||
_update_task_state(task, "some_state", {"phase": "testing"})
|
||||
task.update_state.assert_called_once_with(
|
||||
state="some_state", meta={"phase": "testing"}
|
||||
)
|
||||
# No "logs" key should be injected
|
||||
call_meta = task.update_state.call_args[1]["meta"]
|
||||
assert "logs" not in call_meta
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
def test_state_string_is_passed_through(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = None
|
||||
_update_task_state(task, "PROGRESS", {})
|
||||
task.update_state.assert_called_once_with(state="PROGRESS", meta={})
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
def test_empty_buffer_injects_empty_list(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = deque()
|
||||
_update_task_state(task, "state", {"a": 1})
|
||||
call_meta = task.update_state.call_args[1]["meta"]
|
||||
assert call_meta["logs"] == []
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
|
||||
class TestTaskLogHandler:
|
||||
"""Tests for the TaskLogHandler."""
|
||||
|
||||
def test_emit_appends_to_buffer(self):
|
||||
buf = deque(maxlen=10)
|
||||
handler = TaskLogHandler(buf)
|
||||
handler.setFormatter(
|
||||
__import__("logging").Formatter("%(message)s")
|
||||
)
|
||||
record = __import__("logging").LogRecord(
|
||||
name="test", level=20, pathname="", lineno=0,
|
||||
msg="hello", args=(), exc_info=None,
|
||||
)
|
||||
handler.emit(record)
|
||||
assert "hello" in buf
|
||||
|
||||
def test_buffer_respects_maxlen(self):
|
||||
buf = deque(maxlen=2)
|
||||
handler = TaskLogHandler(buf)
|
||||
handler.setFormatter(
|
||||
__import__("logging").Formatter("%(message)s")
|
||||
)
|
||||
for i in range(5):
|
||||
record = __import__("logging").LogRecord(
|
||||
name="test", level=20, pathname="", lineno=0,
|
||||
msg=f"msg{i}", args=(), exc_info=None,
|
||||
)
|
||||
handler.emit(record)
|
||||
assert len(buf) == 2
|
||||
assert list(buf) == ["msg3", "msg4"]
|
||||
|
||||
|
||||
class TestDumpListingsTask:
|
||||
"""Tests for dump_listings_task Celery task."""
|
||||
|
||||
@patch("tasks.listing_tasks.redis_lock")
|
||||
def test_skips_when_lock_not_acquired(self, mock_redis_lock):
|
||||
"""Task should skip when another scrape is running."""
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__enter__ = MagicMock(return_value=False)
|
||||
mock_cm.__exit__ = MagicMock(return_value=False)
|
||||
mock_redis_lock.return_value = mock_cm
|
||||
|
||||
from tasks.listing_tasks import dump_listings_task
|
||||
|
||||
# Use run() which handles bind=True properly
|
||||
task_instance = dump_listings_task
|
||||
task_instance.update_state = MagicMock()
|
||||
|
||||
result = dump_listings_task.run('{"listing_type": "RENT"}')
|
||||
|
||||
assert result["status"] == "skipped"
|
||||
assert result["reason"] == "another_job_running"
|
||||
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
|
||||
|
||||
@patch("tasks.listing_tasks.asyncio.run")
|
||||
@patch("tasks.listing_tasks.redis_lock")
|
||||
def test_calls_dump_listings_full_when_lock_acquired(
|
||||
self, mock_redis_lock, mock_asyncio_run
|
||||
):
|
||||
"""Task should call dump_listings_full when lock is acquired."""
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__enter__ = MagicMock(return_value=True)
|
||||
mock_cm.__exit__ = MagicMock(return_value=False)
|
||||
mock_redis_lock.return_value = mock_cm
|
||||
|
||||
mock_asyncio_run.return_value = []
|
||||
|
||||
from tasks.listing_tasks import dump_listings_task
|
||||
|
||||
task_instance = dump_listings_task
|
||||
task_instance.update_state = MagicMock()
|
||||
|
||||
params_json = '{"listing_type": "RENT", "min_price": 1000, "max_price": 5000}'
|
||||
result = dump_listings_task.run(params_json)
|
||||
|
||||
assert result["phase"] == "completed"
|
||||
assert result["progress"] == 1
|
||||
mock_asyncio_run.assert_called_once()
|
||||
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
|
||||
|
||||
|
||||
class TestSetupPeriodicTasks:
|
||||
"""Tests for setup_periodic_tasks."""
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_registers_enabled_schedules(self, mock_from_env):
|
||||
from config.schedule_config import ScheduleConfig
|
||||
from models.listing import ListingType
|
||||
|
||||
schedule = ScheduleConfig(
|
||||
name="Test Schedule",
|
||||
listing_type=ListingType.RENT,
|
||||
hour="3",
|
||||
minute="30",
|
||||
)
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_enabled_schedules.return_value = [schedule]
|
||||
mock_from_env.return_value = mock_config
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
sender.add_periodic_task.assert_called_once()
|
||||
call_args = sender.add_periodic_task.call_args
|
||||
assert call_args[1]["name"] == "Test Schedule"
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_handles_config_error_gracefully(self, mock_from_env):
|
||||
mock_from_env.side_effect = ValueError("bad config")
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
sender.add_periodic_task.assert_not_called()
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_registers_nothing_when_no_schedules(self, mock_from_env):
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_enabled_schedules.return_value = []
|
||||
mock_from_env.return_value = mock_config
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
sender.add_periodic_task.assert_not_called()
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_registers_multiple_schedules(self, mock_from_env):
|
||||
from config.schedule_config import ScheduleConfig
|
||||
from models.listing import ListingType
|
||||
|
||||
schedules = [
|
||||
ScheduleConfig(name="Rent", listing_type=ListingType.RENT, hour="2"),
|
||||
ScheduleConfig(name="Buy", listing_type=ListingType.BUY, hour="4"),
|
||||
]
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_enabled_schedules.return_value = schedules
|
||||
mock_from_env.return_value = mock_config
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
assert sender.add_periodic_task.call_count == 2
|
||||
|
||||
|
||||
class TestPipelineState:
|
||||
"""Tests for _PipelineState dataclass."""
|
||||
|
||||
def test_default_initialization(self):
|
||||
state = _PipelineState()
|
||||
assert state.ids_collected == 0
|
||||
assert state.completed_subqueries == 0
|
||||
assert state.total_pages_fetched == 0
|
||||
assert state.fetching_done is False
|
||||
assert state.processed_count == 0
|
||||
assert state.failed_count == 0
|
||||
assert state.details_fetched == 0
|
||||
assert state.images_downloaded == 0
|
||||
assert state.ocr_completed == 0
|
||||
assert state.processed_listings == []
|
||||
|
||||
def test_incrementing_counters(self):
|
||||
state = _PipelineState()
|
||||
state.ids_collected += 5
|
||||
state.completed_subqueries += 3
|
||||
state.total_pages_fetched += 10
|
||||
state.processed_count += 4
|
||||
state.failed_count += 1
|
||||
state.details_fetched += 4
|
||||
state.images_downloaded += 3
|
||||
state.ocr_completed += 2
|
||||
|
||||
assert state.ids_collected == 5
|
||||
assert state.completed_subqueries == 3
|
||||
assert state.total_pages_fetched == 10
|
||||
assert state.processed_count == 4
|
||||
assert state.failed_count == 1
|
||||
assert state.details_fetched == 4
|
||||
assert state.images_downloaded == 3
|
||||
assert state.ocr_completed == 2
|
||||
|
||||
def test_appending_to_processed_listings(self):
|
||||
state = _PipelineState()
|
||||
state.processed_listings.append("listing_a")
|
||||
state.processed_listings.append("listing_b")
|
||||
assert len(state.processed_listings) == 2
|
||||
assert state.processed_listings == ["listing_a", "listing_b"]
|
||||
|
||||
def test_separate_instances_have_independent_lists(self):
|
||||
state_a = _PipelineState()
|
||||
state_b = _PipelineState()
|
||||
state_a.processed_listings.append("only_a")
|
||||
assert state_b.processed_listings == []
|
||||
|
||||
def test_fetching_done_toggle(self):
|
||||
state = _PipelineState()
|
||||
assert state.fetching_done is False
|
||||
state.fetching_done = True
|
||||
assert state.fetching_done is True
|
||||
|
||||
|
||||
class TestPhaseConstants:
|
||||
"""Tests for phase constant values."""
|
||||
|
||||
def test_phase_splitting(self):
|
||||
assert PHASE_SPLITTING == "splitting"
|
||||
|
||||
def test_phase_fetching(self):
|
||||
assert PHASE_FETCHING == "fetching"
|
||||
|
||||
def test_phase_processing(self):
|
||||
assert PHASE_PROCESSING == "processing"
|
||||
|
||||
def test_phase_completed(self):
|
||||
assert PHASE_COMPLETED == "completed"
|
||||
|
||||
def test_num_workers(self):
|
||||
assert NUM_WORKERS == 20
|
||||
|
|
@ -1,16 +1,24 @@
|
|||
"""Unit tests for Listing models."""
|
||||
import dataclasses
|
||||
from datetime import datetime
|
||||
import json
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from models.listing import (
|
||||
BuyListing,
|
||||
DestinationMode,
|
||||
FurnishType,
|
||||
ListingSite,
|
||||
ListingType,
|
||||
PriceHistoryItem,
|
||||
QueryParameters,
|
||||
RentListing,
|
||||
Listing,
|
||||
Route,
|
||||
RouteLegStep,
|
||||
)
|
||||
from rec.routing import TravelMode
|
||||
|
||||
|
||||
class TestListing:
|
||||
|
|
@ -341,3 +349,190 @@ class TestBuyListing:
|
|||
lease_left=120,
|
||||
)
|
||||
assert listing.lease_left == 120
|
||||
|
||||
|
||||
def _make_listing_with_routing(routing_info_json: str | None) -> RentListing:
|
||||
"""Helper to create a RentListing with given routing_info_json."""
|
||||
return RentListing(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=50.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=routing_info_json,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_sample_routing_info() -> dict[DestinationMode, list[Route]]:
|
||||
"""Helper to create sample routing info for tests."""
|
||||
destination_mode = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
routes = [
|
||||
Route(
|
||||
legs=[
|
||||
RouteLegStep(
|
||||
distance_meters=500,
|
||||
duration_s=120,
|
||||
travel_mode=TravelMode.WALK,
|
||||
),
|
||||
RouteLegStep(
|
||||
distance_meters=4000,
|
||||
duration_s=480,
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
),
|
||||
],
|
||||
distance_meters=4500,
|
||||
duration_s=600,
|
||||
)
|
||||
]
|
||||
return {destination_mode: routes}
|
||||
|
||||
|
||||
class TestQueryParametersValidation:
|
||||
"""Tests for QueryParameters validation."""
|
||||
|
||||
def test_valid_parameters(self) -> None:
|
||||
"""Basic valid QueryParameters creation."""
|
||||
params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=3,
|
||||
)
|
||||
assert params.min_price == 1000
|
||||
assert params.max_price == 3000
|
||||
assert params.min_bedrooms == 1
|
||||
assert params.max_bedrooms == 3
|
||||
|
||||
def test_invalid_price_range_raises(self) -> None:
|
||||
"""min_price > max_price should raise ValidationError."""
|
||||
with pytest.raises(ValidationError, match="min_price.*must be <= max_price"):
|
||||
QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_price=5000,
|
||||
max_price=1000,
|
||||
)
|
||||
|
||||
def test_invalid_bedroom_range_raises(self) -> None:
|
||||
"""min_bedrooms > max_bedrooms should raise ValidationError."""
|
||||
with pytest.raises(ValidationError, match="min_bedrooms.*must be <= max_bedrooms"):
|
||||
QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=5,
|
||||
max_bedrooms=2,
|
||||
)
|
||||
|
||||
def test_negative_bedrooms_raises(self) -> None:
|
||||
"""Negative bedroom counts should raise ValidationError."""
|
||||
with pytest.raises(ValidationError, match="min_bedrooms.*must be non-negative"):
|
||||
QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=-1,
|
||||
max_bedrooms=3,
|
||||
)
|
||||
|
||||
|
||||
class TestDestinationMode:
|
||||
"""Tests for DestinationMode."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test to_dict returns correct dict."""
|
||||
dm = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
result = dm.to_dict()
|
||||
assert result == {
|
||||
"destination_address": "London Bridge",
|
||||
"travel_mode": TravelMode.TRANSIT,
|
||||
}
|
||||
|
||||
def test_hash(self) -> None:
|
||||
"""Test hashing works correctly."""
|
||||
dm1 = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
dm2 = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
dm3 = DestinationMode(
|
||||
destination_address="King's Cross",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
assert hash(dm1) == hash(dm2)
|
||||
assert dm1 == dm2
|
||||
assert hash(dm1) != hash(dm3)
|
||||
# Can be used as dict key
|
||||
d = {dm1: "route1"}
|
||||
assert d[dm2] == "route1"
|
||||
|
||||
|
||||
class TestRoutingInfoSerialization:
|
||||
"""Tests for routing info via RouteSerializer."""
|
||||
|
||||
def test_routing_info_property_returns_parsed_routes(self) -> None:
|
||||
"""Test routing_info property deserializes correctly."""
|
||||
routing_info = _make_sample_routing_info()
|
||||
listing = _make_listing_with_routing(None)
|
||||
serialized = listing.serialize_routing_info(routing_info)
|
||||
listing.routing_info_json = serialized
|
||||
|
||||
result = listing.routing_info
|
||||
assert len(result) == 1
|
||||
dest_mode = list(result.keys())[0]
|
||||
assert dest_mode.destination_address == "London Bridge"
|
||||
assert dest_mode.travel_mode == TravelMode.TRANSIT
|
||||
|
||||
routes = result[dest_mode]
|
||||
assert len(routes) == 1
|
||||
assert routes[0].distance_meters == 4500
|
||||
assert routes[0].duration_s == 600
|
||||
assert len(routes[0].legs) == 2
|
||||
assert routes[0].legs[0].distance_meters == 500
|
||||
assert routes[0].legs[0].travel_mode == TravelMode.WALK
|
||||
|
||||
def test_routing_info_empty_json(self) -> None:
|
||||
"""Test routing_info with no routing data."""
|
||||
listing = _make_listing_with_routing(None)
|
||||
assert listing.routing_info == {}
|
||||
|
||||
def test_serialize_routing_info_roundtrip(self) -> None:
|
||||
"""Test serialize then deserialize via routing_info property."""
|
||||
routing_info = _make_sample_routing_info()
|
||||
listing = _make_listing_with_routing(None)
|
||||
|
||||
# Serialize
|
||||
serialized = listing.serialize_routing_info(routing_info)
|
||||
assert isinstance(serialized, str)
|
||||
|
||||
# Assign and deserialize via property
|
||||
listing.routing_info_json = serialized
|
||||
deserialized = listing.routing_info
|
||||
|
||||
# Compare
|
||||
orig_dm = list(routing_info.keys())[0]
|
||||
result_dm = list(deserialized.keys())[0]
|
||||
assert orig_dm.destination_address == result_dm.destination_address
|
||||
assert orig_dm.travel_mode == result_dm.travel_mode
|
||||
|
||||
orig_route = routing_info[orig_dm][0]
|
||||
result_route = deserialized[result_dm][0]
|
||||
assert orig_route.distance_meters == result_route.distance_meters
|
||||
assert orig_route.duration_s == result_route.duration_s
|
||||
assert len(orig_route.legs) == len(result_route.legs)
|
||||
|
|
|
|||
385
crawler/tests/unit/test_query.py
Normal file
385
crawler/tests/unit/test_query.py
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
"""Unit tests for rec/query.py."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import aiohttp
|
||||
|
||||
from rec.query import (
|
||||
detail_query,
|
||||
listing_query,
|
||||
probe_query,
|
||||
PropertyType,
|
||||
create_session,
|
||||
_build_base_params,
|
||||
_build_listing_params,
|
||||
_build_probe_params,
|
||||
ANDROID_APP_VERSION,
|
||||
ANDROID_APP_VERSION_LISTING,
|
||||
RIGHTMOVE_API_BASE,
|
||||
PROPERTY_LISTING_ENDPOINT,
|
||||
DEFAULT_HEADERS,
|
||||
LISTING_HEADERS,
|
||||
check_circuit_breaker,
|
||||
reset_circuit_breaker,
|
||||
get_circuit_breaker,
|
||||
)
|
||||
from models.listing import ListingType, FurnishType
|
||||
from config.scraper_config import ScraperConfig
|
||||
from rec.exceptions import CircuitBreakerOpenError
|
||||
from rec.throttle_detector import reset_throttle_metrics
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config() -> ScraperConfig:
|
||||
return ScraperConfig(
|
||||
max_concurrent_requests=5,
|
||||
request_delay_ms=10,
|
||||
slow_response_threshold=10.0,
|
||||
enable_circuit_breaker=True,
|
||||
circuit_breaker_failure_threshold=3,
|
||||
circuit_breaker_recovery_timeout=0.5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_no_cb() -> ScraperConfig:
|
||||
return ScraperConfig(enable_circuit_breaker=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals() -> None:
|
||||
reset_throttle_metrics()
|
||||
reset_circuit_breaker()
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(
|
||||
self,
|
||||
status: int = 200,
|
||||
json_data: dict | None = None,
|
||||
text: str = "",
|
||||
):
|
||||
self.status = status
|
||||
self._json_data = json_data or {}
|
||||
self._text = text
|
||||
|
||||
async def json(self) -> dict:
|
||||
return self._json_data
|
||||
|
||||
async def text(self) -> str:
|
||||
return self._text
|
||||
|
||||
async def __aenter__(self) -> "MockResponse":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: object) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def make_mock_session(response: MockResponse) -> MagicMock:
|
||||
"""Create a mock session whose .get() returns an async context manager."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(return_value=response)
|
||||
return mock_session
|
||||
|
||||
|
||||
def make_mock_session_fn(get_fn: object) -> MagicMock:
|
||||
"""Create a mock session whose .get() calls a function to produce responses."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(side_effect=get_fn)
|
||||
return mock_session
|
||||
|
||||
|
||||
class TestBuildBaseParams:
|
||||
def test_constructs_correct_params(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"TestDistrict": "REGION^123"}):
|
||||
params = _build_base_params(
|
||||
channel=ListingType.RENT,
|
||||
page=2,
|
||||
page_size=25,
|
||||
radius=1.5,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=3,
|
||||
district="TestDistrict",
|
||||
)
|
||||
|
||||
assert params["locationIdentifier"] == "REGION^123"
|
||||
assert params["channel"] == "RENT"
|
||||
assert params["page"] == "2"
|
||||
assert params["numberOfPropertiesPerPage"] == "25"
|
||||
assert params["radius"] == "1.5"
|
||||
assert params["sortBy"] == "distance"
|
||||
assert params["includeUnavailableProperties"] == "false"
|
||||
assert params["minPrice"] == "1000"
|
||||
assert params["maxPrice"] == "3000"
|
||||
assert params["minBedrooms"] == "1"
|
||||
assert params["maxBedrooms"] == "3"
|
||||
assert params["apiApplication"] == "ANDROID"
|
||||
assert params["appVersion"] == ANDROID_APP_VERSION_LISTING
|
||||
|
||||
def test_buy_channel_includes_dont_show_and_max_days(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_listing_params(
|
||||
page=1,
|
||||
channel=ListingType.BUY,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=100000,
|
||||
max_price=500000,
|
||||
district="D",
|
||||
mustNewHome=False,
|
||||
max_days_since_added=7,
|
||||
property_type=[],
|
||||
page_size=25,
|
||||
furnish_types=[],
|
||||
)
|
||||
|
||||
assert params["dontShow"] == "sharedOwnership,retirement"
|
||||
assert params["maxDaysSinceAdded"] == "7"
|
||||
|
||||
def test_rent_channel_includes_furnish_types(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_listing_params(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
district="D",
|
||||
mustNewHome=False,
|
||||
max_days_since_added=30,
|
||||
property_type=[],
|
||||
page_size=25,
|
||||
furnish_types=[FurnishType.FURNISHED, FurnishType.UNFURNISHED],
|
||||
)
|
||||
|
||||
assert params["furnishTypes"] == "furnished,unfurnished"
|
||||
assert "dontShow" not in params
|
||||
assert "maxDaysSinceAdded" not in params
|
||||
|
||||
def test_buy_channel_probe_includes_dont_show(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_probe_params(
|
||||
channel=ListingType.BUY,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=100000,
|
||||
max_price=500000,
|
||||
district="D",
|
||||
max_days_since_added=7,
|
||||
furnish_types=[],
|
||||
)
|
||||
|
||||
assert params["dontShow"] == "sharedOwnership,retirement"
|
||||
assert params["maxDaysSinceAdded"] == "7"
|
||||
assert params["numberOfPropertiesPerPage"] == "1"
|
||||
|
||||
def test_probe_buy_skips_max_days_if_not_valid(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_probe_params(
|
||||
channel=ListingType.BUY,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=100000,
|
||||
max_price=500000,
|
||||
district="D",
|
||||
max_days_since_added=30,
|
||||
furnish_types=[],
|
||||
)
|
||||
|
||||
# 30 is not in [1, 3, 7, 14], so maxDaysSinceAdded is not added for probe
|
||||
assert "maxDaysSinceAdded" not in params
|
||||
|
||||
|
||||
class TestMutableDefaultArgFix:
|
||||
@pytest.mark.asyncio
|
||||
async def test_property_type_default_not_shared(self, config: ScraperConfig) -> None:
|
||||
"""Calling listing_query with no property_type should not share state between calls."""
|
||||
response = MockResponse(
|
||||
status=200,
|
||||
json_data={"totalAvailableResults": 0, "properties": []},
|
||||
)
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
# Call twice without explicit property_type
|
||||
await listing_query(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
session=mock_session,
|
||||
config=config,
|
||||
)
|
||||
await listing_query(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
session=mock_session,
|
||||
config=config,
|
||||
)
|
||||
# If mutable default was shared, this test would detect mutations.
|
||||
# The fact that it completes without error proves defaults are independent.
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_furnish_types_default_not_shared(self, config: ScraperConfig) -> None:
|
||||
"""Calling probe_query with no furnish_types should not share state between calls."""
|
||||
response = MockResponse(
|
||||
status=200,
|
||||
json_data={"totalAvailableResults": 0, "properties": []},
|
||||
)
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
config=config,
|
||||
)
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
class TestPropertyTypeEnum:
|
||||
def test_enum_values(self) -> None:
|
||||
assert PropertyType.BUNGALOW == "bungalow"
|
||||
assert PropertyType.DETACHED == "detached"
|
||||
assert PropertyType.FLAT == "flat"
|
||||
assert PropertyType.LAND == "land"
|
||||
assert PropertyType.PARK_HOME == "park-home"
|
||||
assert PropertyType.SEMI_DETACHED == "semi-detached"
|
||||
assert PropertyType.TERRACED == "terraced"
|
||||
|
||||
def test_enum_is_str(self) -> None:
|
||||
assert isinstance(PropertyType.FLAT, str)
|
||||
assert ",".join([PropertyType.FLAT, PropertyType.DETACHED]) == "flat,detached"
|
||||
|
||||
|
||||
class TestDetailQuery:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_200(self, config: ScraperConfig) -> None:
|
||||
expected_body = {"id": 12345, "address": "123 Test St"}
|
||||
response = MockResponse(status=200, json_data=expected_body)
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
result = await detail_query(12345, session=mock_session, config=config)
|
||||
assert result == expected_body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_non_200(self, config: ScraperConfig) -> None:
|
||||
response = MockResponse(status=404, text="Not Found")
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
with pytest.raises(Exception, match="Failed due to"):
|
||||
await detail_query(99999, session=mock_session, config=config)
|
||||
|
||||
|
||||
class TestCircuitBreakerBlocksRequests:
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_blocks_when_open(self, config: ScraperConfig) -> None:
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is not None
|
||||
for _ in range(config.circuit_breaker_failure_threshold):
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
await detail_query(1, session=mock_session, config=config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_blocks_listing_query(self, config: ScraperConfig) -> None:
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is not None
|
||||
for _ in range(config.circuit_breaker_failure_threshold):
|
||||
cb.record_failure()
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
await listing_query(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
session=mock_session,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_blocks_probe_query(self, config: ScraperConfig) -> None:
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is not None
|
||||
for _ in range(config.circuit_breaker_failure_threshold):
|
||||
cb.record_failure()
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
class TestConstants:
|
||||
def test_android_app_version(self) -> None:
|
||||
assert ANDROID_APP_VERSION == "3.70.0"
|
||||
|
||||
def test_android_app_version_listing(self) -> None:
|
||||
assert ANDROID_APP_VERSION_LISTING == "4.28.0"
|
||||
|
||||
def test_rightmove_api_base(self) -> None:
|
||||
assert RIGHTMOVE_API_BASE == "https://api.rightmove.co.uk/api"
|
||||
|
||||
def test_property_listing_endpoint(self) -> None:
|
||||
assert PROPERTY_LISTING_ENDPOINT == "https://api.rightmove.co.uk/api/property-listing"
|
||||
|
||||
def test_listing_headers_extends_default(self) -> None:
|
||||
for key, value in DEFAULT_HEADERS.items():
|
||||
assert LISTING_HEADERS[key] == value
|
||||
assert LISTING_HEADERS["Accept-Encoding"] == "gzip, deflate, br"
|
||||
|
|
@ -161,7 +161,7 @@ class TestQuerySplitter:
|
|||
mock_session = AsyncMock()
|
||||
|
||||
# Mock the probe_query function
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
mock_probe.return_value = {"totalAvailableResults": 800}
|
||||
|
||||
count = await splitter.probe_result_count(sq, mock_session, parameters)
|
||||
|
|
@ -184,7 +184,7 @@ class TestQuerySplitter:
|
|||
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
mock_probe.side_effect = Exception("API error")
|
||||
|
||||
count = await splitter.probe_result_count(sq, mock_session, parameters)
|
||||
|
|
@ -208,7 +208,7 @@ class TestQuerySplitter:
|
|||
mock_session = AsyncMock()
|
||||
mock_semaphore = AsyncMock()
|
||||
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
# First half has 600 results, second half has 500
|
||||
mock_probe.side_effect = [
|
||||
{"totalAvailableResults": 600},
|
||||
|
|
@ -240,7 +240,7 @@ class TestQuerySplitter:
|
|||
mock_session = AsyncMock()
|
||||
mock_semaphore = AsyncMock()
|
||||
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
# First split: 1000-3000 has 1300 (over threshold), 3000-5000 has 800
|
||||
# Second split of 1000-3000: 1000-2000 has 700, 2000-3000 has 600
|
||||
mock_probe.side_effect = [
|
||||
|
|
@ -326,7 +326,7 @@ class TestQuerySplitter:
|
|||
mock_districts = {"Kings Cross": "STATION^5168", "Angel": "STATION^1234"}
|
||||
|
||||
with patch("services.query_splitter.get_districts", return_value=mock_districts):
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
# Mock probe results for each initial subquery
|
||||
# 2 districts × 2 bedroom counts = 4 initial subqueries
|
||||
mock_probe.side_effect = [
|
||||
|
|
@ -358,11 +358,11 @@ class TestQuerySplitter:
|
|||
mock_districts = {"Kings Cross": "STATION^5168", "Angel": "STATION^1234"}
|
||||
progress_calls = []
|
||||
|
||||
def on_progress(phase: str, message: str) -> None:
|
||||
def on_progress(phase: str, message: str, **kwargs: object) -> None:
|
||||
progress_calls.append((phase, message))
|
||||
|
||||
with patch("services.query_splitter.get_districts", return_value=mock_districts):
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
mock_probe.return_value = {"totalAvailableResults": 500}
|
||||
|
||||
await splitter.split(parameters, mock_session, on_progress)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Unit tests for ListingRepository."""
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
|
|
@ -225,3 +226,156 @@ class TestListingRepositoryFilters:
|
|||
listings = await listing_repository.get_listings(query_parameters=query_params)
|
||||
# Should match listings with 1-2 bedrooms in price range
|
||||
assert len(listings) == 2
|
||||
|
||||
|
||||
class TestListingRepositoryStreaming:
|
||||
"""Tests for streaming and optimized query methods."""
|
||||
|
||||
async def test_count_listings_empty_db(
|
||||
self, listing_repository: ListingRepository
|
||||
) -> None:
|
||||
"""Test count returns 0 for empty database."""
|
||||
count = listing_repository.count_listings()
|
||||
assert count == 0
|
||||
|
||||
async def test_count_listings_with_data(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test count returns correct number."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
count = listing_repository.count_listings()
|
||||
assert count == 3
|
||||
|
||||
async def test_count_listings_with_filters(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test count respects query parameters."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
query_params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=3,
|
||||
)
|
||||
count = listing_repository.count_listings(query_parameters=query_params)
|
||||
assert count == 2
|
||||
|
||||
async def test_stream_listings_optimized_returns_dicts(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test optimized streaming returns dict rows."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
results = list(listing_repository.stream_listings_optimized())
|
||||
assert len(results) == 3
|
||||
# Each result should be a dict
|
||||
for row in results:
|
||||
assert isinstance(row, dict)
|
||||
assert "id" in row
|
||||
assert "price" in row
|
||||
assert "number_of_bedrooms" in row
|
||||
|
||||
async def test_stream_listings_optimized_respects_limit(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test streaming limit parameter."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
results = list(listing_repository.stream_listings_optimized(limit=2))
|
||||
assert len(results) == 2
|
||||
|
||||
async def test_get_listing_ids(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test get_listing_ids returns set of IDs."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
ids = listing_repository.get_listing_ids()
|
||||
assert isinstance(ids, set)
|
||||
assert ids == {1, 2, 3}
|
||||
|
||||
async def test_get_listing_ids_empty_db(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
) -> None:
|
||||
"""Test get_listing_ids returns empty set for empty database."""
|
||||
ids = listing_repository.get_listing_ids()
|
||||
assert isinstance(ids, set)
|
||||
assert len(ids) == 0
|
||||
|
||||
|
||||
class TestFurnishTypeParsing:
|
||||
"""Tests for _parse_furnish_type helper."""
|
||||
|
||||
def test_parse_furnish_type_none_detailobject(self) -> None:
|
||||
"""Test that None detailobject returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = None
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_missing_property_key(self) -> None:
|
||||
"""Test that missing 'property' key returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_missing_let_furnish_type(self) -> None:
|
||||
"""Test that missing 'letFurnishType' key returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_null_value(self) -> None:
|
||||
"""Test that null letFurnishType value returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": None}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_furnished(self) -> None:
|
||||
"""Test that 'Furnished' is parsed correctly."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Furnished"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.FURNISHED
|
||||
|
||||
def test_parse_furnish_type_unfurnished(self) -> None:
|
||||
"""Test that 'Unfurnished' is parsed correctly."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Unfurnished"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNFURNISHED
|
||||
|
||||
def test_parse_furnish_type_part_furnished(self) -> None:
|
||||
"""Test that 'Part Furnished' is parsed correctly."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Part Furnished"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.PART_FURNISHED
|
||||
|
||||
def test_parse_furnish_type_landlord_variant(self) -> None:
|
||||
"""Test that landlord variants map to ASK_LANDLORD."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Ask Landlord Please"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.ASK_LANDLORD
|
||||
|
||||
def test_parse_furnish_type_landlord_case_insensitive(self) -> None:
|
||||
"""Test that landlord check is case-insensitive."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "LANDLORD decides"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.ASK_LANDLORD
|
||||
|
|
|
|||
10
crawler/tests/unit/test_route_calculator.py
Normal file
10
crawler/tests/unit/test_route_calculator.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""Unit tests for services/route_calculator.py."""
|
||||
from services.route_calculator import _parse_duration
|
||||
|
||||
|
||||
class TestParseDuration:
|
||||
def test_parse_normal_duration(self) -> None:
|
||||
assert _parse_duration("123s") == 123
|
||||
|
||||
def test_parse_zero_duration(self) -> None:
|
||||
assert _parse_duration("0s") == 0
|
||||
72
crawler/tests/unit/test_route_serializer.py
Normal file
72
crawler/tests/unit/test_route_serializer.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""Unit tests for rec/route_serializer.py."""
|
||||
from models.listing import DestinationMode, Route, RouteLegStep
|
||||
from rec.route_serializer import RouteSerializer
|
||||
from rec.routing import TravelMode
|
||||
|
||||
|
||||
def _make_sample_routing_info() -> dict[DestinationMode, list[Route]]:
|
||||
destination_mode = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
routes = [
|
||||
Route(
|
||||
legs=[
|
||||
RouteLegStep(
|
||||
distance_meters=500,
|
||||
duration_s=120,
|
||||
travel_mode=TravelMode.WALK,
|
||||
),
|
||||
RouteLegStep(
|
||||
distance_meters=4000,
|
||||
duration_s=480,
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
),
|
||||
],
|
||||
distance_meters=4500,
|
||||
duration_s=600,
|
||||
)
|
||||
]
|
||||
return {destination_mode: routes}
|
||||
|
||||
|
||||
class TestRouteSerializer:
|
||||
def test_serialize_then_deserialize_roundtrip(self) -> None:
|
||||
routing_info = _make_sample_routing_info()
|
||||
serialized = RouteSerializer.serialize(routing_info)
|
||||
deserialized = RouteSerializer.deserialize(serialized)
|
||||
|
||||
assert len(deserialized) == 1
|
||||
dest_mode = list(deserialized.keys())[0]
|
||||
assert dest_mode.destination_address == "London Bridge"
|
||||
assert dest_mode.travel_mode == TravelMode.TRANSIT
|
||||
|
||||
routes = deserialized[dest_mode]
|
||||
assert len(routes) == 1
|
||||
assert routes[0].distance_meters == 4500
|
||||
assert routes[0].duration_s == 600
|
||||
assert len(routes[0].legs) == 2
|
||||
assert routes[0].legs[0].distance_meters == 500
|
||||
assert routes[0].legs[0].travel_mode == TravelMode.WALK
|
||||
assert routes[0].legs[1].travel_mode == TravelMode.TRANSIT
|
||||
|
||||
def test_deserialize_sample_json(self) -> None:
|
||||
import json
|
||||
import dataclasses
|
||||
|
||||
routing_info = _make_sample_routing_info()
|
||||
# Build the JSON manually to test deserialize independently
|
||||
json_str = json.dumps(
|
||||
{
|
||||
json.dumps(dataclasses.asdict(dm)): [
|
||||
json.dumps(dataclasses.asdict(r)) for r in routes
|
||||
]
|
||||
for dm, routes in routing_info.items()
|
||||
}
|
||||
)
|
||||
|
||||
result = RouteSerializer.deserialize(json_str)
|
||||
assert len(result) == 1
|
||||
dest_mode = list(result.keys())[0]
|
||||
assert dest_mode.destination_address == "London Bridge"
|
||||
assert result[dest_mode][0].duration_s == 600
|
||||
67
crawler/tests/unit/test_routing.py
Normal file
67
crawler/tests/unit/test_routing.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""Unit tests for rec/routing.py."""
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from rec.routing import TravelMode, transit_route, ROUTES_API_URL, ROUTES_FIELD_MASK
|
||||
from rec.exceptions import RoutingApiError
|
||||
|
||||
|
||||
class TestTravelMode:
|
||||
def test_enum_values(self) -> None:
|
||||
assert TravelMode.TRANSIT == "TRANSIT"
|
||||
assert TravelMode.BICYCLE == "BICYCLE"
|
||||
assert TravelMode.WALK == "WALK"
|
||||
assert TravelMode.DRIVE == "DRIVE"
|
||||
|
||||
def test_enum_has_four_members(self) -> None:
|
||||
assert len(TravelMode) == 4
|
||||
|
||||
|
||||
class TestTransitRoute:
|
||||
@patch("rec.routing.requests.post")
|
||||
@patch("rec.routing.nextMonday")
|
||||
def test_success_response(self, mock_monday: MagicMock, mock_post: MagicMock) -> None:
|
||||
mock_monday.return_value = MagicMock(
|
||||
strftime=MagicMock(return_value="2024-01-08T09:00:00.000000Z")
|
||||
)
|
||||
expected = {"routes": [{"duration": "600s", "distanceMeters": 5000}]}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = expected
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with patch.dict(os.environ, {"ROUTING_API_KEY": "test-key"}):
|
||||
result = transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
|
||||
|
||||
assert result == expected
|
||||
mock_post.assert_called_once()
|
||||
call_kwargs = mock_post.call_args
|
||||
assert call_kwargs.kwargs["headers"]["X-Goog-Api-Key"] == "test-key"
|
||||
|
||||
@patch("rec.routing.requests.post")
|
||||
@patch("rec.routing.nextMonday")
|
||||
def test_raises_routing_api_error_on_non_200(self, mock_monday: MagicMock, mock_post: MagicMock) -> None:
|
||||
mock_monday.return_value = MagicMock(
|
||||
strftime=MagicMock(return_value="2024-01-08T09:00:00.000000Z")
|
||||
)
|
||||
error_body = {"error": {"message": "Invalid API key", "status": "PERMISSION_DENIED"}}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 403
|
||||
mock_response.json.return_value = error_body
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with patch.dict(os.environ, {"ROUTING_API_KEY": "bad-key"}):
|
||||
with pytest.raises(RoutingApiError) as exc_info:
|
||||
transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.response_body == error_body
|
||||
|
||||
def test_raises_key_error_when_api_key_not_set(self) -> None:
|
||||
env = os.environ.copy()
|
||||
env.pop("ROUTING_API_KEY", None)
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
with pytest.raises(KeyError):
|
||||
transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
|
||||
306
crawler/tests/unit/test_task_service.py
Normal file
306
crawler/tests/unit/test_task_service.py
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
"""Unit tests for services/task_service.py."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from services.task_service import (
|
||||
TaskStatus,
|
||||
_extract_progress_info,
|
||||
_extract_result,
|
||||
_make_system_user,
|
||||
get_task_status,
|
||||
)
|
||||
|
||||
|
||||
class TestMakeSystemUser:
|
||||
"""Tests for _make_system_user helper."""
|
||||
|
||||
def test_creates_user_with_email(self) -> None:
|
||||
user = _make_system_user("test@example.com")
|
||||
assert user.email == "test@example.com"
|
||||
assert user.sub == ""
|
||||
assert user.name == ""
|
||||
|
||||
def test_different_emails_create_different_users(self) -> None:
|
||||
u1 = _make_system_user("a@b.com")
|
||||
u2 = _make_system_user("c@d.com")
|
||||
assert u1.email != u2.email
|
||||
|
||||
|
||||
class TestExtractResult:
|
||||
"""Tests for _extract_result helper."""
|
||||
|
||||
def test_failed_task_returns_error(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = True
|
||||
mock_result.result = Exception("something broke")
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result is None
|
||||
assert error is not None
|
||||
assert "something broke" in error
|
||||
|
||||
def test_failed_task_with_no_result(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = True
|
||||
mock_result.result = None
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result is None
|
||||
assert error is None
|
||||
|
||||
def test_successful_json_serializable_result(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = {"count": 42, "status": "done"}
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result == {"count": 42, "status": "done"}
|
||||
assert error is None
|
||||
|
||||
def test_non_serializable_result_falls_back_to_str(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = object() # not JSON-serializable
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert isinstance(result, str)
|
||||
assert error is None
|
||||
|
||||
def test_none_result_stays_none(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = None
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result is None
|
||||
assert error is None
|
||||
|
||||
|
||||
class TestExtractProgressInfo:
|
||||
"""Tests for _extract_progress_info helper."""
|
||||
|
||||
def test_extracts_progress_fields(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {"progress": 0.5, "processed": 50, "total": 100}
|
||||
mock_result.status = "STARTED"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["progress"] == 0.5
|
||||
assert info["processed"] == 50
|
||||
assert info["total"] == 100
|
||||
assert info["message"] is None
|
||||
|
||||
def test_extracts_message_from_info(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {"message": "Processing page 3"}
|
||||
mock_result.status = "STARTED"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] == "Processing page 3"
|
||||
|
||||
def test_falls_back_to_reason_for_skipped(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {"reason": "Already running"}
|
||||
mock_result.status = "SKIPPED"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] == "Already running"
|
||||
|
||||
def test_custom_state_used_as_message(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {}
|
||||
mock_result.status = "Fetching listings"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] == "Fetching listings"
|
||||
|
||||
def test_standard_state_not_used_as_message(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {}
|
||||
mock_result.status = "PENDING"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] is None
|
||||
|
||||
def test_none_info_returns_all_none(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = None
|
||||
mock_result.status = "PENDING"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info == {"progress": None, "processed": None, "total": None, "message": None}
|
||||
|
||||
|
||||
class TestGetTaskStatus:
|
||||
"""Tests for get_task_status."""
|
||||
|
||||
def test_pending_task(self) -> None:
|
||||
"""Test status for a pending task."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "PENDING"
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = None
|
||||
mock_result.info = None
|
||||
mock_result.traceback = None
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
# Patch the lazy import
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.task_id == "test-id"
|
||||
assert status.status == "PENDING"
|
||||
assert status.error is None
|
||||
|
||||
def test_failed_task(self) -> None:
|
||||
"""Test status for a failed task."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "FAILURE"
|
||||
mock_result.failed.return_value = True
|
||||
mock_result.result = Exception("something broke")
|
||||
mock_result.info = None
|
||||
mock_result.traceback = "Traceback..."
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.status == "FAILURE"
|
||||
assert status.error is not None
|
||||
assert status.traceback == "Traceback..."
|
||||
|
||||
def test_custom_state_with_progress(self) -> None:
|
||||
"""Test that custom states with progress info are extracted correctly."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "Fetching listings"
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = None
|
||||
mock_result.info = {"progress": 0.5, "processed": 50, "total": 100}
|
||||
mock_result.traceback = None
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.progress == 0.5
|
||||
assert status.processed == 50
|
||||
assert status.total == 100
|
||||
|
||||
def test_successful_task(self) -> None:
|
||||
"""Test status for a successful task."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "SUCCESS"
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = {"listings_count": 42}
|
||||
mock_result.info = None
|
||||
mock_result.traceback = None
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.status == "SUCCESS"
|
||||
assert status.result == {"listings_count": 42}
|
||||
assert status.error is None
|
||||
|
||||
|
||||
class TestGetUserTasks:
|
||||
"""Tests for get_user_tasks."""
|
||||
|
||||
def test_returns_task_list(self) -> None:
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = ["task-1", "task-2"]
|
||||
|
||||
with patch("services.task_service.RedisRepository", create=True) as MockRedisRepo:
|
||||
MockRedisRepo.instance.return_value = mock_redis
|
||||
with patch.dict("sys.modules", {"redis_repository": MagicMock(RedisRepository=MockRedisRepo)}):
|
||||
from services.task_service import get_user_tasks
|
||||
result = get_user_tasks("test@example.com")
|
||||
assert result == ["task-1", "task-2"]
|
||||
|
||||
def test_returns_empty_list_for_unknown_user(self) -> None:
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = []
|
||||
|
||||
with patch("services.task_service.RedisRepository", create=True) as MockRedisRepo:
|
||||
MockRedisRepo.instance.return_value = mock_redis
|
||||
with patch.dict("sys.modules", {"redis_repository": MagicMock(RedisRepository=MockRedisRepo)}):
|
||||
from services.task_service import get_user_tasks
|
||||
result = get_user_tasks("nobody@example.com")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestCancelTask:
|
||||
"""Tests for cancel_task."""
|
||||
|
||||
def test_cancel_revokes_and_removes(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.remove_task_for_user.return_value = True
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import cancel_task
|
||||
result = cancel_task("task-123", user_email="test@example.com")
|
||||
assert result is True
|
||||
mock_celery.control.revoke.assert_called_once_with("task-123", terminate=True)
|
||||
|
||||
def test_cancel_without_user_email(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"celery_app": MagicMock(app=mock_celery)}):
|
||||
from services.task_service import cancel_task
|
||||
result = cancel_task("task-456")
|
||||
assert result is True
|
||||
mock_celery.control.revoke.assert_called_once_with("task-456", terminate=True)
|
||||
|
||||
|
||||
class TestClearAllTasks:
|
||||
"""Tests for clear_all_tasks."""
|
||||
|
||||
def test_clear_with_revoke(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = ["t1", "t2"]
|
||||
mock_redis.clear_tasks_for_user.return_value = 2
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import clear_all_tasks
|
||||
count = clear_all_tasks("test@example.com", revoke=True)
|
||||
assert count == 2
|
||||
assert mock_celery.control.revoke.call_count == 2
|
||||
|
||||
def test_clear_without_revoke(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.clear_tasks_for_user.return_value = 3
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import clear_all_tasks
|
||||
count = clear_all_tasks("test@example.com", revoke=False)
|
||||
assert count == 3
|
||||
mock_celery.control.revoke.assert_not_called()
|
||||
|
||||
def test_revoke_failure_logs_warning_and_continues(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_celery.control.revoke.side_effect = Exception("connection lost")
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = ["t1"]
|
||||
mock_redis.clear_tasks_for_user.return_value = 1
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import clear_all_tasks
|
||||
# Should not raise despite revoke failure
|
||||
count = clear_all_tasks("test@example.com", revoke=True)
|
||||
assert count == 1
|
||||
Loading…
Add table
Add a link
Reference in a new issue