save user queries in redis so that user can refresh the page and still come back to their latest task
This commit is contained in:
parent
a055c92dea
commit
d4b22deda0
7 changed files with 114 additions and 4 deletions
|
|
@ -16,6 +16,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query
|
||||||
from api.auth import User
|
from api.auth import User
|
||||||
from models.listing import QueryParameters
|
from models.listing import QueryParameters
|
||||||
from rec import districts
|
from rec import districts
|
||||||
|
from redis_repository import RedisRepository
|
||||||
from repositories.listing_repository import ListingRepository
|
from repositories.listing_repository import ListingRepository
|
||||||
from repositories.listing_repository import ListingRepository
|
from repositories.listing_repository import ListingRepository
|
||||||
from database import engine
|
from database import engine
|
||||||
|
|
@ -27,6 +28,7 @@ from alembic import command
|
||||||
from alembic.config import Config
|
from alembic.config import Config
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from celery.exceptions import TaskRevokedError
|
from celery.exceptions import TaskRevokedError
|
||||||
|
from celery_app import app as celery_app
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
logger = logging.getLogger("uvicorn")
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
@ -86,6 +88,9 @@ async def refresh_listings(
|
||||||
args=(query_parameters.model_dump_json(),),
|
args=(query_parameters.model_dump_json(),),
|
||||||
expires=expiry_time,
|
expires=expiry_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
redis_repository = RedisRepository.instance()
|
||||||
|
redis_repository.add_task_for_user(user, task.id)
|
||||||
return {"task_id": task.id}
|
return {"task_id": task.id}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -107,6 +112,15 @@ async def get_task_status(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/tasks_for_user")
|
||||||
|
async def get_tasks_for_user(
|
||||||
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
|
) -> list[str]:
|
||||||
|
redis_repository = RedisRepository.instance()
|
||||||
|
user_tasks = redis_repository.get_tasks_for_user(user)
|
||||||
|
return user_tasks
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/get_districts")
|
@app.get("/api/get_districts")
|
||||||
async def get_districts(
|
async def get_districts(
|
||||||
user: Annotated[User, Depends(get_current_user)],
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,26 @@ const fetchData = async (user: User, baseQueyrUri: string, parameters: Parameter
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
const fetchActiveTasksForUser = async (user: User) => {
|
||||||
|
const accessToken = user?.access_token;
|
||||||
|
const response = await fetch(`/api/tasks_for_user`, {
|
||||||
|
method: 'GET',
|
||||||
|
headers: {
|
||||||
|
'Authorization': `Bearer ${accessToken}`, // Pass the token
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to fetch active tasks for user: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data =
|
||||||
|
await response.json();
|
||||||
|
return data;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
function App() {
|
function App() {
|
||||||
const [listingData, setListingData] = useState({});
|
const [listingData, setListingData] = useState({});
|
||||||
const [taskID, setTaskID] = useState<string | null>(null);
|
const [taskID, setTaskID] = useState<string | null>(null);
|
||||||
|
|
@ -83,6 +103,17 @@ function App() {
|
||||||
getUser().then(setUser);
|
getUser().then(setUser);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!user) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
fetchActiveTasksForUser(user).then((tasks) => {
|
||||||
|
if (tasks) {
|
||||||
|
setTaskID(tasks[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}, [user, taskID])
|
||||||
|
|
||||||
if (!user) {
|
if (!user) {
|
||||||
return <LoginModal isOpen={user === null} />
|
return <LoginModal isOpen={user === null} />
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ def listing_filter_options(func):
|
||||||
"--max-price",
|
"--max-price",
|
||||||
default=999_999,
|
default=999_999,
|
||||||
help="Maximum price",
|
help="Maximum price",
|
||||||
type=click.IntRange(min=0, max=40_000), # 40k for renting
|
type=click.IntRange(min=0), # 40k for renting
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--district",
|
"--district",
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ async def detail_query(detail_id: int) -> dict[str, Any]:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
@retry(wait=wait_random(min=1, max=2), stop=stop_after_attempt(3))
|
@retry(wait=wait_random(min=1, max=60), stop=stop_after_attempt(3))
|
||||||
async def listing_query(
|
async def listing_query(
|
||||||
*,
|
*,
|
||||||
page: int,
|
page: int,
|
||||||
|
|
|
||||||
60
crawler/redis_repository.py
Normal file
60
crawler/redis_repository.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
from datetime import timedelta
|
||||||
|
from functools import lru_cache
|
||||||
|
import json
|
||||||
|
from string import Template
|
||||||
|
from typing import Any
|
||||||
|
from api.auth import User
|
||||||
|
import redis
|
||||||
|
from celery_app import app
|
||||||
|
|
||||||
|
|
||||||
|
class RedisRepository:
|
||||||
|
redis_client: redis.Redis
|
||||||
|
tasks_key_template = Template("user:{user_id}/tasks")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
redis_hostname = app.broker_connection().info()["hostname"]
|
||||||
|
redis_port = app.broker_connection().info()["port"]
|
||||||
|
self.redis_client = redis.Redis(
|
||||||
|
host=redis_hostname, port=redis_port, db=0, decode_responses=True
|
||||||
|
) # decode_responses=True returns str, not bytes
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
@staticmethod
|
||||||
|
def instance():
|
||||||
|
return RedisRepository()
|
||||||
|
|
||||||
|
def set_key(self, key: str, value, ttl: timedelta | None = None) -> None:
|
||||||
|
serialized_value = self.__serialize_value(value)
|
||||||
|
self.redis_client.set(key, serialized_value)
|
||||||
|
|
||||||
|
ttl = ttl or timedelta(hours=3)
|
||||||
|
self.redis_client.expire(key, ttl)
|
||||||
|
|
||||||
|
def get_key(self, key: str) -> Any | None:
|
||||||
|
serialized_value = self.redis_client.get(key)
|
||||||
|
if serialized_value is None:
|
||||||
|
return None
|
||||||
|
return self.__deserialize_value(serialized_value)
|
||||||
|
|
||||||
|
def add_task_for_user(self, user: User, task_id: str):
|
||||||
|
# Add the task ID to the Redis set for the user
|
||||||
|
current_tasks = (
|
||||||
|
self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or []
|
||||||
|
)
|
||||||
|
self.set_key(
|
||||||
|
self.tasks_key_template.substitute(user_id=user.email),
|
||||||
|
[task_id] + current_tasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_tasks_for_user(self, user: User) -> list[str]:
|
||||||
|
# Get the task IDs from the Redis set for the user
|
||||||
|
return (
|
||||||
|
self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or []
|
||||||
|
)
|
||||||
|
|
||||||
|
def __serialize_value(self, value: Any) -> str:
|
||||||
|
return json.dumps(value)
|
||||||
|
|
||||||
|
def __deserialize_value(self, value_str) -> Any:
|
||||||
|
return json.loads(value_str)
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
import logging
|
||||||
from data_access import Listing
|
from data_access import Listing
|
||||||
from models.listing import (
|
from models.listing import (
|
||||||
BuyListing,
|
BuyListing,
|
||||||
|
|
@ -13,6 +14,8 @@ from sqlmodel import Sequence, Session, and_, col, select
|
||||||
from sqlmodel.sql.expression import SelectOfScalar
|
from sqlmodel.sql.expression import SelectOfScalar
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn.error")
|
||||||
|
|
||||||
|
|
||||||
class ListingRepository:
|
class ListingRepository:
|
||||||
engine: Engine
|
engine: Engine
|
||||||
|
|
@ -51,7 +54,9 @@ class ListingRepository:
|
||||||
|
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
# query = select(modelListing)
|
# query = select(modelListing)
|
||||||
return list(session.exec(query).all())
|
rows = list(session.exec(query).all())
|
||||||
|
logging.debug(f"Found {len(rows)} listings")
|
||||||
|
return rows
|
||||||
|
|
||||||
def _add_where_from_query_parameters(
|
def _add_where_from_query_parameters(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ detect_floorplan_module = importlib.import_module("4_detect_floorplan")
|
||||||
logger = logging.getLogger("uvicorn.error")
|
logger = logging.getLogger("uvicorn.error")
|
||||||
|
|
||||||
|
|
||||||
@app.task(bind=True)
|
@app.task(bind=True, pydantic=True)
|
||||||
def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
|
def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
|
||||||
parsed_parameters = QueryParameters.model_validate_json(parameters_json)
|
parsed_parameters = QueryParameters.model_validate_json(parameters_json)
|
||||||
asyncio.run(dump_listings_full(self, parsed_parameters))
|
asyncio.run(dump_listings_full(self, parsed_parameters))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue