Add throttling detection and circuit breaker for Rightmove scraper

This commit is contained in:
Viktor Barzin 2026-02-02 22:50:19 +00:00
parent e8293c6042
commit f880664a98
10 changed files with 1428 additions and 86 deletions

View file

@ -16,6 +16,12 @@ RIGHTMOVE_MIN_PRICE_BAND=100 # Minimum price band width (won't split below
RIGHTMOVE_MAX_PAGES=60 # Max pages per subquery (60 * 25 = 1500 max results)
RIGHTMOVE_PROXY_URL= # Optional SOCKS proxy URL (e.g., socks5://localhost:9050 for Tor)
# Throttling detection and circuit breaker
RIGHTMOVE_SLOW_RESPONSE_THRESHOLD=10.0 # Response time threshold in seconds
RIGHTMOVE_ENABLE_CIRCUIT_BREAKER=true # Enable circuit breaker protection
RIGHTMOVE_CIRCUIT_BREAKER_FAILURES=5 # Consecutive failures to open circuit
RIGHTMOVE_CIRCUIT_BREAKER_TIMEOUT=60.0 # Seconds to wait before recovery attempt
# Periodic scraping schedules (JSON array)
# Each schedule has: name, enabled, hour, minute, day_of_week, listing_type, min/max_bedrooms, min/max_price, district_names, furnish_types
# Cron fields: minute (0-59), hour (0-23), day_of_week (0-6, 0=Sunday)

View file

@ -18,6 +18,10 @@ class ScraperConfig:
min_price_band: Minimum width of a price band (won't split below this).
max_pages_per_query: Maximum pages to fetch per subquery (60 * 25 = 1500).
proxy_url: Optional SOCKS proxy URL (e.g., socks5://localhost:9050 for Tor).
slow_response_threshold: Response time threshold in seconds for throttle detection.
enable_circuit_breaker: Whether to enable circuit breaker protection.
circuit_breaker_failure_threshold: Number of consecutive failures to open circuit.
circuit_breaker_recovery_timeout: Seconds to wait before testing recovery.
"""
max_concurrent_requests: int = 5
@ -27,6 +31,10 @@ class ScraperConfig:
min_price_band: int = 100 # Minimum band width in currency units
max_pages_per_query: int = 60 # 60 * 25 = 1500 results max
proxy_url: str | None = None
slow_response_threshold: float = 10.0 # seconds
enable_circuit_breaker: bool = True
circuit_breaker_failure_threshold: int = 5
circuit_breaker_recovery_timeout: float = 60.0
@classmethod
def from_env(cls) -> Self:
@ -40,6 +48,10 @@ class ScraperConfig:
RIGHTMOVE_MIN_PRICE_BAND: Minimum price band width (default: 100)
RIGHTMOVE_MAX_PAGES: Max pages per query (default: 60)
RIGHTMOVE_PROXY_URL: SOCKS proxy URL (default: None)
RIGHTMOVE_SLOW_RESPONSE_THRESHOLD: Slow response threshold in seconds (default: 10.0)
RIGHTMOVE_ENABLE_CIRCUIT_BREAKER: Enable circuit breaker (default: True)
RIGHTMOVE_CIRCUIT_BREAKER_FAILURES: Failures to open circuit (default: 5)
RIGHTMOVE_CIRCUIT_BREAKER_TIMEOUT: Recovery timeout in seconds (default: 60.0)
Returns:
ScraperConfig instance with values from environment or defaults.
@ -62,4 +74,16 @@ class ScraperConfig:
os.environ.get("RIGHTMOVE_MAX_PAGES", "60")
),
proxy_url=os.environ.get("RIGHTMOVE_PROXY_URL") or None,
slow_response_threshold=float(
os.environ.get("RIGHTMOVE_SLOW_RESPONSE_THRESHOLD", "10.0")
),
enable_circuit_breaker=os.environ.get(
"RIGHTMOVE_ENABLE_CIRCUIT_BREAKER", "true"
).lower() in ("true", "1", "yes"),
circuit_breaker_failure_threshold=int(
os.environ.get("RIGHTMOVE_CIRCUIT_BREAKER_FAILURES", "5")
),
circuit_breaker_recovery_timeout=float(
os.environ.get("RIGHTMOVE_CIRCUIT_BREAKER_TIMEOUT", "60.0")
),
)

View file

@ -0,0 +1,137 @@
"""Circuit breaker pattern for protecting against cascading failures."""
from __future__ import annotations
import enum
import logging
import time
from dataclasses import dataclass
from rec.exceptions import CircuitBreakerOpenError
logger = logging.getLogger("uvicorn.error")
class CircuitState(enum.Enum):
"""Circuit breaker states."""
CLOSED = "closed" # Normal operation
OPEN = "open" # Too many failures, blocking requests
HALF_OPEN = "half_open" # Testing if service recovered
@dataclass
class CircuitBreaker:
"""Circuit breaker for protecting against cascading failures.
Implements the circuit breaker pattern:
- CLOSED: Requests pass through normally, failures are counted
- OPEN: After N consecutive failures, circuit opens and blocks all requests
- HALF_OPEN: After recovery timeout, allow one request to test if service recovered
Attributes:
failure_threshold: Number of consecutive failures before opening.
recovery_timeout: Seconds to wait before attempting half-open state.
state: Current circuit state.
failure_count: Count of consecutive failures.
last_failure_time: Timestamp of last failure.
last_state_change: Timestamp of last state change.
"""
failure_threshold: int
recovery_timeout: float
state: CircuitState = CircuitState.CLOSED
failure_count: int = 0
last_failure_time: float = 0.0
last_state_change: float = 0.0
def __post_init__(self) -> None:
"""Initialize state change timestamp."""
self.last_state_change = time.time()
def call(self) -> None:
"""Check if a request should be allowed.
Raises:
CircuitBreakerOpenError: If circuit is open and blocking requests.
"""
current_time = time.time()
if self.state == CircuitState.OPEN:
# Check if we should transition to half-open
if current_time - self.last_failure_time >= self.recovery_timeout:
self._transition_to_half_open()
else:
raise CircuitBreakerOpenError(
f"Circuit breaker is open. "
f"Waiting {self.recovery_timeout - (current_time - self.last_failure_time):.1f}s "
f"before retry."
)
# Allow request to proceed (CLOSED or HALF_OPEN)
def record_success(self) -> None:
"""Record a successful request."""
if self.state == CircuitState.HALF_OPEN:
# Service has recovered, close the circuit
self._transition_to_closed()
# Reset failure count on success
self.failure_count = 0
def record_failure(self) -> None:
"""Record a failed request."""
self.failure_count += 1
self.last_failure_time = time.time()
if self.state == CircuitState.HALF_OPEN:
# Test request failed, reopen circuit
self._transition_to_open()
elif self.state == CircuitState.CLOSED:
# Check if we should open the circuit
if self.failure_count >= self.failure_threshold:
self._transition_to_open()
def _transition_to_open(self) -> None:
"""Transition to OPEN state."""
self.state = CircuitState.OPEN
self.last_state_change = time.time()
logger.warning(
f"Circuit breaker OPENED after {self.failure_count} consecutive failures. "
f"Will retry in {self.recovery_timeout}s"
)
def _transition_to_half_open(self) -> None:
"""Transition to HALF_OPEN state."""
self.state = CircuitState.HALF_OPEN
self.last_state_change = time.time()
logger.info("Circuit breaker entering HALF_OPEN state, testing service recovery")
def _transition_to_closed(self) -> None:
"""Transition to CLOSED state."""
self.state = CircuitState.CLOSED
self.last_state_change = time.time()
self.failure_count = 0
logger.info("Circuit breaker CLOSED, service recovered")
def reset(self) -> None:
"""Manually reset the circuit breaker to CLOSED state."""
self.state = CircuitState.CLOSED
self.failure_count = 0
self.last_failure_time = 0.0
self.last_state_change = time.time()
logger.info("Circuit breaker manually reset to CLOSED state")
@property
def is_open(self) -> bool:
"""Check if circuit is currently open."""
return self.state == CircuitState.OPEN
@property
def is_closed(self) -> bool:
"""Check if circuit is currently closed."""
return self.state == CircuitState.CLOSED
@property
def is_half_open(self) -> bool:
"""Check if circuit is currently half-open."""
return self.state == CircuitState.HALF_OPEN

74
crawler/rec/exceptions.py Normal file
View file

@ -0,0 +1,74 @@
"""Custom exceptions for Rightmove API errors."""
class RightmoveAPIError(Exception):
"""Base exception for all Rightmove API errors."""
pass
class ThrottlingError(RightmoveAPIError):
"""Base exception for throttling-related errors.
Indicates that Rightmove is limiting our requests and we should back off.
"""
pass
class RateLimitError(ThrottlingError):
"""HTTP 429 - Too Many Requests.
Rightmove is explicitly rate limiting our requests.
"""
pass
class ServiceUnavailableError(ThrottlingError):
"""HTTP 503 - Service Unavailable.
Rightmove's service is temporarily unavailable, possibly due to overload.
"""
pass
class IPBlockedError(ThrottlingError):
"""HTTP 403 - Forbidden (IP blocked).
Our IP may be blocked or blacklisted by Rightmove.
"""
pass
class SlowResponseError(ThrottlingError):
"""Response time exceeded threshold.
API is responding very slowly, indicating potential throttling or overload.
"""
pass
class UnexpectedEmptyResponseError(RightmoveAPIError):
"""Empty response received when data was expected."""
pass
class InvalidResponseError(RightmoveAPIError):
"""Response contains error messages or invalid data."""
pass
class CircuitBreakerOpenError(RightmoveAPIError):
"""Circuit breaker is open, requests are being blocked.
The circuit breaker has detected too many failures and is preventing
further requests to allow the service to recover.
"""
pass

View file

@ -1,4 +1,6 @@
import enum
import logging
import time
from typing import Any
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
@ -6,9 +8,26 @@ from collections.abc import AsyncIterator
import aiohttp
from models.listing import FurnishType, ListingType
from rec import districts
from tenacity import retry, stop_after_attempt, wait_random
from rec.exceptions import (
CircuitBreakerOpenError,
ThrottlingError,
)
from rec.throttle_detector import get_throttle_metrics, validate_response
from rec.circuit_breaker import CircuitBreaker
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
wait_random,
)
from config.scraper_config import ScraperConfig
logger = logging.getLogger("uvicorn.error")
# Global circuit breaker instance
_circuit_breaker: CircuitBreaker | None = None
DEFAULT_HEADERS = {
"Host": "api.rightmove.co.uk",
@ -65,20 +84,81 @@ async def create_session(
await session.close()
@retry(wait=wait_random(min=1, max=2), stop=stop_after_attempt(3))
def get_circuit_breaker(config: ScraperConfig | None = None) -> CircuitBreaker | None:
"""Get the global circuit breaker instance.
Args:
config: Configuration for initializing the circuit breaker.
Returns:
CircuitBreaker instance if enabled, None otherwise.
"""
global _circuit_breaker
if config is None:
config = ScraperConfig.from_env()
if not config.enable_circuit_breaker:
return None
if _circuit_breaker is None:
_circuit_breaker = CircuitBreaker(
failure_threshold=config.circuit_breaker_failure_threshold,
recovery_timeout=config.circuit_breaker_recovery_timeout,
)
return _circuit_breaker
def reset_circuit_breaker() -> None:
"""Reset the global circuit breaker."""
global _circuit_breaker
if _circuit_breaker is not None:
_circuit_breaker.reset()
def check_circuit_breaker(config: ScraperConfig | None = None) -> None:
"""Check if the circuit breaker allows requests.
Args:
config: Configuration for the circuit breaker.
Raises:
CircuitBreakerOpenError: If the circuit is open.
"""
cb = get_circuit_breaker(config)
if cb is not None:
cb.call()
@retry(
retry=retry_if_exception_type(ThrottlingError),
wait=wait_exponential(multiplier=2, min=2, max=120),
stop=stop_after_attempt(5),
)
async def detail_query(
detail_id: int,
session: aiohttp.ClientSession | None = None,
config: ScraperConfig | None = None,
) -> dict[str, Any]:
"""Fetch detailed property information.
Args:
detail_id: The property identifier.
session: Optional aiohttp session. Creates new one if not provided.
config: Scraper configuration. Loads from environment if not provided.
Returns:
Property details as a dictionary.
Raises:
CircuitBreakerOpenError: If the circuit breaker is open.
ThrottlingError: If the request is throttled.
"""
if config is None:
config = ScraperConfig.from_env()
check_circuit_breaker(config)
cb = get_circuit_breaker(config)
params = {
"apiApplication": "ANDROID",
"appVersion": "3.70.0",
@ -86,13 +166,38 @@ async def detail_query(
url = f"https://api.rightmove.co.uk/api/property/{detail_id}"
async def do_request(s: aiohttp.ClientSession) -> dict[str, Any]:
async with s.get(url, params=params, headers=DEFAULT_HEADERS) as response:
if response.status != 200:
raise Exception(
f"""id: {detail_id}. Status Code: {response.status}."""
f"""Failed due to: {await response.text()}"""
start_time = time.time()
try:
async with s.get(url, params=params, headers=DEFAULT_HEADERS) as response:
response_time = time.time() - start_time
body = await response.json() if response.status == 200 else None
# Validate response for throttling
validate_response(
response,
response_time,
body,
config.slow_response_threshold,
expect_data=True,
)
return await response.json()
if response.status != 200:
raise Exception(
f"""id: {detail_id}. Status Code: {response.status}."""
f"""Failed due to: {await response.text()}"""
)
if cb is not None:
cb.record_success()
return body # type: ignore
except ThrottlingError:
if cb is not None:
cb.record_failure()
raise
except Exception as e:
if cb is not None:
cb.record_failure()
raise e
if session:
return await do_request(session)
@ -101,7 +206,11 @@ async def detail_query(
return await do_request(new_session)
@retry(wait=wait_random(min=1, max=60), stop=stop_after_attempt(3))
@retry(
retry=retry_if_exception_type(ThrottlingError),
wait=wait_exponential(multiplier=2, min=2, max=120),
stop=stop_after_attempt(5),
)
async def listing_query(
*,
page: int,
@ -118,6 +227,7 @@ async def listing_query(
page_size: int = 25,
furnish_types: list[FurnishType] = [],
session: aiohttp.ClientSession | None = None,
config: ScraperConfig | None = None,
) -> dict[str, Any]:
"""Execute a listing search query.
@ -136,10 +246,21 @@ async def listing_query(
page_size: Number of results per page (default 25).
furnish_types: List of furnish types to filter (RENT only).
session: Optional aiohttp session. Creates new one if not provided.
config: Scraper configuration. Loads from environment if not provided.
Returns:
API response as a dictionary.
Raises:
CircuitBreakerOpenError: If the circuit breaker is open.
ThrottlingError: If the request is throttled.
"""
if config is None:
config = ScraperConfig.from_env()
check_circuit_breaker(config)
cb = get_circuit_breaker(config)
params: dict[str, str] = {
"locationIdentifier": districts.get_districts()[district],
"channel": str(channel).upper(),
@ -185,14 +306,39 @@ async def listing_query(
}
async def do_request(s: aiohttp.ClientSession) -> dict[str, Any]:
async with s.get(
"https://api.rightmove.co.uk/api/property-listing",
params=params,
headers=request_headers,
) as response:
if response.status != 200:
raise Exception(f"Failed due to: {await response.text()}")
return await response.json()
start_time = time.time()
try:
async with s.get(
"https://api.rightmove.co.uk/api/property-listing",
params=params,
headers=request_headers,
) as response:
response_time = time.time() - start_time
body = await response.json() if response.status == 200 else None
# Validate response for throttling
validate_response(
response,
response_time,
body,
config.slow_response_threshold,
expect_data=(page == 1), # Only expect data on first page
)
if response.status != 200:
raise Exception(f"Failed due to: {await response.text()}")
if cb is not None:
cb.record_success()
return body # type: ignore
except ThrottlingError:
if cb is not None:
cb.record_failure()
raise
except Exception as e:
if cb is not None:
cb.record_failure()
raise e
if session:
return await do_request(session)
@ -201,7 +347,11 @@ async def listing_query(
return await do_request(new_session)
@retry(wait=wait_random(min=1, max=10), stop=stop_after_attempt(3))
@retry(
retry=retry_if_exception_type(ThrottlingError),
wait=wait_exponential(multiplier=2, min=2, max=60),
stop=stop_after_attempt(5),
)
async def probe_query(
*,
session: aiohttp.ClientSession,
@ -214,6 +364,7 @@ async def probe_query(
district: str,
max_days_since_added: int = 30,
furnish_types: list[FurnishType] = [],
config: ScraperConfig | None = None,
) -> dict[str, Any]:
"""Probe the API to get result count without fetching full results.
@ -230,10 +381,21 @@ async def probe_query(
district: District identifier string.
max_days_since_added: Maximum days since listing was added (BUY only).
furnish_types: List of furnish types to filter (RENT only).
config: Scraper configuration. Loads from environment if not provided.
Returns:
API response containing totalAvailableResults.
Raises:
CircuitBreakerOpenError: If the circuit breaker is open.
ThrottlingError: If the request is throttled.
"""
if config is None:
config = ScraperConfig.from_env()
check_circuit_breaker(config)
cb = get_circuit_breaker(config)
params: dict[str, str] = {
"locationIdentifier": districts.get_districts()[district],
"channel": str(channel).upper(),
@ -271,11 +433,36 @@ async def probe_query(
"Connection": "keep-alive",
}
async with session.get(
"https://api.rightmove.co.uk/api/property-listing",
params=params,
headers=request_headers,
) as response:
if response.status != 200:
raise Exception(f"Probe failed: {await response.text()}")
return await response.json()
start_time = time.time()
try:
async with session.get(
"https://api.rightmove.co.uk/api/property-listing",
params=params,
headers=request_headers,
) as response:
response_time = time.time() - start_time
body = await response.json() if response.status == 200 else None
# Validate response for throttling
validate_response(
response,
response_time,
body,
config.slow_response_threshold,
expect_data=False, # Probe doesn't need data, just count
)
if response.status != 200:
raise Exception(f"Probe failed: {await response.text()}")
if cb is not None:
cb.record_success()
return body # type: ignore
except ThrottlingError:
if cb is not None:
cb.record_failure()
raise
except Exception as e:
if cb is not None:
cb.record_failure()
raise e

View file

@ -0,0 +1,232 @@
"""Throttling detection and metrics for Rightmove API."""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any
import aiohttp
from rec.exceptions import (
InvalidResponseError,
IPBlockedError,
RateLimitError,
ServiceUnavailableError,
SlowResponseError,
UnexpectedEmptyResponseError,
)
@dataclass
class ThrottleMetrics:
"""Tracks throttling events and metrics.
Attributes:
rate_limit_count: Number of HTTP 429 errors.
service_unavailable_count: Number of HTTP 503 errors.
ip_blocked_count: Number of HTTP 403 errors.
slow_response_count: Number of slow responses.
empty_response_count: Number of unexpected empty responses.
invalid_response_count: Number of invalid/error responses.
total_requests: Total number of requests made.
total_response_time: Cumulative response time in seconds.
"""
rate_limit_count: int = 0
service_unavailable_count: int = 0
ip_blocked_count: int = 0
slow_response_count: int = 0
empty_response_count: int = 0
invalid_response_count: int = 0
total_requests: int = 0
total_response_time: float = 0.0
_start_time: float = field(default_factory=time.time)
def record_rate_limit(self) -> None:
"""Record a rate limit error (HTTP 429)."""
self.rate_limit_count += 1
def record_service_unavailable(self) -> None:
"""Record a service unavailable error (HTTP 503)."""
self.service_unavailable_count += 1
def record_ip_blocked(self) -> None:
"""Record an IP blocked error (HTTP 403)."""
self.ip_blocked_count += 1
def record_slow_response(self, response_time: float) -> None:
"""Record a slow response.
Args:
response_time: Response time in seconds.
"""
self.slow_response_count += 1
self.total_response_time += response_time
self.total_requests += 1
def record_empty_response(self) -> None:
"""Record an unexpected empty response."""
self.empty_response_count += 1
def record_invalid_response(self) -> None:
"""Record an invalid or error response."""
self.invalid_response_count += 1
def record_request(self, response_time: float) -> None:
"""Record a successful request.
Args:
response_time: Response time in seconds.
"""
self.total_requests += 1
self.total_response_time += response_time
@property
def average_response_time(self) -> float:
"""Calculate average response time in seconds."""
if self.total_requests == 0:
return 0.0
return self.total_response_time / self.total_requests
@property
def total_throttling_events(self) -> int:
"""Total number of throttling events."""
return (
self.rate_limit_count
+ self.service_unavailable_count
+ self.ip_blocked_count
+ self.slow_response_count
)
@property
def throttle_rate(self) -> float:
"""Percentage of requests that were throttled."""
if self.total_requests == 0:
return 0.0
return (self.total_throttling_events / self.total_requests) * 100
@property
def elapsed_time(self) -> float:
"""Time elapsed since metrics started tracking."""
return time.time() - self._start_time
def summary(self) -> str:
"""Generate a summary of throttling metrics."""
return (
f"Throttle Metrics Summary:\n"
f" Total Requests: {self.total_requests}\n"
f" Total Throttling Events: {self.total_throttling_events}\n"
f" Throttle Rate: {self.throttle_rate:.2f}%\n"
f" Rate Limit (429): {self.rate_limit_count}\n"
f" Service Unavailable (503): {self.service_unavailable_count}\n"
f" IP Blocked (403): {self.ip_blocked_count}\n"
f" Slow Responses: {self.slow_response_count}\n"
f" Empty Responses: {self.empty_response_count}\n"
f" Invalid Responses: {self.invalid_response_count}\n"
f" Average Response Time: {self.average_response_time:.2f}s\n"
f" Elapsed Time: {self.elapsed_time:.2f}s"
)
# Global metrics instance
_global_metrics: ThrottleMetrics | None = None
def get_throttle_metrics() -> ThrottleMetrics:
"""Get the global throttle metrics instance.
Returns:
Global ThrottleMetrics instance.
"""
global _global_metrics
if _global_metrics is None:
_global_metrics = ThrottleMetrics()
return _global_metrics
def reset_throttle_metrics() -> None:
"""Reset the global throttle metrics."""
global _global_metrics
_global_metrics = ThrottleMetrics()
def validate_response(
response: aiohttp.ClientResponse,
response_time: float,
response_body: dict[str, Any] | None,
slow_response_threshold: float,
expect_data: bool = True,
) -> None:
"""Validate an API response and raise appropriate exceptions for throttling.
Args:
response: The aiohttp response object.
response_time: Time taken for the request in seconds.
response_body: Parsed JSON response body (if available).
slow_response_threshold: Threshold in seconds for slow responses.
expect_data: Whether we expect data in the response.
Raises:
RateLimitError: If HTTP 429 is returned.
ServiceUnavailableError: If HTTP 503 is returned.
IPBlockedError: If HTTP 403 is returned.
SlowResponseError: If response time exceeds threshold.
UnexpectedEmptyResponseError: If response is empty when data is expected.
InvalidResponseError: If response contains error messages.
"""
metrics = get_throttle_metrics()
# Check HTTP status codes
if response.status == 429:
metrics.record_rate_limit()
raise RateLimitError(
f"Rate limit exceeded (HTTP 429). "
f"Response time: {response_time:.2f}s"
)
if response.status == 503:
metrics.record_service_unavailable()
raise ServiceUnavailableError(
f"Service unavailable (HTTP 503). "
f"Response time: {response_time:.2f}s"
)
if response.status == 403:
metrics.record_ip_blocked()
raise IPBlockedError(
f"Access forbidden, possible IP block (HTTP 403). "
f"Response time: {response_time:.2f}s"
)
# Check response time
if response_time > slow_response_threshold:
metrics.record_slow_response(response_time)
raise SlowResponseError(
f"Slow response detected: {response_time:.2f}s "
f"(threshold: {slow_response_threshold}s)"
)
# Check response body if available
if response_body is not None:
# Check for error messages
if "error" in response_body or "GENERIC_ERROR" in str(response_body):
metrics.record_invalid_response()
raise InvalidResponseError(
f"Error in response body: {response_body}"
)
# Check for unexpected empty responses
if expect_data:
properties = response_body.get("properties", [])
total_results = response_body.get("totalAvailableResults", 0)
# If we expect data but got none (and total shows there should be some)
if total_results > 0 and len(properties) == 0:
metrics.record_empty_response()
raise UnexpectedEmptyResponseError(
f"Expected data but got empty response. "
f"Total available: {total_results}"
)
# Record successful request
metrics.record_request(response_time)

View file

@ -6,6 +6,8 @@ from typing import Any
from config.scraper_config import ScraperConfig
from listing_processor import ListingProcessor
from rec.query import create_session, listing_query
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
from rec.throttle_detector import get_throttle_metrics, reset_throttle_metrics
from models.listing import QueryParameters
from repositories import ListingRepository
from tqdm.asyncio import tqdm
@ -40,76 +42,98 @@ async def dump_listings(
config = ScraperConfig.from_env()
splitter = QuerySplitter(config)
async with create_session(config) as session:
# Phase 1 & 2: Split and probe queries
logger.info("Splitting query and probing result counts...")
subqueries = await splitter.split(parameters, session)
# Reset throttle metrics at start
reset_throttle_metrics()
total_estimated = splitter.calculate_total_estimated_results(subqueries)
logger.info(
f"Split into {len(subqueries)} subqueries, "
f"estimated {total_estimated} total results"
)
try:
async with create_session(config) as session:
# Phase 1 & 2: Split and probe queries
logger.info("Splitting query and probing result counts...")
subqueries = await splitter.split(parameters, session)
# Phase 3: Fetch all pages for each subquery
semaphore = asyncio.Semaphore(config.max_concurrent_requests)
async def fetch_subquery(sq: SubQuery) -> list[dict[str, Any]]:
"""Fetch all pages for a single subquery."""
results: list[dict[str, Any]] = []
estimated = sq.estimated_results or 0
if estimated == 0:
return results
page_size = parameters.page_size
max_pages = min(
config.max_pages_per_query,
(estimated // page_size) + 1,
total_estimated = splitter.calculate_total_estimated_results(subqueries)
logger.info(
f"Split into {len(subqueries)} subqueries, "
f"estimated {total_estimated} total results"
)
for page_id in range(1, max_pages + 1):
async with semaphore:
await asyncio.sleep(config.request_delay_ms / 1000)
try:
result = await listing_query(
page=page_id,
channel=parameters.listing_type,
min_bedrooms=sq.min_bedrooms,
max_bedrooms=sq.max_bedrooms,
radius=parameters.radius,
min_price=sq.min_price,
max_price=sq.max_price,
district=sq.district,
page_size=page_size,
max_days_since_added=parameters.max_days_since_added,
furnish_types=parameters.furnish_types or [],
session=session,
)
results.append(result)
# Phase 3: Fetch all pages for each subquery
semaphore = asyncio.Semaphore(config.max_concurrent_requests)
properties = result.get("properties", [])
if len(properties) < page_size:
async def fetch_subquery(sq: SubQuery) -> list[dict[str, Any]]:
"""Fetch all pages for a single subquery."""
results: list[dict[str, Any]] = []
estimated = sq.estimated_results or 0
if estimated == 0:
return results
page_size = parameters.page_size
max_pages = min(
config.max_pages_per_query,
(estimated // page_size) + 1,
)
for page_id in range(1, max_pages + 1):
async with semaphore:
await asyncio.sleep(config.request_delay_ms / 1000)
try:
result = await listing_query(
page=page_id,
channel=parameters.listing_type,
min_bedrooms=sq.min_bedrooms,
max_bedrooms=sq.max_bedrooms,
radius=parameters.radius,
min_price=sq.min_price,
max_price=sq.max_price,
district=sq.district,
page_size=page_size,
max_days_since_added=parameters.max_days_since_added,
furnish_types=parameters.furnish_types or [],
session=session,
config=config,
)
results.append(result)
properties = result.get("properties", [])
if len(properties) < page_size:
break
except CircuitBreakerOpenError as e:
logger.error(f"Circuit breaker open: {e}")
break
except Exception as e:
if "GENERIC_ERROR" in str(e):
logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
except ThrottlingError as e:
logger.warning(
f"Throttling error on page {page_id} for {sq.district}: {e}"
)
break
except Exception as e:
if "GENERIC_ERROR" in str(e):
logger.debug(
f"Max page for {sq.district}: {page_id - 1}"
)
break
logger.warning(
f"Error fetching page {page_id} for {sq.district}: {e}"
)
break
logger.warning(
f"Error fetching page {page_id} for {sq.district}: {e}"
)
break
return results
return results
# Fetch all subqueries with progress bar
all_results = await tqdm.gather(
*[fetch_subquery(sq) for sq in subqueries],
desc="Fetching listings",
)
# Fetch all subqueries with progress bar
all_results = await tqdm.gather(
*[fetch_subquery(sq) for sq in subqueries],
desc="Fetching listings",
)
except CircuitBreakerOpenError as e:
logger.error(f"Circuit breaker prevented listing fetch: {e}")
logger.info(get_throttle_metrics().summary())
return []
finally:
# Log throttle metrics at end
metrics = get_throttle_metrics()
if metrics.total_requests > 0:
logger.info("\n" + metrics.summary())
# Extract listing identifiers from results
listing_ids: list[int] = []

View file

@ -16,6 +16,7 @@ import aiohttp
from config.scraper_config import ScraperConfig
from models.listing import ListingType, QueryParameters
from rec.districts import get_districts
from rec.exceptions import CircuitBreakerOpenError, ThrottlingError
logger = logging.getLogger("uvicorn.error")
@ -113,6 +114,9 @@ class QuerySplitter:
Returns:
Total available results for this subquery.
Raises:
CircuitBreakerOpenError: If the circuit breaker is open.
"""
from rec.query import probe_query
@ -128,8 +132,17 @@ class QuerySplitter:
district=subquery.district,
max_days_since_added=parameters.max_days_since_added,
furnish_types=parameters.furnish_types or [],
config=self.config,
)
return result.get("totalAvailableResults", 0)
except CircuitBreakerOpenError:
logger.error("Circuit breaker is open, stopping probe operations")
raise
except ThrottlingError as e:
logger.warning(
f"Throttling detected during probe for {subquery.district}: {e}"
)
return 0
except Exception as e:
logger.warning(f"Failed to probe subquery {subquery}: {e}")
return 0

View file

@ -0,0 +1,311 @@
"""Integration tests for throttle detection and circuit breaker."""
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from aiohttp import ClientResponse
from config.scraper_config import ScraperConfig
from rec.exceptions import (
CircuitBreakerOpenError,
RateLimitError,
ServiceUnavailableError,
ThrottlingError,
)
from rec.query import (
detail_query,
listing_query,
probe_query,
get_circuit_breaker,
reset_circuit_breaker,
)
from rec.throttle_detector import reset_throttle_metrics, get_throttle_metrics
from rec.circuit_breaker import CircuitBreaker, CircuitState
from models.listing import ListingType
@pytest.fixture
def config() -> ScraperConfig:
"""Create a test configuration."""
return ScraperConfig(
max_concurrent_requests=5,
request_delay_ms=10,
slow_response_threshold=2.0,
enable_circuit_breaker=True,
circuit_breaker_failure_threshold=3,
circuit_breaker_recovery_timeout=0.5,
)
@pytest.fixture(autouse=True)
def reset_globals() -> None:
"""Reset global state before each test."""
reset_throttle_metrics()
reset_circuit_breaker()
class MockResponse:
"""Mock aiohttp response."""
def __init__(
self,
status: int = 200,
json_data: dict | None = None,
text: str = "",
):
self.status = status
self._json_data = json_data or {}
self._text = text
async def json(self) -> dict:
return self._json_data
async def text(self) -> str:
return self._text
async def __aenter__(self) -> "MockResponse":
return self
async def __aexit__(self, *args: object) -> None:
pass
class TestThrottlingRetryBehavior:
"""Test retry behavior for throttling errors."""
@pytest.mark.asyncio
async def test_rate_limit_triggers_retry(self, config: ScraperConfig) -> None:
"""Test that 429 responses trigger retry with backoff."""
call_count = 0
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
nonlocal call_count
call_count += 1
if call_count < 3:
return MockResponse(status=429)
return MockResponse(
status=200,
json_data={"totalAvailableResults": 10, "properties": []},
)
mock_session = MagicMock()
mock_session.get = mock_get
# Mock district lookup
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
# The retry decorator will catch RateLimitError and retry
# We need to patch the tenacity wait to speed up the test
with patch("tenacity.wait_exponential.__call__", return_value=0):
result = await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="Test",
config=config,
)
assert result["totalAvailableResults"] == 10
assert call_count == 3
@pytest.mark.asyncio
async def test_service_unavailable_triggers_retry(
self, config: ScraperConfig
) -> None:
"""Test that 503 responses trigger retry."""
call_count = 0
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
nonlocal call_count
call_count += 1
if call_count < 2:
return MockResponse(status=503)
return MockResponse(
status=200,
json_data={"totalAvailableResults": 5, "properties": []},
)
mock_session = MagicMock()
mock_session.get = mock_get
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
with patch("tenacity.wait_exponential.__call__", return_value=0):
result = await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="Test",
config=config,
)
assert call_count == 2
class TestCircuitBreakerIntegration:
"""Test circuit breaker integration with queries."""
@pytest.mark.asyncio
async def test_circuit_breaker_opens_after_failures(
self, config: ScraperConfig
) -> None:
"""Test that circuit breaker opens after consecutive failures."""
call_count = 0
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
nonlocal call_count
call_count += 1
return MockResponse(status=429)
mock_session = MagicMock()
mock_session.get = mock_get
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
# After enough failures, circuit should open
with pytest.raises((RateLimitError, CircuitBreakerOpenError)):
with patch("tenacity.wait_exponential.__call__", return_value=0):
await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="Test",
config=config,
)
# Check circuit breaker state
cb = get_circuit_breaker(config)
assert cb is not None
# After many failures, the circuit should be open
assert cb.failure_count >= config.circuit_breaker_failure_threshold
@pytest.mark.asyncio
async def test_circuit_breaker_blocks_requests_when_open(
self, config: ScraperConfig
) -> None:
"""Test that open circuit breaker blocks requests immediately."""
# Force open the circuit breaker
cb = get_circuit_breaker(config)
assert cb is not None
for _ in range(config.circuit_breaker_failure_threshold):
cb.record_failure()
assert cb.is_open
mock_session = MagicMock()
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
with pytest.raises(CircuitBreakerOpenError):
await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="Test",
config=config,
)
class TestMetricsTracking:
"""Test throttle metrics are properly tracked."""
@pytest.mark.asyncio
async def test_metrics_tracked_on_rate_limit(self, config: ScraperConfig) -> None:
"""Test that rate limit errors are tracked in metrics."""
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
return MockResponse(status=429)
mock_session = MagicMock()
mock_session.get = mock_get
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
with pytest.raises(RateLimitError):
with patch("tenacity.wait_exponential.__call__", return_value=0):
await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="Test",
config=config,
)
metrics = get_throttle_metrics()
assert metrics.rate_limit_count > 0
@pytest.mark.asyncio
async def test_metrics_tracked_on_success(self, config: ScraperConfig) -> None:
"""Test that successful requests are tracked in metrics."""
async def mock_get(*args: object, **kwargs: object) -> MockResponse:
return MockResponse(
status=200,
json_data={"totalAvailableResults": 10, "properties": []},
)
mock_session = MagicMock()
mock_session.get = mock_get
with patch("rec.query.districts.get_districts", return_value={"Test": "LOC1"}):
await probe_query(
session=mock_session,
channel=ListingType.RENT,
min_bedrooms=1,
max_bedrooms=2,
radius=1.0,
min_price=1000,
max_price=2000,
district="Test",
config=config,
)
metrics = get_throttle_metrics()
assert metrics.total_requests == 1
assert metrics.total_throttling_events == 0
class TestConfigIntegration:
"""Test configuration integration."""
def test_config_from_env_includes_throttle_settings(self) -> None:
"""Test that config loads throttle settings from environment."""
import os
original_env = os.environ.copy()
try:
os.environ["RIGHTMOVE_SLOW_RESPONSE_THRESHOLD"] = "5.0"
os.environ["RIGHTMOVE_ENABLE_CIRCUIT_BREAKER"] = "false"
os.environ["RIGHTMOVE_CIRCUIT_BREAKER_FAILURES"] = "10"
os.environ["RIGHTMOVE_CIRCUIT_BREAKER_TIMEOUT"] = "120.0"
config = ScraperConfig.from_env()
assert config.slow_response_threshold == 5.0
assert config.enable_circuit_breaker is False
assert config.circuit_breaker_failure_threshold == 10
assert config.circuit_breaker_recovery_timeout == 120.0
finally:
os.environ.clear()
os.environ.update(original_env)
def test_circuit_breaker_disabled_returns_none(self) -> None:
"""Test that disabled circuit breaker returns None."""
config = ScraperConfig(
enable_circuit_breaker=False,
)
reset_circuit_breaker()
cb = get_circuit_breaker(config)
assert cb is None

