From f8333092976adced79418b2d0ce90d30a836e0b1 Mon Sep 17 00:00:00 2001 From: Viktor Barzin Date: Tue, 10 Feb 2026 22:19:24 +0000 Subject: [PATCH] Refactor backend for cleaner error handling, DRY, and type safety - Extract rate limiter DRY: consolidate 3 duplicated check/respond paths into _check_counter and _enforce_limit helpers, add proper type annotations - Replace bare Exception raises with FloorplanDownloadError and RightmoveApiError; narrow catch clauses to specific exception types; fix Step base class to inherit from ABC - Consolidate MAX_OCR_WORKERS into config/scraper_config.py; extract _find_tenure_value helper to deduplicate tenure parsing - Extract _build_poi_distances_lookup from stream endpoint to reduce nesting - Fix csv_exporter: optional decisions.json, NaN instead of -1 sentinels, guard against division by zero on missing square meters - Fix notifications.py broken list[Surface]() constructor, database.py stale comments and missing type annotation, auth.py type:ignore, ui_exporter.py stale TODO - Fix 3 pre-existing test failures: mock cache layer in streaming tests, bypass rate limiter for test isolation, fix cache invalidation test to account for two-pattern scan loop --- api/app.py | 61 +++++---- api/auth.py | 3 +- api/rate_limiter.py | 132 ++++++++++---------- config/scraper_config.py | 4 + csv_exporter.py | 33 ++--- database.py | 6 +- listing_processor.py | 49 ++++---- notifications.py | 4 +- rec/exceptions.py | 9 ++ rec/query.py | 9 +- services/floorplan_detector.py | 5 +- services/image_fetcher.py | 6 +- services/listing_fetcher.py | 14 +-- tasks/listing_tasks.py | 12 +- tests/integration/test_listing_processor.py | 2 +- tests/test_listing_geojson.py | 10 +- tests/unit/test_listing_cache.py | 8 +- tests/unit/test_listing_fetcher.py | 4 +- tests/unit/test_listing_processor.py | 4 +- ui_exporter.py | 2 +- 20 files changed, 199 insertions(+), 178 deletions(-) diff --git a/api/app.py b/api/app.py index 1d15184..9bbcbc0 100644 --- a/api/app.py +++ b/api/app.py @@ -185,6 +185,40 @@ async def get_listing_geojson( +def _build_poi_distances_lookup( + user_email: str, + listing_type: ListingType, +) -> dict[int, list[dict[str, str | int]]] | None: + """Build POI distance lookup for a user, or None if no POIs configured.""" + user_repo = UserRepository(engine) + db_user = user_repo.get_user_by_email(user_email) + if not db_user or db_user.id is None: + return None + + poi_repo = POIRepository(engine) + pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)} + if not pois: + return None + + listing_repo = ListingRepository(engine) + all_ids = list(listing_repo.get_listing_ids(listing_type)) + if not all_ids: + return None + + distances = poi_repo.get_distances_for_listings(all_ids, listing_type, db_user.id) + lookup: dict[int, list[dict[str, str | int]]] = {} + for d in distances: + poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown" + lookup.setdefault(d.listing_id, []).append({ + "poi_id": d.poi_id, + "poi_name": poi_name, + "travel_mode": d.travel_mode, + "duration_seconds": d.duration_seconds, + "distance_meters": d.distance_meters, + }) + return lookup + + async def _stream_from_cache( query_parameters: QueryParameters, batch_size: int, @@ -295,32 +329,7 @@ async def stream_listing_geojson( limit = _rate_limit_config.geojson_stream_limit_cap # Build POI distances lookup if requested - poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None - if include_poi_distances: - user_repo = UserRepository(engine) - db_user = user_repo.get_user_by_email(user.email) - if db_user and db_user.id is not None: - poi_repo = POIRepository(engine) - pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)} - if pois: - # Get all listing IDs first for the query - listing_repo = ListingRepository(engine) - all_ids = list(listing_repo.get_listing_ids(query_parameters.listing_type)) - if all_ids: - distances = poi_repo.get_distances_for_listings( - all_ids, query_parameters.listing_type, db_user.id - ) - poi_distances_lookup = {} - for d in distances: - poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown" - entry = { - "poi_id": d.poi_id, - "poi_name": poi_name, - "travel_mode": d.travel_mode, - "duration_seconds": d.duration_seconds, - "distance_meters": d.distance_meters, - } - poi_distances_lookup.setdefault(d.listing_id, []).append(entry) + poi_distances_lookup = _build_poi_distances_lookup(user.email, query_parameters.listing_type) if include_poi_distances else None cached_count = get_cached_count(query_parameters) if cached_count is not None and cached_count > 0 and not include_poi_distances: diff --git a/api/auth.py b/api/auth.py index ecc6825..8a5182d 100644 --- a/api/auth.py +++ b/api/auth.py @@ -13,6 +13,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from httpx import AsyncClient import jwt from pydantic import BaseModel +from typing import Any # HTTPBearer scheme (provider-agnostic, works for both OIDC and passkey JWTs) @@ -28,7 +29,7 @@ class User(BaseModel): name: str -async def get_oidc_metadata() -> dict: # type: ignore[type-arg] +async def get_oidc_metadata() -> dict[str, Any]: if "oidc_metadata" not in OIDC_METADATA_CACHE: async with AsyncClient() as client: resp = await client.get(OIDC_METADATA_URL, follow_redirects=True) diff --git a/api/rate_limiter.py b/api/rate_limiter.py index 7f6e9da..a76be3a 100644 --- a/api/rate_limiter.py +++ b/api/rate_limiter.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging import os import time +from collections.abc import Awaitable, Callable from urllib.parse import urlparse, urlunparse import jwt @@ -11,6 +12,7 @@ import redis from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response +from starlette.types import ASGIApp from api.rate_limit_config import EndpointLimit, RateLimitConfig @@ -87,21 +89,77 @@ class _InMemoryCounter: class RateLimitMiddleware(BaseHTTPMiddleware): """Starlette middleware enforcing per-user fixed-window rate limits via Redis.""" - def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def] + def __init__(self, app: ASGIApp, config: RateLimitConfig | None = None) -> None: super().__init__(app) self.config = config or RateLimitConfig.from_env() self._fallback = _InMemoryCounter() try: - self._redis = _get_rate_limit_redis(self.config) + self._redis: redis.Redis | None = _get_rate_limit_redis(self.config) # type: ignore[type-arg] self._redis.ping() except redis.RedisError: logger.warning("Rate limiter: Redis unavailable at startup, will fail open") self._redis = None - async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def] + def _check_counter(self, key: str, limit: EndpointLimit) -> tuple[bool, int, int | None]: + """Check rate limit counter, returning (allowed, remaining, retry_after). + + Tries Redis first; falls back to in-memory counter on Redis errors. + retry_after is None for in-memory counters (no TTL available). + """ + if self._redis is None: + allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds) + return allowed, remaining, None + + try: + pipe = self._redis.pipeline(transaction=True) + pipe.incr(key) + pipe.ttl(key) + result = pipe.execute() + current_count: int = result[0] + ttl: int = result[1] + + # Set expiry on first request in window + if ttl == -1: + self._redis.expire(key, limit.window_seconds) + ttl = limit.window_seconds + + remaining = max(0, limit.max_requests - current_count) + allowed = current_count <= limit.max_requests + retry_after = max(1, ttl) if not allowed else None + return allowed, remaining, retry_after + + except redis.RedisError as e: + logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}") + allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds) + return allowed, remaining, None + + async def _enforce_limit( + self, + request: Request, + call_next: Callable[[Request], Awaitable[Response]], + limit: EndpointLimit, + key: str, + ) -> Response: + """Check the rate limit and either reject with 429 or forward with headers.""" + allowed, remaining, retry_after = self._check_counter(key, limit) + + if not allowed: + headers: dict[str, str] = { + "X-RateLimit-Limit": str(limit.max_requests), + "X-RateLimit-Remaining": "0", + } + if retry_after is not None: + headers["Retry-After"] = str(retry_after) + return JSONResponse(status_code=429, content={"detail": "Rate limit exceeded"}, headers=headers) + + response = await call_next(request) + response.headers["X-RateLimit-Limit"] = str(limit.max_requests) + response.headers["X-RateLimit-Remaining"] = str(remaining) + return response + + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: path = request.url.path - # Skip exempt paths if path in EXEMPT_PATHS: return await call_next(request) @@ -109,68 +167,6 @@ class RateLimitMiddleware(BaseHTTPMiddleware): if limit is None: return await call_next(request) - # Determine identity for the counter key identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth) - - # If Redis is unavailable, use in-memory fallback - if self._redis is None: - fallback_key = f"ratelimit:{identity}:{path}" - allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds) - if not allowed: - return JSONResponse( - status_code=429, - content={"detail": "Rate limit exceeded"}, - headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"}, - ) - response = await call_next(request) - response.headers["X-RateLimit-Limit"] = str(limit.max_requests) - response.headers["X-RateLimit-Remaining"] = str(remaining) - return response - - redis_key = f"ratelimit:{identity}:{path}" - try: - pipe = self._redis.pipeline(transaction=True) - pipe.incr(redis_key) - pipe.ttl(redis_key) - result = pipe.execute() - current_count: int = result[0] - ttl: int = result[1] - - # Set expiry on first request in window - if ttl == -1: - self._redis.expire(redis_key, limit.window_seconds) - ttl = limit.window_seconds - - remaining = max(0, limit.max_requests - current_count) - - if current_count > limit.max_requests: - retry_after = max(1, ttl) - return JSONResponse( - status_code=429, - content={"detail": "Rate limit exceeded"}, - headers={ - "Retry-After": str(retry_after), - "X-RateLimit-Limit": str(limit.max_requests), - "X-RateLimit-Remaining": "0", - }, - ) - - response = await call_next(request) - response.headers["X-RateLimit-Limit"] = str(limit.max_requests) - response.headers["X-RateLimit-Remaining"] = str(remaining) - return response - - except redis.RedisError as e: - logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}") - fallback_key = f"ratelimit:{identity}:{path}" - allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds) - if not allowed: - return JSONResponse( - status_code=429, - content={"detail": "Rate limit exceeded"}, - headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"}, - ) - response = await call_next(request) - response.headers["X-RateLimit-Limit"] = str(limit.max_requests) - response.headers["X-RateLimit-Remaining"] = str(remaining) - return response + key = f"ratelimit:{identity}:{path}" + return await self._enforce_limit(request, call_next, limit, key) diff --git a/config/scraper_config.py b/config/scraper_config.py index 860d343..2b2232c 100644 --- a/config/scraper_config.py +++ b/config/scraper_config.py @@ -1,10 +1,14 @@ """Scraper configuration with environment variable loading.""" from __future__ import annotations +import multiprocessing import os from dataclasses import dataclass from typing import Self +# Limit OCR threads to 25% of available cores to avoid starving other work. +MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4) + @dataclass(frozen=True) class ScraperConfig: diff --git a/csv_exporter.py b/csv_exporter.py index 9bd286c..f3ca057 100644 --- a/csv_exporter.py +++ b/csv_exporter.py @@ -14,27 +14,30 @@ async def export_to_csv( df = pd.DataFrame(ds) # read decisions on file - decisions_path = "data/decisions.json" - decisions = pd.read_json(decisions_path) - df.loc[:, "decision"] = df.id.apply(lambda x: decisions.get(x)) + decisions_path = Path("data/decisions.json") + if decisions_path.exists(): + decisions = pd.read_json(decisions_path) + df.loc[:, "decision"] = df.id.apply(lambda x: decisions.get(x)) # remove _sa_instance_state column drop_columns = ["_sa_instance_state", "additional_info"] df = df.drop(columns=drop_columns) - # fill in gap values for service charge and lease left for Excel filters - if "service_charge" not in df.columns: - df.loc[:, "service_charge"] = -1 - df.loc[:, "service_charge"] = df.service_charge.fillna(-1) - if "lease_left" not in df.columns: - df.loc[:, "lease_left"] = -1 - df.loc[:, "lease_left"] = df.lease_left.fillna(-1) - if "square_meters" not in df.columns: - df.loc[:, "square_meters"] = -1 - df.loc[:, "square_meters"] = df.square_meters.fillna(-1) + # Ensure columns exist with NaN defaults for clean CSV output + for col in ("service_charge", "lease_left", "square_meters"): + if col not in df.columns: + df.loc[:, col] = float("nan") - # Add price per sqm column - df.loc[:, "price_per_sqm"] = df.price / df.square_meters + # Replace -1 sentinel values with NaN + df.loc[:, "square_meters"] = df.square_meters.replace({-1: float("nan")}) + + # Add price per sqm column (guard against zero/missing square_meters) + df.loc[:, "price_per_sqm"] = df.apply( + lambda row: round(row.price / row.square_meters, 2) + if row.square_meters and row.square_meters > 0 + else None, + axis=1, + ) df = df.sort_values(by=["price_per_sqm"], ascending=True) df.to_csv(str(output_file), index=False) diff --git a/database.py b/database.py index e893fa3..f3e8ea4 100644 --- a/database.py +++ b/database.py @@ -6,10 +6,6 @@ from dotenv import load_dotenv load_dotenv() -# PostgreSQL example (or use "sqlite:///database.db" for SQLite) -# DATABASE_URL = "postgresql://user:password@localhost/db_name" -# DATABASE_URL = "sqlite:///data/wrongmove.db" -# DATABASE_URL = "mysql://wrongmove:wrongmove@localhost:3306/wrongmove" DATABASE_URL = os.environ["DB_CONNECTION_STRING"] @@ -18,6 +14,6 @@ engine = create_engine(DATABASE_URL, echo=debug) # `echo=True` for debug logs SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -def init_db(): +def init_db() -> None: """Create all tables (only for development; use migrations in production).""" SQLModel.metadata.create_all(engine) diff --git a/listing_processor.py b/listing_processor.py index b2390fb..b37e81b 100644 --- a/listing_processor.py +++ b/listing_processor.py @@ -1,15 +1,15 @@ from __future__ import annotations -from abc import abstractmethod +from abc import ABC, abstractmethod import asyncio from collections.abc import Callable from datetime import datetime import logging -import multiprocessing from pathlib import Path import re from typing import Any from urllib.parse import urlparse import aiohttp +from config.scraper_config import MAX_OCR_WORKERS from models.listing import ( BuyListing, FurnishType, @@ -20,14 +20,12 @@ from models.listing import ( RentListing, ) from rec import floorplan +from rec.exceptions import FloorplanDownloadError from rec.query import detail_query from repositories.listing_repository import ListingRepository logger = logging.getLogger("uvicorn.error") -# Limit OCR threads to 25% of available cores to avoid starving other work. -MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4) - def _parse_furnish_type(raw: str | None) -> FurnishType: """Normalise the raw furnish-type string from the API into a FurnishType enum.""" @@ -97,13 +95,13 @@ class ListingProcessor: step_class_name, step_class_name ) on_step_complete(short_name) - except Exception as e: + except (ValueError, KeyError, aiohttp.ClientError, FloorplanDownloadError) as e: logger.error(f"[{listing_id}] {step_class_name} failed: {e}") return None return listing -class Step: +class Step(ABC): listing_repository: ListingRepository listing_type: ListingType @@ -123,29 +121,32 @@ class Step: return True +def _find_tenure_value(details: dict[str, Any], tenure_type: str) -> str | None: + """Find a value in the tenure info content by type key.""" + tenure_content = details.get("property", {}).get("tenureInfo", {}).get("content", []) + for item in tenure_content: + if item.get("type") == tenure_type: + return item.get("value") + return None + + def _parse_service_charge(details: dict[str, Any]) -> float | None: """Parse annual service charge from the tenure info in API response.""" - tenure_content = ( - details.get("property", {}).get("tenureInfo", {}).get("content", []) - ) - for item in tenure_content: - if item.get("type") == "annualServiceCharge": - matches = re.findall(r"([\d,.]+)", str(item.get("value", ""))) - if matches: - return float(matches[0].replace(",", "")) + value = _find_tenure_value(details, "annualServiceCharge") + if value is not None: + matches = re.findall(r"([\d,.]+)", str(value)) + if matches: + return float(matches[0].replace(",", "")) return None def _parse_lease_left(details: dict[str, Any]) -> int | None: """Parse remaining lease years from the tenure info in API response.""" - tenure_content = ( - details.get("property", {}).get("tenureInfo", {}).get("content", []) - ) - for item in tenure_content: - if item.get("type") == "lengthOfLease": - matches = re.findall(r"(\d+)", str(item.get("value", ""))) - if matches: - return int(matches[0]) + value = _find_tenure_value(details, "lengthOfLease") + if value is not None: + matches = re.findall(r"(\d+)", str(value)) + if matches: + return int(matches[0]) return None @@ -265,7 +266,7 @@ class FetchImagesStep(Step): if response.status == 404: return listing if response.status != 200: - raise Exception(f"Error for {url}: {response.status}") + raise FloorplanDownloadError(url, response.status) floorplan_path.parent.mkdir(parents=True, exist_ok=True) with open(floorplan_path, "wb") as f: f.write(await response.read()) diff --git a/notifications.py b/notifications.py index 2d40276..3fec5ef 100644 --- a/notifications.py +++ b/notifications.py @@ -15,8 +15,8 @@ class Slack(Surface): @lru_cache(maxsize=None) -def get_notifier(surfaces: list[Surface] | None = None) -> apprise.Apprise: - surfaces = surfaces or list[Surface]([Slack()]) +def get_notifier() -> apprise.Apprise: + surfaces = [Slack()] obj = apprise.Apprise() for surface in surfaces: if conn := surface.connection_string(): diff --git a/rec/exceptions.py b/rec/exceptions.py index 7c73996..9efb411 100644 --- a/rec/exceptions.py +++ b/rec/exceptions.py @@ -74,6 +74,15 @@ class CircuitBreakerOpenError(RightmoveAPIError): pass +class FloorplanDownloadError(Exception): + """Raised when a floorplan image download fails.""" + + def __init__(self, url: str, status_code: int) -> None: + self.url = url + self.status_code = status_code + super().__init__(f"HTTP {status_code} downloading floorplan from {url}") + + class RoutingApiError(Exception): """Error from the Google Routes API.""" diff --git a/rec/query.py b/rec/query.py index 442e0ef..805fdab 100644 --- a/rec/query.py +++ b/rec/query.py @@ -10,6 +10,7 @@ from models.listing import FurnishType, ListingType from rec import districts from rec.exceptions import ( CircuitBreakerOpenError, + RightmoveAPIError, ThrottlingError, ) from rec.throttle_detector import get_throttle_metrics, validate_response @@ -205,9 +206,9 @@ def _build_listing_params( 7, 14, ]: - raise Exception( - f"Invalid max days - {max_days_since_added} Can only be got", - [1, 3, 7, 14], + raise ValueError( + f"Invalid max_days_since_added={max_days_since_added}, " + f"must be one of [1, 3, 7, 14]" ) params["maxDaysSinceAdded"] = str(max_days_since_added) @@ -287,7 +288,7 @@ async def _execute_api_request( ) if response.status != 200: - raise Exception( + raise RightmoveAPIError( f"{error_context}Failed due to: {await response.text()}" ) diff --git a/services/floorplan_detector.py b/services/floorplan_detector.py index a09c5e4..490b88f 100644 --- a/services/floorplan_detector.py +++ b/services/floorplan_detector.py @@ -1,13 +1,10 @@ """Floorplan detector service - OCR-based square meter detection.""" import asyncio +from config.scraper_config import MAX_OCR_WORKERS from models import Listing from rec import floorplan from repositories.listing_repository import ListingRepository from tqdm.asyncio import tqdm -import multiprocessing - -# Use a quarter of available CPUs to avoid starving other processes -MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4) async def detect_floorplan(repository: ListingRepository) -> None: diff --git a/services/image_fetcher.py b/services/image_fetcher.py index c5f0369..085cba8 100644 --- a/services/image_fetcher.py +++ b/services/image_fetcher.py @@ -5,6 +5,7 @@ from pathlib import Path from urllib.parse import urlparse import aiohttp +from rec.exceptions import FloorplanDownloadError from repositories import ListingRepository from tenacity import retry, stop_after_attempt, wait_random from tqdm.asyncio import tqdm @@ -65,10 +66,7 @@ async def dump_images_for_listing( ) return None if response.status != 200: - raise Exception( - f"Error downloading floorplan for listing {listing.id} " - f"from {url}: HTTP {response.status}" - ) + raise FloorplanDownloadError(url, response.status) floorplan_path.parent.mkdir(parents=True, exist_ok=True) with open(floorplan_path, "wb") as f: f.write(await response.read()) diff --git a/services/listing_fetcher.py b/services/listing_fetcher.py index ee00aee..950740c 100644 --- a/services/listing_fetcher.py +++ b/services/listing_fetcher.py @@ -5,7 +5,7 @@ import logging from config.scraper_config import ScraperConfig from listing_processor import ListingProcessor from rec.query import create_session, listing_query -from rec.exceptions import CircuitBreakerOpenError, ThrottlingError +from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics from models.listing import Listing, QueryParameters from repositories import ListingRepository @@ -107,15 +107,15 @@ async def _fetch_subquery( f"{sq.district}: {e}" ) break - except Exception as e: + except InvalidResponseError: # Rightmove returns GENERIC_ERROR when requesting pages # past the last page of results. This is expected behavior # and signals we've exhausted this subquery's results. - if "GENERIC_ERROR" in str(e): - logger.debug( - f"Max page for {sq.district}: {page_id - 1}" - ) - break + logger.debug( + f"Max page for {sq.district}: {page_id - 1}" + ) + break + except Exception as e: logger.warning( f"Error fetching page {page_id} for " f"{sq.district}: {e}" diff --git a/tasks/listing_tasks.py b/tasks/listing_tasks.py index fb71821..42a5760 100644 --- a/tasks/listing_tasks.py +++ b/tasks/listing_tasks.py @@ -12,7 +12,7 @@ from config.scraper_config import ScraperConfig from listing_processor import ListingProcessor from models.listing import Listing, QueryParameters from rec.query import create_session, listing_query -from rec.exceptions import CircuitBreakerOpenError, ThrottlingError +from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics from repositories.listing_repository import ListingRepository from database import engine @@ -324,12 +324,12 @@ async def _fetch_subquery( f"Throttling on {sq.district} page {page_id}: {e}" ) break + except InvalidResponseError: + celery_logger.debug( + f"Max page for {sq.district}: {page_id - 1}" + ) + break except Exception as e: - if "GENERIC_ERROR" in str(e): - celery_logger.debug( - f"Max page for {sq.district}: {page_id - 1}" - ) - break celery_logger.warning( f"Error fetching page {page_id} for " f"{sq.district}: {e}" diff --git a/tests/integration/test_listing_processor.py b/tests/integration/test_listing_processor.py index ed3ab3a..81a89a0 100644 --- a/tests/integration/test_listing_processor.py +++ b/tests/integration/test_listing_processor.py @@ -70,7 +70,7 @@ async def test_step_failure_stops_pipeline( processor = ListingProcessor(listing_repository) processor.process_steps[0].needs_processing = AsyncMock(return_value=True) - processor.process_steps[0].process = AsyncMock(side_effect=RuntimeError("boom")) + processor.process_steps[0].process = AsyncMock(side_effect=ValueError("boom")) processor.process_steps[1].needs_processing = AsyncMock(return_value=True) processor.process_steps[1].process = AsyncMock() processor.process_steps[2].needs_processing = AsyncMock(return_value=True) diff --git a/tests/test_listing_geojson.py b/tests/test_listing_geojson.py index 3e0a152..cfff592 100644 --- a/tests/test_listing_geojson.py +++ b/tests/test_listing_geojson.py @@ -156,7 +156,7 @@ class TestStreamingEndpoint: @pytest.fixture def client(self): - """Create test client with mocked auth.""" + """Create test client with mocked auth and rate limiting bypassed.""" from fastapi.testclient import TestClient from api.app import app from api.auth import get_current_user, User @@ -165,13 +165,15 @@ class TestStreamingEndpoint: return User(sub="test-id", email="test@example.com", name="Test User") app.dependency_overrides[get_current_user] = mock_auth - yield TestClient(app) + with patch("api.rate_limiter._match_endpoint", return_value=None): + yield TestClient(app) app.dependency_overrides.clear() @pytest.fixture def mock_repository(self): - """Mock the repository methods.""" - with patch("api.app.ListingRepository") as MockRepo: + """Mock the repository methods and bypass cache.""" + with patch("api.app.get_cached_count", return_value=None), \ + patch("api.app.ListingRepository") as MockRepo: mock_instance = MagicMock() mock_instance.count_listings.return_value = 3 mock_instance.stream_listings_optimized.return_value = iter([ diff --git a/tests/unit/test_listing_cache.py b/tests/unit/test_listing_cache.py index f4686da..062c7c8 100644 --- a/tests/unit/test_listing_cache.py +++ b/tests/unit/test_listing_cache.py @@ -204,8 +204,12 @@ class TestInvalidateCache: mock_client = mock.MagicMock() mock_pipeline = mock.MagicMock() mock_client.pipeline.return_value = mock_pipeline - # Simulate one scan iteration that returns keys, then done - mock_client.scan.return_value = (0, ["listings:geojson:abc", "listings:geojson:def"]) + # invalidate_cache scans two patterns (CACHE_PREFIX*, STAGING_PREFIX*) + # First scan returns matching keys, second returns none + mock_client.scan.side_effect = [ + (0, ["listings:geojson:abc", "listings:geojson:def"]), + (0, []), + ] mock_get_client.return_value = mock_client invalidate_cache() diff --git a/tests/unit/test_listing_fetcher.py b/tests/unit/test_listing_fetcher.py index b9e4a81..dae680b 100644 --- a/tests/unit/test_listing_fetcher.py +++ b/tests/unit/test_listing_fetcher.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from models.listing import ListingType, QueryParameters -from rec.exceptions import CircuitBreakerOpenError, ThrottlingError +from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError from services.listing_fetcher import ( NUM_WORKERS, _fetch_subquery, @@ -227,7 +227,7 @@ class TestFetchSubquery: with patch( "services.listing_fetcher.listing_query", new_callable=AsyncMock, - side_effect=Exception("GENERIC_ERROR: no more results"), + side_effect=InvalidResponseError("GENERIC_ERROR: no more results"), ): ids_found = await _fetch_subquery( sq=sq, diff --git a/tests/unit/test_listing_processor.py b/tests/unit/test_listing_processor.py index a515638..4f49c4e 100644 --- a/tests/unit/test_listing_processor.py +++ b/tests/unit/test_listing_processor.py @@ -3,12 +3,12 @@ from datetime import datetime from unittest.mock import AsyncMock, MagicMock, patch import pytest from models.listing import FurnishType, ListingType +from config.scraper_config import MAX_OCR_WORKERS from listing_processor import ( _parse_furnish_type, _parse_available_from, ListingProcessor, FetchListingDetailsStep, - MAX_OCR_WORKERS, ) @@ -77,7 +77,7 @@ class TestListingProcessor: processor = ListingProcessor(mock_repo) for step in processor.process_steps: step.needs_processing = AsyncMock(return_value=True) - step.process = AsyncMock(side_effect=Exception("fail")) + step.process = AsyncMock(side_effect=ValueError("fail")) result = await processor.process_listing(123) assert result is None diff --git a/ui_exporter.py b/ui_exporter.py index fed36d6..b9491a2 100644 --- a/ui_exporter.py +++ b/ui_exporter.py @@ -99,7 +99,7 @@ def convert_to_geojson_feature(listing: RentListing | BuyListing) -> dict[str, A properties: dict[str, Any] = { "listing_type": listing_type, - "city": "London", # change me + "city": "London", "country": "United Kingdom", "qm": listing.square_meters, "qmprice": listing.price_per_square_meter,