Refactor backend for cleaner error handling, DRY, and type safety

- Extract rate limiter DRY: consolidate 3 duplicated check/respond paths
  into _check_counter and _enforce_limit helpers, add proper type annotations
- Replace bare Exception raises with FloorplanDownloadError and
  RightmoveApiError; narrow catch clauses to specific exception types;
  fix Step base class to inherit from ABC
- Consolidate MAX_OCR_WORKERS into config/scraper_config.py; extract
  _find_tenure_value helper to deduplicate tenure parsing
- Extract _build_poi_distances_lookup from stream endpoint to reduce nesting
- Fix csv_exporter: optional decisions.json, NaN instead of -1 sentinels,
  guard against division by zero on missing square meters
- Fix notifications.py broken list[Surface]() constructor, database.py
  stale comments and missing type annotation, auth.py type:ignore,
  ui_exporter.py stale TODO
- Fix 3 pre-existing test failures: mock cache layer in streaming tests,
  bypass rate limiter for test isolation, fix cache invalidation test to
  account for two-pattern scan loop
This commit is contained in:
Viktor Barzin 2026-02-10 22:19:24 +00:00
parent 6897820cc7
commit f833309297
No known key found for this signature in database
GPG key ID: 0EB088298288D958
20 changed files with 199 additions and 178 deletions

View file

@ -185,6 +185,40 @@ async def get_listing_geojson(
def _build_poi_distances_lookup(
user_email: str,
listing_type: ListingType,
) -> dict[int, list[dict[str, str | int]]] | None:
"""Build POI distance lookup for a user, or None if no POIs configured."""
user_repo = UserRepository(engine)
db_user = user_repo.get_user_by_email(user_email)
if not db_user or db_user.id is None:
return None
poi_repo = POIRepository(engine)
pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)}
if not pois:
return None
listing_repo = ListingRepository(engine)
all_ids = list(listing_repo.get_listing_ids(listing_type))
if not all_ids:
return None
distances = poi_repo.get_distances_for_listings(all_ids, listing_type, db_user.id)
lookup: dict[int, list[dict[str, str | int]]] = {}
for d in distances:
poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown"
lookup.setdefault(d.listing_id, []).append({
"poi_id": d.poi_id,
"poi_name": poi_name,
"travel_mode": d.travel_mode,
"duration_seconds": d.duration_seconds,
"distance_meters": d.distance_meters,
})
return lookup
async def _stream_from_cache(
query_parameters: QueryParameters,
batch_size: int,
@ -295,32 +329,7 @@ async def stream_listing_geojson(
limit = _rate_limit_config.geojson_stream_limit_cap
# Build POI distances lookup if requested
poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None
if include_poi_distances:
user_repo = UserRepository(engine)
db_user = user_repo.get_user_by_email(user.email)
if db_user and db_user.id is not None:
poi_repo = POIRepository(engine)
pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)}
if pois:
# Get all listing IDs first for the query
listing_repo = ListingRepository(engine)
all_ids = list(listing_repo.get_listing_ids(query_parameters.listing_type))
if all_ids:
distances = poi_repo.get_distances_for_listings(
all_ids, query_parameters.listing_type, db_user.id
)
poi_distances_lookup = {}
for d in distances:
poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown"
entry = {
"poi_id": d.poi_id,
"poi_name": poi_name,
"travel_mode": d.travel_mode,
"duration_seconds": d.duration_seconds,
"distance_meters": d.distance_meters,
}
poi_distances_lookup.setdefault(d.listing_id, []).append(entry)
poi_distances_lookup = _build_poi_distances_lookup(user.email, query_parameters.listing_type) if include_poi_distances else None
cached_count = get_cached_count(query_parameters)
if cached_count is not None and cached_count > 0 and not include_poi_distances:

View file

