File size: 4,981 Bytes
53594dd 4ab35d5 53594dd 4ab35d5 53594dd 144528c 53594dd a281098 53594dd d928620 cd98f10 d928620 53594dd 95ed7b5 ea8b527 144528c 53594dd 144528c 53594dd 144528c 53594dd a364253 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
import torch, os, json, requests, sys
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
from torchvision import transforms
def load_image_from_URL(url):
res = requests.get(url)
if res.status_code == 200:
img = Image.open(requests.get(url, stream = True).raw)
if img.mode == "RGBA":
img = img.convert("RGB")
return img
return None
class OCRVQAModel(torch.nn.Module):
def add_tokens(self, list_of_tokens):
self.added_tokens.update(list_of_tokens)
newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
if newly_added_num > 0:
self.donut.decoder.resize_token_embeddings(len(self.processor.tokenizer))
def __init__(self, config):
super().__init__()
self.model_name_or_path = config['donut']
self.processor_name_or_path = config['processor']
self.config_name_or_path = config['config']
self.donut_config = VisionEncoderDecoderConfig.from_pretrained(self.config_name_or_path)
self.donut_config.encoder.image_size = [800, 600]
self.donut_config.decoder.max_length = 64
self.processor = DonutProcessor.from_pretrained(self.processor_name_or_path)
self.donut = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path, config = self.donut_config)
self.added_tokens = set([])
self.setup()
def setup(self):
self.add_tokens(["<yes/>", "<no/>"])
self.processor.feature_extractor.size = self.donut_config.encoder.image_size[::-1]
self.processor.feature_extractor.do_align_long_axis = False
def inference(self, image, prompt, device):
# try:
self.donut.eval()
with torch.no_grad():
print(type(image), type(prompt), file = sys.stderr)
image_ids = self.processor(image, return_tensors="pt").pixel_values.to(device)
question = f'<s_docvqa><s_question>{prompt}</s_question><s_answer>'
embedded_question = self.processor.tokenizer(
question,
add_special_tokens = False,
return_tensors = "pt"
)["input_ids"].to(device)
outputs = self.donut.generate(
image_ids,
decoder_input_ids=embedded_question,
max_length = self.donut.decoder.config.max_position_embeddings,
early_stopping = True,
pad_token_id = self.processor.tokenizer.pad_token_id,
eos_token_id = self.processor.tokenizer.eos_token_id,
use_cache = True,
num_beams = 1,
bad_words_ids = [
[self.processor.tokenizer.unk_token_id]
],
return_dict_in_generate = True
)
return self.processor.token2json(self.processor.batch_decode(outputs.sequences)[0])
# except Exception as e:
# raise e
# return {
# 'question': prompt,
# 'answer': 'Some error occurred during inference time.'
# }
model = OCRVQAModel({
"donut": "ndtran/donut_ocr-vqa-200k",
"processor": "ndtran/donut_ocr-vqa-200k",
"config": "naver-clova-ix/donut-base-finetuned-docvqa"
})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
def get_answer(image, url, question) -> str:
global model, device
if url is not None and (url.startswith('http') or url.startswith('https')):
result = model.inference(load_image_from_URL(url), question, device)
return result.get('answer', 'I don\'t know :<')
result = model.inference(image, question, device)
return result.get('answer', 'I don\'t know :<')
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(
"""
## Donut-OCR-VQA
- This demo uses fine-tuned OCR-VQA-Donut model on the OCR-VQA-200k dataset to answer questions about images.
## IO description
- Input is an image or URL that represents a book cover (recommended) and a question that asks about information on the image.
- Output: an answer to the question.
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(shape=(224, 224), type="pil", label="Pick an image")
image_url = gr.Textbox(lines=1, label="Or use this option!", placeholder="Enter the image URL here")
question = gr.Textbox(lines=5, label="Question")
ask = gr.Button(label="Get the answer")
with gr.Column():
answer = gr.Label(label="Answer")
ask.click(get_answer, inputs=[image, image_url, question], outputs=[answer])
demo.launch() |