Spaces:
Running
Running
import os | |
import gc | |
import random | |
import warnings | |
warnings.filterwarnings('ignore') | |
import numpy as np | |
import pandas as pd | |
import torch | |
import tokenizers | |
import transformers | |
from transformers import AutoTokenizer, EncoderDecoderModel, AutoModelForSeq2SeqLM | |
import sentencepiece | |
from rdkit import Chem | |
import rdkit | |
import streamlit as st | |
class CFG(): | |
input_data = st.text_area('enter chemical reaction (e.g. REACTANT:CNc1nc(SC)ncc1CO.O.O=[Cr](=O)([O-])O[Cr](=O)(=O)[O-].[Na+]CATALYST: REAGENT: SOLVENT:CC(=O)O)') | |
model_name_or_path = 'sagawa/ZINC-t5-productpredicition' | |
model = 't5' | |
num_beams = 5 | |
num_return_sequences = 5 | |
seed = 42 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def seed_everything(seed=42): | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
seed_everything(seed=CFG.seed) | |
tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors='pt') | |
if CFG.model == 't5': | |
model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name_or_path).to(device) | |
elif CFG.model == 'deberta': | |
model = EncoderDecoderModel.from_pretrained(CFG.model_name_or_path).to(device) | |
input_compound = CFG.input_data | |
min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0) | |
inp = tokenizer(input_compound, return_tensors='pt').to(device) | |
output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True) | |
scores = output['sequences_scores'].tolist() | |
output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']] | |
for ith, out in enumerate(output): | |
mol = Chem.MolFromSmiles(out.rstrip('.')) | |
if type(mol) == rdkit.Chem.rdchem.Mol: | |
output.append(out.rstrip('.')) | |
scores.append(scores[ith]) | |
break | |
if type(mol) == None: | |
output.append(None) | |
scores.append(None) | |
output += scores | |
output = [input_compound] + output | |
output_df = pd.DataFrame(np.array(output).reshape(1, -1), columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score']) | |
st.table(output_df) |