Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
1d1d4f3
1
Parent(s):
1e001e8
update
Browse files- .gitignore +1 -0
- app.py +62 -83
- llama_customized_models.py +154 -0
- metric_calculator.py +213 -0
- utils.py +269 -190
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/*
|
app.py
CHANGED
@@ -1,31 +1,11 @@
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import HfApi, get_collection, list_collections, list_models
|
3 |
#from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
|
4 |
-
from utils import
|
5 |
import pandas as pd
|
6 |
import os
|
7 |
import spaces
|
8 |
|
9 |
-
def get_models():
|
10 |
-
# we only support two models
|
11 |
-
# 1. ChemFM/uspto_mit_synthesis
|
12 |
-
# 2. ChemFM/uspto_full_retro
|
13 |
-
|
14 |
-
|
15 |
-
models = dict()
|
16 |
-
models['mit_synthesis'] = 'ChemFM/uspto_mit_synthesis'
|
17 |
-
models['full_retro'] = 'ChemFM/uspto_full_retro'
|
18 |
-
|
19 |
-
|
20 |
-
#for item in collection.items:
|
21 |
-
# if item.item_type == "model":
|
22 |
-
# item_name = item.item_id.split("/")[-1]
|
23 |
-
# models[item_name] = item.item_id
|
24 |
-
# assert item_name in dataset_task_types, f"{item_name} is not in the task_types"
|
25 |
-
# assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
|
26 |
-
|
27 |
-
return models
|
28 |
-
|
29 |
#candidate_models = get_models()
|
30 |
#task_names = {
|
31 |
# 'mit_synthesis': 'Reaction Synthesis',
|
@@ -46,16 +26,30 @@ def get_models():
|
|
46 |
#}
|
47 |
|
48 |
#property_names = list(candidate_models.keys())
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
print(value_1, value_2, value_3, value_4)
|
54 |
|
55 |
try:
|
56 |
|
57 |
running_status = None
|
58 |
prediction = None
|
|
|
|
|
59 |
|
60 |
#prediction = model.predict(smiles, property_name, adapter_id)
|
61 |
#prediction = model.predict_single_smiles(smiles, task)
|
@@ -65,10 +59,10 @@ def predict_single_label(value_1, value_2, value_3, value_4):
|
|
65 |
except Exception as e:
|
66 |
# no matter what the error is, we should return
|
67 |
print(e)
|
68 |
-
return "NA", "
|
69 |
|
70 |
-
prediction = "\n".join([f"{idx+1}. {item}" for idx, item in enumerate(prediction)])
|
71 |
-
return prediction, "
|
72 |
|
73 |
"""
|
74 |
def get_description(task_name):
|
@@ -177,6 +171,13 @@ def clear_file(download_button):
|
|
177 |
return gr.update(visible=False), gr.update(visible=False), None
|
178 |
"""
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
def build_inference():
|
181 |
|
182 |
with gr.Blocks() as demo:
|
@@ -184,7 +185,11 @@ def build_inference():
|
|
184 |
#with gr.Row():
|
185 |
#gr.Markdown(f"<span style='color: red;'>If you run out of your GPU quota, you can use the </span> <a href='https://huggingface.co/spaces/ChemFM/molecular_property_prediction'>CPU-powered space</a> but with much lower performance.")
|
186 |
#dropdown = gr.Dropdown([task_names[key] for key in tasks], label="Task", value=task_names[tasks[0]])
|
187 |
-
description = f"
|
|
|
|
|
|
|
|
|
188 |
|
189 |
description_box = gr.Textbox(label="Task description", lines=5,
|
190 |
interactive=False,
|
@@ -192,80 +197,54 @@ def build_inference():
|
|
192 |
# third row - Textbox input and prediction label
|
193 |
with gr.Row(equal_height=True):
|
194 |
with gr.Column():
|
195 |
-
checkbox_1 = gr.Checkbox(label="
|
196 |
-
slider_1 = gr.Slider(
|
|
|
197 |
with gr.Column():
|
198 |
-
checkbox_2 = gr.Checkbox(label="
|
199 |
-
slider_2 = gr.Slider(
|
|
|
200 |
with gr.Column():
|
201 |
-
checkbox_3 = gr.Checkbox(label="
|
202 |
-
slider_3 = gr.Slider(
|
|
|
203 |
with gr.Column():
|
204 |
-
checkbox_4 = gr.Checkbox(label="
|
205 |
-
slider_4 = gr.Slider(
|
|
|
206 |
|
207 |
predict_single_smiles_button = gr.Button("Generate", size='sm')
|
208 |
#prediction = gr.Label("Prediction will appear here")
|
209 |
-
prediction = gr.Textbox(label="Predictions", type="text", placeholder=None, lines=10, interactive=False)
|
|
|
210 |
|
211 |
running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
|
212 |
|
213 |
-
#input_file = gr.File(label="Molecule file",
|
214 |
-
# file_count='single',
|
215 |
-
# file_types=[".smi", ".csv"], height=300)
|
216 |
-
#predict_file_button = gr.Button("Predict", size='sm', visible=False)
|
217 |
-
#download_button = gr.DownloadButton("Download", size='sm', visible=False)
|
218 |
-
#stop_button = gr.Button("Stop", size='sm', visible=False)
|
219 |
|
220 |
# dropdown change event
|
221 |
# predict single button click event
|
222 |
predict_single_smiles_button.click(lambda:(gr.update(interactive=False),
|
|
|
|
|
|
|
|
|
223 |
gr.update(interactive=False),
|
224 |
gr.update(interactive=False),
|
225 |
gr.update(interactive=False),
|
226 |
gr.update(interactive=False),
|
227 |
gr.update(interactive=False),
|
228 |
) , outputs=[slider_1, slider_2, slider_3, slider_4,
|
|
|
229 |
predict_single_smiles_button, running_terminal_label])\
|
230 |
-
.then(predict_single_label, inputs=[slider_1, slider_2, slider_3, slider_4
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
"""
|
240 |
-
# input file upload event
|
241 |
-
file_status = gr.State()
|
242 |
-
input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
|
243 |
-
# input file clear event
|
244 |
-
input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
|
245 |
-
# predict file button click event
|
246 |
-
predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False),
|
247 |
-
gr.update(interactive=False),
|
248 |
-
gr.update(interactive=False),
|
249 |
-
gr.update(interactive=False, visible=True),
|
250 |
-
gr.update(interactive=False),
|
251 |
-
gr.update(interactive=True, visible=False),
|
252 |
-
gr.update(interactive=False),
|
253 |
-
gr.update(interactive=False),
|
254 |
-
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
|
255 |
-
.then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
|
256 |
-
.then(lambda:(gr.update(interactive=True),
|
257 |
-
gr.update(interactive=True),
|
258 |
-
gr.update(interactive=True),
|
259 |
-
gr.update(interactive=True),
|
260 |
-
gr.update(interactive=True),
|
261 |
-
gr.update(interactive=True),
|
262 |
-
gr.update(interactive=True),
|
263 |
-
gr.update(interactive=True),
|
264 |
-
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
|
265 |
-
|
266 |
-
# stop button click event
|
267 |
-
#stop_button.click(fn=None, inputs=None, outputs=None, cancels=[predict_file_event])
|
268 |
-
"""
|
269 |
|
270 |
return demo
|
271 |
|
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import HfApi, get_collection, list_collections, list_models
|
3 |
#from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
|
4 |
+
from utils import MolecularGenerationModel
|
5 |
import pandas as pd
|
6 |
import os
|
7 |
import spaces
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
#candidate_models = get_models()
|
10 |
#task_names = {
|
11 |
# 'mit_synthesis': 'Reaction Synthesis',
|
|
|
26 |
#}
|
27 |
|
28 |
#property_names = list(candidate_models.keys())
|
29 |
+
model = MolecularGenerationModel()
|
30 |
+
|
31 |
+
def predict_single_label(logp, tpas, sas, qed, logp_choose, tpsa_choose, sas_choose, qed_choose):
|
32 |
+
input_dict = dict()
|
33 |
+
if logp_choose:
|
34 |
+
input_dict['logP'] = logp
|
35 |
+
if tpsa_choose:
|
36 |
+
input_dict['TPSA'] = tpas
|
37 |
+
if sas_choose:
|
38 |
+
input_dict['SAS'] = sas
|
39 |
+
if qed_choose:
|
40 |
+
input_dict['QED'] = qed
|
41 |
+
|
42 |
+
if len(input_dict) == 0:
|
43 |
+
return "NA", "No input is selected"
|
44 |
|
45 |
+
print(input_dict)
|
|
|
46 |
|
47 |
try:
|
48 |
|
49 |
running_status = None
|
50 |
prediction = None
|
51 |
+
|
52 |
+
prediction = model.predict_single_smiles(input_dict)
|
53 |
|
54 |
#prediction = model.predict(smiles, property_name, adapter_id)
|
55 |
#prediction = model.predict_single_smiles(smiles, task)
|
|
|
59 |
except Exception as e:
|
60 |
# no matter what the error is, we should return
|
61 |
print(e)
|
62 |
+
return "NA", "Generation failed"
|
63 |
|
64 |
+
#prediction = "\n".join([f"{idx+1}. {item}" for idx, item in enumerate(prediction)])
|
65 |
+
return prediction, "Generation is done"
|
66 |
|
67 |
"""
|
68 |
def get_description(task_name):
|
|
|
171 |
return gr.update(visible=False), gr.update(visible=False), None
|
172 |
"""
|
173 |
|
174 |
+
def toggle_slider(checked):
|
175 |
+
return gr.update(interactive=checked)
|
176 |
+
|
177 |
+
def toggle_sliders_based_on_checkboxes(checked_values):
|
178 |
+
"""Enable or disable sliders based on the corresponding checkbox values."""
|
179 |
+
return [gr.update(interactive=checked_values[i]) for i in range(4)]
|
180 |
+
|
181 |
def build_inference():
|
182 |
|
183 |
with gr.Blocks() as demo:
|
|
|
185 |
#with gr.Row():
|
186 |
#gr.Markdown(f"<span style='color: red;'>If you run out of your GPU quota, you can use the </span> <a href='https://huggingface.co/spaces/ChemFM/molecular_property_prediction'>CPU-powered space</a> but with much lower performance.")
|
187 |
#dropdown = gr.Dropdown([task_names[key] for key in tasks], label="Task", value=task_names[tasks[0]])
|
188 |
+
description = f"This space allows you to generate ten possible molecules based on given conditions. \n" \
|
189 |
+
f"1. You can enable or disable specific properties using checkboxes and adjust their values with sliders. \n" \
|
190 |
+
f"2. The generated SMILES strings and their corresponding predicted properties will be displayed in the generations section. \n" \
|
191 |
+
f"3. The properties include logP, TPSA, SAS, and QED. \n" \
|
192 |
+
f"4. Model trained on the GuacaMol dataset for molecular design. "
|
193 |
|
194 |
description_box = gr.Textbox(label="Task description", lines=5,
|
195 |
interactive=False,
|
|
|
197 |
# third row - Textbox input and prediction label
|
198 |
with gr.Row(equal_height=True):
|
199 |
with gr.Column():
|
200 |
+
checkbox_1 = gr.Checkbox(label="logP", value=True)
|
201 |
+
slider_1 = gr.Slider(1, 7, value=4, label="logP", info="Choose between 1 and 7")
|
202 |
+
checkbox_1.change(toggle_slider, checkbox_1, slider_1)
|
203 |
with gr.Column():
|
204 |
+
checkbox_2 = gr.Checkbox(label="TPSA", value=True)
|
205 |
+
slider_2 = gr.Slider(20, 140, value=80, label="TPSA", info="Choose between 20 and 140")
|
206 |
+
checkbox_2.change(toggle_slider, checkbox_2, slider_2)
|
207 |
with gr.Column():
|
208 |
+
checkbox_3 = gr.Checkbox(label="SAS", value=True)
|
209 |
+
slider_3 = gr.Slider(1, 5, value=3, label="SAS", info="Choose between 1 and 5")
|
210 |
+
checkbox_3.change(toggle_slider, checkbox_3, slider_3)
|
211 |
with gr.Column():
|
212 |
+
checkbox_4 = gr.Checkbox(label="QED", value=True)
|
213 |
+
slider_4 = gr.Slider(0.1, 0.9, value=0.5, label="QED", info="Choose between 0.1 and 0.9")
|
214 |
+
checkbox_4.change(toggle_slider, checkbox_4, slider_4)
|
215 |
|
216 |
predict_single_smiles_button = gr.Button("Generate", size='sm')
|
217 |
#prediction = gr.Label("Prediction will appear here")
|
218 |
+
#prediction = gr.Textbox(label="Predictions", type="text", placeholder=None, lines=10, interactive=False)
|
219 |
+
prediction = gr.Dataframe(label="Generations", type="pandas", interactive=False)
|
220 |
|
221 |
running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
|
222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
# dropdown change event
|
225 |
# predict single button click event
|
226 |
predict_single_smiles_button.click(lambda:(gr.update(interactive=False),
|
227 |
+
gr.update(interactive=False),
|
228 |
+
gr.update(interactive=False),
|
229 |
+
gr.update(interactive=False),
|
230 |
+
gr.update(interactive=False),
|
231 |
gr.update(interactive=False),
|
232 |
gr.update(interactive=False),
|
233 |
gr.update(interactive=False),
|
234 |
gr.update(interactive=False),
|
235 |
gr.update(interactive=False),
|
236 |
) , outputs=[slider_1, slider_2, slider_3, slider_4,
|
237 |
+
checkbox_1, checkbox_2, checkbox_3, checkbox_4,
|
238 |
predict_single_smiles_button, running_terminal_label])\
|
239 |
+
.then(predict_single_label, inputs=[slider_1, slider_2, slider_3, slider_4,
|
240 |
+
checkbox_1, checkbox_2, checkbox_3, checkbox_4
|
241 |
+
], outputs=[prediction, running_terminal_label])\
|
242 |
+
.then(lambda a, b, c, d: toggle_sliders_based_on_checkboxes([a, b, c, d]) +
|
243 |
+
[gr.update(interactive=True)] * 6,
|
244 |
+
inputs=[checkbox_1, checkbox_2, checkbox_3, checkbox_4],
|
245 |
+
outputs=[slider_1, slider_2, slider_3, slider_4,
|
246 |
+
checkbox_1, checkbox_2, checkbox_3, checkbox_4,
|
247 |
+
predict_single_smiles_button, running_terminal_label])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
return demo
|
250 |
|
llama_customized_models.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel
|
2 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
5 |
+
from transformers.modeling_outputs import (
|
6 |
+
BaseModelOutputWithPast,
|
7 |
+
CausalLMOutputWithPast,
|
8 |
+
QuestionAnsweringModelOutput,
|
9 |
+
SequenceClassifierOutputWithPast,
|
10 |
+
)
|
11 |
+
from transformers.cache_utils import Cache
|
12 |
+
|
13 |
+
from transformers.modeling_outputs import (
|
14 |
+
CausalLMOutputWithPast,
|
15 |
+
)
|
16 |
+
from transformers.utils import (
|
17 |
+
add_start_docstrings_to_model_forward,
|
18 |
+
logging,
|
19 |
+
replace_return_docstrings,
|
20 |
+
)
|
21 |
+
from dataclasses import dataclass
|
22 |
+
|
23 |
+
from transformers.utils import ModelOutput
|
24 |
+
|
25 |
+
import torch
|
26 |
+
from typing import List, Optional, Tuple, Union
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
_CONFIG_FOR_DOC = "LlamaConfig"
|
31 |
+
|
32 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
33 |
+
Args:
|
34 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
35 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
36 |
+
it.
|
37 |
+
|
38 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
39 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
40 |
+
|
41 |
+
[What are input IDs?](../glossary#input-ids)
|
42 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
43 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
44 |
+
|
45 |
+
- 1 for tokens that are **not masked**,
|
46 |
+
- 0 for tokens that are **masked**.
|
47 |
+
|
48 |
+
[What are attention masks?](../glossary#attention-mask)
|
49 |
+
|
50 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
51 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
52 |
+
|
53 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
54 |
+
`past_key_values`).
|
55 |
+
|
56 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
57 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
58 |
+
information on the default strategy.
|
59 |
+
|
60 |
+
- 1 indicates the head is **not masked**,
|
61 |
+
- 0 indicates the head is **masked**.
|
62 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
63 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
64 |
+
config.n_positions - 1]`.
|
65 |
+
|
66 |
+
[What are position IDs?](../glossary#position-ids)
|
67 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
68 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
69 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
70 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
71 |
+
|
72 |
+
Two formats are allowed:
|
73 |
+
- a [`~cache_utils.Cache`] instance;
|
74 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
75 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
76 |
+
cache format.
|
77 |
+
|
78 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
79 |
+
legacy cache format will be returned.
|
80 |
+
|
81 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
82 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
83 |
+
of shape `(batch_size, sequence_length)`.
|
84 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
85 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
86 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
87 |
+
model's internal embedding lookup matrix.
|
88 |
+
use_cache (`bool`, *optional*):
|
89 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
90 |
+
`past_key_values`).
|
91 |
+
output_attentions (`bool`, *optional*):
|
92 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
93 |
+
tensors for more detail.
|
94 |
+
output_hidden_states (`bool`, *optional*):
|
95 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
96 |
+
more detail.
|
97 |
+
return_dict (`bool`, *optional*):
|
98 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
99 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
100 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
101 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
102 |
+
the complete sequence length.
|
103 |
+
"""
|
104 |
+
|
105 |
+
class LlamaForCausalLMWithNumericalEmbedding(LlamaForCausalLM):
|
106 |
+
|
107 |
+
def __init__(self, config: LlamaConfig):
|
108 |
+
super().__init__(config)
|
109 |
+
self.numerical_embedding = torch.nn.Linear(1, config.hidden_size, bias=True)
|
110 |
+
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
input_ids: torch.LongTensor = None,
|
115 |
+
properties: List = None,
|
116 |
+
properties_index: List = None,
|
117 |
+
attention_mask: Optional[torch.Tensor] = None,
|
118 |
+
position_ids: Optional[torch.LongTensor] = None,
|
119 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
120 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
121 |
+
cache_position=None,
|
122 |
+
labels: Optional[torch.LongTensor] = None,
|
123 |
+
use_cache: Optional[bool] = None,
|
124 |
+
output_attentions: Optional[bool] = None,
|
125 |
+
output_hidden_states: Optional[bool] = None,
|
126 |
+
return_dict: Optional[bool] = None,
|
127 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
128 |
+
|
129 |
+
b, l = input_ids.size()
|
130 |
+
assert len(properties) == b, "The number of properties should be equal to the batch size."
|
131 |
+
assert len(properties_index) == b, "The number of properties_index should be equal to the batch size."
|
132 |
+
|
133 |
+
embeddings = self.model.embed_tokens(input_ids)
|
134 |
+
|
135 |
+
for i, (props, props_index, embeds) in enumerate(zip(properties, properties_index, embeddings)):
|
136 |
+
assert len(props) == len(props_index), "The number of properties should be equal to the number of properties_index."
|
137 |
+
props = torch.tensor(props, device=embeds.device, dtype=torch.float32).unsqueeze(1)
|
138 |
+
num_embeds = self.numerical_embedding(props)
|
139 |
+
if len(props_index) > 0:
|
140 |
+
assert embeddings[i, props_index, :].shape == num_embeds.shape, "The shape of the embeddings and the numerical embeddings should be the same."
|
141 |
+
embeddings[i, props_index, :] = num_embeds
|
142 |
+
|
143 |
+
return super().forward(
|
144 |
+
input_ids=None,
|
145 |
+
attention_mask=attention_mask,
|
146 |
+
position_ids=position_ids,
|
147 |
+
past_key_values=past_key_values,
|
148 |
+
inputs_embeds=embeddings,
|
149 |
+
labels=labels,
|
150 |
+
use_cache=use_cache,
|
151 |
+
output_attentions=output_attentions,
|
152 |
+
output_hidden_states=output_hidden_states,
|
153 |
+
return_dict=return_dict,
|
154 |
+
)
|
metric_calculator.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics import mean_squared_error, roc_auc_score, r2_score
|
2 |
+
from rdkit.Chem import QED, Crippen, MolFromSmiles, rdmolops, rdMolDescriptors, AllChem
|
3 |
+
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
|
4 |
+
import networkx as nx
|
5 |
+
import os.path as op
|
6 |
+
import math
|
7 |
+
#from rdkit.six.moves import cPickle
|
8 |
+
import _pickle as cPickle
|
9 |
+
#from rdkit.six import iteritems
|
10 |
+
from rdkit import Chem
|
11 |
+
import pickle
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import sys
|
15 |
+
import os
|
16 |
+
from rdkit.Chem import RDConfig
|
17 |
+
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
|
18 |
+
import sascorer
|
19 |
+
|
20 |
+
from rdkit.DataStructs.cDataStructs import TanimotoSimilarity
|
21 |
+
from rdkit.Chem.Fingerprints import FingerprintMols
|
22 |
+
|
23 |
+
def compute_rmse(gt, pred):
|
24 |
+
return mean_squared_error(gt, pred, squared=False)
|
25 |
+
|
26 |
+
def compute_r2score(gt, pred):
|
27 |
+
return r2_score(gt, pred)
|
28 |
+
|
29 |
+
def compute_roc_auc(gt, pred):
|
30 |
+
return roc_auc_score(gt, pred)
|
31 |
+
|
32 |
+
def check_valid(smiles_list):
|
33 |
+
total_num = len(smiles_list)
|
34 |
+
empty_num = smiles_list.count("")
|
35 |
+
return 1 - empty_num / float(total_num)
|
36 |
+
|
37 |
+
def check_unique(smiles_list):
|
38 |
+
total_num = len(smiles_list)
|
39 |
+
smiles_set = set(smiles_list)
|
40 |
+
if "" in smiles_set:
|
41 |
+
smiles_set.remove("")
|
42 |
+
return len(smiles_set) / float(total_num)
|
43 |
+
|
44 |
+
def check_nolvelty(gen_smiles, train_smiles):
|
45 |
+
if len(gen_smiles) == 0:
|
46 |
+
novel_ratio = 0.
|
47 |
+
else:
|
48 |
+
duplicates = [1 for mol in gen_smiles if mol in train_smiles]
|
49 |
+
novel = len(gen_smiles) - sum(duplicates)
|
50 |
+
novel_ratio = novel*100./len(gen_smiles)
|
51 |
+
return novel_ratio
|
52 |
+
|
53 |
+
_fscores = None
|
54 |
+
def readFragmentScores(name='fpscores'):
|
55 |
+
import gzip
|
56 |
+
global _fscores
|
57 |
+
# generate the full path filename:
|
58 |
+
if name == "fpscores":
|
59 |
+
name = op.join(op.dirname(__file__), name)
|
60 |
+
_fscores = cPickle.load(gzip.open('%s.pkl.gz'%name))
|
61 |
+
outDict = {}
|
62 |
+
for i in _fscores:
|
63 |
+
for j in range(1,len(i)):
|
64 |
+
outDict[i[j]] = float(i[0])
|
65 |
+
_fscores = outDict
|
66 |
+
|
67 |
+
def numBridgeheadsAndSpiro(mol,ri=None):
|
68 |
+
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
|
69 |
+
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
|
70 |
+
return nBridgehead,nSpiro
|
71 |
+
|
72 |
+
def calculateScore(m):
|
73 |
+
if _fscores is None: readFragmentScores()
|
74 |
+
|
75 |
+
# fragment score
|
76 |
+
fp = rdMolDescriptors.GetMorganFingerprint(m,2) #<- 2 is the *radius* of the circular fingerprint
|
77 |
+
fps = fp.GetNonzeroElements()
|
78 |
+
score1 = 0.
|
79 |
+
nf = 0
|
80 |
+
for bitId,v in iteritems(fps):
|
81 |
+
nf += v
|
82 |
+
sfp = bitId
|
83 |
+
score1 += _fscores.get(sfp,-4)*v
|
84 |
+
score1 /= nf
|
85 |
+
|
86 |
+
# features score
|
87 |
+
nAtoms = m.GetNumAtoms()
|
88 |
+
nChiralCenters = len(Chem.FindMolChiralCenters(m,includeUnassigned=True))
|
89 |
+
ri = m.GetRingInfo()
|
90 |
+
nBridgeheads,nSpiro=numBridgeheadsAndSpiro(m,ri)
|
91 |
+
nMacrocycles=0
|
92 |
+
for x in ri.AtomRings():
|
93 |
+
if len(x)>8: nMacrocycles+=1
|
94 |
+
|
95 |
+
sizePenalty = nAtoms**1.005 - nAtoms
|
96 |
+
stereoPenalty = math.log10(nChiralCenters+1)
|
97 |
+
spiroPenalty = math.log10(nSpiro+1)
|
98 |
+
bridgePenalty = math.log10(nBridgeheads+1)
|
99 |
+
macrocyclePenalty = 0.
|
100 |
+
# ---------------------------------------
|
101 |
+
# This differs from the paper, which defines:
|
102 |
+
# macrocyclePenalty = math.log10(nMacrocycles+1)
|
103 |
+
# This form generates better results when 2 or more macrocycles are present
|
104 |
+
if nMacrocycles > 0: macrocyclePenalty = math.log10(2)
|
105 |
+
|
106 |
+
score2 = 0. -sizePenalty -stereoPenalty -spiroPenalty -bridgePenalty -macrocyclePenalty
|
107 |
+
|
108 |
+
# correction for the fingerprint density
|
109 |
+
# not in the original publication, added in version 1.1
|
110 |
+
# to make highly symmetrical molecules easier to synthetise
|
111 |
+
score3 = 0.
|
112 |
+
if nAtoms > len(fps):
|
113 |
+
score3 = math.log(float(nAtoms) / len(fps)) * .5
|
114 |
+
|
115 |
+
sascore = score1 + score2 + score3
|
116 |
+
|
117 |
+
# need to transform "raw" value into scale between 1 and 10
|
118 |
+
min = -4.0
|
119 |
+
max = 2.5
|
120 |
+
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
|
121 |
+
# smooth the 10-end
|
122 |
+
if sascore > 8.: sascore = 8. + math.log(sascore+1.-9.)
|
123 |
+
if sascore > 10.: sascore = 10.0
|
124 |
+
elif sascore < 1.: sascore = 1.0
|
125 |
+
|
126 |
+
return sascore
|
127 |
+
|
128 |
+
def compute_plogp(mol):
|
129 |
+
|
130 |
+
#mol = MolFromSmiles(smiles_string)
|
131 |
+
#logp = (Crippen.MolLogP(mol) - np.mean(logP_values)) / np.std(logP_values)
|
132 |
+
logp = Crippen.MolLogP(mol)
|
133 |
+
#SA_score = (-sascorer.calculateScore(mol) - np.mean(SA_scores)) / np.std(SA_scores)
|
134 |
+
SA_score = -calculateScore(mol)
|
135 |
+
cycle_list = nx.cycle_basis(nx.Graph(rdmolops.GetAdjacencyMatrix(mol)))
|
136 |
+
if len(cycle_list) == 0:
|
137 |
+
cycle_length = 0
|
138 |
+
else:
|
139 |
+
cycle_length = max([ len(j) for j in cycle_list ])
|
140 |
+
if cycle_length <= 6:
|
141 |
+
cycle_length = 0
|
142 |
+
else:
|
143 |
+
cycle_length = cycle_length - 6
|
144 |
+
|
145 |
+
#cycle_score = (-cycle_length - np.mean(cycle_scores)) / np.std(cycle_scores)
|
146 |
+
cycle_score = -cycle_length
|
147 |
+
#plogp = -(logp + SA_score + cycle_score)
|
148 |
+
plogp = (logp + SA_score + cycle_score)
|
149 |
+
return plogp
|
150 |
+
|
151 |
+
clf_model = None
|
152 |
+
def load_model():
|
153 |
+
global clf_model
|
154 |
+
#name = op.join(op.dirname(__file__), 'clf_py36.pkl')
|
155 |
+
name = op.join(op.dirname(__file__), 'drd2_current.pkl')
|
156 |
+
with open(name, "rb") as f:
|
157 |
+
clf_model = pickle.load(f)
|
158 |
+
|
159 |
+
def fingerprints_from_mol(mol):
|
160 |
+
fp = AllChem.GetMorganFingerprint(mol, 3, useCounts=True, useFeatures=True)
|
161 |
+
size = 2048
|
162 |
+
nfp = np.zeros((1, size), np.int32)
|
163 |
+
for idx,v in fp.GetNonzeroElements().items():
|
164 |
+
nidx = idx%size
|
165 |
+
nfp[0, nidx] += int(v)
|
166 |
+
return nfp
|
167 |
+
|
168 |
+
def compute_drd2(mol):
|
169 |
+
if clf_model is None:
|
170 |
+
load_model()
|
171 |
+
|
172 |
+
#print(smile)
|
173 |
+
#mol = Chem.MolFromSmiles(smile)
|
174 |
+
if mol:
|
175 |
+
fp = fingerprints_from_mol(mol)
|
176 |
+
score = clf_model.predict_proba(fp)[:, 1]
|
177 |
+
return float(score)
|
178 |
+
return 0.0
|
179 |
+
|
180 |
+
def compute_qed(mol):
|
181 |
+
return QED.qed(mol)
|
182 |
+
|
183 |
+
def compute_logp(mol):
|
184 |
+
return Crippen.MolLogP(mol)
|
185 |
+
|
186 |
+
def compute_tpsa(mol):
|
187 |
+
return rdMolDescriptors.CalcTPSA(mol)
|
188 |
+
|
189 |
+
def compute_sas(mol):
|
190 |
+
return sascorer.calculateScore(mol)
|
191 |
+
|
192 |
+
|
193 |
+
def check_valid_unique(smiles_list):
|
194 |
+
total_num = len(smiles_list)
|
195 |
+
empty_num = smiles_list.count("")
|
196 |
+
|
197 |
+
smiles_set = set(smiles_list)
|
198 |
+
if "" in smiles_set:
|
199 |
+
smiles_set.remove("")
|
200 |
+
return 1 - empty_num / float(total_num), \
|
201 |
+
len(smiles_set) / float(total_num - empty_num)
|
202 |
+
|
203 |
+
def get_similarity(smiles1, smiles2):
|
204 |
+
if smiles1 == "" or smiles2 == "":
|
205 |
+
return np.nan
|
206 |
+
sim = TanimotoSimilarity(FingerprintMols.FingerprintMol(Chem.MolFromSmiles(smiles1)),
|
207 |
+
FingerprintMols.FingerprintMol(Chem.MolFromSmiles(smiles2)))
|
208 |
+
|
209 |
+
return sim
|
210 |
+
|
211 |
+
def get_scaffold(smiles):
|
212 |
+
scaffold = MurckoScaffoldSmiles(smiles)
|
213 |
+
return scaffold
|
utils.py
CHANGED
@@ -12,37 +12,62 @@ from datasets import Dataset
|
|
12 |
from tqdm import tqdm
|
13 |
import spaces
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
from rdkit import RDLogger, Chem
|
16 |
# Suppress RDKit INFO messages
|
17 |
RDLogger.DisableLog('rdApp.*')
|
18 |
|
19 |
DEFAULT_PAD_TOKEN = "[PAD]"
|
20 |
-
device_map = "
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
else:
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
|
48 |
@dataclass
|
@@ -50,36 +75,98 @@ class DataCollatorForCausalLMEval(object):
|
|
50 |
tokenizer: transformers.PreTrainedTokenizer
|
51 |
source_max_len: int
|
52 |
target_max_len: int
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
55 |
end_str: str
|
56 |
-
|
57 |
-
|
58 |
-
return self.sme.augment([molecule])[0]
|
59 |
|
60 |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
76 |
|
77 |
data_dict = {
|
78 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
}
|
80 |
|
81 |
return data_dict
|
82 |
|
|
|
83 |
def smart_tokenizer_and_embedding_resize(
|
84 |
special_tokens_dict: Dict,
|
85 |
tokenizer: transformers.PreTrainedTokenizer,
|
@@ -106,176 +193,168 @@ def smart_tokenizer_and_embedding_resize(
|
|
106 |
input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
|
107 |
print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.")
|
108 |
|
109 |
-
class
|
110 |
-
def __init__(self
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
use_fast=True,
|
119 |
-
trust_remote_code=True,
|
120 |
-
token = os.environ.get("TOKEN")
|
121 |
-
)
|
122 |
-
self.load_retro_model(candidate_models[model])
|
123 |
-
else:
|
124 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
125 |
-
candidate_models[list(candidate_models.keys())[0]],
|
126 |
-
padding_side="right",
|
127 |
-
use_fast=True,
|
128 |
-
trust_remote_code=True,
|
129 |
-
token = os.environ.get("TOKEN")
|
130 |
-
)
|
131 |
-
self.load_forward_model(candidate_models[model])
|
132 |
-
|
133 |
-
string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
|
134 |
-
string_template = json.load(open(string_template_path, 'r'))
|
135 |
-
reactant_start_str = string_template['REACTANTS_START_STRING']
|
136 |
-
product_start_str = string_template['PRODUCTS_START_STRING']
|
137 |
-
end_str = string_template['END_STRING']
|
138 |
-
self.data_collator = DataCollatorForCausalLMEval(
|
139 |
-
tokenizer=self.tokenizer,
|
140 |
-
source_max_len=512,
|
141 |
-
target_max_len=512,
|
142 |
-
reactant_start_str=reactant_start_str,
|
143 |
-
product_start_str=product_start_str,
|
144 |
-
end_str=end_str,
|
145 |
)
|
146 |
-
|
147 |
|
148 |
-
|
149 |
-
def load_retro_model(self, model_path):
|
150 |
-
# our retro model is lora model
|
151 |
config = AutoConfig.from_pretrained(
|
152 |
-
|
|
|
153 |
trust_remote_code=True,
|
154 |
-
token=os.environ.get("TOKEN")
|
155 |
)
|
156 |
|
157 |
-
|
158 |
-
|
159 |
config=config,
|
160 |
-
trust_remote_code=True,
|
161 |
device_map=device_map,
|
|
|
162 |
token = os.environ.get("TOKEN")
|
163 |
)
|
164 |
-
|
165 |
-
#
|
166 |
special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
|
167 |
smart_tokenizer_and_embedding_resize(
|
168 |
special_tokens_dict=special_tokens_dict,
|
169 |
tokenizer=self.tokenizer,
|
170 |
-
model=
|
171 |
-
)
|
172 |
-
base_model.config.pad_token_id = self.tokenizer.pad_token_id
|
173 |
-
|
174 |
-
# load the adapter model
|
175 |
-
self.retro_model = PeftModel.from_pretrained(
|
176 |
-
base_model,
|
177 |
-
model_path,
|
178 |
-
token = os.environ.get("TOKEN")
|
179 |
)
|
|
|
180 |
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
)
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
# the finetune tokenizer could be in different size with pretrain tokenizer, and also, we need to add PAD_TOKEN
|
200 |
-
special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
|
201 |
-
smart_tokenizer_and_embedding_resize(
|
202 |
-
special_tokens_dict=special_tokens_dict,
|
203 |
-
tokenizer=self.tokenizer,
|
204 |
-
model=self.forward_model
|
205 |
-
)
|
206 |
-
self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id
|
207 |
-
#self.forward_model.to("cuda")
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
if "." in smiles:
|
213 |
-
return None
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
-
predictions = []
|
236 |
-
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
237 |
-
with torch.no_grad():
|
238 |
-
generation_prompts = batch['generation_prompts'][0]
|
239 |
-
inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True).to(self.retro_model.device)
|
240 |
-
print(inputs)
|
241 |
-
del inputs['token_type_ids']
|
242 |
-
"""
|
243 |
-
if task_type == "retrosynthesis":
|
244 |
-
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
245 |
-
do_sample=False, num_beams=10,
|
246 |
-
eos_token_id=self.tokenizer.eos_token_id,
|
247 |
-
early_stopping='never',
|
248 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
249 |
-
length_penalty=0.0,
|
250 |
-
)
|
251 |
-
else:
|
252 |
-
outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
253 |
-
do_sample=False, num_beams=10,
|
254 |
-
eos_token_id=self.tokenizer.eos_token_id,
|
255 |
-
early_stopping='never',
|
256 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
257 |
-
length_penalty=0.0,
|
258 |
-
)
|
259 |
-
|
260 |
-
original_smiles_list = self.tokenizer.batch_decode(outputs[:, len(inputs['input_ids'][0]):],
|
261 |
-
skip_special_tokens=True)
|
262 |
-
original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list)
|
263 |
-
# canonize the SMILES
|
264 |
-
canonized_smiles_list = []
|
265 |
-
temp = []
|
266 |
-
for original_smiles in original_smiles_list:
|
267 |
-
temp.append(original_smiles)
|
268 |
-
try:
|
269 |
-
canonized_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(original_smiles)))
|
270 |
-
except:
|
271 |
-
canonized_smiles_list.append("")
|
272 |
-
"""
|
273 |
-
canonized_smiles_list = \
|
274 |
-
['N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1cc(F)c([N+](=O)[O-])cc1F', 'N#Cc1ccsc1Nc1cc(Cl)c(F)cc1[N+](=O)[O-]', 'N#Cc1cnsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1cc(F)c(F)cc1Nc1sccc1C#N', 'N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=N)[O-]', 'N#Cc1cc(C#N)c(Nc2cc(F)c(F)cc2[N+](=O)[O-])s1', 'N#Cc1ccsc1Nc1c(F)c(F)cc(F)c1[N+](=O)[O-]', 'Nc1sccc1CNc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1ccc(F)cc1[N+](=O)[O-]']
|
275 |
-
predictions.append(canonized_smiles_list)
|
276 |
|
277 |
-
|
278 |
-
return rank
|
279 |
|
280 |
|
281 |
|
|
|
12 |
from tqdm import tqdm
|
13 |
import spaces
|
14 |
|
15 |
+
from llama_customized_models import LlamaForCausalLMWithNumericalEmbedding
|
16 |
+
from torch.nn.utils.rnn import pad_sequence
|
17 |
+
import numpy as np
|
18 |
+
from torch.utils.data.dataloader import DataLoader
|
19 |
+
from torch.nn import functional as F
|
20 |
+
import importlib
|
21 |
+
|
22 |
from rdkit import RDLogger, Chem
|
23 |
# Suppress RDKit INFO messages
|
24 |
RDLogger.DisableLog('rdApp.*')
|
25 |
|
26 |
DEFAULT_PAD_TOKEN = "[PAD]"
|
27 |
+
device_map = "cuda"
|
28 |
+
|
29 |
+
means = {"qed": 0.5559003125710424, "logp": 3.497542110420217, "sas": 2.889429694406497, "tpsa": 80.19717097706841}
|
30 |
+
stds = {"qed": 0.21339854620824716, "logp": 1.7923582437824368, "sas": 0.8081188219568571, "tpsa": 38.212259443049554}
|
31 |
+
|
32 |
+
def phrase_df(df):
|
33 |
+
metric_calculator = importlib.import_module("metric_calculator")
|
34 |
+
|
35 |
+
new_df = []
|
36 |
+
# iterate over the dataframe
|
37 |
+
for i in range(len(df)):
|
38 |
+
sub_df = dict()
|
39 |
+
|
40 |
+
# get the SMILES
|
41 |
+
smiles = df.iloc[i]['SMILES']
|
42 |
+
# get the property names
|
43 |
+
property_names = df.iloc[i]['property_names']
|
44 |
+
# get the non normalized properties
|
45 |
+
non_normalized_properties = df.iloc[i]['non_normalized_properties']
|
46 |
+
|
47 |
+
sub_df['SMILES'] = smiles
|
48 |
+
|
49 |
+
|
50 |
+
# compute the similarity between the scaffold and the SMILES
|
51 |
+
|
52 |
+
for j in range(len(property_names)):
|
53 |
+
# get the property name
|
54 |
+
property_name = property_names[j]
|
55 |
+
# get the non normalized property
|
56 |
+
non_normalized_property = non_normalized_properties[j]
|
57 |
+
|
58 |
+
sub_df[f'{property_name}_condition'] = non_normalized_property
|
59 |
+
|
60 |
+
if smiles == "":
|
61 |
+
sub_df[f'{property_name}_measured'] = np.nan
|
62 |
else:
|
63 |
+
property_eval_func_name = f"compute_{property_name}"
|
64 |
+
property_eval_func = getattr(metric_calculator, property_eval_func_name)
|
65 |
+
sub_df[f'{property_name}_measured'] = property_eval_func(Chem.MolFromSmiles(smiles))
|
66 |
+
|
67 |
+
new_df.append(sub_df)
|
68 |
+
|
69 |
+
new_df = pd.DataFrame(new_df)
|
70 |
+
return new_df
|
71 |
|
72 |
|
73 |
@dataclass
|
|
|
75 |
tokenizer: transformers.PreTrainedTokenizer
|
76 |
source_max_len: int
|
77 |
target_max_len: int
|
78 |
+
molecule_target_aug_prob: float
|
79 |
+
molecule_start_str: str
|
80 |
+
scaffold_aug_prob: float
|
81 |
+
scaffold_start_str: str
|
82 |
+
property_start_str: str
|
83 |
+
property_inner_sep: str
|
84 |
+
property_inter_sep: str
|
85 |
end_str: str
|
86 |
+
ignore_index: int
|
87 |
+
has_scaffold: bool
|
|
|
88 |
|
89 |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
90 |
+
# Extract elements
|
91 |
+
prop_token_map = {
|
92 |
+
'qed': '<qed>',
|
93 |
+
'logp': '<logp>',
|
94 |
+
'sas': '<SAS>',
|
95 |
+
'tpsa': '<TPSA>'
|
96 |
+
}
|
97 |
|
98 |
+
sources = []
|
99 |
+
props_list = []
|
100 |
+
non_normalized_props_list = []
|
101 |
+
prop_names_list = []
|
102 |
+
props_index_list = []
|
103 |
+
temperature_list = []
|
104 |
+
scaffold_list = []
|
105 |
+
for example in instances:
|
106 |
+
prop_names = example['property_name']
|
107 |
+
prop_values = example['property_value']
|
108 |
+
non_normalized_prop_values = example['non_normalized_property_value']
|
109 |
+
temperature = example['temperature']
|
110 |
+
# we need to convert the string to a list
|
111 |
+
|
112 |
+
# randomly choose the property and the scaffold combinations:
|
113 |
+
props_str = ""
|
114 |
+
scaffold_str = ""
|
115 |
+
props = []
|
116 |
+
non_nornalized_props = []
|
117 |
+
props_index = []
|
118 |
+
|
119 |
+
|
120 |
+
if self.has_scaffold:
|
121 |
+
scaffold = example['scaffold_smiles'].strip()
|
122 |
+
scaffold_str = f"{self.scaffold_start_str}{scaffold}{self.end_str}"
|
123 |
+
|
124 |
+
props_str = f"{self.property_start_str}"
|
125 |
+
for i, prop in enumerate(prop_names):
|
126 |
+
prop = prop.lower()
|
127 |
+
props_str += f"{prop_token_map[prop]}{self.property_inner_sep}{self.molecule_start_str}{self.property_inter_sep}"
|
128 |
+
props.append(prop_values[i])
|
129 |
+
non_nornalized_props.append(non_normalized_prop_values[i])
|
130 |
+
props_index.append(3 + 4 * i) # this is hard coded for the current template
|
131 |
+
props_str += f"{self.end_str}"
|
132 |
+
|
133 |
+
source = props_str + scaffold_str + "<->>" + self.molecule_start_str
|
134 |
+
|
135 |
+
sources.append(source)
|
136 |
+
props_list.append(props)
|
137 |
+
non_normalized_props_list.append(non_nornalized_props)
|
138 |
+
props_index_list.append(props_index)
|
139 |
+
prop_names_list.append(prop_names)
|
140 |
+
temperature_list.append(temperature)
|
141 |
+
|
142 |
+
# Tokenize
|
143 |
+
tokenized_sources_with_prompt = self.tokenizer(
|
144 |
+
sources,
|
145 |
+
max_length=self.source_max_len,
|
146 |
+
truncation=True,
|
147 |
+
add_special_tokens=False,
|
148 |
+
)
|
149 |
|
150 |
+
# Build the input and labels for causal LM
|
151 |
+
input_ids = []
|
152 |
+
for tokenized_source in tokenized_sources_with_prompt['input_ids']:
|
153 |
+
input_ids.append(torch.tensor(tokenized_source))
|
154 |
+
# Apply padding
|
155 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
156 |
|
157 |
data_dict = {
|
158 |
+
'input_ids': input_ids,
|
159 |
+
'attention_mask':input_ids.ne(self.tokenizer.pad_token_id),
|
160 |
+
'properties': props_list,
|
161 |
+
'non_normalized_properties': non_normalized_props_list,
|
162 |
+
'property_names': prop_names_list,
|
163 |
+
'properties_index': props_index_list,
|
164 |
+
'temperature': temperature_list,
|
165 |
}
|
166 |
|
167 |
return data_dict
|
168 |
|
169 |
+
|
170 |
def smart_tokenizer_and_embedding_resize(
|
171 |
special_tokens_dict: Dict,
|
172 |
tokenizer: transformers.PreTrainedTokenizer,
|
|
|
193 |
input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
|
194 |
print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.")
|
195 |
|
196 |
+
class MolecularGenerationModel():
|
197 |
+
def __init__(self):
|
198 |
+
model_id = "ChemFM/molecular_cond_generation_guacamol"
|
199 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
200 |
+
model_id,
|
201 |
+
padding_side="right",
|
202 |
+
use_fast=True,
|
203 |
+
trust_remote_code=True,
|
204 |
+
token = os.environ.get("TOKEN")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
)
|
|
|
206 |
|
207 |
+
# load model
|
|
|
|
|
208 |
config = AutoConfig.from_pretrained(
|
209 |
+
model_id,
|
210 |
+
device_map=device_map,
|
211 |
trust_remote_code=True,
|
212 |
+
token = os.environ.get("TOKEN")
|
213 |
)
|
214 |
|
215 |
+
self.model = LlamaForCausalLMWithNumericalEmbedding.from_pretrained(
|
216 |
+
model_id,
|
217 |
config=config,
|
|
|
218 |
device_map=device_map,
|
219 |
+
trust_remote_code=True,
|
220 |
token = os.environ.get("TOKEN")
|
221 |
)
|
222 |
+
|
223 |
+
# the finetune tokenizer could be in different size with pretrain tokenizer, and also, we need to add PAD_TOKEN
|
224 |
special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
|
225 |
smart_tokenizer_and_embedding_resize(
|
226 |
special_tokens_dict=special_tokens_dict,
|
227 |
tokenizer=self.tokenizer,
|
228 |
+
model=self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
)
|
230 |
+
self.model.config.pad_token_id = self.tokenizer.pad_token_id
|
231 |
|
232 |
+
self.model.eval()
|
233 |
+
|
234 |
+
string_template_path = hf_hub_download(model_id, filename="string_template.json", token = os.environ.get("TOKEN"))
|
235 |
+
string_template = json.load(open(string_template_path, 'r'))
|
236 |
+
molecule_start_str = string_template['MOLECULE_START_STRING']
|
237 |
+
scaffold_start_str = string_template['SCAFFOLD_MOLECULE_START_STRING']
|
238 |
+
property_start_str = string_template['PROPERTY_START_STRING']
|
239 |
+
property_inner_sep = string_template['PROPERTY_INNER_SEP']
|
240 |
+
property_inter_sep = string_template['PROPERTY_INTER_SEP']
|
241 |
+
end_str = string_template['END_STRING']
|
242 |
|
243 |
+
self.data_collator = DataCollatorForCausalLMEval(
|
244 |
+
tokenizer=self.tokenizer,
|
245 |
+
source_max_len=512,
|
246 |
+
target_max_len=512,
|
247 |
+
molecule_target_aug_prob=1.0,
|
248 |
+
scaffold_aug_prob=0.0,
|
249 |
+
molecule_start_str=molecule_start_str,
|
250 |
+
scaffold_start_str=scaffold_start_str,
|
251 |
+
property_start_str=property_start_str,
|
252 |
+
property_inner_sep=property_inner_sep,
|
253 |
+
property_inter_sep=property_inter_sep,
|
254 |
+
end_str=end_str,
|
255 |
+
ignore_index=-100,
|
256 |
+
has_scaffold=False
|
257 |
)
|
258 |
|
259 |
+
@spaces.GPU(duration=60)
|
260 |
+
def generate(self, loader):
|
261 |
+
|
262 |
+
df = []
|
263 |
+
pbar = tqdm(loader, desc=f"Evaluating...", leave=False)
|
264 |
+
for it, batch in enumerate(pbar):
|
265 |
+
sub_df = dict()
|
266 |
+
|
267 |
+
batch_size = batch['input_ids'].shape[0]
|
268 |
+
assert batch_size == 1, "The batch size should be 1"
|
269 |
+
|
270 |
+
temperature = batch['temperature'][0]
|
271 |
+
property_names = batch['property_names'][0]
|
272 |
+
non_normalized_properties = batch['non_normalized_properties'][0]
|
273 |
+
|
274 |
+
num_generations = 1
|
275 |
+
del batch['temperature']
|
276 |
+
del batch['property_names']
|
277 |
+
del batch['non_normalized_properties']
|
278 |
+
|
279 |
+
input_length = batch['input_ids'].shape[1]
|
280 |
+
steps = 1024 - input_length
|
281 |
+
|
282 |
+
with torch.set_grad_enabled(False):
|
283 |
+
early_stop_flags = torch.zeros(num_generations, dtype=torch.bool).to(self.model.device)
|
284 |
+
for k in range(steps):
|
285 |
+
logits = self.model(**batch)['logits']
|
286 |
+
logits = logits[:, -1, :] / temperature
|
287 |
+
probs = F.softmax(logits, dim=-1)
|
288 |
+
ix = torch.multinomial(probs, num_samples=num_generations)
|
289 |
+
|
290 |
+
ix[early_stop_flags] = self.tokenizer.eos_token_id
|
291 |
+
|
292 |
+
batch['input_ids'] = torch.cat([batch['input_ids'], ix], dim=-1)
|
293 |
+
early_stop_flags |= (ix.squeeze() == self.tokenizer.eos_token_id)
|
294 |
+
|
295 |
+
if torch.all(early_stop_flags):
|
296 |
+
break
|
297 |
+
|
298 |
+
generations = self.tokenizer.batch_decode(batch['input_ids'][:, input_length:], skip_special_tokens=True)
|
299 |
+
generations = map(lambda x: x.replace(" ", ""), generations)
|
300 |
+
|
301 |
+
predictions = []
|
302 |
+
for generation in generations:
|
303 |
+
try:
|
304 |
+
predictions.append(Chem.MolToSmiles(Chem.MolFromSmiles(generation)))
|
305 |
+
except:
|
306 |
+
predictions.append("")
|
307 |
+
|
308 |
+
sub_df['SMILES'] = predictions[0]
|
309 |
+
sub_df['property_names'] = property_names
|
310 |
+
sub_df['property'] = batch['properties'][0]
|
311 |
+
sub_df['non_normalized_properties'] = non_normalized_properties
|
312 |
+
|
313 |
+
df.append(sub_df)
|
314 |
+
|
315 |
+
df = pd.DataFrame(df)
|
316 |
+
return df
|
317 |
+
|
318 |
+
|
319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
+
def predict_single_smiles(self, input_dict: Dict):
|
322 |
+
# conver the key to lower case
|
323 |
+
input_dict = {key.lower(): value for key, value in input_dict.items()}
|
|
|
|
|
324 |
|
325 |
+
properties = [key.lower() for key in input_dict.keys()]
|
326 |
+
property_means = [means[prop] for prop in properties]
|
327 |
+
property_stds = [stds[prop] for prop in properties]
|
328 |
+
|
329 |
+
sample_point = [input_dict[prop] for prop in properties]
|
330 |
+
non_normalized_sample_point = np.array(sample_point).reshape(-1)
|
331 |
+
sample_point = (np.array(sample_point) - np.array(property_means)) / np.array(property_stds)
|
332 |
+
sub_df = {
|
333 |
+
"property_name": properties,
|
334 |
+
"property_value": sample_point.tolist(),
|
335 |
+
"temperature": 1.0,
|
336 |
+
"non_normalized_property_value": non_normalized_sample_point.tolist()
|
337 |
+
}
|
338 |
+
|
339 |
+
test_dataset = [sub_df] * 10
|
340 |
+
test_dataset = pd.DataFrame(test_dataset)
|
341 |
+
test_dataset = Dataset.from_pandas(test_dataset)
|
342 |
+
|
343 |
+
|
344 |
+
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=self.data_collator)
|
345 |
+
df = self.generate(test_loader)
|
346 |
+
new_df = phrase_df(df)
|
347 |
+
# delete the condition columns
|
348 |
+
new_df = new_df.drop(columns=[col for col in new_df.columns if "condition" in col])
|
349 |
+
|
350 |
+
# drop the empty smiles rows
|
351 |
+
new_df = new_df.dropna(subset=['SMILES'])
|
352 |
+
|
353 |
+
# convert the measured to 2 decimal places
|
354 |
+
new_df = new_df.round(2)
|
355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
+
return new_df
|
|
|
358 |
|
359 |
|
360 |
|