# -*- coding: utf-8 -*-
# @Time    : 2022/3/15 21:26
# @Author  : ruihan.wjn
# @File    : pk-plm.py

"""
This code is implemented for the paper ""Knowledge Prompting in Pre-trained Langauge Models for Natural Langauge Understanding""
"""

from time import time
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from collections import OrderedDict
from transformers.models.bert import BertPreTrainedModel, BertModel
from transformers.models.roberta import RobertaModel, RobertaPreTrainedModel, RobertaTokenizer, RobertaForMaskedLM
from transformers.models.deberta import DebertaModel, DebertaPreTrainedModel, DebertaTokenizer, DebertaForMaskedLM
from transformers.models.bert.modeling_bert import BertOnlyMLMHead, BertPreTrainingHeads
from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaLMHead
from transformers.models.deberta.modeling_deberta import DebertaModel, DebertaLMPredictionHead

"""
kg enhanced corpus structure example:
{
    "token_ids": [20, 46098, 3277, 680, 10, 4066, 278, 9, 11129, 4063, 877, 579, 8, 8750, 14720, 8, 22498, 548,
    19231, 46098, 3277, 6, 25, 157, 25, 130, 3753, 46098, 3277, 4, 3684, 19809, 10960, 9, 5, 30731, 2788, 914, 5,
    1675, 8151, 35], "entity_pos": [[8, 11], [13, 15], [26, 27]],
    "entity_qid": ["Q17582", "Q231978", "Q427013"],
    "relation_pos": null,
    "relation_pid": null
}
"""


from enum import Enum
class SiameseDistanceMetric(Enum):
    """
    The metric for the contrastive loss
    """
    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
    COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)


class ContrastiveLoss(nn.Module):
    """
    Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
    two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
    Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    :param model: SentenceTransformer model
    :param distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
    :param margin: Negative samples (label == 0) should have a distance of at least the margin value.
    :param size_average: Average by the size of the mini-batch.
    Example::
        from sentence_transformers import SentenceTransformer,  SentencesDataset, LoggingHandler, losses
        from sentence_transformers.readers import InputExample
        model = SentenceTransformer("distilbert-base-nli-mean-tokens")
        train_examples = [InputExample(texts=["This is a positive pair", "Where the distance will be minimized"], label=1),
            InputExample(texts=["This is a negative pair", "Their distance will be increased"], label=0)]
        train_dataset = SentencesDataset(train_examples, model)
        train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
        train_loss = losses.ContrastiveLoss(model=model)
    """

    def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True):
        super(ContrastiveLoss, self).__init__()
        self.distance_metric = distance_metric
        self.margin = margin
        self.size_average = size_average

    def forward(self, sent_embs1, sent_embs2, labels: torch.Tensor):
        rep_anchor, rep_other = sent_embs1, sent_embs2
        distances = self.distance_metric(rep_anchor, rep_other)
        losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
        return losses.mean() if self.size_average else losses.sum()



class NSPHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score



class RoBertaKPPLMForProcessedWikiKGPLM(RobertaForMaskedLM):

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        # self.roberta = RobertaModel(config)
        try:
            classifier_dropout = (
                config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
            )
        except:
            classifier_dropout = (config.hidden_dropout_prob)
        self.dropout = nn.Dropout(classifier_dropout)
        # self.cls = BertOnlyMLMHead(config)
        # self.lm_head = RobertaLMHead(config)  # Masked Language Modeling head
        self.detector = NSPHead(config)  # Knowledge Noise Detection head
        self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
        self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
        # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])

        self.contrastive_loss_fn = ContrastiveLoss()
        self.post_init()

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            labels=None,
            # entity_label=None,
            entity_candidate=None,
            # relation_label=None,
            relation_candidate=None,
            noise_detect_label=None,
            task_id=None,
            mask_id=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        # start_time = time()
        mlm_labels = labels

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # print("attention_mask.shape=", attention_mask.shape)
        # print("input_ids[0]=", input_ids[0])
        # print("token_type_ids[0]=", token_type_ids[0])
        # attention_mask = None

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        prediction_scores = self.lm_head(sequence_output)  # mlm head
        # noise_detect_scores = self.detector(pooled_output)  # knowledge noise detector use pool output
        noise_detect_scores = self.detector(sequence_output[:, 0, :])  # knowledge noise detector use cls embedding

        # ner
        # sequence_output = self.dropout(sequence_output)
        # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)

        # mlm
        masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None
        total_loss = list()
        if mlm_labels is not None:
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
            total_loss.append(masked_lm_loss)

        # if noise_detect_label is not None:
        #     noise_detect_scores = noise_detect_scores[task_id == 1]
        #     noise_detect_label = noise_detect_label[task_id == 1]
        #
        #     if len(noise_detect_label) > 0:
        #         loss_fct = CrossEntropyLoss()
        #         noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1))
        #         total_loss.append(noise_detect_loss)

        entity_candidate = entity_candidate[task_id == 2]
        if len(entity_candidate) > 0:
            batch_size = entity_candidate.shape[0]
            candidate_num = entity_candidate.shape[1]
            # print("negative_num=", negative_num)
            # 获取被mask实体的embedding
            batch_entity_query_embedding = list()
            for ei, input_id in enumerate(input_ids[task_id == 2]):
                batch_entity_query_embedding.append(
                    torch.mean(sequence_output[task_id == 2][ei][input_id == mask_id[task_id == 2][ei]], 0))  # [hidden_dim]
            batch_entity_query_embedding = torch.stack(batch_entity_query_embedding)  # [bz, dim]
            # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)
            batch_entity_query_embedding = self.entity_mlp(batch_entity_query_embedding)  # [bz, dim]
            batch_entity_query_embedding = batch_entity_query_embedding.unsqueeze(1).repeat((1, candidate_num, 1))  # [bz, 11, dim]
            batch_entity_query_embedding = batch_entity_query_embedding.view(-1, batch_entity_query_embedding.shape[-1])  # [bz * 11, dim]
            # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)

            # 获得positive和negative的BERT表示
            # entity_candidiate: [bz, 11, len]
            entity_candidate = entity_candidate.view(-1, entity_candidate.shape[-1])  # [bz * 11, len]
            entity_candidate_embedding = self.roberta.embeddings(input_ids=entity_candidate)  # [bz * 11, len, dim]
            entity_candidate_embedding = self.entity_mlp(torch.mean(entity_candidate_embedding, 1))  # [bz * 11, dim]

            contrastive_entity_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
            contrastive_entity_label = contrastive_entity_label.unsqueeze(0).repeat([batch_size, 1]).view(-1)  # [bz * 11]

            entity_loss = self.contrastive_loss_fn(
                batch_entity_query_embedding, entity_candidate_embedding, contrastive_entity_label
            )
            total_loss.append(entity_loss)

        relation_candidate = relation_candidate[task_id == 3]
        if len(relation_candidate) > 0:
            batch_size = relation_candidate.shape[0]
            candidate_num = relation_candidate.shape[1]
            # print("negative_num=", negative_num)
            # 获取被mask relation的embedding
            batch_relation_query_embedding = list()
            for ei, input_id in enumerate(input_ids[task_id == 3]):
                batch_relation_query_embedding.append(
                    torch.mean(sequence_output[task_id == 3][ei][input_id == mask_id[task_id == 3][ei]], 0))  # [hidden_dim]
            batch_relation_query_embedding = torch.stack(batch_relation_query_embedding)  # [bz, dim]
            # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)
            batch_relation_query_embedding = self.relation_mlp(batch_relation_query_embedding)  # [bz, dim]
            batch_relation_query_embedding = batch_relation_query_embedding.unsqueeze(1).repeat(
                (1, candidate_num, 1))  # [bz, 11, dim]
            batch_relation_query_embedding = batch_relation_query_embedding.view(-1, batch_relation_query_embedding.shape[-1])  # [bz * 11, dim]
            # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)

            # 获得positive和negative的BERT表示
            # entity_candidiate: [bz, 11, len]
            relation_candidate = relation_candidate.view(-1, relation_candidate.shape[-1])  # [bz * 11, len]
            relation_candidate_embedding = self.roberta.embeddings(input_ids=relation_candidate)  # [bz * 11, len, dim]
            relation_candidate_embedding = self.relation_mlp(torch.mean(relation_candidate_embedding, 1))  # [bz * 11, dim]

            contrastive_relation_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
            contrastive_relation_label = contrastive_relation_label.unsqueeze(0).repeat([batch_size, 1]).view(-1)  # [bz * 11]

            relation_loss = self.contrastive_loss_fn(
                batch_relation_query_embedding, relation_candidate_embedding, contrastive_relation_label
            )
            total_loss.append(relation_loss)

        total_loss = torch.sum(torch.stack(total_loss), -1)

        # end_time = time()
        # print("neural_mode_time: {}".format(end_time - start_time))
        # print("masked_lm_loss.unsqueeze(0)=", masked_lm_loss.unsqueeze(0))
        # print("masked_lm_loss.unsqueeze(0).shape=", masked_lm_loss.unsqueeze(0).shape)
        # print("logits=", prediction_scores.argmax(2))
        # print("logits.shape=", prediction_scores.argmax(2).shape)


        return OrderedDict([
            ("loss", total_loss),
            ("mlm_loss", masked_lm_loss.unsqueeze(0)),
            # ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None),
            # ("entity_loss", entity_loss.unsqueeze(0) if entity_loss is not None else None),
            # ("relation_loss", relation_loss.unsqueeze(0) if relation_loss is not None else None),
            ("logits", prediction_scores.argmax(2)),
            # ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None and len(noise_detect_scores) > 0 else None),
        ])


class DeBertaKPPLMForProcessedWikiKGPLM(DebertaForMaskedLM):

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        # self.roberta = RobertaModel(config)
        try:
            classifier_dropout = (
                config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
            )
        except:
            classifier_dropout = (config.hidden_dropout_prob)
        self.dropout = nn.Dropout(classifier_dropout)
        # self.cls = BertOnlyMLMHead(config)
        # self.lm_head = RobertaLMHead(config)  # Masked Language Modeling head
        self.detector = NSPHead(config)  # Knowledge Noise Detection head
        self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
        self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
        # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])

        self.contrastive_loss_fn = ContrastiveLoss()
        self.post_init()

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            labels=None,
            # entity_label=None,
            entity_candidate=None,
            # relation_label=None,
            relation_candidate=None,
            noise_detect_label=None,
            task_id=None,
            mask_id=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        # start_time = time()
        mlm_labels = labels

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # print("attention_mask.shape=", attention_mask.shape)
        # print("input_ids[0]=", input_ids[0])
        # print("token_type_ids[0]=", token_type_ids[0])
        # attention_mask = None

        outputs = self.deberta(
            input_ids,
            # attention_mask=attention_mask,
            attention_mask=None,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)  # mlm head
        # noise_detect_scores = self.detector(pooled_output)  # knowledge noise detector use pool output
        noise_detect_scores = self.detector(sequence_output[:, 0, :])  # knowledge noise detector use cls embedding

        # ner
        # sequence_output = self.dropout(sequence_output)
        # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)

        # mlm
        masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None
        total_loss = list()
        if mlm_labels is not None:
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
            total_loss.append(masked_lm_loss)

        # if noise_detect_label is not None:
        #     noise_detect_scores = noise_detect_scores[task_id == 1]
        #     noise_detect_label = noise_detect_label[task_id == 1]
        #
        #     if len(noise_detect_label) > 0:
        #         loss_fct = CrossEntropyLoss()
        #         noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1))
        #         total_loss.append(noise_detect_loss)

        entity_candidate = entity_candidate[task_id == 2]
        if len(entity_candidate) > 0:
            batch_size = entity_candidate.shape[0]
            candidate_num = entity_candidate.shape[1]
            # print("negative_num=", negative_num)
            # 获取被mask实体的embedding
            batch_entity_query_embedding = list()
            for ei, input_id in enumerate(input_ids[task_id == 2]):
                batch_entity_query_embedding.append(
                    torch.mean(sequence_output[task_id == 2][ei][input_id == mask_id[task_id == 2][ei]], 0))  # [hidden_dim]
            batch_entity_query_embedding = torch.stack(batch_entity_query_embedding)  # [bz, dim]
            # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)
            batch_entity_query_embedding = self.entity_mlp(batch_entity_query_embedding)  # [bz, dim]
            batch_entity_query_embedding = batch_entity_query_embedding.unsqueeze(1).repeat((1, candidate_num, 1))  # [bz, 11, dim]
            batch_entity_query_embedding = batch_entity_query_embedding.view(-1, batch_entity_query_embedding.shape[-1])  # [bz * 11, dim]
            # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)

            # 获得positive和negative的BERT表示
            # entity_candidiate: [bz, 11, len]
            entity_candidate = entity_candidate.view(-1, entity_candidate.shape[-1])  # [bz * 11, len]
            entity_candidate_embedding = self.deberta.embeddings(input_ids=entity_candidate)  # [bz * 11, len, dim]
            entity_candidate_embedding = self.entity_mlp(torch.mean(entity_candidate_embedding, 1))  # [bz * 11, dim]

            contrastive_entity_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
            contrastive_entity_label = contrastive_entity_label.unsqueeze(0).repeat([batch_size, 1]).view(-1)  # [bz * 11]

            entity_loss = self.contrastive_loss_fn(
                batch_entity_query_embedding, entity_candidate_embedding, contrastive_entity_label
            )
            total_loss.append(entity_loss)

        relation_candidate = relation_candidate[task_id == 3]
        if len(relation_candidate) > 0:
            batch_size = relation_candidate.shape[0]
            candidate_num = relation_candidate.shape[1]
            # print("negative_num=", negative_num)
            # 获取被mask relation的embedding
            batch_relation_query_embedding = list()
            for ei, input_id in enumerate(input_ids[task_id == 3]):
                batch_relation_query_embedding.append(
                    torch.mean(sequence_output[task_id == 3][ei][input_id == mask_id[task_id == 3][ei]], 0))  # [hidden_dim]
            batch_relation_query_embedding = torch.stack(batch_relation_query_embedding)  # [bz, dim]
            # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)
            batch_relation_query_embedding = self.relation_mlp(batch_relation_query_embedding)  # [bz, dim]
            batch_relation_query_embedding = batch_relation_query_embedding.unsqueeze(1).repeat(
                (1, candidate_num, 1))  # [bz, 11, dim]
            batch_relation_query_embedding = batch_relation_query_embedding.view(-1, batch_relation_query_embedding.shape[-1])  # [bz * 11, dim]
            # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)

            # 获得positive和negative的BERT表示
            # entity_candidiate: [bz, 11, len]
            relation_candidate = relation_candidate.view(-1, relation_candidate.shape[-1])  # [bz * 11, len]
            relation_candidate_embedding = self.deberta.embeddings(input_ids=relation_candidate)  # [bz * 11, len, dim]
            relation_candidate_embedding = self.relation_mlp(torch.mean(relation_candidate_embedding, 1))  # [bz * 11, dim]

            contrastive_relation_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
            contrastive_relation_label = contrastive_relation_label.unsqueeze(0).repeat([batch_size, 1]).view(-1)  # [bz * 11]

            relation_loss = self.contrastive_loss_fn(
                batch_relation_query_embedding, relation_candidate_embedding, contrastive_relation_label
            )
            total_loss.append(relation_loss)

        total_loss = torch.sum(torch.stack(total_loss), -1)

        # end_time = time()
        # print("neural_mode_time: {}".format(end_time - start_time))
        # print("masked_lm_loss.unsqueeze(0)=", masked_lm_loss.unsqueeze(0))
        # print("masked_lm_loss.unsqueeze(0).shape=", masked_lm_loss.unsqueeze(0).shape)
        # print("logits=", prediction_scores.argmax(2))
        # print("logits.shape=", prediction_scores.argmax(2).shape)


        return OrderedDict([
            ("loss", total_loss),
            ("mlm_loss", masked_lm_loss.unsqueeze(0)),
            # ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None),
            # ("entity_loss", entity_loss.unsqueeze(0) if entity_loss is not None else None),
            # ("relation_loss", relation_loss.unsqueeze(0) if relation_loss is not None else None),
            ("logits", prediction_scores.argmax(2)),
            # ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None and len(noise_detect_scores) > 0 else None),
        ])


class RoBertaForWikiKGPLM(RobertaPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        self.roberta = RobertaModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        # self.cls = BertOnlyMLMHead(config)
        self.lm_head = RobertaLMHead(config) # Masked Language Modeling head
        self.detector = NSPHead(config) # Knowledge Noise Detection head
        self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
        self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
        # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])

        self.contrastive_loss_fn = ContrastiveLoss()
        self.post_init()

        self.tokenizer = RobertaTokenizer.from_pretrained(config.name_or_path)

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            mlm_labels=None,
            entity_label=None,
            entity_negative=None,
            relation_label=None,
            relation_negative=None,
            noise_detect_label=None,
            task_id=None,
            mask_id=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        # start_time = time()

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # print("attention_mask.shape=", attention_mask.shape)
        # print("input_ids[0]=", input_ids[0])
        # print("token_type_ids[0]=", token_type_ids[0])
        # attention_mask = None


        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output, pooled_output = outputs[:2]
        prediction_scores = self.lm_head(sequence_output) # mlm head
        noise_detect_scores = self.detector(pooled_output) # knowledge noise detector


        # ner
        # sequence_output = self.dropout(sequence_output)
        # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)

        # mlm
        masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None
        if mlm_labels is not None:
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))

        if noise_detect_label is not None:
            loss_fct = CrossEntropyLoss()
            noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1))
            total_loss = masked_lm_loss + noise_detect_loss

        if entity_label is not None and entity_negative is not None:
            batch_size = input_ids.shape[0]
            negative_num = entity_negative.shape[1]
            # print("negative_num=", negative_num)
            # 获取被mask实体的embedding
            batch_query_embedding = list()
            for ei, input_id in enumerate(input_ids):
                batch_query_embedding.append(torch.mean(sequence_output[ei][input_id == mask_id[ei]], 0)) # [hidden_dim]
            batch_query_embedding = torch.stack(batch_query_embedding) # [bz, dim]
            # print("batch_query_embedding.shape=", batch_query_embedding.shape)
            batch_query_embedding = self.entity_mlp(batch_query_embedding) # [bz, dim]
            batch_query_embedding = batch_query_embedding.unsqueeze(1).repeat((1, negative_num + 1, 1)) # [bz, 11, dim]
            batch_query_embedding = batch_query_embedding.view(-1, batch_query_embedding.shape[-1]) # [bz * 11, dim]
            # print("batch_query_embedding.shape=", batch_query_embedding.shape)

            # 获得positive和negative的BERT表示
            # entity_label: [bz, len], entity_negative: [bz, 10, len]
            entity_negative = entity_negative.view(-1, entity_negative.shape[-1]) # [bz * 10, len]
            entity_label_embedding = self.roberta.embeddings(input_ids=entity_label) # [bz, len, dim]
            entity_label_embedding = self.entity_mlp(torch.mean(entity_label_embedding, 1)) # [bz, dim]
            entity_label_embedding = entity_label_embedding.unsqueeze(1) # [bz, 1, dim]

            entity_negative_embedding = self.roberta.embeddings(input_ids=entity_negative) # [bz * 10, len, dim]
            entity_negative_embedding = self.entity_mlp(torch.mean(entity_negative_embedding, 1)) # [bz * 10, dim]
            entity_negative_embedding = entity_negative_embedding \
                .view(input_ids.shape[0], -1, entity_negative_embedding.shape[-1]) # [bz, 10, dim]

            contrastive_label = torch.Tensor([0] * negative_num + [1]).float().cuda()
            contrastive_label = contrastive_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11]
            # print("entity_negative_embedding.shape=", entity_negative_embedding.shape)
            # print("entity_label_embedding.shape=", entity_label_embedding.shape)
            candidate_embedding = torch.cat([entity_negative_embedding, entity_label_embedding], 1) # [bz, 11, dim]
            candidate_embedding = candidate_embedding.view(-1, candidate_embedding.shape[-1]) # [bz * 11, dim]
            # print("candidate_embedding.shape=", candidate_embedding.shape)

            entity_loss = self.contrastive_loss_fn(batch_query_embedding, candidate_embedding, contrastive_label)
            total_loss = masked_lm_loss + entity_loss


        # if ner_labels is not None:
        #     loss_fct = CrossEntropyLoss()
        #     # Only keep active parts of the loss
        #
        #     active_loss = attention_mask.repeat(self.config.entity_type_num, 1, 1).view(-1) == 1
        #     active_logits = ner_logits.reshape(-1, self.config.num_ner_labels)
        #     active_labels = torch.where(
        #         active_loss, ner_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(ner_labels)
        #     )
        #     ner_loss = loss_fct(active_logits, active_labels)
        #
        # if masked_lm_loss:
        #     total_loss = masked_lm_loss + ner_loss * 4
        # print("total_loss=", total_loss)
        # print("mlm_loss=", masked_lm_loss)


        # end_time = time()
        # print("neural_mode_time: {}".format(end_time - start_time))

        return OrderedDict([
            ("loss", total_loss),
            ("mlm_loss", masked_lm_loss.unsqueeze(0)),
            ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None),
            ("entity_loss", entity_loss.unsqueeze(0) if entity_label is not None else None),
            ("logits", prediction_scores.argmax(2)),
            ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None else None),
        ])
        # MaskedLMOutput(
        #     loss=total_loss,
        #     logits=prediction_scores.argmax(2),
        #     ner_l
        #     hidden_states=outputs.hidden_states,
        #     attentions=outputs.attentions,
        # )




