File size: 15,935 Bytes
4c9e6d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39488b0
 
 
4c9e6d9
 
39488b0
 
 
 
 
4c9e6d9
 
39488b0
 
 
 
 
 
 
4c9e6d9
 
39488b0
 
4c9e6d9
39488b0
3cce1b1
4c9e6d9
 
39488b0
 
 
 
4c9e6d9
 
 
39488b0
 
 
 
4c9e6d9
 
 
39488b0
 
4c9e6d9
 
 
 
 
 
 
 
 
39488b0
 
 
1c867f8
39488b0
 
 
4c9e6d9
1c867f8
39488b0
 
 
 
 
 
078fcbb
 
39488b0
 
 
 
 
 
 
 
 
 
 
1c867f8
 
 
 
 
 
 
969a6ef
39488b0
 
 
 
 
75a75eb
4c9e6d9
 
e3d7930
4c9e6d9
 
25a5f8a
 
 
 
 
39488b0
 
25a5f8a
 
 
 
 
 
 
 
 
39488b0
 
25a5f8a
e95a2e3
39488b0
 
 
4c9e6d9
 
 
 
 
 
 
bb27c74
4c9e6d9
bb27c74
 
4c9e6d9
39488b0
 
 
4c9e6d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3d7930
4c9e6d9
 
1d105c9
5ba6435
21ae81c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b92a23b
21ae81c
 
 
 
 
 
 
 
 
 
 
 
 
 
5ba6435
21ae81c
 
 
 
 
 
 
 
 
 
 
 
 
b92a23b
 
21ae81c
 
 
540e177
 
21ae81c
 
8dde4ec
540e177
21ae81c
 
540e177
 
 
2fb78bb
8dde4ec
21ae81c
 
 
 
 
 
 
 
 
 
 
 
8dde4ec
 
 
 
 
 
21ae81c
 
2fb78bb
540e177
21ae81c
 
 
 
 
 
 
540e177
 
21ae81c
 
23efae1
 
 
 
 
 
b92a23b
21ae81c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625ae1f
b83af40
21ae81c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b92a23b
21ae81c
39488b0
21ae81c
 
25a5f8a
5ba6435
21ae81c
e95a2e3
39488b0
 
 
 
e95a2e3
 
 
e3d7930
 
39488b0
 
 
 
 
 
ecaee5b
39488b0
ecaee5b
 
 
 
39488b0
 
969a6ef
39488b0
4c9e6d9
39488b0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
import gradio as gr
from inference import Inference
import PIL
from PIL import Image
import pandas as pd
import random
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
import shutil
import os
import time

class DrugGENConfig:
    # Inference configuration
    submodel = 'DrugGEN'
    inference_model = "/home/user/app/experiments/models/DrugGEN/"
    sample_num = 100

    # Data configuration
    inf_smiles = '/home/user/app/data/chembl_test.smi'
    train_smiles = '/home/user/app/data/chembl_train.smi'
    inf_batch_size = 1
    mol_data_dir = '/home/user/app/data'
    features = False

    # Model configuration
    act = 'relu'
    max_atom = 45
    dim = 128
    depth = 1
    heads = 8
    mlp_ratio = 3
    dropout = 0.

    # Seed configuration
    set_seed = True
    seed = 10

    disable_correction = False


class DrugGENAKT1Config(DrugGENConfig):
    submodel = 'DrugGEN'
    inference_model = "/home/user/app/experiments/models/DrugGEN-akt1/"
    train_drug_smiles = '/home/user/app/data/akt_train.smi'
    max_atom = 45


class DrugGENCDK2Config(DrugGENConfig):
    submodel = 'DrugGEN'
    inference_model = "/home/user/app/experiments/models/DrugGEN-cdk2/"
    train_drug_smiles = '/home/user/app/data/cdk2_train.smi'
    max_atom = 38


class NoTargetConfig(DrugGENConfig):
    submodel = "NoTarget"
    inference_model = "/home/user/app/experiments/models/NoTarget/"


model_configs = {
    "DrugGEN-AKT1": DrugGENAKT1Config(),
    "DrugGEN-CDK2": DrugGENCDK2Config(),
    "DrugGEN-NoTarget": NoTargetConfig(),
}


