Spaces:
Running
Running
File size: 5,277 Bytes
7169e4d 2ef204f 7169e4d 2ef204f 7169e4d 2ef204f 7169e4d 2ef204f 7169e4d 2ef204f 6d42f96 2ef204f 6d42f96 2ef204f 6d42f96 2ef204f 6d42f96 2ef204f 7169e4d 2ef204f 7169e4d 2ef204f 7169e4d 2ef204f 7169e4d 2ef204f 6d42f96 7169e4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import streamlit as st
import os
import pandas as pd
import plotly.express as px
import numpy as np
from predict import process_target_data, get_average_embedding # Import your function
st.set_page_config(page_title="CoV-SNN", page_icon="🧬")
def main():
st.title("CoV-SNN")
st.markdown("##### Predict viral escape potential of novel SARS-CoV-2 variants in seconds!")
# Read the README.md file
try:
with open("INSTRUCTIONS.md", "r") as readme_file:
readme_text = readme_file.read()
except FileNotFoundError:
readme_text = "INSTRUCTIONS.md file not found."
option = st.radio(
"Select a reference embedding:",
["Omicron", "Other"],
captions=["Use average embedding of Omicron sequences (Pre-generated)", "Generate average embedding of your own sequences (Takes longer)"],)
# File uploader for the reference.csv
reference_file = st.file_uploader("Upload reference sequences. Make sure the CSV file has ``sequence`` column.",
type=["csv"],
disabled=option == "Omicron")
# File uploader for the target.csv
target_file = st.file_uploader("Upload target sequences. Make sure the CSV file has ``accession_id`` and ``sequence`` columns.",
type=["csv"],
disabled = option == "Other" and reference_file is None)
if target_file is not None and (option == "Omicron" or reference_file is not None):
if option == "Omicron":
# Assuming you have a pre-defined average_embedding
average_embedding = np.load("average_omicron_embedding.npy")
print(f"Average Omicron embedding loaded from file with shape {average_embedding.shape}")
else:
with st.spinner('Calculating average embedding...'):
ref_df = pd.read_csv(reference_file)
average_embedding = get_average_embedding(ref_df)
with st.spinner('Predicting escape potentials...'):
# Read the uploaded CSV file into a DataFrame
target_dataset = pd.read_csv(target_file)
# Process the target dataset
results_df = process_target_data(average_embedding, target_dataset)
# Reverse the rank_sc_sp by subtracting it from the maximum rank value plus one
results_df['Escape Potential'] = results_df['rank_by_scip'].max() + 1 - results_df['rank_by_scip']
# Create scatter plot with manual color assignment
fig = px.scatter(
results_df.applymap(lambda x: round(x, 6) if isinstance(x, (int, float)) else x),
x="log10(gr)",
y="log10(sc)",
labels={"log10(gr)": "log10(gr)", "log10(sc)": "log10(sc)"},
title="CoV-SNN Results",
hover_name="accession_id",
color="Escape Potential",
color_continuous_scale=["green", "yellow", "red"],
hover_data={
"log10(sp)": True, # display log10(sp)
"log10(sc)": True, # display log10(sc)
"log10(ip)": True, # display log10(ip)
#"log10(gr)": True, # display log10(gr)
"sp": False, # display actual sp
"sc": False, # display actual sc
"ip": False, # display actual ip
#"gr": False, # display actual gr
"rank_by_sc": True, # display rank by sc
"rank_by_sp": True, # display rank by sp
"rank_by_ip": True, # display rank by ip
"rank_by_scsp": True, # display rank by scsp
"rank_by_scip": True, # display rank by scip
#"rank_by_scgr": True, # display rank by scgr
"Escape Potential": False
},
)
# Hide the colorbar ticks and labels
fig.update_coloraxes(
colorbar=dict(
title=None,
tickvals=[],
ticktext=[],
y=0.5,
len=0.7
)
)
# Hide the legend
#fig.update_layout(showlegend=False)
# add your rotated title via annotations
fig.update_layout(
margin=dict(r=110),
annotations=[
dict(
text="Escape Potential",
font_size=14,
textangle=270,
showarrow=False,
xref="paper",
yref="paper",
x=1.14,
y=0.5
)
]
)
# Display the plot in Streamlit
st.plotly_chart(fig, theme="streamlit", border=True, use_container_width=True, border_color="black")
# Display the results as a DataFrame
st.dataframe(results_df[["accession_id", "log10(sc)", "log10(sp)", "log10(ip)",
"rank_by_sc", "rank_by_sp", "rank_by_ip", "rank_by_scsp", "rank_by_scip"
]], hide_index=True)
# Display the README.md file
st.markdown(readme_text)
if __name__ == "__main__":
main()
|