feiyang-cai commited on
Commit
1e001e8
·
1 Parent(s): a35bbb5
Files changed (3) hide show
  1. app.py +276 -0
  2. requirements.txt +9 -0
  3. utils.py +283 -0
app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, get_collection, list_collections, list_models
3
+ #from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
4
+ from utils import ReactionPredictionModel
5
+ import pandas as pd
6
+ import os
7
+ import spaces
8
+
9
+ def get_models():
10
+ # we only support two models
11
+ # 1. ChemFM/uspto_mit_synthesis
12
+ # 2. ChemFM/uspto_full_retro
13
+
14
+
15
+ models = dict()
16
+ models['mit_synthesis'] = 'ChemFM/uspto_mit_synthesis'
17
+ models['full_retro'] = 'ChemFM/uspto_full_retro'
18
+
19
+
20
+ #for item in collection.items:
21
+ # if item.item_type == "model":
22
+ # item_name = item.item_id.split("/")[-1]
23
+ # models[item_name] = item.item_id
24
+ # assert item_name in dataset_task_types, f"{item_name} is not in the task_types"
25
+ # assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
26
+
27
+ return models
28
+
29
+ #candidate_models = get_models()
30
+ #task_names = {
31
+ # 'mit_synthesis': 'Reaction Synthesis',
32
+ # 'full_retro': 'Reaction Retro Synthesis'
33
+ #}
34
+ #task_names_to_tasks = {v: k for k, v in task_names.items()}
35
+ #tasks = list(candidate_models.keys())
36
+ #task_descriptions = {
37
+ # 'mit_synthesis': 'Predict the reaction products given the reactants and reagents. \n' + \
38
+ # '1. This model is trained on the USPTO MIT dataset. \n' + \
39
+ # '2. The reactants and reagents are mixed in the input SMILES string. \n' + \
40
+ # '3. Different compounds are separated by ".". \n' + \
41
+ # '4. Input SMILES string example: C1CCOC1.N#Cc1ccsc1N.O=[N+]([O-])c1cc(F)c(F)cc1F.[H-].[Na+]',
42
+ # 'full_retro': 'Predict the reaction precursors given the reaction products. \n' + \
43
+ # '1. This model is trained on the USPTO Full dataset. \n' + \
44
+ # '2. In this dataset, we consider only a single product in the input SMILES string. \n' + \
45
+ # '3. Input SMILES string example: CC(=O)OCC(=O)[C@@]1(O)CC[C@H]2[C@@H]3CCC4=CC(=O)CC[C@]4(C)C3=CC[C@@]21C'
46
+ #}
47
+
48
+ #property_names = list(candidate_models.keys())
49
+ #model = ReactionPredictionModel(candidate_models)
50
+ #model = MolecularPropertyPredictionModel(candidate_models)
51
+
52
+ def predict_single_label(value_1, value_2, value_3, value_4):
53
+ print(value_1, value_2, value_3, value_4)
54
+
55
+ try:
56
+
57
+ running_status = None
58
+ prediction = None
59
+
60
+ #prediction = model.predict(smiles, property_name, adapter_id)
61
+ #prediction = model.predict_single_smiles(smiles, task)
62
+ if prediction is None:
63
+ return "NA", "Invalid SMILES string"
64
+
65
+ except Exception as e:
66
+ # no matter what the error is, we should return
67
+ print(e)
68
+ return "NA", "Prediction failed"
69
+
70
+ prediction = "\n".join([f"{idx+1}. {item}" for idx, item in enumerate(prediction)])
71
+ return prediction, "Prediction is done"
72
+
73
+ """
74
+ def get_description(task_name):
75
+ task = task_names_to_tasks[task_name]
76
+ return task_descriptions[task]
77
+
78
+ #@spaces.GPU(duration=10)
79
+ """
80
+
81
+ """
82
+ @spaces.GPU(duration=30)
83
+ def predict_file(file, property_name):
84
+ property_id = dataset_property_names_to_dataset[property_name]
85
+ try:
86
+ adapter_id = candidate_models[property_id]
87
+ info = model.swith_adapter(property_id, adapter_id)
88
+
89
+ running_status = None
90
+ if info == "keep":
91
+ running_status = "Adapter is the same as the current one"
92
+ #print("Adapter is the same as the current one")
93
+ elif info == "switched":
94
+ running_status = "Adapter is switched successfully"
95
+ #print("Adapter is switched successfully")
96
+ elif info == "error":
97
+ running_status = "Adapter is not found"
98
+ #print("Adapter is not found")
99
+ return None, None, file, running_status
100
+ else:
101
+ running_status = "Unknown error"
102
+ return None, None, file, running_status
103
+
104
+ df = pd.read_csv(file)
105
+ # we have already checked the file contains the "smiles" column
106
+ df = model.predict_file(df, dataset_task_types[property_id])
107
+ # we should save this file to the disk to be downloaded
108
+ # rename the file to have "_prediction" suffix
109
+ prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
110
+ print(file, prediction_file)
111
+ # save the file to the disk
112
+ df.to_csv(prediction_file, index=False)
113
+ except Exception as e:
114
+ # no matter what the error is, we should return
115
+ print(e)
116
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), file, "Prediction failed"
117
+
118
+ return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done"
119
+
120
+ def validate_file(file):
121
+ try:
122
+ if file.endswith(".csv"):
123
+ df = pd.read_csv(file)
124
+ if "smiles" not in df.columns:
125
+ # we should clear the file input
126
+ return "Invalid file content. The csv file must contain column named 'smiles'", \
127
+ None, gr.update(visible=False), gr.update(visible=False)
128
+
129
+ # check the length of the smiles
130
+ length = len(df["smiles"])
131
+
132
+ elif file.endswith(".smi"):
133
+ return "Invalid file extension", \
134
+ None, gr.update(visible=False), gr.update(visible=False)
135
+
136
+ else:
137
+ return "Invalid file extension", \
138
+ None, gr.update(visible=False), gr.update(visible=False)
139
+ except Exception as e:
140
+ return "Invalid file content.", \
141
+ None, gr.update(visible=False), gr.update(visible=False)
142
+
143
+ if length > 100:
144
+ return "The space does not support the file containing more than 100 SMILES", \
145
+ None, gr.update(visible=False), gr.update(visible=False)
146
+
147
+ return "Valid file", file, gr.update(visible=True), gr.update(visible=False)
148
+ """
149
+
150
+
151
+ def raise_error(status):
152
+ if status != "Valid file":
153
+ raise gr.Error(status)
154
+ return None
155
+
156
+
157
+ """
158
+ def clear_file(download_button):
159
+ # we might need to delete the prediction file and uploaded file
160
+ prediction_path = download_button
161
+ print(prediction_path)
162
+ if prediction_path and os.path.exists(prediction_path):
163
+ os.remove(prediction_path)
164
+ original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv")
165
+ original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi")
166
+ if os.path.exists(original_data_file_0):
167
+ os.remove(original_data_file_0)
168
+ if os.path.exists(original_data_file_1):
169
+ os.remove(original_data_file_1)
170
+ #if os.path.exists(file):
171
+ # os.remove(file)
172
+ #prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
173
+ #if os.path.exists(prediction_file):
174
+ # os.remove(prediction_file)
175
+
176
+
177
+ return gr.update(visible=False), gr.update(visible=False), None
178
+ """
179
+
180
+ def build_inference():
181
+
182
+ with gr.Blocks() as demo:
183
+ # first row - Dropdown input
184
+ #with gr.Row():
185
+ #gr.Markdown(f"<span style='color: red;'>If you run out of your GPU quota, you can use the </span> <a href='https://huggingface.co/spaces/ChemFM/molecular_property_prediction'>CPU-powered space</a> but with much lower performance.")
186
+ #dropdown = gr.Dropdown([task_names[key] for key in tasks], label="Task", value=task_names[tasks[0]])
187
+ description = f"Generate 10 possible molecules based on the given conditions. \n"
188
+
189
+ description_box = gr.Textbox(label="Task description", lines=5,
190
+ interactive=False,
191
+ value= description)
192
+ # third row - Textbox input and prediction label
193
+ with gr.Row(equal_height=True):
194
+ with gr.Column():
195
+ checkbox_1 = gr.Checkbox(label="qed")
196
+ slider_1 = gr.Slider(2, 20, value=4, label="qed", info="Choose between 2 and 20")
197
+ with gr.Column():
198
+ checkbox_2 = gr.Checkbox(label="logp")
199
+ slider_2 = gr.Slider(2, 20, value=4, label="logp", info="Choose between 2 and 20")
200
+ with gr.Column():
201
+ checkbox_3 = gr.Checkbox(label="sas")
202
+ slider_3 = gr.Slider(2, 20, value=4, label="sas", info="Choose between 2 and 20")
203
+ with gr.Column():
204
+ checkbox_4 = gr.Checkbox(label="weight")
205
+ slider_4 = gr.Slider(2, 20, value=4, label="weight", info="Choose between 2 and 20")
206
+
207
+ predict_single_smiles_button = gr.Button("Generate", size='sm')
208
+ #prediction = gr.Label("Prediction will appear here")
209
+ prediction = gr.Textbox(label="Predictions", type="text", placeholder=None, lines=10, interactive=False)
210
+
211
+ running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
212
+
213
+ #input_file = gr.File(label="Molecule file",
214
+ # file_count='single',
215
+ # file_types=[".smi", ".csv"], height=300)
216
+ #predict_file_button = gr.Button("Predict", size='sm', visible=False)
217
+ #download_button = gr.DownloadButton("Download", size='sm', visible=False)
218
+ #stop_button = gr.Button("Stop", size='sm', visible=False)
219
+
220
+ # dropdown change event
221
+ # predict single button click event
222
+ predict_single_smiles_button.click(lambda:(gr.update(interactive=False),
223
+ gr.update(interactive=False),
224
+ gr.update(interactive=False),
225
+ gr.update(interactive=False),
226
+ gr.update(interactive=False),
227
+ gr.update(interactive=False),
228
+ ) , outputs=[slider_1, slider_2, slider_3, slider_4,
229
+ predict_single_smiles_button, running_terminal_label])\
230
+ .then(predict_single_label, inputs=[slider_1, slider_2, slider_3, slider_4], outputs=[prediction, running_terminal_label])\
231
+ .then(lambda:(gr.update(interactive=True),
232
+ gr.update(interactive=True),
233
+ gr.update(interactive=True),
234
+ gr.update(interactive=True),
235
+ gr.update(interactive=True),
236
+ gr.update(interactive=True),
237
+ ) , outputs=[slider_1, slider_2, slider_3, slider_4,
238
+ predict_single_smiles_button, running_terminal_label])
239
+ """
240
+ # input file upload event
241
+ file_status = gr.State()
242
+ input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
243
+ # input file clear event
244
+ input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
245
+ # predict file button click event
246
+ predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False),
247
+ gr.update(interactive=False),
248
+ gr.update(interactive=False),
249
+ gr.update(interactive=False, visible=True),
250
+ gr.update(interactive=False),
251
+ gr.update(interactive=True, visible=False),
252
+ gr.update(interactive=False),
253
+ gr.update(interactive=False),
254
+ ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
255
+ .then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
256
+ .then(lambda:(gr.update(interactive=True),
257
+ gr.update(interactive=True),
258
+ gr.update(interactive=True),
259
+ gr.update(interactive=True),
260
+ gr.update(interactive=True),
261
+ gr.update(interactive=True),
262
+ gr.update(interactive=True),
263
+ gr.update(interactive=True),
264
+ ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
265
+
266
+ # stop button click event
267
+ #stop_button.click(fn=None, inputs=None, outputs=None, cancels=[predict_file_event])
268
+ """
269
+
270
+ return demo
271
+
272
+
273
+ demo = build_inference()
274
+
275
+ if __name__ == '__main__':
276
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ huggingface_hub
4
+ pandas
5
+ peft
6
+ tqdm
7
+ datasets
8
+ rdkit
9
+ scikit-learn
utils.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
2
+ import os
3
+ from typing import Optional, Dict, Sequence
4
+ import transformers
5
+ from peft import PeftModel
6
+ import torch
7
+ from dataclasses import dataclass, field
8
+ from huggingface_hub import hf_hub_download
9
+ import json
10
+ import pandas as pd
11
+ from datasets import Dataset
12
+ from tqdm import tqdm
13
+ import spaces
14
+
15
+ from rdkit import RDLogger, Chem
16
+ # Suppress RDKit INFO messages
17
+ RDLogger.DisableLog('rdApp.*')
18
+
19
+ DEFAULT_PAD_TOKEN = "[PAD]"
20
+ device_map = "cpu"
21
+
22
+ def compute_rank(prediction,raw=False,alpha=1.0):
23
+ valid_score = [[k for k in range(len(prediction[j]))] for j in range(len(prediction))]
24
+ invalid_rates = [0 for k in range(len(prediction[0]))]
25
+ rank = {}
26
+ highest = {}
27
+
28
+ for j in range(len(prediction)):
29
+ for k in range(len(prediction[j])):
30
+ if prediction[j][k] == "":
31
+ valid_score[j][k] = 10 + 1
32
+ invalid_rates[k] += 1
33
+ de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0] != ""]
34
+ prediction[j] = list(set(de_error))
35
+ prediction[j].sort(key=de_error.index)
36
+ for k, data in enumerate(prediction[j]):
37
+ if data in rank:
38
+ rank[data] += 1 / (alpha * k + 1)
39
+ else:
40
+ rank[data] = 1 / (alpha * k + 1)
41
+ if data in highest:
42
+ highest[data] = min(k,highest[data])
43
+ else:
44
+ highest[data] = k
45
+ return rank,invalid_rates
46
+
47
+
48
+ @dataclass
49
+ class DataCollatorForCausalLMEval(object):
50
+ tokenizer: transformers.PreTrainedTokenizer
51
+ source_max_len: int
52
+ target_max_len: int
53
+ reactant_start_str: str
54
+ product_start_str: str
55
+ end_str: str
56
+
57
+ def augment_molecule(self, molecule: str) -> str:
58
+ return self.sme.augment([molecule])[0]
59
+
60
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
61
+
62
+ print(instances)
63
+ srcs = instances[0]['src']
64
+ task_type = instances[0]['task_type'][0]
65
+
66
+ if task_type == 'retrosynthesis':
67
+ src_start_str = self.product_start_str
68
+ tgt_start_str = self.reactant_start_str
69
+ else:
70
+ src_start_str = self.reactant_start_str
71
+ tgt_start_str = self.product_start_str
72
+
73
+ generation_prompts = []
74
+ generation_prompt = f"{src_start_str}{srcs}{self.end_str}{tgt_start_str}"
75
+ generation_prompts.append(generation_prompt)
76
+
77
+ data_dict = {
78
+ 'generation_prompts': generation_prompts
79
+ }
80
+
81
+ return data_dict
82
+
83
+ def smart_tokenizer_and_embedding_resize(
84
+ special_tokens_dict: Dict,
85
+ tokenizer: transformers.PreTrainedTokenizer,
86
+ model: transformers.PreTrainedModel,
87
+ non_special_tokens = None,
88
+ ):
89
+ """Resize tokenizer and embedding.
90
+
91
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
92
+ """
93
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens)
94
+ num_old_tokens = model.get_input_embeddings().weight.shape[0]
95
+ num_new_tokens = len(tokenizer) - num_old_tokens
96
+ if num_new_tokens == 0:
97
+ return
98
+
99
+ model.resize_token_embeddings(len(tokenizer))
100
+
101
+ if num_new_tokens > 0:
102
+ input_embeddings_data = model.get_input_embeddings().weight.data
103
+
104
+ input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
105
+
106
+ input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
107
+ print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.")
108
+
109
+ class ReactionPredictionModel():
110
+ def __init__(self, candidate_models):
111
+
112
+
113
+ for model in candidate_models:
114
+ if "retro" in model:
115
+ self.tokenizer = AutoTokenizer.from_pretrained(
116
+ candidate_models[list(candidate_models.keys())[0]],
117
+ padding_side="right",
118
+ use_fast=True,
119
+ trust_remote_code=True,
120
+ token = os.environ.get("TOKEN")
121
+ )
122
+ self.load_retro_model(candidate_models[model])
123
+ else:
124
+ self.tokenizer = AutoTokenizer.from_pretrained(
125
+ candidate_models[list(candidate_models.keys())[0]],
126
+ padding_side="right",
127
+ use_fast=True,
128
+ trust_remote_code=True,
129
+ token = os.environ.get("TOKEN")
130
+ )
131
+ self.load_forward_model(candidate_models[model])
132
+
133
+ string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
134
+ string_template = json.load(open(string_template_path, 'r'))
135
+ reactant_start_str = string_template['REACTANTS_START_STRING']
136
+ product_start_str = string_template['PRODUCTS_START_STRING']
137
+ end_str = string_template['END_STRING']
138
+ self.data_collator = DataCollatorForCausalLMEval(
139
+ tokenizer=self.tokenizer,
140
+ source_max_len=512,
141
+ target_max_len=512,
142
+ reactant_start_str=reactant_start_str,
143
+ product_start_str=product_start_str,
144
+ end_str=end_str,
145
+ )
146
+
147
+
148
+
149
+ def load_retro_model(self, model_path):
150
+ # our retro model is lora model
151
+ config = AutoConfig.from_pretrained(
152
+ "ChemFM/ChemFM-3B",
153
+ trust_remote_code=True,
154
+ token=os.environ.get("TOKEN")
155
+ )
156
+
157
+ base_model = AutoModelForCausalLM.from_pretrained(
158
+ "ChemFM/ChemFM-3B",
159
+ config=config,
160
+ trust_remote_code=True,
161
+ device_map=device_map,
162
+ token = os.environ.get("TOKEN")
163
+ )
164
+
165
+ # we should resize the embedding layer of the base model to match the adapter's tokenizer
166
+ special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
167
+ smart_tokenizer_and_embedding_resize(
168
+ special_tokens_dict=special_tokens_dict,
169
+ tokenizer=self.tokenizer,
170
+ model=base_model
171
+ )
172
+ base_model.config.pad_token_id = self.tokenizer.pad_token_id
173
+
174
+ # load the adapter model
175
+ self.retro_model = PeftModel.from_pretrained(
176
+ base_model,
177
+ model_path,
178
+ token = os.environ.get("TOKEN")
179
+ )
180
+
181
+ #self.retro_model.to("cuda")
182
+
183
+ def load_forward_model(self, model_path):
184
+ config = AutoConfig.from_pretrained(
185
+ model_path,
186
+ device_map=device_map,
187
+ trust_remote_code=True,
188
+ token = os.environ.get("TOKEN")
189
+ )
190
+
191
+ self.forward_model = AutoModelForCausalLM.from_pretrained(
192
+ model_path,
193
+ config=config,
194
+ device_map=device_map,
195
+ trust_remote_code=True,
196
+ token = os.environ.get("TOKEN")
197
+ )
198
+
199
+ # the finetune tokenizer could be in different size with pretrain tokenizer, and also, we need to add PAD_TOKEN
200
+ special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
201
+ smart_tokenizer_and_embedding_resize(
202
+ special_tokens_dict=special_tokens_dict,
203
+ tokenizer=self.tokenizer,
204
+ model=self.forward_model
205
+ )
206
+ self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id
207
+ #self.forward_model.to("cuda")
208
+
209
+ @spaces.GPU(duration=20)
210
+ def predict_single_smiles(self, smiles, task_type):
211
+ if task_type == "full_retro":
212
+ if "." in smiles:
213
+ return None
214
+
215
+ task_type = "retrosynthesis" if task_type == "full_retro" else "synthesis"
216
+ # canonicalize the smiles
217
+ mol = Chem.MolFromSmiles(smiles)
218
+ if mol is None:
219
+ return None
220
+ smiles = Chem.MolToSmiles(mol)
221
+
222
+ smiles_list = [smiles]
223
+ task_type_list = [task_type]
224
+
225
+
226
+ df = pd.DataFrame({"src": smiles_list, "task_type": task_type_list})
227
+ test_dataset = Dataset.from_pandas(df)
228
+ # construct the dataloader
229
+ test_loader = torch.utils.data.DataLoader(
230
+ test_dataset,
231
+ batch_size=1,
232
+ collate_fn=self.data_collator,
233
+ )
234
+
235
+ predictions = []
236
+ for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
237
+ with torch.no_grad():
238
+ generation_prompts = batch['generation_prompts'][0]
239
+ inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True).to(self.retro_model.device)
240
+ print(inputs)
241
+ del inputs['token_type_ids']
242
+ """
243
+ if task_type == "retrosynthesis":
244
+ outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
245
+ do_sample=False, num_beams=10,
246
+ eos_token_id=self.tokenizer.eos_token_id,
247
+ early_stopping='never',
248
+ pad_token_id=self.tokenizer.pad_token_id,
249
+ length_penalty=0.0,
250
+ )
251
+ else:
252
+ outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
253
+ do_sample=False, num_beams=10,
254
+ eos_token_id=self.tokenizer.eos_token_id,
255
+ early_stopping='never',
256
+ pad_token_id=self.tokenizer.pad_token_id,
257
+ length_penalty=0.0,
258
+ )
259
+
260
+ original_smiles_list = self.tokenizer.batch_decode(outputs[:, len(inputs['input_ids'][0]):],
261
+ skip_special_tokens=True)
262
+ original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list)
263
+ # canonize the SMILES
264
+ canonized_smiles_list = []
265
+ temp = []
266
+ for original_smiles in original_smiles_list:
267
+ temp.append(original_smiles)
268
+ try:
269
+ canonized_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(original_smiles)))
270
+ except:
271
+ canonized_smiles_list.append("")
272
+ """
273
+ canonized_smiles_list = \
274
+ ['N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1cc(F)c([N+](=O)[O-])cc1F', 'N#Cc1ccsc1Nc1cc(Cl)c(F)cc1[N+](=O)[O-]', 'N#Cc1cnsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1cc(F)c(F)cc1Nc1sccc1C#N', 'N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=N)[O-]', 'N#Cc1cc(C#N)c(Nc2cc(F)c(F)cc2[N+](=O)[O-])s1', 'N#Cc1ccsc1Nc1c(F)c(F)cc(F)c1[N+](=O)[O-]', 'Nc1sccc1CNc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1ccc(F)cc1[N+](=O)[O-]']
275
+ predictions.append(canonized_smiles_list)
276
+
277
+ rank, invalid_rate = compute_rank(predictions)
278
+ return rank
279
+
280
+
281
+
282
+
283
+