summary / fengshen /models /GAVAE /gans_model.py
fclong's picture
Upload 396 files
8ebda9e
raw
history blame
14.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
class MyDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
self.len = self.x.size(0)
def __getitem__(self, index):
return self.x[index], self.y[index]
def __len__(self):
return self.len
class MyDataset_new(Dataset):
def __init__(self, x, y, s):
self.x = x
self.y = y
self.s = s
self.len = self.x.size(0)
def __getitem__(self, index):
return self.x[index], self.y[index], self.s[index]
def __len__(self):
return self.len
class CLS_Net(torch.nn.Module):
def __init__(self, cls_num, z_dim, cls_batch_size):
super(CLS_Net, self).__init__()
mini_dim = 256 #256
out_input_num = mini_dim
base_dim = 64 #256 #64
self.cls_batch_size = cls_batch_size
self.jie = 1
self.fc1 = nn.Linear(z_dim, mini_dim)
self.fc1.weight.data.normal_(0, 0.1)
self.fc2 = nn.Linear(out_input_num, base_dim)
self.fc2.weight.data.normal_(0, 0.1)
self.out = nn.Linear(base_dim, cls_num)
self.out.weight.data.normal_(0, 0.1)
def self_dis(self, a):
max_dim = self.cls_batch_size
jie = self.jie
all_tag = False
for j in range(a.shape[0]):
col_tag = False
for i in range(a.shape[0]):
tmp = F.pairwise_distance(a[j,:], a[i,:] , p = jie).view(-1,1)
if col_tag == False:
col_dis = tmp
col_tag = True
else:
col_dis = torch.cat((col_dis, tmp), dim = 0)
if all_tag == False:
all_dis = col_dis
all_tag = True
else:
all_dis = torch.cat((all_dis, col_dis), dim = 1)
'''
print(all_dis.shape)
if all_dis.shape[1] < max_dim:
all_dis = torch.cat((all_dis, all_dis[:,:(max_dim - all_dis.shape[1])]), dim = 1)
print(all_dis.shape)
'''
return all_dis
def forward(self, x):
x = self.fc1(x)
x1 = F.relu(x)
x2 = self.fc2(x1)
x2 = torch.nn.Dropout(0.1)(x2) #0.3
x2 = F.relu(x2)
y = self.out(x2)
return y, x1
class Gen_Net(torch.nn.Module):
def __init__(self,input_x2_dim, output_dim):
super(Gen_Net, self).__init__()
self.x2_input = nn.Linear(input_x2_dim , 60)
self.x2_input.weight.data.normal_(0, 0.1)
self.fc1 = nn.Linear(60, 128)
self.fc1.weight.data.normal_(0, 0.1)
self.fc2 = nn.Linear(128, 256)
self.fc2.weight.data.normal_(0, 0.1)
self.fc3 = nn.Linear(256, 128)
self.fc3.weight.data.normal_(0, 0.1)
self.out = nn.Linear(128, output_dim)
self.out.weight.data.normal_(0, 0.1)
def forward(self,x2):
x2 = self.x2_input(x2)
x = x2
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
x = F.relu(x)
y = self.out(x)
return y
class gans_process():
def __init__(self, config):
#base pare
self.device = config.device
self.cls_num = config.cls_num
self.x2_dim = config.noise_dim
self.z_dim = config.z_dim
self.cls_lr = config.cls_lr
self.gen_lr = config.gen_lr
self.cls_epoches = config.cls_epoches
self.gen_epoches = config.gen_epoches
self.mse_weight = 1.0
self.cls_batch_size = config.cls_batch_size
self.gen_batch_size = config.gen_batch_size
self.eval_batch_size = config.cls_batch_size
self.gen_batch_size = self.cls_batch_size
#optimer and net
self.cls_net = CLS_Net(self.cls_num, self.z_dim, self.cls_batch_size).to(self.device)
self.cls_optimizer = torch.optim.SGD(self.cls_net.parameters(),
lr = self.cls_lr , weight_decay= 1e-5)
# gen net
self.gen_net = Gen_Net(self.x2_dim, self.z_dim).to(self.device)
self.gen_optimizer = torch.optim.SGD(self.gen_net.parameters(),
lr = self.gen_lr , weight_decay= 0.01)
#base loss
self.loss_func = torch.nn.CrossEntropyLoss()
self.loss_mse = torch.nn.MSELoss()
def freeze_cls(self):
for param in self.cls_net.parameters():
param.requires_grad = False
def unfreeze_cls(self):
for param in self.cls_net.parameters():
param.requires_grad = True
def freeze_gen(self):
for param in self.gen_net.parameters():
param.requires_grad = False
def unfreeze_gen(self):
for param in self.gen_net.parameters():
param.requires_grad = True
def labels2genx(self, sample_num):
x = torch.rand(sample_num, self.x2_dim)
return x.to(self.device)
def pad_batch(self, x):
if int(x.shape[0] % self.cls_batch_size) == 0:
return x
pad_len = self.cls_batch_size - ( x.shape[0] % self.cls_batch_size)
x = torch.cat((x, x[:pad_len]), dim = 0)
return x
def ready_cls(self, sent_output,perm=None):
sample_num = len(sent_output)
#---------------make fake z---------------
sent_output = sent_output.to(self.device)
sent_noise = torch.tensor(self.gen_test(sample_num)).to(self.device)
#--------------handle datas---------------
x = torch.cat((sent_output, sent_noise), dim = 0 )
if perm is None:
perm = torch.randperm(len(x))
x = x[perm]
#add y - only one label per time
multi_label_num = 1
multi_output_y = torch.tensor([0]*sample_num).unsqueeze(1)
multi_noise_y = torch.zeros([sent_noise.size(0),1], dtype = torch.int)
multi_noise_y = multi_noise_y + multi_label_num
y = torch.cat((multi_output_y, multi_noise_y), dim = 0).to(self.device)
y = y[perm]
# x_train = x [:self.train_len]
# y_train = y [:self.train_len]
# x_test = x [self.train_len:]
# y_test = y [self.train_len:]
return x,y,None,None,perm
def ready_fake(self, sent_output, inputs_labels, inputs_indexs, label2id, perm = None):
#---------------make fake z---------------
sent_output = sent_output.to(self.device)
sent_noise = torch.tensor(self.gen_test(inputs_labels, inputs_indexs)).to(self.device)
#--------------handle datas---------------
x = sent_noise
y = torch.tensor(inputs_labels).unsqueeze(1)
if perm is None:
perm = torch.randperm(len(x))
x = x[perm]
y = y[perm]
return x,y,perm
def ready_gen(self, sent_output):
#, inputs_labels, inputs_indexs
sent_num = len(sent_output)
sent_output = sent_output.to(self.device)
x2 = self.labels2genx(sent_num)
y = torch.tensor([0]*sent_num).unsqueeze(1).to(self.device)
return x2, y, sent_output
def cls_train(self, x, y, if_oneHot = True):
#init
self.cls_net.train()
self.gen_net.eval()
self.unfreeze_cls()
self.freeze_gen()
x = x.to(self.device)
y = y.to(self.device)
#if oneHot
if if_oneHot:
y = torch.zeros(y.size(0), self.cls_num).to(self.device).scatter_(1, y.long(), 1)
#make dataset
mydataset = MyDataset(x, y)
train_loader = DataLoader(dataset=mydataset,
batch_size=self.cls_batch_size, shuffle=True)
#training
for epoch in range(self.cls_epoches):
losses = []
accuracy = []
for step, (batch_x, batch_y) in enumerate(train_loader):
self.cls_optimizer.zero_grad()
out, _ = self.cls_net(batch_x)
loss = self.loss_func(out, batch_y)
#One-side label smoothing -not used
#location 0 real, location 1 fake
batch_y = batch_y * torch.tensor([0.9, 1.0]).to(self.device)
loss.backward()
self.cls_optimizer.step()
#tqdm
_, predictions = out.max(1)
predictions = predictions.cpu().numpy().tolist()
_,real_y = batch_y.max(1)
real_y = real_y.cpu().numpy().tolist()
num_correct = np.sum([int(x==y) for x,y in zip(predictions, real_y)])
running_train_acc = float(num_correct) / float(batch_x.shape[0])
losses.append(loss)
accuracy.append(running_train_acc)
return self.cls_net
def cls_eval(self, x, y, if_oneHot = True):
#init
self.cls_net.eval()
x = x.to(self.device)
y = y.to(self.device)
#if oneHot
if if_oneHot:
y = torch.zeros(y.size(0), self.cls_num).to(self.device).scatter_(1, y.long(), 1)
#make dataset
mydataset = MyDataset(x, y)
train_loader = DataLoader(dataset=mydataset,
batch_size=self.eval_batch_size, shuffle=False)
losses = []
accuracy = []
#evaling
for step, (batch_x, batch_y) in enumerate(train_loader):
out,_ = self.cls_net(batch_x)
loss = self.loss_func(out, batch_y)
#tqdm
_, predictions = out.max(1)
predictions = predictions.cpu().numpy().tolist()
_,real_y = batch_y.max(1)
real_y = real_y.cpu().numpy().tolist()
num_correct = np.sum([int(x==y) for x,y in zip(predictions, real_y)])
running_train_acc = float(num_correct) / float(batch_x.shape[0])
accuracy.append(running_train_acc)
mean_acc = np.mean(accuracy)
return mean_acc
def cls_real_eval(self, x, y, if_oneHot = True):
#init
self.cls_net.eval()
x = x.to(self.device)
y = y.to(self.device)
#if oneHot
if if_oneHot:
y = torch.zeros(y.size(0), self.cls_num).to(self.device).scatter_(1, y.long(), 1)
#make dataset
mydataset = MyDataset(x, y)
train_loader = DataLoader(dataset=mydataset,
batch_size=self.eval_batch_size, shuffle=False)
rs = 0
alls = 0
#evaling
for step, (batch_x, batch_y) in enumerate(train_loader):
out, _ = self.cls_net(batch_x)
loss = self.loss_func(out, batch_y)
#tqdm
_, predictions = out.max(1)
predictions = predictions.cpu().numpy().tolist()
_,real_y = batch_y.max(1)
real_y = real_y.cpu().numpy().tolist()
right_num = np.sum([int( x==y and int(y) != int(self.cls_num-1) ) for x,y in zip(predictions, real_y)])
all_num = np.sum([int(int(y) != int(self.cls_num-1) ) for x,y in zip(predictions, real_y)])
rs = rs + right_num
alls = alls + all_num
return rs/alls
def cls_test(self, x, if_oneHot = True):
#init
self.cls_net.eval()
x = x.to(self.device)
y = torch.zeros([x.size(0),1], dtype = torch.float).to(self.device)
#if oneHot
if if_oneHot:
y = torch.zeros(y.size(0), self.cls_num).to(self.device).scatter_(1, y.long(), 1)
#make dataset
mydataset = MyDataset(x, y)
train_loader = DataLoader(dataset=mydataset,
batch_size=self.eval_batch_size, shuffle=False)
preds = []
#testing
for step, (batch_x, batch_y) in enumerate(train_loader):
out, _ = self.cls_net(batch_x)
loss = self.loss_func(out, batch_y)
#tqdm
_, predictions = out.max(1)
predictions = predictions.cpu().numpy().tolist()
preds.extend(predictions)
return preds
def gen_train(self, x2, y, s, times):
#init
self.cls_net.eval()
self.gen_net.train()
self.freeze_cls()
self.unfreeze_gen()
#y is gen + cls
y = torch.zeros(y.size(0), self.cls_num).to(self.device).scatter_(1, y.long(), 1)
#make dataset
mydataset = MyDataset_new(x2, y, s)
train_loader = DataLoader(dataset=mydataset,
batch_size=self.gen_batch_size, shuffle=True)
#training
for epoch in range(self.gen_epoches):
losses = []
accuracy = []
for step, (batch_x2, batch_y, batch_s) in enumerate(train_loader):
# no zero_grad = make batch_size
if step % 6 == 5: #23
self.gen_optimizer.zero_grad()
out = self.gen_net(batch_x2)
#fearture matching
out, hds = self.cls_net(out)
out2, hds2 = self.cls_net(batch_s.float())
loss = self.loss_mse(hds, hds2)
loss = loss * pow(0.9, times)
loss.backward()
self.gen_optimizer.step()
#tqdm
_, predictions = out.max(1)
predictions = predictions.cpu().numpy().tolist()
_, real_y = batch_y.max(1)
real_y = real_y.cpu().numpy().tolist()
num_correct = np.sum([int(x==y) for x,y in zip(predictions, real_y)])
running_train_acc = float(num_correct) / float(batch_x2.shape[0])
losses.append(loss)
accuracy.append(running_train_acc)
return self.gen_net
def gen_test(self, sample_num):
#init
self.gen_net.eval()
x2 = self.labels2genx(sample_num)
#x2: len(inputs_labels) * 80
y = torch.zeros([sample_num,1], dtype = torch.float)
y = torch.zeros(sample_num, self.z_dim).scatter_(1, y.long(), 1)
y = y.to(self.device)
s = torch.ones((sample_num, self.z_dim)).to(self.device)
#make dataset
mydataset = MyDataset_new(x2, y, s)
train_loader = DataLoader(dataset=mydataset,
batch_size=self.eval_batch_size, shuffle=False)
preds = []
#testing
for step, (batch_x2, batch_y, batch_s) in enumerate(train_loader):
out = self.gen_net(batch_x2)
loss = self.loss_mse(out.double(), batch_s.double())
predictions = out.cpu().detach().numpy().tolist()
preds.extend(predictions)
return preds
if __name__ == '__main__':
pass