Compare commits

..

10 commits

Author SHA1 Message Date
Viktor Barzin
ced9a153bd
replace pymysql with mysqlclient as it is "better"; also fix an issue in the ui exporter that had wrong imports 2025-10-18 09:58:55 +00:00
Kadir
0801aaf200
More ruff fixes (#2)
* adding ruff auto check for pull requests as well as fixing all ruff errors

* More ruff fixes: forgot half of the ruff checks

Forgot to do a git add all :D

---------

Co-authored-by: Kadir <git@k8n.dev>
2025-09-14 19:44:03 +01:00
Kadir
4c23acdb55
adding ruff auto check for pull requests as well as fixing all ruff errors (#1)
Co-authored-by: Kadir <git@k8n.dev>
2025-09-14 19:40:18 +01:00
Kadir
b1e0a414cf Used ruff to cleanup
I hope it just works right as I cannot test things if they work
2025-09-14 19:02:30 +01:00
Kadir
8d11e4a81c Fix deps and move to a better local environment
- Cleaned up some deps and moved them to the dev section
- Moved from mysqlclient to pymysql which is a python native one which does not require the OS to have the correct mysql lib
- Added a podman compose file so we can have all dependencies in one place easily without the need to install redis or a database locally

For podman install
- podman
- podman-compose (do a poetry sync I think?)
- podman-compose up to start the containers
2025-09-14 19:00:26 +01:00
Viktor Barzin
520286aaee
update readme with instructions on how to run everything 2025-08-28 21:39:20 +00:00
Viktor Barzin
62329a2eb4
add redis container in start.sh in case it is not running in dev mode 2025-08-28 20:48:42 +00:00
Viktor Barzin
ff57117054
do not watch files with uvicorn as the datadir is quite big and monitor is very very slow 2025-08-23 22:36:57 +00:00
Viktor Barzin
526f4fc0c3
update last seen property when processing listings to refresh data 2025-08-23 22:36:37 +00:00
Viktor Barzin
480957dc72
add timeout when fetching details and use new entrypoint for task processing 2025-08-23 22:20:42 +00:00
29 changed files with 776 additions and 532 deletions

39
.github/workflows/ruff.yml vendored Normal file
View file

@ -0,0 +1,39 @@
name: Run Ruff and Auto-merge
on:
pull_request:
types: [opened, synchronize, reopened]
jobs:
ruff-check:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
# Fetch all history for diffing
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11' # Or your desired Python version
- name: Install Ruff
run: pip install ruff
- name: Get changed files
id: changed_files
run: |
# Get a list of changed files between the base and head commits of the PR
git diff --name-only --diff-filter=d ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} > changed_files.txt
# Filter for Python files
grep -E '\.py$' changed_files.txt > python_files.txt
# Remove newlines and join with spaces
echo "files=$(tr '\n' ' ' < python_files.txt)" >> $GITHUB_OUTPUT
- name: Run Ruff on changed files
if: steps.changed_files.outputs.files != ''
run: |
# The ruff command will only run if there are Python files to check
ruff check ${{ steps.changed_files.outputs.files }}

View file

@ -5,10 +5,11 @@ import json
import logging
import pathlib
from typing import Any
from rec.query import detail_query, listing_query, QueryParameters
from listing_processor import ListingProcessor
from rec.query import listing_query
from models.listing import QueryParameters
from rec.districts import get_districts
from repositories import ListingRepository
import requests
from tqdm.asyncio import tqdm
from data_access import Listing
from models import Listing as modelListing
@ -27,15 +28,15 @@ async def dump_listings_full(
"""Fetches all listings, images as well as detects floorplans"""
new_listings = await dump_listings(parameters, repository, data_dir)
logger.debug(f"Upserted {len(new_listings)} new listings")
logger.debug("Starting to fetch floorplans")
await dump_images_module.dump_images(repository, image_base_path=data_dir)
logger.debug("Completed fetching floorplans")
logger.debug("Starting floorplan detection")
await detect_floorplan_module.detect_floorplan(repository)
logger.debug("Completed floorplan detection")
# logger.debug("Starting to fetch floorplans")
# await dump_images_module.dump_images(repository, image_base_path=data_dir)
# logger.debug("Completed fetching floorplans")
# logger.debug("Starting floorplan detection")
# await detect_floorplan_module.detect_floorplan(repository)
# logger.debug("Completed floorplan detection")
# refresh listings
listings = await repository.get_listings(parameters) # this can be better
new_listings = [l for l in listings if l.id in new_listings]
new_listings = [x for x in listings if x.id in new_listings]
return new_listings
@ -77,29 +78,20 @@ async def dump_listings(
listings.append(listing)
# if listing is already in db, do not fetch details again
all_listing_ids = [l.id for l in await repository.get_listings()]
all_listing_ids = [x.id for x in await repository.get_listings()]
missing_listing = [
listing for listing in listings if listing.identifier not in all_listing_ids
]
logger.debug(f"Fetching details for {len(missing_listing)} missing listings")
listing_details = await tqdm.gather(
*[
_fetch_detail_with_semaphore(semaphore, listing.identifier)
for listing in missing_listing
],
desc="Fetching details (only missing)",
missing_ids = [listing.identifier for listing in missing_listing]
missing_ids = [missing_ids[0]]
listing_processor = ListingProcessor(repository)
logger.info(f"Starting processing {len(missing_listing)} new listings")
processed_listings = await tqdm.gather(
*[listing_processor.process_listing(id) for id in missing_ids]
)
for listing, detail in zip(missing_listing, listing_details):
listing._details_object = detail
filtered_listings = [x for x in processed_listings if x is not None]
logger.debug("Dumping listings to fs")
await dump_listings_to_fs(missing_listing)
logger.debug("Upserting listings in db")
model_listings = await repository.upsert_listings_legacy(
missing_listing
) # upsert in db
return model_listings
return filtered_listings
async def _fetch_listings_with_semaphore(
@ -113,7 +105,7 @@ async def _fetch_listings_with_semaphore(
# we do 10 queries each with an increment in price range so we send more queries but each
# has a smaller chance of returning more than 1.5k results
number_of_steps = 10
number_of_steps = 1
price_step = parameters.max_price // number_of_steps
for step in range(number_of_steps):
@ -157,14 +149,6 @@ async def _fetch_listings_with_semaphore(
return result
async def _fetch_detail_with_semaphore(
semaphore: asyncio.Semaphore, listing_id: int
) -> dict[str, Any]:
async with semaphore:
d = await detail_query(listing_id)
return d
async def dump_listings_to_fs(listings: list[Listing]) -> None:
for listing in tqdm(listings, desc="Dumping listings to FS"):
listing.dump_listing()

View file

@ -1,5 +1,4 @@
import asyncio
import json
from pathlib import Path
import aiohttp
from repositories import ListingRepository

View file

@ -36,7 +36,7 @@ async def update_routing_info(
routes_data = routing.transit_route(
listing.latitude,
listing.longtitude,
listing.longitude,
destination_mode.destination_address,
destination_mode.travel_mode,
)

View file

@ -1,8 +1,31 @@
# Setup
1. Instal deps:
```bash
pip install -r requirements.txt
poetry install && cp .env.sample .env
```
2. Check `.env` if you want to customize settings for broker and db
3. run `./start.sh`
This starts the backend
To start the fronend:
```
cd frontend && cp .env.sample .env
```
Change the `DEV_HOST` to any name you want to use to access the web interface.
Next, setup the DNS record (e.g in your /etc/hosts) file.
This is important as auth is done via external [authentik] service that needs to redirect to a name.
Run `./start.sh`
This starts a Caddy proxy with correct certificates, and npm dev server.
All requests going to the frontend are forwarded to the npm server and the ones for the backed (that go to `/api/*`) are forwarded to the backend service.
Lastly, reachout to Viktor to allowlist your `DEV_HOST` so that authentik can authorize callbacks to your host.
# Formatting

View file

@ -1,12 +1,8 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from models import Listing, User # Import all models here
from database import engine
import sqlmodel
from sqlmodel import SQLModel
# this is the Alembic Config object, which provides

View file

@ -29,7 +29,7 @@ def upgrade() -> None:
sa.Column('square_meters', sa.Float(), nullable=True),
sa.Column('agency', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('council_tax_band', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('longtitude', sa.Float(), nullable=False),
sa.Column('longitude', sa.Float(), nullable=False),
sa.Column('latitude', sa.Float(), nullable=False),
sa.Column('price_history_json', sa.TEXT(), nullable=False),
sa.Column('listing_site', sa.Enum('RIGHTMOVE', name='listingsite'), nullable=False),
@ -49,7 +49,7 @@ def upgrade() -> None:
sa.Column('square_meters', sa.Float(), nullable=True),
sa.Column('agency', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('council_tax_band', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('longtitude', sa.Float(), nullable=False),
sa.Column('longitude', sa.Float(), nullable=False),
sa.Column('latitude', sa.Float(), nullable=False),
sa.Column('price_history_json', sa.TEXT(), nullable=False),
sa.Column('listing_site', sa.Enum('RIGHTMOVE', name='listingsite'), nullable=False),

View file

@ -8,7 +8,6 @@ Create Date: 2025-06-30 22:54:11.706618
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.

View file

@ -1,25 +1,18 @@
import asyncio
import dataclasses
from datetime import datetime, timedelta
import json
import logging
import logging.config
from pathlib import Path
import queue
from threading import Thread
from typing import Annotated
import uuid
from api.auth import get_current_user
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, HTTPException, Query
from fastapi import Depends, FastAPI, Query
from api.auth import User
from models.listing import QueryParameters
from notifications import send_notification
from rec import districts
from redis_repository import RedisRepository
from repositories.listing_repository import ListingRepository
from repositories.listing_repository import ListingRepository
from database import engine
from fastapi.middleware.cors import CORSMiddleware
@ -30,6 +23,10 @@ from alembic.config import Config
from contextlib import asynccontextmanager
from celery.exceptions import TaskRevokedError
from celery_app import app as celery_app
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from api.metrics import metrics_app # Import the Prometheus ASGI app
from opentelemetry.metrics import get_meter
load_dotenv()
logger = logging.getLogger("uvicorn")
@ -47,6 +44,16 @@ logger = logging.getLogger("uvicorn")
# app = FastAPI(lifespan=lifespan)
app = FastAPI()
app.mount("/metrics", metrics_app)
meter = get_meter(__name__)
request_counter = meter.create_counter(
name="custom_request_count",
description="Number of times /hello was called",
)
hist = meter.create_histogram(
name="custom_request_duration",
description="Duration of /hello requests in seconds",
)
# Allow CORS (for React frontend)
@ -60,6 +67,8 @@ app.add_middleware(
@app.get("/api/status")
async def get_status():
request_counter.add(1, {"method": "GET", "path": "/status"})
hist.record(1.5, {"method": "GET", "path": "/status"})
return {"status": "OK"}
@ -113,7 +122,7 @@ async def get_task_status(
task_result = listing_tasks.dump_listings_task.AsyncResult(task_id)
try:
result = json.dumps(task_result.result)
except:
except Exception:
result = str(task_result.result)
return {
@ -137,3 +146,6 @@ async def get_districts(
user: Annotated[User, Depends(get_current_user)],
) -> dict[str, str]:
return districts.get_districts()
FastAPIInstrumentor.instrument_app(app)

View file

@ -1,4 +1,3 @@
from datetime import timedelta
from api.config import AUTHENTIK_URL, OIDC_CACHE_TTL, OIDC_CLIENT_ID, OIDC_METADATA_URL
from cachetools import TTLCache
from fastapi import Depends, HTTPException

View file

@ -3,7 +3,7 @@ from collections import defaultdict
from dataclasses import dataclass
import json
import pathlib
from typing import Any, List, Dict
from typing import Any, List
from models.listing import ListingSite, PriceHistoryItem
from rec import floorplan, routing
import re
@ -399,13 +399,7 @@ class Listing:
for item in data
]
@property
def longtitude(self) -> float:
return self.detailobject["property"]["longitude"]
@property
def latitude(self) -> float:
return self.detailobject["property"]["latitude"]
@property
def listing_site(self) -> ListingSite:

View file

@ -195,18 +195,18 @@ export function Map(
.call(xAxis);
}
function openListingsDialog(longtitude: number, latitude: number) {
function openListingsDialog(longitude: number, latitude: number) {
const searchBuffer = 0.001 // ~100m
const properties = heatmap._tree.search({
minX: longtitude - searchBuffer,
maxX: longtitude + searchBuffer,
minX: longitude - searchBuffer,
maxX: longitude + searchBuffer,
minY: latitude - searchBuffer,
maxY: latitude + searchBuffer
})
if (properties.length > 0) {
const listingDialogPopup = getListingDialog(properties);
new mapboxgl.Popup()
.setLngLat([longtitude, latitude])
.setLngLat([longitude, latitude])
.setHTML(renderToString(listingDialogPopup))
.setMaxWidth("500px")
.addTo(mapRef.current);

View file

@ -17,9 +17,11 @@ logger = logging.getLogger("uvicorn.error")
class ListingProcessor:
semaphore: asyncio.Semaphore
process_steps: list[Step]
listing_repository: ListingRepository
def __init__(self, listing_repository: ListingRepository):
self.semaphore = asyncio.Semaphore(20)
self.listing_repository = listing_repository
# Register new processing steps here
# Order is important
self.process_steps = [
@ -29,11 +31,16 @@ class ListingProcessor:
]
async def process_listing(self, listing_id: int) -> Listing | None:
await self.listing_repository.mark_seen(listing_id)
listing = None
for step in self.process_steps:
if await step.needs_processing(listing_id):
async with self.semaphore:
listing = await step.process(listing_id)
try:
listing = await step.process(listing_id)
except Exception as e:
logger.error(f"Failed to process {listing_id=}: {e}")
return None
return listing
async def listing_exists(self, listing_id: int) -> bool: ...
@ -106,7 +113,7 @@ class FetchListingDetailsStep(Step):
council_tax_band=listing_details["property"]["councilTaxInfo"]["content"][
0
]["value"],
longtitude=listing_details["property"]["longitude"],
longitude=listing_details["property"]["longitude"],
latitude=listing_details["property"]["latitude"],
price_history_json="{}", # TODO: should upsert from existing
listing_site=ListingSite.RIGHTMOVE,
@ -145,14 +152,15 @@ class FetchImagesStep(Step):
all_floorplans = listing.additional_info.get("property", {}).get(
"floorplans", []
)
for floorplan in all_floorplans:
url = floorplan["url"]
client_timeout = aiohttp.ClientTimeout(total=30)
for floorplan_obj in all_floorplans:
url = floorplan_obj["url"]
picname = url.split("/")[-1]
floorplan_path = Path(base_path, str(listing.id), "floorplans", picname)
if floorplan_path.exists():
continue
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
async with session.get(url, timeout=client_timeout) as response:
if response.status == 404:
return listing
if response.status != 200:

View file

@ -1,12 +1,10 @@
import asyncio
from datetime import datetime
import json
import os
import pathlib
import click
import importlib
import listing_processor
from models.listing import FurnishType, ListingType, QueryParameters
from rec.districts import get_districts
from data_access import Listing
@ -187,7 +185,6 @@ def dump_images(ctx: click.core.Context):
@cli.command()
@click.pass_context
def detect_floorplan(ctx: click.core.Context):
data_dir = ctx.obj["data_dir"]
click.echo(f"Running detect_floorplan for listings stored in {engine.url}")
repository = ListingRepository(engine=engine)
asyncio.run(detect_floorplan_module.detect_floorplan(repository))

View file

@ -4,11 +4,10 @@ import dataclasses
from datetime import datetime, timedelta
import enum
import json
from pathlib import Path
from typing import Any, Dict, List
from pydantic import BaseModel
from rec import routing
from sqlmodel import JSON, TEXT, SQLModel, Field, String
from sqlmodel import JSON, TEXT, SQLModel, Field
@dataclass(frozen=True)
@ -60,7 +59,7 @@ class Listing(SQLModel, table=False):
square_meters: float | None = Field(default=None, nullable=True, index=True)
agency: str | None = Field(default=None, nullable=True)
council_tax_band: str | None = Field(default=None, nullable=True)
longtitude: float = Field(nullable=False)
longitude: float = Field(nullable=False)
latitude: float = Field(nullable=False)
# price_history: List[Dict[str, Any]] = Field(default_factory=list, sa_type=JSON)
price_history_json: str = Field(sa_type=TEXT)

View file

@ -1,5 +1,4 @@
from abc import abstractmethod
from enum import StrEnum
import apprise
from functools import lru_cache
import os

View file

@ -0,0 +1,24 @@
version: "3.8"
services:
redis:
image: redis:8
container_name: redis-container
ports:
- "6379:6379"
volumes:
- ./data/redis:/data
command: ["redis-server", "--appendonly", "yes"]
mysql:
image: mysql:9
container_name: mysql-container
ports:
- "3306:3306"
environment:
MYSQL_ROOT_PASSWORD: wtfviktordidyoubuildsomuch
MYSQL_DATABASE: wrongmove
MYSQL_USER: wrongmoveuser
MYSQL_PASSWORD: wrongmovepass
volumes:
- ./data/mysql:/var/lib/mysql

997
crawler/poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -33,7 +33,6 @@ response = requests.get(
verify=False,
)
import requests
headers = {
"Host": "api.rightmove.co.uk",

View file

@ -61,7 +61,7 @@ def extract_time(d):
distance_per_transit[step["travelMode"]] += step.get("distanceMeters", 0)
print(
f"dis {distance}, dur {duration}, duration per transit {dict(duration_per_transit)}, distance per transit {dict(distance_per_transit)}"
f"dis {distance}, dur {duration}, duration per transit {dict(duration_per_transit)}, distance per transit {dict(distance_per_transit)}, duration_static {duration_static}"
)

View file

@ -14,7 +14,6 @@ pillow = "^10.2.0"
numpy = "^1.26.4"
transformers = "^4.38.2"
pytesseract = "^0.3.10"
jupyterlab = "^4.1.4"
pandas = "^2.2.1"
geopy = "^2.4.1"
matplotlib = "^3.10.0"
@ -28,7 +27,6 @@ tenacity = "^9.1.2"
fastapi = {extras = ["standard"], version = "^0.115.12"}
pyjwt = "^2.10.1"
cryptography = "^45.0.4"
mysqlclient = "^2.2.7"
celery = "^5.5.3"
redis = "^6.2.0"
watchdog = "^6.0.0"
@ -38,10 +36,20 @@ opentelemetry-sdk = "^1.36.0"
opentelemetry-exporter-prometheus = "^0.57b0"
opentelemetry-instrumentation-fastapi = "^0.57b0"
opentelemetry-instrumentation-sqlalchemy = "^0.57b0"
mysqlclient = "^2.2.7"
[tool.poetry.group.dev.dependencies]
ipdb = "^0.13.13"
jupyterlab = "^4.4.7"
podman-compose = "^1.5.0"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
# List of rules (error codes) to ignore
lint.ignore = [
"E741", # Ambigious name
]
exclude = ["*.ipynb"]

View file

@ -1,10 +1,7 @@
import asyncio
from dataclasses import dataclass
from datetime import datetime
import enum
from typing import Any
import aiohttp
from models.listing import FurnishType, ListingType, QueryParameters
from models.listing import FurnishType, ListingType
from rec import districts
from tenacity import retry, stop_after_attempt, wait_random

View file

@ -1,8 +0,0 @@
def parse_listing_json_entry(d):
id = d["identifier"]
# address = d['address']
propertyType = d["propertyType"]
price = d["price"]
latitude = d["latitude"]
longitude = d["longitude"]
updated_date = d["updateDate"]

View file

@ -10,7 +10,7 @@ from models.listing import (
RentListing,
)
from sqlalchemy import Engine
from sqlmodel import Sequence, Session, and_, col, select
from sqlmodel import Session, select
from sqlmodel.sql.expression import SelectOfScalar
from tqdm import tqdm
@ -160,7 +160,7 @@ class ListingRepository:
square_meters=await listing.sqm_ocr(),
agency=listing.agency,
council_tax_band=listing.councilTaxBand,
longtitude=listing.longtitude,
longitude=listing.longitude,
latitude=listing.latitude,
price_history_json=modelListing.serialize_price_history(
listing.priceHistory
@ -180,7 +180,7 @@ class ListingRepository:
square_meters=await listing.sqm_ocr(),
agency=listing.agency,
council_tax_band=listing.councilTaxBand,
longtitude=listing.longtitude,
longitude=listing.longitude,
latitude=listing.latitude,
price_history_json=modelListing.serialize_price_history(
listing.priceHistory
@ -193,3 +193,12 @@ class ListingRepository:
)
return model_listing
async def mark_seen(self, listing_id: int) -> None:
listings = await self.get_listings(only_ids=[listing_id])
if len(listings) == 0:
return
listing = listings[0]
now = datetime.now()
listing.last_seen = now
await self.upsert_listings([listing])

View file

@ -6,8 +6,6 @@ set -eux
ENV_MODE=${ENV:-"dev"} # Defaults to "dev" if ENV_MODE is unset
echo "Checking connection to redis is successful..."
python celery_app.py
case "$ENV_MODE" in
dev)
@ -16,12 +14,20 @@ case "$ENV_MODE" in
pkill -f celery
pkill watchmedo
set -e
if ! netstat -tlnp |grep 6379; then
echo "Did not find a running redis on 6379. Starting a new instance..."
docker run -d --rm --name redis-server -p 6379:6379 redis:latest
fi
echo "Checking connection to redis is successful..."
python celery_app.py
watchmedo auto-restart --directory=./ --pattern='*.py' --recursive -- celery -A celery_app worker & # DEV to autoreload on changes
CELERY_PID=$!
;;
prod)
echo "🚀 Running in PRODUCTION mode"
echo "Checking connection to redis is successful..."
python celery_app.py
alembic upgrade head
celery -A celery_app worker --beat &
CELERY_PID=$!
@ -42,7 +48,7 @@ cleanup() {
trap cleanup EXIT SIGINT SIGTERM
# celery -A celery_app worker -D # PROD
uvicorn api.app:app --host 0.0.0.0 --port 5001 --reload --reload-exclude "data" --log-level debug
uvicorn api.app:app --host 0.0.0.0 --port 5001 --log-level debug
# UVICORN_PID=$!
# wait for

View file

@ -2,7 +2,7 @@ import json
import logging
import pathlib
from rec.query import QueryParameters
from models.listing import QueryParameters
from repositories.listing_repository import ListingRepository
logger = logging.getLogger("uvicorn.error")
@ -46,7 +46,7 @@ async def export_immoweb(
},
"geometry": {
"coordinates": [
listing.longtitude,
listing.longitude,
listing.latitude,
],
"type": "Point",

View file

@ -24,7 +24,6 @@
"source": [
"from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration\n",
"from PIL import Image\n",
"import pandas as pd\n",
"import re"
]
},

View file

@ -1,4 +1,4 @@
from vqa import Blip, MicrosoftGIT, PixStructDocVA, Vilt, Deplot, VQA
from vqa import MicrosoftGIT, VQA
from PIL import Image
from typing import List
from questions import load_questions

View file

@ -1,18 +1,24 @@
from transformers import BlipProcessor, BlipForQuestionAnswering
from transformers import ViltProcessor, ViltForQuestionAnswering
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from transformers import GitVisionConfig, GitVisionModel, AutoProcessor, GitProcessor
from transformers import GitVisionModel, GitProcessor
from abc import ABC, abstractmethod
from transformers.processing_utils import ProcessorMixin
class VQA:
class VQA(ABC):
name = "Not defined"
def query(image, question: str) -> str:
pass
@abstractmethod
def query(self, image, question: str) -> str:
return "Not implemented"
class Blip(VQA):
name = "Blip"
def query(self, image, question):
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
assert processor is ProcessorMixin
inputs = processor(image, question, return_tensors="pt")
out = model.generate(max_new_tokens=50000, **inputs)
return processor.decode(out[0], skip_special_tokens=True)
@ -25,6 +31,7 @@ class Vilt(VQA):
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
# prepare inputs
assert processor is ProcessorMixin
encoding = processor(image, question, return_tensors="pt")
# forward pass
@ -41,6 +48,7 @@ class Deplot(VQA):
processor = Pix2StructProcessor.from_pretrained('google/deplot')
model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
assert processor is ProcessorMixin
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs, max_new_tokens=512)
return processor.decode(predictions[0], skip_special_tokens=True)
@ -53,6 +61,7 @@ class PixStructDocVA(VQA):
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large")
assert processor is ProcessorMixin
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs, max_new_tokens=10000)
answer = processor.decode(predictions[0], skip_special_tokens=True)
@ -64,6 +73,8 @@ class MicrosoftGIT(VQA):
def query(self, image, question):
processor = GitProcessor.from_pretrained("microsoft/git-base")
model = GitVisionModel.from_pretrained("microsoft/git-base")
assert processor is ProcessorMixin
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs, max_new_tokens=10000)
answer = processor.decode(predictions[0], skip_special_tokens=True)