Sunday01's picture
up
9dce458
import torch.nn as nn
import torch
def constant_init(module, val, bias=0):
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0):
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def uniform_init(module, a=0, b=1, bias=0):
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def kaiming_init(module,
a=0,
is_rnn=False,
mode='fan_in',
nonlinearity='leaky_relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
if is_rnn:
for name, param in module.named_parameters():
if 'bias' in name:
nn.init.constant_(param, bias)
elif 'weight' in name:
nn.init.kaiming_uniform_(param,
a=a,
mode=mode,
nonlinearity=nonlinearity)
else:
nn.init.kaiming_uniform_(module.weight,
a=a,
mode=mode,
nonlinearity=nonlinearity)
else:
if is_rnn:
for name, param in module.named_parameters():
if 'bias' in name:
nn.init.constant_(param, bias)
elif 'weight' in name:
nn.init.kaiming_normal_(param,
a=a,
mode=mode,
nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(module.weight,
a=a,
mode=mode,
nonlinearity=nonlinearity)
if not is_rnn and hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = (torch.arange(kernel_size).reshape(-1, 1),
torch.arange(kernel_size).reshape(1, -1))
filt = (1 - torch.abs(og[0] - center) / factor) * \
(1 - torch.abs(og[1] - center) / factor)
weight = torch.zeros((in_channels, out_channels,
kernel_size, kernel_size))
weight[range(in_channels), range(out_channels), :, :] = filt
return weight
def init_weights(m):
# for m in modules:
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
xavier_init(m)
elif isinstance(m, (nn.LSTM, nn.LSTMCell)):
kaiming_init(m, is_rnn=True)
# elif isinstance(m, nn.ConvTranspose2d):
# m.weight.data.copy_(bilinear_kernel(m.in_channels, m.out_channels, 4));