from datetime import timedelta from functools import lru_cache import json from string import Template from typing import Any from api.auth import User import redis from celery_app import app class RedisRepository: redis_client: redis.Redis tasks_key_template = Template("user:{user_id}/tasks") def __init__(self): redis_hostname = app.broker_connection().info()["hostname"] redis_port = app.broker_connection().info()["port"] self.redis_client = redis.Redis( host=redis_hostname, port=redis_port, db=0, decode_responses=True ) # decode_responses=True returns str, not bytes @lru_cache(maxsize=None) @staticmethod def instance(): return RedisRepository() def set_key(self, key: str, value, ttl: timedelta | None = None) -> None: serialized_value = self.__serialize_value(value) self.redis_client.set(key, serialized_value) ttl = ttl or timedelta(hours=3) self.redis_client.expire(key, ttl) def get_key(self, key: str) -> Any | None: serialized_value = self.redis_client.get(key) if serialized_value is None: return None return self.__deserialize_value(serialized_value) def add_task_for_user(self, user: User, task_id: str): # Add the task ID to the Redis set for the user current_tasks = ( self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or [] ) self.set_key( self.tasks_key_template.substitute(user_id=user.email), [task_id] + current_tasks, ) def get_tasks_for_user(self, user: User) -> list[str]: # Get the task IDs from the Redis set for the user return ( self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or [] ) def __serialize_value(self, value: Any) -> str: return json.dumps(value) def __deserialize_value(self, value_str) -> Any: return json.loads(value_str)