import gradio as gr
from huggingface_hub import HfApi, get_collection, list_collections, list_models
#from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
from utils import MolecularGenerationModel
import pandas as pd
import os
import spaces
#candidate_models = get_models()
#task_names = {
# 'mit_synthesis': 'Reaction Synthesis',
# 'full_retro': 'Reaction Retro Synthesis'
#}
#task_names_to_tasks = {v: k for k, v in task_names.items()}
#tasks = list(candidate_models.keys())
#task_descriptions = {
# 'mit_synthesis': 'Predict the reaction products given the reactants and reagents. \n' + \
# '1. This model is trained on the USPTO MIT dataset. \n' + \
# '2. The reactants and reagents are mixed in the input SMILES string. \n' + \
# '3. Different compounds are separated by ".". \n' + \
# '4. Input SMILES string example: C1CCOC1.N#Cc1ccsc1N.O=[N+]([O-])c1cc(F)c(F)cc1F.[H-].[Na+]',
# 'full_retro': 'Predict the reaction precursors given the reaction products. \n' + \
# '1. This model is trained on the USPTO Full dataset. \n' + \
# '2. In this dataset, we consider only a single product in the input SMILES string. \n' + \
# '3. Input SMILES string example: CC(=O)OCC(=O)[C@@]1(O)CC[C@H]2[C@@H]3CCC4=CC(=O)CC[C@]4(C)C3=CC[C@@]21C'
#}
#property_names = list(candidate_models.keys())
model = MolecularGenerationModel()
@spaces.GPU(duration=60)
def predict_single_label(logp, tpas, sas, qed, logp_choose, tpsa_choose, sas_choose, qed_choose):
input_dict = dict()
if logp_choose:
input_dict['logP'] = logp
if tpsa_choose:
input_dict['TPSA'] = tpas
if sas_choose:
input_dict['SAS'] = sas
if qed_choose:
input_dict['QED'] = qed
if len(input_dict) == 0:
return "NA", "No input is selected"
print(input_dict)
try:
running_status = None
prediction = None
prediction = model.predict_single_smiles(input_dict)
#prediction = model.predict(smiles, property_name, adapter_id)
#prediction = model.predict_single_smiles(smiles, task)
if prediction is None:
return "NA", "Invalid SMILES string"
except Exception as e:
# no matter what the error is, we should return
print(e)
return "NA", "Generation failed"
#prediction = "\n".join([f"{idx+1}. {item}" for idx, item in enumerate(prediction)])
return prediction, "Generation is done"
"""
def get_description(task_name):
task = task_names_to_tasks[task_name]
return task_descriptions[task]
#@spaces.GPU(duration=10)
"""
"""
@spaces.GPU(duration=30)
def predict_file(file, property_name):
property_id = dataset_property_names_to_dataset[property_name]
try:
adapter_id = candidate_models[property_id]
info = model.swith_adapter(property_id, adapter_id)
running_status = None
if info == "keep":
running_status = "Adapter is the same as the current one"
#print("Adapter is the same as the current one")
elif info == "switched":
running_status = "Adapter is switched successfully"
#print("Adapter is switched successfully")
elif info == "error":
running_status = "Adapter is not found"
#print("Adapter is not found")
return None, None, file, running_status
else:
running_status = "Unknown error"
return None, None, file, running_status
df = pd.read_csv(file)
# we have already checked the file contains the "smiles" column
df = model.predict_file(df, dataset_task_types[property_id])
# we should save this file to the disk to be downloaded
# rename the file to have "_prediction" suffix
prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
print(file, prediction_file)
# save the file to the disk
df.to_csv(prediction_file, index=False)
except Exception as e:
# no matter what the error is, we should return
print(e)
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), file, "Prediction failed"
return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done"
def validate_file(file):
try:
if file.endswith(".csv"):
df = pd.read_csv(file)
if "smiles" not in df.columns:
# we should clear the file input
return "Invalid file content. The csv file must contain column named 'smiles'", \
None, gr.update(visible=False), gr.update(visible=False)
# check the length of the smiles
length = len(df["smiles"])
elif file.endswith(".smi"):
return "Invalid file extension", \
None, gr.update(visible=False), gr.update(visible=False)
else:
return "Invalid file extension", \
None, gr.update(visible=False), gr.update(visible=False)
except Exception as e:
return "Invalid file content.", \
None, gr.update(visible=False), gr.update(visible=False)
if length > 100:
return "The space does not support the file containing more than 100 SMILES", \
None, gr.update(visible=False), gr.update(visible=False)
return "Valid file", file, gr.update(visible=True), gr.update(visible=False)
"""
def raise_error(status):
if status != "Valid file":
raise gr.Error(status)
return None
"""
def clear_file(download_button):
# we might need to delete the prediction file and uploaded file
prediction_path = download_button
print(prediction_path)
if prediction_path and os.path.exists(prediction_path):
os.remove(prediction_path)
original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv")
original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi")
if os.path.exists(original_data_file_0):
os.remove(original_data_file_0)
if os.path.exists(original_data_file_1):
os.remove(original_data_file_1)
#if os.path.exists(file):
# os.remove(file)
#prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
#if os.path.exists(prediction_file):
# os.remove(prediction_file)
return gr.update(visible=False), gr.update(visible=False), None
"""
def toggle_slider(checked):
return gr.update(interactive=checked)
def toggle_sliders_based_on_checkboxes(checked_values):
"""Enable or disable sliders based on the corresponding checkbox values."""
return [gr.update(interactive=checked_values[i]) for i in range(4)]
def build_inference():
with gr.Blocks() as demo:
# first row - Dropdown input
#with gr.Row():
#gr.Markdown(f"If you run out of your GPU quota, you can use the CPU-powered space but with much lower performance.")
#dropdown = gr.Dropdown([task_names[key] for key in tasks], label="Task", value=task_names[tasks[0]])
description = f"This space allows you to generate ten possible molecules based on given conditions. \n" \
f"1. You can enable or disable specific properties using checkboxes and adjust their values with sliders. \n" \
f"2. The generated SMILES strings and their corresponding predicted properties will be displayed in the generations section. \n" \
f"3. The properties include logP, TPSA, SAS, and QED. \n" \
f"4. Model trained on the GuacaMol dataset for molecular design. "
description_box = gr.Textbox(label="Task description", lines=5,
interactive=False,
value= description)
# third row - Textbox input and prediction label
with gr.Row(equal_height=True):
with gr.Column():
checkbox_1 = gr.Checkbox(label="logP", value=True)
slider_1 = gr.Slider(1, 7, value=4, label="logP", info="Choose between 1 and 7")
checkbox_1.change(toggle_slider, checkbox_1, slider_1)
with gr.Column():
checkbox_2 = gr.Checkbox(label="TPSA", value=True)
slider_2 = gr.Slider(20, 140, value=80, label="TPSA", info="Choose between 20 and 140")
checkbox_2.change(toggle_slider, checkbox_2, slider_2)
with gr.Column():
checkbox_3 = gr.Checkbox(label="SAS", value=True)
slider_3 = gr.Slider(1, 5, value=3, label="SAS", info="Choose between 1 and 5")
checkbox_3.change(toggle_slider, checkbox_3, slider_3)
with gr.Column():
checkbox_4 = gr.Checkbox(label="QED", value=True)
slider_4 = gr.Slider(0.1, 0.9, value=0.5, label="QED", info="Choose between 0.1 and 0.9")
checkbox_4.change(toggle_slider, checkbox_4, slider_4)
predict_single_smiles_button = gr.Button("Generate", size='sm')
#prediction = gr.Label("Prediction will appear here")
#prediction = gr.Textbox(label="Predictions", type="text", placeholder=None, lines=10, interactive=False)
prediction = gr.Dataframe(label="Generations", type="pandas", interactive=False)
running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
# dropdown change event
# predict single button click event
predict_single_label.zerogpu=True
predict_single_smiles_button.click(lambda:(gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
) , outputs=[slider_1, slider_2, slider_3, slider_4,
checkbox_1, checkbox_2, checkbox_3, checkbox_4,
predict_single_smiles_button, running_terminal_label])\
.then(predict_single_label, inputs=[slider_1, slider_2, slider_3, slider_4,
checkbox_1, checkbox_2, checkbox_3, checkbox_4
], outputs=[prediction, running_terminal_label])\
.then(lambda a, b, c, d: toggle_sliders_based_on_checkboxes([a, b, c, d]) +
[gr.update(interactive=True)] * 6,
inputs=[checkbox_1, checkbox_2, checkbox_3, checkbox_4],
outputs=[slider_1, slider_2, slider_3, slider_4,
checkbox_1, checkbox_2, checkbox_3, checkbox_4,
predict_single_smiles_button, running_terminal_label])
return demo
demo = build_inference()
if __name__ == '__main__':
demo.launch()