83 lines
No EOL
3.2 KiB
Python
83 lines
No EOL
3.2 KiB
Python
from transformers import BlipProcessor, BlipForQuestionAnswering
|
|
from transformers import ViltProcessor, ViltForQuestionAnswering
|
|
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
|
from transformers import GitVisionModel, GitProcessor
|
|
from abc import ABC, abstractmethod
|
|
from transformers.processing_utils import ProcessorMixin
|
|
|
|
|
|
class VQA(ABC):
|
|
name = "Not defined"
|
|
@abstractmethod
|
|
def query(self, image, question: str) -> str:
|
|
return "Not implemented"
|
|
|
|
class Blip(VQA):
|
|
name = "Blip"
|
|
def query(self, image, question):
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
|
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
|
|
|
assert processor is ProcessorMixin
|
|
inputs = processor(image, question, return_tensors="pt")
|
|
out = model.generate(max_new_tokens=50000, **inputs)
|
|
return processor.decode(out[0], skip_special_tokens=True)
|
|
|
|
|
|
class Vilt(VQA):
|
|
name = "Vilt"
|
|
def query(self, image, question):
|
|
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
|
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
|
|
|
# prepare inputs
|
|
assert processor is ProcessorMixin
|
|
encoding = processor(image, question, return_tensors="pt")
|
|
|
|
# forward pass
|
|
outputs = model(**encoding)
|
|
logits = outputs.logits
|
|
idx = logits.argmax(-1).item()
|
|
return model.config.id2label[idx]
|
|
|
|
|
|
|
|
class Deplot(VQA):
|
|
name = "Deplot"
|
|
def query(self, image, question):
|
|
processor = Pix2StructProcessor.from_pretrained('google/deplot')
|
|
model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
|
|
|
|
assert processor is ProcessorMixin
|
|
inputs = processor(images=image, text=question, return_tensors="pt")
|
|
predictions = model.generate(**inputs, max_new_tokens=512)
|
|
return processor.decode(predictions[0], skip_special_tokens=True)
|
|
|
|
|
|
class PixStructDocVA(VQA):
|
|
name = "google/pix2struct-docvqa-large"
|
|
def query(self, image, question):
|
|
print(question)
|
|
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large")
|
|
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large")
|
|
|
|
assert processor is ProcessorMixin
|
|
inputs = processor(images=image, text=question, return_tensors="pt")
|
|
predictions = model.generate(**inputs, max_new_tokens=10000)
|
|
answer = processor.decode(predictions[0], skip_special_tokens=True)
|
|
print(answer)
|
|
return answer
|
|
|
|
class MicrosoftGIT(VQA):
|
|
name = "microsoft/git-base-textvqa"
|
|
def query(self, image, question):
|
|
processor = GitProcessor.from_pretrained("microsoft/git-base")
|
|
model = GitVisionModel.from_pretrained("microsoft/git-base")
|
|
|
|
assert processor is ProcessorMixin
|
|
inputs = processor(images=image, text=question, return_tensors="pt")
|
|
predictions = model.generate(**inputs, max_new_tokens=10000)
|
|
answer = processor.decode(predictions[0], skip_special_tokens=True)
|
|
|
|
return answer
|
|
|