from abc import ABC | |
import torch.nn as nn | |
from transformers import BertModel, BertPreTrainedModel | |
class BertForClassification(BertPreTrainedModel, ABC): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.bert = BertModel(config) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
self.classifier = nn.Linear(config.hidden_size, self.num_labels) | |
self.init_weights() | |
def forward(self, input_ids=None, attention_mask=None) -> tuple: | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
last_hidden_state = outputs[0] # (batch_size, seq_len, 768) | |
last_hidden_state = self.dropout(last_hidden_state) | |
logits = self.classifier(last_hidden_state) # (batch_size, seq_len, num_labels) | |
outputs = (logits, ) + outputs[2:] | |
return outputs # logits, (hidden_states), (attentions) | |