Refactor codebase following Clean Code principles and add 229 tests
- Extract helpers to reduce function sizes (listing_tasks, app.py, query.py, listing_fetcher) - Replace nonlocal mutations with _PipelineState dataclass in listing_tasks - Fix bugs: isinstance→equality check in repository, verify_exp for OIDC tokens - Consolidate duplicate filter methods in listing_repository - Move hardcoded config to env vars with backward-compatible defaults - Simplify CLI decorator to auto-build QueryParameters - Add deprecation docstring to data_access.py - Test count: 158 → 387 (all passing)
This commit is contained in:
parent
7e05b3c971
commit
150342bb9e
48 changed files with 5029 additions and 990 deletions
151
crawler/tests/unit/test_auth.py
Normal file
151
crawler/tests/unit/test_auth.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
"""Unit tests for api/auth.py."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
from api.auth import (
|
||||
User,
|
||||
_verify_passkey_token,
|
||||
_verify_authentik_token,
|
||||
get_current_user,
|
||||
)
|
||||
from api.config import JWT_SECRET, JWT_ALGORITHM, JWT_ISSUER
|
||||
|
||||
|
||||
def _make_passkey_token(
|
||||
sub: str = "user-123",
|
||||
email: str = "test@example.com",
|
||||
name: str = "Test User",
|
||||
issuer: str = JWT_ISSUER,
|
||||
secret: str = JWT_SECRET,
|
||||
algorithm: str = JWT_ALGORITHM,
|
||||
expires_delta: timedelta | None = timedelta(hours=1),
|
||||
) -> str:
|
||||
"""Helper to mint a passkey-style HS256 JWT."""
|
||||
payload: dict = {"sub": sub, "email": email, "name": name, "iss": issuer}
|
||||
if expires_delta is not None:
|
||||
payload["exp"] = datetime.now(timezone.utc) + expires_delta
|
||||
return pyjwt.encode(payload, secret, algorithm=algorithm)
|
||||
|
||||
|
||||
class TestVerifyPasskeyToken:
|
||||
"""Tests for _verify_passkey_token()."""
|
||||
|
||||
def test_valid_token_returns_user(self) -> None:
|
||||
token = _make_passkey_token()
|
||||
user = _verify_passkey_token(token)
|
||||
assert isinstance(user, User)
|
||||
assert user.sub == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
assert user.name == "Test User"
|
||||
|
||||
def test_valid_token_without_name_uses_email(self) -> None:
|
||||
payload = {
|
||||
"sub": "user-456",
|
||||
"email": "noname@example.com",
|
||||
"iss": JWT_ISSUER,
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
}
|
||||
token = pyjwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
||||
user = _verify_passkey_token(token)
|
||||
assert user.name == "noname@example.com"
|
||||
|
||||
def test_rejects_expired_token(self) -> None:
|
||||
token = _make_passkey_token(expires_delta=timedelta(hours=-1))
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
_verify_passkey_token(token)
|
||||
|
||||
def test_rejects_wrong_secret(self) -> None:
|
||||
token = _make_passkey_token(secret="wrong-secret")
|
||||
with pytest.raises(pyjwt.InvalidSignatureError):
|
||||
_verify_passkey_token(token)
|
||||
|
||||
def test_rejects_wrong_issuer(self) -> None:
|
||||
token = _make_passkey_token(issuer="some-other-issuer")
|
||||
with pytest.raises(pyjwt.InvalidIssuerError):
|
||||
_verify_passkey_token(token)
|
||||
|
||||
|
||||
class TestVerifyAuthentikToken:
|
||||
"""Tests for _verify_authentik_token() — specifically that expiration is verified."""
|
||||
|
||||
async def test_verifies_expiration_after_fix(self) -> None:
|
||||
"""After removing verify_exp: False, expired Authentik tokens should be rejected."""
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
public_key = private_key.public_key()
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
issuer = "https://authentik.viktorbarzin.me/application/o/wrongmove/"
|
||||
payload = {
|
||||
"sub": "authentik-user",
|
||||
"email": "auth@example.com",
|
||||
"name": "Auth User",
|
||||
"iss": issuer,
|
||||
"aud": "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe",
|
||||
"exp": datetime.now(timezone.utc) - timedelta(hours=1), # expired
|
||||
}
|
||||
token = pyjwt.encode(payload, private_key, algorithm="RS256")
|
||||
|
||||
# Build a real PyJWK-compatible signing key mock so jwt.decode
|
||||
# takes the PyJWK code path (uses key.key directly, skips prepare_key)
|
||||
mock_signing_key = MagicMock(spec=pyjwt.PyJWK)
|
||||
mock_signing_key.key = public_key
|
||||
mock_signing_key.algorithm_name = "RS256"
|
||||
mock_signing_key.Algorithm = pyjwt.get_algorithm_by_name("RS256")
|
||||
|
||||
mock_jwks_client = MagicMock()
|
||||
mock_jwks_client.get_signing_key_from_jwt.return_value = mock_signing_key
|
||||
|
||||
mock_metadata = {
|
||||
"issuer": issuer,
|
||||
"jwks_uri": f"{issuer}jwks/",
|
||||
}
|
||||
|
||||
with patch("api.auth.get_oidc_metadata", new_callable=AsyncMock, return_value=mock_metadata), \
|
||||
patch("api.auth.get_cached_jwks_client", new_callable=AsyncMock, return_value=mock_jwks_client):
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
await _verify_authentik_token(token)
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for get_current_user()."""
|
||||
|
||||
async def test_routes_to_passkey_verifier_for_matching_issuer(self) -> None:
|
||||
token = _make_passkey_token()
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
user = await get_current_user(credentials)
|
||||
assert user.sub == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
async def test_routes_to_authentik_for_other_issuer(self) -> None:
|
||||
"""When issuer != JWT_ISSUER, should route to Authentik verifier."""
|
||||
token = _make_passkey_token(issuer="https://authentik.viktorbarzin.me/application/o/wrongmove/")
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
mock_user = User(sub="authentik-user", email="auth@example.com", name="Auth User")
|
||||
with patch("api.auth._verify_authentik_token", new_callable=AsyncMock, return_value=mock_user):
|
||||
user = await get_current_user(credentials)
|
||||
assert user.email == "auth@example.com"
|
||||
|
||||
async def test_raises_http_exception_for_invalid_token(self) -> None:
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="not.a.valid.token")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
async def test_raises_http_exception_for_garbage_token(self) -> None:
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="totalgarbage")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
388
crawler/tests/unit/test_cli.py
Normal file
388
crawler/tests/unit/test_cli.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
"""Characterization and unit tests for the CLI (main.py)."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import click
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from models.listing import FurnishType, ListingType, QueryParameters
|
||||
from main import build_query_parameters, cli, listing_filter_options
|
||||
|
||||
|
||||
class TestBuildQueryParameters:
|
||||
"""Tests for build_query_parameters()."""
|
||||
|
||||
def test_typical_rent_inputs(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London", "Camden"],
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=4,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
furnish_types=["FURNISHED"],
|
||||
available_from=datetime(2025, 6, 1),
|
||||
last_seen_days=7,
|
||||
min_sqm=50,
|
||||
)
|
||||
assert qp.listing_type == ListingType.RENT
|
||||
assert qp.district_names == {"London", "Camden"}
|
||||
assert qp.min_bedrooms == 2
|
||||
assert qp.max_bedrooms == 4
|
||||
assert qp.min_price == 1000
|
||||
assert qp.max_price == 3000
|
||||
assert qp.furnish_types == [FurnishType.FURNISHED]
|
||||
assert qp.let_date_available_from == datetime(2025, 6, 1)
|
||||
assert qp.last_seen_days == 7
|
||||
assert qp.min_sqm == 50
|
||||
|
||||
def test_typical_buy_inputs(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="BUY",
|
||||
district=["Barnet"],
|
||||
min_bedrooms=3,
|
||||
max_bedrooms=5,
|
||||
min_price=200000,
|
||||
max_price=500000,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.listing_type == ListingType.BUY
|
||||
assert qp.district_names == {"Barnet"}
|
||||
assert qp.furnish_types is None
|
||||
assert qp.let_date_available_from is None
|
||||
assert qp.min_sqm is None
|
||||
|
||||
def test_empty_districts_yields_empty_set(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=[],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.district_names == set()
|
||||
|
||||
def test_none_districts_yields_empty_set(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=None,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.district_names == set()
|
||||
|
||||
def test_furnish_types_conversion(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London"],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=["FURNISHED", "UNFURNISHED"],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.furnish_types == [FurnishType.FURNISHED, FurnishType.UNFURNISHED]
|
||||
|
||||
def test_empty_furnish_types_yields_none(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London"],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.furnish_types is None
|
||||
|
||||
def test_default_optional_parameters(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London"],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.radius == 0
|
||||
assert qp.page_size == 500
|
||||
assert qp.max_days_since_added == 14
|
||||
|
||||
|
||||
class TestListingFilterOptionsDecorator:
|
||||
"""Tests for the listing_filter_options decorator."""
|
||||
|
||||
def test_applies_all_expected_options(self) -> None:
|
||||
@click.command()
|
||||
@listing_filter_options
|
||||
def dummy_cmd(**kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
expected_option_names = {
|
||||
"type",
|
||||
"min_bedrooms",
|
||||
"max_bedrooms",
|
||||
"min_price",
|
||||
"max_price",
|
||||
"district",
|
||||
"furnish_types",
|
||||
"available_from",
|
||||
"last_seen_days",
|
||||
"min_sqm",
|
||||
}
|
||||
param_names = {p.name for p in dummy_cmd.params}
|
||||
assert expected_option_names.issubset(param_names), (
|
||||
f"Missing options: {expected_option_names - param_names}"
|
||||
)
|
||||
|
||||
def test_type_option_is_required(self) -> None:
|
||||
@click.command()
|
||||
@listing_filter_options
|
||||
def dummy_cmd(**kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
type_param = next(p for p in dummy_cmd.params if p.name == "type")
|
||||
assert type_param.required is True
|
||||
|
||||
def test_produces_query_parameters_kwarg(self) -> None:
|
||||
"""After refactoring, the decorator should produce a query_parameters kwarg."""
|
||||
captured: dict = {}
|
||||
|
||||
@click.command()
|
||||
@listing_filter_options
|
||||
def dummy_cmd(query_parameters: QueryParameters) -> None:
|
||||
captured["qp"] = query_parameters
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(dummy_cmd, ["--type", "RENT"])
|
||||
assert result.exit_code == 0, f"Command failed: {result.output}"
|
||||
assert isinstance(captured["qp"], QueryParameters)
|
||||
assert captured["qp"].listing_type == ListingType.RENT
|
||||
|
||||
|
||||
class TestDumpListingsCommand:
|
||||
"""Tests for the dump-listings CLI command."""
|
||||
|
||||
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_calls_refresh_listings_with_correct_params(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_refresh: AsyncMock,
|
||||
) -> None:
|
||||
from services.listing_service import RefreshResult
|
||||
|
||||
mock_refresh.return_value = RefreshResult(
|
||||
task_id=None,
|
||||
new_listings_count=5,
|
||||
message="Fetched 5 new listings",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"dump-listings",
|
||||
"--type", "RENT",
|
||||
"--min-bedrooms", "2",
|
||||
"--max-bedrooms", "4",
|
||||
"--min-price", "1000",
|
||||
"--max-price", "3000",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_refresh.assert_called_once()
|
||||
call_args = mock_refresh.call_args
|
||||
qp: QueryParameters = call_args.args[1]
|
||||
assert qp.listing_type == ListingType.RENT
|
||||
assert qp.min_bedrooms == 2
|
||||
assert qp.max_bedrooms == 4
|
||||
assert qp.min_price == 1000
|
||||
assert qp.max_price == 3000
|
||||
assert call_args.kwargs.get("full") is not True
|
||||
|
||||
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_include_processing_flag_passes_full_true(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_refresh: AsyncMock,
|
||||
) -> None:
|
||||
from services.listing_service import RefreshResult
|
||||
|
||||
mock_refresh.return_value = RefreshResult(
|
||||
task_id=None,
|
||||
new_listings_count=0,
|
||||
message="Fetched 0 new listings",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"dump-listings",
|
||||
"--type", "RENT",
|
||||
"--include-processing",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_refresh.assert_called_once()
|
||||
call_kwargs = mock_refresh.call_args.kwargs
|
||||
assert call_kwargs.get("full") is True
|
||||
|
||||
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_include_processing_short_flag(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_refresh: AsyncMock,
|
||||
) -> None:
|
||||
from services.listing_service import RefreshResult
|
||||
|
||||
mock_refresh.return_value = RefreshResult(
|
||||
task_id=None,
|
||||
new_listings_count=0,
|
||||
message="Fetched 0 new listings",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"dump-listings",
|
||||
"--type", "RENT",
|
||||
"-p",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_refresh.assert_called_once()
|
||||
call_kwargs = mock_refresh.call_args.kwargs
|
||||
assert call_kwargs.get("full") is True
|
||||
|
||||
|
||||
class TestExportCsvCommand:
|
||||
"""Tests for the export-csv CLI command."""
|
||||
|
||||
@patch("main.export_service.export_to_csv", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_calls_export_to_csv(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_export: AsyncMock,
|
||||
) -> None:
|
||||
from services.export_service import ExportResult
|
||||
|
||||
mock_export.return_value = ExportResult(
|
||||
success=True,
|
||||
output_path="/tmp/test.csv",
|
||||
data=None,
|
||||
record_count=10,
|
||||
message="Exported 10 listings to /tmp/test.csv",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"export-csv",
|
||||
"--output-file", "/tmp/test.csv",
|
||||
"--type", "RENT",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_export.assert_called_once()
|
||||
call_args = mock_export.call_args
|
||||
qp = call_args[0][2]
|
||||
assert qp.listing_type == ListingType.RENT
|
||||
|
||||
|
||||
class TestExportImmowebCommand:
|
||||
"""Tests for the export-immoweb CLI command."""
|
||||
|
||||
@patch("main.export_service.export_to_geojson", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_calls_export_to_geojson(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_export: AsyncMock,
|
||||
) -> None:
|
||||
from services.export_service import ExportResult
|
||||
|
||||
mock_export.return_value = ExportResult(
|
||||
success=True,
|
||||
output_path="/tmp/test.geojson",
|
||||
data=None,
|
||||
record_count=5,
|
||||
message="Exported 5 listings to /tmp/test.geojson",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"export-immoweb",
|
||||
"--output-file", "/tmp/test.geojson",
|
||||
"--type", "RENT",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_export.assert_called_once()
|
||||
|
||||
|
||||
class TestListDistrictsCommand:
|
||||
"""Tests for the list-districts CLI command."""
|
||||
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_outputs_district_names(self, mock_engine: MagicMock) -> None:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["list-districts"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "London" in result.output
|
||||
assert "Camden" in result.output
|
||||
assert "Available districts" in result.output
|
||||
|
||||
|
||||
class TestRoutingCommand:
|
||||
"""Tests for the routing CLI command."""
|
||||
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_requires_api_key_env_var(self, mock_engine: MagicMock) -> None:
|
||||
runner = CliRunner(env={"ROUTING_API_KEY": None})
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"routing",
|
||||
"--destination-address", "London Bridge",
|
||||
"--travel-mode", "transit",
|
||||
"--limit", "1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "ROUTING_API_KEY" in result.output
|
||||
62
crawler/tests/unit/test_districts.py
Normal file
62
crawler/tests/unit/test_districts.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""Unit tests for rec/districts.py and services/district_service.py."""
|
||||
from rec.districts import get_districts, get_district_by_name
|
||||
from services.district_service import get_all_districts, get_district_names, validate_districts
|
||||
|
||||
|
||||
class TestGetDistricts:
|
||||
def test_returns_non_empty_dict(self) -> None:
|
||||
districts = get_districts()
|
||||
assert isinstance(districts, dict)
|
||||
assert len(districts) > 0
|
||||
|
||||
def test_values_start_with_region_prefix(self) -> None:
|
||||
for name, region_id in get_districts().items():
|
||||
assert region_id.startswith("REGION^"), (
|
||||
f"District '{name}' has value '{region_id}' that doesn't start with REGION^"
|
||||
)
|
||||
|
||||
def test_contains_expected_london_boroughs(self) -> None:
|
||||
districts = get_districts()
|
||||
for borough in ("Camden", "Westminster", "Hackney"):
|
||||
assert borough in districts, f"Expected borough '{borough}' not found"
|
||||
|
||||
|
||||
class TestGetDistrictByName:
|
||||
def test_valid_name_returns_region_id(self) -> None:
|
||||
result = get_district_by_name("Camden")
|
||||
assert result == "REGION^93941"
|
||||
|
||||
def test_invalid_name_returns_none(self) -> None:
|
||||
result = get_district_by_name("Nonexistent District")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetDistrictNames:
|
||||
def test_returns_list_matching_dict_keys(self) -> None:
|
||||
names = get_district_names()
|
||||
assert isinstance(names, list)
|
||||
assert names == list(get_districts().keys())
|
||||
|
||||
|
||||
class TestGetAllDistricts:
|
||||
def test_returns_same_as_get_districts(self) -> None:
|
||||
assert get_all_districts() == get_districts()
|
||||
|
||||
|
||||
class TestValidateDistricts:
|
||||
def test_all_valid_returns_empty_list(self) -> None:
|
||||
result = validate_districts(["Camden", "Westminster", "Hackney"])
|
||||
assert result == []
|
||||
|
||||
def test_some_invalid_returns_invalid_ones(self) -> None:
|
||||
result = validate_districts(["Camden", "Faketown", "Westminster", "Nowhere"])
|
||||
assert result == ["Faketown", "Nowhere"]
|
||||
|
||||
def test_all_invalid_returns_all(self) -> None:
|
||||
invalid = ["Faketown", "Nowhere", "Neverland"]
|
||||
result = validate_districts(invalid)
|
||||
assert result == invalid
|
||||
|
||||
def test_empty_list_returns_empty_list(self) -> None:
|
||||
result = validate_districts([])
|
||||
assert result == []
|
||||
104
crawler/tests/unit/test_floorplan.py
Normal file
104
crawler/tests/unit/test_floorplan.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Unit tests for rec/floorplan.py."""
|
||||
from unittest.mock import patch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import pytest
|
||||
|
||||
from rec.floorplan import extract_total_sqm, improve_img_for_ocr, calculate_ocr
|
||||
|
||||
|
||||
class TestExtractTotalSqm:
|
||||
|
||||
def test_normal_value(self) -> None:
|
||||
assert extract_total_sqm("Total area: 75.5 sq m") == 75.5
|
||||
|
||||
def test_multiple_values_returns_max_in_range(self) -> None:
|
||||
assert extract_total_sqm("Room 1: 20 sqm, Total: 65 sq m") == 65.0
|
||||
|
||||
def test_no_match_returns_none(self) -> None:
|
||||
assert extract_total_sqm("No area info") is None
|
||||
|
||||
def test_below_minimum_returns_none(self) -> None:
|
||||
assert extract_total_sqm("Area: 15 sq m") is None
|
||||
|
||||
def test_above_maximum_returns_none(self) -> None:
|
||||
assert extract_total_sqm("Area: 200 sq m") is None
|
||||
|
||||
def test_edge_just_above_min(self) -> None:
|
||||
assert extract_total_sqm("Area: 30.1 sq m") == 30.1
|
||||
|
||||
def test_edge_just_below_max(self) -> None:
|
||||
assert extract_total_sqm("Area: 159.9 sq m") == 159.9
|
||||
|
||||
def test_exactly_at_min_boundary_returns_none(self) -> None:
|
||||
# MIN_SQM < sqm, so 30 is not strictly greater than 30
|
||||
assert extract_total_sqm("Area: 30 sq m") is None
|
||||
|
||||
def test_exactly_at_max_boundary_returns_none(self) -> None:
|
||||
# sqm < MAX_SQM, so 160 is not strictly less than 160
|
||||
assert extract_total_sqm("Area: 160 sq m") is None
|
||||
|
||||
def test_format_sq_dot_m(self) -> None:
|
||||
assert extract_total_sqm("Area: 80 sq. m") == 80.0
|
||||
|
||||
def test_format_sqm_no_space(self) -> None:
|
||||
assert extract_total_sqm("Area: 80sqm") == 80.0
|
||||
|
||||
def test_format_sq_m_with_space(self) -> None:
|
||||
assert extract_total_sqm("Area: 80 sq m") == 80.0
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
assert extract_total_sqm("") is None
|
||||
|
||||
def test_multiple_valid_values_returns_max(self) -> None:
|
||||
assert extract_total_sqm("Living: 40 sq m, Total: 100 sq m") == 100.0
|
||||
|
||||
|
||||
class TestImproveImgForOcr:
|
||||
|
||||
def test_produces_valid_pil_image(self) -> None:
|
||||
# Create a small test image (50x50 white image)
|
||||
img = Image.fromarray(np.ones((50, 50, 3), dtype=np.uint8) * 200)
|
||||
result = improve_img_for_ocr(img)
|
||||
assert isinstance(result, Image.Image)
|
||||
# Result should be a grayscale (thresholded) image
|
||||
assert result.mode == "L"
|
||||
|
||||
def test_output_dimensions_scaled(self) -> None:
|
||||
img = Image.fromarray(np.ones((100, 100, 3), dtype=np.uint8) * 128)
|
||||
result = improve_img_for_ocr(img)
|
||||
# After 1.2x resize, 100 -> 120
|
||||
assert result.size[0] == 120
|
||||
assert result.size[1] == 120
|
||||
|
||||
|
||||
class TestCalculateOcr:
|
||||
|
||||
def test_invalid_path_raises_file_not_found(self) -> None:
|
||||
with pytest.raises(FileNotFoundError):
|
||||
calculate_ocr("/nonexistent/path/to/image.png")
|
||||
|
||||
def test_returns_sqm_from_first_pass(self, tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
# Create a real image file so the path check passes
|
||||
image_file = tmp_path / "test.png"
|
||||
Image.fromarray(np.ones((10, 10, 3), dtype=np.uint8)).save(str(image_file))
|
||||
|
||||
with patch("pytesseract.image_to_string", return_value="Total: 85 sq m"):
|
||||
result_sqm, result_text = calculate_ocr(str(image_file))
|
||||
|
||||
assert result_sqm == 85.0
|
||||
assert result_text == "Total: 85 sq m"
|
||||
|
||||
def test_falls_back_to_improved_image(self, tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
image_file = tmp_path / "test.png"
|
||||
Image.fromarray(np.ones((10, 10, 3), dtype=np.uint8)).save(str(image_file))
|
||||
|
||||
# First call returns no sqm data, second (on improved image) returns valid data
|
||||
with patch("pytesseract.image_to_string", side_effect=[
|
||||
"No area info here",
|
||||
"Total: 72 sq m",
|
||||
]):
|
||||
result_sqm, result_text = calculate_ocr(str(image_file))
|
||||
|
||||
assert result_sqm == 72.0
|
||||
assert result_text == "Total: 72 sq m"
|
||||
110
crawler/tests/unit/test_floorplan_detector.py
Normal file
110
crawler/tests/unit/test_floorplan_detector.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""Unit tests for services/floorplan_detector.py."""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from models.listing import RentListing, ListingSite, FurnishType
|
||||
from services.floorplan_detector import _calculate_sqm_ocr, detect_floorplan
|
||||
|
||||
|
||||
def _make_listing(**kwargs) -> RentListing: # type: ignore[no-untyped-def]
|
||||
defaults = dict(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=None,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return RentListing(**defaults)
|
||||
|
||||
|
||||
class TestCalculateSqmOcr:
|
||||
|
||||
async def test_skips_listing_with_existing_square_meters(self) -> None:
|
||||
listing = _make_listing(square_meters=50.0)
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is None
|
||||
|
||||
async def test_empty_floorplan_paths_returns_listing_with_zero(self) -> None:
|
||||
listing = _make_listing(floorplan_image_paths=[])
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 0
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_with_mocked_ocr_returning_value(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.return_value = (85.0, "Total: 85 sq m")
|
||||
listing = _make_listing(floorplan_image_paths=["/fake/path.png"])
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 85.0
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_with_mocked_ocr_returning_none(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.return_value = (None, "no data")
|
||||
listing = _make_listing(floorplan_image_paths=["/fake/path.png"])
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 0
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_picks_max_from_multiple_floorplans(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.side_effect = [
|
||||
(50.0, "50 sq m"),
|
||||
(90.0, "90 sq m"),
|
||||
]
|
||||
listing = _make_listing(floorplan_image_paths=["/fake/a.png", "/fake/b.png"])
|
||||
semaphore = asyncio.Semaphore(2)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 90.0
|
||||
|
||||
|
||||
class TestDetectFloorplan:
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_detect_floorplan_with_mocked_repository(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.return_value = (75.0, "75 sq m")
|
||||
|
||||
listing = _make_listing(
|
||||
floorplan_image_paths=["/fake/path.png"],
|
||||
)
|
||||
repository = MagicMock()
|
||||
repository.get_listings = AsyncMock(return_value=[listing])
|
||||
repository.upsert_listings = AsyncMock(return_value=[])
|
||||
|
||||
await detect_floorplan(repository)
|
||||
|
||||
repository.upsert_listings.assert_called_once()
|
||||
upserted = repository.upsert_listings.call_args[0][0]
|
||||
assert len(upserted) == 1
|
||||
assert upserted[0].square_meters == 75.0
|
||||
|
||||
async def test_detect_floorplan_skips_already_processed(self) -> None:
|
||||
listing = _make_listing(square_meters=50.0)
|
||||
repository = MagicMock()
|
||||
repository.get_listings = AsyncMock(return_value=[listing])
|
||||
repository.upsert_listings = AsyncMock(return_value=[])
|
||||
|
||||
await detect_floorplan(repository)
|
||||
|
||||
repository.upsert_listings.assert_called_once()
|
||||
upserted = repository.upsert_listings.call_args[0][0]
|
||||
assert len(upserted) == 0
|
||||
215
crawler/tests/unit/test_image_fetcher.py
Normal file
215
crawler/tests/unit/test_image_fetcher.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
"""Unit tests for the image fetcher service."""
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from tenacity import stop_after_attempt
|
||||
|
||||
from models.listing import RentListing, ListingSite, FurnishType
|
||||
from services.image_fetcher import dump_images_for_listing, MAX_CONCURRENT_DOWNLOADS
|
||||
|
||||
|
||||
def _make_listing(**kwargs) -> RentListing: # type: ignore[no-untyped-def]
|
||||
"""Create a RentListing with sensible defaults for testing."""
|
||||
defaults = dict(
|
||||
id=12345,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=None,
|
||||
agency="Test Agency",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={
|
||||
"property": {
|
||||
"visible": True,
|
||||
"floorplans": [
|
||||
{"url": "https://media.rightmove.co.uk/imgs/floorplan_1.jpg"}
|
||||
],
|
||||
}
|
||||
},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return RentListing(**defaults)
|
||||
|
||||
|
||||
class TestDumpImagesForListing:
|
||||
"""Tests for dump_images_for_listing function."""
|
||||
|
||||
async def test_downloads_floorplan_image(self, tmp_path: Path) -> None:
|
||||
"""Test successful floorplan image download."""
|
||||
listing = _make_listing()
|
||||
image_bytes = b"\x89PNG fake image data"
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read = AsyncMock(return_value=image_bytes)
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == 12345
|
||||
assert len(result.floorplan_image_paths) == 1
|
||||
# Verify the image was written
|
||||
written_path = Path(result.floorplan_image_paths[0])
|
||||
assert written_path.exists()
|
||||
assert written_path.read_bytes() == image_bytes
|
||||
|
||||
async def test_skips_existing_images(self, tmp_path: Path) -> None:
|
||||
"""Test that existing images are not re-downloaded."""
|
||||
listing = _make_listing()
|
||||
# Pre-create the floorplan file
|
||||
floorplan_dir = tmp_path / str(listing.id) / "floorplans"
|
||||
floorplan_dir.mkdir(parents=True)
|
||||
existing_file = floorplan_dir / "floorplan_1.jpg"
|
||||
existing_file.write_bytes(b"existing image")
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
# Should return None because the only floorplan was skipped (continue)
|
||||
assert result is None
|
||||
# Session.get should NOT have been called
|
||||
mock_session.get.assert_not_called()
|
||||
|
||||
async def test_returns_none_on_404(self, tmp_path: Path) -> None:
|
||||
"""Test that 404 responses return None (image not found)."""
|
||||
listing = _make_listing()
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 404
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_raises_on_non_200_status(self, tmp_path: Path) -> None:
|
||||
"""Test that non-200/404 status raises exception."""
|
||||
listing = _make_listing()
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 500
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
with pytest.raises(Exception, match="HTTP 500"):
|
||||
# Disable tenacity retry for testing: stop after 1 attempt and reraise
|
||||
await dump_images_for_listing.retry_with(
|
||||
stop=stop_after_attempt(1),
|
||||
reraise=True,
|
||||
)(listing, tmp_path, session=mock_session)
|
||||
|
||||
async def test_returns_none_when_no_floorplans(self, tmp_path: Path) -> None:
|
||||
"""Test listing with no floorplans returns None."""
|
||||
listing = _make_listing(
|
||||
additional_info={"property": {"visible": True, "floorplans": []}}
|
||||
)
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_url_filename_extraction(self, tmp_path: Path) -> None:
|
||||
"""Test that filenames are correctly extracted from URLs."""
|
||||
listing = _make_listing(
|
||||
additional_info={
|
||||
"property": {
|
||||
"visible": True,
|
||||
"floorplans": [
|
||||
{
|
||||
"url": "https://media.rightmove.co.uk/dir/sub/my_floorplan.png"
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
image_bytes = b"fake png"
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read = AsyncMock(return_value=image_bytes)
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
written_path = Path(result.floorplan_image_paths[0])
|
||||
assert written_path.name == "my_floorplan.png"
|
||||
|
||||
async def test_creates_session_when_none_provided(self, tmp_path: Path) -> None:
|
||||
"""Test that a session is created and closed when none is provided."""
|
||||
listing = _make_listing()
|
||||
image_bytes = b"fake image"
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read = AsyncMock(return_value=image_bytes)
|
||||
|
||||
mock_session_instance = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session_instance.get = MagicMock(return_value=mock_cm)
|
||||
mock_session_instance.close = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"services.image_fetcher.aiohttp.ClientSession",
|
||||
return_value=mock_session_instance,
|
||||
):
|
||||
result = await dump_images_for_listing(listing, tmp_path, session=None)
|
||||
|
||||
assert result is not None
|
||||
mock_session_instance.close.assert_awaited_once()
|
||||
|
||||
|
||||
class TestImageFetcherConfig:
|
||||
"""Tests for image fetcher configuration."""
|
||||
|
||||
def test_max_concurrent_downloads_constant(self) -> None:
|
||||
"""Test that MAX_CONCURRENT_DOWNLOADS is defined and reasonable."""
|
||||
assert MAX_CONCURRENT_DOWNLOADS > 0
|
||||
assert MAX_CONCURRENT_DOWNLOADS <= 20
|
||||
225
crawler/tests/unit/test_listing_cache.py
Normal file
225
crawler/tests/unit/test_listing_cache.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
"""Unit tests for services/listing_cache.py."""
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
from models.listing import ListingType, QueryParameters
|
||||
from services.listing_cache import (
|
||||
CACHE_PREFIX,
|
||||
_get_redis_client,
|
||||
cache_features_batch,
|
||||
get_cached_count,
|
||||
get_cached_features,
|
||||
invalidate_cache,
|
||||
make_cache_key,
|
||||
)
|
||||
|
||||
|
||||
def _make_query(**overrides) -> QueryParameters:
|
||||
"""Create a QueryParameters with defaults for testing."""
|
||||
defaults = {"listing_type": ListingType.RENT, "min_price": 1000, "max_price": 3000}
|
||||
defaults.update(overrides)
|
||||
return QueryParameters(**defaults)
|
||||
|
||||
|
||||
class TestMakeCacheKey:
|
||||
"""Tests for make_cache_key()."""
|
||||
|
||||
def test_deterministic_for_same_params(self):
|
||||
"""Same parameters produce the same cache key."""
|
||||
qp = _make_query()
|
||||
assert make_cache_key(qp) == make_cache_key(qp)
|
||||
|
||||
def test_different_for_different_params(self):
|
||||
"""Different parameters produce different cache keys."""
|
||||
qp1 = _make_query(min_price=1000)
|
||||
qp2 = _make_query(min_price=2000)
|
||||
assert make_cache_key(qp1) != make_cache_key(qp2)
|
||||
|
||||
def test_key_starts_with_prefix(self):
|
||||
"""Cache key starts with CACHE_PREFIX."""
|
||||
qp = _make_query()
|
||||
assert make_cache_key(qp).startswith(CACHE_PREFIX)
|
||||
|
||||
|
||||
class TestGetRedisClient:
|
||||
"""Tests for _get_redis_client() URL parsing."""
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_default_broker_url(self, mock_redis):
|
||||
"""Uses default localhost URL when env var is not set."""
|
||||
with mock.patch.dict("os.environ", {}, clear=True):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://localhost:6379/2", decode_responses=True
|
||||
)
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_custom_broker_url(self, mock_redis):
|
||||
"""Replaces db number from custom broker URL."""
|
||||
with mock.patch.dict(
|
||||
"os.environ", {"CELERY_BROKER_URL": "redis://myhost:1234/5"}
|
||||
):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://myhost:1234/2", decode_responses=True
|
||||
)
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_broker_url_with_password(self, mock_redis):
|
||||
"""Preserves auth info in broker URL."""
|
||||
with mock.patch.dict(
|
||||
"os.environ",
|
||||
{"CELERY_BROKER_URL": "redis://:secret@myhost:6379/0"},
|
||||
):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://:secret@myhost:6379/2", decode_responses=True
|
||||
)
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_broker_url_with_query_params(self, mock_redis):
|
||||
"""Preserves query parameters in broker URL."""
|
||||
with mock.patch.dict(
|
||||
"os.environ",
|
||||
{"CELERY_BROKER_URL": "redis://myhost:6379/0?timeout=5"},
|
||||
):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://myhost:6379/2?timeout=5", decode_responses=True
|
||||
)
|
||||
|
||||
|
||||
class TestGetCachedCount:
|
||||
"""Tests for get_cached_count()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_returns_none_on_redis_error(self, mock_get_client):
|
||||
"""Returns None when Redis raises an error."""
|
||||
mock_get_client.side_effect = redis.RedisError("connection refused")
|
||||
|
||||
result = get_cached_count(_make_query())
|
||||
assert result is None
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_returns_none_when_key_not_exists(self, mock_get_client):
|
||||
"""Returns None when the cache key does not exist."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.exists.return_value = False
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = get_cached_count(_make_query())
|
||||
assert result is None
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_returns_count_when_key_exists(self, mock_get_client):
|
||||
"""Returns list length when key exists."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.exists.return_value = True
|
||||
mock_client.llen.return_value = 42
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = get_cached_count(_make_query())
|
||||
assert result == 42
|
||||
|
||||
|
||||
class TestGetCachedFeatures:
|
||||
"""Tests for get_cached_features()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_yields_empty_on_redis_error(self, mock_get_client):
|
||||
"""Yields nothing when Redis raises an error."""
|
||||
mock_get_client.side_effect = redis.RedisError("connection refused")
|
||||
|
||||
batches = list(get_cached_features(_make_query()))
|
||||
assert batches == []
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_yields_batches(self, mock_get_client):
|
||||
"""Yields features in batches."""
|
||||
features = [{"type": "Feature", "id": i} for i in range(3)]
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.llen.return_value = 3
|
||||
mock_client.lrange.return_value = [json.dumps(f) for f in features]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
batches = list(get_cached_features(_make_query(), batch_size=50))
|
||||
assert len(batches) == 1
|
||||
assert batches[0] == features
|
||||
|
||||
|
||||
class TestCacheFeaturesBatch:
|
||||
"""Tests for cache_features_batch()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_empty_features_returns_early(self, mock_get_client):
|
||||
"""Does not call Redis when features list is empty."""
|
||||
cache_features_batch(_make_query(), [])
|
||||
mock_get_client.assert_not_called()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_writes_features_via_pipeline(self, mock_get_client):
|
||||
"""Writes features and sets TTL through pipeline."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_pipeline = mock.MagicMock()
|
||||
mock_client.pipeline.return_value = mock_pipeline
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
features = [{"type": "Feature", "id": 1}]
|
||||
cache_features_batch(_make_query(), features)
|
||||
|
||||
mock_pipeline.rpush.assert_called_once()
|
||||
mock_pipeline.expire.assert_called_once()
|
||||
mock_pipeline.execute.assert_called_once()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_handles_redis_error(self, mock_get_client):
|
||||
"""Handles Redis error gracefully during write."""
|
||||
mock_get_client.side_effect = redis.RedisError("write error")
|
||||
|
||||
# Should not raise
|
||||
cache_features_batch(_make_query(), [{"id": 1}])
|
||||
|
||||
|
||||
class TestInvalidateCache:
|
||||
"""Tests for invalidate_cache()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_handles_redis_error(self, mock_get_client):
|
||||
"""Handles Redis error gracefully during invalidation."""
|
||||
mock_get_client.side_effect = redis.RedisError("connection refused")
|
||||
|
||||
# Should not raise
|
||||
invalidate_cache()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_deletes_matching_keys_via_pipeline(self, mock_get_client):
|
||||
"""Deletes keys matching the cache prefix using pipeline."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_pipeline = mock.MagicMock()
|
||||
mock_client.pipeline.return_value = mock_pipeline
|
||||
# Simulate one scan iteration that returns keys, then done
|
||||
mock_client.scan.return_value = (0, ["listings:geojson:abc", "listings:geojson:def"])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
invalidate_cache()
|
||||
|
||||
assert mock_pipeline.delete.call_count == 2
|
||||
mock_pipeline.execute.assert_called_once()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_no_keys_to_delete(self, mock_get_client):
|
||||
"""Does nothing when no cache keys exist."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.scan.return_value = (0, [])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
invalidate_cache()
|
||||
|
||||
mock_client.pipeline.assert_not_called()
|
||||
372
crawler/tests/unit/test_listing_fetcher.py
Normal file
372
crawler/tests/unit/test_listing_fetcher.py
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
"""Unit tests for the listing fetcher service."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.listing import ListingType, QueryParameters
|
||||
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
|
||||
from services.listing_fetcher import (
|
||||
NUM_WORKERS,
|
||||
_fetch_subquery,
|
||||
dump_listings,
|
||||
dump_listings_full,
|
||||
)
|
||||
from services.query_splitter import SubQuery
|
||||
|
||||
|
||||
def _make_subquery(**kwargs) -> SubQuery:
|
||||
"""Create a SubQuery with sensible defaults for testing."""
|
||||
defaults = dict(
|
||||
district="REGION^123",
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=3,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
estimated_results=50,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return SubQuery(**defaults)
|
||||
|
||||
|
||||
class TestDumpListingsFull:
|
||||
"""Tests for dump_listings_full."""
|
||||
|
||||
async def test_returns_empty_list_when_no_new_listings(self) -> None:
|
||||
"""Test that empty results from dump_listings returns empty list."""
|
||||
with patch(
|
||||
"services.listing_fetcher.dump_listings",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listings = AsyncMock(return_value=[])
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
result = await dump_listings_full(params, mock_repo)
|
||||
assert result == []
|
||||
|
||||
async def test_returns_only_new_listings_from_db(self) -> None:
|
||||
"""Test that dump_listings_full fetches new listings by ID from the repository."""
|
||||
mock_listing_1 = MagicMock()
|
||||
mock_listing_1.id = 100
|
||||
mock_listing_2 = MagicMock()
|
||||
mock_listing_2.id = 200
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.dump_listings",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_listing_1, mock_listing_2],
|
||||
):
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listings = AsyncMock(
|
||||
return_value=[mock_listing_1, mock_listing_2]
|
||||
)
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
result = await dump_listings_full(params, mock_repo)
|
||||
|
||||
# Verify get_listings was called with the correct IDs
|
||||
mock_repo.get_listings.assert_awaited_once_with(
|
||||
only_ids=[100, 200]
|
||||
)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestFetchSubquery:
|
||||
"""Tests for _fetch_subquery helper."""
|
||||
|
||||
async def test_skips_subquery_with_zero_estimated_results(self) -> None:
|
||||
"""Test that subqueries with 0 estimated results are skipped."""
|
||||
sq = _make_subquery(estimated_results=0)
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=MagicMock(),
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_skips_subquery_with_none_estimated_results(self) -> None:
|
||||
"""Test that subqueries with None estimated results are skipped."""
|
||||
sq = _make_subquery(estimated_results=None)
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=MagicMock(),
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_enqueues_new_ids_only(self) -> None:
|
||||
"""Test that only new (not existing) IDs are enqueued."""
|
||||
sq = _make_subquery(estimated_results=10)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
existing_ids: set[int] = {101, 103}
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
mock_config.max_concurrent_requests = 5
|
||||
|
||||
api_result = {
|
||||
"properties": [
|
||||
{"identifier": 101}, # existing
|
||||
{"identifier": 102}, # new
|
||||
{"identifier": 103}, # existing
|
||||
{"identifier": 104}, # new
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
return_value=api_result,
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=existing_ids,
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 2
|
||||
# Verify that queued IDs are the new ones
|
||||
queued = []
|
||||
while not queue.empty():
|
||||
queued.append(queue.get_nowait())
|
||||
assert 102 in queued
|
||||
assert 104 in queued
|
||||
assert 101 not in queued
|
||||
assert 103 not in queued
|
||||
|
||||
async def test_stops_on_circuit_breaker_error(self) -> None:
|
||||
"""Test that CircuitBreakerOpenError breaks the page loop."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=CircuitBreakerOpenError("open"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_on_throttling_error(self) -> None:
|
||||
"""Test that ThrottlingError breaks the page loop."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ThrottlingError("throttled"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_on_generic_error(self) -> None:
|
||||
"""Test that GENERIC_ERROR (past last page) stops pagination."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("GENERIC_ERROR: no more results"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_on_unexpected_error(self) -> None:
|
||||
"""Test that unexpected errors also stop pagination."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("some network error"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_when_fewer_results_than_page_size(self) -> None:
|
||||
"""Test that pagination stops when a page has fewer results than page_size."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
# Return fewer properties than page_size
|
||||
api_result = {
|
||||
"properties": [
|
||||
{"identifier": 1},
|
||||
{"identifier": 2},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
return_value=api_result,
|
||||
) as mock_query:
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
# Should have called listing_query exactly once (then stopped)
|
||||
assert mock_query.await_count == 1
|
||||
assert ids_found == 2
|
||||
|
||||
|
||||
class TestDumpListings:
|
||||
"""Tests for dump_listings."""
|
||||
|
||||
async def test_circuit_breaker_returns_empty_list(self) -> None:
|
||||
"""Test that CircuitBreakerOpenError returns empty list."""
|
||||
mock_repo = AsyncMock()
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
|
||||
with patch("services.listing_fetcher.create_session") as mock_cs:
|
||||
mock_cs.side_effect = CircuitBreakerOpenError("open")
|
||||
result = await dump_listings(params, mock_repo)
|
||||
assert result == []
|
||||
|
||||
async def test_returns_processed_listings(self) -> None:
|
||||
"""Test that dump_listings returns processed listings from the pipeline."""
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listing_ids = MagicMock(return_value=set())
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.id = 42
|
||||
|
||||
mock_session_cm = AsyncMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.listing_fetcher.create_session",
|
||||
return_value=mock_session_cm,
|
||||
),
|
||||
patch(
|
||||
"services.listing_fetcher.QuerySplitter"
|
||||
) as mock_splitter_cls,
|
||||
patch(
|
||||
"services.listing_fetcher._fetch_subquery",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
),
|
||||
):
|
||||
mock_splitter = mock_splitter_cls.return_value
|
||||
mock_splitter.split = AsyncMock(return_value=[])
|
||||
mock_splitter.calculate_total_estimated_results = MagicMock(
|
||||
return_value=0
|
||||
)
|
||||
|
||||
result = await dump_listings(params, mock_repo)
|
||||
# With no subqueries, no listings are processed
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestNumWorkers:
|
||||
"""Tests for NUM_WORKERS constant."""
|
||||
|
||||
def test_num_workers_is_positive(self) -> None:
|
||||
"""Test that NUM_WORKERS is a positive integer."""
|
||||
assert NUM_WORKERS > 0
|
||||
|
||||
def test_num_workers_value(self) -> None:
|
||||
"""Test that NUM_WORKERS has the expected value."""
|
||||
assert NUM_WORKERS == 20
|
||||
87
crawler/tests/unit/test_listing_processor.py
Normal file
87
crawler/tests/unit/test_listing_processor.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Unit tests for the listing processor."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from models.listing import FurnishType
|
||||
from listing_processor import (
|
||||
_parse_furnish_type,
|
||||
_parse_available_from,
|
||||
ListingProcessor,
|
||||
FetchListingDetailsStep,
|
||||
MAX_OCR_WORKERS,
|
||||
)
|
||||
|
||||
|
||||
class TestParseFurnishType:
|
||||
"""Tests for _parse_furnish_type helper."""
|
||||
|
||||
def test_none_returns_unknown(self):
|
||||
assert _parse_furnish_type(None) == FurnishType.UNKNOWN
|
||||
|
||||
def test_ask_landlord_variant(self):
|
||||
assert _parse_furnish_type("Ask landlord") == FurnishType.ASK_LANDLORD
|
||||
|
||||
def test_furnished_lowercased(self):
|
||||
assert _parse_furnish_type("Furnished") == FurnishType.FURNISHED
|
||||
|
||||
def test_unfurnished(self):
|
||||
assert _parse_furnish_type("Unfurnished") == FurnishType.UNFURNISHED
|
||||
|
||||
def test_part_furnished(self):
|
||||
assert _parse_furnish_type("Part Furnished") == FurnishType.PART_FURNISHED
|
||||
|
||||
def test_unknown_string_returns_unknown(self):
|
||||
assert _parse_furnish_type("unknown") == FurnishType.UNKNOWN
|
||||
|
||||
def test_garbage_string_returns_unknown(self):
|
||||
assert _parse_furnish_type("xyzzy") == FurnishType.UNKNOWN
|
||||
|
||||
|
||||
class TestParseAvailableFrom:
|
||||
"""Tests for _parse_available_from helper."""
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert _parse_available_from(None) is None
|
||||
|
||||
def test_now_returns_datetime(self):
|
||||
result = _parse_available_from("Now")
|
||||
assert isinstance(result, datetime)
|
||||
|
||||
def test_valid_date_string(self):
|
||||
result = _parse_available_from("15/03/2024")
|
||||
assert result is not None
|
||||
assert result.day == 15
|
||||
assert result.month == 3
|
||||
|
||||
def test_invalid_date_returns_none(self):
|
||||
assert _parse_available_from("invalid") is None
|
||||
|
||||
|
||||
class TestListingProcessor:
|
||||
"""Tests for ListingProcessor."""
|
||||
|
||||
async def test_process_listing_marks_seen(self):
|
||||
"""Test that process_listing calls mark_seen."""
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listings = AsyncMock(return_value=[MagicMock()])
|
||||
processor = ListingProcessor(mock_repo)
|
||||
# Mock all steps to not need processing
|
||||
for step in processor.process_steps:
|
||||
step.needs_processing = AsyncMock(return_value=False)
|
||||
await processor.process_listing(123)
|
||||
mock_repo.mark_seen.assert_awaited_once_with(123)
|
||||
|
||||
async def test_process_listing_returns_none_on_step_failure(self):
|
||||
"""Test that a step failure returns None."""
|
||||
mock_repo = AsyncMock()
|
||||
processor = ListingProcessor(mock_repo)
|
||||
for step in processor.process_steps:
|
||||
step.needs_processing = AsyncMock(return_value=True)
|
||||
step.process = AsyncMock(side_effect=Exception("fail"))
|
||||
result = await processor.process_listing(123)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestOcrWorkersConfig:
|
||||
def test_max_ocr_workers_positive(self):
|
||||
assert MAX_OCR_WORKERS >= 1
|
||||
295
crawler/tests/unit/test_listing_tasks.py
Normal file
295
crawler/tests/unit/test_listing_tasks.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
"""Unit tests for tasks/listing_tasks.py."""
|
||||
import json
|
||||
import os
|
||||
from collections import deque
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
import tasks.listing_tasks as module
|
||||
from tasks.listing_tasks import (
|
||||
_update_task_state,
|
||||
_PipelineState,
|
||||
TaskLogHandler,
|
||||
SCRAPE_LOCK_NAME,
|
||||
LOG_BUFFER_MAX_LINES,
|
||||
NUM_WORKERS,
|
||||
PHASE_SPLITTING,
|
||||
PHASE_FETCHING,
|
||||
PHASE_PROCESSING,
|
||||
PHASE_COMPLETED,
|
||||
)
|
||||
|
||||
|
||||
class TestUpdateTaskState:
|
||||
"""Tests for _update_task_state."""
|
||||
|
||||
def test_injects_logs_from_active_buffer(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = deque(["log line 1", "log line 2"])
|
||||
_update_task_state(task, "test_state", {"key": "value"})
|
||||
task.update_state.assert_called_once()
|
||||
call_meta = task.update_state.call_args[1]["meta"]
|
||||
assert call_meta["logs"] == ["log line 1", "log line 2"]
|
||||
assert call_meta["key"] == "value"
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
def test_works_when_buffer_is_none(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = None
|
||||
_update_task_state(task, "some_state", {"phase": "testing"})
|
||||
task.update_state.assert_called_once_with(
|
||||
state="some_state", meta={"phase": "testing"}
|
||||
)
|
||||
# No "logs" key should be injected
|
||||
call_meta = task.update_state.call_args[1]["meta"]
|
||||
assert "logs" not in call_meta
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
def test_state_string_is_passed_through(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = None
|
||||
_update_task_state(task, "PROGRESS", {})
|
||||
task.update_state.assert_called_once_with(state="PROGRESS", meta={})
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
def test_empty_buffer_injects_empty_list(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = deque()
|
||||
_update_task_state(task, "state", {"a": 1})
|
||||
call_meta = task.update_state.call_args[1]["meta"]
|
||||
assert call_meta["logs"] == []
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
|
||||
class TestTaskLogHandler:
|
||||
"""Tests for the TaskLogHandler."""
|
||||
|
||||
def test_emit_appends_to_buffer(self):
|
||||
buf = deque(maxlen=10)
|
||||
handler = TaskLogHandler(buf)
|
||||
handler.setFormatter(
|
||||
__import__("logging").Formatter("%(message)s")
|
||||
)
|
||||
record = __import__("logging").LogRecord(
|
||||
name="test", level=20, pathname="", lineno=0,
|
||||
msg="hello", args=(), exc_info=None,
|
||||
)
|
||||
handler.emit(record)
|
||||
assert "hello" in buf
|
||||
|
||||
def test_buffer_respects_maxlen(self):
|
||||
buf = deque(maxlen=2)
|
||||
handler = TaskLogHandler(buf)
|
||||
handler.setFormatter(
|
||||
__import__("logging").Formatter("%(message)s")
|
||||
)
|
||||
for i in range(5):
|
||||
record = __import__("logging").LogRecord(
|
||||
name="test", level=20, pathname="", lineno=0,
|
||||
msg=f"msg{i}", args=(), exc_info=None,
|
||||
)
|
||||
handler.emit(record)
|
||||
assert len(buf) == 2
|
||||
assert list(buf) == ["msg3", "msg4"]
|
||||
|
||||
|
||||
class TestDumpListingsTask:
|
||||
"""Tests for dump_listings_task Celery task."""
|
||||
|
||||
@patch("tasks.listing_tasks.redis_lock")
|
||||
def test_skips_when_lock_not_acquired(self, mock_redis_lock):
|
||||
"""Task should skip when another scrape is running."""
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__enter__ = MagicMock(return_value=False)
|
||||
mock_cm.__exit__ = MagicMock(return_value=False)
|
||||
mock_redis_lock.return_value = mock_cm
|
||||
|
||||
from tasks.listing_tasks import dump_listings_task
|
||||
|
||||
# Use run() which handles bind=True properly
|
||||
task_instance = dump_listings_task
|
||||
task_instance.update_state = MagicMock()
|
||||
|
||||
result = dump_listings_task.run('{"listing_type": "RENT"}')
|
||||
|
||||
assert result["status"] == "skipped"
|
||||
assert result["reason"] == "another_job_running"
|
||||
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
|
||||
|
||||
@patch("tasks.listing_tasks.asyncio.run")
|
||||
@patch("tasks.listing_tasks.redis_lock")
|
||||
def test_calls_dump_listings_full_when_lock_acquired(
|
||||
self, mock_redis_lock, mock_asyncio_run
|
||||
):
|
||||
"""Task should call dump_listings_full when lock is acquired."""
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__enter__ = MagicMock(return_value=True)
|
||||
mock_cm.__exit__ = MagicMock(return_value=False)
|
||||
mock_redis_lock.return_value = mock_cm
|
||||
|
||||
mock_asyncio_run.return_value = []
|
||||
|
||||
from tasks.listing_tasks import dump_listings_task
|
||||
|
||||
task_instance = dump_listings_task
|
||||
task_instance.update_state = MagicMock()
|
||||
|
||||
params_json = '{"listing_type": "RENT", "min_price": 1000, "max_price": 5000}'
|
||||
result = dump_listings_task.run(params_json)
|
||||
|
||||
assert result["phase"] == "completed"
|
||||
assert result["progress"] == 1
|
||||
mock_asyncio_run.assert_called_once()
|
||||
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
|
||||
|
||||
|
||||
class TestSetupPeriodicTasks:
|
||||
"""Tests for setup_periodic_tasks."""
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_registers_enabled_schedules(self, mock_from_env):
|
||||
from config.schedule_config import ScheduleConfig
|
||||
from models.listing import ListingType
|
||||
|
||||
schedule = ScheduleConfig(
|
||||
name="Test Schedule",
|
||||
listing_type=ListingType.RENT,
|
||||
hour="3",
|
||||
minute="30",
|
||||
)
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_enabled_schedules.return_value = [schedule]
|
||||
mock_from_env.return_value = mock_config
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
sender.add_periodic_task.assert_called_once()
|
||||
call_args = sender.add_periodic_task.call_args
|
||||
assert call_args[1]["name"] == "Test Schedule"
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_handles_config_error_gracefully(self, mock_from_env):
|
||||
mock_from_env.side_effect = ValueError("bad config")
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
sender.add_periodic_task.assert_not_called()
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_registers_nothing_when_no_schedules(self, mock_from_env):
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_enabled_schedules.return_value = []
|
||||
mock_from_env.return_value = mock_config
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
sender.add_periodic_task.assert_not_called()
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_registers_multiple_schedules(self, mock_from_env):
|
||||
from config.schedule_config import ScheduleConfig
|
||||
from models.listing import ListingType
|
||||
|
||||
schedules = [
|
||||
ScheduleConfig(name="Rent", listing_type=ListingType.RENT, hour="2"),
|
||||
ScheduleConfig(name="Buy", listing_type=ListingType.BUY, hour="4"),
|
||||
]
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_enabled_schedules.return_value = schedules
|
||||
mock_from_env.return_value = mock_config
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
assert sender.add_periodic_task.call_count == 2
|
||||
|
||||
|
||||
class TestPipelineState:
|
||||
"""Tests for _PipelineState dataclass."""
|
||||
|
||||
def test_default_initialization(self):
|
||||
state = _PipelineState()
|
||||
assert state.ids_collected == 0
|
||||
assert state.completed_subqueries == 0
|
||||
assert state.total_pages_fetched == 0
|
||||
assert state.fetching_done is False
|
||||
assert state.processed_count == 0
|
||||
assert state.failed_count == 0
|
||||
assert state.details_fetched == 0
|
||||
assert state.images_downloaded == 0
|
||||
assert state.ocr_completed == 0
|
||||
assert state.processed_listings == []
|
||||
|
||||
def test_incrementing_counters(self):
|
||||
state = _PipelineState()
|
||||
state.ids_collected += 5
|
||||
state.completed_subqueries += 3
|
||||
state.total_pages_fetched += 10
|
||||
state.processed_count += 4
|
||||
state.failed_count += 1
|
||||
state.details_fetched += 4
|
||||
state.images_downloaded += 3
|
||||
state.ocr_completed += 2
|
||||
|
||||
assert state.ids_collected == 5
|
||||
assert state.completed_subqueries == 3
|
||||
assert state.total_pages_fetched == 10
|
||||
assert state.processed_count == 4
|
||||
assert state.failed_count == 1
|
||||
assert state.details_fetched == 4
|
||||
assert state.images_downloaded == 3
|
||||
assert state.ocr_completed == 2
|
||||
|
||||
def test_appending_to_processed_listings(self):
|
||||
state = _PipelineState()
|
||||
state.processed_listings.append("listing_a")
|
||||
state.processed_listings.append("listing_b")
|
||||
assert len(state.processed_listings) == 2
|
||||
assert state.processed_listings == ["listing_a", "listing_b"]
|
||||
|
||||
def test_separate_instances_have_independent_lists(self):
|
||||
state_a = _PipelineState()
|
||||
state_b = _PipelineState()
|
||||
state_a.processed_listings.append("only_a")
|
||||
assert state_b.processed_listings == []
|
||||
|
||||
def test_fetching_done_toggle(self):
|
||||
state = _PipelineState()
|
||||
assert state.fetching_done is False
|
||||
state.fetching_done = True
|
||||
assert state.fetching_done is True
|
||||
|
||||
|
||||
class TestPhaseConstants:
|
||||
"""Tests for phase constant values."""
|
||||
|
||||
def test_phase_splitting(self):
|
||||
assert PHASE_SPLITTING == "splitting"
|
||||
|
||||
def test_phase_fetching(self):
|
||||
assert PHASE_FETCHING == "fetching"
|
||||
|
||||
def test_phase_processing(self):
|
||||
assert PHASE_PROCESSING == "processing"
|
||||
|
||||
def test_phase_completed(self):
|
||||
assert PHASE_COMPLETED == "completed"
|
||||
|
||||
def test_num_workers(self):
|
||||
assert NUM_WORKERS == 20
|
||||
|
|
@ -1,16 +1,24 @@
|
|||
"""Unit tests for Listing models."""
|
||||
import dataclasses
|
||||
from datetime import datetime
|
||||
import json
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from models.listing import (
|
||||
BuyListing,
|
||||
DestinationMode,
|
||||
FurnishType,
|
||||
ListingSite,
|
||||
ListingType,
|
||||
PriceHistoryItem,
|
||||
QueryParameters,
|
||||
RentListing,
|
||||
Listing,
|
||||
Route,
|
||||
RouteLegStep,
|
||||
)
|
||||
from rec.routing import TravelMode
|
||||
|
||||
|
||||
class TestListing:
|
||||
|
|
@ -341,3 +349,190 @@ class TestBuyListing:
|
|||
lease_left=120,
|
||||
)
|
||||
assert listing.lease_left == 120
|
||||
|
||||
|
||||
def _make_listing_with_routing(routing_info_json: str | None) -> RentListing:
|
||||
"""Helper to create a RentListing with given routing_info_json."""
|
||||
return RentListing(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=50.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=routing_info_json,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_sample_routing_info() -> dict[DestinationMode, list[Route]]:
|
||||
"""Helper to create sample routing info for tests."""
|
||||
destination_mode = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
routes = [
|
||||
Route(
|
||||
legs=[
|
||||
RouteLegStep(
|
||||
distance_meters=500,
|
||||
duration_s=120,
|
||||
travel_mode=TravelMode.WALK,
|
||||
),
|
||||
RouteLegStep(
|
||||
distance_meters=4000,
|
||||
duration_s=480,
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
),
|
||||
],
|
||||
distance_meters=4500,
|
||||
duration_s=600,
|
||||
)
|
||||
]
|
||||
return {destination_mode: routes}
|
||||
|
||||
|
||||
class TestQueryParametersValidation:
|
||||
"""Tests for QueryParameters validation."""
|
||||
|
||||
def test_valid_parameters(self) -> None:
|
||||
"""Basic valid QueryParameters creation."""
|
||||
params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=3,
|
||||
)
|
||||
assert params.min_price == 1000
|
||||
assert params.max_price == 3000
|
||||
assert params.min_bedrooms == 1
|
||||
assert params.max_bedrooms == 3
|
||||
|
||||
def test_invalid_price_range_raises(self) -> None:
|
||||
"""min_price > max_price should raise ValidationError."""
|
||||
with pytest.raises(ValidationError, match="min_price.*must be <= max_price"):
|
||||
QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_price=5000,
|
||||
max_price=1000,
|
||||
)
|
||||
|
||||
def test_invalid_bedroom_range_raises(self) -> None:
|
||||
"""min_bedrooms > max_bedrooms should raise ValidationError."""
|
||||
with pytest.raises(ValidationError, match="min_bedrooms.*must be <= max_bedrooms"):
|
||||
QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=5,
|
||||
max_bedrooms=2,
|
||||
)
|
||||
|
||||
def test_negative_bedrooms_raises(self) -> None:
|
||||
"""Negative bedroom counts should raise ValidationError."""
|
||||
with pytest.raises(ValidationError, match="min_bedrooms.*must be non-negative"):
|
||||
QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=-1,
|
||||
max_bedrooms=3,
|
||||
)
|
||||
|
||||
|
||||
class TestDestinationMode:
|
||||
"""Tests for DestinationMode."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test to_dict returns correct dict."""
|
||||
dm = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
result = dm.to_dict()
|
||||
assert result == {
|
||||
"destination_address": "London Bridge",
|
||||
"travel_mode": TravelMode.TRANSIT,
|
||||
}
|
||||
|
||||
def test_hash(self) -> None:
|
||||
"""Test hashing works correctly."""
|
||||
dm1 = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
dm2 = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
dm3 = DestinationMode(
|
||||
destination_address="King's Cross",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
assert hash(dm1) == hash(dm2)
|
||||
assert dm1 == dm2
|
||||
assert hash(dm1) != hash(dm3)
|
||||
# Can be used as dict key
|
||||
d = {dm1: "route1"}
|
||||
assert d[dm2] == "route1"
|
||||
|
||||
|
||||
class TestRoutingInfoSerialization:
|
||||
"""Tests for routing info via RouteSerializer."""
|
||||
|
||||
def test_routing_info_property_returns_parsed_routes(self) -> None:
|
||||
"""Test routing_info property deserializes correctly."""
|
||||
routing_info = _make_sample_routing_info()
|
||||
listing = _make_listing_with_routing(None)
|
||||
serialized = listing.serialize_routing_info(routing_info)
|
||||
listing.routing_info_json = serialized
|
||||
|
||||
result = listing.routing_info
|
||||
assert len(result) == 1
|
||||
dest_mode = list(result.keys())[0]
|
||||
assert dest_mode.destination_address == "London Bridge"
|
||||
assert dest_mode.travel_mode == TravelMode.TRANSIT
|
||||
|
||||
routes = result[dest_mode]
|
||||
assert len(routes) == 1
|
||||
assert routes[0].distance_meters == 4500
|
||||
assert routes[0].duration_s == 600
|
||||
assert len(routes[0].legs) == 2
|
||||
assert routes[0].legs[0].distance_meters == 500
|
||||
assert routes[0].legs[0].travel_mode == TravelMode.WALK
|
||||
|
||||
def test_routing_info_empty_json(self) -> None:
|
||||
"""Test routing_info with no routing data."""
|
||||
listing = _make_listing_with_routing(None)
|
||||
assert listing.routing_info == {}
|
||||
|
||||
def test_serialize_routing_info_roundtrip(self) -> None:
|
||||
"""Test serialize then deserialize via routing_info property."""
|
||||
routing_info = _make_sample_routing_info()
|
||||
listing = _make_listing_with_routing(None)
|
||||
|
||||
# Serialize
|
||||
serialized = listing.serialize_routing_info(routing_info)
|
||||
assert isinstance(serialized, str)
|
||||
|
||||
# Assign and deserialize via property
|
||||
listing.routing_info_json = serialized
|
||||
deserialized = listing.routing_info
|
||||
|
||||
# Compare
|
||||
orig_dm = list(routing_info.keys())[0]
|
||||
result_dm = list(deserialized.keys())[0]
|
||||
assert orig_dm.destination_address == result_dm.destination_address
|
||||
assert orig_dm.travel_mode == result_dm.travel_mode
|
||||
|
||||
orig_route = routing_info[orig_dm][0]
|
||||
result_route = deserialized[result_dm][0]
|
||||
assert orig_route.distance_meters == result_route.distance_meters
|
||||
assert orig_route.duration_s == result_route.duration_s
|
||||
assert len(orig_route.legs) == len(result_route.legs)
|
||||
|
|
|
|||
385
crawler/tests/unit/test_query.py
Normal file
385
crawler/tests/unit/test_query.py
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
"""Unit tests for rec/query.py."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import aiohttp
|
||||
|
||||
from rec.query import (
|
||||
detail_query,
|
||||
listing_query,
|
||||
probe_query,
|
||||
PropertyType,
|
||||
create_session,
|
||||
_build_base_params,
|
||||
_build_listing_params,
|
||||
_build_probe_params,
|
||||
ANDROID_APP_VERSION,
|
||||
ANDROID_APP_VERSION_LISTING,
|
||||
RIGHTMOVE_API_BASE,
|
||||
PROPERTY_LISTING_ENDPOINT,
|
||||
DEFAULT_HEADERS,
|
||||
LISTING_HEADERS,
|
||||
check_circuit_breaker,
|
||||
reset_circuit_breaker,
|
||||
get_circuit_breaker,
|
||||
)
|
||||
from models.listing import ListingType, FurnishType
|
||||
from config.scraper_config import ScraperConfig
|
||||
from rec.exceptions import CircuitBreakerOpenError
|
||||
from rec.throttle_detector import reset_throttle_metrics
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config() -> ScraperConfig:
|
||||
return ScraperConfig(
|
||||
max_concurrent_requests=5,
|
||||
request_delay_ms=10,
|
||||
slow_response_threshold=10.0,
|
||||
enable_circuit_breaker=True,
|
||||
circuit_breaker_failure_threshold=3,
|
||||
circuit_breaker_recovery_timeout=0.5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_no_cb() -> ScraperConfig:
|
||||
return ScraperConfig(enable_circuit_breaker=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals() -> None:
|
||||
reset_throttle_metrics()
|
||||
reset_circuit_breaker()
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(
|
||||
self,
|
||||
status: int = 200,
|
||||
json_data: dict | None = None,
|
||||
text: str = "",
|
||||
):
|
||||
self.status = status
|
||||
self._json_data = json_data or {}
|
||||
self._text = text
|
||||
|
||||
async def json(self) -> dict:
|
||||
return self._json_data
|
||||
|
||||
async def text(self) -> str:
|
||||
return self._text
|
||||
|
||||
async def __aenter__(self) -> "MockResponse":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: object) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def make_mock_session(response: MockResponse) -> MagicMock:
|
||||
"""Create a mock session whose .get() returns an async context manager."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(return_value=response)
|
||||
return mock_session
|
||||
|
||||
|
||||
def make_mock_session_fn(get_fn: object) -> MagicMock:
|
||||
"""Create a mock session whose .get() calls a function to produce responses."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(side_effect=get_fn)
|
||||
return mock_session
|
||||
|
||||
|
||||
class TestBuildBaseParams:
|
||||
def test_constructs_correct_params(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"TestDistrict": "REGION^123"}):
|
||||
params = _build_base_params(
|
||||
channel=ListingType.RENT,
|
||||
page=2,
|
||||
page_size=25,
|
||||
radius=1.5,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=3,
|
||||
district="TestDistrict",
|
||||
)
|
||||
|
||||
assert params["locationIdentifier"] == "REGION^123"
|
||||
assert params["channel"] == "RENT"
|
||||
assert params["page"] == "2"
|
||||
assert params["numberOfPropertiesPerPage"] == "25"
|
||||
assert params["radius"] == "1.5"
|
||||
assert params["sortBy"] == "distance"
|
||||
assert params["includeUnavailableProperties"] == "false"
|
||||
assert params["minPrice"] == "1000"
|
||||
assert params["maxPrice"] == "3000"
|
||||
assert params["minBedrooms"] == "1"
|
||||
assert params["maxBedrooms"] == "3"
|
||||
assert params["apiApplication"] == "ANDROID"
|
||||
assert params["appVersion"] == ANDROID_APP_VERSION_LISTING
|
||||
|
||||
def test_buy_channel_includes_dont_show_and_max_days(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_listing_params(
|
||||
page=1,
|
||||
channel=ListingType.BUY,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=100000,
|
||||
max_price=500000,
|
||||
district="D",
|
||||
mustNewHome=False,
|
||||
max_days_since_added=7,
|
||||
property_type=[],
|
||||
page_size=25,
|
||||
furnish_types=[],
|
||||
)
|
||||
|
||||
assert params["dontShow"] == "sharedOwnership,retirement"
|
||||
assert params["maxDaysSinceAdded"] == "7"
|
||||
|
||||
def test_rent_channel_includes_furnish_types(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_listing_params(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
district="D",
|
||||
mustNewHome=False,
|
||||
max_days_since_added=30,
|
||||
property_type=[],
|
||||
page_size=25,
|
||||
furnish_types=[FurnishType.FURNISHED, FurnishType.UNFURNISHED],
|
||||
)
|
||||
|
||||
assert params["furnishTypes"] == "furnished,unfurnished"
|
||||
assert "dontShow" not in params
|
||||
assert "maxDaysSinceAdded" not in params
|
||||
|
||||
def test_buy_channel_probe_includes_dont_show(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_probe_params(
|
||||
channel=ListingType.BUY,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=100000,
|
||||
max_price=500000,
|
||||
district="D",
|
||||
max_days_since_added=7,
|
||||
furnish_types=[],
|
||||
)
|
||||
|
||||
assert params["dontShow"] == "sharedOwnership,retirement"
|
||||
assert params["maxDaysSinceAdded"] == "7"
|
||||
assert params["numberOfPropertiesPerPage"] == "1"
|
||||
|
||||
def test_probe_buy_skips_max_days_if_not_valid(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_probe_params(
|
||||
channel=ListingType.BUY,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=100000,
|
||||
max_price=500000,
|
||||
district="D",
|
||||
max_days_since_added=30,
|
||||
furnish_types=[],
|
||||
)
|
||||
|
||||
# 30 is not in [1, 3, 7, 14], so maxDaysSinceAdded is not added for probe
|
||||
assert "maxDaysSinceAdded" not in params
|
||||
|
||||
|
||||
class TestMutableDefaultArgFix:
|
||||
@pytest.mark.asyncio
|
||||
async def test_property_type_default_not_shared(self, config: ScraperConfig) -> None:
|
||||
"""Calling listing_query with no property_type should not share state between calls."""
|
||||
response = MockResponse(
|
||||
status=200,
|
||||
json_data={"totalAvailableResults": 0, "properties": []},
|
||||
)
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
# Call twice without explicit property_type
|
||||
await listing_query(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
session=mock_session,
|
||||
config=config,
|
||||
)
|
||||
await listing_query(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
session=mock_session,
|
||||
config=config,
|
||||
)
|
||||
# If mutable default was shared, this test would detect mutations.
|
||||
# The fact that it completes without error proves defaults are independent.
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_furnish_types_default_not_shared(self, config: ScraperConfig) -> None:
|
||||
"""Calling probe_query with no furnish_types should not share state between calls."""
|
||||
response = MockResponse(
|
||||
status=200,
|
||||
json_data={"totalAvailableResults": 0, "properties": []},
|
||||
)
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
config=config,
|
||||
)
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
class TestPropertyTypeEnum:
|
||||
def test_enum_values(self) -> None:
|
||||
assert PropertyType.BUNGALOW == "bungalow"
|
||||
assert PropertyType.DETACHED == "detached"
|
||||
assert PropertyType.FLAT == "flat"
|
||||
assert PropertyType.LAND == "land"
|
||||
assert PropertyType.PARK_HOME == "park-home"
|
||||
assert PropertyType.SEMI_DETACHED == "semi-detached"
|
||||
assert PropertyType.TERRACED == "terraced"
|
||||
|
||||
def test_enum_is_str(self) -> None:
|
||||
assert isinstance(PropertyType.FLAT, str)
|
||||
assert ",".join([PropertyType.FLAT, PropertyType.DETACHED]) == "flat,detached"
|
||||
|
||||
|
||||
class TestDetailQuery:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_200(self, config: ScraperConfig) -> None:
|
||||
expected_body = {"id": 12345, "address": "123 Test St"}
|
||||
response = MockResponse(status=200, json_data=expected_body)
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
result = await detail_query(12345, session=mock_session, config=config)
|
||||
assert result == expected_body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_non_200(self, config: ScraperConfig) -> None:
|
||||
response = MockResponse(status=404, text="Not Found")
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
with pytest.raises(Exception, match="Failed due to"):
|
||||
await detail_query(99999, session=mock_session, config=config)
|
||||
|
||||
|
||||
class TestCircuitBreakerBlocksRequests:
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_blocks_when_open(self, config: ScraperConfig) -> None:
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is not None
|
||||
for _ in range(config.circuit_breaker_failure_threshold):
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
await detail_query(1, session=mock_session, config=config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_blocks_listing_query(self, config: ScraperConfig) -> None:
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is not None
|
||||
for _ in range(config.circuit_breaker_failure_threshold):
|
||||
cb.record_failure()
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
await listing_query(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
session=mock_session,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_blocks_probe_query(self, config: ScraperConfig) -> None:
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is not None
|
||||
for _ in range(config.circuit_breaker_failure_threshold):
|
||||
cb.record_failure()
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
class TestConstants:
|
||||
def test_android_app_version(self) -> None:
|
||||
assert ANDROID_APP_VERSION == "3.70.0"
|
||||
|
||||
def test_android_app_version_listing(self) -> None:
|
||||
assert ANDROID_APP_VERSION_LISTING == "4.28.0"
|
||||
|
||||
def test_rightmove_api_base(self) -> None:
|
||||
assert RIGHTMOVE_API_BASE == "https://api.rightmove.co.uk/api"
|
||||
|
||||
def test_property_listing_endpoint(self) -> None:
|
||||
assert PROPERTY_LISTING_ENDPOINT == "https://api.rightmove.co.uk/api/property-listing"
|
||||
|
||||
def test_listing_headers_extends_default(self) -> None:
|
||||
for key, value in DEFAULT_HEADERS.items():
|
||||
assert LISTING_HEADERS[key] == value
|
||||
assert LISTING_HEADERS["Accept-Encoding"] == "gzip, deflate, br"
|
||||
|
|
@ -161,7 +161,7 @@ class TestQuerySplitter:
|
|||
mock_session = AsyncMock()
|
||||
|
||||
# Mock the probe_query function
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
mock_probe.return_value = {"totalAvailableResults": 800}
|
||||
|
||||
count = await splitter.probe_result_count(sq, mock_session, parameters)
|
||||
|
|
@ -184,7 +184,7 @@ class TestQuerySplitter:
|
|||
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
mock_probe.side_effect = Exception("API error")
|
||||
|
||||
count = await splitter.probe_result_count(sq, mock_session, parameters)
|
||||
|
|
@ -208,7 +208,7 @@ class TestQuerySplitter:
|
|||
mock_session = AsyncMock()
|
||||
mock_semaphore = AsyncMock()
|
||||
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
# First half has 600 results, second half has 500
|
||||
mock_probe.side_effect = [
|
||||
{"totalAvailableResults": 600},
|
||||
|
|
@ -240,7 +240,7 @@ class TestQuerySplitter:
|
|||
mock_session = AsyncMock()
|
||||
mock_semaphore = AsyncMock()
|
||||
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
# First split: 1000-3000 has 1300 (over threshold), 3000-5000 has 800
|
||||
# Second split of 1000-3000: 1000-2000 has 700, 2000-3000 has 600
|
||||
mock_probe.side_effect = [
|
||||
|
|
@ -326,7 +326,7 @@ class TestQuerySplitter:
|
|||
mock_districts = {"Kings Cross": "STATION^5168", "Angel": "STATION^1234"}
|
||||
|
||||
with patch("services.query_splitter.get_districts", return_value=mock_districts):
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
# Mock probe results for each initial subquery
|
||||
# 2 districts × 2 bedroom counts = 4 initial subqueries
|
||||
mock_probe.side_effect = [
|
||||
|
|
@ -358,11 +358,11 @@ class TestQuerySplitter:
|
|||
mock_districts = {"Kings Cross": "STATION^5168", "Angel": "STATION^1234"}
|
||||
progress_calls = []
|
||||
|
||||
def on_progress(phase: str, message: str) -> None:
|
||||
def on_progress(phase: str, message: str, **kwargs: object) -> None:
|
||||
progress_calls.append((phase, message))
|
||||
|
||||
with patch("services.query_splitter.get_districts", return_value=mock_districts):
|
||||
with patch("services.query_splitter.probe_query") as mock_probe:
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
mock_probe.return_value = {"totalAvailableResults": 500}
|
||||
|
||||
await splitter.split(parameters, mock_session, on_progress)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Unit tests for ListingRepository."""
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
|
|
@ -225,3 +226,156 @@ class TestListingRepositoryFilters:
|
|||
listings = await listing_repository.get_listings(query_parameters=query_params)
|
||||
# Should match listings with 1-2 bedrooms in price range
|
||||
assert len(listings) == 2
|
||||
|
||||
|
||||
class TestListingRepositoryStreaming:
|
||||
"""Tests for streaming and optimized query methods."""
|
||||
|
||||
async def test_count_listings_empty_db(
|
||||
self, listing_repository: ListingRepository
|
||||
) -> None:
|
||||
"""Test count returns 0 for empty database."""
|
||||
count = listing_repository.count_listings()
|
||||
assert count == 0
|
||||
|
||||
async def test_count_listings_with_data(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test count returns correct number."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
count = listing_repository.count_listings()
|
||||
assert count == 3
|
||||
|
||||
async def test_count_listings_with_filters(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test count respects query parameters."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
query_params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=3,
|
||||
)
|
||||
count = listing_repository.count_listings(query_parameters=query_params)
|
||||
assert count == 2
|
||||
|
||||
async def test_stream_listings_optimized_returns_dicts(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test optimized streaming returns dict rows."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
results = list(listing_repository.stream_listings_optimized())
|
||||
assert len(results) == 3
|
||||
# Each result should be a dict
|
||||
for row in results:
|
||||
assert isinstance(row, dict)
|
||||
assert "id" in row
|
||||
assert "price" in row
|
||||
assert "number_of_bedrooms" in row
|
||||
|
||||
async def test_stream_listings_optimized_respects_limit(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test streaming limit parameter."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
results = list(listing_repository.stream_listings_optimized(limit=2))
|
||||
assert len(results) == 2
|
||||
|
||||
async def test_get_listing_ids(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test get_listing_ids returns set of IDs."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
ids = listing_repository.get_listing_ids()
|
||||
assert isinstance(ids, set)
|
||||
assert ids == {1, 2, 3}
|
||||
|
||||
async def test_get_listing_ids_empty_db(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
) -> None:
|
||||
"""Test get_listing_ids returns empty set for empty database."""
|
||||
ids = listing_repository.get_listing_ids()
|
||||
assert isinstance(ids, set)
|
||||
assert len(ids) == 0
|
||||
|
||||
|
||||
class TestFurnishTypeParsing:
|
||||
"""Tests for _parse_furnish_type helper."""
|
||||
|
||||
def test_parse_furnish_type_none_detailobject(self) -> None:
|
||||
"""Test that None detailobject returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = None
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_missing_property_key(self) -> None:
|
||||
"""Test that missing 'property' key returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_missing_let_furnish_type(self) -> None:
|
||||
"""Test that missing 'letFurnishType' key returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_null_value(self) -> None:
|
||||
"""Test that null letFurnishType value returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": None}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_furnished(self) -> None:
|
||||
"""Test that 'Furnished' is parsed correctly."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Furnished"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.FURNISHED
|
||||
|
||||
def test_parse_furnish_type_unfurnished(self) -> None:
|
||||
"""Test that 'Unfurnished' is parsed correctly."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Unfurnished"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNFURNISHED
|
||||
|
||||
def test_parse_furnish_type_part_furnished(self) -> None:
|
||||
"""Test that 'Part Furnished' is parsed correctly."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Part Furnished"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.PART_FURNISHED
|
||||
|
||||
def test_parse_furnish_type_landlord_variant(self) -> None:
|
||||
"""Test that landlord variants map to ASK_LANDLORD."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Ask Landlord Please"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.ASK_LANDLORD
|
||||
|
||||
def test_parse_furnish_type_landlord_case_insensitive(self) -> None:
|
||||
"""Test that landlord check is case-insensitive."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "LANDLORD decides"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.ASK_LANDLORD
|
||||
|
|
|
|||
10
crawler/tests/unit/test_route_calculator.py
Normal file
10
crawler/tests/unit/test_route_calculator.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""Unit tests for services/route_calculator.py."""
|
||||
from services.route_calculator import _parse_duration
|
||||
|
||||
|
||||
class TestParseDuration:
|
||||
def test_parse_normal_duration(self) -> None:
|
||||
assert _parse_duration("123s") == 123
|
||||
|
||||
def test_parse_zero_duration(self) -> None:
|
||||
assert _parse_duration("0s") == 0
|
||||
72
crawler/tests/unit/test_route_serializer.py
Normal file
72
crawler/tests/unit/test_route_serializer.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""Unit tests for rec/route_serializer.py."""
|
||||
from models.listing import DestinationMode, Route, RouteLegStep
|
||||
from rec.route_serializer import RouteSerializer
|
||||
from rec.routing import TravelMode
|
||||
|
||||
|
||||
def _make_sample_routing_info() -> dict[DestinationMode, list[Route]]:
|
||||
destination_mode = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
routes = [
|
||||
Route(
|
||||
legs=[
|
||||
RouteLegStep(
|
||||
distance_meters=500,
|
||||
duration_s=120,
|
||||
travel_mode=TravelMode.WALK,
|
||||
),
|
||||
RouteLegStep(
|
||||
distance_meters=4000,
|
||||
duration_s=480,
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
),
|
||||
],
|
||||
distance_meters=4500,
|
||||
duration_s=600,
|
||||
)
|
||||
]
|
||||
return {destination_mode: routes}
|
||||
|
||||
|
||||
class TestRouteSerializer:
|
||||
def test_serialize_then_deserialize_roundtrip(self) -> None:
|
||||
routing_info = _make_sample_routing_info()
|
||||
serialized = RouteSerializer.serialize(routing_info)
|
||||
deserialized = RouteSerializer.deserialize(serialized)
|
||||
|
||||
assert len(deserialized) == 1
|
||||
dest_mode = list(deserialized.keys())[0]
|
||||
assert dest_mode.destination_address == "London Bridge"
|
||||
assert dest_mode.travel_mode == TravelMode.TRANSIT
|
||||
|
||||
routes = deserialized[dest_mode]
|
||||
assert len(routes) == 1
|
||||
assert routes[0].distance_meters == 4500
|
||||
assert routes[0].duration_s == 600
|
||||
assert len(routes[0].legs) == 2
|
||||
assert routes[0].legs[0].distance_meters == 500
|
||||
assert routes[0].legs[0].travel_mode == TravelMode.WALK
|
||||
assert routes[0].legs[1].travel_mode == TravelMode.TRANSIT
|
||||
|
||||
def test_deserialize_sample_json(self) -> None:
|
||||
import json
|
||||
import dataclasses
|
||||
|
||||
routing_info = _make_sample_routing_info()
|
||||
# Build the JSON manually to test deserialize independently
|
||||
json_str = json.dumps(
|
||||
{
|
||||
json.dumps(dataclasses.asdict(dm)): [
|
||||
json.dumps(dataclasses.asdict(r)) for r in routes
|
||||
]
|
||||
for dm, routes in routing_info.items()
|
||||
}
|
||||
)
|
||||
|
||||
result = RouteSerializer.deserialize(json_str)
|
||||
assert len(result) == 1
|
||||
dest_mode = list(result.keys())[0]
|
||||
assert dest_mode.destination_address == "London Bridge"
|
||||
assert result[dest_mode][0].duration_s == 600
|
||||
67
crawler/tests/unit/test_routing.py
Normal file
67
crawler/tests/unit/test_routing.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""Unit tests for rec/routing.py."""
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from rec.routing import TravelMode, transit_route, ROUTES_API_URL, ROUTES_FIELD_MASK
|
||||
from rec.exceptions import RoutingApiError
|
||||
|
||||
|
||||
class TestTravelMode:
|
||||
def test_enum_values(self) -> None:
|
||||
assert TravelMode.TRANSIT == "TRANSIT"
|
||||
assert TravelMode.BICYCLE == "BICYCLE"
|
||||
assert TravelMode.WALK == "WALK"
|
||||
assert TravelMode.DRIVE == "DRIVE"
|
||||
|
||||
def test_enum_has_four_members(self) -> None:
|
||||
assert len(TravelMode) == 4
|
||||
|
||||
|
||||
class TestTransitRoute:
|
||||
@patch("rec.routing.requests.post")
|
||||
@patch("rec.routing.nextMonday")
|
||||
def test_success_response(self, mock_monday: MagicMock, mock_post: MagicMock) -> None:
|
||||
mock_monday.return_value = MagicMock(
|
||||
strftime=MagicMock(return_value="2024-01-08T09:00:00.000000Z")
|
||||
)
|
||||
expected = {"routes": [{"duration": "600s", "distanceMeters": 5000}]}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = expected
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with patch.dict(os.environ, {"ROUTING_API_KEY": "test-key"}):
|
||||
result = transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
|
||||
|
||||
assert result == expected
|
||||
mock_post.assert_called_once()
|
||||
call_kwargs = mock_post.call_args
|
||||
assert call_kwargs.kwargs["headers"]["X-Goog-Api-Key"] == "test-key"
|
||||
|
||||
@patch("rec.routing.requests.post")
|
||||
@patch("rec.routing.nextMonday")
|
||||
def test_raises_routing_api_error_on_non_200(self, mock_monday: MagicMock, mock_post: MagicMock) -> None:
|
||||
mock_monday.return_value = MagicMock(
|
||||
strftime=MagicMock(return_value="2024-01-08T09:00:00.000000Z")
|
||||
)
|
||||
error_body = {"error": {"message": "Invalid API key", "status": "PERMISSION_DENIED"}}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 403
|
||||
mock_response.json.return_value = error_body
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with patch.dict(os.environ, {"ROUTING_API_KEY": "bad-key"}):
|
||||
with pytest.raises(RoutingApiError) as exc_info:
|
||||
transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.response_body == error_body
|
||||
|
||||
def test_raises_key_error_when_api_key_not_set(self) -> None:
|
||||
env = os.environ.copy()
|
||||
env.pop("ROUTING_API_KEY", None)
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
with pytest.raises(KeyError):
|
||||
transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
|
||||
306
crawler/tests/unit/test_task_service.py
Normal file
306
crawler/tests/unit/test_task_service.py
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
"""Unit tests for services/task_service.py."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from services.task_service import (
|
||||
TaskStatus,
|
||||
_extract_progress_info,
|
||||
_extract_result,
|
||||
_make_system_user,
|
||||
get_task_status,
|
||||
)
|
||||
|
||||
|
||||
class TestMakeSystemUser:
|
||||
"""Tests for _make_system_user helper."""
|
||||
|
||||
def test_creates_user_with_email(self) -> None:
|
||||
user = _make_system_user("test@example.com")
|
||||
assert user.email == "test@example.com"
|
||||
assert user.sub == ""
|
||||
assert user.name == ""
|
||||
|
||||
def test_different_emails_create_different_users(self) -> None:
|
||||
u1 = _make_system_user("a@b.com")
|
||||
u2 = _make_system_user("c@d.com")
|
||||
assert u1.email != u2.email
|
||||
|
||||
|
||||
class TestExtractResult:
|
||||
"""Tests for _extract_result helper."""
|
||||
|
||||
def test_failed_task_returns_error(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = True
|
||||
mock_result.result = Exception("something broke")
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result is None
|
||||
assert error is not None
|
||||
assert "something broke" in error
|
||||
|
||||
def test_failed_task_with_no_result(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = True
|
||||
mock_result.result = None
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result is None
|
||||
assert error is None
|
||||
|
||||
def test_successful_json_serializable_result(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = {"count": 42, "status": "done"}
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result == {"count": 42, "status": "done"}
|
||||
assert error is None
|
||||
|
||||
def test_non_serializable_result_falls_back_to_str(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = object() # not JSON-serializable
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert isinstance(result, str)
|
||||
assert error is None
|
||||
|
||||
def test_none_result_stays_none(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = None
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result is None
|
||||
assert error is None
|
||||
|
||||
|
||||
class TestExtractProgressInfo:
|
||||
"""Tests for _extract_progress_info helper."""
|
||||
|
||||
def test_extracts_progress_fields(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {"progress": 0.5, "processed": 50, "total": 100}
|
||||
mock_result.status = "STARTED"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["progress"] == 0.5
|
||||
assert info["processed"] == 50
|
||||
assert info["total"] == 100
|
||||
assert info["message"] is None
|
||||
|
||||
def test_extracts_message_from_info(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {"message": "Processing page 3"}
|
||||
mock_result.status = "STARTED"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] == "Processing page 3"
|
||||
|
||||
def test_falls_back_to_reason_for_skipped(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {"reason": "Already running"}
|
||||
mock_result.status = "SKIPPED"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] == "Already running"
|
||||
|
||||
def test_custom_state_used_as_message(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {}
|
||||
mock_result.status = "Fetching listings"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] == "Fetching listings"
|
||||
|
||||
def test_standard_state_not_used_as_message(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {}
|
||||
mock_result.status = "PENDING"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] is None
|
||||
|
||||
def test_none_info_returns_all_none(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = None
|
||||
mock_result.status = "PENDING"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info == {"progress": None, "processed": None, "total": None, "message": None}
|
||||
|
||||
|
||||
class TestGetTaskStatus:
|
||||
"""Tests for get_task_status."""
|
||||
|
||||
def test_pending_task(self) -> None:
|
||||
"""Test status for a pending task."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "PENDING"
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = None
|
||||
mock_result.info = None
|
||||
mock_result.traceback = None
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
# Patch the lazy import
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.task_id == "test-id"
|
||||
assert status.status == "PENDING"
|
||||
assert status.error is None
|
||||
|
||||
def test_failed_task(self) -> None:
|
||||
"""Test status for a failed task."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "FAILURE"
|
||||
mock_result.failed.return_value = True
|
||||
mock_result.result = Exception("something broke")
|
||||
mock_result.info = None
|
||||
mock_result.traceback = "Traceback..."
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.status == "FAILURE"
|
||||
assert status.error is not None
|
||||
assert status.traceback == "Traceback..."
|
||||
|
||||
def test_custom_state_with_progress(self) -> None:
|
||||
"""Test that custom states with progress info are extracted correctly."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "Fetching listings"
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = None
|
||||
mock_result.info = {"progress": 0.5, "processed": 50, "total": 100}
|
||||
mock_result.traceback = None
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.progress == 0.5
|
||||
assert status.processed == 50
|
||||
assert status.total == 100
|
||||
|
||||
def test_successful_task(self) -> None:
|
||||
"""Test status for a successful task."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "SUCCESS"
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = {"listings_count": 42}
|
||||
mock_result.info = None
|
||||
mock_result.traceback = None
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.status == "SUCCESS"
|
||||
assert status.result == {"listings_count": 42}
|
||||
assert status.error is None
|
||||
|
||||
|
||||
class TestGetUserTasks:
|
||||
"""Tests for get_user_tasks."""
|
||||
|
||||
def test_returns_task_list(self) -> None:
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = ["task-1", "task-2"]
|
||||
|
||||
with patch("services.task_service.RedisRepository", create=True) as MockRedisRepo:
|
||||
MockRedisRepo.instance.return_value = mock_redis
|
||||
with patch.dict("sys.modules", {"redis_repository": MagicMock(RedisRepository=MockRedisRepo)}):
|
||||
from services.task_service import get_user_tasks
|
||||
result = get_user_tasks("test@example.com")
|
||||
assert result == ["task-1", "task-2"]
|
||||
|
||||
def test_returns_empty_list_for_unknown_user(self) -> None:
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = []
|
||||
|
||||
with patch("services.task_service.RedisRepository", create=True) as MockRedisRepo:
|
||||
MockRedisRepo.instance.return_value = mock_redis
|
||||
with patch.dict("sys.modules", {"redis_repository": MagicMock(RedisRepository=MockRedisRepo)}):
|
||||
from services.task_service import get_user_tasks
|
||||
result = get_user_tasks("nobody@example.com")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestCancelTask:
|
||||
"""Tests for cancel_task."""
|
||||
|
||||
def test_cancel_revokes_and_removes(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.remove_task_for_user.return_value = True
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import cancel_task
|
||||
result = cancel_task("task-123", user_email="test@example.com")
|
||||
assert result is True
|
||||
mock_celery.control.revoke.assert_called_once_with("task-123", terminate=True)
|
||||
|
||||
def test_cancel_without_user_email(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"celery_app": MagicMock(app=mock_celery)}):
|
||||
from services.task_service import cancel_task
|
||||
result = cancel_task("task-456")
|
||||
assert result is True
|
||||
mock_celery.control.revoke.assert_called_once_with("task-456", terminate=True)
|
||||
|
||||
|
||||
class TestClearAllTasks:
|
||||
"""Tests for clear_all_tasks."""
|
||||
|
||||
def test_clear_with_revoke(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = ["t1", "t2"]
|
||||
mock_redis.clear_tasks_for_user.return_value = 2
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import clear_all_tasks
|
||||
count = clear_all_tasks("test@example.com", revoke=True)
|
||||
assert count == 2
|
||||
assert mock_celery.control.revoke.call_count == 2
|
||||
|
||||
def test_clear_without_revoke(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.clear_tasks_for_user.return_value = 3
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import clear_all_tasks
|
||||
count = clear_all_tasks("test@example.com", revoke=False)
|
||||
assert count == 3
|
||||
mock_celery.control.revoke.assert_not_called()
|
||||
|
||||
def test_revoke_failure_logs_warning_and_continues(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_celery.control.revoke.side_effect = Exception("connection lost")
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = ["t1"]
|
||||
mock_redis.clear_tasks_for_user.return_value = 1
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import clear_all_tasks
|
||||
# Should not raise despite revoke failure
|
||||
count = clear_all_tasks("test@example.com", revoke=True)
|
||||
assert count == 1
|
||||
Loading…
Add table
Add a link
Reference in a new issue