migrate background tasks to celery

This commit is contained in:
Viktor Barzin 2025-06-22 21:18:52 +00:00
parent efe3248c07
commit 93129333e6
No known key found for this signature in database
GPG key ID: 4056458DBDBF8863
7 changed files with 106 additions and 101 deletions

View file

@ -1,4 +1,6 @@
import asyncio
import dataclasses
import json
import logging
import logging.config
from pathlib import Path
@ -15,6 +17,7 @@ from api.worker import (
task_queue,
task_results,
)
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, HTTPException, Query
from api.auth import User
from models.listing import QueryParameters
@ -22,11 +25,14 @@ from repositories.listing_repository import ListingRepository
from repositories.listing_repository import ListingRepository
from database import engine
from fastapi.middleware.cors import CORSMiddleware
from tasks import listing_tasks
from ui_exporter import export_immoweb
from alembic import command
from alembic.config import Config
from contextlib import asynccontextmanager
load_dotenv()
logger = logging.getLogger("uvicorn")
@ -43,9 +49,6 @@ logger = logging.getLogger("uvicorn")
# app = FastAPI(lifespan=lifespan)
app = FastAPI()
# Start worker thread
WorkerManager(DumpListingsWorker()).start()
# Allow CORS (for React frontend)
app.add_middleware(
@ -81,19 +84,9 @@ async def refresh_listings(
user: Annotated[User, Depends(get_current_user)],
query_parameters: Annotated[QueryParameters, Query()],
) -> dict[str, str]:
# Submit processing task
task_id = str(uuid.uuid4())
task_results[task_id] = {"status": TaskStatus.QUEUED}
try:
task_queue.put_nowait(
(task_id, query_parameters),
)
except queue.Full:
raise HTTPException(
status_code=429,
detail="Already processing at maximum capacity. Please try again later",
)
return {"task_id": task_id}
# TODO: rate limit
task = listing_tasks.dump_listings_task.delay(query_parameters.json())
return {"task_id": task.id}
@app.get("/api/task_status")
@ -101,6 +94,9 @@ async def get_task_status(
user: Annotated[User, Depends(get_current_user)],
task_id: str,
) -> dict[str, str]:
if task_id not in task_results:
return {"status": "not_found"}
return task_results[task_id]
task_result = listing_tasks.dump_listings_task.AsyncResult(task_id)
return {
"task_id": task_id,
"status": task_result.status,
"result": json.dumps(task_result.result),
}

View file

@ -1,79 +0,0 @@
from __future__ import annotations
from abc import abstractmethod
import asyncio
import atexit
import enum
import importlib
from pathlib import Path
from queue import Queue
import queue
from threading import Thread
from database import engine
from repositories.listing_repository import ListingRepository
dump_listings_module = importlib.import_module("1_dump_listings")
# In-memory task queue and results store
task_queue = Queue(maxsize=1) # Disallow multiple in flight requests for now
task_results = {}
class WorkerManager:
def __init__(self, worker: WorkerThread):
super().__init__()
self._worker = worker
atexit.register(asyncio.run, self.stop())
async def stop(self) -> None:
await self._worker.stop()
self._worker_thread.join()
def start(self):
self._worker_thread = Thread(
target=asyncio.run, args=[self._worker.run()], daemon=True
)
self._worker_thread.start()
class WorkerThread:
@abstractmethod
async def stop(self) -> None: ...
@abstractmethod
async def run(self) -> None: ...
class DumpListingsWorker(WorkerThread):
should_stop = False
async def stop(self) -> None:
self.should_stop = True
async def run(self) -> None: # global results is updated
"""Background worker that processes tasks"""
repository = ListingRepository(engine)
data_dir_path = Path("data/rs")
while not self.should_stop:
task_id, task_data = task_queue.get()
task_results[task_id] = {"status": TaskStatus.PROCESSING}
query_parameters = task_data
try:
new_listings = await dump_listings_module.dump_listings_full(
query_parameters, repository, data_dir_path
)
task_results[task_id] = {
"status": "completed",
"result": f"Fetched {len(new_listings)} new listings for query {task_data}",
}
except Exception as e:
task_results[task_id] = {"status": TaskStatus.FAILED, "error": str(e)}
finally:
task_queue.task_done()
class TaskStatus(enum.StrEnum):
QUEUED = "queued"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"

31
crawler/celery_app.py Normal file
View file

@ -0,0 +1,31 @@
import sys
from celery import Celery
from dotenv import load_dotenv
import os
load_dotenv()
app = Celery(
"celery_app",
broker=os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
backend=os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/1"),
include=["tasks.listing_tasks"],
)
app.conf.update(
task_serializer="json",
result_serializer="json",
accept_content=["json"],
timezone="UTC",
enable_utc=True,
)
if __name__ == "__main__":
try:
with app.connection() as conn:
conn.ensure_connection(max_retries=0)
print("Broker connection OK")
sys.exit(0)
except Exception as e:
print(f"Broker connection failed: {e}")
sys.exit(1)

View file

@ -8,7 +8,7 @@ from pathlib import Path
from typing import Any, Dict, List
from pydantic import BaseModel
from rec import routing
from sqlmodel import JSON, SQLModel, Field, String
from sqlmodel import JSON, TEXT, SQLModel, Field, String
@dataclass(frozen=True)
@ -63,7 +63,7 @@ class Listing(SQLModel, table=False):
longtitude: float = Field(nullable=False)
latitude: float = Field(nullable=False)
# price_history: List[Dict[str, Any]] = Field(default_factory=list, sa_type=JSON)
price_history_json: str = Field(sa_type=String)
price_history_json: str = Field(sa_type=TEXT)
listing_site: ListingSite = Field(nullable=False)
last_seen: datetime = Field(default_factory=datetime.now, nullable=False)
photo_thumbnail: str | None = Field(default=None, nullable=True)
@ -74,7 +74,7 @@ class Listing(SQLModel, table=False):
default_factory=dict, sa_type=JSON, nullable=False
)
routing_info_json: str = Field(
sa_type=String, nullable=True, default=None
sa_type=TEXT, nullable=True, default=None
) # Store as JSON string for simplicity
@property

