LuyangZ's picture
Upload 30 files
01df1d6 verified
'''Convnet encoder module.
'''
import torch
import torch.nn as nn
#from cortex.built_ins.networks.utils import get_nonlinearity
from cortex_DIM.nn_modules.misc import Fold, Unfold, View
def infer_conv_size(w, k, s, p):
'''Infers the next size after convolution.
Args:
w: Input size.
k: Kernel size.
s: Stride.
p: Padding.
Returns:
int: Output size.
'''
x = (w - k + 2 * p) // s + 1
return x
class Convnet(nn.Module):
'''Basic convnet convenience class.
Attributes:
conv_layers: nn.Sequential of nn.Conv2d layers with batch norm,
dropout, nonlinearity.
fc_layers: nn.Sequential of nn.Linear layers with batch norm,
dropout, nonlinearity.
reshape: Simple reshape layer.
conv_shape: Shape of the convolutional output.
'''
def __init__(self, *args, **kwargs):
super().__init__()
self.create_layers(*args, **kwargs)
def create_layers(self, shape, conv_args=None, fc_args=None):
'''Creates layers
conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool)
fc_args are in format (dim_h, batch_norm, dropout, nonlinearity)
Args:
shape: Shape of input.
conv_args: List of tuple of convolutional arguments.
fc_args: List of tuple of fully-connected arguments.
'''
self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args)
dim_x, dim_y, dim_out = self.conv_shape
dim_r = dim_x * dim_y * dim_out
self.reshape = View(-1, dim_r)
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)
def create_conv_layers(self, shape, conv_args):
'''Creates a set of convolutional layers.
Args:
shape: Input shape.
conv_args: List of tuple of convolutional arguments.
Returns:
nn.Sequential: a sequence of convolutional layers.
'''
conv_layers = nn.Sequential()
conv_args = conv_args or []
dim_x, dim_y, dim_in = shape
for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args):
name = '({}/{})_{}'.format(dim_in, dim_out, i + 1)
conv_block = nn.Sequential()
if dim_out is not None:
conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm))
conv_block.add_module(name + 'conv', conv)
dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p)
else:
dim_out = dim_in
if dropout:
conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout))
if batch_norm:
bn = nn.BatchNorm2d(dim_out)
conv_block.add_module(name + 'bn', bn)
if nonlinearity:
nonlinearity = get_nonlinearity(nonlinearity)
conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity)
if pool:
(pool_type, kernel, stride) = pool
Pool = getattr(nn, pool_type)
conv_block.add_module(name + 'pool', Pool(kernel_size=kernel, stride=stride))
dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0)
conv_layers.add_module(name, conv_block)
dim_in = dim_out
dim_out = dim_in
return conv_layers, (dim_x, dim_y, dim_out)
def create_linear_layers(self, dim_in, fc_args):
'''
Args:
dim_in: Number of input units.
fc_args: List of tuple of fully-connected arguments.
Returns:
nn.Sequential.
'''
fc_layers = nn.Sequential()
fc_args = fc_args or []
for i, (dim_out, batch_norm, dropout, nonlinearity) in enumerate(fc_args):
name = '({}/{})_{}'.format(dim_in, dim_out, i + 1)
fc_block = nn.Sequential()
if dim_out is not None:
fc_block.add_module(name + 'fc', nn.Linear(dim_in, dim_out))
else:
dim_out = dim_in
if dropout:
fc_block.add_module(name + 'do', nn.Dropout(p=dropout))
if batch_norm:
bn = nn.BatchNorm1d(dim_out)
fc_block.add_module(name + 'bn', bn)
if nonlinearity:
nonlinearity = get_nonlinearity(nonlinearity)
fc_block.add_module(nonlinearity.__class__.__name__, nonlinearity)
fc_layers.add_module(name, fc_block)
dim_in = dim_out
return fc_layers, dim_in
def next_size(self, dim_x, dim_y, k, s, p):
'''Infers the next size of a convolutional layer.
Args:
dim_x: First dimension.
dim_y: Second dimension.
k: Kernel size.
s: Stride.
p: Padding.
Returns:
(int, int): (First output dimension, Second output dimension)
'''
if isinstance(k, int):
kx, ky = (k, k)
else:
kx, ky = k
if isinstance(s, int):
sx, sy = (s, s)
else:
sx, sy = s
if isinstance(p, int):
px, py = (p, p)
else:
px, py = p
return (infer_conv_size(dim_x, kx, sx, px),
infer_conv_size(dim_y, ky, sy, py))
def forward(self, x: torch.Tensor, return_full_list=False):
'''Forward pass
Args:
x: Input.
return_full_list: Optional, returns all layer outputs.
Returns:
torch.Tensor or list of torch.Tensor.
'''
if return_full_list:
conv_out = []
for conv_layer in self.conv_layers:
x = conv_layer(x)
conv_out.append(x)
else:
conv_out = self.conv_layers(x)
x = conv_out
x = self.reshape(x)
if return_full_list:
fc_out = []
for fc_layer in self.fc_layers:
x = fc_layer(x)
fc_out.append(x)
else:
fc_out = self.fc_layers(x)
return conv_out, fc_out
class FoldedConvnet(Convnet):
'''Convnet with strided crop input.
'''
def create_layers(self, shape, crop_size=8, conv_args=None, fc_args=None):
'''Creates layers
conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool)
fc_args are in format (dim_h, batch_norm, dropout, nonlinearity)
Args:
shape: Shape of input.
crop_size: Size of crops
conv_args: List of tuple of convolutional arguments.
fc_args: List of tuple of fully-connected arguments.
'''
self.crop_size = crop_size
dim_x, dim_y, dim_in = shape
if dim_x != dim_y:
raise ValueError('x and y dimensions must be the same to use Folded encoders.')
self.final_size = 2 * (dim_x // self.crop_size) - 1
self.unfold = Unfold(dim_x, self.crop_size)
self.refold = Fold(dim_x, self.crop_size)
shape = (self.crop_size, self.crop_size, dim_in)
self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args)
dim_x, dim_y, dim_out = self.conv_shape
dim_r = dim_x * dim_y * dim_out
self.reshape = View(-1, dim_r)
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)
def create_conv_layers(self, shape, conv_args):
'''Creates a set of convolutional layers.
Args:
shape: Input shape.
conv_args: List of tuple of convolutional arguments.
Returns:
nn.Sequential: A sequence of convolutional layers.
'''
conv_layers = nn.Sequential()
conv_args = conv_args or []
dim_x, dim_y, dim_in = shape
for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args):
name = '({}/{})_{}'.format(dim_in, dim_out, i + 1)
conv_block = nn.Sequential()
if dim_out is not None:
conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm))
conv_block.add_module(name + 'conv', conv)
dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p)
else:
dim_out = dim_in
if dropout:
conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout))
if batch_norm:
bn = nn.BatchNorm2d(dim_out)
conv_block.add_module(name + 'bn', bn)
if nonlinearity:
nonlinearity = get_nonlinearity(nonlinearity)
conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity)
if pool:
(pool_type, kernel, stride) = pool
Pool = getattr(nn, pool_type)
conv_block.add_module('pool', Pool(kernel_size=kernel, stride=stride))
dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0)
conv_layers.add_module(name, conv_block)
dim_in = dim_out
if dim_x != dim_y:
raise ValueError('dim_x and dim_y do not match.')
if dim_x == 1:
dim_x = self.final_size
dim_y = self.final_size
dim_out = dim_in
return conv_layers, (dim_x, dim_y, dim_out)
def forward(self, x: torch.Tensor, return_full_list=False):
'''Forward pass
Args:
x: Input.
return_full_list: Optional, returns all layer outputs.
Returns:
torch.Tensor or list of torch.Tensor.
'''
x = self.unfold(x)
conv_out = []
for conv_layer in self.conv_layers:
x = conv_layer(x)
if x.size(2) == 1:
x = self.refold(x)
conv_out.append(x)
x = self.reshape(x)
if return_full_list:
fc_out = []
for fc_layer in self.fc_layers:
x = fc_layer(x)
fc_out.append(x)
else:
fc_out = self.fc_layers(x)
if not return_full_list:
conv_out = conv_out[-1]
return conv_out, fc_out