File size: 5,277 Bytes
7169e4d
2ef204f
7169e4d
2ef204f
7169e4d
2ef204f
7169e4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ef204f
 
 
 
 
 
 
 
 
7169e4d
 
2ef204f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d42f96
2ef204f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d42f96
2ef204f
 
 
6d42f96
2ef204f
 
 
 
 
6d42f96
2ef204f
 
 
 
 
 
 
 
 
 
 
 
 
 
7169e4d
2ef204f
 
7169e4d
2ef204f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7169e4d
2ef204f
 
7169e4d
2ef204f
 
6d42f96
 
7169e4d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import streamlit as st
import os
import pandas as pd
import plotly.express as px
import numpy as np
from predict import process_target_data, get_average_embedding  # Import your function


st.set_page_config(page_title="CoV-SNN", page_icon="🧬")

def main():
    st.title("CoV-SNN")
    st.markdown("##### Predict viral escape potential of novel SARS-CoV-2 variants in seconds!")

    # Read the README.md file
    try:
        with open("INSTRUCTIONS.md", "r") as readme_file:
            readme_text = readme_file.read()
    except FileNotFoundError:
        readme_text = "INSTRUCTIONS.md file not found."

    option = st.radio(
    "Select a reference embedding:",
    ["Omicron", "Other"],
    captions=["Use average embedding of Omicron sequences (Pre-generated)", "Generate average embedding of your own sequences (Takes longer)"],)

    # File uploader for the reference.csv
    reference_file = st.file_uploader("Upload reference sequences. Make sure the CSV file has ``sequence`` column.",
                                      type=["csv"], 
                                      disabled=option == "Omicron")

    # File uploader for the target.csv
    target_file = st.file_uploader("Upload target sequences. Make sure the CSV file has ``accession_id`` and ``sequence`` columns.",
                                   type=["csv"], 
                                   disabled = option == "Other" and reference_file is None)

    if target_file is not None and (option == "Omicron" or reference_file is not None):

        if option == "Omicron":
            # Assuming you have a pre-defined average_embedding
            average_embedding = np.load("average_omicron_embedding.npy")
            print(f"Average Omicron embedding loaded from file with shape {average_embedding.shape}")                
        else:
            with st.spinner('Calculating average embedding...'):
                ref_df = pd.read_csv(reference_file)
                average_embedding = get_average_embedding(ref_df)

        with st.spinner('Predicting escape potentials...'):
            # Read the uploaded CSV file into a DataFrame
            target_dataset = pd.read_csv(target_file)

            # Process the target dataset
            results_df = process_target_data(average_embedding, target_dataset)

        # Reverse the rank_sc_sp by subtracting it from the maximum rank value plus one
        results_df['Escape Potential'] = results_df['rank_by_scip'].max() + 1 - results_df['rank_by_scip']

        # Create scatter plot with manual color assignment
        fig = px.scatter(
            results_df.applymap(lambda x: round(x, 6) if isinstance(x, (int, float)) else x),
            x="log10(gr)",
            y="log10(sc)",
            labels={"log10(gr)": "log10(gr)", "log10(sc)": "log10(sc)"},
            title="CoV-SNN Results",
            hover_name="accession_id",
            color="Escape Potential",
            color_continuous_scale=["green", "yellow", "red"],
            hover_data={
                "log10(sp)": True,     # display log10(sp)
                "log10(sc)": True,     # display log10(sc)
                "log10(ip)": True,     # display log10(ip)
                #"log10(gr)": True,     # display log10(gr)
                "sp": False,            # display actual sp
                "sc": False,            # display actual sc
                "ip": False,            # display actual ip
                #"gr": False,            # display actual gr
                "rank_by_sc": True,    # display rank by sc
                "rank_by_sp": True,    # display rank by sp
                "rank_by_ip": True,    # display rank by ip
                "rank_by_scsp": True,  # display rank by scsp
                "rank_by_scip": True,  # display rank by scip
                #"rank_by_scgr": True,  # display rank by scgr
                "Escape Potential": False
            },
        )

        # Hide the colorbar ticks and labels
        fig.update_coloraxes(
            colorbar=dict(
                title=None,
                tickvals=[],
                ticktext=[],
                y=0.5,
                len=0.7
            )
        )

        # Hide the legend
        #fig.update_layout(showlegend=False)

        # add your rotated title via annotations
        fig.update_layout(
            margin=dict(r=110),
            annotations=[
                dict(
                    text="Escape Potential",
                    font_size=14,
                    textangle=270,
                    showarrow=False,
                    xref="paper",
                    yref="paper",
                    x=1.14,
                    y=0.5
                )
            ]
        )

        # Display the plot in Streamlit
        st.plotly_chart(fig, theme="streamlit", border=True, use_container_width=True, border_color="black")

        # Display the results as a DataFrame
        st.dataframe(results_df[["accession_id", "log10(sc)", "log10(sp)", "log10(ip)", 
                                 "rank_by_sc", "rank_by_sp", "rank_by_ip", "rank_by_scsp", "rank_by_scip"
                                 ]], hide_index=True)

    # Display the README.md file
    st.markdown(readme_text)

if __name__ == "__main__":
    main()