from datetime import timedelta from functools import lru_cache import json from string import Template from typing import Any, TypeVar from api.auth import User import redis from celery_app import app T = TypeVar("T") class RedisRepository: redis_client: redis.Redis # type: ignore[type-arg] tasks_key_template: Template = Template("user:{user_id}/tasks") def __init__(self) -> None: redis_hostname: str = app.broker_connection().info()["hostname"] redis_port: int = app.broker_connection().info()["port"] self.redis_client = redis.Redis( host=redis_hostname, port=redis_port, db=0, decode_responses=True ) # decode_responses=True returns str, not bytes @staticmethod @lru_cache(maxsize=None) def instance() -> "RedisRepository": return RedisRepository() def set_key(self, key: str, value: Any, ttl: timedelta | None = None) -> None: serialized_value = self.__serialize_value(value) self.redis_client.set(key, serialized_value) ttl = ttl or timedelta(hours=3) self.redis_client.expire(key, ttl) def get_key(self, key: str) -> Any | None: serialized_value = self.redis_client.get(key) if serialized_value is None: return None return self.__deserialize_value(serialized_value) def add_task_for_user(self, user: User, task_id: str) -> None: # Add the task ID to the Redis set for the user current_tasks: list[str] = ( self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or [] ) self.set_key( self.tasks_key_template.substitute(user_id=user.email), [task_id] + current_tasks, ) def get_tasks_for_user(self, user: User) -> list[str]: # Get the task IDs from the Redis set for the user return ( self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or [] ) def remove_task_for_user(self, user: User, task_id: str) -> bool: """Remove a specific task from the user's task list.""" current_tasks: list[str] = self.get_tasks_for_user(user) if task_id not in current_tasks: return False updated_tasks = [t for t in current_tasks if t != task_id] self.set_key( self.tasks_key_template.substitute(user_id=user.email), updated_tasks, ) return True def clear_tasks_for_user(self, user: User) -> int: """Clear all tasks for a user. Returns the number of tasks cleared.""" current_tasks: list[str] = self.get_tasks_for_user(user) count = len(current_tasks) self.redis_client.delete(self.tasks_key_template.substitute(user_id=user.email)) return count def __serialize_value(self, value: Any) -> str: return json.dumps(value) def __deserialize_value(self, value_str: str) -> Any: return json.loads(value_str)