"""Redis-based caching for listing GeoJSON query results.""" import hashlib import json import logging import os import uuid from typing import Generator from urllib.parse import urlparse, urlunparse import redis from models.listing import QueryParameters logger = logging.getLogger(__name__) CACHE_PREFIX = "listings:geojson:" STAGING_PREFIX = "listings:geojson:staging:" CACHE_TTL_SECONDS = 24 * 60 * 60 # 24 hours STALE_AFTER_SECONDS = 4 * 60 * 60 # 4 hours — serve stale, revalidate in background REPOPULATING_PREFIX = "listings:geojson:repopulating:" STAGING_TTL_SECONDS = 5 * 60 # 5 minutes safety net for orphaned staging keys CACHE_DB = 2 def _get_redis_client() -> redis.Redis: """Get Redis client using Celery broker URL but overriding to db=2.""" broker_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") parsed = urlparse(broker_url) cache_url = urlunparse(parsed._replace(path=f"/{CACHE_DB}")) return redis.from_url(cache_url, decode_responses=True) def make_cache_key(query_params: QueryParameters) -> str: """Generate a cache key from query parameters.""" params_json = query_params.model_dump_json() hash_suffix = hashlib.sha256(params_json.encode()).hexdigest()[:16] return f"{CACHE_PREFIX}{hash_suffix}" def get_cached_count(query_params: QueryParameters) -> int | None: """Return the number of cached features for a query, or None if not cached.""" try: client = _get_redis_client() key = make_cache_key(query_params) if not client.exists(key): return None return client.llen(key) except redis.RedisError as e: logger.warning(f"Redis cache read error: {e}") return None def get_cached_features( query_params: QueryParameters, batch_size: int = 50 ) -> Generator[list[dict], None, None]: """Yield batches of cached GeoJSON features.""" try: client = _get_redis_client() key = make_cache_key(query_params) total = client.llen(key) for start in range(0, total, batch_size): end = start + batch_size - 1 items = client.lrange(key, start, end) batch = [json.loads(item) for item in items] if batch: yield batch except redis.RedisError as e: logger.warning(f"Redis cache read error during streaming: {e}") def cache_features_batch(query_params: QueryParameters, features: list[dict]) -> None: """Append a batch of features to the cache list.""" if not features: return try: client = _get_redis_client() key = make_cache_key(query_params) pipeline = client.pipeline() for feature in features: pipeline.rpush(key, json.dumps(feature)) # Set/refresh TTL pipeline.expire(key, CACHE_TTL_SECONDS) pipeline.execute() except redis.RedisError as e: logger.warning(f"Redis cache write error: {e}") def begin_cache_population(query_params: QueryParameters) -> str: """Begin staged cache population. Returns a unique staging key. The staging key gets its TTL set by cache_features_batch_staged on the first rpush, so no pre-creation is needed here. """ return f"{STAGING_PREFIX}{uuid.uuid4().hex}" def cache_features_batch_staged(staging_key: str, features: list[dict]) -> None: """Append a batch of features to a staging key.""" if not features: return try: client = _get_redis_client() pipeline = client.pipeline() for feature in features: pipeline.rpush(staging_key, json.dumps(feature)) pipeline.expire(staging_key, STAGING_TTL_SECONDS) pipeline.execute() except redis.RedisError as e: logger.warning(f"Redis staged cache write error: {e}") def finalize_cache_population(staging_key: str, query_params: QueryParameters) -> None: """Atomically rename the staging key to the live cache key and set TTL.""" try: client = _get_redis_client() live_key = make_cache_key(query_params) # RENAME is atomic — replaces the live key in one operation client.rename(staging_key, live_key) client.expire(live_key, CACHE_TTL_SECONDS) logger.debug(f"Finalized cache population for {live_key}") except redis.RedisError as e: logger.warning(f"Redis cache finalize error: {e}") def delete_staging_key(staging_key: str) -> None: """Delete an orphaned staging key (used in error cleanup).""" try: client = _get_redis_client() client.delete(staging_key) except redis.RedisError as e: logger.warning(f"Redis staging key cleanup error: {e}") def invalidate_cache() -> None: """Delete all listing GeoJSON cache entries, including staging keys.""" try: client = _get_redis_client() cursor = 0 deleted = 0 # Clean both live cache keys and staging keys for pattern in [f"{CACHE_PREFIX}*", f"{STAGING_PREFIX}*"]: cursor = 0 while True: cursor, keys = client.scan(cursor, match=pattern, count=100) if keys: pipeline = client.pipeline() for key in keys: pipeline.delete(key) pipeline.execute() deleted += len(keys) if cursor == 0: break if deleted: logger.info(f"Invalidated {deleted} listing cache entries") except redis.RedisError as e: logger.warning(f"Redis cache invalidation error: {e}") def get_cache_age(query_params: QueryParameters) -> int | None: """Return the age in seconds of a cache entry, or None if not cached.""" try: client = _get_redis_client() key = make_cache_key(query_params) ttl = client.ttl(key) if ttl < 0: # -2 = key doesn't exist, -1 = no expiry return None return CACHE_TTL_SECONDS - ttl except redis.RedisError as e: logger.warning(f"Redis cache age check error: {e}") return None def is_cache_stale(query_params: QueryParameters) -> bool: """Return True if the cache entry exists but is older than STALE_AFTER_SECONDS.""" age = get_cache_age(query_params) if age is None: return False return age > STALE_AFTER_SECONDS def acquire_repopulation_lock(query_params: QueryParameters) -> bool: """Try to acquire a lock to prevent concurrent repopulations. Returns True if the lock was acquired, False if another repopulation is already in progress for the same query. """ try: client = _get_redis_client() key = make_cache_key(query_params) hash_suffix = key.removeprefix(CACHE_PREFIX) lock_key = f"{REPOPULATING_PREFIX}{hash_suffix}" # SETNX with 60-second TTL acquired: bool = bool(client.set(lock_key, "1", nx=True, ex=60)) return acquired except redis.RedisError as e: logger.warning(f"Redis repopulation lock error: {e}") return False