Spaces:
Sleeping
Sleeping
Commit
·
e972659
1
Parent(s):
6fa0eea
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import logging
|
4 |
+
import argparse
|
5 |
+
import streamlit as st
|
6 |
+
import nltk
|
7 |
+
import evaluate
|
8 |
+
from PIL import Image
|
9 |
+
from transformers import AutoProcessor
|
10 |
+
from transformers import VisionEncoderDecoderModel
|
11 |
+
from src.utils import common_utils
|
12 |
+
from nltk import edit_distance as compute_edit_distance
|
13 |
+
from src.utils.common_utils import compute_exprate
|
14 |
+
|
15 |
+
bleu_func = evaluate.load("bleu")
|
16 |
+
wer_func = evaluate.load("wer")
|
17 |
+
exact_match_func = evaluate.load("exact_match")
|
18 |
+
|
19 |
+
logging.basicConfig(
|
20 |
+
level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s"
|
21 |
+
)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
logger.setLevel(logging.INFO)
|
24 |
+
|
25 |
+
|
26 |
+
def main(args):
|
27 |
+
@st.cache_resource
|
28 |
+
def init_model():
|
29 |
+
# Get the device
|
30 |
+
device = common_utils.check_device(logger)
|
31 |
+
# Init model
|
32 |
+
logger.info("Load model & processor from: {}".format(args.ckpt))
|
33 |
+
model = VisionEncoderDecoderModel.from_pretrained(
|
34 |
+
args.ckpt
|
35 |
+
).to(device)
|
36 |
+
# Load processor
|
37 |
+
processor = AutoProcessor.from_pretrained(args.ckpt)
|
38 |
+
task_prompt = processor.tokenizer.bos_token
|
39 |
+
decoder_input_ids = processor.tokenizer(
|
40 |
+
task_prompt,
|
41 |
+
add_special_tokens=False,
|
42 |
+
return_tensors="pt"
|
43 |
+
).input_ids
|
44 |
+
return model, processor, decoder_input_ids, device
|
45 |
+
|
46 |
+
model, processor, decoder_input_ids, device = init_model()
|
47 |
+
|
48 |
+
@st.cache_data
|
49 |
+
def inference(input_image):
|
50 |
+
# Load image
|
51 |
+
logger.info("\nLoad image from: {}".format(input_image))
|
52 |
+
image = Image.open(input_image)
|
53 |
+
if not image.mode == "RGB":
|
54 |
+
image = image.convert('RGB')
|
55 |
+
pixel_values = processor.image_processor(
|
56 |
+
image,
|
57 |
+
return_tensors="pt",
|
58 |
+
data_format="channels_first",
|
59 |
+
).pixel_values
|
60 |
+
# Generate LaTeX expression
|
61 |
+
with torch.no_grad():
|
62 |
+
outputs = model.generate(
|
63 |
+
pixel_values.to(device),
|
64 |
+
decoder_input_ids=decoder_input_ids.to(device),
|
65 |
+
max_length=model.decoder.config.max_length,
|
66 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
67 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
68 |
+
use_cache=True,
|
69 |
+
num_beams=4,
|
70 |
+
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
71 |
+
return_dict_in_generate=True,
|
72 |
+
)
|
73 |
+
sequence = processor.tokenizer.batch_decode(outputs.sequences)[0]
|
74 |
+
sequence = sequence.replace(
|
75 |
+
processor.tokenizer.eos_token, ""
|
76 |
+
).replace(
|
77 |
+
processor.tokenizer.pad_token, ""
|
78 |
+
).replace(processor.tokenizer.bos_token,"")
|
79 |
+
logger.info("Output: {}".format(sequence))
|
80 |
+
return sequence
|
81 |
+
|
82 |
+
@st.cache_data
|
83 |
+
def compute_crohme_metrics(label_str, pred_str):
|
84 |
+
wer = wer_func.compute(predictions=[pred_str], references=[label_str])
|
85 |
+
# Compute expression rate score
|
86 |
+
exprate, error_1, error_2, error_3 = compute_exprate(
|
87 |
+
predictions=[pred_str],
|
88 |
+
references=[label_str]
|
89 |
+
)
|
90 |
+
return round(wer*100, 2), round(exprate*100, 2), round(error_1*100, 2), round(error_2*100, 2), round(error_3*100, 2)
|
91 |
+
|
92 |
+
|
93 |
+
@st.cache_data
|
94 |
+
def compute_img2latex100k_metrics(label_str, pred_str):
|
95 |
+
# Compute edit distance score
|
96 |
+
edit_distance = compute_edit_distance(
|
97 |
+
pred_str,
|
98 |
+
label_str
|
99 |
+
)/max(len(pred_str),len(label_str))
|
100 |
+
# Convert minimun edit distance score to maximun edit distance score
|
101 |
+
edit_distance = round((1 - edit_distance)*100, 2)
|
102 |
+
# Compute bleu score
|
103 |
+
bleu = bleu_func.compute(
|
104 |
+
predictions=[pred_str],
|
105 |
+
references=[label_str],
|
106 |
+
max_order=4 # Maximum n-gram order to use when computing BLEU score
|
107 |
+
)
|
108 |
+
bleu = round(bleu['bleu']*100, 2)
|
109 |
+
exact_match = exact_match_func.compute(
|
110 |
+
predictions=[pred_str],
|
111 |
+
references=[label_str]
|
112 |
+
)
|
113 |
+
exact_match = round(exact_match['exact_match']*100, 2)
|
114 |
+
return bleu, edit_distance, exact_match
|
115 |
+
|
116 |
+
# --------------------------------- Sreamlit code ---------------------------------
|
117 |
+
|
118 |
+
st.markdown("<h1 style='text-align: center; color: LightSkyBlue;'>Math Formula Images To LaTeX Code Based On End-to-End Approach With Attention Mechanism</h1>", unsafe_allow_html=True)
|
119 |
+
st.write('')
|
120 |
+
st.write('')
|
121 |
+
st.write('')
|
122 |
+
st.header('Input', divider='blue')
|
123 |
+
uploaded_file = st.file_uploader(
|
124 |
+
"Upload an image",
|
125 |
+
type = ['png', 'jpg'],
|
126 |
+
)
|
127 |
+
if uploaded_file is not None:
|
128 |
+
bytes_data = uploaded_file.read()
|
129 |
+
st.image(
|
130 |
+
bytes_data,
|
131 |
+
width = 700,
|
132 |
+
channels = 'RGB',
|
133 |
+
output_format = 'PNG'
|
134 |
+
)
|
135 |
+
on = st.toggle('Enable testing with label')
|
136 |
+
|
137 |
+
if on:
|
138 |
+
with st.container(border=True):
|
139 |
+
option = st.selectbox(
|
140 |
+
'Benchmark ?',
|
141 |
+
('Im2latex-100k', 'CROHME'))
|
142 |
+
label = st.text_input('Label', None)
|
143 |
+
run = st.button("Run")
|
144 |
+
|
145 |
+
if run is True and uploaded_file is not None and label is not None and option == 'Im2latex-100k':
|
146 |
+
pred_str = inference(uploaded_file)
|
147 |
+
st.header('Output', divider='blue')
|
148 |
+
st.latex(pred_str)
|
149 |
+
st.write(':orange[Latex sequences:]', pred_str)
|
150 |
+
bleu, edit_distance, exact_match = compute_img2latex100k_metrics(label, pred_str)
|
151 |
+
with st.container(border=True):
|
152 |
+
col1, col2, col3 = st.columns(3)
|
153 |
+
col1.metric("Bleu", bleu)
|
154 |
+
col2.metric("Edit Distance", edit_distance)
|
155 |
+
col3.metric("Exact Match", exact_match)
|
156 |
+
|
157 |
+
if run is True and uploaded_file is not None and label is not None and option == 'CROHME':
|
158 |
+
pred_str = inference(uploaded_file)
|
159 |
+
st.header('Output', divider='blue')
|
160 |
+
st.latex(pred_str)
|
161 |
+
st.write(':orange[Latex sequences:]', pred_str)
|
162 |
+
wer, exprate, error_1, error_2, error_3 = compute_crohme_metrics(label, pred_str)
|
163 |
+
with st.container(border=True):
|
164 |
+
col1, col2, col3, col4, col5 = st.columns(5)
|
165 |
+
col1.metric("ExpRate", exprate)
|
166 |
+
col2.metric("ExpRate 1", error_1)
|
167 |
+
col3.metric("ExpRate 2", error_2)
|
168 |
+
col4.metric("ExpRate 3", error_3)
|
169 |
+
col5.metric("WER", wer)
|
170 |
+
|
171 |
+
else:
|
172 |
+
run = st.button("Run")
|
173 |
+
if run is True and uploaded_file is not None:
|
174 |
+
pred_str = inference(uploaded_file)
|
175 |
+
st.write('')
|
176 |
+
st.header('Output', divider='blue')
|
177 |
+
st.latex(pred_str)
|
178 |
+
st.write(':orange[Latex sequences:]', pred_str)
|
179 |
+
|
180 |
+
|
181 |
+
if __name__ == "__main__":
|
182 |
+
parser = argparse.ArgumentParser(description="Sumen Latex OCR")
|
183 |
+
parser.add_argument(
|
184 |
+
"--ckpt",
|
185 |
+
type=str,
|
186 |
+
default="checkpoints",
|
187 |
+
help="Path to the checkpoint",
|
188 |
+
)
|
189 |
+
args = parser.parse_args()
|
190 |
+
main(args)
|