adding ruff auto check for pull requests as well as fixing all ruff errors (#1)
Co-authored-by: Kadir <git@k8n.dev>
This commit is contained in:
parent
b1e0a414cf
commit
4c23acdb55
5 changed files with 60 additions and 10 deletions
|
|
@ -24,7 +24,6 @@
|
|||
"source": [
|
||||
"from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration\n",
|
||||
"from PIL import Image\n",
|
||||
"import pandas as pd\n",
|
||||
"import re"
|
||||
]
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from vqa import Blip, MicrosoftGIT, PixStructDocVA, Vilt, Deplot, VQA
|
||||
from vqa import MicrosoftGIT, VQA
|
||||
from PIL import Image
|
||||
from typing import List
|
||||
from questions import load_questions
|
||||
|
|
|
|||
19
vqa/vqa.py
19
vqa/vqa.py
|
|
@ -1,18 +1,24 @@
|
|||
from transformers import BlipProcessor, BlipForQuestionAnswering
|
||||
from transformers import ViltProcessor, ViltForQuestionAnswering
|
||||
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
||||
from transformers import GitVisionConfig, GitVisionModel, AutoProcessor, GitProcessor
|
||||
from transformers import GitVisionModel, GitProcessor
|
||||
from abc import ABC, abstractmethod
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
class VQA:
|
||||
|
||||
class VQA(ABC):
|
||||
name = "Not defined"
|
||||
def query(image, question: str) -> str:
|
||||
pass
|
||||
@abstractmethod
|
||||
def query(self, image, question: str) -> str:
|
||||
return "Not implemented"
|
||||
|
||||
class Blip(VQA):
|
||||
name = "Blip"
|
||||
def query(self, image, question):
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
||||
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
||||
|
||||
assert processor is ProcessorMixin
|
||||
inputs = processor(image, question, return_tensors="pt")
|
||||
out = model.generate(max_new_tokens=50000, **inputs)
|
||||
return processor.decode(out[0], skip_special_tokens=True)
|
||||
|
|
@ -25,6 +31,7 @@ class Vilt(VQA):
|
|||
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
||||
|
||||
# prepare inputs
|
||||
assert processor is ProcessorMixin
|
||||
encoding = processor(image, question, return_tensors="pt")
|
||||
|
||||
# forward pass
|
||||
|
|
@ -41,6 +48,7 @@ class Deplot(VQA):
|
|||
processor = Pix2StructProcessor.from_pretrained('google/deplot')
|
||||
model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
|
||||
|
||||
assert processor is ProcessorMixin
|
||||
inputs = processor(images=image, text=question, return_tensors="pt")
|
||||
predictions = model.generate(**inputs, max_new_tokens=512)
|
||||
return processor.decode(predictions[0], skip_special_tokens=True)
|
||||
|
|
@ -53,6 +61,7 @@ class PixStructDocVA(VQA):
|
|||
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large")
|
||||
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large")
|
||||
|
||||
assert processor is ProcessorMixin
|
||||
inputs = processor(images=image, text=question, return_tensors="pt")
|
||||
predictions = model.generate(**inputs, max_new_tokens=10000)
|
||||
answer = processor.decode(predictions[0], skip_special_tokens=True)
|
||||
|
|
@ -64,6 +73,8 @@ class MicrosoftGIT(VQA):
|
|||
def query(self, image, question):
|
||||
processor = GitProcessor.from_pretrained("microsoft/git-base")
|
||||
model = GitVisionModel.from_pretrained("microsoft/git-base")
|
||||
|
||||
assert processor is ProcessorMixin
|
||||
inputs = processor(images=image, text=question, return_tensors="pt")
|
||||
predictions = model.generate(**inputs, max_new_tokens=10000)
|
||||
answer = processor.decode(predictions[0], skip_special_tokens=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue