Refactor codebase following Clean Code principles and add 229 tests

- Extract helpers to reduce function sizes (listing_tasks, app.py, query.py, listing_fetcher)
  - Replace nonlocal mutations with _PipelineState dataclass in listing_tasks
  - Fix bugs: isinstance→equality check in repository, verify_exp for OIDC tokens
  - Consolidate duplicate filter methods in listing_repository
  - Move hardcoded config to env vars with backward-compatible defaults
  - Simplify CLI decorator to auto-build QueryParameters
  - Add deprecation docstring to data_access.py
  - Test count: 158 → 387 (all passing)
This commit is contained in:
Viktor Barzin 2026-02-07 20:19:57 +00:00
parent 7e05b3c971
commit 150342bb9e
No known key found for this signature in database
GPG key ID: 0EB088298288D958
48 changed files with 5029 additions and 990 deletions

View file

@ -0,0 +1,151 @@
"""Unit tests for api/auth.py."""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch, MagicMock
import jwt as pyjwt
import pytest
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from api.auth import (
User,
_verify_passkey_token,
_verify_authentik_token,
get_current_user,
)
from api.config import JWT_SECRET, JWT_ALGORITHM, JWT_ISSUER
def _make_passkey_token(
sub: str = "user-123",
email: str = "test@example.com",
name: str = "Test User",
issuer: str = JWT_ISSUER,
secret: str = JWT_SECRET,
algorithm: str = JWT_ALGORITHM,
expires_delta: timedelta | None = timedelta(hours=1),
) -> str:
"""Helper to mint a passkey-style HS256 JWT."""
payload: dict = {"sub": sub, "email": email, "name": name, "iss": issuer}
if expires_delta is not None:
payload["exp"] = datetime.now(timezone.utc) + expires_delta
return pyjwt.encode(payload, secret, algorithm=algorithm)
class TestVerifyPasskeyToken:
"""Tests for _verify_passkey_token()."""
def test_valid_token_returns_user(self) -> None:
token = _make_passkey_token()
user = _verify_passkey_token(token)
assert isinstance(user, User)
assert user.sub == "user-123"
assert user.email == "test@example.com"
assert user.name == "Test User"
def test_valid_token_without_name_uses_email(self) -> None:
payload = {
"sub": "user-456",
"email": "noname@example.com",
"iss": JWT_ISSUER,
"exp": datetime.now(timezone.utc) + timedelta(hours=1),
}
token = pyjwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
user = _verify_passkey_token(token)
assert user.name == "noname@example.com"
def test_rejects_expired_token(self) -> None:
token = _make_passkey_token(expires_delta=timedelta(hours=-1))
with pytest.raises(pyjwt.ExpiredSignatureError):
_verify_passkey_token(token)
def test_rejects_wrong_secret(self) -> None:
token = _make_passkey_token(secret="wrong-secret")
with pytest.raises(pyjwt.InvalidSignatureError):
_verify_passkey_token(token)
def test_rejects_wrong_issuer(self) -> None:
token = _make_passkey_token(issuer="some-other-issuer")
with pytest.raises(pyjwt.InvalidIssuerError):
_verify_passkey_token(token)
class TestVerifyAuthentikToken:
"""Tests for _verify_authentik_token() — specifically that expiration is verified."""
async def test_verifies_expiration_after_fix(self) -> None:
"""After removing verify_exp: False, expired Authentik tokens should be rejected."""
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
public_key = private_key.public_key()
public_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
issuer = "https://authentik.viktorbarzin.me/application/o/wrongmove/"
payload = {
"sub": "authentik-user",
"email": "auth@example.com",
"name": "Auth User",
"iss": issuer,
"aud": "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe",
"exp": datetime.now(timezone.utc) - timedelta(hours=1), # expired
}
token = pyjwt.encode(payload, private_key, algorithm="RS256")
# Build a real PyJWK-compatible signing key mock so jwt.decode
# takes the PyJWK code path (uses key.key directly, skips prepare_key)
mock_signing_key = MagicMock(spec=pyjwt.PyJWK)
mock_signing_key.key = public_key
mock_signing_key.algorithm_name = "RS256"
mock_signing_key.Algorithm = pyjwt.get_algorithm_by_name("RS256")
mock_jwks_client = MagicMock()
mock_jwks_client.get_signing_key_from_jwt.return_value = mock_signing_key
mock_metadata = {
"issuer": issuer,
"jwks_uri": f"{issuer}jwks/",
}
with patch("api.auth.get_oidc_metadata", new_callable=AsyncMock, return_value=mock_metadata), \
patch("api.auth.get_cached_jwks_client", new_callable=AsyncMock, return_value=mock_jwks_client):
with pytest.raises(pyjwt.ExpiredSignatureError):
await _verify_authentik_token(token)
class TestGetCurrentUser:
"""Tests for get_current_user()."""
async def test_routes_to_passkey_verifier_for_matching_issuer(self) -> None:
token = _make_passkey_token()
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
user = await get_current_user(credentials)
assert user.sub == "user-123"
assert user.email == "test@example.com"
async def test_routes_to_authentik_for_other_issuer(self) -> None:
"""When issuer != JWT_ISSUER, should route to Authentik verifier."""
token = _make_passkey_token(issuer="https://authentik.viktorbarzin.me/application/o/wrongmove/")
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
mock_user = User(sub="authentik-user", email="auth@example.com", name="Auth User")
with patch("api.auth._verify_authentik_token", new_callable=AsyncMock, return_value=mock_user):
user = await get_current_user(credentials)
assert user.email == "auth@example.com"
async def test_raises_http_exception_for_invalid_token(self) -> None:
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="not.a.valid.token")
with pytest.raises(HTTPException) as exc_info:
await get_current_user(credentials)
assert exc_info.value.status_code == 401
assert "Invalid token" in exc_info.value.detail
async def test_raises_http_exception_for_garbage_token(self) -> None:
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="totalgarbage")
with pytest.raises(HTTPException) as exc_info:
await get_current_user(credentials)
assert exc_info.value.status_code == 401

View file

@ -0,0 +1,388 @@
"""Characterization and unit tests for the CLI (main.py)."""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import click
import pytest
from click.testing import CliRunner
from models.listing import FurnishType, ListingType, QueryParameters
from main import build_query_parameters, cli, listing_filter_options
class TestBuildQueryParameters:
"""Tests for build_query_parameters()."""
def test_typical_rent_inputs(self) -> None:
qp = build_query_parameters(
type="RENT",
district=["London", "Camden"],
min_bedrooms=2,
max_bedrooms=4,
min_price=1000,
max_price=3000,
furnish_types=["FURNISHED"],
available_from=datetime(2025, 6, 1),
last_seen_days=7,
min_sqm=50,
)
assert qp.listing_type == ListingType.RENT
assert qp.district_names == {"London", "Camden"}
assert qp.min_bedrooms == 2
assert qp.max_bedrooms == 4
assert qp.min_price == 1000
assert qp.max_price == 3000
assert qp.furnish_types == [FurnishType.FURNISHED]
assert qp.let_date_available_from == datetime(2025, 6, 1)
assert qp.last_seen_days == 7
assert qp.min_sqm == 50
def test_typical_buy_inputs(self) -> None:
qp = build_query_parameters(
type="BUY",
district=["Barnet"],
min_bedrooms=3,
max_bedrooms=5,
min_price=200000,
max_price=500000,
furnish_types=[],
available_from=None,
last_seen_days=14,
)
assert qp.listing_type == ListingType.BUY
assert qp.district_names == {"Barnet"}
assert qp.furnish_types is None
assert qp.let_date_available_from is None
assert qp.min_sqm is None
def test_empty_districts_yields_empty_set(self) -> None:
qp = build_query_parameters(
type="RENT",
district=[],
min_bedrooms=1,
max_bedrooms=10,
min_price=0,
max_price=999999,
furnish_types=[],
available_from=None,
last_seen_days=14,
)
assert qp.district_names == set()
def test_none_districts_yields_empty_set(self) -> None:
qp = build_query_parameters(
type="RENT",
district=None,
min_bedrooms=1,
max_bedrooms=10,
min_price=0,
max_price=999999,
furnish_types=[],
available_from=None,
last_seen_days=14,
)
assert qp.district_names == set()
def test_furnish_types_conversion(self) -> None:
qp = build_query_parameters(
type="RENT",
district=["London"],
min_bedrooms=1,
max_bedrooms=10,
min_price=0,
max_price=999999,
furnish_types=["FURNISHED", "UNFURNISHED"],
available_from=None,
last_seen_days=14,
)
assert qp.furnish_types == [FurnishType.FURNISHED, FurnishType.UNFURNISHED]
def test_empty_furnish_types_yields_none(self) -> None:
qp = build_query_parameters(
type="RENT",
district=["London"],
min_bedrooms=1,
max_bedrooms=10,
min_price=0,
max_price=999999,
furnish_types=[],
available_from=None,
last_seen_days=14,
)
assert qp.furnish_types is None
def test_default_optional_parameters(self) -> None:
qp = build_query_parameters(
type="RENT",
district=["London"],
min_bedrooms=1,
max_bedrooms=10,
min_price=0,
max_price=999999,
furnish_types=[],
available_from=None,
last_seen_days=14,
)
assert qp.radius == 0
assert qp.page_size == 500
assert qp.max_days_since_added == 14
class TestListingFilterOptionsDecorator:
"""Tests for the listing_filter_options decorator."""
def test_applies_all_expected_options(self) -> None:
@click.command()
@listing_filter_options
def dummy_cmd(**kwargs: object) -> None:
pass
expected_option_names = {
"type",
"min_bedrooms",
"max_bedrooms",
"min_price",
"max_price",
"district",
"furnish_types",
"available_from",
"last_seen_days",
"min_sqm",
}
param_names = {p.name for p in dummy_cmd.params}
assert expected_option_names.issubset(param_names), (
f"Missing options: {expected_option_names - param_names}"
)
def test_type_option_is_required(self) -> None:
@click.command()
@listing_filter_options
def dummy_cmd(**kwargs: object) -> None:
pass
type_param = next(p for p in dummy_cmd.params if p.name == "type")
assert type_param.required is True
def test_produces_query_parameters_kwarg(self) -> None:
"""After refactoring, the decorator should produce a query_parameters kwarg."""
captured: dict = {}
@click.command()
@listing_filter_options
def dummy_cmd(query_parameters: QueryParameters) -> None:
captured["qp"] = query_parameters
runner = CliRunner()
result = runner.invoke(dummy_cmd, ["--type", "RENT"])
assert result.exit_code == 0, f"Command failed: {result.output}"
assert isinstance(captured["qp"], QueryParameters)
assert captured["qp"].listing_type == ListingType.RENT
class TestDumpListingsCommand:
"""Tests for the dump-listings CLI command."""
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
@patch("main.engine", new_callable=MagicMock)
def test_calls_refresh_listings_with_correct_params(
self,
mock_engine: MagicMock,
mock_refresh: AsyncMock,
) -> None:
from services.listing_service import RefreshResult
mock_refresh.return_value = RefreshResult(
task_id=None,
new_listings_count=5,
message="Fetched 5 new listings",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"dump-listings",
"--type", "RENT",
"--min-bedrooms", "2",
"--max-bedrooms", "4",
"--min-price", "1000",
"--max-price", "3000",
],
)
assert result.exit_code == 0, f"CLI failed: {result.output}"
mock_refresh.assert_called_once()
call_args = mock_refresh.call_args
qp: QueryParameters = call_args.args[1]
assert qp.listing_type == ListingType.RENT
assert qp.min_bedrooms == 2
assert qp.max_bedrooms == 4
assert qp.min_price == 1000
assert qp.max_price == 3000
assert call_args.kwargs.get("full") is not True
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
@patch("main.engine", new_callable=MagicMock)
def test_include_processing_flag_passes_full_true(
self,
mock_engine: MagicMock,
mock_refresh: AsyncMock,
) -> None:
from services.listing_service import RefreshResult
mock_refresh.return_value = RefreshResult(
task_id=None,
new_listings_count=0,
message="Fetched 0 new listings",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"dump-listings",
"--type", "RENT",
"--include-processing",
],
)
assert result.exit_code == 0, f"CLI failed: {result.output}"
mock_refresh.assert_called_once()
call_kwargs = mock_refresh.call_args.kwargs
assert call_kwargs.get("full") is True
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
@patch("main.engine", new_callable=MagicMock)
def test_include_processing_short_flag(
self,
mock_engine: MagicMock,
mock_refresh: AsyncMock,
) -> None:
from services.listing_service import RefreshResult
mock_refresh.return_value = RefreshResult(
task_id=None,
new_listings_count=0,
message="Fetched 0 new listings",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"dump-listings",
"--type", "RENT",
"-p",
],
)
assert result.exit_code == 0, f"CLI failed: {result.output}"
mock_refresh.assert_called_once()
call_kwargs = mock_refresh.call_args.kwargs
assert call_kwargs.get("full") is True
class TestExportCsvCommand:
"""Tests for the export-csv CLI command."""
@patch("main.export_service.export_to_csv", new_callable=AsyncMock)
@patch("main.engine", new_callable=MagicMock)
def test_calls_export_to_csv(
self,
mock_engine: MagicMock,
mock_export: AsyncMock,
) -> None:
from services.export_service import ExportResult
mock_export.return_value = ExportResult(
success=True,
output_path="/tmp/test.csv",
data=None,
record_count=10,
message="Exported 10 listings to /tmp/test.csv",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"export-csv",
"--output-file", "/tmp/test.csv",
"--type", "RENT",
],
)
assert result.exit_code == 0, f"CLI failed: {result.output}"
mock_export.assert_called_once()
call_args = mock_export.call_args
qp = call_args[0][2]
assert qp.listing_type == ListingType.RENT
class TestExportImmowebCommand:
"""Tests for the export-immoweb CLI command."""
@patch("main.export_service.export_to_geojson", new_callable=AsyncMock)
@patch("main.engine", new_callable=MagicMock)
def test_calls_export_to_geojson(
self,
mock_engine: MagicMock,
mock_export: AsyncMock,
) -> None:
from services.export_service import ExportResult
mock_export.return_value = ExportResult(
success=True,
output_path="/tmp/test.geojson",
data=None,
record_count=5,
message="Exported 5 listings to /tmp/test.geojson",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"export-immoweb",
"--output-file", "/tmp/test.geojson",
"--type", "RENT",
],
)
assert result.exit_code == 0, f"CLI failed: {result.output}"
mock_export.assert_called_once()
class TestListDistrictsCommand:
"""Tests for the list-districts CLI command."""
@patch("main.engine", new_callable=MagicMock)
def test_outputs_district_names(self, mock_engine: MagicMock) -> None:
runner = CliRunner()
result = runner.invoke(cli, ["list-districts"])
assert result.exit_code == 0
assert "London" in result.output
assert "Camden" in result.output
assert "Available districts" in result.output
class TestRoutingCommand:
"""Tests for the routing CLI command."""
@patch("main.engine", new_callable=MagicMock)
def test_requires_api_key_env_var(self, mock_engine: MagicMock) -> None:
runner = CliRunner(env={"ROUTING_API_KEY": None})
result = runner.invoke(
cli,
[
"routing",
"--destination-address", "London Bridge",
"--travel-mode", "transit",
"--limit", "1",
],
catch_exceptions=False,
)
assert result.exit_code != 0
assert "ROUTING_API_KEY" in result.output

View file

@ -0,0 +1,62 @@
"""Unit tests for rec/districts.py and services/district_service.py."""
from rec.districts import get_districts, get_district_by_name
from services.district_service import get_all_districts, get_district_names, validate_districts
class TestGetDistricts:
def test_returns_non_empty_dict(self) -> None:
districts = get_districts()
assert isinstance(districts, dict)
assert len(districts) > 0
def test_values_start_with_region_prefix(self) -> None:
for name, region_id in get_districts().items():
assert region_id.startswith("REGION^"), (
f"District '{name}' has value '{region_id}' that doesn't start with REGION^"
)
def test_contains_expected_london_boroughs(self) -> None:
districts = get_districts()
for borough in ("Camden", "Westminster", "Hackney"):
assert borough in districts, f"Expected borough '{borough}' not found"
class TestGetDistrictByName:
def test_valid_name_returns_region_id(self) -> None:
result = get_district_by_name("Camden")
assert result == "REGION^93941"
def test_invalid_name_returns_none(self) -> None:
result = get_district_by_name("Nonexistent District")
assert result is None
class TestGetDistrictNames:
def test_returns_list_matching_dict_keys(self) -> None:
names = get_district_names()
assert isinstance(names, list)
assert names == list(get_districts().keys())
class TestGetAllDistricts:
def test_returns_same_as_get_districts(self) -> None:
assert get_all_districts() == get_districts()
class TestValidateDistricts:
def test_all_valid_returns_empty_list(self) -> None:
result = validate_districts(["Camden", "Westminster", "Hackney"])
assert result == []
def test_some_invalid_returns_invalid_ones(self) -> None:
result = validate_districts(["Camden", "Faketown", "Westminster", "Nowhere"])
assert result == ["Faketown", "Nowhere"]
def test_all_invalid_returns_all(self) -> None:
invalid = ["Faketown", "Nowhere", "Neverland"]
result = validate_districts(invalid)
assert result == invalid
def test_empty_list_returns_empty_list(self) -> None:
result = validate_districts([])
assert result == []

View file

@ -0,0 +1,104 @@
"""Unit tests for rec/floorplan.py."""
from unittest.mock import patch
import numpy as np
from PIL import Image
import pytest
from rec.floorplan import extract_total_sqm, improve_img_for_ocr, calculate_ocr
class TestExtractTotalSqm:
def test_normal_value(self) -> None:
assert extract_total_sqm("Total area: 75.5 sq m") == 75.5
def test_multiple_values_returns_max_in_range(self) -> None:
assert extract_total_sqm("Room 1: 20 sqm, Total: 65 sq m") == 65.0
def test_no_match_returns_none(self) -> None:
assert extract_total_sqm("No area info") is None
def test_below_minimum_returns_none(self) -> None:
assert extract_total_sqm("Area: 15 sq m") is None
def test_above_maximum_returns_none(self) -> None:
assert extract_total_sqm("Area: 200 sq m") is None
def test_edge_just_above_min(self) -> None:
assert extract_total_sqm("Area: 30.1 sq m") == 30.1
def test_edge_just_below_max(self) -> None:
assert extract_total_sqm("Area: 159.9 sq m") == 159.9
def test_exactly_at_min_boundary_returns_none(self) -> None:
# MIN_SQM < sqm, so 30 is not strictly greater than 30
assert extract_total_sqm("Area: 30 sq m") is None
def test_exactly_at_max_boundary_returns_none(self) -> None:
# sqm < MAX_SQM, so 160 is not strictly less than 160
assert extract_total_sqm("Area: 160 sq m") is None
def test_format_sq_dot_m(self) -> None:
assert extract_total_sqm("Area: 80 sq. m") == 80.0
def test_format_sqm_no_space(self) -> None:
assert extract_total_sqm("Area: 80sqm") == 80.0
def test_format_sq_m_with_space(self) -> None:
assert extract_total_sqm("Area: 80 sq m") == 80.0
def test_empty_string(self) -> None:
assert extract_total_sqm("") is None
def test_multiple_valid_values_returns_max(self) -> None:
assert extract_total_sqm("Living: 40 sq m, Total: 100 sq m") == 100.0
class TestImproveImgForOcr:
def test_produces_valid_pil_image(self) -> None:
# Create a small test image (50x50 white image)
img = Image.fromarray(np.ones((50, 50, 3), dtype=np.uint8) * 200)
result = improve_img_for_ocr(img)
assert isinstance(result, Image.Image)
# Result should be a grayscale (thresholded) image
assert result.mode == "L"
def test_output_dimensions_scaled(self) -> None:
img = Image.fromarray(np.ones((100, 100, 3), dtype=np.uint8) * 128)
result = improve_img_for_ocr(img)
# After 1.2x resize, 100 -> 120
assert result.size[0] == 120
assert result.size[1] == 120
class TestCalculateOcr:
def test_invalid_path_raises_file_not_found(self) -> None:
with pytest.raises(FileNotFoundError):
calculate_ocr("/nonexistent/path/to/image.png")
def test_returns_sqm_from_first_pass(self, tmp_path) -> None: # type: ignore[no-untyped-def]
# Create a real image file so the path check passes
image_file = tmp_path / "test.png"
Image.fromarray(np.ones((10, 10, 3), dtype=np.uint8)).save(str(image_file))
with patch("pytesseract.image_to_string", return_value="Total: 85 sq m"):
result_sqm, result_text = calculate_ocr(str(image_file))
assert result_sqm == 85.0
assert result_text == "Total: 85 sq m"
def test_falls_back_to_improved_image(self, tmp_path) -> None: # type: ignore[no-untyped-def]
image_file = tmp_path / "test.png"
Image.fromarray(np.ones((10, 10, 3), dtype=np.uint8)).save(str(image_file))
# First call returns no sqm data, second (on improved image) returns valid data
with patch("pytesseract.image_to_string", side_effect=[
"No area info here",
"Total: 72 sq m",
]):
result_sqm, result_text = calculate_ocr(str(image_file))
assert result_sqm == 72.0
assert result_text == "Total: 72 sq m"

View file

