Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from atten import Atten | |
| class FGA(nn.Module): | |
| def __init__(self, vocab_size, word_embed_dim, hidden_ques_dim, hidden_ans_dim, | |
| hidden_hist_dim, hidden_cap_dim, hidden_img_dim): | |
| ''' | |
| Factor Graph Attention | |
| :param vocab_size: vocabulary size | |
| :param word_embed_dim | |
| :param hidden_ques_dim: | |
| :param hidden_ans_dim: | |
| :param hidden_hist_dim: | |
| :param img_features_dim: | |
| ''' | |
| super(FGA, self).__init__() | |
| print("Init FGA with vocab size %s, word embed %s, hidden ques %s, hidden ans %s," | |
| " hidden hist %s, hidden cap %s, hidden img %s" % (vocab_size, word_embed_dim, | |
| hidden_ques_dim, | |
| hidden_ans_dim, | |
| hidden_hist_dim, | |
| hidden_cap_dim, | |
| hidden_img_dim)) | |
| self.hidden_ques_dim = hidden_ques_dim | |
| self.hidden_ans_dim = hidden_ans_dim | |
| self.hidden_cap_dim = hidden_cap_dim | |
| self.hidden_img_dim = hidden_img_dim | |
| self.hidden_hist_dim = hidden_hist_dim | |
| # Vocab of History LSTMs is one more as we are keeping a stop id (the last id) | |
| self.word_embedddings = nn.Embedding(vocab_size+1+1, word_embed_dim, padding_idx=0) | |
| self.lstm_ques = nn.LSTM(word_embed_dim, self.hidden_ques_dim, batch_first=True) | |
| self.lstm_ans = nn.LSTM(word_embed_dim, self.hidden_ans_dim, batch_first=True) | |
| self.lstm_hist_ques = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True) | |
| self.lstm_hist_ans = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True) | |
| self.lstm_hist_cap = nn.LSTM(word_embed_dim, self.hidden_cap_dim, batch_first=True) | |
| self.qahistnet = nn.Sequential( | |
| nn.Linear(self.hidden_hist_dim*2, self.hidden_hist_dim), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.concat_dim = self.hidden_ques_dim + self.hidden_ans_dim + \ | |
| self.hidden_ans_dim + self.hidden_img_dim + \ | |
| self.hidden_cap_dim + self.hidden_hist_dim*9 | |
| self.simnet = nn.Sequential( | |
| nn.Linear(self.concat_dim, (self.concat_dim)//2, bias=False), | |
| nn.BatchNorm1d((self.concat_dim) // 2), | |
| nn.ReLU(inplace=True), | |
| nn.Linear((self.concat_dim)//2, (self.concat_dim)//4, bias=False), | |
| nn.BatchNorm1d((self.concat_dim) // 4), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.5), | |
| nn.Linear((self.concat_dim)//4, 1) | |
| ) | |
| # To share weights, provide list of tuples: (idx, list of connected utils) | |
| # Note, for efficiency, the shared utils (i.e., history, are connected to ans and question only. | |
| # connecting shared factors is not supported (!) | |
| sharing_factor_weights = {4: (9, [0, 1]), | |
| 5: (9, [0, 1])} | |
| self.mul_atten = Atten(util_e=[self.hidden_ans_dim, # Answer modal | |
| self.hidden_ques_dim, # Question modal | |
| self.hidden_cap_dim, # Caption modal | |
| self.hidden_img_dim, # Image modal | |
| self.hidden_hist_dim, # Question-history modal | |
| self.hidden_hist_dim # Answer-history modal | |
| ], | |
| sharing_factor_weights=sharing_factor_weights, | |
| sizes=[100, # 100 Answers | |
| 21, # Question length | |
| 41, # Caption length | |
| 37, # 36 Image regions | |
| 21, # History-Question length | |
| 21 # History-Answer length | |
| ] # The spatial dim used for pairwise normalization (use force for adaptive) | |
| , prior_flag=True, | |
| pairwise_flag=True) | |
| def forward(self, input_ques, input_ans, input_hist_ques, input_hist_ans, input_hist_cap, | |
| input_ques_length, input_ans_length, input_cap_length, i_e): | |
| """ | |
| :param input_ques: | |
| :param input_ans: | |
| :param input_hist_ques: | |
| :param input_hist_ans: | |
| :param input_hist_cap: | |
| :param input_ques_length: | |
| :param input_ans_length: | |
| :param input_cap_length: | |
| :param i_e: | |
| :return: | |
| """ | |
| n_options = input_ans.size()[1] | |
| batch_size = input_ques.size()[0] | |
| nqa_per_dial, nwords_per_qa = input_hist_ques.size()[1], input_hist_ques.size()[2] | |
| nwords_per_cap = input_hist_cap.size()[1] | |
| max_length_input_ans = input_ans.size()[-1] | |
| assert batch_size == input_hist_ques.size()[0] == input_hist_ans.size()[0] == input_ques.size()[0] == \ | |
| input_ans.size()[0] == input_hist_cap.size()[0] | |
| assert nqa_per_dial == input_hist_ques.size()[1] == input_hist_ans.size()[1] | |
| assert nwords_per_qa == input_hist_ques.size()[2] == input_hist_ans.size()[2] | |
| q_we = self.word_embedddings(input_ques) | |
| a_we = self.word_embedddings(input_ans.view(-1, max_length_input_ans)) | |
| hq_we = self.word_embedddings(input_hist_ques.view(-1, nwords_per_qa)) | |
| ha_we = self.word_embedddings(input_hist_ans.view(-1, nwords_per_qa)) | |
| c_we = self.word_embedddings(input_hist_cap.view(-1, nwords_per_cap)) | |
| ''' | |
| q_we = batch x 20 x embed_ques_dim | |
| a_we = 100*batch x 20 x embed_ans_dim | |
| hq_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim | |
| ha_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim | |
| c_we = batch*ncap_per_dial, nwords_per_cap, embed_hist_dim | |
| ''' | |
| self.lstm_ques.flatten_parameters() | |
| self.lstm_ans.flatten_parameters() | |
| self.lstm_hist_ques.flatten_parameters() | |
| self.lstm_hist_ans.flatten_parameters() | |
| self.lstm_hist_cap.flatten_parameters() | |
| i_feat = i_e | |
| q_seq, self.hidden_ques = self.lstm_ques(q_we) | |
| a_seq, self.hidden_ans = self.lstm_ans(a_we) | |
| hq_seq, self.hidden_hist_ques = self.lstm_hist_ques(hq_we) | |
| ha_seq, self.hidden_hist_ans = self.lstm_hist_ans(ha_we) | |
| cap_seq, self.hidden_cap = self.lstm_hist_cap(c_we) | |
| ''' | |
| length is used for attention prior | |
| ''' | |
| q_len = input_ques_length.data - 1 | |
| c_len = input_cap_length.data.view(-1) - 1 | |
| ans_index = torch.arange(0, n_options * batch_size).long().cuda() | |
| ans_len = input_ans_length.data.view(-1) - 1 | |
| ans_seq = a_seq[ans_index, ans_len, :] | |
| ans_seq = ans_seq.view(batch_size, n_options, self.hidden_ans_dim) | |
| batch_index = torch.arange(0, batch_size).long().cuda() | |
| q_prior = torch.zeros(batch_size, q_seq.size(1)).cuda() | |
| q_prior[batch_index, q_len] = 100 | |
| c_prior = torch.zeros(batch_size, cap_seq.size(1)).cuda() | |
| c_prior[batch_index, c_len] = 100 | |
| ans_prior = torch.ones(batch_size, ans_seq.size(1)).cuda() | |
| img_prior = torch.ones(batch_size, i_feat.size(1)).cuda() | |
| (ans_atten, ques_atten, cap_atten, img_atten, hq_atten, ha_atten) = \ | |
| self.mul_atten([ans_seq, q_seq, cap_seq, i_feat, hq_seq, ha_seq], | |
| priors=[ans_prior, q_prior, c_prior, img_prior, None, None]) | |
| ''' | |
| expand to answers based | |
| ''' | |
| ques_atten = torch.unsqueeze(ques_atten, 1).expand(batch_size, | |
| n_options, | |
| self.hidden_ques_dim) | |
| cap_atten = torch.unsqueeze(cap_atten, 1).expand(batch_size, | |
| n_options, | |
| self.hidden_cap_dim) | |
| img_atten = torch.unsqueeze(img_atten, 1).expand(batch_size, n_options, | |
| self.hidden_img_dim) | |
| ans_atten = torch.unsqueeze(ans_atten, 1).expand(batch_size, n_options, | |
| self.hidden_ans_dim) | |
| ''' | |
| combine history | |
| ''' | |
| input_qahistnet = torch.cat((hq_atten, ha_atten), 1) | |
| # input_qahistnet: (nqa_per_dial*batch x 2*hidden_hist_dim) | |
| output_qahistnet = self.qahistnet(input_qahistnet) | |
| # output_qahistnet: (nqa_per_dial*batch x hidden_hist_dim) | |
| output_qahistnet = output_qahistnet.view(batch_size, | |
| nqa_per_dial * self.hidden_hist_dim) | |
| # output_qahistnet: (batch x nqa_per_dial*hidden_hist_dim) | |
| output_qahistnet = torch.unsqueeze(output_qahistnet, 1)\ | |
| .expand(batch_size, | |
| n_options, | |
| nqa_per_dial * self.hidden_hist_dim) | |
| input_qa = torch.cat((ans_seq, ques_atten, ans_atten, img_atten, | |
| output_qahistnet, cap_atten), 2) # Concatenate last dimension | |
| input_qa = input_qa.view(batch_size * n_options, self.concat_dim) | |
| out_scores = self.simnet(input_qa) | |
| out_scores = out_scores.squeeze(dim=1) | |
| out_scores = out_scores.view(batch_size, n_options) | |
| return out_scores |