import torch
from torch import nn
from torch.nn import CrossEntropyLoss

from transformers import (
    BertForSequenceClassification as SeqClassification,
    BertPreTrainedModel,
    BertModel,
    BertConfig,
)

from .modeling_outputs import (
    QuestionAnsweringModelOutput,
    QuestionAnsweringNaModelOutput,
)


class BertForSequenceClassification(SeqClassification):
    model_type = "bert"


class BertForQuestionAnsweringAVPool(BertPreTrainedModel):    
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    model_type = "bert"
    
    def __init__(self, config):
        super(BertForQuestionAnsweringAVPool, self).__init__(config)
        self.num_labels = config.num_labels
        
        self.bert  = BertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
        self.has_ans = nn.Sequential(
            nn.Dropout(p=config.hidden_dropout_prob),
            nn.Linear(config.hidden_size, 2)
        )
        
        # Initialize weights and apply final processing
        self.post_init()
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        is_impossibles=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        
        sequence_output = outputs[0]
        
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()
        
        first_word = sequence_output[:, 0, :]
        
        has_logits = self.has_ans(first_word)
        
        total_loss = None
        if (
            start_positions is not None and 
            end_positions is not None and 
            is_impossibles is not None
        ):
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            if len(is_impossibles.size()) > 1:
                is_impossibles = is_impossibles.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)
            is_impossibles.clamp_(0, ignored_index)
            
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            span_loss = start_loss + end_loss
            
            # Internal Front Verification (I-FV)
            # alpha1 == 1.0, alpha2 == 0.5
            choice_loss = loss_fct(has_logits, is_impossibles.long())
            total_loss = 1.0 * span_loss + 0.5 * choice_loss
        
        if not return_dict:
            output = (
                start_logits,
                end_logits,
                has_logits,
            ) + outputs[2:] # hidden_states, attentions
            return ((total_loss,) + output) if total_loss is not None else output
        
        return QuestionAnsweringNaModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            has_logits=has_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )