"""Unit tests for tasks/listing_tasks.py.""" import asyncio import json import os from collections import deque from unittest.mock import MagicMock, patch, AsyncMock, call import pytest import tasks.listing_tasks as module from tasks.listing_tasks import ( _update_task_state, _PipelineState, _process_worker, TaskLogHandler, SCRAPE_LOCK_NAME, LOG_BUFFER_MAX_LINES, NUM_WORKERS, PHASE_SPLITTING, PHASE_FETCHING, PHASE_PROCESSING, PHASE_COMPLETED, dump_listings_task, ) class TestUpdateTaskState: """Tests for _update_task_state.""" def test_injects_logs_from_active_buffer(self): task = MagicMock() original = module._active_log_buffer try: module._active_log_buffer = deque(["log line 1", "log line 2"]) _update_task_state(task, "test_state", {"key": "value"}) task.update_state.assert_called_once() call_meta = task.update_state.call_args[1]["meta"] assert call_meta["logs"] == ["log line 1", "log line 2"] assert call_meta["key"] == "value" finally: module._active_log_buffer = original def test_works_when_buffer_is_none(self): task = MagicMock() original = module._active_log_buffer try: module._active_log_buffer = None _update_task_state(task, "some_state", {"phase": "testing"}) task.update_state.assert_called_once_with( state="some_state", meta={"phase": "testing"} ) # No "logs" key should be injected call_meta = task.update_state.call_args[1]["meta"] assert "logs" not in call_meta finally: module._active_log_buffer = original def test_state_string_is_passed_through(self): task = MagicMock() original = module._active_log_buffer try: module._active_log_buffer = None _update_task_state(task, "PROGRESS", {}) task.update_state.assert_called_once_with(state="PROGRESS", meta={}) finally: module._active_log_buffer = original def test_empty_buffer_injects_empty_list(self): task = MagicMock() original = module._active_log_buffer try: module._active_log_buffer = deque() _update_task_state(task, "state", {"a": 1}) call_meta = task.update_state.call_args[1]["meta"] assert call_meta["logs"] == [] finally: module._active_log_buffer = original class TestTaskLogHandler: """Tests for the TaskLogHandler.""" def test_emit_appends_to_buffer(self): buf = deque(maxlen=10) handler = TaskLogHandler(buf) handler.setFormatter( __import__("logging").Formatter("%(message)s") ) record = __import__("logging").LogRecord( name="test", level=20, pathname="", lineno=0, msg="hello", args=(), exc_info=None, ) handler.emit(record) assert "hello" in buf def test_buffer_respects_maxlen(self): buf = deque(maxlen=2) handler = TaskLogHandler(buf) handler.setFormatter( __import__("logging").Formatter("%(message)s") ) for i in range(5): record = __import__("logging").LogRecord( name="test", level=20, pathname="", lineno=0, msg=f"msg{i}", args=(), exc_info=None, ) handler.emit(record) assert len(buf) == 2 assert list(buf) == ["msg3", "msg4"] class TestDumpListingsTask: """Tests for dump_listings_task Celery task.""" @patch("tasks.listing_tasks.redis_lock") def test_skips_when_lock_not_acquired(self, mock_redis_lock): """Task should skip when another scrape is running.""" mock_cm = MagicMock() mock_cm.__enter__ = MagicMock(return_value=False) mock_cm.__exit__ = MagicMock(return_value=False) mock_redis_lock.return_value = mock_cm from tasks.listing_tasks import dump_listings_task # Use run() which handles bind=True properly task_instance = dump_listings_task task_instance.update_state = MagicMock() result = dump_listings_task.run('{"listing_type": "RENT"}') assert result["status"] == "skipped" assert result["reason"] == "another_job_running" mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME) @patch("tasks.listing_tasks.asyncio.run") @patch("tasks.listing_tasks.redis_lock") def test_calls_dump_listings_full_when_lock_acquired( self, mock_redis_lock, mock_asyncio_run ): """Task should call dump_listings_full when lock is acquired.""" mock_cm = MagicMock() mock_cm.__enter__ = MagicMock(return_value=True) mock_cm.__exit__ = MagicMock(return_value=False) mock_redis_lock.return_value = mock_cm mock_asyncio_run.return_value = [] from tasks.listing_tasks import dump_listings_task task_instance = dump_listings_task task_instance.update_state = MagicMock() params_json = '{"listing_type": "RENT", "min_price": 1000, "max_price": 5000}' result = dump_listings_task.run(params_json) assert result["phase"] == "completed" assert result["progress"] == 1 mock_asyncio_run.assert_called_once() mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME) class TestSetupPeriodicTasks: """Tests for setup_periodic_tasks.""" # NOTE: every call to setup_periodic_tasks also registers the unconditional # `daily-market-aggregator` task (one extra call per invocation), so # call_count assertions below account for that +1. @patch("tasks.listing_tasks.SchedulesConfig.from_env") def test_registers_enabled_schedules(self, mock_from_env): from config.schedule_config import ScheduleConfig from models.listing import ListingType schedule = ScheduleConfig( name="Test Schedule", listing_type=ListingType.RENT, hour="3", minute="30", ) mock_config = MagicMock() mock_config.get_enabled_schedules.return_value = [schedule] mock_from_env.return_value = mock_config sender = MagicMock() module.setup_periodic_tasks(sender) # 1 schedule + 1 market aggregator. assert sender.add_periodic_task.call_count == 2 names = [c.kwargs["name"] for c in sender.add_periodic_task.call_args_list] assert "Test Schedule" in names assert "daily-market-aggregator" in names @patch("tasks.listing_tasks.SchedulesConfig.from_env") def test_handles_config_error_gracefully(self, mock_from_env): """A malformed SCRAPE_SCHEDULES must not block the market aggregator.""" mock_from_env.side_effect = ValueError("bad config") sender = MagicMock() module.setup_periodic_tasks(sender) # Aggregator still registers (the two systems are independent). assert sender.add_periodic_task.call_count == 1 assert sender.add_periodic_task.call_args.kwargs["name"] == "daily-market-aggregator" @patch("tasks.listing_tasks.SchedulesConfig.from_env") def test_registers_nothing_when_no_schedules(self, mock_from_env): mock_config = MagicMock() mock_config.get_enabled_schedules.return_value = [] mock_from_env.return_value = mock_config sender = MagicMock() module.setup_periodic_tasks(sender) # Only the market aggregator registered — no user schedules. assert sender.add_periodic_task.call_count == 1 assert sender.add_periodic_task.call_args.kwargs["name"] == "daily-market-aggregator" @patch("tasks.listing_tasks.SchedulesConfig.from_env") def test_registers_multiple_schedules(self, mock_from_env): from config.schedule_config import ScheduleConfig from models.listing import ListingType schedules = [ ScheduleConfig(name="Rent", listing_type=ListingType.RENT, hour="2"), ScheduleConfig(name="Buy", listing_type=ListingType.BUY, hour="4"), ] mock_config = MagicMock() mock_config.get_enabled_schedules.return_value = schedules mock_from_env.return_value = mock_config sender = MagicMock() module.setup_periodic_tasks(sender) # 2 schedules + 1 market aggregator. assert sender.add_periodic_task.call_count == 3 class TestPipelineState: """Tests for _PipelineState dataclass.""" def test_default_initialization(self): state = _PipelineState() assert state.ids_collected == 0 assert state.completed_subqueries == 0 assert state.total_pages_fetched == 0 assert state.fetching_done is False assert state.processed_count == 0 assert state.failed_count == 0 assert state.details_fetched == 0 assert state.images_downloaded == 0 assert state.ocr_completed == 0 assert state.processed_listings == [] def test_incrementing_counters(self): state = _PipelineState() state.ids_collected += 5 state.completed_subqueries += 3 state.total_pages_fetched += 10 state.processed_count += 4 state.failed_count += 1 state.details_fetched += 4 state.images_downloaded += 3 state.ocr_completed += 2 assert state.ids_collected == 5 assert state.completed_subqueries == 3 assert state.total_pages_fetched == 10 assert state.processed_count == 4 assert state.failed_count == 1 assert state.details_fetched == 4 assert state.images_downloaded == 3 assert state.ocr_completed == 2 def test_appending_to_processed_listings(self): state = _PipelineState() state.processed_listings.append("listing_a") state.processed_listings.append("listing_b") assert len(state.processed_listings) == 2 assert state.processed_listings == ["listing_a", "listing_b"] def test_separate_instances_have_independent_lists(self): state_a = _PipelineState() state_b = _PipelineState() state_a.processed_listings.append("only_a") assert state_b.processed_listings == [] def test_fetching_done_toggle(self): state = _PipelineState() assert state.fetching_done is False state.fetching_done = True assert state.fetching_done is True class TestPhaseConstants: """Tests for phase constant values.""" def test_phase_splitting(self): assert PHASE_SPLITTING == "splitting" def test_phase_fetching(self): assert PHASE_FETCHING == "fetching" def test_phase_processing(self): assert PHASE_PROCESSING == "processing" def test_phase_completed(self): assert PHASE_COMPLETED == "completed" def test_num_workers(self): assert NUM_WORKERS == 20 # --------------------------------------------------------------------------- # Regression tests for QA-round-3 backend bugs (B5, B6, B20) # --------------------------------------------------------------------------- class TestProcessWorkerExceptionHandling: """B6 regression: _process_worker must keep draining the queue when a single listing raises an unhandled exception (e.g. PermissionError). Previously one bad listing aborted the entire scrape.""" async def test_continues_after_per_listing_exception(self): """A PermissionError from one listing should not stop sibling listings.""" # Three listings in the queue followed by a None sentinel. queue: asyncio.Queue[int | None] = asyncio.Queue() for listing_id in [1, 2, 3]: await queue.put(listing_id) await queue.put(None) # Processor: listing 1 succeeds, listing 2 raises, listing 3 succeeds. good_listing = MagicMock() async def fake_process_listing(listing_id, on_step_complete=None): if listing_id == 2: raise PermissionError("Permission denied: data/rs/2") return good_listing processor = MagicMock() processor.process_listing = AsyncMock(side_effect=fake_process_listing) state = _PipelineState() reporter = MagicMock() await _process_worker(queue, processor, state, reporter) # All three IDs were attempted (queue drained before exit). assert processor.process_listing.call_count == 3 # Two succeeded, one failed. assert state.processed_count == 2 assert state.failed_count == 1 assert len(state.processed_listings) == 2 async def test_cancelled_error_propagates(self): """CancelledError must NOT be swallowed — cooperative cancellation relies on it propagating up through asyncio.gather().""" queue: asyncio.Queue[int | None] = asyncio.Queue() await queue.put(99) # No sentinel — the worker should bubble the CancelledError before # ever getting a chance to drain further. processor = MagicMock() processor.process_listing = AsyncMock(side_effect=asyncio.CancelledError()) state = _PipelineState() reporter = MagicMock() with pytest.raises(asyncio.CancelledError): await _process_worker(queue, processor, state, reporter) class TestDumpListingsTaskFailurePublish: """B5 regression: dump_listings_task must publish a terminal FAILURE event to the task_progress: pub/sub channel when the underlying scrape raises an exception. Previously only the happy-path SUCCESS was published, leaving WebSocket subscribers stuck on the last progress packet.""" @patch("tasks.listing_tasks.publish_task_progress") @patch("tasks.listing_tasks.asyncio.run") @patch("tasks.listing_tasks.redis_lock") def test_publishes_failure_event_on_exception( self, mock_redis_lock, mock_asyncio_run, mock_publish ): """When dump_listings_full raises, a FAILURE event is published.""" mock_cm = MagicMock() mock_cm.__enter__ = MagicMock(return_value=True) mock_cm.__exit__ = MagicMock(return_value=False) mock_redis_lock.return_value = mock_cm mock_asyncio_run.side_effect = PermissionError( "[Errno 13] Permission denied: 'data/rs/12345'" ) dump_listings_task.update_state = MagicMock() # Force a deterministic task_id so we can assert on it. with patch.object( type(dump_listings_task), "request", new=MagicMock(id="fake-task-id"), create=True, ): with pytest.raises(PermissionError): dump_listings_task.run( '{"listing_type": "RENT", "min_price": 1000, "max_price": 5000}' ) # Inspect publish_task_progress calls for a FAILURE event. failure_calls = [ c for c in mock_publish.call_args_list if len(c.args) >= 2 and c.args[1] == "FAILURE" ] assert failure_calls, ( f"Expected a FAILURE publish, got: " f"{[c.args[1] for c in mock_publish.call_args_list if len(c.args) >= 2]}" ) # The meta payload must include an error message. meta = failure_calls[0].args[2] assert "error" in meta assert "Permission denied" in meta["error"] assert meta["exc_type"] == "PermissionError" class TestDumpListingsTaskDecoratorConfig: """B20 regression: dump_listings_task must have time_limit / soft_time_limit / acks_late configured so dead tasks reap themselves even after pickup.""" def test_task_has_time_limits(self): # Celery exposes these via the task attributes once decorated. assert dump_listings_task.time_limit == 3600 assert dump_listings_task.soft_time_limit == 3500 def test_task_acks_late(self): assert dump_listings_task.acks_late is True class TestCeleryAppKeepaliveOptions: """B4 regression: broker / result-backend transport options must enable TCP keepalive and a Celery-level health check so the Redis HAProxy in front of the in-cluster Sentinel doesn't reap idle connections every 30s.""" def test_broker_transport_options_present(self): from celery_app import app as celery_app opts = celery_app.conf.get("broker_transport_options") or {} assert opts.get("socket_keepalive") is True assert opts.get("health_check_interval") == 25 def test_result_backend_transport_options_present(self): from celery_app import app as celery_app opts = celery_app.conf.get("result_backend_transport_options") or {} assert opts.get("socket_keepalive") is True assert opts.get("health_check_interval") == 25 class TestRedisClientKeepalive: """B4 regression: every helper that creates a Redis client must pass socket_keepalive=True and health_check_interval=25.""" @patch("services.task_progress_publisher.redis") def test_task_progress_publisher_uses_keepalive(self, mock_redis): # Reset the cached client so the patch takes effect. import services.task_progress_publisher as m m._redis_client = None m._get_redis_client() mock_redis.Redis.from_url.assert_called_once() kwargs = mock_redis.Redis.from_url.call_args.kwargs assert kwargs["socket_keepalive"] is True assert kwargs["health_check_interval"] == 25 m._redis_client = None # leave the module clean for other tests @patch("utils.redis_lock.redis") def test_redis_lock_uses_keepalive(self, mock_redis): from utils.redis_lock import get_redis_client get_redis_client() mock_redis.from_url.assert_called_once() kwargs = mock_redis.from_url.call_args.kwargs assert kwargs["socket_keepalive"] is True assert kwargs["health_check_interval"] == 25