Refactor backend for cleaner error handling, DRY, and type safety

- Extract rate limiter DRY: consolidate 3 duplicated check/respond paths
  into _check_counter and _enforce_limit helpers, add proper type annotations
- Replace bare Exception raises with FloorplanDownloadError and
  RightmoveApiError; narrow catch clauses to specific exception types;
  fix Step base class to inherit from ABC
- Consolidate MAX_OCR_WORKERS into config/scraper_config.py; extract
  _find_tenure_value helper to deduplicate tenure parsing
- Extract _build_poi_distances_lookup from stream endpoint to reduce nesting
- Fix csv_exporter: optional decisions.json, NaN instead of -1 sentinels,
  guard against division by zero on missing square meters
- Fix notifications.py broken list[Surface]() constructor, database.py
  stale comments and missing type annotation, auth.py type:ignore,
  ui_exporter.py stale TODO
- Fix 3 pre-existing test failures: mock cache layer in streaming tests,
  bypass rate limiter for test isolation, fix cache invalidation test to
  account for two-pattern scan loop
This commit is contained in:
Viktor Barzin 2026-02-10 22:19:24 +00:00
parent 6897820cc7
commit f833309297
No known key found for this signature in database
GPG key ID: 0EB088298288D958
20 changed files with 199 additions and 178 deletions

View file

