|
--- |
|
license: apache-2.0 |
|
pipeline_tag: image-to-text |
|
datasets: |
|
- hoang-quoc-trung/fusion-image-to-latex-datasets |
|
tags: |
|
- img2latex |
|
- latex ocr |
|
- Printed Mathematical Expression Recognition |
|
- Handwritten Mathematical Expression Recognition |
|
--- |
|
|
|
# <font color="turquoise"> <p style="text-align:center"> Translating Math Formula Images To LaTeX Sequences </p> </font> |
|
|
|
|
|
Scaling Up Image-to-LaTeX Performance: Sumen An End-to-End Transformer Model With Large Dataset |
|
|
|
 |
|
|
|
## Performance |
|
|
|
 |
|
|
|
|
|
 |
|
|
|
## Uses |
|
|
|
#### Source code: https://github.com/hoang-quoc-trung/sumen |
|
|
|
#### Inference |
|
|
|
```python |
|
import torch |
|
import requests |
|
from PIL import Image |
|
from transformers import AutoProcessor, VisionEncoderDecoderModel |
|
|
|
# Load model & processor |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = VisionEncoderDecoderModel.from_pretrained('hoang-quoc-trung/sumen-base').to(device) |
|
processor = AutoProcessor.from_pretrained('hoang-quoc-trung/sumen-base') |
|
task_prompt = processor.tokenizer.bos_token |
|
decoder_input_ids = processor.tokenizer( |
|
task_prompt, |
|
add_special_tokens=False, |
|
return_tensors="pt" |
|
).input_ids |
|
# Load image |
|
img_url = 'https://raw.githubusercontent.com/hoang-quoc-trung/sumen/main/assets/example_1.png' |
|
image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') |
|
pixel_values = processor.image_processor( |
|
image, |
|
return_tensors="pt", |
|
data_format="channels_first", |
|
).pixel_values |
|
# Generate LaTeX expression |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
pixel_values.to(device), |
|
decoder_input_ids=decoder_input_ids.to(device), |
|
max_length=model.decoder.config.max_length, |
|
pad_token_id=processor.tokenizer.pad_token_id, |
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
use_cache=True, |
|
num_beams=4, |
|
bad_words_ids=[[processor.tokenizer.unk_token_id]], |
|
return_dict_in_generate=True, |
|
) |
|
sequence = processor.tokenizer.batch_decode(outputs.sequences)[0] |
|
sequence = sequence.replace( |
|
processor.tokenizer.eos_token, "" |
|
).replace( |
|
processor.tokenizer.pad_token, "" |
|
).replace(processor.tokenizer.bos_token,"") |
|
print(sequence) |
|
``` |