|
import torch.nn as nn
|
|
import torch
|
|
|
|
def constant_init(module, val, bias=0):
|
|
nn.init.constant_(module.weight, val)
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
def xavier_init(module, gain=1, bias=0, distribution='normal'):
|
|
assert distribution in ['uniform', 'normal']
|
|
if distribution == 'uniform':
|
|
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
else:
|
|
nn.init.xavier_normal_(module.weight, gain=gain)
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
def normal_init(module, mean=0, std=1, bias=0):
|
|
nn.init.normal_(module.weight, mean, std)
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
def uniform_init(module, a=0, b=1, bias=0):
|
|
nn.init.uniform_(module.weight, a, b)
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
def kaiming_init(module,
|
|
a=0,
|
|
is_rnn=False,
|
|
mode='fan_in',
|
|
nonlinearity='leaky_relu',
|
|
bias=0,
|
|
distribution='normal'):
|
|
assert distribution in ['uniform', 'normal']
|
|
if distribution == 'uniform':
|
|
if is_rnn:
|
|
for name, param in module.named_parameters():
|
|
if 'bias' in name:
|
|
nn.init.constant_(param, bias)
|
|
elif 'weight' in name:
|
|
nn.init.kaiming_uniform_(param,
|
|
a=a,
|
|
mode=mode,
|
|
nonlinearity=nonlinearity)
|
|
else:
|
|
nn.init.kaiming_uniform_(module.weight,
|
|
a=a,
|
|
mode=mode,
|
|
nonlinearity=nonlinearity)
|
|
|
|
else:
|
|
if is_rnn:
|
|
for name, param in module.named_parameters():
|
|
if 'bias' in name:
|
|
nn.init.constant_(param, bias)
|
|
elif 'weight' in name:
|
|
nn.init.kaiming_normal_(param,
|
|
a=a,
|
|
mode=mode,
|
|
nonlinearity=nonlinearity)
|
|
else:
|
|
nn.init.kaiming_normal_(module.weight,
|
|
a=a,
|
|
mode=mode,
|
|
nonlinearity=nonlinearity)
|
|
|
|
if not is_rnn and hasattr(module, 'bias') and module.bias is not None:
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
def bilinear_kernel(in_channels, out_channels, kernel_size):
|
|
factor = (kernel_size + 1) // 2
|
|
if kernel_size % 2 == 1:
|
|
center = factor - 1
|
|
else:
|
|
center = factor - 0.5
|
|
og = (torch.arange(kernel_size).reshape(-1, 1),
|
|
torch.arange(kernel_size).reshape(1, -1))
|
|
filt = (1 - torch.abs(og[0] - center) / factor) * \
|
|
(1 - torch.abs(og[1] - center) / factor)
|
|
weight = torch.zeros((in_channels, out_channels,
|
|
kernel_size, kernel_size))
|
|
weight[range(in_channels), range(out_channels), :, :] = filt
|
|
return weight
|
|
|
|
|
|
def init_weights(m):
|
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
kaiming_init(m)
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
constant_init(m, 1)
|
|
elif isinstance(m, nn.Linear):
|
|
xavier_init(m)
|
|
elif isinstance(m, (nn.LSTM, nn.LSTMCell)):
|
|
kaiming_init(m, is_rnn=True)
|
|
|
|
|
|
|