VEGA_AE / Scripts /UnixCoder /model_gen.py
unknown
Initial
ac3312e
raw
history blame
9.87 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch
from torch.autograd import Variable
import copy
import numpy as np
class Seq2Seq(nn.Module):
"""
Build Seqence-to-Sequence.
Parameters:
* `encoder`- encoder of seq2seq model. e.g. roberta
* `decoder`- decoder of seq2seq model. e.g. transformer
* `config`- configuration of encoder model.
* `beam_size`- beam size for beam search.
* `max_length`- max length of target for beam search.
* `sos_id`- start of symbol ids in target for beam search.
* `eos_id`- end of symbol ids in target for beam search.
"""
def __init__(self, encoder, decoder, config, mse_loss_weight=0.95, ce_loss_weight=0.05, beam_size=None, max_length=None, sos_id=None, eos_id=None, ):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.config = config
self.register_buffer(
"bias", torch.tril(torch.ones(
(1024, 1024), dtype=torch.uint8)).view(1, 1024, 1024)
)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=False)
self.lm_head.weight = self.encoder.embeddings.word_embeddings.weight
self.lsm = nn.LogSoftmax(dim=-1)
self.pred_dense = nn.Linear(config.hidden_size, 1, bias=True)
self.sigmoid = nn.Sigmoid()
self.mse_loss_weight = mse_loss_weight
self.ce_loss_weight = ce_loss_weight
self.beam_size = beam_size
self.max_length = max_length
self.sos_id = sos_id
self.eos_id = eos_id
def forward(self, source_ids, exist=None, target_ids=None):
if target_ids is None or exist is None:
return self.generate(source_ids)
mask = source_ids.ne(1)[:, None, :]*source_ids.ne(1)[:, :, None]
encoder_output = self.encoder(
source_ids, attention_mask=mask, use_cache=True)
ids = torch.cat((source_ids, target_ids), -1)
mask = self.bias[:,
source_ids.size(-1):ids.size(-1), :ids.size(-1)].bool()
mask = mask & ids[:, None, :].ne(1)
out = self.decoder(target_ids, attention_mask=mask,
past_key_values=encoder_output.past_key_values).last_hidden_state
lm_logits = self.lm_head(out[..., 1:, :])
# Shift so that tokens < n predict n
active_loss = target_ids[..., 2:].ne(1).view(-1)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = target_ids[..., 2:].contiguous()
exist_labels = exist.contiguous()
pred_out = out[..., 0, :]
pred_sigmoid = self.sigmoid(self.pred_dense(pred_out))
# Flatten the tokens
loss_fct_code = nn.CrossEntropyLoss(ignore_index=-1)
loss_fct_pred = nn.MSELoss(reduction="mean")
loss_code = loss_fct_code(shift_logits.view(-1, shift_logits.size(-1))[active_loss],
shift_labels.view(-1)[active_loss])
loss_pred = loss_fct_pred(pred_sigmoid, exist_labels)
loss = loss_pred * self.mse_loss_weight + loss_code * self.ce_loss_weight
outputs = loss, loss*active_loss.sum(), active_loss.sum(), loss_pred, loss_code
return outputs
def generate(self, source_ids):
mask = source_ids.ne(1)[:, None, :] * source_ids.ne(1)[:, :, None]
encoder_output = self.encoder(
source_ids, attention_mask=mask, use_cache=True)
preds = []
predicates = []
zero = torch.cuda.LongTensor(1).fill_(0)
source_len = list(source_ids.ne(1).sum(-1).cpu().numpy())
for i in range(source_ids.shape[0]):
context = [[x[i:i+1, :, :source_len[i]].repeat(self.beam_size, 1, 1, 1) for x in y]
for y in encoder_output.past_key_values]
beam = Beam(self.beam_size, self.sos_id, self.eos_id)
input_ids = beam.getCurrentState()
context_ids = source_ids[i:i+1,
:source_len[i]].repeat(self.beam_size, 1)
predicate = []
for _ in range(self.max_length):
if beam.done():
break
ids = torch.cat((context_ids, input_ids), -1)
mask = self.bias[:,
context_ids.size(-1):ids.size(-1), :ids.size(-1)].bool()
mask = mask & ids[:, None, :].ne(1)
out = self.decoder(input_ids, attention_mask=mask,
past_key_values=context).last_hidden_state
hidden_states = out[:, -1, :]
if out.size(1) == 1:
pred_sigmoid = self.sigmoid(self.pred_dense(
hidden_states.view(-1, 1, hidden_states.size(-1))))
predicate.append(
pred_sigmoid.view(-1, pred_sigmoid.size(-1)))
#predicate.append(pred_sigmoid.view(-1, pred_sigmoid.size(-1)).cpu().numpy())# ZM modified
out = self.lsm(self.lm_head(hidden_states)).data
beam.advance(out)
input_ids.data.copy_(input_ids.data.index_select(
0, beam.getCurrentOrigin()))
input_ids = torch.cat((input_ids, beam.getCurrentState()), -1)
hyp = beam.getHyp(beam.getFinal())
pred = beam.buildTargetTokens(hyp)[:self.beam_size]
pred = [torch.cat([x.view(-1) for x in p] + [zero] *
(self.max_length-len(p))).view(1, -1) for p in pred]
predicates.append(predicate[0][0])# ZM modified
preds.append(torch.cat(pred, 0).unsqueeze(0))
preds = torch.cat(preds, 0)
predicates = torch.tensor(predicates, device="cuda")# ZM modified
return preds, predicates
class Beam(object):
def __init__(self, size, sos, eos):
self.size = size
self.tt = torch.cuda
# The score for each translation on the beam.
self.scores = self.tt.FloatTensor(size).zero_()
# The backpointers at each time-step.
self.prevKs = []
# The outputs at each time-step.
self.nextYs = [self.tt.LongTensor(size)
.fill_(0)]
self.nextYs[0][0] = sos
# Has EOS topped the beam yet.
self._eos = eos
self.eosTop = False
# Time and k pair for finished.
self.finished = []
def getCurrentState(self):
"Get the outputs for the current timestep."
batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
return batch
def getCurrentOrigin(self):
"Get the backpointers for the current timestep."
return self.prevKs[-1]
def advance(self, wordLk):
"""
Given prob over words for every last beam `wordLk` and attention
`attnOut`: Compute and update the beam search.
Parameters:
* `wordLk`- probs of advancing from the last step (K x words)
* `attnOut`- attention at the last step
Returns: True if beam search is complete.
"""
numWords = wordLk.size(1)
# Sum the previous scores.
if len(self.prevKs) > 0:
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
# Don't let EOS have children.
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
beamLk[i] = -1e20
else:
beamLk = wordLk[0]
flatBeamLk = beamLk.view(-1)
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
self.scores = bestScores
# bestScoresId is flattened beam x word array, so calculate which
# word and beam each score came from
prevK = bestScoresId // numWords
self.prevKs.append(prevK)
self.nextYs.append((bestScoresId - prevK * numWords))
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
s = self.scores[i]
self.finished.append((s, len(self.nextYs) - 1, i))
# End condition is when top-of-beam is EOS and no global score.
if self.nextYs[-1][0] == self._eos:
self.eosTop = True
def done(self):
return self.eosTop and len(self.finished) >= self.size
def getFinal(self):
if len(self.finished) == 0:
self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
self.finished.sort(key=lambda a: -a[0])
if len(self.finished) != self.size:
unfinished = []
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] != self._eos:
s = self.scores[i]
unfinished.append((s, len(self.nextYs) - 1, i))
unfinished.sort(key=lambda a: -a[0])
self.finished += unfinished[:self.size-len(self.finished)]
return self.finished[:self.size]
def getHyp(self, beam_res):
"""
Walk back to construct the full hypothesis.
"""
hyps = []
for _, timestep, k in beam_res:
hyp = []
for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
hyp.append(self.nextYs[j+1][k])
k = self.prevKs[j][k]
hyps.append(hyp[::-1])
return hyps
def buildTargetTokens(self, preds):
sentence = []
for pred in preds:
tokens = []
for tok in pred:
if tok == self._eos:
break
tokens.append(tok)
sentence.append(tokens)
return sentence