@ -0,0 +1,110 @@
"""Unit tests for services/floorplan_detector.py."""
import asyncio
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
from models.listing import RentListing, ListingSite, FurnishType
from services.floorplan_detector import _calculate_sqm_ocr, detect_floorplan
def _make_listing(**kwargs) -> RentListing: # type: ignore[no-untyped-def]
defaults = dict(
id=1,
price=2000.0,
number_of_bedrooms=2,
square_meters=None,
agency="Test",
council_tax_band="C",
longitude=0.0,
latitude=0.0,
price_history_json="[]",
listing_site=ListingSite.RIGHTMOVE,
last_seen=datetime.now(),
photo_thumbnail=None,
floorplan_image_paths=[],
additional_info={"property": {"visible": True}},
routing_info_json=None,
furnish_type=FurnishType.FURNISHED,
available_from=None,
)
defaults.update(kwargs)
return RentListing(**defaults)
class TestCalculateSqmOcr:
async def test_skips_listing_with_existing_square_meters(self) -> None:
listing = _make_listing(square_meters=50.0)
semaphore = asyncio.Semaphore(1)
result = await _calculate_sqm_ocr(listing, semaphore)
assert result is None
async def test_empty_floorplan_paths_returns_listing_with_zero(self) -> None:
listing = _make_listing(floorplan_image_paths=[])
semaphore = asyncio.Semaphore(1)
result = await _calculate_sqm_ocr(listing, semaphore)
assert result is not None
assert result.square_meters == 0
@patch("services.floorplan_detector.floorplan")
async def test_with_mocked_ocr_returning_value(self, mock_floorplan: MagicMock) -> None:
mock_floorplan.calculate_ocr.return_value = (85.0, "Total: 85 sq m")
listing = _make_listing(floorplan_image_paths=["/fake/path.png"])
semaphore = asyncio.Semaphore(1)
result = await _calculate_sqm_ocr(listing, semaphore)
assert result is not None
assert result.square_meters == 85.0
@patch("services.floorplan_detector.floorplan")
async def test_with_mocked_ocr_returning_none(self, mock_floorplan: MagicMock) -> None:
mock_floorplan.calculate_ocr.return_value = (None, "no data")
listing = _make_listing(floorplan_image_paths=["/fake/path.png"])
semaphore = asyncio.Semaphore(1)
result = await _calculate_sqm_ocr(listing, semaphore)
assert result is not None
assert result.square_meters == 0
@patch("services.floorplan_detector.floorplan")
async def test_picks_max_from_multiple_floorplans(self, mock_floorplan: MagicMock) -> None:
mock_floorplan.calculate_ocr.side_effect = [
(50.0, "50 sq m"),
(90.0, "90 sq m"),
]
listing = _make_listing(floorplan_image_paths=["/fake/a.png", "/fake/b.png"])
semaphore = asyncio.Semaphore(2)
result = await _calculate_sqm_ocr(listing, semaphore)
assert result is not None
assert result.square_meters == 90.0
class TestDetectFloorplan:
@patch("services.floorplan_detector.floorplan")
async def test_detect_floorplan_with_mocked_repository(self, mock_floorplan: MagicMock) -> None:
mock_floorplan.calculate_ocr.return_value = (75.0, "75 sq m")
listing = _make_listing(
floorplan_image_paths=["/fake/path.png"],
)
repository = MagicMock()
repository.get_listings = AsyncMock(return_value=[listing])
repository.upsert_listings = AsyncMock(return_value=[])
await detect_floorplan(repository)
repository.upsert_listings.assert_called_once()
upserted = repository.upsert_listings.call_args[0][0]
assert len(upserted) == 1
assert upserted[0].square_meters == 75.0
async def test_detect_floorplan_skips_already_processed(self) -> None:
listing = _make_listing(square_meters=50.0)
repository = MagicMock()
repository.get_listings = AsyncMock(return_value=[listing])
repository.upsert_listings = AsyncMock(return_value=[])
await detect_floorplan(repository)
repository.upsert_listings.assert_called_once()
upserted = repository.upsert_listings.call_args[0][0]
assert len(upserted) == 0

View file

@ -0,0 +1,215 @@
"""Unit tests for the image fetcher service."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime
import aiohttp
import pytest
from tenacity import stop_after_attempt
from models.listing import RentListing, ListingSite, FurnishType
from services.image_fetcher import dump_images_for_listing, MAX_CONCURRENT_DOWNLOADS
def _make_listing(**kwargs) -> RentListing: # type: ignore[no-untyped-def]
"""Create a RentListing with sensible defaults for testing."""
defaults = dict(
id=12345,
price=2000.0,
number_of_bedrooms=2,
square_meters=None,
agency="Test Agency",
council_tax_band="C",
longitude=0.0,
latitude=0.0,
price_history_json="[]",
listing_site=ListingSite.RIGHTMOVE,
last_seen=datetime.now(),
photo_thumbnail=None,
floorplan_image_paths=[],
additional_info={
"property": {
"visible": True,
"floorplans": [
{"url": "https://media.rightmove.co.uk/imgs/floorplan_1.jpg"}
],
}
},
routing_info_json=None,
furnish_type=FurnishType.FURNISHED,
available_from=None,
)
defaults.update(kwargs)
return RentListing(**defaults)
class TestDumpImagesForListing:
"""Tests for dump_images_for_listing function."""
async def test_downloads_floorplan_image(self, tmp_path: Path) -> None:
"""Test successful floorplan image download."""
listing = _make_listing()
image_bytes = b"\x89PNG fake image data"
mock_response = AsyncMock()
mock_response.status = 200
mock_response.read = AsyncMock(return_value=image_bytes)
mock_session = MagicMock(spec=aiohttp.ClientSession)
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_session.get = MagicMock(return_value=mock_cm)
result = await dump_images_for_listing(
listing, tmp_path, session=mock_session
)
assert result is not None
assert result.id == 12345
assert len(result.floorplan_image_paths) == 1
# Verify the image was written
written_path = Path(result.floorplan_image_paths[0])
assert written_path.exists()
assert written_path.read_bytes() == image_bytes
async def test_skips_existing_images(self, tmp_path: Path) -> None:
"""Test that existing images are not re-downloaded."""
listing = _make_listing()
# Pre-create the floorplan file
floorplan_dir = tmp_path / str(listing.id) / "floorplans"
floorplan_dir.mkdir(parents=True)
existing_file = floorplan_dir / "floorplan_1.jpg"
existing_file.write_bytes(b"existing image")
mock_session = MagicMock(spec=aiohttp.ClientSession)
result = await dump_images_for_listing(
listing, tmp_path, session=mock_session
)
# Should return None because the only floorplan was skipped (continue)
assert result is None
# Session.get should NOT have been called
mock_session.get.assert_not_called()
async def test_returns_none_on_404(self, tmp_path: Path) -> None:
"""Test that 404 responses return None (image not found)."""
listing = _make_listing()
mock_response = AsyncMock()
mock_response.status = 404
mock_session = MagicMock(spec=aiohttp.ClientSession)
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_session.get = MagicMock(return_value=mock_cm)
result = await dump_images_for_listing(
listing, tmp_path, session=mock_session
)
assert result is None
async def test_raises_on_non_200_status(self, tmp_path: Path) -> None:
"""Test that non-200/404 status raises exception."""
listing = _make_listing()
mock_response = AsyncMock()
mock_response.status = 500
mock_session = MagicMock(spec=aiohttp.ClientSession)
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_session.get = MagicMock(return_value=mock_cm)
with pytest.raises(Exception, match="HTTP 500"):
# Disable tenacity retry for testing: stop after 1 attempt and reraise
await dump_images_for_listing.retry_with(
stop=stop_after_attempt(1),
reraise=True,
)(listing, tmp_path, session=mock_session)
async def test_returns_none_when_no_floorplans(self, tmp_path: Path) -> None:
"""Test listing with no floorplans returns None."""
listing = _make_listing(
additional_info={"property": {"visible": True, "floorplans": []}}
)
mock_session = MagicMock(spec=aiohttp.ClientSession)
result = await dump_images_for_listing(
listing, tmp_path, session=mock_session
)
assert result is None
async def test_url_filename_extraction(self, tmp_path: Path) -> None:
"""Test that filenames are correctly extracted from URLs."""
listing = _make_listing(
additional_info={
"property": {
"visible": True,
"floorplans": [
{
"url": "https://media.rightmove.co.uk/dir/sub/my_floorplan.png"
}
],
}
}
)
image_bytes = b"fake png"
mock_response = AsyncMock()
mock_response.status = 200
mock_response.read = AsyncMock(return_value=image_bytes)
mock_session = MagicMock(spec=aiohttp.ClientSession)
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_session.get = MagicMock(return_value=mock_cm)
result = await dump_images_for_listing(
listing, tmp_path, session=mock_session
)
assert result is not None
written_path = Path(result.floorplan_image_paths[0])
assert written_path.name == "my_floorplan.png"
async def test_creates_session_when_none_provided(self, tmp_path: Path) -> None:
"""Test that a session is created and closed when none is provided."""
listing = _make_listing()
image_bytes = b"fake image"
mock_response = AsyncMock()
mock_response.status = 200
mock_response.read = AsyncMock(return_value=image_bytes)
mock_session_instance = MagicMock(spec=aiohttp.ClientSession)
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_session_instance.get = MagicMock(return_value=mock_cm)
mock_session_instance.close = AsyncMock()
with patch(
"services.image_fetcher.aiohttp.ClientSession",
return_value=mock_session_instance,
):
result = await dump_images_for_listing(listing, tmp_path, session=None)
assert result is not None
mock_session_instance.close.assert_awaited_once()
class TestImageFetcherConfig:
"""Tests for image fetcher configuration."""
def test_max_concurrent_downloads_constant(self) -> None:
"""Test that MAX_CONCURRENT_DOWNLOADS is defined and reasonable."""
assert MAX_CONCURRENT_DOWNLOADS > 0
assert MAX_CONCURRENT_DOWNLOADS <= 20

View file

@ -0,0 +1,225 @@
"""Unit tests for services/listing_cache.py."""
import json
from unittest import mock
import pytest
import redis
from models.listing import ListingType, QueryParameters
from services.listing_cache import (
CACHE_PREFIX,
_get_redis_client,
cache_features_batch,
get_cached_count,
get_cached_features,
invalidate_cache,
make_cache_key,
)
def _make_query(**overrides) -> QueryParameters:
"""Create a QueryParameters with defaults for testing."""
defaults = {"listing_type": ListingType.RENT, "min_price": 1000, "max_price": 3000}
defaults.update(overrides)
return QueryParameters(**defaults)
class TestMakeCacheKey:
"""Tests for make_cache_key()."""
def test_deterministic_for_same_params(self):
"""Same parameters produce the same cache key."""
qp = _make_query()
assert make_cache_key(qp) == make_cache_key(qp)
def test_different_for_different_params(self):
"""Different parameters produce different cache keys."""
qp1 = _make_query(min_price=1000)
qp2 = _make_query(min_price=2000)
assert make_cache_key(qp1) != make_cache_key(qp2)
def test_key_starts_with_prefix(self):
"""Cache key starts with CACHE_PREFIX."""
qp = _make_query()
assert make_cache_key(qp).startswith(CACHE_PREFIX)
class TestGetRedisClient:
"""Tests for _get_redis_client() URL parsing."""
@mock.patch("services.listing_cache.redis")
def test_default_broker_url(self, mock_redis):
"""Uses default localhost URL when env var is not set."""
with mock.patch.dict("os.environ", {}, clear=True):
_get_redis_client()
mock_redis.from_url.assert_called_once_with(
"redis://localhost:6379/2", decode_responses=True
)
@mock.patch("services.listing_cache.redis")
def test_custom_broker_url(self, mock_redis):
"""Replaces db number from custom broker URL."""
with mock.patch.dict(
"os.environ", {"CELERY_BROKER_URL": "redis://myhost:1234/5"}
):
_get_redis_client()
mock_redis.from_url.assert_called_once_with(
"redis://myhost:1234/2", decode_responses=True
)
@mock.patch("services.listing_cache.redis")
def test_broker_url_with_password(self, mock_redis):
"""Preserves auth info in broker URL."""
with mock.patch.dict(
"os.environ",
{"CELERY_BROKER_URL": "redis://:secret@myhost:6379/0"},
):
_get_redis_client()
mock_redis.from_url.assert_called_once_with(
"redis://:secret@myhost:6379/2", decode_responses=True
)
@mock.patch("services.listing_cache.redis")
def test_broker_url_with_query_params(self, mock_redis):
"""Preserves query parameters in broker URL."""
with mock.patch.dict(
"os.environ",
{"CELERY_BROKER_URL": "redis://myhost:6379/0?timeout=5"},
):
_get_redis_client()
mock_redis.from_url.assert_called_once_with(
"redis://myhost:6379/2?timeout=5", decode_responses=True
)
class TestGetCachedCount:
"""Tests for get_cached_count()."""
@mock.patch("services.listing_cache._get_redis_client")
def test_returns_none_on_redis_error(self, mock_get_client):
"""Returns None when Redis raises an error."""
mock_get_client.side_effect = redis.RedisError("connection refused")
result = get_cached_count(_make_query())
assert result is None
@mock.patch("services.listing_cache._get_redis_client")
def test_returns_none_when_key_not_exists(self, mock_get_client):
"""Returns None when the cache key does not exist."""
mock_client = mock.MagicMock()
mock_client.exists.return_value = False
mock_get_client.return_value = mock_client
result = get_cached_count(_make_query())
assert result is None
@mock.patch("services.listing_cache._get_redis_client")
def test_returns_count_when_key_exists(self, mock_get_client):
"""Returns list length when key exists."""
mock_client = mock.MagicMock()
mock_client.exists.return_value = True
mock_client.llen.return_value = 42
mock_get_client.return_value = mock_client
result = get_cached_count(_make_query())
assert result == 42
class TestGetCachedFeatures:
"""Tests for get_cached_features()."""
@mock.patch("services.listing_cache._get_redis_client")
def test_yields_empty_on_redis_error(self, mock_get_client):
"""Yields nothing when Redis raises an error."""
mock_get_client.side_effect = redis.RedisError("connection refused")
batches = list(get_cached_features(_make_query()))
assert batches == []
@mock.patch("services.listing_cache._get_redis_client")
def test_yields_batches(self, mock_get_client):
"""Yields features in batches."""
features = [{"type": "Feature", "id": i} for i in range(3)]
mock_client = mock.MagicMock()
mock_client.llen.return_value = 3
mock_client.lrange.return_value = [json.dumps(f) for f in features]
mock_get_client.return_value = mock_client
batches = list(get_cached_features(_make_query(), batch_size=50))
assert len(batches) == 1
assert batches[0] == features
class TestCacheFeaturesBatch:
"""Tests for cache_features_batch()."""
@mock.patch("services.listing_cache._get_redis_client")
def test_empty_features_returns_early(self, mock_get_client):
"""Does not call Redis when features list is empty."""
cache_features_batch(_make_query(), [])
mock_get_client.assert_not_called()
@mock.patch("services.listing_cache._get_redis_client")
def test_writes_features_via_pipeline(self, mock_get_client):
"""Writes features and sets TTL through pipeline."""
mock_client = mock.MagicMock()
mock_pipeline = mock.MagicMock()
mock_client.pipeline.return_value = mock_pipeline
mock_get_client.return_value = mock_client
features = [{"type": "Feature", "id": 1}]
cache_features_batch(_make_query(), features)
mock_pipeline.rpush.assert_called_once()
mock_pipeline.expire.assert_called_once()
mock_pipeline.execute.assert_called_once()
@mock.patch("services.listing_cache._get_redis_client")
def test_handles_redis_error(self, mock_get_client):
"""Handles Redis error gracefully during write."""
mock_get_client.side_effect = redis.RedisError("write error")
# Should not raise
cache_features_batch(_make_query(), [{"id": 1}])
class TestInvalidateCache:
"""Tests for invalidate_cache()."""
@mock.patch("services.listing_cache._get_redis_client")
def test_handles_redis_error(self, mock_get_client):
"""Handles Redis error gracefully during invalidation."""
mock_get_client.side_effect = redis.RedisError("connection refused")
# Should not raise
invalidate_cache()
@mock.patch("services.listing_cache._get_redis_client")
def test_deletes_matching_keys_via_pipeline(self, mock_get_client):
"""Deletes keys matching the cache prefix using pipeline."""
mock_client = mock.MagicMock()
mock_pipeline = mock.MagicMock()
mock_client.pipeline.return_value = mock_pipeline
# Simulate one scan iteration that returns keys, then done
mock_client.scan.return_value = (0, ["listings:geojson:abc", "listings:geojson:def"])
mock_get_client.return_value = mock_client
invalidate_cache()
assert mock_pipeline.delete.call_count == 2
mock_pipeline.execute.assert_called_once()
@mock.patch("services.listing_cache._get_redis_client")
def test_no_keys_to_delete(self, mock_get_client):
"""Does nothing when no cache keys exist."""
mock_client = mock.MagicMock()
mock_client.scan.return_value = (0, [])
mock_get_client.return_value = mock_client
invalidate_cache()
mock_client.pipeline.assert_not_called()

View file

@ -0,0 +1,372 @@
"""Unit tests for the listing fetcher service."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from models.listing import ListingType, QueryParameters
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
from services.listing_fetcher import (
NUM_WORKERS,
_fetch_subquery,
dump_listings,
dump_listings_full,
)
from services.query_splitter import SubQuery
def _make_subquery(**kwargs) -> SubQuery:
"""Create a SubQuery with sensible defaults for testing."""
defaults = dict(
district="REGION^123",
min_bedrooms=1,
max_bedrooms=3,
min_price=1000,
max_price=3000,
estimated_results=50,
)
defaults.update(kwargs)
return SubQuery(**defaults)
class TestDumpListingsFull:
"""Tests for dump_listings_full."""
async def test_returns_empty_list_when_no_new_listings(self) -> None:
"""Test that empty results from dump_listings returns empty list."""
with patch(
"services.listing_fetcher.dump_listings",
new_callable=AsyncMock,
return_value=[],
):
mock_repo = AsyncMock()
mock_repo.get_listings = AsyncMock(return_value=[])
params = QueryParameters(listing_type=ListingType.RENT)
result = await dump_listings_full(params, mock_repo)
assert result == []
async def test_returns_only_new_listings_from_db(self) -> None:
"""Test that dump_listings_full fetches new listings by ID from the repository."""
mock_listing_1 = MagicMock()
mock_listing_1.id = 100
mock_listing_2 = MagicMock()
mock_listing_2.id = 200
with patch(
"services.listing_fetcher.dump_listings",
new_callable=AsyncMock,
return_value=[mock_listing_1, mock_listing_2],
):
mock_repo = AsyncMock()
mock_repo.get_listings = AsyncMock(
return_value=[mock_listing_1, mock_listing_2]
)
params = QueryParameters(listing_type=ListingType.RENT)
result = await dump_listings_full(params, mock_repo)
# Verify get_listings was called with the correct IDs
mock_repo.get_listings.assert_awaited_once_with(
only_ids=[100, 200]
)
assert len(result) == 2
class TestFetchSubquery:
"""Tests for _fetch_subquery helper."""
async def test_skips_subquery_with_zero_estimated_results(self) -> None:
"""Test that subqueries with 0 estimated results are skipped."""
sq = _make_subquery(estimated_results=0)
params = QueryParameters(listing_type=ListingType.RENT)
queue: asyncio.Queue[int | None] = asyncio.Queue()
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=MagicMock(),
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_skips_subquery_with_none_estimated_results(self) -> None:
"""Test that subqueries with None estimated results are skipped."""
sq = _make_subquery(estimated_results=None)
params = QueryParameters(listing_type=ListingType.RENT)
queue: asyncio.Queue[int | None] = asyncio.Queue()
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=MagicMock(),
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_enqueues_new_ids_only(self) -> None:
"""Test that only new (not existing) IDs are enqueued."""
sq = _make_subquery(estimated_results=10)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
existing_ids: set[int] = {101, 103}
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
mock_config.max_concurrent_requests = 5
api_result = {
"properties": [
{"identifier": 101}, # existing
{"identifier": 102}, # new
{"identifier": 103}, # existing
{"identifier": 104}, # new
]
}
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
return_value=api_result,
):
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=existing_ids,
queue=queue,
)
assert ids_found == 2
# Verify that queued IDs are the new ones
queued = []
while not queue.empty():
queued.append(queue.get_nowait())
assert 102 in queued
assert 104 in queued
assert 101 not in queued
assert 103 not in queued
async def test_stops_on_circuit_breaker_error(self) -> None:
"""Test that CircuitBreakerOpenError breaks the page loop."""
sq = _make_subquery(estimated_results=100)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
side_effect=CircuitBreakerOpenError("open"),
):
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_stops_on_throttling_error(self) -> None:
"""Test that ThrottlingError breaks the page loop."""
sq = _make_subquery(estimated_results=100)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
side_effect=ThrottlingError("throttled"),
):
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_stops_on_generic_error(self) -> None:
"""Test that GENERIC_ERROR (past last page) stops pagination."""
sq = _make_subquery(estimated_results=100)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
side_effect=Exception("GENERIC_ERROR: no more results"),
):
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_stops_on_unexpected_error(self) -> None:
"""Test that unexpected errors also stop pagination."""
sq = _make_subquery(estimated_results=100)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
side_effect=Exception("some network error"),
):
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
assert ids_found == 0
assert queue.empty()
async def test_stops_when_fewer_results_than_page_size(self) -> None:
"""Test that pagination stops when a page has fewer results than page_size."""
sq = _make_subquery(estimated_results=100)
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
queue: asyncio.Queue[int | None] = asyncio.Queue()
mock_config = MagicMock()
mock_config.max_pages_per_query = 60
mock_config.request_delay_ms = 0
# Return fewer properties than page_size
api_result = {
"properties": [
{"identifier": 1},
{"identifier": 2},
]
}
with patch(
"services.listing_fetcher.listing_query",
new_callable=AsyncMock,
return_value=api_result,
) as mock_query:
ids_found = await _fetch_subquery(
sq=sq,
parameters=params,
session=MagicMock(),
config=mock_config,
semaphore=asyncio.Semaphore(5),
existing_ids=set(),
queue=queue,
)
# Should have called listing_query exactly once (then stopped)
assert mock_query.await_count == 1
assert ids_found == 2
class TestDumpListings:
"""Tests for dump_listings."""
async def test_circuit_breaker_returns_empty_list(self) -> None:
"""Test that CircuitBreakerOpenError returns empty list."""
mock_repo = AsyncMock()
params = QueryParameters(listing_type=ListingType.RENT)
with patch("services.listing_fetcher.create_session") as mock_cs:
mock_cs.side_effect = CircuitBreakerOpenError("open")
result = await dump_listings(params, mock_repo)
assert result == []
async def test_returns_processed_listings(self) -> None:
"""Test that dump_listings returns processed listings from the pipeline."""
mock_repo = AsyncMock()
mock_repo.get_listing_ids = MagicMock(return_value=set())
params = QueryParameters(listing_type=ListingType.RENT)
mock_listing = MagicMock()
mock_listing.id = 42
mock_session_cm = AsyncMock()
mock_session = MagicMock()
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cm.__aexit__ = AsyncMock(return_value=False)
with (
patch(
"services.listing_fetcher.create_session",
return_value=mock_session_cm,
),
patch(
"services.listing_fetcher.QuerySplitter"
) as mock_splitter_cls,
patch(
"services.listing_fetcher._fetch_subquery",
new_callable=AsyncMock,
return_value=0,
),
):
mock_splitter = mock_splitter_cls.return_value
mock_splitter.split = AsyncMock(return_value=[])
mock_splitter.calculate_total_estimated_results = MagicMock(
return_value=0
)
result = await dump_listings(params, mock_repo)
# With no subqueries, no listings are processed
assert result == []
class TestNumWorkers:
"""Tests for NUM_WORKERS constant."""
def test_num_workers_is_positive(self) -> None:
"""Test that NUM_WORKERS is a positive integer."""
assert NUM_WORKERS > 0
def test_num_workers_value(self) -> None:
"""Test that NUM_WORKERS has the expected value."""
assert NUM_WORKERS == 20

View file

@ -0,0 +1,87 @@
"""Unit tests for the listing processor."""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from models.listing import FurnishType
from listing_processor import (
_parse_furnish_type,
_parse_available_from,
ListingProcessor,
FetchListingDetailsStep,
MAX_OCR_WORKERS,
)
class TestParseFurnishType:
"""Tests for _parse_furnish_type helper."""
def test_none_returns_unknown(self):
assert _parse_furnish_type(None) == FurnishType.UNKNOWN
def test_ask_landlord_variant(self):
assert _parse_furnish_type("Ask landlord") == FurnishType.ASK_LANDLORD
def test_furnished_lowercased(self):
assert _parse_furnish_type("Furnished") == FurnishType.FURNISHED
def test_unfurnished(self):
assert _parse_furnish_type("Unfurnished") == FurnishType.UNFURNISHED
def test_part_furnished(self):
assert _parse_furnish_type("Part Furnished") == FurnishType.PART_FURNISHED
def test_unknown_string_returns_unknown(self):
assert _parse_furnish_type("unknown") == FurnishType.UNKNOWN
def test_garbage_string_returns_unknown(self):
assert _parse_furnish_type("xyzzy") == FurnishType.UNKNOWN
class TestParseAvailableFrom:
"""Tests for _parse_available_from helper."""
def test_none_returns_none(self):
assert _parse_available_from(None) is None
def test_now_returns_datetime(self):
result = _parse_available_from("Now")
assert isinstance(result, datetime)
def test_valid_date_string(self):
result = _parse_available_from("15/03/2024")
assert result is not None
assert result.day == 15
assert result.month == 3
def test_invalid_date_returns_none(self):
assert _parse_available_from("invalid") is None
class TestListingProcessor:
"""Tests for ListingProcessor."""
async def test_process_listing_marks_seen(self):
"""Test that process_listing calls mark_seen."""
mock_repo = AsyncMock()
mock_repo.get_listings = AsyncMock(return_value=[MagicMock()])
processor = ListingProcessor(mock_repo)
# Mock all steps to not need processing
for step in processor.process_steps:
step.needs_processing = AsyncMock(return_value=False)
await processor.process_listing(123)
mock_repo.mark_seen.assert_awaited_once_with(123)
async def test_process_listing_returns_none_on_step_failure(self):
"""Test that a step failure returns None."""
mock_repo = AsyncMock()
processor = ListingProcessor(mock_repo)
for step in processor.process_steps:
step.needs_processing = AsyncMock(return_value=True)
step.process = AsyncMock(side_effect=Exception("fail"))
result = await processor.process_listing(123)
assert result is None
class TestOcrWorkersConfig:
def test_max_ocr_workers_positive(self):
assert MAX_OCR_WORKERS >= 1

View file

@ -0,0 +1,295 @@
"""Unit tests for tasks/listing_tasks.py."""
import json
import os
from collections import deque
from unittest.mock import MagicMock, patch, AsyncMock, call
import pytest
import tasks.listing_tasks as module
from tasks.listing_tasks import (
_update_task_state,
_PipelineState,
TaskLogHandler,
SCRAPE_LOCK_NAME,
LOG_BUFFER_MAX_LINES,
NUM_WORKERS,
PHASE_SPLITTING,
PHASE_FETCHING,
PHASE_PROCESSING,
PHASE_COMPLETED,
)
class TestUpdateTaskState:
"""Tests for _update_task_state."""
def test_injects_logs_from_active_buffer(self):
task = MagicMock()
original = module._active_log_buffer
try:
module._active_log_buffer = deque(["log line 1", "log line 2"])
_update_task_state(task, "test_state", {"key": "value"})
task.update_state.assert_called_once()
call_meta = task.update_state.call_args[1]["meta"]
assert call_meta["logs"] == ["log line 1", "log line 2"]
assert call_meta["key"] == "value"
finally:
module._active_log_buffer = original
def test_works_when_buffer_is_none(self):
task = MagicMock()
original = module._active_log_buffer
try:
module._active_log_buffer = None
_update_task_state(task, "some_state", {"phase": "testing"})
task.update_state.assert_called_once_with(
state="some_state", meta={"phase": "testing"}
)
# No "logs" key should be injected
call_meta = task.update_state.call_args[1]["meta"]
assert "logs" not in call_meta
finally:
module._active_log_buffer = original
def test_state_string_is_passed_through(self):
task = MagicMock()
original = module._active_log_buffer
try:
module._active_log_buffer = None
_update_task_state(task, "PROGRESS", {})
task.update_state.assert_called_once_with(state="PROGRESS", meta={})
finally:
module._active_log_buffer = original
def test_empty_buffer_injects_empty_list(self):
task = MagicMock()
original = module._active_log_buffer
try:
module._active_log_buffer = deque()
_update_task_state(task, "state", {"a": 1})
call_meta = task.update_state.call_args[1]["meta"]
assert call_meta["logs"] == []
finally:
module._active_log_buffer = original
class TestTaskLogHandler:
"""Tests for the TaskLogHandler."""
def test_emit_appends_to_buffer(self):
buf = deque(maxlen=10)
handler = TaskLogHandler(buf)
handler.setFormatter(
__import__("logging").Formatter("%(message)s")
)
record = __import__("logging").LogRecord(
name="test", level=20, pathname="", lineno=0,
msg="hello", args=(), exc_info=None,
)
handler.emit(record)
assert "hello" in buf
def test_buffer_respects_maxlen(self):
buf = deque(maxlen=2)
handler = TaskLogHandler(buf)
handler.setFormatter(
__import__("logging").Formatter("%(message)s")
)
for i in range(5):
record = __import__("logging").LogRecord(
name="test", level=20, pathname="", lineno=0,
msg=f"msg{i}", args=(), exc_info=None,
)
handler.emit(record)
assert len(buf) == 2
assert list(buf) == ["msg3", "msg4"]
class TestDumpListingsTask:
"""Tests for dump_listings_task Celery task."""
@patch("tasks.listing_tasks.redis_lock")
def test_skips_when_lock_not_acquired(self, mock_redis_lock):
"""Task should skip when another scrape is running."""
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=False)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_redis_lock.return_value = mock_cm
from tasks.listing_tasks import dump_listings_task
# Use run() which handles bind=True properly
task_instance = dump_listings_task
task_instance.update_state = MagicMock()
result = dump_listings_task.run('{"listing_type": "RENT"}')
assert result["status"] == "skipped"
assert result["reason"] == "another_job_running"
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
@patch("tasks.listing_tasks.asyncio.run")
@patch("tasks.listing_tasks.redis_lock")
def test_calls_dump_listings_full_when_lock_acquired(
self, mock_redis_lock, mock_asyncio_run
):
"""Task should call dump_listings_full when lock is acquired."""
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=True)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_redis_lock.return_value = mock_cm
mock_asyncio_run.return_value = []
from tasks.listing_tasks import dump_listings_task
task_instance = dump_listings_task
task_instance.update_state = MagicMock()
params_json = '{"listing_type": "RENT", "min_price": 1000, "max_price": 5000}'
result = dump_listings_task.run(params_json)
assert result["phase"] == "completed"
assert result["progress"] == 1
mock_asyncio_run.assert_called_once()
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
class TestSetupPeriodicTasks:
"""Tests for setup_periodic_tasks."""
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
def test_registers_enabled_schedules(self, mock_from_env):
from config.schedule_config import ScheduleConfig
from models.listing import ListingType
schedule = ScheduleConfig(
name="Test Schedule",
listing_type=ListingType.RENT,
hour="3",
minute="30",
)
mock_config = MagicMock()
mock_config.get_enabled_schedules.return_value = [schedule]
mock_from_env.return_value = mock_config
sender = MagicMock()
module.setup_periodic_tasks(sender)
sender.add_periodic_task.assert_called_once()
call_args = sender.add_periodic_task.call_args
assert call_args[1]["name"] == "Test Schedule"
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
def test_handles_config_error_gracefully(self, mock_from_env):
mock_from_env.side_effect = ValueError("bad config")
sender = MagicMock()
module.setup_periodic_tasks(sender)
sender.add_periodic_task.assert_not_called()
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
def test_registers_nothing_when_no_schedules(self, mock_from_env):
mock_config = MagicMock()
mock_config.get_enabled_schedules.return_value = []
mock_from_env.return_value = mock_config
sender = MagicMock()
module.setup_periodic_tasks(sender)
sender.add_periodic_task.assert_not_called()
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
def test_registers_multiple_schedules(self, mock_from_env):
from config.schedule_config import ScheduleConfig
from models.listing import ListingType
schedules = [
ScheduleConfig(name="Rent", listing_type=ListingType.RENT, hour="2"),
ScheduleConfig(name="Buy", listing_type=ListingType.BUY, hour="4"),
]
mock_config = MagicMock()
mock_config.get_enabled_schedules.return_value = schedules
mock_from_env.return_value = mock_config
sender = MagicMock()
module.setup_periodic_tasks(sender)
assert sender.add_periodic_task.call_count == 2
class TestPipelineState:
"""Tests for _PipelineState dataclass."""
def test_default_initialization(self):
state = _PipelineState()
assert state.ids_collected == 0
assert state.completed_subqueries == 0
assert state.total_pages_fetched == 0
assert state.fetching_done is False
assert state.processed_count == 0
assert state.failed_count == 0
assert state.details_fetched == 0
assert state.images_downloaded == 0
assert state.ocr_completed == 0
assert state.processed_listings == []
def test_incrementing_counters(self):
state = _PipelineState()
state.ids_collected += 5
state.completed_subqueries += 3
state.total_pages_fetched += 10
state.processed_count += 4
state.failed_count += 1
state.details_fetched += 4
state.images_downloaded += 3
state.ocr_completed += 2
assert state.ids_collected == 5
assert state.completed_subqueries == 3
assert state.total_pages_fetched == 10
assert state.processed_count == 4
assert state.failed_count == 1
assert state.details_fetched == 4
assert state.images_downloaded == 3
assert state.ocr_completed == 2
def test_appending_to_processed_listings(self):
state = _PipelineState()
state.processed_listings.append("listing_a")
state.processed_listings.append("listing_b")
assert len(state.processed_listings) == 2
assert state.processed_listings == ["listing_a", "listing_b"]
def test_separate_instances_have_independent_lists(self):
state_a = _PipelineState()
state_b = _PipelineState()
state_a.processed_listings.append("only_a")
assert state_b.processed_listings == []
def test_fetching_done_toggle(self):
state = _PipelineState()
assert state.fetching_done is False
state.fetching_done = True
assert state.fetching_done is True
class TestPhaseConstants:
"""Tests for phase constant values."""
def test_phase_splitting(self):
assert PHASE_SPLITTING == "splitting"
def test_phase_fetching(self):
assert PHASE_FETCHING == "fetching"
def test_phase_processing(self):
assert PHASE_PROCESSING == "processing"
def test_phase_completed(self):
assert PHASE_COMPLETED == "completed"
def test_num_workers(self):
assert NUM_WORKERS == 20

View file

@ -1,16 +1,24 @@
"""Unit tests for Listing models."""
import dataclasses
from datetime import datetime
import json
import pytest
from pydantic import ValidationError
from models.listing import (
BuyListing,
DestinationMode,
FurnishType,
ListingSite,
ListingType,
PriceHistoryItem,
QueryParameters,
RentListing,
Listing,
Route,
RouteLegStep,
)
from rec.routing import TravelMode
class TestListing:
@ -341,3 +349,190 @@ class TestBuyListing:
lease_left=120,
)
assert listing.lease_left == 120
def _make_listing_with_routing(routing_info_json: str | None) -> RentListing:
"""Helper to create a RentListing with given routing_info_json."""
return RentListing(
id=1,
price=2000.0,
number_of_bedrooms=2,
square_meters=50.0,
agency="Test",
council_tax_band="C",
longitude=0.0,
latitude=0.0,
price_history_json="[]",
listing_site=ListingSite.RIGHTMOVE,
last_seen=datetime.now(),
photo_thumbnail=None,
floorplan_image_paths=[],
additional_info={"property": {"visible": True}},
routing_info_json=routing_info_json,
furnish_type=FurnishType.FURNISHED,
available_from=None,
)
def _make_sample_routing_info() -> dict[DestinationMode, list[Route]]:
"""Helper to create sample routing info for tests."""
destination_mode = DestinationMode(
destination_address="London Bridge",
travel_mode=TravelMode.TRANSIT,
)
routes = [
Route(
legs=[
RouteLegStep(
distance_meters=500,
duration_s=120,
travel_mode=TravelMode.WALK,
),
RouteLegStep(
distance_meters=4000,
duration_s=480,
travel_mode=TravelMode.TRANSIT,
),
],
distance_meters=4500,
duration_s=600,
)
]
return {destination_mode: routes}
class TestQueryParametersValidation:
"""Tests for QueryParameters validation."""
def test_valid_parameters(self) -> None:
"""Basic valid QueryParameters creation."""
params = QueryParameters(
listing_type=ListingType.RENT,
min_price=1000,
max_price=3000,
min_bedrooms=1,
max_bedrooms=3,
)
assert params.min_price == 1000
assert params.max_price == 3000
assert params.min_bedrooms == 1
assert params.max_bedrooms == 3
def test_invalid_price_range_raises(self) -> None:
"""min_price > max_price should raise ValidationError."""
with pytest.raises(ValidationError, match="min_price.*must be <= max_price"):
QueryParameters(
listing_type=ListingType.RENT,
min_price=5000,
max_price=1000,
)
def test_invalid_bedroom_range_raises(self) -> None:
"""min_bedrooms > max_bedrooms should raise ValidationError."""
with pytest.raises(ValidationError, match="min_bedrooms.*must be <= max_bedrooms"):
QueryParameters(
listing_type=ListingType.RENT,
min_bedrooms=5,
max_bedrooms=2,
)
def test_negative_bedrooms_raises(self) -> None:
"""Negative bedroom counts should raise ValidationError."""
with pytest.raises(ValidationError, match="min_bedrooms.*must be non-negative"):
QueryParameters(
listing_type=ListingType.RENT,
min_bedrooms=-1,
max_bedrooms=3,
)
class TestDestinationMode:
"""Tests for DestinationMode."""
def test_to_dict(self) -> None:
"""Test to_dict returns correct dict."""
dm = DestinationMode(
destination_address="London Bridge",
travel_mode=TravelMode.TRANSIT,
)
result = dm.to_dict()
assert result == {
"destination_address": "London Bridge",
"travel_mode": TravelMode.TRANSIT,
}
def test_hash(self) -> None:
"""Test hashing works correctly."""
dm1 = DestinationMode(
destination_address="London Bridge",
travel_mode=TravelMode.TRANSIT,
)
dm2 = DestinationMode(
destination_address="London Bridge",
travel_mode=TravelMode.TRANSIT,
)
dm3 = DestinationMode(
destination_address="King's Cross",
travel_mode=TravelMode.TRANSIT,
)
assert hash(dm1) == hash(dm2)
assert dm1 == dm2
assert hash(dm1) != hash(dm3)
# Can be used as dict key
d = {dm1: "route1"}
assert d[dm2] == "route1"
class TestRoutingInfoSerialization:
"""Tests for routing info via RouteSerializer."""
def test_routing_info_property_returns_parsed_routes(self) -> None:
"""Test routing_info property deserializes correctly."""
routing_info = _make_sample_routing_info()
listing = _make_listing_with_routing(None)
serialized = listing.serialize_routing_info(routing_info)
listing.routing_info_json = serialized
result = listing.routing_info
assert len(result) == 1
dest_mode = list(result.keys())[0]
assert dest_mode.destination_address == "London Bridge"
assert dest_mode.travel_mode == TravelMode.TRANSIT
routes = result[dest_mode]
assert len(routes) == 1
assert routes[0].distance_meters == 4500
assert routes[0].duration_s == 600
assert len(routes[0].legs) == 2
assert routes[0].legs[0].distance_meters == 500
assert routes[0].legs[0].travel_mode == TravelMode.WALK
def test_routing_info_empty_json(self) -> None:
"""Test routing_info with no routing data."""
listing = _make_listing_with_routing(None)
assert listing.routing_info == {}
def test_serialize_routing_info_roundtrip(self) -> None:
"""Test serialize then deserialize via routing_info property."""
routing_info = _make_sample_routing_info()
listing = _make_listing_with_routing(None)
# Serialize
serialized = listing.serialize_routing_info(routing_info)
assert isinstance(serialized, str)
# Assign and deserialize via property
listing.routing_info_json = serialized
deserialized = listing.routing_info
# Compare
orig_dm = list(routing_info.keys())[0]
result_dm = list(deserialized.keys())[0]
assert orig_dm.destination_address == result_dm.destination_address
assert orig_dm.travel_mode == result_dm.travel_mode
orig_route = routing_info[orig_dm][0]
result_route = deserialized[result_dm][0]
assert orig_route.distance_meters == result_route.distance_meters
assert orig_route.duration_s == result_route.duration_s
assert len(orig_route.legs) == len(result_route.legs)

View file

