File size: 12,445 Bytes
1e001e8
 
 
1d1d4f3
1e001e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1d4f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e001e8
1d1d4f3
1e001e8
 
 
 
 
1d1d4f3
 
1e001e8
 
 
 
 
 
 
 
 
1d1d4f3
1e001e8
1d1d4f3
 
1e001e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1d4f3
 
 
 
 
 
 
1e001e8
 
 
 
 
 
 
1d1d4f3
 
 
 
 
1e001e8
 
 
 
 
 
 
1d1d4f3
 
 
1e001e8
1d1d4f3
 
 
1e001e8
1d1d4f3
 
 
1e001e8
1d1d4f3
 
 
1e001e8
 
 
1d1d4f3
 
1e001e8
 
 
 
 
 
 
1d1d4f3
 
 
 
1e001e8
 
 
 
 
 
1d1d4f3
1e001e8
1d1d4f3
 
 
 
 
 
 
 
 
1e001e8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import gradio as gr
from huggingface_hub import HfApi, get_collection, list_collections, list_models
#from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
from utils import MolecularGenerationModel
import pandas as pd
import os
import spaces

#candidate_models = get_models()
#task_names = {
#    'mit_synthesis': 'Reaction Synthesis',
#    'full_retro': 'Reaction Retro Synthesis'
#}
#task_names_to_tasks = {v: k for k, v in task_names.items()}
#tasks = list(candidate_models.keys())
#task_descriptions = {
#    'mit_synthesis': 'Predict the reaction products given the reactants and reagents. \n' + \
#                     '1. This model is trained on the USPTO MIT dataset. \n' + \
#                     '2. The reactants and reagents are mixed in the input SMILES string. \n' + \
#                     '3. Different compounds are separated by ".". \n' + \
#                     '4. Input SMILES string example: C1CCOC1.N#Cc1ccsc1N.O=[N+]([O-])c1cc(F)c(F)cc1F.[H-].[Na+]',
#    'full_retro': 'Predict the reaction precursors given the reaction products. \n' + \
#                    '1. This model is trained on the USPTO Full dataset. \n' + \
#                    '2. In this dataset, we consider only a single product in the input SMILES string. \n' + \
#                    '3. Input SMILES string example: CC(=O)OCC(=O)[C@@]1(O)CC[C@H]2[C@@H]3CCC4=CC(=O)CC[C@]4(C)C3=CC[C@@]21C'
#}

#property_names = list(candidate_models.keys())
model = MolecularGenerationModel()

def predict_single_label(logp, tpas, sas, qed, logp_choose, tpsa_choose, sas_choose, qed_choose):
    input_dict = dict()
    if logp_choose:
        input_dict['logP'] = logp
    if tpsa_choose:
        input_dict['TPSA'] = tpas
    if sas_choose:
        input_dict['SAS'] = sas
    if qed_choose:
        input_dict['QED'] = qed
    
    if len(input_dict) == 0:
        return "NA", "No input is selected"

    print(input_dict)

    try:

        running_status = None
        prediction = None

        prediction = model.predict_single_smiles(input_dict)
    
        #prediction = model.predict(smiles, property_name, adapter_id)
        #prediction = model.predict_single_smiles(smiles, task)
        if prediction is None:
            return "NA", "Invalid SMILES string"
    
    except Exception as e:
        # no matter what the error is, we should return
        print(e)
        return "NA", "Generation failed"

    #prediction = "\n".join([f"{idx+1}. {item}" for idx, item in enumerate(prediction)])
    return prediction, "Generation is done"

"""
def get_description(task_name):
    task = task_names_to_tasks[task_name]
    return task_descriptions[task]

#@spaces.GPU(duration=10)
"""

"""
@spaces.GPU(duration=30)
def predict_file(file, property_name):
    property_id = dataset_property_names_to_dataset[property_name]
    try:
        adapter_id = candidate_models[property_id]
        info = model.swith_adapter(property_id, adapter_id)

        running_status = None
        if info == "keep":
            running_status = "Adapter is the same as the current one"
            #print("Adapter is the same as the current one")
        elif info == "switched":
            running_status = "Adapter is switched successfully"
            #print("Adapter is switched successfully")
        elif info == "error":
            running_status = "Adapter is not found"
            #print("Adapter is not found")
            return None, None, file, running_status
        else:
            running_status = "Unknown error"
            return None, None, file, running_status
    
        df = pd.read_csv(file)
        # we have already checked the file contains the "smiles" column
        df = model.predict_file(df, dataset_task_types[property_id])
        # we should save this file to the disk to be downloaded
        # rename the file to have "_prediction" suffix
        prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
        print(file, prediction_file)
        # save the file to the disk
        df.to_csv(prediction_file, index=False)
    except Exception as e:
        # no matter what the error is, we should return
        print(e)
        return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), file, "Prediction failed"
    
    return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done"

def validate_file(file):
    try:
        if file.endswith(".csv"):
            df = pd.read_csv(file)
            if "smiles" not in df.columns:
                # we should clear the file input
                return "Invalid file content. The csv file must contain column named 'smiles'", \
                         None, gr.update(visible=False), gr.update(visible=False)
            
            # check the length of the smiles
            length = len(df["smiles"])

        elif file.endswith(".smi"):
            return "Invalid file extension", \
                    None, gr.update(visible=False), gr.update(visible=False)

        else:
            return "Invalid file extension", \
                    None, gr.update(visible=False), gr.update(visible=False)
    except Exception as e:
        return "Invalid file content.", \
                None, gr.update(visible=False), gr.update(visible=False)
    
    if length > 100: 
        return "The space does not support the file containing more than 100 SMILES", \
                None, gr.update(visible=False), gr.update(visible=False)

    return "Valid file", file, gr.update(visible=True), gr.update(visible=False)
"""
    

