Flatten repo structure: move crawler/ to root, remove vqa/ and immoweb/
The crawler subdirectory was the only active project. Moving it to the repo root simplifies paths and removes the unnecessary nesting. The vqa/ and immoweb/ directories were legacy/unused and have been removed. Updated .drone.yml, .gitignore, .claude/ docs, and skills to reflect the new flat structure.
This commit is contained in:
parent
e2247be700
commit
eafbc1ac52
221 changed files with 70 additions and 146140 deletions
1
tests/unit/__init__.py
Normal file
1
tests/unit/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# Unit tests package
|
||||
151
tests/unit/test_auth.py
Normal file
151
tests/unit/test_auth.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
"""Unit tests for api/auth.py."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
from api.auth import (
|
||||
User,
|
||||
_verify_passkey_token,
|
||||
_verify_authentik_token,
|
||||
get_current_user,
|
||||
)
|
||||
from api.config import JWT_SECRET, JWT_ALGORITHM, JWT_ISSUER
|
||||
|
||||
|
||||
def _make_passkey_token(
|
||||
sub: str = "user-123",
|
||||
email: str = "test@example.com",
|
||||
name: str = "Test User",
|
||||
issuer: str = JWT_ISSUER,
|
||||
secret: str = JWT_SECRET,
|
||||
algorithm: str = JWT_ALGORITHM,
|
||||
expires_delta: timedelta | None = timedelta(hours=1),
|
||||
) -> str:
|
||||
"""Helper to mint a passkey-style HS256 JWT."""
|
||||
payload: dict = {"sub": sub, "email": email, "name": name, "iss": issuer}
|
||||
if expires_delta is not None:
|
||||
payload["exp"] = datetime.now(timezone.utc) + expires_delta
|
||||
return pyjwt.encode(payload, secret, algorithm=algorithm)
|
||||
|
||||
|
||||
class TestVerifyPasskeyToken:
|
||||
"""Tests for _verify_passkey_token()."""
|
||||
|
||||
def test_valid_token_returns_user(self) -> None:
|
||||
token = _make_passkey_token()
|
||||
user = _verify_passkey_token(token)
|
||||
assert isinstance(user, User)
|
||||
assert user.sub == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
assert user.name == "Test User"
|
||||
|
||||
def test_valid_token_without_name_uses_email(self) -> None:
|
||||
payload = {
|
||||
"sub": "user-456",
|
||||
"email": "noname@example.com",
|
||||
"iss": JWT_ISSUER,
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
}
|
||||
token = pyjwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
||||
user = _verify_passkey_token(token)
|
||||
assert user.name == "noname@example.com"
|
||||
|
||||
def test_rejects_expired_token(self) -> None:
|
||||
token = _make_passkey_token(expires_delta=timedelta(hours=-1))
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
_verify_passkey_token(token)
|
||||
|
||||
def test_rejects_wrong_secret(self) -> None:
|
||||
token = _make_passkey_token(secret="wrong-secret")
|
||||
with pytest.raises(pyjwt.InvalidSignatureError):
|
||||
_verify_passkey_token(token)
|
||||
|
||||
def test_rejects_wrong_issuer(self) -> None:
|
||||
token = _make_passkey_token(issuer="some-other-issuer")
|
||||
with pytest.raises(pyjwt.InvalidIssuerError):
|
||||
_verify_passkey_token(token)
|
||||
|
||||
|
||||
class TestVerifyAuthentikToken:
|
||||
"""Tests for _verify_authentik_token() — specifically that expiration is verified."""
|
||||
|
||||
async def test_verifies_expiration_after_fix(self) -> None:
|
||||
"""After removing verify_exp: False, expired Authentik tokens should be rejected."""
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
public_key = private_key.public_key()
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
issuer = "https://authentik.viktorbarzin.me/application/o/wrongmove/"
|
||||
payload = {
|
||||
"sub": "authentik-user",
|
||||
"email": "auth@example.com",
|
||||
"name": "Auth User",
|
||||
"iss": issuer,
|
||||
"aud": "5AJKRgcdgVm1OyApBzFkadDFfStW9a555zwv2MOe",
|
||||
"exp": datetime.now(timezone.utc) - timedelta(hours=1), # expired
|
||||
}
|
||||
token = pyjwt.encode(payload, private_key, algorithm="RS256")
|
||||
|
||||
# Build a real PyJWK-compatible signing key mock so jwt.decode
|
||||
# takes the PyJWK code path (uses key.key directly, skips prepare_key)
|
||||
mock_signing_key = MagicMock(spec=pyjwt.PyJWK)
|
||||
mock_signing_key.key = public_key
|
||||
mock_signing_key.algorithm_name = "RS256"
|
||||
mock_signing_key.Algorithm = pyjwt.get_algorithm_by_name("RS256")
|
||||
|
||||
mock_jwks_client = MagicMock()
|
||||
mock_jwks_client.get_signing_key_from_jwt.return_value = mock_signing_key
|
||||
|
||||
mock_metadata = {
|
||||
"issuer": issuer,
|
||||
"jwks_uri": f"{issuer}jwks/",
|
||||
}
|
||||
|
||||
with patch("api.auth.get_oidc_metadata", new_callable=AsyncMock, return_value=mock_metadata), \
|
||||
patch("api.auth.get_cached_jwks_client", new_callable=AsyncMock, return_value=mock_jwks_client):
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
await _verify_authentik_token(token)
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for get_current_user()."""
|
||||
|
||||
async def test_routes_to_passkey_verifier_for_matching_issuer(self) -> None:
|
||||
token = _make_passkey_token()
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
user = await get_current_user(credentials)
|
||||
assert user.sub == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
async def test_routes_to_authentik_for_other_issuer(self) -> None:
|
||||
"""When issuer != JWT_ISSUER, should route to Authentik verifier."""
|
||||
token = _make_passkey_token(issuer="https://authentik.viktorbarzin.me/application/o/wrongmove/")
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
mock_user = User(sub="authentik-user", email="auth@example.com", name="Auth User")
|
||||
with patch("api.auth._verify_authentik_token", new_callable=AsyncMock, return_value=mock_user):
|
||||
user = await get_current_user(credentials)
|
||||
assert user.email == "auth@example.com"
|
||||
|
||||
async def test_raises_http_exception_for_invalid_token(self) -> None:
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="not.a.valid.token")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
async def test_raises_http_exception_for_garbage_token(self) -> None:
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="totalgarbage")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
388
tests/unit/test_cli.py
Normal file
388
tests/unit/test_cli.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
"""Characterization and unit tests for the CLI (main.py)."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import click
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from models.listing import FurnishType, ListingType, QueryParameters
|
||||
from main import build_query_parameters, cli, listing_filter_options
|
||||
|
||||
|
||||
class TestBuildQueryParameters:
|
||||
"""Tests for build_query_parameters()."""
|
||||
|
||||
def test_typical_rent_inputs(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London", "Camden"],
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=4,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
furnish_types=["FURNISHED"],
|
||||
available_from=datetime(2025, 6, 1),
|
||||
last_seen_days=7,
|
||||
min_sqm=50,
|
||||
)
|
||||
assert qp.listing_type == ListingType.RENT
|
||||
assert qp.district_names == {"London", "Camden"}
|
||||
assert qp.min_bedrooms == 2
|
||||
assert qp.max_bedrooms == 4
|
||||
assert qp.min_price == 1000
|
||||
assert qp.max_price == 3000
|
||||
assert qp.furnish_types == [FurnishType.FURNISHED]
|
||||
assert qp.let_date_available_from == datetime(2025, 6, 1)
|
||||
assert qp.last_seen_days == 7
|
||||
assert qp.min_sqm == 50
|
||||
|
||||
def test_typical_buy_inputs(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="BUY",
|
||||
district=["Barnet"],
|
||||
min_bedrooms=3,
|
||||
max_bedrooms=5,
|
||||
min_price=200000,
|
||||
max_price=500000,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.listing_type == ListingType.BUY
|
||||
assert qp.district_names == {"Barnet"}
|
||||
assert qp.furnish_types is None
|
||||
assert qp.let_date_available_from is None
|
||||
assert qp.min_sqm is None
|
||||
|
||||
def test_empty_districts_yields_empty_set(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=[],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.district_names == set()
|
||||
|
||||
def test_none_districts_yields_empty_set(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=None,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.district_names == set()
|
||||
|
||||
def test_furnish_types_conversion(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London"],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=["FURNISHED", "UNFURNISHED"],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.furnish_types == [FurnishType.FURNISHED, FurnishType.UNFURNISHED]
|
||||
|
||||
def test_empty_furnish_types_yields_none(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London"],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.furnish_types is None
|
||||
|
||||
def test_default_optional_parameters(self) -> None:
|
||||
qp = build_query_parameters(
|
||||
type="RENT",
|
||||
district=["London"],
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=10,
|
||||
min_price=0,
|
||||
max_price=999999,
|
||||
furnish_types=[],
|
||||
available_from=None,
|
||||
last_seen_days=14,
|
||||
)
|
||||
assert qp.radius == 0
|
||||
assert qp.page_size == 500
|
||||
assert qp.max_days_since_added == 14
|
||||
|
||||
|
||||
class TestListingFilterOptionsDecorator:
|
||||
"""Tests for the listing_filter_options decorator."""
|
||||
|
||||
def test_applies_all_expected_options(self) -> None:
|
||||
@click.command()
|
||||
@listing_filter_options
|
||||
def dummy_cmd(**kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
expected_option_names = {
|
||||
"type",
|
||||
"min_bedrooms",
|
||||
"max_bedrooms",
|
||||
"min_price",
|
||||
"max_price",
|
||||
"district",
|
||||
"furnish_types",
|
||||
"available_from",
|
||||
"last_seen_days",
|
||||
"min_sqm",
|
||||
}
|
||||
param_names = {p.name for p in dummy_cmd.params}
|
||||
assert expected_option_names.issubset(param_names), (
|
||||
f"Missing options: {expected_option_names - param_names}"
|
||||
)
|
||||
|
||||
def test_type_option_is_required(self) -> None:
|
||||
@click.command()
|
||||
@listing_filter_options
|
||||
def dummy_cmd(**kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
type_param = next(p for p in dummy_cmd.params if p.name == "type")
|
||||
assert type_param.required is True
|
||||
|
||||
def test_produces_query_parameters_kwarg(self) -> None:
|
||||
"""After refactoring, the decorator should produce a query_parameters kwarg."""
|
||||
captured: dict = {}
|
||||
|
||||
@click.command()
|
||||
@listing_filter_options
|
||||
def dummy_cmd(query_parameters: QueryParameters) -> None:
|
||||
captured["qp"] = query_parameters
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(dummy_cmd, ["--type", "RENT"])
|
||||
assert result.exit_code == 0, f"Command failed: {result.output}"
|
||||
assert isinstance(captured["qp"], QueryParameters)
|
||||
assert captured["qp"].listing_type == ListingType.RENT
|
||||
|
||||
|
||||
class TestDumpListingsCommand:
|
||||
"""Tests for the dump-listings CLI command."""
|
||||
|
||||
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_calls_refresh_listings_with_correct_params(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_refresh: AsyncMock,
|
||||
) -> None:
|
||||
from services.listing_service import RefreshResult
|
||||
|
||||
mock_refresh.return_value = RefreshResult(
|
||||
task_id=None,
|
||||
new_listings_count=5,
|
||||
message="Fetched 5 new listings",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"dump-listings",
|
||||
"--type", "RENT",
|
||||
"--min-bedrooms", "2",
|
||||
"--max-bedrooms", "4",
|
||||
"--min-price", "1000",
|
||||
"--max-price", "3000",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_refresh.assert_called_once()
|
||||
call_args = mock_refresh.call_args
|
||||
qp: QueryParameters = call_args.args[1]
|
||||
assert qp.listing_type == ListingType.RENT
|
||||
assert qp.min_bedrooms == 2
|
||||
assert qp.max_bedrooms == 4
|
||||
assert qp.min_price == 1000
|
||||
assert qp.max_price == 3000
|
||||
assert call_args.kwargs.get("full") is not True
|
||||
|
||||
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_include_processing_flag_passes_full_true(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_refresh: AsyncMock,
|
||||
) -> None:
|
||||
from services.listing_service import RefreshResult
|
||||
|
||||
mock_refresh.return_value = RefreshResult(
|
||||
task_id=None,
|
||||
new_listings_count=0,
|
||||
message="Fetched 0 new listings",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"dump-listings",
|
||||
"--type", "RENT",
|
||||
"--include-processing",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_refresh.assert_called_once()
|
||||
call_kwargs = mock_refresh.call_args.kwargs
|
||||
assert call_kwargs.get("full") is True
|
||||
|
||||
@patch("main.listing_service.refresh_listings", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_include_processing_short_flag(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_refresh: AsyncMock,
|
||||
) -> None:
|
||||
from services.listing_service import RefreshResult
|
||||
|
||||
mock_refresh.return_value = RefreshResult(
|
||||
task_id=None,
|
||||
new_listings_count=0,
|
||||
message="Fetched 0 new listings",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"dump-listings",
|
||||
"--type", "RENT",
|
||||
"-p",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_refresh.assert_called_once()
|
||||
call_kwargs = mock_refresh.call_args.kwargs
|
||||
assert call_kwargs.get("full") is True
|
||||
|
||||
|
||||
class TestExportCsvCommand:
|
||||
"""Tests for the export-csv CLI command."""
|
||||
|
||||
@patch("main.export_service.export_to_csv", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_calls_export_to_csv(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_export: AsyncMock,
|
||||
) -> None:
|
||||
from services.export_service import ExportResult
|
||||
|
||||
mock_export.return_value = ExportResult(
|
||||
success=True,
|
||||
output_path="/tmp/test.csv",
|
||||
data=None,
|
||||
record_count=10,
|
||||
message="Exported 10 listings to /tmp/test.csv",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"export-csv",
|
||||
"--output-file", "/tmp/test.csv",
|
||||
"--type", "RENT",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_export.assert_called_once()
|
||||
call_args = mock_export.call_args
|
||||
qp = call_args[0][2]
|
||||
assert qp.listing_type == ListingType.RENT
|
||||
|
||||
|
||||
class TestExportImmowebCommand:
|
||||
"""Tests for the export-immoweb CLI command."""
|
||||
|
||||
@patch("main.export_service.export_to_geojson", new_callable=AsyncMock)
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_calls_export_to_geojson(
|
||||
self,
|
||||
mock_engine: MagicMock,
|
||||
mock_export: AsyncMock,
|
||||
) -> None:
|
||||
from services.export_service import ExportResult
|
||||
|
||||
mock_export.return_value = ExportResult(
|
||||
success=True,
|
||||
output_path="/tmp/test.geojson",
|
||||
data=None,
|
||||
record_count=5,
|
||||
message="Exported 5 listings to /tmp/test.geojson",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"export-immoweb",
|
||||
"--output-file", "/tmp/test.geojson",
|
||||
"--type", "RENT",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"CLI failed: {result.output}"
|
||||
mock_export.assert_called_once()
|
||||
|
||||
|
||||
class TestListDistrictsCommand:
|
||||
"""Tests for the list-districts CLI command."""
|
||||
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_outputs_district_names(self, mock_engine: MagicMock) -> None:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["list-districts"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "London" in result.output
|
||||
assert "Camden" in result.output
|
||||
assert "Available districts" in result.output
|
||||
|
||||
|
||||
class TestRoutingCommand:
|
||||
"""Tests for the routing CLI command."""
|
||||
|
||||
@patch("main.engine", new_callable=MagicMock)
|
||||
def test_requires_api_key_env_var(self, mock_engine: MagicMock) -> None:
|
||||
runner = CliRunner(env={"ROUTING_API_KEY": None})
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"routing",
|
||||
"--destination-address", "London Bridge",
|
||||
"--travel-mode", "transit",
|
||||
"--limit", "1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "ROUTING_API_KEY" in result.output
|
||||
62
tests/unit/test_districts.py
Normal file
62
tests/unit/test_districts.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""Unit tests for rec/districts.py and services/district_service.py."""
|
||||
from rec.districts import get_districts, get_district_by_name
|
||||
from services.district_service import get_all_districts, get_district_names, validate_districts
|
||||
|
||||
|
||||
class TestGetDistricts:
|
||||
def test_returns_non_empty_dict(self) -> None:
|
||||
districts = get_districts()
|
||||
assert isinstance(districts, dict)
|
||||
assert len(districts) > 0
|
||||
|
||||
def test_values_start_with_region_prefix(self) -> None:
|
||||
for name, region_id in get_districts().items():
|
||||
assert region_id.startswith("REGION^"), (
|
||||
f"District '{name}' has value '{region_id}' that doesn't start with REGION^"
|
||||
)
|
||||
|
||||
def test_contains_expected_london_boroughs(self) -> None:
|
||||
districts = get_districts()
|
||||
for borough in ("Camden", "Westminster", "Hackney"):
|
||||
assert borough in districts, f"Expected borough '{borough}' not found"
|
||||
|
||||
|
||||
class TestGetDistrictByName:
|
||||
def test_valid_name_returns_region_id(self) -> None:
|
||||
result = get_district_by_name("Camden")
|
||||
assert result == "REGION^93941"
|
||||
|
||||
def test_invalid_name_returns_none(self) -> None:
|
||||
result = get_district_by_name("Nonexistent District")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetDistrictNames:
|
||||
def test_returns_list_matching_dict_keys(self) -> None:
|
||||
names = get_district_names()
|
||||
assert isinstance(names, list)
|
||||
assert names == list(get_districts().keys())
|
||||
|
||||
|
||||
class TestGetAllDistricts:
|
||||
def test_returns_same_as_get_districts(self) -> None:
|
||||
assert get_all_districts() == get_districts()
|
||||
|
||||
|
||||
class TestValidateDistricts:
|
||||
def test_all_valid_returns_empty_list(self) -> None:
|
||||
result = validate_districts(["Camden", "Westminster", "Hackney"])
|
||||
assert result == []
|
||||
|
||||
def test_some_invalid_returns_invalid_ones(self) -> None:
|
||||
result = validate_districts(["Camden", "Faketown", "Westminster", "Nowhere"])
|
||||
assert result == ["Faketown", "Nowhere"]
|
||||
|
||||
def test_all_invalid_returns_all(self) -> None:
|
||||
invalid = ["Faketown", "Nowhere", "Neverland"]
|
||||
result = validate_districts(invalid)
|
||||
assert result == invalid
|
||||
|
||||
def test_empty_list_returns_empty_list(self) -> None:
|
||||
result = validate_districts([])
|
||||
assert result == []
|
||||
104
tests/unit/test_floorplan.py
Normal file
104
tests/unit/test_floorplan.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Unit tests for rec/floorplan.py."""
|
||||
from unittest.mock import patch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import pytest
|
||||
|
||||
from rec.floorplan import extract_total_sqm, improve_img_for_ocr, calculate_ocr
|
||||
|
||||
|
||||
class TestExtractTotalSqm:
|
||||
|
||||
def test_normal_value(self) -> None:
|
||||
assert extract_total_sqm("Total area: 75.5 sq m") == 75.5
|
||||
|
||||
def test_multiple_values_returns_max_in_range(self) -> None:
|
||||
assert extract_total_sqm("Room 1: 20 sqm, Total: 65 sq m") == 65.0
|
||||
|
||||
def test_no_match_returns_none(self) -> None:
|
||||
assert extract_total_sqm("No area info") is None
|
||||
|
||||
def test_below_minimum_returns_none(self) -> None:
|
||||
assert extract_total_sqm("Area: 15 sq m") is None
|
||||
|
||||
def test_above_maximum_returns_none(self) -> None:
|
||||
assert extract_total_sqm("Area: 200 sq m") is None
|
||||
|
||||
def test_edge_just_above_min(self) -> None:
|
||||
assert extract_total_sqm("Area: 30.1 sq m") == 30.1
|
||||
|
||||
def test_edge_just_below_max(self) -> None:
|
||||
assert extract_total_sqm("Area: 159.9 sq m") == 159.9
|
||||
|
||||
def test_exactly_at_min_boundary_returns_none(self) -> None:
|
||||
# MIN_SQM < sqm, so 30 is not strictly greater than 30
|
||||
assert extract_total_sqm("Area: 30 sq m") is None
|
||||
|
||||
def test_exactly_at_max_boundary_returns_none(self) -> None:
|
||||
# sqm < MAX_SQM, so 160 is not strictly less than 160
|
||||
assert extract_total_sqm("Area: 160 sq m") is None
|
||||
|
||||
def test_format_sq_dot_m(self) -> None:
|
||||
assert extract_total_sqm("Area: 80 sq. m") == 80.0
|
||||
|
||||
def test_format_sqm_no_space(self) -> None:
|
||||
assert extract_total_sqm("Area: 80sqm") == 80.0
|
||||
|
||||
def test_format_sq_m_with_space(self) -> None:
|
||||
assert extract_total_sqm("Area: 80 sq m") == 80.0
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
assert extract_total_sqm("") is None
|
||||
|
||||
def test_multiple_valid_values_returns_max(self) -> None:
|
||||
assert extract_total_sqm("Living: 40 sq m, Total: 100 sq m") == 100.0
|
||||
|
||||
|
||||
class TestImproveImgForOcr:
|
||||
|
||||
def test_produces_valid_pil_image(self) -> None:
|
||||
# Create a small test image (50x50 white image)
|
||||
img = Image.fromarray(np.ones((50, 50, 3), dtype=np.uint8) * 200)
|
||||
result = improve_img_for_ocr(img)
|
||||
assert isinstance(result, Image.Image)
|
||||
# Result should be a grayscale (thresholded) image
|
||||
assert result.mode == "L"
|
||||
|
||||
def test_output_dimensions_scaled(self) -> None:
|
||||
img = Image.fromarray(np.ones((100, 100, 3), dtype=np.uint8) * 128)
|
||||
result = improve_img_for_ocr(img)
|
||||
# After 1.2x resize, 100 -> 120
|
||||
assert result.size[0] == 120
|
||||
assert result.size[1] == 120
|
||||
|
||||
|
||||
class TestCalculateOcr:
|
||||
|
||||
def test_invalid_path_raises_file_not_found(self) -> None:
|
||||
with pytest.raises(FileNotFoundError):
|
||||
calculate_ocr("/nonexistent/path/to/image.png")
|
||||
|
||||
def test_returns_sqm_from_first_pass(self, tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
# Create a real image file so the path check passes
|
||||
image_file = tmp_path / "test.png"
|
||||
Image.fromarray(np.ones((10, 10, 3), dtype=np.uint8)).save(str(image_file))
|
||||
|
||||
with patch("pytesseract.image_to_string", return_value="Total: 85 sq m"):
|
||||
result_sqm, result_text = calculate_ocr(str(image_file))
|
||||
|
||||
assert result_sqm == 85.0
|
||||
assert result_text == "Total: 85 sq m"
|
||||
|
||||
def test_falls_back_to_improved_image(self, tmp_path) -> None: # type: ignore[no-untyped-def]
|
||||
image_file = tmp_path / "test.png"
|
||||
Image.fromarray(np.ones((10, 10, 3), dtype=np.uint8)).save(str(image_file))
|
||||
|
||||
# First call returns no sqm data, second (on improved image) returns valid data
|
||||
with patch("pytesseract.image_to_string", side_effect=[
|
||||
"No area info here",
|
||||
"Total: 72 sq m",
|
||||
]):
|
||||
result_sqm, result_text = calculate_ocr(str(image_file))
|
||||
|
||||
assert result_sqm == 72.0
|
||||
assert result_text == "Total: 72 sq m"
|
||||
110
tests/unit/test_floorplan_detector.py
Normal file
110
tests/unit/test_floorplan_detector.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""Unit tests for services/floorplan_detector.py."""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from models.listing import RentListing, ListingSite, FurnishType
|
||||
from services.floorplan_detector import _calculate_sqm_ocr, detect_floorplan
|
||||
|
||||
|
||||
def _make_listing(**kwargs) -> RentListing: # type: ignore[no-untyped-def]
|
||||
defaults = dict(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=None,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return RentListing(**defaults)
|
||||
|
||||
|
||||
class TestCalculateSqmOcr:
|
||||
|
||||
async def test_skips_listing_with_existing_square_meters(self) -> None:
|
||||
listing = _make_listing(square_meters=50.0)
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is None
|
||||
|
||||
async def test_empty_floorplan_paths_returns_listing_with_zero(self) -> None:
|
||||
listing = _make_listing(floorplan_image_paths=[])
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 0
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_with_mocked_ocr_returning_value(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.return_value = (85.0, "Total: 85 sq m")
|
||||
listing = _make_listing(floorplan_image_paths=["/fake/path.png"])
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 85.0
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_with_mocked_ocr_returning_none(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.return_value = (None, "no data")
|
||||
listing = _make_listing(floorplan_image_paths=["/fake/path.png"])
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 0
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_picks_max_from_multiple_floorplans(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.side_effect = [
|
||||
(50.0, "50 sq m"),
|
||||
(90.0, "90 sq m"),
|
||||
]
|
||||
listing = _make_listing(floorplan_image_paths=["/fake/a.png", "/fake/b.png"])
|
||||
semaphore = asyncio.Semaphore(2)
|
||||
result = await _calculate_sqm_ocr(listing, semaphore)
|
||||
assert result is not None
|
||||
assert result.square_meters == 90.0
|
||||
|
||||
|
||||
class TestDetectFloorplan:
|
||||
|
||||
@patch("services.floorplan_detector.floorplan")
|
||||
async def test_detect_floorplan_with_mocked_repository(self, mock_floorplan: MagicMock) -> None:
|
||||
mock_floorplan.calculate_ocr.return_value = (75.0, "75 sq m")
|
||||
|
||||
listing = _make_listing(
|
||||
floorplan_image_paths=["/fake/path.png"],
|
||||
)
|
||||
repository = MagicMock()
|
||||
repository.get_listings = AsyncMock(return_value=[listing])
|
||||
repository.upsert_listings = AsyncMock(return_value=[])
|
||||
|
||||
await detect_floorplan(repository)
|
||||
|
||||
repository.upsert_listings.assert_called_once()
|
||||
upserted = repository.upsert_listings.call_args[0][0]
|
||||
assert len(upserted) == 1
|
||||
assert upserted[0].square_meters == 75.0
|
||||
|
||||
async def test_detect_floorplan_skips_already_processed(self) -> None:
|
||||
listing = _make_listing(square_meters=50.0)
|
||||
repository = MagicMock()
|
||||
repository.get_listings = AsyncMock(return_value=[listing])
|
||||
repository.upsert_listings = AsyncMock(return_value=[])
|
||||
|
||||
await detect_floorplan(repository)
|
||||
|
||||
repository.upsert_listings.assert_called_once()
|
||||
upserted = repository.upsert_listings.call_args[0][0]
|
||||
assert len(upserted) == 0
|
||||
215
tests/unit/test_image_fetcher.py
Normal file
215
tests/unit/test_image_fetcher.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
"""Unit tests for the image fetcher service."""
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from tenacity import stop_after_attempt
|
||||
|
||||
from models.listing import RentListing, ListingSite, FurnishType
|
||||
from services.image_fetcher import dump_images_for_listing, MAX_CONCURRENT_DOWNLOADS
|
||||
|
||||
|
||||
def _make_listing(**kwargs) -> RentListing: # type: ignore[no-untyped-def]
|
||||
"""Create a RentListing with sensible defaults for testing."""
|
||||
defaults = dict(
|
||||
id=12345,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=None,
|
||||
agency="Test Agency",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={
|
||||
"property": {
|
||||
"visible": True,
|
||||
"floorplans": [
|
||||
{"url": "https://media.rightmove.co.uk/imgs/floorplan_1.jpg"}
|
||||
],
|
||||
}
|
||||
},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return RentListing(**defaults)
|
||||
|
||||
|
||||
class TestDumpImagesForListing:
|
||||
"""Tests for dump_images_for_listing function."""
|
||||
|
||||
async def test_downloads_floorplan_image(self, tmp_path: Path) -> None:
|
||||
"""Test successful floorplan image download."""
|
||||
listing = _make_listing()
|
||||
image_bytes = b"\x89PNG fake image data"
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read = AsyncMock(return_value=image_bytes)
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == 12345
|
||||
assert len(result.floorplan_image_paths) == 1
|
||||
# Verify the image was written
|
||||
written_path = Path(result.floorplan_image_paths[0])
|
||||
assert written_path.exists()
|
||||
assert written_path.read_bytes() == image_bytes
|
||||
|
||||
async def test_skips_existing_images(self, tmp_path: Path) -> None:
|
||||
"""Test that existing images are not re-downloaded."""
|
||||
listing = _make_listing()
|
||||
# Pre-create the floorplan file
|
||||
floorplan_dir = tmp_path / str(listing.id) / "floorplans"
|
||||
floorplan_dir.mkdir(parents=True)
|
||||
existing_file = floorplan_dir / "floorplan_1.jpg"
|
||||
existing_file.write_bytes(b"existing image")
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
# Should return None because the only floorplan was skipped (continue)
|
||||
assert result is None
|
||||
# Session.get should NOT have been called
|
||||
mock_session.get.assert_not_called()
|
||||
|
||||
async def test_returns_none_on_404(self, tmp_path: Path) -> None:
|
||||
"""Test that 404 responses return None (image not found)."""
|
||||
listing = _make_listing()
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 404
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_raises_on_non_200_status(self, tmp_path: Path) -> None:
|
||||
"""Test that non-200/404 status raises exception."""
|
||||
listing = _make_listing()
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 500
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
with pytest.raises(Exception, match="HTTP 500"):
|
||||
# Disable tenacity retry for testing: stop after 1 attempt and reraise
|
||||
await dump_images_for_listing.retry_with(
|
||||
stop=stop_after_attempt(1),
|
||||
reraise=True,
|
||||
)(listing, tmp_path, session=mock_session)
|
||||
|
||||
async def test_returns_none_when_no_floorplans(self, tmp_path: Path) -> None:
|
||||
"""Test listing with no floorplans returns None."""
|
||||
listing = _make_listing(
|
||||
additional_info={"property": {"visible": True, "floorplans": []}}
|
||||
)
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_url_filename_extraction(self, tmp_path: Path) -> None:
|
||||
"""Test that filenames are correctly extracted from URLs."""
|
||||
listing = _make_listing(
|
||||
additional_info={
|
||||
"property": {
|
||||
"visible": True,
|
||||
"floorplans": [
|
||||
{
|
||||
"url": "https://media.rightmove.co.uk/dir/sub/my_floorplan.png"
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
image_bytes = b"fake png"
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read = AsyncMock(return_value=image_bytes)
|
||||
|
||||
mock_session = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session.get = MagicMock(return_value=mock_cm)
|
||||
|
||||
result = await dump_images_for_listing(
|
||||
listing, tmp_path, session=mock_session
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
written_path = Path(result.floorplan_image_paths[0])
|
||||
assert written_path.name == "my_floorplan.png"
|
||||
|
||||
async def test_creates_session_when_none_provided(self, tmp_path: Path) -> None:
|
||||
"""Test that a session is created and closed when none is provided."""
|
||||
listing = _make_listing()
|
||||
image_bytes = b"fake image"
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read = AsyncMock(return_value=image_bytes)
|
||||
|
||||
mock_session_instance = MagicMock(spec=aiohttp.ClientSession)
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session_instance.get = MagicMock(return_value=mock_cm)
|
||||
mock_session_instance.close = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"services.image_fetcher.aiohttp.ClientSession",
|
||||
return_value=mock_session_instance,
|
||||
):
|
||||
result = await dump_images_for_listing(listing, tmp_path, session=None)
|
||||
|
||||
assert result is not None
|
||||
mock_session_instance.close.assert_awaited_once()
|
||||
|
||||
|
||||
class TestImageFetcherConfig:
|
||||
"""Tests for image fetcher configuration."""
|
||||
|
||||
def test_max_concurrent_downloads_constant(self) -> None:
|
||||
"""Test that MAX_CONCURRENT_DOWNLOADS is defined and reasonable."""
|
||||
assert MAX_CONCURRENT_DOWNLOADS > 0
|
||||
assert MAX_CONCURRENT_DOWNLOADS <= 20
|
||||
225
tests/unit/test_listing_cache.py
Normal file
225
tests/unit/test_listing_cache.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
"""Unit tests for services/listing_cache.py."""
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
from models.listing import ListingType, QueryParameters
|
||||
from services.listing_cache import (
|
||||
CACHE_PREFIX,
|
||||
_get_redis_client,
|
||||
cache_features_batch,
|
||||
get_cached_count,
|
||||
get_cached_features,
|
||||
invalidate_cache,
|
||||
make_cache_key,
|
||||
)
|
||||
|
||||
|
||||
def _make_query(**overrides) -> QueryParameters:
|
||||
"""Create a QueryParameters with defaults for testing."""
|
||||
defaults = {"listing_type": ListingType.RENT, "min_price": 1000, "max_price": 3000}
|
||||
defaults.update(overrides)
|
||||
return QueryParameters(**defaults)
|
||||
|
||||
|
||||
class TestMakeCacheKey:
|
||||
"""Tests for make_cache_key()."""
|
||||
|
||||
def test_deterministic_for_same_params(self):
|
||||
"""Same parameters produce the same cache key."""
|
||||
qp = _make_query()
|
||||
assert make_cache_key(qp) == make_cache_key(qp)
|
||||
|
||||
def test_different_for_different_params(self):
|
||||
"""Different parameters produce different cache keys."""
|
||||
qp1 = _make_query(min_price=1000)
|
||||
qp2 = _make_query(min_price=2000)
|
||||
assert make_cache_key(qp1) != make_cache_key(qp2)
|
||||
|
||||
def test_key_starts_with_prefix(self):
|
||||
"""Cache key starts with CACHE_PREFIX."""
|
||||
qp = _make_query()
|
||||
assert make_cache_key(qp).startswith(CACHE_PREFIX)
|
||||
|
||||
|
||||
class TestGetRedisClient:
|
||||
"""Tests for _get_redis_client() URL parsing."""
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_default_broker_url(self, mock_redis):
|
||||
"""Uses default localhost URL when env var is not set."""
|
||||
with mock.patch.dict("os.environ", {}, clear=True):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://localhost:6379/2", decode_responses=True
|
||||
)
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_custom_broker_url(self, mock_redis):
|
||||
"""Replaces db number from custom broker URL."""
|
||||
with mock.patch.dict(
|
||||
"os.environ", {"CELERY_BROKER_URL": "redis://myhost:1234/5"}
|
||||
):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://myhost:1234/2", decode_responses=True
|
||||
)
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_broker_url_with_password(self, mock_redis):
|
||||
"""Preserves auth info in broker URL."""
|
||||
with mock.patch.dict(
|
||||
"os.environ",
|
||||
{"CELERY_BROKER_URL": "redis://:secret@myhost:6379/0"},
|
||||
):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://:secret@myhost:6379/2", decode_responses=True
|
||||
)
|
||||
|
||||
@mock.patch("services.listing_cache.redis")
|
||||
def test_broker_url_with_query_params(self, mock_redis):
|
||||
"""Preserves query parameters in broker URL."""
|
||||
with mock.patch.dict(
|
||||
"os.environ",
|
||||
{"CELERY_BROKER_URL": "redis://myhost:6379/0?timeout=5"},
|
||||
):
|
||||
_get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://myhost:6379/2?timeout=5", decode_responses=True
|
||||
)
|
||||
|
||||
|
||||
class TestGetCachedCount:
|
||||
"""Tests for get_cached_count()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_returns_none_on_redis_error(self, mock_get_client):
|
||||
"""Returns None when Redis raises an error."""
|
||||
mock_get_client.side_effect = redis.RedisError("connection refused")
|
||||
|
||||
result = get_cached_count(_make_query())
|
||||
assert result is None
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_returns_none_when_key_not_exists(self, mock_get_client):
|
||||
"""Returns None when the cache key does not exist."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.exists.return_value = False
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = get_cached_count(_make_query())
|
||||
assert result is None
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_returns_count_when_key_exists(self, mock_get_client):
|
||||
"""Returns list length when key exists."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.exists.return_value = True
|
||||
mock_client.llen.return_value = 42
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = get_cached_count(_make_query())
|
||||
assert result == 42
|
||||
|
||||
|
||||
class TestGetCachedFeatures:
|
||||
"""Tests for get_cached_features()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_yields_empty_on_redis_error(self, mock_get_client):
|
||||
"""Yields nothing when Redis raises an error."""
|
||||
mock_get_client.side_effect = redis.RedisError("connection refused")
|
||||
|
||||
batches = list(get_cached_features(_make_query()))
|
||||
assert batches == []
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_yields_batches(self, mock_get_client):
|
||||
"""Yields features in batches."""
|
||||
features = [{"type": "Feature", "id": i} for i in range(3)]
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.llen.return_value = 3
|
||||
mock_client.lrange.return_value = [json.dumps(f) for f in features]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
batches = list(get_cached_features(_make_query(), batch_size=50))
|
||||
assert len(batches) == 1
|
||||
assert batches[0] == features
|
||||
|
||||
|
||||
class TestCacheFeaturesBatch:
|
||||
"""Tests for cache_features_batch()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_empty_features_returns_early(self, mock_get_client):
|
||||
"""Does not call Redis when features list is empty."""
|
||||
cache_features_batch(_make_query(), [])
|
||||
mock_get_client.assert_not_called()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_writes_features_via_pipeline(self, mock_get_client):
|
||||
"""Writes features and sets TTL through pipeline."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_pipeline = mock.MagicMock()
|
||||
mock_client.pipeline.return_value = mock_pipeline
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
features = [{"type": "Feature", "id": 1}]
|
||||
cache_features_batch(_make_query(), features)
|
||||
|
||||
mock_pipeline.rpush.assert_called_once()
|
||||
mock_pipeline.expire.assert_called_once()
|
||||
mock_pipeline.execute.assert_called_once()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_handles_redis_error(self, mock_get_client):
|
||||
"""Handles Redis error gracefully during write."""
|
||||
mock_get_client.side_effect = redis.RedisError("write error")
|
||||
|
||||
# Should not raise
|
||||
cache_features_batch(_make_query(), [{"id": 1}])
|
||||
|
||||
|
||||
class TestInvalidateCache:
|
||||
"""Tests for invalidate_cache()."""
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_handles_redis_error(self, mock_get_client):
|
||||
"""Handles Redis error gracefully during invalidation."""
|
||||
mock_get_client.side_effect = redis.RedisError("connection refused")
|
||||
|
||||
# Should not raise
|
||||
invalidate_cache()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_deletes_matching_keys_via_pipeline(self, mock_get_client):
|
||||
"""Deletes keys matching the cache prefix using pipeline."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_pipeline = mock.MagicMock()
|
||||
mock_client.pipeline.return_value = mock_pipeline
|
||||
# Simulate one scan iteration that returns keys, then done
|
||||
mock_client.scan.return_value = (0, ["listings:geojson:abc", "listings:geojson:def"])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
invalidate_cache()
|
||||
|
||||
assert mock_pipeline.delete.call_count == 2
|
||||
mock_pipeline.execute.assert_called_once()
|
||||
|
||||
@mock.patch("services.listing_cache._get_redis_client")
|
||||
def test_no_keys_to_delete(self, mock_get_client):
|
||||
"""Does nothing when no cache keys exist."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.scan.return_value = (0, [])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
invalidate_cache()
|
||||
|
||||
mock_client.pipeline.assert_not_called()
|
||||
372
tests/unit/test_listing_fetcher.py
Normal file
372
tests/unit/test_listing_fetcher.py
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
"""Unit tests for the listing fetcher service."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.listing import ListingType, QueryParameters
|
||||
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
|
||||
from services.listing_fetcher import (
|
||||
NUM_WORKERS,
|
||||
_fetch_subquery,
|
||||
dump_listings,
|
||||
dump_listings_full,
|
||||
)
|
||||
from services.query_splitter import SubQuery
|
||||
|
||||
|
||||
def _make_subquery(**kwargs) -> SubQuery:
|
||||
"""Create a SubQuery with sensible defaults for testing."""
|
||||
defaults = dict(
|
||||
district="REGION^123",
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=3,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
estimated_results=50,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return SubQuery(**defaults)
|
||||
|
||||
|
||||
class TestDumpListingsFull:
|
||||
"""Tests for dump_listings_full."""
|
||||
|
||||
async def test_returns_empty_list_when_no_new_listings(self) -> None:
|
||||
"""Test that empty results from dump_listings returns empty list."""
|
||||
with patch(
|
||||
"services.listing_fetcher.dump_listings",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listings = AsyncMock(return_value=[])
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
result = await dump_listings_full(params, mock_repo)
|
||||
assert result == []
|
||||
|
||||
async def test_returns_only_new_listings_from_db(self) -> None:
|
||||
"""Test that dump_listings_full fetches new listings by ID from the repository."""
|
||||
mock_listing_1 = MagicMock()
|
||||
mock_listing_1.id = 100
|
||||
mock_listing_2 = MagicMock()
|
||||
mock_listing_2.id = 200
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.dump_listings",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_listing_1, mock_listing_2],
|
||||
):
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listings = AsyncMock(
|
||||
return_value=[mock_listing_1, mock_listing_2]
|
||||
)
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
result = await dump_listings_full(params, mock_repo)
|
||||
|
||||
# Verify get_listings was called with the correct IDs
|
||||
mock_repo.get_listings.assert_awaited_once_with(
|
||||
only_ids=[100, 200]
|
||||
)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestFetchSubquery:
|
||||
"""Tests for _fetch_subquery helper."""
|
||||
|
||||
async def test_skips_subquery_with_zero_estimated_results(self) -> None:
|
||||
"""Test that subqueries with 0 estimated results are skipped."""
|
||||
sq = _make_subquery(estimated_results=0)
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=MagicMock(),
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_skips_subquery_with_none_estimated_results(self) -> None:
|
||||
"""Test that subqueries with None estimated results are skipped."""
|
||||
sq = _make_subquery(estimated_results=None)
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=MagicMock(),
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_enqueues_new_ids_only(self) -> None:
|
||||
"""Test that only new (not existing) IDs are enqueued."""
|
||||
sq = _make_subquery(estimated_results=10)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
existing_ids: set[int] = {101, 103}
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
mock_config.max_concurrent_requests = 5
|
||||
|
||||
api_result = {
|
||||
"properties": [
|
||||
{"identifier": 101}, # existing
|
||||
{"identifier": 102}, # new
|
||||
{"identifier": 103}, # existing
|
||||
{"identifier": 104}, # new
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
return_value=api_result,
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=existing_ids,
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 2
|
||||
# Verify that queued IDs are the new ones
|
||||
queued = []
|
||||
while not queue.empty():
|
||||
queued.append(queue.get_nowait())
|
||||
assert 102 in queued
|
||||
assert 104 in queued
|
||||
assert 101 not in queued
|
||||
assert 103 not in queued
|
||||
|
||||
async def test_stops_on_circuit_breaker_error(self) -> None:
|
||||
"""Test that CircuitBreakerOpenError breaks the page loop."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=CircuitBreakerOpenError("open"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_on_throttling_error(self) -> None:
|
||||
"""Test that ThrottlingError breaks the page loop."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ThrottlingError("throttled"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_on_generic_error(self) -> None:
|
||||
"""Test that GENERIC_ERROR (past last page) stops pagination."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("GENERIC_ERROR: no more results"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_on_unexpected_error(self) -> None:
|
||||
"""Test that unexpected errors also stop pagination."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("some network error"),
|
||||
):
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
assert ids_found == 0
|
||||
assert queue.empty()
|
||||
|
||||
async def test_stops_when_fewer_results_than_page_size(self) -> None:
|
||||
"""Test that pagination stops when a page has fewer results than page_size."""
|
||||
sq = _make_subquery(estimated_results=100)
|
||||
params = QueryParameters(listing_type=ListingType.RENT, page_size=25)
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.max_pages_per_query = 60
|
||||
mock_config.request_delay_ms = 0
|
||||
|
||||
# Return fewer properties than page_size
|
||||
api_result = {
|
||||
"properties": [
|
||||
{"identifier": 1},
|
||||
{"identifier": 2},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"services.listing_fetcher.listing_query",
|
||||
new_callable=AsyncMock,
|
||||
return_value=api_result,
|
||||
) as mock_query:
|
||||
ids_found = await _fetch_subquery(
|
||||
sq=sq,
|
||||
parameters=params,
|
||||
session=MagicMock(),
|
||||
config=mock_config,
|
||||
semaphore=asyncio.Semaphore(5),
|
||||
existing_ids=set(),
|
||||
queue=queue,
|
||||
)
|
||||
|
||||
# Should have called listing_query exactly once (then stopped)
|
||||
assert mock_query.await_count == 1
|
||||
assert ids_found == 2
|
||||
|
||||
|
||||
class TestDumpListings:
|
||||
"""Tests for dump_listings."""
|
||||
|
||||
async def test_circuit_breaker_returns_empty_list(self) -> None:
|
||||
"""Test that CircuitBreakerOpenError returns empty list."""
|
||||
mock_repo = AsyncMock()
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
|
||||
with patch("services.listing_fetcher.create_session") as mock_cs:
|
||||
mock_cs.side_effect = CircuitBreakerOpenError("open")
|
||||
result = await dump_listings(params, mock_repo)
|
||||
assert result == []
|
||||
|
||||
async def test_returns_processed_listings(self) -> None:
|
||||
"""Test that dump_listings returns processed listings from the pipeline."""
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listing_ids = MagicMock(return_value=set())
|
||||
params = QueryParameters(listing_type=ListingType.RENT)
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.id = 42
|
||||
|
||||
mock_session_cm = AsyncMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.listing_fetcher.create_session",
|
||||
return_value=mock_session_cm,
|
||||
),
|
||||
patch(
|
||||
"services.listing_fetcher.QuerySplitter"
|
||||
) as mock_splitter_cls,
|
||||
patch(
|
||||
"services.listing_fetcher._fetch_subquery",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
),
|
||||
):
|
||||
mock_splitter = mock_splitter_cls.return_value
|
||||
mock_splitter.split = AsyncMock(return_value=[])
|
||||
mock_splitter.calculate_total_estimated_results = MagicMock(
|
||||
return_value=0
|
||||
)
|
||||
|
||||
result = await dump_listings(params, mock_repo)
|
||||
# With no subqueries, no listings are processed
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestNumWorkers:
|
||||
"""Tests for NUM_WORKERS constant."""
|
||||
|
||||
def test_num_workers_is_positive(self) -> None:
|
||||
"""Test that NUM_WORKERS is a positive integer."""
|
||||
assert NUM_WORKERS > 0
|
||||
|
||||
def test_num_workers_value(self) -> None:
|
||||
"""Test that NUM_WORKERS has the expected value."""
|
||||
assert NUM_WORKERS == 20
|
||||
87
tests/unit/test_listing_processor.py
Normal file
87
tests/unit/test_listing_processor.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Unit tests for the listing processor."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from models.listing import FurnishType
|
||||
from listing_processor import (
|
||||
_parse_furnish_type,
|
||||
_parse_available_from,
|
||||
ListingProcessor,
|
||||
FetchListingDetailsStep,
|
||||
MAX_OCR_WORKERS,
|
||||
)
|
||||
|
||||
|
||||
class TestParseFurnishType:
|
||||
"""Tests for _parse_furnish_type helper."""
|
||||
|
||||
def test_none_returns_unknown(self):
|
||||
assert _parse_furnish_type(None) == FurnishType.UNKNOWN
|
||||
|
||||
def test_ask_landlord_variant(self):
|
||||
assert _parse_furnish_type("Ask landlord") == FurnishType.ASK_LANDLORD
|
||||
|
||||
def test_furnished_lowercased(self):
|
||||
assert _parse_furnish_type("Furnished") == FurnishType.FURNISHED
|
||||
|
||||
def test_unfurnished(self):
|
||||
assert _parse_furnish_type("Unfurnished") == FurnishType.UNFURNISHED
|
||||
|
||||
def test_part_furnished(self):
|
||||
assert _parse_furnish_type("Part Furnished") == FurnishType.PART_FURNISHED
|
||||
|
||||
def test_unknown_string_returns_unknown(self):
|
||||
assert _parse_furnish_type("unknown") == FurnishType.UNKNOWN
|
||||
|
||||
def test_garbage_string_returns_unknown(self):
|
||||
assert _parse_furnish_type("xyzzy") == FurnishType.UNKNOWN
|
||||
|
||||
|
||||
class TestParseAvailableFrom:
|
||||
"""Tests for _parse_available_from helper."""
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert _parse_available_from(None) is None
|
||||
|
||||
def test_now_returns_datetime(self):
|
||||
result = _parse_available_from("Now")
|
||||
assert isinstance(result, datetime)
|
||||
|
||||
def test_valid_date_string(self):
|
||||
result = _parse_available_from("15/03/2024")
|
||||
assert result is not None
|
||||
assert result.day == 15
|
||||
assert result.month == 3
|
||||
|
||||
def test_invalid_date_returns_none(self):
|
||||
assert _parse_available_from("invalid") is None
|
||||
|
||||
|
||||
class TestListingProcessor:
|
||||
"""Tests for ListingProcessor."""
|
||||
|
||||
async def test_process_listing_marks_seen(self):
|
||||
"""Test that process_listing calls mark_seen."""
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_listings = AsyncMock(return_value=[MagicMock()])
|
||||
processor = ListingProcessor(mock_repo)
|
||||
# Mock all steps to not need processing
|
||||
for step in processor.process_steps:
|
||||
step.needs_processing = AsyncMock(return_value=False)
|
||||
await processor.process_listing(123)
|
||||
mock_repo.mark_seen.assert_awaited_once_with(123)
|
||||
|
||||
async def test_process_listing_returns_none_on_step_failure(self):
|
||||
"""Test that a step failure returns None."""
|
||||
mock_repo = AsyncMock()
|
||||
processor = ListingProcessor(mock_repo)
|
||||
for step in processor.process_steps:
|
||||
step.needs_processing = AsyncMock(return_value=True)
|
||||
step.process = AsyncMock(side_effect=Exception("fail"))
|
||||
result = await processor.process_listing(123)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestOcrWorkersConfig:
|
||||
def test_max_ocr_workers_positive(self):
|
||||
assert MAX_OCR_WORKERS >= 1
|
||||
295
tests/unit/test_listing_tasks.py
Normal file
295
tests/unit/test_listing_tasks.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
"""Unit tests for tasks/listing_tasks.py."""
|
||||
import json
|
||||
import os
|
||||
from collections import deque
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
import tasks.listing_tasks as module
|
||||
from tasks.listing_tasks import (
|
||||
_update_task_state,
|
||||
_PipelineState,
|
||||
TaskLogHandler,
|
||||
SCRAPE_LOCK_NAME,
|
||||
LOG_BUFFER_MAX_LINES,
|
||||
NUM_WORKERS,
|
||||
PHASE_SPLITTING,
|
||||
PHASE_FETCHING,
|
||||
PHASE_PROCESSING,
|
||||
PHASE_COMPLETED,
|
||||
)
|
||||
|
||||
|
||||
class TestUpdateTaskState:
|
||||
"""Tests for _update_task_state."""
|
||||
|
||||
def test_injects_logs_from_active_buffer(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = deque(["log line 1", "log line 2"])
|
||||
_update_task_state(task, "test_state", {"key": "value"})
|
||||
task.update_state.assert_called_once()
|
||||
call_meta = task.update_state.call_args[1]["meta"]
|
||||
assert call_meta["logs"] == ["log line 1", "log line 2"]
|
||||
assert call_meta["key"] == "value"
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
def test_works_when_buffer_is_none(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = None
|
||||
_update_task_state(task, "some_state", {"phase": "testing"})
|
||||
task.update_state.assert_called_once_with(
|
||||
state="some_state", meta={"phase": "testing"}
|
||||
)
|
||||
# No "logs" key should be injected
|
||||
call_meta = task.update_state.call_args[1]["meta"]
|
||||
assert "logs" not in call_meta
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
def test_state_string_is_passed_through(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = None
|
||||
_update_task_state(task, "PROGRESS", {})
|
||||
task.update_state.assert_called_once_with(state="PROGRESS", meta={})
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
def test_empty_buffer_injects_empty_list(self):
|
||||
task = MagicMock()
|
||||
original = module._active_log_buffer
|
||||
try:
|
||||
module._active_log_buffer = deque()
|
||||
_update_task_state(task, "state", {"a": 1})
|
||||
call_meta = task.update_state.call_args[1]["meta"]
|
||||
assert call_meta["logs"] == []
|
||||
finally:
|
||||
module._active_log_buffer = original
|
||||
|
||||
|
||||
class TestTaskLogHandler:
|
||||
"""Tests for the TaskLogHandler."""
|
||||
|
||||
def test_emit_appends_to_buffer(self):
|
||||
buf = deque(maxlen=10)
|
||||
handler = TaskLogHandler(buf)
|
||||
handler.setFormatter(
|
||||
__import__("logging").Formatter("%(message)s")
|
||||
)
|
||||
record = __import__("logging").LogRecord(
|
||||
name="test", level=20, pathname="", lineno=0,
|
||||
msg="hello", args=(), exc_info=None,
|
||||
)
|
||||
handler.emit(record)
|
||||
assert "hello" in buf
|
||||
|
||||
def test_buffer_respects_maxlen(self):
|
||||
buf = deque(maxlen=2)
|
||||
handler = TaskLogHandler(buf)
|
||||
handler.setFormatter(
|
||||
__import__("logging").Formatter("%(message)s")
|
||||
)
|
||||
for i in range(5):
|
||||
record = __import__("logging").LogRecord(
|
||||
name="test", level=20, pathname="", lineno=0,
|
||||
msg=f"msg{i}", args=(), exc_info=None,
|
||||
)
|
||||
handler.emit(record)
|
||||
assert len(buf) == 2
|
||||
assert list(buf) == ["msg3", "msg4"]
|
||||
|
||||
|
||||
class TestDumpListingsTask:
|
||||
"""Tests for dump_listings_task Celery task."""
|
||||
|
||||
@patch("tasks.listing_tasks.redis_lock")
|
||||
def test_skips_when_lock_not_acquired(self, mock_redis_lock):
|
||||
"""Task should skip when another scrape is running."""
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__enter__ = MagicMock(return_value=False)
|
||||
mock_cm.__exit__ = MagicMock(return_value=False)
|
||||
mock_redis_lock.return_value = mock_cm
|
||||
|
||||
from tasks.listing_tasks import dump_listings_task
|
||||
|
||||
# Use run() which handles bind=True properly
|
||||
task_instance = dump_listings_task
|
||||
task_instance.update_state = MagicMock()
|
||||
|
||||
result = dump_listings_task.run('{"listing_type": "RENT"}')
|
||||
|
||||
assert result["status"] == "skipped"
|
||||
assert result["reason"] == "another_job_running"
|
||||
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
|
||||
|
||||
@patch("tasks.listing_tasks.asyncio.run")
|
||||
@patch("tasks.listing_tasks.redis_lock")
|
||||
def test_calls_dump_listings_full_when_lock_acquired(
|
||||
self, mock_redis_lock, mock_asyncio_run
|
||||
):
|
||||
"""Task should call dump_listings_full when lock is acquired."""
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__enter__ = MagicMock(return_value=True)
|
||||
mock_cm.__exit__ = MagicMock(return_value=False)
|
||||
mock_redis_lock.return_value = mock_cm
|
||||
|
||||
mock_asyncio_run.return_value = []
|
||||
|
||||
from tasks.listing_tasks import dump_listings_task
|
||||
|
||||
task_instance = dump_listings_task
|
||||
task_instance.update_state = MagicMock()
|
||||
|
||||
params_json = '{"listing_type": "RENT", "min_price": 1000, "max_price": 5000}'
|
||||
result = dump_listings_task.run(params_json)
|
||||
|
||||
assert result["phase"] == "completed"
|
||||
assert result["progress"] == 1
|
||||
mock_asyncio_run.assert_called_once()
|
||||
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
|
||||
|
||||
|
||||
class TestSetupPeriodicTasks:
|
||||
"""Tests for setup_periodic_tasks."""
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_registers_enabled_schedules(self, mock_from_env):
|
||||
from config.schedule_config import ScheduleConfig
|
||||
from models.listing import ListingType
|
||||
|
||||
schedule = ScheduleConfig(
|
||||
name="Test Schedule",
|
||||
listing_type=ListingType.RENT,
|
||||
hour="3",
|
||||
minute="30",
|
||||
)
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_enabled_schedules.return_value = [schedule]
|
||||
mock_from_env.return_value = mock_config
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
sender.add_periodic_task.assert_called_once()
|
||||
call_args = sender.add_periodic_task.call_args
|
||||
assert call_args[1]["name"] == "Test Schedule"
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_handles_config_error_gracefully(self, mock_from_env):
|
||||
mock_from_env.side_effect = ValueError("bad config")
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
sender.add_periodic_task.assert_not_called()
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_registers_nothing_when_no_schedules(self, mock_from_env):
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_enabled_schedules.return_value = []
|
||||
mock_from_env.return_value = mock_config
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
sender.add_periodic_task.assert_not_called()
|
||||
|
||||
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
||||
def test_registers_multiple_schedules(self, mock_from_env):
|
||||
from config.schedule_config import ScheduleConfig
|
||||
from models.listing import ListingType
|
||||
|
||||
schedules = [
|
||||
ScheduleConfig(name="Rent", listing_type=ListingType.RENT, hour="2"),
|
||||
ScheduleConfig(name="Buy", listing_type=ListingType.BUY, hour="4"),
|
||||
]
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_enabled_schedules.return_value = schedules
|
||||
mock_from_env.return_value = mock_config
|
||||
|
||||
sender = MagicMock()
|
||||
module.setup_periodic_tasks(sender)
|
||||
|
||||
assert sender.add_periodic_task.call_count == 2
|
||||
|
||||
|
||||
class TestPipelineState:
|
||||
"""Tests for _PipelineState dataclass."""
|
||||
|
||||
def test_default_initialization(self):
|
||||
state = _PipelineState()
|
||||
assert state.ids_collected == 0
|
||||
assert state.completed_subqueries == 0
|
||||
assert state.total_pages_fetched == 0
|
||||
assert state.fetching_done is False
|
||||
assert state.processed_count == 0
|
||||
assert state.failed_count == 0
|
||||
assert state.details_fetched == 0
|
||||
assert state.images_downloaded == 0
|
||||
assert state.ocr_completed == 0
|
||||
assert state.processed_listings == []
|
||||
|
||||
def test_incrementing_counters(self):
|
||||
state = _PipelineState()
|
||||
state.ids_collected += 5
|
||||
state.completed_subqueries += 3
|
||||
state.total_pages_fetched += 10
|
||||
state.processed_count += 4
|
||||
state.failed_count += 1
|
||||
state.details_fetched += 4
|
||||
state.images_downloaded += 3
|
||||
state.ocr_completed += 2
|
||||
|
||||
assert state.ids_collected == 5
|
||||
assert state.completed_subqueries == 3
|
||||
assert state.total_pages_fetched == 10
|
||||
assert state.processed_count == 4
|
||||
assert state.failed_count == 1
|
||||
assert state.details_fetched == 4
|
||||
assert state.images_downloaded == 3
|
||||
assert state.ocr_completed == 2
|
||||
|
||||
def test_appending_to_processed_listings(self):
|
||||
state = _PipelineState()
|
||||
state.processed_listings.append("listing_a")
|
||||
state.processed_listings.append("listing_b")
|
||||
assert len(state.processed_listings) == 2
|
||||
assert state.processed_listings == ["listing_a", "listing_b"]
|
||||
|
||||
def test_separate_instances_have_independent_lists(self):
|
||||
state_a = _PipelineState()
|
||||
state_b = _PipelineState()
|
||||
state_a.processed_listings.append("only_a")
|
||||
assert state_b.processed_listings == []
|
||||
|
||||
def test_fetching_done_toggle(self):
|
||||
state = _PipelineState()
|
||||
assert state.fetching_done is False
|
||||
state.fetching_done = True
|
||||
assert state.fetching_done is True
|
||||
|
||||
|
||||
class TestPhaseConstants:
|
||||
"""Tests for phase constant values."""
|
||||
|
||||
def test_phase_splitting(self):
|
||||
assert PHASE_SPLITTING == "splitting"
|
||||
|
||||
def test_phase_fetching(self):
|
||||
assert PHASE_FETCHING == "fetching"
|
||||
|
||||
def test_phase_processing(self):
|
||||
assert PHASE_PROCESSING == "processing"
|
||||
|
||||
def test_phase_completed(self):
|
||||
assert PHASE_COMPLETED == "completed"
|
||||
|
||||
def test_num_workers(self):
|
||||
assert NUM_WORKERS == 20
|
||||
538
tests/unit/test_models.py
Normal file
538
tests/unit/test_models.py
Normal file
|
|
@ -0,0 +1,538 @@
|
|||
"""Unit tests for Listing models."""
|
||||
import dataclasses
|
||||
from datetime import datetime
|
||||
import json
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from models.listing import (
|
||||
BuyListing,
|
||||
DestinationMode,
|
||||
FurnishType,
|
||||
ListingSite,
|
||||
ListingType,
|
||||
PriceHistoryItem,
|
||||
QueryParameters,
|
||||
RentListing,
|
||||
Listing,
|
||||
Route,
|
||||
RouteLegStep,
|
||||
)
|
||||
from rec.routing import TravelMode
|
||||
|
||||
|
||||
class TestListing:
|
||||
"""Tests for the base Listing model."""
|
||||
|
||||
def test_price_per_square_meter_calculation(self) -> None:
|
||||
"""Test that price_per_square_meter is calculated correctly."""
|
||||
listing = RentListing(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=50.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
assert listing.price_per_square_meter == 40.0
|
||||
|
||||
def test_price_per_square_meter_none_when_no_sqm(self) -> None:
|
||||
"""Test that price_per_square_meter is None when square_meters is None."""
|
||||
listing = RentListing(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=None,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
assert listing.price_per_square_meter is None
|
||||
|
||||
def test_price_per_square_meter_none_when_sqm_zero(self) -> None:
|
||||
"""Test that price_per_square_meter is None when square_meters is 0."""
|
||||
listing = RentListing(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=0.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
assert listing.price_per_square_meter is None
|
||||
|
||||
def test_url_property(self) -> None:
|
||||
"""Test that url property returns correct Rightmove URL."""
|
||||
listing = RentListing(
|
||||
id=123456789,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=50.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
assert listing.url == "https://www.rightmove.co.uk/properties/123456789"
|
||||
|
||||
def test_is_removed_property_visible(self) -> None:
|
||||
"""Test that is_removed returns False when property is visible."""
|
||||
listing = RentListing(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=50.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
assert listing.is_removed is False
|
||||
|
||||
def test_is_removed_property_not_visible(self) -> None:
|
||||
"""Test that is_removed returns True when property is not visible."""
|
||||
listing = RentListing(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=50.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": False}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
assert listing.is_removed is True
|
||||
|
||||
|
||||
class TestPriceHistory:
|
||||
"""Tests for price history serialization/deserialization."""
|
||||
|
||||
def test_price_history_serialization_roundtrip(self) -> None:
|
||||
"""Test that price history can be serialized and deserialized."""
|
||||
now = datetime.now()
|
||||
price_history = [
|
||||
PriceHistoryItem(
|
||||
first_seen=now,
|
||||
last_seen=now,
|
||||
price=2000.0,
|
||||
),
|
||||
PriceHistoryItem(
|
||||
first_seen=now,
|
||||
last_seen=now,
|
||||
price=2100.0,
|
||||
),
|
||||
]
|
||||
|
||||
# Serialize
|
||||
serialized = Listing.serialize_price_history(price_history)
|
||||
assert isinstance(serialized, str)
|
||||
|
||||
# Create listing with serialized history
|
||||
listing = RentListing(
|
||||
id=1,
|
||||
price=2100.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=50.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json=serialized,
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=now,
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
|
||||
# Deserialize
|
||||
deserialized = listing.price_history
|
||||
assert len(deserialized) == 2
|
||||
assert deserialized[0].price == 2000.0
|
||||
assert deserialized[1].price == 2100.0
|
||||
|
||||
def test_price_history_empty(self) -> None:
|
||||
"""Test that empty price history works correctly."""
|
||||
listing = RentListing(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=50.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
assert listing.price_history == []
|
||||
|
||||
def test_price_history_item_to_dict(self) -> None:
|
||||
"""Test PriceHistoryItem.to_dict() method."""
|
||||
now = datetime.now()
|
||||
item = PriceHistoryItem(
|
||||
first_seen=now,
|
||||
last_seen=now,
|
||||
price=2500.0,
|
||||
)
|
||||
result = item.to_dict()
|
||||
assert result["price"] == 2500.0
|
||||
assert result["first_seen"] == now.isoformat()
|
||||
assert result["last_seen"] == now.isoformat()
|
||||
|
||||
|
||||
class TestRentListing:
|
||||
"""Tests specific to RentListing model."""
|
||||
|
||||
def test_rent_listing_has_furnish_type(self) -> None:
|
||||
"""Test that RentListing has furnish_type field."""
|
||||
listing = RentListing(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=50.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.PART_FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
assert listing.furnish_type == FurnishType.PART_FURNISHED
|
||||
|
||||
def test_rent_listing_has_available_from(self) -> None:
|
||||
"""Test that RentListing has available_from field."""
|
||||
now = datetime.now()
|
||||
listing = RentListing(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=50.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=now,
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=now,
|
||||
)
|
||||
assert listing.available_from == now
|
||||
|
||||
|
||||
class TestBuyListing:
|
||||
"""Tests specific to BuyListing model."""
|
||||
|
||||
def test_buy_listing_has_service_charge(self) -> None:
|
||||
"""Test that BuyListing has service_charge field."""
|
||||
listing = BuyListing(
|
||||
id=1,
|
||||
price=450000.0,
|
||||
number_of_bedrooms=3,
|
||||
square_meters=95.0,
|
||||
agency="Test",
|
||||
council_tax_band="D",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
service_charge=2500.0,
|
||||
lease_left=85,
|
||||
)
|
||||
assert listing.service_charge == 2500.0
|
||||
|
||||
def test_buy_listing_has_lease_left(self) -> None:
|
||||
"""Test that BuyListing has lease_left field."""
|
||||
listing = BuyListing(
|
||||
id=1,
|
||||
price=450000.0,
|
||||
number_of_bedrooms=3,
|
||||
square_meters=95.0,
|
||||
agency="Test",
|
||||
council_tax_band="D",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=None,
|
||||
service_charge=None,
|
||||
lease_left=120,
|
||||
)
|
||||
assert listing.lease_left == 120
|
||||
|
||||
|
||||
def _make_listing_with_routing(routing_info_json: str | None) -> RentListing:
|
||||
"""Helper to create a RentListing with given routing_info_json."""
|
||||
return RentListing(
|
||||
id=1,
|
||||
price=2000.0,
|
||||
number_of_bedrooms=2,
|
||||
square_meters=50.0,
|
||||
agency="Test",
|
||||
council_tax_band="C",
|
||||
longitude=0.0,
|
||||
latitude=0.0,
|
||||
price_history_json="[]",
|
||||
listing_site=ListingSite.RIGHTMOVE,
|
||||
last_seen=datetime.now(),
|
||||
photo_thumbnail=None,
|
||||
floorplan_image_paths=[],
|
||||
additional_info={"property": {"visible": True}},
|
||||
routing_info_json=routing_info_json,
|
||||
furnish_type=FurnishType.FURNISHED,
|
||||
available_from=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_sample_routing_info() -> dict[DestinationMode, list[Route]]:
|
||||
"""Helper to create sample routing info for tests."""
|
||||
destination_mode = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
routes = [
|
||||
Route(
|
||||
legs=[
|
||||
RouteLegStep(
|
||||
distance_meters=500,
|
||||
duration_s=120,
|
||||
travel_mode=TravelMode.WALK,
|
||||
),
|
||||
RouteLegStep(
|
||||
distance_meters=4000,
|
||||
duration_s=480,
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
),
|
||||
],
|
||||
distance_meters=4500,
|
||||
duration_s=600,
|
||||
)
|
||||
]
|
||||
return {destination_mode: routes}
|
||||
|
||||
|
||||
class TestQueryParametersValidation:
|
||||
"""Tests for QueryParameters validation."""
|
||||
|
||||
def test_valid_parameters(self) -> None:
|
||||
"""Basic valid QueryParameters creation."""
|
||||
params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=3,
|
||||
)
|
||||
assert params.min_price == 1000
|
||||
assert params.max_price == 3000
|
||||
assert params.min_bedrooms == 1
|
||||
assert params.max_bedrooms == 3
|
||||
|
||||
def test_invalid_price_range_raises(self) -> None:
|
||||
"""min_price > max_price should raise ValidationError."""
|
||||
with pytest.raises(ValidationError, match="min_price.*must be <= max_price"):
|
||||
QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_price=5000,
|
||||
max_price=1000,
|
||||
)
|
||||
|
||||
def test_invalid_bedroom_range_raises(self) -> None:
|
||||
"""min_bedrooms > max_bedrooms should raise ValidationError."""
|
||||
with pytest.raises(ValidationError, match="min_bedrooms.*must be <= max_bedrooms"):
|
||||
QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=5,
|
||||
max_bedrooms=2,
|
||||
)
|
||||
|
||||
def test_negative_bedrooms_raises(self) -> None:
|
||||
"""Negative bedroom counts should raise ValidationError."""
|
||||
with pytest.raises(ValidationError, match="min_bedrooms.*must be non-negative"):
|
||||
QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=-1,
|
||||
max_bedrooms=3,
|
||||
)
|
||||
|
||||
|
||||
class TestDestinationMode:
|
||||
"""Tests for DestinationMode."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test to_dict returns correct dict."""
|
||||
dm = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
result = dm.to_dict()
|
||||
assert result == {
|
||||
"destination_address": "London Bridge",
|
||||
"travel_mode": TravelMode.TRANSIT,
|
||||
}
|
||||
|
||||
def test_hash(self) -> None:
|
||||
"""Test hashing works correctly."""
|
||||
dm1 = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
dm2 = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
dm3 = DestinationMode(
|
||||
destination_address="King's Cross",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
assert hash(dm1) == hash(dm2)
|
||||
assert dm1 == dm2
|
||||
assert hash(dm1) != hash(dm3)
|
||||
# Can be used as dict key
|
||||
d = {dm1: "route1"}
|
||||
assert d[dm2] == "route1"
|
||||
|
||||
|
||||
class TestRoutingInfoSerialization:
|
||||
"""Tests for routing info via RouteSerializer."""
|
||||
|
||||
def test_routing_info_property_returns_parsed_routes(self) -> None:
|
||||
"""Test routing_info property deserializes correctly."""
|
||||
routing_info = _make_sample_routing_info()
|
||||
listing = _make_listing_with_routing(None)
|
||||
serialized = listing.serialize_routing_info(routing_info)
|
||||
listing.routing_info_json = serialized
|
||||
|
||||
result = listing.routing_info
|
||||
assert len(result) == 1
|
||||
dest_mode = list(result.keys())[0]
|
||||
assert dest_mode.destination_address == "London Bridge"
|
||||
assert dest_mode.travel_mode == TravelMode.TRANSIT
|
||||
|
||||
routes = result[dest_mode]
|
||||
assert len(routes) == 1
|
||||
assert routes[0].distance_meters == 4500
|
||||
assert routes[0].duration_s == 600
|
||||
assert len(routes[0].legs) == 2
|
||||
assert routes[0].legs[0].distance_meters == 500
|
||||
assert routes[0].legs[0].travel_mode == TravelMode.WALK
|
||||
|
||||
def test_routing_info_empty_json(self) -> None:
|
||||
"""Test routing_info with no routing data."""
|
||||
listing = _make_listing_with_routing(None)
|
||||
assert listing.routing_info == {}
|
||||
|
||||
def test_serialize_routing_info_roundtrip(self) -> None:
|
||||
"""Test serialize then deserialize via routing_info property."""
|
||||
routing_info = _make_sample_routing_info()
|
||||
listing = _make_listing_with_routing(None)
|
||||
|
||||
# Serialize
|
||||
serialized = listing.serialize_routing_info(routing_info)
|
||||
assert isinstance(serialized, str)
|
||||
|
||||
# Assign and deserialize via property
|
||||
listing.routing_info_json = serialized
|
||||
deserialized = listing.routing_info
|
||||
|
||||
# Compare
|
||||
orig_dm = list(routing_info.keys())[0]
|
||||
result_dm = list(deserialized.keys())[0]
|
||||
assert orig_dm.destination_address == result_dm.destination_address
|
||||
assert orig_dm.travel_mode == result_dm.travel_mode
|
||||
|
||||
orig_route = routing_info[orig_dm][0]
|
||||
result_route = deserialized[result_dm][0]
|
||||
assert orig_route.distance_meters == result_route.distance_meters
|
||||
assert orig_route.duration_s == result_route.duration_s
|
||||
assert len(orig_route.legs) == len(result_route.legs)
|
||||
385
tests/unit/test_query.py
Normal file
385
tests/unit/test_query.py
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
"""Unit tests for rec/query.py."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import aiohttp
|
||||
|
||||
from rec.query import (
|
||||
detail_query,
|
||||
listing_query,
|
||||
probe_query,
|
||||
PropertyType,
|
||||
create_session,
|
||||
_build_base_params,
|
||||
_build_listing_params,
|
||||
_build_probe_params,
|
||||
ANDROID_APP_VERSION,
|
||||
ANDROID_APP_VERSION_LISTING,
|
||||
RIGHTMOVE_API_BASE,
|
||||
PROPERTY_LISTING_ENDPOINT,
|
||||
DEFAULT_HEADERS,
|
||||
LISTING_HEADERS,
|
||||
check_circuit_breaker,
|
||||
reset_circuit_breaker,
|
||||
get_circuit_breaker,
|
||||
)
|
||||
from models.listing import ListingType, FurnishType
|
||||
from config.scraper_config import ScraperConfig
|
||||
from rec.exceptions import CircuitBreakerOpenError
|
||||
from rec.throttle_detector import reset_throttle_metrics
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config() -> ScraperConfig:
|
||||
return ScraperConfig(
|
||||
max_concurrent_requests=5,
|
||||
request_delay_ms=10,
|
||||
slow_response_threshold=10.0,
|
||||
enable_circuit_breaker=True,
|
||||
circuit_breaker_failure_threshold=3,
|
||||
circuit_breaker_recovery_timeout=0.5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_no_cb() -> ScraperConfig:
|
||||
return ScraperConfig(enable_circuit_breaker=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals() -> None:
|
||||
reset_throttle_metrics()
|
||||
reset_circuit_breaker()
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(
|
||||
self,
|
||||
status: int = 200,
|
||||
json_data: dict | None = None,
|
||||
text: str = "",
|
||||
):
|
||||
self.status = status
|
||||
self._json_data = json_data or {}
|
||||
self._text = text
|
||||
|
||||
async def json(self) -> dict:
|
||||
return self._json_data
|
||||
|
||||
async def text(self) -> str:
|
||||
return self._text
|
||||
|
||||
async def __aenter__(self) -> "MockResponse":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: object) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def make_mock_session(response: MockResponse) -> MagicMock:
|
||||
"""Create a mock session whose .get() returns an async context manager."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(return_value=response)
|
||||
return mock_session
|
||||
|
||||
|
||||
def make_mock_session_fn(get_fn: object) -> MagicMock:
|
||||
"""Create a mock session whose .get() calls a function to produce responses."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(side_effect=get_fn)
|
||||
return mock_session
|
||||
|
||||
|
||||
class TestBuildBaseParams:
|
||||
def test_constructs_correct_params(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"TestDistrict": "REGION^123"}):
|
||||
params = _build_base_params(
|
||||
channel=ListingType.RENT,
|
||||
page=2,
|
||||
page_size=25,
|
||||
radius=1.5,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=3,
|
||||
district="TestDistrict",
|
||||
)
|
||||
|
||||
assert params["locationIdentifier"] == "REGION^123"
|
||||
assert params["channel"] == "RENT"
|
||||
assert params["page"] == "2"
|
||||
assert params["numberOfPropertiesPerPage"] == "25"
|
||||
assert params["radius"] == "1.5"
|
||||
assert params["sortBy"] == "distance"
|
||||
assert params["includeUnavailableProperties"] == "false"
|
||||
assert params["minPrice"] == "1000"
|
||||
assert params["maxPrice"] == "3000"
|
||||
assert params["minBedrooms"] == "1"
|
||||
assert params["maxBedrooms"] == "3"
|
||||
assert params["apiApplication"] == "ANDROID"
|
||||
assert params["appVersion"] == ANDROID_APP_VERSION_LISTING
|
||||
|
||||
def test_buy_channel_includes_dont_show_and_max_days(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_listing_params(
|
||||
page=1,
|
||||
channel=ListingType.BUY,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=100000,
|
||||
max_price=500000,
|
||||
district="D",
|
||||
mustNewHome=False,
|
||||
max_days_since_added=7,
|
||||
property_type=[],
|
||||
page_size=25,
|
||||
furnish_types=[],
|
||||
)
|
||||
|
||||
assert params["dontShow"] == "sharedOwnership,retirement"
|
||||
assert params["maxDaysSinceAdded"] == "7"
|
||||
|
||||
def test_rent_channel_includes_furnish_types(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_listing_params(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=3000,
|
||||
district="D",
|
||||
mustNewHome=False,
|
||||
max_days_since_added=30,
|
||||
property_type=[],
|
||||
page_size=25,
|
||||
furnish_types=[FurnishType.FURNISHED, FurnishType.UNFURNISHED],
|
||||
)
|
||||
|
||||
assert params["furnishTypes"] == "furnished,unfurnished"
|
||||
assert "dontShow" not in params
|
||||
assert "maxDaysSinceAdded" not in params
|
||||
|
||||
def test_buy_channel_probe_includes_dont_show(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_probe_params(
|
||||
channel=ListingType.BUY,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=100000,
|
||||
max_price=500000,
|
||||
district="D",
|
||||
max_days_since_added=7,
|
||||
furnish_types=[],
|
||||
)
|
||||
|
||||
assert params["dontShow"] == "sharedOwnership,retirement"
|
||||
assert params["maxDaysSinceAdded"] == "7"
|
||||
assert params["numberOfPropertiesPerPage"] == "1"
|
||||
|
||||
def test_probe_buy_skips_max_days_if_not_valid(self) -> None:
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
params = _build_probe_params(
|
||||
channel=ListingType.BUY,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=100000,
|
||||
max_price=500000,
|
||||
district="D",
|
||||
max_days_since_added=30,
|
||||
furnish_types=[],
|
||||
)
|
||||
|
||||
# 30 is not in [1, 3, 7, 14], so maxDaysSinceAdded is not added for probe
|
||||
assert "maxDaysSinceAdded" not in params
|
||||
|
||||
|
||||
class TestMutableDefaultArgFix:
|
||||
@pytest.mark.asyncio
|
||||
async def test_property_type_default_not_shared(self, config: ScraperConfig) -> None:
|
||||
"""Calling listing_query with no property_type should not share state between calls."""
|
||||
response = MockResponse(
|
||||
status=200,
|
||||
json_data={"totalAvailableResults": 0, "properties": []},
|
||||
)
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
# Call twice without explicit property_type
|
||||
await listing_query(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
session=mock_session,
|
||||
config=config,
|
||||
)
|
||||
await listing_query(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
session=mock_session,
|
||||
config=config,
|
||||
)
|
||||
# If mutable default was shared, this test would detect mutations.
|
||||
# The fact that it completes without error proves defaults are independent.
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_furnish_types_default_not_shared(self, config: ScraperConfig) -> None:
|
||||
"""Calling probe_query with no furnish_types should not share state between calls."""
|
||||
response = MockResponse(
|
||||
status=200,
|
||||
json_data={"totalAvailableResults": 0, "properties": []},
|
||||
)
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
config=config,
|
||||
)
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
class TestPropertyTypeEnum:
|
||||
def test_enum_values(self) -> None:
|
||||
assert PropertyType.BUNGALOW == "bungalow"
|
||||
assert PropertyType.DETACHED == "detached"
|
||||
assert PropertyType.FLAT == "flat"
|
||||
assert PropertyType.LAND == "land"
|
||||
assert PropertyType.PARK_HOME == "park-home"
|
||||
assert PropertyType.SEMI_DETACHED == "semi-detached"
|
||||
assert PropertyType.TERRACED == "terraced"
|
||||
|
||||
def test_enum_is_str(self) -> None:
|
||||
assert isinstance(PropertyType.FLAT, str)
|
||||
assert ",".join([PropertyType.FLAT, PropertyType.DETACHED]) == "flat,detached"
|
||||
|
||||
|
||||
class TestDetailQuery:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_200(self, config: ScraperConfig) -> None:
|
||||
expected_body = {"id": 12345, "address": "123 Test St"}
|
||||
response = MockResponse(status=200, json_data=expected_body)
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
result = await detail_query(12345, session=mock_session, config=config)
|
||||
assert result == expected_body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_non_200(self, config: ScraperConfig) -> None:
|
||||
response = MockResponse(status=404, text="Not Found")
|
||||
mock_session = make_mock_session(response)
|
||||
|
||||
with pytest.raises(Exception, match="Failed due to"):
|
||||
await detail_query(99999, session=mock_session, config=config)
|
||||
|
||||
|
||||
class TestCircuitBreakerBlocksRequests:
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_blocks_when_open(self, config: ScraperConfig) -> None:
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is not None
|
||||
for _ in range(config.circuit_breaker_failure_threshold):
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
await detail_query(1, session=mock_session, config=config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_blocks_listing_query(self, config: ScraperConfig) -> None:
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is not None
|
||||
for _ in range(config.circuit_breaker_failure_threshold):
|
||||
cb.record_failure()
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
await listing_query(
|
||||
page=1,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
session=mock_session,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_blocks_probe_query(self, config: ScraperConfig) -> None:
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is not None
|
||||
for _ in range(config.circuit_breaker_failure_threshold):
|
||||
cb.record_failure()
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"D": "LOC1"}):
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="D",
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
class TestConstants:
|
||||
def test_android_app_version(self) -> None:
|
||||
assert ANDROID_APP_VERSION == "3.70.0"
|
||||
|
||||
def test_android_app_version_listing(self) -> None:
|
||||
assert ANDROID_APP_VERSION_LISTING == "4.28.0"
|
||||
|
||||
def test_rightmove_api_base(self) -> None:
|
||||
assert RIGHTMOVE_API_BASE == "https://api.rightmove.co.uk/api"
|
||||
|
||||
def test_property_listing_endpoint(self) -> None:
|
||||
assert PROPERTY_LISTING_ENDPOINT == "https://api.rightmove.co.uk/api/property-listing"
|
||||
|
||||
def test_listing_headers_extends_default(self) -> None:
|
||||
for key, value in DEFAULT_HEADERS.items():
|
||||
assert LISTING_HEADERS[key] == value
|
||||
assert LISTING_HEADERS["Accept-Encoding"] == "gzip, deflate, br"
|
||||
374
tests/unit/test_query_splitter.py
Normal file
374
tests/unit/test_query_splitter.py
Normal file
|
|
@ -0,0 +1,374 @@
|
|||
"""Unit tests for QuerySplitter service."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from config.scraper_config import ScraperConfig
|
||||
from models.listing import ListingType, QueryParameters
|
||||
from services.query_splitter import QuerySplitter, SubQuery
|
||||
|
||||
|
||||
class TestScraperConfig:
|
||||
"""Tests for the ScraperConfig dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test that default values are set correctly."""
|
||||
config = ScraperConfig()
|
||||
assert config.max_concurrent_requests == 5
|
||||
assert config.request_delay_ms == 100
|
||||
assert config.result_cap == 1500
|
||||
assert config.split_threshold == 1200
|
||||
assert config.min_price_band == 100
|
||||
assert config.max_pages_per_query == 60
|
||||
assert config.proxy_url is None
|
||||
|
||||
def test_from_env(self) -> None:
|
||||
"""Test loading configuration from environment variables."""
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"RIGHTMOVE_MAX_CONCURRENT": "10",
|
||||
"RIGHTMOVE_REQUEST_DELAY_MS": "200",
|
||||
"RIGHTMOVE_SPLIT_THRESHOLD": "1000",
|
||||
"RIGHTMOVE_MIN_PRICE_BAND": "50",
|
||||
"RIGHTMOVE_MAX_PAGES": "30",
|
||||
"RIGHTMOVE_PROXY_URL": "socks5://localhost:9050",
|
||||
},
|
||||
):
|
||||
config = ScraperConfig.from_env()
|
||||
assert config.max_concurrent_requests == 10
|
||||
assert config.request_delay_ms == 200
|
||||
assert config.split_threshold == 1000
|
||||
assert config.min_price_band == 50
|
||||
assert config.max_pages_per_query == 30
|
||||
assert config.proxy_url == "socks5://localhost:9050"
|
||||
|
||||
def test_from_env_empty_proxy(self) -> None:
|
||||
"""Test that empty proxy URL is converted to None."""
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"RIGHTMOVE_PROXY_URL": "",
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
config = ScraperConfig.from_env()
|
||||
assert config.proxy_url is None
|
||||
|
||||
|
||||
class TestSubQuery:
|
||||
"""Tests for the SubQuery dataclass."""
|
||||
|
||||
def test_price_range_calculation(self) -> None:
|
||||
"""Test that price_range is calculated correctly."""
|
||||
sq = SubQuery(
|
||||
district="Kings Cross",
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=2,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
)
|
||||
assert sq.price_range == 1000
|
||||
|
||||
|
||||
class TestQuerySplitter:
|
||||
"""Tests for the QuerySplitter class."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> ScraperConfig:
|
||||
"""Create a test configuration."""
|
||||
return ScraperConfig(
|
||||
max_concurrent_requests=5,
|
||||
request_delay_ms=10, # Faster for testing
|
||||
result_cap=1500,
|
||||
split_threshold=1200,
|
||||
min_price_band=100,
|
||||
max_pages_per_query=60,
|
||||
proxy_url=None,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def splitter(self, config: ScraperConfig) -> QuerySplitter:
|
||||
"""Create a QuerySplitter instance."""
|
||||
return QuerySplitter(config)
|
||||
|
||||
@pytest.fixture
|
||||
def parameters(self) -> QueryParameters:
|
||||
"""Create test query parameters."""
|
||||
return QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=3,
|
||||
min_price=1000,
|
||||
max_price=5000,
|
||||
district_names={"Kings Cross", "Angel"},
|
||||
)
|
||||
|
||||
def test_create_initial_subqueries(
|
||||
self, splitter: QuerySplitter, parameters: QueryParameters
|
||||
) -> None:
|
||||
"""Test that initial subqueries are created correctly."""
|
||||
districts = {"Kings Cross": "STATION^5168", "Angel": "STATION^1234"}
|
||||
|
||||
subqueries = splitter.create_initial_subqueries(parameters, districts)
|
||||
|
||||
# 2 districts × 2 bedroom counts (2,3) = 4 subqueries
|
||||
assert len(subqueries) == 4
|
||||
|
||||
# Check first subquery
|
||||
assert subqueries[0].district == "Kings Cross"
|
||||
assert subqueries[0].min_bedrooms == 2
|
||||
assert subqueries[0].max_bedrooms == 2
|
||||
assert subqueries[0].min_price == 1000
|
||||
assert subqueries[0].max_price == 5000
|
||||
|
||||
def test_split_by_price(self, splitter: QuerySplitter) -> None:
|
||||
"""Test that price splitting works correctly."""
|
||||
sq = SubQuery(
|
||||
district="Kings Cross",
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=2,
|
||||
min_price=1000,
|
||||
max_price=5000,
|
||||
)
|
||||
|
||||
halves = splitter.split_by_price(sq)
|
||||
|
||||
assert len(halves) == 2
|
||||
assert halves[0].min_price == 1000
|
||||
assert halves[0].max_price == 3000 # midpoint
|
||||
assert halves[1].min_price == 3000
|
||||
assert halves[1].max_price == 5000
|
||||
|
||||
# Both should have same bedroom range and district
|
||||
for half in halves:
|
||||
assert half.district == "Kings Cross"
|
||||
assert half.min_bedrooms == 2
|
||||
assert half.max_bedrooms == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_probe_result_count(
|
||||
self, splitter: QuerySplitter, parameters: QueryParameters
|
||||
) -> None:
|
||||
"""Test probing API for result count."""
|
||||
sq = SubQuery(
|
||||
district="Kings Cross",
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=2,
|
||||
min_price=1000,
|
||||
max_price=5000,
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
|
||||
# Mock the probe_query function
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
mock_probe.return_value = {"totalAvailableResults": 800}
|
||||
|
||||
count = await splitter.probe_result_count(sq, mock_session, parameters)
|
||||
|
||||
assert count == 800
|
||||
mock_probe.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_probe_result_count_handles_error(
|
||||
self, splitter: QuerySplitter, parameters: QueryParameters
|
||||
) -> None:
|
||||
"""Test that probe_result_count handles errors gracefully."""
|
||||
sq = SubQuery(
|
||||
district="Kings Cross",
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=2,
|
||||
min_price=1000,
|
||||
max_price=5000,
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
mock_probe.side_effect = Exception("API error")
|
||||
|
||||
count = await splitter.probe_result_count(sq, mock_session, parameters)
|
||||
|
||||
# Should return 0 on error
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adaptive_split_no_split_needed(
|
||||
self, splitter: QuerySplitter, parameters: QueryParameters
|
||||
) -> None:
|
||||
"""Test adaptive split when results are below threshold."""
|
||||
sq = SubQuery(
|
||||
district="Kings Cross",
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=2,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_semaphore = AsyncMock()
|
||||
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
# First half has 600 results, second half has 500
|
||||
mock_probe.side_effect = [
|
||||
{"totalAvailableResults": 600},
|
||||
{"totalAvailableResults": 500},
|
||||
]
|
||||
|
||||
result = await splitter.adaptive_split(
|
||||
sq, mock_session, parameters, mock_semaphore
|
||||
)
|
||||
|
||||
# Both halves are under threshold (1200), so we get 2 subqueries back
|
||||
assert len(result) == 2
|
||||
assert result[0].estimated_results == 600
|
||||
assert result[1].estimated_results == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adaptive_split_recursive_splitting(
|
||||
self, splitter: QuerySplitter, parameters: QueryParameters
|
||||
) -> None:
|
||||
"""Test adaptive split performs recursive splitting when needed."""
|
||||
sq = SubQuery(
|
||||
district="Kings Cross",
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=2,
|
||||
min_price=1000,
|
||||
max_price=5000,
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_semaphore = AsyncMock()
|
||||
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
# First split: 1000-3000 has 1300 (over threshold), 3000-5000 has 800
|
||||
# Second split of 1000-3000: 1000-2000 has 700, 2000-3000 has 600
|
||||
mock_probe.side_effect = [
|
||||
{"totalAvailableResults": 1300}, # First half - needs more splitting
|
||||
{"totalAvailableResults": 800}, # Second half - OK
|
||||
{"totalAvailableResults": 700}, # First quarter - OK
|
||||
{"totalAvailableResults": 600}, # Second quarter - OK
|
||||
]
|
||||
|
||||
result = await splitter.adaptive_split(
|
||||
sq, mock_session, parameters, mock_semaphore
|
||||
)
|
||||
|
||||
# Should get 3 subqueries: [1000-2000 (700), 2000-3000 (600), 3000-5000 (800)]
|
||||
assert len(result) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adaptive_split_respects_min_price_band(
|
||||
self, splitter: QuerySplitter, parameters: QueryParameters
|
||||
) -> None:
|
||||
"""Test that adaptive split stops at min_price_band."""
|
||||
sq = SubQuery(
|
||||
district="Kings Cross",
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=2,
|
||||
min_price=1000,
|
||||
max_price=1050, # Only 50 range, below min_price_band of 100
|
||||
estimated_results=1500, # Over threshold but can't split
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_semaphore = AsyncMock()
|
||||
|
||||
result = await splitter.adaptive_split(
|
||||
sq, mock_session, parameters, mock_semaphore
|
||||
)
|
||||
|
||||
# Can't split below min_price_band, should return original
|
||||
assert len(result) == 1
|
||||
assert result[0].min_price == 1000
|
||||
assert result[0].max_price == 1050
|
||||
|
||||
def test_calculate_total_estimated_results(
|
||||
self, splitter: QuerySplitter
|
||||
) -> None:
|
||||
"""Test calculation of total estimated results."""
|
||||
subqueries = [
|
||||
SubQuery(
|
||||
district="Kings Cross",
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=2,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
estimated_results=500,
|
||||
),
|
||||
SubQuery(
|
||||
district="Kings Cross",
|
||||
min_bedrooms=3,
|
||||
max_bedrooms=3,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
estimated_results=300,
|
||||
),
|
||||
SubQuery(
|
||||
district="Angel",
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=2,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
estimated_results=None, # Not probed
|
||||
),
|
||||
]
|
||||
|
||||
total = splitter.calculate_total_estimated_results(subqueries)
|
||||
assert total == 800 # 500 + 300 + 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_split_integration(
|
||||
self, splitter: QuerySplitter, parameters: QueryParameters
|
||||
) -> None:
|
||||
"""Integration test for the full split workflow."""
|
||||
mock_session = AsyncMock()
|
||||
mock_districts = {"Kings Cross": "STATION^5168", "Angel": "STATION^1234"}
|
||||
|
||||
with patch("services.query_splitter.get_districts", return_value=mock_districts):
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
# Mock probe results for each initial subquery
|
||||
# 2 districts × 2 bedroom counts = 4 initial subqueries
|
||||
mock_probe.side_effect = [
|
||||
{"totalAvailableResults": 500}, # KC 2BR - OK
|
||||
{"totalAvailableResults": 1300}, # KC 3BR - needs split
|
||||
{"totalAvailableResults": 600}, # Angel 2BR - OK
|
||||
{"totalAvailableResults": 800}, # Angel 3BR - OK
|
||||
# Split KC 3BR
|
||||
{"totalAvailableResults": 700}, # KC 3BR first half
|
||||
{"totalAvailableResults": 600}, # KC 3BR second half
|
||||
]
|
||||
|
||||
result = await splitter.split(parameters, mock_session)
|
||||
|
||||
# Should have 5 subqueries total:
|
||||
# KC 2BR (500), KC 3BR split into 2 (700+600), Angel 2BR (600), Angel 3BR (800)
|
||||
assert len(result) == 5
|
||||
|
||||
# Verify total estimated results
|
||||
total = splitter.calculate_total_estimated_results(result)
|
||||
assert total == 3200 # 500 + 700 + 600 + 600 + 800
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_split_with_on_progress_callback(
|
||||
self, splitter: QuerySplitter, parameters: QueryParameters
|
||||
) -> None:
|
||||
"""Test that on_progress callback is called during split."""
|
||||
mock_session = AsyncMock()
|
||||
mock_districts = {"Kings Cross": "STATION^5168", "Angel": "STATION^1234"}
|
||||
progress_calls = []
|
||||
|
||||
def on_progress(phase: str, message: str, **kwargs: object) -> None:
|
||||
progress_calls.append((phase, message))
|
||||
|
||||
with patch("services.query_splitter.get_districts", return_value=mock_districts):
|
||||
with patch("rec.query.probe_query") as mock_probe:
|
||||
mock_probe.return_value = {"totalAvailableResults": 500}
|
||||
|
||||
await splitter.split(parameters, mock_session, on_progress)
|
||||
|
||||
# Should have received at least 2 progress updates
|
||||
assert len(progress_calls) >= 2
|
||||
phases = [call[0] for call in progress_calls]
|
||||
assert "splitting" in phases
|
||||
assert "splitting_complete" in phases
|
||||
74
tests/unit/test_redis_lock.py
Normal file
74
tests/unit/test_redis_lock.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
"""Unit tests for Redis distributed lock."""
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from utils.redis_lock import redis_lock, get_redis_client
|
||||
|
||||
|
||||
class TestRedisLock:
|
||||
"""Tests for redis_lock context manager."""
|
||||
|
||||
@mock.patch("utils.redis_lock.get_redis_client")
|
||||
def test_lock_acquired_successfully(self, mock_get_client):
|
||||
"""Test lock acquisition when no other lock exists."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.set.return_value = True
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with redis_lock("test_lock") as acquired:
|
||||
assert acquired is True
|
||||
|
||||
mock_client.set.assert_called_once_with("lock:test_lock", "1", nx=True, ex=3600 * 4)
|
||||
mock_client.delete.assert_called_once_with("lock:test_lock")
|
||||
|
||||
@mock.patch("utils.redis_lock.get_redis_client")
|
||||
def test_lock_not_acquired(self, mock_get_client):
|
||||
"""Test lock not acquired when another lock exists."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.set.return_value = None # Redis returns None when nx=True fails
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with redis_lock("test_lock") as acquired:
|
||||
assert acquired is False
|
||||
|
||||
mock_client.set.assert_called_once_with("lock:test_lock", "1", nx=True, ex=3600 * 4)
|
||||
# Should NOT call delete since we didn't acquire the lock
|
||||
mock_client.delete.assert_not_called()
|
||||
|
||||
@mock.patch("utils.redis_lock.get_redis_client")
|
||||
def test_lock_released_on_exception(self, mock_get_client):
|
||||
"""Test lock is released even when exception occurs."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.set.return_value = True
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with redis_lock("test_lock") as acquired:
|
||||
assert acquired is True
|
||||
raise ValueError("Test error")
|
||||
|
||||
# Lock should still be released
|
||||
mock_client.delete.assert_called_once_with("lock:test_lock")
|
||||
|
||||
@mock.patch("utils.redis_lock.get_redis_client")
|
||||
def test_custom_timeout(self, mock_get_client):
|
||||
"""Test lock with custom timeout."""
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.set.return_value = True
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with redis_lock("test_lock", timeout=300) as acquired:
|
||||
assert acquired is True
|
||||
|
||||
mock_client.set.assert_called_once_with("lock:test_lock", "1", nx=True, ex=300)
|
||||
|
||||
@mock.patch("utils.redis_lock.redis")
|
||||
def test_get_redis_client_uses_broker_url(self, mock_redis):
|
||||
"""Test Redis client is created from CELERY_BROKER_URL."""
|
||||
with mock.patch.dict("os.environ", {"CELERY_BROKER_URL": "redis://testhost:1234/5"}):
|
||||
get_redis_client()
|
||||
|
||||
mock_redis.from_url.assert_called_once_with(
|
||||
"redis://testhost:1234/5", decode_responses=True
|
||||
)
|
||||
381
tests/unit/test_repository.py
Normal file
381
tests/unit/test_repository.py
Normal file
|
|
@ -0,0 +1,381 @@
|
|||
"""Unit tests for ListingRepository."""
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from models.listing import (
|
||||
FurnishType,
|
||||
ListingType,
|
||||
QueryParameters,
|
||||
RentListing,
|
||||
)
|
||||
from repositories.listing_repository import ListingRepository
|
||||
|
||||
|
||||
class TestListingRepository:
|
||||
"""Tests for ListingRepository methods."""
|
||||
|
||||
async def test_get_listings_empty_db(
|
||||
self, listing_repository: ListingRepository
|
||||
) -> None:
|
||||
"""Test that get_listings returns empty list for empty database."""
|
||||
listings = await listing_repository.get_listings()
|
||||
assert listings == []
|
||||
|
||||
async def test_get_listings_returns_inserted_listings(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listing: RentListing,
|
||||
) -> None:
|
||||
"""Test that get_listings returns listings that were inserted."""
|
||||
await listing_repository.upsert_listings([sample_rent_listing])
|
||||
listings = await listing_repository.get_listings()
|
||||
assert len(listings) == 1
|
||||
assert listings[0].id == sample_rent_listing.id
|
||||
|
||||
async def test_upsert_listings_creates_new(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listing: RentListing,
|
||||
) -> None:
|
||||
"""Test that upsert_listings creates new listings."""
|
||||
result = await listing_repository.upsert_listings([sample_rent_listing])
|
||||
assert len(result) == 1
|
||||
assert result[0].id == sample_rent_listing.id
|
||||
|
||||
# Verify it's in the database
|
||||
listings = await listing_repository.get_listings()
|
||||
assert len(listings) == 1
|
||||
|
||||
async def test_upsert_listings_updates_existing(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listing: RentListing,
|
||||
) -> None:
|
||||
"""Test that upsert_listings updates existing listings."""
|
||||
# Insert initial listing
|
||||
await listing_repository.upsert_listings([sample_rent_listing])
|
||||
|
||||
# Update the listing
|
||||
sample_rent_listing.price = 3000.0
|
||||
await listing_repository.upsert_listings([sample_rent_listing])
|
||||
|
||||
# Verify update
|
||||
listings = await listing_repository.get_listings()
|
||||
assert len(listings) == 1
|
||||
assert listings[0].price == 3000.0
|
||||
|
||||
async def test_mark_seen_updates_timestamp(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listing: RentListing,
|
||||
) -> None:
|
||||
"""Test that mark_seen updates the last_seen timestamp."""
|
||||
# Set an old timestamp
|
||||
old_time = datetime.now() - timedelta(days=7)
|
||||
sample_rent_listing.last_seen = old_time
|
||||
await listing_repository.upsert_listings([sample_rent_listing])
|
||||
|
||||
# Mark as seen
|
||||
await listing_repository.mark_seen(sample_rent_listing.id)
|
||||
|
||||
# Verify timestamp was updated
|
||||
listings = await listing_repository.get_listings()
|
||||
assert len(listings) == 1
|
||||
assert listings[0].last_seen > old_time
|
||||
|
||||
async def test_mark_seen_nonexistent_listing(
|
||||
self, listing_repository: ListingRepository
|
||||
) -> None:
|
||||
"""Test that mark_seen handles nonexistent listings gracefully."""
|
||||
# Should not raise an exception
|
||||
await listing_repository.mark_seen(999999)
|
||||
|
||||
async def test_get_listings_with_only_ids(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test that get_listings filters by only_ids."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
# Request only specific IDs
|
||||
listings = await listing_repository.get_listings(only_ids=[1, 3])
|
||||
assert len(listings) == 2
|
||||
listing_ids = [l.id for l in listings]
|
||||
assert 1 in listing_ids
|
||||
assert 3 in listing_ids
|
||||
assert 2 not in listing_ids
|
||||
|
||||
async def test_get_listings_with_limit(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test that get_listings respects limit parameter."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
listings = await listing_repository.get_listings(limit=2)
|
||||
assert len(listings) == 2
|
||||
|
||||
|
||||
class TestListingRepositoryFilters:
|
||||
"""Tests for ListingRepository query parameter filtering."""
|
||||
|
||||
async def test_filter_by_bedrooms(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test filtering by bedroom count."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
query_params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=2,
|
||||
)
|
||||
listings = await listing_repository.get_listings(query_parameters=query_params)
|
||||
assert len(listings) == 1
|
||||
assert listings[0].number_of_bedrooms == 2
|
||||
|
||||
async def test_filter_by_price_range(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test filtering by price range."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
query_params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_price=1800,
|
||||
max_price=2500,
|
||||
)
|
||||
listings = await listing_repository.get_listings(query_parameters=query_params)
|
||||
assert len(listings) == 1
|
||||
assert listings[0].price == 2000.0
|
||||
|
||||
async def test_filter_by_min_sqm(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test filtering by minimum square meters."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
query_params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_sqm=60,
|
||||
)
|
||||
listings = await listing_repository.get_listings(query_parameters=query_params)
|
||||
assert len(listings) == 1
|
||||
assert listings[0].square_meters == 80.0
|
||||
|
||||
async def test_filter_by_furnish_type(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test filtering by furnish type."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
query_params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
furnish_types=[FurnishType.UNFURNISHED],
|
||||
)
|
||||
listings = await listing_repository.get_listings(query_parameters=query_params)
|
||||
assert len(listings) == 1
|
||||
assert listings[0].furnish_type == FurnishType.UNFURNISHED
|
||||
|
||||
async def test_filter_by_last_seen_days(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test filtering by last_seen_days."""
|
||||
# Make one listing old
|
||||
sample_rent_listings[0].last_seen = datetime.now() - timedelta(days=30)
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
query_params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
last_seen_days=7,
|
||||
)
|
||||
listings = await listing_repository.get_listings(query_parameters=query_params)
|
||||
# Only 2 should be recent enough
|
||||
assert len(listings) == 2
|
||||
|
||||
async def test_combined_filters(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test combining multiple filters."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
query_params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
min_price=1000,
|
||||
max_price=2500,
|
||||
furnish_types=[FurnishType.FURNISHED, FurnishType.UNFURNISHED],
|
||||
)
|
||||
listings = await listing_repository.get_listings(query_parameters=query_params)
|
||||
# Should match listings with 1-2 bedrooms in price range
|
||||
assert len(listings) == 2
|
||||
|
||||
|
||||
class TestListingRepositoryStreaming:
|
||||
"""Tests for streaming and optimized query methods."""
|
||||
|
||||
async def test_count_listings_empty_db(
|
||||
self, listing_repository: ListingRepository
|
||||
) -> None:
|
||||
"""Test count returns 0 for empty database."""
|
||||
count = listing_repository.count_listings()
|
||||
assert count == 0
|
||||
|
||||
async def test_count_listings_with_data(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test count returns correct number."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
count = listing_repository.count_listings()
|
||||
assert count == 3
|
||||
|
||||
async def test_count_listings_with_filters(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test count respects query parameters."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
query_params = QueryParameters(
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=3,
|
||||
)
|
||||
count = listing_repository.count_listings(query_parameters=query_params)
|
||||
assert count == 2
|
||||
|
||||
async def test_stream_listings_optimized_returns_dicts(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test optimized streaming returns dict rows."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
results = list(listing_repository.stream_listings_optimized())
|
||||
assert len(results) == 3
|
||||
# Each result should be a dict
|
||||
for row in results:
|
||||
assert isinstance(row, dict)
|
||||
assert "id" in row
|
||||
assert "price" in row
|
||||
assert "number_of_bedrooms" in row
|
||||
|
||||
async def test_stream_listings_optimized_respects_limit(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test streaming limit parameter."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
results = list(listing_repository.stream_listings_optimized(limit=2))
|
||||
assert len(results) == 2
|
||||
|
||||
async def test_get_listing_ids(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
sample_rent_listings: list[RentListing],
|
||||
) -> None:
|
||||
"""Test get_listing_ids returns set of IDs."""
|
||||
await listing_repository.upsert_listings(sample_rent_listings)
|
||||
|
||||
ids = listing_repository.get_listing_ids()
|
||||
assert isinstance(ids, set)
|
||||
assert ids == {1, 2, 3}
|
||||
|
||||
async def test_get_listing_ids_empty_db(
|
||||
self,
|
||||
listing_repository: ListingRepository,
|
||||
) -> None:
|
||||
"""Test get_listing_ids returns empty set for empty database."""
|
||||
ids = listing_repository.get_listing_ids()
|
||||
assert isinstance(ids, set)
|
||||
assert len(ids) == 0
|
||||
|
||||
|
||||
class TestFurnishTypeParsing:
|
||||
"""Tests for _parse_furnish_type helper."""
|
||||
|
||||
def test_parse_furnish_type_none_detailobject(self) -> None:
|
||||
"""Test that None detailobject returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = None
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_missing_property_key(self) -> None:
|
||||
"""Test that missing 'property' key returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_missing_let_furnish_type(self) -> None:
|
||||
"""Test that missing 'letFurnishType' key returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_null_value(self) -> None:
|
||||
"""Test that null letFurnishType value returns UNKNOWN."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": None}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNKNOWN
|
||||
|
||||
def test_parse_furnish_type_furnished(self) -> None:
|
||||
"""Test that 'Furnished' is parsed correctly."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Furnished"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.FURNISHED
|
||||
|
||||
def test_parse_furnish_type_unfurnished(self) -> None:
|
||||
"""Test that 'Unfurnished' is parsed correctly."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Unfurnished"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.UNFURNISHED
|
||||
|
||||
def test_parse_furnish_type_part_furnished(self) -> None:
|
||||
"""Test that 'Part Furnished' is parsed correctly."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Part Furnished"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.PART_FURNISHED
|
||||
|
||||
def test_parse_furnish_type_landlord_variant(self) -> None:
|
||||
"""Test that landlord variants map to ASK_LANDLORD."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "Ask Landlord Please"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.ASK_LANDLORD
|
||||
|
||||
def test_parse_furnish_type_landlord_case_insensitive(self) -> None:
|
||||
"""Test that landlord check is case-insensitive."""
|
||||
listing = MagicMock()
|
||||
listing.detailobject = {"property": {"letFurnishType": "LANDLORD decides"}}
|
||||
result = ListingRepository._parse_furnish_type(listing)
|
||||
assert result == FurnishType.ASK_LANDLORD
|
||||
10
tests/unit/test_route_calculator.py
Normal file
10
tests/unit/test_route_calculator.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""Unit tests for services/route_calculator.py."""
|
||||
from services.route_calculator import _parse_duration
|
||||
|
||||
|
||||
class TestParseDuration:
|
||||
def test_parse_normal_duration(self) -> None:
|
||||
assert _parse_duration("123s") == 123
|
||||
|
||||
def test_parse_zero_duration(self) -> None:
|
||||
assert _parse_duration("0s") == 0
|
||||
72
tests/unit/test_route_serializer.py
Normal file
72
tests/unit/test_route_serializer.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""Unit tests for rec/route_serializer.py."""
|
||||
from models.listing import DestinationMode, Route, RouteLegStep
|
||||
from rec.route_serializer import RouteSerializer
|
||||
from rec.routing import TravelMode
|
||||
|
||||
|
||||
def _make_sample_routing_info() -> dict[DestinationMode, list[Route]]:
|
||||
destination_mode = DestinationMode(
|
||||
destination_address="London Bridge",
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
)
|
||||
routes = [
|
||||
Route(
|
||||
legs=[
|
||||
RouteLegStep(
|
||||
distance_meters=500,
|
||||
duration_s=120,
|
||||
travel_mode=TravelMode.WALK,
|
||||
),
|
||||
RouteLegStep(
|
||||
distance_meters=4000,
|
||||
duration_s=480,
|
||||
travel_mode=TravelMode.TRANSIT,
|
||||
),
|
||||
],
|
||||
distance_meters=4500,
|
||||
duration_s=600,
|
||||
)
|
||||
]
|
||||
return {destination_mode: routes}
|
||||
|
||||
|
||||
class TestRouteSerializer:
|
||||
def test_serialize_then_deserialize_roundtrip(self) -> None:
|
||||
routing_info = _make_sample_routing_info()
|
||||
serialized = RouteSerializer.serialize(routing_info)
|
||||
deserialized = RouteSerializer.deserialize(serialized)
|
||||
|
||||
assert len(deserialized) == 1
|
||||
dest_mode = list(deserialized.keys())[0]
|
||||
assert dest_mode.destination_address == "London Bridge"
|
||||
assert dest_mode.travel_mode == TravelMode.TRANSIT
|
||||
|
||||
routes = deserialized[dest_mode]
|
||||
assert len(routes) == 1
|
||||
assert routes[0].distance_meters == 4500
|
||||
assert routes[0].duration_s == 600
|
||||
assert len(routes[0].legs) == 2
|
||||
assert routes[0].legs[0].distance_meters == 500
|
||||
assert routes[0].legs[0].travel_mode == TravelMode.WALK
|
||||
assert routes[0].legs[1].travel_mode == TravelMode.TRANSIT
|
||||
|
||||
def test_deserialize_sample_json(self) -> None:
|
||||
import json
|
||||
import dataclasses
|
||||
|
||||
routing_info = _make_sample_routing_info()
|
||||
# Build the JSON manually to test deserialize independently
|
||||
json_str = json.dumps(
|
||||
{
|
||||
json.dumps(dataclasses.asdict(dm)): [
|
||||
json.dumps(dataclasses.asdict(r)) for r in routes
|
||||
]
|
||||
for dm, routes in routing_info.items()
|
||||
}
|
||||
)
|
||||
|
||||
result = RouteSerializer.deserialize(json_str)
|
||||
assert len(result) == 1
|
||||
dest_mode = list(result.keys())[0]
|
||||
assert dest_mode.destination_address == "London Bridge"
|
||||
assert result[dest_mode][0].duration_s == 600
|
||||
67
tests/unit/test_routing.py
Normal file
67
tests/unit/test_routing.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""Unit tests for rec/routing.py."""
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from rec.routing import TravelMode, transit_route, ROUTES_API_URL, ROUTES_FIELD_MASK
|
||||
from rec.exceptions import RoutingApiError
|
||||
|
||||
|
||||
class TestTravelMode:
|
||||
def test_enum_values(self) -> None:
|
||||
assert TravelMode.TRANSIT == "TRANSIT"
|
||||
assert TravelMode.BICYCLE == "BICYCLE"
|
||||
assert TravelMode.WALK == "WALK"
|
||||
assert TravelMode.DRIVE == "DRIVE"
|
||||
|
||||
def test_enum_has_four_members(self) -> None:
|
||||
assert len(TravelMode) == 4
|
||||
|
||||
|
||||
class TestTransitRoute:
|
||||
@patch("rec.routing.requests.post")
|
||||
@patch("rec.routing.nextMonday")
|
||||
def test_success_response(self, mock_monday: MagicMock, mock_post: MagicMock) -> None:
|
||||
mock_monday.return_value = MagicMock(
|
||||
strftime=MagicMock(return_value="2024-01-08T09:00:00.000000Z")
|
||||
)
|
||||
expected = {"routes": [{"duration": "600s", "distanceMeters": 5000}]}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = expected
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with patch.dict(os.environ, {"ROUTING_API_KEY": "test-key"}):
|
||||
result = transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
|
||||
|
||||
assert result == expected
|
||||
mock_post.assert_called_once()
|
||||
call_kwargs = mock_post.call_args
|
||||
assert call_kwargs.kwargs["headers"]["X-Goog-Api-Key"] == "test-key"
|
||||
|
||||
@patch("rec.routing.requests.post")
|
||||
@patch("rec.routing.nextMonday")
|
||||
def test_raises_routing_api_error_on_non_200(self, mock_monday: MagicMock, mock_post: MagicMock) -> None:
|
||||
mock_monday.return_value = MagicMock(
|
||||
strftime=MagicMock(return_value="2024-01-08T09:00:00.000000Z")
|
||||
)
|
||||
error_body = {"error": {"message": "Invalid API key", "status": "PERMISSION_DENIED"}}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 403
|
||||
mock_response.json.return_value = error_body
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with patch.dict(os.environ, {"ROUTING_API_KEY": "bad-key"}):
|
||||
with pytest.raises(RoutingApiError) as exc_info:
|
||||
transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.response_body == error_body
|
||||
|
||||
def test_raises_key_error_when_api_key_not_set(self) -> None:
|
||||
env = os.environ.copy()
|
||||
env.pop("ROUTING_API_KEY", None)
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
with pytest.raises(KeyError):
|
||||
transit_route(51.5, -0.1, "London Bridge", TravelMode.TRANSIT)
|
||||
293
tests/unit/test_schedule_config.py
Normal file
293
tests/unit/test_schedule_config.py
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
"""Unit tests for schedule configuration."""
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from config.schedule_config import ScheduleConfig, SchedulesConfig
|
||||
from models.listing import FurnishType, ListingType
|
||||
|
||||
|
||||
class TestScheduleConfig:
|
||||
"""Tests for ScheduleConfig model."""
|
||||
|
||||
def test_basic_creation_with_defaults(self):
|
||||
"""Test creating a schedule with minimal required fields."""
|
||||
schedule = ScheduleConfig(name="Test Schedule", listing_type=ListingType.RENT)
|
||||
|
||||
assert schedule.name == "Test Schedule"
|
||||
assert schedule.enabled is True
|
||||
assert schedule.minute == "0"
|
||||
assert schedule.hour == "2"
|
||||
assert schedule.day_of_week == "*"
|
||||
assert schedule.listing_type == ListingType.RENT
|
||||
assert schedule.min_bedrooms == 1
|
||||
assert schedule.max_bedrooms == 999
|
||||
assert schedule.min_price == 0
|
||||
assert schedule.max_price == 10_000_000
|
||||
assert schedule.district_names == []
|
||||
assert schedule.furnish_types is None
|
||||
|
||||
def test_full_creation(self):
|
||||
"""Test creating a schedule with all fields specified."""
|
||||
schedule = ScheduleConfig(
|
||||
name="Full Schedule",
|
||||
enabled=False,
|
||||
minute="30",
|
||||
hour="4",
|
||||
day_of_week="1,3,5",
|
||||
listing_type=ListingType.BUY,
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=3,
|
||||
min_price=400000,
|
||||
max_price=800000,
|
||||
district_names=["Westminster", "Camden"],
|
||||
furnish_types=["furnished", "unfurnished"],
|
||||
)
|
||||
|
||||
assert schedule.name == "Full Schedule"
|
||||
assert schedule.enabled is False
|
||||
assert schedule.minute == "30"
|
||||
assert schedule.hour == "4"
|
||||
assert schedule.day_of_week == "1,3,5"
|
||||
assert schedule.listing_type == ListingType.BUY
|
||||
assert schedule.min_bedrooms == 2
|
||||
assert schedule.max_bedrooms == 3
|
||||
assert schedule.min_price == 400000
|
||||
assert schedule.max_price == 800000
|
||||
assert schedule.district_names == ["Westminster", "Camden"]
|
||||
assert schedule.furnish_types == ["furnished", "unfurnished"]
|
||||
|
||||
def test_to_query_parameters(self):
|
||||
"""Test conversion to QueryParameters."""
|
||||
schedule = ScheduleConfig(
|
||||
name="Test",
|
||||
listing_type=ListingType.RENT,
|
||||
min_bedrooms=2,
|
||||
max_bedrooms=3,
|
||||
min_price=2000,
|
||||
max_price=4000,
|
||||
district_names=["Westminster"],
|
||||
furnish_types=["furnished"],
|
||||
)
|
||||
|
||||
params = schedule.to_query_parameters()
|
||||
|
||||
assert params.listing_type == ListingType.RENT
|
||||
assert params.min_bedrooms == 2
|
||||
assert params.max_bedrooms == 3
|
||||
assert params.min_price == 2000
|
||||
assert params.max_price == 4000
|
||||
assert params.district_names == {"Westminster"}
|
||||
assert params.furnish_types == [FurnishType.FURNISHED]
|
||||
|
||||
def test_to_query_parameters_no_furnish_types(self):
|
||||
"""Test conversion when furnish_types is None."""
|
||||
schedule = ScheduleConfig(
|
||||
name="Test",
|
||||
listing_type=ListingType.BUY,
|
||||
)
|
||||
|
||||
params = schedule.to_query_parameters()
|
||||
|
||||
assert params.furnish_types is None
|
||||
|
||||
|
||||
class TestCronValidation:
|
||||
"""Tests for cron field validation."""
|
||||
|
||||
# Valid minute values
|
||||
@pytest.mark.parametrize(
|
||||
"minute",
|
||||
[
|
||||
"0",
|
||||
"59",
|
||||
"*",
|
||||
"*/5",
|
||||
"*/15",
|
||||
"0,15,30,45",
|
||||
],
|
||||
)
|
||||
def test_valid_minute(self, minute: str):
|
||||
"""Test valid minute values are accepted."""
|
||||
schedule = ScheduleConfig(
|
||||
name="Test", listing_type=ListingType.RENT, minute=minute
|
||||
)
|
||||
assert schedule.minute == minute
|
||||
|
||||
# Invalid minute values
|
||||
@pytest.mark.parametrize(
|
||||
"minute",
|
||||
[
|
||||
"60",
|
||||
"-1",
|
||||
"abc",
|
||||
"*/0",
|
||||
],
|
||||
)
|
||||
def test_invalid_minute(self, minute: str):
|
||||
"""Test invalid minute values are rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
ScheduleConfig(name="Test", listing_type=ListingType.RENT, minute=minute)
|
||||
|
||||
# Valid hour values
|
||||
@pytest.mark.parametrize(
|
||||
"hour",
|
||||
[
|
||||
"0",
|
||||
"23",
|
||||
"*",
|
||||
"*/6",
|
||||
"0,6,12,18",
|
||||
],
|
||||
)
|
||||
def test_valid_hour(self, hour: str):
|
||||
"""Test valid hour values are accepted."""
|
||||
schedule = ScheduleConfig(
|
||||
name="Test", listing_type=ListingType.RENT, hour=hour
|
||||
)
|
||||
assert schedule.hour == hour
|
||||
|
||||
# Invalid hour values
|
||||
@pytest.mark.parametrize(
|
||||
"hour",
|
||||
[
|
||||
"24",
|
||||
"-1",
|
||||
"abc",
|
||||
"*/0",
|
||||
],
|
||||
)
|
||||
def test_invalid_hour(self, hour: str):
|
||||
"""Test invalid hour values are rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
ScheduleConfig(name="Test", listing_type=ListingType.RENT, hour=hour)
|
||||
|
||||
# Valid day_of_week values
|
||||
@pytest.mark.parametrize(
|
||||
"day_of_week",
|
||||
[
|
||||
"0",
|
||||
"6",
|
||||
"*",
|
||||
"1,3,5",
|
||||
"*/2",
|
||||
],
|
||||
)
|
||||
def test_valid_day_of_week(self, day_of_week: str):
|
||||
"""Test valid day_of_week values are accepted."""
|
||||
schedule = ScheduleConfig(
|
||||
name="Test", listing_type=ListingType.RENT, day_of_week=day_of_week
|
||||
)
|
||||
assert schedule.day_of_week == day_of_week
|
||||
|
||||
# Invalid day_of_week values
|
||||
@pytest.mark.parametrize(
|
||||
"day_of_week",
|
||||
[
|
||||
"7",
|
||||
"-1",
|
||||
"abc",
|
||||
"*/0",
|
||||
],
|
||||
)
|
||||
def test_invalid_day_of_week(self, day_of_week: str):
|
||||
"""Test invalid day_of_week values are rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
ScheduleConfig(
|
||||
name="Test", listing_type=ListingType.RENT, day_of_week=day_of_week
|
||||
)
|
||||
|
||||
|
||||
class TestSchedulesConfig:
|
||||
"""Tests for SchedulesConfig container."""
|
||||
|
||||
def test_from_env_empty(self):
|
||||
"""Test loading from empty environment variable."""
|
||||
with mock.patch.dict(os.environ, {"SCRAPE_SCHEDULES": ""}, clear=False):
|
||||
config = SchedulesConfig.from_env()
|
||||
assert config.schedules == []
|
||||
|
||||
def test_from_env_missing(self):
|
||||
"""Test loading when environment variable is not set."""
|
||||
with mock.patch.dict(os.environ, {}, clear=True):
|
||||
# Ensure SCRAPE_SCHEDULES is not set
|
||||
os.environ.pop("SCRAPE_SCHEDULES", None)
|
||||
config = SchedulesConfig.from_env()
|
||||
assert config.schedules == []
|
||||
|
||||
def test_from_env_valid_single(self):
|
||||
"""Test loading a single valid schedule."""
|
||||
json_config = '[{"name":"Daily RENT","listing_type":"RENT","hour":"2"}]'
|
||||
with mock.patch.dict(os.environ, {"SCRAPE_SCHEDULES": json_config}):
|
||||
config = SchedulesConfig.from_env()
|
||||
|
||||
assert len(config.schedules) == 1
|
||||
assert config.schedules[0].name == "Daily RENT"
|
||||
assert config.schedules[0].listing_type == ListingType.RENT
|
||||
assert config.schedules[0].hour == "2"
|
||||
|
||||
def test_from_env_valid_multiple(self):
|
||||
"""Test loading multiple valid schedules."""
|
||||
json_config = """[
|
||||
{"name":"Daily RENT","listing_type":"RENT","hour":"2"},
|
||||
{"name":"Daily BUY","listing_type":"BUY","hour":"4","enabled":false}
|
||||
]"""
|
||||
with mock.patch.dict(os.environ, {"SCRAPE_SCHEDULES": json_config}):
|
||||
config = SchedulesConfig.from_env()
|
||||
|
||||
assert len(config.schedules) == 2
|
||||
assert config.schedules[0].name == "Daily RENT"
|
||||
assert config.schedules[0].enabled is True
|
||||
assert config.schedules[1].name == "Daily BUY"
|
||||
assert config.schedules[1].enabled is False
|
||||
|
||||
def test_from_env_invalid_json(self):
|
||||
"""Test error on invalid JSON."""
|
||||
with mock.patch.dict(os.environ, {"SCRAPE_SCHEDULES": "not json"}):
|
||||
with pytest.raises(ValueError, match="Invalid JSON"):
|
||||
SchedulesConfig.from_env()
|
||||
|
||||
def test_from_env_not_array(self):
|
||||
"""Test error when JSON is not an array."""
|
||||
with mock.patch.dict(os.environ, {"SCRAPE_SCHEDULES": '{"name":"test"}'}):
|
||||
with pytest.raises(ValueError, match="must be a JSON array"):
|
||||
SchedulesConfig.from_env()
|
||||
|
||||
def test_from_env_invalid_schedule(self):
|
||||
"""Test error when schedule validation fails."""
|
||||
# Missing required listing_type
|
||||
json_config = '[{"name":"Invalid"}]'
|
||||
with mock.patch.dict(os.environ, {"SCRAPE_SCHEDULES": json_config}):
|
||||
with pytest.raises(ValidationError):
|
||||
SchedulesConfig.from_env()
|
||||
|
||||
def test_get_enabled_schedules(self):
|
||||
"""Test filtering to only enabled schedules."""
|
||||
config = SchedulesConfig(
|
||||
schedules=[
|
||||
ScheduleConfig(name="Enabled", listing_type=ListingType.RENT, enabled=True),
|
||||
ScheduleConfig(name="Disabled", listing_type=ListingType.BUY, enabled=False),
|
||||
ScheduleConfig(name="Also Enabled", listing_type=ListingType.RENT, enabled=True),
|
||||
]
|
||||
)
|
||||
|
||||
enabled = config.get_enabled_schedules()
|
||||
|
||||
assert len(enabled) == 2
|
||||
assert enabled[0].name == "Enabled"
|
||||
assert enabled[1].name == "Also Enabled"
|
||||
|
||||
def test_get_enabled_schedules_all_disabled(self):
|
||||
"""Test when all schedules are disabled."""
|
||||
config = SchedulesConfig(
|
||||
schedules=[
|
||||
ScheduleConfig(name="Disabled1", listing_type=ListingType.RENT, enabled=False),
|
||||
ScheduleConfig(name="Disabled2", listing_type=ListingType.BUY, enabled=False),
|
||||
]
|
||||
)
|
||||
|
||||
enabled = config.get_enabled_schedules()
|
||||
|
||||
assert len(enabled) == 0
|
||||
306
tests/unit/test_task_service.py
Normal file
306
tests/unit/test_task_service.py
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
"""Unit tests for services/task_service.py."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from services.task_service import (
|
||||
TaskStatus,
|
||||
_extract_progress_info,
|
||||
_extract_result,
|
||||
_make_system_user,
|
||||
get_task_status,
|
||||
)
|
||||
|
||||
|
||||
class TestMakeSystemUser:
|
||||
"""Tests for _make_system_user helper."""
|
||||
|
||||
def test_creates_user_with_email(self) -> None:
|
||||
user = _make_system_user("test@example.com")
|
||||
assert user.email == "test@example.com"
|
||||
assert user.sub == ""
|
||||
assert user.name == ""
|
||||
|
||||
def test_different_emails_create_different_users(self) -> None:
|
||||
u1 = _make_system_user("a@b.com")
|
||||
u2 = _make_system_user("c@d.com")
|
||||
assert u1.email != u2.email
|
||||
|
||||
|
||||
class TestExtractResult:
|
||||
"""Tests for _extract_result helper."""
|
||||
|
||||
def test_failed_task_returns_error(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = True
|
||||
mock_result.result = Exception("something broke")
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result is None
|
||||
assert error is not None
|
||||
assert "something broke" in error
|
||||
|
||||
def test_failed_task_with_no_result(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = True
|
||||
mock_result.result = None
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result is None
|
||||
assert error is None
|
||||
|
||||
def test_successful_json_serializable_result(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = {"count": 42, "status": "done"}
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result == {"count": 42, "status": "done"}
|
||||
assert error is None
|
||||
|
||||
def test_non_serializable_result_falls_back_to_str(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = object() # not JSON-serializable
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert isinstance(result, str)
|
||||
assert error is None
|
||||
|
||||
def test_none_result_stays_none(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = None
|
||||
|
||||
result, error = _extract_result(mock_result)
|
||||
assert result is None
|
||||
assert error is None
|
||||
|
||||
|
||||
class TestExtractProgressInfo:
|
||||
"""Tests for _extract_progress_info helper."""
|
||||
|
||||
def test_extracts_progress_fields(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {"progress": 0.5, "processed": 50, "total": 100}
|
||||
mock_result.status = "STARTED"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["progress"] == 0.5
|
||||
assert info["processed"] == 50
|
||||
assert info["total"] == 100
|
||||
assert info["message"] is None
|
||||
|
||||
def test_extracts_message_from_info(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {"message": "Processing page 3"}
|
||||
mock_result.status = "STARTED"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] == "Processing page 3"
|
||||
|
||||
def test_falls_back_to_reason_for_skipped(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {"reason": "Already running"}
|
||||
mock_result.status = "SKIPPED"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] == "Already running"
|
||||
|
||||
def test_custom_state_used_as_message(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {}
|
||||
mock_result.status = "Fetching listings"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] == "Fetching listings"
|
||||
|
||||
def test_standard_state_not_used_as_message(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = {}
|
||||
mock_result.status = "PENDING"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info["message"] is None
|
||||
|
||||
def test_none_info_returns_all_none(self) -> None:
|
||||
mock_result = MagicMock()
|
||||
mock_result.info = None
|
||||
mock_result.status = "PENDING"
|
||||
|
||||
info = _extract_progress_info(mock_result)
|
||||
assert info == {"progress": None, "processed": None, "total": None, "message": None}
|
||||
|
||||
|
||||
class TestGetTaskStatus:
|
||||
"""Tests for get_task_status."""
|
||||
|
||||
def test_pending_task(self) -> None:
|
||||
"""Test status for a pending task."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "PENDING"
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = None
|
||||
mock_result.info = None
|
||||
mock_result.traceback = None
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
# Patch the lazy import
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.task_id == "test-id"
|
||||
assert status.status == "PENDING"
|
||||
assert status.error is None
|
||||
|
||||
def test_failed_task(self) -> None:
|
||||
"""Test status for a failed task."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "FAILURE"
|
||||
mock_result.failed.return_value = True
|
||||
mock_result.result = Exception("something broke")
|
||||
mock_result.info = None
|
||||
mock_result.traceback = "Traceback..."
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.status == "FAILURE"
|
||||
assert status.error is not None
|
||||
assert status.traceback == "Traceback..."
|
||||
|
||||
def test_custom_state_with_progress(self) -> None:
|
||||
"""Test that custom states with progress info are extracted correctly."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "Fetching listings"
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = None
|
||||
mock_result.info = {"progress": 0.5, "processed": 50, "total": 100}
|
||||
mock_result.traceback = None
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.progress == 0.5
|
||||
assert status.processed == 50
|
||||
assert status.total == 100
|
||||
|
||||
def test_successful_task(self) -> None:
|
||||
"""Test status for a successful task."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "SUCCESS"
|
||||
mock_result.failed.return_value = False
|
||||
mock_result.result = {"listings_count": 42}
|
||||
mock_result.info = None
|
||||
mock_result.traceback = None
|
||||
|
||||
with patch("services.task_service.dump_listings_task", create=True) as mock_task:
|
||||
mock_task.AsyncResult.return_value = mock_result
|
||||
with patch.dict("sys.modules", {"tasks.listing_tasks": MagicMock(dump_listings_task=mock_task)}):
|
||||
status = get_task_status("test-id")
|
||||
assert status.status == "SUCCESS"
|
||||
assert status.result == {"listings_count": 42}
|
||||
assert status.error is None
|
||||
|
||||
|
||||
class TestGetUserTasks:
|
||||
"""Tests for get_user_tasks."""
|
||||
|
||||
def test_returns_task_list(self) -> None:
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = ["task-1", "task-2"]
|
||||
|
||||
with patch("services.task_service.RedisRepository", create=True) as MockRedisRepo:
|
||||
MockRedisRepo.instance.return_value = mock_redis
|
||||
with patch.dict("sys.modules", {"redis_repository": MagicMock(RedisRepository=MockRedisRepo)}):
|
||||
from services.task_service import get_user_tasks
|
||||
result = get_user_tasks("test@example.com")
|
||||
assert result == ["task-1", "task-2"]
|
||||
|
||||
def test_returns_empty_list_for_unknown_user(self) -> None:
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = []
|
||||
|
||||
with patch("services.task_service.RedisRepository", create=True) as MockRedisRepo:
|
||||
MockRedisRepo.instance.return_value = mock_redis
|
||||
with patch.dict("sys.modules", {"redis_repository": MagicMock(RedisRepository=MockRedisRepo)}):
|
||||
from services.task_service import get_user_tasks
|
||||
result = get_user_tasks("nobody@example.com")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestCancelTask:
|
||||
"""Tests for cancel_task."""
|
||||
|
||||
def test_cancel_revokes_and_removes(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.remove_task_for_user.return_value = True
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import cancel_task
|
||||
result = cancel_task("task-123", user_email="test@example.com")
|
||||
assert result is True
|
||||
mock_celery.control.revoke.assert_called_once_with("task-123", terminate=True)
|
||||
|
||||
def test_cancel_without_user_email(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"celery_app": MagicMock(app=mock_celery)}):
|
||||
from services.task_service import cancel_task
|
||||
result = cancel_task("task-456")
|
||||
assert result is True
|
||||
mock_celery.control.revoke.assert_called_once_with("task-456", terminate=True)
|
||||
|
||||
|
||||
class TestClearAllTasks:
|
||||
"""Tests for clear_all_tasks."""
|
||||
|
||||
def test_clear_with_revoke(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = ["t1", "t2"]
|
||||
mock_redis.clear_tasks_for_user.return_value = 2
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import clear_all_tasks
|
||||
count = clear_all_tasks("test@example.com", revoke=True)
|
||||
assert count == 2
|
||||
assert mock_celery.control.revoke.call_count == 2
|
||||
|
||||
def test_clear_without_revoke(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.clear_tasks_for_user.return_value = 3
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import clear_all_tasks
|
||||
count = clear_all_tasks("test@example.com", revoke=False)
|
||||
assert count == 3
|
||||
mock_celery.control.revoke.assert_not_called()
|
||||
|
||||
def test_revoke_failure_logs_warning_and_continues(self) -> None:
|
||||
mock_celery = MagicMock()
|
||||
mock_celery.control.revoke.side_effect = Exception("connection lost")
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_tasks_for_user.return_value = ["t1"]
|
||||
mock_redis.clear_tasks_for_user.return_value = 1
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"celery_app": MagicMock(app=mock_celery),
|
||||
"redis_repository": MagicMock(RedisRepository=MagicMock(instance=MagicMock(return_value=mock_redis))),
|
||||
}):
|
||||
from services.task_service import clear_all_tasks
|
||||
# Should not raise despite revoke failure
|
||||
count = clear_all_tasks("test@example.com", revoke=True)
|
||||
assert count == 1
|
||||
334
tests/unit/test_throttle_detection.py
Normal file
334
tests/unit/test_throttle_detection.py
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
"""Unit tests for throttle detection and circuit breaker."""
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
import time
|
||||
|
||||
from rec.exceptions import (
|
||||
RightmoveAPIError,
|
||||
ThrottlingError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
IPBlockedError,
|
||||
SlowResponseError,
|
||||
UnexpectedEmptyResponseError,
|
||||
InvalidResponseError,
|
||||
CircuitBreakerOpenError,
|
||||
)
|
||||
from rec.throttle_detector import (
|
||||
ThrottleMetrics,
|
||||
validate_response,
|
||||
get_throttle_metrics,
|
||||
reset_throttle_metrics,
|
||||
)
|
||||
from rec.circuit_breaker import CircuitBreaker, CircuitState
|
||||
|
||||
|
||||
class TestExceptionHierarchy:
|
||||
"""Test custom exception hierarchy."""
|
||||
|
||||
def test_rightmove_api_error_is_exception(self) -> None:
|
||||
assert issubclass(RightmoveAPIError, Exception)
|
||||
|
||||
def test_throttling_error_is_rightmove_api_error(self) -> None:
|
||||
assert issubclass(ThrottlingError, RightmoveAPIError)
|
||||
|
||||
def test_rate_limit_error_is_throttling_error(self) -> None:
|
||||
assert issubclass(RateLimitError, ThrottlingError)
|
||||
|
||||
def test_service_unavailable_error_is_throttling_error(self) -> None:
|
||||
assert issubclass(ServiceUnavailableError, ThrottlingError)
|
||||
|
||||
def test_ip_blocked_error_is_throttling_error(self) -> None:
|
||||
assert issubclass(IPBlockedError, ThrottlingError)
|
||||
|
||||
def test_slow_response_error_is_throttling_error(self) -> None:
|
||||
assert issubclass(SlowResponseError, ThrottlingError)
|
||||
|
||||
def test_unexpected_empty_response_error_is_rightmove_api_error(self) -> None:
|
||||
assert issubclass(UnexpectedEmptyResponseError, RightmoveAPIError)
|
||||
assert not issubclass(UnexpectedEmptyResponseError, ThrottlingError)
|
||||
|
||||
def test_invalid_response_error_is_rightmove_api_error(self) -> None:
|
||||
assert issubclass(InvalidResponseError, RightmoveAPIError)
|
||||
assert not issubclass(InvalidResponseError, ThrottlingError)
|
||||
|
||||
def test_circuit_breaker_open_error_is_rightmove_api_error(self) -> None:
|
||||
assert issubclass(CircuitBreakerOpenError, RightmoveAPIError)
|
||||
|
||||
def test_exception_messages(self) -> None:
|
||||
error = RateLimitError("Too many requests")
|
||||
assert str(error) == "Too many requests"
|
||||
|
||||
|
||||
class TestThrottleMetrics:
|
||||
"""Test ThrottleMetrics class."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
assert metrics.rate_limit_count == 0
|
||||
assert metrics.service_unavailable_count == 0
|
||||
assert metrics.ip_blocked_count == 0
|
||||
assert metrics.slow_response_count == 0
|
||||
assert metrics.empty_response_count == 0
|
||||
assert metrics.invalid_response_count == 0
|
||||
assert metrics.total_requests == 0
|
||||
assert metrics.total_response_time == 0.0
|
||||
|
||||
def test_record_rate_limit(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_rate_limit()
|
||||
assert metrics.rate_limit_count == 1
|
||||
metrics.record_rate_limit()
|
||||
assert metrics.rate_limit_count == 2
|
||||
|
||||
def test_record_service_unavailable(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_service_unavailable()
|
||||
assert metrics.service_unavailable_count == 1
|
||||
|
||||
def test_record_ip_blocked(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_ip_blocked()
|
||||
assert metrics.ip_blocked_count == 1
|
||||
|
||||
def test_record_slow_response(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_slow_response(15.0)
|
||||
assert metrics.slow_response_count == 1
|
||||
assert metrics.total_response_time == 15.0
|
||||
assert metrics.total_requests == 1
|
||||
|
||||
def test_record_empty_response(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_empty_response()
|
||||
assert metrics.empty_response_count == 1
|
||||
|
||||
def test_record_invalid_response(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_invalid_response()
|
||||
assert metrics.invalid_response_count == 1
|
||||
|
||||
def test_record_request(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_request(0.5)
|
||||
assert metrics.total_requests == 1
|
||||
assert metrics.total_response_time == 0.5
|
||||
|
||||
def test_average_response_time(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_request(1.0)
|
||||
metrics.record_request(2.0)
|
||||
metrics.record_request(3.0)
|
||||
assert metrics.average_response_time == 2.0
|
||||
|
||||
def test_average_response_time_zero_requests(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
assert metrics.average_response_time == 0.0
|
||||
|
||||
def test_total_throttling_events(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_rate_limit()
|
||||
metrics.record_service_unavailable()
|
||||
metrics.record_ip_blocked()
|
||||
metrics.record_slow_response(15.0)
|
||||
assert metrics.total_throttling_events == 4
|
||||
|
||||
def test_throttle_rate(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_request(0.5) # 1 normal request
|
||||
metrics.record_request(0.5) # 2 normal requests
|
||||
metrics.record_rate_limit()
|
||||
metrics.record_request(0.5) # 3 normal requests (rate limit doesn't count as request)
|
||||
# 1 throttling event, 3 requests = 33.33%
|
||||
assert metrics.throttle_rate == pytest.approx(33.33, rel=0.01)
|
||||
|
||||
def test_throttle_rate_zero_requests(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
assert metrics.throttle_rate == 0.0
|
||||
|
||||
def test_elapsed_time(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
time.sleep(0.1)
|
||||
assert metrics.elapsed_time >= 0.1
|
||||
|
||||
def test_summary(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_request(1.0)
|
||||
metrics.record_rate_limit()
|
||||
summary = metrics.summary()
|
||||
assert "Total Requests:" in summary
|
||||
assert "Rate Limit (429):" in summary
|
||||
assert "1" in summary
|
||||
|
||||
|
||||
class TestGlobalMetrics:
|
||||
"""Test global metrics accessor."""
|
||||
|
||||
def test_get_throttle_metrics_singleton(self) -> None:
|
||||
reset_throttle_metrics()
|
||||
m1 = get_throttle_metrics()
|
||||
m2 = get_throttle_metrics()
|
||||
assert m1 is m2
|
||||
|
||||
def test_reset_throttle_metrics(self) -> None:
|
||||
reset_throttle_metrics()
|
||||
metrics = get_throttle_metrics()
|
||||
metrics.record_rate_limit()
|
||||
assert metrics.rate_limit_count == 1
|
||||
reset_throttle_metrics()
|
||||
new_metrics = get_throttle_metrics()
|
||||
assert new_metrics.rate_limit_count == 0
|
||||
|
||||
|
||||
class TestValidateResponse:
|
||||
"""Test validate_response function."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
reset_throttle_metrics()
|
||||
|
||||
def create_mock_response(self, status: int) -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.status = status
|
||||
return response
|
||||
|
||||
def test_rate_limit_error(self) -> None:
|
||||
response = self.create_mock_response(429)
|
||||
with pytest.raises(RateLimitError):
|
||||
validate_response(response, 0.5, None, 10.0)
|
||||
assert get_throttle_metrics().rate_limit_count == 1
|
||||
|
||||
def test_service_unavailable_error(self) -> None:
|
||||
response = self.create_mock_response(503)
|
||||
with pytest.raises(ServiceUnavailableError):
|
||||
validate_response(response, 0.5, None, 10.0)
|
||||
assert get_throttle_metrics().service_unavailable_count == 1
|
||||
|
||||
def test_ip_blocked_error(self) -> None:
|
||||
response = self.create_mock_response(403)
|
||||
with pytest.raises(IPBlockedError):
|
||||
validate_response(response, 0.5, None, 10.0)
|
||||
assert get_throttle_metrics().ip_blocked_count == 1
|
||||
|
||||
def test_slow_response_error(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"totalAvailableResults": 0, "properties": []}
|
||||
with pytest.raises(SlowResponseError):
|
||||
validate_response(response, 15.0, body, 10.0)
|
||||
assert get_throttle_metrics().slow_response_count == 1
|
||||
|
||||
def test_slow_response_just_under_threshold(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"totalAvailableResults": 0, "properties": []}
|
||||
# Should not raise
|
||||
validate_response(response, 9.9, body, 10.0)
|
||||
assert get_throttle_metrics().slow_response_count == 0
|
||||
|
||||
def test_error_in_response_body(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"error": "Something went wrong"}
|
||||
with pytest.raises(InvalidResponseError):
|
||||
validate_response(response, 0.5, body, 10.0)
|
||||
assert get_throttle_metrics().invalid_response_count == 1
|
||||
|
||||
def test_generic_error_in_body(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"message": "GENERIC_ERROR occurred"}
|
||||
with pytest.raises(InvalidResponseError):
|
||||
validate_response(response, 0.5, body, 10.0)
|
||||
|
||||
def test_unexpected_empty_response(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"totalAvailableResults": 100, "properties": []}
|
||||
with pytest.raises(UnexpectedEmptyResponseError):
|
||||
validate_response(response, 0.5, body, 10.0, expect_data=True)
|
||||
assert get_throttle_metrics().empty_response_count == 1
|
||||
|
||||
def test_empty_response_when_not_expecting_data(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"totalAvailableResults": 100, "properties": []}
|
||||
# Should not raise when expect_data=False
|
||||
validate_response(response, 0.5, body, 10.0, expect_data=False)
|
||||
assert get_throttle_metrics().empty_response_count == 0
|
||||
|
||||
def test_valid_response(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {
|
||||
"totalAvailableResults": 10,
|
||||
"properties": [{"id": 1}, {"id": 2}],
|
||||
}
|
||||
validate_response(response, 0.5, body, 10.0, expect_data=True)
|
||||
assert get_throttle_metrics().total_requests == 1
|
||||
assert get_throttle_metrics().total_throttling_events == 0
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
"""Test CircuitBreaker class."""
|
||||
|
||||
def test_initial_state_is_closed(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.is_closed
|
||||
assert not cb.is_open
|
||||
assert not cb.is_half_open
|
||||
|
||||
def test_allows_requests_when_closed(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
|
||||
# Should not raise
|
||||
cb.call()
|
||||
|
||||
def test_opens_after_threshold_failures(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
assert cb.is_closed
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
|
||||
def test_blocks_requests_when_open(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=60.0)
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
cb.call()
|
||||
|
||||
def test_success_resets_failure_count(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
assert cb.failure_count == 2
|
||||
cb.record_success()
|
||||
assert cb.failure_count == 0
|
||||
|
||||
def test_transitions_to_half_open_after_timeout(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
time.sleep(0.15)
|
||||
cb.call() # Should transition to half-open
|
||||
assert cb.is_half_open
|
||||
|
||||
def test_half_open_success_closes_circuit(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
|
||||
cb.record_failure()
|
||||
time.sleep(0.15)
|
||||
cb.call() # Transition to half-open
|
||||
assert cb.is_half_open
|
||||
cb.record_success()
|
||||
assert cb.is_closed
|
||||
|
||||
def test_half_open_failure_reopens_circuit(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
|
||||
cb.record_failure()
|
||||
time.sleep(0.15)
|
||||
cb.call() # Transition to half-open
|
||||
assert cb.is_half_open
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
|
||||
def test_reset(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=60.0)
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
cb.reset()
|
||||
assert cb.is_closed
|
||||
assert cb.failure_count == 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue