Refactor backend for cleaner error handling, DRY, and type safety
- Extract rate limiter DRY: consolidate 3 duplicated check/respond paths into _check_counter and _enforce_limit helpers, add proper type annotations - Replace bare Exception raises with FloorplanDownloadError and RightmoveApiError; narrow catch clauses to specific exception types; fix Step base class to inherit from ABC - Consolidate MAX_OCR_WORKERS into config/scraper_config.py; extract _find_tenure_value helper to deduplicate tenure parsing - Extract _build_poi_distances_lookup from stream endpoint to reduce nesting - Fix csv_exporter: optional decisions.json, NaN instead of -1 sentinels, guard against division by zero on missing square meters - Fix notifications.py broken list[Surface]() constructor, database.py stale comments and missing type annotation, auth.py type:ignore, ui_exporter.py stale TODO - Fix 3 pre-existing test failures: mock cache layer in streaming tests, bypass rate limiter for test isolation, fix cache invalidation test to account for two-pattern scan loop
This commit is contained in:
parent
6897820cc7
commit
f833309297
20 changed files with 199 additions and 178 deletions
61
api/app.py
61
api/app.py
|
|
@ -185,6 +185,40 @@ async def get_listing_geojson(
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _build_poi_distances_lookup(
|
||||||
|
user_email: str,
|
||||||
|
listing_type: ListingType,
|
||||||
|
) -> dict[int, list[dict[str, str | int]]] | None:
|
||||||
|
"""Build POI distance lookup for a user, or None if no POIs configured."""
|
||||||
|
user_repo = UserRepository(engine)
|
||||||
|
db_user = user_repo.get_user_by_email(user_email)
|
||||||
|
if not db_user or db_user.id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
poi_repo = POIRepository(engine)
|
||||||
|
pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)}
|
||||||
|
if not pois:
|
||||||
|
return None
|
||||||
|
|
||||||
|
listing_repo = ListingRepository(engine)
|
||||||
|
all_ids = list(listing_repo.get_listing_ids(listing_type))
|
||||||
|
if not all_ids:
|
||||||
|
return None
|
||||||
|
|
||||||
|
distances = poi_repo.get_distances_for_listings(all_ids, listing_type, db_user.id)
|
||||||
|
lookup: dict[int, list[dict[str, str | int]]] = {}
|
||||||
|
for d in distances:
|
||||||
|
poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown"
|
||||||
|
lookup.setdefault(d.listing_id, []).append({
|
||||||
|
"poi_id": d.poi_id,
|
||||||
|
"poi_name": poi_name,
|
||||||
|
"travel_mode": d.travel_mode,
|
||||||
|
"duration_seconds": d.duration_seconds,
|
||||||
|
"distance_meters": d.distance_meters,
|
||||||
|
})
|
||||||
|
return lookup
|
||||||
|
|
||||||
|
|
||||||
async def _stream_from_cache(
|
async def _stream_from_cache(
|
||||||
query_parameters: QueryParameters,
|
query_parameters: QueryParameters,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
|
@ -295,32 +329,7 @@ async def stream_listing_geojson(
|
||||||
limit = _rate_limit_config.geojson_stream_limit_cap
|
limit = _rate_limit_config.geojson_stream_limit_cap
|
||||||
|
|
||||||
# Build POI distances lookup if requested
|
# Build POI distances lookup if requested
|
||||||
poi_distances_lookup: dict[int, list[dict[str, str | int]]] | None = None
|
poi_distances_lookup = _build_poi_distances_lookup(user.email, query_parameters.listing_type) if include_poi_distances else None
|
||||||
if include_poi_distances:
|
|
||||||
user_repo = UserRepository(engine)
|
|
||||||
db_user = user_repo.get_user_by_email(user.email)
|
|
||||||
if db_user and db_user.id is not None:
|
|
||||||
poi_repo = POIRepository(engine)
|
|
||||||
pois = {p.id: p for p in poi_repo.get_pois_for_user(db_user.id)}
|
|
||||||
if pois:
|
|
||||||
# Get all listing IDs first for the query
|
|
||||||
listing_repo = ListingRepository(engine)
|
|
||||||
all_ids = list(listing_repo.get_listing_ids(query_parameters.listing_type))
|
|
||||||
if all_ids:
|
|
||||||
distances = poi_repo.get_distances_for_listings(
|
|
||||||
all_ids, query_parameters.listing_type, db_user.id
|
|
||||||
)
|
|
||||||
poi_distances_lookup = {}
|
|
||||||
for d in distances:
|
|
||||||
poi_name = pois[d.poi_id].name if d.poi_id in pois else "Unknown"
|
|
||||||
entry = {
|
|
||||||
"poi_id": d.poi_id,
|
|
||||||
"poi_name": poi_name,
|
|
||||||
"travel_mode": d.travel_mode,
|
|
||||||
"duration_seconds": d.duration_seconds,
|
|
||||||
"distance_meters": d.distance_meters,
|
|
||||||
}
|
|
||||||
poi_distances_lookup.setdefault(d.listing_id, []).append(entry)
|
|
||||||
|
|
||||||
cached_count = get_cached_count(query_parameters)
|
cached_count = get_cached_count(query_parameters)
|
||||||
if cached_count is not None and cached_count > 0 and not include_poi_distances:
|
if cached_count is not None and cached_count > 0 and not include_poi_distances:
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
import jwt
|
import jwt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
# HTTPBearer scheme (provider-agnostic, works for both OIDC and passkey JWTs)
|
# HTTPBearer scheme (provider-agnostic, works for both OIDC and passkey JWTs)
|
||||||
|
|
@ -28,7 +29,7 @@ class User(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
async def get_oidc_metadata() -> dict: # type: ignore[type-arg]
|
async def get_oidc_metadata() -> dict[str, Any]:
|
||||||
if "oidc_metadata" not in OIDC_METADATA_CACHE:
|
if "oidc_metadata" not in OIDC_METADATA_CACHE:
|
||||||
async with AsyncClient() as client:
|
async with AsyncClient() as client:
|
||||||
resp = await client.get(OIDC_METADATA_URL, follow_redirects=True)
|
resp = await client.get(OIDC_METADATA_URL, follow_redirects=True)
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
@ -11,6 +12,7 @@ import redis
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse, Response
|
from starlette.responses import JSONResponse, Response
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
from api.rate_limit_config import EndpointLimit, RateLimitConfig
|
from api.rate_limit_config import EndpointLimit, RateLimitConfig
|
||||||
|
|
||||||
|
|
@ -87,21 +89,77 @@ class _InMemoryCounter:
|
||||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
|
"""Starlette middleware enforcing per-user fixed-window rate limits via Redis."""
|
||||||
|
|
||||||
def __init__(self, app, config: RateLimitConfig | None = None) -> None: # type: ignore[no-untyped-def]
|
def __init__(self, app: ASGIApp, config: RateLimitConfig | None = None) -> None:
|
||||||
super().__init__(app)
|
super().__init__(app)
|
||||||
self.config = config or RateLimitConfig.from_env()
|
self.config = config or RateLimitConfig.from_env()
|
||||||
self._fallback = _InMemoryCounter()
|
self._fallback = _InMemoryCounter()
|
||||||
try:
|
try:
|
||||||
self._redis = _get_rate_limit_redis(self.config)
|
self._redis: redis.Redis | None = _get_rate_limit_redis(self.config) # type: ignore[type-arg]
|
||||||
self._redis.ping()
|
self._redis.ping()
|
||||||
except redis.RedisError:
|
except redis.RedisError:
|
||||||
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
|
logger.warning("Rate limiter: Redis unavailable at startup, will fail open")
|
||||||
self._redis = None
|
self._redis = None
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[no-untyped-def]
|
def _check_counter(self, key: str, limit: EndpointLimit) -> tuple[bool, int, int | None]:
|
||||||
|
"""Check rate limit counter, returning (allowed, remaining, retry_after).
|
||||||
|
|
||||||
|
Tries Redis first; falls back to in-memory counter on Redis errors.
|
||||||
|
retry_after is None for in-memory counters (no TTL available).
|
||||||
|
"""
|
||||||
|
if self._redis is None:
|
||||||
|
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
|
||||||
|
return allowed, remaining, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
pipe = self._redis.pipeline(transaction=True)
|
||||||
|
pipe.incr(key)
|
||||||
|
pipe.ttl(key)
|
||||||
|
result = pipe.execute()
|
||||||
|
current_count: int = result[0]
|
||||||
|
ttl: int = result[1]
|
||||||
|
|
||||||
|
# Set expiry on first request in window
|
||||||
|
if ttl == -1:
|
||||||
|
self._redis.expire(key, limit.window_seconds)
|
||||||
|
ttl = limit.window_seconds
|
||||||
|
|
||||||
|
remaining = max(0, limit.max_requests - current_count)
|
||||||
|
allowed = current_count <= limit.max_requests
|
||||||
|
retry_after = max(1, ttl) if not allowed else None
|
||||||
|
return allowed, remaining, retry_after
|
||||||
|
|
||||||
|
except redis.RedisError as e:
|
||||||
|
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
|
||||||
|
allowed, remaining = self._fallback.check(key, limit.max_requests, limit.window_seconds)
|
||||||
|
return allowed, remaining, None
|
||||||
|
|
||||||
|
async def _enforce_limit(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
call_next: Callable[[Request], Awaitable[Response]],
|
||||||
|
limit: EndpointLimit,
|
||||||
|
key: str,
|
||||||
|
) -> Response:
|
||||||
|
"""Check the rate limit and either reject with 429 or forward with headers."""
|
||||||
|
allowed, remaining, retry_after = self._check_counter(key, limit)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
"X-RateLimit-Limit": str(limit.max_requests),
|
||||||
|
"X-RateLimit-Remaining": "0",
|
||||||
|
}
|
||||||
|
if retry_after is not None:
|
||||||
|
headers["Retry-After"] = str(retry_after)
|
||||||
|
return JSONResponse(status_code=429, content={"detail": "Rate limit exceeded"}, headers=headers)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
||||||
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
|
|
||||||
# Skip exempt paths
|
|
||||||
if path in EXEMPT_PATHS:
|
if path in EXEMPT_PATHS:
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
@ -109,68 +167,6 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
if limit is None:
|
if limit is None:
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# Determine identity for the counter key
|
|
||||||
identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
|
identity = _extract_user_email(request) or _client_ip(request, self.config.trusted_proxy_depth)
|
||||||
|
key = f"ratelimit:{identity}:{path}"
|
||||||
# If Redis is unavailable, use in-memory fallback
|
return await self._enforce_limit(request, call_next, limit, key)
|
||||||
if self._redis is None:
|
|
||||||
fallback_key = f"ratelimit:{identity}:{path}"
|
|
||||||
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
|
|
||||||
if not allowed:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=429,
|
|
||||||
content={"detail": "Rate limit exceeded"},
|
|
||||||
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
|
|
||||||
)
|
|
||||||
response = await call_next(request)
|
|
||||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
|
||||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
||||||
return response
|
|
||||||
|
|
||||||
redis_key = f"ratelimit:{identity}:{path}"
|
|
||||||
try:
|
|
||||||
pipe = self._redis.pipeline(transaction=True)
|
|
||||||
pipe.incr(redis_key)
|
|
||||||
pipe.ttl(redis_key)
|
|
||||||
result = pipe.execute()
|
|
||||||
current_count: int = result[0]
|
|
||||||
ttl: int = result[1]
|
|
||||||
|
|
||||||
# Set expiry on first request in window
|
|
||||||
if ttl == -1:
|
|
||||||
self._redis.expire(redis_key, limit.window_seconds)
|
|
||||||
ttl = limit.window_seconds
|
|
||||||
|
|
||||||
remaining = max(0, limit.max_requests - current_count)
|
|
||||||
|
|
||||||
if current_count > limit.max_requests:
|
|
||||||
retry_after = max(1, ttl)
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=429,
|
|
||||||
content={"detail": "Rate limit exceeded"},
|
|
||||||
headers={
|
|
||||||
"Retry-After": str(retry_after),
|
|
||||||
"X-RateLimit-Limit": str(limit.max_requests),
|
|
||||||
"X-RateLimit-Remaining": "0",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await call_next(request)
|
|
||||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
|
||||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
||||||
return response
|
|
||||||
|
|
||||||
except redis.RedisError as e:
|
|
||||||
logger.warning(f"Rate limiter Redis error, using in-memory fallback: {e}")
|
|
||||||
fallback_key = f"ratelimit:{identity}:{path}"
|
|
||||||
allowed, remaining = self._fallback.check(fallback_key, limit.max_requests, limit.window_seconds)
|
|
||||||
if not allowed:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=429,
|
|
||||||
content={"detail": "Rate limit exceeded"},
|
|
||||||
headers={"X-RateLimit-Limit": str(limit.max_requests), "X-RateLimit-Remaining": "0"},
|
|
||||||
)
|
|
||||||
response = await call_next(request)
|
|
||||||
response.headers["X-RateLimit-Limit"] = str(limit.max_requests)
|
|
||||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
||||||
return response
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,14 @@
|
||||||
"""Scraper configuration with environment variable loading."""
|
"""Scraper configuration with environment variable loading."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
|
# Limit OCR threads to 25% of available cores to avoid starving other work.
|
||||||
|
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ScraperConfig:
|
class ScraperConfig:
|
||||||
|
|
|
||||||
|
|
@ -14,27 +14,30 @@ async def export_to_csv(
|
||||||
df = pd.DataFrame(ds)
|
df = pd.DataFrame(ds)
|
||||||
|
|
||||||
# read decisions on file
|
# read decisions on file
|
||||||
decisions_path = "data/decisions.json"
|
decisions_path = Path("data/decisions.json")
|
||||||
decisions = pd.read_json(decisions_path)
|
if decisions_path.exists():
|
||||||
df.loc[:, "decision"] = df.id.apply(lambda x: decisions.get(x))
|
decisions = pd.read_json(decisions_path)
|
||||||
|
df.loc[:, "decision"] = df.id.apply(lambda x: decisions.get(x))
|
||||||
|
|
||||||
# remove _sa_instance_state column
|
# remove _sa_instance_state column
|
||||||
drop_columns = ["_sa_instance_state", "additional_info"]
|
drop_columns = ["_sa_instance_state", "additional_info"]
|
||||||
df = df.drop(columns=drop_columns)
|
df = df.drop(columns=drop_columns)
|
||||||
|
|
||||||
# fill in gap values for service charge and lease left for Excel filters
|
# Ensure columns exist with NaN defaults for clean CSV output
|
||||||
if "service_charge" not in df.columns:
|
for col in ("service_charge", "lease_left", "square_meters"):
|
||||||
df.loc[:, "service_charge"] = -1
|
if col not in df.columns:
|
||||||
df.loc[:, "service_charge"] = df.service_charge.fillna(-1)
|
df.loc[:, col] = float("nan")
|
||||||
if "lease_left" not in df.columns:
|
|
||||||
df.loc[:, "lease_left"] = -1
|
|
||||||
df.loc[:, "lease_left"] = df.lease_left.fillna(-1)
|
|
||||||
if "square_meters" not in df.columns:
|
|
||||||
df.loc[:, "square_meters"] = -1
|
|
||||||
df.loc[:, "square_meters"] = df.square_meters.fillna(-1)
|
|
||||||
|
|
||||||
# Add price per sqm column
|
# Replace -1 sentinel values with NaN
|
||||||
df.loc[:, "price_per_sqm"] = df.price / df.square_meters
|
df.loc[:, "square_meters"] = df.square_meters.replace({-1: float("nan")})
|
||||||
|
|
||||||
|
# Add price per sqm column (guard against zero/missing square_meters)
|
||||||
|
df.loc[:, "price_per_sqm"] = df.apply(
|
||||||
|
lambda row: round(row.price / row.square_meters, 2)
|
||||||
|
if row.square_meters and row.square_meters > 0
|
||||||
|
else None,
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
df = df.sort_values(by=["price_per_sqm"], ascending=True)
|
df = df.sort_values(by=["price_per_sqm"], ascending=True)
|
||||||
df.to_csv(str(output_file), index=False)
|
df.to_csv(str(output_file), index=False)
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,6 @@ from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
# PostgreSQL example (or use "sqlite:///database.db" for SQLite)
|
|
||||||
# DATABASE_URL = "postgresql://user:password@localhost/db_name"
|
|
||||||
# DATABASE_URL = "sqlite:///data/wrongmove.db"
|
|
||||||
# DATABASE_URL = "mysql://wrongmove:wrongmove@localhost:3306/wrongmove"
|
|
||||||
DATABASE_URL = os.environ["DB_CONNECTION_STRING"]
|
DATABASE_URL = os.environ["DB_CONNECTION_STRING"]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,6 +14,6 @@ engine = create_engine(DATABASE_URL, echo=debug) # `echo=True` for debug logs
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
def init_db() -> None:
|
||||||
"""Create all tables (only for development; use migrations in production)."""
|
"""Create all tables (only for development; use migrations in production)."""
|
||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,15 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from abc import abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
from config.scraper_config import MAX_OCR_WORKERS
|
||||||
from models.listing import (
|
from models.listing import (
|
||||||
BuyListing,
|
BuyListing,
|
||||||
FurnishType,
|
FurnishType,
|
||||||
|
|
@ -20,14 +20,12 @@ from models.listing import (
|
||||||
RentListing,
|
RentListing,
|
||||||
)
|
)
|
||||||
from rec import floorplan
|
from rec import floorplan
|
||||||
|
from rec.exceptions import FloorplanDownloadError
|
||||||
from rec.query import detail_query
|
from rec.query import detail_query
|
||||||
from repositories.listing_repository import ListingRepository
|
from repositories.listing_repository import ListingRepository
|
||||||
|
|
||||||
logger = logging.getLogger("uvicorn.error")
|
logger = logging.getLogger("uvicorn.error")
|
||||||
|
|
||||||
# Limit OCR threads to 25% of available cores to avoid starving other work.
|
|
||||||
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_furnish_type(raw: str | None) -> FurnishType:
|
def _parse_furnish_type(raw: str | None) -> FurnishType:
|
||||||
"""Normalise the raw furnish-type string from the API into a FurnishType enum."""
|
"""Normalise the raw furnish-type string from the API into a FurnishType enum."""
|
||||||
|
|
@ -97,13 +95,13 @@ class ListingProcessor:
|
||||||
step_class_name, step_class_name
|
step_class_name, step_class_name
|
||||||
)
|
)
|
||||||
on_step_complete(short_name)
|
on_step_complete(short_name)
|
||||||
except Exception as e:
|
except (ValueError, KeyError, aiohttp.ClientError, FloorplanDownloadError) as e:
|
||||||
logger.error(f"[{listing_id}] {step_class_name} failed: {e}")
|
logger.error(f"[{listing_id}] {step_class_name} failed: {e}")
|
||||||
return None
|
return None
|
||||||
return listing
|
return listing
|
||||||
|
|
||||||
|
|
||||||
class Step:
|
class Step(ABC):
|
||||||
listing_repository: ListingRepository
|
listing_repository: ListingRepository
|
||||||
listing_type: ListingType
|
listing_type: ListingType
|
||||||
|
|
||||||
|
|
@ -123,29 +121,32 @@ class Step:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _find_tenure_value(details: dict[str, Any], tenure_type: str) -> str | None:
|
||||||
|
"""Find a value in the tenure info content by type key."""
|
||||||
|
tenure_content = details.get("property", {}).get("tenureInfo", {}).get("content", [])
|
||||||
|
for item in tenure_content:
|
||||||
|
if item.get("type") == tenure_type:
|
||||||
|
return item.get("value")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _parse_service_charge(details: dict[str, Any]) -> float | None:
|
def _parse_service_charge(details: dict[str, Any]) -> float | None:
|
||||||
"""Parse annual service charge from the tenure info in API response."""
|
"""Parse annual service charge from the tenure info in API response."""
|
||||||
tenure_content = (
|
value = _find_tenure_value(details, "annualServiceCharge")
|
||||||
details.get("property", {}).get("tenureInfo", {}).get("content", [])
|
if value is not None:
|
||||||
)
|
matches = re.findall(r"([\d,.]+)", str(value))
|
||||||
for item in tenure_content:
|
if matches:
|
||||||
if item.get("type") == "annualServiceCharge":
|
return float(matches[0].replace(",", ""))
|
||||||
matches = re.findall(r"([\d,.]+)", str(item.get("value", "")))
|
|
||||||
if matches:
|
|
||||||
return float(matches[0].replace(",", ""))
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _parse_lease_left(details: dict[str, Any]) -> int | None:
|
def _parse_lease_left(details: dict[str, Any]) -> int | None:
|
||||||
"""Parse remaining lease years from the tenure info in API response."""
|
"""Parse remaining lease years from the tenure info in API response."""
|
||||||
tenure_content = (
|
value = _find_tenure_value(details, "lengthOfLease")
|
||||||
details.get("property", {}).get("tenureInfo", {}).get("content", [])
|
if value is not None:
|
||||||
)
|
matches = re.findall(r"(\d+)", str(value))
|
||||||
for item in tenure_content:
|
if matches:
|
||||||
if item.get("type") == "lengthOfLease":
|
return int(matches[0])
|
||||||
matches = re.findall(r"(\d+)", str(item.get("value", "")))
|
|
||||||
if matches:
|
|
||||||
return int(matches[0])
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -265,7 +266,7 @@ class FetchImagesStep(Step):
|
||||||
if response.status == 404:
|
if response.status == 404:
|
||||||
return listing
|
return listing
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise Exception(f"Error for {url}: {response.status}")
|
raise FloorplanDownloadError(url, response.status)
|
||||||
floorplan_path.parent.mkdir(parents=True, exist_ok=True)
|
floorplan_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(floorplan_path, "wb") as f:
|
with open(floorplan_path, "wb") as f:
|
||||||
f.write(await response.read())
|
f.write(await response.read())
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,8 @@ class Slack(Surface):
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_notifier(surfaces: list[Surface] | None = None) -> apprise.Apprise:
|
def get_notifier() -> apprise.Apprise:
|
||||||
surfaces = surfaces or list[Surface]([Slack()])
|
surfaces = [Slack()]
|
||||||
obj = apprise.Apprise()
|
obj = apprise.Apprise()
|
||||||
for surface in surfaces:
|
for surface in surfaces:
|
||||||
if conn := surface.connection_string():
|
if conn := surface.connection_string():
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,15 @@ class CircuitBreakerOpenError(RightmoveAPIError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FloorplanDownloadError(Exception):
|
||||||
|
"""Raised when a floorplan image download fails."""
|
||||||
|
|
||||||
|
def __init__(self, url: str, status_code: int) -> None:
|
||||||
|
self.url = url
|
||||||
|
self.status_code = status_code
|
||||||
|
super().__init__(f"HTTP {status_code} downloading floorplan from {url}")
|
||||||
|
|
||||||
|
|
||||||
class RoutingApiError(Exception):
|
class RoutingApiError(Exception):
|
||||||
"""Error from the Google Routes API."""
|
"""Error from the Google Routes API."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from models.listing import FurnishType, ListingType
|
||||||
from rec import districts
|
from rec import districts
|
||||||
from rec.exceptions import (
|
from rec.exceptions import (
|
||||||
CircuitBreakerOpenError,
|
CircuitBreakerOpenError,
|
||||||
|
RightmoveAPIError,
|
||||||
ThrottlingError,
|
ThrottlingError,
|
||||||
)
|
)
|
||||||
from rec.throttle_detector import get_throttle_metrics, validate_response
|
from rec.throttle_detector import get_throttle_metrics, validate_response
|
||||||
|
|
@ -205,9 +206,9 @@ def _build_listing_params(
|
||||||
7,
|
7,
|
||||||
14,
|
14,
|
||||||
]:
|
]:
|
||||||
raise Exception(
|
raise ValueError(
|
||||||
f"Invalid max days - {max_days_since_added} Can only be got",
|
f"Invalid max_days_since_added={max_days_since_added}, "
|
||||||
[1, 3, 7, 14],
|
f"must be one of [1, 3, 7, 14]"
|
||||||
)
|
)
|
||||||
params["maxDaysSinceAdded"] = str(max_days_since_added)
|
params["maxDaysSinceAdded"] = str(max_days_since_added)
|
||||||
|
|
||||||
|
|
@ -287,7 +288,7 @@ async def _execute_api_request(
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise Exception(
|
raise RightmoveAPIError(
|
||||||
f"{error_context}Failed due to: {await response.text()}"
|
f"{error_context}Failed due to: {await response.text()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,10 @@
|
||||||
"""Floorplan detector service - OCR-based square meter detection."""
|
"""Floorplan detector service - OCR-based square meter detection."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from config.scraper_config import MAX_OCR_WORKERS
|
||||||
from models import Listing
|
from models import Listing
|
||||||
from rec import floorplan
|
from rec import floorplan
|
||||||
from repositories.listing_repository import ListingRepository
|
from repositories.listing_repository import ListingRepository
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
import multiprocessing
|
|
||||||
|
|
||||||
# Use a quarter of available CPUs to avoid starving other processes
|
|
||||||
MAX_OCR_WORKERS = max(1, multiprocessing.cpu_count() // 4)
|
|
||||||
|
|
||||||
|
|
||||||
async def detect_floorplan(repository: ListingRepository) -> None:
|
async def detect_floorplan(repository: ListingRepository) -> None:
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from pathlib import Path
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
from rec.exceptions import FloorplanDownloadError
|
||||||
from repositories import ListingRepository
|
from repositories import ListingRepository
|
||||||
from tenacity import retry, stop_after_attempt, wait_random
|
from tenacity import retry, stop_after_attempt, wait_random
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
|
|
@ -65,10 +66,7 @@ async def dump_images_for_listing(
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise Exception(
|
raise FloorplanDownloadError(url, response.status)
|
||||||
f"Error downloading floorplan for listing {listing.id} "
|
|
||||||
f"from {url}: HTTP {response.status}"
|
|
||||||
)
|
|
||||||
floorplan_path.parent.mkdir(parents=True, exist_ok=True)
|
floorplan_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(floorplan_path, "wb") as f:
|
with open(floorplan_path, "wb") as f:
|
||||||
f.write(await response.read())
|
f.write(await response.read())
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import logging
|
||||||
from config.scraper_config import ScraperConfig
|
from config.scraper_config import ScraperConfig
|
||||||
from listing_processor import ListingProcessor
|
from listing_processor import ListingProcessor
|
||||||
from rec.query import create_session, listing_query
|
from rec.query import create_session, listing_query
|
||||||
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
|
from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
|
||||||
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
|
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
|
||||||
from models.listing import Listing, QueryParameters
|
from models.listing import Listing, QueryParameters
|
||||||
from repositories import ListingRepository
|
from repositories import ListingRepository
|
||||||
|
|
@ -107,15 +107,15 @@ async def _fetch_subquery(
|
||||||
f"{sq.district}: {e}"
|
f"{sq.district}: {e}"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except InvalidResponseError:
|
||||||
# Rightmove returns GENERIC_ERROR when requesting pages
|
# Rightmove returns GENERIC_ERROR when requesting pages
|
||||||
# past the last page of results. This is expected behavior
|
# past the last page of results. This is expected behavior
|
||||||
# and signals we've exhausted this subquery's results.
|
# and signals we've exhausted this subquery's results.
|
||||||
if "GENERIC_ERROR" in str(e):
|
logger.debug(
|
||||||
logger.debug(
|
f"Max page for {sq.district}: {page_id - 1}"
|
||||||
f"Max page for {sq.district}: {page_id - 1}"
|
)
|
||||||
)
|
break
|
||||||
break
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Error fetching page {page_id} for "
|
f"Error fetching page {page_id} for "
|
||||||
f"{sq.district}: {e}"
|
f"{sq.district}: {e}"
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from config.scraper_config import ScraperConfig
|
||||||
from listing_processor import ListingProcessor
|
from listing_processor import ListingProcessor
|
||||||
from models.listing import Listing, QueryParameters
|
from models.listing import Listing, QueryParameters
|
||||||
from rec.query import create_session, listing_query
|
from rec.query import create_session, listing_query
|
||||||
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
|
from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
|
||||||
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
|
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
|
||||||
from repositories.listing_repository import ListingRepository
|
from repositories.listing_repository import ListingRepository
|
||||||
from database import engine
|
from database import engine
|
||||||
|
|
@ -324,12 +324,12 @@ async def _fetch_subquery(
|
||||||
f"Throttling on {sq.district} page {page_id}: {e}"
|
f"Throttling on {sq.district} page {page_id}: {e}"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
except InvalidResponseError:
|
||||||
|
celery_logger.debug(
|
||||||
|
f"Max page for {sq.district}: {page_id - 1}"
|
||||||
|
)
|
||||||
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "GENERIC_ERROR" in str(e):
|
|
||||||
celery_logger.debug(
|
|
||||||
f"Max page for {sq.district}: {page_id - 1}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
celery_logger.warning(
|
celery_logger.warning(
|
||||||
f"Error fetching page {page_id} for "
|
f"Error fetching page {page_id} for "
|
||||||
f"{sq.district}: {e}"
|
f"{sq.district}: {e}"
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ async def test_step_failure_stops_pipeline(
|
||||||
processor = ListingProcessor(listing_repository)
|
processor = ListingProcessor(listing_repository)
|
||||||
|
|
||||||
processor.process_steps[0].needs_processing = AsyncMock(return_value=True)
|
processor.process_steps[0].needs_processing = AsyncMock(return_value=True)
|
||||||
processor.process_steps[0].process = AsyncMock(side_effect=RuntimeError("boom"))
|
processor.process_steps[0].process = AsyncMock(side_effect=ValueError("boom"))
|
||||||
processor.process_steps[1].needs_processing = AsyncMock(return_value=True)
|
processor.process_steps[1].needs_processing = AsyncMock(return_value=True)
|
||||||
processor.process_steps[1].process = AsyncMock()
|
processor.process_steps[1].process = AsyncMock()
|
||||||
processor.process_steps[2].needs_processing = AsyncMock(return_value=True)
|
processor.process_steps[2].needs_processing = AsyncMock(return_value=True)
|
||||||
|
|
|
||||||
|
|
@ -156,7 +156,7 @@ class TestStreamingEndpoint:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client(self):
|
def client(self):
|
||||||
"""Create test client with mocked auth."""
|
"""Create test client with mocked auth and rate limiting bypassed."""
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from api.app import app
|
from api.app import app
|
||||||
from api.auth import get_current_user, User
|
from api.auth import get_current_user, User
|
||||||
|
|
@ -165,13 +165,15 @@ class TestStreamingEndpoint:
|
||||||
return User(sub="test-id", email="test@example.com", name="Test User")
|
return User(sub="test-id", email="test@example.com", name="Test User")
|
||||||
|
|
||||||
app.dependency_overrides[get_current_user] = mock_auth
|
app.dependency_overrides[get_current_user] = mock_auth
|
||||||
yield TestClient(app)
|
with patch("api.rate_limiter._match_endpoint", return_value=None):
|
||||||
|
yield TestClient(app)
|
||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_repository(self):
|
def mock_repository(self):
|
||||||
"""Mock the repository methods."""
|
"""Mock the repository methods and bypass cache."""
|
||||||
with patch("api.app.ListingRepository") as MockRepo:
|
with patch("api.app.get_cached_count", return_value=None), \
|
||||||
|
patch("api.app.ListingRepository") as MockRepo:
|
||||||
mock_instance = MagicMock()
|
mock_instance = MagicMock()
|
||||||
mock_instance.count_listings.return_value = 3
|
mock_instance.count_listings.return_value = 3
|
||||||
mock_instance.stream_listings_optimized.return_value = iter([
|
mock_instance.stream_listings_optimized.return_value = iter([
|
||||||
|
|
|
||||||
|
|
@ -204,8 +204,12 @@ class TestInvalidateCache:
|
||||||
mock_client = mock.MagicMock()
|
mock_client = mock.MagicMock()
|
||||||
mock_pipeline = mock.MagicMock()
|
mock_pipeline = mock.MagicMock()
|
||||||
mock_client.pipeline.return_value = mock_pipeline
|
mock_client.pipeline.return_value = mock_pipeline
|
||||||
# Simulate one scan iteration that returns keys, then done
|
# invalidate_cache scans two patterns (CACHE_PREFIX*, STAGING_PREFIX*)
|
||||||
mock_client.scan.return_value = (0, ["listings:geojson:abc", "listings:geojson:def"])
|
# First scan returns matching keys, second returns none
|
||||||
|
mock_client.scan.side_effect = [
|
||||||
|
(0, ["listings:geojson:abc", "listings:geojson:def"]),
|
||||||
|
(0, []),
|
||||||
|
]
|
||||||
mock_get_client.return_value = mock_client
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
invalidate_cache()
|
invalidate_cache()
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from models.listing import ListingType, QueryParameters
|
from models.listing import ListingType, QueryParameters
|
||||||
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
|
from rec.exceptions import CircuitBreakerOpenError, InvalidResponseError, ThrottlingError
|
||||||
from services.listing_fetcher import (
|
from services.listing_fetcher import (
|
||||||
NUM_WORKERS,
|
NUM_WORKERS,
|
||||||
_fetch_subquery,
|
_fetch_subquery,
|
||||||
|
|
@ -227,7 +227,7 @@ class TestFetchSubquery:
|
||||||
with patch(
|
with patch(
|
||||||
"services.listing_fetcher.listing_query",
|
"services.listing_fetcher.listing_query",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=Exception("GENERIC_ERROR: no more results"),
|
side_effect=InvalidResponseError("GENERIC_ERROR: no more results"),
|
||||||
):
|
):
|
||||||
ids_found = await _fetch_subquery(
|
ids_found = await _fetch_subquery(
|
||||||
sq=sq,
|
sq=sq,
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,12 @@ from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
import pytest
|
import pytest
|
||||||
from models.listing import FurnishType, ListingType
|
from models.listing import FurnishType, ListingType
|
||||||
|
from config.scraper_config import MAX_OCR_WORKERS
|
||||||
from listing_processor import (
|
from listing_processor import (
|
||||||
_parse_furnish_type,
|
_parse_furnish_type,
|
||||||
_parse_available_from,
|
_parse_available_from,
|
||||||
ListingProcessor,
|
ListingProcessor,
|
||||||
FetchListingDetailsStep,
|
FetchListingDetailsStep,
|
||||||
MAX_OCR_WORKERS,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -77,7 +77,7 @@ class TestListingProcessor:
|
||||||
processor = ListingProcessor(mock_repo)
|
processor = ListingProcessor(mock_repo)
|
||||||
for step in processor.process_steps:
|
for step in processor.process_steps:
|
||||||
step.needs_processing = AsyncMock(return_value=True)
|
step.needs_processing = AsyncMock(return_value=True)
|
||||||
step.process = AsyncMock(side_effect=Exception("fail"))
|
step.process = AsyncMock(side_effect=ValueError("fail"))
|
||||||
result = await processor.process_listing(123)
|
result = await processor.process_listing(123)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,7 @@ def convert_to_geojson_feature(listing: RentListing | BuyListing) -> dict[str, A
|
||||||
|
|
||||||
properties: dict[str, Any] = {
|
properties: dict[str, Any] = {
|
||||||
"listing_type": listing_type,
|
"listing_type": listing_type,
|
||||||
"city": "London", # change me
|
"city": "London",
|
||||||
"country": "United Kingdom",
|
"country": "United Kingdom",
|
||||||
"qm": listing.square_meters,
|
"qm": listing.square_meters,
|
||||||
"qmprice": listing.price_per_square_meter,
|
"qmprice": listing.price_per_square_meter,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue