Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from torch.nn.utils.rnn import pad_sequence | |
import numpy as np, itertools, random, copy, math | |
from model import SimpleAttention, MatchingAttention, Attention | |
class CommonsenseRNNCell(nn.Module): | |
def __init__(self, D_m, D_s, D_g, D_p, D_r, D_i, D_e, listener_state=False, | |
context_attention='simple', D_a=100, dropout=0.5, emo_gru=True): | |
super(CommonsenseRNNCell, self).__init__() | |
self.D_m = D_m | |
self.D_s = D_s | |
self.D_g = D_g | |
self.D_p = D_p | |
self.D_r = D_r | |
self.D_i = D_i | |
self.D_e = D_e | |
# print ('dmsg', D_m, D_s, D_g) | |
self.g_cell = nn.GRUCell(D_m+D_p+D_r, D_g) | |
self.p_cell = nn.GRUCell(D_s+D_g, D_p) | |
self.r_cell = nn.GRUCell(D_m+D_s+D_g, D_r) | |
self.i_cell = nn.GRUCell(D_s+D_p, D_i) | |
self.e_cell = nn.GRUCell(D_m+D_p+D_r+D_i, D_e) | |
self.emo_gru = emo_gru | |
self.listener_state = listener_state | |
if listener_state: | |
self.pl_cell = nn.GRUCell(D_s+D_g, D_p) | |
self.rl_cell = nn.GRUCell(D_m+D_s+D_g, D_r) | |
self.dropout = nn.Dropout(dropout) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.dropout3 = nn.Dropout(dropout) | |
self.dropout4 = nn.Dropout(dropout) | |
self.dropout5 = nn.Dropout(dropout) | |
if context_attention=='simple': | |
self.attention = SimpleAttention(D_g) | |
else: | |
self.attention = MatchingAttention(D_g, D_m, D_a, context_attention) | |
def _select_parties(self, X, indices): | |
q0_sel = [] | |
for idx, j in zip(indices, X): | |
q0_sel.append(j[idx].unsqueeze(0)) | |
q0_sel = torch.cat(q0_sel,0) | |
return q0_sel | |
def forward(self, U, x1, x2, x3, o1, o2, qmask, g_hist, q0, r0, i0, e0): | |
""" | |
U -> batch, D_m | |
x1, x2, x3, o1, o2 -> batch, D_m | |
x1 -> effect on self; x2 -> reaction of self; x3 -> intent of self | |
o1 -> effect on others; o2 -> reaction of others | |
qmask -> batch, party | |
g_hist -> t-1, batch, D_g | |
q0 -> batch, party, D_p | |
e0 -> batch, self.D_e | |
""" | |
qm_idx = torch.argmax(qmask, 1) | |
q0_sel = self._select_parties(q0, qm_idx) | |
r0_sel = self._select_parties(r0, qm_idx) | |
## global state ## | |
g_ = self.g_cell(torch.cat([U, q0_sel, r0_sel], dim=1), | |
torch.zeros(U.size()[0],self.D_g).type(U.type()) if g_hist.size()[0]==0 else | |
g_hist[-1]) | |
# g_ = self.dropout(g_) | |
## context ## | |
if g_hist.size()[0]==0: | |
c_ = torch.zeros(U.size()[0], self.D_g).type(U.type()) | |
alpha = None | |
else: | |
c_, alpha = self.attention(g_hist, U) | |
## external state ## | |
U_r_c_ = torch.cat([U, x2, c_], dim=1).unsqueeze(1).expand(-1, qmask.size()[1],-1) | |
# print ('urc', U_r_c_.size()) | |
# print ('u x2, c', U.size(), x2.size(), c_.size()) | |
rs_ = self.r_cell(U_r_c_.contiguous().view(-1, self.D_m+self.D_s+self.D_g), | |
r0.view(-1, self.D_r)).view(U.size()[0], -1, self.D_r) | |
# rs_ = self.dropout(rs_) | |
## internal state ## | |
es_c_ = torch.cat([x1, c_], dim=1).unsqueeze(1).expand(-1,qmask.size()[1],-1) | |
qs_ = self.p_cell(es_c_.contiguous().view(-1, self.D_s+self.D_g), | |
q0.view(-1, self.D_p)).view(U.size()[0], -1, self.D_p) | |
# qs_ = self.dropout(qs_) | |
if self.listener_state: | |
## listener external state ## | |
U_ = U.unsqueeze(1).expand(-1,qmask.size()[1],-1).contiguous().view(-1,self.D_m) | |
er_ = o2.unsqueeze(1).expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_s) | |
ss_ = self._select_parties(rs_, qm_idx).unsqueeze(1).\ | |
expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_r) | |
U_er_ss_ = torch.cat([U_, er_, ss_], 1) | |
rl_ = self.rl_cell(U_er_ss_, r0.view(-1, self.D_r)).view(U.size()[0], -1, self.D_r) | |
# rl_ = self.dropout(rl_) | |
## listener internal state ## | |
es_ = o1.unsqueeze(1).expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_s) | |
ss_ = self._select_parties(qs_, qm_idx).unsqueeze(1).\ | |
expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_p) | |
es_ss_ = torch.cat([es_, ss_], 1) | |
ql_ = self.pl_cell(es_ss_, q0.view(-1, self.D_p)).view(U.size()[0], -1, self.D_p) | |
# ql_ = self.dropout(ql_) | |
else: | |
rl_ = r0 | |
ql_ = q0 | |
qmask_ = qmask.unsqueeze(2) | |
q_ = ql_*(1-qmask_) + qs_*qmask_ | |
r_ = rl_*(1-qmask_) + rs_*qmask_ | |
## intent ## | |
i_q_ = torch.cat([x3, self._select_parties(q_, qm_idx)], dim=1).unsqueeze(1).expand(-1, qmask.size()[1], -1) | |
is_ = self.i_cell(i_q_.contiguous().view(-1, self.D_s+self.D_p), | |
i0.view(-1, self.D_i)).view(U.size()[0], -1, self.D_i) | |
# is_ = self.dropout(is_) | |
il_ = i0 | |
i_ = il_*(1-qmask_) + is_*qmask_ | |
## emotion ## | |
es_ = torch.cat([U, self._select_parties(q_, qm_idx), self._select_parties(r_, qm_idx), | |
self._select_parties(i_, qm_idx)], dim=1) | |
e0 = torch.zeros(qmask.size()[0], self.D_e).type(U.type()) if e0.size()[0]==0\ | |
else e0 | |
if self.emo_gru: | |
e_ = self.e_cell(es_, e0) | |
else: | |
e_ = es_ | |
# e_ = self.dropout(e_) | |
g_ = self.dropout1(g_) | |
q_ = self.dropout2(q_) | |
r_ = self.dropout3(r_) | |
i_ = self.dropout4(i_) | |
e_ = self.dropout5(e_) | |
return g_, q_, r_, i_, e_, alpha | |
class CommonsenseRNN(nn.Module): | |
def __init__(self, D_m, D_s, D_g, D_p, D_r, D_i, D_e, listener_state=False, | |
context_attention='simple', D_a=100, dropout=0.5, emo_gru=True): | |
super(CommonsenseRNN, self).__init__() | |
self.D_m = D_m | |
self.D_g = D_g | |
self.D_p = D_p | |
self.D_r = D_r | |
self.D_i = D_i | |
self.D_e = D_e | |
self.dropout = nn.Dropout(dropout) | |
self.dialogue_cell = CommonsenseRNNCell(D_m, D_s, D_g, D_p, D_r, D_i, D_e, | |
listener_state, context_attention, D_a, dropout, emo_gru) | |
def forward(self, U, x1, x2, x3, o1, o2, qmask): | |
""" | |
U -> seq_len, batch, D_m | |
x1, x2, x3, o1, o2 -> seq_len, batch, D_s | |
qmask -> seq_len, batch, party | |
""" | |
g_hist = torch.zeros(0).type(U.type()) # 0-dimensional tensor | |
q_ = torch.zeros(qmask.size()[1], qmask.size()[2], self.D_p).type(U.type()) # batch, party, D_p | |
r_ = torch.zeros(qmask.size()[1], qmask.size()[2], self.D_r).type(U.type()) # batch, party, D_r | |
i_ = torch.zeros(qmask.size()[1], qmask.size()[2], self.D_i).type(U.type()) # batch, party, D_i | |
e_ = torch.zeros(0).type(U.type()) # batch, D_e | |
e = e_ | |
alpha = [] | |
for u_, x1_, x2_, x3_, o1_, o2_, qmask_ in zip(U, x1, x2, x3, o1, o2, qmask): | |
g_, q_, r_, i_, e_, alpha_ = self.dialogue_cell(u_, x1_, x2_, x3_, o1_, o2_, | |
qmask_, g_hist, q_, r_, i_, e_) | |
g_hist = torch.cat([g_hist, g_.unsqueeze(0)],0) | |
e = torch.cat([e, e_.unsqueeze(0)],0) | |
if type(alpha_)!=type(None): | |
alpha.append(alpha_[:,0,:]) | |
return e, alpha # seq_len, batch, D_e | |
class CommonsenseGRUModel(nn.Module): | |
def __init__(self, D_m, D_s, D_g, D_p, D_r, D_i, D_e, D_h, D_a=100, n_classes=7, listener_state=False, | |
context_attention='simple', dropout_rec=0.5, dropout=0.1, emo_gru=True, mode1=0, norm=0, residual=False): | |
super(CommonsenseGRUModel, self).__init__() | |
if mode1 == 0: | |
D_x = 4 * D_m | |
elif mode1 == 1: | |
D_x = 2 * D_m | |
else: | |
D_x = D_m | |
self.mode1 = mode1 | |
self.norm_strategy = norm | |
self.linear_in = nn.Linear(D_x, D_h) | |
self.residual = residual | |
self.r_weights = nn.Parameter(torch.tensor([0.25, 0.25, 0.25, 0.25])) | |
norm_train = True | |
self.norm1a = nn.LayerNorm(D_m, elementwise_affine=norm_train) | |
self.norm1b = nn.LayerNorm(D_m, elementwise_affine=norm_train) | |
self.norm1c = nn.LayerNorm(D_m, elementwise_affine=norm_train) | |
self.norm1d = nn.LayerNorm(D_m, elementwise_affine=norm_train) | |
self.norm3a = nn.BatchNorm1d(D_m, affine=norm_train) | |
self.norm3b = nn.BatchNorm1d(D_m, affine=norm_train) | |
self.norm3c = nn.BatchNorm1d(D_m, affine=norm_train) | |
self.norm3d = nn.BatchNorm1d(D_m, affine=norm_train) | |
self.dropout = nn.Dropout(dropout) | |
self.dropout_rec = nn.Dropout(dropout_rec) | |
self.cs_rnn_f = CommonsenseRNN(D_h, D_s, D_g, D_p, D_r, D_i, D_e, listener_state, | |
context_attention, D_a, dropout_rec, emo_gru) | |
self.cs_rnn_r = CommonsenseRNN(D_h, D_s, D_g, D_p, D_r, D_i, D_e, listener_state, | |
context_attention, D_a, dropout_rec, emo_gru) | |
self.sense_gru = nn.GRU(input_size=D_s, hidden_size=D_s//2, num_layers=1, bidirectional=True) | |
self.matchatt = MatchingAttention(2*D_e,2*D_e,att_type='general2') | |
self.linear = nn.Linear(2*D_e, D_h) | |
self.smax_fc = nn.Linear(D_h, n_classes) | |
def _reverse_seq(self, X, mask): | |
""" | |
X -> seq_len, batch, dim | |
mask -> batch, seq_len | |
""" | |
X_ = X.transpose(0,1) | |
mask_sum = torch.sum(mask, 1).int() | |
xfs = [] | |
for x, c in zip(X_, mask_sum): | |
xf = torch.flip(x[:c], [0]) | |
xfs.append(xf) | |
return pad_sequence(xfs) | |
def forward(self, r1, r2, r3, r4, x1, x2, x3, o1, o2, qmask, umask, att2=False, return_hidden=False): | |
""" | |
U -> seq_len, batch, D_m | |
qmask -> seq_len, batch, party | |
""" | |
seq_len, batch, feature_dim = r1.size() | |
if self.norm_strategy == 1: | |
r1 = self.norm1a(r1.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0) | |
r2 = self.norm1b(r2.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0) | |
r3 = self.norm1c(r3.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0) | |
r4 = self.norm1d(r4.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0) | |
elif self.norm_strategy == 2: | |
norm2 = nn.LayerNorm((seq_len, feature_dim), elementwise_affine=False) | |
r1 = norm2(r1.transpose(0, 1)).transpose(0, 1) | |
r2 = norm2(r2.transpose(0, 1)).transpose(0, 1) | |
r3 = norm2(r3.transpose(0, 1)).transpose(0, 1) | |
r4 = norm2(r4.transpose(0, 1)).transpose(0, 1) | |
elif self.norm_strategy == 3: | |
r1 = self.norm3a(r1.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0) | |
r2 = self.norm3b(r2.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0) | |
r3 = self.norm3c(r3.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0) | |
r4 = self.norm3d(r4.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0) | |
if self.mode1 == 0: | |
r = torch.cat([r1, r2, r3, r4], axis=-1) | |
elif self.mode1 == 1: | |
r = torch.cat([r1, r2], axis=-1) | |
elif self.mode1 == 2: | |
r = (r1 + r2 + r3 + r4)/4 | |
elif self.mode1 == 3: | |
r = r1 | |
elif self.mode1 == 4: | |
r = r2 | |
elif self.mode1 == 5: | |
r = r3 | |
elif self.mode1 == 6: | |
r = r4 | |
elif self.mode1 == 7: | |
r = self.r_weights[0]*r1 + self.r_weights[1]*r2 + self.r_weights[2]*r3 + self.r_weights[3]*r4 | |
r = self.linear_in(r) | |
emotions_f, alpha_f = self.cs_rnn_f(r, x1, x2, x3, o1, o2, qmask) | |
out_sense, _ = self.sense_gru(x1) | |
rev_r = self._reverse_seq(r, umask) | |
rev_x1 = self._reverse_seq(x1, umask) | |
rev_x2 = self._reverse_seq(x2, umask) | |
rev_x3 = self._reverse_seq(x3, umask) | |
rev_o1 = self._reverse_seq(o1, umask) | |
rev_o2 = self._reverse_seq(o2, umask) | |
rev_qmask = self._reverse_seq(qmask, umask) | |
emotions_b, alpha_b = self.cs_rnn_r(rev_r, rev_x1, rev_x2, rev_x3, rev_o1, rev_o2, rev_qmask) | |
emotions_b = self._reverse_seq(emotions_b, umask) | |
emotions = torch.cat([emotions_f,emotions_b],dim=-1) | |
emotions = self.dropout_rec(emotions) | |
alpha, alpha_f, alpha_b = [], [], [] | |
if att2: | |
att_emotions = [] | |
alpha = [] | |
for t in emotions: | |
att_em, alpha_ = self.matchatt(emotions,t,mask=umask) | |
att_emotions.append(att_em.unsqueeze(0)) | |
alpha.append(alpha_[:,0,:]) | |
att_emotions = torch.cat(att_emotions,dim=0) | |
hidden = F.relu(self.linear(att_emotions)) | |
else: | |
hidden = F.relu(self.linear(emotions)) | |
hidden = self.dropout(hidden) | |
if self.residual: | |
hidden = hidden + r | |
log_prob = F.log_softmax(self.smax_fc(hidden), 2) | |
if return_hidden: | |
return hidden, alpha, alpha_f, alpha_b, emotions | |
return log_prob, out_sense, alpha, alpha_f, alpha_b, emotions | |