def raise_error(status):
    if status != "Valid file":
        raise gr.Error(status)
    return None


"""
def clear_file(download_button):
    # we might need to delete the prediction file and uploaded file
    prediction_path = download_button
    print(prediction_path)
    if prediction_path and os.path.exists(prediction_path):
        os.remove(prediction_path)
        original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv")
        original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi")
        if os.path.exists(original_data_file_0):
            os.remove(original_data_file_0)
        if os.path.exists(original_data_file_1):
            os.remove(original_data_file_1)
    #if os.path.exists(file):
    #    os.remove(file)
    #prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
    #if os.path.exists(prediction_file):
    #    os.remove(prediction_file)
    

    return gr.update(visible=False), gr.update(visible=False), None
"""

def toggle_slider(checked):
    return gr.update(interactive=checked)

def toggle_sliders_based_on_checkboxes(checked_values):
    """Enable or disable sliders based on the corresponding checkbox values."""
    return [gr.update(interactive=checked_values[i]) for i in range(4)]

def build_inference():

    with gr.Blocks() as demo:
        # first row - Dropdown input
        #with gr.Row():
        #gr.Markdown(f"<span style='color: red;'>If you run out of your GPU quota, you can use the </span> <a href='https://huggingface.co/spaces/ChemFM/molecular_property_prediction'>CPU-powered space</a> but with much lower performance.")
        #dropdown = gr.Dropdown([task_names[key] for key in tasks], label="Task", value=task_names[tasks[0]])
        description = f"This space allows you to generate ten possible molecules based on given conditions. \n" \
                      f"1. You can enable or disable specific properties using checkboxes and adjust their values with sliders. \n" \
                      f"2. The generated SMILES strings and their corresponding predicted properties will be displayed in the generations section. \n" \
                      f"3. The properties include logP, TPSA, SAS, and QED. \n" \
                      f"4. Model trained on the GuacaMol dataset for molecular design. "

        description_box = gr.Textbox(label="Task description", lines=5,
                                     interactive=False,
                                     value= description)
        # third row - Textbox input and prediction label
        with gr.Row(equal_height=True):
            with gr.Column():
                checkbox_1 = gr.Checkbox(label="logP", value=True)   
                slider_1 = gr.Slider(1, 7, value=4, label="logP", info="Choose between 1 and 7")
                checkbox_1.change(toggle_slider, checkbox_1, slider_1)
            with gr.Column():
                checkbox_2 = gr.Checkbox(label="TPSA", value=True)
                slider_2 = gr.Slider(20, 140, value=80, label="TPSA", info="Choose between 20 and 140")
                checkbox_2.change(toggle_slider, checkbox_2, slider_2)
            with gr.Column():
                checkbox_3 = gr.Checkbox(label="SAS", value=True)
                slider_3 = gr.Slider(1, 5, value=3, label="SAS", info="Choose between 1 and 5")
                checkbox_3.change(toggle_slider, checkbox_3, slider_3)
            with gr.Column():
                checkbox_4 = gr.Checkbox(label="QED", value=True)
                slider_4 = gr.Slider(0.1, 0.9, value=0.5, label="QED", info="Choose between 0.1 and 0.9")
                checkbox_4.change(toggle_slider, checkbox_4, slider_4)

        predict_single_smiles_button = gr.Button("Generate", size='sm')
        #prediction = gr.Label("Prediction will appear here")
        #prediction = gr.Textbox(label="Predictions", type="text", placeholder=None, lines=10, interactive=False)
        prediction = gr.Dataframe(label="Generations", type="pandas", interactive=False)

        running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
        

        # dropdown change event
        # predict single button click event
        predict_single_smiles_button.click(lambda:(gr.update(interactive=False), 
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   ) , outputs=[slider_1, slider_2, slider_3, slider_4, 
                                                                checkbox_1, checkbox_2, checkbox_3, checkbox_4,
                                                                predict_single_smiles_button, running_terminal_label])\
                                                   .then(predict_single_label, inputs=[slider_1, slider_2, slider_3, slider_4, 
                                                                                        checkbox_1, checkbox_2, checkbox_3, checkbox_4
                                                                                       ], outputs=[prediction, running_terminal_label])\
                                                   .then(lambda a, b, c, d: toggle_sliders_based_on_checkboxes([a, b, c, d]) + 
                                                                            [gr.update(interactive=True)] * 6,
                                                         inputs=[checkbox_1, checkbox_2, checkbox_3, checkbox_4],
                                                         outputs=[slider_1, slider_2, slider_3, slider_4,
                                                                  checkbox_1, checkbox_2, checkbox_3, checkbox_4,
                                                                  predict_single_smiles_button, running_terminal_label])
        
    return demo


demo = build_inference() 

if __name__ == '__main__':
    demo.launch()