Spaces:
Running
Running
import gradio as gr | |
from inference import Inference | |
import PIL | |
from PIL import Image | |
import pandas as pd | |
import random | |
from rdkit import Chem | |
from rdkit.Chem import Draw | |
from rdkit.Chem.Draw import IPythonConsole | |
import shutil | |
import os | |
import time | |
class DrugGENConfig: | |
# Inference configuration | |
submodel='DrugGEN' | |
inference_model="/home/user/app/experiments/models/DrugGEN/" | |
sample_num=100 | |
# Data configuration | |
inf_smiles='/home/user/app/data/chembl_test.smi' | |
train_smiles='/home/user/app/data/chembl_train.smi' | |
inf_batch_size=1 | |
mol_data_dir='/home/user/app/data' | |
features=False | |
# Model configuration | |
act='relu' | |
max_atom=45 | |
dim=128 | |
depth=1 | |
heads=8 | |
mlp_ratio=3 | |
dropout=0. | |
# Seed configuration | |
set_seed=True | |
seed=10 | |
disable_correction=False | |
class DrugGENAKT1Config(DrugGENConfig): | |
submodel='DrugGEN' | |
inference_model="/home/user/app/experiments/models/DrugGEN-akt1/" | |
train_drug_smiles='/home/user/app/data/akt_train.smi' | |
max_atom=45 | |
class DrugGENCDK2Config(DrugGENConfig): | |
submodel='DrugGEN' | |
inference_model="/home/user/app/experiments/models/DrugGEN-cdk2/" | |
train_drug_smiles='/home/user/app//data/cdk2_train.smi' | |
max_atom=38 | |
class NoTargetConfig(DrugGENConfig): | |
submodel="NoTarget" | |
inference_model="/home/user/app/experiments/models/NoTarget/" | |
model_configs = { | |
"DrugGEN-AKT1": DrugGENAKT1Config(), | |
"DrugGEN-CDK2": DrugGENCDK2Config(), | |
"DrugGEN-NoTarget": NoTargetConfig(), | |
} | |
def function(model_name: str, num_molecules: int, seed_num: int): | |
''' | |
Returns: | |
image, metrics_df, file_path, basic_metrics, advanced_metrics | |
''' | |
if model_name == "DrugGEN-NoTarget": | |
model_name = "NoTarget" | |
config = model_configs[model_name] | |
config.sample_num = num_molecules | |
if config.sample_num > 250: | |
raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.") | |
if seed_num is None or seed_num.strip() == "": | |
config.seed = random.randint(0, 10000) | |
else: | |
try: | |
config.seed = int(seed_num) | |
except ValueError: | |
raise gr.Error("The seed must be an integer value!") | |
if model_name != "NoTarget": | |
model_name = "DrugGEN" | |
inferer = Inference(config) | |
start_time = time.time() | |
scores = inferer.inference() # This returns a DataFrame with specific columns | |
et = time.time() - start_time | |
score_df = pd.DataFrame({ | |
"Runtime (seconds)": [et], | |
"Validity": [scores["validity"].iloc[0]], | |
"Uniqueness": [scores["uniqueness"].iloc[0]], | |
"Novelty (Train)": [scores["novelty"].iloc[0]], | |
"Novelty (Test)": [scores["novelty_test"].iloc[0]], | |
"Drug Novelty": [scores["drug_novelty"].iloc[0]], | |
"Max Length": [scores["max_len"].iloc[0]], | |
"Mean Atom Type": [scores["mean_atom_type"].iloc[0]], | |
"SNN ChEMBL": [scores["snn_chembl"].iloc[0]], | |
"SNN Drug": [scores["snn_drug"].iloc[0]], | |
"Internal Diversity": [scores["IntDiv"].iloc[0]], | |
"QED": [scores["qed"].iloc[0]], | |
"SA Score": [scores["sa"].iloc[0]] | |
}) | |
# Create basic metrics dataframe | |
basic_metrics = pd.DataFrame({ | |
"Validity": [scores["validity"].iloc[0]], | |
"Uniqueness": [scores["uniqueness"].iloc[0]], | |
"Novelty (Train)": [scores["novelty"].iloc[0]], | |
"Novelty (Test)": [scores["novelty_test"].iloc[0]], | |
"Drug Novelty": [scores["drug_novelty"].iloc[0]], | |
"Runtime (s)": [round(et, 2)] | |
}) | |
# Create advanced metrics dataframe | |
advanced_metrics = pd.DataFrame({ | |
"QED": [scores["qed"].iloc[0]], | |
"SA Score": [scores["sa"].iloc[0]], | |
"Internal Diversity": [scores["IntDiv"].iloc[0]], | |
"SNN ChEMBL": [scores["snn_chembl"].iloc[0]], | |
"SNN Drug": [scores["snn_drug"].iloc[0]], | |
"Max Length": [scores["max_len"].iloc[0]] | |
}) | |
output_file_path = f'/home/user/app/experiments/inference/{model_name}/inference_drugs.txt' | |
new_path = f'{model_name}_denovo_mols.smi' | |
os.rename(output_file_path, new_path) | |
with open(new_path) as f: | |
inference_drugs = f.read() | |
generated_molecule_list = inference_drugs.split("\n")[:-1] | |
rng = random.Random(config.seed) | |
if num_molecules > 12: | |
selected_molecules = rng.choices(generated_molecule_list, k=12) | |
else: | |
selected_molecules = generated_molecule_list | |
selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules if Chem.MolFromSmiles(mol) is not None] | |
drawOptions = Draw.rdMolDraw2D.MolDrawOptions() | |
drawOptions.prepareMolsBeforeDrawing = False | |
drawOptions.bondLineWidth = 0.5 | |
molecule_image = Draw.MolsToGridImage( | |
selected_molecules, | |
molsPerRow=3, | |
subImgSize=(400, 400), | |
maxMols=len(selected_molecules), | |
# legends=None, | |
returnPNG=False, | |
drawOptions=drawOptions, | |
highlightAtomLists=None, | |
highlightBondLists=None, | |
) | |
return molecule_image, new_path, basic_metrics, advanced_metrics | |
with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
# Add custom CSS for styling | |
gr.HTML(""" | |
<style> | |
#metrics-container { | |
border: 1px solid rgba(128, 128, 128, 0.3); | |
border-radius: 8px; | |
padding: 15px; | |
margin-top: 15px; | |
margin-bottom: 15px; | |
background-color: rgba(255, 255, 255, 0.05); | |
} | |
</style> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks") | |
gr.HTML(""" | |
<div style="display: flex; gap: 10px; margin-bottom: 15px;"> | |
<!-- arXiv badge --> | |
<a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;"> | |
<div style=" | |
display: inline-block; | |
background-color: #b31b1b; | |
color: #ffffff !important; /* Force white text */ | |
padding: 5px 10px; | |
border-radius: 5px; | |
font-size: 14px;" | |
> | |
<span style="font-weight: bold;">arXiv</span> 2302.07868 | |
</div> | |
</a> | |
<!-- GitHub badge --> | |
<a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;"> | |
<div style=" | |
display: inline-block; | |
background-color: #24292e; | |
color: #ffffff !important; /* Force white text */ | |
padding: 5px 10px; | |
border-radius: 5px; | |
font-size: 14px;" | |
> | |
<span style="font-weight: bold;">GitHub</span> Repository | |
</div> | |
</a> | |
</div> | |
""") | |
with gr.Accordion("About DrugGEN Models", open=False): | |
gr.Markdown(""" | |
## Model Variations | |
### DrugGEN-AKT1 | |
This model is designed to generate molecules targeting the human AKT1 protein (UniProt ID: P31749). | |
### DrugGEN-CDK2 | |
This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941). | |
### DrugGEN-NoTarget | |
This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. It's useful for: | |
- Exploring chemical space | |
- Generating diverse scaffolds | |
- Creating molecules with drug-like properties | |
For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868). | |
""") | |
with gr.Accordion("Understanding the Metrics", open=False): | |
gr.Markdown(""" | |
## Evaluation Metrics | |
### Basic Metrics | |
- **Validity**: Percentage of generated molecules that are chemically valid | |
- **Uniqueness**: Percentage of unique molecules among valid ones | |
- **Runtime**: Time taken to generate the requested molecules | |
### Novelty Metrics | |
- **Novelty (Train)**: Percentage of molecules not found in the training set | |
- **Novelty (Test)**: Percentage of molecules not found in the test set | |
- **Drug Novelty**: Percentage of molecules not found in known inhibitors of the target protein | |
### Structural Metrics | |
- **Max Length**: Maximum component length in the generated molecules | |
- **Mean Atom Type**: Average distribution of atom types | |
- **Internal Diversity**: Diversity within the generated set (higher is more diverse) | |
### Drug-likeness Metrics | |
- **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better) | |
- **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is easier) | |
### Similarity Metrics | |
- **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds) | |
- **SNN Drug**: Similarity to known drugs (higher means more similar to approved drugs) | |
""") | |
model_name = gr.Radio( | |
choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"), | |
value="DrugGEN-AKT1", | |
label="Select Target Model", | |
info="Choose which protein target or general model to use for molecule generation" | |
) | |
num_molecules = gr.Slider( | |
minimum=10, | |
maximum=250, | |
value=100, | |
step=10, | |
label="Number of Molecules to Generate", | |
info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, We set a 250-molecule cap. On a GPU, the model can generate 10,000 molecules in the same amount of time. Please check our GitHub repo for running our models on GPU." | |
) | |
seed_num = gr.Textbox( | |
label="Random Seed (Optional)", | |
value="", | |
info="Set a specific seed for reproducible results, or leave empty for random generation" | |
) | |
submit_button = gr.Button( | |
value="Generate Molecules", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=2): | |
basic_metrics_df = gr.Dataframe( | |
headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)", "Novelty (Drug)", "Runtime (s)"], | |
elem_id="basic-metrics" | |
) | |
advanced_metrics_df = gr.Dataframe( | |
headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Drug)", "Max Length"], | |
elem_id="advanced-metrics" | |
) | |
file_download = gr.File( | |
label="Download All Generated Molecules (SMILES format)", | |
) | |
image_output = gr.Image( | |
label="Structures of Randomly Selected Generated Molecules", | |
elem_id="molecule_display" | |
) | |
gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)") | |
submit_button.click( | |
function, | |
inputs=[model_name, num_molecules, seed_num], | |
outputs=[ | |
image_output, | |
file_download, | |
basic_metrics_df, | |
advanced_metrics_df | |
], | |
api_name="inference" | |
) | |
#demo.queue(concurrency_count=1) | |
demo.queue() | |
demo.launch() |