wrongmove/vqa/vqa.py
2025-09-14 19:40:18 +01:00

83 lines
No EOL
3.2 KiB
Python

from transformers import BlipProcessor, BlipForQuestionAnswering
from transformers import ViltProcessor, ViltForQuestionAnswering
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from transformers import GitVisionModel, GitProcessor
from abc import ABC, abstractmethod
from transformers.processing_utils import ProcessorMixin
class VQA(ABC):
name = "Not defined"
@abstractmethod
def query(self, image, question: str) -> str:
return "Not implemented"
class Blip(VQA):
name = "Blip"
def query(self, image, question):
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
assert processor is ProcessorMixin
inputs = processor(image, question, return_tensors="pt")
out = model.generate(max_new_tokens=50000, **inputs)
return processor.decode(out[0], skip_special_tokens=True)
class Vilt(VQA):
name = "Vilt"
def query(self, image, question):
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
# prepare inputs
assert processor is ProcessorMixin
encoding = processor(image, question, return_tensors="pt")
# forward pass
outputs = model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
return model.config.id2label[idx]
class Deplot(VQA):
name = "Deplot"
def query(self, image, question):
processor = Pix2StructProcessor.from_pretrained('google/deplot')
model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
assert processor is ProcessorMixin
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs, max_new_tokens=512)
return processor.decode(predictions[0], skip_special_tokens=True)
class PixStructDocVA(VQA):
name = "google/pix2struct-docvqa-large"
def query(self, image, question):
print(question)
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large")
assert processor is ProcessorMixin
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs, max_new_tokens=10000)
answer = processor.decode(predictions[0], skip_special_tokens=True)
print(answer)
return answer
class MicrosoftGIT(VQA):
name = "microsoft/git-base-textvqa"
def query(self, image, question):
processor = GitProcessor.from_pretrained("microsoft/git-base")
model = GitVisionModel.from_pretrained("microsoft/git-base")
assert processor is ProcessorMixin
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs, max_new_tokens=10000)
answer = processor.decode(predictions[0], skip_special_tokens=True)
return answer