feiyang-cai commited on
Commit
1d1d4f3
·
1 Parent(s): 1e001e8
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +62 -83
  3. llama_customized_models.py +154 -0
  4. metric_calculator.py +213 -0
  5. utils.py +269 -190
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/*
app.py CHANGED
@@ -1,31 +1,11 @@
1
  import gradio as gr
2
  from huggingface_hub import HfApi, get_collection, list_collections, list_models
3
  #from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
4
- from utils import ReactionPredictionModel
5
  import pandas as pd
6
  import os
7
  import spaces
8
 
9
- def get_models():
10
- # we only support two models
11
- # 1. ChemFM/uspto_mit_synthesis
12
- # 2. ChemFM/uspto_full_retro
13
-
14
-
15
- models = dict()
16
- models['mit_synthesis'] = 'ChemFM/uspto_mit_synthesis'
17
- models['full_retro'] = 'ChemFM/uspto_full_retro'
18
-
19
-
20
- #for item in collection.items:
21
- # if item.item_type == "model":
22
- # item_name = item.item_id.split("/")[-1]
23
- # models[item_name] = item.item_id
24
- # assert item_name in dataset_task_types, f"{item_name} is not in the task_types"
25
- # assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
26
-
27
- return models
28
-
29
  #candidate_models = get_models()
30
  #task_names = {
31
  # 'mit_synthesis': 'Reaction Synthesis',
@@ -46,16 +26,30 @@ def get_models():
46
  #}
47
 
48
  #property_names = list(candidate_models.keys())
49
- #model = ReactionPredictionModel(candidate_models)
50
- #model = MolecularPropertyPredictionModel(candidate_models)
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- def predict_single_label(value_1, value_2, value_3, value_4):
53
- print(value_1, value_2, value_3, value_4)
54
 
55
  try:
56
 
57
  running_status = None
58
  prediction = None
 
 
59
 
60
  #prediction = model.predict(smiles, property_name, adapter_id)
61
  #prediction = model.predict_single_smiles(smiles, task)
@@ -65,10 +59,10 @@ def predict_single_label(value_1, value_2, value_3, value_4):
65
  except Exception as e:
66
  # no matter what the error is, we should return
67
  print(e)
68
- return "NA", "Prediction failed"
69
 
70
- prediction = "\n".join([f"{idx+1}. {item}" for idx, item in enumerate(prediction)])
71
- return prediction, "Prediction is done"
72
 
73
  """
74
  def get_description(task_name):
@@ -177,6 +171,13 @@ def clear_file(download_button):
177
  return gr.update(visible=False), gr.update(visible=False), None
178
  """
179
 
 
 
 
 
 
 
 
180
  def build_inference():
181
 
182
  with gr.Blocks() as demo:
@@ -184,7 +185,11 @@ def build_inference():
184
  #with gr.Row():
185
  #gr.Markdown(f"<span style='color: red;'>If you run out of your GPU quota, you can use the </span> <a href='https://huggingface.co/spaces/ChemFM/molecular_property_prediction'>CPU-powered space</a> but with much lower performance.")
186
  #dropdown = gr.Dropdown([task_names[key] for key in tasks], label="Task", value=task_names[tasks[0]])
187
- description = f"Generate 10 possible molecules based on the given conditions. \n"
 
 
 
 
188
 
189
  description_box = gr.Textbox(label="Task description", lines=5,
190
  interactive=False,
@@ -192,80 +197,54 @@ def build_inference():
192
  # third row - Textbox input and prediction label
193
  with gr.Row(equal_height=True):
194
  with gr.Column():
195
- checkbox_1 = gr.Checkbox(label="qed")
196
- slider_1 = gr.Slider(2, 20, value=4, label="qed", info="Choose between 2 and 20")
 
197
  with gr.Column():
198
- checkbox_2 = gr.Checkbox(label="logp")
199
- slider_2 = gr.Slider(2, 20, value=4, label="logp", info="Choose between 2 and 20")
 
200
  with gr.Column():
201
- checkbox_3 = gr.Checkbox(label="sas")
202
- slider_3 = gr.Slider(2, 20, value=4, label="sas", info="Choose between 2 and 20")
 
203
  with gr.Column():
204
- checkbox_4 = gr.Checkbox(label="weight")
205
- slider_4 = gr.Slider(2, 20, value=4, label="weight", info="Choose between 2 and 20")
 
206
 
207
  predict_single_smiles_button = gr.Button("Generate", size='sm')
208
  #prediction = gr.Label("Prediction will appear here")
209
- prediction = gr.Textbox(label="Predictions", type="text", placeholder=None, lines=10, interactive=False)
 
210
 
211
  running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
212
 
213
- #input_file = gr.File(label="Molecule file",
214
- # file_count='single',
215
- # file_types=[".smi", ".csv"], height=300)
216
- #predict_file_button = gr.Button("Predict", size='sm', visible=False)
217
- #download_button = gr.DownloadButton("Download", size='sm', visible=False)
218
- #stop_button = gr.Button("Stop", size='sm', visible=False)
219
 
220
  # dropdown change event
221
  # predict single button click event
222
  predict_single_smiles_button.click(lambda:(gr.update(interactive=False),
 
 
 
 
223
  gr.update(interactive=False),
224
  gr.update(interactive=False),
225
  gr.update(interactive=False),
226
  gr.update(interactive=False),
227
  gr.update(interactive=False),
228
  ) , outputs=[slider_1, slider_2, slider_3, slider_4,
 
229
  predict_single_smiles_button, running_terminal_label])\
230
- .then(predict_single_label, inputs=[slider_1, slider_2, slider_3, slider_4], outputs=[prediction, running_terminal_label])\
231
- .then(lambda:(gr.update(interactive=True),
232
- gr.update(interactive=True),
233
- gr.update(interactive=True),
234
- gr.update(interactive=True),
235
- gr.update(interactive=True),
236
- gr.update(interactive=True),
237
- ) , outputs=[slider_1, slider_2, slider_3, slider_4,
238
- predict_single_smiles_button, running_terminal_label])
239
- """
240
- # input file upload event
241
- file_status = gr.State()
242
- input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
243
- # input file clear event
244
- input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
245
- # predict file button click event
246
- predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False),
247
- gr.update(interactive=False),
248
- gr.update(interactive=False),
249
- gr.update(interactive=False, visible=True),
250
- gr.update(interactive=False),
251
- gr.update(interactive=True, visible=False),
252
- gr.update(interactive=False),
253
- gr.update(interactive=False),
254
- ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
255
- .then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
256
- .then(lambda:(gr.update(interactive=True),
257
- gr.update(interactive=True),
258
- gr.update(interactive=True),
259
- gr.update(interactive=True),
260
- gr.update(interactive=True),
261
- gr.update(interactive=True),
262
- gr.update(interactive=True),
263
- gr.update(interactive=True),
264
- ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
265
-
266
- # stop button click event
267
- #stop_button.click(fn=None, inputs=None, outputs=None, cancels=[predict_file_event])
268
- """
269
 
270
  return demo
271
 
 
1
  import gradio as gr
2
  from huggingface_hub import HfApi, get_collection, list_collections, list_models
3
  #from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
4
+ from utils import MolecularGenerationModel
5
  import pandas as pd
6
  import os
7
  import spaces
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  #candidate_models = get_models()
10
  #task_names = {
11
  # 'mit_synthesis': 'Reaction Synthesis',
 
26
  #}
27
 
28
  #property_names = list(candidate_models.keys())
29
+ model = MolecularGenerationModel()
30
+
31
+ def predict_single_label(logp, tpas, sas, qed, logp_choose, tpsa_choose, sas_choose, qed_choose):
32
+ input_dict = dict()
33
+ if logp_choose:
34
+ input_dict['logP'] = logp
35
+ if tpsa_choose:
36
+ input_dict['TPSA'] = tpas
37
+ if sas_choose:
38
+ input_dict['SAS'] = sas
39
+ if qed_choose:
40
+ input_dict['QED'] = qed
41
+
42
+ if len(input_dict) == 0:
43
+ return "NA", "No input is selected"
44
 
45
+ print(input_dict)
 
46
 
47
  try:
48
 
49
  running_status = None
50
  prediction = None
51
+
52
+ prediction = model.predict_single_smiles(input_dict)
53
 
54
  #prediction = model.predict(smiles, property_name, adapter_id)
55
  #prediction = model.predict_single_smiles(smiles, task)
 
59
  except Exception as e:
60
  # no matter what the error is, we should return
61
  print(e)
62
+ return "NA", "Generation failed"
63
 
64
+ #prediction = "\n".join([f"{idx+1}. {item}" for idx, item in enumerate(prediction)])
65
+ return prediction, "Generation is done"
66
 
67
  """
68
  def get_description(task_name):
 
171
  return gr.update(visible=False), gr.update(visible=False), None
172
  """
173
 
174
+ def toggle_slider(checked):
175
+ return gr.update(interactive=checked)
176
+
177
+ def toggle_sliders_based_on_checkboxes(checked_values):
178
+ """Enable or disable sliders based on the corresponding checkbox values."""
179
+ return [gr.update(interactive=checked_values[i]) for i in range(4)]
180
+
181
  def build_inference():
182
 
183
  with gr.Blocks() as demo:
 
185
  #with gr.Row():
186
  #gr.Markdown(f"<span style='color: red;'>If you run out of your GPU quota, you can use the </span> <a href='https://huggingface.co/spaces/ChemFM/molecular_property_prediction'>CPU-powered space</a> but with much lower performance.")
187
  #dropdown = gr.Dropdown([task_names[key] for key in tasks], label="Task", value=task_names[tasks[0]])
188
+ description = f"This space allows you to generate ten possible molecules based on given conditions. \n" \
189
+ f"1. You can enable or disable specific properties using checkboxes and adjust their values with sliders. \n" \
190
+ f"2. The generated SMILES strings and their corresponding predicted properties will be displayed in the generations section. \n" \
191
+ f"3. The properties include logP, TPSA, SAS, and QED. \n" \
192
+ f"4. Model trained on the GuacaMol dataset for molecular design. "
193
 
194
  description_box = gr.Textbox(label="Task description", lines=5,
195
  interactive=False,
 
197
  # third row - Textbox input and prediction label
198
  with gr.Row(equal_height=True):
199
  with gr.Column():
200
+ checkbox_1 = gr.Checkbox(label="logP", value=True)
201
+ slider_1 = gr.Slider(1, 7, value=4, label="logP", info="Choose between 1 and 7")
202
+ checkbox_1.change(toggle_slider, checkbox_1, slider_1)
203
  with gr.Column():
204
+ checkbox_2 = gr.Checkbox(label="TPSA", value=True)
205
+ slider_2 = gr.Slider(20, 140, value=80, label="TPSA", info="Choose between 20 and 140")
206
+ checkbox_2.change(toggle_slider, checkbox_2, slider_2)
207
  with gr.Column():
208
+ checkbox_3 = gr.Checkbox(label="SAS", value=True)
209
+ slider_3 = gr.Slider(1, 5, value=3, label="SAS", info="Choose between 1 and 5")
210
+ checkbox_3.change(toggle_slider, checkbox_3, slider_3)
211
  with gr.Column():
212
+ checkbox_4 = gr.Checkbox(label="QED", value=True)
213
+ slider_4 = gr.Slider(0.1, 0.9, value=0.5, label="QED", info="Choose between 0.1 and 0.9")
214
+ checkbox_4.change(toggle_slider, checkbox_4, slider_4)
215
 
216
  predict_single_smiles_button = gr.Button("Generate", size='sm')
217
  #prediction = gr.Label("Prediction will appear here")
218
+ #prediction = gr.Textbox(label="Predictions", type="text", placeholder=None, lines=10, interactive=False)
219
+ prediction = gr.Dataframe(label="Generations", type="pandas", interactive=False)
220
 
221
  running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
222
 
 
 
 
 
 
 
223
 
224
  # dropdown change event
225
  # predict single button click event
226
  predict_single_smiles_button.click(lambda:(gr.update(interactive=False),
227
+ gr.update(interactive=False),
228
+ gr.update(interactive=False),
229
+ gr.update(interactive=False),
230
+ gr.update(interactive=False),
231
  gr.update(interactive=False),
232
  gr.update(interactive=False),
233
  gr.update(interactive=False),
234
  gr.update(interactive=False),
235
  gr.update(interactive=False),
236
  ) , outputs=[slider_1, slider_2, slider_3, slider_4,
237
+ checkbox_1, checkbox_2, checkbox_3, checkbox_4,
238
  predict_single_smiles_button, running_terminal_label])\
239
+ .then(predict_single_label, inputs=[slider_1, slider_2, slider_3, slider_4,
240
+ checkbox_1, checkbox_2, checkbox_3, checkbox_4
241
+ ], outputs=[prediction, running_terminal_label])\
242
+ .then(lambda a, b, c, d: toggle_sliders_based_on_checkboxes([a, b, c, d]) +
243
+ [gr.update(interactive=True)] * 6,
244
+ inputs=[checkbox_1, checkbox_2, checkbox_3, checkbox_4],
245
+ outputs=[slider_1, slider_2, slider_3, slider_4,
246
+ checkbox_1, checkbox_2, checkbox_3, checkbox_4,
247
+ predict_single_smiles_button, running_terminal_label])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  return demo
250
 
llama_customized_models.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel
2
+ from transformers.models.llama.configuration_llama import LlamaConfig
3
+ import torch.nn as nn
4
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
5
+ from transformers.modeling_outputs import (
6
+ BaseModelOutputWithPast,
7
+ CausalLMOutputWithPast,
8
+ QuestionAnsweringModelOutput,
9
+ SequenceClassifierOutputWithPast,
10
+ )
11
+ from transformers.cache_utils import Cache
12
+
13
+ from transformers.modeling_outputs import (
14
+ CausalLMOutputWithPast,
15
+ )
16
+ from transformers.utils import (
17
+ add_start_docstrings_to_model_forward,
18
+ logging,
19
+ replace_return_docstrings,
20
+ )
21
+ from dataclasses import dataclass
22
+
23
+ from transformers.utils import ModelOutput
24
+
25
+ import torch
26
+ from typing import List, Optional, Tuple, Union
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ _CONFIG_FOR_DOC = "LlamaConfig"
31
+
32
+ LLAMA_INPUTS_DOCSTRING = r"""
33
+ Args:
34
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
35
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
36
+ it.
37
+
38
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
39
+ [`PreTrainedTokenizer.__call__`] for details.
40
+
41
+ [What are input IDs?](../glossary#input-ids)
42
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
43
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
44
+
45
+ - 1 for tokens that are **not masked**,
46
+ - 0 for tokens that are **masked**.
47
+
48
+ [What are attention masks?](../glossary#attention-mask)
49
+
50
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
51
+ [`PreTrainedTokenizer.__call__`] for details.
52
+
53
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
54
+ `past_key_values`).
55
+
56
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
57
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
58
+ information on the default strategy.
59
+
60
+ - 1 indicates the head is **not masked**,
61
+ - 0 indicates the head is **masked**.
62
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
63
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
64
+ config.n_positions - 1]`.
65
+
66
+ [What are position IDs?](../glossary#position-ids)
67
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
68
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
69
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
70
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
71
+
72
+ Two formats are allowed:
73
+ - a [`~cache_utils.Cache`] instance;
74
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
75
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
76
+ cache format.
77
+
78
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
79
+ legacy cache format will be returned.
80
+
81
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
82
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
83
+ of shape `(batch_size, sequence_length)`.
84
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
85
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
86
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
87
+ model's internal embedding lookup matrix.
88
+ use_cache (`bool`, *optional*):
89
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
90
+ `past_key_values`).
91
+ output_attentions (`bool`, *optional*):
92
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
93
+ tensors for more detail.
94
+ output_hidden_states (`bool`, *optional*):
95
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
96
+ more detail.
97
+ return_dict (`bool`, *optional*):
98
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
99
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
100
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
101
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
102
+ the complete sequence length.
103
+ """
104
+
105
+ class LlamaForCausalLMWithNumericalEmbedding(LlamaForCausalLM):
106
+
107
+ def __init__(self, config: LlamaConfig):
108
+ super().__init__(config)
109
+ self.numerical_embedding = torch.nn.Linear(1, config.hidden_size, bias=True)
110
+
111
+
112
+ def forward(
113
+ self,
114
+ input_ids: torch.LongTensor = None,
115
+ properties: List = None,
116
+ properties_index: List = None,
117
+ attention_mask: Optional[torch.Tensor] = None,
118
+ position_ids: Optional[torch.LongTensor] = None,
119
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
120
+ inputs_embeds: Optional[torch.FloatTensor] = None,
121
+ cache_position=None,
122
+ labels: Optional[torch.LongTensor] = None,
123
+ use_cache: Optional[bool] = None,
124
+ output_attentions: Optional[bool] = None,
125
+ output_hidden_states: Optional[bool] = None,
126
+ return_dict: Optional[bool] = None,
127
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
128
+
129
+ b, l = input_ids.size()
130
+ assert len(properties) == b, "The number of properties should be equal to the batch size."
131
+ assert len(properties_index) == b, "The number of properties_index should be equal to the batch size."
132
+
133
+ embeddings = self.model.embed_tokens(input_ids)
134
+
135
+ for i, (props, props_index, embeds) in enumerate(zip(properties, properties_index, embeddings)):
136
+ assert len(props) == len(props_index), "The number of properties should be equal to the number of properties_index."
137
+ props = torch.tensor(props, device=embeds.device, dtype=torch.float32).unsqueeze(1)
138
+ num_embeds = self.numerical_embedding(props)
139
+ if len(props_index) > 0:
140
+ assert embeddings[i, props_index, :].shape == num_embeds.shape, "The shape of the embeddings and the numerical embeddings should be the same."
141
+ embeddings[i, props_index, :] = num_embeds
142
+
143
+ return super().forward(
144
+ input_ids=None,
145
+ attention_mask=attention_mask,
146
+ position_ids=position_ids,
147
+ past_key_values=past_key_values,
148
+ inputs_embeds=embeddings,
149
+ labels=labels,
150
+ use_cache=use_cache,
151
+ output_attentions=output_attentions,
152
+ output_hidden_states=output_hidden_states,
153
+ return_dict=return_dict,
154
+ )
metric_calculator.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import mean_squared_error, roc_auc_score, r2_score
2
+ from rdkit.Chem import QED, Crippen, MolFromSmiles, rdmolops, rdMolDescriptors, AllChem
3
+ from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
4
+ import networkx as nx
5
+ import os.path as op
6
+ import math
7
+ #from rdkit.six.moves import cPickle
8
+ import _pickle as cPickle
9
+ #from rdkit.six import iteritems
10
+ from rdkit import Chem
11
+ import pickle
12
+ import numpy as np
13
+
14
+ import sys
15
+ import os
16
+ from rdkit.Chem import RDConfig
17
+ sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
18
+ import sascorer
19
+
20
+ from rdkit.DataStructs.cDataStructs import TanimotoSimilarity
21
+ from rdkit.Chem.Fingerprints import FingerprintMols
22
+
23
+ def compute_rmse(gt, pred):
24
+ return mean_squared_error(gt, pred, squared=False)
25
+
26
+ def compute_r2score(gt, pred):
27
+ return r2_score(gt, pred)
28
+
29
+ def compute_roc_auc(gt, pred):
30
+ return roc_auc_score(gt, pred)
31
+
32
+ def check_valid(smiles_list):
33
+ total_num = len(smiles_list)
34
+ empty_num = smiles_list.count("")
35
+ return 1 - empty_num / float(total_num)
36
+
37
+ def check_unique(smiles_list):
38
+ total_num = len(smiles_list)
39
+ smiles_set = set(smiles_list)
40
+ if "" in smiles_set:
41
+ smiles_set.remove("")
42
+ return len(smiles_set) / float(total_num)
43
+
44
+ def check_nolvelty(gen_smiles, train_smiles):
45
+ if len(gen_smiles) == 0:
46
+ novel_ratio = 0.
47
+ else:
48
+ duplicates = [1 for mol in gen_smiles if mol in train_smiles]
49
+ novel = len(gen_smiles) - sum(duplicates)
50
+ novel_ratio = novel*100./len(gen_smiles)
51
+ return novel_ratio
52
+
53
+ _fscores = None
54
+ def readFragmentScores(name='fpscores'):
55
+ import gzip
56
+ global _fscores
57
+ # generate the full path filename:
58
+ if name == "fpscores":
59
+ name = op.join(op.dirname(__file__), name)
60
+ _fscores = cPickle.load(gzip.open('%s.pkl.gz'%name))
61
+ outDict = {}
62
+ for i in _fscores:
63
+ for j in range(1,len(i)):
64
+ outDict[i[j]] = float(i[0])
65
+ _fscores = outDict
66
+
67
+ def numBridgeheadsAndSpiro(mol,ri=None):
68
+ nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
69
+ nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
70
+ return nBridgehead,nSpiro
71
+
72
+ def calculateScore(m):
73
+ if _fscores is None: readFragmentScores()
74
+
75
+ # fragment score
76
+ fp = rdMolDescriptors.GetMorganFingerprint(m,2) #<- 2 is the *radius* of the circular fingerprint
77
+ fps = fp.GetNonzeroElements()
78
+ score1 = 0.
79
+ nf = 0
80
+ for bitId,v in iteritems(fps):
81
+ nf += v
82
+ sfp = bitId
83
+ score1 += _fscores.get(sfp,-4)*v
84
+ score1 /= nf
85
+
86
+ # features score
87
+ nAtoms = m.GetNumAtoms()
88
+ nChiralCenters = len(Chem.FindMolChiralCenters(m,includeUnassigned=True))
89
+ ri = m.GetRingInfo()
90
+ nBridgeheads,nSpiro=numBridgeheadsAndSpiro(m,ri)
91
+ nMacrocycles=0
92
+ for x in ri.AtomRings():
93
+ if len(x)>8: nMacrocycles+=1
94
+
95
+ sizePenalty = nAtoms**1.005 - nAtoms
96
+ stereoPenalty = math.log10(nChiralCenters+1)
97
+ spiroPenalty = math.log10(nSpiro+1)
98
+ bridgePenalty = math.log10(nBridgeheads+1)
99
+ macrocyclePenalty = 0.
100
+ # ---------------------------------------
101
+ # This differs from the paper, which defines:
102
+ # macrocyclePenalty = math.log10(nMacrocycles+1)
103
+ # This form generates better results when 2 or more macrocycles are present
104
+ if nMacrocycles > 0: macrocyclePenalty = math.log10(2)
105
+
106
+ score2 = 0. -sizePenalty -stereoPenalty -spiroPenalty -bridgePenalty -macrocyclePenalty
107
+
108
+ # correction for the fingerprint density
109
+ # not in the original publication, added in version 1.1
110
+ # to make highly symmetrical molecules easier to synthetise
111
+ score3 = 0.
112
+ if nAtoms > len(fps):
113
+ score3 = math.log(float(nAtoms) / len(fps)) * .5
114
+
115
+ sascore = score1 + score2 + score3
116
+
117
+ # need to transform "raw" value into scale between 1 and 10
118
+ min = -4.0
119
+ max = 2.5
120
+ sascore = 11. - (sascore - min + 1) / (max - min) * 9.
121
+ # smooth the 10-end
122
+ if sascore > 8.: sascore = 8. + math.log(sascore+1.-9.)
123
+ if sascore > 10.: sascore = 10.0
124
+ elif sascore < 1.: sascore = 1.0
125
+
126
+ return sascore
127
+
128
+ def compute_plogp(mol):
129
+
130
+ #mol = MolFromSmiles(smiles_string)
131
+ #logp = (Crippen.MolLogP(mol) - np.mean(logP_values)) / np.std(logP_values)
132
+ logp = Crippen.MolLogP(mol)
133
+ #SA_score = (-sascorer.calculateScore(mol) - np.mean(SA_scores)) / np.std(SA_scores)
134
+ SA_score = -calculateScore(mol)
135
+ cycle_list = nx.cycle_basis(nx.Graph(rdmolops.GetAdjacencyMatrix(mol)))
136
+ if len(cycle_list) == 0:
137
+ cycle_length = 0
138
+ else:
139
+ cycle_length = max([ len(j) for j in cycle_list ])
140
+ if cycle_length <= 6:
141
+ cycle_length = 0
142
+ else:
143
+ cycle_length = cycle_length - 6
144
+
145
+ #cycle_score = (-cycle_length - np.mean(cycle_scores)) / np.std(cycle_scores)
146
+ cycle_score = -cycle_length
147
+ #plogp = -(logp + SA_score + cycle_score)
148
+ plogp = (logp + SA_score + cycle_score)
149
+ return plogp
150
+
151
+ clf_model = None
152
+ def load_model():
153
+ global clf_model
154
+ #name = op.join(op.dirname(__file__), 'clf_py36.pkl')
155
+ name = op.join(op.dirname(__file__), 'drd2_current.pkl')
156
+ with open(name, "rb") as f:
157
+ clf_model = pickle.load(f)
158
+
159
+ def fingerprints_from_mol(mol):
160
+ fp = AllChem.GetMorganFingerprint(mol, 3, useCounts=True, useFeatures=True)
161
+ size = 2048
162
+ nfp = np.zeros((1, size), np.int32)
163
+ for idx,v in fp.GetNonzeroElements().items():
164
+ nidx = idx%size
165
+ nfp[0, nidx] += int(v)
166
+ return nfp
167
+
168
+ def compute_drd2(mol):
169
+ if clf_model is None:
170
+ load_model()
171
+
172
+ #print(smile)
173
+ #mol = Chem.MolFromSmiles(smile)
174
+ if mol:
175
+ fp = fingerprints_from_mol(mol)
176
+ score = clf_model.predict_proba(fp)[:, 1]
177
+ return float(score)
178
+ return 0.0
179
+
180
+ def compute_qed(mol):
181
+ return QED.qed(mol)
182
+
183
+ def compute_logp(mol):
184
+ return Crippen.MolLogP(mol)
185
+
186
+ def compute_tpsa(mol):
187
+ return rdMolDescriptors.CalcTPSA(mol)
188
+
189
+ def compute_sas(mol):
190
+ return sascorer.calculateScore(mol)
191
+
192
+
193
+ def check_valid_unique(smiles_list):
194
+ total_num = len(smiles_list)
195
+ empty_num = smiles_list.count("")
196
+
197
+ smiles_set = set(smiles_list)
198
+ if "" in smiles_set:
199
+ smiles_set.remove("")
200
+ return 1 - empty_num / float(total_num), \
201
+ len(smiles_set) / float(total_num - empty_num)
202
+
203
+ def get_similarity(smiles1, smiles2):
204
+ if smiles1 == "" or smiles2 == "":
205
+ return np.nan
206
+ sim = TanimotoSimilarity(FingerprintMols.FingerprintMol(Chem.MolFromSmiles(smiles1)),
207
+ FingerprintMols.FingerprintMol(Chem.MolFromSmiles(smiles2)))
208
+
209
+ return sim
210
+
211
+ def get_scaffold(smiles):
212
+ scaffold = MurckoScaffoldSmiles(smiles)
213
+ return scaffold
utils.py CHANGED
@@ -12,37 +12,62 @@ from datasets import Dataset
12
  from tqdm import tqdm
13
  import spaces
14
 
 
 
 
 
 
 
 
15
  from rdkit import RDLogger, Chem
16
  # Suppress RDKit INFO messages
17
  RDLogger.DisableLog('rdApp.*')
18
 
19
  DEFAULT_PAD_TOKEN = "[PAD]"
20
- device_map = "cpu"
21
-
22
- def compute_rank(prediction,raw=False,alpha=1.0):
23
- valid_score = [[k for k in range(len(prediction[j]))] for j in range(len(prediction))]
24
- invalid_rates = [0 for k in range(len(prediction[0]))]
25
- rank = {}
26
- highest = {}
27
-
28
- for j in range(len(prediction)):
29
- for k in range(len(prediction[j])):
30
- if prediction[j][k] == "":
31
- valid_score[j][k] = 10 + 1
32
- invalid_rates[k] += 1
33
- de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0] != ""]
34
- prediction[j] = list(set(de_error))
35
- prediction[j].sort(key=de_error.index)
36
- for k, data in enumerate(prediction[j]):
37
- if data in rank:
38
- rank[data] += 1 / (alpha * k + 1)
39
- else:
40
- rank[data] = 1 / (alpha * k + 1)
41
- if data in highest:
42
- highest[data] = min(k,highest[data])
 
 
 
 
 
 
 
 
 
 
 
 
43
  else:
44
- highest[data] = k
45
- return rank,invalid_rates
 
 
 
 
 
 
46
 
47
 
48
  @dataclass
@@ -50,36 +75,98 @@ class DataCollatorForCausalLMEval(object):
50
  tokenizer: transformers.PreTrainedTokenizer
51
  source_max_len: int
52
  target_max_len: int
53
- reactant_start_str: str
54
- product_start_str: str
 
 
 
 
 
55
  end_str: str
56
-
57
- def augment_molecule(self, molecule: str) -> str:
58
- return self.sme.augment([molecule])[0]
59
 
60
  def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
61
-
62
- print(instances)
63
- srcs = instances[0]['src']
64
- task_type = instances[0]['task_type'][0]
 
 
 
65
 
66
- if task_type == 'retrosynthesis':
67
- src_start_str = self.product_start_str
68
- tgt_start_str = self.reactant_start_str
69
- else:
70
- src_start_str = self.reactant_start_str
71
- tgt_start_str = self.product_start_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- generation_prompts = []
74
- generation_prompt = f"{src_start_str}{srcs}{self.end_str}{tgt_start_str}"
75
- generation_prompts.append(generation_prompt)
 
 
 
76
 
77
  data_dict = {
78
- 'generation_prompts': generation_prompts
 
 
 
 
 
 
79
  }
80
 
81
  return data_dict
82
 
 
83
  def smart_tokenizer_and_embedding_resize(
84
  special_tokens_dict: Dict,
85
  tokenizer: transformers.PreTrainedTokenizer,
@@ -106,176 +193,168 @@ def smart_tokenizer_and_embedding_resize(
106
  input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
107
  print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.")
108
 
109
- class ReactionPredictionModel():
110
- def __init__(self, candidate_models):
111
-
112
-
113
- for model in candidate_models:
114
- if "retro" in model:
115
- self.tokenizer = AutoTokenizer.from_pretrained(
116
- candidate_models[list(candidate_models.keys())[0]],
117
- padding_side="right",
118
- use_fast=True,
119
- trust_remote_code=True,
120
- token = os.environ.get("TOKEN")
121
- )
122
- self.load_retro_model(candidate_models[model])
123
- else:
124
- self.tokenizer = AutoTokenizer.from_pretrained(
125
- candidate_models[list(candidate_models.keys())[0]],
126
- padding_side="right",
127
- use_fast=True,
128
- trust_remote_code=True,
129
- token = os.environ.get("TOKEN")
130
- )
131
- self.load_forward_model(candidate_models[model])
132
-
133
- string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
134
- string_template = json.load(open(string_template_path, 'r'))
135
- reactant_start_str = string_template['REACTANTS_START_STRING']
136
- product_start_str = string_template['PRODUCTS_START_STRING']
137
- end_str = string_template['END_STRING']
138
- self.data_collator = DataCollatorForCausalLMEval(
139
- tokenizer=self.tokenizer,
140
- source_max_len=512,
141
- target_max_len=512,
142
- reactant_start_str=reactant_start_str,
143
- product_start_str=product_start_str,
144
- end_str=end_str,
145
  )
146
-
147
 
148
-
149
- def load_retro_model(self, model_path):
150
- # our retro model is lora model
151
  config = AutoConfig.from_pretrained(
152
- "ChemFM/ChemFM-3B",
 
153
  trust_remote_code=True,
154
- token=os.environ.get("TOKEN")
155
  )
156
 
157
- base_model = AutoModelForCausalLM.from_pretrained(
158
- "ChemFM/ChemFM-3B",
159
  config=config,
160
- trust_remote_code=True,
161
  device_map=device_map,
 
162
  token = os.environ.get("TOKEN")
163
  )
164
-
165
- # we should resize the embedding layer of the base model to match the adapter's tokenizer
166
  special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
167
  smart_tokenizer_and_embedding_resize(
168
  special_tokens_dict=special_tokens_dict,
169
  tokenizer=self.tokenizer,
170
- model=base_model
171
- )
172
- base_model.config.pad_token_id = self.tokenizer.pad_token_id
173
-
174
- # load the adapter model
175
- self.retro_model = PeftModel.from_pretrained(
176
- base_model,
177
- model_path,
178
- token = os.environ.get("TOKEN")
179
  )
 
180
 
181
- #self.retro_model.to("cuda")
 
 
 
 
 
 
 
 
 
182
 
183
- def load_forward_model(self, model_path):
184
- config = AutoConfig.from_pretrained(
185
- model_path,
186
- device_map=device_map,
187
- trust_remote_code=True,
188
- token = os.environ.get("TOKEN")
 
 
 
 
 
 
 
 
189
  )
190
 
191
- self.forward_model = AutoModelForCausalLM.from_pretrained(
192
- model_path,
193
- config=config,
194
- device_map=device_map,
195
- trust_remote_code=True,
196
- token = os.environ.get("TOKEN")
197
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- # the finetune tokenizer could be in different size with pretrain tokenizer, and also, we need to add PAD_TOKEN
200
- special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
201
- smart_tokenizer_and_embedding_resize(
202
- special_tokens_dict=special_tokens_dict,
203
- tokenizer=self.tokenizer,
204
- model=self.forward_model
205
- )
206
- self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id
207
- #self.forward_model.to("cuda")
208
 
209
- @spaces.GPU(duration=20)
210
- def predict_single_smiles(self, smiles, task_type):
211
- if task_type == "full_retro":
212
- if "." in smiles:
213
- return None
214
 
215
- task_type = "retrosynthesis" if task_type == "full_retro" else "synthesis"
216
- # canonicalize the smiles
217
- mol = Chem.MolFromSmiles(smiles)
218
- if mol is None:
219
- return None
220
- smiles = Chem.MolToSmiles(mol)
221
-
222
- smiles_list = [smiles]
223
- task_type_list = [task_type]
224
-
225
-
226
- df = pd.DataFrame({"src": smiles_list, "task_type": task_type_list})
227
- test_dataset = Dataset.from_pandas(df)
228
- # construct the dataloader
229
- test_loader = torch.utils.data.DataLoader(
230
- test_dataset,
231
- batch_size=1,
232
- collate_fn=self.data_collator,
233
- )
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- predictions = []
236
- for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
237
- with torch.no_grad():
238
- generation_prompts = batch['generation_prompts'][0]
239
- inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True).to(self.retro_model.device)
240
- print(inputs)
241
- del inputs['token_type_ids']
242
- """
243
- if task_type == "retrosynthesis":
244
- outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
245
- do_sample=False, num_beams=10,
246
- eos_token_id=self.tokenizer.eos_token_id,
247
- early_stopping='never',
248
- pad_token_id=self.tokenizer.pad_token_id,
249
- length_penalty=0.0,
250
- )
251
- else:
252
- outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
253
- do_sample=False, num_beams=10,
254
- eos_token_id=self.tokenizer.eos_token_id,
255
- early_stopping='never',
256
- pad_token_id=self.tokenizer.pad_token_id,
257
- length_penalty=0.0,
258
- )
259
-
260
- original_smiles_list = self.tokenizer.batch_decode(outputs[:, len(inputs['input_ids'][0]):],
261
- skip_special_tokens=True)
262
- original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list)
263
- # canonize the SMILES
264
- canonized_smiles_list = []
265
- temp = []
266
- for original_smiles in original_smiles_list:
267
- temp.append(original_smiles)
268
- try:
269
- canonized_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(original_smiles)))
270
- except:
271
- canonized_smiles_list.append("")
272
- """
273
- canonized_smiles_list = \
274
- ['N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1cc(F)c([N+](=O)[O-])cc1F', 'N#Cc1ccsc1Nc1cc(Cl)c(F)cc1[N+](=O)[O-]', 'N#Cc1cnsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1cc(F)c(F)cc1Nc1sccc1C#N', 'N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=N)[O-]', 'N#Cc1cc(C#N)c(Nc2cc(F)c(F)cc2[N+](=O)[O-])s1', 'N#Cc1ccsc1Nc1c(F)c(F)cc(F)c1[N+](=O)[O-]', 'Nc1sccc1CNc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1ccc(F)cc1[N+](=O)[O-]']
275
- predictions.append(canonized_smiles_list)
276
 
277
- rank, invalid_rate = compute_rank(predictions)
278
- return rank
279
 
280
 
281
 
 
12
  from tqdm import tqdm
13
  import spaces
14
 
15
+ from llama_customized_models import LlamaForCausalLMWithNumericalEmbedding
16
+ from torch.nn.utils.rnn import pad_sequence
17
+ import numpy as np
18
+ from torch.utils.data.dataloader import DataLoader
19
+ from torch.nn import functional as F
20
+ import importlib
21
+
22
  from rdkit import RDLogger, Chem
23
  # Suppress RDKit INFO messages
24
  RDLogger.DisableLog('rdApp.*')
25
 
26
  DEFAULT_PAD_TOKEN = "[PAD]"
27
+ device_map = "cuda"
28
+
29
+ means = {"qed": 0.5559003125710424, "logp": 3.497542110420217, "sas": 2.889429694406497, "tpsa": 80.19717097706841}
30
+ stds = {"qed": 0.21339854620824716, "logp": 1.7923582437824368, "sas": 0.8081188219568571, "tpsa": 38.212259443049554}
31
+
32
+ def phrase_df(df):
33
+ metric_calculator = importlib.import_module("metric_calculator")
34
+
35
+ new_df = []
36
+ # iterate over the dataframe
37
+ for i in range(len(df)):
38
+ sub_df = dict()
39
+
40
+ # get the SMILES
41
+ smiles = df.iloc[i]['SMILES']
42
+ # get the property names
43
+ property_names = df.iloc[i]['property_names']
44
+ # get the non normalized properties
45
+ non_normalized_properties = df.iloc[i]['non_normalized_properties']
46
+
47
+ sub_df['SMILES'] = smiles
48
+
49
+
50
+ # compute the similarity between the scaffold and the SMILES
51
+
52
+ for j in range(len(property_names)):
53
+ # get the property name
54
+ property_name = property_names[j]
55
+ # get the non normalized property
56
+ non_normalized_property = non_normalized_properties[j]
57
+
58
+ sub_df[f'{property_name}_condition'] = non_normalized_property
59
+
60
+ if smiles == "":
61
+ sub_df[f'{property_name}_measured'] = np.nan
62
  else:
63
+ property_eval_func_name = f"compute_{property_name}"
64
+ property_eval_func = getattr(metric_calculator, property_eval_func_name)
65
+ sub_df[f'{property_name}_measured'] = property_eval_func(Chem.MolFromSmiles(smiles))
66
+
67
+ new_df.append(sub_df)
68
+
69
+ new_df = pd.DataFrame(new_df)
70
+ return new_df
71
 
72
 
73
  @dataclass
 
75
  tokenizer: transformers.PreTrainedTokenizer
76
  source_max_len: int
77
  target_max_len: int
78
+ molecule_target_aug_prob: float
79
+ molecule_start_str: str
80
+ scaffold_aug_prob: float
81
+ scaffold_start_str: str
82
+ property_start_str: str
83
+ property_inner_sep: str
84
+ property_inter_sep: str
85
  end_str: str
86
+ ignore_index: int
87
+ has_scaffold: bool
 
88
 
89
  def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
90
+ # Extract elements
91
+ prop_token_map = {
92
+ 'qed': '<qed>',
93
+ 'logp': '<logp>',
94
+ 'sas': '<SAS>',
95
+ 'tpsa': '<TPSA>'
96
+ }
97
 
98
+ sources = []
99
+ props_list = []
100
+ non_normalized_props_list = []
101
+ prop_names_list = []
102
+ props_index_list = []
103
+ temperature_list = []
104
+ scaffold_list = []
105
+ for example in instances:
106
+ prop_names = example['property_name']
107
+ prop_values = example['property_value']
108
+ non_normalized_prop_values = example['non_normalized_property_value']
109
+ temperature = example['temperature']
110
+ # we need to convert the string to a list
111
+
112
+ # randomly choose the property and the scaffold combinations:
113
+ props_str = ""
114
+ scaffold_str = ""
115
+ props = []
116
+ non_nornalized_props = []
117
+ props_index = []
118
+
119
+
120
+ if self.has_scaffold:
121
+ scaffold = example['scaffold_smiles'].strip()
122
+ scaffold_str = f"{self.scaffold_start_str}{scaffold}{self.end_str}"
123
+
124
+ props_str = f"{self.property_start_str}"
125
+ for i, prop in enumerate(prop_names):
126
+ prop = prop.lower()
127
+ props_str += f"{prop_token_map[prop]}{self.property_inner_sep}{self.molecule_start_str}{self.property_inter_sep}"
128
+ props.append(prop_values[i])
129
+ non_nornalized_props.append(non_normalized_prop_values[i])
130
+ props_index.append(3 + 4 * i) # this is hard coded for the current template
131
+ props_str += f"{self.end_str}"
132
+
133
+ source = props_str + scaffold_str + "<->>" + self.molecule_start_str
134
+
135
+ sources.append(source)
136
+ props_list.append(props)
137
+ non_normalized_props_list.append(non_nornalized_props)
138
+ props_index_list.append(props_index)
139
+ prop_names_list.append(prop_names)
140
+ temperature_list.append(temperature)
141
+
142
+ # Tokenize
143
+ tokenized_sources_with_prompt = self.tokenizer(
144
+ sources,
145
+ max_length=self.source_max_len,
146
+ truncation=True,
147
+ add_special_tokens=False,
148
+ )
149
 
150
+ # Build the input and labels for causal LM
151
+ input_ids = []
152
+ for tokenized_source in tokenized_sources_with_prompt['input_ids']:
153
+ input_ids.append(torch.tensor(tokenized_source))
154
+ # Apply padding
155
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
156
 
157
  data_dict = {
158
+ 'input_ids': input_ids,
159
+ 'attention_mask':input_ids.ne(self.tokenizer.pad_token_id),
160
+ 'properties': props_list,
161
+ 'non_normalized_properties': non_normalized_props_list,
162
+ 'property_names': prop_names_list,
163
+ 'properties_index': props_index_list,
164
+ 'temperature': temperature_list,
165
  }
166
 
167
  return data_dict
168
 
169
+
170
  def smart_tokenizer_and_embedding_resize(
171
  special_tokens_dict: Dict,
172
  tokenizer: transformers.PreTrainedTokenizer,
 
193
  input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
194
  print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.")
195
 
196
+ class MolecularGenerationModel():
197
+ def __init__(self):
198
+ model_id = "ChemFM/molecular_cond_generation_guacamol"
199
+ self.tokenizer = AutoTokenizer.from_pretrained(
200
+ model_id,
201
+ padding_side="right",
202
+ use_fast=True,
203
+ trust_remote_code=True,
204
+ token = os.environ.get("TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  )
 
206
 
207
+ # load model
 
 
208
  config = AutoConfig.from_pretrained(
209
+ model_id,
210
+ device_map=device_map,
211
  trust_remote_code=True,
212
+ token = os.environ.get("TOKEN")
213
  )
214
 
215
+ self.model = LlamaForCausalLMWithNumericalEmbedding.from_pretrained(
216
+ model_id,
217
  config=config,
 
218
  device_map=device_map,
219
+ trust_remote_code=True,
220
  token = os.environ.get("TOKEN")
221
  )
222
+
223
+ # the finetune tokenizer could be in different size with pretrain tokenizer, and also, we need to add PAD_TOKEN
224
  special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
225
  smart_tokenizer_and_embedding_resize(
226
  special_tokens_dict=special_tokens_dict,
227
  tokenizer=self.tokenizer,
228
+ model=self.model
 
 
 
 
 
 
 
 
229
  )
230
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
231
 
232
+ self.model.eval()
233
+
234
+ string_template_path = hf_hub_download(model_id, filename="string_template.json", token = os.environ.get("TOKEN"))
235
+ string_template = json.load(open(string_template_path, 'r'))
236
+ molecule_start_str = string_template['MOLECULE_START_STRING']
237
+ scaffold_start_str = string_template['SCAFFOLD_MOLECULE_START_STRING']
238
+ property_start_str = string_template['PROPERTY_START_STRING']
239
+ property_inner_sep = string_template['PROPERTY_INNER_SEP']
240
+ property_inter_sep = string_template['PROPERTY_INTER_SEP']
241
+ end_str = string_template['END_STRING']
242
 
243
+ self.data_collator = DataCollatorForCausalLMEval(
244
+ tokenizer=self.tokenizer,
245
+ source_max_len=512,
246
+ target_max_len=512,
247
+ molecule_target_aug_prob=1.0,
248
+ scaffold_aug_prob=0.0,
249
+ molecule_start_str=molecule_start_str,
250
+ scaffold_start_str=scaffold_start_str,
251
+ property_start_str=property_start_str,
252
+ property_inner_sep=property_inner_sep,
253
+ property_inter_sep=property_inter_sep,
254
+ end_str=end_str,
255
+ ignore_index=-100,
256
+ has_scaffold=False
257
  )
258
 
259
+ @spaces.GPU(duration=60)
260
+ def generate(self, loader):
261
+
262
+ df = []
263
+ pbar = tqdm(loader, desc=f"Evaluating...", leave=False)
264
+ for it, batch in enumerate(pbar):
265
+ sub_df = dict()
266
+
267
+ batch_size = batch['input_ids'].shape[0]
268
+ assert batch_size == 1, "The batch size should be 1"
269
+
270
+ temperature = batch['temperature'][0]
271
+ property_names = batch['property_names'][0]
272
+ non_normalized_properties = batch['non_normalized_properties'][0]
273
+
274
+ num_generations = 1
275
+ del batch['temperature']
276
+ del batch['property_names']
277
+ del batch['non_normalized_properties']
278
+
279
+ input_length = batch['input_ids'].shape[1]
280
+ steps = 1024 - input_length
281
+
282
+ with torch.set_grad_enabled(False):
283
+ early_stop_flags = torch.zeros(num_generations, dtype=torch.bool).to(self.model.device)
284
+ for k in range(steps):
285
+ logits = self.model(**batch)['logits']
286
+ logits = logits[:, -1, :] / temperature
287
+ probs = F.softmax(logits, dim=-1)
288
+ ix = torch.multinomial(probs, num_samples=num_generations)
289
+
290
+ ix[early_stop_flags] = self.tokenizer.eos_token_id
291
+
292
+ batch['input_ids'] = torch.cat([batch['input_ids'], ix], dim=-1)
293
+ early_stop_flags |= (ix.squeeze() == self.tokenizer.eos_token_id)
294
+
295
+ if torch.all(early_stop_flags):
296
+ break
297
+
298
+ generations = self.tokenizer.batch_decode(batch['input_ids'][:, input_length:], skip_special_tokens=True)
299
+ generations = map(lambda x: x.replace(" ", ""), generations)
300
+
301
+ predictions = []
302
+ for generation in generations:
303
+ try:
304
+ predictions.append(Chem.MolToSmiles(Chem.MolFromSmiles(generation)))
305
+ except:
306
+ predictions.append("")
307
+
308
+ sub_df['SMILES'] = predictions[0]
309
+ sub_df['property_names'] = property_names
310
+ sub_df['property'] = batch['properties'][0]
311
+ sub_df['non_normalized_properties'] = non_normalized_properties
312
+
313
+ df.append(sub_df)
314
+
315
+ df = pd.DataFrame(df)
316
+ return df
317
+
318
+
319
 
 
 
 
 
 
 
 
 
 
320
 
321
+ def predict_single_smiles(self, input_dict: Dict):
322
+ # conver the key to lower case
323
+ input_dict = {key.lower(): value for key, value in input_dict.items()}
 
 
324
 
325
+ properties = [key.lower() for key in input_dict.keys()]
326
+ property_means = [means[prop] for prop in properties]
327
+ property_stds = [stds[prop] for prop in properties]
328
+
329
+ sample_point = [input_dict[prop] for prop in properties]
330
+ non_normalized_sample_point = np.array(sample_point).reshape(-1)
331
+ sample_point = (np.array(sample_point) - np.array(property_means)) / np.array(property_stds)
332
+ sub_df = {
333
+ "property_name": properties,
334
+ "property_value": sample_point.tolist(),
335
+ "temperature": 1.0,
336
+ "non_normalized_property_value": non_normalized_sample_point.tolist()
337
+ }
338
+
339
+ test_dataset = [sub_df] * 10
340
+ test_dataset = pd.DataFrame(test_dataset)
341
+ test_dataset = Dataset.from_pandas(test_dataset)
342
+
343
+
344
+ test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=self.data_collator)
345
+ df = self.generate(test_loader)
346
+ new_df = phrase_df(df)
347
+ # delete the condition columns
348
+ new_df = new_df.drop(columns=[col for col in new_df.columns if "condition" in col])
349
+
350
+ # drop the empty smiles rows
351
+ new_df = new_df.dropna(subset=['SMILES'])
352
+
353
+ # convert the measured to 2 decimal places
354
+ new_df = new_df.round(2)
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
+ return new_df
 
358
 
359
 
360