UNCANNY69 commited on
Commit
f50e876
·
verified ·
1 Parent(s): eea959d

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +45 -0
model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta
2
+ import torch
3
+ from transformers.pytorch_utils import nn
4
+ import torch.nn.functional as F
5
+ from transformers import AlbertModel, AlbertForSequenceClassification, PreTrainedModel
6
+ from transformers.modeling_outputs import SequenceClassifierOutput
7
+ from transformers import AlbertConfig
8
+
9
+ class AlbertLSTMForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
10
+ config_class = AlbertLSTMConfig
11
+
12
+ def __init__(self, config):
13
+ super(AlbertLSTMForSequenceClassification, self).__init__(config)
14
+ self.num_classes = config.num_classes
15
+ self.embed_dim = config.embed_dim
16
+ self.num_layers = config.num_layers
17
+ self.hidden_dim_lstm = config.hidden_dim_lstm
18
+ self.dropout = nn.Dropout(config.dropout_rate)
19
+ self.albert = AlbertModel.from_pretrained('albert-base-v2',
20
+ output_hidden_states=True,
21
+ output_attentions=False)
22
+ print("ALBERT Model Loaded")
23
+ self.lstm = nn.LSTM(self.embed_dim, self.hidden_dim_lstm, batch_first=True, num_layers=3) # noqa
24
+ self.fc = nn.Linear(self.hidden_dim_lstm, self.num_classes)
25
+
26
+ def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
27
+ albert_output = self.albert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
28
+ hidden_states = albert_output["hidden_states"]
29
+
30
+ hidden_states = torch.stack([hidden_states[layer_i][:, 0].squeeze()
31
+ for layer_i in range(0, self.num_layers)], dim=-1) # noqa
32
+ hidden_states = hidden_states.view(-1, self.num_layers, self.embed_dim)
33
+ out, _ = self.lstm(hidden_states, None)
34
+ out = self.dropout(out[:, -1, :])
35
+ logits = self.fc(out)
36
+ loss = None
37
+ if labels is not None:
38
+ loss = F.cross_entropy(logits, labels)
39
+ out = SequenceClassifierOutput(
40
+ loss=loss,
41
+ logits=logits,
42
+ hidden_states=albert_output.hidden_states,
43
+ attentions=albert_output.attentions,
44
+ )
45
+ return out