def run_inference(mode: str, model_name: str, num_molecules: int, seed_num: str, custom_smiles: str):
    """
    Depending on the selected mode, either generate new molecules or evaluate provided SMILES.
    
    Returns:
        image, file_path, basic_metrics, advanced_metrics
    """
    config = model_configs[model_name]

    if mode == "Custom Input SMILES":
        # Process the custom input SMILES
        smiles_list = [s.strip() for s in custom_smiles.strip().splitlines() if s.strip() != ""]
        if len(smiles_list) > 100:
            raise gr.Error("You have provided more than the allowed limit of 100 molecules. Please provide 100 or fewer.")
        # Write the custom SMILES to a temporary file and update config
        config.seed = random.randint(0, 10000)
        temp_input_file = f"custom_input{config.seed}.smi"
        with open(temp_input_file, "w") as f:
            for s in smiles_list:
                f.write(s + "\n")
        config.inf_smiles = temp_input_file
        config.sample_num = len(smiles_list)
        # Always use a random seed for custom mode
    else:
        # Classical Generation mode
        config.sample_num = num_molecules
        if config.sample_num > 200:
            raise gr.Error("You have requested to generate more than the allowed limit of 200 molecules. Please reduce your request to 200 or fewer.")
        if seed_num is None or seed_num.strip() == "":
            config.seed = random.randint(0, 10000)
        else:
            try:
                config.seed = int(seed_num)
            except ValueError:
                raise gr.Error("The seed must be an integer value!")

    # Adjust model name for the inference if not using NoTarget
    if model_name != "DrugGEN-NoTarget":
        target_model_name = "DrugGEN"
    else:
        target_model_name = "NoTarget"

    inferer = Inference(config)
    start_time = time.time()
    scores = inferer.inference()  # This returns a DataFrame with specific columns
    et = time.time() - start_time

    # Create basic metrics dataframe
    basic_metrics = pd.DataFrame({
        "Validity": [scores["validity"].iloc[0]],
        "Uniqueness": [scores["uniqueness"].iloc[0]],
        "Novelty (Train)": [scores["novelty"].iloc[0]],
        "Novelty (Inference)": [scores["novelty_test"].iloc[0]],
        "Novelty (Real Inhibitors)": [scores["drug_novelty"].iloc[0]],
        "Runtime (s)": [round(et, 2)]
    })
    
    # Create advanced metrics dataframe
    advanced_metrics = pd.DataFrame({
        "QED": [scores["qed"].iloc[0]],
        "SA Score": [scores["sa"].iloc[0]],
        "Internal Diversity": [scores["IntDiv"].iloc[0]],
        "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
        "SNN Real Inhibitors": [scores["snn_drug"].iloc[0]],
        "Average Length": [scores["max_len"].iloc[0]]
    })

    # Process the output file from inference
    output_file_path = f'/home/user/app/experiments/inference/{target_model_name}/inference_drugs.txt'
    new_path = f'{target_model_name}_denovo_mols.smi'
    os.rename(output_file_path, new_path)

    with open(new_path) as f:
        inference_drugs = f.read()

    generated_molecule_list = inference_drugs.split("\n")[:-1]

    # Randomly select up to 9 molecules for display
    rng = random.Random(config.seed)
    if len(generated_molecule_list) > 9:
        selected_smiles = rng.choices(generated_molecule_list, k=9)
    else:
        selected_smiles = generated_molecule_list

    selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_smiles if Chem.MolFromSmiles(mol) is not None]

    drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
    drawOptions.prepareMolsBeforeDrawing = False
    drawOptions.bondLineWidth = 0.5

    molecule_image = Draw.MolsToGridImage(
        selected_molecules,
        molsPerRow=3,
        subImgSize=(400, 400),
        maxMols=len(selected_molecules),
        returnPNG=False,
        drawOptions=drawOptions,
        highlightAtomLists=None,
        highlightBondLists=None,
    )

    return molecule_image, new_path, basic_metrics, advanced_metrics


