Spaces:
Running
Running
| import gradio as gr | |
| from inference import Inference | |
| import PIL | |
| from PIL import Image | |
| import pandas as pd | |
| import random | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw | |
| from rdkit.Chem.Draw import IPythonConsole | |
| import shutil | |
| import os | |
| import time | |
| class DrugGENConfig: | |
| # Inference configuration | |
| submodel='DrugGEN' | |
| inference_model="experiments/models/DrugGEN/" | |
| sample_num=100 | |
| disable_correction=False # corresponds to correct=True in old config | |
| # Data configuration | |
| inf_smiles='data/chembl_test.smi' # corresponds to inf_raw_file in old config | |
| train_smiles='data/chembl_train.smi' | |
| train_drug_smiles='data/akt1_train.smi' | |
| inf_batch_size=1 | |
| mol_data_dir='data' | |
| features=False | |
| # Model configuration | |
| act='relu' | |
| max_atom=45 | |
| dim=128 | |
| depth=1 | |
| heads=8 | |
| mlp_ratio=3 | |
| dropout=0. | |
| # Seed configuration | |
| set_seed=True | |
| seed=10 | |
| class DrugGENAKT1Config(DrugGENConfig): | |
| submodel='DrugGEN' | |
| inference_model="experiments/models/DrugGEN-AKT1/" | |
| train_drug_smiles='data/akt1_train.smi' | |
| max_atom=45 | |
| class DrugGENCDK2Config(DrugGENConfig): | |
| submodel='DrugGEN' | |
| inference_model="experiments/models/DrugGEN-CDK2/" | |
| train_drug_smiles='data/cdk2_train.smi' | |
| max_atom=38 | |
| class NoTargetConfig(DrugGENConfig): | |
| submodel="NoTarget" | |
| inference_model="experiments/models/NoTarget/" | |
| train_drug_smiles='data/chembl_train.smi' # No specific target, use general ChEMBL data | |
| model_configs = { | |
| "DrugGEN-AKT1": DrugGENAKT1Config(), | |
| "DrugGEN-CDK2": DrugGENCDK2Config(), | |
| "DrugGEN-NoTarget": NoTargetConfig(), | |
| } | |
| def function(model_name: str, num_molecules: int, seed_num: int): | |
| ''' | |
| Returns: | |
| image, metrics_df, file_path, basic_metrics, advanced_metrics | |
| ''' | |
| if model_name == "DrugGEN-NoTarget": | |
| model_name = "NoTarget" | |
| config = model_configs[model_name] | |
| config.sample_num = num_molecules | |
| if config.sample_num > 250: | |
| raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.") | |
| if seed_num is None or seed_num.strip() == "": | |
| config.seed = random.randint(0, 10000) | |
| else: | |
| try: | |
| config.seed = int(seed_num) | |
| except ValueError: | |
| raise gr.Error("The seed must be an integer value!") | |
| inferer = Inference(config) | |
| start_time = time.time() | |
| scores = inferer.inference() # This returns a DataFrame with specific columns | |
| et = time.time() - start_time | |
| score_df = pd.DataFrame({ | |
| "Runtime (seconds)": [et], | |
| "Validity": [scores["validity"].iloc[0]], | |
| "Uniqueness": [scores["uniqueness"].iloc[0]], | |
| "Novelty (Train)": [scores["novelty"].iloc[0]], | |
| "Novelty (Test)": [scores["novelty_test"].iloc[0]], | |
| "Drug Novelty": [scores["drug_novelty"].iloc[0]], | |
| "Max Length": [scores["max_len"].iloc[0]], | |
| "Mean Atom Type": [scores["mean_atom_type"].iloc[0]], | |
| "SNN ChEMBL": [scores["snn_chembl"].iloc[0]], | |
| "SNN Drug": [scores["snn_drug"].iloc[0]], | |
| "Internal Diversity": [scores["IntDiv"].iloc[0]], | |
| "QED": [scores["qed"].iloc[0]], | |
| "SA Score": [scores["sa"].iloc[0]] | |
| }) | |
| # Create basic metrics dataframe | |
| basic_metrics = pd.DataFrame({ | |
| "Validity": [scores["validity"].iloc[0]], | |
| "Uniqueness": [scores["uniqueness"].iloc[0]], | |
| "Novelty (Train)": [scores["novelty"].iloc[0]], | |
| "Novelty (Test)": [scores["novelty_test"].iloc[0]], | |
| "Drug Novelty": [scores["drug_novelty"].iloc[0]], | |
| "Runtime (s)": [round(et, 2)] | |
| }) | |
| # Create advanced metrics dataframe | |
| advanced_metrics = pd.DataFrame({ | |
| "QED": [scores["qed"].iloc[0]], | |
| "SA Score": [scores["sa"].iloc[0]], | |
| "Internal Diversity": [scores["IntDiv"].iloc[0]], | |
| "SNN ChEMBL": [scores["snn_chembl"].iloc[0]], | |
| "SNN Drug": [scores["snn_drug"].iloc[0]], | |
| "Max Length": [scores["max_len"].iloc[0]] | |
| }) | |
| output_file_path = f'experiments/inference/{model_name}/inference_drugs.txt' | |
| new_path = f'{model_name}_denovo_mols.smi' | |
| os.rename(output_file_path, new_path) | |
| with open(new_path) as f: | |
| inference_drugs = f.read() | |
| generated_molecule_list = inference_drugs.split("\n")[:-1] | |
| rng = random.Random(config.seed) | |
| if num_molecules > 12: | |
| selected_molecules = rng.choices(generated_molecule_list, k=12) | |
| else: | |
| selected_molecules = generated_molecule_list | |
| selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules if Chem.MolFromSmiles(mol) is not None] | |
| drawOptions = Draw.rdMolDraw2D.MolDrawOptions() | |
| drawOptions.prepareMolsBeforeDrawing = False | |
| drawOptions.bondLineWidth = 0.5 | |
| molecule_image = Draw.MolsToGridImage( | |
| selected_molecules, | |
| molsPerRow=3, | |
| subImgSize=(400, 400), | |
| maxMols=len(selected_molecules), | |
| # legends=None, | |
| returnPNG=False, | |
| drawOptions=drawOptions, | |
| highlightAtomLists=None, | |
| highlightBondLists=None, | |
| ) | |
| return molecule_image, new_path, basic_metrics, advanced_metrics | |
| with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
| # Add custom CSS for styling | |
| gr.HTML(""" | |
| <style> | |
| #metrics-container { | |
| border: 1px solid rgba(128, 128, 128, 0.3); | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin-top: 15px; | |
| margin-bottom: 15px; | |
| background-color: rgba(255, 255, 255, 0.05); | |
| } | |
| </style> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks") | |
| gr.HTML(""" | |
| <div style="display: flex; gap: 10px; margin-bottom: 15px;"> | |
| <a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;"> | |
| <div style="display: inline-block; background-color: #b31b1b; color: white; padding: 5px 10px; border-radius: 5px; font-size: 14px;"> | |
| <span style="font-weight: bold;">arXiv</span> 2302.07868 | |
| </div> | |
| </a> | |
| <a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;"> | |
| <div style="display: inline-block; background-color: #24292e; color: white; padding: 5px 10px; border-radius: 5px; font-size: 14px;"> | |
| <span style="font-weight: bold;">GitHub</span> Repository | |
| </div> | |
| </a> | |
| </div> | |
| """) | |
| with gr.Accordion("About DrugGEN Models", open=False): | |
| gr.Markdown(""" | |
| ## Model Variations | |
| ### DrugGEN-AKT1 | |
| This model is designed to generate molecules targeting the human AKT1 protein (UniProt ID: P31749). | |
| ### DrugGEN-CDK2 | |
| This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941). | |
| ### DrugGEN-NoTarget | |
| This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. It's useful for: | |
| - Exploring chemical space | |
| - Generating diverse scaffolds | |
| - Creating molecules with drug-like properties | |
| For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868). | |
| """) | |
| with gr.Accordion("Understanding the Metrics", open=False): | |
| gr.Markdown(""" | |
| ## Evaluation Metrics | |
| ### Basic Metrics | |
| - **Validity**: Percentage of generated molecules that are chemically valid | |
| - **Uniqueness**: Percentage of unique molecules among valid ones | |
| - **Runtime**: Time taken to generate the requested molecules | |
| ### Novelty Metrics | |
| - **Novelty (Train)**: Percentage of molecules not found in the training set | |
| - **Novelty (Test)**: Percentage of molecules not found in the test set | |
| - **Drug Novelty**: Percentage of molecules not found in known inhibitors of the target protein | |
| ### Structural Metrics | |
| - **Max Length**: Maximum component length in the generated molecules | |
| - **Mean Atom Type**: Average distribution of atom types | |
| - **Internal Diversity**: Diversity within the generated set (higher is more diverse) | |
| ### Drug-likeness Metrics | |
| - **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better) | |
| - **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is easier) | |
| ### Similarity Metrics | |
| - **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds) | |
| - **SNN Drug**: Similarity to known drugs (higher means more similar to approved drugs) | |
| """) | |
| model_name = gr.Radio( | |
| choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"), | |
| value="DrugGEN-AKT1", | |
| label="Select Target Model", | |
| info="Choose which protein target or general model to use for molecule generation" | |
| ) | |
| num_molecules = gr.Slider( | |
| minimum=10, | |
| maximum=250, | |
| value=100, | |
| step=10, | |
| label="Number of Molecules to Generate", | |
| info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, We set a 250-molecule cap. On a GPU, the model can generate 10,000 molecules in the same amount of time. Please check our GitHub repo for running our models on GPU." | |
| ) | |
| seed_num = gr.Textbox( | |
| label="Random Seed (Optional)", | |
| value="", | |
| info="Set a specific seed for reproducible results, or leave empty for random generation" | |
| ) | |
| submit_button = gr.Button( | |
| value="Generate Molecules", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=2): | |
| basic_metrics_df = gr.Dataframe( | |
| headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)", "Drug Novelty", "Runtime (s)"], | |
| elem_id="basic-metrics" | |
| ) | |
| advanced_metrics_df = gr.Dataframe( | |
| headers=["QED", "SA Score", "Internal Diversity", "SNN ChEMBL", "SNN Drug", "Max Length"], | |
| elem_id="advanced-metrics" | |
| ) | |
| image_output = gr.Image( | |
| label="Sample of Generated Molecules", | |
| elem_id="molecule_display" | |
| ) | |
| file_download = gr.File( | |
| label="Download All Generated Molecules (SMILES format)", | |
| ) | |
| gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)") | |
| submit_button.click( | |
| function, | |
| inputs=[model_name, num_molecules, seed_num], | |
| outputs=[ | |
| image_output, | |
| file_download, | |
| basic_metrics_df, | |
| advanced_metrics_df | |
| ], | |
| api_name="inference" | |
| ) | |
| #demo.queue(concurrency_count=1) | |
| demo.queue() | |
| demo.launch() |