@ -185,6 +185,40 @@ async def get_listing_geojson(
def _build_poi_distances_lookup(
user_email: str,
listing_type: ListingType,
) -> dict[int, list[dict[str, str | int]]] | None:
"""Build POI distance lookup for a user, or None if no POIs configured."""
user_repo = UserRepository(engine)
db_user = user_repo.get_user_by_email(user_email)
if not db_user or db_user.id is None:
return None
poi_repo = POIRepository(engine)
pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)}
if not pois:
return None
listing_repo = ListingRepository(engine)
all_ids = list(listing_repo.get_listing_ids(listing_type))
if not all_ids:
return None
distances = poi_repo.get_distances_for_listings(all_ids, listing_type, db_user.id)
lookup: dict[int, list[dict[str, str | int]]] = {}
for d in distances:
poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown"
lookup.setdefault(d.listing_id, []).append({
"poi_id": d.poi_id,
"poi_name": poi_name,
"travel_mode": d.travel_mode,
"duration_seconds": d.duration_seconds,
"distance_meters": d.distance_meters,
})
return lookup
async def _stream_from_cache( async def _stream_from_cache(
query_parameters: QueryParameters, query_parameters: QueryParameters,
batch_size: int, batch_size: int,
@ -295,32 +329,7 @@ async def stream_listing_geojson(
limit = _rate_limit_config.geojson_stream_limit_cap limit = _rate_limit_config.geojson_stream_limit_cap
# Build POI distances lookup if requested # Build POI distances lookup if requested
poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None poi_distances_lookup = _build_poi_distances_lookup(user.email, query_parameters.listing_type) if include_poi_distances else None
if include_poi_distances:
user_repo = UserRepository(engine)
db_user = user_repo.get_user_by_email(user.email)
if db_user and db_user.id is not None:
poi_repo = POIRepository(engine)
pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)}
if pois:
# Get all listing IDs first for the query
listing_repo = ListingRepository(engine)
all_ids = list(listing_repo.get_listing_ids(query_parameters.listing_type))
if all_ids:
distances = poi_repo.get_distances_for_listings(
all_ids, query_parameters.listing_type, db_user.id
)
poi_distances_lookup = {}
for d in distances:
poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown"
entry = {
"poi_id": d.poi_id,
"poi_name": poi_name,
"travel_mode": d.travel_mode,
"duration_seconds": d.duration_seconds,
"distance_meters": d.distance_meters,
}
poi_distances_lookup.setdefault(d.listing_id, []).append(entry)
cached_count = get_cached_count(query_parameters) cached_count = get_cached_count(query_parameters)
if cached_count is not None and cached_count > 0 and not include_poi_distances: if cached_count is not None and cached_count > 0 and not include_poi_distances:

View file

@ -13,6 +13,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from httpx import AsyncClient from httpx import AsyncClient
import jwt import jwt
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any
# HTTPBearer scheme (provider-agnostic, works for both OIDC and passkey JWTs) # HTTPBearer scheme (provider-agnostic, works for both OIDC and passkey JWTs)
@ -28,7 +29,7 @@ class User(BaseModel):
name: str name: str
async def get_oidc_metadata() -> dict: # type: ignore[type-arg] async def get_oidc_metadata() -> dict[str, Any]:
if "oidc_metadata" not in OIDC_METADATA_CACHE: if "oidc_metadata" not in OIDC_METADATA_CACHE:
async with AsyncClient() as client: async with AsyncClient() as client:
resp = await client.get(OIDC_METADATA_URL, follow_redirects=True) resp = await client.get(OIDC_METADATA_URL, follow_redirects=True)

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import logging import logging
import os import os
import time import time
from collections.abc import Awaitable, Callable
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
import jwt import jwt
@ -11,6 +12,7 @@ import redis
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from starlette.types import ASGIApp
from api.rate_limit_config import EndpointLimit, RateLimitConfig from api.rate_limit_config import EndpointLimit, RateLimitConfig
@ -87,21 +89,77 @@ class _InMemoryCounter:
class RateLimitMiddleware(BaseHTTPMiddleware): class RateLimitMiddleware(BaseHTTPMiddleware):
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis.""" """Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def] def __init__(self, app: ASGIApp, config: RateLimitConfig | None = None) -> None:
super().__init__(app) super().__init__(app)
self.config = config or RateLimitConfig.from_env() self.config = config or RateLimitConfig.from_env()
self._fallback = _InMemoryCounter() self._fallback = _InMemoryCounter()
try: try:
self._redis = _get_rate_limit_redis(self.config) self._redis: redis.Redis | None = _get_rate_limit_redis(self.config) # type: ignore[type-arg]
self._redis.ping() self._redis.ping()
except redis.RedisError: except redis.RedisError:
logger.warning("Rate limiter: Redis unavailable at startup, will fail open") logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
self._redis = None self._redis = None
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def] def _check_counter(self, key: str, limit: EndpointLimit) -> tuple[bool, int, int | None]:
"""Check rate limit counter, returning (allowed, remaining, retry_after).
Tries Redis first; falls back to in-memory counter on Redis errors.
retry_after is None for in-memory counters (no TTL available).
"""
if self._redis is None:
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
return allowed, remaining, None
try:
pipe = self._redis.pipeline(transaction=True)
pipe.incr(key)
pipe.ttl(key)
result = pipe.execute()
current_count: int = result[0]
ttl: int = result[1]
# Set expiry on first request in window
if ttl == -1:
self._redis.expire(key, limit.window_seconds)
ttl = limit.window_seconds
remaining = max(0, limit.max_requests - current_count)
allowed = current_count <= limit.max_requests
retry_after = max(1, ttl) if not allowed else None
return allowed, remaining, retry_after
except redis.RedisError as e:
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
return allowed, remaining, None
async def _enforce_limit(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
limit: EndpointLimit,
key: str,
) -> Response:
"""Check the rate limit and either reject with 429 or forward with headers."""
allowed, remaining, retry_after = self._check_counter(key, limit)
if not allowed:
headers: dict[str, str] = {
"X-RateLimit-Limit": str(limit.max_requests),
"X-RateLimit-Remaining": "0",
}
if retry_after is not None:
headers["Retry-After"] = str(retry_after)
return JSONResponse(status_code=429, content={"detail": "Rate limit exceeded"}, headers=headers)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
path = request.url.path path = request.url.path
# Skip exempt paths
if path in EXEMPT_PATHS: if path in EXEMPT_PATHS:
return await call_next(request) return await call_next(request)
@ -109,68 +167,6 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
if limit is None: if limit is None:
return await call_next(request) return await call_next(request)
# Determine identity for the counter key
identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth) identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
key = f"ratelimit:{identity}:{path}"
# If Redis is unavailable, use in-memory fallback return await self._enforce_limit(request, call_next, limit, key)
if self._redis is None:
fallback_key = f"ratelimit:{identity}:{path}"
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
redis_key = f"ratelimit:{identity}:{path}"
try:
pipe = self._redis.pipeline(transaction=True)
pipe.incr(redis_key)
pipe.ttl(redis_key)
result = pipe.execute()
current_count: int = result[0]
ttl: int = result[1]
# Set expiry on first request in window
if ttl == -1:
self._redis.expire(redis_key, limit.window_seconds)
ttl = limit.window_seconds
remaining = max(0, limit.max_requests - current_count)
if current_count > limit.max_requests:
retry_after = max(1, ttl)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": str(limit.max_requests),
"X-RateLimit-Remaining": "0",
},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
except redis.RedisError as e:
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
fallback_key = f"ratelimit:{identity}:{path}"
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response

