from pyvis.network import Network
import os
import json
import gzip

NODE_TYPE_COLORS = {
 'Disease': '#079dbb',
 'HPO': '#58d0e8',
 'Drug': '#815ac0',
 'Compound': '#d2b7e5',
 'Domain': '#6bbf59',
 'GO_term_P': '#ff8800',
 'GO_term_F': '#ffaa00',
 'GO_term_C': '#ffc300',
 'Pathway': '#720026',
 'kegg_Pathway': '#720026',
 'EC_number': '#ce4257',
 'Protein': '#3aa6a4'
}

EDGE_LABEL_TRANSLATION = {
    'Orthology': 'is ortholog to',
    'Pathway': 'takes part in',
    'kegg_path_prot': 'takes part in',
    ('domain_function', 'GO_term_F'): 'enables',
    ('domain_function', 'GO_term_P'): 'is involved in',
    ('domain_function', 'GO_term_C'): 'localizes to', 
    'function_function': 'ontological relationship',   
    'protein_domain': 'has',
    'PPI': 'interacts with',
    'HPO': 'is associated with',
    'kegg_dis_prot': 'is related to',
    'Disease': 'is related to',
    'Drug': 'targets',
    'kegg_dis_path': 'modulates',
    'protein_ec': 'catalyzes',
    'hpodis': 'is associated with',
    'kegg_dis_drug': 'treats',
    'Chembl': 'targets',
    ('protein_function', 'GO_term_F'): 'enables',
    ('protein_function', 'GO_term_P'): 'is involved in',
    ('protein_function', 'GO_term_C'): 'localizes to',    
}

NODE_LABEL_TRANSLATION = {
    'HPO': 'Phenotype',
    'GO_term_P': 'Biological Process',
    'GO_term_F': 'Molecular Function',
    'GO_term_C': 'Cellular Component',
    'kegg_Pathway': 'Pathway',
    'EC_number': 'EC Number',
}

GO_CATEGORY_MAPPING = {
    'Biological Process': 'GO_term_P',
    'Molecular Function': 'GO_term_F',
    'Cellular Component': 'GO_term_C'
}

def get_node_url(node_type, node_id):
    """Get the URL for a node based on its type and ID"""
    if node_type.startswith('GO_term'):
        return f"https://www.ebi.ac.uk/QuickGO/term/{node_id}"
    elif node_type == 'Protein':
        return f"https://www.uniprot.org/uniprotkb/{node_id}/entry"
    elif node_type == 'Disease':
        if ':' in node_id:
            ontology = node_id.split(':')[0]
            if ontology == 'EFO':
                return f"http://www.ebi.ac.uk/efo/EFO_{node_id.split(':')[1]}"
            elif ontology == 'MONDO':
                return f'http://purl.obolibrary.org/obo/MONDO_{node_id.split(":")[1]}'
            elif ontology == 'Orphanet':
                return f"http://www.orpha.net/ORDO/Orphanet_{node_id.split(':')[1]}"
        else:
            return f"https://www.genome.jp/entry/{node_id}"
    elif node_type == 'HPO':
        return f"https://hpo.jax.org/browse/term/{node_id}"
    elif node_type == 'Drug':
        return f"https://go.drugbank.com/drugs/{node_id}"
    elif node_type == 'Compound':
        return f"https://www.ebi.ac.uk/chembl/explore/compound/{node_id}"
    elif node_type == 'Domain':
        return f"https://www.ebi.ac.uk/interpro/entry/InterPro/{node_id}"
    elif node_type == 'Pathway':
        return f"https://reactome.org/content/detail/{node_id}"
    elif node_type == 'kegg_Pathway':
        return f"https://www.genome.jp/pathway/{node_id}"
    elif node_type == 'EC_number':
        return f"https://enzyme.expasy.org/EC/{node_id}"
    else:
        return None

def _gather_protein_edges(data, protein_id):

    protein_idx = data['Protein']['id_mapping'][protein_id]
    reverse_id_mapping = {}
    for node_type in data.node_types:
        reverse_id_mapping[node_type] = {v:k for k, v in data[node_type]['id_mapping'].items()}
    
    protein_edges = {}

    print(f'Gathering edges for {protein_id}...')

    for edge_type in data.edge_types:
        if 'rev' not in edge_type[1]:
            if edge_type not in protein_edges:
                protein_edges[edge_type] = []
            if edge_type[0] == 'Protein':
                print(f'Gathering edges for {edge_type}...')
                # append the edges with protein_idx as source node
                edges = data[edge_type].edge_index[:, data[edge_type].edge_index[0] == protein_idx]
                protein_edges[edge_type].extend(edges.T.tolist())
            elif edge_type[2] == 'Protein':
                print(f'Gathering edges for {edge_type}...')
                # append the edges with protein_idx as target node
                edges = data[edge_type].edge_index[:, data[edge_type].edge_index[1] == protein_idx]
                protein_edges[edge_type].extend(edges.T.tolist())
    
    for edge_type in protein_edges.keys():
        if protein_edges[edge_type]:
            mapped_edges = set()
            for edge in protein_edges[edge_type]:
                # Get source and target node types from edge_type
                source_type, _, target_type = edge_type
                # Map indices back to original IDs
                source_id = reverse_id_mapping[source_type][edge[0]]
                target_id = reverse_id_mapping[target_type][edge[1]]
                mapped_edges.add((source_id, target_id))
            protein_edges[edge_type] = mapped_edges
        
    return protein_edges

def _filter_edges(protein_id, protein_edges, prediction_df, limit=10, is_second_degree=False, second_degree_limit=3):
    """
    Filter edges based on type and limit
    
    Args:
        protein_id: ID of the protein
        protein_edges: Dictionary of edges to filter
        prediction_df: DataFrame containing predictions
        limit: Maximum number of edges to keep for first-degree connections
        is_second_degree: Whether these are second-degree edges
        second_degree_limit: Maximum number of edges to keep for second-degree connections
    """
    filtered_edges = {}
    
    # Use appropriate limit based on edge degree
    current_limit = second_degree_limit if is_second_degree else limit
    
    prediction_categories = prediction_df['GO_category'].unique()
    prediction_categories = [GO_CATEGORY_MAPPING[category] for category in prediction_categories]
    go_category_reverse_mapping = {v:k for k, v in GO_CATEGORY_MAPPING.items()}

    for edge_type, edges in protein_edges.items():
        # Skip if edges is empty
        if edges is None or len(edges) == 0:
            continue
            
        if edge_type[2].startswith('GO_term'):  # Check if it's any GO term edge
            if edge_type[2] in prediction_categories:
                # Handle edges for GO terms that are in prediction_df
                category_mask = (prediction_df['GO_category'] == go_category_reverse_mapping[edge_type[2]]) & (prediction_df['UniProt_ID'] == protein_id)
                category_predictions = prediction_df[category_mask]

                if len(category_predictions) > 0:
                    category_predictions = category_predictions.sort_values(by='Probability', ascending=False)
                    edges_set = set(edges)  # Convert to set for O(1) lookup
                    
                    valid_edges = []
                    for _, row in category_predictions.iterrows():
                        term = row['GO_ID']
                        prob = row['Probability']
                        edge = (protein_id, term)
                        is_ground_truth = edge in edges_set
                        valid_edges.append((edge, prob, is_ground_truth))
                        if len(valid_edges) >= current_limit:
                            break
                    filtered_edges[edge_type] = valid_edges
                else:
                    # If no predictions but it's a GO category in prediction_df
                    filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:current_limit]]
            else:
                # For GO terms not in prediction_df, mark them as ground truth with blue color
                filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:current_limit]]
        else:
            # For non-GO edges, include all edges up to limit
            filtered_edges[edge_type] = [(edge, None, True) for edge in list(edges)[:current_limit]]

    return filtered_edges

