diff --git a/crawler/api/app.py b/crawler/api/app.py index 1a24d0a..73ba146 100644 --- a/crawler/api/app.py +++ b/crawler/api/app.py @@ -16,6 +16,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from api.auth import User from models.listing import QueryParameters from rec import districts +from redis_repository import RedisRepository from repositories.listing_repository import ListingRepository from repositories.listing_repository import ListingRepository from database import engine @@ -27,6 +28,7 @@ from alembic import command from alembic.config import Config from contextlib import asynccontextmanager from celery.exceptions import TaskRevokedError +from celery_app import app as celery_app load_dotenv() logger = logging.getLogger("uvicorn") @@ -86,6 +88,9 @@ async def refresh_listings( args=(query_parameters.model_dump_json(),), expires=expiry_time, ) + + redis_repository = RedisRepository.instance() + redis_repository.add_task_for_user(user, task.id) return {"task_id": task.id} @@ -107,6 +112,15 @@ async def get_task_status( } +@app.get("/api/tasks_for_user") +async def get_tasks_for_user( + user: Annotated[User, Depends(get_current_user)], +) -> list[str]: + redis_repository = RedisRepository.instance() + user_tasks = redis_repository.get_tasks_for_user(user) + return user_tasks + + @app.get("/api/get_districts") async def get_districts( user: Annotated[User, Depends(get_current_user)], diff --git a/crawler/frontend/src/App.tsx b/crawler/frontend/src/App.tsx index d12753a..621947b 100644 --- a/crawler/frontend/src/App.tsx +++ b/crawler/frontend/src/App.tsx @@ -60,6 +60,26 @@ const fetchData = async (user: User, baseQueyrUri: string, parameters: Parameter }; +const fetchActiveTasksForUser = async (user: User) => { + const accessToken = user?.access_token; + const response = await fetch(`/api/tasks_for_user`, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${accessToken}`, // Pass the token + 'Content-Type': 'application/json', + }, + }); + if (!response.ok) { + throw new Error(`Failed to fetch active tasks for user: ${response.status}`); + } + + const data = + await response.json(); + return data; +}; + + + function App() { const [listingData, setListingData] = useState({}); const [taskID, setTaskID] = useState(null); @@ -83,6 +103,17 @@ function App() { getUser().then(setUser); }, []); + useEffect(() => { + if (!user) { + return; + } + fetchActiveTasksForUser(user).then((tasks) => { + if (tasks) { + setTaskID(tasks[0]) + } + }) + }, [user, taskID]) + if (!user) { return } diff --git a/crawler/main.py b/crawler/main.py index 791c47c..1135278 100644 --- a/crawler/main.py +++ b/crawler/main.py @@ -58,7 +58,7 @@ def listing_filter_options(func): "--max-price", default=999_999, help="Maximum price", - type=click.IntRange(min=0, max=40_000), # 40k for renting + type=click.IntRange(min=0), # 40k for renting ) @click.option( "--district", diff --git a/crawler/rec/query.py b/crawler/rec/query.py index d43cd0a..aa16c33 100644 --- a/crawler/rec/query.py +++ b/crawler/rec/query.py @@ -44,7 +44,7 @@ async def detail_query(detail_id: int) -> dict[str, Any]: return await response.json() -@retry(wait=wait_random(min=1, max=2), stop=stop_after_attempt(3)) +@retry(wait=wait_random(min=1, max=60), stop=stop_after_attempt(3)) async def listing_query( *, page: int, diff --git a/crawler/redis_repository.py b/crawler/redis_repository.py new file mode 100644 index 0000000..9ece35e --- /dev/null +++ b/crawler/redis_repository.py @@ -0,0 +1,60 @@ +from datetime import timedelta +from functools import lru_cache +import json +from string import Template +from typing import Any +from api.auth import User +import redis +from celery_app import app + + +class RedisRepository: + redis_client: redis.Redis + tasks_key_template = Template("user:{user_id}/tasks") + + def __init__(self): + redis_hostname = app.broker_connection().info()["hostname"] + redis_port = app.broker_connection().info()["port"] + self.redis_client = redis.Redis( + host=redis_hostname, port=redis_port, db=0, decode_responses=True + ) # decode_responses=True returns str, not bytes + + @lru_cache(maxsize=None) + @staticmethod + def instance(): + return RedisRepository() + + def set_key(self, key: str, value, ttl: timedelta | None = None) -> None: + serialized_value = self.__serialize_value(value) + self.redis_client.set(key, serialized_value) + + ttl = ttl or timedelta(hours=3) + self.redis_client.expire(key, ttl) + + def get_key(self, key: str) -> Any | None: + serialized_value = self.redis_client.get(key) + if serialized_value is None: + return None + return self.__deserialize_value(serialized_value) + + def add_task_for_user(self, user: User, task_id: str): + # Add the task ID to the Redis set for the user + current_tasks = ( + self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or [] + ) + self.set_key( + self.tasks_key_template.substitute(user_id=user.email), + [task_id] + current_tasks, + ) + + def get_tasks_for_user(self, user: User) -> list[str]: + # Get the task IDs from the Redis set for the user + return ( + self.get_key(self.tasks_key_template.substitute(user_id=user.email)) or [] + ) + + def __serialize_value(self, value: Any) -> str: + return json.dumps(value) + + def __deserialize_value(self, value_str) -> Any: + return json.loads(value_str) diff --git a/crawler/repositories/listing_repository.py b/crawler/repositories/listing_repository.py index d8ffc41..ea4dfe7 100644 --- a/crawler/repositories/listing_repository.py +++ b/crawler/repositories/listing_repository.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +import logging from data_access import Listing from models.listing import ( BuyListing, @@ -13,6 +14,8 @@ from sqlmodel import Sequence, Session, and_, col, select from sqlmodel.sql.expression import SelectOfScalar from tqdm import tqdm +logger = logging.getLogger("uvicorn.error") + class ListingRepository: engine: Engine @@ -51,7 +54,9 @@ class ListingRepository: with Session(self.engine) as session: # query = select(modelListing) - return list(session.exec(query).all()) + rows = list(session.exec(query).all()) + logging.debug(f"Found {len(rows)} listings") + return rows def _add_where_from_query_parameters( self, diff --git a/crawler/tasks/listing_tasks.py b/crawler/tasks/listing_tasks.py index ffea36a..59e99f8 100644 --- a/crawler/tasks/listing_tasks.py +++ b/crawler/tasks/listing_tasks.py @@ -18,7 +18,7 @@ detect_floorplan_module = importlib.import_module("4_detect_floorplan") logger = logging.getLogger("uvicorn.error") -@app.task(bind=True) +@app.task(bind=True, pydantic=True) def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]: parsed_parameters = QueryParameters.model_validate_json(parameters_json) asyncio.run(dump_listings_full(self, parsed_parameters))