gzzyyxy's picture
Upload folder using huggingface_hub
c1a7f73 verified
import torch
from typing import List, Optional, Tuple, Union
from torch_geometric.utils import coalesce
from torch_geometric.utils import degree
def add_edges(
from_edge_index: torch.Tensor,
to_edge_index: torch.Tensor,
from_edge_attr: Optional[torch.Tensor] = None,
to_edge_attr: Optional[torch.Tensor] = None,
replace: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
from_edge_index = from_edge_index.to(device=to_edge_index.device, dtype=to_edge_index.dtype)
mask = ((to_edge_index[0].unsqueeze(-1) == from_edge_index[0].unsqueeze(0)) &
(to_edge_index[1].unsqueeze(-1) == from_edge_index[1].unsqueeze(0)))
if replace:
to_mask = mask.any(dim=1)
if from_edge_attr is not None and to_edge_attr is not None:
from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)
to_edge_attr = torch.cat([to_edge_attr[~to_mask], from_edge_attr], dim=0)
to_edge_index = torch.cat([to_edge_index[:, ~to_mask], from_edge_index], dim=1)
else:
from_mask = mask.any(dim=0)
if from_edge_attr is not None and to_edge_attr is not None:
from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)
to_edge_attr = torch.cat([to_edge_attr, from_edge_attr[~from_mask]], dim=0)
to_edge_index = torch.cat([to_edge_index, from_edge_index[:, ~from_mask]], dim=1)
return to_edge_index, to_edge_attr
def merge_edges(
edge_indices: List[torch.Tensor],
edge_attrs: Optional[List[torch.Tensor]] = None,
reduce: str = 'add') -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
edge_index = torch.cat(edge_indices, dim=1)
if edge_attrs is not None:
edge_attr = torch.cat(edge_attrs, dim=0)
else:
edge_attr = None
return coalesce(edge_index=edge_index, edge_attr=edge_attr, reduce=reduce)
def complete_graph(
num_nodes: Union[int, Tuple[int, int]],
ptr: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
loop: bool = False,
device: Optional[Union[torch.device, str]] = None) -> torch.Tensor:
if ptr is None:
if isinstance(num_nodes, int):
num_src, num_dst = num_nodes, num_nodes
else:
num_src, num_dst = num_nodes
edge_index = torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),
torch.arange(num_dst, dtype=torch.long, device=device)).t()
else:
if isinstance(ptr, torch.Tensor):
ptr_src, ptr_dst = ptr, ptr
num_src_batch = num_dst_batch = ptr[1:] - ptr[:-1]
else:
ptr_src, ptr_dst = ptr
num_src_batch = ptr_src[1:] - ptr_src[:-1]
num_dst_batch = ptr_dst[1:] - ptr_dst[:-1]
edge_index = torch.cat(
[torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),
torch.arange(num_dst, dtype=torch.long, device=device)) + p
for num_src, num_dst, p in zip(num_src_batch, num_dst_batch, torch.stack([ptr_src, ptr_dst], dim=1))],
dim=0)
edge_index = edge_index.t()
if isinstance(num_nodes, int) and not loop:
edge_index = edge_index[:, edge_index[0] != edge_index[1]]
return edge_index.contiguous()
def bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor:
index = adj.nonzero(as_tuple=True)
if len(index) == 3:
batch_src = index[0] * adj.size(1)
batch_dst = index[0] * adj.size(2)
index = (batch_src + index[1], batch_dst + index[2])
return torch.stack(index, dim=0)
def unbatch(
src: torch.Tensor,
batch: torch.Tensor,
dim: int = 0) -> List[torch.Tensor]:
sizes = degree(batch, dtype=torch.long).tolist()
return src.split(sizes, dim)