ndtran commited on
Commit
53594dd
·
1 Parent(s): d1f4161

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -0
app.py CHANGED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch, os, json, requests
3
+ from PIL import Image
4
+ from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
5
+ from torchvision import transforms
6
+
7
+ def load_image_from_URL(url):
8
+ res = requests.get(url)
9
+
10
+ if res.status_code == 200:
11
+ img = Image.open(requests.get(url, stream = True).raw)
12
+
13
+ if img.mode == "RGBA":
14
+ img = img.convert("RGB")
15
+
16
+ return img
17
+
18
+ return None
19
+
20
+ class OCRVQAModel(torch.nn.Module):
21
+ def add_tokens(self, list_of_tokens):
22
+ self.added_tokens.update(list_of_tokens)
23
+ newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
24
+
25
+ if newly_added_num > 0:
26
+ self.donut.decoder.resize_token_embeddings(len(self.processor.tokenizer))
27
+
28
+ def __init__(self, config):
29
+ super().__init__()
30
+
31
+ self.model_name_or_path = config['donut']
32
+ self.processor_name_or_path = config['processor']
33
+ self.config_name_or_path = config['config']
34
+
35
+ self.donut_config = VisionEncoderDecoderConfig.from_pretrained(self.config_name_or_path)
36
+ self.donut_config.encoder.image_size = [800, 600]
37
+ self.donut_config.decoder.max_length = 64
38
+
39
+ self.processor = DonutProcessor.from_pretrained(self.processor_name_or_path)
40
+ self.donut = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path, config = self.donut_config)
41
+
42
+ self.added_tokens = set([])
43
+ self.setup()
44
+
45
+ def setup(self):
46
+ self.add_tokens(["<yes/>", "<no/>"])
47
+ self.processor.feature_extractor.size = self.donut_config.encoder.image_size[::-1]
48
+ self.processor.feature_extractor.do_align_long_axis = False
49
+
50
+ def inference(self, image, prompt, device):
51
+ # try:
52
+ self.donut.eval()
53
+ with torch.no_grad():
54
+ image_ids = self.processor(image, return_tensors="pt").pixel_values.to(device)
55
+
56
+ question = f'<s_docvqa><s_question>{prompt}</s_question><s_answer>'
57
+
58
+ embedded_question = self.processor.tokenizer(
59
+ question,
60
+ add_special_tokens = False,
61
+ return_tensors = "pt"
62
+ )["input_ids"].to(device)
63
+
64
+ outputs = self.donut.generate(
65
+ image_ids,
66
+ decoder_input_ids=embedded_question,
67
+ max_length = self.donut.decoder.config.max_position_embeddings,
68
+ early_stopping = True,
69
+ pad_token_id = self.processor.tokenizer.pad_token_id,
70
+ eos_token_id = self.processor.tokenizer.eos_token_id,
71
+ use_cache = True,
72
+ num_beams = 1,
73
+ bad_words_ids = [
74
+ [self.processor.tokenizer.unk_token_id]
75
+ ],
76
+ return_dict_in_generate = True
77
+ )
78
+
79
+ return self.processor.token2json(self.processor.batch_decode(outputs.sequences)[0])
80
+ # except Exception as e:
81
+ # raise e
82
+ # return {
83
+ # 'question': prompt,
84
+ # 'answer': 'Some error occurred during inference time.'
85
+ # }
86
+
87
+ model = OCRVQAModel({
88
+ "donut": "ndtran/donut_ocr-vqa-200k",
89
+ "processor": "ndtran/donut_ocr-vqa-200k",
90
+ "config": "naver-clova-ix/donut-base-finetuned-docvqa"
91
+ })
92
+
93
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
94
+ model = model.to(device)
95
+
96
+ def get_answer(image, question) -> str:
97
+ global model, device
98
+
99
+ result = model.inference(image, question, device)
100
+ return result.get('answer', 'I don\'t know :<')
101
+
102
+
103
+ with gr.Blocks() as demo:
104
+ with gr.Row():
105
+ with gr.Column():
106
+ image_url = gr.Textbox(lines=1, label="Image URL", placeholder="Paste image URL here")
107
+
108
+ if image_url.value:
109
+ print("Loading image from URL...")
110
+ image = load_image_from_URL(image_url)
111
+ else:
112
+ # Or upload from your computer
113
+ print("Loading uploaded image...")
114
+ image = gr.Image(shape=(224, 224), type="pil")
115
+ # image = transforms.ToTensor()(image)
116
+ # image = transforms.ToPILImage()(image)
117
+
118
+ # print(image.size)
119
+
120
+ with gr.Column():
121
+ gr.Markdown(
122
+ """
123
+ # OCR-VQA-Donut
124
+ This demo using fine-tuned OCR-VQA-Donut model to answer questions about images.
125
+
126
+ Feel free to try it out!
127
+
128
+ """)
129
+ question = gr.Textbox(lines=5, label="Question")
130
+ answer = gr.Label(label="Answer")
131
+ ask = gr.Button(label="Get the answer")
132
+
133
+ ask.click(get_answer, inputs=[image, question], outputs=[answer])
134
+
135
+ demo.launch()