merging the visual query answering with the crawler. Monorepo go!
This commit is contained in:
parent
85686a8b24
commit
e2f7998ee9
32 changed files with 3449 additions and 0 deletions
72
vqa/vqa.py
Normal file
72
vqa/vqa.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
from transformers import BlipProcessor, BlipForQuestionAnswering
|
||||
from transformers import ViltProcessor, ViltForQuestionAnswering
|
||||
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
||||
from transformers import GitVisionConfig, GitVisionModel, AutoProcessor, GitProcessor
|
||||
|
||||
class VQA:
|
||||
name = "Not defined"
|
||||
def query(image, question: str) -> str:
|
||||
pass
|
||||
|
||||
class Blip(VQA):
|
||||
name = "Blip"
|
||||
def query(self, image, question):
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
||||
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
||||
inputs = processor(image, question, return_tensors="pt")
|
||||
out = model.generate(max_new_tokens=50000, **inputs)
|
||||
return processor.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
class Vilt(VQA):
|
||||
name = "Vilt"
|
||||
def query(self, image, question):
|
||||
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
||||
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
||||
|
||||
# prepare inputs
|
||||
encoding = processor(image, question, return_tensors="pt")
|
||||
|
||||
# forward pass
|
||||
outputs = model(**encoding)
|
||||
logits = outputs.logits
|
||||
idx = logits.argmax(-1).item()
|
||||
return model.config.id2label[idx]
|
||||
|
||||
|
||||
|
||||
class Deplot(VQA):
|
||||
name = "Deplot"
|
||||
def query(self, image, question):
|
||||
processor = Pix2StructProcessor.from_pretrained('google/deplot')
|
||||
model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
|
||||
|
||||
inputs = processor(images=image, text=question, return_tensors="pt")
|
||||
predictions = model.generate(**inputs, max_new_tokens=512)
|
||||
return processor.decode(predictions[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
class PixStructDocVA(VQA):
|
||||
name = "google/pix2struct-docvqa-large"
|
||||
def query(self, image, question):
|
||||
print(question)
|
||||
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large")
|
||||
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large")
|
||||
|
||||
inputs = processor(images=image, text=question, return_tensors="pt")
|
||||
predictions = model.generate(**inputs, max_new_tokens=10000)
|
||||
answer = processor.decode(predictions[0], skip_special_tokens=True)
|
||||
print(answer)
|
||||
return answer
|
||||
|
||||
class MicrosoftGIT(VQA):
|
||||
name = "microsoft/git-base-textvqa"
|
||||
def query(self, image, question):
|
||||
processor = GitProcessor.from_pretrained("microsoft/git-base")
|
||||
model = GitVisionModel.from_pretrained("microsoft/git-base")
|
||||
inputs = processor(images=image, text=question, return_tensors="pt")
|
||||
predictions = model.generate(**inputs, max_new_tokens=10000)
|
||||
answer = processor.decode(predictions[0], skip_special_tokens=True)
|
||||
|
||||
return answer
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue