import gradio as gr import torch, os, json, requests, sys from PIL import Image from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig from torchvision import transforms def load_image_from_URL(url): res = requests.get(url) if res.status_code == 200: img = Image.open(requests.get(url, stream = True).raw) if img.mode == "RGBA": img = img.convert("RGB") return img return None class OCRVQAModel(torch.nn.Module): def add_tokens(self, list_of_tokens): self.added_tokens.update(list_of_tokens) newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens) if newly_added_num > 0: self.donut.decoder.resize_token_embeddings(len(self.processor.tokenizer)) def __init__(self, config): super().__init__() self.model_name_or_path = config['donut'] self.processor_name_or_path = config['processor'] self.config_name_or_path = config['config'] self.donut_config = VisionEncoderDecoderConfig.from_pretrained(self.config_name_or_path) self.donut_config.encoder.image_size = [800, 600] self.donut_config.decoder.max_length = 64 self.processor = DonutProcessor.from_pretrained(self.processor_name_or_path) self.donut = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path, config = self.donut_config) self.added_tokens = set([]) self.setup() def setup(self): self.add_tokens(["", ""]) self.processor.feature_extractor.size = self.donut_config.encoder.image_size[::-1] self.processor.feature_extractor.do_align_long_axis = False def inference(self, image, prompt, device): # try: self.donut.eval() with torch.no_grad(): print(type(image), type(prompt), file = sys.stderr) image_ids = self.processor(image, return_tensors="pt").pixel_values.to(device) question = f'{prompt}' embedded_question = self.processor.tokenizer( question, add_special_tokens = False, return_tensors = "pt" )["input_ids"].to(device) outputs = self.donut.generate( image_ids, decoder_input_ids=embedded_question, max_length = self.donut.decoder.config.max_position_embeddings, early_stopping = True, pad_token_id = self.processor.tokenizer.pad_token_id, eos_token_id = self.processor.tokenizer.eos_token_id, use_cache = True, num_beams = 1, bad_words_ids = [ [self.processor.tokenizer.unk_token_id] ], return_dict_in_generate = True ) return self.processor.token2json(self.processor.batch_decode(outputs.sequences)[0]) # except Exception as e: # raise e # return { # 'question': prompt, # 'answer': 'Some error occurred during inference time.' # } model = OCRVQAModel({ "donut": "ndtran/donut_ocr-vqa-200k", "processor": "ndtran/donut_ocr-vqa-200k", "config": "naver-clova-ix/donut-base-finetuned-docvqa" }) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) def get_answer(image, url, question) -> str: global model, device if url is not None and (url.startswith('http') or url.startswith('https')): result = model.inference(load_image_from_URL(url), question, device) return result.get('answer', 'I don\'t know :<') result = model.inference(image, question, device) return result.get('answer', 'I don\'t know :<') with gr.Blocks() as demo: with gr.Row(): gr.Markdown( """ ## Donut-OCR-VQA - This demo uses fine-tuned OCR-VQA-Donut model on the OCR-VQA-200k dataset to answer questions about images. ## IO description - Input is an image or URL that represents a book cover (recommended) and a question that asks about information on the image. - Output: an answer to the question. """ ) with gr.Row(): with gr.Column(): image = gr.Image(shape=(224, 224), type="pil", label="Pick an image") image_url = gr.Textbox(lines=1, label="Or use this option!", placeholder="Enter the image URL here") question = gr.Textbox(lines=5, label="Question") ask = gr.Button(label="Get the answer") with gr.Column(): answer = gr.Label(label="Answer") ask.click(get_answer, inputs=[image, image_url, question], outputs=[answer]) demo.launch()