Spaces:
Running
Running
File size: 690 Bytes
e760df8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
"""Miscilaneous functions.
"""
import torch
def log_sum_exp(x, axis=None):
"""Log sum exp function
Args:
x: Input.
axis: Axis over which to perform sum.
Returns:
torch.Tensor: log sum exp
"""
x_max = torch.max(x, axis)[0]
y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max
return y
def random_permute(X):
"""Randomly permutes a tensor.
Args:
X: Input tensor.
Returns:
torch.Tensor
"""
X = X.transpose(1, 2)
b = torch.rand((X.size(0), X.size(1))).cuda()
idx = b.sort(0)[1]
adx = torch.range(0, X.size(1) - 1).long()
X = X[idx, adx[None, :]].transpose(1, 2)
return X
|