import os
import tiger
import cas9on
import cas9off
import cas12
import pandas as pd
import streamlit as st
import plotly.graph_objs as go
from pygenomeviz import Genbank, GenomeViz
import numpy as np
from pathlib import Path
import zipfile
import io


# title and documentation
st.markdown(Path('crisprTool.md').read_text(), unsafe_allow_html=True)
st.divider()

CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']

selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
cas9on_path = 'cas9_model/on-cla.h5'
cas12_path = 'cas12_model/Seq_deepCpf1_weights.h5'

@st.cache_data
def convert_df(df):
            # IMPORTANT: Cache the conversion to prevent computation on every rerun
            return df.to_csv().encode('utf-8')


def mode_change_callback():
        if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}:  # TODO: support titration
            st.session_state.check_off_targets = False
            st.session_state.disable_off_target_checkbox = True
        else:
            st.session_state.disable_off_target_checkbox = False


def progress_update(update_text, percent_complete):
        with progress.container():
            st.write(update_text)
            st.progress(percent_complete / 100)


def initiate_run():
        # initialize state variables
        st.session_state.transcripts = None
        st.session_state.input_error = None
        st.session_state.on_target = None
        st.session_state.titration = None
        st.session_state.off_target = None

        # initialize transcript DataFrame
        transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL])

        # manual entry
        if st.session_state.entry_method == ENTRY_METHODS['manual']:
            transcripts = pd.DataFrame({
                tiger.ID_COL: ['ManualEntry'],
                tiger.SEQ_COL: [st.session_state.manual_entry]
            }).set_index(tiger.ID_COL)

        # fasta file upload
        elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
            if st.session_state.fasta_entry is not None:
                fasta_path = st.session_state.fasta_entry.name
                with open(fasta_path, 'w') as f:
                    f.write(st.session_state.fasta_entry.getvalue().decode('utf-8'))
                transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False)
                os.remove(fasta_path)

        # convert to upper case as used by tokenizer
        transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T'))

        # ensure all transcripts have unique identifiers
        if transcripts.index.has_duplicates:
                st.session_state.input_error = "Duplicate transcript ID's detected in fasta file"

        # ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N
        elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))):
            st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us'

        # ensure all transcripts satisfy length requirements
        elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)):
            st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN)

        # run model if we have any transcripts
        elif len(transcripts) > 0:
            st.session_state.transcripts = transcripts

def parse_gene_annotations(file_path):
    gene_dict = {}
    with open(file_path, 'r') as file:
        headers = file.readline().strip().split('\t')  # Assuming tab-delimited file
        symbol_idx = headers.index('Approved symbol')  # Find index of 'Approved symbol'
        ensembl_idx = headers.index('Ensembl gene ID')  # Find index of 'Ensembl gene ID'
        for line in file:
            values = line.strip().split('\t')
            # Ensure we have enough values and add mapping from symbol to Ensembl ID
            if len(values) > max(symbol_idx, ensembl_idx):
                gene_dict[values[symbol_idx]] = values[ensembl_idx]
    return gene_dict

# Replace 'your_annotation_file.txt' with the path to your actual gene annotation file
gene_annotations = parse_gene_annotations('Human_genes_HUGO_02242024_annotation.txt')
gene_symbol_list = list(gene_annotations.keys())  # List of gene symbols for the autocomplete feature
# Check if the selected model is Cas9
if selected_model == 'Cas9':
    # Use a radio button to select enzymes, making sure only one can be selected at a time
    target_selection = st.radio(
        "Select either on-target or off-target:",
        ('on-target', 'off-target'),
        key='target_selection'
    )
    if 'current_gene_symbol' not in st.session_state:
        st.session_state['current_gene_symbol'] = ""

    # Define a function to clean up old files
    def clean_up_old_files(gene_symbol):
        genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
        bed_file_path = f"{gene_symbol}_crispr_targets.bed"
        csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
        for path in [genbank_file_path, bed_file_path, csv_file_path]:
            if os.path.exists(path):
                os.remove(path)


    # Gene symbol entry with autocomplete-like feature
    gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
                               format_func=lambda x: x if x else "")

    # Handle gene symbol change and file cleanup
    if gene_symbol != st.session_state['current_gene_symbol'] and gene_symbol:
        if st.session_state['current_gene_symbol']:
            # Clean up files only if a different gene symbol is entered and a previous symbol exists
            clean_up_old_files(st.session_state['current_gene_symbol'])
        # Update the session state with the new gene symbol
        st.session_state['current_gene_symbol'] = gene_symbol

    if target_selection == 'on-target':
        # Prediction button
        predict_button = st.button('Predict on-target')

        if 'exons' not in st.session_state:
            st.session_state['exons'] = []

        # Process predictions
        if predict_button and gene_symbol:
            with st.spinner('Predicting... Please wait'):
                predictions, gene_sequence, exons = cas9on.process_gene(gene_symbol, cas9on_path)
                sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
                st.session_state['on_target_results'] = sorted_predictions
                st.session_state['gene_sequence'] = gene_sequence  # Save gene sequence in session state
                st.session_state['exons'] = exons  # Store exon data

            # Notify the user once the process is completed successfully.
            st.success('Prediction completed!')
            st.session_state['prediction_made'] = True

            if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
                ensembl_id = gene_annotations.get(gene_symbol, 'Unknown')  # Get Ensembl ID or default to 'Unknown'
                col1, col2, col3 = st.columns(3)
                with col1:
                    st.markdown("**Genome**")
                    st.markdown("Homo sapiens")
                with col2:
                    st.markdown("**Gene**")
                    st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
                with col3:
                    st.markdown("**Nuclease**")
                    st.markdown("SpCas9")
                # Include "Target" in the DataFrame's columns
                try:
                    df = pd.DataFrame(st.session_state['on_target_results'],
                                      columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"])
                    st.dataframe(df)
                except ValueError as e:
                    st.error(f"DataFrame creation error: {e}")
                    # Optionally print or log the problematic data for debugging:
                    print(st.session_state['on_target_results'])

                # Initialize Plotly figure
                fig = go.Figure()

                EXON_BASE = 0  # Base position for exons and CDS on the Y axis
                EXON_HEIGHT = 0.02  # How 'tall' the exon markers should appear

                # Plot Exons as small markers on the X-axis
                for exon in st.session_state['exons']:
                    exon_start, exon_end = exon['start'], exon['end']
                    fig.add_trace(go.Bar(
                        x=[(exon_start + exon_end) / 2],
                        y=[EXON_HEIGHT],
                        width=[exon_end - exon_start],
                        base=EXON_BASE,
                        marker_color='rgba(128, 0, 128, 0.5)',
                        name='Exon'
                    ))

                VERTICAL_GAP = 0.2  # Gap between different ranks

                # Define max and min Y values based on strand and rank
                MAX_STRAND_Y = 0.1  # Maximum Y value for positive strand results
                MIN_STRAND_Y = -0.1  # Minimum Y value for negative strand results

                # Iterate over top 5 sorted predictions to create the plot
                for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1):  # Only top 5
                    chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction
                    midpoint = (int(start) + int(end)) / 2

                    # Vertical position based on rank, modified by strand
                    y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else (
                            MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)

                    fig.add_trace(go.Scatter(
                        x=[midpoint],
                        y=[y_value],
                        mode='markers+text',
                        marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down',
                                    size=12),
                        text=f"Rank: {i}",  # Text label
                        hoverinfo='text',
                        hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' or strand == '+' else '-'}<br>Transcript: {transcript}<br>Prediction: {prediction_score:.4f}",
                    ))

                # Update layout for clarity and interaction
                fig.update_layout(
                    title='Top 5 gRNA Sequences by Prediction Score',
                    xaxis_title='Genomic Position',
                    yaxis_title='Strand',
                    yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
                    showlegend=False,
                    hovermode='x unified',
                )

                # Display the plot
                st.plotly_chart(fig)

                # if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
                #     gene_symbol = st.session_state['current_gene_symbol']
                #     gene_sequence = st.session_state['gene_sequence']
                #
                #     # Define file paths
                #     genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
                #     bed_file_path = f"{gene_symbol}_crispr_targets.bed"
                #     csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
                #
                #     # Generate files
                #     cas9on.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
                #     cas9on.create_bed_file_from_df(df, bed_file_path)
                #     cas9on.create_csv_from_df(df, csv_file_path)
                #
                #     # Prepare an in-memory buffer for the ZIP file
                #     zip_buffer = io.BytesIO()
                #     with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
                #         # For each file, add it to the ZIP file
                #         zip_file.write(genbank_file_path, arcname=genbank_file_path.split('/')[-1])
                #         zip_file.write(bed_file_path, arcname=bed_file_path.split('/')[-1])
                #         zip_file.write(csv_file_path, arcname=csv_file_path.split('/')[-1])
                #
                #     # Important: move the cursor to the beginning of the BytesIO buffer before reading it
                #     zip_buffer.seek(0)
                #
                #     # Display the download button for the ZIP file
                #     st.download_button(
                #         label="Download genbank,.bed,csv files as ZIP",
                #         data=zip_buffer.getvalue(),
                #         file_name=f"{gene_symbol}_files.zip",
                #         mime="application/zip"
                #     )

    elif target_selection == 'off-target':
        ENTRY_METHODS = dict(
            manual='Manual entry of target sequence',
            txt="txt file upload"
        )
        if __name__ == '__main__':
            # app initialization for Cas9 off-target
            if 'target_sequence' not in st.session_state:
                st.session_state.target_sequence = None
            if 'input_error' not in st.session_state:
                st.session_state.input_error = None
            if 'off_target_results' not in st.session_state:
                st.session_state.off_target_results = None

            # target sequence entry
            st.selectbox(
                label='How would you like to provide target sequences?',
                options=ENTRY_METHODS.values(),
                key='entry_method',
                disabled=st.session_state.target_sequence is not None
            )
            if st.session_state.entry_method == ENTRY_METHODS['manual']:
                st.text_input(
                    label='Enter on/off sequences:',
                    key='manual_entry',
                    placeholder='Enter on/off sequences like:GGGTGGGGGGAGTTTGCTCCAGG,AGGTGGGGTGA_TTTGCTCCAGG',
                    disabled=st.session_state.target_sequence is not None
                )
            elif st.session_state.entry_method == ENTRY_METHODS['txt']:
                st.file_uploader(
                    label='Upload a txt file:',
                    key='txt_entry',
                    disabled=st.session_state.target_sequence is not None
                )

            # prediction button
            if st.button('Predict off-target'):
                if st.session_state.entry_method == ENTRY_METHODS['manual']:
                    user_input = st.session_state.manual_entry
                    if user_input:  # Check if user_input is not empty
                        predictions = cas9off.process_input_and_predict(user_input, input_type='manual')
                elif st.session_state.entry_method == ENTRY_METHODS['txt']:
                    uploaded_file = st.session_state.txt_entry
                    if uploaded_file is not None:
                        # Read the uploaded file content
                        file_content = uploaded_file.getvalue().decode("utf-8")
                        predictions = cas9off.process_input_and_predict(file_content, input_type='manual')

                st.session_state.off_target_results = predictions
            else:
                predictions = None
            progress = st.empty()

            # input error display
            error = st.empty()
            if st.session_state.input_error is not None:
                error.error(st.session_state.input_error, icon="🚨")
            else:
                error.empty()

            # off-target results display
            off_target_results = st.empty()
            if st.session_state.off_target_results is not None:
                with off_target_results.container():
                    if len(st.session_state.off_target_results) > 0:
                        st.write('Off-target predictions:', st.session_state.off_target_results)
                        st.download_button(
                            label='Download off-target predictions',
                            data=convert_df(st.session_state.off_target_results),
                            file_name='off_target_results.csv',
                            mime='text/csv'
                        )
                    else:
                        st.write('No significant off-target effects detected!')
            else:
                off_target_results.empty()

            # running the CRISPR-Net model for off-target predictions
            if st.session_state.target_sequence is not None:
                st.session_state.off_target_results = cas9off.predict_off_targets(
                    target_sequence=st.session_state.target_sequence,
                    status_update_fn=progress_update
                )
                st.session_state.target_sequence = None
                st.experimental_rerun()

