ocr-vqa / app.py
ndtran's picture
Update app.py
53594dd
raw
history blame
4.8 kB
import gradio as gr
import torch, os, json, requests
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
from torchvision import transforms
def load_image_from_URL(url):
res = requests.get(url)
if res.status_code == 200:
img = Image.open(requests.get(url, stream = True).raw)
if img.mode == "RGBA":
img = img.convert("RGB")
return img
return None
class OCRVQAModel(torch.nn.Module):
def add_tokens(self, list_of_tokens):
self.added_tokens.update(list_of_tokens)
newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
if newly_added_num > 0:
self.donut.decoder.resize_token_embeddings(len(self.processor.tokenizer))
def __init__(self, config):
super().__init__()
self.model_name_or_path = config['donut']
self.processor_name_or_path = config['processor']
self.config_name_or_path = config['config']
self.donut_config = VisionEncoderDecoderConfig.from_pretrained(self.config_name_or_path)
self.donut_config.encoder.image_size = [800, 600]
self.donut_config.decoder.max_length = 64
self.processor = DonutProcessor.from_pretrained(self.processor_name_or_path)
self.donut = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path, config = self.donut_config)
self.added_tokens = set([])
self.setup()
def setup(self):
self.add_tokens(["<yes/>", "<no/>"])
self.processor.feature_extractor.size = self.donut_config.encoder.image_size[::-1]
self.processor.feature_extractor.do_align_long_axis = False
def inference(self, image, prompt, device):
# try:
self.donut.eval()
with torch.no_grad():
image_ids = self.processor(image, return_tensors="pt").pixel_values.to(device)
question = f'<s_docvqa><s_question>{prompt}</s_question><s_answer>'
embedded_question = self.processor.tokenizer(
question,
add_special_tokens = False,
return_tensors = "pt"
)["input_ids"].to(device)
outputs = self.donut.generate(
image_ids,
decoder_input_ids=embedded_question,
max_length = self.donut.decoder.config.max_position_embeddings,
early_stopping = True,
pad_token_id = self.processor.tokenizer.pad_token_id,
eos_token_id = self.processor.tokenizer.eos_token_id,
use_cache = True,
num_beams = 1,
bad_words_ids = [
[self.processor.tokenizer.unk_token_id]
],
return_dict_in_generate = True
)
return self.processor.token2json(self.processor.batch_decode(outputs.sequences)[0])
# except Exception as e:
# raise e
# return {
# 'question': prompt,
# 'answer': 'Some error occurred during inference time.'
# }
model = OCRVQAModel({
"donut": "ndtran/donut_ocr-vqa-200k",
"processor": "ndtran/donut_ocr-vqa-200k",
"config": "naver-clova-ix/donut-base-finetuned-docvqa"
})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
def get_answer(image, question) -> str:
global model, device
result = model.inference(image, question, device)
return result.get('answer', 'I don\'t know :<')
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_url = gr.Textbox(lines=1, label="Image URL", placeholder="Paste image URL here")
if image_url.value:
print("Loading image from URL...")
image = load_image_from_URL(image_url)
else:
# Or upload from your computer
print("Loading uploaded image...")
image = gr.Image(shape=(224, 224), type="pil")
# image = transforms.ToTensor()(image)
# image = transforms.ToPILImage()(image)
# print(image.size)
with gr.Column():
gr.Markdown(
"""
# OCR-VQA-Donut
This demo using fine-tuned OCR-VQA-Donut model to answer questions about images.
Feel free to try it out!
""")
question = gr.Textbox(lines=5, label="Question")
answer = gr.Label(label="Answer")
ask = gr.Button(label="Get the answer")
ask.click(get_answer, inputs=[image, question], outputs=[answer])
demo.launch()