Add throttling detection and circuit breaker for Rightmove scraper
This commit is contained in:
parent
e8293c6042
commit
f880664a98
10 changed files with 1428 additions and 86 deletions
311
crawler/tests/integration/test_throttle_integration.py
Normal file
311
crawler/tests/integration/test_throttle_integration.py
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
"""Integration tests for throttle detection and circuit breaker."""
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from config.scraper_config import ScraperConfig
|
||||
from rec.exceptions import (
|
||||
CircuitBreakerOpenError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
ThrottlingError,
|
||||
)
|
||||
from rec.query import (
|
||||
detail_query,
|
||||
listing_query,
|
||||
probe_query,
|
||||
get_circuit_breaker,
|
||||
reset_circuit_breaker,
|
||||
)
|
||||
from rec.throttle_detector import reset_throttle_metrics, get_throttle_metrics
|
||||
from rec.circuit_breaker import CircuitBreaker, CircuitState
|
||||
from models.listing import ListingType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config() -> ScraperConfig:
|
||||
"""Create a test configuration."""
|
||||
return ScraperConfig(
|
||||
max_concurrent_requests=5,
|
||||
request_delay_ms=10,
|
||||
slow_response_threshold=2.0,
|
||||
enable_circuit_breaker=True,
|
||||
circuit_breaker_failure_threshold=3,
|
||||
circuit_breaker_recovery_timeout=0.5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals() -> None:
|
||||
"""Reset global state before each test."""
|
||||
reset_throttle_metrics()
|
||||
reset_circuit_breaker()
|
||||
|
||||
|
||||
class MockResponse:
|
||||
"""Mock aiohttp response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status: int = 200,
|
||||
json_data: dict | None = None,
|
||||
text: str = "",
|
||||
):
|
||||
self.status = status
|
||||
self._json_data = json_data or {}
|
||||
self._text = text
|
||||
|
||||
async def json(self) -> dict:
|
||||
return self._json_data
|
||||
|
||||
async def text(self) -> str:
|
||||
return self._text
|
||||
|
||||
async def __aenter__(self) -> "MockResponse":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: object) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class TestThrottlingRetryBehavior:
|
||||
"""Test retry behavior for throttling errors."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_triggers_retry(self, config: ScraperConfig) -> None:
|
||||
"""Test that 429 responses trigger retry with backoff."""
|
||||
call_count = 0
|
||||
|
||||
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
return MockResponse(status=429)
|
||||
return MockResponse(
|
||||
status=200,
|
||||
json_data={"totalAvailableResults": 10, "properties": []},
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = mock_get
|
||||
|
||||
# Mock district lookup
|
||||
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
|
||||
# The retry decorator will catch RateLimitError and retry
|
||||
# We need to patch the tenacity wait to speed up the test
|
||||
with patch("tenacity.wait_exponential.__call__", return_value=0):
|
||||
result = await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="Test",
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result["totalAvailableResults"] == 10
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_unavailable_triggers_retry(
|
||||
self, config: ScraperConfig
|
||||
) -> None:
|
||||
"""Test that 503 responses trigger retry."""
|
||||
call_count = 0
|
||||
|
||||
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 2:
|
||||
return MockResponse(status=503)
|
||||
return MockResponse(
|
||||
status=200,
|
||||
json_data={"totalAvailableResults": 5, "properties": []},
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = mock_get
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
|
||||
with patch("tenacity.wait_exponential.__call__", return_value=0):
|
||||
result = await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="Test",
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
class TestCircuitBreakerIntegration:
|
||||
"""Test circuit breaker integration with queries."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_opens_after_failures(
|
||||
self, config: ScraperConfig
|
||||
) -> None:
|
||||
"""Test that circuit breaker opens after consecutive failures."""
|
||||
call_count = 0
|
||||
|
||||
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return MockResponse(status=429)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = mock_get
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
|
||||
# After enough failures, circuit should open
|
||||
with pytest.raises((RateLimitError, CircuitBreakerOpenError)):
|
||||
with patch("tenacity.wait_exponential.__call__", return_value=0):
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="Test",
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Check circuit breaker state
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is not None
|
||||
# After many failures, the circuit should be open
|
||||
assert cb.failure_count >= config.circuit_breaker_failure_threshold
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_blocks_requests_when_open(
|
||||
self, config: ScraperConfig
|
||||
) -> None:
|
||||
"""Test that open circuit breaker blocks requests immediately."""
|
||||
# Force open the circuit breaker
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is not None
|
||||
for _ in range(config.circuit_breaker_failure_threshold):
|
||||
cb.record_failure()
|
||||
|
||||
assert cb.is_open
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="Test",
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
class TestMetricsTracking:
|
||||
"""Test throttle metrics are properly tracked."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_tracked_on_rate_limit(self, config: ScraperConfig) -> None:
|
||||
"""Test that rate limit errors are tracked in metrics."""
|
||||
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
return MockResponse(status=429)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = mock_get
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
|
||||
with pytest.raises(RateLimitError):
|
||||
with patch("tenacity.wait_exponential.__call__", return_value=0):
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="Test",
|
||||
config=config,
|
||||
)
|
||||
|
||||
metrics = get_throttle_metrics()
|
||||
assert metrics.rate_limit_count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_tracked_on_success(self, config: ScraperConfig) -> None:
|
||||
"""Test that successful requests are tracked in metrics."""
|
||||
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
|
||||
return MockResponse(
|
||||
status=200,
|
||||
json_data={"totalAvailableResults": 10, "properties": []},
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = mock_get
|
||||
|
||||
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
|
||||
await probe_query(
|
||||
session=mock_session,
|
||||
channel=ListingType.RENT,
|
||||
min_bedrooms=1,
|
||||
max_bedrooms=2,
|
||||
radius=1.0,
|
||||
min_price=1000,
|
||||
max_price=2000,
|
||||
district="Test",
|
||||
config=config,
|
||||
)
|
||||
|
||||
metrics = get_throttle_metrics()
|
||||
assert metrics.total_requests == 1
|
||||
assert metrics.total_throttling_events == 0
|
||||
|
||||
|
||||
class TestConfigIntegration:
|
||||
"""Test configuration integration."""
|
||||
|
||||
def test_config_from_env_includes_throttle_settings(self) -> None:
|
||||
"""Test that config loads throttle settings from environment."""
|
||||
import os
|
||||
|
||||
original_env = os.environ.copy()
|
||||
try:
|
||||
os.environ["RIGHTMOVE_SLOW_RESPONSE_THRESHOLD"] = "5.0"
|
||||
os.environ["RIGHTMOVE_ENABLE_CIRCUIT_BREAKER"] = "false"
|
||||
os.environ["RIGHTMOVE_CIRCUIT_BREAKER_FAILURES"] = "10"
|
||||
os.environ["RIGHTMOVE_CIRCUIT_BREAKER_TIMEOUT"] = "120.0"
|
||||
|
||||
config = ScraperConfig.from_env()
|
||||
|
||||
assert config.slow_response_threshold == 5.0
|
||||
assert config.enable_circuit_breaker is False
|
||||
assert config.circuit_breaker_failure_threshold == 10
|
||||
assert config.circuit_breaker_recovery_timeout == 120.0
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
|
||||
def test_circuit_breaker_disabled_returns_none(self) -> None:
|
||||
"""Test that disabled circuit breaker returns None."""
|
||||
config = ScraperConfig(
|
||||
enable_circuit_breaker=False,
|
||||
)
|
||||
reset_circuit_breaker()
|
||||
cb = get_circuit_breaker(config)
|
||||
assert cb is None
|
||||
334
crawler/tests/unit/test_throttle_detection.py
Normal file
334
crawler/tests/unit/test_throttle_detection.py
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
"""Unit tests for throttle detection and circuit breaker."""
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
import time
|
||||
|
||||
from rec.exceptions import (
|
||||
RightmoveAPIError,
|
||||
ThrottlingError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
IPBlockedError,
|
||||
SlowResponseError,
|
||||
UnexpectedEmptyResponseError,
|
||||
InvalidResponseError,
|
||||
CircuitBreakerOpenError,
|
||||
)
|
||||
from rec.throttle_detector import (
|
||||
ThrottleMetrics,
|
||||
validate_response,
|
||||
get_throttle_metrics,
|
||||
reset_throttle_metrics,
|
||||
)
|
||||
from rec.circuit_breaker import CircuitBreaker, CircuitState
|
||||
|
||||
|
||||
class TestExceptionHierarchy:
|
||||
"""Test custom exception hierarchy."""
|
||||
|
||||
def test_rightmove_api_error_is_exception(self) -> None:
|
||||
assert issubclass(RightmoveAPIError, Exception)
|
||||
|
||||
def test_throttling_error_is_rightmove_api_error(self) -> None:
|
||||
assert issubclass(ThrottlingError, RightmoveAPIError)
|
||||
|
||||
def test_rate_limit_error_is_throttling_error(self) -> None:
|
||||
assert issubclass(RateLimitError, ThrottlingError)
|
||||
|
||||
def test_service_unavailable_error_is_throttling_error(self) -> None:
|
||||
assert issubclass(ServiceUnavailableError, ThrottlingError)
|
||||
|
||||
def test_ip_blocked_error_is_throttling_error(self) -> None:
|
||||
assert issubclass(IPBlockedError, ThrottlingError)
|
||||
|
||||
def test_slow_response_error_is_throttling_error(self) -> None:
|
||||
assert issubclass(SlowResponseError, ThrottlingError)
|
||||
|
||||
def test_unexpected_empty_response_error_is_rightmove_api_error(self) -> None:
|
||||
assert issubclass(UnexpectedEmptyResponseError, RightmoveAPIError)
|
||||
assert not issubclass(UnexpectedEmptyResponseError, ThrottlingError)
|
||||
|
||||
def test_invalid_response_error_is_rightmove_api_error(self) -> None:
|
||||
assert issubclass(InvalidResponseError, RightmoveAPIError)
|
||||
assert not issubclass(InvalidResponseError, ThrottlingError)
|
||||
|
||||
def test_circuit_breaker_open_error_is_rightmove_api_error(self) -> None:
|
||||
assert issubclass(CircuitBreakerOpenError, RightmoveAPIError)
|
||||
|
||||
def test_exception_messages(self) -> None:
|
||||
error = RateLimitError("Too many requests")
|
||||
assert str(error) == "Too many requests"
|
||||
|
||||
|
||||
class TestThrottleMetrics:
|
||||
"""Test ThrottleMetrics class."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
assert metrics.rate_limit_count == 0
|
||||
assert metrics.service_unavailable_count == 0
|
||||
assert metrics.ip_blocked_count == 0
|
||||
assert metrics.slow_response_count == 0
|
||||
assert metrics.empty_response_count == 0
|
||||
assert metrics.invalid_response_count == 0
|
||||
assert metrics.total_requests == 0
|
||||
assert metrics.total_response_time == 0.0
|
||||
|
||||
def test_record_rate_limit(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_rate_limit()
|
||||
assert metrics.rate_limit_count == 1
|
||||
metrics.record_rate_limit()
|
||||
assert metrics.rate_limit_count == 2
|
||||
|
||||
def test_record_service_unavailable(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_service_unavailable()
|
||||
assert metrics.service_unavailable_count == 1
|
||||
|
||||
def test_record_ip_blocked(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_ip_blocked()
|
||||
assert metrics.ip_blocked_count == 1
|
||||
|
||||
def test_record_slow_response(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_slow_response(15.0)
|
||||
assert metrics.slow_response_count == 1
|
||||
assert metrics.total_response_time == 15.0
|
||||
assert metrics.total_requests == 1
|
||||
|
||||
def test_record_empty_response(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_empty_response()
|
||||
assert metrics.empty_response_count == 1
|
||||
|
||||
def test_record_invalid_response(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_invalid_response()
|
||||
assert metrics.invalid_response_count == 1
|
||||
|
||||
def test_record_request(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_request(0.5)
|
||||
assert metrics.total_requests == 1
|
||||
assert metrics.total_response_time == 0.5
|
||||
|
||||
def test_average_response_time(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_request(1.0)
|
||||
metrics.record_request(2.0)
|
||||
metrics.record_request(3.0)
|
||||
assert metrics.average_response_time == 2.0
|
||||
|
||||
def test_average_response_time_zero_requests(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
assert metrics.average_response_time == 0.0
|
||||
|
||||
def test_total_throttling_events(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_rate_limit()
|
||||
metrics.record_service_unavailable()
|
||||
metrics.record_ip_blocked()
|
||||
metrics.record_slow_response(15.0)
|
||||
assert metrics.total_throttling_events == 4
|
||||
|
||||
def test_throttle_rate(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_request(0.5) # 1 normal request
|
||||
metrics.record_request(0.5) # 2 normal requests
|
||||
metrics.record_rate_limit()
|
||||
metrics.record_request(0.5) # 3 normal requests (rate limit doesn't count as request)
|
||||
# 1 throttling event, 3 requests = 33.33%
|
||||
assert metrics.throttle_rate == pytest.approx(33.33, rel=0.01)
|
||||
|
||||
def test_throttle_rate_zero_requests(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
assert metrics.throttle_rate == 0.0
|
||||
|
||||
def test_elapsed_time(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
time.sleep(0.1)
|
||||
assert metrics.elapsed_time >= 0.1
|
||||
|
||||
def test_summary(self) -> None:
|
||||
metrics = ThrottleMetrics()
|
||||
metrics.record_request(1.0)
|
||||
metrics.record_rate_limit()
|
||||
summary = metrics.summary()
|
||||
assert "Total Requests:" in summary
|
||||
assert "Rate Limit (429):" in summary
|
||||
assert "1" in summary
|
||||
|
||||
|
||||
class TestGlobalMetrics:
|
||||
"""Test global metrics accessor."""
|
||||
|
||||
def test_get_throttle_metrics_singleton(self) -> None:
|
||||
reset_throttle_metrics()
|
||||
m1 = get_throttle_metrics()
|
||||
m2 = get_throttle_metrics()
|
||||
assert m1 is m2
|
||||
|
||||
def test_reset_throttle_metrics(self) -> None:
|
||||
reset_throttle_metrics()
|
||||
metrics = get_throttle_metrics()
|
||||
metrics.record_rate_limit()
|
||||
assert metrics.rate_limit_count == 1
|
||||
reset_throttle_metrics()
|
||||
new_metrics = get_throttle_metrics()
|
||||
assert new_metrics.rate_limit_count == 0
|
||||
|
||||
|
||||
class TestValidateResponse:
|
||||
"""Test validate_response function."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
reset_throttle_metrics()
|
||||
|
||||
def create_mock_response(self, status: int) -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.status = status
|
||||
return response
|
||||
|
||||
def test_rate_limit_error(self) -> None:
|
||||
response = self.create_mock_response(429)
|
||||
with pytest.raises(RateLimitError):
|
||||
validate_response(response, 0.5, None, 10.0)
|
||||
assert get_throttle_metrics().rate_limit_count == 1
|
||||
|
||||
def test_service_unavailable_error(self) -> None:
|
||||
response = self.create_mock_response(503)
|
||||
with pytest.raises(ServiceUnavailableError):
|
||||
validate_response(response, 0.5, None, 10.0)
|
||||
assert get_throttle_metrics().service_unavailable_count == 1
|
||||
|
||||
def test_ip_blocked_error(self) -> None:
|
||||
response = self.create_mock_response(403)
|
||||
with pytest.raises(IPBlockedError):
|
||||
validate_response(response, 0.5, None, 10.0)
|
||||
assert get_throttle_metrics().ip_blocked_count == 1
|
||||
|
||||
def test_slow_response_error(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"totalAvailableResults": 0, "properties": []}
|
||||
with pytest.raises(SlowResponseError):
|
||||
validate_response(response, 15.0, body, 10.0)
|
||||
assert get_throttle_metrics().slow_response_count == 1
|
||||
|
||||
def test_slow_response_just_under_threshold(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"totalAvailableResults": 0, "properties": []}
|
||||
# Should not raise
|
||||
validate_response(response, 9.9, body, 10.0)
|
||||
assert get_throttle_metrics().slow_response_count == 0
|
||||
|
||||
def test_error_in_response_body(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"error": "Something went wrong"}
|
||||
with pytest.raises(InvalidResponseError):
|
||||
validate_response(response, 0.5, body, 10.0)
|
||||
assert get_throttle_metrics().invalid_response_count == 1
|
||||
|
||||
def test_generic_error_in_body(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"message": "GENERIC_ERROR occurred"}
|
||||
with pytest.raises(InvalidResponseError):
|
||||
validate_response(response, 0.5, body, 10.0)
|
||||
|
||||
def test_unexpected_empty_response(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"totalAvailableResults": 100, "properties": []}
|
||||
with pytest.raises(UnexpectedEmptyResponseError):
|
||||
validate_response(response, 0.5, body, 10.0, expect_data=True)
|
||||
assert get_throttle_metrics().empty_response_count == 1
|
||||
|
||||
def test_empty_response_when_not_expecting_data(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {"totalAvailableResults": 100, "properties": []}
|
||||
# Should not raise when expect_data=False
|
||||
validate_response(response, 0.5, body, 10.0, expect_data=False)
|
||||
assert get_throttle_metrics().empty_response_count == 0
|
||||
|
||||
def test_valid_response(self) -> None:
|
||||
response = self.create_mock_response(200)
|
||||
body = {
|
||||
"totalAvailableResults": 10,
|
||||
"properties": [{"id": 1}, {"id": 2}],
|
||||
}
|
||||
validate_response(response, 0.5, body, 10.0, expect_data=True)
|
||||
assert get_throttle_metrics().total_requests == 1
|
||||
assert get_throttle_metrics().total_throttling_events == 0
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
"""Test CircuitBreaker class."""
|
||||
|
||||
def test_initial_state_is_closed(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.is_closed
|
||||
assert not cb.is_open
|
||||
assert not cb.is_half_open
|
||||
|
||||
def test_allows_requests_when_closed(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
|
||||
# Should not raise
|
||||
cb.call()
|
||||
|
||||
def test_opens_after_threshold_failures(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
assert cb.is_closed
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
|
||||
def test_blocks_requests_when_open(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=60.0)
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
cb.call()
|
||||
|
||||
def test_success_resets_failure_count(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
assert cb.failure_count == 2
|
||||
cb.record_success()
|
||||
assert cb.failure_count == 0
|
||||
|
||||
def test_transitions_to_half_open_after_timeout(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
time.sleep(0.15)
|
||||
cb.call() # Should transition to half-open
|
||||
assert cb.is_half_open
|
||||
|
||||
def test_half_open_success_closes_circuit(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
|
||||
cb.record_failure()
|
||||
time.sleep(0.15)
|
||||
cb.call() # Transition to half-open
|
||||
assert cb.is_half_open
|
||||
cb.record_success()
|
||||
assert cb.is_closed
|
||||
|
||||
def test_half_open_failure_reopens_circuit(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
|
||||
cb.record_failure()
|
||||
time.sleep(0.15)
|
||||
cb.call() # Transition to half-open
|
||||
assert cb.is_half_open
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
|
||||
def test_reset(self) -> None:
|
||||
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=60.0)
|
||||
cb.record_failure()
|
||||
assert cb.is_open
|
||||
cb.reset()
|
||||
assert cb.is_closed
|
||||
assert cb.failure_count == 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue