save user queries in redis so that user can refresh the page and still come back to their latest task

This commit is contained in:
Viktor Barzin 2025-07-06 12:02:25 +00:00
parent a055c92dea
commit d4b22deda0
No known key found for this signature in database
GPG key ID: 4056458DBDBF8863
7 changed files with 114 additions and 4 deletions

View file

@ -16,6 +16,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query
from api.auth import User from api.auth import User
from models.listing import QueryParameters from models.listing import QueryParameters
from rec import districts from rec import districts
from redis_repository import RedisRepository
from repositories.listing_repository import ListingRepository from repositories.listing_repository import ListingRepository
from repositories.listing_repository import ListingRepository from repositories.listing_repository import ListingRepository
from database import engine from database import engine
@ -27,6 +28,7 @@ from alembic import command
from alembic.config import Config from alembic.config import Config
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from celery.exceptions import TaskRevokedError from celery.exceptions import TaskRevokedError
from celery_app import app as celery_app
load_dotenv() load_dotenv()
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
@ -86,6 +88,9 @@ async def refresh_listings(
args=(query_parameters.model_dump_json(),), args=(query_parameters.model_dump_json(),),
expires=expiry_time, expires=expiry_time,
) )
redis_repository = RedisRepository.instance()
redis_repository.add_task_for_user(user, task.id)
return {"task_id": task.id} return {"task_id": task.id}
@ -107,6 +112,15 @@ async def get_task_status(
} }
@app.get("/api/tasks_for_user")
async def get_tasks_for_user(
user: Annotated[User, Depends(get_current_user)],
) -> list[str]:
redis_repository = RedisRepository.instance()
user_tasks = redis_repository.get_tasks_for_user(user)
return user_tasks
@app.get("/api/get_districts") @app.get("/api/get_districts")
async def get_districts( async def get_districts(
user: Annotated[User, Depends(get_current_user)], user: Annotated[User, Depends(get_current_user)],

View file

@ -60,6 +60,26 @@ const fetchData = async (user: User, baseQueyrUri: string, parameters: Parameter
}; };
const fetchActiveTasksForUser = async (user: User) => {
const accessToken = user?.access_token;
const response = await fetch(`/api/tasks_for_user`, {
method: 'GET',
headers: {
'Authorization': `Bearer ${accessToken}`, // Pass the token
'Content-Type': 'application/json',
},
});
if (!response.ok) {
throw new Error(`Failed to fetch active tasks for user: ${response.status}`);
}
const data =
await response.json();
return data;
};
function App() { function App() {
const [listingData, setListingData] = useState({}); const [listingData, setListingData] = useState({});
const [taskID, setTaskID] = useState<string | null>(null); const [taskID, setTaskID] = useState<string | null>(null);
@ -83,6 +103,17 @@ function App() {
getUser().then(setUser); getUser().then(setUser);
}, []); }, []);
useEffect(() => {
if (!user) {
return;
}
fetchActiveTasksForUser(user).then((tasks) => {
if (tasks) {
setTaskID(tasks[0])
}
})
}, [user, taskID])
if (!user) { if (!user) {
return <LoginModal isOpen={user === null} /> return <LoginModal isOpen={user === null} />
} }

View file

@ -58,7 +58,7 @@ def listing_filter_options(func):
"--max-price", "--max-price",
default=999_999, default=999_999,
help="Maximum price", help="Maximum price",
type=click.IntRange(min=0, max=40_000), # 40k for renting type=click.IntRange(min=0), # 40k for renting
) )
@click.option( @click.option(
"--district", "--district",

View file

@ -44,7 +44,7 @@ async def detail_query(detail_id: int) -> dict[str, Any]:
return await response.json() return await response.json()
@retry(wait=wait_random(min=1, max=2), stop=stop_after_attempt(3)) @retry(wait=wait_random(min=1, max=60), stop=stop_after_attempt(3))
async def listing_query( async def listing_query(
*, *,
page: int, page: int,

View file

@ -0,0 +1,60 @@
from datetime import timedelta
from functools import lru_cache
import json
from string import Template
from typing import Any
from api.auth import User
import redis
from celery_app import app
class RedisRepository:
redis_client: redis.Redis
tasks_key_template = Template("user:{user_id}/tasks")
def __init__(self):
redis_hostname = app.broker_connection().info()["hostname"]
redis_port = app.broker_connection().info()["port"]
self.redis_client = redis.Redis(
host=redis_hostname, port=redis_port, db=0, decode_responses=True
) # decode_responses=True returns str, not bytes
@lru_cache(maxsize=None)
@staticmethod
def instance():
return RedisRepository()
def set_key(self, key: str, value, ttl: timedelta | None = None) -> None:
serialized_value = self.__serialize_value(value)
self.redis_client.set(key, serialized_value)
ttl = ttl or timedelta(hours=3)
self.redis_client.expire(key, ttl)
def get_key(self, key: str) -> Any | None:
serialized_value = self.redis_client.get(key)
if serialized_value is None:
return None
return self.__deserialize_value(serialized_value)
def add_task_for_user(self, user: User, task_id: str):
# Add the task ID to the Redis set for the user
current_tasks = (
self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or []
)
self.set_key(
self.tasks_key_template.substitute(user_id=user.email),
[task_id] + current_tasks,
)
def get_tasks_for_user(self, user: User) -> list[str]:
# Get the task IDs from the Redis set for the user
return (
self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or []
)
def __serialize_value(self, value: Any) -> str:
return json.dumps(value)
def __deserialize_value(self, value_str) -> Any:
return json.loads(value_str)

View file

@ -1,4 +1,5 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging
from data_access import Listing from data_access import Listing
from models.listing import ( from models.listing import (
BuyListing, BuyListing,
@ -13,6 +14,8 @@ from sqlmodel import Sequence, Session, and_, col, select
from sqlmodel.sql.expression import SelectOfScalar from sqlmodel.sql.expression import SelectOfScalar
from tqdm import tqdm from tqdm import tqdm
logger = logging.getLogger("uvicorn.error")
class ListingRepository: class ListingRepository:
engine: Engine engine: Engine
@ -51,7 +54,9 @@ class ListingRepository:
with Session(self.engine) as session: with Session(self.engine) as session:
# query = select(modelListing) # query = select(modelListing)
return list(session.exec(query).all()) rows = list(session.exec(query).all())
logging.debug(f"Found {len(rows)} listings")
return rows
def _add_where_from_query_parameters( def _add_where_from_query_parameters(
self, self,

View file

@ -18,7 +18,7 @@ detect_floorplan_module = importlib.import_module("4_detect_floorplan")
logger = logging.getLogger("uvicorn.error") logger = logging.getLogger("uvicorn.error")
@app.task(bind=True) @app.task(bind=True, pydantic=True)
def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]: def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
parsed_parameters = QueryParameters.model_validate_json(parameters_json) parsed_parameters = QueryParameters.model_validate_json(parameters_json)
asyncio.run(dump_listings_full(self, parsed_parameters)) asyncio.run(dump_listings_full(self, parsed_parameters))