"""Unit tests for tasks/listing_tasks.py.""" import json import os from collections import deque from unittest.mock import MagicMock, patch, AsyncMock, call import pytest import tasks.listing_tasks as module from tasks.listing_tasks import ( _update_task_state, _PipelineState, TaskLogHandler, SCRAPE_LOCK_NAME, LOG_BUFFER_MAX_LINES, NUM_WORKERS, PHASE_SPLITTING, PHASE_FETCHING, PHASE_PROCESSING, PHASE_COMPLETED, ) class TestUpdateTaskState: """Tests for _update_task_state.""" def test_injects_logs_from_active_buffer(self): task = MagicMock() original = module._active_log_buffer try: module._active_log_buffer = deque(["log line 1", "log line 2"]) _update_task_state(task, "test_state", {"key": "value"}) task.update_state.assert_called_once() call_meta = task.update_state.call_args[1]["meta"] assert call_meta["logs"] == ["log line 1", "log line 2"] assert call_meta["key"] == "value" finally: module._active_log_buffer = original def test_works_when_buffer_is_none(self): task = MagicMock() original = module._active_log_buffer try: module._active_log_buffer = None _update_task_state(task, "some_state", {"phase": "testing"}) task.update_state.assert_called_once_with( state="some_state", meta={"phase": "testing"} ) # No "logs" key should be injected call_meta = task.update_state.call_args[1]["meta"] assert "logs" not in call_meta finally: module._active_log_buffer = original def test_state_string_is_passed_through(self): task = MagicMock() original = module._active_log_buffer try: module._active_log_buffer = None _update_task_state(task, "PROGRESS", {}) task.update_state.assert_called_once_with(state="PROGRESS", meta={}) finally: module._active_log_buffer = original def test_empty_buffer_injects_empty_list(self): task = MagicMock() original = module._active_log_buffer try: module._active_log_buffer = deque() _update_task_state(task, "state", {"a": 1}) call_meta = task.update_state.call_args[1]["meta"] assert call_meta["logs"] == [] finally: module._active_log_buffer = original class TestTaskLogHandler: """Tests for the TaskLogHandler.""" def test_emit_appends_to_buffer(self): buf = deque(maxlen=10) handler = TaskLogHandler(buf) handler.setFormatter( __import__("logging").Formatter("%(message)s") ) record = __import__("logging").LogRecord( name="test", level=20, pathname="", lineno=0, msg="hello", args=(), exc_info=None, ) handler.emit(record) assert "hello" in buf def test_buffer_respects_maxlen(self): buf = deque(maxlen=2) handler = TaskLogHandler(buf) handler.setFormatter( __import__("logging").Formatter("%(message)s") ) for i in range(5): record = __import__("logging").LogRecord( name="test", level=20, pathname="", lineno=0, msg=f"msg{i}", args=(), exc_info=None, ) handler.emit(record) assert len(buf) == 2 assert list(buf) == ["msg3", "msg4"] class TestDumpListingsTask: """Tests for dump_listings_task Celery task.""" @patch("tasks.listing_tasks.redis_lock") def test_skips_when_lock_not_acquired(self, mock_redis_lock): """Task should skip when another scrape is running.""" mock_cm = MagicMock() mock_cm.__enter__ = MagicMock(return_value=False) mock_cm.__exit__ = MagicMock(return_value=False) mock_redis_lock.return_value = mock_cm from tasks.listing_tasks import dump_listings_task # Use run() which handles bind=True properly task_instance = dump_listings_task task_instance.update_state = MagicMock() result = dump_listings_task.run('{"listing_type": "RENT"}') assert result["status"] == "skipped" assert result["reason"] == "another_job_running" mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME) @patch("tasks.listing_tasks.asyncio.run") @patch("tasks.listing_tasks.redis_lock") def test_calls_dump_listings_full_when_lock_acquired( self, mock_redis_lock, mock_asyncio_run ): """Task should call dump_listings_full when lock is acquired.""" mock_cm = MagicMock() mock_cm.__enter__ = MagicMock(return_value=True) mock_cm.__exit__ = MagicMock(return_value=False) mock_redis_lock.return_value = mock_cm mock_asyncio_run.return_value = [] from tasks.listing_tasks import dump_listings_task task_instance = dump_listings_task task_instance.update_state = MagicMock() params_json = '{"listing_type": "RENT", "min_price": 1000, "max_price": 5000}' result = dump_listings_task.run(params_json) assert result["phase"] == "completed" assert result["progress"] == 1 mock_asyncio_run.assert_called_once() mock_redis_lock.assert_called_once_with(SCRAPE_LOCK_NAME) class TestSetupPeriodicTasks: """Tests for setup_periodic_tasks.""" @patch("tasks.listing_tasks.SchedulesConfig.from_env") def test_registers_enabled_schedules(self, mock_from_env): from config.schedule_config import ScheduleConfig from models.listing import ListingType schedule = ScheduleConfig( name="Test Schedule", listing_type=ListingType.RENT, hour="3", minute="30", ) mock_config = MagicMock() mock_config.get_enabled_schedules.return_value = [schedule] mock_from_env.return_value = mock_config sender = MagicMock() module.setup_periodic_tasks(sender) sender.add_periodic_task.assert_called_once() call_args = sender.add_periodic_task.call_args assert call_args[1]["name"] == "Test Schedule" @patch("tasks.listing_tasks.SchedulesConfig.from_env") def test_handles_config_error_gracefully(self, mock_from_env): mock_from_env.side_effect = ValueError("bad config") sender = MagicMock() module.setup_periodic_tasks(sender) sender.add_periodic_task.assert_not_called() @patch("tasks.listing_tasks.SchedulesConfig.from_env") def test_registers_nothing_when_no_schedules(self, mock_from_env): mock_config = MagicMock() mock_config.get_enabled_schedules.return_value = [] mock_from_env.return_value = mock_config sender = MagicMock() module.setup_periodic_tasks(sender) sender.add_periodic_task.assert_not_called() @patch("tasks.listing_tasks.SchedulesConfig.from_env") def test_registers_multiple_schedules(self, mock_from_env): from config.schedule_config import ScheduleConfig from models.listing import ListingType schedules = [ ScheduleConfig(name="Rent", listing_type=ListingType.RENT, hour="2"), ScheduleConfig(name="Buy", listing_type=ListingType.BUY, hour="4"), ] mock_config = MagicMock() mock_config.get_enabled_schedules.return_value = schedules mock_from_env.return_value = mock_config sender = MagicMock() module.setup_periodic_tasks(sender) assert sender.add_periodic_task.call_count == 2 class TestPipelineState: """Tests for _PipelineState dataclass.""" def test_default_initialization(self): state = _PipelineState() assert state.ids_collected == 0 assert state.completed_subqueries == 0 assert state.total_pages_fetched == 0 assert state.fetching_done is False assert state.processed_count == 0 assert state.failed_count == 0 assert state.details_fetched == 0 assert state.images_downloaded == 0 assert state.ocr_completed == 0 assert state.processed_listings == [] def test_incrementing_counters(self): state = _PipelineState() state.ids_collected += 5 state.completed_subqueries += 3 state.total_pages_fetched += 10 state.processed_count += 4 state.failed_count += 1 state.details_fetched += 4 state.images_downloaded += 3 state.ocr_completed += 2 assert state.ids_collected == 5 assert state.completed_subqueries == 3 assert state.total_pages_fetched == 10 assert state.processed_count == 4 assert state.failed_count == 1 assert state.details_fetched == 4 assert state.images_downloaded == 3 assert state.ocr_completed == 2 def test_appending_to_processed_listings(self): state = _PipelineState() state.processed_listings.append("listing_a") state.processed_listings.append("listing_b") assert len(state.processed_listings) == 2 assert state.processed_listings == ["listing_a", "listing_b"] def test_separate_instances_have_independent_lists(self): state_a = _PipelineState() state_b = _PipelineState() state_a.processed_listings.append("only_a") assert state_b.processed_listings == [] def test_fetching_done_toggle(self): state = _PipelineState() assert state.fetching_done is False state.fetching_done = True assert state.fetching_done is True class TestPhaseConstants: """Tests for phase constant values.""" def test_phase_splitting(self): assert PHASE_SPLITTING == "splitting" def test_phase_fetching(self): assert PHASE_FETCHING == "fetching" def test_phase_processing(self): assert PHASE_PROCESSING == "processing" def test_phase_completed(self): assert PHASE_COMPLETED == "completed" def test_num_workers(self): assert NUM_WORKERS == 20