import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .modules import AudioEncoder
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig

class BartCaptionModel(nn.Module):
    def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768):
        super(BartCaptionModel, self).__init__()
        # non-finetunning case
        bart_config = BartConfig.from_pretrained(bart_type)
        self.tokenizer = BartTokenizer.from_pretrained(bart_type)
        self.bart = BartForConditionalGeneration(bart_config)
        
        self.n_sample = sr * duration
        self.hop_length = int(0.01 * sr) # hard coding hop_size
        self.n_frames = int(self.n_sample // self.hop_length)
        self.num_of_stride_conv = num_of_conv - 1
        self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1
        self.audio_encoder = AudioEncoder(
            n_mels = n_mels, # hard coding n_mel
            n_ctx = self.n_ctx, 
            audio_dim = audio_dim, 
            text_dim = self.bart.config.hidden_size,
            num_of_stride_conv = self.num_of_stride_conv
        )

        self.max_length = max_length
        self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100)

    @property
    def device(self):
        return list(self.parameters())[0].device

    def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
        """
        Shift input ids one token to the right.ls
        """
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
        shifted_input_ids[:, 0] = decoder_start_token_id

        if pad_token_id is None:
            raise ValueError("self.model.config.pad_token_id has to be defined.")
        # replace possible -100 values in labels by `pad_token_id`
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
        return shifted_input_ids

    def forward_encoder(self, audio):
        audio_embs = self.audio_encoder(audio)
        encoder_outputs = self.bart.model.encoder(
            input_ids=None,
            inputs_embeds=audio_embs,
            return_dict=True
        )["last_hidden_state"]
        return encoder_outputs, audio_embs

    def forward_decoder(self, text, encoder_outputs):
        text = self.tokenizer(text,
                              padding='longest',
                              truncation=True,
                              max_length=self.max_length,
                              return_tensors="pt")
        input_ids = text["input_ids"].to(self.device)
        attention_mask = text["attention_mask"].to(self.device)

        decoder_targets = input_ids.masked_fill(
            input_ids == self.tokenizer.pad_token_id, -100
        )

        decoder_input_ids = self.shift_tokens_right(
            decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id
        )

        decoder_outputs = self.bart(
            input_ids=None,
            attention_mask=None,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=attention_mask,
            inputs_embeds=None,
            labels=None,
            encoder_outputs=(encoder_outputs,),
            return_dict=True
        )
        lm_logits = decoder_outputs["logits"]
        loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1))
        return loss

    def forward(self, audio, text):
        encoder_outputs, _ = self.forward_encoder(audio)
        loss = self.forward_decoder(text, encoder_outputs)
        return loss

    def generate(self,
                 samples,
                 use_nucleus_sampling=False,
                 num_beams=5,
                 max_length=128,
                 min_length=2,
                 top_p=0.9,
                 repetition_penalty=1.0,
                 ):

        # self.bart.force_bos_token_to_be_generated = True
        audio_embs = self.audio_encoder(samples)
        encoder_outputs = self.bart.model.encoder(
            input_ids=None,
            attention_mask=None,
            head_mask=None,
            inputs_embeds=audio_embs,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=True)

        input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
        input_ids[:, 0] = self.bart.config.decoder_start_token_id
        decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
        if use_nucleus_sampling:
            outputs = self.bart.generate(
                input_ids=None,
                attention_mask=None,
                decoder_input_ids=input_ids,
                decoder_attention_mask=decoder_attention_mask,
                encoder_outputs=encoder_outputs,
                max_length=max_length,
                min_length=min_length,
                do_sample=True,
                top_p=top_p,
                num_return_sequences=1,
                repetition_penalty=1.1)
        else:
            outputs = self.bart.generate(input_ids=None,
                                            attention_mask=None,
                                            decoder_input_ids=input_ids,
                                            decoder_attention_mask=decoder_attention_mask,
                                            encoder_outputs=encoder_outputs,
                                            head_mask=None,
                                            decoder_head_mask=None,
                                            inputs_embeds=None,
                                            decoder_inputs_embeds=None,
                                            use_cache=None,
                                            output_attentions=None,
                                            output_hidden_states=None,
                                            max_length=max_length,
                                            min_length=min_length,
                                            num_beams=num_beams,
                                            repetition_penalty=repetition_penalty)

        captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        return captions