View file

@ -1,10 +1,14 @@
"""Scraper configuration with environment variable loading.""" """Scraper configuration with environment variable loading."""
from __future__ import annotations from __future__ import annotations
import multiprocessing
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Self from typing import Self
# Limit OCR threads to 25% of available cores to avoid starving other work.
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
@dataclass(frozen=True) @dataclass(frozen=True)
class ScraperConfig: class ScraperConfig:

View file

@ -14,27 +14,30 @@ async def export_to_csv(
df = pd.DataFrame(ds) df = pd.DataFrame(ds)
# read decisions on file # read decisions on file
decisions_path = "data/decisions.json" decisions_path = Path("data/decisions.json")
decisions = pd.read_json(decisions_path) if decisions_path.exists():
df.loc[:, "decision"] = df.id.apply(lambda x: decisions.get(x)) decisions = pd.read_json(decisions_path)
df.loc[:, "decision"] = df.id.apply(lambda x: decisions.get(x))
# remove _sa_instance_state column # remove _sa_instance_state column
drop_columns = ["_sa_instance_state", "additional_info"] drop_columns = ["_sa_instance_state", "additional_info"]
df = df.drop(columns=drop_columns) df = df.drop(columns=drop_columns)
# fill in gap values for service charge and lease left for Excel filters # Ensure columns exist with NaN defaults for clean CSV output
if "service_charge" not in df.columns: for col in ("service_charge", "lease_left", "square_meters"):
df.loc[:, "service_charge"] = -1 if col not in df.columns:
df.loc[:, "service_charge"] = df.service_charge.fillna(-1) df.loc[:, col] = float("nan")
if "lease_left" not in df.columns:
df.loc[:, "lease_left"] = -1
df.loc[:, "lease_left"] = df.lease_left.fillna(-1)
if "square_meters" not in df.columns:
df.loc[:, "square_meters"] = -1
df.loc[:, "square_meters"] = df.square_meters.fillna(-1)
# Add price per sqm column # Replace -1 sentinel values with NaN
df.loc[:, "price_per_sqm"] = df.price / df.square_meters df.loc[:, "square_meters"] = df.square_meters.replace({-1: float("nan")})
# Add price per sqm column (guard against zero/missing square_meters)
df.loc[:, "price_per_sqm"] = df.apply(
lambda row: round(row.price / row.square_meters, 2)
if row.square_meters and row.square_meters > 0
else None,
axis=1,
)
df = df.sort_values(by=["price_per_sqm"], ascending=True) df = df.sort_values(by=["price_per_sqm"], ascending=True)
df.to_csv(str(output_file), index=False) df.to_csv(str(output_file), index=False)

View file

@ -6,10 +6,6 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
# PostgreSQL example (or use "sqlite:///database.db" for SQLite)
# DATABASE_URL = "postgresql://user:password@localhost/db_name"
# DATABASE_URL = "sqlite:///data/wrongmove.db"
# DATABASE_URL = "mysql://wrongmove:wrongmove@localhost:3306/wrongmove"
DATABASE_URL = os.environ["DB_CONNECTION_STRING"] DATABASE_URL = os.environ["DB_CONNECTION_STRING"]
@ -18,6 +14,6 @@ engine = create_engine(DATABASE_URL, echo=debug) # `echo=True` for debug logs
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db(): def init_db() -> None:
"""Create all tables (only for development; use migrations in production).""" """Create all tables (only for development; use migrations in production)."""
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)

