'''Module for making resnet encoders. ''' import torch import torch.nn as nn from cortex_DIM.nn_modules.convnet import Convnet from cortex_DIM.nn_modules.misc import Fold, Unfold, View _nonlin_idx = 6 class ResBlock(Convnet): '''Residual block for ResNet ''' def create_layers(self, shape, conv_args=None): '''Creates layers Args: shape: Shape of input. conv_args: Layer arguments for block. ''' # Move nonlinearity to a separate step for residual. final_nonlin = conv_args[-1][_nonlin_idx] conv_args[-1] = list(conv_args[-1]) conv_args[-1][_nonlin_idx] = None conv_args.append((None, 0, 0, 0, False, False, final_nonlin, None)) super().create_layers(shape, conv_args=conv_args) if self.conv_shape != shape: dim_x, dim_y, dim_in = shape dim_x_, dim_y_, dim_out = self.conv_shape stride = dim_x // dim_x_ next_x, _ = self.next_size(dim_x, dim_y, 1, stride, 0) assert next_x == dim_x_, (self.conv_shape, shape) self.downsample = nn.Sequential( nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=stride, padding=0, bias=False), nn.BatchNorm2d(dim_out), ) else: self.downsample = None def forward(self, x: torch.Tensor): '''Forward pass Args: x: Input. Returns: torch.Tensor or list of torch.Tensor. ''' if self.downsample is not None: residual = self.downsample(x) else: residual = x x = self.conv_layers[-1](self.conv_layers[:-1](x) + residual) return x class ResNet(Convnet): def create_layers(self, shape, conv_before_args=None, res_args=None, conv_after_args=None, fc_args=None): '''Creates layers Args: shape: Shape of the input. conv_before_args: Arguments for convolutional layers before residuals. res_args: Residual args. conv_after_args: Arguments for convolutional layers after residuals. fc_args: Fully-connected arguments. ''' dim_x, dim_y, dim_in = shape shape = (dim_x, dim_y, dim_in) self.conv_before_layers, self.conv_before_shape = self.create_conv_layers(shape, conv_before_args) self.res_layers, self.res_shape = self.create_res_layers(self.conv_before_shape, res_args) self.conv_after_layers, self.conv_after_shape = self.create_conv_layers(self.res_shape, conv_after_args) dim_x, dim_y, dim_out = self.conv_after_shape dim_r = dim_x * dim_y * dim_out self.reshape = View(-1, dim_r) self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) def create_res_layers(self, shape, block_args=None): '''Creates a set of residual blocks. Args: shape: input shape. block_args: Arguments for blocks. Returns: nn.Sequential: sequence of residual blocks. ''' res_layers = nn.Sequential() block_args = block_args or [] for i, (conv_args, n_blocks) in enumerate(block_args): block = ResBlock(shape, conv_args=conv_args) res_layers.add_module('block_{}_0'.format(i), block) for j in range(1, n_blocks): shape = block.conv_shape block = ResBlock(shape, conv_args=conv_args) res_layers.add_module('block_{}_{}'.format(i, j), block) shape = block.conv_shape return res_layers, shape def forward(self, x: torch.Tensor, return_full_list=False): '''Forward pass Args: x: Input. return_full_list: Optional, returns all layer outputs. Returns: torch.Tensor or list of torch.Tensor. ''' if return_full_list: conv_before_out = [] for conv_layer in self.conv_before_layers: x = conv_layer(x) conv_before_out.append(x) else: conv_before_out = self.conv_layers(x) x = conv_before_out if return_full_list: res_out = [] for res_layer in self.res_layers: x = res_layer(x) res_out.append(x) else: res_out = self.res_layers(x) x = res_out if return_full_list: conv_after_out = [] for conv_layer in self.conv_after_layers: x = conv_layer(x) conv_after_out.append(x) else: conv_after_out = self.conv_after_layers(x) x = conv_after_out x = self.reshape(x) if return_full_list: fc_out = [] for fc_layer in self.fc_layers: x = fc_layer(x) fc_out.append(x) else: fc_out = self.fc_layers(x) return conv_before_out, res_out, conv_after_out, fc_out class FoldedResNet(ResNet): '''Resnet with strided crop input. ''' def create_layers(self, shape, crop_size=8, conv_before_args=None, res_args=None, conv_after_args=None, fc_args=None): '''Creates layers Args: shape: Shape of the input. crop_size: Size of the crops. conv_before_args: Arguments for convolutional layers before residuals. res_args: Residual args. conv_after_args: Arguments for convolutional layers after residuals. fc_args: Fully-connected arguments. ''' self.crop_size = crop_size dim_x, dim_y, dim_in = shape self.final_size = 2 * (dim_x // self.crop_size) - 1 self.unfold = Unfold(dim_x, self.crop_size) self.refold = Fold(dim_x, self.crop_size) shape = (self.crop_size, self.crop_size, dim_in) self.conv_before_layers, self.conv_before_shape = self.create_conv_layers(shape, conv_before_args) self.res_layers, self.res_shape = self.create_res_layers(self.conv_before_shape, res_args) self.conv_after_layers, self.conv_after_shape = self.create_conv_layers(self.res_shape, conv_after_args) self.conv_after_shape = self.res_shape dim_x, dim_y, dim_out = self.conv_after_shape dim_r = dim_x * dim_y * dim_out self.reshape = View(-1, dim_r) self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) def create_res_layers(self, shape, block_args=None): '''Creates a set of residual blocks. Args: shape: input shape. block_args: Arguments for blocks. Returns: nn.Sequential: sequence of residual blocks. ''' res_layers = nn.Sequential() block_args = block_args or [] for i, (conv_args, n_blocks) in enumerate(block_args): block = ResBlock(shape, conv_args=conv_args) res_layers.add_module('block_{}_0'.format(i), block) for j in range(1, n_blocks): shape = block.conv_shape block = ResBlock(shape, conv_args=conv_args) res_layers.add_module('block_{}_{}'.format(i, j), block) shape = block.conv_shape dim_x, dim_y = shape[:2] if dim_x != dim_y: raise ValueError('dim_x and dim_y do not match.') if dim_x == 1: shape = (self.final_size, self.final_size, shape[2]) return res_layers, shape def forward(self, x: torch.Tensor, return_full_list=False): '''Forward pass Args: x: Input. return_full_list: Optional, returns all layer outputs. Returns: torch.Tensor or list of torch.Tensor. ''' x = self.unfold(x) conv_before_out = [] for conv_layer in self.conv_before_layers: x = conv_layer(x) if x.size(2) == 1: x = self.refold(x) conv_before_out.append(x) res_out = [] for res_layer in self.res_layers: x = res_layer(x) res_out.append(x) if x.size(2) == 1: x = self.refold(x) res_out[-1] = x conv_after_out = [] for conv_layer in self.conv_after_layers: x = conv_layer(x) if x.size(2) == 1: x = self.refold(x) conv_after_out.append(x) x = self.reshape(x) if return_full_list: fc_out = [] for fc_layer in self.fc_layers: x = fc_layer(x) fc_out.append(x) else: fc_out = self.fc_layers(x) if not return_full_list: conv_before_out = conv_before_out[-1] res_out = res_out[-1] conv_after_out = conv_after_out[-1] return conv_before_out, res_out, conv_after_out, fc_out