| import transformers | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class AttentionPool(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.attention = nn.Linear(hidden_size, 1) | |
| def forward(self, last_hidden_state): | |
| attention_scores = self.attention(last_hidden_state).squeeze(-1) | |
| attention_weights = F.softmax(attention_scores, dim=1) | |
| pooled_output = torch.bmm(attention_weights.unsqueeze(1), last_hidden_state).squeeze(1) | |
| return pooled_output | |
| class MultiSampleDropout(nn.Module): | |
| def __init__(self, dropout=0.5, num_samples=5): | |
| super().__init__() | |
| self.dropout = nn.Dropout(dropout) | |
| self.num_samples = num_samples | |
| def forward(self, x): | |
| return torch.mean(torch.stack([self.dropout(x) for _ in range(self.num_samples)]), dim=0) | |
| class ImprovedBERTClass(nn.Module): | |
| def __init__(self, num_classes=13): | |
| super().__init__() | |
| self.bert = transformers.BertModel.from_pretrained('bert-base-uncased') | |
| self.attention_pool = AttentionPool(768) | |
| self.dropout = MultiSampleDropout() | |
| self.norm = nn.LayerNorm(768) | |
| self.classifier = nn.Linear(768, num_classes) | |
| def forward(self, input_ids, attention_mask, token_type_ids): | |
| bert_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) | |
| pooled_output = self.attention_pool(bert_output.last_hidden_state) | |
| pooled_output = self.dropout(pooled_output) | |
| pooled_output = self.norm(pooled_output) | |
| logits = self.classifier(pooled_output) | |
| return logits | |