#!/usr/bin/python # coding: utf-8 # Author: LE YUAN # Date: 2021-03-23 import pickle import sys import timeit import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from sklearn.metrics import mean_squared_error,r2_score class KcatPrediction(nn.Module): def __init__(self, device, n_fingerprint, n_word, dim, layer_gnn, window, layer_cnn, layer_output): super(KcatPrediction, self).__init__() self.embed_fingerprint = nn.Embedding(n_fingerprint, dim) self.embed_word = nn.Embedding(n_word, dim) self.W_gnn = nn.ModuleList([nn.Linear(dim, dim) for _ in range(layer_gnn)]) self.W_cnn = nn.ModuleList([nn.Conv2d( in_channels=1, out_channels=1, kernel_size=2*window+1, stride=1, padding=window) for _ in range(layer_cnn)]) self.W_attention = nn.Linear(dim, dim) self.W_out = nn.ModuleList([nn.Linear(2*dim, 2*dim) for _ in range(layer_output)]) # self.W_interaction = nn.Linear(2*dim, 2) self.W_interaction = nn.Linear(2*dim, 1) self.device = device self.dim = dim self.layer_gnn = layer_gnn self.window = window self.layer_cnn = layer_cnn self.layer_output = layer_output def gnn(self, xs, A, layer): for i in range(layer): hs = torch.relu(self.W_gnn[i](xs)) xs = xs + torch.matmul(A, hs) # return torch.unsqueeze(torch.sum(xs, 0), 0) return torch.unsqueeze(torch.mean(xs, 0), 0) def attention_cnn(self, x, xs, layer): """The attention mechanism is applied to the last layer of CNN.""" xs = torch.unsqueeze(torch.unsqueeze(xs, 0), 0) for i in range(layer): xs = torch.relu(self.W_cnn[i](xs)) xs = torch.squeeze(torch.squeeze(xs, 0), 0) h = torch.relu(self.W_attention(x)) hs = torch.relu(self.W_attention(xs)) weights = torch.tanh(F.linear(h, hs)) ys = torch.t(weights) * hs attention_weights = F.linear(h,hs)[0].tolist() max_attention = max([float(attention) for attention in attention_weights]) # print(max_attention) attention_profiles = ['%.4f' %(float(attention)/max_attention) for attention in attention_weights] return torch.unsqueeze(torch.mean(ys, 0), 0), attention_profiles def forward(self, inputs): fingerprints, adjacency, words = inputs layer_gnn = 3 layer_cnn = 3 layer_output = 3 """Compound vector with GNN.""" fingerprint_vectors = self.embed_fingerprint(fingerprints) compound_vector = self.gnn(fingerprint_vectors, adjacency, layer_gnn) """Protein vector with attention-CNN.""" word_vectors = self.embed_word(words) protein_vector, attention_profiles = self.attention_cnn(compound_vector, word_vectors, layer_cnn) # print(protein_vector) # print('The length of protein vectors is:', len(protein_vector[0])) """Concatenate the above two vectors and output the interaction.""" cat_vector = torch.cat((compound_vector, protein_vector), 1) for j in range(layer_output): cat_vector = torch.relu(self.W_out[j](cat_vector)) # print(cat_vector) interaction = self.W_interaction(cat_vector) # print(interaction) return interaction, attention_profiles def __call__(self, data, train=True): inputs, correct_interaction = data[:-1], data[-1] predicted_interaction = self.forward(inputs) print(predicted_interaction) if train: loss = F.mse_loss(predicted_interaction, correct_interaction) return loss else: correct_values = correct_interaction.to('cpu').data.numpy() predicted_values = predicted_interaction.to('cpu').data.numpy()[0] print(correct_values) print(predicted_values) return correct_values, predicted_values class Trainer(object): def __init__(self, model): self.model = model self.optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) def train(self, dataset): np.random.shuffle(dataset) N = len(dataset) loss_total = 0 for data in dataset: loss = self.model(data) self.optimizer.zero_grad() loss.backward() self.optimizer.step() loss_total += loss.to('cpu').data.numpy() return loss_total class Tester(object): def __init__(self, model): self.model = model def test(self, dataset): N = len(dataset) SAE = 0 # sum absolute error. testY, testPredict = [], [] for data in dataset : (correct_values, predicted_values) = self.model(data, train=False) SAE += sum(np.abs(predicted_values-correct_values)) testY.append(correct_values) testPredict.append(predicted_values) MAE = SAE / N # mean absolute error. rmse = np.sqrt(mean_squared_error(testY,testPredict)) r2 = r2_score(testY,testPredict) return MAE, rmse, r2 def save_MAEs(self, MAEs, filename): with open(filename, 'a') as f: f.write('\t'.join(map(str, MAEs)) + '\n') def save_model(self, model, filename): torch.save(model.state_dict(), filename) def load_tensor(file_name, dtype): return [dtype(d).to(device) for d in np.load(file_name + '.npy', allow_pickle=True)] def load_pickle(file_name): with open(file_name, 'rb') as f: return pickle.load(f) def shuffle_dataset(dataset, seed): np.random.seed(seed) np.random.shuffle(dataset) return dataset def split_dataset(dataset, ratio): n = int(ratio * len(dataset)) dataset_1, dataset_2 = dataset[:n], dataset[n:] return dataset_1, dataset_2 if __name__ == "__main__": """Hyperparameters.""" (DATASET, radius, ngram, dim, layer_gnn, window, layer_cnn, layer_output, lr, lr_decay, decay_interval, weight_decay, iteration, setting) = sys.argv[1:] (dim, layer_gnn, window, layer_cnn, layer_output, decay_interval, iteration) = map(int, [dim, layer_gnn, window, layer_cnn, layer_output, decay_interval, iteration]) lr, lr_decay, weight_decay = map(float, [lr, lr_decay, weight_decay]) """CPU or GPU.""" if torch.cuda.is_available(): device = torch.device('cuda') print('The code uses GPU...') else: device = torch.device('cpu') print('The code uses CPU!!!') """Load preprocessed data.""" dir_input = ('../../Data/input/') compounds = load_tensor(dir_input + 'compounds', torch.LongTensor) adjacencies = load_tensor(dir_input + 'adjacencies', torch.FloatTensor) proteins = load_tensor(dir_input + 'proteins', torch.LongTensor) interactions = load_tensor(dir_input + 'regression', torch.FloatTensor) fingerprint_dict = load_pickle(dir_input + 'fingerprint_dict.pickle') word_dict = load_pickle(dir_input + 'sequence_dict.pickle') n_fingerprint = len(fingerprint_dict) n_word = len(word_dict) """Create a dataset and split it into train/dev/test.""" dataset = list(zip(compounds, adjacencies, proteins, interactions)) dataset = shuffle_dataset(dataset, 1234) print(len(dataset)) dataset_train, dataset_ = split_dataset(dataset, 0.8) dataset_dev, dataset_test = split_dataset(dataset_, 0.5) """Set a model.""" torch.manual_seed(1234) model = KcatPrediction().to(device) trainer = Trainer(model) tester = Tester(model) """Output files.""" file_MAEs = '../../Results/output/MAEs--' + setting + '.txt' file_model = '../../Results/output/' + setting # MAEs = ('Epoch\tTime(sec)\tLoss_train\tMAE_dev\t' # 'MAE_test\tPrecision_test\tRecall_test') MAEs = ('Epoch\tTime(sec)\tLoss_train\tMAE_dev\tMAE_test\tRMSE_dev\tRMSE_test\tR2_dev\tR2_test') with open(file_MAEs, 'w') as f: f.write(MAEs + '\n') """Start training.""" print('Training...') print(MAEs) start = timeit.default_timer() for epoch in range(1, iteration): if epoch % decay_interval == 0: trainer.optimizer.param_groups[0]['lr'] *= lr_decay loss_train = trainer.train(dataset_train) MAE_dev, RMSE_dev, R2_dev = tester.test(dataset_dev) MAE_test, RMSE_test, R2_test = tester.test(dataset_test) end = timeit.default_timer() time = end - start MAEs = [epoch, time, loss_train, MAE_dev, MAE_test, RMSE_dev, RMSE_test, R2_dev, R2_test] tester.save_MAEs(MAEs, file_MAEs) tester.save_model(model, file_model) print('\t'.join(map(str, MAEs)))