ocr-vqa / app.py
ndtran's picture
Update app.py
a281098
raw
history blame
4.72 kB
import gradio as gr
import torch, os, json, requests
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
from torchvision import transforms
def load_image_from_URL(url):
res = requests.get(url)
if res.status_code == 200:
img = Image.open(requests.get(url, stream = True).raw)
if img.mode == "RGBA":
img = img.convert("RGB")
return img
return None
class OCRVQAModel(torch.nn.Module):
def add_tokens(self, list_of_tokens):
self.added_tokens.update(list_of_tokens)
newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
if newly_added_num > 0:
self.donut.decoder.resize_token_embeddings(len(self.processor.tokenizer))
def __init__(self, config):
super().__init__()
self.model_name_or_path = config['donut']
self.processor_name_or_path = config['processor']
self.config_name_or_path = config['config']
self.donut_config = VisionEncoderDecoderConfig.from_pretrained(self.config_name_or_path)
self.donut_config.encoder.image_size = [800, 600]
self.donut_config.decoder.max_length = 64
self.processor = DonutProcessor.from_pretrained(self.processor_name_or_path)
self.donut = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path, config = self.donut_config)
self.added_tokens = set([])
self.setup()
def setup(self):
self.add_tokens(["<yes/>", "<no/>"])
self.processor.feature_extractor.size = self.donut_config.encoder.image_size[::-1]
self.processor.feature_extractor.do_align_long_axis = False
def inference(self, image, prompt, device):
# try:
self.donut.eval()
with torch.no_grad():
image_ids = self.processor(image, return_tensors="pt").pixel_values.to(device)
question = f'<s_docvqa><s_question>{prompt}</s_question><s_answer>'
embedded_question = self.processor.tokenizer(
question,
add_special_tokens = False,
return_tensors = "pt"
)["input_ids"].to(device)
outputs = self.donut.generate(
image_ids,
decoder_input_ids=embedded_question,
max_length = self.donut.decoder.config.max_position_embeddings,
early_stopping = True,
pad_token_id = self.processor.tokenizer.pad_token_id,
eos_token_id = self.processor.tokenizer.eos_token_id,
use_cache = True,
num_beams = 1,
bad_words_ids = [
[self.processor.tokenizer.unk_token_id]
],
return_dict_in_generate = True
)
return self.processor.token2json(self.processor.batch_decode(outputs.sequences)[0])
# except Exception as e:
# raise e
# return {
# 'question': prompt,
# 'answer': 'Some error occurred during inference time.'
# }
model = OCRVQAModel({
"donut": "ndtran/donut_ocr-vqa-200k",
"processor": "ndtran/donut_ocr-vqa-200k",
"config": "naver-clova-ix/donut-base-finetuned-docvqa"
})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
def get_answer(image, url, question) -> str:
global model, device
if url is not None and (url.startswith('http') or url.startswith('https')):
result = model.inference(load_image_from_URL(url), question, device)
return result.get('answer', 'I don\'t know :<')
result = model.inference(image, question, device)
return result.get('answer', 'I don\'t know :<')
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown(
"""
# OCR-VQA-Donut
This demo using fine-tuned OCR-VQA-Donut model to answer questions about images.
Feel free to try it out!
"""
)
image = gr.Image(shape=(224, 224), type="pil")
image_url = gr.Textbox(lines=1, label="Image URL", placeholder="Or, paste the image URL here")
question = gr.Textbox(lines=5, label="Question")
with gr.Column():
ask = gr.Button(label="Get the answer")
with gr.Column():
answer = gr.Label(label="Answer")
ask.click(get_answer, inputs=[image, image_url, question], outputs=[answer])
demo.launch(shared = True)