wrongmove/vqa/vqa.py

72 lines
2.8 KiB
Python
Raw Normal View History

from transformers import BlipProcessor, BlipForQuestionAnswering
from transformers import ViltProcessor, ViltForQuestionAnswering
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from transformers import GitVisionConfig, GitVisionModel, AutoProcessor, GitProcessor
class VQA:
name = "Not defined"
def query(image, question: str) -> str:
pass
class Blip(VQA):
name = "Blip"
def query(self, image, question):
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
inputs = processor(image, question, return_tensors="pt")
out = model.generate(max_new_tokens=50000, **inputs)
return processor.decode(out[0], skip_special_tokens=True)
class Vilt(VQA):
name = "Vilt"
def query(self, image, question):
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
# prepare inputs
encoding = processor(image, question, return_tensors="pt")
# forward pass
outputs = model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
return model.config.id2label[idx]
class Deplot(VQA):
name = "Deplot"
def query(self, image, question):
processor = Pix2StructProcessor.from_pretrained('google/deplot')
model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs, max_new_tokens=512)
return processor.decode(predictions[0], skip_special_tokens=True)
class PixStructDocVA(VQA):
name = "google/pix2struct-docvqa-large"
def query(self, image, question):
print(question)
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large")
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs, max_new_tokens=10000)
answer = processor.decode(predictions[0], skip_special_tokens=True)
print(answer)
return answer
class MicrosoftGIT(VQA):
name = "microsoft/git-base-textvqa"
def query(self, image, question):
processor = GitProcessor.from_pretrained("microsoft/git-base")
model = GitVisionModel.from_pretrained("microsoft/git-base")
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs, max_new_tokens=10000)
answer = processor.decode(predictions[0], skip_special_tokens=True)
return answer