# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
import torch.nn as nn
import torch
from torch.autograd import Variable
import copy
import numpy as np

class Seq2Seq(nn.Module):
    """
        Build Seqence-to-Sequence.

        Parameters:

        * `encoder`- encoder of seq2seq model. e.g. roberta
        * `decoder`- decoder of seq2seq model. e.g. transformer
        * `config`- configuration of encoder model. 
        * `beam_size`- beam size for beam search. 
        * `max_length`- max length of target for beam search. 
        * `sos_id`- start of symbol ids in target for beam search.
        * `eos_id`- end of symbol ids in target for beam search. 
    """

    def __init__(self, encoder, decoder, config, mse_loss_weight=0.95, ce_loss_weight=0.05, beam_size=None, max_length=None, sos_id=None, eos_id=None, ):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.config = config
        self.register_buffer(
            "bias", torch.tril(torch.ones(
                (1024, 1024), dtype=torch.uint8)).view(1, 1024, 1024)
        )
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.lm_head = nn.Linear(
            config.hidden_size, config.vocab_size, bias=False)
        self.lm_head.weight = self.encoder.embeddings.word_embeddings.weight
        self.lsm = nn.LogSoftmax(dim=-1)

        self.pred_dense = nn.Linear(config.hidden_size, 1, bias=True)
        self.sigmoid = nn.Sigmoid()

        self.mse_loss_weight = mse_loss_weight
        self.ce_loss_weight = ce_loss_weight

        self.beam_size = beam_size
        self.max_length = max_length
        self.sos_id = sos_id
        self.eos_id = eos_id

    def forward(self, source_ids, exist=None, target_ids=None):
        if target_ids is None or exist is None:
            return self.generate(source_ids)

        mask = source_ids.ne(1)[:, None, :]*source_ids.ne(1)[:, :, None]
        encoder_output = self.encoder(
            source_ids, attention_mask=mask, use_cache=True)
        ids = torch.cat((source_ids, target_ids), -1)

        mask = self.bias[:,
                         source_ids.size(-1):ids.size(-1), :ids.size(-1)].bool()
        mask = mask & ids[:, None, :].ne(1)

        out = self.decoder(target_ids, attention_mask=mask,
                           past_key_values=encoder_output.past_key_values).last_hidden_state

        lm_logits = self.lm_head(out[..., 1:, :])
        # Shift so that tokens < n predict n
        active_loss = target_ids[..., 2:].ne(1).view(-1)
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = target_ids[..., 2:].contiguous()

        exist_labels = exist.contiguous()
        pred_out = out[..., 0, :]
        pred_sigmoid = self.sigmoid(self.pred_dense(pred_out))
        # Flatten the tokens
        loss_fct_code = nn.CrossEntropyLoss(ignore_index=-1)
        loss_fct_pred = nn.MSELoss(reduction="mean")
        loss_code = loss_fct_code(shift_logits.view(-1, shift_logits.size(-1))[active_loss],
                                  shift_labels.view(-1)[active_loss])

        loss_pred = loss_fct_pred(pred_sigmoid, exist_labels)
        loss = loss_pred * self.mse_loss_weight + loss_code * self.ce_loss_weight

        outputs = loss, loss*active_loss.sum(), active_loss.sum(), loss_pred, loss_code
        return outputs

    def generate(self, source_ids):
        mask = source_ids.ne(1)[:, None, :] * source_ids.ne(1)[:, :, None]
        encoder_output = self.encoder(
            source_ids, attention_mask=mask, use_cache=True)
        preds = []
        predicates = []
        zero = torch.cuda.LongTensor(1).fill_(0)
        source_len = list(source_ids.ne(1).sum(-1).cpu().numpy())
        for i in range(source_ids.shape[0]):
            context = [[x[i:i+1, :, :source_len[i]].repeat(self.beam_size, 1, 1, 1) for x in y]
                       for y in encoder_output.past_key_values]
            beam = Beam(self.beam_size, self.sos_id, self.eos_id)
            input_ids = beam.getCurrentState()
            context_ids = source_ids[i:i+1,
                                     :source_len[i]].repeat(self.beam_size, 1)
            predicate = []
            for _ in range(self.max_length):
                if beam.done():
                    break

                ids = torch.cat((context_ids, input_ids), -1)
                mask = self.bias[:,
                                 context_ids.size(-1):ids.size(-1), :ids.size(-1)].bool()
                mask = mask & ids[:, None, :].ne(1)
                out = self.decoder(input_ids, attention_mask=mask,
                                   past_key_values=context).last_hidden_state
                hidden_states = out[:, -1, :]
                if out.size(1) == 1:
                    pred_sigmoid = self.sigmoid(self.pred_dense(
                        hidden_states.view(-1, 1, hidden_states.size(-1))))
                    predicate.append(
                        pred_sigmoid.view(-1, pred_sigmoid.size(-1)))
                    #predicate.append(pred_sigmoid.view(-1, pred_sigmoid.size(-1)).cpu().numpy())# ZM modified
                out = self.lsm(self.lm_head(hidden_states)).data
                beam.advance(out)
                input_ids.data.copy_(input_ids.data.index_select(
                    0, beam.getCurrentOrigin()))
                input_ids = torch.cat((input_ids, beam.getCurrentState()), -1)
            hyp = beam.getHyp(beam.getFinal())
            pred = beam.buildTargetTokens(hyp)[:self.beam_size]
            pred = [torch.cat([x.view(-1) for x in p] + [zero] *
                              (self.max_length-len(p))).view(1, -1) for p in pred]
            predicates.append(predicate[0][0])# ZM modified
            preds.append(torch.cat(pred, 0).unsqueeze(0))
        preds = torch.cat(preds, 0)
        predicates = torch.tensor(predicates, device="cuda")# ZM modified
        return preds, predicates