@ -13,6 +13,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from httpx import AsyncClient
import jwt
from pydantic import BaseModel
from typing import Any
# HTTPBearer scheme (provider-agnostic, works for both OIDC and passkey JWTs)
@ -28,7 +29,7 @@ class User(BaseModel):
name: str
async def get_oidc_metadata() -> dict: # type: ignore[type-arg]
async def get_oidc_metadata() -> dict[str, Any]:
if "oidc_metadata" not in OIDC_METADATA_CACHE:
async with AsyncClient() as client:
resp = await client.get(OIDC_METADATA_URL, follow_redirects=True)

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import logging
import os
import time
from collections.abc import Awaitable, Callable
from urllib.parse import urlparse, urlunparse
import jwt
@ -11,6 +12,7 @@ import redis
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.types import ASGIApp
from api.rate_limit_config import EndpointLimit, RateLimitConfig
@ -87,21 +89,77 @@ class _InMemoryCounter:
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
def __init__(self, app: ASGIApp, config: RateLimitConfig | None = None) -> None:
super().__init__(app)
self.config = config or RateLimitConfig.from_env()
self._fallback = _InMemoryCounter()
try:
self._redis = _get_rate_limit_redis(self.config)
self._redis: redis.Redis | None = _get_rate_limit_redis(self.config) # type: ignore[type-arg]
self._redis.ping()
except redis.RedisError:
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
self._redis = None
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
def _check_counter(self, key: str, limit: EndpointLimit) -> tuple[bool, int, int | None]:
"""Check rate limit counter, returning (allowed, remaining, retry_after).
Tries Redis first; falls back to in-memory counter on Redis errors.
retry_after is None for in-memory counters (no TTL available).
"""
if self._redis is None:
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
return allowed, remaining, None
try:
pipe = self._redis.pipeline(transaction=True)
pipe.incr(key)
pipe.ttl(key)
result = pipe.execute()
current_count: int = result[0]
ttl: int = result[1]
# Set expiry on first request in window
if ttl == -1:
self._redis.expire(key, limit.window_seconds)
ttl = limit.window_seconds
remaining = max(0, limit.max_requests - current_count)
allowed = current_count <= limit.max_requests
retry_after = max(1, ttl) if not allowed else None
return allowed, remaining, retry_after
except redis.RedisError as e:
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
return allowed, remaining, None
async def _enforce_limit(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
limit: EndpointLimit,
key: str,
) -> Response:
"""Check the rate limit and either reject with 429 or forward with headers."""
allowed, remaining, retry_after = self._check_counter(key, limit)
if not allowed:
headers: dict[str, str] = {
"X-RateLimit-Limit": str(limit.max_requests),
"X-RateLimit-Remaining": "0",
}
if retry_after is not None:
headers["Retry-After"] = str(retry_after)
return JSONResponse(status_code=429, content={"detail": "Rate limit exceeded"}, headers=headers)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
path = request.url.path
# Skip exempt paths
if path in EXEMPT_PATHS:
return await call_next(request)
@ -109,68 +167,6 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
if limit is None:
return await call_next(request)
# Determine identity for the counter key
identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
# If Redis is unavailable, use in-memory fallback
if self._redis is None:
fallback_key = f"ratelimit:{identity}:{path}"
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
redis_key = f"ratelimit:{identity}:{path}"
try:
pipe = self._redis.pipeline(transaction=True)
pipe.incr(redis_key)
pipe.ttl(redis_key)
result = pipe.execute()
current_count: int = result[0]
ttl: int = result[1]
# Set expiry on first request in window
if ttl == -1:
self._redis.expire(redis_key, limit.window_seconds)
ttl = limit.window_seconds
remaining = max(0, limit.max_requests - current_count)
if current_count > limit.max_requests:
retry_after = max(1, ttl)
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": str(limit.max_requests),
"X-RateLimit-Remaining": "0",
},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
except redis.RedisError as e:
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
fallback_key = f"ratelimit:{identity}:{path}"
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
key = f"ratelimit:{identity}:{path}"
return await self._enforce_limit(request, call_next, limit, key)

View file

@ -1,10 +1,14 @@
"""Scraper configuration with environment variable loading."""
from __future__ import annotations
import multiprocessing
import os
from dataclasses import dataclass
from typing import Self
# Limit OCR threads to 25% of available cores to avoid starving other work.
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
@dataclass(frozen=True)
class ScraperConfig:

View file