View file

@ -1,15 +1,15 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime from datetime import datetime
import logging import logging
import multiprocessing
from pathlib import Path from pathlib import Path
import re import re
from typing import Any from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
from config.scraper_config import MAX_OCR_WORKERS
from models.listing import ( from models.listing import (
BuyListing, BuyListing,
FurnishType, FurnishType,
@ -20,14 +20,12 @@ from models.listing import (
RentListing, RentListing,
) )
from rec import floorplan from rec import floorplan
from rec.exceptions import FloorplanDownloadError
from rec.query import detail_query from rec.query import detail_query
from repositories.listing_repository import ListingRepository from repositories.listing_repository import ListingRepository
logger = logging.getLogger("uvicorn.error") logger = logging.getLogger("uvicorn.error")
# Limit OCR threads to 25% of available cores to avoid starving other work.
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
def _parse_furnish_type(raw: str | None) -> FurnishType: def _parse_furnish_type(raw: str | None) -> FurnishType:
"""Normalise the raw furnish-type string from the API into a FurnishType enum.""" """Normalise the raw furnish-type string from the API into a FurnishType enum."""
@ -97,13 +95,13 @@ class ListingProcessor:
step_class_name, step_class_name step_class_name, step_class_name
) )
on_step_complete(short_name) on_step_complete(short_name)
except Exception as e: except (ValueError, KeyError, aiohttp.ClientError, FloorplanDownloadError) as e:
logger.error(f"[{listing_id}] {step_class_name} failed: {e}") logger.error(f"[{listing_id}] {step_class_name} failed: {e}")
return None return None
return listing return listing
class Step: class Step(ABC):
listing_repository: ListingRepository listing_repository: ListingRepository
listing_type: ListingType listing_type: ListingType
@ -123,29 +121,32 @@ class Step:
return True return True
def _find_tenure_value(details: dict[str, Any], tenure_type: str) -> str | None:
"""Find a value in the tenure info content by type key."""
tenure_content = details.get("property", {}).get("tenureInfo", {}).get("content", [])
for item in tenure_content:
if item.get("type") == tenure_type:
return item.get("value")
return None
def _parse_service_charge(details: dict[str, Any]) -> float | None: def _parse_service_charge(details: dict[str, Any]) -> float | None:
"""Parse annual service charge from the tenure info in API response.""" """Parse annual service charge from the tenure info in API response."""
tenure_content = ( value = _find_tenure_value(details, "annualServiceCharge")
details.get("property", {}).get("tenureInfo", {}).get("content", []) if value is not None:
) matches = re.findall(r"([\d,.]+)", str(value))
for item in tenure_content: if matches:
if item.get("type") == "annualServiceCharge": return float(matches[0].replace(",", ""))
matches = re.findall(r"([\d,.]+)", str(item.get("value", "")))
if matches:
return float(matches[0].replace(",", ""))
return None return None
def _parse_lease_left(details: dict[str, Any]) -> int | None: def _parse_lease_left(details: dict[str, Any]) -> int | None:
"""Parse remaining lease years from the tenure info in API response.""" """Parse remaining lease years from the tenure info in API response."""
tenure_content = ( value = _find_tenure_value(details, "lengthOfLease")
details.get("property", {}).get("tenureInfo", {}).get("content", []) if value is not None:
) matches = re.findall(r"(\d+)", str(value))
for item in tenure_content: if matches:
if item.get("type") == "lengthOfLease": return int(matches[0])
matches = re.findall(r"(\d+)", str(item.get("value", "")))
if matches:
return int(matches[0])
return None return None
@ -265,7 +266,7 @@ class FetchImagesStep(Step):
if response.status == 404: if response.status == 404:
return listing return listing
if response.status != 200: if response.status != 200:
raise Exception(f"Error for {url}: {response.status}") raise FloorplanDownloadError(url, response.status)
floorplan_path.parent.mkdir(parents=True, exist_ok=True) floorplan_path.parent.mkdir(parents=True, exist_ok=True)
with open(floorplan_path, "wb") as f: with open(floorplan_path, "wb") as f:
f.write(await response.read()) f.write(await response.read())