with gr.Blocks(theme=gr.themes.Ocean()) as demo:
    
    # Add custom CSS for styling
    gr.HTML("""
    <style>
    #metrics-container {
        border: 1px solid rgba(128, 128, 128, 0.3);
        border-radius: 8px;
        padding: 15px;
        margin-top: 15px;
        margin-bottom: 15px;
        background-color: rgba(255, 255, 255, 0.05);
    }
    </style>
    """)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
        
            gr.HTML("""
            <div style="display: flex; gap: 10px; margin-bottom: 15px;">
                <!-- arXiv badge -->
                <a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;">
                    <div style="
                        display: inline-block; 
                        background-color: #b31b1b; 
                        color: #ffffff !important; 
                        padding: 5px 10px; 
                        border-radius: 5px; 
                        font-size: 14px;">
                        <span style="font-weight: bold;">arXiv</span> 2302.07868
                    </div>
                </a>
                
                <!-- GitHub badge -->
                <a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;">
                    <div style="
                        display: inline-block; 
                        background-color: #24292e; 
                        color: #ffffff !important; 
                        padding: 5px 10px; 
                        border-radius: 5px; 
                        font-size: 14px;">
                        <span style="font-weight: bold;">GitHub</span> Repository
                    </div>
                </a>
            </div>
            """)
        
            with gr.Accordion("About DrugGEN Models", open=False):
                gr.Markdown("""
                    ### DrugGEN-AKT1
                    This model is designed to generate molecules targeting the human AKT1 protein (UniProt ID: P31749). Trained with [2,607 bioactive compounds](https://drive.google.com/file/d/1B2OOim5wrUJalixeBTDKXLHY8BAIvNh-/view?usp=drive_link).
                    Molecules larger than 45 heavy atoms were excluded.
                    
                    ### DrugGEN-CDK2
                    This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941). Trained with [1,817 bioactive compounds](https://drive.google.com/file/d/1C0CGFKx0I2gdSfbIEgUO7q3K2S1P9ksT/view?usp=drive_link).
                    Molecules larger than 38 heavy atoms were excluded.
                    
                    ### DrugGEN-NoTarget
                    This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. Trained with a general [ChEMBL dataset]((https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link)
                    Molecules larger than 45 heavy atoms were excluded.

                    - Useful for exploring chemical space, generating diverse scaffolds, and creating molecules with drug-like properties.
                    
                    
                    For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
                """)
            
            with gr.Accordion("Understanding the Metrics", open=False):
                gr.Markdown("""
                    ### Basic Metrics
                    - **Validity**: Percentage of generated molecules that are chemically valid
                    - **Uniqueness**: Percentage of unique molecules among valid ones
                    - **Runtime**: Time taken to generate or evaluate the molecules
                    
                    ### Novelty Metrics
                    - **Novelty (Train)**: Percentage of molecules not found in the [training set](https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link). These molecules are used as inputs to
                    the generator during training.
                    - **Novelty (Inference)**: Percentage of molecules not found in the [inference set](https://drive.google.com/file/d/1vMGXqK1SQXB3Od3l80gMWvTEOjJ5MFXP/view?usp=share_link). These molecules are used as inputs
                    to the generator during inference.
                    - **Novelty (Real Inhibitors)**: Percentage of molecules not found in known inhibitors of the target protein (look at About DrugGEN Models for details). These molecules are used as inputs to the
                    discriminator during training.
                    
                    ### Structural Metrics
                    - **Average Length**: Normalized average number of atoms in the generated molecules, normalized by the maximum number of atoms (e.g., 45 for AKT1/NoTarget, 38 for CDK2)
                    - **Mean Atom Type**: Average number of distinct atom types in the generated molecules
                    - **Internal Diversity**: Diversity within the generated set (higher is more diverse)
                    
                    ### Drug-likeness Metrics
                    - **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
                    - **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is better)
                    
                    ### Similarity Metrics
                    - **SNN ChEMBL**: Similarity to [ChEMBL molecules](https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link) (higher means more similar to known drug-like compounds)
                    - **SNN Real Inhibitors**: Similarity to the real inhibitors of the selected target (higher means more similar to the real inhibitors)
                """)
            
            model_name = gr.Radio(
                choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
                value="DrugGEN-AKT1",
                label="Select Target Model",
                info="Choose which protein target or general model to use for molecule generation"
            )
                
            with gr.Tabs():
                with gr.TabItem("Classical Generation"):
                        num_molecules = gr.Slider(
                            minimum=10,
                            maximum=200,
                            value=100,
                            step=10,
                            label="Number of Molecules to Generate",
                            info="This space runs on a CPU, which may result in slower performance. Generating 100 molecules takes approximately 6 minutes. Therefore, we set a 200-molecule cap."
                        )
        
                        seed_num = gr.Textbox(
                            label="Random Seed (Optional)",
                            value="",
                            info="Set a specific seed for reproducible results, or leave empty for random generation"
                        )
        
                        classical_submit = gr.Button(
                            value="Generate Molecules",
                            variant="primary",
                            size="lg"
                        )
        
                with gr.TabItem("Custom Input SMILES"):
                        custom_smiles = gr.Textbox(
                            label="Input SMILES (one per line, maximum 100 molecules)",
                            info="This space runs on a CPU, which may result in slower performance. Generating 100 molecules takes approximately 6 minutes. Therefore, we set a 100-molecule cap.\n\n Molecules larger than allowed maximum length (45 for AKT1/NoTarget and 38 for CDK2) and allowed atom types are going to be filtered.\n\n Novelty (Inference) metric is going to be calculated using these input smiles.",
                            placeholder="Nc1ccccc1-c1nc(N)c2ccccc2n1\nO=C(O)c1ccccc1C(=O)c1cccc(Cl)c1\n...",
                            lines=10
                        )
                        
                        custom_submit = gr.Button(
                            value="Generate Molecules using Custom SMILES",
                            variant="primary",
                            size="lg"
                        )

        with gr.Column(scale=2):
            basic_metrics_df = gr.Dataframe(
                headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (Real Inhibitors)", "Runtime (s)"],
                elem_id="basic-metrics"
            )
                            
            advanced_metrics_df = gr.Dataframe(
                headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Real Inhibitors)", "Average Length"],
                elem_id="advanced-metrics"
            )
    
            file_download = gr.File(label="Download All Generated Molecules (SMILES format)")
                        
            image_output = gr.Image(label="Structures of Randomly Selected Generated Molecules",
                            elem_id="molecule_display")


    gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")

    # Set up the click actions for each tab.
    classical_submit.click(
        run_inference, 
        inputs=[gr.State("Generate Molecules"), model_name, num_molecules, seed_num, gr.State("")],
        outputs=[
            image_output, 
            file_download,
            basic_metrics_df,
            advanced_metrics_df
        ],
        api_name="inference_classical"
    )
    
    custom_submit.click(
        run_inference, 
        inputs=[gr.State("Custom Input SMILES"), model_name, gr.State(0), gr.State(""), custom_smiles],
        outputs=[
            image_output, 
            file_download,
            basic_metrics_df,
            advanced_metrics_df
        ],
        api_name="inference_custom"
    )

demo.queue()
demo.launch()