|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch |
|
from torch.autograd import Variable |
|
import copy |
|
import numpy as np |
|
|
|
class Seq2Seq(nn.Module): |
|
""" |
|
Build Seqence-to-Sequence. |
|
|
|
Parameters: |
|
|
|
* `encoder`- encoder of seq2seq model. e.g. roberta |
|
* `decoder`- decoder of seq2seq model. e.g. transformer |
|
* `config`- configuration of encoder model. |
|
* `beam_size`- beam size for beam search. |
|
* `max_length`- max length of target for beam search. |
|
* `sos_id`- start of symbol ids in target for beam search. |
|
* `eos_id`- end of symbol ids in target for beam search. |
|
""" |
|
|
|
def __init__(self, encoder, decoder, config, mse_loss_weight=0.95, ce_loss_weight=0.05, beam_size=None, max_length=None, sos_id=None, eos_id=None, ): |
|
super(Seq2Seq, self).__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.config = config |
|
self.register_buffer( |
|
"bias", torch.tril(torch.ones( |
|
(1024, 1024), dtype=torch.uint8)).view(1, 1024, 1024) |
|
) |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.lm_head = nn.Linear( |
|
config.hidden_size, config.vocab_size, bias=False) |
|
self.lm_head.weight = self.encoder.embeddings.word_embeddings.weight |
|
self.lsm = nn.LogSoftmax(dim=-1) |
|
|
|
self.pred_dense = nn.Linear(config.hidden_size, 1, bias=True) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
self.mse_loss_weight = mse_loss_weight |
|
self.ce_loss_weight = ce_loss_weight |
|
|
|
self.beam_size = beam_size |
|
self.max_length = max_length |
|
self.sos_id = sos_id |
|
self.eos_id = eos_id |
|
|
|
def forward(self, source_ids, exist=None, target_ids=None): |
|
if target_ids is None or exist is None: |
|
return self.generate(source_ids) |
|
|
|
mask = source_ids.ne(1)[:, None, :]*source_ids.ne(1)[:, :, None] |
|
encoder_output = self.encoder( |
|
source_ids, attention_mask=mask, use_cache=True) |
|
ids = torch.cat((source_ids, target_ids), -1) |
|
|
|
mask = self.bias[:, |
|
source_ids.size(-1):ids.size(-1), :ids.size(-1)].bool() |
|
mask = mask & ids[:, None, :].ne(1) |
|
|
|
out = self.decoder(target_ids, attention_mask=mask, |
|
past_key_values=encoder_output.past_key_values).last_hidden_state |
|
|
|
lm_logits = self.lm_head(out[..., 1:, :]) |
|
|
|
active_loss = target_ids[..., 2:].ne(1).view(-1) |
|
shift_logits = lm_logits[..., :-1, :].contiguous() |
|
shift_labels = target_ids[..., 2:].contiguous() |
|
|
|
exist_labels = exist.contiguous() |
|
pred_out = out[..., 0, :] |
|
pred_sigmoid = self.sigmoid(self.pred_dense(pred_out)) |
|
|
|
loss_fct_code = nn.CrossEntropyLoss(ignore_index=-1) |
|
loss_fct_pred = nn.MSELoss(reduction="mean") |
|
loss_code = loss_fct_code(shift_logits.view(-1, shift_logits.size(-1))[active_loss], |
|
shift_labels.view(-1)[active_loss]) |
|
|
|
loss_pred = loss_fct_pred(pred_sigmoid, exist_labels) |
|
loss = loss_pred * self.mse_loss_weight + loss_code * self.ce_loss_weight |
|
|
|
outputs = loss, loss*active_loss.sum(), active_loss.sum(), loss_pred, loss_code |
|
return outputs |
|
|
|
def generate(self, source_ids): |
|
mask = source_ids.ne(1)[:, None, :] * source_ids.ne(1)[:, :, None] |
|
encoder_output = self.encoder( |
|
source_ids, attention_mask=mask, use_cache=True) |
|
preds = [] |
|
predicates = [] |
|
zero = torch.cuda.LongTensor(1).fill_(0) |
|
source_len = list(source_ids.ne(1).sum(-1).cpu().numpy()) |
|
for i in range(source_ids.shape[0]): |
|
context = [[x[i:i+1, :, :source_len[i]].repeat(self.beam_size, 1, 1, 1) for x in y] |
|
for y in encoder_output.past_key_values] |
|
beam = Beam(self.beam_size, self.sos_id, self.eos_id) |
|
input_ids = beam.getCurrentState() |
|
context_ids = source_ids[i:i+1, |
|
:source_len[i]].repeat(self.beam_size, 1) |
|
predicate = [] |
|
for _ in range(self.max_length): |
|
if beam.done(): |
|
break |
|
|
|
ids = torch.cat((context_ids, input_ids), -1) |
|
mask = self.bias[:, |
|
context_ids.size(-1):ids.size(-1), :ids.size(-1)].bool() |
|
mask = mask & ids[:, None, :].ne(1) |
|
out = self.decoder(input_ids, attention_mask=mask, |
|
past_key_values=context).last_hidden_state |
|
hidden_states = out[:, -1, :] |
|
if out.size(1) == 1: |
|
pred_sigmoid = self.sigmoid(self.pred_dense( |
|
hidden_states.view(-1, 1, hidden_states.size(-1)))) |
|
predicate.append( |
|
pred_sigmoid.view(-1, pred_sigmoid.size(-1))) |
|
|
|
out = self.lsm(self.lm_head(hidden_states)).data |
|
beam.advance(out) |
|
input_ids.data.copy_(input_ids.data.index_select( |
|
0, beam.getCurrentOrigin())) |
|
input_ids = torch.cat((input_ids, beam.getCurrentState()), -1) |
|
hyp = beam.getHyp(beam.getFinal()) |
|
pred = beam.buildTargetTokens(hyp)[:self.beam_size] |
|
pred = [torch.cat([x.view(-1) for x in p] + [zero] * |
|
(self.max_length-len(p))).view(1, -1) for p in pred] |
|
predicates.append(predicate[0][0]) |
|
preds.append(torch.cat(pred, 0).unsqueeze(0)) |
|
preds = torch.cat(preds, 0) |
|
predicates = torch.tensor(predicates, device="cuda") |
|
return preds, predicates |
|
|
|
|
|
class Beam(object): |
|
def __init__(self, size, sos, eos): |
|
self.size = size |
|
self.tt = torch.cuda |
|
|
|
self.scores = self.tt.FloatTensor(size).zero_() |
|
|
|
self.prevKs = [] |
|
|
|
self.nextYs = [self.tt.LongTensor(size) |
|
.fill_(0)] |
|
self.nextYs[0][0] = sos |
|
|
|
self._eos = eos |
|
self.eosTop = False |
|
|
|
self.finished = [] |
|
|
|
def getCurrentState(self): |
|
"Get the outputs for the current timestep." |
|
batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) |
|
return batch |
|
|
|
def getCurrentOrigin(self): |
|
"Get the backpointers for the current timestep." |
|
return self.prevKs[-1] |
|
|
|
def advance(self, wordLk): |
|
""" |
|
Given prob over words for every last beam `wordLk` and attention |
|
`attnOut`: Compute and update the beam search. |
|
|
|
Parameters: |
|
|
|
* `wordLk`- probs of advancing from the last step (K x words) |
|
* `attnOut`- attention at the last step |
|
|
|
Returns: True if beam search is complete. |
|
""" |
|
numWords = wordLk.size(1) |
|
|
|
|
|
if len(self.prevKs) > 0: |
|
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) |
|
|
|
|
|
for i in range(self.nextYs[-1].size(0)): |
|
if self.nextYs[-1][i] == self._eos: |
|
beamLk[i] = -1e20 |
|
else: |
|
beamLk = wordLk[0] |
|
flatBeamLk = beamLk.view(-1) |
|
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) |
|
|
|
self.scores = bestScores |
|
|
|
|
|
|
|
prevK = bestScoresId // numWords |
|
self.prevKs.append(prevK) |
|
self.nextYs.append((bestScoresId - prevK * numWords)) |
|
|
|
for i in range(self.nextYs[-1].size(0)): |
|
if self.nextYs[-1][i] == self._eos: |
|
s = self.scores[i] |
|
self.finished.append((s, len(self.nextYs) - 1, i)) |
|
|
|
|
|
if self.nextYs[-1][0] == self._eos: |
|
self.eosTop = True |
|
|
|
def done(self): |
|
return self.eosTop and len(self.finished) >= self.size |
|
|
|
def getFinal(self): |
|
if len(self.finished) == 0: |
|
self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) |
|
self.finished.sort(key=lambda a: -a[0]) |
|
if len(self.finished) != self.size: |
|
unfinished = [] |
|
for i in range(self.nextYs[-1].size(0)): |
|
if self.nextYs[-1][i] != self._eos: |
|
s = self.scores[i] |
|
unfinished.append((s, len(self.nextYs) - 1, i)) |
|
unfinished.sort(key=lambda a: -a[0]) |
|
self.finished += unfinished[:self.size-len(self.finished)] |
|
return self.finished[:self.size] |
|
|
|
def getHyp(self, beam_res): |
|
""" |
|
Walk back to construct the full hypothesis. |
|
""" |
|
hyps = [] |
|
for _, timestep, k in beam_res: |
|
hyp = [] |
|
for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): |
|
hyp.append(self.nextYs[j+1][k]) |
|
k = self.prevKs[j][k] |
|
hyps.append(hyp[::-1]) |
|
return hyps |
|
|
|
def buildTargetTokens(self, preds): |
|
sentence = [] |
|
for pred in preds: |
|
tokens = [] |
|
for tok in pred: |
|
if tok == self._eos: |
|
break |
|
tokens.append(tok) |
|
sentence.append(tokens) |
|
return sentence |
|
|