Spaces:
Running
Running
import os | |
import gc | |
import random | |
import warnings | |
warnings.filterwarnings('ignore') | |
import numpy as np | |
import pandas as pd | |
import torch | |
import tokenizers | |
import transformers | |
from transformers import AutoTokenizer, EncoderDecoderModel, AutoModelForSeq2SeqLM | |
import sentencepiece | |
from rdkit import Chem | |
import rdkit | |
import streamlit as st | |
st.title('predictproduct-t5') | |
st.markdown('##### At this space, you can predict the products of reactions from their inputs.') | |
st.markdown('##### The code expects input_data as a string or CSV file that contains an "input" column. The format of the string or contents of the column are like "REACTANT:{reactants of the reaction}REAGENT:{reagents, catalysts, or solvents of the reaction}".') | |
st.markdown('##### If there is no reagent, fill the blank with a space. And if there are multiple compounds, concatenate them with "."') | |
st.markdown('##### The output contains smiles of predicted products and sum of log-likelihood for each prediction. Predictions are ordered by their log-likelihood.(0th is the most probable product.) "valid compound" is the most probable and valid(can be recognized by RDKit) prediction.') | |
display_text = 'input the reaction smiles (e.g. REACTANT:COC(=O)C1=CCCN(C)C1.O.[Al+3].[H-].[Li+].[Na+].[OH-]REAGENT:C1CCOC1' | |
st.download_button( | |
label="Download demo_input.csv", | |
data=pd.read_csv('demo_input.csv').to_csv(index=False), | |
file_name='demo_input.csv', | |
mime='text/csv', | |
) | |
class CFG(): | |
num_beams = st.number_input(label='num beams', min_value=1, max_value=10, value=5, step=1) | |
num_return_sequences = num_beams | |
uploaded_file = st.file_uploader("Choose a CSV file") | |
input_data = st.text_area(display_text) | |
model_name_or_path = 'sagawa/ZINC-t5-productpredicition' | |
model = 't5' | |
seed = 42 | |
if st.button('predict'): | |
with st.spinner('Now processing. If num beams=5, this process takes about 15 seconds per reaction.'): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def seed_everything(seed=42): | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
seed_everything(seed=CFG.seed) | |
tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors='pt') | |
if CFG.model == 't5': | |
model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name_or_path).to(device) | |
elif CFG.model == 'deberta': | |
model = EncoderDecoderModel.from_pretrained(CFG.model_name_or_path).to(device) | |
if CFG.uploaded_file is not None: | |
input_data = pd.read_csv(CFG.uploaded_file) | |
outputs = [] | |
for idx, row in input_data.iterrows(): | |
input_compound = row['input'] | |
# min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0) | |
inp = tokenizer(input_compound, return_tensors='pt').to(device) | |
output = model.generate(**inp, min_length=2, max_length=181, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True) | |
if CFG.num_beams > 1: | |
scores = output['sequences_scores'].tolist() | |
output = [tokenizer.decode(i, skip_special_tokens=True).replace(' ', '').rstrip('.') for i in output['sequences']] | |
for ith, out in enumerate(output): | |
mol = Chem.MolFromSmiles(out.rstrip('.')) | |
if type(mol) == rdkit.Chem.rdchem.Mol: | |
output.append(out.rstrip('.')) | |
scores.append(scores[ith]) | |
break | |
if type(mol) == None: | |
output.append(None) | |
scores.append(None) | |
output += scores | |
output = [input_compound] + output | |
outputs.append(output) | |
else: | |
output = [tokenizer.decode(output['sequences'][0], skip_special_tokens=True).replace('. ', '.').rstrip('.')] | |
mol = Chem.MolFromSmiles(output[0]) | |
if type(mol) == rdkit.Chem.rdchem.Mol: | |
output.append(output[0]) | |
else: | |
output.append(None) | |
output = [input_compound] + output | |
outputs.append(output) | |
if CFG.num_beams > 1: | |
output_df = pd.DataFrame(outputs, columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score']) | |
else: | |
output_df = pd.DataFrame(outputs, columns=['input', '0th', 'valid compound']) | |
def convert_df(df): | |
# IMPORTANT: Cache the conversion to prevent computation on every rerun | |
return df.to_csv(index=False) | |
csv = convert_df(output_df) | |
st.download_button( | |
label="Download data as CSV", | |
data=csv, | |
file_name='output.csv', | |
mime='text/csv', | |
) | |
else: | |
input_compound = CFG.input_data | |
# min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0) | |
inp = tokenizer(input_compound, return_tensors='pt').to(device) | |
output = model.generate(**inp, min_length=2, max_length=181, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True) | |
if CFG.num_beams > 1: | |
scores = output['sequences_scores'].tolist() | |
output = [tokenizer.decode(i, skip_special_tokens=True).replace(' ', '').rstrip('.') for i in output['sequences']] | |
for ith, out in enumerate(output): | |
mol = Chem.MolFromSmiles(out.rstrip('.')) | |
if type(mol) == rdkit.Chem.rdchem.Mol: | |
output.append(out.rstrip('.')) | |
scores.append(scores[ith]) | |
break | |
if type(mol) == None: | |
output.append(None) | |
scores.append(None) | |
output += scores | |
output = [input_compound] + output | |
else: | |
output = [tokenizer.decode(output['sequences'][0], skip_special_tokens=True).replace('. ', '.').rstrip('.')] | |
mol = Chem.MolFromSmiles(output[0]) | |
if type(mol) == rdkit.Chem.rdchem.Mol: | |
output.append(output[0]) | |
else: | |
output.append(None) | |
if CFG.num_beams > 1: | |
output_df = pd.DataFrame(np.array(output).reshape(1, -1), columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score']) | |
else: | |
output_df = pd.DataFrame(np.array([input_compound]+output).reshape(1, -1), columns=['input', '0th', 'valid compound']) | |
st.table(output_df) | |
def convert_df(df): | |
# IMPORTANT: Cache the conversion to prevent computation on every rerun | |
return df.to_csv(index=False) | |
csv = convert_df(output_df) | |
st.download_button( | |
label="Download data as CSV", | |
data=csv, | |
file_name='output.csv', | |
mime='text/csv', | |
) | |