import os import gc import random import warnings warnings.filterwarnings('ignore') import numpy as np import pandas as pd import torch import tokenizers import transformers from transformers import AutoTokenizer, EncoderDecoderModel, AutoModelForSeq2SeqLM import sentencepiece from rdkit import Chem import rdkit import streamlit as st class CFG(): input_data = st.text_area('enter chemical reaction (e.g. REACTANT:CNc1nc(SC)ncc1CO.O.O=[Cr](=O)([O-])O[Cr](=O)(=O)[O-].[Na+]CATALYST: REAGENT: SOLVENT:CC(=O)O)') model_name_or_path = 'sagawa/ZINC-t5-productpredicition' model = 't5' num_beams = 5 num_return_sequences = 5 seed = 42 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def seed_everything(seed=42): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True seed_everything(seed=CFG.seed) tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors='pt') if CFG.model == 't5': model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name_or_path).to(device) elif CFG.model == 'deberta': model = EncoderDecoderModel.from_pretrained(CFG.model_name_or_path).to(device) input_compound = CFG.input_data min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0) inp = tokenizer(input_compound, return_tensors='pt').to(device) output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True) scores = output['sequences_scores'].tolist() output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']] for ith, out in enumerate(output): mol = Chem.MolFromSmiles(out.rstrip('.')) if type(mol) == rdkit.Chem.rdchem.Mol: output.append(out.rstrip('.')) scores.append(scores[ith]) break if type(mol) == None: output.append(None) scores.append(None) output += scores output = [input_compound] + output output_df = pd.DataFrame(np.array(output).reshape(1, -1), columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score']) st.table(output_df)