KyanChen's picture
add
02c5426
raw
history blame
3.18 kB
import math
from argparse import Namespace
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import register
class gen_basis(nn.Module):
def __init__(self, args):
super(gen_basis, self).__init__()
self.basis_num = args.basis_num
self.hidden = args.hidden
self.state = args.state
self.path=args.path
def init_basis_bias(self):
self.w0 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*580), requires_grad=True)
nn.init.kaiming_uniform_(self.w0, a=math.sqrt(5))
self.w1 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True)
nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
self.w2 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True)
nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
self.w3 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True)
nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5))
self.w4 = nn.Parameter(torch.Tensor(self.basis_num,3*self.hidden), requires_grad=True)
nn.init.kaiming_uniform_(self.w4, a=math.sqrt(5))
basis = [self.w0, self.w1, self.w2, self.w3, self.w4]
self.bias1 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True)
self.bias2 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True)
self.bias3 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True)
self.bias4 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True)
self.bias5 = nn.Parameter(torch.Tensor(self.basis_num,3), requires_grad=True)
bias = [self.bias1,self.bias2,self.bias3,self.bias4,self.bias5]
for i in range(len(bias)):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(basis[i])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(bias[i], -bound, bound)
return basis,bias
def load_basis_for_test_kaiming(self,path):
model_spec = torch.load(path)['model']
w0 = model_spec['sd']['basis.w0']
w1 = model_spec['sd']['basis.w1']
w2 = model_spec['sd']['basis.w2']
w3 = model_spec['sd']['basis.w3']
w4 = model_spec['sd']['basis.w4']
b0 = model_spec['sd']['basis.bias1']
b1 = model_spec['sd']['basis.bias2']
b2 = model_spec['sd']['basis.bias3']
b3 = model_spec['sd']['basis.bias4']
b4 = model_spec['sd']['basis.bias5']
torch.cuda.empty_cache()
return [w0,w1,w2,w3,w4],[b0,b1,b2,b3,b4]
def forward(self):
if self.state=='train':
print('init_basis_use_kaiming')
res=self.init_basis_bias()
else:
print('load_basis_from_model')
res=self.load_basis_for_test_kaiming(self.path)
return res
@register('basis')
def make_basis(basis_num=10,hidden=16,state=None,path=None):
args = Namespace()
args.basis_num = basis_num
args.hidden = hidden
args.state = state
args.path = path
return gen_basis(args)