Spaces:
Runtime error
Runtime error
'''Module for making resnet encoders. | |
''' | |
import torch | |
import torch.nn as nn | |
from cortex_DIM.nn_modules.convnet import Convnet | |
from cortex_DIM.nn_modules.misc import Fold, Unfold, View | |
_nonlin_idx = 6 | |
class ResBlock(Convnet): | |
'''Residual block for ResNet | |
''' | |
def create_layers(self, shape, conv_args=None): | |
'''Creates layers | |
Args: | |
shape: Shape of input. | |
conv_args: Layer arguments for block. | |
''' | |
# Move nonlinearity to a separate step for residual. | |
final_nonlin = conv_args[-1][_nonlin_idx] | |
conv_args[-1] = list(conv_args[-1]) | |
conv_args[-1][_nonlin_idx] = None | |
conv_args.append((None, 0, 0, 0, False, False, final_nonlin, None)) | |
super().create_layers(shape, conv_args=conv_args) | |
if self.conv_shape != shape: | |
dim_x, dim_y, dim_in = shape | |
dim_x_, dim_y_, dim_out = self.conv_shape | |
stride = dim_x // dim_x_ | |
next_x, _ = self.next_size(dim_x, dim_y, 1, stride, 0) | |
assert next_x == dim_x_, (self.conv_shape, shape) | |
self.downsample = nn.Sequential( | |
nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=stride, padding=0, bias=False), | |
nn.BatchNorm2d(dim_out), | |
) | |
else: | |
self.downsample = None | |
def forward(self, x: torch.Tensor): | |
'''Forward pass | |
Args: | |
x: Input. | |
Returns: | |
torch.Tensor or list of torch.Tensor. | |
''' | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
else: | |
residual = x | |
x = self.conv_layers[-1](self.conv_layers[:-1](x) + residual) | |
return x | |
class ResNet(Convnet): | |
def create_layers(self, shape, conv_before_args=None, res_args=None, conv_after_args=None, fc_args=None): | |
'''Creates layers | |
Args: | |
shape: Shape of the input. | |
conv_before_args: Arguments for convolutional layers before residuals. | |
res_args: Residual args. | |
conv_after_args: Arguments for convolutional layers after residuals. | |
fc_args: Fully-connected arguments. | |
''' | |
dim_x, dim_y, dim_in = shape | |
shape = (dim_x, dim_y, dim_in) | |
self.conv_before_layers, self.conv_before_shape = self.create_conv_layers(shape, conv_before_args) | |
self.res_layers, self.res_shape = self.create_res_layers(self.conv_before_shape, res_args) | |
self.conv_after_layers, self.conv_after_shape = self.create_conv_layers(self.res_shape, conv_after_args) | |
dim_x, dim_y, dim_out = self.conv_after_shape | |
dim_r = dim_x * dim_y * dim_out | |
self.reshape = View(-1, dim_r) | |
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) | |
def create_res_layers(self, shape, block_args=None): | |
'''Creates a set of residual blocks. | |
Args: | |
shape: input shape. | |
block_args: Arguments for blocks. | |
Returns: | |
nn.Sequential: sequence of residual blocks. | |
''' | |
res_layers = nn.Sequential() | |
block_args = block_args or [] | |
for i, (conv_args, n_blocks) in enumerate(block_args): | |
block = ResBlock(shape, conv_args=conv_args) | |
res_layers.add_module('block_{}_0'.format(i), block) | |
for j in range(1, n_blocks): | |
shape = block.conv_shape | |
block = ResBlock(shape, conv_args=conv_args) | |
res_layers.add_module('block_{}_{}'.format(i, j), block) | |
shape = block.conv_shape | |
return res_layers, shape | |
def forward(self, x: torch.Tensor, return_full_list=False): | |
'''Forward pass | |
Args: | |
x: Input. | |
return_full_list: Optional, returns all layer outputs. | |
Returns: | |
torch.Tensor or list of torch.Tensor. | |
''' | |
if return_full_list: | |
conv_before_out = [] | |
for conv_layer in self.conv_before_layers: | |
x = conv_layer(x) | |
conv_before_out.append(x) | |
else: | |
conv_before_out = self.conv_layers(x) | |
x = conv_before_out | |
if return_full_list: | |
res_out = [] | |
for res_layer in self.res_layers: | |
x = res_layer(x) | |
res_out.append(x) | |
else: | |
res_out = self.res_layers(x) | |
x = res_out | |
if return_full_list: | |
conv_after_out = [] | |
for conv_layer in self.conv_after_layers: | |
x = conv_layer(x) | |
conv_after_out.append(x) | |
else: | |
conv_after_out = self.conv_after_layers(x) | |
x = conv_after_out | |
x = self.reshape(x) | |
if return_full_list: | |
fc_out = [] | |
for fc_layer in self.fc_layers: | |
x = fc_layer(x) | |
fc_out.append(x) | |
else: | |
fc_out = self.fc_layers(x) | |
return conv_before_out, res_out, conv_after_out, fc_out | |
class FoldedResNet(ResNet): | |
'''Resnet with strided crop input. | |
''' | |
def create_layers(self, shape, crop_size=8, conv_before_args=None, res_args=None, | |
conv_after_args=None, fc_args=None): | |
'''Creates layers | |
Args: | |
shape: Shape of the input. | |
crop_size: Size of the crops. | |
conv_before_args: Arguments for convolutional layers before residuals. | |
res_args: Residual args. | |
conv_after_args: Arguments for convolutional layers after residuals. | |
fc_args: Fully-connected arguments. | |
''' | |
self.crop_size = crop_size | |
dim_x, dim_y, dim_in = shape | |
self.final_size = 2 * (dim_x // self.crop_size) - 1 | |
self.unfold = Unfold(dim_x, self.crop_size) | |
self.refold = Fold(dim_x, self.crop_size) | |
shape = (self.crop_size, self.crop_size, dim_in) | |
self.conv_before_layers, self.conv_before_shape = self.create_conv_layers(shape, conv_before_args) | |
self.res_layers, self.res_shape = self.create_res_layers(self.conv_before_shape, res_args) | |
self.conv_after_layers, self.conv_after_shape = self.create_conv_layers(self.res_shape, conv_after_args) | |
self.conv_after_shape = self.res_shape | |
dim_x, dim_y, dim_out = self.conv_after_shape | |
dim_r = dim_x * dim_y * dim_out | |
self.reshape = View(-1, dim_r) | |
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) | |
def create_res_layers(self, shape, block_args=None): | |
'''Creates a set of residual blocks. | |
Args: | |
shape: input shape. | |
block_args: Arguments for blocks. | |
Returns: | |
nn.Sequential: sequence of residual blocks. | |
''' | |
res_layers = nn.Sequential() | |
block_args = block_args or [] | |
for i, (conv_args, n_blocks) in enumerate(block_args): | |
block = ResBlock(shape, conv_args=conv_args) | |
res_layers.add_module('block_{}_0'.format(i), block) | |
for j in range(1, n_blocks): | |
shape = block.conv_shape | |
block = ResBlock(shape, conv_args=conv_args) | |
res_layers.add_module('block_{}_{}'.format(i, j), block) | |
shape = block.conv_shape | |
dim_x, dim_y = shape[:2] | |
if dim_x != dim_y: | |
raise ValueError('dim_x and dim_y do not match.') | |
if dim_x == 1: | |
shape = (self.final_size, self.final_size, shape[2]) | |
return res_layers, shape | |
def forward(self, x: torch.Tensor, return_full_list=False): | |
'''Forward pass | |
Args: | |
x: Input. | |
return_full_list: Optional, returns all layer outputs. | |
Returns: | |
torch.Tensor or list of torch.Tensor. | |
''' | |
x = self.unfold(x) | |
conv_before_out = [] | |
for conv_layer in self.conv_before_layers: | |
x = conv_layer(x) | |
if x.size(2) == 1: | |
x = self.refold(x) | |
conv_before_out.append(x) | |
res_out = [] | |
for res_layer in self.res_layers: | |
x = res_layer(x) | |
res_out.append(x) | |
if x.size(2) == 1: | |
x = self.refold(x) | |
res_out[-1] = x | |
conv_after_out = [] | |
for conv_layer in self.conv_after_layers: | |
x = conv_layer(x) | |
if x.size(2) == 1: | |
x = self.refold(x) | |
conv_after_out.append(x) | |
x = self.reshape(x) | |
if return_full_list: | |
fc_out = [] | |
for fc_layer in self.fc_layers: | |
x = fc_layer(x) | |
fc_out.append(x) | |
else: | |
fc_out = self.fc_layers(x) | |
if not return_full_list: | |
conv_before_out = conv_before_out[-1] | |
res_out = res_out[-1] | |
conv_after_out = conv_after_out[-1] | |
return conv_before_out, res_out, conv_after_out, fc_out | |