migrate background tasks to celery
This commit is contained in:
parent
efe3248c07
commit
93129333e6
7 changed files with 106 additions and 101 deletions
|
|
@ -1,4 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import logging.config
|
import logging.config
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -15,6 +17,7 @@ from api.worker import (
|
||||||
task_queue,
|
task_queue,
|
||||||
task_results,
|
task_results,
|
||||||
)
|
)
|
||||||
|
from dotenv import load_dotenv
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Query
|
from fastapi import Depends, FastAPI, HTTPException, Query
|
||||||
from api.auth import User
|
from api.auth import User
|
||||||
from models.listing import QueryParameters
|
from models.listing import QueryParameters
|
||||||
|
|
@ -22,11 +25,14 @@ from repositories.listing_repository import ListingRepository
|
||||||
from repositories.listing_repository import ListingRepository
|
from repositories.listing_repository import ListingRepository
|
||||||
from database import engine
|
from database import engine
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from tasks import listing_tasks
|
||||||
from ui_exporter import export_immoweb
|
from ui_exporter import export_immoweb
|
||||||
from alembic import command
|
from alembic import command
|
||||||
from alembic.config import Config
|
from alembic.config import Config
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
logger = logging.getLogger("uvicorn")
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -43,9 +49,6 @@ logger = logging.getLogger("uvicorn")
|
||||||
# app = FastAPI(lifespan=lifespan)
|
# app = FastAPI(lifespan=lifespan)
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# Start worker thread
|
|
||||||
WorkerManager(DumpListingsWorker()).start()
|
|
||||||
|
|
||||||
|
|
||||||
# Allow CORS (for React frontend)
|
# Allow CORS (for React frontend)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
|
@ -81,19 +84,9 @@ async def refresh_listings(
|
||||||
user: Annotated[User, Depends(get_current_user)],
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
query_parameters: Annotated[QueryParameters, Query()],
|
query_parameters: Annotated[QueryParameters, Query()],
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
# Submit processing task
|
# TODO: rate limit
|
||||||
task_id = str(uuid.uuid4())
|
task = listing_tasks.dump_listings_task.delay(query_parameters.json())
|
||||||
task_results[task_id] = {"status": TaskStatus.QUEUED}
|
return {"task_id": task.id}
|
||||||
try:
|
|
||||||
task_queue.put_nowait(
|
|
||||||
(task_id, query_parameters),
|
|
||||||
)
|
|
||||||
except queue.Full:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=429,
|
|
||||||
detail="Already processing at maximum capacity. Please try again later",
|
|
||||||
)
|
|
||||||
return {"task_id": task_id}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/task_status")
|
@app.get("/api/task_status")
|
||||||
|
|
@ -101,6 +94,9 @@ async def get_task_status(
|
||||||
user: Annotated[User, Depends(get_current_user)],
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
task_id: str,
|
task_id: str,
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
if task_id not in task_results:
|
task_result = listing_tasks.dump_listings_task.AsyncResult(task_id)
|
||||||
return {"status": "not_found"}
|
return {
|
||||||
return task_results[task_id]
|
"task_id": task_id,
|
||||||
|
"status": task_result.status,
|
||||||
|
"result": json.dumps(task_result.result),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,79 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
from abc import abstractmethod
|
|
||||||
import asyncio
|
|
||||||
import atexit
|
|
||||||
import enum
|
|
||||||
import importlib
|
|
||||||
from pathlib import Path
|
|
||||||
from queue import Queue
|
|
||||||
import queue
|
|
||||||
from threading import Thread
|
|
||||||
from database import engine
|
|
||||||
|
|
||||||
from repositories.listing_repository import ListingRepository
|
|
||||||
|
|
||||||
dump_listings_module = importlib.import_module("1_dump_listings")
|
|
||||||
|
|
||||||
# In-memory task queue and results store
|
|
||||||
task_queue = Queue(maxsize=1) # Disallow multiple in flight requests for now
|
|
||||||
task_results = {}
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerManager:
|
|
||||||
def __init__(self, worker: WorkerThread):
|
|
||||||
super().__init__()
|
|
||||||
self._worker = worker
|
|
||||||
atexit.register(asyncio.run, self.stop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
await self._worker.stop()
|
|
||||||
self._worker_thread.join()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
self._worker_thread = Thread(
|
|
||||||
target=asyncio.run, args=[self._worker.run()], daemon=True
|
|
||||||
)
|
|
||||||
self._worker_thread.start()
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerThread:
|
|
||||||
@abstractmethod
|
|
||||||
async def stop(self) -> None: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def run(self) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
class DumpListingsWorker(WorkerThread):
|
|
||||||
should_stop = False
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
self.should_stop = True
|
|
||||||
|
|
||||||
async def run(self) -> None: # global results is updated
|
|
||||||
"""Background worker that processes tasks"""
|
|
||||||
repository = ListingRepository(engine)
|
|
||||||
data_dir_path = Path("data/rs")
|
|
||||||
while not self.should_stop:
|
|
||||||
task_id, task_data = task_queue.get()
|
|
||||||
task_results[task_id] = {"status": TaskStatus.PROCESSING}
|
|
||||||
query_parameters = task_data
|
|
||||||
try:
|
|
||||||
new_listings = await dump_listings_module.dump_listings_full(
|
|
||||||
query_parameters, repository, data_dir_path
|
|
||||||
)
|
|
||||||
task_results[task_id] = {
|
|
||||||
"status": "completed",
|
|
||||||
"result": f"Fetched {len(new_listings)} new listings for query {task_data}",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
task_results[task_id] = {"status": TaskStatus.FAILED, "error": str(e)}
|
|
||||||
finally:
|
|
||||||
task_queue.task_done()
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(enum.StrEnum):
|
|
||||||
QUEUED = "queued"
|
|
||||||
PROCESSING = "processing"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
31
crawler/celery_app.py
Normal file
31
crawler/celery_app.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
import sys
|
||||||
|
from celery import Celery
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
app = Celery(
|
||||||
|
"celery_app",
|
||||||
|
broker=os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
|
||||||
|
backend=os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/1"),
|
||||||
|
include=["tasks.listing_tasks"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.conf.update(
|
||||||
|
task_serializer="json",
|
||||||
|
result_serializer="json",
|
||||||
|
accept_content=["json"],
|
||||||
|
timezone="UTC",
|
||||||
|
enable_utc=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
with app.connection() as conn:
|
||||||
|
conn.ensure_connection(max_retries=0)
|
||||||
|
print("Broker connection OK")
|
||||||
|
sys.exit(0)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Broker connection failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
@ -8,7 +8,7 @@ from pathlib import Path
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from rec import routing
|
from rec import routing
|
||||||
from sqlmodel import JSON, SQLModel, Field, String
|
from sqlmodel import JSON, TEXT, SQLModel, Field, String
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -63,7 +63,7 @@ class Listing(SQLModel, table=False):
|
||||||
longtitude: float = Field(nullable=False)
|
longtitude: float = Field(nullable=False)
|
||||||
latitude: float = Field(nullable=False)
|
latitude: float = Field(nullable=False)
|
||||||
# price_history: List[Dict[str, Any]] = Field(default_factory=list, sa_type=JSON)
|
# price_history: List[Dict[str, Any]] = Field(default_factory=list, sa_type=JSON)
|
||||||
price_history_json: str = Field(sa_type=String)
|
price_history_json: str = Field(sa_type=TEXT)
|
||||||
listing_site: ListingSite = Field(nullable=False)
|
listing_site: ListingSite = Field(nullable=False)
|
||||||
last_seen: datetime = Field(default_factory=datetime.now, nullable=False)
|
last_seen: datetime = Field(default_factory=datetime.now, nullable=False)
|
||||||
photo_thumbnail: str | None = Field(default=None, nullable=True)
|
photo_thumbnail: str | None = Field(default=None, nullable=True)
|
||||||
|
|
@ -74,7 +74,7 @@ class Listing(SQLModel, table=False):
|
||||||
default_factory=dict, sa_type=JSON, nullable=False
|
default_factory=dict, sa_type=JSON, nullable=False
|
||||||
)
|
)
|
||||||
routing_info_json: str = Field(
|
routing_info_json: str = Field(
|
||||||
sa_type=String, nullable=True, default=None
|
sa_type=TEXT, nullable=True, default=None
|
||||||
) # Store as JSON string for simplicity
|
) # Store as JSON string for simplicity
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
||||||
45
crawler/tasks/listing_tasks.py
Normal file
45
crawler/tasks/listing_tasks.py
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
import asyncio
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
from celery import Celery, Task
|
||||||
|
from celery_app import app
|
||||||
|
from models.listing import Listing, QueryParameters
|
||||||
|
from repositories.listing_repository import ListingRepository
|
||||||
|
from database import engine
|
||||||
|
from tasks.task_state import TaskStatus
|
||||||
|
|
||||||
|
dump_listings_module = importlib.import_module("1_dump_listings")
|
||||||
|
dump_images_module = importlib.import_module("3_dump_images")
|
||||||
|
detect_floorplan_module = importlib.import_module("4_detect_floorplan")
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn.error")
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(bind=True)
|
||||||
|
def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
|
||||||
|
parsed_parameters = QueryParameters.model_validate_json(parameters_json)
|
||||||
|
asyncio.run(dump_listings_full(self, parsed_parameters))
|
||||||
|
return {"progress": 1}
|
||||||
|
|
||||||
|
|
||||||
|
async def dump_listings_full(self: Task, parameters: QueryParameters) -> list[Listing]:
|
||||||
|
"""Fetches all listings, images as well as detects floorplans"""
|
||||||
|
self.update_state(state="FETCHING_LISTINGS", meta={"progress": 0.1})
|
||||||
|
repository = ListingRepository(engine)
|
||||||
|
new_listings = await dump_listings_module.dump_listings(parameters, repository)
|
||||||
|
self.update_state(state="FETCHING_FLOORPLANS", meta={"progress": 0.3})
|
||||||
|
logger.debug(f"Upserted {len(new_listings)} new listings")
|
||||||
|
logger.debug("Starting to fetch floorplans")
|
||||||
|
await dump_images_module.dump_images(repository)
|
||||||
|
self.update_state(state="RUNNING_OCR_ON_FLOORPLANS", meta={"progress": 0.6})
|
||||||
|
logger.debug("Completed fetching floorplans")
|
||||||
|
logger.debug("Starting floorplan detection")
|
||||||
|
await detect_floorplan_module.detect_floorplan(repository)
|
||||||
|
logger.debug("Completed floorplan detection")
|
||||||
|
# refresh listings
|
||||||
|
listings = await repository.get_listings(parameters) # this can be better
|
||||||
|
new_listings = [l for l in listings if l.id in new_listings]
|
||||||
|
return new_listings
|
||||||
8
crawler/tasks/task_state.py
Normal file
8
crawler/tasks/task_state.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(enum.StrEnum):
|
||||||
|
QUEUED = "queued"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from rec.query import QueryParameters
|
from rec.query import QueryParameters
|
||||||
from repositories.listing_repository import ListingRepository
|
from repositories.listing_repository import ListingRepository
|
||||||
|
|
||||||
|
logger = logging.getLogger("uvicorn.error")
|
||||||
|
|
||||||
|
|
||||||
async def export_immoweb(
|
async def export_immoweb(
|
||||||
repository: ListingRepository,
|
repository: ListingRepository,
|
||||||
|
|
@ -15,6 +18,7 @@ async def export_immoweb(
|
||||||
query_parameters=query_parameters,
|
query_parameters=query_parameters,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
logger.info(f"Fetched {len(listings)} listings")
|
||||||
|
|
||||||
# Convert listings to immoweb format
|
# Convert listings to immoweb format
|
||||||
immoweb_listings = []
|
immoweb_listings = []
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue