import streamlit as st
import networkx as nx
from pyvis.network import Network
import json
from streamlit.components.v1 import html

# Streamlit app layout
st.title("Interactive Graph Visualization and Editor")
st.write("Edit the graph's JSON data and visualize updates dynamically.")

# Read the test.json file initially
default_file = "test.json"

# Function to load JSON from file
def load_json_file(file_path):
    try:
        with open(file_path, "r") as file:
            return json.load(file)
    except Exception as e:
        st.error(f"Error loading the file: {e}")
        return {"nodes": [], "edges": []}

# Load the graph data from the default JSON file
graph_data = load_json_file(default_file)

# Sidebar editor for JSON input
st.sidebar.header("Edit Graph JSON")
json_editor = st.sidebar.text_area("Graph JSON", value=json.dumps(graph_data, indent=4), height=300)

if st.sidebar.button("Update Graph"):
    try:
        # Parse the updated JSON
        graph_data = json.loads(json_editor)
        # Save back to the file
        with open(default_file, "w") as file:
            json.dump(graph_data, file, indent=4)
        st.success("Graph JSON updated successfully!")
    except Exception as e:
        st.sidebar.error(f"Invalid JSON format: {e}")

# Validate JSON structure
if "nodes" not in graph_data or "edges" not in graph_data:
    st.error("The JSON file must contain 'nodes' and 'edges' keys.")
else:
    try:
        # Function to create a NetworkX graph from data
        def create_graph(data):
            G = nx.DiGraph()
            for node in data["nodes"]:
                G.add_node(node["id"], label=node.get("label", node["id"]))
            for edge in data["edges"]:
                G.add_edge(edge["source"], edge["target"], label=edge.get("label", ""))
            return G

        # Generate the graph
        graph = create_graph(graph_data)

        # Function to create a pyvis network from NetworkX graph
        def create_pyvis_graph(G):
            net = Network(height="600px", width="100%", directed=True)
            for node, data in G.nodes(data=True):
                net.add_node(node, label=data.get("label", node))
            for source, target, data in G.edges(data=True):
                net.add_edge(source, target, title=data.get("label", ""))
            return net

        # Create the Pyvis graph
        pyvis_graph = create_pyvis_graph(graph)

        # Generate the HTML representation of the graph
        pyvis_graph_html = pyvis_graph.generate_html()

        # Display the graph in Streamlit
        html(pyvis_graph_html, height=600)

    except Exception as e:
        st.error(f"Error processing the graph: {e}")