# Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License """ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ from typing import Any import numpy as np import networkx as nx from dataclasses import dataclass from graphrag.leiden import stable_largest_connected_component import graspologic as gc @dataclass class NodeEmbeddings: """Node embeddings class definition.""" nodes: list[str] embeddings: np.ndarray def embed_nod2vec( graph: nx.Graph | nx.DiGraph, dimensions: int = 1536, num_walks: int = 10, walk_length: int = 40, window_size: int = 2, iterations: int = 3, random_seed: int = 86, ) -> NodeEmbeddings: """Generate node embeddings using Node2Vec.""" # generate embedding lcc_tensors = gc.embed.node2vec_embed( # type: ignore graph=graph, dimensions=dimensions, window_size=window_size, iterations=iterations, num_walks=num_walks, walk_length=walk_length, random_seed=random_seed, ) return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1]) def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings: """Run method definition.""" if args.get("use_lcc", True): graph = stable_largest_connected_component(graph) # create graph embedding using node2vec embeddings = embed_nod2vec( graph=graph, dimensions=args.get("dimensions", 1536), num_walks=args.get("num_walks", 10), walk_length=args.get("walk_length", 40), window_size=args.get("window_size", 2), iterations=args.get("iterations", 3), random_seed=args.get("random_seed", 86), ) pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True) sorted_pairs = sorted(pairs, key=lambda x: x[0]) return dict(sorted_pairs)