|
import streamlit as st |
|
import os |
|
import pandas as pd |
|
import numpy as np |
|
|
|
import matplotlib.pyplot as plt |
|
from rdkit import Chem |
|
from streamlit_ketcher import st_ketcher |
|
from io import StringIO |
|
|
|
|
|
from openadmet_models.models.gradient_boosting.lgbm import LGBMRegressorModel |
|
from openadmet_models.features.combine import FeatureConcatenator |
|
from openadmet_models.features.molfeat_properties import DescriptorFeaturizer |
|
from openadmet_models.features.molfeat_fingerprint import FingerprintFeaturizer |
|
|
|
|
|
|
|
def _is_valid_smiles(smi): |
|
if smi is None or smi == "": |
|
return False |
|
try: |
|
m = Chem.MolFromSmiles(smi) |
|
if m is None: |
|
return False |
|
else: |
|
return True |
|
except: |
|
return False |
|
|
|
|
|
def sdf_str_to_rdkit_mol(sdf): |
|
from io import BytesIO |
|
|
|
bio = BytesIO(sdf.encode()) |
|
suppl = Chem.ForwardSDMolSupplier(bio, removeHs=False) |
|
mols = [mol for mol in suppl if mol is not None] |
|
return mols |
|
|
|
|
|
@st.cache_data |
|
def convert_df(df): |
|
|
|
return df.to_csv().encode("utf-8") |
|
|
|
|
|
|
|
def get_model(path, target, model_type): |
|
|
|
|
|
model_path = os.path.join(path, f"{model_type}/{target.lower()}_model.json") |
|
model_file = os.path.join(path, f"{model_type}/{target.lower()}_model.pkl") |
|
|
|
print(model_path, model_file) |
|
|
|
if not os.path.exists(model_path) or not os.path.exists(model_file): |
|
return None |
|
|
|
model = LGBMRegressorModel.deserialize(model_path, model_file) |
|
featurizer = FeatureConcatenator(featurizers=[FingerprintFeaturizer(fp_type="ecfp:4"), DescriptorFeaturizer(descr_type="mordred")]) |
|
return model, featurizer |
|
|
|
|
|
st.title("OpenADMET Streamlit DEMO") |
|
|
|
|
|
st.markdown("## Background") |
|
|
|
st.markdown( |
|
"**The [OpenADMET](https://openadmet.org) initiative provides a suite of open-source machine learning models to predict ADMET (Absorption, Distribution, Metabolism, Excretion, and Toxicity) properties, facilitating drug discovery and development.**" |
|
) |
|
|
|
st.markdown( |
|
"This web app enables researchers and scientists to leverage OpenADMET’s models without needing to write or run code, making predictive analytics more accessible." |
|
) |
|
st.markdown("---") |
|
st.markdown("## Input :clipboard:") |
|
|
|
input = st.selectbox( |
|
"How would you like to enter your input?", |
|
["Upload a CSV file", "Draw a molecule", "Enter SMILES", "Upload an SDF file"], |
|
key="input", |
|
) |
|
|
|
multismiles = False |
|
if input == "Draw a molecule": |
|
smiles = st_ketcher(None) |
|
if _is_valid_smiles(smiles): |
|
st.success("Valid molecule", icon="✅") |
|
else: |
|
st.error("Invalid molecule", icon="🚨") |
|
st.stop() |
|
smiles = [smiles] |
|
queried_df = pd.DataFrame(smiles, columns=["SMILES"]) |
|
smiles_column_name = "SMILES" |
|
smiles_column = queried_df[smiles_column_name] |
|
elif input == "Enter SMILES": |
|
smiles = st.text_input("Enter a SMILES string", key="smiles_user_input") |
|
if _is_valid_smiles(smiles): |
|
st.success("Valid SMILES string", icon="✅") |
|
else: |
|
st.error("Invalid SMILES string", icon="🚨") |
|
st.stop() |
|
smiles = [smiles] |
|
queried_df = pd.DataFrame(smiles, columns=["SMILES"]) |
|
smiles_column_name = "SMILES" |
|
smiles_column = queried_df[smiles_column_name] |
|
elif input == "Upload a CSV file": |
|
|
|
uploaded_file = st.file_uploader( |
|
"Choose a CSV file to upload your predictions to", type="csv", key="csv_file" |
|
) |
|
|
|
|
|
if uploaded_file is not None: |
|
queried_df = pd.read_csv(uploaded_file) |
|
else: |
|
st.stop() |
|
|
|
smiles_column_name = st.selectbox("Select a SMILES column", queried_df.columns, key="df_smiles_column") |
|
multismiles = True |
|
smiles_column = queried_df[smiles_column_name] |
|
|
|
|
|
valid_smiles = [_is_valid_smiles(smi) for smi in smiles_column] |
|
if not all(valid_smiles): |
|
st.error( |
|
"Some of the SMILES strings are invalid, please check the input", icon="🚨" |
|
) |
|
st.stop() |
|
st.success( |
|
f"All SMILES strings are valid (n={len(valid_smiles)}), proceeding with prediction", |
|
icon="✅", |
|
) |
|
|
|
elif input == "Upload an SDF file": |
|
|
|
uploaded_file = st.file_uploader( |
|
"Choose a SDF file to upload your predictions to", type="sdf" |
|
) |
|
|
|
if uploaded_file is not None: |
|
|
|
try: |
|
stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) |
|
|
|
string_data = stringio.read() |
|
mols = sdf_str_to_rdkit_mol(string_data) |
|
smiles = [Chem.MolToSmiles(m) for m in mols] |
|
queried_df = pd.DataFrame(smiles, columns=["SMILES"]) |
|
except: |
|
st.error("Error reading the SDF file, please check the input", icon="🚨") |
|
st.stop() |
|
else: |
|
st.error("No file uploaded", icon="🚨") |
|
st.stop() |
|
|
|
st.success( |
|
f"All molecule entries are valid (n={len(queried_df)}), proceeding with prediction", |
|
icon="✅", |
|
) |
|
smiles_column_name = "SMILES" |
|
smiles_column = queried_df[smiles_column_name] |
|
multismiles = True |
|
|
|
st.markdown("## Model parameters :nut_and_bolt:") |
|
|
|
|
|
|
|
|
|
targets = ['CYP3A4', 'CYP2D6', 'CYP2C9'] |
|
models = {"ecfp:4 Mordred LGBM":"ecfp4_mordred_lgbm", "ChemProp":"chemprop"} |
|
models_inversed = {v: k for k, v in models.items()} |
|
|
|
model_names = list(models.keys()) |
|
|
|
endpoints = ["pIC50"] |
|
|
|
|
|
target_value = st.selectbox("Select a biological target ", targets, key="target") |
|
|
|
|
|
|
|
|
|
endpoint_value = st.selectbox("Select a property ", endpoints, key="endpoint") |
|
|
|
model_value = st.selectbox("Select a model type ", model_names, key="model") |
|
|
|
|
|
if target_value != "CYP3A4": |
|
st.write("Only CYP3A4 is currently supported") |
|
st.stop() |
|
|
|
if endpoint_value != "pIC50": |
|
st.write("Only pIC50 is currently supported") |
|
st.stop() |
|
|
|
if model_value != "ecfp:4 Mordred LGBM": |
|
st.write("Only ecfp:4 Mordred LGBM is currently supported") |
|
st.stop() |
|
|
|
model, featurizer = get_model("./models", target_value, models[model_value]) |
|
|
|
|
|
|
|
|
|
if model is None: |
|
st.write(f"No model found for {target_value} {endpoint_value}") |
|
st.stop() |
|
|
|
|
|
st.markdown("## Prediction 🚀") |
|
|
|
|
|
st.write( |
|
f"Predicting **{target_value} {endpoint_value}** using model:\n\n `{model_value}`" |
|
) |
|
|
|
|
|
X, _ = featurizer.featurize(smiles_column) |
|
|
|
preds = model.predict(X) |
|
|
|
|
|
err = None |
|
|
|
|
|
pred_column_name = f"{target_value}_computed-{endpoint_value}" |
|
unc_column_name = f"{target_value}_computed-{endpoint_value}_uncertainty" |
|
queried_df[pred_column_name] = preds |
|
queried_df[unc_column_name] = err |
|
|
|
st.markdown("---") |
|
if multismiles: |
|
|
|
|
|
fig, ax = plt.subplots() |
|
|
|
sorted_df = queried_df.sort_values(by=pred_column_name) |
|
n_bins = int(len(sorted_df[pred_column_name]) / 10) |
|
if n_bins < 5: |
|
n_bins = 5 |
|
|
|
ax.hist(sorted_df[pred_column_name], bins=n_bins) |
|
|
|
ax.set_ylabel("Count") |
|
ax.set_xlabel(f"Computed {endpoint_value}") |
|
ax.set_title(f"Histogram of computed {endpoint_value} for target: {target_value}") |
|
|
|
st.pyplot(fig) |
|
|
|
|
|
fig, ax = plt.subplots() |
|
|
|
ax.bar(range(len(sorted_df)), sorted_df[pred_column_name]) |
|
|
|
ax.set_xticks([]) |
|
ax.set_xlabel(f"Query compounds") |
|
ax.set_ylabel(f"Computed {endpoint_value}") |
|
|
|
ax.set_title(f"Barplot of computed {endpoint_value} for target: {target_value}") |
|
|
|
st.pyplot(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
preds = queried_df[pred_column_name].values[0] |
|
smiles = queried_df["SMILES"].values[0] |
|
if err: |
|
err = queried_df[unc_column_name].values[0] |
|
errstr = f"± {err:.2f}" |
|
else: |
|
errstr = "" |
|
|
|
st.markdown( |
|
f"Predicted {target_value} {endpoint_value} for {smiles} is {preds:.2f} {errstr}." |
|
) |
|
|
|
|
|
csv = convert_df(queried_df) |
|
st.download_button( |
|
label="Download data as CSV", |
|
data=csv, |
|
file_name=f"predictions_{model_value}.csv", |
|
mime="text/csv", |
|
) |