| import transformers | |
| import torch.nn as nn | |
| class BertClassificationModel(nn.Module): | |
| def __init__(self): | |
| super(BertClassificationModel, self).__init__() | |
| pretrained_weights="bert-base-chinese" | |
| self.bert = transformers.BertModel.from_pretrained(pretrained_weights) | |
| for param in self.bert.parameters(): | |
| param.requires_grad = True | |
| self.dense = nn.Linear(768, 3) | |
| def forward(self, input_ids,token_type_ids,attention_mask): | |
| bert_output = self.bert(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=attention_mask) | |
| bert_cls_hidden_state = bert_output[1] | |
| linear_output = self.dense(bert_cls_hidden_state) | |
| return linear_output |