@ -0,0 +1,385 @@
"""Unit tests for rec/query.py."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
from rec.query import (
detail_query,
listing_query,
probe_query,
PropertyType,
create_session,
_build_base_params,
_build_listing_params,
_build_probe_params,
ANDROID_APP_VERSION,
ANDROID_APP_VERSION_LISTING,
RIGHTMOVE_API_BASE,
PROPERTY_LISTING_ENDPOINT,
DEFAULT_HEADERS,
LISTING_HEADERS,
check_circuit_breaker,
reset_circuit_breaker,
get_circuit_breaker,
)
from models.listing import ListingType, FurnishType
from config.scraper_config import ScraperConfig
from rec.exceptions import CircuitBreakerOpenError
from rec.throttle_detector import reset_throttle_metrics
@pytest.fixture
def config() -> ScraperConfig:
return ScraperConfig(
max_concurrent_requests=5,
request_delay_ms=10,
slow_response_threshold=10.0,
enable_circuit_breaker=True,
circuit_breaker_failure_threshold=3,
circuit_breaker_recovery_timeout=0.5,
)
@pytest.fixture
def config_no_cb() -> ScraperConfig:
return ScraperConfig(enable_circuit_breaker=False)
@pytest.fixture(autouse=True)
def reset_globals() -> None:
reset_throttle_metrics()
reset_circuit_breaker()
class MockResponse:
def __init__(
self,
status: int = 200,
json_data: dict | None = None,
text: str = "",
):
self.status = status
self._json_data = json_data or {}
self._text = text
async def json(self) -> dict:
return self._json_data
async def text(self) -> str:
return self._text
async def __aenter__(self) -> "MockResponse":
return self
async def __aexit__(self, *args: object) -> None:
pass
def make_mock_session(response: MockResponse) -> MagicMock:
"""Create a mock session whose .get() returns an async context manager."""
mock_session = MagicMock()
mock_session.get = MagicMock(return_value=response)
return mock_session
def make_mock_session_fn(get_fn: object) -> MagicMock:
"""Create a mock session whose .get() calls a function to produce responses."""
mock_session = MagicMock()
mock_session.get = MagicMock(side_effect=get_fn)
return mock_session
class TestBuildBaseParams:
def test_constructs_correct_params(self) -> None:
with patch("rec.query.districts.get_districts", return_value={"TestDistrict": "REGION^123"}):
params = _build_base_params(
channel=ListingType.RENT,
page=2,
page_size=25,
radius=1.5,
min_price=1000,
max_price=3000,
min_bedrooms=1,
max_bedrooms=3,
district="TestDistrict",
)
assert params["locationIdentifier"] == "REGION^123"
assert params["channel"] == "RENT"
assert params["page"] == "2"
assert params["numberOfPropertiesPerPage"] == "25"
assert params["radius"] == "1.5"
assert params["sortBy"] == "distance"
assert params["includeUnavailableProperties"] == "false"
assert params["minPrice"] == "1000"
assert params["maxPrice"] == "3000"
assert params["minBedrooms"] == "1"
assert params["maxBedrooms"] == "3"
assert params["apiApplication"] == "ANDROID"
assert params["appVersion"] == ANDROID_APP_VERSION_LISTING
def test_buy_channel_includes_dont_show_and_max_days(self) -> None:
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
params = _build_listing_params(
page=1,
channel=ListingType.BUY,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=100000,
max_price=500000,
district="D",
mustNewHome=False,
max_days_since_added=7,
property_type=[],
page_size=25,
furnish_types=[],
)
assert params["dontShow"] == "sharedOwnership,retirement"
assert params["maxDaysSinceAdded"] == "7"
def test_rent_channel_includes_furnish_types(self) -> None:
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
params = _build_listing_params(
page=1,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=3000,
district="D",
mustNewHome=False,
max_days_since_added=30,
property_type=[],
page_size=25,
furnish_types=[FurnishType.FURNISHED, FurnishType.UNFURNISHED],
)
assert params["furnishTypes"] == "furnished,unfurnished"
assert "dontShow" not in params
assert "maxDaysSinceAdded" not in params
def test_buy_channel_probe_includes_dont_show(self) -> None:
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
params = _build_probe_params(
channel=ListingType.BUY,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=100000,
max_price=500000,
district="D",
max_days_since_added=7,
furnish_types=[],
)
assert params["dontShow"] == "sharedOwnership,retirement"
assert params["maxDaysSinceAdded"] == "7"
assert params["numberOfPropertiesPerPage"] == "1"
def test_probe_buy_skips_max_days_if_not_valid(self) -> None:
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
params = _build_probe_params(
channel=ListingType.BUY,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=100000,
max_price=500000,
district="D",
max_days_since_added=30,
furnish_types=[],
)
# 30 is not in [1, 3, 7, 14], so maxDaysSinceAdded is not added for probe
assert "maxDaysSinceAdded" not in params
class TestMutableDefaultArgFix:
@pytest.mark.asyncio
async def test_property_type_default_not_shared(self, config: ScraperConfig) -> None:
"""Calling listing_query with no property_type should not share state between calls."""
response = MockResponse(
status=200,
json_data={"totalAvailableResults": 0, "properties": []},
)
mock_session = make_mock_session(response)
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
# Call twice without explicit property_type
await listing_query(
page=1,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
session=mock_session,
config=config,
)
await listing_query(
page=1,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
session=mock_session,
config=config,
)
# If mutable default was shared, this test would detect mutations.
# The fact that it completes without error proves defaults are independent.
@pytest.mark.asyncio
async def test_furnish_types_default_not_shared(self, config: ScraperConfig) -> None:
"""Calling probe_query with no furnish_types should not share state between calls."""
response = MockResponse(
status=200,
json_data={"totalAvailableResults": 0, "properties": []},
)
mock_session = make_mock_session(response)
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
config=config,
)
await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
config=config,
)
class TestPropertyTypeEnum:
def test_enum_values(self) -> None:
assert PropertyType.BUNGALOW == "bungalow"
assert PropertyType.DETACHED == "detached"
assert PropertyType.FLAT == "flat"
assert PropertyType.LAND == "land"
assert PropertyType.PARK_HOME == "park-home"
assert PropertyType.SEMI_DETACHED == "semi-detached"
assert PropertyType.TERRACED == "terraced"
def test_enum_is_str(self) -> None:
assert isinstance(PropertyType.FLAT, str)
assert ",".join([PropertyType.FLAT, PropertyType.DETACHED]) == "flat,detached"
class TestDetailQuery:
@pytest.mark.asyncio
async def test_success_200(self, config: ScraperConfig) -> None:
expected_body = {"id": 12345, "address": "123 Test St"}
response = MockResponse(status=200, json_data=expected_body)
mock_session = make_mock_session(response)
result = await detail_query(12345, session=mock_session, config=config)
assert result == expected_body
@pytest.mark.asyncio
async def test_raises_on_non_200(self, config: ScraperConfig) -> None:
response = MockResponse(status=404, text="Not Found")
mock_session = make_mock_session(response)
with pytest.raises(Exception, match="Failed due to"):
await detail_query(99999, session=mock_session, config=config)
class TestCircuitBreakerBlocksRequests:
@pytest.mark.asyncio
async def test_circuit_breaker_blocks_when_open(self, config: ScraperConfig) -> None:
cb = get_circuit_breaker(config)
assert cb is not None
for _ in range(config.circuit_breaker_failure_threshold):
cb.record_failure()
assert cb.is_open
mock_session = MagicMock()
with pytest.raises(CircuitBreakerOpenError):
await detail_query(1, session=mock_session, config=config)
@pytest.mark.asyncio
async def test_circuit_breaker_blocks_listing_query(self, config: ScraperConfig) -> None:
cb = get_circuit_breaker(config)
assert cb is not None
for _ in range(config.circuit_breaker_failure_threshold):
cb.record_failure()
mock_session = MagicMock()
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
with pytest.raises(CircuitBreakerOpenError):
await listing_query(
page=1,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
session=mock_session,
config=config,
)
@pytest.mark.asyncio
async def test_circuit_breaker_blocks_probe_query(self, config: ScraperConfig) -> None:
cb = get_circuit_breaker(config)
assert cb is not None
for _ in range(config.circuit_breaker_failure_threshold):
cb.record_failure()
mock_session = MagicMock()
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
with pytest.raises(CircuitBreakerOpenError):
await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="D",
config=config,
)
class TestConstants:
def test_android_app_version(self) -> None:
assert ANDROID_APP_VERSION == "3.70.0"
def test_android_app_version_listing(self) -> None:
assert ANDROID_APP_VERSION_LISTING == "4.28.0"
def test_rightmove_api_base(self) -> None:
assert RIGHTMOVE_API_BASE == "https://api.rightmove.co.uk/api"
def test_property_listing_endpoint(self) -> None:
assert PROPERTY_LISTING_ENDPOINT == "https://api.rightmove.co.uk/api/property-listing"
def test_listing_headers_extends_default(self) -> None:
for key, value in DEFAULT_HEADERS.items():
assert LISTING_HEADERS[key] == value
assert LISTING_HEADERS["Accept-Encoding"] == "gzip, deflate, br"

View file

@ -161,7 +161,7 @@ class TestQuerySplitter:
mock_session = AsyncMock()
# Mock the probe_query function
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
mock_probe.return_value = {"totalAvailableResults": 800}
count = await splitter.probe_result_count(sq, mock_session, parameters)
@ -184,7 +184,7 @@ class TestQuerySplitter:
mock_session = AsyncMock()
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
mock_probe.side_effect = Exception("API error")
count = await splitter.probe_result_count(sq, mock_session, parameters)
@ -208,7 +208,7 @@ class TestQuerySplitter:
mock_session = AsyncMock()
mock_semaphore = AsyncMock()
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
# First half has 600 results, second half has 500
mock_probe.side_effect = [
{"totalAvailableResults": 600},
@ -240,7 +240,7 @@ class TestQuerySplitter:
mock_session = AsyncMock()
mock_semaphore = AsyncMock()
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
# First split: 1000-3000 has 1300 (over threshold), 3000-5000 has 800
# Second split of 1000-3000: 1000-2000 has 700, 2000-3000 has 600
mock_probe.side_effect = [
@ -326,7 +326,7 @@ class TestQuerySplitter:
mock_districts = {"Kings Cross": "STATION^5168", "Angel": "STATION^1234"}
with patch("services.query_splitter.get_districts", return_value=mock_districts):
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
# Mock probe results for each initial subquery
# 2 districts × 2 bedroom counts = 4 initial subqueries
mock_probe.side_effect = [
@ -358,11 +358,11 @@ class TestQuerySplitter:
mock_districts = {"Kings Cross": "STATION^5168", "Angel": "STATION^1234"}
progress_calls = []
def on_progress(phase: str, message: str) -> None:
def on_progress(phase: str, message: str, **kwargs: object) -> None:
progress_calls.append((phase, message))
with patch("services.query_splitter.get_districts", return_value=mock_districts):
with patch("services.query_splitter.probe_query") as mock_probe:
with patch("rec.query.probe_query") as mock_probe:
mock_probe.return_value = {"totalAvailableResults": 500}
await splitter.split(parameters, mock_session, on_progress)

View file

@ -1,5 +1,6 @@
"""Unit tests for ListingRepository."""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import pytest
from sqlalchemy import Engine
@ -225,3 +226,156 @@ class TestListingRepositoryFilters:
listings = await listing_repository.get_listings(query_parameters=query_params)
# Should match listings with 1-2 bedrooms in price range
assert len(listings) == 2
class TestListingRepositoryStreaming:
"""Tests for streaming and optimized query methods."""
async def test_count_listings_empty_db(
self, listing_repository: ListingRepository
) -> None:
"""Test count returns 0 for empty database."""
count = listing_repository.count_listings()
assert count == 0
async def test_count_listings_with_data(
self,
listing_repository: ListingRepository,
sample_rent_listings: list[RentListing],
) -> None:
"""Test count returns correct number."""
await listing_repository.upsert_listings(sample_rent_listings)
count = listing_repository.count_listings()
assert count == 3
async def test_count_listings_with_filters(
self,
listing_repository: ListingRepository,
sample_rent_listings: list[RentListing],
) -> None:
"""Test count respects query parameters."""
await listing_repository.upsert_listings(sample_rent_listings)
query_params = QueryParameters(
listing_type=ListingType.RENT,
min_bedrooms=2,
max_bedrooms=3,
)
count = listing_repository.count_listings(query_parameters=query_params)
assert count == 2
async def test_stream_listings_optimized_returns_dicts(
self,
listing_repository: ListingRepository,
sample_rent_listings: list[RentListing],
) -> None:
"""Test optimized streaming returns dict rows."""
await listing_repository.upsert_listings(sample_rent_listings)
results = list(listing_repository.stream_listings_optimized())
assert len(results) == 3
# Each result should be a dict
for row in results:
assert isinstance(row, dict)
assert "id" in row
assert "price" in row
assert "number_of_bedrooms" in row
async def test_stream_listings_optimized_respects_limit(
self,
listing_repository: ListingRepository,
sample_rent_listings: list[RentListing],
) -> None:
"""Test streaming limit parameter."""
await listing_repository.upsert_listings(sample_rent_listings)
results = list(listing_repository.stream_listings_optimized(limit=2))
assert len(results) == 2
async def test_get_listing_ids(
self,
listing_repository: ListingRepository,
sample_rent_listings: list[RentListing],
) -> None:
"""Test get_listing_ids returns set of IDs."""
await listing_repository.upsert_listings(sample_rent_listings)
ids = listing_repository.get_listing_ids()
assert isinstance(ids, set)
assert ids == {1, 2, 3}
async def test_get_listing_ids_empty_db(
self,
listing_repository: ListingRepository,
) -> None:
"""Test get_listing_ids returns empty set for empty database."""
ids = listing_repository.get_listing_ids()
assert isinstance(ids, set)
assert len(ids) == 0
class TestFurnishTypeParsing:
"""Tests for _parse_furnish_type helper."""
def test_parse_furnish_type_none_detailobject(self) -> None:
"""Test that None detailobject returns UNKNOWN."""
listing = MagicMock()
listing.detailobject = None
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.UNKNOWN
def test_parse_furnish_type_missing_property_key(self) -> None:
"""Test that missing 'property' key returns UNKNOWN."""
listing = MagicMock()
listing.detailobject = {}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.UNKNOWN
def test_parse_furnish_type_missing_let_furnish_type(self) -> None:
"""Test that missing 'letFurnishType' key returns UNKNOWN."""
listing = MagicMock()
listing.detailobject = {"property": {}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.UNKNOWN
def test_parse_furnish_type_null_value(self) -> None:
"""Test that null letFurnishType value returns UNKNOWN."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": None}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.UNKNOWN
def test_parse_furnish_type_furnished(self) -> None:
"""Test that 'Furnished' is parsed correctly."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": "Furnished"}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.FURNISHED
def test_parse_furnish_type_unfurnished(self) -> None:
"""Test that 'Unfurnished' is parsed correctly."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": "Unfurnished"}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.UNFURNISHED
def test_parse_furnish_type_part_furnished(self) -> None:
"""Test that 'Part Furnished' is parsed correctly."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": "Part Furnished"}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.PART_FURNISHED
def test_parse_furnish_type_landlord_variant(self) -> None:
"""Test that landlord variants map to ASK_LANDLORD."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": "Ask Landlord Please"}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.ASK_LANDLORD
def test_parse_furnish_type_landlord_case_insensitive(self) -> None:
"""Test that landlord check is case-insensitive."""
listing = MagicMock()
listing.detailobject = {"property": {"letFurnishType": "LANDLORD decides"}}
result = ListingRepository._parse_furnish_type(listing)
assert result == FurnishType.ASK_LANDLORD

View file

@ -0,0 +1,10 @@
"""Unit tests for services/route_calculator.py."""
from services.route_calculator import _parse_duration
class TestParseDuration:
def test_parse_normal_duration(self) -> None:
assert _parse_duration("123s") == 123
def test_parse_zero_duration(self) -> None:
assert _parse_duration("0s") == 0

View file

@ -0,0 +1,72 @@
"""Unit tests for rec/route_serializer.py."""
from models.listing import DestinationMode, Route, RouteLegStep
from rec.route_serializer import RouteSerializer
from rec.routing import TravelMode
def _make_sample_routing_info() -> dict[DestinationMode, list[Route]]:
destination_mode = DestinationMode(
destination_address="London Bridge",
travel_mode=TravelMode.TRANSIT,
)
routes = [
Route(
legs=[
RouteLegStep(
distance_meters=500,
duration_s=120,
travel_mode=TravelMode.WALK,
),
RouteLegStep(
distance_meters=4000,
duration_s=480,
travel_mode=TravelMode.TRANSIT,
),
],
distance_meters=4500,
duration_s=600,
)
]
return {destination_mode: routes}
class TestRouteSerializer:
def test_serialize_then_deserialize_roundtrip(self) -> None:
routing_info = _make_sample_routing_info()
serialized = RouteSerializer.serialize(routing_info)
deserialized = RouteSerializer.deserialize(serialized)
assert len(deserialized) == 1
dest_mode = list(deserialized.keys())[0]
assert dest_mode.destination_address == "London Bridge"
assert dest_mode.travel_mode == TravelMode.TRANSIT
routes = deserialized[dest_mode]
assert len(routes) == 1
assert routes[0].distance_meters == 4500
assert routes[0].duration_s == 600
assert len(routes[0].legs) == 2
assert routes[0].legs[0].distance_meters == 500
assert routes[0].legs[0].travel_mode == TravelMode.WALK
assert routes[0].legs[1].travel_mode == TravelMode.TRANSIT
def test_deserialize_sample_json(self) -> None:
import json
import dataclasses
routing_info = _make_sample_routing_info()
# Build the JSON manually to test deserialize independently
json_str = json.dumps(
{
json.dumps(dataclasses.asdict(dm)): [
json.dumps(dataclasses.asdict(r)) for r in routes
]
for dm, routes in routing_info.items()
}
)
result = RouteSerializer.deserialize(json_str)
assert len(result) == 1
dest_mode = list(result.keys())[0]
assert dest_mode.destination_address == "London Bridge"
assert result[dest_mode][0].duration_s == 600

