ohgnues commited on
Commit
4c6f0a3
·
verified ·
1 Parent(s): 4f02e25

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +143 -0
model.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from transformers import T5ForConditionalGeneration
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import L1Loss, CrossEntropyLoss
6
+ import copy
7
+ from typing import Optional, Tuple, Union
8
+ from dataclasses import dataclass
9
+ from transformers.utils import ModelOutput
10
+
11
+
12
+ @dataclass
13
+ class AnalystOutput(ModelOutput):
14
+ loss: Optional[torch.FloatTensor] = None
15
+ logits: torch.FloatTensor = None
16
+ regression_logits: torch.FloatTensor = None
17
+ classification_logits: torch.FloatTensor = None
18
+ tagging_logits: torch.FloatTensor = None
19
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
20
+
21
+ class ClassificationHead(nn.Module):
22
+ def __init__(self, config):
23
+ super().__init__()
24
+ self.dense = nn.Linear(config.d_model, config.d_model)
25
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
26
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
27
+
28
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
29
+ hidden_states = self.dropout(hidden_states)
30
+ hidden_states = self.dense(hidden_states)
31
+ hidden_states = torch.tanh(hidden_states)
32
+ hidden_states = self.dropout(hidden_states)
33
+ hidden_states = self.out_proj(hidden_states)
34
+ return hidden_states
35
+
36
+ class Analyst(T5ForConditionalGeneration):
37
+ def __init__(self, config):
38
+ super().__init__(config)
39
+ regression_config = copy.deepcopy(config)
40
+ regression_config.num_labels = 1
41
+ self.regression_head = ClassificationHead(regression_config)
42
+
43
+ tagging_config = copy.deepcopy(config)
44
+ tagging_config.num_labels = 2
45
+ self.tagging_head = ClassificationHead(tagging_config)
46
+
47
+ self.classification_head = ClassificationHead(config)
48
+
49
+ self.post_init()
50
+
51
+ def forward(
52
+ self,
53
+ input_ids: Optional[torch.LongTensor] = None,
54
+ attention_mask: Optional[torch.FloatTensor] = None,
55
+ decoder_input_ids: Optional[torch.LongTensor] = None,
56
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
57
+ head_mask: Optional[torch.FloatTensor] = None,
58
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
59
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
60
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
61
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ use_cache: Optional[bool] = None,
65
+ output_attentions: Optional[bool] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
+ labels: Optional[torch.LongTensor] = None,
68
+ labels_regression: Optional[torch.FloatTensor] = None,
69
+ labels_tagging: Optional[torch.LongTensor] = None,
70
+ labels_classification: Optional[torch.LongTensor] = None,
71
+ return_dict: Optional[bool] = None,
72
+ ) -> Union[Tuple[torch.FloatTensor], AnalystOutput]:
73
+ output = super().forward(
74
+ input_ids=input_ids,
75
+ attention_mask=attention_mask,
76
+ decoder_input_ids=decoder_input_ids,
77
+ decoder_attention_mask=decoder_attention_mask,
78
+ head_mask=head_mask,
79
+ decoder_head_mask=decoder_head_mask,
80
+ cross_attn_head_mask=cross_attn_head_mask,
81
+ encoder_outputs=encoder_outputs,
82
+ past_key_values=past_key_values,
83
+ inputs_embeds=inputs_embeds,
84
+ decoder_inputs_embeds=decoder_inputs_embeds,
85
+ use_cache=use_cache,
86
+ output_attentions=output_attentions,
87
+ output_hidden_states=output_hidden_states,
88
+ labels=labels,
89
+ return_dict=return_dict
90
+ )
91
+ encoder_hidden_state = output.encoder_last_hidden_state
92
+ lm_logits = output.logits
93
+
94
+ loss = output.loss
95
+ regression_logits = None
96
+ classification_logits = None
97
+ tagging_logits = None
98
+
99
+ if input_ids is not None:
100
+ eos_mask = input_ids.eq(self.config.eos_token_id).to(encoder_hidden_state.device)
101
+ batch_size, _, hidden_size = encoder_hidden_state.shape
102
+ sentence_representation = encoder_hidden_state[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
103
+
104
+ regression_logits = self.regression_head(sentence_representation)
105
+ classification_logits = self.classification_head(sentence_representation)
106
+ tagging_logits = self.tagging_head(encoder_hidden_state)
107
+
108
+ if labels_regression is not None:
109
+ labels_regression = labels_regression.to(lm_logits.device)
110
+ loss_fct = L1Loss()
111
+ regression_loss = loss_fct(regression_logits.squeeze(), labels_regression.squeeze())
112
+ loss += regression_loss
113
+ else:
114
+ regression_loss = None
115
+
116
+ if labels_classification is not None:
117
+ labels_classification = labels_classification.to(lm_logits.device)
118
+ loss_fct = CrossEntropyLoss()
119
+ classification_loss = loss_fct(classification_logits.view(-1, self.config.num_labels), labels_classification.squeeze())
120
+ loss += classification_loss
121
+ else:
122
+ classification_loss = None
123
+
124
+ if labels_tagging is not None:
125
+ labels_tagging = labels_tagging.to(lm_logits.device)
126
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
127
+ tagging_loss = loss_fct(tagging_logits.view(-1, tagging_logits.size(-1)), labels_tagging.view(-1))
128
+ loss += tagging_loss
129
+ else:
130
+ tagging_loss = None
131
+
132
+ if not return_dict:
133
+ output = (loss, lm_logits, regression_logits, classification_logits, tagging_logits)
134
+ return output
135
+
136
+ return AnalystOutput(
137
+ loss=loss,
138
+ logits=lm_logits,
139
+ regression_logits=regression_logits,
140
+ classification_logits=classification_logits,
141
+ tagging_logits=tagging_logits,
142
+ encoder_last_hidden_state=encoder_hidden_state
143
+ )