ocr-vqa / app.py
ndtran's picture
Update app.py
d928620
raw
history blame
4.94 kB
import gradio as gr
import torch, os, json, requests, sys
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
from torchvision import transforms
def load_image_from_URL(url):
res = requests.get(url)
if res.status_code == 200:
img = Image.open(requests.get(url, stream = True).raw)
if img.mode == "RGBA":
img = img.convert("RGB")
return img
return None
class OCRVQAModel(torch.nn.Module):
def add_tokens(self, list_of_tokens):
self.added_tokens.update(list_of_tokens)
newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
if newly_added_num > 0:
self.donut.decoder.resize_token_embeddings(len(self.processor.tokenizer))
def __init__(self, config):
super().__init__()
self.model_name_or_path = config['donut']
self.processor_name_or_path = config['processor']
self.config_name_or_path = config['config']
self.donut_config = VisionEncoderDecoderConfig.from_pretrained(self.config_name_or_path)
self.donut_config.encoder.image_size = [800, 600]
self.donut_config.decoder.max_length = 64
self.processor = DonutProcessor.from_pretrained(self.processor_name_or_path)
self.donut = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path, config = self.donut_config)
self.added_tokens = set([])
self.setup()
def setup(self):
self.add_tokens(["<yes/>", "<no/>"])
self.processor.feature_extractor.size = self.donut_config.encoder.image_size[::-1]
self.processor.feature_extractor.do_align_long_axis = False
def inference(self, image, prompt, device):
# try:
self.donut.eval()
with torch.no_grad():
print(type(image), type(prompt), file = sys.stderr)
image_ids = self.processor(image, return_tensors="pt").pixel_values.to(device)
question = f'<s_docvqa><s_question>{prompt}</s_question><s_answer>'
embedded_question = self.processor.tokenizer(
question,
add_special_tokens = False,
return_tensors = "pt"
)["input_ids"].to(device)
outputs = self.donut.generate(
image_ids,
decoder_input_ids=embedded_question,
max_length = self.donut.decoder.config.max_position_embeddings,
early_stopping = True,
pad_token_id = self.processor.tokenizer.pad_token_id,
eos_token_id = self.processor.tokenizer.eos_token_id,
use_cache = True,
num_beams = 1,
bad_words_ids = [
[self.processor.tokenizer.unk_token_id]
],
return_dict_in_generate = True
)
return self.processor.token2json(self.processor.batch_decode(outputs.sequences)[0])
# except Exception as e:
# raise e
# return {
# 'question': prompt,
# 'answer': 'Some error occurred during inference time.'
# }
model = OCRVQAModel({
"donut": "ndtran/donut_ocr-vqa-200k",
"processor": "ndtran/donut_ocr-vqa-200k",
"config": "naver-clova-ix/donut-base-finetuned-docvqa"
})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
def get_answer(image, url, question) -> str:
global model, device
if url is not None and (url.startswith('http') or url.startswith('https')):
result = model.inference(load_image_from_URL(url), question, device)
return result.get('answer', 'I don\'t know :<')
result = model.inference(image, question, device)
return result.get('answer', 'I don\'t know :<')
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(
"""
## Donut-OCR-VQA
- This demo uses fine-tuned OCR-VQA-Donut model on OCR-VQA-200k to answer questions about images.
## IO description
- Input is an image or URL that represents a book cover (recommended) and a question that asks about information on the image.
- Output: an answer to the question.
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(shape=(224, 224), type="pil")
image_url = gr.Textbox(lines=1, label="Image URL", placeholder="Or, paste the image URL here")
question = gr.Textbox(lines=5, label="Question")
ask = gr.Button(label="Get the answer")
with gr.Column():
answer = gr.Label(label="Answer")
ask.click(get_answer, inputs=[image, image_url, question], outputs=[answer])
demo.launch()