|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
import numpy as np |
|
|
|
|
|
class MyDataset(Dataset): |
|
def __init__(self, x, y): |
|
self.x = x |
|
self.y = y |
|
self.len = self.x.size(0) |
|
|
|
def __getitem__(self, index): |
|
return self.x[index], self.y[index] |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
|
|
class MyDataset_new(Dataset): |
|
def __init__(self, x, y, s): |
|
self.x = x |
|
self.y = y |
|
self.s = s |
|
self.len = self.x.size(0) |
|
|
|
def __getitem__(self, index): |
|
return self.x[index], self.y[index], self.s[index] |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
|
|
class CLS_Net(torch.nn.Module): |
|
|
|
def __init__(self, cls_num, z_dim, cls_batch_size): |
|
super(CLS_Net, self).__init__() |
|
|
|
mini_dim = 256 |
|
|
|
out_input_num = mini_dim |
|
|
|
base_dim = 64 |
|
|
|
self.cls_batch_size = cls_batch_size |
|
self.jie = 1 |
|
|
|
self.fc1 = nn.Linear(z_dim, mini_dim) |
|
self.fc1.weight.data.normal_(0, 0.1) |
|
|
|
self.fc2 = nn.Linear(out_input_num, base_dim) |
|
self.fc2.weight.data.normal_(0, 0.1) |
|
|
|
self.out = nn.Linear(base_dim, cls_num) |
|
self.out.weight.data.normal_(0, 0.1) |
|
|
|
def self_dis(self, a): |
|
max_dim = self.cls_batch_size |
|
jie = self.jie |
|
|
|
all_tag = False |
|
for j in range(a.shape[0]): |
|
col_tag = False |
|
for i in range(a.shape[0]): |
|
tmp = F.pairwise_distance(a[j,:], a[i,:] , p = jie).view(-1,1) |
|
if col_tag == False: |
|
col_dis = tmp |
|
col_tag = True |
|
else: |
|
col_dis = torch.cat((col_dis, tmp), dim = 0) |
|
if all_tag == False: |
|
all_dis = col_dis |
|
all_tag = True |
|
else: |
|
all_dis = torch.cat((all_dis, col_dis), dim = 1) |
|
''' |
|
print(all_dis.shape) |
|
if all_dis.shape[1] < max_dim: |
|
all_dis = torch.cat((all_dis, all_dis[:,:(max_dim - all_dis.shape[1])]), dim = 1) |
|
print(all_dis.shape) |
|
''' |
|
return all_dis |
|
|
|
def forward(self, x): |
|
|
|
x = self.fc1(x) |
|
x1 = F.relu(x) |
|
|
|
x2 = self.fc2(x1) |
|
x2 = torch.nn.Dropout(0.1)(x2) |
|
x2 = F.relu(x2) |
|
|
|
y = self.out(x2) |
|
|
|
return y, x1 |
|
|
|
|
|
class Gen_Net(torch.nn.Module): |
|
|
|
def __init__(self,input_x2_dim, output_dim): |
|
super(Gen_Net, self).__init__() |
|
|
|
self.x2_input = nn.Linear(input_x2_dim , 60) |
|
self.x2_input.weight.data.normal_(0, 0.1) |
|
|
|
self.fc1 = nn.Linear(60, 128) |
|
self.fc1.weight.data.normal_(0, 0.1) |
|
|
|
self.fc2 = nn.Linear(128, 256) |
|
self.fc2.weight.data.normal_(0, 0.1) |
|
|
|
self.fc3 = nn.Linear(256, 128) |
|
self.fc3.weight.data.normal_(0, 0.1) |
|
|
|
self.out = nn.Linear(128, output_dim) |
|
self.out.weight.data.normal_(0, 0.1) |
|
|
|
def forward(self,x2): |
|
x2 = self.x2_input(x2) |
|
|
|
x = x2 |
|
x = self.fc1(x) |
|
x = F.relu(x) |
|
|
|
x = self.fc2(x) |
|
x = F.relu(x) |
|
|
|
x = self.fc3(x) |
|
x = F.relu(x) |
|
y = self.out(x) |
|
|
|
return y |
|
|
|
|
|
class gans_process(): |
|
|
|
def __init__(self, config): |
|
|
|
|
|
self.device = config.device |
|
self.cls_num = config.cls_num |
|
self.x2_dim = config.noise_dim |
|
self.z_dim = config.z_dim |
|
|
|
self.cls_lr = config.cls_lr |
|
self.gen_lr = config.gen_lr |
|
self.cls_epoches = config.cls_epoches |
|
self.gen_epoches = config.gen_epoches |
|
self.mse_weight = 1.0 |
|
|
|
self.cls_batch_size = config.cls_batch_size |
|
self.gen_batch_size = config.gen_batch_size |
|
self.eval_batch_size = config.cls_batch_size |
|
self.gen_batch_size = self.cls_batch_size |
|
|
|
|
|
self.cls_net = CLS_Net(self.cls_num, self.z_dim, self.cls_batch_size).to(self.device) |
|
self.cls_optimizer = torch.optim.SGD(self.cls_net.parameters(), |
|
lr = self.cls_lr , weight_decay= 1e-5) |
|
|
|
self.gen_net = Gen_Net(self.x2_dim, self.z_dim).to(self.device) |
|
|
|
self.gen_optimizer = torch.optim.SGD(self.gen_net.parameters(), |
|
lr = self.gen_lr , weight_decay= 0.01) |
|
|
|
|
|
self.loss_func = torch.nn.CrossEntropyLoss() |
|
self.loss_mse = torch.nn.MSELoss() |
|
|
|
def freeze_cls(self): |
|
for param in self.cls_net.parameters(): |
|
param.requires_grad = False |
|
|
|
def unfreeze_cls(self): |
|
for param in self.cls_net.parameters(): |
|
param.requires_grad = True |
|
|
|
def freeze_gen(self): |
|
for param in self.gen_net.parameters(): |
|
param.requires_grad = False |
|
|
|
def unfreeze_gen(self): |
|
for param in self.gen_net.parameters(): |
|
param.requires_grad = True |
|
|
|
def labels2genx(self, sample_num): |
|
x = torch.rand(sample_num, self.x2_dim) |
|
return x.to(self.device) |
|
|
|
def pad_batch(self, x): |
|
if int(x.shape[0] % self.cls_batch_size) == 0: |
|
return x |
|
pad_len = self.cls_batch_size - ( x.shape[0] % self.cls_batch_size) |
|
x = torch.cat((x, x[:pad_len]), dim = 0) |
|
return x |
|
|
|
def ready_cls(self, sent_output,perm=None): |
|
sample_num = len(sent_output) |
|
|
|
sent_output = sent_output.to(self.device) |
|
sent_noise = torch.tensor(self.gen_test(sample_num)).to(self.device) |
|
|
|
|
|
x = torch.cat((sent_output, sent_noise), dim = 0 ) |
|
if perm is None: |
|
perm = torch.randperm(len(x)) |
|
x = x[perm] |
|
|
|
multi_label_num = 1 |
|
multi_output_y = torch.tensor([0]*sample_num).unsqueeze(1) |
|
multi_noise_y = torch.zeros([sent_noise.size(0),1], dtype = torch.int) |
|
multi_noise_y = multi_noise_y + multi_label_num |
|
|
|
y = torch.cat((multi_output_y, multi_noise_y), dim = 0).to(self.device) |
|
y = y[perm] |
|
|
|
|
|
|
|
|
|
|
|
return x,y,None,None,perm |
|
|
|
def ready_fake(self, sent_output, inputs_labels, inputs_indexs, label2id, perm = None): |
|
|
|
|
|
sent_output = sent_output.to(self.device) |
|
sent_noise = torch.tensor(self.gen_test(inputs_labels, inputs_indexs)).to(self.device) |
|
|
|
|
|
x = sent_noise |
|
y = torch.tensor(inputs_labels).unsqueeze(1) |
|
if perm is None: |
|
perm = torch.randperm(len(x)) |
|
x = x[perm] |
|
y = y[perm] |
|
|
|
return x,y,perm |
|
|
|
def ready_gen(self, sent_output): |
|
|
|
sent_num = len(sent_output) |
|
sent_output = sent_output.to(self.device) |
|
x2 = self.labels2genx(sent_num) |
|
y = torch.tensor([0]*sent_num).unsqueeze(1).to(self.device) |
|
|
|
return x2, y, sent_output |
|
|
|
def cls_train(self, x, y, if_oneHot = True): |
|
|
|
|
|
self.cls_net.train() |
|
self.gen_net.eval() |
|
|
|
self.unfreeze_cls() |
|
self.freeze_gen() |
|
|
|
x = x.to(self.device) |
|
y = y.to(self.device) |
|
|
|
|
|
if if_oneHot: |
|
y = torch.zeros(y.size(0), self.cls_num).to(self.device).scatter_(1, y.long(), 1) |
|
|
|
mydataset = MyDataset(x, y) |
|
train_loader = DataLoader(dataset=mydataset, |
|
batch_size=self.cls_batch_size, shuffle=True) |
|
|
|
|
|
for epoch in range(self.cls_epoches): |
|
losses = [] |
|
accuracy = [] |
|
for step, (batch_x, batch_y) in enumerate(train_loader): |
|
self.cls_optimizer.zero_grad() |
|
|
|
out, _ = self.cls_net(batch_x) |
|
loss = self.loss_func(out, batch_y) |
|
|
|
|
|
|
|
batch_y = batch_y * torch.tensor([0.9, 1.0]).to(self.device) |
|
|
|
loss.backward() |
|
self.cls_optimizer.step() |
|
|
|
_, predictions = out.max(1) |
|
predictions = predictions.cpu().numpy().tolist() |
|
_,real_y = batch_y.max(1) |
|
real_y = real_y.cpu().numpy().tolist() |
|
|
|
num_correct = np.sum([int(x==y) for x,y in zip(predictions, real_y)]) |
|
running_train_acc = float(num_correct) / float(batch_x.shape[0]) |
|
losses.append(loss) |
|
accuracy.append(running_train_acc) |
|
|
|
|
|
return self.cls_net |
|
|
|
def cls_eval(self, x, y, if_oneHot = True): |
|
|
|
|
|
self.cls_net.eval() |
|
x = x.to(self.device) |
|
y = y.to(self.device) |
|
|
|
|
|
if if_oneHot: |
|
y = torch.zeros(y.size(0), self.cls_num).to(self.device).scatter_(1, y.long(), 1) |
|
|
|
mydataset = MyDataset(x, y) |
|
train_loader = DataLoader(dataset=mydataset, |
|
batch_size=self.eval_batch_size, shuffle=False) |
|
|
|
losses = [] |
|
accuracy = [] |
|
|
|
for step, (batch_x, batch_y) in enumerate(train_loader): |
|
out,_ = self.cls_net(batch_x) |
|
loss = self.loss_func(out, batch_y) |
|
|
|
|
|
_, predictions = out.max(1) |
|
predictions = predictions.cpu().numpy().tolist() |
|
_,real_y = batch_y.max(1) |
|
real_y = real_y.cpu().numpy().tolist() |
|
|
|
num_correct = np.sum([int(x==y) for x,y in zip(predictions, real_y)]) |
|
running_train_acc = float(num_correct) / float(batch_x.shape[0]) |
|
accuracy.append(running_train_acc) |
|
|
|
|
|
mean_acc = np.mean(accuracy) |
|
return mean_acc |
|
|
|
def cls_real_eval(self, x, y, if_oneHot = True): |
|
|
|
|
|
self.cls_net.eval() |
|
x = x.to(self.device) |
|
y = y.to(self.device) |
|
|
|
|
|
if if_oneHot: |
|
y = torch.zeros(y.size(0), self.cls_num).to(self.device).scatter_(1, y.long(), 1) |
|
|
|
mydataset = MyDataset(x, y) |
|
train_loader = DataLoader(dataset=mydataset, |
|
batch_size=self.eval_batch_size, shuffle=False) |
|
|
|
rs = 0 |
|
alls = 0 |
|
|
|
|
|
for step, (batch_x, batch_y) in enumerate(train_loader): |
|
out, _ = self.cls_net(batch_x) |
|
loss = self.loss_func(out, batch_y) |
|
|
|
|
|
_, predictions = out.max(1) |
|
predictions = predictions.cpu().numpy().tolist() |
|
_,real_y = batch_y.max(1) |
|
real_y = real_y.cpu().numpy().tolist() |
|
|
|
right_num = np.sum([int( x==y and int(y) != int(self.cls_num-1) ) for x,y in zip(predictions, real_y)]) |
|
all_num = np.sum([int(int(y) != int(self.cls_num-1) ) for x,y in zip(predictions, real_y)]) |
|
|
|
rs = rs + right_num |
|
alls = alls + all_num |
|
|
|
|
|
return rs/alls |
|
|
|
def cls_test(self, x, if_oneHot = True): |
|
|
|
|
|
self.cls_net.eval() |
|
x = x.to(self.device) |
|
y = torch.zeros([x.size(0),1], dtype = torch.float).to(self.device) |
|
|
|
|
|
if if_oneHot: |
|
y = torch.zeros(y.size(0), self.cls_num).to(self.device).scatter_(1, y.long(), 1) |
|
|
|
mydataset = MyDataset(x, y) |
|
train_loader = DataLoader(dataset=mydataset, |
|
batch_size=self.eval_batch_size, shuffle=False) |
|
|
|
preds = [] |
|
|
|
for step, (batch_x, batch_y) in enumerate(train_loader): |
|
out, _ = self.cls_net(batch_x) |
|
loss = self.loss_func(out, batch_y) |
|
|
|
|
|
_, predictions = out.max(1) |
|
predictions = predictions.cpu().numpy().tolist() |
|
preds.extend(predictions) |
|
|
|
return preds |
|
|
|
def gen_train(self, x2, y, s, times): |
|
|
|
|
|
self.cls_net.eval() |
|
self.gen_net.train() |
|
|
|
self.freeze_cls() |
|
self.unfreeze_gen() |
|
|
|
|
|
y = torch.zeros(y.size(0), self.cls_num).to(self.device).scatter_(1, y.long(), 1) |
|
|
|
|
|
mydataset = MyDataset_new(x2, y, s) |
|
train_loader = DataLoader(dataset=mydataset, |
|
batch_size=self.gen_batch_size, shuffle=True) |
|
|
|
|
|
for epoch in range(self.gen_epoches): |
|
losses = [] |
|
accuracy = [] |
|
for step, (batch_x2, batch_y, batch_s) in enumerate(train_loader): |
|
|
|
|
|
if step % 6 == 5: |
|
self.gen_optimizer.zero_grad() |
|
|
|
out = self.gen_net(batch_x2) |
|
|
|
|
|
out, hds = self.cls_net(out) |
|
out2, hds2 = self.cls_net(batch_s.float()) |
|
loss = self.loss_mse(hds, hds2) |
|
loss = loss * pow(0.9, times) |
|
loss.backward() |
|
self.gen_optimizer.step() |
|
|
|
|
|
_, predictions = out.max(1) |
|
predictions = predictions.cpu().numpy().tolist() |
|
_, real_y = batch_y.max(1) |
|
real_y = real_y.cpu().numpy().tolist() |
|
|
|
num_correct = np.sum([int(x==y) for x,y in zip(predictions, real_y)]) |
|
running_train_acc = float(num_correct) / float(batch_x2.shape[0]) |
|
losses.append(loss) |
|
accuracy.append(running_train_acc) |
|
|
|
return self.gen_net |
|
|
|
def gen_test(self, sample_num): |
|
|
|
|
|
self.gen_net.eval() |
|
x2 = self.labels2genx(sample_num) |
|
|
|
y = torch.zeros([sample_num,1], dtype = torch.float) |
|
y = torch.zeros(sample_num, self.z_dim).scatter_(1, y.long(), 1) |
|
y = y.to(self.device) |
|
s = torch.ones((sample_num, self.z_dim)).to(self.device) |
|
|
|
|
|
mydataset = MyDataset_new(x2, y, s) |
|
train_loader = DataLoader(dataset=mydataset, |
|
batch_size=self.eval_batch_size, shuffle=False) |
|
|
|
preds = [] |
|
|
|
for step, (batch_x2, batch_y, batch_s) in enumerate(train_loader): |
|
|
|
out = self.gen_net(batch_x2) |
|
|
|
loss = self.loss_mse(out.double(), batch_s.double()) |
|
|
|
predictions = out.cpu().detach().numpy().tolist() |
|
preds.extend(predictions) |
|
|
|
return preds |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
pass |
|
|
|
|