ruff format
This commit is contained in:
parent
37e3e8ad6f
commit
d777558b34
17 changed files with 411 additions and 368 deletions
|
|
@ -3,21 +3,22 @@ from PIL import Image
|
|||
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
||||
import pytesseract
|
||||
|
||||
|
||||
def inference(image_path):
|
||||
image = Image.open(image_path)
|
||||
question = "How many living rooms are displayed on this floor plan?" # not sure if it even has an effect
|
||||
processor = Pix2StructProcessor.from_pretrained('google/deplot')
|
||||
model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
|
||||
question = "How many living rooms are displayed on this floor plan?" # not sure if it even has an effect
|
||||
processor = Pix2StructProcessor.from_pretrained("google/deplot")
|
||||
model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
|
||||
|
||||
inputs = processor(images=image, text=question, return_tensors="pt")
|
||||
predictions = model.generate(**inputs, max_new_tokens=512)
|
||||
output = processor.decode(predictions[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
return output, predictions
|
||||
|
||||
|
||||
|
||||
def extract_total_sqm(deplot_input_str):
|
||||
sqmregex = r'(\d+\.\d*) ?(sq ?m|sq. ?m)'
|
||||
sqmregex = r"(\d+\.\d*) ?(sq ?m|sq. ?m)"
|
||||
matches = re.findall(sqmregex, deplot_input_str.lower())
|
||||
if len(matches) == 0:
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue