cankoban commited on
Commit
37d70c1
·
1 Parent(s): bfc3f40

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ import pandas as pd
4
+ import mols2grid
5
+ from ipywidgets import interact, widgets
6
+ import textwrap
7
+ # import numpy as np
8
+ from transformers import EncoderDecoderModel, RobertaTokenizer
9
+
10
+ from moses.metrics.utils import QED, SA, logP, NP, weight, get_n_rings
11
+ from moses.utils import mapper, get_mol
12
+
13
+ # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str})
14
+ from typing import List
15
+
16
+ from util import filter_dataframe
17
+
18
+
19
+ @st.cache(suppress_st_warning=True)
20
+ def load_models():
21
+ # protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo")
22
+ # mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
23
+ model1 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenOne")
24
+ model2 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenTwo")
25
+ return model1, model2 # , protein_tokenizer, mol_tokenizer
26
+
27
+
28
+ def count(smiles_list: List[str]):
29
+ counts = []
30
+ for smiles in smiles_list:
31
+ counts.append(len(smiles))
32
+
33
+ return counts
34
+
35
+
36
+ def remove_none_elements(mol_list, smiles_list):
37
+ filtered_mol_list = []
38
+ filtered_smiles_list = []
39
+ indices = []
40
+ for i, element in enumerate(mol_list):
41
+ if element is not None:
42
+ filtered_mol_list.append(element)
43
+ else:
44
+ indices.append(i)
45
+ removed_len = len(indices)
46
+
47
+ for i in range(len(smiles_list)):
48
+ if i not in indices:
49
+ filtered_smiles_list.append(smiles_list.__getitem__(i))
50
+
51
+ return filtered_mol_list, filtered_smiles_list, removed_len
52
+
53
+
54
+ def format_list_numbers(lst):
55
+ for i, value in enumerate(lst):
56
+ lst[i] = float("{:.3f}".format(value))
57
+
58
+
59
+ def generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, target, pool):
60
+ protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo")
61
+ mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
62
+ # model1, model2, protein_tokenizer, mol_tokenizer = load_models()
63
+ model1, model2 = load_models()
64
+ inputs = protein_tokenizer(target, return_tensors="pt")
65
+
66
+ model = model1 if model_name == 'WarmMolGenOne' else model2
67
+ outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id,
68
+ eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id,
69
+ max_length=int(max_new_tokens), num_return_sequences=int(num_mols),
70
+ do_sample=do_sample, num_beams=num_beams)
71
+ output_smiles = mol_tokenizer.batch_decode(outputs, skip_special_tokens=True)
72
+ st.write("### Generated Molecules")
73
+ # mol_list = list(map(MolFromSmiles, output_smiles))
74
+ # print(mol_list)
75
+ # QED_scores = list(map(QED.qed, mol_list))
76
+ # print(QED_scores)
77
+ # st.write(output_smiles)
78
+ mol_list = mapper(pool)(get_mol, output_smiles)
79
+ mol_list, output_smiles, removed_len = remove_none_elements(mol_list, output_smiles)
80
+ if removed_len != 0:
81
+ st.write(f"#### Note that: {removed_len} numbers of generated invalid molecules are discarded.")
82
+
83
+ QED_scores = mapper(pool)(QED, mol_list)
84
+ SA_scores = mapper(pool)(SA, mol_list)
85
+ logP_scores = mapper(pool)(logP, mol_list)
86
+ NP_scores = mapper(pool)(NP, mol_list)
87
+ weight_scores = mapper(pool)(weight, mol_list)
88
+
89
+ format_list_numbers(QED_scores)
90
+ format_list_numbers(SA_scores)
91
+ format_list_numbers(logP_scores)
92
+ format_list_numbers(NP_scores)
93
+ format_list_numbers(weight_scores)
94
+
95
+ df_smiles = pd.DataFrame(
96
+ {'SMILES': output_smiles, "QED": QED_scores, "SA": SA_scores, "logP": logP_scores, "NP": NP_scores,
97
+ "Weight": weight_scores})
98
+
99
+ return df_smiles
100
+
101
+
102
+ def warm_molgen_demo():
103
+ with st.form("my_form"):
104
+ with st.sidebar:
105
+ st.sidebar.subheader("Configurable parameters")
106
+
107
+ model_name = st.sidebar.selectbox(
108
+ "Model Selector",
109
+ options=[
110
+ "WarmMolGenOne",
111
+ "WarmMolGenTwo",
112
+ ],
113
+ index=0,
114
+ )
115
+
116
+ num_mols = st.sidebar.number_input(
117
+ "Number of generated molecules",
118
+ min_value=0,
119
+ max_value=20,
120
+ value=10,
121
+ help="The number of molecules to be generated.",
122
+ )
123
+
124
+ max_new_tokens = st.sidebar.number_input(
125
+ "Maximum length",
126
+ min_value=0,
127
+ max_value=1024,
128
+ value=128,
129
+ help="The maximum length of the sequence to be generated.",
130
+ )
131
+ do_sample = st.sidebar.selectbox(
132
+ "Sampling?",
133
+ (True, False),
134
+ help="Whether or not to use sampling; use beam decoding otherwise.",
135
+ )
136
+ target = st.text_area(
137
+ "Target Sequence",
138
+ "MENTENSVDSKSIKNLEPKIIHGSESMDSGISLDNSYKMDYPEMGLCIIINNKNFHKSTG",
139
+ )
140
+ generate_new_molecules = st.form_submit_button("Generate Molecules")
141
+
142
+ num_beams = None if do_sample is True else int(num_mols)
143
+
144
+ pool = 1
145
+
146
+ if generate_new_molecules:
147
+ st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams,
148
+ target, pool)
149
+ if 'df' not in st.session_state:
150
+ st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams,
151
+ target, pool)
152
+ df = st.session_state.df
153
+
154
+ filtered_df = filter_dataframe(df)
155
+ if filtered_df.empty:
156
+ st.markdown(
157
+ """
158
+ <span style='color: blue; font-size: 30px;'>No molecules were found with specified properties.</span>
159
+ """,
160
+ unsafe_allow_html=True
161
+ )
162
+ else:
163
+ raw_html = mols2grid.display(filtered_df, height=1000)._repr_html_()
164
+ components.html(raw_html, width=900, height=450, scrolling=True)
165
+
166
+ st.markdown("## How to Generate")
167
+ generation_code = f"""
168
+ from transformers import EncoderDecoderModel, RobertaTokenizer
169
+ protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/{model_name}")
170
+ mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")
171
+ model = EncoderDecoderModel.from_pretrained("gokceuludogan/{model_name}")
172
+ inputs = protein_tokenizer("{target}", return_tensors="pt")
173
+ outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id,
174
+ eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id,
175
+ max_length={max_new_tokens}, num_return_sequences={num_mols}, do_sample={do_sample}, num_beams={num_beams})
176
+ mol_tokenizer.batch_decode(outputs, skip_special_tokens=True)
177
+ """
178
+ st.code(textwrap.dedent(generation_code)) # textwrap.dedent("".join("Halletcez")))
179
+
180
+
181
+ st.set_page_config(page_title="WarmMolGen Demo", page_icon="🔥", layout='wide')
182
+ st.markdown("# WarmMolGen Demo")
183
+ st.sidebar.header("WarmMolGen Demo")
184
+ st.markdown(
185
+ """
186
+ This demo illustrates WarmMolGen models' generation capabilities.
187
+ Given a target sequence and a set of parameters, the models generate molecules targeting the given protein sequence.
188
+ Please enter an input sequence below 👇 and configure parameters from the sidebar 👈 to generate molecules!
189
+ See below for saving the output molecules and the code snippet generating them!
190
+ """
191
+ )
192
+
193
+ warm_molgen_demo()