View file

@ -0,0 +1,45 @@
import asyncio
import importlib
import logging
from pathlib import Path
import time
from typing import Any
from celery import Celery, Task
from celery_app import app
from models.listing import Listing, QueryParameters
from repositories.listing_repository import ListingRepository
from database import engine
from tasks.task_state import TaskStatus
dump_listings_module = importlib.import_module("1_dump_listings")
dump_images_module = importlib.import_module("3_dump_images")
detect_floorplan_module = importlib.import_module("4_detect_floorplan")
logger = logging.getLogger("uvicorn.error")
@app.task(bind=True)
def dump_listings_task(self: Task, parameters_json: str) -> dict[str, Any]:
parsed_parameters = QueryParameters.model_validate_json(parameters_json)
asyncio.run(dump_listings_full(self, parsed_parameters))
return {"progress": 1}
async def dump_listings_full(self: Task, parameters: QueryParameters) -> list[Listing]:
"""Fetches all listings, images as well as detects floorplans"""
self.update_state(state="FETCHING_LISTINGS", meta={"progress": 0.1})
repository = ListingRepository(engine)
new_listings = await dump_listings_module.dump_listings(parameters, repository)
self.update_state(state="FETCHING_FLOORPLANS", meta={"progress": 0.3})
logger.debug(f"Upserted {len(new_listings)} new listings")
logger.debug("Starting to fetch floorplans")
await dump_images_module.dump_images(repository)
self.update_state(state="RUNNING_OCR_ON_FLOORPLANS", meta={"progress": 0.6})
logger.debug("Completed fetching floorplans")
logger.debug("Starting floorplan detection")
await detect_floorplan_module.detect_floorplan(repository)
logger.debug("Completed floorplan detection")
# refresh listings
listings = await repository.get_listings(parameters) # this can be better
new_listings = [l for l in listings if l.id in new_listings]
return new_listings

View file

@ -0,0 +1,8 @@
import enum
class TaskStatus(enum.StrEnum):
QUEUED = "queued"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"

View file

@ -1,9 +1,12 @@
import json
import logging
import pathlib
from rec.query import QueryParameters
from repositories.listing_repository import ListingRepository
logger = logging.getLogger("uvicorn.error")
async def export_immoweb(
repository: ListingRepository,
@ -15,6 +18,7 @@ async def export_immoweb(
query_parameters=query_parameters,
limit=limit,
)
logger.info(f"Fetched {len(listings)} listings")
# Convert listings to immoweb format
immoweb_listings = []