class Beam(object):
    def __init__(self, size, sos, eos):
        self.size = size
        self.tt = torch.cuda
        # The score for each translation on the beam.
        self.scores = self.tt.FloatTensor(size).zero_()
        # The backpointers at each time-step.
        self.prevKs = []
        # The outputs at each time-step.
        self.nextYs = [self.tt.LongTensor(size)
                       .fill_(0)]
        self.nextYs[0][0] = sos
        # Has EOS topped the beam yet.
        self._eos = eos
        self.eosTop = False
        # Time and k pair for finished.
        self.finished = []

    def getCurrentState(self):
        "Get the outputs for the current timestep."
        batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
        return batch

    def getCurrentOrigin(self):
        "Get the backpointers for the current timestep."
        return self.prevKs[-1]

    def advance(self, wordLk):
        """
        Given prob over words for every last beam `wordLk` and attention
        `attnOut`: Compute and update the beam search.

        Parameters:

        * `wordLk`- probs of advancing from the last step (K x words)
        * `attnOut`- attention at the last step

        Returns: True if beam search is complete.
        """
        numWords = wordLk.size(1)

        # Sum the previous scores.
        if len(self.prevKs) > 0:
            beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)

            # Don't let EOS have children.
            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] == self._eos:
                    beamLk[i] = -1e20
        else:
            beamLk = wordLk[0]
        flatBeamLk = beamLk.view(-1)
        bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)

        self.scores = bestScores

        # bestScoresId is flattened beam x word array, so calculate which
        # word and beam each score came from
        prevK = bestScoresId // numWords
        self.prevKs.append(prevK)
        self.nextYs.append((bestScoresId - prevK * numWords))

        for i in range(self.nextYs[-1].size(0)):
            if self.nextYs[-1][i] == self._eos:
                s = self.scores[i]
                self.finished.append((s, len(self.nextYs) - 1, i))

        # End condition is when top-of-beam is EOS and no global score.
        if self.nextYs[-1][0] == self._eos:
            self.eosTop = True

    def done(self):
        return self.eosTop and len(self.finished) >= self.size

    def getFinal(self):
        if len(self.finished) == 0:
            self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
        self.finished.sort(key=lambda a: -a[0])
        if len(self.finished) != self.size:
            unfinished = []
            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] != self._eos:
                    s = self.scores[i]
                    unfinished.append((s, len(self.nextYs) - 1, i))
            unfinished.sort(key=lambda a: -a[0])
            self.finished += unfinished[:self.size-len(self.finished)]
        return self.finished[:self.size]

    def getHyp(self, beam_res):
        """
        Walk back to construct the full hypothesis.
        """
        hyps = []
        for _, timestep, k in beam_res:
            hyp = []
            for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
                hyp.append(self.nextYs[j+1][k])
                k = self.prevKs[j][k]
            hyps.append(hyp[::-1])
        return hyps

    def buildTargetTokens(self, preds):
        sentence = []
        for pred in preds:
            tokens = []
            for tok in pred:
                if tok == self._eos:
                    break
                tokens.append(tok)
            sentence.append(tokens)
        return sentence