class BertForWikiKGPLM(BertPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        # self.cls = BertOnlyMLMHead(config)
        self.cls = BertPreTrainedModel(config)
        self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
        self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
        # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])

        self.contrastive_loss_fn = ContrastiveLoss()
        self.post_init()

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            mlm_labels=None,
            entity_label=None,
            entity_negative=None,
            relation_label=None,
            relation_negative=None,
            noise_detect_label=None,
            task_id=None,
            mask_id=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        print("attention_mask.shape=", attention_mask.shape)
        print("input_ids[0]=", input_ids[0])
        print("token_type_ids[0]=", token_type_ids[0])
        attention_mask = None
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output, pooled_output = outputs[:2]
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

        # ner
        # sequence_output = self.dropout(sequence_output)
        # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)

        # mlm
        masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None

        if mlm_labels is not None:
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))

        if noise_detect_label is not None:
            loss_fct = CrossEntropyLoss()
            noise_detect_loss = loss_fct(seq_relationship_score.view(-1, 2), noise_detect_label.view(-1))
            total_loss = masked_lm_loss + noise_detect_loss

        if entity_label is not None and entity_negative is not None:
            negative_num = entity_negative.shape[1]
            # 获取被mask实体的embedding
            batch_query_embedding = list()
            for ei, input_id in enumerate(input_ids):
                batch_query_embedding.append(torch.mean(sequence_output[ei][input_id == mask_id[ei]], 0)) # [hidden_dim]
            batch_query_embedding = torch.stack(batch_query_embedding) # [bz, dim]
            batch_query_embedding = self.entity_mlp(batch_query_embedding) # [bz, dim]
            batch_query_embedding = batch_query_embedding.repeat((1, negative_num + 1, 1)) # [bz, 11, dim]

            # 获得positive和negative的BERT表示
            # entity_label: [bz, len], entity_negative: [bz, 10, len]
            entity_negative = entity_negative.view(-1, entity_negative.shape[-1]) # [bz * 10, len]
            entity_label_embedding = self.bert.embeddings(input_id=entity_label) # [bz, len, dim]
            entity_label_embedding = self.entity_mlp(torch.mean(entity_label_embedding, 1)) # [bz, dim]
            entity_label_embedding = entity_label_embedding.unsqueeze(1) # [bz, 1, dim]

            entity_negative_embedding = self.bert.embeddings(input_id=entity_negative) # [bz * 10, len, dim]
            entity_negative_embedding = self.entity_mlp(torch.mean(entity_negative_embedding, 1)) # [bz * 10, dim]
            entity_negative_embedding = entity_negative_embedding \
                .view(input_ids.shape[0], -1, entity_negative_embedding.shape[-1]) # [bz, 10, dim]

            contrastive_label = torch.Tensor([0] * negative_num + [1]).float().cuda()
            candidate_embedding = torch.cat([entity_negative_embedding, entity_label_embedding], 1) # [bz, 11, dim]

            entity_loss = self.contrastive_loss_fn(batch_query_embedding, candidate_embedding, contrastive_label)
            total_loss = masked_lm_loss + entity_loss


        # if ner_labels is not None:
        #     loss_fct = CrossEntropyLoss()
        #     # Only keep active parts of the loss
        #
        #     active_loss = attention_mask.repeat(self.config.entity_type_num, 1, 1).view(-1) == 1
        #     active_logits = ner_logits.reshape(-1, self.config.num_ner_labels)
        #     active_labels = torch.where(
        #         active_loss, ner_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(ner_labels)
        #     )
        #     ner_loss = loss_fct(active_logits, active_labels)
        #
        # if masked_lm_loss:
        #     total_loss = masked_lm_loss + ner_loss * 4

        return OrderedDict([
            ("loss", total_loss),
            ("mlm_loss", masked_lm_loss.unsqueeze(0)),
            ("noise_detect_loss", noise_detect_loss.unsqueeze(0)),
            ("entity_loss", entity_loss.unsqueeze(0)),
            ("logits", prediction_scores.argmax(2)),
            ("noise_detect_logits", seq_relationship_score.argmax(3)),
            ()
        ])
        # MaskedLMOutput(
        #     loss=total_loss,
        #     logits=prediction_scores.argmax(2),
        #     ner_l
        #     hidden_states=outputs.hidden_states,
        #     attentions=outputs.attentions,
        # )