CI test-unit failed on pipeline #49 because the three TestSetupPeriodicTasks cases asserted exact call counts on `sender.add_periodic_task` and the new unconditional `daily-market-aggregator` registration bumped each by one. Fix: - `tasks/listing_tasks.py`: lifted the market-aggregator registration out of the SchedulesConfig try-block — it's now independent of the user's SCRAPE_SCHEDULES (a malformed scrape config no longer takes the aggregator down with it). - `tests/unit/test_listing_tasks.py`: updated the four cases to account for the +1 unconditional aggregator call. `test_handles_config_error_ gracefully` now asserts the aggregator still registers when SchedulesConfig.from_env raises (regression coverage for the independence guarantee). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
481 lines
18 KiB
Python
481 lines
18 KiB
Python
"""Unit tests for tasks/listing_tasks.py."""
|
|
import asyncio
|
|
import json
|
|
import os
|
|
from collections import deque
|
|
from unittest.mock import MagicMock, patch, AsyncMock, call
|
|
|
|
import pytest
|
|
|
|
import tasks.listing_tasks as module
|
|
from tasks.listing_tasks import (
|
|
_update_task_state,
|
|
_PipelineState,
|
|
_process_worker,
|
|
TaskLogHandler,
|
|
SCRAPE_LOCK_NAME,
|
|
LOG_BUFFER_MAX_LINES,
|
|
NUM_WORKERS,
|
|
PHASE_SPLITTING,
|
|
PHASE_FETCHING,
|
|
PHASE_PROCESSING,
|
|
PHASE_COMPLETED,
|
|
dump_listings_task,
|
|
)
|
|
|
|
|
|
class TestUpdateTaskState:
|
|
"""Tests for _update_task_state."""
|
|
|
|
def test_injects_logs_from_active_buffer(self):
|
|
task = MagicMock()
|
|
original = module._active_log_buffer
|
|
try:
|
|
module._active_log_buffer = deque(["log line 1", "log line 2"])
|
|
_update_task_state(task, "test_state", {"key": "value"})
|
|
task.update_state.assert_called_once()
|
|
call_meta = task.update_state.call_args[1]["meta"]
|
|
assert call_meta["logs"] == ["log line 1", "log line 2"]
|
|
assert call_meta["key"] == "value"
|
|
finally:
|
|
module._active_log_buffer = original
|
|
|
|
def test_works_when_buffer_is_none(self):
|
|
task = MagicMock()
|
|
original = module._active_log_buffer
|
|
try:
|
|
module._active_log_buffer = None
|
|
_update_task_state(task, "some_state", {"phase": "testing"})
|
|
task.update_state.assert_called_once_with(
|
|
state="some_state", meta={"phase": "testing"}
|
|
)
|
|
# No "logs" key should be injected
|
|
call_meta = task.update_state.call_args[1]["meta"]
|
|
assert "logs" not in call_meta
|
|
finally:
|
|
module._active_log_buffer = original
|
|
|
|
def test_state_string_is_passed_through(self):
|
|
task = MagicMock()
|
|
original = module._active_log_buffer
|
|
try:
|
|
module._active_log_buffer = None
|
|
_update_task_state(task, "PROGRESS", {})
|
|
task.update_state.assert_called_once_with(state="PROGRESS", meta={})
|
|
finally:
|
|
module._active_log_buffer = original
|
|
|
|
def test_empty_buffer_injects_empty_list(self):
|
|
task = MagicMock()
|
|
original = module._active_log_buffer
|
|
try:
|
|
module._active_log_buffer = deque()
|
|
_update_task_state(task, "state", {"a": 1})
|
|
call_meta = task.update_state.call_args[1]["meta"]
|
|
assert call_meta["logs"] == []
|
|
finally:
|
|
module._active_log_buffer = original
|
|
|
|
|
|
class TestTaskLogHandler:
|
|
"""Tests for the TaskLogHandler."""
|
|
|
|
def test_emit_appends_to_buffer(self):
|
|
buf = deque(maxlen=10)
|
|
handler = TaskLogHandler(buf)
|
|
handler.setFormatter(
|
|
__import__("logging").Formatter("%(message)s")
|
|
)
|
|
record = __import__("logging").LogRecord(
|
|
name="test", level=20, pathname="", lineno=0,
|
|
msg="hello", args=(), exc_info=None,
|
|
)
|
|
handler.emit(record)
|
|
assert "hello" in buf
|
|
|
|
def test_buffer_respects_maxlen(self):
|
|
buf = deque(maxlen=2)
|
|
handler = TaskLogHandler(buf)
|
|
handler.setFormatter(
|
|
__import__("logging").Formatter("%(message)s")
|
|
)
|
|
for i in range(5):
|
|
record = __import__("logging").LogRecord(
|
|
name="test", level=20, pathname="", lineno=0,
|
|
msg=f"msg{i}", args=(), exc_info=None,
|
|
)
|
|
handler.emit(record)
|
|
assert len(buf) == 2
|
|
assert list(buf) == ["msg3", "msg4"]
|
|
|
|
|
|
class TestDumpListingsTask:
|
|
"""Tests for dump_listings_task Celery task."""
|
|
|
|
@patch("tasks.listing_tasks.redis_lock")
|
|
def test_skips_when_lock_not_acquired(self, mock_redis_lock):
|
|
"""Task should skip when another scrape is running."""
|
|
mock_cm = MagicMock()
|
|
mock_cm.__enter__ = MagicMock(return_value=False)
|
|
mock_cm.__exit__ = MagicMock(return_value=False)
|
|
mock_redis_lock.return_value = mock_cm
|
|
|
|
from tasks.listing_tasks import dump_listings_task
|
|
|
|
# Use run() which handles bind=True properly
|
|
task_instance = dump_listings_task
|
|
task_instance.update_state = MagicMock()
|
|
|
|
result = dump_listings_task.run('{"listing_type": "RENT"}')
|
|
|
|
assert result["status"] == "skipped"
|
|
assert result["reason"] == "another_job_running"
|
|
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
|
|
|
|
@patch("tasks.listing_tasks.asyncio.run")
|
|
@patch("tasks.listing_tasks.redis_lock")
|
|
def test_calls_dump_listings_full_when_lock_acquired(
|
|
self, mock_redis_lock, mock_asyncio_run
|
|
):
|
|
"""Task should call dump_listings_full when lock is acquired."""
|
|
mock_cm = MagicMock()
|
|
mock_cm.__enter__ = MagicMock(return_value=True)
|
|
mock_cm.__exit__ = MagicMock(return_value=False)
|
|
mock_redis_lock.return_value = mock_cm
|
|
|
|
mock_asyncio_run.return_value = []
|
|
|
|
from tasks.listing_tasks import dump_listings_task
|
|
|
|
task_instance = dump_listings_task
|
|
task_instance.update_state = MagicMock()
|
|
|
|
params_json = '{"listing_type": "RENT", "min_price": 1000, "max_price": 5000}'
|
|
result = dump_listings_task.run(params_json)
|
|
|
|
assert result["phase"] == "completed"
|
|
assert result["progress"] == 1
|
|
mock_asyncio_run.assert_called_once()
|
|
mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME)
|
|
|
|
|
|
class TestSetupPeriodicTasks:
|
|
"""Tests for setup_periodic_tasks."""
|
|
|
|
# NOTE: every call to setup_periodic_tasks also registers the unconditional
|
|
# `daily-market-aggregator` task (one extra call per invocation), so
|
|
# call_count assertions below account for that +1.
|
|
|
|
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
|
def test_registers_enabled_schedules(self, mock_from_env):
|
|
from config.schedule_config import ScheduleConfig
|
|
from models.listing import ListingType
|
|
|
|
schedule = ScheduleConfig(
|
|
name="Test Schedule",
|
|
listing_type=ListingType.RENT,
|
|
hour="3",
|
|
minute="30",
|
|
)
|
|
mock_config = MagicMock()
|
|
mock_config.get_enabled_schedules.return_value = [schedule]
|
|
mock_from_env.return_value = mock_config
|
|
|
|
sender = MagicMock()
|
|
module.setup_periodic_tasks(sender)
|
|
|
|
# 1 schedule + 1 market aggregator.
|
|
assert sender.add_periodic_task.call_count == 2
|
|
names = [c.kwargs["name"] for c in sender.add_periodic_task.call_args_list]
|
|
assert "Test Schedule" in names
|
|
assert "daily-market-aggregator" in names
|
|
|
|
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
|
def test_handles_config_error_gracefully(self, mock_from_env):
|
|
"""A malformed SCRAPE_SCHEDULES must not block the market aggregator."""
|
|
mock_from_env.side_effect = ValueError("bad config")
|
|
|
|
sender = MagicMock()
|
|
module.setup_periodic_tasks(sender)
|
|
|
|
# Aggregator still registers (the two systems are independent).
|
|
assert sender.add_periodic_task.call_count == 1
|
|
assert sender.add_periodic_task.call_args.kwargs["name"] == "daily-market-aggregator"
|
|
|
|
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
|
def test_registers_nothing_when_no_schedules(self, mock_from_env):
|
|
mock_config = MagicMock()
|
|
mock_config.get_enabled_schedules.return_value = []
|
|
mock_from_env.return_value = mock_config
|
|
|
|
sender = MagicMock()
|
|
module.setup_periodic_tasks(sender)
|
|
|
|
# Only the market aggregator registered — no user schedules.
|
|
assert sender.add_periodic_task.call_count == 1
|
|
assert sender.add_periodic_task.call_args.kwargs["name"] == "daily-market-aggregator"
|
|
|
|
@patch("tasks.listing_tasks.SchedulesConfig.from_env")
|
|
def test_registers_multiple_schedules(self, mock_from_env):
|
|
from config.schedule_config import ScheduleConfig
|
|
from models.listing import ListingType
|
|
|
|
schedules = [
|
|
ScheduleConfig(name="Rent", listing_type=ListingType.RENT, hour="2"),
|
|
ScheduleConfig(name="Buy", listing_type=ListingType.BUY, hour="4"),
|
|
]
|
|
mock_config = MagicMock()
|
|
mock_config.get_enabled_schedules.return_value = schedules
|
|
mock_from_env.return_value = mock_config
|
|
|
|
sender = MagicMock()
|
|
module.setup_periodic_tasks(sender)
|
|
|
|
# 2 schedules + 1 market aggregator.
|
|
assert sender.add_periodic_task.call_count == 3
|
|
|
|
|
|
class TestPipelineState:
|
|
"""Tests for _PipelineState dataclass."""
|
|
|
|
def test_default_initialization(self):
|
|
state = _PipelineState()
|
|
assert state.ids_collected == 0
|
|
assert state.completed_subqueries == 0
|
|
assert state.total_pages_fetched == 0
|
|
assert state.fetching_done is False
|
|
assert state.processed_count == 0
|
|
assert state.failed_count == 0
|
|
assert state.details_fetched == 0
|
|
assert state.images_downloaded == 0
|
|
assert state.ocr_completed == 0
|
|
assert state.processed_listings == []
|
|
|
|
def test_incrementing_counters(self):
|
|
state = _PipelineState()
|
|
state.ids_collected += 5
|
|
state.completed_subqueries += 3
|
|
state.total_pages_fetched += 10
|
|
state.processed_count += 4
|
|
state.failed_count += 1
|
|
state.details_fetched += 4
|
|
state.images_downloaded += 3
|
|
state.ocr_completed += 2
|
|
|
|
assert state.ids_collected == 5
|
|
assert state.completed_subqueries == 3
|
|
assert state.total_pages_fetched == 10
|
|
assert state.processed_count == 4
|
|
assert state.failed_count == 1
|
|
assert state.details_fetched == 4
|
|
assert state.images_downloaded == 3
|
|
assert state.ocr_completed == 2
|
|
|
|
def test_appending_to_processed_listings(self):
|
|
state = _PipelineState()
|
|
state.processed_listings.append("listing_a")
|
|
state.processed_listings.append("listing_b")
|
|
assert len(state.processed_listings) == 2
|
|
assert state.processed_listings == ["listing_a", "listing_b"]
|
|
|
|
def test_separate_instances_have_independent_lists(self):
|
|
state_a = _PipelineState()
|
|
state_b = _PipelineState()
|
|
state_a.processed_listings.append("only_a")
|
|
assert state_b.processed_listings == []
|
|
|
|
def test_fetching_done_toggle(self):
|
|
state = _PipelineState()
|
|
assert state.fetching_done is False
|
|
state.fetching_done = True
|
|
assert state.fetching_done is True
|
|
|
|
|
|
class TestPhaseConstants:
|
|
"""Tests for phase constant values."""
|
|
|
|
def test_phase_splitting(self):
|
|
assert PHASE_SPLITTING == "splitting"
|
|
|
|
def test_phase_fetching(self):
|
|
assert PHASE_FETCHING == "fetching"
|
|
|
|
def test_phase_processing(self):
|
|
assert PHASE_PROCESSING == "processing"
|
|
|
|
def test_phase_completed(self):
|
|
assert PHASE_COMPLETED == "completed"
|
|
|
|
def test_num_workers(self):
|
|
assert NUM_WORKERS == 20
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Regression tests for QA-round-3 backend bugs (B5, B6, B20)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestProcessWorkerExceptionHandling:
|
|
"""B6 regression: _process_worker must keep draining the queue when a
|
|
single listing raises an unhandled exception (e.g. PermissionError).
|
|
Previously one bad listing aborted the entire scrape."""
|
|
|
|
async def test_continues_after_per_listing_exception(self):
|
|
"""A PermissionError from one listing should not stop sibling listings."""
|
|
# Three listings in the queue followed by a None sentinel.
|
|
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
|
for listing_id in [1, 2, 3]:
|
|
await queue.put(listing_id)
|
|
await queue.put(None)
|
|
|
|
# Processor: listing 1 succeeds, listing 2 raises, listing 3 succeeds.
|
|
good_listing = MagicMock()
|
|
|
|
async def fake_process_listing(listing_id, on_step_complete=None):
|
|
if listing_id == 2:
|
|
raise PermissionError("Permission denied: data/rs/2")
|
|
return good_listing
|
|
|
|
processor = MagicMock()
|
|
processor.process_listing = AsyncMock(side_effect=fake_process_listing)
|
|
|
|
state = _PipelineState()
|
|
reporter = MagicMock()
|
|
|
|
await _process_worker(queue, processor, state, reporter)
|
|
|
|
# All three IDs were attempted (queue drained before exit).
|
|
assert processor.process_listing.call_count == 3
|
|
# Two succeeded, one failed.
|
|
assert state.processed_count == 2
|
|
assert state.failed_count == 1
|
|
assert len(state.processed_listings) == 2
|
|
|
|
async def test_cancelled_error_propagates(self):
|
|
"""CancelledError must NOT be swallowed — cooperative cancellation
|
|
relies on it propagating up through asyncio.gather()."""
|
|
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
|
await queue.put(99)
|
|
# No sentinel — the worker should bubble the CancelledError before
|
|
# ever getting a chance to drain further.
|
|
|
|
processor = MagicMock()
|
|
processor.process_listing = AsyncMock(side_effect=asyncio.CancelledError())
|
|
|
|
state = _PipelineState()
|
|
reporter = MagicMock()
|
|
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await _process_worker(queue, processor, state, reporter)
|
|
|
|
|
|
class TestDumpListingsTaskFailurePublish:
|
|
"""B5 regression: dump_listings_task must publish a terminal FAILURE
|
|
event to the task_progress:<id> pub/sub channel when the underlying
|
|
scrape raises an exception. Previously only the happy-path SUCCESS
|
|
was published, leaving WebSocket subscribers stuck on the last
|
|
progress packet."""
|
|
|
|
@patch("tasks.listing_tasks.publish_task_progress")
|
|
@patch("tasks.listing_tasks.asyncio.run")
|
|
@patch("tasks.listing_tasks.redis_lock")
|
|
def test_publishes_failure_event_on_exception(
|
|
self, mock_redis_lock, mock_asyncio_run, mock_publish
|
|
):
|
|
"""When dump_listings_full raises, a FAILURE event is published."""
|
|
mock_cm = MagicMock()
|
|
mock_cm.__enter__ = MagicMock(return_value=True)
|
|
mock_cm.__exit__ = MagicMock(return_value=False)
|
|
mock_redis_lock.return_value = mock_cm
|
|
|
|
mock_asyncio_run.side_effect = PermissionError(
|
|
"[Errno 13] Permission denied: 'data/rs/12345'"
|
|
)
|
|
|
|
dump_listings_task.update_state = MagicMock()
|
|
|
|
# Force a deterministic task_id so we can assert on it.
|
|
with patch.object(
|
|
type(dump_listings_task),
|
|
"request",
|
|
new=MagicMock(id="fake-task-id"),
|
|
create=True,
|
|
):
|
|
with pytest.raises(PermissionError):
|
|
dump_listings_task.run(
|
|
'{"listing_type": "RENT", "min_price": 1000, "max_price": 5000}'
|
|
)
|
|
|
|
# Inspect publish_task_progress calls for a FAILURE event.
|
|
failure_calls = [
|
|
c for c in mock_publish.call_args_list
|
|
if len(c.args) >= 2 and c.args[1] == "FAILURE"
|
|
]
|
|
assert failure_calls, (
|
|
f"Expected a FAILURE publish, got: "
|
|
f"{[c.args[1] for c in mock_publish.call_args_list if len(c.args) >= 2]}"
|
|
)
|
|
# The meta payload must include an error message.
|
|
meta = failure_calls[0].args[2]
|
|
assert "error" in meta
|
|
assert "Permission denied" in meta["error"]
|
|
assert meta["exc_type"] == "PermissionError"
|
|
|
|
|
|
class TestDumpListingsTaskDecoratorConfig:
|
|
"""B20 regression: dump_listings_task must have time_limit /
|
|
soft_time_limit / acks_late configured so dead tasks reap themselves
|
|
even after pickup."""
|
|
|
|
def test_task_has_time_limits(self):
|
|
# Celery exposes these via the task attributes once decorated.
|
|
assert dump_listings_task.time_limit == 3600
|
|
assert dump_listings_task.soft_time_limit == 3500
|
|
|
|
def test_task_acks_late(self):
|
|
assert dump_listings_task.acks_late is True
|
|
|
|
|
|
class TestCeleryAppKeepaliveOptions:
|
|
"""B4 regression: broker / result-backend transport options must
|
|
enable TCP keepalive and a Celery-level health check so the Redis
|
|
HAProxy in front of the in-cluster Sentinel doesn't reap idle
|
|
connections every 30s."""
|
|
|
|
def test_broker_transport_options_present(self):
|
|
from celery_app import app as celery_app
|
|
opts = celery_app.conf.get("broker_transport_options") or {}
|
|
assert opts.get("socket_keepalive") is True
|
|
assert opts.get("health_check_interval") == 25
|
|
|
|
def test_result_backend_transport_options_present(self):
|
|
from celery_app import app as celery_app
|
|
opts = celery_app.conf.get("result_backend_transport_options") or {}
|
|
assert opts.get("socket_keepalive") is True
|
|
assert opts.get("health_check_interval") == 25
|
|
|
|
|
|
class TestRedisClientKeepalive:
|
|
"""B4 regression: every helper that creates a Redis client must
|
|
pass socket_keepalive=True and health_check_interval=25."""
|
|
|
|
@patch("services.task_progress_publisher.redis")
|
|
def test_task_progress_publisher_uses_keepalive(self, mock_redis):
|
|
# Reset the cached client so the patch takes effect.
|
|
import services.task_progress_publisher as m
|
|
m._redis_client = None
|
|
m._get_redis_client()
|
|
mock_redis.Redis.from_url.assert_called_once()
|
|
kwargs = mock_redis.Redis.from_url.call_args.kwargs
|
|
assert kwargs["socket_keepalive"] is True
|
|
assert kwargs["health_check_interval"] == 25
|
|
m._redis_client = None # leave the module clean for other tests
|
|
|
|
@patch("utils.redis_lock.redis")
|
|
def test_redis_lock_uses_keepalive(self, mock_redis):
|
|
from utils.redis_lock import get_redis_client
|
|
get_redis_client()
|
|
mock_redis.from_url.assert_called_once()
|
|
kwargs = mock_redis.from_url.call_args.kwargs
|
|
assert kwargs["socket_keepalive"] is True
|
|
assert kwargs["health_check_interval"] == 25
|