from datetime import timedelta from functools import lru_cache import json import os from string import Template from typing import Any, TypeVar from api.auth import User import redis from celery_app import app T = TypeVar("T") # Default Redis logical DB for per-user state (task lists, WebAuthn # challenges). Previously this lived on db0 alongside the Celery broker # AND another app's kombu bindings (paperless-ngx). Moving to db3 isolates # user state from broker traffic and prevents key collisions. _DEFAULT_USER_DB = 3 # Namespace for every key written by this class so any other process # sharing the Redis instance can't collide. _KEY_PREFIX = "wrongmove:user:" class RedisRepository: redis_client: redis.Redis # type: ignore[type-arg] # tasks_key_template is the *suffix* portion; set_key / get_key uniformly # prepend ``_KEY_PREFIX`` so every key this class writes is namespaced. tasks_key_template: Template = Template("${user_id}/tasks") def __init__(self) -> None: redis_hostname: str = app.broker_connection().info()["hostname"] redis_port: int = app.broker_connection().info()["port"] db = int(os.getenv("REDIS_USER_DB", str(_DEFAULT_USER_DB))) # socket_keepalive + health_check_interval keep the connection # alive across the Redis HAProxy 30s idle timeout (see celery_app.py). self.redis_client = redis.Redis( host=redis_hostname, port=redis_port, db=db, decode_responses=True, socket_keepalive=True, health_check_interval=25, ) # decode_responses=True returns str, not bytes @staticmethod @lru_cache(maxsize=None) def instance() -> "RedisRepository": return RedisRepository() @staticmethod def _prefixed(key: str) -> str: """Prepend the wrongmove user-namespace prefix if not already present.""" if key.startswith(_KEY_PREFIX): return key return f"{_KEY_PREFIX}{key}" def set_key(self, key: str, value: Any, ttl: timedelta | None = None) -> None: full_key = self._prefixed(key) serialized_value = self.__serialize_value(value) self.redis_client.set(full_key, serialized_value) ttl = ttl or timedelta(hours=3) self.redis_client.expire(full_key, ttl) def get_key(self, key: str) -> Any | None: serialized_value = self.redis_client.get(self._prefixed(key)) if serialized_value is None: return None return self.__deserialize_value(serialized_value) def add_task_for_user(self, user: User, task_id: str) -> None: # Add the task ID to the Redis set for the user current_tasks: list[str] = ( self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or [] ) self.set_key( self.tasks_key_template.substitute(user_id=user.email), [task_id] + current_tasks, ) def get_tasks_for_user(self, user: User) -> list[str]: # Get the task IDs from the Redis set for the user return ( self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or [] ) def remove_task_for_user(self, user: User, task_id: str) -> bool: """Remove a specific task from the user's task list.""" current_tasks: list[str] = self.get_tasks_for_user(user) if task_id not in current_tasks: return False updated_tasks = [t for t in current_tasks if t != task_id] self.set_key( self.tasks_key_template.substitute(user_id=user.email), updated_tasks, ) return True def clear_tasks_for_user(self, user: User) -> int: """Clear all tasks for a user. Returns the number of tasks cleared.""" current_tasks: list[str] = self.get_tasks_for_user(user) count = len(current_tasks) self.redis_client.delete( self._prefixed(self.tasks_key_template.substitute(user_id=user.email)) ) return count def __serialize_value(self, value: Any) -> str: return json.dumps(value) def __deserialize_value(self, value_str: str) -> Any: return json.loads(value_str)