Refactor backend for cleaner error handling, DRY, and type safety
- Extract rate limiter DRY: consolidate 3 duplicated check/respond paths into _check_counter and _enforce_limit helpers, add proper type annotations - Replace bare Exception raises with FloorplanDownloadError and RightmoveApiError; narrow catch clauses to specific exception types; fix Step base class to inherit from ABC - Consolidate MAX_OCR_WORKERS into config/scraper_config.py; extract _find_tenure_value helper to deduplicate tenure parsing - Extract _build_poi_distances_lookup from stream endpoint to reduce nesting - Fix csv_exporter: optional decisions.json, NaN instead of -1 sentinels, guard against division by zero on missing square meters - Fix notifications.py broken list[Surface]() constructor, database.py stale comments and missing type annotation, auth.py type:ignore, ui_exporter.py stale TODO - Fix 3 pre-existing test failures: mock cache layer in streaming tests, bypass rate limiter for test isolation, fix cache invalidation test to account for two-pattern scan loop
This commit is contained in:
parent
6897820cc7
commit
f833309297
20 changed files with 199 additions and 178 deletions
61
api/app.py
61
api/app.py
|
|
@ -185,6 +185,40 @@ async def get_listing_geojson(
|
|||
|
||||
|
||||
|
||||
def _build_poi_distances_lookup(
|
||||
user_email: str,
|
||||
listing_type: ListingType,
|
||||
) -> dict[int, list[dict[str, str | int]]] | None:
|
||||
"""Build POI distance lookup for a user, or None if no POIs configured."""
|
||||
user_repo = UserRepository(engine)
|
||||
db_user = user_repo.get_user_by_email(user_email)
|
||||
if not db_user or db_user.id is None:
|
||||
return None
|
||||
|
||||
poi_repo = POIRepository(engine)
|
||||
pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)}
|
||||
if not pois:
|
||||
return None
|
||||
|
||||
listing_repo = ListingRepository(engine)
|
||||
all_ids = list(listing_repo.get_listing_ids(listing_type))
|
||||
if not all_ids:
|
||||
return None
|
||||
|
||||
distances = poi_repo.get_distances_for_listings(all_ids, listing_type, db_user.id)
|
||||
lookup: dict[int, list[dict[str, str | int]]] = {}
|
||||
for d in distances:
|
||||
poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown"
|
||||
lookup.setdefault(d.listing_id, []).append({
|
||||
"poi_id": d.poi_id,
|
||||
"poi_name": poi_name,
|
||||
"travel_mode": d.travel_mode,
|
||||
"duration_seconds": d.duration_seconds,
|
||||
"distance_meters": d.distance_meters,
|
||||
})
|
||||
return lookup
|
||||
|
||||
|
||||
async def _stream_from_cache(
|
||||
query_parameters: QueryParameters,
|
||||
batch_size: int,
|
||||
|
|
@ -295,32 +329,7 @@ async def stream_listing_geojson(
|
|||
limit = _rate_limit_config.geojson_stream_limit_cap
|
||||
|
||||
# Build POI distances lookup if requested
|
||||
poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None
|
||||
if include_poi_distances:
|
||||
user_repo = UserRepository(engine)
|
||||
db_user = user_repo.get_user_by_email(user.email)
|
||||
if db_user and db_user.id is not None:
|
||||
poi_repo = POIRepository(engine)
|
||||
pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)}
|
||||
if pois:
|
||||
# Get all listing IDs first for the query
|
||||
listing_repo = ListingRepository(engine)
|
||||
all_ids = list(listing_repo.get_listing_ids(query_parameters.listing_type))
|
||||
if all_ids:
|
||||
distances = poi_repo.get_distances_for_listings(
|
||||
all_ids, query_parameters.listing_type, db_user.id
|
||||
)
|
||||
poi_distances_lookup = {}
|
||||
for d in distances:
|
||||
poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown"
|
||||
entry = {
|
||||
"poi_id": d.poi_id,
|
||||
"poi_name": poi_name,
|
||||
"travel_mode": d.travel_mode,
|
||||
"duration_seconds": d.duration_seconds,
|
||||
"distance_meters": d.distance_meters,
|
||||
}
|
||||
poi_distances_lookup.setdefault(d.listing_id, []).append(entry)
|
||||
poi_distances_lookup = _build_poi_distances_lookup(user.email, query_parameters.listing_type) if include_poi_distances else None
|
||||
|
||||
cached_count = get_cached_count(query_parameters)
|
||||
if cached_count is not None and cached_count > 0 and not include_poi_distances:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|||
from httpx import AsyncClient
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
|
||||
# HTTPBearer scheme (provider-agnostic, works for both OIDC and passkey JWTs)
|
||||
|
|
@ -28,7 +29,7 @@ class User(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
async def get_oidc_metadata() -> dict: # type: ignore[type-arg]
|
||||
async def get_oidc_metadata() -> dict[str, Any]:
|
||||
if "oidc_metadata" not in OIDC_METADATA_CACHE:
|
||||
async with AsyncClient() as client:
|
||||
resp = await client.get(OIDC_METADATA_URL, follow_redirects=True)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import jwt
|
||||
|
|
@ -11,6 +12,7 @@ import redis
|
|||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from api.rate_limit_config import EndpointLimit, RateLimitConfig
|
||||
|
||||
|
|
@ -87,21 +89,77 @@ class _InMemoryCounter:
|
|||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
|
||||
|
||||
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
||||
def __init__(self, app: ASGIApp, config: RateLimitConfig | None = None) -> None:
|
||||
super().__init__(app)
|
||||
self.config = config or RateLimitConfig.from_env()
|
||||
self._fallback = _InMemoryCounter()
|
||||
try:
|
||||
self._redis = _get_rate_limit_redis(self.config)
|
||||
self._redis: redis.Redis | None = _get_rate_limit_redis(self.config) # type: ignore[type-arg]
|
||||
self._redis.ping()
|
||||
except redis.RedisError:
|
||||
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
|
||||
self._redis = None
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
||||
def _check_counter(self, key: str, limit: EndpointLimit) -> tuple[bool, int, int | None]:
|
||||
"""Check rate limit counter, returning (allowed, remaining, retry_after).
|
||||
|
||||
Tries Redis first; falls back to in-memory counter on Redis errors.
|
||||
retry_after is None for in-memory counters (no TTL available).
|
||||
"""
|
||||
if self._redis is None:
|
||||
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
|
||||
return allowed, remaining, None
|
||||
|
||||
try:
|
||||
pipe = self._redis.pipeline(transaction=True)
|
||||
pipe.incr(key)
|
||||
pipe.ttl(key)
|
||||
result = pipe.execute()
|
||||
current_count: int = result[0]
|
||||
ttl: int = result[1]
|
||||
|
||||
# Set expiry on first request in window
|
||||
if ttl == -1:
|
||||
self._redis.expire(key, limit.window_seconds)
|
||||
ttl = limit.window_seconds
|
||||
|
||||
remaining = max(0, limit.max_requests - current_count)
|
||||
allowed = current_count <= limit.max_requests
|
||||
retry_after = max(1, ttl) if not allowed else None
|
||||
return allowed, remaining, retry_after
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
|
||||
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
|
||||
return allowed, remaining, None
|
||||
|
||||
async def _enforce_limit(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
limit: EndpointLimit,
|
||||
key: str,
|
||||
) -> Response:
|
||||
"""Check the rate limit and either reject with 429 or forward with headers."""
|
||||
allowed, remaining, retry_after = self._check_counter(key, limit)
|
||||
|
||||
if not allowed:
|
||||
headers: dict[str, str] = {
|
||||
"X-RateLimit-Limit": str(limit.max_requests),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
}
|
||||
if retry_after is not None:
|
||||
headers["Retry-After"] = str(retry_after)
|
||||
return JSONResponse(status_code=429, content={"detail": "Rate limit exceeded"}, headers=headers)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
path = request.url.path
|
||||
|
||||
# Skip exempt paths
|
||||
if path in EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
|
|
@ -109,68 +167,6 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
|||
if limit is None:
|
||||
return await call_next(request)
|
||||
|
||||
# Determine identity for the counter key
|
||||
identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
|
||||
|
||||
# If Redis is unavailable, use in-memory fallback
|
||||
if self._redis is None:
|
||||
fallback_key = f"ratelimit:{identity}:{path}"
|
||||
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"},
|
||||
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
redis_key = f"ratelimit:{identity}:{path}"
|
||||
try:
|
||||
pipe = self._redis.pipeline(transaction=True)
|
||||
pipe.incr(redis_key)
|
||||
pipe.ttl(redis_key)
|
||||
result = pipe.execute()
|
||||
current_count: int = result[0]
|
||||
ttl: int = result[1]
|
||||
|
||||
# Set expiry on first request in window
|
||||
if ttl == -1:
|
||||
self._redis.expire(redis_key, limit.window_seconds)
|
||||
ttl = limit.window_seconds
|
||||
|
||||
remaining = max(0, limit.max_requests - current_count)
|
||||
|
||||
if current_count > limit.max_requests:
|
||||
retry_after = max(1, ttl)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"},
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"X-RateLimit-Limit": str(limit.max_requests),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
|
||||
fallback_key = f"ratelimit:{identity}:{path}"
|
||||
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"},
|
||||
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
key = f"ratelimit:{identity}:{path}"
|
||||
return await self._enforce_limit(request, call_next, limit, key)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
"""Scraper configuration with environment variable loading."""
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
# Limit OCR threads to 25% of available cores to avoid starving other work.
|
||||
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScraperConfig:
|
||||
|
|
|
|||
|
|
@ -14,27 +14,30 @@ async def export_to_csv(
|
|||
df = pd.DataFrame(ds)
|
||||
|
||||
# read decisions on file
|
||||
decisions_path = "data/decisions.json"
|
||||
decisions = pd.read_json(decisions_path)
|
||||
df.loc[:, "decision"] = df.id.apply(lambda x: decisions.get(x))
|
||||
decisions_path = Path("data/decisions.json")
|
||||
if decisions_path.exists():
|
||||
decisions = pd.read_json(decisions_path)
|
||||
df.loc[:, "decision"] = df.id.apply(lambda x: decisions.get(x))
|
||||
|
||||
# remove _sa_instance_state column
|
||||
drop_columns = ["_sa_instance_state", "additional_info"]
|
||||
df = df.drop(columns=drop_columns)
|
||||
|
||||
# fill in gap values for service charge and lease left for Excel filters
|
||||
if "service_charge" not in df.columns:
|
||||
df.loc[:, "service_charge"] = -1
|
||||
df.loc[:, "service_charge"] = df.service_charge.fillna(-1)
|
||||
if "lease_left" not in df.columns:
|
||||
df.loc[:, "lease_left"] = -1
|
||||
df.loc[:, "lease_left"] = df.lease_left.fillna(-1)
|
||||
if "square_meters" not in df.columns:
|
||||
df.loc[:, "square_meters"] = -1
|
||||
df.loc[:, "square_meters"] = df.square_meters.fillna(-1)
|
||||
# Ensure columns exist with NaN defaults for clean CSV output
|
||||
for col in ("service_charge", "lease_left", "square_meters"):
|
||||
if col not in df.columns:
|
||||
df.loc[:, col] = float("nan")
|
||||
|
||||
# Add price per sqm column
|
||||
df.loc[:, "price_per_sqm"] = df.price / df.square_meters
|
||||
# Replace -1 sentinel values with NaN
|
||||
df.loc[:, "square_meters"] = df.square_meters.replace({-1: float("nan")})
|
||||
|
||||
# Add price per sqm column (guard against zero/missing square_meters)
|
||||
df.loc[:, "price_per_sqm"] = df.apply(
|
||||
lambda row: round(row.price / row.square_meters, 2)
|
||||
if row.square_meters and row.square_meters > 0
|
||||
else None,
|
||||
axis=1,
|
||||
)
|
||||
|
||||
df = df.sort_values(by=["price_per_sqm"], ascending=True)
|
||||
df.to_csv(str(output_file), index=False)
|
||||
|
|
|
|||
|
|
@ -6,10 +6,6 @@ from dotenv import load_dotenv
|
|||
load_dotenv()
|
||||
|
||||
|
||||
# PostgreSQL example (or use "sqlite:///database.db" for SQLite)
|
||||
# DATABASE_URL = "postgresql://user:password@localhost/db_name"
|
||||
# DATABASE_URL = "sqlite:///data/wrongmove.db"
|
||||
# DATABASE_URL = "mysql://wrongmove:wrongmove@localhost:3306/wrongmove"
|
||||
DATABASE_URL = os.environ["DB_CONNECTION_STRING"]
|
||||
|
||||
|
||||
|
|
@ -18,6 +14,6 @@ engine = create_engine(DATABASE_URL, echo=debug) # `echo=True` for debug logs
|
|||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def init_db():
|
||||
def init_db() -> None:
|
||||
"""Create all tables (only for development; use migrations in production)."""
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
from config.scraper_config import MAX_OCR_WORKERS
|
||||
from models.listing import (
|
||||
BuyListing,
|
||||
FurnishType,
|
||||
|
|
@ -20,14 +20,12 @@ from models.listing import (
|
|||
RentListing,
|
||||
)
|
||||
from rec import floorplan
|
||||
from rec.exceptions import FloorplanDownloadError
|
||||
from rec.query import detail_query
|
||||
from repositories.listing_repository import ListingRepository
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
# Limit OCR threads to 25% of available cores to avoid starving other work.
|
||||
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
|
||||
|
||||
|
||||
def _parse_furnish_type(raw: str | None) -> FurnishType:
|
||||
"""Normalise the raw furnish-type string from the API into a FurnishType enum."""
|
||||
|
|
@ -97,13 +95,13 @@ class ListingProcessor:
|
|||
step_class_name, step_class_name
|
||||
)
|
||||
on_step_complete(short_name)
|
||||
except Exception as e:
|
||||
except (ValueError, KeyError, aiohttp.ClientError, FloorplanDownloadError) as e:
|
||||
logger.error(f"[{listing_id}] {step_class_name} failed: {e}")
|
||||
return None
|
||||
return listing
|
||||
|
||||
|
||||
class Step:
|
||||
class Step(ABC):
|
||||
listing_repository: ListingRepository
|
||||
listing_type: ListingType
|
||||
|
||||
|
|
@ -123,29 +121,32 @@ class Step:
|
|||
return True
|
||||
|
||||
|
||||
def _find_tenure_value(details: dict[str, Any], tenure_type: str) -> str | None:
|
||||
"""Find a value in the tenure info content by type key."""
|
||||
tenure_content = details.get("property", {}).get("tenureInfo", {}).get("content", [])
|
||||
for item in tenure_content:
|
||||
if item.get("type") == tenure_type:
|
||||
return item.get("value")
|
||||
return None
|
||||
|
||||
|
||||
def _parse_service_charge(details: dict[str, Any]) -> float | None:
|
||||
"""Parse annual service charge from the tenure info in API response."""
|
||||
tenure_content = (
|
||||
details.get("property", {}).get("tenureInfo", {}).get("content", [])
|
||||
)
|
||||
for item in tenure_content:
|
||||
if item.get("type") == "annualServiceCharge":
|
||||
matches = re.findall(r"([\d,.]+)", str(item.get("value", "")))
|
||||
if matches:
|
||||
return float(matches[0].replace(",", ""))
|
||||
value = _find_tenure_value(details, "annualServiceCharge")
|
||||
if value is not None:
|
||||
matches = re.findall(r"([\d,.]+)", str(value))
|
||||
if matches:
|
||||
return float(matches[0].replace(",", ""))
|
||||
return None
|
||||
|
||||
|
||||
def _parse_lease_left(details: dict[str, Any]) -> int | None:
|
||||
"""Parse remaining lease years from the tenure info in API response."""
|
||||
tenure_content = (
|
||||
details.get("property", {}).get("tenureInfo", {}).get("content", [])
|
||||
)
|
||||
for item in tenure_content:
|
||||
if item.get("type") == "lengthOfLease":
|
||||
matches = re.findall(r"(\d+)", str(item.get("value", "")))
|
||||
if matches:
|
||||
return int(matches[0])
|
||||
value = _find_tenure_value(details, "lengthOfLease")
|
||||
if value is not None:
|
||||
matches = re.findall(r"(\d+)", str(value))
|
||||
if matches:
|
||||
return int(matches[0])
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -265,7 +266,7 @@ class FetchImagesStep(Step):
|
|||
if response.status == 404:
|
||||
return listing
|
||||
if response.status != 200:
|
||||
raise Exception(f"Error for {url}: {response.status}")
|
||||
raise FloorplanDownloadError(url, response.status)
|
||||
floorplan_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(floorplan_path, "wb") as f:
|
||||
f.write(await response.read())
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ class Slack(Surface):
|
|||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_notifier(surfaces: list[Surface] | None = None) -> apprise.Apprise:
|
||||
surfaces = surfaces or list[Surface]([Slack()])
|
||||
def get_notifier() -> apprise.Apprise:
|
||||
surfaces = [Slack()]
|
||||
obj = apprise.Apprise()
|
||||
for surface in surfaces:
|
||||
if conn := surface.connection_string():
|
||||
|
|
|
|||
|
|
@ -74,6 +74,15 @@ class CircuitBreakerOpenError(RightmoveAPIError):
|
|||
pass
|
||||
|
||||
|
||||
class FloorplanDownloadError(Exception):
|
||||
"""Raised when a floorplan image download fails."""
|
||||
|
||||
def __init__(self, url: str, status_code: int) -> None:
|
||||
self.url = url
|
||||
self.status_code = status_code
|
||||
super().__init__(f"HTTP {status_code} downloading floorplan from {url}")
|
||||
|
||||
|
||||
class RoutingApiError(Exception):
|
||||
"""Error from the Google Routes API."""
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from models.listing import FurnishType, ListingType
|
|||
from rec import districts
|
||||
from rec.exceptions import (
|
||||
CircuitBreakerOpenError,
|
||||
RightmoveAPIError,
|
||||
ThrottlingError,
|
||||
)
|
||||
from rec.throttle_detector import get_throttle_metrics, validate_response
|
||||
|
|
@ -205,9 +206,9 @@ def _build_listing_params(
|
|||
7,
|
||||
14,
|
||||
]:
|
||||
raise Exception(
|
||||
f"Invalid max days - {max_days_since_added} Can only be got",
|
||||
[1, 3, 7, 14],
|
||||
raise ValueError(
|
||||
f"Invalid max_days_since_added={max_days_since_added}, "
|
||||
f"must be one of [1, 3, 7, 14]"
|
||||
)
|
||||
params["maxDaysSinceAdded"] = str(max_days_since_added)
|
||||
|
||||
|
|
@ -287,7 +288,7 @@ async def _execute_api_request(
|
|||
)
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
raise RightmoveAPIError(
|
||||
f"{error_context}Failed due to: {await response.text()}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,10 @@
|
|||
"""Floorplan detector service - OCR-based square meter detection."""
|
||||
import asyncio
|
||||
from config.scraper_config import MAX_OCR_WORKERS
|
||||
from models import Listing
|
||||
from rec import floorplan
|
||||
from repositories.listing_repository import ListingRepository
|
||||
from tqdm.asyncio import tqdm
|
||||
import multiprocessing
|
||||
|
||||
# Use a quarter of available CPUs to avoid starving other processes
|
||||
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
|
||||
|
||||
|
||||
async def detect_floorplan(repository: ListingRepository) -> None:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from pathlib import Path
|
|||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from rec.exceptions import FloorplanDownloadError
|
||||
from repositories import ListingRepository
|
||||
from tenacity import retry, stop_after_attempt, wait_random
|
||||
from tqdm.asyncio import tqdm
|
||||
|
|
@ -65,10 +66,7 @@ async def dump_images_for_listing(
|
|||
)
|
||||
return None
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Error downloading floorplan for listing {listing.id} "
|
||||
f"from {url}: HTTP {response.status}"
|
||||
)
|
||||
raise FloorplanDownloadError(url, response.status)
|
||||
floorplan_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(floorplan_path, "wb") as f:
|
||||
f.write(await response.read())
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import logging
|
|||
from config.scraper_config import ScraperConfig
|
||||
from listing_processor import ListingProcessor
|
||||
from rec.query import create_session, listing_query
|
||||
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
|
||||
from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
|
||||
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
|
||||
from models.listing import Listing, QueryParameters
|
||||
from repositories import ListingRepository
|
||||
|
|
@ -107,15 +107,15 @@ async def _fetch_subquery(
|
|||
f"{sq.district}: {e}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
except InvalidResponseError:
|
||||
# Rightmove returns GENERIC_ERROR when requesting pages
|
||||
# past the last page of results. This is expected behavior
|
||||
# and signals we've exhausted this subquery's results.
|
||||
if "GENERIC_ERROR" in str(e):
|
||||
logger.debug(
|
||||
f"Max page for {sq.district}: {page_id - 1}"
|
||||
)
|
||||
break
|
||||
logger.debug(
|
||||
f"Max page for {sq.district}: {page_id - 1}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error fetching page {page_id} for "
|
||||
f"{sq.district}: {e}"
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from config.scraper_config import ScraperConfig
|
|||
from listing_processor import ListingProcessor
|
||||
from models.listing import Listing, QueryParameters
|
||||
from rec.query import create_session, listing_query
|
||||
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
|
||||
from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
|
||||
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
|
||||
from repositories.listing_repository import ListingRepository
|
||||
from database import engine
|
||||
|
|
@ -324,12 +324,12 @@ async def _fetch_subquery(
|
|||
f"Throttling on {sq.district} page {page_id}: {e}"
|
||||
)
|
||||
break
|
||||
except InvalidResponseError:
|
||||
celery_logger.debug(
|
||||
f"Max page for {sq.district}: {page_id - 1}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
if "GENERIC_ERROR" in str(e):
|
||||
celery_logger.debug(
|
||||
f"Max page for {sq.district}: {page_id - 1}"
|
||||
)
|
||||
break
|
||||
celery_logger.warning(
|
||||
f"Error fetching page {page_id} for "
|
||||
f"{sq.district}: {e}"
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ async def test_step_failure_stops_pipeline(
|
|||
processor = ListingProcessor(listing_repository)
|
||||
|
||||
processor.process_steps[0].needs_processing = AsyncMock(return_value=True)
|
||||
processor.process_steps[0].process = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
processor.process_steps[0].process = AsyncMock(side_effect=ValueError("boom"))
|
||||
processor.process_steps[1].needs_processing = AsyncMock(return_value=True)
|
||||
processor.process_steps[1].process = AsyncMock()
|
||||
processor.process_steps[2].needs_processing = AsyncMock(return_value=True)
|
||||
|
|
|
|||
|
|
@ -156,7 +156,7 @@ class TestStreamingEndpoint:
|
|||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""Create test client with mocked auth."""
|
||||
"""Create test client with mocked auth and rate limiting bypassed."""
|
||||
from fastapi.testclient import TestClient
|
||||
from api.app import app
|
||||
from api.auth import get_current_user, User
|
||||
|
|
@ -165,13 +165,15 @@ class TestStreamingEndpoint:
|
|||
return User(sub="test-id", email="test@example.com", name="Test User")
|
||||
|
||||
app.dependency_overrides[get_current_user] = mock_auth
|
||||
yield TestClient(app)
|
||||
with patch("api.rate_limiter._match_endpoint", return_value=None):
|
||||
yield TestClient(app)
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_repository(self):
|
||||
"""Mock the repository methods."""
|
||||
with patch("api.app.ListingRepository") as MockRepo:
|
||||
"""Mock the repository methods and bypass cache."""
|
||||
with patch("api.app.get_cached_count", return_value=None), \
|
||||
patch("api.app.ListingRepository") as MockRepo:
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.count_listings.return_value = 3
|
||||
mock_instance.stream_listings_optimized.return_value = iter([
|
||||
|
|
|
|||
|
|
@ -204,8 +204,12 @@ class TestInvalidateCache:
|
|||
mock_client = mock.MagicMock()
|
||||
mock_pipeline = mock.MagicMock()
|
||||
mock_client.pipeline.return_value = mock_pipeline
|
||||
# Simulate one scan iteration that returns keys, then done
|
||||
mock_client.scan.return_value = (0, ["listings:geojson:abc", "listings:geojson:def"])
|
||||
# invalidate_cache scans two patterns (CACHE_PREFIX*, STAGING_PREFIX*)
|
||||
# First scan returns matching keys, second returns none
|
||||
mock_client.scan.side_effect = [
|
||||
(0, ["listings:geojson:abc", "listings:geojson:def"]),
|
||||
(0, []),
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
invalidate_cache()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
import pytest
|
||||
|
||||
from models.listing import ListingType, QueryParameters
|
||||
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
|
||||
from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
|
||||
from services.listing_fetcher import (
|
||||
NUM_WORKERS,
|
||||
_fetch_subquery,
|
||||
|
|
@ -227,7 +227,7 @@ class TestFetchSubquery:
|
|||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("GENERIC_ERROR: no more results"),
|
||||
side_effect=InvalidResponseError("GENERIC_ERROR: no more results"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ from datetime import datetime
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from models.listing import FurnishType, ListingType
|
||||
from config.scraper_config import MAX_OCR_WORKERS
|
||||
from listing_processor import (
|
||||
_parse_furnish_type,
|
||||
_parse_available_from,
|
||||
ListingProcessor,
|
||||
FetchListingDetailsStep,
|
||||
MAX_OCR_WORKERS,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ class TestListingProcessor:
|
|||
processor = ListingProcessor(mock_repo)
|
||||
for step in processor.process_steps:
|
||||
step.needs_processing = AsyncMock(return_value=True)
|
||||
step.process = AsyncMock(side_effect=Exception("fail"))
|
||||
step.process = AsyncMock(side_effect=ValueError("fail"))
|
||||
result = await processor.process_listing(123)
|
||||
assert result is None
|
||||
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ def convert_to_geojson_feature(listing: RentListing | BuyListing) -> dict[str, A
|
|||
|
||||
properties: dict[str, Any] = {
|
||||
"listing_type": listing_type,
|
||||
"city": "London", # change me
|
||||
"city": "London",
|
||||
"country": "United Kingdom",
|
||||
"qm": listing.square_meters,
|
||||
"qmprice": listing.price_per_square_meter,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue