hmacdope commited on
Commit
b1ec78c
·
1 Parent(s): 80a757e
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [theme]
2
+ base="light"
app.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ import matplotlib.pyplot as plt
7
+ from rdkit import Chem
8
+ from streamlit_ketcher import st_ketcher
9
+ from io import StringIO
10
+
11
+
12
+ from openadmet_models.models.gradient_boosting.lgbm import LGBMRegressorModel
13
+ from openadmet_models.features.combine import FeatureConcatenator
14
+ from openadmet_models.features.molfeat_properties import DescriptorFeaturizer
15
+ from openadmet_models.features.molfeat_fingerprint import FingerprintFeaturizer
16
+
17
+
18
+
19
+ def _is_valid_smiles(smi):
20
+ if smi is None or smi == "":
21
+ return False
22
+ try:
23
+ m = Chem.MolFromSmiles(smi)
24
+ if m is None:
25
+ return False
26
+ else:
27
+ return True
28
+ except:
29
+ return False
30
+
31
+
32
+ def sdf_str_to_rdkit_mol(sdf):
33
+ from io import BytesIO
34
+
35
+ bio = BytesIO(sdf.encode())
36
+ suppl = Chem.ForwardSDMolSupplier(bio, removeHs=False)
37
+ mols = [mol for mol in suppl if mol is not None]
38
+ return mols
39
+
40
+
41
+ @st.cache_data
42
+ def convert_df(df):
43
+ # IMPORTANT: Cache the conversion to prevent computation on every rerun
44
+ return df.to_csv().encode("utf-8")
45
+
46
+
47
+
48
+ def get_model(path, target, model_type):
49
+
50
+
51
+ model_path = os.path.join(path, f"{model_type}/{target.lower()}_model.json")
52
+ model_file = os.path.join(path, f"{model_type}/{target.lower()}_model.pkl")
53
+
54
+ print(model_path, model_file)
55
+
56
+ if not os.path.exists(model_path) or not os.path.exists(model_file):
57
+ return None
58
+
59
+ model = LGBMRegressorModel.deserialize(model_path, model_file)
60
+ featurizer = FeatureConcatenator(featurizers=[FingerprintFeaturizer(fp_type="ecfp:4"), DescriptorFeaturizer(descr_type="mordred")])
61
+ return model, featurizer
62
+
63
+ # Set the title of the Streamlit app
64
+ st.title("OpenADMET Streamlit DEMO")
65
+
66
+ # Set the title of the Streamlit app
67
+ st.title("OpenADMET Streamlit DEMO")
68
+
69
+ st.markdown("## Background")
70
+
71
+ st.markdown(
72
+ "**The [OpenADMET](https://openadmet.org) initiative provides a suite of open-source machine learning models to predict ADMET (Absorption, Distribution, Metabolism, Excretion, and Toxicity) properties, facilitating drug discovery and development.**"
73
+ )
74
+
75
+ st.markdown(
76
+ "This web app enables researchers and scientists to leverage OpenADMET’s models without needing to write or run code, making predictive analytics more accessible."
77
+ )
78
+ st.markdown("---")
79
+ st.markdown("## Input :clipboard:")
80
+
81
+ input = st.selectbox(
82
+ "How would you like to enter your input?",
83
+ ["Upload a CSV file", "Draw a molecule", "Enter SMILES", "Upload an SDF file"],
84
+ key="input",
85
+ )
86
+
87
+ multismiles = False
88
+ if input == "Draw a molecule":
89
+ smiles = st_ketcher(None)
90
+ if _is_valid_smiles(smiles):
91
+ st.success("Valid molecule", icon="✅")
92
+ else:
93
+ st.error("Invalid molecule", icon="🚨")
94
+ st.stop()
95
+ smiles = [smiles]
96
+ queried_df = pd.DataFrame(smiles, columns=["SMILES"])
97
+ smiles_column_name = "SMILES"
98
+ smiles_column = queried_df[smiles_column_name]
99
+ elif input == "Enter SMILES":
100
+ smiles = st.text_input("Enter a SMILES string", key="smiles_user_input")
101
+ if _is_valid_smiles(smiles):
102
+ st.success("Valid SMILES string", icon="✅")
103
+ else:
104
+ st.error("Invalid SMILES string", icon="🚨")
105
+ st.stop()
106
+ smiles = [smiles]
107
+ queried_df = pd.DataFrame(smiles, columns=["SMILES"])
108
+ smiles_column_name = "SMILES"
109
+ smiles_column = queried_df[smiles_column_name]
110
+ elif input == "Upload a CSV file":
111
+ # Create a file uploader for CSV files
112
+ uploaded_file = st.file_uploader(
113
+ "Choose a CSV file to upload your predictions to", type="csv", key="csv_file"
114
+ )
115
+
116
+ # If a file is uploaded, parse it into a DataFrame
117
+ if uploaded_file is not None:
118
+ queried_df = pd.read_csv(uploaded_file)
119
+ else:
120
+ st.stop()
121
+ # Select a column from the DataFrame
122
+ smiles_column_name = st.selectbox("Select a SMILES column", queried_df.columns, key="df_smiles_column")
123
+ multismiles = True
124
+ smiles_column = queried_df[smiles_column_name]
125
+
126
+ # check if the smiles are valid
127
+ valid_smiles = [_is_valid_smiles(smi) for smi in smiles_column]
128
+ if not all(valid_smiles):
129
+ st.error(
130
+ "Some of the SMILES strings are invalid, please check the input", icon="🚨"
131
+ )
132
+ st.stop()
133
+ st.success(
134
+ f"All SMILES strings are valid (n={len(valid_smiles)}), proceeding with prediction",
135
+ icon="✅",
136
+ )
137
+
138
+ elif input == "Upload an SDF file":
139
+ # Create a file uploader for SDF files
140
+ uploaded_file = st.file_uploader(
141
+ "Choose a SDF file to upload your predictions to", type="sdf"
142
+ )
143
+ # read with rdkit
144
+ if uploaded_file is not None:
145
+ # To convert to a string based IO:
146
+ try:
147
+ stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
148
+ # To read file as string:
149
+ string_data = stringio.read()
150
+ mols = sdf_str_to_rdkit_mol(string_data)
151
+ smiles = [Chem.MolToSmiles(m) for m in mols]
152
+ queried_df = pd.DataFrame(smiles, columns=["SMILES"])
153
+ except:
154
+ st.error("Error reading the SDF file, please check the input", icon="🚨")
155
+ st.stop()
156
+ else:
157
+ st.error("No file uploaded", icon="🚨")
158
+ st.stop()
159
+
160
+ st.success(
161
+ f"All molecule entries are valid (n={len(queried_df)}), proceeding with prediction",
162
+ icon="✅",
163
+ )
164
+ smiles_column_name = "SMILES"
165
+ smiles_column = queried_df[smiles_column_name]
166
+ multismiles = True
167
+
168
+ st.markdown("## Model parameters :nut_and_bolt:")
169
+
170
+
171
+
172
+
173
+ targets = ['CYP3A4', 'CYP2D6', 'CYP2C9']
174
+ models = {"ecfp:4 Mordred LGBM":"ecfp4_mordred_lgbm", "ChemProp":"chemprop"}
175
+ models_inversed = {v: k for k, v in models.items()}
176
+
177
+ model_names = list(models.keys())
178
+
179
+ endpoints = ["pIC50"]
180
+
181
+ # Select a target value from the preset list
182
+ target_value = st.selectbox("Select a biological target ", targets, key="target")
183
+ # endpoints
184
+
185
+
186
+ # Select a target value from the preset list
187
+ endpoint_value = st.selectbox("Select a property ", endpoints, key="endpoint")
188
+
189
+ model_value = st.selectbox("Select a model type ", model_names, key="model")
190
+
191
+
192
+ if target_value != "CYP3A4":
193
+ st.write("Only CYP3A4 is currently supported")
194
+ st.stop()
195
+
196
+ if endpoint_value != "pIC50":
197
+ st.write("Only pIC50 is currently supported")
198
+ st.stop()
199
+
200
+ if model_value != "ecfp:4 Mordred LGBM":
201
+ st.write("Only ecfp:4 Mordred LGBM is currently supported")
202
+ st.stop()
203
+
204
+ model, featurizer = get_model("./models", target_value, models[model_value])
205
+
206
+
207
+
208
+
209
+ if model is None:
210
+ st.write(f"No model found for {target_value} {endpoint_value}")
211
+ st.stop()
212
+ # retry with a different target or endpoint
213
+
214
+ st.markdown("## Prediction 🚀")
215
+
216
+
217
+ st.write(
218
+ f"Predicting **{target_value} {endpoint_value}** using model:\n\n `{model_value}`"
219
+ )
220
+
221
+ # featurize the smiles
222
+ X, _ = featurizer.featurize(smiles_column)
223
+ # predict the properties
224
+ preds = model.predict(X)
225
+
226
+ # not implemented yet
227
+ err = None
228
+
229
+
230
+ pred_column_name = f"{target_value}_computed-{endpoint_value}"
231
+ unc_column_name = f"{target_value}_computed-{endpoint_value}_uncertainty"
232
+ queried_df[pred_column_name] = preds
233
+ queried_df[unc_column_name] = err
234
+
235
+ st.markdown("---")
236
+ if multismiles:
237
+ # plot the predictions and errors
238
+ # Histogram first
239
+ fig, ax = plt.subplots()
240
+
241
+ sorted_df = queried_df.sort_values(by=pred_column_name)
242
+ n_bins = int(len(sorted_df[pred_column_name]) / 10)
243
+ if n_bins < 5: # makes the histogram slightly more interpretable with low data
244
+ n_bins = 5
245
+
246
+ ax.hist(sorted_df[pred_column_name], bins=n_bins)
247
+
248
+ ax.set_ylabel("Count")
249
+ ax.set_xlabel(f"Computed {endpoint_value}")
250
+ ax.set_title(f"Histogram of computed {endpoint_value} for target: {target_value}")
251
+
252
+ st.pyplot(fig)
253
+
254
+ # then a barplot
255
+ fig, ax = plt.subplots()
256
+
257
+ ax.bar(range(len(sorted_df)), sorted_df[pred_column_name])
258
+
259
+ ax.set_xticks([])
260
+ ax.set_xlabel(f"Query compounds")
261
+ ax.set_ylabel(f"Computed {endpoint_value}")
262
+
263
+ ax.set_title(f"Barplot of computed {endpoint_value} for target: {target_value}")
264
+
265
+ st.pyplot(fig)
266
+
267
+ # if endpoint_value == "pIC50":
268
+ # from rdkit.Chem.Descriptors import MolWt
269
+ # import seaborn as sns
270
+
271
+ # # then a scatterplot of uncertainty vs MW
272
+ # queried_df["MW"] = [
273
+ # MolWt(Chem.MolFromSmiles(smi)) for smi in sorted_df[smiles_column_name]
274
+ # ]
275
+ # fig, ax = plt.subplots()
276
+
277
+ # ax = sns.scatterplot(
278
+ # x="MW",
279
+ # y=pred_column_name,
280
+ # hue=unc_column_name,
281
+ # palette="coolwarm",
282
+ # data=queried_df,
283
+ # )
284
+
285
+ # norm = plt.Normalize(
286
+ # queried_df[unc_column_name].min(), queried_df[unc_column_name].max()
287
+ # )
288
+ # sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
289
+ # sm.set_array([])
290
+
291
+ # # Remove the legend and add a colorbar
292
+ # cbar = ax.figure.colorbar(sm, ax=ax)
293
+ # ax.annotate(
294
+ # f"Computed {endpoint_value} uncertainty",
295
+ # xy=(1.2, 0.3),
296
+ # xycoords="axes fraction",
297
+ # rotation=270,
298
+ # )
299
+
300
+ # ax.set_title(
301
+ # f"Scatterplot of predicted {endpoint_value} versus MW\ntarget: {target_value}"
302
+ # )
303
+ # ax.set_xlabel(f"Molecular weight (Da)")
304
+ # ax.set_ylabel(f"Computed {endpoint_value}")
305
+ # st.pyplot(fig)
306
+
307
+ else:
308
+ # just print the prediction
309
+ preds = queried_df[pred_column_name].values[0]
310
+ smiles = queried_df["SMILES"].values[0]
311
+ if err:
312
+ err = queried_df[unc_column_name].values[0]
313
+ errstr = f"± {err:.2f}"
314
+ else:
315
+ errstr = ""
316
+
317
+ st.markdown(
318
+ f"Predicted {target_value} {endpoint_value} for {smiles} is {preds:.2f} {errstr}."
319
+ )
320
+
321
+ # allow the user to download the predictions
322
+ csv = convert_df(queried_df)
323
+ st.download_button(
324
+ label="Download data as CSV",
325
+ data=csv,
326
+ file_name=f"predictions_{model_value}.csv",
327
+ mime="text/csv",
328
+ )
models/ecfp4_mordred_lgbm/cyp3a4_model.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_params": {
3
+ "boosting_type": "gbdt",
4
+ "class_weight": null,
5
+ "colsample_bytree": 1.0,
6
+ "importance_type": "split",
7
+ "learning_rate": 0.05,
8
+ "max_depth": -1,
9
+ "min_child_samples": 20,
10
+ "min_child_weight": 0.001,
11
+ "min_split_gain": 0.0,
12
+ "n_estimators": 500,
13
+ "n_jobs": null,
14
+ "num_leaves": 31,
15
+ "objective": null,
16
+ "random_state": null,
17
+ "reg_alpha": 0.0,
18
+ "reg_lambda": 0.0,
19
+ "subsample": 1.0,
20
+ "subsample_for_bin": 200000,
21
+ "subsample_freq": 0,
22
+ "alpha": 0.005
23
+ }
24
+ }
models/ecfp4_mordred_lgbm/cyp3a4_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4c4dd3ba5e2deb461c7471ad8c6e957044e81225fe1f55faebe3ee5abc2e3e0
3
+ size 1605595