@ -14,27 +14,30 @@ async def export_to_csv(
df = pd.DataFrame(ds)
# read decisions on file
decisions_path = "data/decisions.json"
decisions = pd.read_json(decisions_path)
df.loc[:, "decision"] = df.id.apply(lambda x: decisions.get(x))
decisions_path = Path("data/decisions.json")
if decisions_path.exists():
decisions = pd.read_json(decisions_path)
df.loc[:, "decision"] = df.id.apply(lambda x: decisions.get(x))
# remove _sa_instance_state column
drop_columns = ["_sa_instance_state", "additional_info"]
df = df.drop(columns=drop_columns)
# fill in gap values for service charge and lease left for Excel filters
if "service_charge" not in df.columns:
df.loc[:, "service_charge"] = -1
df.loc[:, "service_charge"] = df.service_charge.fillna(-1)
if "lease_left" not in df.columns:
df.loc[:, "lease_left"] = -1
df.loc[:, "lease_left"] = df.lease_left.fillna(-1)
if "square_meters" not in df.columns:
df.loc[:, "square_meters"] = -1
df.loc[:, "square_meters"] = df.square_meters.fillna(-1)
# Ensure columns exist with NaN defaults for clean CSV output
for col in ("service_charge", "lease_left", "square_meters"):
if col not in df.columns:
df.loc[:, col] = float("nan")
# Add price per sqm column
df.loc[:, "price_per_sqm"] = df.price / df.square_meters
# Replace -1 sentinel values with NaN
df.loc[:, "square_meters"] = df.square_meters.replace({-1: float("nan")})
# Add price per sqm column (guard against zero/missing square_meters)
df.loc[:, "price_per_sqm"] = df.apply(
lambda row: round(row.price / row.square_meters, 2)
if row.square_meters and row.square_meters > 0
else None,
axis=1,
)
df = df.sort_values(by=["price_per_sqm"], ascending=True)
df.to_csv(str(output_file), index=False)

View file

@ -6,10 +6,6 @@ from dotenv import load_dotenv
load_dotenv()
# PostgreSQL example (or use "sqlite:///database.db" for SQLite)
# DATABASE_URL = "postgresql://user:password@localhost/db_name"
# DATABASE_URL = "sqlite:///data/wrongmove.db"
# DATABASE_URL = "mysql://wrongmove:wrongmove@localhost:3306/wrongmove"
DATABASE_URL = os.environ["DB_CONNECTION_STRING"]
@ -18,6 +14,6 @@ engine = create_engine(DATABASE_URL, echo=debug) # `echo=True` for debug logs
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db():
def init_db() -> None:
"""Create all tables (only for development; use migrations in production)."""
SQLModel.metadata.create_all(engine)

View file

@ -1,15 +1,15 @@
from __future__ import annotations
from abc import abstractmethod
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable
from datetime import datetime
import logging
import multiprocessing
from pathlib import Path
import re
from typing import Any
from urllib.parse import urlparse
import aiohttp
from config.scraper_config import MAX_OCR_WORKERS
from models.listing import (
BuyListing,
FurnishType,
@ -20,14 +20,12 @@ from models.listing import (
RentListing,
)
from rec import floorplan
from rec.exceptions import FloorplanDownloadError
from rec.query import detail_query
from repositories.listing_repository import ListingRepository
logger = logging.getLogger("uvicorn.error")
# Limit OCR threads to 25% of available cores to avoid starving other work.
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
def _parse_furnish_type(raw: str | None) -> FurnishType:
"""Normalise the raw furnish-type string from the API into a FurnishType enum."""
@ -97,13 +95,13 @@ class ListingProcessor:
step_class_name, step_class_name
)
on_step_complete(short_name)
except Exception as e:
except (ValueError, KeyError, aiohttp.ClientError, FloorplanDownloadError) as e:
logger.error(f"[{listing_id}] {step_class_name} failed: {e}")
return None
return listing
class Step:
class Step(ABC):
listing_repository: ListingRepository
listing_type: ListingType
@ -123,29 +121,32 @@ class Step:
return True
def _find_tenure_value(details: dict[str, Any], tenure_type: str) -> str | None:
"""Find a value in the tenure info content by type key."""
tenure_content = details.get("property", {}).get("tenureInfo", {}).get("content", [])
for item in tenure_content:
if item.get("type") == tenure_type:
return item.get("value")
return None
def _parse_service_charge(details: dict[str, Any]) -> float | None:
"""Parse annual service charge from the tenure info in API response."""
tenure_content = (
details.get("property", {}).get("tenureInfo", {}).get("content", [])
)
for item in tenure_content:
if item.get("type") == "annualServiceCharge":
matches = re.findall(r"([\d,.]+)", str(item.get("value", "")))
if matches:
return float(matches[0].replace(",", ""))
value = _find_tenure_value(details, "annualServiceCharge")
if value is not None:
matches = re.findall(r"([\d,.]+)", str(value))
if matches:
return float(matches[0].replace(",", ""))
return None
def _parse_lease_left(details: dict[str, Any]) -> int | None:
"""Parse remaining lease years from the tenure info in API response."""
tenure_content = (
details.get("property", {}).get("tenureInfo", {}).get("content", [])
)
for item in tenure_content:
if item.get("type") == "lengthOfLease":
matches = re.findall(r"(\d+)", str(item.get("value", "")))
if matches:
return int(matches[0])
value = _find_tenure_value(details, "lengthOfLease")
if value is not None:
matches = re.findall(r"(\d+)", str(value))
if matches:
return int(matches[0])
return None
@ -265,7 +266,7 @@ class FetchImagesStep(Step):
if response.status == 404:
return listing
if response.status != 200:
raise Exception(f"Error for {url}: {response.status}")
raise FloorplanDownloadError(url, response.status)
floorplan_path.parent.mkdir(parents=True, exist_ok=True)
with open(floorplan_path, "wb") as f:
f.write(await response.read())

View file

@ -15,8 +15,8 @@ class Slack(Surface):
@lru_cache(maxsize=None)
def get_notifier(surfaces: list[Surface] | None = None) -> apprise.Apprise:
surfaces = surfaces or list[Surface]([Slack()])
def get_notifier() -> apprise.Apprise:
surfaces = [Slack()]
obj = apprise.Apprise()
for surface in surfaces:
if conn := surface.connection_string():

View file

@ -74,6 +74,15 @@ class CircuitBreakerOpenError(RightmoveAPIError):
pass
class FloorplanDownloadError(Exception):
"""Raised when a floorplan image download fails."""
def __init__(self, url: str, status_code: int) -> None:
self.url = url
self.status_code = status_code
super().__init__(f"HTTP {status_code} downloading floorplan from {url}")
class RoutingApiError(Exception):
"""Error from the Google Routes API."""

View file

@ -10,6 +10,7 @@ from models.listing import FurnishType, ListingType
from rec import districts
from rec.exceptions import (
CircuitBreakerOpenError,
RightmoveAPIError,
ThrottlingError,
)
from rec.throttle_detector import get_throttle_metrics, validate_response
@ -205,9 +206,9 @@ def _build_listing_params(
7,
14,
]:
raise Exception(
f"Invalid max days - {max_days_since_added} Can only be got",
[1, 3, 7, 14],
raise ValueError(
f"Invalid max_days_since_added={max_days_since_added}, "
f"must be one of [1, 3, 7, 14]"
)
params["maxDaysSinceAdded"] = str(max_days_since_added)
@ -287,7 +288,7 @@ async def _execute_api_request(
)
if response.status != 200:
raise Exception(
raise RightmoveAPIError(
f"{error_context}Failed due to: {await response.text()}"
)

View file

@ -1,13 +1,10 @@
"""Floorplan detector service - OCR-based square meter detection."""
import asyncio
from config.scraper_config import MAX_OCR_WORKERS
from models import Listing
from rec import floorplan
from repositories.listing_repository import ListingRepository
from tqdm.asyncio import tqdm
import multiprocessing
# Use a quarter of available CPUs to avoid starving other processes
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
async def detect_floorplan(repository: ListingRepository) -> None:

View file

@ -5,6 +5,7 @@ from pathlib import Path
from urllib.parse import urlparse
import aiohttp
from rec.exceptions import FloorplanDownloadError
from repositories import ListingRepository
from tenacity import retry, stop_after_attempt, wait_random
from tqdm.asyncio import tqdm
@ -65,10 +66,7 @@ async def dump_images_for_listing(
)
return None
if response.status != 200:
raise Exception(
f"Error downloading floorplan for listing {listing.id} "
f"from {url}: HTTP {response.status}"
)
raise FloorplanDownloadError(url, response.status)
floorplan_path.parent.mkdir(parents=True, exist_ok=True)
with open(floorplan_path, "wb") as f:
f.write(await response.read())

View file

@ -5,7 +5,7 @@ import logging
from config.scraper_config import ScraperConfig
from listing_processor import ListingProcessor
from rec.query import create_session, listing_query
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
from models.listing import Listing, QueryParameters
from repositories import ListingRepository
@ -107,15 +107,15 @@ async def _fetch_subquery(
f"{sq.district}: {e}"
)
break
except Exception as e:
except InvalidResponseError:
# Rightmove returns GENERIC_ERROR when requesting pages
# past the last page of results. This is expected behavior
# and signals we've exhausted this subquery's results.
if "GENERIC_ERROR" in str(e):
logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
except Exception as e:
logger.warning(
f"Error fetching page {page_id} for "
f"{sq.district}: {e}"

View file

@ -12,7 +12,7 @@ from config.scraper_config import ScraperConfig
from listing_processor import ListingProcessor
from models.listing import Listing, QueryParameters
from rec.query import create_session, listing_query
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
from repositories.listing_repository import ListingRepository
from database import engine
@ -324,12 +324,12 @@ async def _fetch_subquery(
f"Throttling on {sq.district} page {page_id}: {e}"
)
break
except InvalidResponseError:
celery_logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
except Exception as e:
if "GENERIC_ERROR" in str(e):
celery_logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
celery_logger.warning(
f"Error fetching page {page_id} for "
f"{sq.district}: {e}"

View file

@ -70,7 +70,7 @@ async def test_step_failure_stops_pipeline(
processor = ListingProcessor(listing_repository)
processor.process_steps[0].needs_processing = AsyncMock(return_value=True)
processor.process_steps[0].process = AsyncMock(side_effect=RuntimeError("boom"))
processor.process_steps[0].process = AsyncMock(side_effect=ValueError("boom"))
processor.process_steps[1].needs_processing = AsyncMock(return_value=True)
processor.process_steps[1].process = AsyncMock()
processor.process_steps[2].needs_processing = AsyncMock(return_value=True)

View file

@ -156,7 +156,7 @@ class TestStreamingEndpoint:
@pytest.fixture
def client(self):
"""Create test client with mocked auth."""
"""Create test client with mocked auth and rate limiting bypassed."""
from fastapi.testclient import TestClient
from api.app import app
from api.auth import get_current_user, User
@ -165,13 +165,15 @@ class TestStreamingEndpoint:
return User(sub="test-id", email="test@example.com", name="Test User")
app.dependency_overrides[get_current_user] = mock_auth
yield TestClient(app)
with patch("api.rate_limiter._match_endpoint", return_value=None):
yield TestClient(app)
app.dependency_overrides.clear()
@pytest.fixture
def mock_repository(self):
"""Mock the repository methods."""
with patch("api.app.ListingRepository") as MockRepo:
"""Mock the repository methods and bypass cache."""
with patch("api.app.get_cached_count", return_value=None), \
patch("api.app.ListingRepository") as MockRepo:
mock_instance = MagicMock()
mock_instance.count_listings.return_value = 3
mock_instance.stream_listings_optimized.return_value = iter([

View file

@ -204,8 +204,12 @@ class TestInvalidateCache:
mock_client = mock.MagicMock()
mock_pipeline = mock.MagicMock()
mock_client.pipeline.return_value = mock_pipeline
# Simulate one scan iteration that returns keys, then done
mock_client.scan.return_value = (0, ["listings:geojson:abc", "listings:geojson:def"])
# invalidate_cache scans two patterns (CACHE_PREFIX*, STAGING_PREFIX*)
# First scan returns matching keys, second returns none
mock_client.scan.side_effect = [
(0, ["listings:geojson:abc", "listings:geojson:def"]),
(0, []),
]
mock_get_client.return_value = mock_client
invalidate_cache()

View file

@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from models.listing import ListingType, QueryParameters
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
from services.listing_fetcher import (
NUM_WORKERS,
_fetch_subquery,
@ -227,7 +227,7 @@ class TestFetchSubquery:
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
side_effect=Exception("GENERIC_ERROR: no more results"),
side_effect=InvalidResponseError("GENERIC_ERROR: no more results"),
):
ids_found = await _fetch_subquery(
sq=sq,

View file

@ -3,12 +3,12 @@ from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from models.listing import FurnishType, ListingType
from config.scraper_config import MAX_OCR_WORKERS
from listing_processor import (
_parse_furnish_type,
_parse_available_from,
ListingProcessor,
FetchListingDetailsStep,
MAX_OCR_WORKERS,
)
@ -77,7 +77,7 @@ class TestListingProcessor:
processor = ListingProcessor(mock_repo)
for step in processor.process_steps:
step.needs_processing = AsyncMock(return_value=True)
step.process = AsyncMock(side_effect=Exception("fail"))
step.process = AsyncMock(side_effect=ValueError("fail"))
result = await processor.process_listing(123)
assert result is None

View file

@ -99,7 +99,7 @@ def convert_to_geojson_feature(listing: RentListing | BuyListing) -> dict[str, A
properties: dict[str, Any] = {
"listing_type": listing_type,
"city": "London", # change me
"city": "London",
"country": "United Kingdom",
"qm": listing.square_meters,
"qmprice": listing.price_per_square_meter,