LedZeppe1in
Added custom rutabert pipeline for column type annotation
1507360
raw
history blame contribute delete
942 Bytes
from abc import ABC
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
class BertForClassification(BertPreTrainedModel, ABC):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None) -> tuple:
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = outputs[0] # (batch_size, seq_len, 768)
last_hidden_state = self.dropout(last_hidden_state)
logits = self.classifier(last_hidden_state) # (batch_size, seq_len, num_labels)
outputs = (logits, ) + outputs[2:]
return outputs # logits, (hidden_states), (attentions)