hoang-quoc-trung commited on
Commit
e972659
·
1 Parent(s): 6fa0eea

Add application file

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import logging
4
+ import argparse
5
+ import streamlit as st
6
+ import nltk
7
+ import evaluate
8
+ from PIL import Image
9
+ from transformers import AutoProcessor
10
+ from transformers import VisionEncoderDecoderModel
11
+ from src.utils import common_utils
12
+ from nltk import edit_distance as compute_edit_distance
13
+ from src.utils.common_utils import compute_exprate
14
+
15
+ bleu_func = evaluate.load("bleu")
16
+ wer_func = evaluate.load("wer")
17
+ exact_match_func = evaluate.load("exact_match")
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s"
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+ logger.setLevel(logging.INFO)
24
+
25
+
26
+ def main(args):
27
+ @st.cache_resource
28
+ def init_model():
29
+ # Get the device
30
+ device = common_utils.check_device(logger)
31
+ # Init model
32
+ logger.info("Load model & processor from: {}".format(args.ckpt))
33
+ model = VisionEncoderDecoderModel.from_pretrained(
34
+ args.ckpt
35
+ ).to(device)
36
+ # Load processor
37
+ processor = AutoProcessor.from_pretrained(args.ckpt)
38
+ task_prompt = processor.tokenizer.bos_token
39
+ decoder_input_ids = processor.tokenizer(
40
+ task_prompt,
41
+ add_special_tokens=False,
42
+ return_tensors="pt"
43
+ ).input_ids
44
+ return model, processor, decoder_input_ids, device
45
+
46
+ model, processor, decoder_input_ids, device = init_model()
47
+
48
+ @st.cache_data
49
+ def inference(input_image):
50
+ # Load image
51
+ logger.info("\nLoad image from: {}".format(input_image))
52
+ image = Image.open(input_image)
53
+ if not image.mode == "RGB":
54
+ image = image.convert('RGB')
55
+ pixel_values = processor.image_processor(
56
+ image,
57
+ return_tensors="pt",
58
+ data_format="channels_first",
59
+ ).pixel_values
60
+ # Generate LaTeX expression
61
+ with torch.no_grad():
62
+ outputs = model.generate(
63
+ pixel_values.to(device),
64
+ decoder_input_ids=decoder_input_ids.to(device),
65
+ max_length=model.decoder.config.max_length,
66
+ pad_token_id=processor.tokenizer.pad_token_id,
67
+ eos_token_id=processor.tokenizer.eos_token_id,
68
+ use_cache=True,
69
+ num_beams=4,
70
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
71
+ return_dict_in_generate=True,
72
+ )
73
+ sequence = processor.tokenizer.batch_decode(outputs.sequences)[0]
74
+ sequence = sequence.replace(
75
+ processor.tokenizer.eos_token, ""
76
+ ).replace(
77
+ processor.tokenizer.pad_token, ""
78
+ ).replace(processor.tokenizer.bos_token,"")
79
+ logger.info("Output: {}".format(sequence))
80
+ return sequence
81
+
82
+ @st.cache_data
83
+ def compute_crohme_metrics(label_str, pred_str):
84
+ wer = wer_func.compute(predictions=[pred_str], references=[label_str])
85
+ # Compute expression rate score
86
+ exprate, error_1, error_2, error_3 = compute_exprate(
87
+ predictions=[pred_str],
88
+ references=[label_str]
89
+ )
90
+ return round(wer*100, 2), round(exprate*100, 2), round(error_1*100, 2), round(error_2*100, 2), round(error_3*100, 2)
91
+
92
+
93
+ @st.cache_data
94
+ def compute_img2latex100k_metrics(label_str, pred_str):
95
+ # Compute edit distance score
96
+ edit_distance = compute_edit_distance(
97
+ pred_str,
98
+ label_str
99
+ )/max(len(pred_str),len(label_str))
100
+ # Convert minimun edit distance score to maximun edit distance score
101
+ edit_distance = round((1 - edit_distance)*100, 2)
102
+ # Compute bleu score
103
+ bleu = bleu_func.compute(
104
+ predictions=[pred_str],
105
+ references=[label_str],
106
+ max_order=4 # Maximum n-gram order to use when computing BLEU score
107
+ )
108
+ bleu = round(bleu['bleu']*100, 2)
109
+ exact_match = exact_match_func.compute(
110
+ predictions=[pred_str],
111
+ references=[label_str]
112
+ )
113
+ exact_match = round(exact_match['exact_match']*100, 2)
114
+ return bleu, edit_distance, exact_match
115
+
116
+ # --------------------------------- Sreamlit code ---------------------------------
117
+
118
+ st.markdown("<h1 style='text-align: center; color: LightSkyBlue;'>Math Formula Images To LaTeX Code Based On End-to-End Approach With Attention Mechanism</h1>", unsafe_allow_html=True)
119
+ st.write('')
120
+ st.write('')
121
+ st.write('')
122
+ st.header('Input', divider='blue')
123
+ uploaded_file = st.file_uploader(
124
+ "Upload an image",
125
+ type = ['png', 'jpg'],
126
+ )
127
+ if uploaded_file is not None:
128
+ bytes_data = uploaded_file.read()
129
+ st.image(
130
+ bytes_data,
131
+ width = 700,
132
+ channels = 'RGB',
133
+ output_format = 'PNG'
134
+ )
135
+ on = st.toggle('Enable testing with label')
136
+
137
+ if on:
138
+ with st.container(border=True):
139
+ option = st.selectbox(
140
+ 'Benchmark ?',
141
+ ('Im2latex-100k', 'CROHME'))
142
+ label = st.text_input('Label', None)
143
+ run = st.button("Run")
144
+
145
+ if run is True and uploaded_file is not None and label is not None and option == 'Im2latex-100k':
146
+ pred_str = inference(uploaded_file)
147
+ st.header('Output', divider='blue')
148
+ st.latex(pred_str)
149
+ st.write(':orange[Latex sequences:]', pred_str)
150
+ bleu, edit_distance, exact_match = compute_img2latex100k_metrics(label, pred_str)
151
+ with st.container(border=True):
152
+ col1, col2, col3 = st.columns(3)
153
+ col1.metric("Bleu", bleu)
154
+ col2.metric("Edit Distance", edit_distance)
155
+ col3.metric("Exact Match", exact_match)
156
+
157
+ if run is True and uploaded_file is not None and label is not None and option == 'CROHME':
158
+ pred_str = inference(uploaded_file)
159
+ st.header('Output', divider='blue')
160
+ st.latex(pred_str)
161
+ st.write(':orange[Latex sequences:]', pred_str)
162
+ wer, exprate, error_1, error_2, error_3 = compute_crohme_metrics(label, pred_str)
163
+ with st.container(border=True):
164
+ col1, col2, col3, col4, col5 = st.columns(5)
165
+ col1.metric("ExpRate", exprate)
166
+ col2.metric("ExpRate 1", error_1)
167
+ col3.metric("ExpRate 2", error_2)
168
+ col4.metric("ExpRate 3", error_3)
169
+ col5.metric("WER", wer)
170
+
171
+ else:
172
+ run = st.button("Run")
173
+ if run is True and uploaded_file is not None:
174
+ pred_str = inference(uploaded_file)
175
+ st.write('')
176
+ st.header('Output', divider='blue')
177
+ st.latex(pred_str)
178
+ st.write(':orange[Latex sequences:]', pred_str)
179
+
180
+
181
+ if __name__ == "__main__":
182
+ parser = argparse.ArgumentParser(description="Sumen Latex OCR")
183
+ parser.add_argument(
184
+ "--ckpt",
185
+ type=str,
186
+ default="checkpoints",
187
+ help="Path to the checkpoint",
188
+ )
189
+ args = parser.parse_args()
190
+ main(args)