Compare commits
10 commits
e5c68f6bb7
...
ced9a153bd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ced9a153bd | ||
|
|
0801aaf200 | ||
|
|
4c23acdb55 | ||
|
|
b1e0a414cf | ||
|
|
8d11e4a81c | ||
|
|
520286aaee | ||
|
|
62329a2eb4 | ||
|
|
ff57117054 | ||
|
|
526f4fc0c3 | ||
|
|
480957dc72 |
29 changed files with 776 additions and 532 deletions
39
.github/workflows/ruff.yml
vendored
Normal file
39
.github/workflows/ruff.yml
vendored
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
name: Run Ruff and Auto-merge
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, synchronize, reopened]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
ruff-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
# Fetch all history for diffing
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11' # Or your desired Python version
|
||||||
|
|
||||||
|
- name: Install Ruff
|
||||||
|
run: pip install ruff
|
||||||
|
|
||||||
|
- name: Get changed files
|
||||||
|
id: changed_files
|
||||||
|
run: |
|
||||||
|
# Get a list of changed files between the base and head commits of the PR
|
||||||
|
git diff --name-only --diff-filter=d ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} > changed_files.txt
|
||||||
|
# Filter for Python files
|
||||||
|
grep -E '\.py$' changed_files.txt > python_files.txt
|
||||||
|
# Remove newlines and join with spaces
|
||||||
|
echo "files=$(tr '\n' ' ' < python_files.txt)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Run Ruff on changed files
|
||||||
|
if: steps.changed_files.outputs.files != ''
|
||||||
|
run: |
|
||||||
|
# The ruff command will only run if there are Python files to check
|
||||||
|
ruff check ${{ steps.changed_files.outputs.files }}
|
||||||
|
|
@ -5,10 +5,11 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from rec.query import detail_query, listing_query, QueryParameters
|
from listing_processor import ListingProcessor
|
||||||
|
from rec.query import listing_query
|
||||||
|
from models.listing import QueryParameters
|
||||||
from rec.districts import get_districts
|
from rec.districts import get_districts
|
||||||
from repositories import ListingRepository
|
from repositories import ListingRepository
|
||||||
import requests
|
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
from data_access import Listing
|
from data_access import Listing
|
||||||
from models import Listing as modelListing
|
from models import Listing as modelListing
|
||||||
|
|
@ -27,15 +28,15 @@ async def dump_listings_full(
|
||||||
"""Fetches all listings, images as well as detects floorplans"""
|
"""Fetches all listings, images as well as detects floorplans"""
|
||||||
new_listings = await dump_listings(parameters, repository, data_dir)
|
new_listings = await dump_listings(parameters, repository, data_dir)
|
||||||
logger.debug(f"Upserted {len(new_listings)} new listings")
|
logger.debug(f"Upserted {len(new_listings)} new listings")
|
||||||
logger.debug("Starting to fetch floorplans")
|
# logger.debug("Starting to fetch floorplans")
|
||||||
await dump_images_module.dump_images(repository, image_base_path=data_dir)
|
# await dump_images_module.dump_images(repository, image_base_path=data_dir)
|
||||||
logger.debug("Completed fetching floorplans")
|
# logger.debug("Completed fetching floorplans")
|
||||||
logger.debug("Starting floorplan detection")
|
# logger.debug("Starting floorplan detection")
|
||||||
await detect_floorplan_module.detect_floorplan(repository)
|
# await detect_floorplan_module.detect_floorplan(repository)
|
||||||
logger.debug("Completed floorplan detection")
|
# logger.debug("Completed floorplan detection")
|
||||||
# refresh listings
|
# refresh listings
|
||||||
listings = await repository.get_listings(parameters) # this can be better
|
listings = await repository.get_listings(parameters) # this can be better
|
||||||
new_listings = [l for l in listings if l.id in new_listings]
|
new_listings = [x for x in listings if x.id in new_listings]
|
||||||
return new_listings
|
return new_listings
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -77,29 +78,20 @@ async def dump_listings(
|
||||||
listings.append(listing)
|
listings.append(listing)
|
||||||
|
|
||||||
# if listing is already in db, do not fetch details again
|
# if listing is already in db, do not fetch details again
|
||||||
all_listing_ids = [l.id for l in await repository.get_listings()]
|
all_listing_ids = [x.id for x in await repository.get_listings()]
|
||||||
missing_listing = [
|
missing_listing = [
|
||||||
listing for listing in listings if listing.identifier not in all_listing_ids
|
listing for listing in listings if listing.identifier not in all_listing_ids
|
||||||
]
|
]
|
||||||
logger.debug(f"Fetching details for {len(missing_listing)} missing listings")
|
missing_ids = [listing.identifier for listing in missing_listing]
|
||||||
listing_details = await tqdm.gather(
|
missing_ids = [missing_ids[0]]
|
||||||
*[
|
listing_processor = ListingProcessor(repository)
|
||||||
_fetch_detail_with_semaphore(semaphore, listing.identifier)
|
logger.info(f"Starting processing {len(missing_listing)} new listings")
|
||||||
for listing in missing_listing
|
processed_listings = await tqdm.gather(
|
||||||
],
|
*[listing_processor.process_listing(id) for id in missing_ids]
|
||||||
desc="Fetching details (only missing)",
|
|
||||||
)
|
)
|
||||||
for listing, detail in zip(missing_listing, listing_details):
|
filtered_listings = [x for x in processed_listings if x is not None]
|
||||||
listing._details_object = detail
|
|
||||||
|
|
||||||
logger.debug("Dumping listings to fs")
|
return filtered_listings
|
||||||
await dump_listings_to_fs(missing_listing)
|
|
||||||
logger.debug("Upserting listings in db")
|
|
||||||
model_listings = await repository.upsert_listings_legacy(
|
|
||||||
missing_listing
|
|
||||||
) # upsert in db
|
|
||||||
|
|
||||||
return model_listings
|
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_listings_with_semaphore(
|
async def _fetch_listings_with_semaphore(
|
||||||
|
|
@ -113,7 +105,7 @@ async def _fetch_listings_with_semaphore(
|
||||||
# we do 10 queries each with an increment in price range so we send more queries but each
|
# we do 10 queries each with an increment in price range so we send more queries but each
|
||||||
# has a smaller chance of returning more than 1.5k results
|
# has a smaller chance of returning more than 1.5k results
|
||||||
|
|
||||||
number_of_steps = 10
|
number_of_steps = 1
|
||||||
price_step = parameters.max_price // number_of_steps
|
price_step = parameters.max_price // number_of_steps
|
||||||
|
|
||||||
for step in range(number_of_steps):
|
for step in range(number_of_steps):
|
||||||
|
|
@ -157,14 +149,6 @@ async def _fetch_listings_with_semaphore(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_detail_with_semaphore(
|
|
||||||
semaphore: asyncio.Semaphore, listing_id: int
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
async with semaphore:
|
|
||||||
d = await detail_query(listing_id)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
async def dump_listings_to_fs(listings: list[Listing]) -> None:
|
async def dump_listings_to_fs(listings: list[Listing]) -> None:
|
||||||
for listing in tqdm(listings, desc="Dumping listings to FS"):
|
for listing in tqdm(listings, desc="Dumping listings to FS"):
|
||||||
listing.dump_listing()
|
listing.dump_listing()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from repositories import ListingRepository
|
from repositories import ListingRepository
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ async def update_routing_info(
|
||||||
|
|
||||||
routes_data = routing.transit_route(
|
routes_data = routing.transit_route(
|
||||||
listing.latitude,
|
listing.latitude,
|
||||||
listing.longtitude,
|
listing.longitude,
|
||||||
destination_mode.destination_address,
|
destination_mode.destination_address,
|
||||||
destination_mode.travel_mode,
|
destination_mode.travel_mode,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,31 @@
|
||||||
# Setup
|
# Setup
|
||||||
|
|
||||||
|
1. Instal deps:
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
poetry install && cp .env.sample .env
|
||||||
```
|
```
|
||||||
|
2. Check `.env` if you want to customize settings for broker and db
|
||||||
|
3. run `./start.sh`
|
||||||
|
|
||||||
|
This starts the backend
|
||||||
|
|
||||||
|
To start the fronend:
|
||||||
|
|
||||||
|
```
|
||||||
|
cd frontend && cp .env.sample .env
|
||||||
|
```
|
||||||
|
Change the `DEV_HOST` to any name you want to use to access the web interface.
|
||||||
|
|
||||||
|
Next, setup the DNS record (e.g in your /etc/hosts) file.
|
||||||
|
This is important as auth is done via external [authentik] service that needs to redirect to a name.
|
||||||
|
|
||||||
|
Run `./start.sh`
|
||||||
|
|
||||||
|
This starts a Caddy proxy with correct certificates, and npm dev server.
|
||||||
|
All requests going to the frontend are forwarded to the npm server and the ones for the backed (that go to `/api/*`) are forwarded to the backend service.
|
||||||
|
|
||||||
|
Lastly, reachout to Viktor to allowlist your `DEV_HOST` so that authentik can authorize callbacks to your host.
|
||||||
|
|
||||||
|
|
||||||
# Formatting
|
# Formatting
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,8 @@
|
||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from sqlalchemy import engine_from_config
|
|
||||||
from sqlalchemy import pool
|
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from models import Listing, User # Import all models here
|
|
||||||
from database import engine
|
from database import engine
|
||||||
import sqlmodel
|
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ def upgrade() -> None:
|
||||||
sa.Column('square_meters', sa.Float(), nullable=True),
|
sa.Column('square_meters', sa.Float(), nullable=True),
|
||||||
sa.Column('agency', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
sa.Column('agency', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
sa.Column('council_tax_band', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
sa.Column('council_tax_band', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
sa.Column('longtitude', sa.Float(), nullable=False),
|
sa.Column('longitude', sa.Float(), nullable=False),
|
||||||
sa.Column('latitude', sa.Float(), nullable=False),
|
sa.Column('latitude', sa.Float(), nullable=False),
|
||||||
sa.Column('price_history_json', sa.TEXT(), nullable=False),
|
sa.Column('price_history_json', sa.TEXT(), nullable=False),
|
||||||
sa.Column('listing_site', sa.Enum('RIGHTMOVE', name='listingsite'), nullable=False),
|
sa.Column('listing_site', sa.Enum('RIGHTMOVE', name='listingsite'), nullable=False),
|
||||||
|
|
@ -49,7 +49,7 @@ def upgrade() -> None:
|
||||||
sa.Column('square_meters', sa.Float(), nullable=True),
|
sa.Column('square_meters', sa.Float(), nullable=True),
|
||||||
sa.Column('agency', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
sa.Column('agency', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
sa.Column('council_tax_band', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
sa.Column('council_tax_band', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
sa.Column('longtitude', sa.Float(), nullable=False),
|
sa.Column('longitude', sa.Float(), nullable=False),
|
||||||
sa.Column('latitude', sa.Float(), nullable=False),
|
sa.Column('latitude', sa.Float(), nullable=False),
|
||||||
sa.Column('price_history_json', sa.TEXT(), nullable=False),
|
sa.Column('price_history_json', sa.TEXT(), nullable=False),
|
||||||
sa.Column('listing_site', sa.Enum('RIGHTMOVE', name='listingsite'), nullable=False),
|
sa.Column('listing_site', sa.Enum('RIGHTMOVE', name='listingsite'), nullable=False),
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ Create Date: 2025-06-30 22:54:11.706618
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,18 @@
|
||||||
import asyncio
|
|
||||||
import dataclasses
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import logging.config
|
import logging.config
|
||||||
from pathlib import Path
|
|
||||||
import queue
|
|
||||||
from threading import Thread
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
import uuid
|
|
||||||
from api.auth import get_current_user
|
from api.auth import get_current_user
|
||||||
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
from api.config import DEV_TIER_ORIGINS, PROD_TIER_ORIGINS
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Query
|
from fastapi import Depends, FastAPI, Query
|
||||||
from api.auth import User
|
from api.auth import User
|
||||||
from models.listing import QueryParameters
|
from models.listing import QueryParameters
|
||||||
from notifications import send_notification
|
from notifications import send_notification
|
||||||
from rec import districts
|
from rec import districts
|
||||||
from redis_repository import RedisRepository
|
from redis_repository import RedisRepository
|
||||||
from repositories.listing_repository import ListingRepository
|
from repositories.listing_repository import ListingRepository
|
||||||
from repositories.listing_repository import ListingRepository
|
|
||||||
from database import engine
|
from database import engine
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
|
@ -30,6 +23,10 @@ from alembic.config import Config
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from celery.exceptions import TaskRevokedError
|
from celery.exceptions import TaskRevokedError
|
||||||
from celery_app import app as celery_app
|
from celery_app import app as celery_app
|
||||||
|
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||||
|
from api.metrics import metrics_app # Import the Prometheus ASGI app
|
||||||
|
from opentelemetry.metrics import get_meter
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
logger = logging.getLogger("uvicorn")
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
@ -47,6 +44,16 @@ logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
# app = FastAPI(lifespan=lifespan)
|
# app = FastAPI(lifespan=lifespan)
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
app.mount("/metrics", metrics_app)
|
||||||
|
meter = get_meter(__name__)
|
||||||
|
request_counter = meter.create_counter(
|
||||||
|
name="custom_request_count",
|
||||||
|
description="Number of times /hello was called",
|
||||||
|
)
|
||||||
|
hist = meter.create_histogram(
|
||||||
|
name="custom_request_duration",
|
||||||
|
description="Duration of /hello requests in seconds",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Allow CORS (for React frontend)
|
# Allow CORS (for React frontend)
|
||||||
|
|
@ -60,6 +67,8 @@ app.add_middleware(
|
||||||
|
|
||||||
@app.get("/api/status")
|
@app.get("/api/status")
|
||||||
async def get_status():
|
async def get_status():
|
||||||
|
request_counter.add(1, {"method": "GET", "path": "/status"})
|
||||||
|
hist.record(1.5, {"method": "GET", "path": "/status"})
|
||||||
return {"status": "OK"}
|
return {"status": "OK"}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -113,7 +122,7 @@ async def get_task_status(
|
||||||
task_result = listing_tasks.dump_listings_task.AsyncResult(task_id)
|
task_result = listing_tasks.dump_listings_task.AsyncResult(task_id)
|
||||||
try:
|
try:
|
||||||
result = json.dumps(task_result.result)
|
result = json.dumps(task_result.result)
|
||||||
except:
|
except Exception:
|
||||||
result = str(task_result.result)
|
result = str(task_result.result)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -137,3 +146,6 @@ async def get_districts(
|
||||||
user: Annotated[User, Depends(get_current_user)],
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
return districts.get_districts()
|
return districts.get_districts()
|
||||||
|
|
||||||
|
|
||||||
|
FastAPIInstrumentor.instrument_app(app)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
from datetime import timedelta
|
|
||||||
from api.config import AUTHENTIK_URL, OIDC_CACHE_TTL, OIDC_CLIENT_ID, OIDC_METADATA_URL
|
from api.config import AUTHENTIK_URL, OIDC_CACHE_TTL, OIDC_CLIENT_ID, OIDC_METADATA_URL
|
||||||
from cachetools import TTLCache
|
from cachetools import TTLCache
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Any, List, Dict
|
from typing import Any, List
|
||||||
from models.listing import ListingSite, PriceHistoryItem
|
from models.listing import ListingSite, PriceHistoryItem
|
||||||
from rec import floorplan, routing
|
from rec import floorplan, routing
|
||||||
import re
|
import re
|
||||||
|
|
@ -399,13 +399,7 @@ class Listing:
|
||||||
for item in data
|
for item in data
|
||||||
]
|
]
|
||||||
|
|
||||||
@property
|
|
||||||
def longtitude(self) -> float:
|
|
||||||
return self.detailobject["property"]["longitude"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def latitude(self) -> float:
|
|
||||||
return self.detailobject["property"]["latitude"]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def listing_site(self) -> ListingSite:
|
def listing_site(self) -> ListingSite:
|
||||||
|
|
|
||||||
|
|
@ -195,18 +195,18 @@ export function Map(
|
||||||
.call(xAxis);
|
.call(xAxis);
|
||||||
}
|
}
|
||||||
|
|
||||||
function openListingsDialog(longtitude: number, latitude: number) {
|
function openListingsDialog(longitude: number, latitude: number) {
|
||||||
const searchBuffer = 0.001 // ~100m
|
const searchBuffer = 0.001 // ~100m
|
||||||
const properties = heatmap._tree.search({
|
const properties = heatmap._tree.search({
|
||||||
minX: longtitude - searchBuffer,
|
minX: longitude - searchBuffer,
|
||||||
maxX: longtitude + searchBuffer,
|
maxX: longitude + searchBuffer,
|
||||||
minY: latitude - searchBuffer,
|
minY: latitude - searchBuffer,
|
||||||
maxY: latitude + searchBuffer
|
maxY: latitude + searchBuffer
|
||||||
})
|
})
|
||||||
if (properties.length > 0) {
|
if (properties.length > 0) {
|
||||||
const listingDialogPopup = getListingDialog(properties);
|
const listingDialogPopup = getListingDialog(properties);
|
||||||
new mapboxgl.Popup()
|
new mapboxgl.Popup()
|
||||||
.setLngLat([longtitude, latitude])
|
.setLngLat([longitude, latitude])
|
||||||
.setHTML(renderToString(listingDialogPopup))
|
.setHTML(renderToString(listingDialogPopup))
|
||||||
.setMaxWidth("500px")
|
.setMaxWidth("500px")
|
||||||
.addTo(mapRef.current);
|
.addTo(mapRef.current);
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,11 @@ logger = logging.getLogger("uvicorn.error")
|
||||||
class ListingProcessor:
|
class ListingProcessor:
|
||||||
semaphore: asyncio.Semaphore
|
semaphore: asyncio.Semaphore
|
||||||
process_steps: list[Step]
|
process_steps: list[Step]
|
||||||
|
listing_repository: ListingRepository
|
||||||
|
|
||||||
def __init__(self, listing_repository: ListingRepository):
|
def __init__(self, listing_repository: ListingRepository):
|
||||||
self.semaphore = asyncio.Semaphore(20)
|
self.semaphore = asyncio.Semaphore(20)
|
||||||
|
self.listing_repository = listing_repository
|
||||||
# Register new processing steps here
|
# Register new processing steps here
|
||||||
# Order is important
|
# Order is important
|
||||||
self.process_steps = [
|
self.process_steps = [
|
||||||
|
|
@ -29,11 +31,16 @@ class ListingProcessor:
|
||||||
]
|
]
|
||||||
|
|
||||||
async def process_listing(self, listing_id: int) -> Listing | None:
|
async def process_listing(self, listing_id: int) -> Listing | None:
|
||||||
|
await self.listing_repository.mark_seen(listing_id)
|
||||||
listing = None
|
listing = None
|
||||||
for step in self.process_steps:
|
for step in self.process_steps:
|
||||||
if await step.needs_processing(listing_id):
|
if await step.needs_processing(listing_id):
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
listing = await step.process(listing_id)
|
try:
|
||||||
|
listing = await step.process(listing_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to process {listing_id=}: {e}")
|
||||||
|
return None
|
||||||
return listing
|
return listing
|
||||||
|
|
||||||
async def listing_exists(self, listing_id: int) -> bool: ...
|
async def listing_exists(self, listing_id: int) -> bool: ...
|
||||||
|
|
@ -106,7 +113,7 @@ class FetchListingDetailsStep(Step):
|
||||||
council_tax_band=listing_details["property"]["councilTaxInfo"]["content"][
|
council_tax_band=listing_details["property"]["councilTaxInfo"]["content"][
|
||||||
0
|
0
|
||||||
]["value"],
|
]["value"],
|
||||||
longtitude=listing_details["property"]["longitude"],
|
longitude=listing_details["property"]["longitude"],
|
||||||
latitude=listing_details["property"]["latitude"],
|
latitude=listing_details["property"]["latitude"],
|
||||||
price_history_json="{}", # TODO: should upsert from existing
|
price_history_json="{}", # TODO: should upsert from existing
|
||||||
listing_site=ListingSite.RIGHTMOVE,
|
listing_site=ListingSite.RIGHTMOVE,
|
||||||
|
|
@ -145,14 +152,15 @@ class FetchImagesStep(Step):
|
||||||
all_floorplans = listing.additional_info.get("property", {}).get(
|
all_floorplans = listing.additional_info.get("property", {}).get(
|
||||||
"floorplans", []
|
"floorplans", []
|
||||||
)
|
)
|
||||||
for floorplan in all_floorplans:
|
client_timeout = aiohttp.ClientTimeout(total=30)
|
||||||
url = floorplan["url"]
|
for floorplan_obj in all_floorplans:
|
||||||
|
url = floorplan_obj["url"]
|
||||||
picname = url.split("/")[-1]
|
picname = url.split("/")[-1]
|
||||||
floorplan_path = Path(base_path, str(listing.id), "floorplans", picname)
|
floorplan_path = Path(base_path, str(listing.id), "floorplans", picname)
|
||||||
if floorplan_path.exists():
|
if floorplan_path.exists():
|
||||||
continue
|
continue
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(url) as response:
|
async with session.get(url, timeout=client_timeout) as response:
|
||||||
if response.status == 404:
|
if response.status == 404:
|
||||||
return listing
|
return listing
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import click
|
import click
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import listing_processor
|
|
||||||
from models.listing import FurnishType, ListingType, QueryParameters
|
from models.listing import FurnishType, ListingType, QueryParameters
|
||||||
from rec.districts import get_districts
|
from rec.districts import get_districts
|
||||||
from data_access import Listing
|
from data_access import Listing
|
||||||
|
|
@ -187,7 +185,6 @@ def dump_images(ctx: click.core.Context):
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def detect_floorplan(ctx: click.core.Context):
|
def detect_floorplan(ctx: click.core.Context):
|
||||||
data_dir = ctx.obj["data_dir"]
|
|
||||||
click.echo(f"Running detect_floorplan for listings stored in {engine.url}")
|
click.echo(f"Running detect_floorplan for listings stored in {engine.url}")
|
||||||
repository = ListingRepository(engine=engine)
|
repository = ListingRepository(engine=engine)
|
||||||
asyncio.run(detect_floorplan_module.detect_floorplan(repository))
|
asyncio.run(detect_floorplan_module.detect_floorplan(repository))
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,10 @@ import dataclasses
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from rec import routing
|
from rec import routing
|
||||||
from sqlmodel import JSON, TEXT, SQLModel, Field, String
|
from sqlmodel import JSON, TEXT, SQLModel, Field
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -60,7 +59,7 @@ class Listing(SQLModel, table=False):
|
||||||
square_meters: float | None = Field(default=None, nullable=True, index=True)
|
square_meters: float | None = Field(default=None, nullable=True, index=True)
|
||||||
agency: str | None = Field(default=None, nullable=True)
|
agency: str | None = Field(default=None, nullable=True)
|
||||||
council_tax_band: str | None = Field(default=None, nullable=True)
|
council_tax_band: str | None = Field(default=None, nullable=True)
|
||||||
longtitude: float = Field(nullable=False)
|
longitude: float = Field(nullable=False)
|
||||||
latitude: float = Field(nullable=False)
|
latitude: float = Field(nullable=False)
|
||||||
# price_history: List[Dict[str, Any]] = Field(default_factory=list, sa_type=JSON)
|
# price_history: List[Dict[str, Any]] = Field(default_factory=list, sa_type=JSON)
|
||||||
price_history_json: str = Field(sa_type=TEXT)
|
price_history_json: str = Field(sa_type=TEXT)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import StrEnum
|
|
||||||
import apprise
|
import apprise
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
24
crawler/podman-compose.yml
Normal file
24
crawler/podman-compose.yml
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
version: "3.8"
|
||||||
|
|
||||||
|
services:
|
||||||
|
redis:
|
||||||
|
image: redis:8
|
||||||
|
container_name: redis-container
|
||||||
|
ports:
|
||||||
|
- "6379:6379"
|
||||||
|
volumes:
|
||||||
|
- ./data/redis:/data
|
||||||
|
command: ["redis-server", "--appendonly", "yes"]
|
||||||
|
|
||||||
|
mysql:
|
||||||
|
image: mysql:9
|
||||||
|
container_name: mysql-container
|
||||||
|
ports:
|
||||||
|
- "3306:3306"
|
||||||
|
environment:
|
||||||
|
MYSQL_ROOT_PASSWORD: wtfviktordidyoubuildsomuch
|
||||||
|
MYSQL_DATABASE: wrongmove
|
||||||
|
MYSQL_USER: wrongmoveuser
|
||||||
|
MYSQL_PASSWORD: wrongmovepass
|
||||||
|
volumes:
|
||||||
|
- ./data/mysql:/var/lib/mysql
|
||||||
997
crawler/poetry.lock
generated
997
crawler/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -33,7 +33,6 @@ response = requests.get(
|
||||||
verify=False,
|
verify=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Host": "api.rightmove.co.uk",
|
"Host": "api.rightmove.co.uk",
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ def extract_time(d):
|
||||||
distance_per_transit[step["travelMode"]] += step.get("distanceMeters", 0)
|
distance_per_transit[step["travelMode"]] += step.get("distanceMeters", 0)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"dis {distance}, dur {duration}, duration per transit {dict(duration_per_transit)}, distance per transit {dict(distance_per_transit)}"
|
f"dis {distance}, dur {duration}, duration per transit {dict(duration_per_transit)}, distance per transit {dict(distance_per_transit)}, duration_static {duration_static}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ pillow = "^10.2.0"
|
||||||
numpy = "^1.26.4"
|
numpy = "^1.26.4"
|
||||||
transformers = "^4.38.2"
|
transformers = "^4.38.2"
|
||||||
pytesseract = "^0.3.10"
|
pytesseract = "^0.3.10"
|
||||||
jupyterlab = "^4.1.4"
|
|
||||||
pandas = "^2.2.1"
|
pandas = "^2.2.1"
|
||||||
geopy = "^2.4.1"
|
geopy = "^2.4.1"
|
||||||
matplotlib = "^3.10.0"
|
matplotlib = "^3.10.0"
|
||||||
|
|
@ -28,7 +27,6 @@ tenacity = "^9.1.2"
|
||||||
fastapi = {extras = ["standard"], version = "^0.115.12"}
|
fastapi = {extras = ["standard"], version = "^0.115.12"}
|
||||||
pyjwt = "^2.10.1"
|
pyjwt = "^2.10.1"
|
||||||
cryptography = "^45.0.4"
|
cryptography = "^45.0.4"
|
||||||
mysqlclient = "^2.2.7"
|
|
||||||
celery = "^5.5.3"
|
celery = "^5.5.3"
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
watchdog = "^6.0.0"
|
watchdog = "^6.0.0"
|
||||||
|
|
@ -38,10 +36,20 @@ opentelemetry-sdk = "^1.36.0"
|
||||||
opentelemetry-exporter-prometheus = "^0.57b0"
|
opentelemetry-exporter-prometheus = "^0.57b0"
|
||||||
opentelemetry-instrumentation-fastapi = "^0.57b0"
|
opentelemetry-instrumentation-fastapi = "^0.57b0"
|
||||||
opentelemetry-instrumentation-sqlalchemy = "^0.57b0"
|
opentelemetry-instrumentation-sqlalchemy = "^0.57b0"
|
||||||
|
mysqlclient = "^2.2.7"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
ipdb = "^0.13.13"
|
ipdb = "^0.13.13"
|
||||||
|
jupyterlab = "^4.4.7"
|
||||||
|
podman-compose = "^1.5.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=1.0.0"]
|
requires = ["poetry-core>=1.0.0"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
# List of rules (error codes) to ignore
|
||||||
|
lint.ignore = [
|
||||||
|
"E741", # Ambigious name
|
||||||
|
]
|
||||||
|
exclude = ["*.ipynb"]
|
||||||
|
|
@ -1,10 +1,7 @@
|
||||||
import asyncio
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
|
||||||
import enum
|
import enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from models.listing import FurnishType, ListingType, QueryParameters
|
from models.listing import FurnishType, ListingType
|
||||||
from rec import districts
|
from rec import districts
|
||||||
from tenacity import retry, stop_after_attempt, wait_random
|
from tenacity import retry, stop_after_attempt, wait_random
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
def parse_listing_json_entry(d):
|
|
||||||
id = d["identifier"]
|
|
||||||
# address = d['address']
|
|
||||||
propertyType = d["propertyType"]
|
|
||||||
price = d["price"]
|
|
||||||
latitude = d["latitude"]
|
|
||||||
longitude = d["longitude"]
|
|
||||||
updated_date = d["updateDate"]
|
|
||||||
|
|
@ -10,7 +10,7 @@ from models.listing import (
|
||||||
RentListing,
|
RentListing,
|
||||||
)
|
)
|
||||||
from sqlalchemy import Engine
|
from sqlalchemy import Engine
|
||||||
from sqlmodel import Sequence, Session, and_, col, select
|
from sqlmodel import Session, select
|
||||||
from sqlmodel.sql.expression import SelectOfScalar
|
from sqlmodel.sql.expression import SelectOfScalar
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
@ -160,7 +160,7 @@ class ListingRepository:
|
||||||
square_meters=await listing.sqm_ocr(),
|
square_meters=await listing.sqm_ocr(),
|
||||||
agency=listing.agency,
|
agency=listing.agency,
|
||||||
council_tax_band=listing.councilTaxBand,
|
council_tax_band=listing.councilTaxBand,
|
||||||
longtitude=listing.longtitude,
|
longitude=listing.longitude,
|
||||||
latitude=listing.latitude,
|
latitude=listing.latitude,
|
||||||
price_history_json=modelListing.serialize_price_history(
|
price_history_json=modelListing.serialize_price_history(
|
||||||
listing.priceHistory
|
listing.priceHistory
|
||||||
|
|
@ -180,7 +180,7 @@ class ListingRepository:
|
||||||
square_meters=await listing.sqm_ocr(),
|
square_meters=await listing.sqm_ocr(),
|
||||||
agency=listing.agency,
|
agency=listing.agency,
|
||||||
council_tax_band=listing.councilTaxBand,
|
council_tax_band=listing.councilTaxBand,
|
||||||
longtitude=listing.longtitude,
|
longitude=listing.longitude,
|
||||||
latitude=listing.latitude,
|
latitude=listing.latitude,
|
||||||
price_history_json=modelListing.serialize_price_history(
|
price_history_json=modelListing.serialize_price_history(
|
||||||
listing.priceHistory
|
listing.priceHistory
|
||||||
|
|
@ -193,3 +193,12 @@ class ListingRepository:
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_listing
|
return model_listing
|
||||||
|
|
||||||
|
async def mark_seen(self, listing_id: int) -> None:
|
||||||
|
listings = await self.get_listings(only_ids=[listing_id])
|
||||||
|
if len(listings) == 0:
|
||||||
|
return
|
||||||
|
listing = listings[0]
|
||||||
|
now = datetime.now()
|
||||||
|
listing.last_seen = now
|
||||||
|
await self.upsert_listings([listing])
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,6 @@ set -eux
|
||||||
|
|
||||||
ENV_MODE=${ENV:-"dev"} # Defaults to "dev" if ENV_MODE is unset
|
ENV_MODE=${ENV:-"dev"} # Defaults to "dev" if ENV_MODE is unset
|
||||||
|
|
||||||
echo "Checking connection to redis is successful..."
|
|
||||||
python celery_app.py
|
|
||||||
|
|
||||||
case "$ENV_MODE" in
|
case "$ENV_MODE" in
|
||||||
dev)
|
dev)
|
||||||
|
|
@ -16,12 +14,20 @@ case "$ENV_MODE" in
|
||||||
pkill -f celery
|
pkill -f celery
|
||||||
pkill watchmedo
|
pkill watchmedo
|
||||||
set -e
|
set -e
|
||||||
|
if ! netstat -tlnp |grep 6379; then
|
||||||
|
echo "Did not find a running redis on 6379. Starting a new instance..."
|
||||||
|
docker run -d --rm --name redis-server -p 6379:6379 redis:latest
|
||||||
|
fi
|
||||||
|
echo "Checking connection to redis is successful..."
|
||||||
|
python celery_app.py
|
||||||
|
|
||||||
watchmedo auto-restart --directory=./ --pattern='*.py' --recursive -- celery -A celery_app worker & # DEV to autoreload on changes
|
watchmedo auto-restart --directory=./ --pattern='*.py' --recursive -- celery -A celery_app worker & # DEV to autoreload on changes
|
||||||
CELERY_PID=$!
|
CELERY_PID=$!
|
||||||
;;
|
;;
|
||||||
prod)
|
prod)
|
||||||
echo "🚀 Running in PRODUCTION mode"
|
echo "🚀 Running in PRODUCTION mode"
|
||||||
|
echo "Checking connection to redis is successful..."
|
||||||
|
python celery_app.py
|
||||||
alembic upgrade head
|
alembic upgrade head
|
||||||
celery -A celery_app worker --beat &
|
celery -A celery_app worker --beat &
|
||||||
CELERY_PID=$!
|
CELERY_PID=$!
|
||||||
|
|
@ -42,7 +48,7 @@ cleanup() {
|
||||||
trap cleanup EXIT SIGINT SIGTERM
|
trap cleanup EXIT SIGINT SIGTERM
|
||||||
|
|
||||||
# celery -A celery_app worker -D # PROD
|
# celery -A celery_app worker -D # PROD
|
||||||
uvicorn api.app:app --host 0.0.0.0 --port 5001 --reload --reload-exclude "data" --log-level debug
|
uvicorn api.app:app --host 0.0.0.0 --port 5001 --log-level debug
|
||||||
# UVICORN_PID=$!
|
# UVICORN_PID=$!
|
||||||
|
|
||||||
# wait for
|
# wait for
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from rec.query import QueryParameters
|
from models.listing import QueryParameters
|
||||||
from repositories.listing_repository import ListingRepository
|
from repositories.listing_repository import ListingRepository
|
||||||
|
|
||||||
logger = logging.getLogger("uvicorn.error")
|
logger = logging.getLogger("uvicorn.error")
|
||||||
|
|
@ -46,7 +46,7 @@ async def export_immoweb(
|
||||||
},
|
},
|
||||||
"geometry": {
|
"geometry": {
|
||||||
"coordinates": [
|
"coordinates": [
|
||||||
listing.longtitude,
|
listing.longitude,
|
||||||
listing.latitude,
|
listing.latitude,
|
||||||
],
|
],
|
||||||
"type": "Point",
|
"type": "Point",
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration\n",
|
"from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration\n",
|
||||||
"from PIL import Image\n",
|
"from PIL import Image\n",
|
||||||
"import pandas as pd\n",
|
|
||||||
"import re"
|
"import re"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from vqa import Blip, MicrosoftGIT, PixStructDocVA, Vilt, Deplot, VQA
|
from vqa import MicrosoftGIT, VQA
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import List
|
from typing import List
|
||||||
from questions import load_questions
|
from questions import load_questions
|
||||||
|
|
|
||||||
19
vqa/vqa.py
19
vqa/vqa.py
|
|
@ -1,18 +1,24 @@
|
||||||
from transformers import BlipProcessor, BlipForQuestionAnswering
|
from transformers import BlipProcessor, BlipForQuestionAnswering
|
||||||
from transformers import ViltProcessor, ViltForQuestionAnswering
|
from transformers import ViltProcessor, ViltForQuestionAnswering
|
||||||
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
||||||
from transformers import GitVisionConfig, GitVisionModel, AutoProcessor, GitProcessor
|
from transformers import GitVisionModel, GitProcessor
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from transformers.processing_utils import ProcessorMixin
|
||||||
|
|
||||||
class VQA:
|
|
||||||
|
class VQA(ABC):
|
||||||
name = "Not defined"
|
name = "Not defined"
|
||||||
def query(image, question: str) -> str:
|
@abstractmethod
|
||||||
pass
|
def query(self, image, question: str) -> str:
|
||||||
|
return "Not implemented"
|
||||||
|
|
||||||
class Blip(VQA):
|
class Blip(VQA):
|
||||||
name = "Blip"
|
name = "Blip"
|
||||||
def query(self, image, question):
|
def query(self, image, question):
|
||||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
||||||
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
||||||
|
|
||||||
|
assert processor is ProcessorMixin
|
||||||
inputs = processor(image, question, return_tensors="pt")
|
inputs = processor(image, question, return_tensors="pt")
|
||||||
out = model.generate(max_new_tokens=50000, **inputs)
|
out = model.generate(max_new_tokens=50000, **inputs)
|
||||||
return processor.decode(out[0], skip_special_tokens=True)
|
return processor.decode(out[0], skip_special_tokens=True)
|
||||||
|
|
@ -25,6 +31,7 @@ class Vilt(VQA):
|
||||||
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
||||||
|
|
||||||
# prepare inputs
|
# prepare inputs
|
||||||
|
assert processor is ProcessorMixin
|
||||||
encoding = processor(image, question, return_tensors="pt")
|
encoding = processor(image, question, return_tensors="pt")
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
|
@ -41,6 +48,7 @@ class Deplot(VQA):
|
||||||
processor = Pix2StructProcessor.from_pretrained('google/deplot')
|
processor = Pix2StructProcessor.from_pretrained('google/deplot')
|
||||||
model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
|
model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
|
||||||
|
|
||||||
|
assert processor is ProcessorMixin
|
||||||
inputs = processor(images=image, text=question, return_tensors="pt")
|
inputs = processor(images=image, text=question, return_tensors="pt")
|
||||||
predictions = model.generate(**inputs, max_new_tokens=512)
|
predictions = model.generate(**inputs, max_new_tokens=512)
|
||||||
return processor.decode(predictions[0], skip_special_tokens=True)
|
return processor.decode(predictions[0], skip_special_tokens=True)
|
||||||
|
|
@ -53,6 +61,7 @@ class PixStructDocVA(VQA):
|
||||||
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large")
|
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large")
|
||||||
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large")
|
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large")
|
||||||
|
|
||||||
|
assert processor is ProcessorMixin
|
||||||
inputs = processor(images=image, text=question, return_tensors="pt")
|
inputs = processor(images=image, text=question, return_tensors="pt")
|
||||||
predictions = model.generate(**inputs, max_new_tokens=10000)
|
predictions = model.generate(**inputs, max_new_tokens=10000)
|
||||||
answer = processor.decode(predictions[0], skip_special_tokens=True)
|
answer = processor.decode(predictions[0], skip_special_tokens=True)
|
||||||
|
|
@ -64,6 +73,8 @@ class MicrosoftGIT(VQA):
|
||||||
def query(self, image, question):
|
def query(self, image, question):
|
||||||
processor = GitProcessor.from_pretrained("microsoft/git-base")
|
processor = GitProcessor.from_pretrained("microsoft/git-base")
|
||||||
model = GitVisionModel.from_pretrained("microsoft/git-base")
|
model = GitVisionModel.from_pretrained("microsoft/git-base")
|
||||||
|
|
||||||
|
assert processor is ProcessorMixin
|
||||||
inputs = processor(images=image, text=question, return_tensors="pt")
|
inputs = processor(images=image, text=question, return_tensors="pt")
|
||||||
predictions = model.generate(**inputs, max_new_tokens=10000)
|
predictions = model.generate(**inputs, max_new_tokens=10000)
|
||||||
answer = processor.decode(predictions[0], skip_special_tokens=True)
|
answer = processor.decode(predictions[0], skip_special_tokens=True)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue