Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Variable | |
| import math | |
| class PositionalEncoding_RGB(nn.Module): | |
| "Implement the PE function." | |
| def __init__(self, d_model, dropout=0.1, max_len=50): | |
| super(PositionalEncoding_RGB, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| # Compute the positional encodings once in log space. | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| x = x + Variable(self.pe[:, :x.size(1)], | |
| requires_grad=False) | |
| return self.dropout(x) | |
| def calc_receptive_field(layers, imsize, layer_names=None): | |
| if layer_names is not None: | |
| print("-------Net summary------") | |
| currentLayer = [imsize, 1, 1, 0.5] | |
| for l_id, layer in enumerate(layers): | |
| conv = [ | |
| layer[key][-1] if type(layer[key]) in [list, tuple] else layer[key] | |
| for key in ['kernel_size', 'stride', 'padding'] | |
| ] | |
| currentLayer = outFromIn(conv, currentLayer) | |
| if 'maxpool' in layer: | |
| conv = [ | |
| (layer['maxpool'][key][-1] if type(layer['maxpool'][key]) | |
| in [list, tuple] else layer['maxpool'][key]) if | |
| (not key == 'padding' or 'padding' in layer['maxpool']) else 0 | |
| for key in ['kernel_size', 'stride', 'padding'] | |
| ] | |
| currentLayer = outFromIn(conv, currentLayer, ceil_mode=False) | |
| return currentLayer | |
| def outFromIn(conv, layerIn, ceil_mode=True): | |
| n_in = layerIn[0] | |
| j_in = layerIn[1] | |
| r_in = layerIn[2] | |
| start_in = layerIn[3] | |
| k = conv[0] | |
| s = conv[1] | |
| p = conv[2] | |
| n_out = math.floor((n_in - k + 2 * p) / s) + 1 | |
| actualP = (n_out - 1) * s - n_in + k | |
| pR = math.ceil(actualP / 2) | |
| pL = math.floor(actualP / 2) | |
| j_out = j_in * s | |
| r_out = r_in + (k - 1) * j_in | |
| start_out = start_in + ((k - 1) / 2 - pL) * j_in | |
| return n_out, j_out, r_out, start_out | |
| class DebugModule(nn.Module): | |
| """ | |
| Wrapper class for printing the activation dimensions | |
| """ | |
| def __init__(self, name=None): | |
| super().__init__() | |
| self.name = name | |
| self.debug_log = True | |
| def debug_line(self, layer_str, output, memuse=1, final_call=False): | |
| if self.debug_log: | |
| namestr = '{}: '.format(self.name) if self.name is not None else '' | |
| # print('{}{:80s}: dims {}'.format(namestr, repr(layer_str), | |
| # output.shape)) | |
| if final_call: | |
| self.debug_log = False | |
| # print() | |
| class VGGNet(DebugModule): | |
| conv_dict = { | |
| 'conv1d': nn.Conv1d, | |
| 'conv2d': nn.Conv2d, | |
| 'conv3d': nn.Conv3d, | |
| 'fc1d': nn.Conv1d, | |
| 'fc2d': nn.Conv2d, | |
| 'fc3d': nn.Conv3d, | |
| } | |
| pool_dict = { | |
| 'conv1d': nn.MaxPool1d, | |
| 'conv2d': nn.MaxPool2d, | |
| 'conv3d': nn.MaxPool3d, | |
| } | |
| norm_dict = { | |
| 'conv1d': nn.BatchNorm1d, | |
| 'conv2d': nn.BatchNorm2d, | |
| 'conv3d': nn.BatchNorm3d, | |
| 'fc1d': nn.BatchNorm1d, | |
| 'fc2d': nn.BatchNorm2d, | |
| 'fc3d': nn.BatchNorm3d, | |
| } | |
| def __init__(self, n_channels_in, layers): | |
| super(VGGNet, self).__init__() | |
| self.layers = layers | |
| n_channels_prev = n_channels_in | |
| for l_id, lr in enumerate(self.layers): | |
| l_id += 1 | |
| name = 'fc' if 'fc' in lr['type'] else 'conv' | |
| conv_type = self.conv_dict[lr['type']] | |
| norm_type = self.norm_dict[lr['type']] | |
| self.__setattr__( | |
| '{:s}{:d}'.format(name, l_id), | |
| conv_type(n_channels_prev, | |
| lr['n_channels'], | |
| kernel_size=lr['kernel_size'], | |
| stride=lr['stride'], | |
| padding=lr['padding'])) | |
| n_channels_prev = lr['n_channels'] | |
| self.__setattr__('bn{:d}'.format(l_id), norm_type(lr['n_channels'])) | |
| if 'maxpool' in lr: | |
| pool_type = self.pool_dict[lr['type']] | |
| padding = lr['maxpool']['padding'] if 'padding' in lr[ | |
| 'maxpool'] else 0 | |
| self.__setattr__( | |
| 'mp{:d}'.format(l_id), | |
| pool_type(kernel_size=lr['maxpool']['kernel_size'], | |
| stride=lr['maxpool']['stride'], | |
| padding=padding), | |
| ) | |
| def forward(self, inp): | |
| self.debug_line('Input', inp) | |
| out = inp | |
| for l_id, lr in enumerate(self.layers): | |
| l_id += 1 | |
| name = 'fc' if 'fc' in lr['type'] else 'conv' | |
| out = self.__getattr__('{:s}{:d}'.format(name, l_id))(out) | |
| out = self.__getattr__('bn{:d}'.format(l_id))(out) | |
| out = nn.ReLU(inplace=True)(out) | |
| self.debug_line(self.__getattr__('{:s}{:d}'.format(name, l_id)), | |
| out) | |
| if 'maxpool' in lr: | |
| out = self.__getattr__('mp{:d}'.format(l_id))(out) | |
| self.debug_line(self.__getattr__('mp{:d}'.format(l_id)), out) | |
| self.debug_line('Output', out, final_call=True) | |
| return out | |
| class NetFC(DebugModule): | |
| def __init__(self, input_dim, hidden_dim, embed_dim): | |
| super(NetFC, self).__init__() | |
| self.fc7 = nn.Conv3d(input_dim, hidden_dim, kernel_size=(1, 1, 1)) | |
| self.bn7 = nn.BatchNorm3d(hidden_dim) | |
| self.fc8 = nn.Conv3d(hidden_dim, embed_dim, kernel_size=(1, 1, 1)) | |
| def forward(self, inp): | |
| out = self.fc7(inp) | |
| self.debug_line(self.fc7, out) | |
| out = self.bn7(out) | |
| out = nn.ReLU(inplace=True)(out) | |
| out = self.fc8(out) | |
| self.debug_line(self.fc8, out, final_call=True) | |
| return out | |
| class NetFC_2D(DebugModule): | |
| def __init__(self, input_dim, hidden_dim, embed_dim): | |
| super(NetFC_2D, self).__init__() | |
| self.fc7 = nn.Conv2d(input_dim, hidden_dim, kernel_size=(1, 1)) | |
| self.bn7 = nn.BatchNorm2d(hidden_dim) | |
| self.fc8 = nn.Conv2d(hidden_dim, embed_dim, kernel_size=(1, 1)) | |
| def forward(self, inp): | |
| out = self.fc7(inp) | |
| self.debug_line(self.fc7, out) | |
| out = self.bn7(out) | |
| out = nn.ReLU(inplace=True)(out) | |
| out = self.fc8(out) | |
| self.debug_line(self.fc8, out, final_call=True) | |
| return out |