elif selected_model == 'Cas12':
    # Gene symbol entry with autocomplete-like feature
    gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
                               format_func=lambda x: x if x else "")

    # Initialize the current_gene_symbol in the session state if it doesn't exist
    if 'current_gene_symbol' not in st.session_state:
        st.session_state['current_gene_symbol'] = ""

    # Prediction button
    predict_button = st.button('Predict on-target')

    # Function to clean up old files
    def clean_up_old_files(gene_symbol):
        genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
        bed_file_path = f"{gene_symbol}_crispr_targets.bed"
        csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
        for path in [genbank_file_path, bed_file_path, csv_file_path]:
            if os.path.exists(path):
                os.remove(path)

    # Clean up files if a new gene symbol is entered
    if st.session_state['current_gene_symbol'] and gene_symbol != st.session_state['current_gene_symbol']:
        clean_up_old_files(st.session_state['current_gene_symbol'])

    # Process predictions
    if predict_button and gene_symbol:
        # Update the current gene symbol
        st.session_state['current_gene_symbol'] = gene_symbol

        # Run the prediction process
        with st.spinner('Predicting... Please wait'):
            predictions, gene_sequence = cas12.process_gene(gene_symbol,cas12_path)
            sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
            st.session_state['on_target_results'] = sorted_predictions
        st.success('Prediction completed!')

        # Visualization and file generation
        if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
            df = pd.DataFrame(st.session_state['on_target_results'],
                              columns=["Gene ID", "Start Pos", "End Pos", "Strand", "Target", "gRNA", "Prediction"])
            st.dataframe(df)
            # Now create a Plotly plot with the sorted_predictions
            fig = go.Figure()

            # Initialize the y position for the positive and negative strands
            positive_strand_y = 0.1
            negative_strand_y = -0.1

            # Use an offset to spread gRNA sequences vertically
            offset = 0.05

            # Iterate over the sorted predictions to create the plot
            for i, prediction in enumerate(sorted_predictions, start=1):
                # Extract data for plotting and convert start and end to integers
                chrom, start, end, strand, target, gRNA, pred_score = prediction
                start, end = int(start), int(end)
                midpoint = (start + end) / 2

                # Set the y-value and arrow symbol based on the strand
                if strand == '1':
                    y_value = positive_strand_y
                    arrow_symbol = 'triangle-right'
                    # Increment the y-value for the next positive strand gRNA
                    positive_strand_y += offset
                else:
                    y_value = negative_strand_y
                    arrow_symbol = 'triangle-left'
                    # Decrement the y-value for the next negative strand gRNA
                    negative_strand_y -= offset

                fig.add_trace(go.Scatter(
                    x=[midpoint],
                    y=[y_value],  # Use the y_value set above for the strand
                    mode='markers+text',
                    marker=dict(symbol=arrow_symbol, size=10),
                    name=f"gRNA: {gRNA}",
                    text=f"Rank: {i}",  # Place text at the marker
                    hoverinfo='text',
                    hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == 1 else '-'}<br>Prediction Score: {pred_score:.4f}",
                ))

            # Update the layout of the plot
            fig.update_layout(
                title='Top 10 gRNA Sequences by Prediction Score',
                xaxis_title='Genomic Position',
                yaxis=dict(
                    title='Strand',
                    showgrid=True,  # Show horizontal gridlines for clarity
                    zeroline=True,  # Show a line at y=0 to represent the axis
                    zerolinecolor='Black',
                    zerolinewidth=2,
                    tickvals=[positive_strand_y, negative_strand_y],
                    ticktext=['+ Strand', '- Strand']
                ),
                showlegend=False  # Hide the legend if it's not necessary
            )

            # Display the plot
            st.plotly_chart(fig)

            # Ensure gene_sequence is not empty before generating files
            if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
                gene_symbol = st.session_state['current_gene_symbol']
                gene_sequence = st.session_state['gene_sequence']

                # Define file paths
                genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
                bed_file_path = f"{gene_symbol}_crispr_targets.bed"
                csv_file_path = f"{gene_symbol}_crispr_predictions.csv"

                # Generate files
                cas12.generate_genbank_file_from_data(df, gene_sequence, gene_symbol, genbank_file_path)
                cas12.generate_bed_file_from_data(df, bed_file_path)
                cas12.create_csv_from_df(df, csv_file_path)

                # Prepare an in-memory buffer for the ZIP file
                zip_buffer = io.BytesIO()
                with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
                    # For each file, add it to the ZIP file
                    zip_file.write(genbank_file_path, arcname=genbank_file_path.split('/')[-1])
                    zip_file.write(bed_file_path, arcname=bed_file_path.split('/')[-1])
                    zip_file.write(csv_file_path, arcname=csv_file_path.split('/')[-1])

                # Important: move the cursor to the beginning of the BytesIO buffer before reading it
                zip_buffer.seek(0)

                # Display the download button for the ZIP file
                st.download_button(
                    label="Download genbank,.bed,csv files as ZIP",
                    data=zip_buffer.getvalue(),
                    file_name=f"{gene_symbol}_files.zip",
                    mime="application/zip"
                )

elif selected_model == 'Cas13d':
        ENTRY_METHODS = dict(
        manual='Manual entry of single transcript',
        fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)"
        )

        if __name__ == '__main__':
            # app initialization
            if 'mode' not in st.session_state:
                st.session_state.mode = tiger.RUN_MODES['all']
                st.session_state.disable_off_target_checkbox = True
            if 'entry_method' not in st.session_state:
                st.session_state.entry_method = ENTRY_METHODS['manual']
            if 'transcripts' not in st.session_state:
                st.session_state.transcripts = None
            if 'input_error' not in st.session_state:
                st.session_state.input_error = None
            if 'on_target' not in st.session_state:
                st.session_state.on_target = None
            if 'titration' not in st.session_state:
                st.session_state.titration = None
            if 'off_target' not in st.session_state:
                st.session_state.off_target = None

            # mode selection
            col1, col2 = st.columns([0.65, 0.35])
            with col1:
                st.radio(
                    label='What do you want to predict?',
                    options=tuple(tiger.RUN_MODES.values()),
                    key='mode',
                    on_change=mode_change_callback,
                    disabled=st.session_state.transcripts is not None,
                )
            with col2:
                st.checkbox(
                    label='Find off-target effects (slow)',
                    key='check_off_targets',
                    disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None
                )

            # transcript entry
            st.selectbox(
                label='How would you like to provide transcript(s) of interest?',
                options=ENTRY_METHODS.values(),
                key='entry_method',
                disabled=st.session_state.transcripts is not None
            )
            if st.session_state.entry_method == ENTRY_METHODS['manual']:
                st.text_input(
                    label='Enter a target transcript:',
                    key='manual_entry',
                    placeholder='Upper or lower case',
                    disabled=st.session_state.transcripts is not None
                )
            elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
                st.file_uploader(
                    label='Upload a fasta file:',
                    key='fasta_entry',
                    disabled=st.session_state.transcripts is not None
                )

            # let's go!
            st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None)
            progress = st.empty()

            # input error
            error = st.empty()
            if st.session_state.input_error is not None:
                error.error(st.session_state.input_error, icon="🚨")
            else:
                error.empty()

            # on-target results
            on_target_results = st.empty()
            if st.session_state.on_target is not None:
                with on_target_results.container():
                    st.write('On-target predictions:', st.session_state.on_target)
                    st.download_button(
                        label='Download on-target predictions',
                        data=convert_df(st.session_state.on_target),
                        file_name='on_target.csv',
                        mime='text/csv'
                    )
            else:
                on_target_results.empty()

            # titration results
            titration_results = st.empty()
            if st.session_state.titration is not None:
                with titration_results.container():
                    st.write('Titration predictions:', st.session_state.titration)
                    st.download_button(
                        label='Download titration predictions',
                        data=convert_df(st.session_state.titration),
                        file_name='titration.csv',
                        mime='text/csv'
                    )
            else:
                titration_results.empty()

            # off-target results
            off_target_results = st.empty()
            if st.session_state.off_target is not None:
                with off_target_results.container():
                    if len(st.session_state.off_target) > 0:
                        st.write('Off-target predictions:', st.session_state.off_target)
                        st.download_button(
                            label='Download off-target predictions',
                            data=convert_df(st.session_state.off_target),
                            file_name='off_target.csv',
                            mime='text/csv'
                        )
                    else:
                        st.write('We did not find any off-target effects!')
            else:
                off_target_results.empty()

            # keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns)
            if st.session_state.transcripts is not None:
                st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit(
                    transcripts=st.session_state.transcripts,
                    mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode],
                    check_off_targets=st.session_state.check_off_targets,
                    status_update_fn=progress_update
                )
                st.session_state.transcripts = None
                st.experimental_rerun()