WarmMolGenDemo / app.py
cankoban's picture
Upload app.py
37d70c1
raw
history blame
7.65 kB
import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import mols2grid
from ipywidgets import interact, widgets
import textwrap
# import numpy as np
from transformers import EncoderDecoderModel, RobertaTokenizer
from moses.metrics.utils import QED, SA, logP, NP, weight, get_n_rings
from moses.utils import mapper, get_mol
# @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
from typing import List
from util import filter_dataframe
@st.cache(suppress_st_warning=True)
def load_models():
# protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo")
# mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
model1 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenOne")
model2 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenTwo")
return model1, model2 # , protein_tokenizer, mol_tokenizer
def count(smiles_list: List[str]):
counts = []
for smiles in smiles_list:
counts.append(len(smiles))
return counts
def remove_none_elements(mol_list, smiles_list):
filtered_mol_list = []
filtered_smiles_list = []
indices = []
for i, element in enumerate(mol_list):
if element is not None:
filtered_mol_list.append(element)
else:
indices.append(i)
removed_len = len(indices)
for i in range(len(smiles_list)):
if i not in indices:
filtered_smiles_list.append(smiles_list.__getitem__(i))
return filtered_mol_list, filtered_smiles_list, removed_len
def format_list_numbers(lst):
for i, value in enumerate(lst):
lst[i] = float("{:.3f}".format(value))
def generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, target, pool):
protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo")
mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
# model1, model2, protein_tokenizer, mol_tokenizer = load_models()
model1, model2 = load_models()
inputs = protein_tokenizer(target, return_tensors="pt")
model = model1 if model_name == 'WarmMolGenOne' else model2
outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id,
eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id,
max_length=int(max_new_tokens), num_return_sequences=int(num_mols),
do_sample=do_sample, num_beams=num_beams)
output_smiles = mol_tokenizer.batch_decode(outputs, skip_special_tokens=True)
st.write("### Generated Molecules")
# mol_list = list(map(MolFromSmiles, output_smiles))
# print(mol_list)
# QED_scores = list(map(QED.qed, mol_list))
# print(QED_scores)
# st.write(output_smiles)
mol_list = mapper(pool)(get_mol, output_smiles)
mol_list, output_smiles, removed_len = remove_none_elements(mol_list, output_smiles)
if removed_len != 0:
st.write(f"#### Note that: {removed_len} numbers of generated invalid molecules are discarded.")
QED_scores = mapper(pool)(QED, mol_list)
SA_scores = mapper(pool)(SA, mol_list)
logP_scores = mapper(pool)(logP, mol_list)
NP_scores = mapper(pool)(NP, mol_list)
weight_scores = mapper(pool)(weight, mol_list)
format_list_numbers(QED_scores)
format_list_numbers(SA_scores)
format_list_numbers(logP_scores)
format_list_numbers(NP_scores)
format_list_numbers(weight_scores)
df_smiles = pd.DataFrame(
{'SMILES': output_smiles, "QED": QED_scores, "SA": SA_scores, "logP": logP_scores, "NP": NP_scores,
"Weight": weight_scores})
return df_smiles
def warm_molgen_demo():
with st.form("my_form"):
with st.sidebar:
st.sidebar.subheader("Configurable parameters")
model_name = st.sidebar.selectbox(
"Model Selector",
options=[
"WarmMolGenOne",
"WarmMolGenTwo",
],
index=0,
)
num_mols = st.sidebar.number_input(
"Number of generated molecules",
min_value=0,
max_value=20,
value=10,
help="The number of molecules to be generated.",
)
max_new_tokens = st.sidebar.number_input(
"Maximum length",
min_value=0,
max_value=1024,
value=128,
help="The maximum length of the sequence to be generated.",
)
do_sample = st.sidebar.selectbox(
"Sampling?",
(True, False),
help="Whether or not to use sampling; use beam decoding otherwise.",
)
target = st.text_area(
"Target Sequence",
"MENTENSVDSKSIKNLEPKIIHGSESMDSGISLDNSYKMDYPEMGLCIIINNKNFHKSTG",
)
generate_new_molecules = st.form_submit_button("Generate Molecules")
num_beams = None if do_sample is True else int(num_mols)
pool = 1
if generate_new_molecules:
st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams,
target, pool)
if 'df' not in st.session_state:
st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams,
target, pool)
df = st.session_state.df
filtered_df = filter_dataframe(df)
if filtered_df.empty:
st.markdown(
"""
<span style='color: blue; font-size: 30px;'>No molecules were found with specified properties.</span>
""",
unsafe_allow_html=True
)
else:
raw_html = mols2grid.display(filtered_df, height=1000)._repr_html_()
components.html(raw_html, width=900, height=450, scrolling=True)
st.markdown("## How to Generate")
generation_code = f"""
from transformers import EncoderDecoderModel, RobertaTokenizer
protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/{model_name}")
mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
model = EncoderDecoderModel.from_pretrained("gokceuludogan/{model_name}")
inputs = protein_tokenizer("{target}", return_tensors="pt")
outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id,
eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id,
max_length={max_new_tokens}, num_return_sequences={num_mols}, do_sample={do_sample}, num_beams={num_beams})
mol_tokenizer.batch_decode(outputs, skip_special_tokens=True)
"""
st.code(textwrap.dedent(generation_code)) # textwrap.dedent("".join("Halletcez")))
st.set_page_config(page_title="WarmMolGen Demo", page_icon="🔥", layout='wide')
st.markdown("# WarmMolGen Demo")
st.sidebar.header("WarmMolGen Demo")
st.markdown(
"""
This demo illustrates WarmMolGen models' generation capabilities.
Given a target sequence and a set of parameters, the models generate molecules targeting the given protein sequence.
Please enter an input sequence below 👇 and configure parameters from the sidebar 👈 to generate molecules!
See below for saving the output molecules and the code snippet generating them!
"""
)
warm_molgen_demo()