| import gradio as gr | |
| import os | |
| import glob | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from rxnscribe import RxnScribe | |
| from huggingface_hub import hf_hub_download | |
| REPO_ID = "yujieq/RxnScribe" | |
| FILENAME = "pix2seq_reaction_full.ckpt" | |
| ckpt_path = hf_hub_download(REPO_ID, FILENAME) | |
| device = torch.device('cpu') | |
| model = RxnScribe(ckpt_path, device) | |
| def get_markdown(reaction): | |
| output = [] | |
| for x in ['reactants', 'conditions', 'products']: | |
| s = '' | |
| for ent in reaction[x]: | |
| if 'smiles' in ent: | |
| s += "\n```\n" + ent['smiles'] + "\n```\n" | |
| elif 'text' in ent: | |
| s += ' '.join(ent['text']) + '<br>' | |
| else: | |
| s += ent['category'] | |
| output.append(s) | |
| return output | |
| def predict(image, molscribe, ocr): | |
| predictions = model.predict_image(image, molscribe=molscribe, ocr=ocr) | |
| pred_image = model.draw_predictions_combined(predictions, image=image) | |
| markdown = [[i] + get_markdown(reaction) for i, reaction in enumerate(predictions)] | |
| return pred_image, markdown | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| <center> <h1>RxnScribe</h1> </center> | |
| Extract chemical reactions from a diagram. Please upload a reaction diagram, RxnScribe will predict the reaction structures in the diagram. | |
| The predicted reactions are visualized in separate images. | |
| <b style="color:red">Red boxes are <i><u style="color:red">reactants</u></i>.</b> | |
| <b style="color:green">Green boxes are <i><u style="color:green">reaction conditions</u></i>.</b> | |
| <b style="color:blue">Blue boxes are <i><u style="color:blue">products</u></i>.</b> | |
| It usually takes 5-10 seconds to process a diagram with this demo. | |
| Check the options to run [MolScribe](https://huggingface.co/spaces/yujieq/MolScribe) and [OCR](https://huggingface.co/spaces/tomofi/EasyOCR) (it will take a longer time, of course). | |
| Paper: [RxnScribe: A Sequence Generation Model for Reaction Diagram Parsing](https://pubs.acs.org/doi/10.1021/acs.jcim.3c00439) | |
| Code: [https://github.com/thomas0809/RxnScribe](https://github.com/thomas0809/RxnScribe) | |
| Authors: [Yujie Qian](mailto:[email protected]), Jiang Guo, Zhengkai Tu, Connor W. Coley, Regina Barzilay. _MIT CSAIL_. | |
| """) | |
| with gr.Column(): | |
| with gr.Row(): | |
| image = gr.Image(label="Upload reaction diagram", show_label=False, type='pil').style(height=256) | |
| with gr.Row(): | |
| molscribe = gr.Checkbox(label="Run MolScribe to recognize molecule structures") | |
| ocr = gr.Checkbox(label="Run OCR to recognize text") | |
| btn = gr.Button("Submit").style(full_width=False) | |
| with gr.Row(): | |
| gallery = gr.Image(label='Predicted reactions', show_label=True).style(height="auto") | |
| markdown = gr.Dataframe( | |
| headers=['#', 'reactant', 'condition', 'product'], | |
| datatype=['number'] + ['markdown'] * 3, | |
| wrap=False | |
| ) | |
| btn.click(predict, inputs=[image, molscribe, ocr], outputs=[gallery, markdown]) | |
| gr.Examples( | |
| examples=sorted(glob.glob('examples/*.png')), | |
| inputs=[image], | |
| outputs=[gallery, markdown], | |
| fn=predict, | |
| ) | |
| demo.launch() | |