View file

@ -0,0 +1,67 @@
"""Unit tests for rec/routing.py."""
import os
from unittest.mock import patch, MagicMock
import pytest
from rec.routing import TravelMode, transit_route, ROUTES_API_URL, ROUTES_FIELD_MASK
from rec.exceptions import RoutingApiError
class TestTravelMode:
def test_enum_values(self) -> None:
assert TravelMode.TRANSIT == "TRANSIT"
assert TravelMode.BICYCLE == "BICYCLE"
assert TravelMode.WALK == "WALK"
assert TravelMode.DRIVE == "DRIVE"
def test_enum_has_four_members(self) -> None:
assert len(TravelMode) == 4
class TestTransitRoute:
@patch("rec.routing.requests.post")
@patch("rec.routing.nextMonday")
def test_success_response(self, mock_monday: MagicMock, mock_post: MagicMock) -> None:
mock_monday.return_value = MagicMock(
strftime=MagicMock(return_value="2024-01-08T09:00:00.000000Z")
)
expected = {"routes": [{"duration": "600s", "distanceMeters": 5000}]}
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = expected
mock_post.return_value = mock_response
with patch.dict(os.environ, {"ROUTING_API_KEY": "test-key"}):
result = transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
assert result == expected
mock_post.assert_called_once()
call_kwargs = mock_post.call_args
assert call_kwargs.kwargs["headers"]["X-Goog-Api-Key"] == "test-key"
@patch("rec.routing.requests.post")
@patch("rec.routing.nextMonday")
def test_raises_routing_api_error_on_non_200(self, mock_monday: MagicMock, mock_post: MagicMock) -> None:
mock_monday.return_value = MagicMock(
strftime=MagicMock(return_value="2024-01-08T09:00:00.000000Z")
)
error_body = {"error": {"message": "Invalid API key", "status": "PERMISSION_DENIED"}}
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.json.return_value = error_body
mock_post.return_value = mock_response
with patch.dict(os.environ, {"ROUTING_API_KEY": "bad-key"}):
with pytest.raises(RoutingApiError) as exc_info:
transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
assert exc_info.value.status_code == 403
assert exc_info.value.response_body == error_body
def test_raises_key_error_when_api_key_not_set(self) -> None:
env = os.environ.copy()
env.pop("ROUTING_API_KEY", None)
with patch.dict(os.environ, env, clear=True):
with pytest.raises(KeyError):
transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)

View file

@ -0,0 +1,306 @@
"""Unit tests for services/task_service.py."""
from unittest.mock import MagicMock, patch
import pytest
from services.task_service import (
TaskStatus,
_extract_progress_info,
_extract_result,
_make_system_user,
get_task_status,
)
class TestMakeSystemUser:
"""Tests for _make_system_user helper."""
def test_creates_user_with_email(self) -> None:
user = _make_system_user("test@example.com")
assert user.email == "test@example.com"
assert user.sub == ""
assert user.name == ""
def test_different_emails_create_different_users(self) -> None:
u1 = _make_system_user("a@b.com")
u2 = _make_system_user("c@d.com")
assert u1.email != u2.email
class TestExtractResult:
"""Tests for _extract_result helper."""
def test_failed_task_returns_error(self) -> None:
mock_result = MagicMock()
mock_result.failed.return_value = True
mock_result.result = Exception("something broke")
result, error = _extract_result(mock_result)
assert result is None
assert error is not None
assert "something broke" in error
def test_failed_task_with_no_result(self) -> None:
mock_result = MagicMock()
mock_result.failed.return_value = True
mock_result.result = None
result, error = _extract_result(mock_result)
assert result is None
assert error is None
def test_successful_json_serializable_result(self) -> None:
mock_result = MagicMock()
mock_result.failed.return_value = False
mock_result.result = {"count": 42, "status": "done"}
result, error = _extract_result(mock_result)
assert result == {"count": 42, "status": "done"}
assert error is None
def test_non_serializable_result_falls_back_to_str(self) -> None:
mock_result = MagicMock()
mock_result.failed.return_value = False
mock_result.result = object() # not JSON-serializable
result, error = _extract_result(mock_result)
assert isinstance(result, str)
assert error is None
def test_none_result_stays_none(self) -> None:
mock_result = MagicMock()
mock_result.failed.return_value = False
mock_result.result = None
result, error = _extract_result(mock_result)
assert result is None
assert error is None
class TestExtractProgressInfo:
"""Tests for _extract_progress_info helper."""
def test_extracts_progress_fields(self) -> None:
mock_result = MagicMock()
mock_result.info = {"progress": 0.5, "processed": 50, "total": 100}
mock_result.status = "STARTED"
info = _extract_progress_info(mock_result)
assert info["progress"] == 0.5
assert info["processed"] == 50
assert info["total"] == 100
assert info["message"] is None
def test_extracts_message_from_info(self) -> None:
mock_result = MagicMock()
mock_result.info = {"message": "Processing page 3"}
mock_result.status = "STARTED"
info = _extract_progress_info(mock_result)
assert info["message"] == "Processing page 3"
def test_falls_back_to_reason_for_skipped(self) -> None:
mock_result = MagicMock()
mock_result.info = {"reason": "Already running"}
mock_result.status = "SKIPPED"
info = _extract_progress_info(mock_result)
assert info["message"] == "Already running"
def test_custom_state_used_as_message(self) -> None:
mock_result = MagicMock()
mock_result.info = {}
mock_result.status = "Fetching listings"
info = _extract_progress_info(mock_result)
assert info["message"] == "Fetching listings"
def test_standard_state_not_used_as_message(self) -> None:
mock_result = MagicMock()
mock_result.info = {}
mock_result.status = "PENDING"
info = _extract_progress_info(mock_result)
assert info["message"] is None
def test_none_info_returns_all_none(self) -> None:
mock_result = MagicMock()
mock_result.info = None
mock_result.status = "PENDING"
info = _extract_progress_info(mock_result)
assert info == {"progress": None, "processed": None, "total": None, "message": None}
class TestGetTaskStatus:
"""Tests for get_task_status."""
def test_pending_task(self) -> None:
"""Test status for a pending task."""
mock_result = MagicMock()
mock_result.status = "PENDING"
mock_result.failed.return_value = False
mock_result.result = None
mock_result.info = None
mock_result.traceback = None
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
mock_task.AsyncResult.return_value = mock_result
# Patch the lazy import
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
status = get_task_status("test-id")
assert status.task_id == "test-id"
assert status.status == "PENDING"
assert status.error is None
def test_failed_task(self) -> None:
"""Test status for a failed task."""
mock_result = MagicMock()
mock_result.status = "FAILURE"
mock_result.failed.return_value = True
mock_result.result = Exception("something broke")
mock_result.info = None
mock_result.traceback = "Traceback..."
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
mock_task.AsyncResult.return_value = mock_result
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
status = get_task_status("test-id")
assert status.status == "FAILURE"
assert status.error is not None
assert status.traceback == "Traceback..."
def test_custom_state_with_progress(self) -> None:
"""Test that custom states with progress info are extracted correctly."""
mock_result = MagicMock()
mock_result.status = "Fetching listings"
mock_result.failed.return_value = False
mock_result.result = None
mock_result.info = {"progress": 0.5, "processed": 50, "total": 100}
mock_result.traceback = None
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
mock_task.AsyncResult.return_value = mock_result
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
status = get_task_status("test-id")
assert status.progress == 0.5
assert status.processed == 50
assert status.total == 100
def test_successful_task(self) -> None:
"""Test status for a successful task."""
mock_result = MagicMock()
mock_result.status = "SUCCESS"
mock_result.failed.return_value = False
mock_result.result = {"listings_count": 42}
mock_result.info = None
mock_result.traceback = None
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
mock_task.AsyncResult.return_value = mock_result
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
status = get_task_status("test-id")
assert status.status == "SUCCESS"
assert status.result == {"listings_count": 42}
assert status.error is None
class TestGetUserTasks:
"""Tests for get_user_tasks."""
def test_returns_task_list(self) -> None:
mock_redis = MagicMock()
mock_redis.get_tasks_for_user.return_value = ["task-1", "task-2"]
with patch("services.task_service.RedisRepository", create=True) as MockRedisRepo:
MockRedisRepo.instance.return_value = mock_redis
with patch.dict("sys.modules", {"redis_repository": MagicMock(RedisRepository=MockRedisRepo)}):
from services.task_service import get_user_tasks
result = get_user_tasks("test@example.com")
assert result == ["task-1", "task-2"]
def test_returns_empty_list_for_unknown_user(self) -> None:
mock_redis = MagicMock()
mock_redis.get_tasks_for_user.return_value = []
with patch("services.task_service.RedisRepository", create=True) as MockRedisRepo:
MockRedisRepo.instance.return_value = mock_redis
with patch.dict("sys.modules", {"redis_repository": MagicMock(RedisRepository=MockRedisRepo)}):
from services.task_service import get_user_tasks
result = get_user_tasks("nobody@example.com")
assert result == []
class TestCancelTask:
"""Tests for cancel_task."""
def test_cancel_revokes_and_removes(self) -> None:
mock_celery = MagicMock()
mock_redis = MagicMock()
mock_redis.remove_task_for_user.return_value = True
with patch.dict("sys.modules", {
"celery_app": MagicMock(app=mock_celery),
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
}):
from services.task_service import cancel_task
result = cancel_task("task-123", user_email="test@example.com")
assert result is True
mock_celery.control.revoke.assert_called_once_with("task-123", terminate=True)
def test_cancel_without_user_email(self) -> None:
mock_celery = MagicMock()
with patch.dict("sys.modules", {"celery_app": MagicMock(app=mock_celery)}):
from services.task_service import cancel_task
result = cancel_task("task-456")
assert result is True
mock_celery.control.revoke.assert_called_once_with("task-456", terminate=True)
class TestClearAllTasks:
"""Tests for clear_all_tasks."""
def test_clear_with_revoke(self) -> None:
mock_celery = MagicMock()
mock_redis = MagicMock()
mock_redis.get_tasks_for_user.return_value = ["t1", "t2"]
mock_redis.clear_tasks_for_user.return_value = 2
with patch.dict("sys.modules", {
"celery_app": MagicMock(app=mock_celery),
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
}):
from services.task_service import clear_all_tasks
count = clear_all_tasks("test@example.com", revoke=True)
assert count == 2
assert mock_celery.control.revoke.call_count == 2
def test_clear_without_revoke(self) -> None:
mock_celery = MagicMock()
mock_redis = MagicMock()
mock_redis.clear_tasks_for_user.return_value = 3
with patch.dict("sys.modules", {
"celery_app": MagicMock(app=mock_celery),
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
}):
from services.task_service import clear_all_tasks
count = clear_all_tasks("test@example.com", revoke=False)
assert count == 3
mock_celery.control.revoke.assert_not_called()
def test_revoke_failure_logs_warning_and_continues(self) -> None:
mock_celery = MagicMock()
mock_celery.control.revoke.side_effect = Exception("connection lost")
mock_redis = MagicMock()
mock_redis.get_tasks_for_user.return_value = ["t1"]
mock_redis.clear_tasks_for_user.return_value = 1
with patch.dict("sys.modules", {
"celery_app": MagicMock(app=mock_celery),
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
}):
from services.task_service import clear_all_tasks
# Should not raise despite revoke failure
count = clear_all_tasks("test@example.com", revoke=True)
assert count == 1