File size: 329 Bytes
8111d5c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13

import torch
import torch.nn as nn

class BERT_for_MID(nn.Module):
    def __init__(self, bert, classifier):
        super(BERT_for_MID, self).__init__()
        self.bert = bert
        self.classifier = classifier

    def forward(self, input, attention_mask):
        return self.classifier(self.bert(input, attention_mask))