import plotly.graph_objects as go
import networkx as nx
import numpy as np
from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges, 
                          Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource)
from bokeh.palettes import Spectral4
from bokeh.plotting import from_networkx

def create_graph(entities, relationships):
    G = nx.Graph()
    for entity_id, entity_data in entities.items():
        G.add_node(entity_id, label=f"{entity_data.get('value', 'Unknown')} ({entity_data.get('type', 'Unknown')})")
    
    for source, relation, target in relationships:
        G.add_edge(source, target, label=relation)
    
    return G

def improved_spectral_layout(G, scale=1):
    pos = nx.spectral_layout(G)
    # Add some random noise to prevent overlapping
    pos = {node: (x + np.random.normal(0, 0.1), y + np.random.normal(0, 0.1)) for node, (x, y) in pos.items()}
    # Scale the layout
    pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
    return pos

def create_bokeh_plot(G, layout_type='spring'):
    plot = Plot(width=600, height=600,
                x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2))
    plot.title.text = "Knowledge Graph Interaction"

    node_hover = HoverTool(tooltips=[("Entity", "@label")])
    edge_hover = HoverTool(tooltips=[("Relation", "@label")])
    plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool())

    # Create layout based on layout_type
    if layout_type == 'spring':
        pos = nx.spring_layout(G, k=0.5, iterations=50)
    elif layout_type == 'fruchterman_reingold':
        pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
    elif layout_type == 'circular':
        pos = nx.circular_layout(G)
    elif layout_type == 'random':
        pos = nx.random_layout(G)
    elif layout_type == 'spectral':
        pos = improved_spectral_layout(G)
    elif layout_type == 'shell':
        pos = nx.shell_layout(G)
    else:
        pos = nx.spring_layout(G, k=0.5, iterations=50)

    graph_renderer = from_networkx(G, pos, scale=1, center=(0, 0))

    graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0])
    graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2])
    graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1])

    graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3)
    graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4)
    graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3)

    graph_renderer.selection_policy = NodesAndLinkedEdges()
    graph_renderer.inspection_policy = NodesAndLinkedEdges()

    plot.renderers.append(graph_renderer)

    # Add node labels
    x, y = zip(*graph_renderer.layout_provider.graph_layout.values())
    node_labels = nx.get_node_attributes(G, 'label')
    source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]})
    labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white',
                      text_font_size='8pt', background_fill_alpha=0.7)
    plot.renderers.append(labels)

    # Add edge labels
    edge_x, edge_y, edge_labels = [], [], []
    for (start_node, end_node, label) in G.edges(data='label'):
        start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node]
        end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node]
        edge_x.append((start_x + end_x) / 2)
        edge_y.append((start_y + end_y) / 2)
        edge_labels.append(label)

    edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels})
    edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source,
                           background_fill_color='white', text_font_size='8pt',
                           background_fill_alpha=0.7)
    plot.renderers.append(edge_labels)

    return plot

def create_plotly_plot(G, layout_type='spring'):
    # Create layout based on layout_type
    if layout_type == 'spring':
        pos = nx.spring_layout(G, k=0.5, iterations=50)
    elif layout_type == 'fruchterman_reingold':
        pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
    elif layout_type == 'circular':
        pos = nx.circular_layout(G)
    elif layout_type == 'random':
        pos = nx.random_layout(G)
    elif layout_type == 'spectral':
        pos = improved_spectral_layout(G)
    elif layout_type == 'shell':
        pos = nx.shell_layout(G)
    else:
        pos = nx.spring_layout(G, k=0.5, iterations=50)

    edge_trace = go.Scatter(x=[], y=[], line=dict(width=1, color="#888"), hoverinfo="text", mode="lines", text=[])
    node_trace = go.Scatter(x=[], y=[], mode="markers+text", hoverinfo="text",
                            marker=dict(showscale=True, colorscale="Viridis", reversescale=True, color=[], size=15,
                                        colorbar=dict(thickness=15, title="Node Connections", xanchor="left", titleside="right"),
                                        line_width=2),
                            text=[], textposition="top center")

    edge_labels = []

    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_trace["x"] += (x0, x1, None)
        edge_trace["y"] += (y0, y1, None)
        
        mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
        edge_labels.append(go.Scatter(x=[mid_x], y=[mid_y], mode="text", text=[G.edges[edge]["label"]],
                                      textposition="middle center", hoverinfo="none", showlegend=False, textfont=dict(size=8)))

    for node in G.nodes():
        x, y = pos[node]
        node_trace["x"] += (x,)
        node_trace["y"] += (y,)
        node_trace["text"] += (G.nodes[node]["label"],)
        node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)

    fig = go.Figure(data=[edge_trace, node_trace] + edge_labels,
                    layout=go.Layout(title="Knowledge Graph", titlefont_size=16, showlegend=False, hovermode="closest",
                                     margin=dict(b=20, l=5, r=5, t=40), annotations=[],
                                     xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                                     yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                                     width=800, height=600))

    fig.update_layout(newshape=dict(line_color="#009900"),
                      xaxis=dict(scaleanchor="y", scaleratio=1),
                      yaxis=dict(scaleanchor="x", scaleratio=1))

    return fig