|
|
|
|
|
""" |
|
Reference: |
|
- [graphrag](https://github.com/microsoft/graphrag) |
|
""" |
|
|
|
from typing import Any |
|
|
|
import numpy as np |
|
import networkx as nx |
|
from graphrag.leiden import stable_largest_connected_component |
|
|
|
|
|
@dataclass |
|
class NodeEmbeddings: |
|
"""Node embeddings class definition.""" |
|
|
|
nodes: list[str] |
|
embeddings: np.ndarray |
|
|
|
|
|
def embed_nod2vec( |
|
graph: nx.Graph | nx.DiGraph, |
|
dimensions: int = 1536, |
|
num_walks: int = 10, |
|
walk_length: int = 40, |
|
window_size: int = 2, |
|
iterations: int = 3, |
|
random_seed: int = 86, |
|
) -> NodeEmbeddings: |
|
"""Generate node embeddings using Node2Vec.""" |
|
|
|
lcc_tensors = gc.embed.node2vec_embed( |
|
graph=graph, |
|
dimensions=dimensions, |
|
window_size=window_size, |
|
iterations=iterations, |
|
num_walks=num_walks, |
|
walk_length=walk_length, |
|
random_seed=random_seed, |
|
) |
|
return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1]) |
|
|
|
|
|
def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings: |
|
"""Run method definition.""" |
|
if args.get("use_lcc", True): |
|
graph = stable_largest_connected_component(graph) |
|
|
|
|
|
embeddings = embed_nod2vec( |
|
graph=graph, |
|
dimensions=args.get("dimensions", 1536), |
|
num_walks=args.get("num_walks", 10), |
|
walk_length=args.get("walk_length", 40), |
|
window_size=args.get("window_size", 2), |
|
iterations=args.get("iterations", 3), |
|
random_seed=args.get("random_seed", 86), |
|
) |
|
|
|
pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True) |
|
sorted_pairs = sorted(pairs, key=lambda x: x[0]) |
|
|
|
return dict(sorted_pairs) |