feiyang-cai commited on
Commit
edaff0a
·
1 Parent(s): 5362c33

update the reaction

Browse files
Files changed (3) hide show
  1. app.py +249 -0
  2. requirements.txt +9 -0
  3. utils.py +280 -0
app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, get_collection, list_collections, list_models
3
+ #from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
4
+ from utils import ReactionPredictionModel
5
+ import pandas as pd
6
+ import os
7
+ import spaces
8
+
9
+ def get_models():
10
+ # we only support two models
11
+ # 1. ChemFM/uspto_mit_synthesis
12
+ # 2. ChemFM/uspto_full_retro
13
+
14
+
15
+ models = dict()
16
+ models['mit_synthesis'] = 'ChemFM/uspto_mit_synthesis'
17
+ models['full_retro'] = 'ChemFM/uspto_full_retro'
18
+
19
+
20
+ #for item in collection.items:
21
+ # if item.item_type == "model":
22
+ # item_name = item.item_id.split("/")[-1]
23
+ # models[item_name] = item.item_id
24
+ # assert item_name in dataset_task_types, f"{item_name} is not in the task_types"
25
+ # assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
26
+
27
+ return models
28
+
29
+ candidate_models = get_models()
30
+ task_names = {
31
+ 'mit_synthesis': 'Reaction Synthesis',
32
+ 'full_retro': 'Reaction Retro Synthesis'
33
+ }
34
+ task_names_to_tasks = {v: k for k, v in task_names.items()}
35
+ tasks = list(candidate_models.keys())
36
+ task_descriptions = {
37
+ 'mit_synthesis': 'Predict the reaction products given the reactants and reagents (reactants and reagents are mixed; different compounds are separated by ".").' + \
38
+ 'C1CCOC1.N#Cc1ccsc1N.O=[N+]([O-])c1cc(F)c(F)cc1F.[H-].[Na+]',
39
+ 'full_retro': 'Predict the reaction precursors given the reaction products (different compounds are separated by ".").'
40
+ }
41
+
42
+ #property_names = list(candidate_models.keys())
43
+ model = ReactionPredictionModel(candidate_models)
44
+ #model = MolecularPropertyPredictionModel(candidate_models)
45
+
46
+ def get_description(task_name):
47
+ task = task_names_to_tasks[task_name]
48
+ return task_descriptions[task]
49
+
50
+ #@spaces.GPU(duration=10)
51
+ def predict_single_label(smiles, task_name):
52
+ task = task_names_to_tasks[task_name]
53
+
54
+ try:
55
+
56
+ running_status = None
57
+
58
+ #prediction = model.predict(smiles, property_name, adapter_id)
59
+ prediction = model.predict_single_smiles(smiles, task)
60
+ if prediction is None:
61
+ return "NA", "Invalid SMILES string"
62
+
63
+ except Exception as e:
64
+ # no matter what the error is, we should return
65
+ print(e)
66
+ return "NA", "Prediction failed"
67
+
68
+ prediction = "\n".join([f"{idx+1}. {item}" for idx, item in enumerate(prediction)])
69
+ return prediction, "Prediction is done"
70
+
71
+ """
72
+ @spaces.GPU(duration=30)
73
+ def predict_file(file, property_name):
74
+ property_id = dataset_property_names_to_dataset[property_name]
75
+ try:
76
+ adapter_id = candidate_models[property_id]
77
+ info = model.swith_adapter(property_id, adapter_id)
78
+
79
+ running_status = None
80
+ if info == "keep":
81
+ running_status = "Adapter is the same as the current one"
82
+ #print("Adapter is the same as the current one")
83
+ elif info == "switched":
84
+ running_status = "Adapter is switched successfully"
85
+ #print("Adapter is switched successfully")
86
+ elif info == "error":
87
+ running_status = "Adapter is not found"
88
+ #print("Adapter is not found")
89
+ return None, None, file, running_status
90
+ else:
91
+ running_status = "Unknown error"
92
+ return None, None, file, running_status
93
+
94
+ df = pd.read_csv(file)
95
+ # we have already checked the file contains the "smiles" column
96
+ df = model.predict_file(df, dataset_task_types[property_id])
97
+ # we should save this file to the disk to be downloaded
98
+ # rename the file to have "_prediction" suffix
99
+ prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
100
+ print(file, prediction_file)
101
+ # save the file to the disk
102
+ df.to_csv(prediction_file, index=False)
103
+ except Exception as e:
104
+ # no matter what the error is, we should return
105
+ print(e)
106
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), file, "Prediction failed"
107
+
108
+ return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done"
109
+
110
+ def validate_file(file):
111
+ try:
112
+ if file.endswith(".csv"):
113
+ df = pd.read_csv(file)
114
+ if "smiles" not in df.columns:
115
+ # we should clear the file input
116
+ return "Invalid file content. The csv file must contain column named 'smiles'", \
117
+ None, gr.update(visible=False), gr.update(visible=False)
118
+
119
+ # check the length of the smiles
120
+ length = len(df["smiles"])
121
+
122
+ elif file.endswith(".smi"):
123
+ return "Invalid file extension", \
124
+ None, gr.update(visible=False), gr.update(visible=False)
125
+
126
+ else:
127
+ return "Invalid file extension", \
128
+ None, gr.update(visible=False), gr.update(visible=False)
129
+ except Exception as e:
130
+ return "Invalid file content.", \
131
+ None, gr.update(visible=False), gr.update(visible=False)
132
+
133
+ if length > 100:
134
+ return "The space does not support the file containing more than 100 SMILES", \
135
+ None, gr.update(visible=False), gr.update(visible=False)
136
+
137
+ return "Valid file", file, gr.update(visible=True), gr.update(visible=False)
138
+ """
139
+
140
+
141
+ def raise_error(status):
142
+ if status != "Valid file":
143
+ raise gr.Error(status)
144
+ return None
145
+
146
+
147
+ """
148
+ def clear_file(download_button):
149
+ # we might need to delete the prediction file and uploaded file
150
+ prediction_path = download_button
151
+ print(prediction_path)
152
+ if prediction_path and os.path.exists(prediction_path):
153
+ os.remove(prediction_path)
154
+ original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv")
155
+ original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi")
156
+ if os.path.exists(original_data_file_0):
157
+ os.remove(original_data_file_0)
158
+ if os.path.exists(original_data_file_1):
159
+ os.remove(original_data_file_1)
160
+ #if os.path.exists(file):
161
+ # os.remove(file)
162
+ #prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
163
+ #if os.path.exists(prediction_file):
164
+ # os.remove(prediction_file)
165
+
166
+
167
+ return gr.update(visible=False), gr.update(visible=False), None
168
+ """
169
+
170
+ def build_inference():
171
+
172
+ with gr.Blocks() as demo:
173
+ # first row - Dropdown input
174
+ #with gr.Row():
175
+ #gr.Markdown(f"<span style='color: red;'>If you run out of your GPU quota, you can use the </span> <a href='https://huggingface.co/spaces/ChemFM/molecular_property_prediction'>CPU-powered space</a> but with much lower performance.")
176
+ dropdown = gr.Dropdown([task_names[key] for key in tasks], label="Task", value=task_names[tasks[0]])
177
+ description_box = gr.Textbox(label="Task description", lines=5,
178
+ interactive=False,
179
+ value= task_descriptions[tasks[0]])
180
+ # third row - Textbox input and prediction label
181
+ #with gr.Row(equal_height=True):
182
+ # with gr.Column():
183
+ textbox = gr.Textbox(label="Reatants (Products) SMILES string", type="text", placeholder="Provide a SMILES string here",
184
+ lines=1)
185
+ predict_single_smiles_button = gr.Button("Predict", size='sm')
186
+ #prediction = gr.Label("Prediction will appear here")
187
+ prediction = gr.Textbox(label="Predictions", type="text", placeholder=None, lines=10, interactive=False)
188
+
189
+ running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
190
+
191
+ #input_file = gr.File(label="Molecule file",
192
+ # file_count='single',
193
+ # file_types=[".smi", ".csv"], height=300)
194
+ #predict_file_button = gr.Button("Predict", size='sm', visible=False)
195
+ #download_button = gr.DownloadButton("Download", size='sm', visible=False)
196
+ #stop_button = gr.Button("Stop", size='sm', visible=False)
197
+
198
+ # dropdown change event
199
+ dropdown.change(get_description, inputs=dropdown, outputs=description_box)
200
+ # predict single button click event
201
+ predict_single_smiles_button.click(lambda:(gr.update(interactive=False),
202
+ gr.update(interactive=False),
203
+ gr.update(interactive=False),
204
+ gr.update(interactive=False),
205
+ ) , outputs=[dropdown, textbox, predict_single_smiles_button, running_terminal_label])\
206
+ .then(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])\
207
+ .then(lambda:(gr.update(interactive=True),
208
+ gr.update(interactive=True),
209
+ gr.update(interactive=True),
210
+ gr.update(interactive=True),
211
+ ) , outputs=[dropdown, textbox, predict_single_smiles_button, running_terminal_label])
212
+ """
213
+ # input file upload event
214
+ file_status = gr.State()
215
+ input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
216
+ # input file clear event
217
+ input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
218
+ # predict file button click event
219
+ predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False),
220
+ gr.update(interactive=False),
221
+ gr.update(interactive=False),
222
+ gr.update(interactive=False, visible=True),
223
+ gr.update(interactive=False),
224
+ gr.update(interactive=True, visible=False),
225
+ gr.update(interactive=False),
226
+ gr.update(interactive=False),
227
+ ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
228
+ .then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
229
+ .then(lambda:(gr.update(interactive=True),
230
+ gr.update(interactive=True),
231
+ gr.update(interactive=True),
232
+ gr.update(interactive=True),
233
+ gr.update(interactive=True),
234
+ gr.update(interactive=True),
235
+ gr.update(interactive=True),
236
+ gr.update(interactive=True),
237
+ ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
238
+
239
+ # stop button click event
240
+ #stop_button.click(fn=None, inputs=None, outputs=None, cancels=[predict_file_event])
241
+ """
242
+
243
+ return demo
244
+
245
+
246
+ demo = build_inference()
247
+
248
+ if __name__ == '__main__':
249
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ huggingface_hub
4
+ pandas
5
+ peft
6
+ tqdm
7
+ datasets
8
+ rdkit
9
+ scikit-learn
utils.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
2
+ import os
3
+ from typing import Optional, Dict, Sequence
4
+ import transformers
5
+ from peft import PeftModel
6
+ import torch
7
+ from dataclasses import dataclass, field
8
+ from huggingface_hub import hf_hub_download
9
+ import json
10
+ import pandas as pd
11
+ from datasets import Dataset
12
+ from tqdm import tqdm
13
+ import spaces
14
+
15
+ from rdkit import RDLogger, Chem
16
+ # Suppress RDKit INFO messages
17
+ RDLogger.DisableLog('rdApp.*')
18
+
19
+ DEFAULT_PAD_TOKEN = "[PAD]"
20
+ device_map = "cpu"
21
+
22
+ def compute_rank(prediction,raw=False,alpha=1.0):
23
+ valid_score = [[k for k in range(len(prediction[j]))] for j in range(len(prediction))]
24
+ invalid_rates = [0 for k in range(len(prediction[0]))]
25
+ rank = {}
26
+ highest = {}
27
+
28
+ for j in range(len(prediction)):
29
+ for k in range(len(prediction[j])):
30
+ if prediction[j][k] == "":
31
+ valid_score[j][k] = 10 + 1
32
+ invalid_rates[k] += 1
33
+ de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0] != ""]
34
+ prediction[j] = list(set(de_error))
35
+ prediction[j].sort(key=de_error.index)
36
+ for k, data in enumerate(prediction[j]):
37
+ if data in rank:
38
+ rank[data] += 1 / (alpha * k + 1)
39
+ else:
40
+ rank[data] = 1 / (alpha * k + 1)
41
+ if data in highest:
42
+ highest[data] = min(k,highest[data])
43
+ else:
44
+ highest[data] = k
45
+ return rank,invalid_rates
46
+
47
+
48
+ @dataclass
49
+ class DataCollatorForCausalLMEval(object):
50
+ tokenizer: transformers.PreTrainedTokenizer
51
+ source_max_len: int
52
+ target_max_len: int
53
+ reactant_start_str: str
54
+ product_start_str: str
55
+ end_str: str
56
+
57
+ def augment_molecule(self, molecule: str) -> str:
58
+ return self.sme.augment([molecule])[0]
59
+
60
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
61
+
62
+ print(instances)
63
+ srcs = instances[0]['src']
64
+ task_type = instances[0]['task_type'][0]
65
+
66
+ if task_type == 'retrosynthesis':
67
+ src_start_str = self.product_start_str
68
+ tgt_start_str = self.reactant_start_str
69
+ else:
70
+ src_start_str = self.reactant_start_str
71
+ tgt_start_str = self.product_start_str
72
+
73
+ generation_prompts = []
74
+ generation_prompt = f"{src_start_str}{srcs}{self.end_str}{tgt_start_str}"
75
+ generation_prompts.append(generation_prompt)
76
+
77
+ data_dict = {
78
+ 'generation_prompts': generation_prompts
79
+ }
80
+
81
+ return data_dict
82
+
83
+ def smart_tokenizer_and_embedding_resize(
84
+ special_tokens_dict: Dict,
85
+ tokenizer: transformers.PreTrainedTokenizer,
86
+ model: transformers.PreTrainedModel,
87
+ non_special_tokens = None,
88
+ ):
89
+ """Resize tokenizer and embedding.
90
+
91
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
92
+ """
93
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens)
94
+ num_old_tokens = model.get_input_embeddings().weight.shape[0]
95
+ num_new_tokens = len(tokenizer) - num_old_tokens
96
+ if num_new_tokens == 0:
97
+ return
98
+
99
+ model.resize_token_embeddings(len(tokenizer))
100
+
101
+ if num_new_tokens > 0:
102
+ input_embeddings_data = model.get_input_embeddings().weight.data
103
+
104
+ input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
105
+
106
+ input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
107
+ print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.")
108
+
109
+ class ReactionPredictionModel():
110
+ def __init__(self, candidate_models):
111
+
112
+
113
+ for model in candidate_models:
114
+ if "retro" in model:
115
+ self.tokenizer = AutoTokenizer.from_pretrained(
116
+ candidate_models[list(candidate_models.keys())[0]],
117
+ padding_side="right",
118
+ use_fast=True,
119
+ trust_remote_code=True,
120
+ token = os.environ.get("TOKEN")
121
+ )
122
+ self.load_retro_model(candidate_models[model])
123
+ else:
124
+ self.tokenizer = AutoTokenizer.from_pretrained(
125
+ candidate_models[list(candidate_models.keys())[0]],
126
+ padding_side="right",
127
+ use_fast=True,
128
+ trust_remote_code=True,
129
+ token = os.environ.get("TOKEN")
130
+ )
131
+ self.load_forward_model(candidate_models[model])
132
+
133
+ string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
134
+ string_template = json.load(open(string_template_path, 'r'))
135
+ reactant_start_str = string_template['REACTANTS_START_STRING']
136
+ product_start_str = string_template['PRODUCTS_START_STRING']
137
+ end_str = string_template['END_STRING']
138
+ self.data_collator = DataCollatorForCausalLMEval(
139
+ tokenizer=self.tokenizer,
140
+ source_max_len=512,
141
+ target_max_len=512,
142
+ reactant_start_str=reactant_start_str,
143
+ product_start_str=product_start_str,
144
+ end_str=end_str,
145
+ )
146
+
147
+
148
+
149
+ def load_retro_model(self, model_path):
150
+ # our retro model is lora model
151
+ config = AutoConfig.from_pretrained(
152
+ "ChemFM/ChemFM-3B",
153
+ trust_remote_code=True,
154
+ token=os.environ.get("TOKEN")
155
+ )
156
+
157
+ base_model = AutoModelForCausalLM.from_pretrained(
158
+ "ChemFM/ChemFM-3B",
159
+ config=config,
160
+ trust_remote_code=True,
161
+ device_map=device_map,
162
+ token = os.environ.get("TOKEN")
163
+ )
164
+
165
+ # we should resize the embedding layer of the base model to match the adapter's tokenizer
166
+ special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
167
+ smart_tokenizer_and_embedding_resize(
168
+ special_tokens_dict=special_tokens_dict,
169
+ tokenizer=self.tokenizer,
170
+ model=base_model
171
+ )
172
+ base_model.config.pad_token_id = self.tokenizer.pad_token_id
173
+
174
+ # load the adapter model
175
+ self.retro_model = PeftModel.from_pretrained(
176
+ base_model,
177
+ model_path,
178
+ token = os.environ.get("TOKEN")
179
+ )
180
+
181
+ self.retro_model.to("cuda")
182
+
183
+ def load_forward_model(self, model_path):
184
+ config = AutoConfig.from_pretrained(
185
+ model_path,
186
+ device_map=device_map,
187
+ trust_remote_code=True,
188
+ token = os.environ.get("TOKEN")
189
+ )
190
+
191
+ self.forward_model = AutoModelForCausalLM.from_pretrained(
192
+ model_path,
193
+ config=config,
194
+ device_map=device_map,
195
+ trust_remote_code=True,
196
+ token = os.environ.get("TOKEN")
197
+ )
198
+
199
+ # the finetune tokenizer could be in different size with pretrain tokenizer, and also, we need to add PAD_TOKEN
200
+ special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
201
+ smart_tokenizer_and_embedding_resize(
202
+ special_tokens_dict=special_tokens_dict,
203
+ tokenizer=self.tokenizer,
204
+ model=self.forward_model
205
+ )
206
+ self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id
207
+ self.forward_model.to("cuda")
208
+
209
+ @spaces.GPU(duration=20)
210
+ def predict_single_smiles(self, smiles, task_type):
211
+ if task_type == "full_retro":
212
+ if "." in smiles:
213
+ return None
214
+
215
+ task_type = "retrosynthesis" if task_type == "full_retro" else "synthesis"
216
+ # canonicalize the smiles
217
+ mol = Chem.MolFromSmiles(smiles)
218
+ if mol is None:
219
+ return None
220
+ smiles = Chem.MolToSmiles(mol)
221
+
222
+ smiles_list = [smiles]
223
+ task_type_list = [task_type]
224
+
225
+
226
+ df = pd.DataFrame({"src": smiles_list, "task_type": task_type_list})
227
+ test_dataset = Dataset.from_pandas(df)
228
+ # construct the dataloader
229
+ test_loader = torch.utils.data.DataLoader(
230
+ test_dataset,
231
+ batch_size=1,
232
+ collate_fn=self.data_collator,
233
+ )
234
+
235
+ predictions = []
236
+ for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
237
+ with torch.no_grad():
238
+ generation_prompts = batch['generation_prompts'][0]
239
+ inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True).to(self.retro_model.device)
240
+ del inputs['token_type_ids']
241
+ if task_type == "retrosynthesis":
242
+ outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
243
+ do_sample=False, num_beams=10,
244
+ eos_token_id=self.tokenizer.eos_token_id,
245
+ early_stopping='never',
246
+ pad_token_id=self.tokenizer.pad_token_id,
247
+ length_penalty=0.0,
248
+ )
249
+ else:
250
+ outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
251
+ do_sample=False, num_beams=10,
252
+ eos_token_id=self.tokenizer.eos_token_id,
253
+ early_stopping='never',
254
+ pad_token_id=self.tokenizer.pad_token_id,
255
+ length_penalty=0.0,
256
+ )
257
+
258
+ original_smiles_list = self.tokenizer.batch_decode(outputs[:, len(inputs['input_ids'][0]):],
259
+ skip_special_tokens=True)
260
+ original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list)
261
+ # canonize the SMILES
262
+ canonized_smiles_list = []
263
+ temp = []
264
+ for original_smiles in original_smiles_list:
265
+ temp.append(original_smiles)
266
+ try:
267
+ canonized_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(original_smiles)))
268
+ except:
269
+ canonized_smiles_list.append("")
270
+ #canonized_smiles_list = \
271
+ #['N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1cc(F)c([N+](=O)[O-])cc1F', 'N#Cc1ccsc1Nc1cc(Cl)c(F)cc1[N+](=O)[O-]', 'N#Cc1cnsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1cc(F)c(F)cc1Nc1sccc1C#N', 'N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=N)[O-]', 'N#Cc1cc(C#N)c(Nc2cc(F)c(F)cc2[N+](=O)[O-])s1', 'N#Cc1ccsc1Nc1c(F)c(F)cc(F)c1[N+](=O)[O-]', 'Nc1sccc1CNc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1ccc(F)cc1[N+](=O)[O-]']
272
+ predictions.append(canonized_smiles_list)
273
+
274
+ rank, invalid_rate = compute_rank(predictions)
275
+ return rank
276
+
277
+
278
+
279
+
280
+