File size: 329 Bytes
8111d5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import torch
import torch.nn as nn
class BERT_for_MID(nn.Module):
def __init__(self, bert, classifier):
super(BERT_for_MID, self).__init__()
self.bert = bert
self.classifier = classifier
def forward(self, input, attention_mask):
return self.classifier(self.bert(input, attention_mask))
|