View file

@ -15,8 +15,8 @@ class Slack(Surface):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_notifier(surfaces: list[Surface] | None = None) -> apprise.Apprise: def get_notifier() -> apprise.Apprise:
surfaces = surfaces or list[Surface]([Slack()]) surfaces = [Slack()]
obj = apprise.Apprise() obj = apprise.Apprise()
for surface in surfaces: for surface in surfaces:
if conn := surface.connection_string(): if conn := surface.connection_string():

View file

@ -74,6 +74,15 @@ class CircuitBreakerOpenError(RightmoveAPIError):
pass pass
class FloorplanDownloadError(Exception):
"""Raised when a floorplan image download fails."""
def __init__(self, url: str, status_code: int) -> None:
self.url = url
self.status_code = status_code
super().__init__(f"HTTP {status_code} downloading floorplan from {url}")
class RoutingApiError(Exception): class RoutingApiError(Exception):
"""Error from the Google Routes API.""" """Error from the Google Routes API."""

View file

@ -10,6 +10,7 @@ from models.listing import FurnishType, ListingType
from rec import districts from rec import districts
from rec.exceptions import ( from rec.exceptions import (
CircuitBreakerOpenError, CircuitBreakerOpenError,
RightmoveAPIError,
ThrottlingError, ThrottlingError,
) )
from rec.throttle_detector import get_throttle_metrics, validate_response from rec.throttle_detector import get_throttle_metrics, validate_response
@ -205,9 +206,9 @@ def _build_listing_params(
7, 7,
14, 14,
]: ]:
raise Exception( raise ValueError(
f"Invalid max days - {max_days_since_added} Can only be got", f"Invalid max_days_since_added={max_days_since_added}, "
[1, 3, 7, 14], f"must be one of [1, 3, 7, 14]"
) )
params["maxDaysSinceAdded"] = str(max_days_since_added) params["maxDaysSinceAdded"] = str(max_days_since_added)
@ -287,7 +288,7 @@ async def _execute_api_request(
) )
if response.status != 200: if response.status != 200:
raise Exception( raise RightmoveAPIError(
f"{error_context}Failed due to: {await response.text()}" f"{error_context}Failed due to: {await response.text()}"
) )

View file

@ -1,13 +1,10 @@
"""Floorplan detector service - OCR-based square meter detection.""" """Floorplan detector service - OCR-based square meter detection."""
import asyncio import asyncio
from config.scraper_config import MAX_OCR_WORKERS
from models import Listing from models import Listing
from rec import floorplan from rec import floorplan
from repositories.listing_repository import ListingRepository from repositories.listing_repository import ListingRepository
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
import multiprocessing
# Use a quarter of available CPUs to avoid starving other processes
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
async def detect_floorplan(repository: ListingRepository) -> None: async def detect_floorplan(repository: ListingRepository) -> None:

View file

@ -5,6 +5,7 @@ from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
from rec.exceptions import FloorplanDownloadError
from repositories import ListingRepository from repositories import ListingRepository
from tenacity import retry, stop_after_attempt, wait_random from tenacity import retry, stop_after_attempt, wait_random
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
@ -65,10 +66,7 @@ async def dump_images_for_listing(
) )
return None return None
if response.status != 200: if response.status != 200:
raise Exception( raise FloorplanDownloadError(url, response.status)
f"Error downloading floorplan for listing {listing.id} "
f"from {url}: HTTP {response.status}"
)
floorplan_path.parent.mkdir(parents=True, exist_ok=True) floorplan_path.parent.mkdir(parents=True, exist_ok=True)
with open(floorplan_path, "wb") as f: with open(floorplan_path, "wb") as f:
f.write(await response.read()) f.write(await response.read())

