from transformers import BlipProcessor, BlipForQuestionAnswering from transformers import ViltProcessor, ViltForQuestionAnswering from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration from transformers import GitVisionConfig, GitVisionModel, AutoProcessor, GitProcessor class VQA: name = "Not defined" def query(image, question: str) -> str: pass class Blip(VQA): name = "Blip" def query(self, image, question): processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") inputs = processor(image, question, return_tensors="pt") out = model.generate(max_new_tokens=50000, **inputs) return processor.decode(out[0], skip_special_tokens=True) class Vilt(VQA): name = "Vilt" def query(self, image, question): processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") # prepare inputs encoding = processor(image, question, return_tensors="pt") # forward pass outputs = model(**encoding) logits = outputs.logits idx = logits.argmax(-1).item() return model.config.id2label[idx] class Deplot(VQA): name = "Deplot" def query(self, image, question): processor = Pix2StructProcessor.from_pretrained('google/deplot') model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot') inputs = processor(images=image, text=question, return_tensors="pt") predictions = model.generate(**inputs, max_new_tokens=512) return processor.decode(predictions[0], skip_special_tokens=True) class PixStructDocVA(VQA): name = "google/pix2struct-docvqa-large" def query(self, image, question): print(question) model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large") processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large") inputs = processor(images=image, text=question, return_tensors="pt") predictions = model.generate(**inputs, max_new_tokens=10000) answer = processor.decode(predictions[0], skip_special_tokens=True) print(answer) return answer class MicrosoftGIT(VQA): name = "microsoft/git-base-textvqa" def query(self, image, question): processor = GitProcessor.from_pretrained("microsoft/git-base") model = GitVisionModel.from_pretrained("microsoft/git-base") inputs = processor(images=image, text=question, return_tensors="pt") predictions = model.generate(**inputs, max_new_tokens=10000) answer = processor.decode(predictions[0], skip_special_tokens=True) return answer