import torch import torch.nn as nn class BERT_for_MID(nn.Module): def __init__(self, bert, classifier): super(BERT_for_MID, self).__init__() self.bert = bert self.classifier = classifier def forward(self, input, attention_mask): return self.classifier(self.bert(input, attention_mask))