LuyangZ's picture
Upload 34 files
e760df8 verified
raw
history blame contribute delete
690 Bytes
"""Miscilaneous functions.
"""
import torch
def log_sum_exp(x, axis=None):
"""Log sum exp function
Args:
x: Input.
axis: Axis over which to perform sum.
Returns:
torch.Tensor: log sum exp
"""
x_max = torch.max(x, axis)[0]
y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max
return y
def random_permute(X):
"""Randomly permutes a tensor.
Args:
X: Input tensor.
Returns:
torch.Tensor
"""
X = X.transpose(1, 2)
b = torch.rand((X.size(0), X.size(1))).cuda()
idx = b.sort(0)[1]
adx = torch.range(0, X.size(1) - 1).long()
X = X[idx, adx[None, :]].transpose(1, 2)
return X