Epik / Model /COSMIC /erc_training /commonsense_model.py
Minh Q. Le
Pushed COSMIC code
a446b0b
raw
history blame
14 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_sequence
import numpy as np, itertools, random, copy, math
from model import SimpleAttention, MatchingAttention, Attention
class CommonsenseRNNCell(nn.Module):
def __init__(self, D_m, D_s, D_g, D_p, D_r, D_i, D_e, listener_state=False,
context_attention='simple', D_a=100, dropout=0.5, emo_gru=True):
super(CommonsenseRNNCell, self).__init__()
self.D_m = D_m
self.D_s = D_s
self.D_g = D_g
self.D_p = D_p
self.D_r = D_r
self.D_i = D_i
self.D_e = D_e
# print ('dmsg', D_m, D_s, D_g)
self.g_cell = nn.GRUCell(D_m+D_p+D_r, D_g)
self.p_cell = nn.GRUCell(D_s+D_g, D_p)
self.r_cell = nn.GRUCell(D_m+D_s+D_g, D_r)
self.i_cell = nn.GRUCell(D_s+D_p, D_i)
self.e_cell = nn.GRUCell(D_m+D_p+D_r+D_i, D_e)
self.emo_gru = emo_gru
self.listener_state = listener_state
if listener_state:
self.pl_cell = nn.GRUCell(D_s+D_g, D_p)
self.rl_cell = nn.GRUCell(D_m+D_s+D_g, D_r)
self.dropout = nn.Dropout(dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.dropout4 = nn.Dropout(dropout)
self.dropout5 = nn.Dropout(dropout)
if context_attention=='simple':
self.attention = SimpleAttention(D_g)
else:
self.attention = MatchingAttention(D_g, D_m, D_a, context_attention)
def _select_parties(self, X, indices):
q0_sel = []
for idx, j in zip(indices, X):
q0_sel.append(j[idx].unsqueeze(0))
q0_sel = torch.cat(q0_sel,0)
return q0_sel
def forward(self, U, x1, x2, x3, o1, o2, qmask, g_hist, q0, r0, i0, e0):
"""
U -> batch, D_m
x1, x2, x3, o1, o2 -> batch, D_m
x1 -> effect on self; x2 -> reaction of self; x3 -> intent of self
o1 -> effect on others; o2 -> reaction of others
qmask -> batch, party
g_hist -> t-1, batch, D_g
q0 -> batch, party, D_p
e0 -> batch, self.D_e
"""
qm_idx = torch.argmax(qmask, 1)
q0_sel = self._select_parties(q0, qm_idx)
r0_sel = self._select_parties(r0, qm_idx)
## global state ##
g_ = self.g_cell(torch.cat([U, q0_sel, r0_sel], dim=1),
torch.zeros(U.size()[0],self.D_g).type(U.type()) if g_hist.size()[0]==0 else
g_hist[-1])
# g_ = self.dropout(g_)
## context ##
if g_hist.size()[0]==0:
c_ = torch.zeros(U.size()[0], self.D_g).type(U.type())
alpha = None
else:
c_, alpha = self.attention(g_hist, U)
## external state ##
U_r_c_ = torch.cat([U, x2, c_], dim=1).unsqueeze(1).expand(-1, qmask.size()[1],-1)
# print ('urc', U_r_c_.size())
# print ('u x2, c', U.size(), x2.size(), c_.size())
rs_ = self.r_cell(U_r_c_.contiguous().view(-1, self.D_m+self.D_s+self.D_g),
r0.view(-1, self.D_r)).view(U.size()[0], -1, self.D_r)
# rs_ = self.dropout(rs_)
## internal state ##
es_c_ = torch.cat([x1, c_], dim=1).unsqueeze(1).expand(-1,qmask.size()[1],-1)
qs_ = self.p_cell(es_c_.contiguous().view(-1, self.D_s+self.D_g),
q0.view(-1, self.D_p)).view(U.size()[0], -1, self.D_p)
# qs_ = self.dropout(qs_)
if self.listener_state:
## listener external state ##
U_ = U.unsqueeze(1).expand(-1,qmask.size()[1],-1).contiguous().view(-1,self.D_m)
er_ = o2.unsqueeze(1).expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_s)
ss_ = self._select_parties(rs_, qm_idx).unsqueeze(1).\
expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_r)
U_er_ss_ = torch.cat([U_, er_, ss_], 1)
rl_ = self.rl_cell(U_er_ss_, r0.view(-1, self.D_r)).view(U.size()[0], -1, self.D_r)
# rl_ = self.dropout(rl_)
## listener internal state ##
es_ = o1.unsqueeze(1).expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_s)
ss_ = self._select_parties(qs_, qm_idx).unsqueeze(1).\
expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_p)
es_ss_ = torch.cat([es_, ss_], 1)
ql_ = self.pl_cell(es_ss_, q0.view(-1, self.D_p)).view(U.size()[0], -1, self.D_p)
# ql_ = self.dropout(ql_)
else:
rl_ = r0
ql_ = q0
qmask_ = qmask.unsqueeze(2)
q_ = ql_*(1-qmask_) + qs_*qmask_
r_ = rl_*(1-qmask_) + rs_*qmask_
## intent ##
i_q_ = torch.cat([x3, self._select_parties(q_, qm_idx)], dim=1).unsqueeze(1).expand(-1, qmask.size()[1], -1)
is_ = self.i_cell(i_q_.contiguous().view(-1, self.D_s+self.D_p),
i0.view(-1, self.D_i)).view(U.size()[0], -1, self.D_i)
# is_ = self.dropout(is_)
il_ = i0
i_ = il_*(1-qmask_) + is_*qmask_
## emotion ##
es_ = torch.cat([U, self._select_parties(q_, qm_idx), self._select_parties(r_, qm_idx),
self._select_parties(i_, qm_idx)], dim=1)
e0 = torch.zeros(qmask.size()[0], self.D_e).type(U.type()) if e0.size()[0]==0\
else e0
if self.emo_gru:
e_ = self.e_cell(es_, e0)
else:
e_ = es_
# e_ = self.dropout(e_)
g_ = self.dropout1(g_)
q_ = self.dropout2(q_)
r_ = self.dropout3(r_)
i_ = self.dropout4(i_)
e_ = self.dropout5(e_)
return g_, q_, r_, i_, e_, alpha
class CommonsenseRNN(nn.Module):
def __init__(self, D_m, D_s, D_g, D_p, D_r, D_i, D_e, listener_state=False,
context_attention='simple', D_a=100, dropout=0.5, emo_gru=True):
super(CommonsenseRNN, self).__init__()
self.D_m = D_m
self.D_g = D_g
self.D_p = D_p
self.D_r = D_r
self.D_i = D_i
self.D_e = D_e
self.dropout = nn.Dropout(dropout)
self.dialogue_cell = CommonsenseRNNCell(D_m, D_s, D_g, D_p, D_r, D_i, D_e,
listener_state, context_attention, D_a, dropout, emo_gru)
def forward(self, U, x1, x2, x3, o1, o2, qmask):
"""
U -> seq_len, batch, D_m
x1, x2, x3, o1, o2 -> seq_len, batch, D_s
qmask -> seq_len, batch, party
"""
g_hist = torch.zeros(0).type(U.type()) # 0-dimensional tensor
q_ = torch.zeros(qmask.size()[1], qmask.size()[2], self.D_p).type(U.type()) # batch, party, D_p
r_ = torch.zeros(qmask.size()[1], qmask.size()[2], self.D_r).type(U.type()) # batch, party, D_r
i_ = torch.zeros(qmask.size()[1], qmask.size()[2], self.D_i).type(U.type()) # batch, party, D_i
e_ = torch.zeros(0).type(U.type()) # batch, D_e
e = e_
alpha = []
for u_, x1_, x2_, x3_, o1_, o2_, qmask_ in zip(U, x1, x2, x3, o1, o2, qmask):
g_, q_, r_, i_, e_, alpha_ = self.dialogue_cell(u_, x1_, x2_, x3_, o1_, o2_,
qmask_, g_hist, q_, r_, i_, e_)
g_hist = torch.cat([g_hist, g_.unsqueeze(0)],0)
e = torch.cat([e, e_.unsqueeze(0)],0)
if type(alpha_)!=type(None):
alpha.append(alpha_[:,0,:])
return e, alpha # seq_len, batch, D_e
class CommonsenseGRUModel(nn.Module):
def __init__(self, D_m, D_s, D_g, D_p, D_r, D_i, D_e, D_h, D_a=100, n_classes=7, listener_state=False,
context_attention='simple', dropout_rec=0.5, dropout=0.1, emo_gru=True, mode1=0, norm=0, residual=False):
super(CommonsenseGRUModel, self).__init__()
if mode1 == 0:
D_x = 4 * D_m
elif mode1 == 1:
D_x = 2 * D_m
else:
D_x = D_m
self.mode1 = mode1
self.norm_strategy = norm
self.linear_in = nn.Linear(D_x, D_h)
self.residual = residual
self.r_weights = nn.Parameter(torch.tensor([0.25, 0.25, 0.25, 0.25]))
norm_train = True
self.norm1a = nn.LayerNorm(D_m, elementwise_affine=norm_train)
self.norm1b = nn.LayerNorm(D_m, elementwise_affine=norm_train)
self.norm1c = nn.LayerNorm(D_m, elementwise_affine=norm_train)
self.norm1d = nn.LayerNorm(D_m, elementwise_affine=norm_train)
self.norm3a = nn.BatchNorm1d(D_m, affine=norm_train)
self.norm3b = nn.BatchNorm1d(D_m, affine=norm_train)
self.norm3c = nn.BatchNorm1d(D_m, affine=norm_train)
self.norm3d = nn.BatchNorm1d(D_m, affine=norm_train)
self.dropout = nn.Dropout(dropout)
self.dropout_rec = nn.Dropout(dropout_rec)
self.cs_rnn_f = CommonsenseRNN(D_h, D_s, D_g, D_p, D_r, D_i, D_e, listener_state,
context_attention, D_a, dropout_rec, emo_gru)
self.cs_rnn_r = CommonsenseRNN(D_h, D_s, D_g, D_p, D_r, D_i, D_e, listener_state,
context_attention, D_a, dropout_rec, emo_gru)
self.sense_gru = nn.GRU(input_size=D_s, hidden_size=D_s//2, num_layers=1, bidirectional=True)
self.matchatt = MatchingAttention(2*D_e,2*D_e,att_type='general2')
self.linear = nn.Linear(2*D_e, D_h)
self.smax_fc = nn.Linear(D_h, n_classes)
def _reverse_seq(self, X, mask):
"""
X -> seq_len, batch, dim
mask -> batch, seq_len
"""
X_ = X.transpose(0,1)
mask_sum = torch.sum(mask, 1).int()
xfs = []
for x, c in zip(X_, mask_sum):
xf = torch.flip(x[:c], [0])
xfs.append(xf)
return pad_sequence(xfs)
def forward(self, r1, r2, r3, r4, x1, x2, x3, o1, o2, qmask, umask, att2=False, return_hidden=False):
"""
U -> seq_len, batch, D_m
qmask -> seq_len, batch, party
"""
seq_len, batch, feature_dim = r1.size()
if self.norm_strategy == 1:
r1 = self.norm1a(r1.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
r2 = self.norm1b(r2.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
r3 = self.norm1c(r3.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
r4 = self.norm1d(r4.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
elif self.norm_strategy == 2:
norm2 = nn.LayerNorm((seq_len, feature_dim), elementwise_affine=False)
r1 = norm2(r1.transpose(0, 1)).transpose(0, 1)
r2 = norm2(r2.transpose(0, 1)).transpose(0, 1)
r3 = norm2(r3.transpose(0, 1)).transpose(0, 1)
r4 = norm2(r4.transpose(0, 1)).transpose(0, 1)
elif self.norm_strategy == 3:
r1 = self.norm3a(r1.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
r2 = self.norm3b(r2.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
r3 = self.norm3c(r3.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
r4 = self.norm3d(r4.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
if self.mode1 == 0:
r = torch.cat([r1, r2, r3, r4], axis=-1)
elif self.mode1 == 1:
r = torch.cat([r1, r2], axis=-1)
elif self.mode1 == 2:
r = (r1 + r2 + r3 + r4)/4
elif self.mode1 == 3:
r = r1
elif self.mode1 == 4:
r = r2
elif self.mode1 == 5:
r = r3
elif self.mode1 == 6:
r = r4
elif self.mode1 == 7:
r = self.r_weights[0]*r1 + self.r_weights[1]*r2 + self.r_weights[2]*r3 + self.r_weights[3]*r4
r = self.linear_in(r)
emotions_f, alpha_f = self.cs_rnn_f(r, x1, x2, x3, o1, o2, qmask)
out_sense, _ = self.sense_gru(x1)
rev_r = self._reverse_seq(r, umask)
rev_x1 = self._reverse_seq(x1, umask)
rev_x2 = self._reverse_seq(x2, umask)
rev_x3 = self._reverse_seq(x3, umask)
rev_o1 = self._reverse_seq(o1, umask)
rev_o2 = self._reverse_seq(o2, umask)
rev_qmask = self._reverse_seq(qmask, umask)
emotions_b, alpha_b = self.cs_rnn_r(rev_r, rev_x1, rev_x2, rev_x3, rev_o1, rev_o2, rev_qmask)
emotions_b = self._reverse_seq(emotions_b, umask)
emotions = torch.cat([emotions_f,emotions_b],dim=-1)
emotions = self.dropout_rec(emotions)
alpha, alpha_f, alpha_b = [], [], []
if att2:
att_emotions = []
alpha = []
for t in emotions:
att_em, alpha_ = self.matchatt(emotions,t,mask=umask)
att_emotions.append(att_em.unsqueeze(0))
alpha.append(alpha_[:,0,:])
att_emotions = torch.cat(att_emotions,dim=0)
hidden = F.relu(self.linear(att_emotions))
else:
hidden = F.relu(self.linear(emotions))
hidden = self.dropout(hidden)
if self.residual:
hidden = hidden + r
log_prob = F.log_softmax(self.smax_fc(hidden), 2)
if return_hidden:
return hidden, alpha, alpha_f, alpha_b, emotions
return log_prob, out_sense, alpha, alpha_f, alpha_b, emotions