Spaces:
Sleeping
Sleeping
Commit
·
079a08e
1
Parent(s):
63ce440
Introducing team-specific defaults, immediate about page loading, and validation display on predict page
Browse files- app.py +3 -1
- pages/input.py +15 -0
- pages/predict.py +44 -4
- pages/validate.py +10 -0
app.py
CHANGED
|
@@ -155,4 +155,6 @@ def check_password():
|
|
| 155 |
menu() # Render the dynamic menu!
|
| 156 |
|
| 157 |
if not check_password():
|
| 158 |
-
st.stop()
|
|
|
|
|
|
|
|
|
| 155 |
menu() # Render the dynamic menu!
|
| 156 |
|
| 157 |
if not check_password():
|
| 158 |
+
st.stop()
|
| 159 |
+
|
| 160 |
+
st.switch_page("pages/about.py")
|
pages/input.py
CHANGED
|
@@ -43,6 +43,17 @@ if "query" not in st.session_state:
|
|
| 43 |
source_node_index = 0
|
| 44 |
target_node_type_index = 0
|
| 45 |
relation_index = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
else:
|
| 47 |
source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
|
| 48 |
source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
|
|
@@ -88,6 +99,10 @@ if st.button("Submit Query"):
|
|
| 88 |
"relation": list(relation_options)
|
| 89 |
}
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
# # Write query to console
|
| 92 |
# st.write("Current Query:")
|
| 93 |
# st.write(st.session_state.query)
|
|
|
|
| 43 |
source_node_index = 0
|
| 44 |
target_node_type_index = 0
|
| 45 |
relation_index = 0
|
| 46 |
+
|
| 47 |
+
if st.session_state.team == "Clalit":
|
| 48 |
+
source_node_type_index = 2
|
| 49 |
+
source_node_index = 0
|
| 50 |
+
target_node_type_index = 3
|
| 51 |
+
relation_index = 2
|
| 52 |
+
|
| 53 |
+
if st.session_state.team == "ASAP":
|
| 54 |
+
source_node_type_index = 2
|
| 55 |
+
source_node_index = 10255
|
| 56 |
+
|
| 57 |
else:
|
| 58 |
source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
|
| 59 |
source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
|
|
|
|
| 99 |
"relation": list(relation_options)
|
| 100 |
}
|
| 101 |
|
| 102 |
+
# Delete validation from session state
|
| 103 |
+
if "validation" in st.session_state:
|
| 104 |
+
del st.session_state.validation
|
| 105 |
+
|
| 106 |
# # Write query to console
|
| 107 |
# st.write("Current Query:")
|
| 108 |
# st.write(st.session_state.query)
|
pages/predict.py
CHANGED
|
@@ -143,6 +143,37 @@ with st.spinner('Computing predictions...'):
|
|
| 143 |
# Add URLs to database column
|
| 144 |
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
# NODE SEARCH
|
| 148 |
|
|
@@ -152,11 +183,16 @@ with st.spinner('Computing predictions...'):
|
|
| 152 |
|
| 153 |
# Filter nodes
|
| 154 |
if len(selected_nodes) > 0:
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
# Show filtered nodes
|
| 158 |
if target_node_type not in ['disease', 'anatomy']:
|
| 159 |
-
st.dataframe(selected_display_data, use_container_width = True,
|
| 160 |
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
| 161 |
help = "Click to visit external database.",
|
| 162 |
display_text = display_database)})
|
|
@@ -185,14 +221,18 @@ with st.spinner('Computing predictions...'):
|
|
| 185 |
# Show top ranked nodes
|
| 186 |
st.subheader("Model Predictions", divider = "blue")
|
| 187 |
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
if target_node_type not in ['disease', 'anatomy']:
|
| 190 |
-
st.dataframe(
|
| 191 |
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
| 192 |
help = "Click to visit external database.",
|
| 193 |
display_text = display_database)})
|
| 194 |
else:
|
| 195 |
-
st.dataframe(
|
| 196 |
|
| 197 |
# Save to session state
|
| 198 |
st.session_state.predictions = display_data
|
|
|
|
| 143 |
# Add URLs to database column
|
| 144 |
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
|
| 145 |
|
| 146 |
+
# Check if validation data exists
|
| 147 |
+
if 'validation' in st.session_state:
|
| 148 |
+
|
| 149 |
+
# Checkbox to allow reverse edges
|
| 150 |
+
show_val = st.checkbox("Show Ground Truth Validation?", value = False)
|
| 151 |
+
|
| 152 |
+
if show_val:
|
| 153 |
+
|
| 154 |
+
# Get validation data
|
| 155 |
+
val_results = st.session_state.validation.copy()
|
| 156 |
+
|
| 157 |
+
# Merge with predictions
|
| 158 |
+
val_display_data = pd.merge(display_data, val_results, left_on = 'ID', right_on = 'y_id', how='left')
|
| 159 |
+
val_display_data = val_display_data.fillna(0).drop(columns='y_id')
|
| 160 |
+
|
| 161 |
+
# Get new columns
|
| 162 |
+
val_relations = val_display_data.columns.difference(display_data.columns).tolist()
|
| 163 |
+
|
| 164 |
+
# Replace 0 with blank and 1 with check emoji in new columns
|
| 165 |
+
for col in val_relations:
|
| 166 |
+
val_display_data[col] = val_display_data[col].replace({0: '', 1: '✅'})
|
| 167 |
+
|
| 168 |
+
# Define a function to apply styles
|
| 169 |
+
def style_val(val):
|
| 170 |
+
if val == '✅':
|
| 171 |
+
return 'background-color: #C2EABD;' # text-align: center;
|
| 172 |
+
return 'background-color: #F5F5F5;' # text-align: center;
|
| 173 |
+
|
| 174 |
+
else:
|
| 175 |
+
show_val = False
|
| 176 |
+
|
| 177 |
|
| 178 |
# NODE SEARCH
|
| 179 |
|
|
|
|
| 183 |
|
| 184 |
# Filter nodes
|
| 185 |
if len(selected_nodes) > 0:
|
| 186 |
+
|
| 187 |
+
if show_val:
|
| 188 |
+
# selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
|
| 189 |
+
selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].style.map(style_val, subset=val_relations)
|
| 190 |
+
else:
|
| 191 |
+
selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
|
| 192 |
|
| 193 |
# Show filtered nodes
|
| 194 |
if target_node_type not in ['disease', 'anatomy']:
|
| 195 |
+
st.dataframe(selected_display_data, use_container_width = True, hide_index = True,
|
| 196 |
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
| 197 |
help = "Click to visit external database.",
|
| 198 |
display_text = display_database)})
|
|
|
|
| 221 |
# Show top ranked nodes
|
| 222 |
st.subheader("Model Predictions", divider = "blue")
|
| 223 |
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
|
| 224 |
+
|
| 225 |
+
# Show full results
|
| 226 |
+
# full_results = val_display_data.iloc[:top_k] if show_val else display_data.iloc[:top_k]
|
| 227 |
+
full_results = val_display_data.iloc[:top_k].style.map(style_val, subset=val_relations) if show_val else display_data.iloc[:top_k]
|
| 228 |
|
| 229 |
if target_node_type not in ['disease', 'anatomy']:
|
| 230 |
+
st.dataframe(full_results, use_container_width = True, hide_index = True,
|
| 231 |
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
| 232 |
help = "Click to visit external database.",
|
| 233 |
display_text = display_database)})
|
| 234 |
else:
|
| 235 |
+
st.dataframe(full_results, use_container_width = True, hide_index = True,)
|
| 236 |
|
| 237 |
# Save to session state
|
| 238 |
st.session_state.predictions = display_data
|
pages/validate.py
CHANGED
|
@@ -65,6 +65,16 @@ with st.spinner('Searching known relationships...'):
|
|
| 65 |
# If there exist edges in KG
|
| 66 |
if len(edges_in_kg) > 0:
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
with st.spinner('Plotting known relationships...'):
|
| 69 |
|
| 70 |
# Define a color map for different relations
|
|
|
|
| 65 |
# If there exist edges in KG
|
| 66 |
if len(edges_in_kg) > 0:
|
| 67 |
|
| 68 |
+
with st.spinner('Saving validation results...'):
|
| 69 |
+
|
| 70 |
+
# Cast long to wide
|
| 71 |
+
val_results = edge_subset[['relation', 'y_id']].pivot_table(index='y_id', columns='relation', aggfunc='size', fill_value=0)
|
| 72 |
+
val_results = (val_results > 0).astype(int).reset_index()
|
| 73 |
+
val_results.columns = [val_results.columns[0]] + [x.replace('_', ' ').title() for x in val_results.columns[1:]]
|
| 74 |
+
|
| 75 |
+
# Save validation results to session state
|
| 76 |
+
st.session_state.validation = val_results
|
| 77 |
+
|
| 78 |
with st.spinner('Plotting known relationships...'):
|
| 79 |
|
| 80 |
# Define a color map for different relations
|