import streamlit as st import streamlit.components.v1 as components import pandas as pd import mols2grid from ipywidgets import interact, widgets import textwrap # import numpy as np from transformers import EncoderDecoderModel, RobertaTokenizer from moses.metrics.utils import QED, SA, logP, NP, weight, get_n_rings from moses.utils import mapper, get_mol # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str}) from typing import List from util import filter_dataframe @st.cache(suppress_st_warning=True) def load_models(): # protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo") # mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") model1 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenOne") model2 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenTwo") return model1, model2 # , protein_tokenizer, mol_tokenizer def count(smiles_list: List[str]): counts = [] for smiles in smiles_list: counts.append(len(smiles)) return counts def remove_none_elements(mol_list, smiles_list): filtered_mol_list = [] filtered_smiles_list = [] indices = [] for i, element in enumerate(mol_list): if element is not None: filtered_mol_list.append(element) else: indices.append(i) removed_len = len(indices) for i in range(len(smiles_list)): if i not in indices: filtered_smiles_list.append(smiles_list.__getitem__(i)) return filtered_mol_list, filtered_smiles_list, removed_len def format_list_numbers(lst): for i, value in enumerate(lst): lst[i] = float("{:.3f}".format(value)) def generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, target, pool): protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo") mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") # model1, model2, protein_tokenizer, mol_tokenizer = load_models() model1, model2 = load_models() inputs = protein_tokenizer(target, return_tensors="pt") model = model1 if model_name == 'WarmMolGenOne' else model2 outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id, eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id, max_length=int(max_new_tokens), num_return_sequences=int(num_mols), do_sample=do_sample, num_beams=num_beams) output_smiles = mol_tokenizer.batch_decode(outputs, skip_special_tokens=True) st.write("### Generated Molecules") # mol_list = list(map(MolFromSmiles, output_smiles)) # print(mol_list) # QED_scores = list(map(QED.qed, mol_list)) # print(QED_scores) # st.write(output_smiles) mol_list = mapper(pool)(get_mol, output_smiles) mol_list, output_smiles, removed_len = remove_none_elements(mol_list, output_smiles) if removed_len != 0: st.write(f"#### Note that: {removed_len} numbers of generated invalid molecules are discarded.") QED_scores = mapper(pool)(QED, mol_list) SA_scores = mapper(pool)(SA, mol_list) logP_scores = mapper(pool)(logP, mol_list) NP_scores = mapper(pool)(NP, mol_list) weight_scores = mapper(pool)(weight, mol_list) format_list_numbers(QED_scores) format_list_numbers(SA_scores) format_list_numbers(logP_scores) format_list_numbers(NP_scores) format_list_numbers(weight_scores) df_smiles = pd.DataFrame( {'SMILES': output_smiles, "QED": QED_scores, "SA": SA_scores, "logP": logP_scores, "NP": NP_scores, "Weight": weight_scores}) return df_smiles def warm_molgen_demo(): with st.form("my_form"): with st.sidebar: st.sidebar.subheader("Configurable parameters") model_name = st.sidebar.selectbox( "Model Selector", options=[ "WarmMolGenOne", "WarmMolGenTwo", ], index=0, ) num_mols = st.sidebar.number_input( "Number of generated molecules", min_value=0, max_value=20, value=10, help="The number of molecules to be generated.", ) max_new_tokens = st.sidebar.number_input( "Maximum length", min_value=0, max_value=1024, value=128, help="The maximum length of the sequence to be generated.", ) do_sample = st.sidebar.selectbox( "Sampling?", (True, False), help="Whether or not to use sampling; use beam decoding otherwise.", ) target = st.text_area( "Target Sequence", "MENTENSVDSKSIKNLEPKIIHGSESMDSGISLDNSYKMDYPEMGLCIIINNKNFHKSTG", ) generate_new_molecules = st.form_submit_button("Generate Molecules") num_beams = None if do_sample is True else int(num_mols) pool = 1 if generate_new_molecules: st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, target, pool) if 'df' not in st.session_state: st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, target, pool) df = st.session_state.df filtered_df = filter_dataframe(df) if filtered_df.empty: st.markdown( """ No molecules were found with specified properties. """, unsafe_allow_html=True ) else: raw_html = mols2grid.display(filtered_df, height=1000)._repr_html_() components.html(raw_html, width=900, height=450, scrolling=True) st.markdown("## How to Generate") generation_code = f""" from transformers import EncoderDecoderModel, RobertaTokenizer protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/{model_name}") mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") model = EncoderDecoderModel.from_pretrained("gokceuludogan/{model_name}") inputs = protein_tokenizer("{target}", return_tensors="pt") outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id, eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id, max_length={max_new_tokens}, num_return_sequences={num_mols}, do_sample={do_sample}, num_beams={num_beams}) mol_tokenizer.batch_decode(outputs, skip_special_tokens=True) """ st.code(textwrap.dedent(generation_code)) # textwrap.dedent("".join("Halletcez"))) st.set_page_config(page_title="WarmMolGen Demo", page_icon="🔥", layout='wide') st.markdown("# WarmMolGen Demo") st.sidebar.header("WarmMolGen Demo") st.markdown( """ This demo illustrates WarmMolGen models' generation capabilities. Given a target sequence and a set of parameters, the models generate molecules targeting the given protein sequence. Please enter an input sequence below 👇 and configure parameters from the sidebar 👈 to generate molecules! See below for saving the output molecules and the code snippet generating them! """ ) warm_molgen_demo()