60 lines
2 KiB
Python
60 lines
2 KiB
Python
from datetime import timedelta
|
|
from functools import lru_cache
|
|
import json
|
|
from string import Template
|
|
from typing import Any
|
|
from api.auth import User
|
|
import redis
|
|
from celery_app import app
|
|
|
|
|
|
class RedisRepository:
|
|
redis_client: redis.Redis
|
|
tasks_key_template = Template("user:{user_id}/tasks")
|
|
|
|
def __init__(self):
|
|
redis_hostname = app.broker_connection().info()["hostname"]
|
|
redis_port = app.broker_connection().info()["port"]
|
|
self.redis_client = redis.Redis(
|
|
host=redis_hostname, port=redis_port, db=0, decode_responses=True
|
|
) # decode_responses=True returns str, not bytes
|
|
|
|
@lru_cache(maxsize=None)
|
|
@staticmethod
|
|
def instance():
|
|
return RedisRepository()
|
|
|
|
def set_key(self, key: str, value, ttl: timedelta | None = None) -> None:
|
|
serialized_value = self.__serialize_value(value)
|
|
self.redis_client.set(key, serialized_value)
|
|
|
|
ttl = ttl or timedelta(hours=3)
|
|
self.redis_client.expire(key, ttl)
|
|
|
|
def get_key(self, key: str) -> Any | None:
|
|
serialized_value = self.redis_client.get(key)
|
|
if serialized_value is None:
|
|
return None
|
|
return self.__deserialize_value(serialized_value)
|
|
|
|
def add_task_for_user(self, user: User, task_id: str):
|
|
# Add the task ID to the Redis set for the user
|
|
current_tasks = (
|
|
self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or []
|
|
)
|
|
self.set_key(
|
|
self.tasks_key_template.substitute(user_id=user.email),
|
|
[task_id] + current_tasks,
|
|
)
|
|
|
|
def get_tasks_for_user(self, user: User) -> list[str]:
|
|
# Get the task IDs from the Redis set for the user
|
|
return (
|
|
self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or []
|
|
)
|
|
|
|
def __serialize_value(self, value: Any) -> str:
|
|
return json.dumps(value)
|
|
|
|
def __deserialize_value(self, value_str) -> Any:
|
|
return json.loads(value_str)
|