hmacdope's picture
fix double header
366727d
import streamlit as st
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from rdkit import Chem
from streamlit_ketcher import st_ketcher
from io import StringIO
from openadmet_models.models.gradient_boosting.lgbm import LGBMRegressorModel
from openadmet_models.features.combine import FeatureConcatenator
from openadmet_models.features.molfeat_properties import DescriptorFeaturizer
from openadmet_models.features.molfeat_fingerprint import FingerprintFeaturizer
def _is_valid_smiles(smi):
if smi is None or smi == "":
return False
try:
m = Chem.MolFromSmiles(smi)
if m is None:
return False
else:
return True
except:
return False
def sdf_str_to_rdkit_mol(sdf):
from io import BytesIO
bio = BytesIO(sdf.encode())
suppl = Chem.ForwardSDMolSupplier(bio, removeHs=False)
mols = [mol for mol in suppl if mol is not None]
return mols
@st.cache_data
def convert_df(df):
# IMPORTANT: Cache the conversion to prevent computation on every rerun
return df.to_csv().encode("utf-8")
def get_model(path, target, model_type):
model_path = os.path.join(path, f"{model_type}/{target.lower()}_model.json")
model_file = os.path.join(path, f"{model_type}/{target.lower()}_model.pkl")
print(model_path, model_file)
if not os.path.exists(model_path) or not os.path.exists(model_file):
return None
model = LGBMRegressorModel.deserialize(model_path, model_file)
featurizer = FeatureConcatenator(featurizers=[FingerprintFeaturizer(fp_type="ecfp:4"), DescriptorFeaturizer(descr_type="mordred")])
return model, featurizer
# Set the title of the Streamlit app
st.title("OpenADMET Streamlit DEMO")
st.markdown("## Background")
st.markdown(
"**The [OpenADMET](https://openadmet.org) initiative provides a suite of open-source machine learning models to predict ADMET (Absorption, Distribution, Metabolism, Excretion, and Toxicity) properties, facilitating drug discovery and development.**"
)
st.markdown(
"This web app enables researchers and scientists to leverage OpenADMET’s models without needing to write or run code, making predictive analytics more accessible."
)
st.markdown("---")
st.markdown("## Input :clipboard:")
input = st.selectbox(
"How would you like to enter your input?",
["Upload a CSV file", "Draw a molecule", "Enter SMILES", "Upload an SDF file"],
key="input",
)
multismiles = False
if input == "Draw a molecule":
smiles = st_ketcher(None)
if _is_valid_smiles(smiles):
st.success("Valid molecule", icon="✅")
else:
st.error("Invalid molecule", icon="🚨")
st.stop()
smiles = [smiles]
queried_df = pd.DataFrame(smiles, columns=["SMILES"])
smiles_column_name = "SMILES"
smiles_column = queried_df[smiles_column_name]
elif input == "Enter SMILES":
smiles = st.text_input("Enter a SMILES string", key="smiles_user_input")
if _is_valid_smiles(smiles):
st.success("Valid SMILES string", icon="✅")
else:
st.error("Invalid SMILES string", icon="🚨")
st.stop()
smiles = [smiles]
queried_df = pd.DataFrame(smiles, columns=["SMILES"])
smiles_column_name = "SMILES"
smiles_column = queried_df[smiles_column_name]
elif input == "Upload a CSV file":
# Create a file uploader for CSV files
uploaded_file = st.file_uploader(
"Choose a CSV file to upload your predictions to", type="csv", key="csv_file"
)
# If a file is uploaded, parse it into a DataFrame
if uploaded_file is not None:
queried_df = pd.read_csv(uploaded_file)
else:
st.stop()
# Select a column from the DataFrame
smiles_column_name = st.selectbox("Select a SMILES column", queried_df.columns, key="df_smiles_column")
multismiles = True
smiles_column = queried_df[smiles_column_name]
# check if the smiles are valid
valid_smiles = [_is_valid_smiles(smi) for smi in smiles_column]
if not all(valid_smiles):
st.error(
"Some of the SMILES strings are invalid, please check the input", icon="🚨"
)
st.stop()
st.success(
f"All SMILES strings are valid (n={len(valid_smiles)}), proceeding with prediction",
icon="✅",
)
elif input == "Upload an SDF file":
# Create a file uploader for SDF files
uploaded_file = st.file_uploader(
"Choose a SDF file to upload your predictions to", type="sdf"
)
# read with rdkit
if uploaded_file is not None:
# To convert to a string based IO:
try:
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
# To read file as string:
string_data = stringio.read()
mols = sdf_str_to_rdkit_mol(string_data)
smiles = [Chem.MolToSmiles(m) for m in mols]
queried_df = pd.DataFrame(smiles, columns=["SMILES"])
except:
st.error("Error reading the SDF file, please check the input", icon="🚨")
st.stop()
else:
st.error("No file uploaded", icon="🚨")
st.stop()
st.success(
f"All molecule entries are valid (n={len(queried_df)}), proceeding with prediction",
icon="✅",
)
smiles_column_name = "SMILES"
smiles_column = queried_df[smiles_column_name]
multismiles = True
st.markdown("## Model parameters :nut_and_bolt:")
targets = ['CYP3A4', 'CYP2D6', 'CYP2C9']
models = {"ecfp:4 Mordred LGBM":"ecfp4_mordred_lgbm", "ChemProp":"chemprop"}
models_inversed = {v: k for k, v in models.items()}
model_names = list(models.keys())
endpoints = ["pIC50"]
# Select a target value from the preset list
target_value = st.selectbox("Select a biological target ", targets, key="target")
# endpoints
# Select a target value from the preset list
endpoint_value = st.selectbox("Select a property ", endpoints, key="endpoint")
model_value = st.selectbox("Select a model type ", model_names, key="model")
if target_value != "CYP3A4":
st.write("Only CYP3A4 is currently supported")
st.stop()
if endpoint_value != "pIC50":
st.write("Only pIC50 is currently supported")
st.stop()
if model_value != "ecfp:4 Mordred LGBM":
st.write("Only ecfp:4 Mordred LGBM is currently supported")
st.stop()
model, featurizer = get_model("./models", target_value, models[model_value])
if model is None:
st.write(f"No model found for {target_value} {endpoint_value}")
st.stop()
# retry with a different target or endpoint
st.markdown("## Prediction 🚀")
st.write(
f"Predicting **{target_value} {endpoint_value}** using model:\n\n `{model_value}`"
)
# featurize the smiles
X, _ = featurizer.featurize(smiles_column)
# predict the properties
preds = model.predict(X)
# not implemented yet
err = None
pred_column_name = f"{target_value}_computed-{endpoint_value}"
unc_column_name = f"{target_value}_computed-{endpoint_value}_uncertainty"
queried_df[pred_column_name] = preds
queried_df[unc_column_name] = err
st.markdown("---")
if multismiles:
# plot the predictions and errors
# Histogram first
fig, ax = plt.subplots()
sorted_df = queried_df.sort_values(by=pred_column_name)
n_bins = int(len(sorted_df[pred_column_name]) / 10)
if n_bins < 5: # makes the histogram slightly more interpretable with low data
n_bins = 5
ax.hist(sorted_df[pred_column_name], bins=n_bins)
ax.set_ylabel("Count")
ax.set_xlabel(f"Computed {endpoint_value}")
ax.set_title(f"Histogram of computed {endpoint_value} for target: {target_value}")
st.pyplot(fig)
# then a barplot
fig, ax = plt.subplots()
ax.bar(range(len(sorted_df)), sorted_df[pred_column_name])
ax.set_xticks([])
ax.set_xlabel(f"Query compounds")
ax.set_ylabel(f"Computed {endpoint_value}")
ax.set_title(f"Barplot of computed {endpoint_value} for target: {target_value}")
st.pyplot(fig)
# if endpoint_value == "pIC50":
# from rdkit.Chem.Descriptors import MolWt
# import seaborn as sns
# # then a scatterplot of uncertainty vs MW
# queried_df["MW"] = [
# MolWt(Chem.MolFromSmiles(smi)) for smi in sorted_df[smiles_column_name]
# ]
# fig, ax = plt.subplots()
# ax = sns.scatterplot(
# x="MW",
# y=pred_column_name,
# hue=unc_column_name,
# palette="coolwarm",
# data=queried_df,
# )
# norm = plt.Normalize(
# queried_df[unc_column_name].min(), queried_df[unc_column_name].max()
# )
# sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
# sm.set_array([])
# # Remove the legend and add a colorbar
# cbar = ax.figure.colorbar(sm, ax=ax)
# ax.annotate(
# f"Computed {endpoint_value} uncertainty",
# xy=(1.2, 0.3),
# xycoords="axes fraction",
# rotation=270,
# )
# ax.set_title(
# f"Scatterplot of predicted {endpoint_value} versus MW\ntarget: {target_value}"
# )
# ax.set_xlabel(f"Molecular weight (Da)")
# ax.set_ylabel(f"Computed {endpoint_value}")
# st.pyplot(fig)
else:
# just print the prediction
preds = queried_df[pred_column_name].values[0]
smiles = queried_df["SMILES"].values[0]
if err:
err = queried_df[unc_column_name].values[0]
errstr = f"± {err:.2f}"
else:
errstr = ""
st.markdown(
f"Predicted {target_value} {endpoint_value} for {smiles} is {preds:.2f} {errstr}."
)
# allow the user to download the predictions
csv = convert_df(queried_df)
st.download_button(
label="Download data as CSV",
data=csv,
file_name=f"predictions_{model_value}.csv",
mime="text/csv",
)