def _gather_neighbor_edges(data, node_id, node_type, exclude_node_id):
    """Gather edges for a neighbor node, excluding edges back to the original query protein"""
    
    node_idx = data[node_type]['id_mapping'][node_id]
    reverse_id_mapping = {}
    for ntype in data.node_types:
        reverse_id_mapping[ntype] = {v:k for k, v in data[ntype]['id_mapping'].items()}
    
    neighbor_edges = {}

    for edge_type in data.edge_types:
        if 'rev' not in edge_type[1]:
            if edge_type not in neighbor_edges:
                neighbor_edges[edge_type] = []
            
            if edge_type[0] == node_type:
                # Get edges where neighbor is source
                edges = data[edge_type].edge_index[:, data[edge_type].edge_index[0] == node_idx]
                edges = edges.T.tolist()
                # Filter out edges going back to the query protein
                edges = [edge for edge in edges if reverse_id_mapping[edge_type[2]][edge[1]] != exclude_node_id]
                neighbor_edges[edge_type].extend(edges)
                
            elif edge_type[2] == node_type:
                # Get edges where neighbor is target
                edges = data[edge_type].edge_index[:, data[edge_type].edge_index[1] == node_idx]
                edges = edges.T.tolist()
                # Filter out edges coming from the query protein
                edges = [edge for edge in edges if reverse_id_mapping[edge_type[0]][edge[0]] != exclude_node_id]
                neighbor_edges[edge_type].extend(edges)
    
    # Map indices back to IDs
    for edge_type in neighbor_edges.keys():
        if neighbor_edges[edge_type]:
            mapped_edges = set()
            for edge in neighbor_edges[edge_type]:
                source_type, _, target_type = edge_type
                source_id = reverse_id_mapping[source_type][edge[0]]
                target_id = reverse_id_mapping[target_type][edge[1]]
                mapped_edges.add((source_id, target_id))
            neighbor_edges[edge_type] = mapped_edges
            
    return neighbor_edges

def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10, second_degree_limit=3, include_second_degree=False):
    with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file:
        name_info = json.load(file)

    # Get the first-degree edges and filter them
    protein_edges = _gather_protein_edges(data, protein_id)
    first_degree_edges = _filter_edges(protein_id, protein_edges, prediction_df, 
                                     limit=limit, is_second_degree=False)
    
    # Initialize all_edges with first degree edges
    all_edges = first_degree_edges.copy()
    
    if include_second_degree:
        # Collect neighbor nodes from first-degree edges
        neighbor_nodes = set()
        for edge_type, edges in first_degree_edges.items():
            source_type, _, target_type = edge_type
            for edge_info in edges:
                edge = edge_info[0]
                source, target = edge
                if source != protein_id:
                    neighbor_nodes.add((source, source_type))
                if target != protein_id:
                    neighbor_nodes.add((target, target_type))
        
        # Gather and filter second-degree edges with the smaller limit
        second_degree_edges = {}
        for neighbor_id, neighbor_type in neighbor_nodes:
            neighbor_edges = _gather_neighbor_edges(data, neighbor_id, neighbor_type, protein_id)
            filtered_neighbor_edges = _filter_edges(neighbor_id, neighbor_edges, prediction_df,
                                                 limit=limit, 
                                                 is_second_degree=True,
                                                 second_degree_limit=second_degree_limit)
            
            # Merge filtered neighbor edges into second_degree_edges
            for edge_type, edges in filtered_neighbor_edges.items():
                if edge_type not in second_degree_edges:
                    second_degree_edges[edge_type] = []
                second_degree_edges[edge_type].extend(edges)
        
        # Merge first and second degree edges
        for edge_type, edges in second_degree_edges.items():
            if edge_type in all_edges:
                all_edges[edge_type].extend(edges)
            else:
                all_edges[edge_type] = edges

    # Update visualized_edges with all edges
    visualized_edges = all_edges
    
    print(f'Edges to be visualized: {visualized_edges}')

    net = Network(height="600px", width="100%", directed=True, notebook=False)

    # Create groups configuration from NODE_TYPE_COLORS
    groups_config = {}
    for node_type, color in NODE_TYPE_COLORS.items():
            groups_config[node_type] = {
                "color": {"background": color, "border": color}
            }

    # Convert groups_config to a JSON-compatible string
    groups_json = json.dumps(groups_config)

    # Configure physics options with settings for better clustering
    net.set_options("""{
        "physics": {
            "enabled": true,
            "barnesHut": {
                "gravitationalConstant": -1000,
                "springLength": 250,
                "springConstant": 0.001,
                "damping": 0.09,
                "avoidOverlap": 0
            },
            "forceAtlas2Based": {
                "gravitationalConstant": -50,
                "centralGravity": 0.01,
                "springLength": 100,
                "springConstant": 0.08,
                "damping": 0.4,
                "avoidOverlap": 0
            },
            "solver": "barnesHut",
            "stabilization": {
                "enabled": true,
                "iterations": 1000,
                "updateInterval": 25
            }
        },
        "layout": {
            "improvedLayout": true,
            "hierarchical": {
                "enabled": false
            }
        },
        "interaction": {
            "hover": true,
            "navigationButtons": true,
            "multiselect": true
        },
        "configure": {
            "enabled": false,
            "filter": ["physics", "layout", "manipulation"],
            "showButton": true
        },
        "groups": """ + groups_json + "}")

    # Add the main protein node
    query_node_url = get_node_url('Protein', protein_id)
    node_name = name_info['Protein'][protein_id]
    query_node_title = f"{node_name} (Query Protein)"
    if query_node_url:
        query_node_title = f'<a href="{query_node_url}" target="_blank">{query_node_title}</a>'

    net.add_node(protein_id, 
                 label=protein_id, 
                 title=query_node_title,
                 color={'background': 'white', 'border': '#c1121f'},
                 borderWidth=4,
                 shape="dot", 
                 font={'color': '#000000', 'size': 15},
                 group='Protein',
                 size=30,
                 mass=2.5)

    # Track added nodes to avoid duplication
    added_nodes = {protein_id}

    # Add edges and target nodes
    for edge_type, edges in visualized_edges.items():            
        source_type, relation_type, target_type = edge_type

        if relation_type in ['protein_function', 'domain_function']:
            relation_type = EDGE_LABEL_TRANSLATION[(relation_type, target_type)]
        else:
            relation_type = EDGE_LABEL_TRANSLATION[relation_type]

        for edge_info in edges:
            edge, probability, is_ground_truth = edge_info 
            source, target = edge[0], edge[1]
            source_str = str(source)
            target_str = str(target)

            # Add source node if not present
            if source_str not in added_nodes:
                if not source_type.startswith('GO_term'):
                    node_name = name_info[source_type][source_str]
                else:
                    node_name = name_info['GO_term'][source_str]
                url = get_node_url(source_type, source_str)
                node_color = NODE_TYPE_COLORS[source_type]
                node_type_label = NODE_LABEL_TRANSLATION[source_type] if source_type in NODE_LABEL_TRANSLATION else source_type
                if url:
                    title = f"<div style='color: {node_color}'><a href='{url}' target='_blank'>{node_name} ({node_type_label})</a></div>"
                else:
                    title = f"<div style='color: {node_color}'>{node_name} ({node_type_label})</div>"
                net.add_node(source_str, 
                           label=source_str,
                           shape="dot", 
                           font={'color': '#000000', 'size': 12},
                           title=title,
                           group=source_type,
                           size=15,
                           mass=1.5)
                added_nodes.add(source_str)

            # Add target node if not present
            if target_str not in added_nodes:
                if not target_type.startswith('GO_term'):
                    node_name = name_info[target_type][target_str]
                else:
                    node_name = name_info['GO_term'][target_str]
                url = get_node_url(target_type, target_str)
                node_color = NODE_TYPE_COLORS[target_type]
                node_type_label = NODE_LABEL_TRANSLATION[target_type] if target_type in NODE_LABEL_TRANSLATION else target_type
                if url:
                    title = f"<div style='color: {node_color}'><a href='{url}' target='_blank'>{node_name} ({node_type_label})</a></div>"
                else:
                    title = f"<div style='color: {node_color}'>{node_name} ({node_type_label})</div>"
                net.add_node(target_str, 
                           label=target_str,
                           shape="dot", 
                           font={'color': '#000000', 'size': 12},
                           title=title,
                           group=target_type,
                           size=15,
                           mass=1.5)
                added_nodes.add(target_str)

            # Add edge with relationship type and probability as label
            edge_label = f"{relation_type}"
            if probability is not None:
                if probability == 'no_pred':
                    edge_color = '#219ebc'
                    title_text = f"{relation_type} (P=Not generated)"
                else:
                    title_text = f"{relation_type} (P={probability:.2f})"
                    edge_color = '#8338ec' if is_ground_truth else '#c1121f'
                # Add color to edge title
                title_text = f"<div style='color: {edge_color}'>{title_text}</div>"
                net.add_edge(source_str, target_str, 
                        label='',  # Empty label
                        font={'size': 0},
                        color=edge_color,
                        title=title_text,
                        length=200,
                        smooth={'type': 'curvedCW', 'roundness': 0.1})
            else:
                edge_color = '#666666'
                title_text = f"<div style='color: {edge_color}'>{edge_label}</div>"
                net.add_edge(source_str, target_str, 
                        label='',  # Empty label
                        font={'size': 0},
                        color=edge_color,
                        title=title_text,
                        length=200,
                        smooth={'type': 'curvedCW', 'roundness': 0.1})
                
    # LEGEND
    legend_html = """
    <style>
        .kg-legend {
            margin-top: 20px;
            padding: 20px;
            border: 1px solid #ddd;
            border-radius: 5px;
            font-family: Arial, sans-serif;
            display: flex;
            gap: 20px;
        }
        .legend-section-nodes {
            flex: 2;  /* Takes up 2/3 of the space */
        }
        .legend-section-edges {
            flex: 1;  /* Takes up 1/3 of the space */
        }
        .legend-title {
            margin-bottom: 15px;
            color: #333;
            font-size: 16px;
            font-weight: bold;
        }
        .nodes-grid {
            display: grid;
            grid-template-columns: repeat(2, 1fr);
            gap: 12px;
        }
        .edges-grid {
            display: grid;
            grid-template-columns: 1fr;
            gap: 12px;
        }
        .legend-item {
            display: flex;
            align-items: center;
            padding: 4px;
        }
        .node-indicator {
            width: 15px;
            height: 15px;
            border-radius: 50%;
            margin-right: 10px;
            flex-shrink: 0;
        }
        .edge-indicator {
            width: 40px;
            height: 3px;
            margin-right: 10px;
            flex-shrink: 0;
        }
        .legend-label {
            font-size: 14px;
        }
    </style>
    <div class="kg-legend">
        <div class="legend-section-nodes">
            <div class="legend-title">Node Types</div>
            <div class="nodes-grid">"""

    # Node types in 2 columns
    for node_type, color in NODE_TYPE_COLORS.items():
        if node_type == 'kegg_Pathway':
            continue
        if node_type in NODE_LABEL_TRANSLATION:
            node_label = NODE_LABEL_TRANSLATION[node_type]
        else:
            node_label = node_type
        legend_html += f"""
                <div class="legend-item">
                    <div class="node-indicator" style="background-color: {color};"></div>
                    <span class="legend-label">{node_label}</span>
                </div>"""

    # Edge types in 1 column
    legend_html += """
            </div>
        </div>
        <div class="legend-section-edges">
            <div class="legend-title">Edge Colors</div>
            <div class="edges-grid">
                <div class="legend-item">
                    <div class="edge-indicator" style="background-color: #8338ec;"></div>
                    <span class="legend-label">Confirmed Prediction (Found in Ground Truth)</span>
                </div>
                <div class="legend-item">
                    <div class="edge-indicator" style="background-color: #c1121f;"></div>
                    <span class="legend-label">Novel Prediction (Not in Ground Truth)</span>
                </div>
                <div class="legend-item">
                    <div class="edge-indicator" style="background-color: #219ebc;"></div>
                    <span class="legend-label">Existing GO Term Annotation</span>
                </div>
                <div class="legend-item">
                    <div class="edge-indicator" style="background-color: #666666;"></div>
                    <span class="legend-label">Other Relationships</span>
                </div>
        </div>
    </div>
    """

    # Save graph to a protein-specific file in a temporary directory
    os.makedirs('temp_viz', exist_ok=True)
    suffix = "_with_2nd_degree" if include_second_degree else "_1st_degree"
    file_path = os.path.join('temp_viz', f'{protein_id}_graph{suffix}.html')
    
    net.save_graph(file_path)
    
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()

    # Add the custom popup JavaScript code before the return network statement
    custom_popup_code = """
        // make a custom popup
        var popup = document.createElement("div");
        popup.className = 'popup';
        popupTimeout = null;
        popup.addEventListener('mouseover', function () {
            if (popupTimeout !== null) {
                clearTimeout(popupTimeout);
                popupTimeout = null;
            }
        });
        popup.addEventListener('mouseout', function () {
            if (popupTimeout === null) {
                hidePopup();
            }
        });
        container.appendChild(popup);

        // use the popup event to show
        network.on("showPopup", function (params) {
            showPopup(params);
        });

        // use the hide event to hide it
        network.on("hidePopup", function (params) {
            hidePopup();
        });

        // hiding the popup through css
        function hidePopup() {
            popupTimeout = setTimeout(function () { popup.style.display = 'none'; }, 500);
        }

        // showing the popup
        function showPopup(nodeId) {
            // get the data from the vis.DataSet
            var nodeData = nodes.get(nodeId);
            // get the position of the node
            var posCanvas = network.getPositions([nodeId])[nodeId];

            if (!nodeData) {
                var edgeData = edges.get(nodeId);
                var poses = network.getPositions([edgeData.from, edgeData.to]);
                var middle_x = (poses[edgeData.to].x - poses[edgeData.from].x) * 0.5;
                var middle_y = (poses[edgeData.to].y - poses[edgeData.from].y) * 0.5;
                posCanvas = poses[edgeData.from];
                posCanvas.x = posCanvas.x + middle_x;
                posCanvas.y = posCanvas.y + middle_y;

                popup.innerHTML = edgeData.title;
            } else {
                popup.innerHTML = nodeData.title;
                // get the bounding box of the node
                var boundingBox = network.getBoundingBox(nodeId);
                posCanvas.x = posCanvas.x + 0.5 * (boundingBox.right - boundingBox.left);
                posCanvas.y = posCanvas.y + 0.5 * (boundingBox.top - boundingBox.bottom);
            };

            //position tooltip:
            // convert coordinates to the DOM space
            var posDOM = network.canvasToDOM(posCanvas);

            // Give it an offset
            posDOM.x += 10;
            posDOM.y -= 20;

            // show and place the tooltip.
            popup.style.display = 'block';
            popup.style.top = posDOM.y + 'px';
            popup.style.left = posDOM.x + 'px';
        }
    """

    # Add the custom popup CSS
    custom_popup_css = """
        /* position absolute is important and the container has to be relative or absolute as well. */
        div.popup {
            position: absolute;
            top: 0px;
            left: 0px;
            display: none;
            background-color: white;
            border-radius: 3px;
            border: 1px solid #ddd;
            box-shadow: 3px 3px 10px rgba(0, 0, 0, 0.2);
            padding: 5px;
            z-index: 1000;
        }
        div.popup a {
            color: inherit;
            text-decoration: underline;
        }
    """

    # Insert the custom CSS in the head
    content = content.replace('</style>', f'{custom_popup_css}</style>')

    # Insert the custom popup code before the "return network;" statement
    content = content.replace('return network;', f'{custom_popup_code}\nreturn network;')

    # Remove the original tooltip-hiding CSS if it exists
    content = content.replace("""
        /* hide the original tooltip */
        .vis-network-tooltip {
          display:none;
        }""", "")

    # Insert the legend before the closing body tag
    content = content.replace('</body>', f'{legend_html}</body>')
    
    with open(file_path, 'w', encoding='utf-8') as f:
        f.write(content)

    return file_path, visualized_edges