Yin Fang commited on
Commit
227b864
Β·
1 Parent(s): 9b4a51c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -12
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- #from src.utils import plogp, sf_decode, sim
4
  import pandas as pd
5
  from rdkit import Chem
6
  from rdkit.Chem import AllChem
@@ -59,12 +58,57 @@ def sim(input_smile, output_smile):
59
  else: return None
60
 
61
 
62
- def greet(name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen-large-opt")
65
  model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen-large-opt")
66
 
67
- input = name
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  sf_input = tokenizer(input, return_tensors="pt")
70
  molecules = model.generate(
@@ -83,7 +127,6 @@ def greet(name):
83
  sm_output = [sf_decode(sf) for sf in sf_output]
84
 
85
 
86
-
87
  input_plogp = plogp(input_sm)
88
  plogp_improve = [plogp(i)-input_plogp for i in sm_output]
89
 
@@ -93,20 +136,86 @@ def greet(name):
93
  candidate_selfies = {"candidates": sf_output, "improvement": plogp_improve, "sim": simm}
94
  data = pd.DataFrame(candidate_selfies)
95
 
96
- return data[(data['improvement']> 0) & (data['sim']>0.4)]
 
 
 
97
 
98
-
 
 
 
 
99
 
 
 
 
 
 
 
100
 
 
 
 
 
101
 
 
102
 
103
- examples = [
104
-
105
- ['[C][C][=Branch1][C][=O][N][C][C][O][C][C][O][C][C][O][C][C][Ring1][N]'],['[C][C][S][C][C][S][C][C][C][S][C][C][S][C][Ring1][=C]']
106
 
107
- ]
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
 
 
110
 
111
- iface = gr.Interface(fn=greet, inputs="text", outputs="numpy", title="Molecular Language Model as Multi-task Generator",examples=examples)
112
- iface.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
  import pandas as pd
4
  from rdkit import Chem
5
  from rdkit.Chem import AllChem
 
58
  else: return None
59
 
60
 
61
+ def gen_opt(gen_input):
62
+ tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen")
63
+ model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen")
64
+
65
+ sf_input = tokenizer(gen_input, return_tensors="pt")
66
+
67
+ # beam search
68
+ molecules = model.generate(input_ids=sf_input["input_ids"],
69
+ attention_mask=sf_input["attention_mask"],
70
+ max_length=15,
71
+ min_length=5,
72
+ num_return_sequences=4,
73
+ num_beams=5)
74
+
75
+ gen_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
76
+ smis = [sf.decoder(i) for i in gen_output]
77
+ mols = []
78
+ for smi in smis:
79
+ mol = Chem.MolFromSmiles(smi)
80
+ mols.append(mol)
81
+
82
+ gen_output_image = Draw.MolsToGridImage(
83
+ mols,
84
+ molsPerRow=4,
85
+ subImgSize=(200,200),
86
+ legends=['' for x in mols]
87
+ )
88
+
89
+ return "\n".join(gen_output), gen_output_image
90
+
91
+
92
+
93
+ def opt_process(opt_input):
94
 
95
  tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen-large-opt")
96
  model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen-large-opt")
97
 
98
+ input = opt_input
99
+
100
+ smis_input = [sf.decoder(i) for i in input]
101
+ mols_input = []
102
+ for smi in smis_input:
103
+ mol = Chem.MolFromSmiles(smi)
104
+ mols_input.append(mol)
105
+
106
+ opt_input_img = Draw.MolsToGridImage(
107
+ mols_input,
108
+ molsPerRow=4,
109
+ subImgSize=(200,200),
110
+ legends=['' for x in mols]
111
+ )
112
 
113
  sf_input = tokenizer(input, return_tensors="pt")
114
  molecules = model.generate(
 
127
  sm_output = [sf_decode(sf) for sf in sf_output]
128
 
129
 
 
130
  input_plogp = plogp(input_sm)
131
  plogp_improve = [plogp(i)-input_plogp for i in sm_output]
132
 
 
136
  candidate_selfies = {"candidates": sf_output, "improvement": plogp_improve, "sim": simm}
137
  data = pd.DataFrame(candidate_selfies)
138
 
139
+ results = data[(data['improvement']> 0) & (data['sim']>0.4)]
140
+ opt_output = results["candidates"].tolist()
141
+ opt_output_imp = results["improvement"].tolist()
142
+ opt_output_sim = results["sim"].tolist()
143
 
144
+ smis = [sf.decoder(i) for i in opt_output]
145
+ mols = []
146
+ for smi in smis:
147
+ mol = Chem.MolFromSmiles(smi)
148
+ mols.append(mol)
149
 
150
+ opt_output_img = Draw.MolsToGridImage(
151
+ mols,
152
+ molsPerRow=4,
153
+ subImgSize=(200,200),
154
+ legends=['' for x in mols]
155
+ )
156
 
157
+ return opt_input_img, "\n".join(opt_output), "\n".join(opt_output_imp), "\n".join(opt_output_sim), opt_output_img
158
+ # examples = [
159
+
160
+ # ['[C][C][=Branch1][C][=O][N][C][C][O][C][C][O][C][C][O][C][C][Ring1][N]'],['[C][C][S][C][C][S][C][C][C][S][C][C][S][C][Ring1][=C]']
161
 
162
+ # ]
163
 
 
 
 
164
 
 
165
 
166
+ # iface = gr.Interface(fn=greet, inputs="text", outputs="numpy", title="Molecular Language Model as Multi-task Generator",examples=examples)
167
+ # iface.launch()
168
+
169
+ with gr.Blocks() as demo:
170
+ init_triple_input()
171
+ gr.Markdown("# MolGen: Molecular Language Model as Multi-task Generator")
172
+
173
+ with gr.Tabs():
174
+ with gr.TabItem("Molecular Generation"):
175
+ with gr.Row():
176
+ with gr.Column():
177
+ gen_input = gr.Textbox(label="Input", lines=1, placeholder="SELFIES Input")
178
+ gen_button = gr.Button("Generate")
179
+
180
+ with gr.Column():
181
+ gen_output = gr.Textbox(label="Generation Results", lines=5, placeholder="")
182
+ gen_output_image = gr.Textbox(label="Visualization", lines=3, placeholder="")
183
+
184
+ gr.Examples(
185
+ examples=[["[C][=C][C][=C][C][=C][Ring1][=Branch1]"],
186
+ ["[C]"]
187
+ ],
188
+ inputs=[gen_input],
189
+ outputs=[gen_output, gen_output_image],
190
+ fn=gen_process,
191
+ cache_examples=True,
192
+ )
193
+
194
+ with gr.TabItem("Constrained Molecular Property Optimization"):
195
+ with gr.Row():
196
+ with gr.Column():
197
+ opt_input = gr.Textbox(label="Input", lines=1, placeholder="SELFIES Input")
198
+ opt_button = gr.Button("Optimize")
199
+
200
+ with gr.Column():
201
+ opt_input_img = gr.Textbox(label="Input Visualization", lines=3, placeholder="")
202
+ opt_output = gr.Textbox(label="Optimization Results", lines=3, placeholder="")
203
+ opt_output_imp = gr.Textbox(label="Optimization Property Improvements", lines=3, placeholder="")
204
+ opt_output_sim = gr.Textbox(label="Similarity", lines=3, placeholder="")
205
+ opt_output_img = gr.Textbox(label="Output Visualization", lines=3, placeholder="")
206
+
207
+ gr.Examples(
208
+ examples=[["[C][C][=Branch1][C][=O][N][C][C][O][C][C][O][C][C][O][C][C][Ring1][N]"],
209
+ ["[C][C][S][C][C][S][C][C][C][S][C][C][S][C][Ring1][=C]"],
210
+ ["[N][#C][C][C][C@@H1][C][C][C][C][C][C][C][C][C][C][C][Ring1][N][=O]"]
211
+ ],
212
+ inputs=[opt_input],
213
+ outputs=[opt_input_img, opt_output, opt_output_imp, opt_output_sim, opt_output_img],
214
+ fn=opt_process,
215
+ cache_examples=True,
216
+ )
217
 
218
+ gen_button.click(fn=gen_process, inputs=[gen_input], outputs=[gen_output, gen_output_image])
219
+ opt_button.click(fn=opt_process, inputs=[opt_input], outputs=[opt_input_img, opt_output, opt_output_imp, opt_output_sim, opt_output_img])
220
 
221
+ demo.launch()