codys12 commited on
Commit
28effb4
·
verified ·
1 Parent(s): 202bc81

Create hunyuan.py

Browse files
Files changed (1) hide show
  1. hunyuan.py +851 -0
hunyuan.py ADDED
@@ -0,0 +1,851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
3
+ #
4
+ """ PyTorch HunYuan model."""
5
+
6
+ import math
7
+ import warnings
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from torch import nn
15
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
16
+
17
+ from transformers.activations import ACT2FN
18
+ from transformers.cache_utils import Cache, DynamicCache
19
+ from transformers.modeling_attn_mask_utils import (
20
+ AttentionMaskConverter,
21
+ _prepare_4d_attention_mask,
22
+ _prepare_4d_causal_attention_mask,
23
+ _prepare_4d_causal_attention_mask_for_sdpa,
24
+ )
25
+ from transformers.modeling_outputs import (
26
+ BaseModelOutputWithPast,
27
+ CausalLMOutputWithPast,
28
+ SequenceClassifierOutputWithPast
29
+ )
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
32
+ from transformers.utils import (
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ is_flash_attn_2_available,
36
+ is_flash_attn_greater_or_equal_2_10,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from transformers.utils.import_utils import is_torch_fx_available
41
+ from transformers.generation.utils import GenerateOutput
42
+ from .configuration_hunyuan import HunYuanConfig
43
+ from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
44
+
45
+
46
+ if is_flash_attn_2_available():
47
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
48
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
49
+
50
+
51
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
52
+ # It means that the function will not be traced through and simply appear as a node in the graph.
53
+ if is_torch_fx_available():
54
+ if not is_torch_greater_or_equal_than_1_13:
55
+ import torch.fx
56
+
57
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
58
+
59
+
60
+
61
+ _CONFIG_FOR_DOC = "HunYuanConfig"
62
+
63
+
64
+ HUNYUAN_START_DOCSTRING = r"""
65
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
66
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
67
+ etc.)
68
+
69
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
70
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
71
+ and behavior.
72
+
73
+ Parameters:
74
+ config ([`HunYuanConfig`]):
75
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
76
+ load the weights associated with the model, only the configuration. Check out the
77
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
78
+ """
79
+
80
+
81
+ @add_start_docstrings(
82
+ "The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
83
+ HUNYUAN_START_DOCSTRING,
84
+ )
85
+ class HunYuanPreTrainedModel(PreTrainedModel):
86
+ config_class = HunYuanConfig
87
+ base_model_prefix = "model"
88
+ supports_gradient_checkpointing = True
89
+ _no_split_modules = ["HunYuanDecoderLayer"]
90
+ _skip_keys_device_placement = "past_key_values"
91
+ _supports_flash_attn_2 = True
92
+ _supports_sdpa = True
93
+ _supports_cache_class = True
94
+
95
+ def _init_weights(self, module):
96
+ std = self.config.initializer_range
97
+ if isinstance(module, nn.Linear):
98
+ module.weight.data.normal_(mean=0.0, std=std)
99
+ if module.bias is not None:
100
+ module.bias.data.zero_()
101
+ elif isinstance(module, nn.Embedding):
102
+ module.weight.data.normal_(mean=0.0, std=std)
103
+ if module.padding_idx is not None:
104
+ module.weight.data[module.padding_idx].zero_()
105
+
106
+
107
+ HUNYUAN_INPUTS_DOCSTRING = r"""
108
+ Args:
109
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
110
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
111
+ it.
112
+
113
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
114
+ [`PreTrainedTokenizer.__call__`] for details.
115
+
116
+ [What are input IDs?](../glossary#input-ids)
117
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
118
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
119
+
120
+ - 1 for tokens that are **not masked**,
121
+ - 0 for tokens that are **masked**.
122
+
123
+ [What are attention masks?](../glossary#attention-mask)
124
+
125
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
126
+ [`PreTrainedTokenizer.__call__`] for details.
127
+
128
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
129
+ `past_key_values`).
130
+
131
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
132
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
133
+ information on the default strategy.
134
+
135
+ - 1 indicates the head is **not masked**,
136
+ - 0 indicates the head is **masked**.
137
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
138
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
139
+ config.n_positions - 1]`.
140
+
141
+ [What are position IDs?](../glossary#position-ids)
142
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
143
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
144
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
145
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
146
+
147
+ Two formats are allowed:
148
+ - a [`~cache_utils.Cache`] instance;
149
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
150
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
151
+ cache format.
152
+
153
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
154
+ legacy cache format will be returned.
155
+
156
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
157
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
158
+ of shape `(batch_size, sequence_length)`.
159
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
160
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
161
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
162
+ model's internal embedding lookup matrix.
163
+ use_cache (`bool`, *optional*):
164
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
165
+ `past_key_values`).
166
+ output_attentions (`bool`, *optional*):
167
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
168
+ tensors for more detail.
169
+ output_hidden_states (`bool`, *optional*):
170
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
171
+ more detail.
172
+ return_dict (`bool`, *optional*):
173
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
174
+ """
175
+
176
+
177
+ @add_start_docstrings(
178
+ "The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
179
+ HUNYUAN_START_DOCSTRING,
180
+ )
181
+ class HunYuanModel(HunYuanPreTrainedModel):
182
+ """
183
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HunYuanDecoderLayer`]
184
+
185
+ Args:
186
+ config: HunYuanConfig
187
+ """
188
+
189
+ def __init__(self, config: HunYuanConfig):
190
+ super().__init__(config)
191
+ self.padding_idx = config.pad_token_id
192
+ self.vocab_size = config.vocab_size
193
+ self.add_classification_head = config.add_classification_head
194
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
195
+ self.layers = nn.ModuleList(
196
+ [HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
197
+ )
198
+ self._use_sdpa = config._attn_implementation == "sdpa"
199
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
200
+ if not config.add_classification_head:
201
+ self.norm = HunYuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
202
+
203
+ self.cla = config.use_cla
204
+ self.cla_share_factor = config.cla_share_factor
205
+
206
+ self.gradient_checkpointing = False
207
+ # Initialize weights and apply final processing
208
+ self.post_init()
209
+
210
+ def get_input_embeddings(self):
211
+ return self.embed_tokens
212
+
213
+ def set_input_embeddings(self, value):
214
+ self.embed_tokens = value
215
+
216
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
217
+ def forward(
218
+ self,
219
+ input_ids: torch.LongTensor = None,
220
+ attention_mask: Optional[torch.Tensor] = None,
221
+ position_ids: Optional[torch.LongTensor] = None,
222
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
223
+ inputs_embeds: Optional[torch.FloatTensor] = None,
224
+ use_cache: Optional[bool] = None,
225
+ output_attentions: Optional[bool] = None,
226
+ output_hidden_states: Optional[bool] = None,
227
+ return_dict: Optional[bool] = None,
228
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
229
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
230
+ output_hidden_states = (
231
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
232
+ )
233
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
234
+
235
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
236
+
237
+ # retrieve input_ids and inputs_embeds
238
+ # if input_ids is not None and inputs_embeds is not None:
239
+ # raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
240
+ if input_ids is not None:
241
+ batch_size, seq_length = input_ids.shape[:2]
242
+ elif inputs_embeds is not None:
243
+ batch_size, seq_length = inputs_embeds.shape[:2]
244
+ else:
245
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
246
+
247
+ if self.gradient_checkpointing and self.training:
248
+ if use_cache:
249
+ logger.warning_once(
250
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
251
+ )
252
+ use_cache = False
253
+
254
+ past_key_values_length = 0
255
+ if use_cache:
256
+ use_legacy_cache = not isinstance(past_key_values, Cache)
257
+ if use_legacy_cache:
258
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
259
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
260
+
261
+ if position_ids is None:
262
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
263
+ position_ids = torch.arange(
264
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
265
+ )
266
+ position_ids = position_ids.unsqueeze(0)
267
+
268
+ if inputs_embeds is None:
269
+ inputs_embeds = self.embed_tokens(input_ids)
270
+
271
+ # Fix lora with gradient checkpointing training
272
+ if self.training and inputs_embeds.is_leaf:
273
+ inputs_embeds.requires_grad = True
274
+
275
+ if self._use_flash_attention_2:
276
+ # 2d mask is passed through the layers
277
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
278
+ elif self._use_sdpa and not output_attentions:
279
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
280
+ # the manual implementation that requires a 4D causal mask in all cases.
281
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
282
+ attention_mask,
283
+ (batch_size, seq_length),
284
+ inputs_embeds,
285
+ past_key_values_length,
286
+ )
287
+ else:
288
+ # 4d mask is passed through the layers
289
+ attention_mask = _prepare_4d_causal_attention_mask(
290
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
291
+ )
292
+
293
+ # embed positions
294
+ hidden_states = inputs_embeds
295
+
296
+ # decoder layers
297
+ all_hidden_states = () if output_hidden_states else None
298
+ all_self_attns = () if output_attentions else None
299
+ next_decoder_cache = None
300
+
301
+ prev_kv_states = None
302
+ for layer_idx, decoder_layer in enumerate(self.layers):
303
+ if output_hidden_states:
304
+ all_hidden_states += (hidden_states,)
305
+
306
+ if self.gradient_checkpointing and self.training:
307
+ layer_outputs = self._gradient_checkpointing_func(
308
+ decoder_layer.__call__,
309
+ hidden_states,
310
+ attention_mask,
311
+ position_ids,
312
+ past_key_values,
313
+ output_attentions,
314
+ use_cache,
315
+ prev_kv_states,
316
+ )
317
+ else:
318
+ layer_outputs = decoder_layer(
319
+ hidden_states,
320
+ attention_mask=attention_mask,
321
+ position_ids=position_ids,
322
+ past_key_value=past_key_values,
323
+ output_attentions=output_attentions,
324
+ use_cache=use_cache,
325
+ kv_states=prev_kv_states
326
+ )
327
+
328
+ hidden_states = layer_outputs[0]
329
+
330
+ if use_cache:
331
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
332
+
333
+ if output_attentions:
334
+ all_self_attns += (layer_outputs[1],)
335
+
336
+ kv_states = layer_outputs[-1]
337
+
338
+ if self.cla and layer_idx % self.cla_share_factor == 0:
339
+ prev_kv_states = kv_states
340
+ if not self.add_classification_head:
341
+ hidden_states = self.norm(hidden_states)
342
+
343
+ # add hidden states from the last decoder layer
344
+ if output_hidden_states:
345
+ all_hidden_states += (hidden_states,)
346
+
347
+ next_cache = None
348
+ if use_cache:
349
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
350
+ if not return_dict:
351
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
352
+ return BaseModelOutputWithPast(
353
+ last_hidden_state=hidden_states,
354
+ past_key_values=next_cache,
355
+ hidden_states=all_hidden_states,
356
+ attentions=all_self_attns,
357
+ )
358
+
359
+
360
+ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
361
+ _tied_weights_keys = ["lm_head.weight"]
362
+
363
+ def __init__(self, config: HunYuanConfig):
364
+ super().__init__(config)
365
+
366
+ self.config = config
367
+ self.model = HunYuanModel(config)
368
+ self.add_classification_head = config.add_classification_head
369
+ self.pad_id = config.pad_id
370
+ self.vocab_size = config.vocab_size
371
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
372
+ if config.add_classification_head:
373
+ self.pool_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
374
+ self.pool_head2 = nn.Linear(config.hidden_size, config.class_num, bias=False)
375
+ # Initialize weights and apply final processing
376
+ self.post_init()
377
+
378
+ def get_input_embeddings(self):
379
+ return self.model.embed_tokens
380
+
381
+ def set_input_embeddings(self, value):
382
+ self.model.embed_tokens = value
383
+
384
+ def get_output_embeddings(self):
385
+ return self.lm_head
386
+
387
+ def set_output_embeddings(self, new_embeddings):
388
+ self.lm_head = new_embeddings
389
+
390
+ def set_decoder(self, decoder):
391
+ self.model = decoder
392
+
393
+ def get_decoder(self):
394
+ return self.model
395
+
396
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
397
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
398
+ def forward(
399
+ self,
400
+ input_ids: torch.LongTensor = None,
401
+ attention_mask: Optional[torch.Tensor] = None,
402
+ position_ids: Optional[torch.LongTensor] = None,
403
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
404
+ inputs_embeds: Optional[torch.FloatTensor] = None,
405
+ labels: Optional[torch.LongTensor] = None,
406
+ use_cache: Optional[bool] = None,
407
+ output_attentions: Optional[bool] = None,
408
+ output_hidden_states: Optional[bool] = None,
409
+ return_dict: Optional[bool] = None,
410
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
411
+ r"""
412
+ Args:
413
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
414
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
415
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
416
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
417
+
418
+ Returns:
419
+
420
+ Example:
421
+
422
+ ```python
423
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
424
+
425
+ >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
426
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
427
+
428
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
429
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
430
+
431
+ >>> # Generate
432
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
433
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
434
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
435
+ ```"""
436
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
437
+ output_hidden_states = (
438
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
439
+ )
440
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
441
+
442
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
443
+ outputs = self.model(
444
+ input_ids=input_ids,
445
+ attention_mask=attention_mask,
446
+ position_ids=position_ids,
447
+ past_key_values=past_key_values,
448
+ inputs_embeds=inputs_embeds,
449
+ use_cache=use_cache,
450
+ output_attentions=output_attentions,
451
+ output_hidden_states=output_hidden_states,
452
+ return_dict=return_dict,
453
+ )
454
+
455
+ hidden_states = outputs[0]
456
+
457
+ if not self.add_classification_head:
458
+ if self.config.pretraining_tp > 1:
459
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
460
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
461
+ logits = torch.cat(logits, dim=-1)
462
+ else:
463
+ logits = self.lm_head(hidden_states)
464
+ logits = logits.float()
465
+ else:
466
+ logits = hidden_states
467
+ logits = logits.float()
468
+ pooled_output = self.pool_head(logits)
469
+ pooled_output = torch.tanh(pooled_output)
470
+ pooled_output = self.pool_head2(pooled_output).contiguous() # bs * class_num
471
+ if len(pooled_output.shape) < 2:
472
+ raise ValueError("pooled_output does not have enough dimensions for transpose")
473
+
474
+ if self.config.pool_type == "mean":
475
+ reward = pooled_output.mean(dim=1).squeeze(-1)
476
+ elif self.config.pool_type == "last":
477
+ # bs * hidden_size
478
+ seq_length = (input_ids != self.pad_id).long().sum(dim=1) - 1
479
+ batch_size = input_ids.size(0)
480
+ reward = pooled_output[torch.arange(batch_size, device=pooled_output.device), seq_length].squeeze(-1)
481
+ else:
482
+ reward = pooled_output[:, 0].squeeze(-1)
483
+
484
+ loss = None
485
+ if labels is not None:
486
+ # Shift so that tokens < n predict n
487
+ shift_logits = logits[..., :-1, :].contiguous()
488
+ shift_labels = labels[..., 1:].contiguous()
489
+ # Flatten the tokens
490
+ loss_fct = CrossEntropyLoss()
491
+ shift_logits = shift_logits.reshape(-1, self.config.vocab_size)
492
+ shift_labels = shift_labels.reshape(-1)
493
+ # Enable model parallelism
494
+ shift_labels = shift_labels.to(shift_logits.device)
495
+ loss = loss_fct(shift_logits, shift_labels)
496
+
497
+ if not return_dict:
498
+ output = (logits,) + outputs[1:]
499
+ return (loss,) + output if loss is not None else output
500
+
501
+ output = CausalLMOutputWithPast(
502
+ loss=loss,
503
+ logits=logits,
504
+ past_key_values=outputs.past_key_values,
505
+ hidden_states=outputs.hidden_states,
506
+ attentions=outputs.attentions,
507
+ )
508
+ if self.add_classification_head:
509
+ output['reward'] = reward
510
+
511
+ return output
512
+
513
+ def prepare_inputs_for_generation(
514
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
515
+ ):
516
+ if past_key_values is not None:
517
+ if isinstance(past_key_values, Cache):
518
+ cache_length = past_key_values.get_seq_length()
519
+ past_length = past_key_values.seen_tokens
520
+ max_cache_length = past_key_values.get_max_cache_shape()
521
+ else:
522
+ cache_length = past_length = past_key_values[0][0].shape[2]
523
+ max_cache_length = None
524
+
525
+ # Keep only the unprocessed tokens:
526
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
527
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
528
+ # input)
529
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
530
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
531
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
532
+ # input_ids based on the past_length.
533
+ elif past_length < input_ids.shape[1]:
534
+ input_ids = input_ids[:, past_length:]
535
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
536
+
537
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
538
+ if (
539
+ max_cache_length is not None
540
+ and attention_mask is not None
541
+ and cache_length + input_ids.shape[1] > max_cache_length
542
+ ):
543
+ attention_mask = attention_mask[:, -max_cache_length:]
544
+
545
+ position_ids = kwargs.get("position_ids", None)
546
+ if attention_mask is not None and position_ids is None:
547
+ # create position_ids on the fly for batch generation
548
+ position_ids = attention_mask.long().cumsum(-1) - 1
549
+ position_ids.masked_fill_(attention_mask == 0, 1)
550
+ if past_key_values:
551
+ position_ids = position_ids[:, -input_ids.shape[1]:]
552
+
553
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
554
+ if inputs_embeds is not None and past_key_values is None:
555
+ model_inputs = {"inputs_embeds": inputs_embeds}
556
+ else:
557
+ model_inputs = {"input_ids": input_ids}
558
+
559
+ model_inputs.update(
560
+ {
561
+ "position_ids": position_ids,
562
+ "past_key_values": past_key_values,
563
+ "use_cache": kwargs.get("use_cache"),
564
+ "attention_mask": attention_mask,
565
+ }
566
+ )
567
+ return model_inputs
568
+
569
+ @staticmethod
570
+ def _reorder_cache(past_key_values, beam_idx):
571
+ reordered_past = ()
572
+ for layer_past in past_key_values:
573
+ reordered_past += (
574
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
575
+ )
576
+ return reordered_past
577
+
578
+
579
+ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM):
580
+ _tied_weights_keys = ["lm_head.weight"]
581
+
582
+ def __init__(self, config: HunYuanConfig):
583
+ super().__init__(config)
584
+
585
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
586
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
587
+ def forward(
588
+ self,
589
+ input_ids: torch.LongTensor = None,
590
+ attention_mask: Optional[torch.Tensor] = None,
591
+ position_ids: Optional[torch.LongTensor] = None,
592
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
593
+ inputs_embeds: Optional[torch.FloatTensor] = None,
594
+ labels: Optional[torch.LongTensor] = None,
595
+ imgs: Optional[List[torch.FloatTensor]] = None,
596
+ imgs_pos: Optional[List[int]] = None,
597
+ use_cache: Optional[bool] = None,
598
+ output_attentions: Optional[bool] = None,
599
+ output_hidden_states: Optional[bool] = None,
600
+ return_dict: Optional[bool] = None,
601
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
602
+ r"""
603
+ Args:
604
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
605
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
606
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
607
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
608
+
609
+ Returns:
610
+
611
+ Example:
612
+
613
+ ```python
614
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
615
+
616
+ >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
617
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
618
+
619
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
620
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
621
+
622
+ >>> # Generate
623
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
624
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
625
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
626
+ ```"""
627
+ mask_init_id = self.config.mask_init_id
628
+ pad_id = self.config.pad_token_id
629
+ eod_id = self.config.eod_token_id
630
+ image_token_id = self.config.image_token_id
631
+ im_start_id = self.config.im_start_id
632
+ im_end_id = self.config.im_end_id
633
+ video_start_id = self.config.video_start_id
634
+ video_end_id = self.config.video_end_id
635
+
636
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
637
+ output_hidden_states = (
638
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
639
+ )
640
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
641
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
642
+
643
+ outputs = self.model(
644
+ input_ids=input_ids,
645
+ attention_mask=attention_mask,
646
+ position_ids=position_ids,
647
+ past_key_values=past_key_values,
648
+ inputs_embeds=inputs_embeds,
649
+ use_cache=use_cache,
650
+ output_attentions=output_attentions,
651
+ output_hidden_states=output_hidden_states,
652
+ return_dict=return_dict,
653
+ )
654
+
655
+ hidden_states = outputs[0]
656
+ if self.config.pretraining_tp > 1:
657
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
658
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
659
+ logits = torch.cat(logits, dim=-1)
660
+ else:
661
+ logits = self.lm_head(hidden_states)
662
+ logits = logits.float()
663
+
664
+ loss = None
665
+ if labels is not None:
666
+ labels = labels.to(logits.device)
667
+ # Shift so that tokens < n predict n
668
+ shift_logits = logits
669
+ shift_labels = labels
670
+ # Flatten the tokens
671
+ loss_fct = CrossEntropyLoss()
672
+ shift_logits = shift_logits.reshape(-1, self.config.vocab_size)
673
+ shift_labels = shift_labels.reshape(-1)
674
+ shift_tokens = input_ids.reshape(-1)
675
+ # compute loss
676
+ mask = (shift_labels < mask_init_id) & (shift_labels != pad_id) & (shift_labels != image_token_id) & (shift_labels != im_start_id) \
677
+ & (shift_labels != im_end_id) & (shift_labels != video_start_id) & (shift_labels != video_end_id) & (shift_tokens != pad_id) & (shift_tokens != eod_id)
678
+ shift_logits = shift_logits[mask, :]
679
+ shift_labels = shift_labels[mask]
680
+ loss = loss_fct(shift_logits, shift_labels)
681
+
682
+ if not return_dict:
683
+ output = (logits,) + outputs[1:]
684
+ return (loss,) + output if loss is not None else output
685
+
686
+ return CausalLMOutputWithPast(
687
+ loss=loss,
688
+ logits=logits,
689
+ past_key_values=outputs.past_key_values,
690
+ hidden_states=outputs.hidden_states,
691
+ attentions=outputs.attentions,
692
+ )
693
+
694
+ def prepare_inputs_for_generation(
695
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
696
+ ):
697
+ imgs = kwargs.pop("imgs", None)
698
+ imgs_pos = kwargs.pop("imgs_pos", None)
699
+ inputs = super().prepare_inputs_for_generation(
700
+ input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs
701
+ )
702
+
703
+ if imgs is not None:
704
+ inputs['imgs'] = imgs
705
+ if imgs_pos is not None:
706
+ inputs['imgs_pos'] = imgs_pos
707
+ return inputs
708
+
709
+ @torch.no_grad()
710
+ def generate(
711
+ self,
712
+ inputs: Optional[torch.Tensor] = None,
713
+ attention_mask: Optional[torch.Tensor] = None,
714
+ position_ids: Optional[torch.LongTensor] = None,
715
+ imgs: Optional[List[torch.FloatTensor]] = None,
716
+ imgs_pos: Optional[List[int]] = None,
717
+ **kwargs,
718
+ ) -> Union[GenerateOutput, torch.LongTensor]:
719
+ if "inputs_embeds" in kwargs:
720
+ raise NotImplementedError("`inputs_embeds` is not supported")
721
+
722
+ return super().generate(
723
+ inputs=input_ids,
724
+ position_ids=position_ids,
725
+ attention_mask=attention_mask,
726
+ inputs_embeds=inputs_embeds,
727
+ eos_token_id=self.config.eod_token_id,
728
+ **kwargs
729
+ )
730
+
731
+
732
+ @add_start_docstrings(
733
+ """
734
+ The HunYuan Model transformer with a sequence classification head on top (linear layer).
735
+
736
+ [`HunYuanForSequenceClassification`] uses the last token in order to do the classification, as other causal models
737
+ (e.g. GPT-2) do.
738
+
739
+ Since it does classification on the last token, it requires to know the position of the last token. If a
740
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
741
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
742
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
743
+ each row of the batch).
744
+ """,
745
+ HUNYUAN_START_DOCSTRING,
746
+ )
747
+ class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
748
+ def __init__(self, config):
749
+ super().__init__(config)
750
+ self.num_labels = config.num_labels
751
+ self.model = HunYuanModel(config)
752
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
753
+
754
+ # Initialize weights and apply final processing
755
+ self.post_init()
756
+
757
+ def get_input_embeddings(self):
758
+ return self.model.embed_tokens
759
+
760
+ def set_input_embeddings(self, value):
761
+ self.model.embed_tokens = value
762
+
763
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
764
+ def forward(
765
+ self,
766
+ input_ids: torch.LongTensor = None,
767
+ attention_mask: Optional[torch.Tensor] = None,
768
+ position_ids: Optional[torch.LongTensor] = None,
769
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
770
+ inputs_embeds: Optional[torch.FloatTensor] = None,
771
+ labels: Optional[torch.LongTensor] = None,
772
+ use_cache: Optional[bool] = None,
773
+ output_attentions: Optional[bool] = None,
774
+ output_hidden_states: Optional[bool] = None,
775
+ return_dict: Optional[bool] = None,
776
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
777
+ r"""
778
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
779
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
780
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
781
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
782
+ """
783
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
784
+
785
+ transformer_outputs = self.model(
786
+ input_ids,
787
+ attention_mask=attention_mask,
788
+ position_ids=position_ids,
789
+ past_key_values=past_key_values,
790
+ inputs_embeds=inputs_embeds,
791
+ use_cache=use_cache,
792
+ output_attentions=output_attentions,
793
+ output_hidden_states=output_hidden_states,
794
+ return_dict=return_dict,
795
+ )
796
+ hidden_states = transformer_outputs[0]
797
+ logits = self.score(hidden_states)
798
+
799
+ if input_ids is not None:
800
+ batch_size = input_ids.shape[0]
801
+ else:
802
+ batch_size = inputs_embeds.shape[0]
803
+
804
+ if self.config.pad_token_id is None and batch_size != 1:
805
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
806
+ if self.config.pad_token_id is None:
807
+ sequence_lengths = -1
808
+ else:
809
+ if input_ids is not None:
810
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
811
+ logits.device
812
+ )
813
+ else:
814
+ sequence_lengths = -1
815
+
816
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
817
+
818
+ loss = None
819
+ if labels is not None:
820
+ labels = labels.to(logits.device)
821
+ if self.config.problem_type is None:
822
+ if self.num_labels == 1:
823
+ self.config.problem_type = "regression"
824
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
825
+ self.config.problem_type = "single_label_classification"
826
+ else:
827
+ self.config.problem_type = "multi_label_classification"
828
+
829
+ if self.config.problem_type == "regression":
830
+ loss_fct = MSELoss()
831
+ if self.num_labels == 1:
832
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
833
+ else:
834
+ loss = loss_fct(pooled_logits, labels)
835
+ elif self.config.problem_type == "single_label_classification":
836
+ loss_fct = CrossEntropyLoss()
837
+ loss = loss_fct(pooled_logits.reshape(-1, self.num_labels), labels.reshape(-1))
838
+ elif self.config.problem_type == "multi_label_classification":
839
+ loss_fct = BCEWithLogitsLoss()
840
+ loss = loss_fct(pooled_logits, labels)
841
+ if not return_dict:
842
+ output = (pooled_logits,) + transformer_outputs[1:]
843
+ return ((loss,) + output) if loss is not None else output
844
+
845
+ return SequenceClassifierOutputWithPast(
846
+ loss=loss,
847
+ logits=pooled_logits,
848
+ past_key_values=transformer_outputs.past_key_values,
849
+ hidden_states=transformer_outputs.hidden_states,
850
+ attentions=transformer_outputs.attentions,
851
+ )