import requests
import networkx as nx
import matplotlib.pyplot as plt

# API Base URL
base_url = "http://localhost:5000"

def fetch_relationships(node_id, direction="down"):
    """Fetch relationships for the specified node in the given direction (up or down)."""
    response = requests.get(f"{base_url}/traverse_node?node_id={node_id}&direction={direction}")
    return response.json().get("traversal_path", {})

def build_graph_from_relationships(node_id):
    """Builds a NetworkX graph based on recursive relationship traversal."""
    # Initialize directed graph
    G = nx.DiGraph()

    # Collect descendants and ancestors to build the graph structure
    descendants_data = fetch_relationships(node_id, direction="down")
    ancestors_data = fetch_relationships(node_id, direction="up")

    # Recursively add nodes and edges for both descendants and ancestors
    add_nodes_and_edges(G, descendants_data)
    add_nodes_and_edges(G, ancestors_data)

    return G

def add_nodes_and_edges(G, node, visited=None):
    """Recursive function to add nodes and edges from a traversal hierarchy to a NetworkX graph."""
    if visited is None:
        visited = set()

    node_id = node.get("node_id")
    if not node_id or node_id in visited:
        return
    visited.add(node_id)

    # Add node to graph
    G.add_node(node_id, label=node_id)

    # Process child (descendant) relationships
    for child in node.get("descendants", []):
        child_id = child.get("node_id")
        relationship = child.get("relationship", "related_to")
        G.add_edge(node_id, child_id, label=relationship)
        add_nodes_and_edges(G, child, visited)  # Recursive call for descendants

    # Process parent (ancestor) relationships
    for ancestor in node.get("ancestors", []):
        ancestor_id = ancestor.get("node_id")
        relationship = ancestor.get("relationship", "related_to")
        G.add_edge(ancestor_id, node_id, label=relationship)
        add_nodes_and_edges(G, ancestor, visited)  # Recursive call for ancestors

def visualize_graph(G, title="Graph Structure and Relationships"):
    """Visualize the graph using matplotlib and networkx."""
    plt.figure(figsize=(12, 8))
    pos = nx.spring_layout(G)

    # Draw nodes and labels
    nx.draw_networkx_nodes(G, pos, node_size=3000, node_color="skyblue", alpha=0.8)
    nx.draw_networkx_labels(G, pos, font_size=10, font_color="black")

    # Draw edges with labels
    nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True)
    edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)}
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color="red")

    # Title and display options
    plt.title(title)
    plt.axis("off")
    plt.show()

# Step 1: Load Graph (Specify the graph to load, e.g., PHSA/340B section)
print("\n--- Loading Graph ---")
graph_data = {"graph_file": "graphs/PHSA/phsa_sec_340b.json"}
response = requests.post(f"{base_url}/load_graph", json=graph_data)
print("Load Graph Response:", response.json())

# Step 2: Build and visualize the graph for 340B Program
print("\n--- Building Graph for Visualization ---")
G = build_graph_from_relationships("340B Program")
visualize_graph(G, title="340B Program - Inferred Contextual Relationships")