|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
def get_emb(sin_inp): |
|
""" |
|
Gets a base embedding for one dimension with sin and cos intertwined |
|
""" |
|
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) |
|
return torch.flatten(emb, -2, -1) |
|
|
|
|
|
class PositionalEncoding1D(nn.Module): |
|
def __init__(self, channels): |
|
""" |
|
:param channels: The last dimension of the tensor you want to apply pos emb to. |
|
""" |
|
super(PositionalEncoding1D, self).__init__() |
|
self.org_channels = channels |
|
channels = int(np.ceil(channels / 2) * 2) |
|
self.channels = channels |
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.cached_penc = None |
|
|
|
def forward(self, tensor): |
|
""" |
|
:param tensor: A 3d tensor of size (batch_size, x, ch) |
|
:return: Positional Encoding Matrix of size (batch_size, x, ch) |
|
""" |
|
if len(tensor.shape) != 3: |
|
raise RuntimeError("The input tensor has to be 3d!") |
|
|
|
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: |
|
return self.cached_penc |
|
|
|
self.cached_penc = None |
|
batch_size, x, orig_ch = tensor.shape |
|
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) |
|
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) |
|
emb_x = get_emb(sin_inp_x) |
|
emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type()) |
|
emb[:, : self.channels] = emb_x |
|
|
|
self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1) |
|
return self.cached_penc |
|
|
|
|
|
class PositionalEncodingPermute1D(nn.Module): |
|
def __init__(self, channels): |
|
""" |
|
Accepts (batchsize, ch, x) instead of (batchsize, x, ch) |
|
""" |
|
super(PositionalEncodingPermute1D, self).__init__() |
|
self.penc = PositionalEncoding1D(channels) |
|
|
|
def forward(self, tensor): |
|
tensor = tensor.permute(0, 2, 1) |
|
enc = self.penc(tensor) |
|
return enc.permute(0, 2, 1) |
|
|
|
@property |
|
def org_channels(self): |
|
return self.penc.org_channels |
|
|
|
|
|
class PositionalEncoding2D(nn.Module): |
|
def __init__(self, channels): |
|
""" |
|
:param channels: The last dimension of the tensor you want to apply pos emb to. |
|
""" |
|
super(PositionalEncoding2D, self).__init__() |
|
self.org_channels = channels |
|
channels = int(np.ceil(channels / 4) * 2) |
|
self.channels = channels |
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
self.cached_penc = None |
|
|
|
def forward(self, tensor): |
|
""" |
|
:param tensor: A 4d tensor of size (batch_size, x, y, ch) |
|
:return: Positional Encoding Matrix of size (batch_size, x, y, ch) |
|
""" |
|
if len(tensor.shape) != 4: |
|
raise RuntimeError("The input tensor has to be 4d!") |
|
|
|
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: |
|
return self.cached_penc |
|
|
|
self.cached_penc = None |
|
batch_size, x, y, orig_ch = tensor.shape |
|
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) |
|
pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) |
|
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) |
|
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) |
|
emb_x = get_emb(sin_inp_x).unsqueeze(1) |
|
emb_y = get_emb(sin_inp_y) |
|
emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type( |
|
tensor.type() |
|
) |
|
emb[:, :, : self.channels] = emb_x |
|
emb[:, :, self.channels : 2 * self.channels] = emb_y |
|
|
|
self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1) |
|
return self.cached_penc |
|
|
|
|
|
class PositionalEncodingPermute2D(nn.Module): |
|
def __init__(self, channels): |
|
""" |
|
Accepts (batchsize, ch, x, y) instead of (batchsize, x, y, ch) |
|
""" |
|
super(PositionalEncodingPermute2D, self).__init__() |
|
self.penc = PositionalEncoding2D(channels) |
|
|
|
def forward(self, tensor): |
|
tensor = tensor.permute(0, 2, 3, 1) |
|
enc = self.penc(tensor) |
|
return enc.permute(0, 3, 1, 2) |
|
|
|
@property |
|
def org_channels(self): |
|
return self.penc.org_channels |
|
|
|
|
|
class PositionalEncoding3D(nn.Module): |
|
def __init__(self, channels): |
|
""" |
|
:param channels: The last dimension of the tensor you want to apply pos emb to. |
|
""" |
|
super(PositionalEncoding3D, self).__init__() |
|
self.org_channels = channels |
|
channels = int(np.ceil(channels / 6) * 2) |
|
if channels % 2: |
|
channels += 1 |
|
self.channels = channels |
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
self.cached_penc = None |
|
|
|
def forward(self, tensor): |
|
""" |
|
:param tensor: A 5d tensor of size (batch_size, x, y, z, ch) |
|
:return: Positional Encoding Matrix of size (batch_size, x, y, z, ch) |
|
""" |
|
if len(tensor.shape) != 5: |
|
raise RuntimeError("The input tensor has to be 5d!") |
|
|
|
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: |
|
return self.cached_penc |
|
|
|
self.cached_penc = None |
|
batch_size, x, y, z, orig_ch = tensor.shape |
|
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) |
|
pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) |
|
pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type()) |
|
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) |
|
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) |
|
sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq) |
|
emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1) |
|
emb_y = get_emb(sin_inp_y).unsqueeze(1) |
|
emb_z = get_emb(sin_inp_z) |
|
emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type( |
|
tensor.type() |
|
) |
|
emb[:, :, :, : self.channels] = emb_x |
|
emb[:, :, :, self.channels : 2 * self.channels] = emb_y |
|
emb[:, :, :, 2 * self.channels :] = emb_z |
|
|
|
self.cached_penc = emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1) |
|
return self.cached_penc |
|
|
|
|
|
class PositionalEncodingPermute3D(nn.Module): |
|
def __init__(self, channels): |
|
""" |
|
Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch) |
|
""" |
|
super(PositionalEncodingPermute3D, self).__init__() |
|
self.penc = PositionalEncoding3D(channels) |
|
|
|
def forward(self, tensor): |
|
tensor = tensor.permute(0, 2, 3, 4, 1) |
|
enc = self.penc(tensor) |
|
return enc.permute(0, 4, 1, 2, 3) |
|
|
|
@property |
|
def org_channels(self): |
|
return self.penc.org_channels |
|
|
|
|
|
class Summer(nn.Module): |
|
def __init__(self, penc): |
|
""" |
|
:param model: The type of positional encoding to run the summer on. |
|
""" |
|
super(Summer, self).__init__() |
|
self.penc = penc |
|
|
|
def forward(self, tensor): |
|
""" |
|
:param tensor: A 3, 4 or 5d tensor that matches the model output size |
|
:return: Positional Encoding Matrix summed to the original tensor |
|
""" |
|
penc = self.penc(tensor) |
|
assert ( |
|
tensor.size() == penc.size() |
|
), "The original tensor size {} and the positional encoding tensor size {} must match!".format( |
|
tensor.size(), penc.size() |
|
) |
|
return tensor + penc |
|
|
|
|
|
class SparsePositionalEncoding2D(PositionalEncoding2D): |
|
def __init__(self, channels, x, y, device='cuda'): |
|
super(SparsePositionalEncoding2D, self).__init__(channels) |
|
self.y, self.x = y, x |
|
self.fake_tensor = torch.zeros((1, x, y, channels), device=device) |
|
|
|
def forward(self, coords): |
|
""" |
|
:param coords: A list of list of coordinates (((x1, y1), (x2, y22), ... ), ... ) |
|
:return: Positional Encoding Matrix summed to the original tensor |
|
""" |
|
encodings = super().forward(self.fake_tensor) |
|
encodings = encodings.permute(0, 3, 1, 2) |
|
indices = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(c) for c in coords], batch_first=True, padding_value=-1) |
|
indices = indices.unsqueeze(0).to(self.fake_tensor.device) |
|
assert self.x == self.y |
|
indices = (indices + 0.5) / self.x * 2 - 1 |
|
indices = torch.flip(indices, (-1, )) |
|
return torch.nn.functional.grid_sample(encodings, indices).squeeze().permute(2, 1, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
pos = SparsePositionalEncoding2D(10, 10, 20) |
|
pos([[0, 0], [0, 9], [1, 0], [9, 15]]) |
|
|