|
import math |
|
from argparse import Namespace |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from models import register |
|
|
|
|
|
class gen_basis(nn.Module): |
|
def __init__(self, args): |
|
super(gen_basis, self).__init__() |
|
self.basis_num = args.basis_num |
|
self.hidden = args.hidden |
|
self.state = args.state |
|
self.path=args.path |
|
|
|
def init_basis_bias(self): |
|
self.w0 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*580), requires_grad=True) |
|
nn.init.kaiming_uniform_(self.w0, a=math.sqrt(5)) |
|
self.w1 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True) |
|
nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) |
|
self.w2 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True) |
|
nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) |
|
self.w3 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True) |
|
nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5)) |
|
self.w4 = nn.Parameter(torch.Tensor(self.basis_num,3*self.hidden), requires_grad=True) |
|
nn.init.kaiming_uniform_(self.w4, a=math.sqrt(5)) |
|
basis = [self.w0, self.w1, self.w2, self.w3, self.w4] |
|
self.bias1 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True) |
|
self.bias2 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True) |
|
self.bias3 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True) |
|
self.bias4 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True) |
|
self.bias5 = nn.Parameter(torch.Tensor(self.basis_num,3), requires_grad=True) |
|
bias = [self.bias1,self.bias2,self.bias3,self.bias4,self.bias5] |
|
|
|
for i in range(len(bias)): |
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(basis[i]) |
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
|
nn.init.uniform_(bias[i], -bound, bound) |
|
|
|
|
|
|
|
return basis,bias |
|
|
|
|
|
def load_basis_for_test_kaiming(self,path): |
|
model_spec = torch.load(path)['model'] |
|
w0 = model_spec['sd']['basis.w0'] |
|
w1 = model_spec['sd']['basis.w1'] |
|
w2 = model_spec['sd']['basis.w2'] |
|
w3 = model_spec['sd']['basis.w3'] |
|
w4 = model_spec['sd']['basis.w4'] |
|
b0 = model_spec['sd']['basis.bias1'] |
|
b1 = model_spec['sd']['basis.bias2'] |
|
b2 = model_spec['sd']['basis.bias3'] |
|
b3 = model_spec['sd']['basis.bias4'] |
|
b4 = model_spec['sd']['basis.bias5'] |
|
torch.cuda.empty_cache() |
|
return [w0,w1,w2,w3,w4],[b0,b1,b2,b3,b4] |
|
|
|
def forward(self): |
|
if self.state=='train': |
|
print('init_basis_use_kaiming') |
|
res=self.init_basis_bias() |
|
else: |
|
print('load_basis_from_model') |
|
res=self.load_basis_for_test_kaiming(self.path) |
|
return res |
|
|
|
@register('basis') |
|
def make_basis(basis_num=10,hidden=16,state=None,path=None): |
|
args = Namespace() |
|
args.basis_num = basis_num |
|
args.hidden = hidden |
|
args.state = state |
|
args.path = path |
|
return gen_basis(args) |
|
|