hyu-nlp-hw4 / modeling_bert_for_mid.py
Hyun9898's picture
Upload folder using huggingface_hub
8111d5c verified
raw
history blame
329 Bytes
import torch
import torch.nn as nn
class BERT_for_MID(nn.Module):
def __init__(self, bert, classifier):
super(BERT_for_MID, self).__init__()
self.bert = bert
self.classifier = classifier
def forward(self, input, attention_mask):
return self.classifier(self.bert(input, attention_mask))