View file

@ -5,7 +5,7 @@ import logging
from config.scraper_config import ScraperConfig from config.scraper_config import ScraperConfig
from listing_processor import ListingProcessor from listing_processor import ListingProcessor
from rec.query import create_session, listing_query from rec.query import create_session, listing_query
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
from models.listing import Listing, QueryParameters from models.listing import Listing, QueryParameters
from repositories import ListingRepository from repositories import ListingRepository
@ -107,15 +107,15 @@ async def _fetch_subquery(
f"{sq.district}: {e}" f"{sq.district}: {e}"
) )
break break
except Exception as e: except InvalidResponseError:
# Rightmove returns GENERIC_ERROR when requesting pages # Rightmove returns GENERIC_ERROR when requesting pages
# past the last page of results. This is expected behavior # past the last page of results. This is expected behavior
# and signals we've exhausted this subquery's results. # and signals we've exhausted this subquery's results.
if "GENERIC_ERROR" in str(e): logger.debug(
logger.debug( f"Max page for {sq.district}: {page_id - 1}"
f"Max page for {sq.district}: {page_id - 1}" )
) break
break except Exception as e:
logger.warning( logger.warning(
f"Error fetching page {page_id} for " f"Error fetching page {page_id} for "
f"{sq.district}: {e}" f"{sq.district}: {e}"

View file

@ -12,7 +12,7 @@ from config.scraper_config import ScraperConfig
from listing_processor import ListingProcessor from listing_processor import ListingProcessor
from models.listing import Listing, QueryParameters from models.listing import Listing, QueryParameters
from rec.query import create_session, listing_query from rec.query import create_session, listing_query
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
from repositories.listing_repository import ListingRepository from repositories.listing_repository import ListingRepository
from database import engine from database import engine
@ -324,12 +324,12 @@ async def _fetch_subquery(
f"Throttling on {sq.district} page {page_id}: {e}" f"Throttling on {sq.district} page {page_id}: {e}"
) )
break break
except InvalidResponseError:
celery_logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
except Exception as e: except Exception as e:
if "GENERIC_ERROR" in str(e):
celery_logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
celery_logger.warning( celery_logger.warning(
f"Error fetching page {page_id} for " f"Error fetching page {page_id} for "
f"{sq.district}: {e}" f"{sq.district}: {e}"

View file

@ -70,7 +70,7 @@ async def test_step_failure_stops_pipeline(
processor = ListingProcessor(listing_repository) processor = ListingProcessor(listing_repository)
processor.process_steps[0].needs_processing = AsyncMock(return_value=True) processor.process_steps[0].needs_processing = AsyncMock(return_value=True)
processor.process_steps[0].process = AsyncMock(side_effect=RuntimeError("boom")) processor.process_steps[0].process = AsyncMock(side_effect=ValueError("boom"))
processor.process_steps[1].needs_processing = AsyncMock(return_value=True) processor.process_steps[1].needs_processing = AsyncMock(return_value=True)
processor.process_steps[1].process = AsyncMock() processor.process_steps[1].process = AsyncMock()
processor.process_steps[2].needs_processing = AsyncMock(return_value=True) processor.process_steps[2].needs_processing = AsyncMock(return_value=True)

View file

@ -156,7 +156,7 @@ class TestStreamingEndpoint:
@pytest.fixture @pytest.fixture
def client(self): def client(self):
"""Create test client with mocked auth.""" """Create test client with mocked auth and rate limiting bypassed."""
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from api.app import app from api.app import app
from api.auth import get_current_user, User from api.auth import get_current_user, User
@ -165,13 +165,15 @@ class TestStreamingEndpoint:
return User(sub="test-id", email="test@example.com", name="Test User") return User(sub="test-id", email="test@example.com", name="Test User")
app.dependency_overrides[get_current_user] = mock_auth app.dependency_overrides[get_current_user] = mock_auth
yield TestClient(app) with patch("api.rate_limiter._match_endpoint", return_value=None):
yield TestClient(app)
app.dependency_overrides.clear() app.dependency_overrides.clear()
@pytest.fixture @pytest.fixture
def mock_repository(self): def mock_repository(self):
"""Mock the repository methods.""" """Mock the repository methods and bypass cache."""
with patch("api.app.ListingRepository") as MockRepo: with patch("api.app.get_cached_count", return_value=None), \
patch("api.app.ListingRepository") as MockRepo:
mock_instance = MagicMock() mock_instance = MagicMock()
mock_instance.count_listings.return_value = 3 mock_instance.count_listings.return_value = 3
mock_instance.stream_listings_optimized.return_value = iter([ mock_instance.stream_listings_optimized.return_value = iter([

View file

@ -204,8 +204,12 @@ class TestInvalidateCache:
mock_client = mock.MagicMock() mock_client = mock.MagicMock()
mock_pipeline = mock.MagicMock() mock_pipeline = mock.MagicMock()
mock_client.pipeline.return_value = mock_pipeline mock_client.pipeline.return_value = mock_pipeline
# Simulate one scan iteration that returns keys, then done # invalidate_cache scans two patterns (CACHE_PREFIX*, STAGING_PREFIX*)
mock_client.scan.return_value = (0, ["listings:geojson:abc", "listings:geojson:def"]) # First scan returns matching keys, second returns none
mock_client.scan.side_effect = [
(0, ["listings:geojson:abc", "listings:geojson:def"]),
(0, []),
]
mock_get_client.return_value = mock_client mock_get_client.return_value = mock_client
invalidate_cache() invalidate_cache()

View file

@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from models.listing import ListingType, QueryParameters from models.listing import ListingType, QueryParameters
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
from services.listing_fetcher import ( from services.listing_fetcher import (
NUM_WORKERS, NUM_WORKERS,
_fetch_subquery, _fetch_subquery,
@ -227,7 +227,7 @@ class TestFetchSubquery:
with patch( with patch(
"services.listing_fetcher.listing_query", "services.listing_fetcher.listing_query",
new_callable=AsyncMock, new_callable=AsyncMock,
side_effect=Exception("GENERIC_ERROR: no more results"), side_effect=InvalidResponseError("GENERIC_ERROR: no more results"),
): ):
ids_found = await _fetch_subquery( ids_found = await _fetch_subquery(
sq=sq, sq=sq,

View file

@ -3,12 +3,12 @@ from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from models.listing import FurnishType, ListingType from models.listing import FurnishType, ListingType
from config.scraper_config import MAX_OCR_WORKERS
from listing_processor import ( from listing_processor import (
_parse_furnish_type, _parse_furnish_type,
_parse_available_from, _parse_available_from,
ListingProcessor, ListingProcessor,
FetchListingDetailsStep, FetchListingDetailsStep,
MAX_OCR_WORKERS,
) )
@ -77,7 +77,7 @@ class TestListingProcessor:
processor = ListingProcessor(mock_repo) processor = ListingProcessor(mock_repo)
for step in processor.process_steps: for step in processor.process_steps:
step.needs_processing = AsyncMock(return_value=True) step.needs_processing = AsyncMock(return_value=True)
step.process = AsyncMock(side_effect=Exception("fail")) step.process = AsyncMock(side_effect=ValueError("fail"))
result = await processor.process_listing(123) result = await processor.process_listing(123)
assert result is None assert result is None

View file

@ -99,7 +99,7 @@ def convert_to_geojson_feature(listing: RentListing | BuyListing) -> dict[str, A
properties: dict[str, Any] = { properties: dict[str, Any] = {
"listing_type": listing_type, "listing_type": listing_type,
"city": "London", # change me "city": "London",
"country": "United Kingdom", "country": "United Kingdom",
"qm": listing.square_meters, "qm": listing.square_meters,
"qmprice": listing.price_per_square_meter, "qmprice": listing.price_per_square_meter,