View file

@ -0,0 +1,334 @@
"""Unit tests for throttle detection and circuit breaker."""
import pytest
from unittest.mock import MagicMock, AsyncMock
import time
from rec.exceptions import (
RightmoveAPIError,
ThrottlingError,
RateLimitError,
ServiceUnavailableError,
IPBlockedError,
SlowResponseError,
UnexpectedEmptyResponseError,
InvalidResponseError,
CircuitBreakerOpenError,
)
from rec.throttle_detector import (
ThrottleMetrics,
validate_response,
get_throttle_metrics,
reset_throttle_metrics,
)
from rec.circuit_breaker import CircuitBreaker, CircuitState
class TestExceptionHierarchy:
"""Test custom exception hierarchy."""
def test_rightmove_api_error_is_exception(self) -> None:
assert issubclass(RightmoveAPIError, Exception)
def test_throttling_error_is_rightmove_api_error(self) -> None:
assert issubclass(ThrottlingError, RightmoveAPIError)
def test_rate_limit_error_is_throttling_error(self) -> None:
assert issubclass(RateLimitError, ThrottlingError)
def test_service_unavailable_error_is_throttling_error(self) -> None:
assert issubclass(ServiceUnavailableError, ThrottlingError)
def test_ip_blocked_error_is_throttling_error(self) -> None:
assert issubclass(IPBlockedError, ThrottlingError)
def test_slow_response_error_is_throttling_error(self) -> None:
assert issubclass(SlowResponseError, ThrottlingError)
def test_unexpected_empty_response_error_is_rightmove_api_error(self) -> None:
assert issubclass(UnexpectedEmptyResponseError, RightmoveAPIError)
assert not issubclass(UnexpectedEmptyResponseError, ThrottlingError)
def test_invalid_response_error_is_rightmove_api_error(self) -> None:
assert issubclass(InvalidResponseError, RightmoveAPIError)
assert not issubclass(InvalidResponseError, ThrottlingError)
def test_circuit_breaker_open_error_is_rightmove_api_error(self) -> None:
assert issubclass(CircuitBreakerOpenError, RightmoveAPIError)
def test_exception_messages(self) -> None:
error = RateLimitError("Too many requests")
assert str(error) == "Too many requests"
class TestThrottleMetrics:
"""Test ThrottleMetrics class."""
def test_initial_state(self) -> None:
metrics = ThrottleMetrics()
assert metrics.rate_limit_count == 0
assert metrics.service_unavailable_count == 0
assert metrics.ip_blocked_count == 0
assert metrics.slow_response_count == 0
assert metrics.empty_response_count == 0
assert metrics.invalid_response_count == 0
assert metrics.total_requests == 0
assert metrics.total_response_time == 0.0
def test_record_rate_limit(self) -> None:
metrics = ThrottleMetrics()
metrics.record_rate_limit()
assert metrics.rate_limit_count == 1
metrics.record_rate_limit()
assert metrics.rate_limit_count == 2
def test_record_service_unavailable(self) -> None:
metrics = ThrottleMetrics()
metrics.record_service_unavailable()
assert metrics.service_unavailable_count == 1
def test_record_ip_blocked(self) -> None:
metrics = ThrottleMetrics()
metrics.record_ip_blocked()
assert metrics.ip_blocked_count == 1
def test_record_slow_response(self) -> None:
metrics = ThrottleMetrics()
metrics.record_slow_response(15.0)
assert metrics.slow_response_count == 1
assert metrics.total_response_time == 15.0
assert metrics.total_requests == 1
def test_record_empty_response(self) -> None:
metrics = ThrottleMetrics()
metrics.record_empty_response()
assert metrics.empty_response_count == 1
def test_record_invalid_response(self) -> None:
metrics = ThrottleMetrics()
metrics.record_invalid_response()
assert metrics.invalid_response_count == 1
def test_record_request(self) -> None:
metrics = ThrottleMetrics()
metrics.record_request(0.5)
assert metrics.total_requests == 1
assert metrics.total_response_time == 0.5
def test_average_response_time(self) -> None:
metrics = ThrottleMetrics()
metrics.record_request(1.0)
metrics.record_request(2.0)
metrics.record_request(3.0)
assert metrics.average_response_time == 2.0
def test_average_response_time_zero_requests(self) -> None:
metrics = ThrottleMetrics()
assert metrics.average_response_time == 0.0
def test_total_throttling_events(self) -> None:
metrics = ThrottleMetrics()
metrics.record_rate_limit()
metrics.record_service_unavailable()
metrics.record_ip_blocked()
metrics.record_slow_response(15.0)
assert metrics.total_throttling_events == 4
def test_throttle_rate(self) -> None:
metrics = ThrottleMetrics()
metrics.record_request(0.5) # 1 normal request
metrics.record_request(0.5) # 2 normal requests
metrics.record_rate_limit()
metrics.record_request(0.5) # 3 normal requests (rate limit doesn't count as request)
# 1 throttling event, 3 requests = 33.33%
assert metrics.throttle_rate == pytest.approx(33.33, rel=0.01)
def test_throttle_rate_zero_requests(self) -> None:
metrics = ThrottleMetrics()
assert metrics.throttle_rate == 0.0
def test_elapsed_time(self) -> None:
metrics = ThrottleMetrics()
time.sleep(0.1)
assert metrics.elapsed_time >= 0.1
def test_summary(self) -> None:
metrics = ThrottleMetrics()
metrics.record_request(1.0)
metrics.record_rate_limit()
summary = metrics.summary()
assert "Total Requests:" in summary
assert "Rate Limit (429):" in summary
assert "1" in summary
class TestGlobalMetrics:
"""Test global metrics accessor."""
def test_get_throttle_metrics_singleton(self) -> None:
reset_throttle_metrics()
m1 = get_throttle_metrics()
m2 = get_throttle_metrics()
assert m1 is m2
def test_reset_throttle_metrics(self) -> None:
reset_throttle_metrics()
metrics = get_throttle_metrics()
metrics.record_rate_limit()
assert metrics.rate_limit_count == 1
reset_throttle_metrics()
new_metrics = get_throttle_metrics()
assert new_metrics.rate_limit_count == 0
class TestValidateResponse:
"""Test validate_response function."""
def setup_method(self) -> None:
reset_throttle_metrics()
def create_mock_response(self, status: int) -> MagicMock:
response = MagicMock()
response.status = status
return response
def test_rate_limit_error(self) -> None:
response = self.create_mock_response(429)
with pytest.raises(RateLimitError):
validate_response(response, 0.5, None, 10.0)
assert get_throttle_metrics().rate_limit_count == 1
def test_service_unavailable_error(self) -> None:
response = self.create_mock_response(503)
with pytest.raises(ServiceUnavailableError):
validate_response(response, 0.5, None, 10.0)
assert get_throttle_metrics().service_unavailable_count == 1
def test_ip_blocked_error(self) -> None:
response = self.create_mock_response(403)
with pytest.raises(IPBlockedError):
validate_response(response, 0.5, None, 10.0)
assert get_throttle_metrics().ip_blocked_count == 1
def test_slow_response_error(self) -> None:
response = self.create_mock_response(200)
body = {"totalAvailableResults": 0, "properties": []}
with pytest.raises(SlowResponseError):
validate_response(response, 15.0, body, 10.0)
assert get_throttle_metrics().slow_response_count == 1
def test_slow_response_just_under_threshold(self) -> None:
response = self.create_mock_response(200)
body = {"totalAvailableResults": 0, "properties": []}
# Should not raise
validate_response(response, 9.9, body, 10.0)
assert get_throttle_metrics().slow_response_count == 0
def test_error_in_response_body(self) -> None:
response = self.create_mock_response(200)
body = {"error": "Something went wrong"}
with pytest.raises(InvalidResponseError):
validate_response(response, 0.5, body, 10.0)
assert get_throttle_metrics().invalid_response_count == 1
def test_generic_error_in_body(self) -> None:
response = self.create_mock_response(200)
body = {"message": "GENERIC_ERROR occurred"}
with pytest.raises(InvalidResponseError):
validate_response(response, 0.5, body, 10.0)
def test_unexpected_empty_response(self) -> None:
response = self.create_mock_response(200)
body = {"totalAvailableResults": 100, "properties": []}
with pytest.raises(UnexpectedEmptyResponseError):
validate_response(response, 0.5, body, 10.0, expect_data=True)
assert get_throttle_metrics().empty_response_count == 1
def test_empty_response_when_not_expecting_data(self) -> None:
response = self.create_mock_response(200)
body = {"totalAvailableResults": 100, "properties": []}
# Should not raise when expect_data=False
validate_response(response, 0.5, body, 10.0, expect_data=False)
assert get_throttle_metrics().empty_response_count == 0
def test_valid_response(self) -> None:
response = self.create_mock_response(200)
body = {
"totalAvailableResults": 10,
"properties": [{"id": 1}, {"id": 2}],
}
validate_response(response, 0.5, body, 10.0, expect_data=True)
assert get_throttle_metrics().total_requests == 1
assert get_throttle_metrics().total_throttling_events == 0
class TestCircuitBreaker:
"""Test CircuitBreaker class."""
def test_initial_state_is_closed(self) -> None:
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
assert cb.state == CircuitState.CLOSED
assert cb.is_closed
assert not cb.is_open
assert not cb.is_half_open
def test_allows_requests_when_closed(self) -> None:
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
# Should not raise
cb.call()
def test_opens_after_threshold_failures(self) -> None:
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
cb.record_failure()
cb.record_failure()
assert cb.is_closed
cb.record_failure()
assert cb.is_open
def test_blocks_requests_when_open(self) -> None:
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=60.0)
cb.record_failure()
assert cb.is_open
with pytest.raises(CircuitBreakerOpenError):
cb.call()
def test_success_resets_failure_count(self) -> None:
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=10.0)
cb.record_failure()
cb.record_failure()
assert cb.failure_count == 2
cb.record_success()
assert cb.failure_count == 0
def test_transitions_to_half_open_after_timeout(self) -> None:
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
cb.record_failure()
assert cb.is_open
time.sleep(0.15)
cb.call() # Should transition to half-open
assert cb.is_half_open
def test_half_open_success_closes_circuit(self) -> None:
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
cb.record_failure()
time.sleep(0.15)
cb.call() # Transition to half-open
assert cb.is_half_open
cb.record_success()
assert cb.is_closed
def test_half_open_failure_reopens_circuit(self) -> None:
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
cb.record_failure()
time.sleep(0.15)
cb.call() # Transition to half-open
assert cb.is_half_open
cb.record_failure()
assert cb.is_open
def test_reset(self) -> None:
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=60.0)
cb.record_failure()
assert cb.is_open
cb.reset()
assert cb.is_closed
assert cb.failure_count == 0