PolygonGNN / util.py
Dzy6's picture
init
e551dda
raw
history blame
4.46 kB
import torch
import pandas as pd
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
import scipy
import torch.nn.functional as F
import torchvision
from sklearn.metrics import explained_variance_score,mean_squared_error,mean_absolute_error,r2_score,precision_score,recall_score,f1_score,roc_auc_score,roc_curve, auc,confusion_matrix
from sklearn.feature_selection import r_regression
from torch_sparse import SparseTensor
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from math import pi as PI
def scipy_spanning_tree(edge_index, edge_weight,num_nodes ):
row, col = edge_index.cpu()
edge_weight=edge_weight.cpu()
cgraph = csr_matrix((edge_weight, (row, col)), shape=(num_nodes, num_nodes))
Tcsr = minimum_spanning_tree(cgraph)
tree_row, tree_col = Tcsr.nonzero()
spanning_edges = np.stack([tree_row,tree_col],0)
return spanning_edges
def build_spanning_tree_edge(edge_index,edge_weight, num_nodes):
spanning_edges = scipy_spanning_tree(edge_index, edge_weight,num_nodes,)
spanning_edges = torch.tensor(spanning_edges, dtype=torch.long, device=edge_index.device)
spanning_edges_undirected = torch.cat([spanning_edges,torch.stack([spanning_edges[1],spanning_edges[0]])],1)
return spanning_edges_undirected
def record(values,epoch,writer,phase="Train"):
""" tfboard write """
for key,value in values.items():
writer.add_scalar(key+"/"+phase,value,epoch)
def calculate(y_hat,y_true,y_hat_logit):
""" calculate five metrics using y_hat, y_true, y_hat_logit """
train_acc=(np.array(y_hat) == np.array(y_true)).sum()/len(y_true)
# recall=recall_score(y_true, y_hat,zero_division=0,average='micro')
# precision=precision_score(y_true, y_hat,zero_division=0,average='micro')
# Fscore=f1_score(y_true, y_hat,zero_division=0,average='micro')
# roc=roc_auc_score(y_true, scipy.special.softmax(np.array(y_hat_logit),axis=1)[:,1],average='micro',multi_class='ovr')
# one_hot_encoded_labels = np.zeros((len(y_true), 100))
# one_hot_encoded_labels[np.arange(len(y_true)), y_true] = 1
# roc=roc_auc_score(one_hot_encoded_labels, scipy.special.softmax(np.array(y_hat_logit),axis=1),average='micro',multi_class='ovr')
return train_acc
def print_1(epoch,phase,values,color=None):
""" print epoch info"""
if color is not None:
print(color( f"epoch[{epoch:d}] {phase}"+ " ".join([f"{key}={value:.3f}" for key, value in values.items()]) ))
else:
print(( f"epoch[{epoch:d}] {phase}"+ " ".join([f"{key}={value:.3f}" for key, value in values.items()]) ))
def get_angle(v1, v2):
if v1.shape[1]==2:
v1=F.pad(v1, (0, 1),value=0)
if v2.shape[1]==2:
v2= F.pad(v2, (0, 1),value=0)
return torch.atan2( torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1))
def get_theta(v1, v2):
# v1 is starting line, right-hand rule to v2, if thumb is up, +, else -
angle=get_angle(v1, v2)
if v1.shape[1]==2:
v1=F.pad(v1, (0, 1),value=0)
if v2.shape[1]==2:
v2= F.pad(v2, (0, 1),value=0)
v = torch.cross(v1, v2, dim=1)[...,2]
flag = torch.sign((v))
flag[flag==0]=-1
return angle*flag
def triplets(edge_index, num_nodes):
row, col = edge_index
value = torch.arange(row.size(0), device=row.device)
adj_t = SparseTensor(row=row, col=col, value=value,
sparse_sizes=(num_nodes, num_nodes))
adj_t_col = adj_t[:,row]
num_triplets = adj_t_col.set_value(None).sum(dim=0).to(torch.long)
idx_j = row.repeat_interleave(num_triplets)
idx_i = col.repeat_interleave(num_triplets)
edx_2nd = value.repeat_interleave(num_triplets)
idx_k = adj_t_col.t().storage.col()
edx_1st = adj_t_col.t().storage.value()
mask1 = (idx_i == idx_k) & (idx_j != idx_i) # Remove go back triplets.
mask2 = (idx_i == idx_j) & (idx_j != idx_k) # Remove repeat self loop triplets
mask3 = (idx_j == idx_k) & (idx_i != idx_k) # Remove self-loop neighbors
mask = ~(mask1 | mask2 | mask3)
idx_i, idx_j, idx_k, edx_1st, edx_2nd = idx_i[mask], idx_j[mask], idx_k[mask], edx_1st[mask], edx_2nd[mask]
num_triplets_real = torch.cumsum(num_triplets, dim=0) - torch.cumsum(~mask, dim=0)[torch.cumsum(num_triplets, dim=0)-1]
return torch.stack([idx_i, idx_j, idx_k]), num_triplets_real.to(torch.long), edx_1st, edx_2nd