lin5547 commited on
Commit
dd7417a
·
verified ·
1 Parent(s): 20ce51c

Upload folder using huggingface_hub

Browse files
added_tokens.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<B_APE>": 151671,
4
+ "<B_CODE>": 151670,
5
+ "<B_FUNC>": 151669,
6
+ "<B_SYS>": 151665,
7
+ "<B_USYS>": 151666,
8
+ "<C_A>": 151668,
9
+ "<C_Q>": 151667,
10
+ "<audio_delim_baichuan>": 151693,
11
+ "<audio_end_baichuan>": 151677,
12
+ "<audio_pad_baichuan>": 151678,
13
+ "<audio_start_baichuan>": 151676,
14
+ "<audiogen_end_baichuan>": 151701,
15
+ "<audiogen_start_baichuan>": 151700,
16
+ "<audiotext_end_baichuan>": 151698,
17
+ "<audiotext_pad_baichuan>": 151699,
18
+ "<audiotext_start_baichuan>": 151697,
19
+ "<baichuan_pad_token>": 151691,
20
+ "<box_delim_baichuan>": 151685,
21
+ "<box_end_baichuan>": 151684,
22
+ "<box_start_baichuan>": 151683,
23
+ "<calc_end>": 151674,
24
+ "<calc_start>": 151673,
25
+ "<function_calling>": 151672,
26
+ "<img_delim_baichuan>": 151688,
27
+ "<img_end_baichuan>": 151680,
28
+ "<img_newline_baichuan>": 151682,
29
+ "<img_pad_baichuan>": 151681,
30
+ "<img_start_baichuan>": 151679,
31
+ "<inner_think>": 151675,
32
+ "<polygon_end_baichuan>": 151690,
33
+ "<polygon_start_baichuan>": 151689,
34
+ "<ref_end_baichuan>": 151687,
35
+ "<ref_start_baichuan>": 151686,
36
+ "<reserved_113>": 151692,
37
+ "<tool_call>": 151657,
38
+ "<video_end_baichuan>": 151696,
39
+ "<video_palce_baichuan>": 151694,
40
+ "<video_start_baichuan>": 151695,
41
+ "<|box_end|>": 151649,
42
+ "<|box_start|>": 151648,
43
+ "<|endoftext|>": 151643,
44
+ "<|file_sep|>": 151664,
45
+ "<|fim_middle|>": 151660,
46
+ "<|fim_pad|>": 151662,
47
+ "<|fim_prefix|>": 151659,
48
+ "<|fim_suffix|>": 151661,
49
+ "<|im_end|>": 151645,
50
+ "<|im_start|>": 151644,
51
+ "<|image_pad|>": 151655,
52
+ "<|object_ref_end|>": 151647,
53
+ "<|object_ref_start|>": 151646,
54
+ "<|quad_end|>": 151651,
55
+ "<|quad_start|>": 151650,
56
+ "<|repo_name|>": 151663,
57
+ "<|video_pad|>": 151656,
58
+ "<|vision_end|>": 151653,
59
+ "<|vision_pad|>": 151654,
60
+ "<|vision_start|>": 151652
61
+ }
audio_modeling_omni.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, fire
2
+ from typing import Optional
3
+ import torch.distributed
4
+ from torch.nn import functional as F
5
+ from flash_attn import flash_attn_varlen_func
6
+ from torch import nn
7
+ import numpy as np
8
+ import deepspeed
9
+ from transformers.activations import ACT2FN
10
+ from dataclasses import dataclass
11
+ from transformers.modeling_outputs import ModelOutput
12
+ try:
13
+ from .vector_quantize import VectorQuantize
14
+ except:
15
+ from vector_quantize import VectorQuantize
16
+
17
+ from .flow_matching import (
18
+ ConditionalDecoder,
19
+ ConditionalCFM,
20
+ )
21
+
22
+ import math
23
+ import copy
24
+
25
+ def sinusoids(length, channels, max_timescale=10000):
26
+ """Returns sinusoids for positional embedding"""
27
+ assert channels % 2 == 0
28
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
29
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
30
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
31
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
32
+
33
+ def get_sequence_mask(inputs, inputs_length):
34
+ if inputs.dim() == 3:
35
+ bsz, tgt_len, _ = inputs.size()
36
+ else:
37
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
38
+ sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
39
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
40
+ unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
41
+ return sequence_mask, unpacking_index
42
+
43
+ def unpack_hidden_states(hidden_states, lengths):
44
+ bsz = lengths.shape[0]
45
+ sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
46
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
47
+ bsz, torch.max(lengths), hidden_states.shape[-1]
48
+ )
49
+ hidden_states = torch.where(
50
+ sequence_mask, hidden_states, 0
51
+ ) # 3d (bsz, max_input_len, d)
52
+ return hidden_states
53
+
54
+
55
+ class RMSNorm(nn.Module):
56
+ def __init__(self, hidden_size, eps=1e-6):
57
+ """
58
+ RMSNorm is equivalent to T5LayerNorm
59
+ """
60
+ super().__init__()
61
+ self.weight = nn.Parameter(torch.ones(hidden_size))
62
+ self.variance_epsilon = eps
63
+
64
+ def forward(self, hidden_states):
65
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
66
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
67
+
68
+ # convert into half-precision if necessary
69
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
70
+ hidden_states = hidden_states.to(self.weight.dtype)
71
+
72
+ return self.weight * hidden_states
73
+
74
+
75
+ class OmniWhisperAttention(nn.Module):
76
+ def __init__(self, embed_dim, num_heads, causal=False):
77
+ super().__init__()
78
+ self.embed_dim = embed_dim
79
+ self.num_heads = num_heads
80
+ self.head_dim = embed_dim // num_heads
81
+
82
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
83
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
84
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
85
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
86
+
87
+ self.causal = causal
88
+
89
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
90
+ bsz, _ = hidden_states.size()
91
+
92
+ query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
93
+ key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
94
+ value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
95
+
96
+ cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
97
+ max_seqlen = torch.max(seq_len).to(torch.int32).detach()
98
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen,
99
+ max_seqlen, causal=self.causal) # (bsz * qlen, nheads, headdim)
100
+ attn_output = attn_output.reshape(bsz, self.embed_dim)
101
+ attn_output = self.out_proj(attn_output)
102
+ return attn_output
103
+
104
+
105
+ class OmniWhisperTransformerLayer(nn.Module):
106
+ def __init__(
107
+ self,
108
+ act,
109
+ d_model,
110
+ encoder_attention_heads,
111
+ encoder_ffn_dim,
112
+ causal,
113
+ ln_type="LayerNorm",
114
+ ):
115
+ super().__init__()
116
+ self.embed_dim = d_model
117
+ self.self_attn = OmniWhisperAttention(
118
+ self.embed_dim, encoder_attention_heads, causal
119
+ )
120
+
121
+ if ln_type == "LayerNorm":
122
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
123
+ elif ln_type == "RMSNorm":
124
+ self.self_attn_layer_norm = RMSNorm(self.embed_dim)
125
+ else:
126
+ raise ValueError(f"Unknown ln_type: {ln_type}")
127
+
128
+ self.activation_fn = act
129
+ self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
130
+ self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
131
+
132
+ if ln_type == "LayerNorm":
133
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
134
+ elif ln_type == "RMSNorm":
135
+ self.final_layer_norm = RMSNorm(self.embed_dim)
136
+ else:
137
+ raise ValueError(f"Unknown ln_type: {ln_type}")
138
+
139
+ def forward(
140
+ self, hidden_states: torch.Tensor, seq_len: torch.Tensor
141
+ ) -> torch.Tensor:
142
+ residual = hidden_states
143
+ hidden_states = self.self_attn_layer_norm(hidden_states)
144
+ hidden_states = self.self_attn(hidden_states, seq_len)
145
+ hidden_states = residual + hidden_states
146
+ residual = hidden_states
147
+ hidden_states = self.final_layer_norm(hidden_states)
148
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
149
+ hidden_states = self.fc2(hidden_states)
150
+ hidden_states = residual + hidden_states
151
+
152
+ if (
153
+ hidden_states.dtype == torch.float16
154
+ or hidden_states.dtype == torch.bfloat16
155
+ ) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
156
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
157
+ hidden_states = torch.clamp(
158
+ hidden_states, min=-clamp_value, max=clamp_value
159
+ )
160
+ return hidden_states
161
+
162
+
163
+ class OmniAudioEncoder(nn.Module):
164
+ def __init__(self, config):
165
+ super().__init__()
166
+ config._attn_implementation = 'flash_attention_2' #
167
+ self.config = config
168
+ self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
169
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
170
+
171
+ self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
172
+ self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size,
173
+ stride=config.stride_size, padding=1)
174
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
175
+
176
+ self.layers = nn.ModuleList([OmniWhisperTransformerLayer(
177
+ ACT2FN[config.activation_function],
178
+ config.d_model,
179
+ config.encoder_attention_heads,
180
+ config.encoder_ffn_dim,
181
+ False) for _ in range(config.encoder_layers)])
182
+ self.layer_norm = nn.LayerNorm(config.d_model)
183
+
184
+ @torch.no_grad()
185
+ def fake_input(self, device):
186
+ input_features = torch.rand([2, self.config.num_mel_bins, 10], dtype=torch.float32, device=device)
187
+ encoder_length = torch.ones([2], dtype=torch.int32, device=device) * 3
188
+ bridge_length = torch.ones([2], dtype=torch.int32, device=device)
189
+ return input_features, encoder_length, bridge_length
190
+
191
+ def forward(
192
+ self,
193
+ input_features,
194
+ output_length,
195
+ ):
196
+ input_features = input_features.to(self.conv1.weight.dtype)
197
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
198
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
199
+ inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
200
+ bsz, tgt_len, _ = inputs_embeds.size()
201
+ if tgt_len < self.positional_embedding.shape[0]:
202
+ current_positional_embedding = self.positional_embedding[:tgt_len]
203
+ else:
204
+ current_positional_embedding = self.positional_embedding
205
+ hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
206
+
207
+ # packing hidden states
208
+ attention_mask, unpacking_index = get_sequence_mask(hidden_states, output_length)
209
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length),
210
+ self.config.d_model)
211
+
212
+ for idx, encoder_layer in enumerate(self.layers):
213
+ hidden_states = encoder_layer(hidden_states, output_length)
214
+ hidden_states = self.layer_norm(hidden_states)
215
+ # unpacking
216
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
217
+ hidden_states = torch.where(attention_mask, hidden_states, 0)
218
+ return hidden_states
219
+
220
+
221
+ class CasualConvTranspose1d(nn.Module): # 反卷积
222
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
223
+ super().__init__()
224
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
225
+ self.norm = nn.GroupNorm(1, out_channels)
226
+ self.in_channels = in_channels
227
+ self.out_channels = out_channels
228
+
229
+ def forward(self, hidden_states, input_length, output_dim=None):
230
+ kernel_size = self.conv.kernel_size[0]
231
+ stride = self.conv.stride[0]
232
+ bsz = input_length.shape[0]
233
+
234
+ if output_dim is None:
235
+ output_dim = hidden_states.dim()
236
+ if hidden_states.dim() <= 2: # unpack sequence to 3d
237
+ sequence_mask, unpacking_index = get_sequence_mask(hidden_states, input_length)
238
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, torch.max(input_length),
239
+ self.in_channels)
240
+ hidden_states = torch.where(sequence_mask, hidden_states, 0) # 3d (bsz, max_input_len, d)
241
+
242
+ hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
243
+ hidden_states = self.conv(hidden_states)
244
+ hidden_states = self.norm(hidden_states)
245
+ hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
246
+
247
+ casual_padding_right = max(0, kernel_size - stride)
248
+ hidden_states = hidden_states[:, :hidden_states.shape[1] - casual_padding_right,
249
+ :]
250
+ output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
251
+ sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
252
+ if output_dim <= 2:
253
+ hidden_states = torch.masked_select(hidden_states, sequence_mask).view(-1, self.out_channels)
254
+ else:
255
+ hidden_states = torch.where(sequence_mask, hidden_states, 0)
256
+ hidden_states = hidden_states[:, :torch.max(output_length), :] # 截断到最大有效长度
257
+ return hidden_states, output_length
258
+
259
+
260
+ class MelSpecRefineNet(nn.Module):
261
+ """
262
+ # post net, coarse to refined mel-spectrogram frames
263
+ # ref1: Autoregressive Speech Synthesis without Vector Quantization
264
+ # ref2: CosyVoice length_regulator.py
265
+ # ref3: Neural Speech Synthesis with Transformer Network https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
266
+ """
267
+
268
+ def __init__(self, encoder_config, vocoder_config):
269
+ super().__init__()
270
+ self.encoder_config = encoder_config
271
+ self.vocoder_config = vocoder_config
272
+
273
+ layers = nn.ModuleList([])
274
+ in_channels = self.vocoder_config.num_mel_bins
275
+ for i, out_channels in enumerate(self.vocoder_config.channels[:-1]):
276
+ module = nn.Conv1d(in_channels, out_channels, 5, 1, 2) # cosyvoice kernel=3, stride=1, pad=1
277
+ in_channels = out_channels
278
+ norm = nn.GroupNorm(1, out_channels)
279
+ act = nn.Mish()
280
+ layers.extend([module, norm, act])
281
+ layers.append(nn.Conv1d(in_channels, self.vocoder_config.num_mel_bins, 1, 1)) # projector
282
+ self.layers = nn.Sequential(*layers)
283
+
284
+ def compute_output_length(self, input_length):
285
+ output_length = input_length.to(
286
+ torch.float32) * self.encoder_config.hop_length / self.encoder_config.sampling_rate
287
+ output_length = output_length * self.vocoder_config.sampling_rate / self.vocoder_config.hop_length
288
+ return output_length.to(torch.int64)
289
+
290
+ def forward(self, coarse_mel, input_length, output_length=None):
291
+ bsz, _, d = coarse_mel.shape
292
+ assert (d == self.vocoder_config.num_mel_bins)
293
+ if output_length is None or not self.training:
294
+ output_length = self.compute_output_length(input_length)
295
+ coarse_mel, default_dtype = coarse_mel[:, :torch.max(input_length), :], coarse_mel.dtype
296
+ coarse_mel = F.interpolate(coarse_mel.to(torch.float32).transpose(1, 2).contiguous(), size=output_length.max(),
297
+ mode='nearest').to(default_dtype)
298
+ refined_mel = self.layers(coarse_mel).transpose(1, 2).contiguous() # (bs, t, d)
299
+ coarse_mel = coarse_mel.transpose(1, 2) # (bs, max(output_length), d)
300
+ refined_mel += coarse_mel # residual conntection
301
+ sequence_mask, _ = get_sequence_mask(refined_mel, output_length)
302
+ coarse_mel = torch.where(sequence_mask, coarse_mel, 0)
303
+ refined_mel = torch.where(sequence_mask, refined_mel, 0)
304
+ return refined_mel, coarse_mel, output_length
305
+
306
+
307
+ @dataclass
308
+ class OmniAudioDecoderOutput(ModelOutput):
309
+ refined_mel: Optional[torch.FloatTensor] = None
310
+ coarse_mel: Optional[torch.FloatTensor] = None
311
+ mel_length: Optional[torch.Tensor] = None
312
+ hidden_states_before_dconv2: Optional[torch.FloatTensor] = None
313
+ output_length_before_dconv2: Optional[torch.Tensor] = None
314
+
315
+
316
+ class OmniAudioDecoder(nn.Module):
317
+ def __init__(self, config):
318
+ super().__init__()
319
+ self.config = config.audio_config
320
+ self.vocoder_config = config.vocoder_config
321
+ self.max_source_positions = self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length
322
+
323
+ self.dconv1 = CasualConvTranspose1d(
324
+ self.config.d_model,
325
+ self.config.d_model,
326
+ self.config.decoder_kernel_size,
327
+ self.config.avg_pooler,
328
+ )
329
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, self.config.d_model))
330
+ # causal transformer layers
331
+ self.layers = nn.ModuleList(
332
+ [OmniWhisperTransformerLayer(
333
+ ACT2FN[self.config.activation_function],
334
+ self.config.d_model,
335
+ self.config.decoder_attention_heads,
336
+ self.config.decoder_ffn_dim,
337
+ True # causal
338
+ ) for _ in range(self.config.decoder_layers)
339
+ ])
340
+ self.layer_norm = nn.LayerNorm(self.config.d_model)
341
+ self.dconv2 = CasualConvTranspose1d(
342
+ self.config.d_model,
343
+ self.vocoder_config.num_mel_bins,
344
+ self.config.decoder_kernel_size,
345
+ self.config.decoder_stride_size
346
+ )
347
+ self.post_net = MelSpecRefineNet(config.audio_config, config.vocoder_config)
348
+ self.gradient_checkpointing = True
349
+
350
+ @torch.no_grad()
351
+ def fake_input(self, device):
352
+ audio_embed = torch.rand([1, 10, self.config.d_model], dtype=torch.float32, device=device)
353
+ input_length = torch.ones([1], dtype=torch.int32, device=device) * 10
354
+ mel_labels_length = self.post_net.compute_output_length(input_length)
355
+ return audio_embed, input_length, None, mel_labels_length
356
+
357
+ def forward(self,
358
+ audio_embed,
359
+ input_length,
360
+ mel_labels=None,
361
+ mel_labels_length=None,
362
+ fake_input=False,
363
+ ):
364
+ if fake_input:
365
+ audio_embed, input_length, mel_labels, mel_labels_length = self.fake_input(self.layer_norm.weight.device)
366
+
367
+ assert (audio_embed.shape[-1] == self.config.d_model)
368
+ audio_embed = audio_embed.to(self.layer_norm.weight) # device and type
369
+ audio_embed, output_length = self.dconv1(audio_embed, input_length, output_dim=3) # (b, l*2, d_model)
370
+ _, tgt_len, _ = audio_embed.size()
371
+ if tgt_len < self.positional_embedding.shape[0]:
372
+ current_positional_embedding = self.positional_embedding[:tgt_len]
373
+ else:
374
+ current_positional_embedding = self.positional_embedding
375
+ hidden_states = (audio_embed.to(torch.float32) + current_positional_embedding).to(audio_embed.dtype)
376
+
377
+ # packing hidden states
378
+ attention_mask, _ = get_sequence_mask(hidden_states, output_length)
379
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
380
+
381
+ for idx, encoder_layer in enumerate(self.layers):
382
+ hidden_states = encoder_layer(hidden_states, output_length)
383
+
384
+ hidden_states = self.layer_norm(hidden_states)
385
+ hidden_states_before_dconv2 = hidden_states
386
+ output_length_before_dconv2 = output_length
387
+
388
+ coarse_mel, output_length = self.dconv2(hidden_states, output_length, output_dim=3)
389
+ refined_mel, coarse_mel, mel_labels_length = self.post_net(coarse_mel, output_length, mel_labels_length)
390
+
391
+ return OmniAudioDecoderOutput(
392
+ refined_mel=refined_mel,
393
+ coarse_mel=coarse_mel,
394
+ mel_length=mel_labels_length,
395
+ hidden_states_before_dconv2=hidden_states_before_dconv2,
396
+ output_length_before_dconv2=output_length_before_dconv2,
397
+ )
398
+
399
+
400
+ class OmniAudioVQBridgeTokenizer(nn.Module):
401
+ def __init__(self, config):
402
+ super().__init__()
403
+ self.config = config.audio_config
404
+ self.gradient_checkpointing = False
405
+ self.intermediate_dim = self.config.d_model * self.config.avg_pooler
406
+ self.gate_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
407
+ self.up_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
408
+
409
+ self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
410
+ self.act_fn = ACT2FN['silu']
411
+ self.layer_norm = nn.LayerNorm(self.intermediate_dim)
412
+ self.proj_decoder = nn.Linear(self.intermediate_dim, self.config.d_model)
413
+
414
+ self.vq_list = nn.ModuleList([])
415
+ for idx, codebook_size in enumerate(self.config.vq_config.codebook_sizes):
416
+ vq_config = copy.deepcopy(self.config.vq_config)
417
+ vq_config.dim = self.intermediate_dim
418
+ vq_config.codebook_size = codebook_size
419
+ self.vq_list.append(VectorQuantize(vq_config))
420
+ for vq_layer in self.vq_list:
421
+ deepspeed.zero.register_external_parameter(self, vq_layer.codebook.embed)
422
+
423
+ def rvq_op(self, inputs, output_length):
424
+ def rvq_layer_op(vq_layer, residual_encoding, output_length):
425
+ q_v_i, code_ids_i = vq_layer(residual_encoding, output_length)
426
+ residual_encoding = residual_encoding.float() - q_v_i.float()
427
+ residual_encoding = residual_encoding.to(inputs.dtype)
428
+ return residual_encoding, code_ids_i
429
+
430
+ cmt_loss, residual_encoding = 0, inputs
431
+ code_ids_list = []
432
+ for i, vq_layer in enumerate(self.vq_list):
433
+ residual_encoding, code_ids_i = rvq_layer_op(vq_layer, residual_encoding, output_length)
434
+ code_ids_list.append(code_ids_i)
435
+ return torch.stack(code_ids_list, -1)
436
+
437
+ def forward(self, x, output_length):
438
+ batch_size, _, _ = x.shape
439
+ output_length = output_length.to(x.device)
440
+
441
+ if x.shape[1] % self.config.avg_pooler != 0:
442
+ x = F.pad(x, (0, 0, 0, self.config.avg_pooler - x.shape[1] % self.config.avg_pooler), "constant", 0)
443
+ xt = x.permute(0, 2, 1)
444
+ g = self.gate_proj(xt).permute(0, 2, 1) # (bs, sl//poolersizre+1, d*2)
445
+ u = self.up_proj(xt).permute(0, 2, 1)
446
+ x = x.reshape(batch_size, -1, self.intermediate_dim) # (bs, sl//poolersizre+1, d*2)
447
+
448
+ c = self.down_proj(self.act_fn(g) * u)
449
+ res = self.layer_norm(c + x)
450
+ valid_mask, _ = get_sequence_mask(res, output_length)
451
+ code_ids = self.rvq_op(res, output_length)
452
+ code_ids = torch.masked_select(code_ids, valid_mask).reshape(-1, len(self.vq_list)) # (sum(valid_sequence_length), vq_num)
453
+ return code_ids
454
+
455
+ @torch.no_grad()
456
+ def decode(self, code_ids):
457
+ vq_num = code_ids.shape[-1]
458
+ res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
459
+ decoder_emb = self.proj_decoder(res.to(self.proj_decoder.weight))
460
+ return decoder_emb
461
+
462
+ @torch.no_grad()
463
+ def recover(self, code_ids):
464
+ vq_num = code_ids.shape[-1]
465
+ res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
466
+ return res
467
+
468
+
469
+ class FlowmatchingPrenet(nn.Module):
470
+ def __init__(
471
+ self,
472
+ input_feat_dim,
473
+ out_feat_dim,
474
+ d_model,
475
+ attention_heads,
476
+ ffn_dim,
477
+ nlayers,
478
+ activation_function,
479
+ max_source_positions,
480
+ target_mel_length_scale_ratio,
481
+ ):
482
+ super().__init__()
483
+
484
+ self.d_model = d_model
485
+ self.target_mel_length_scale_ratio = target_mel_length_scale_ratio
486
+ self.gradient_checkpointing = False
487
+
488
+ self.register_buffer(
489
+ "positional_embedding", sinusoids(max_source_positions, d_model)
490
+ )
491
+
492
+ self.in_mlp = nn.Sequential(
493
+ nn.Linear(input_feat_dim, d_model * 4),
494
+ nn.SiLU(),
495
+ nn.Linear(d_model * 4, d_model),
496
+ )
497
+
498
+ self.transformer_layers = nn.ModuleList(
499
+ [
500
+ OmniWhisperTransformerLayer(
501
+ act=ACT2FN[activation_function],
502
+ d_model=d_model,
503
+ encoder_attention_heads=attention_heads,
504
+ encoder_ffn_dim=ffn_dim,
505
+ causal=True, # causal
506
+ ln_type="RMSNorm",
507
+ )
508
+ for _ in range(nlayers)
509
+ ]
510
+ )
511
+
512
+ self.final_norm = RMSNorm(self.d_model)
513
+ self.out_proj = nn.Linear(d_model, out_feat_dim, bias=False)
514
+
515
+ def compute_output_length(self, input_length):
516
+ output_length = input_length.float() * self.target_mel_length_scale_ratio
517
+ return output_length.to(torch.int64)
518
+
519
+ def forward(self, input_feat, input_length, output_length=None):
520
+ """
521
+ Args:
522
+ input_feat: [B, T, input_feat_dim]
523
+ input_length: [B]
524
+ output_length: [B]
525
+
526
+ """
527
+ if output_length is None or not self.training:
528
+ output_length = self.compute_output_length(input_length)
529
+
530
+ input_feat = input_feat[:, : input_length.max(), :] # [B, T, D]
531
+ orig_dtype = input_feat.dtype
532
+
533
+ input_feat = F.interpolate(
534
+ input=input_feat.to(torch.float32).transpose(1, 2).contiguous(),
535
+ size=output_length.max(),
536
+ mode="nearest",
537
+ ).to(orig_dtype)
538
+ input_feat = input_feat.transpose(1, 2).contiguous() # [B, T, D]
539
+ hidden_states = self.in_mlp(input_feat)
540
+
541
+ # packing hidden states
542
+ bsz, tgt_len, d_model = hidden_states.shape
543
+ attention_mask, unpacking_index = get_sequence_mask(
544
+ hidden_states, output_length
545
+ )
546
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(
547
+ torch.sum(output_length), self.d_model
548
+ )
549
+
550
+ for idx, encoder_layer in enumerate(self.transformer_layers):
551
+ hidden_states = encoder_layer(hidden_states, output_length)
552
+
553
+ # unpacking
554
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
555
+ bsz, tgt_len, d_model
556
+ )
557
+ hidden_states = torch.where(attention_mask, hidden_states, 0)
558
+
559
+ hidden_states = self.final_norm(hidden_states)
560
+ output = self.out_proj(hidden_states)
561
+ return output, output_length
562
+
563
+
564
+ @dataclass
565
+ class OmniAudioFlowMatchingDecoderOutput(ModelOutput):
566
+ flow_matching_mel: Optional[torch.FloatTensor] = None
567
+ flow_matching_mel_lengths: Optional[torch.FloatTensor] = None
568
+
569
+
570
+ class OmniAudioFlowMatchingDecoder(nn.Module):
571
+ def __init__(self, config):
572
+ super().__init__()
573
+ self.config = config.flow_matching_config
574
+ self.in_channels = self.config.in_channels
575
+ self.spk_emb_dim = self.config.spk_emb_dim
576
+ self.diffusion_steps = self.config.diffusion_steps
577
+ self.cal_mel_mae = self.config.cal_mel_mae
578
+ self.forward_step = -1
579
+
580
+ self.prenet = FlowmatchingPrenet(
581
+ input_feat_dim=self.config.prenet_in_dim,
582
+ out_feat_dim=self.config.prenet_out_dim,
583
+ d_model=self.config.prenet_d_model,
584
+ attention_heads=self.config.prenet_attention_heads,
585
+ ffn_dim=self.config.prenet_ffn_dim,
586
+ nlayers=self.config.prenet_nlayers,
587
+ activation_function=self.config.prenet_activation_function,
588
+ max_source_positions=self.config.prenet_max_source_positions,
589
+ target_mel_length_scale_ratio=self.config.prenet_target_mel_length_scale_ratio,
590
+ )
591
+
592
+ self.conditional_decoder = ConditionalDecoder(
593
+ in_channels=self.in_channels * 2 + self.spk_emb_dim,
594
+ out_channels=self.in_channels,
595
+ causal=True,
596
+ channels=self.config.channels,
597
+ dropout=self.config.dropout,
598
+ attention_head_dim=self.config.attention_head_dim,
599
+ n_blocks=self.config.n_blocks,
600
+ num_mid_blocks=self.config.num_mid_blocks,
601
+ num_heads=self.config.num_heads,
602
+ act_fn=self.config.act_fn,
603
+ )
604
+
605
+ self.cfm = ConditionalCFM(
606
+ in_channels=self.in_channels,
607
+ cfm_params=self.config.cfm_params,
608
+ n_spks=0,
609
+ spk_emb_dim=self.spk_emb_dim,
610
+ )
611
+
612
+
613
+ def unpack_hidden_states(self, hidden_states, output_length):
614
+ unpacked = unpack_hidden_states(hidden_states, output_length)
615
+ return unpacked, output_length
616
+
617
+ def forward(
618
+ self, refined_mel, input_length, mel_labels=None, mel_labels_length=None
619
+ ):
620
+ """
621
+ :param refined_mel: [bs, max_input_len, mel_bin]
622
+ :param input_length: [batch_size]
623
+ :param refined_mel: [bs, mel_bin, max_input_len]
624
+ :return:
625
+ """
626
+ self.forward_step += 1
627
+
628
+ orig_dtype = refined_mel.dtype
629
+ prenet_mae_metric = torch.tensor(0.0).to(refined_mel.device)
630
+ prenet_regression_loss = torch.tensor(0.0).to(refined_mel.device)
631
+
632
+ if self.prenet is not None:
633
+ refined_mel = refined_mel[:, : torch.max(input_length), :]
634
+ if mel_labels_length is None:
635
+ mel_labels_length = self.prenet.compute_output_length(input_length)
636
+ refined_mel, input_length = self.prenet(
637
+ refined_mel, input_length, mel_labels_length
638
+ )
639
+
640
+ float_dtype = refined_mel.dtype
641
+ refined_mel = refined_mel.float()
642
+ input_length = input_length.long()
643
+
644
+ refined_mel = refined_mel[:, : torch.max(input_length), :]
645
+ sequence_mask, unpacking_index = get_sequence_mask(refined_mel, input_length)
646
+ refined_mel = refined_mel.transpose(1, 2) # (bs, mel_bin, max_input_len)
647
+ sequence_mask = sequence_mask.transpose(2, 1) # (bs, 1, sl)
648
+
649
+ fm_mel = self.cfm.forward(
650
+ estimator=self.conditional_decoder,
651
+ mu=refined_mel.to(float_dtype),
652
+ mask=sequence_mask.float(),
653
+ n_timesteps=self.diffusion_steps,
654
+ )
655
+ return OmniAudioFlowMatchingDecoderOutput(
656
+ flow_matching_mel=fm_mel.transpose(1, 2),
657
+ flow_matching_mel_lengths=mel_labels_length,
658
+ )
config.json ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "_",
3
+ "architectures": [
4
+ "OmniForCausalLM"
5
+ ],
6
+ "attention_qkv_bias": true,
7
+ "attention_qkv_pack": true,
8
+ "audio_config": {
9
+ "audio_head_transformer_layers": 3,
10
+ "audio_delim_token_id": 151693,
11
+ "audio_end_token_id": 151677,
12
+ "audio_pad_token_id": 151678,
13
+ "audio_start_token_id": 151676,
14
+ "audiogen_end_token_id": 151701,
15
+ "audiogen_start_token_id": 151700,
16
+ "audiotext_end_token_id": 151698,
17
+ "audiotext_pad_token_id": 151699,
18
+ "audiotext_start_token_id": 151697,
19
+ "avg_pooler": 4,
20
+ "d_model": 1280,
21
+ "decoder_attention_heads": 20,
22
+ "decoder_ffn_dim": 5120,
23
+ "decoder_kernel_size": 3,
24
+ "decoder_layers": 8,
25
+ "decoder_stride_size": 2,
26
+ "enable": true,
27
+ "encoder_attention_heads": 20,
28
+ "encoder_ffn_dim": 5120,
29
+ "encoder_layers": 32,
30
+ "hop_length": 160,
31
+ "kernel_size": 3,
32
+ "max_audio_seconds": 30,
33
+ "n_fft": 400,
34
+ "num_mel_bins": 128,
35
+ "sampling_rate": 16000,
36
+ "stride_size": 2,
37
+ "split_overlap": 0.0,
38
+ "vq_config":{
39
+ "enable": true,
40
+ "codebook_sizes": [8192, 4096, 2048, 1024, 1024, 1024, 1024, 1024]
41
+ }
42
+ },
43
+ "auto_map": {
44
+ "AutoConfig": "configuration_omni.OmniConfig",
45
+ "AutoModelForCausalLM": "modeling_omni.OmniForCausalLM"
46
+ },
47
+ "omni_tokenizer_type": "auto",
48
+ "bos_token_id": 1,
49
+ "eos_token_id": 2,
50
+ "flow_matching_config": {
51
+ "enable": true,
52
+ "use_hires_mel": true,
53
+ "sampling_rate": 24000,
54
+ "hop_length": 480,
55
+ "max_audio_seconds": 30,
56
+ "split_overlap": 0.1,
57
+ "use_hidden_states_before_dconv2": true,
58
+ "prenet_in_dim": 1280,
59
+ "prenet_out_dim": 80,
60
+ "prenet_d_model": 512,
61
+ "prenet_attention_heads": 8,
62
+ "prenet_ffn_dim": 2048,
63
+ "prenet_nlayers": 12,
64
+ "prenet_activation_function": "gelu",
65
+ "prenet_max_source_positions": 5000,
66
+ "prenet_target_mel_length_scale_ratio": 1.0,
67
+ "prenet_loss_weight": 1.0,
68
+ "unet_use_omni_attn": false,
69
+ "loss_weight": 1.0,
70
+ "in_channels": 80,
71
+ "spk_emb_dim": 0,
72
+ "diffusion_steps": 10,
73
+ "channels": [256],
74
+ "dropout": 0.0,
75
+ "attention_head_dim": 64,
76
+ "n_blocks": 4,
77
+ "num_mid_blocks": 12,
78
+ "num_heads": 8,
79
+ "act_fn": "gelu",
80
+ "cal_mel_mae": true,
81
+ "cfm_params": {
82
+ "sigma_min": 1e-6,
83
+ "solver": "euler",
84
+ "t_scheduler": "cosine",
85
+ "training_cfg_rate": 0.2,
86
+ "inference_cfg_rate": 0.7,
87
+ "reg_loss_type": "l1"
88
+ }
89
+ },
90
+ "head_dim": 128,
91
+ "hidden_act": "silu",
92
+ "hidden_size": 3584,
93
+ "initializer_range": 0.02,
94
+ "intermediate_size": 18944,
95
+ "max_position_embeddings": 8192,
96
+ "max_window_layers": 28,
97
+ "model_type": "omni",
98
+ "multimodal": [
99
+ "audio",
100
+ "image",
101
+ "video",
102
+ "audiogen"
103
+ ],
104
+ "multimodal_special_token_list": [
105
+ 151676,
106
+ 151677,
107
+ 151678,
108
+ 151679,
109
+ 151680,
110
+ 151681,
111
+ 151682,
112
+ 151683,
113
+ 151684,
114
+ 151685,
115
+ 151686,
116
+ 151687,
117
+ 151688,
118
+ 151693,
119
+ 151694,
120
+ 151695,
121
+ 151696,
122
+ 151697,
123
+ 151698,
124
+ 151699,
125
+ 151700,
126
+ 151701
127
+ ],
128
+ "num_attention_heads": 28,
129
+ "num_hidden_layers": 28,
130
+ "num_key_value_heads": 4,
131
+ "pad_token_id": 0,
132
+ "position_embedding_type": "rope",
133
+ "rms_norm_eps": 1e-06,
134
+ "rope_theta": 1000000.0,
135
+ "sliding_window": 131072,
136
+ "sparse_attention_heads": null,
137
+ "sparse_attention_layers": [],
138
+ "tie_word_embeddings": false,
139
+ "torch_dtype": "bfloat16",
140
+ "train_multimodal_special_tokens_only": false,
141
+ "transformers_version": "4.45.0.dev0",
142
+ "use_cache": false,
143
+ "use_norm_head": false,
144
+ "use_sliding_window": false,
145
+ "video_config": {
146
+ "_name_or_path": "",
147
+ "_attn_implementation": "flash_attention_2",
148
+ "decode_way": "1fps",
149
+ "depth": 32,
150
+ "embed_dim": 1280,
151
+ "enable": true,
152
+ "hidden_act": "quick_gelu",
153
+ "hidden_size": 3584,
154
+ "image_delimiter_token_id": 151688,
155
+ "image_end_token_id": 151680,
156
+ "image_line_token_id": 151682,
157
+ "image_mean": [
158
+ 0.48145466,
159
+ 0.4578275,
160
+ 0.40821073
161
+ ],
162
+ "image_pad_token_id": 151681,
163
+ "image_size": 224,
164
+ "image_start_token_id": 151679,
165
+ "image_std": [
166
+ 0.26862954,
167
+ 0.26130258,
168
+ 0.27577711
169
+ ],
170
+ "in_channels": 3,
171
+ "in_chans": 3,
172
+ "intermediate_size": 3072,
173
+ "layer_norm_eps": 1e-05,
174
+ "max_frame_num": 32,
175
+ "max_length": 20,
176
+ "max_pixels": 602112,
177
+ "merge_size": 2,
178
+ "min_length": 0,
179
+ "min_pixels": 3136,
180
+ "mlp_ratio": 4,
181
+ "model_type": "clip_vision_model",
182
+ "num_attention_heads": 12,
183
+ "num_channels": 3,
184
+ "num_heads": 16,
185
+ "num_hidden_layers": 12,
186
+ "patch_size": 14,
187
+ "spatial_merge_size": 2,
188
+ "spatial_patch_size": 14,
189
+ "temporal_patch_size": 2,
190
+ "video_end_token_id": 151696,
191
+ "video_place_token_id": 151694,
192
+ "video_start_token_id": 151695
193
+ },
194
+ "visual_config": {
195
+ "_name_or_path": "",
196
+ "_attn_implementation": "flash_attention_2",
197
+ "depth": 32,
198
+ "diversity_penalty": 0.0,
199
+ "do_sample": false,
200
+ "early_stopping": false,
201
+ "embed_dim": 1280,
202
+ "enable": true,
203
+ "hidden_act": "quick_gelu",
204
+ "hidden_size": 3584,
205
+ "image_delimiter_token_id": 151688,
206
+ "image_end_token_id": 151680,
207
+ "image_line_token_id": 151682,
208
+ "image_mean": [
209
+ 0.48145466,
210
+ 0.4578275,
211
+ 0.40821073
212
+ ],
213
+ "image_pad_token_id": 151681,
214
+ "image_size": 224,
215
+ "image_start_token_id": 151679,
216
+ "image_std": [
217
+ 0.26862954,
218
+ 0.26130258,
219
+ 0.27577711
220
+ ],
221
+ "in_channels": 3,
222
+ "in_chans": 3,
223
+ "intermediate_size": 3072,
224
+ "layer_norm_eps": 1e-05,
225
+ "length_penalty": 1.0,
226
+ "max_length": 20,
227
+ "max_pixels": 3211264,
228
+ "merge_size": 2,
229
+ "min_length": 0,
230
+ "min_pixels": 3136,
231
+ "mlp_ratio": 4,
232
+ "model_type": "clip_vision_model",
233
+ "num_attention_heads": 12,
234
+ "num_channels": 3,
235
+ "num_heads": 16,
236
+ "num_hidden_layers": 12,
237
+ "patch_size": 14,
238
+ "projection_dim": 512,
239
+ "spatial_merge_size": 2,
240
+ "spatial_patch_size": 14,
241
+ "temporal_patch_size": 2
242
+ },
243
+ "vocab_size": 152064,
244
+ "vocoder_config":{
245
+ "enable": true,
246
+ "enable_multi_scale": true,
247
+ "max_audio_seconds": 30,
248
+ "sampling_rate": 16000,
249
+ "hop_length": 256,
250
+ "split_overlap": 0.0,
251
+ "n_fft": 1024,
252
+ "num_mel_bins": 80,
253
+ "channels": [256, 256, 256, 256, 256]
254
+ }
255
+ }
configuration_omni.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Baichuan Inc. All Rights Reserved.
2
+
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+ from transformers import WhisperConfig
25
+ from transformers import CLIPVisionConfig
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class OmniConfig(PretrainedConfig):
31
+ model_type = "omni"
32
+ keys_to_ignore_at_inference = ["past_key_values"]
33
+
34
+ def __init__(
35
+ self,
36
+ vocab_size=125696,
37
+ hidden_size=4096,
38
+ intermediate_size=11008,
39
+ num_hidden_layers=32,
40
+ num_attention_heads=32,
41
+ num_key_value_heads=None,
42
+ sparse_attention_heads=None,
43
+ sparse_attention_layers=[],
44
+ head_dim=None,
45
+ attention_qkv_pack=True,
46
+ attention_qkv_bias=False,
47
+ use_norm_head=True,
48
+ hidden_act="silu",
49
+ max_position_embeddings=4096,
50
+ position_embedding_type="rope",
51
+ initializer_range=0.02,
52
+ rms_norm_eps=1e-6,
53
+ use_cache=True,
54
+ pad_token_id=0,
55
+ bos_token_id=1,
56
+ eos_token_id=2,
57
+ tie_word_embeddings=False,
58
+ audio_config=None,
59
+ visual_config=None,
60
+ video_config=None,
61
+ vocoder_config=None,
62
+ flow_matching_config=None,
63
+ **kwargs,
64
+ ):
65
+ self.vocab_size = vocab_size
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.hidden_size = hidden_size
68
+ self.intermediate_size = intermediate_size
69
+ self.num_hidden_layers = num_hidden_layers
70
+ self.num_attention_heads = num_attention_heads
71
+ self.num_key_value_heads = num_key_value_heads or self.num_attention_heads
72
+ self.sparse_attention_heads = sparse_attention_heads
73
+ self.sparse_attention_layers = sparse_attention_layers
74
+ self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
75
+ self.attention_qkv_pack = attention_qkv_pack
76
+ self.attention_qkv_bias = attention_qkv_bias
77
+ self.use_norm_head = use_norm_head
78
+ self.hidden_act = hidden_act
79
+ self.position_embedding_type = position_embedding_type
80
+ self.initializer_range = initializer_range
81
+ self.rms_norm_eps = rms_norm_eps
82
+ self.use_cache = use_cache
83
+ assert self.position_embedding_type.lower() in ("rope", "alibi")
84
+ super().__init__(
85
+ pad_token_id=pad_token_id,
86
+ bos_token_id=bos_token_id,
87
+ eos_token_id=eos_token_id,
88
+ tie_word_embeddings=tie_word_embeddings,
89
+ **kwargs,
90
+ )
91
+ if audio_config is not None:
92
+ self.audio_config = WhisperConfig(**audio_config)
93
+ if self.audio_config.vq_config is not None:
94
+ self.audio_config.vq_config = PretrainedConfig(**self.audio_config.vq_config)
95
+ if vocoder_config is not None:
96
+ self.vocoder_config = WhisperConfig(**vocoder_config)
97
+ if flow_matching_config is not None:
98
+ self.flow_matching_config = PretrainedConfig(**flow_matching_config)
99
+ self.flow_matching_config.cfm_params = PretrainedConfig(**self.flow_matching_config.cfm_params)
100
+ if visual_config is not None:
101
+ self.visual_config = CLIPVisionConfig(**visual_config)
102
+ if video_config is not None:
103
+ self.video_config = CLIPVisionConfig(**video_config)
104
+
105
+
106
+ def to_diff_dict(self):
107
+ data = super().to_diff_dict()
108
+ data["model_type"] = self.model_type
109
+ return data
110
+
111
+ def get_rotary_base(self):
112
+ if hasattr(self, "rotary_emb_base"):
113
+ return self.rotary_emb_base
114
+ else:
115
+ return self.rope_theta
116
+
117
+ if __name__ == '__main__':
118
+ from transformers import AutoConfig
119
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
120
+ print(config)
flow_matching.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice/tree/main
2
+ """
3
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ from abc import ABC
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from typing import Dict, Optional
22
+
23
+ import torch.nn as nn
24
+ from einops import pack, rearrange, repeat
25
+ from .matcha_components import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
26
+ from .matcha_transformer import BasicTransformerBlock
27
+ from omegaconf import DictConfig
28
+
29
+
30
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
31
+ assert mask.dtype == torch.bool
32
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
33
+ mask = mask.to(dtype)
34
+ # attention mask bias
35
+ # NOTE(Mddct): torch.finfo jit issues
36
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
37
+ mask = (1.0 - mask) * torch.finfo(dtype).min
38
+ return mask
39
+
40
+
41
+ def subsequent_chunk_mask(
42
+ size: int,
43
+ chunk_size: int,
44
+ num_left_chunks: int = -1,
45
+ device: torch.device = torch.device("cpu"),
46
+ ) -> torch.Tensor:
47
+ """Create mask for subsequent steps (size, size) with chunk size,
48
+ this is for streaming encoder
49
+
50
+ Args:
51
+ size (int): size of mask
52
+ chunk_size (int): size of chunk
53
+ num_left_chunks (int): number of left chunks
54
+ <0: use full chunk
55
+ >=0: use num_left_chunks
56
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
57
+
58
+ Returns:
59
+ torch.Tensor: mask
60
+
61
+ Examples:
62
+ >>> subsequent_chunk_mask(4, 2)
63
+ [[1, 1, 0, 0],
64
+ [1, 1, 0, 0],
65
+ [1, 1, 1, 1],
66
+ [1, 1, 1, 1]]
67
+ """
68
+ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
69
+ # actually this is not needed after we have inference cache implemented, will remove it later
70
+ pos_idx = torch.arange(size, device=device)
71
+ block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
72
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
73
+ return ret
74
+
75
+ def subsequent_mask(
76
+ size: int,
77
+ device: torch.device = torch.device("cpu"),
78
+ ) -> torch.Tensor:
79
+ """Create mask for subsequent steps (size, size).
80
+
81
+ This mask is used only in decoder which works in an auto-regressive mode.
82
+ This means the current step could only do attention with its left steps.
83
+
84
+ In encoder, fully attention is used when streaming is not necessary and
85
+ the sequence is not long. In this case, no attention mask is needed.
86
+
87
+ When streaming is need, chunk-based attention is used in encoder. See
88
+ subsequent_chunk_mask for the chunk-based attention mask.
89
+
90
+ Args:
91
+ size (int): size of mask
92
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
93
+ dtype (torch.device): result dtype
94
+
95
+ Returns:
96
+ torch.Tensor: mask
97
+
98
+ Examples:
99
+ >>> subsequent_mask(3)
100
+ [[1, 0, 0],
101
+ [1, 1, 0],
102
+ [1, 1, 1]]
103
+ """
104
+ arange = torch.arange(size, device=device)
105
+ mask = arange.expand(size, size)
106
+ arange = arange.unsqueeze(-1)
107
+ mask = mask <= arange
108
+ return mask
109
+
110
+
111
+ def add_optional_chunk_mask(xs: torch.Tensor,
112
+ masks: torch.Tensor,
113
+ use_dynamic_chunk: bool,
114
+ use_dynamic_left_chunk: bool,
115
+ decoding_chunk_size: int,
116
+ static_chunk_size: int,
117
+ num_decoding_left_chunks: int,
118
+ enable_full_context: bool = True):
119
+ """ Apply optional mask for encoder.
120
+
121
+ Args:
122
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
123
+ mask (torch.Tensor): mask for xs, (B, 1, L)
124
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
125
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
126
+ training.
127
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
128
+ 0: default for training, use random dynamic chunk.
129
+ <0: for decoding, use full chunk.
130
+ >0: for decoding, use fixed chunk size as set.
131
+ static_chunk_size (int): chunk size for static chunk training/decoding
132
+ if it's greater than 0, if use_dynamic_chunk is true,
133
+ this parameter will be ignored
134
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
135
+ the chunk size is decoding_chunk_size.
136
+ >=0: use num_decoding_left_chunks
137
+ <0: use all left chunks
138
+ enable_full_context (bool):
139
+ True: chunk size is either [1, 25] or full context(max_len)
140
+ False: chunk size ~ U[1, 25]
141
+
142
+ Returns:
143
+ torch.Tensor: chunk mask of the input xs.
144
+ """
145
+ # Whether to use chunk mask or not
146
+ if use_dynamic_chunk:
147
+ max_len = xs.size(1)
148
+ if decoding_chunk_size < 0:
149
+ chunk_size = max_len
150
+ num_left_chunks = -1
151
+ elif decoding_chunk_size > 0:
152
+ chunk_size = decoding_chunk_size
153
+ num_left_chunks = num_decoding_left_chunks
154
+ else:
155
+ # chunk size is either [1, 25] or full context(max_len).
156
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
157
+ # delay, the maximum frame is 100 / 4 = 25.
158
+ chunk_size = torch.randint(1, max_len, (1, )).item()
159
+ num_left_chunks = -1
160
+ if chunk_size > max_len // 2 and enable_full_context:
161
+ chunk_size = max_len
162
+ else:
163
+ chunk_size = chunk_size % 25 + 1
164
+ if use_dynamic_left_chunk:
165
+ max_left_chunks = (max_len - 1) // chunk_size
166
+ num_left_chunks = torch.randint(0, max_left_chunks,
167
+ (1, )).item()
168
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
169
+ num_left_chunks,
170
+ xs.device) # (L, L)
171
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
172
+ chunk_masks = masks & chunk_masks # (B, L, L)
173
+ elif static_chunk_size > 0:
174
+ num_left_chunks = num_decoding_left_chunks
175
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
176
+ num_left_chunks,
177
+ xs.device) # (L, L)
178
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
179
+ chunk_masks = masks & chunk_masks # (B, L, L)
180
+ else:
181
+ chunk_masks = masks
182
+ return chunk_masks
183
+
184
+
185
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
186
+ """Make mask tensor containing indices of padded part.
187
+
188
+ See description of make_non_pad_mask.
189
+
190
+ Args:
191
+ lengths (torch.Tensor): Batch of lengths (B,).
192
+ Returns:
193
+ torch.Tensor: Mask tensor containing indices of padded part.
194
+
195
+ Examples:
196
+ >>> lengths = [5, 3, 2]
197
+ >>> make_pad_mask(lengths)
198
+ masks = [[0, 0, 0, 0 ,0],
199
+ [0, 0, 0, 1, 1],
200
+ [0, 0, 1, 1, 1]]
201
+ """
202
+ batch_size = lengths.size(0)
203
+ max_len = max_len if max_len > 0 else lengths.max().item()
204
+ seq_range = torch.arange(0,
205
+ max_len,
206
+ dtype=torch.int64,
207
+ device=lengths.device)
208
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
209
+ seq_length_expand = lengths.unsqueeze(-1)
210
+ mask = seq_range_expand >= seq_length_expand
211
+ return mask
212
+
213
+ # Causal
214
+ class Transpose(torch.nn.Module):
215
+ def __init__(self, dim0: int, dim1: int):
216
+ super().__init__()
217
+ self.dim0 = dim0
218
+ self.dim1 = dim1
219
+
220
+ def forward(self, x: torch.Tensor):
221
+ x = torch.transpose(x, self.dim0, self.dim1)
222
+ return x
223
+
224
+ class CausalBlock1D(Block1D):
225
+ def __init__(self, dim: int, dim_out: int):
226
+ super(CausalBlock1D, self).__init__(dim, dim_out)
227
+ self.block = torch.nn.Sequential(
228
+ CausalConv1d(dim, dim_out, 3),
229
+ Transpose(1, 2),
230
+ nn.LayerNorm(dim_out),
231
+ Transpose(1, 2),
232
+ nn.Mish(),
233
+ )
234
+
235
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
236
+ output = self.block(x * mask)
237
+ return output * mask
238
+
239
+ class CausalResnetBlock1D(ResnetBlock1D):
240
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
241
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
242
+ self.block1 = CausalBlock1D(dim, dim_out)
243
+ self.block2 = CausalBlock1D(dim_out, dim_out)
244
+
245
+ class CausalConv1d(torch.nn.Conv1d):
246
+ def __init__(
247
+ self,
248
+ in_channels: int,
249
+ out_channels: int,
250
+ kernel_size: int,
251
+ stride: int = 1,
252
+ dilation: int = 1,
253
+ groups: int = 1,
254
+ bias: bool = True,
255
+ padding_mode: str = 'zeros',
256
+ device=None,
257
+ dtype=None
258
+ ) -> None:
259
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
260
+ kernel_size, stride,
261
+ padding=0, dilation=dilation,
262
+ groups=groups, bias=bias,
263
+ padding_mode=padding_mode,
264
+ device=device, dtype=dtype)
265
+ assert stride == 1
266
+ self.causal_padding = (kernel_size - 1, 0)
267
+
268
+ def forward(self, x: torch.Tensor):
269
+ x = F.pad(x, self.causal_padding)
270
+ x = super(CausalConv1d, self).forward(x)
271
+ return x
272
+
273
+
274
+ class BASECFM(torch.nn.Module, ABC):
275
+ def __init__(
276
+ self,
277
+ n_feats,
278
+ cfm_params,
279
+ n_spks=1,
280
+ spk_emb_dim=128,
281
+ ):
282
+ super().__init__()
283
+ self.n_feats = n_feats
284
+ self.n_spks = n_spks
285
+ self.spk_emb_dim = spk_emb_dim
286
+ self.solver = cfm_params.solver
287
+ if hasattr(cfm_params, "sigma_min"):
288
+ self.sigma_min = cfm_params.sigma_min
289
+ else:
290
+ self.sigma_min = 1e-4
291
+
292
+ self.estimator = None
293
+
294
+ @torch.inference_mode()
295
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
296
+ """Forward diffusion
297
+
298
+ Args:
299
+ mu (torch.Tensor): output of encoder
300
+ shape: (batch_size, n_feats, mel_timesteps)
301
+ mask (torch.Tensor): output_mask
302
+ shape: (batch_size, 1, mel_timesteps)
303
+ n_timesteps (int): number of diffusion steps
304
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
305
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
306
+ shape: (batch_size, spk_emb_dim)
307
+ cond: Not used but kept for future purposes
308
+
309
+ Returns:
310
+ sample: generated mel-spectrogram
311
+ shape: (batch_size, n_feats, mel_timesteps)
312
+ """
313
+ z = torch.randn_like(mu) * temperature
314
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
315
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
316
+
317
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
318
+ """
319
+ Fixed euler solver for ODEs.
320
+ Args:
321
+ x (torch.Tensor): random noise
322
+ t_span (torch.Tensor): n_timesteps interpolated
323
+ shape: (n_timesteps + 1,)
324
+ mu (torch.Tensor): output of encoder
325
+ shape: (batch_size, n_feats, mel_timesteps)
326
+ mask (torch.Tensor): output_mask
327
+ shape: (batch_size, 1, mel_timesteps)
328
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
329
+ shape: (batch_size, spk_emb_dim)
330
+ cond: Not used but kept for future purposes
331
+ """
332
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
333
+
334
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
335
+ # Or in future might add like a return_all_steps flag
336
+ sol = []
337
+
338
+ for step in range(1, len(t_span)):
339
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
340
+
341
+ x = x + dt * dphi_dt
342
+ t = t + dt
343
+ sol.append(x)
344
+ if step < len(t_span) - 1:
345
+ dt = t_span[step + 1] - t
346
+
347
+ return sol[-1]
348
+
349
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
350
+ """Computes diffusion loss
351
+
352
+ Args:
353
+ x1 (torch.Tensor): Target
354
+ shape: (batch_size, n_feats, mel_timesteps)
355
+ mask (torch.Tensor): target mask
356
+ shape: (batch_size, 1, mel_timesteps)
357
+ mu (torch.Tensor): output of encoder
358
+ shape: (batch_size, n_feats, mel_timesteps)
359
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
360
+ shape: (batch_size, spk_emb_dim)
361
+
362
+ Returns:
363
+ loss: conditional flow matching loss
364
+ y: conditional flow
365
+ shape: (batch_size, n_feats, mel_timesteps)
366
+ """
367
+ b, _, t = mu.shape
368
+
369
+ # random timestep
370
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
371
+ # sample noise p(x_0)
372
+ z = torch.randn_like(x1)
373
+
374
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
375
+ u = x1 - (1 - self.sigma_min) * z
376
+
377
+ loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
378
+ torch.sum(mask) * u.shape[1]
379
+ )
380
+ return loss, y
381
+
382
+
383
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
384
+ """Make mask tensor containing indices of padded part.
385
+
386
+ See description of make_non_pad_mask.
387
+
388
+ Args:
389
+ lengths (torch.Tensor): Batch of lengths (B,).
390
+ Returns:
391
+ torch.Tensor: Mask tensor containing indices of padded part.
392
+
393
+ Examples:
394
+ >>> lengths = [5, 3, 2]
395
+ >>> make_pad_mask(lengths)
396
+ masks = [[0, 0, 0, 0 ,0],
397
+ [0, 0, 0, 1, 1],
398
+ [0, 0, 1, 1, 1]]
399
+ """
400
+ batch_size = lengths.size(0)
401
+ max_len = max_len if max_len > 0 else lengths.max().item()
402
+ seq_range = torch.arange(0,
403
+ max_len,
404
+ dtype=torch.int64,
405
+ device=lengths.device)
406
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
407
+ seq_length_expand = lengths.unsqueeze(-1)
408
+ mask = seq_range_expand >= seq_length_expand
409
+ return mask
410
+
411
+
412
+ class ConditionalDecoder(nn.Module):
413
+ def __init__(
414
+ self,
415
+ in_channels,
416
+ out_channels,
417
+ causal=False,
418
+ channels=(256, 256),
419
+ dropout=0.05,
420
+ attention_head_dim=64,
421
+ n_blocks=1,
422
+ num_mid_blocks=2,
423
+ num_heads=4,
424
+ act_fn="snake",
425
+ gradient_checkpointing=True,
426
+ ):
427
+ """
428
+ This decoder requires an input with the same shape of the target. So, if your text content
429
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
430
+ """
431
+ super().__init__()
432
+ channels = tuple(channels)
433
+ self.in_channels = in_channels
434
+ self.out_channels = out_channels
435
+ self.causal = causal
436
+ self.static_chunk_size = 2 * 25 * 2 # 2*input_frame_rate*token_mel_ratio
437
+ self.gradient_checkpointing = gradient_checkpointing
438
+
439
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
440
+ time_embed_dim = channels[0] * 4
441
+ self.time_mlp = TimestepEmbedding(
442
+ in_channels=in_channels,
443
+ time_embed_dim=time_embed_dim,
444
+ act_fn="silu",
445
+ )
446
+ self.down_blocks = nn.ModuleList([])
447
+ self.mid_blocks = nn.ModuleList([])
448
+ self.up_blocks = nn.ModuleList([])
449
+
450
+ output_channel = in_channels
451
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
452
+ input_channel = output_channel
453
+ output_channel = channels[i]
454
+ is_last = i == len(channels) - 1
455
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
456
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
457
+ transformer_blocks = nn.ModuleList(
458
+ [
459
+ BasicTransformerBlock(
460
+ dim=output_channel,
461
+ num_attention_heads=num_heads,
462
+ attention_head_dim=attention_head_dim,
463
+ dropout=dropout,
464
+ activation_fn=act_fn,
465
+ )
466
+ for _ in range(n_blocks)
467
+ ]
468
+ )
469
+ downsample = (
470
+ Downsample1D(output_channel) if not is_last else
471
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
472
+ )
473
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
474
+
475
+ for _ in range(num_mid_blocks):
476
+ input_channel = channels[-1]
477
+ out_channels = channels[-1]
478
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
479
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
480
+
481
+ transformer_blocks = nn.ModuleList(
482
+ [
483
+ BasicTransformerBlock(
484
+ dim=output_channel,
485
+ num_attention_heads=num_heads,
486
+ attention_head_dim=attention_head_dim,
487
+ dropout=dropout,
488
+ activation_fn=act_fn,
489
+ )
490
+ for _ in range(n_blocks)
491
+ ]
492
+ )
493
+
494
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
495
+
496
+ channels = channels[::-1] + (channels[0],)
497
+ for i in range(len(channels) - 1):
498
+ input_channel = channels[i] * 2
499
+ output_channel = channels[i + 1]
500
+ is_last = i == len(channels) - 2
501
+ resnet = CausalResnetBlock1D(
502
+ dim=input_channel,
503
+ dim_out=output_channel,
504
+ time_emb_dim=time_embed_dim,
505
+ ) if self.causal else ResnetBlock1D(
506
+ dim=input_channel,
507
+ dim_out=output_channel,
508
+ time_emb_dim=time_embed_dim,
509
+ )
510
+ transformer_blocks = nn.ModuleList(
511
+ [
512
+ BasicTransformerBlock(
513
+ dim=output_channel,
514
+ num_attention_heads=num_heads,
515
+ attention_head_dim=attention_head_dim,
516
+ dropout=dropout,
517
+ activation_fn=act_fn,
518
+ )
519
+ for _ in range(n_blocks)
520
+ ]
521
+ )
522
+ upsample = (
523
+ Upsample1D(output_channel, use_conv_transpose=True)
524
+ if not is_last
525
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
526
+ )
527
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
528
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
529
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
530
+ self.initialize_weights()
531
+
532
+ def initialize_weights(self):
533
+ for m in self.modules():
534
+ if isinstance(m, nn.Conv1d):
535
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
536
+ if m.bias is not None:
537
+ nn.init.constant_(m.bias, 0)
538
+ elif isinstance(m, nn.GroupNorm):
539
+ nn.init.constant_(m.weight, 1)
540
+ nn.init.constant_(m.bias, 0)
541
+ elif isinstance(m, nn.Linear):
542
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
543
+ if m.bias is not None:
544
+ nn.init.constant_(m.bias, 0)
545
+
546
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
547
+ """Forward pass of the UNet1DConditional model.
548
+
549
+ Args:
550
+ x (torch.Tensor): shape (batch_size, in_channels, time)
551
+ mask (_type_): shape (batch_size, 1, time)
552
+ t (_type_): shape (batch_size)
553
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
554
+ cond (_type_, optional): placeholder for future use. Defaults to None.
555
+
556
+ Raises:
557
+ ValueError: _description_
558
+ ValueError: _description_
559
+
560
+ Returns:
561
+ _type_: _description_
562
+ """
563
+ t = self.time_embeddings(t)
564
+ t = t.to(x.dtype)
565
+ t = self.time_mlp(t)
566
+ x = pack([x, mu], "b * t")[0]
567
+ mask = mask.to(x.dtype)
568
+ if spks is not None:
569
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
570
+ x = pack([x, spks], "b * t")[0]
571
+ if cond is not None:
572
+ x = pack([x, cond], "b * t")[0]
573
+
574
+ hiddens = []
575
+ masks = [mask]
576
+ for resnet, transformer_blocks, downsample in self.down_blocks:
577
+ mask_down = masks[-1]
578
+ x = resnet(x, mask_down, t)
579
+ x = rearrange(x, "b c t -> b t c").contiguous()
580
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
581
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
582
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
583
+ for transformer_block in transformer_blocks:
584
+ if self.gradient_checkpointing and self.training:
585
+ def create_custom_forward(module):
586
+ def custom_forward(*inputs):
587
+ return module(*inputs)
588
+ return custom_forward
589
+ x = torch.utils.checkpoint.checkpoint(
590
+ create_custom_forward(transformer_block),
591
+ x,
592
+ attn_mask,
593
+ t,
594
+ )
595
+ else:
596
+ x = transformer_block(
597
+ hidden_states=x,
598
+ attention_mask=attn_mask,
599
+ timestep=t,
600
+ )
601
+ x = rearrange(x, "b t c -> b c t").contiguous()
602
+ hiddens.append(x) # Save hidden states for skip connections
603
+ x = downsample(x * mask_down)
604
+ masks.append(mask_down[:, :, ::2])
605
+ masks = masks[:-1]
606
+ mask_mid = masks[-1]
607
+
608
+ for resnet, transformer_blocks in self.mid_blocks:
609
+ x = resnet(x, mask_mid, t)
610
+ x = rearrange(x, "b c t -> b t c").contiguous()
611
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
612
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
613
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
614
+ for transformer_block in transformer_blocks:
615
+ if self.gradient_checkpointing and self.training:
616
+ def create_custom_forward(module):
617
+ def custom_forward(*inputs):
618
+ return module(*inputs)
619
+ return custom_forward
620
+ x = torch.utils.checkpoint.checkpoint(
621
+ create_custom_forward(transformer_block),
622
+ x,
623
+ attn_mask,
624
+ t,
625
+ )
626
+ else:
627
+ x = transformer_block(
628
+ hidden_states=x,
629
+ attention_mask=attn_mask,
630
+ timestep=t,
631
+ )
632
+ x = rearrange(x, "b t c -> b c t").contiguous()
633
+
634
+ for resnet, transformer_blocks, upsample in self.up_blocks:
635
+ mask_up = masks.pop()
636
+ skip = hiddens.pop()
637
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
638
+ x = resnet(x, mask_up, t)
639
+ x = rearrange(x, "b c t -> b t c").contiguous()
640
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
641
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
642
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
643
+ for transformer_block in transformer_blocks:
644
+ if self.gradient_checkpointing and self.training:
645
+ def create_custom_forward(module):
646
+ def custom_forward(*inputs):
647
+ return module(*inputs)
648
+ return custom_forward
649
+ x = torch.utils.checkpoint.checkpoint(
650
+ create_custom_forward(transformer_block),
651
+ x,
652
+ attn_mask,
653
+ t,
654
+ )
655
+ else:
656
+ x = transformer_block(
657
+ hidden_states=x,
658
+ attention_mask=attn_mask,
659
+ timestep=t,
660
+ )
661
+ x = rearrange(x, "b t c -> b c t").contiguous()
662
+ x = upsample(x * mask_up)
663
+ x = self.final_block(x, mask_up)
664
+ output = self.final_proj(x * mask_up)
665
+ return output * mask
666
+
667
+
668
+ class ConditionalCFM(BASECFM):
669
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64):
670
+ super().__init__(
671
+ n_feats=in_channels,
672
+ cfm_params=cfm_params,
673
+ n_spks=n_spks,
674
+ spk_emb_dim=spk_emb_dim,
675
+ )
676
+ self.t_scheduler = cfm_params.t_scheduler
677
+ self.training_cfg_rate = cfm_params.training_cfg_rate
678
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
679
+
680
+ @torch.inference_mode()
681
+ def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
682
+ """Forward diffusion
683
+
684
+ Args:
685
+ mu (torch.Tensor): output of encoder
686
+ shape: (batch_size, n_feats, mel_timesteps)
687
+ mask (torch.Tensor): output_mask
688
+ shape: (batch_size, 1, mel_timesteps)
689
+ n_timesteps (int): number of diffusion steps
690
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
691
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
692
+ shape: (batch_size, spk_emb_dim)
693
+ cond: Not used but kept for future purposes
694
+
695
+ Returns:
696
+ sample: generated mel-spectrogram
697
+ shape: (batch_size, n_feats, mel_timesteps)
698
+ """
699
+ z = torch.randn_like(mu) * temperature
700
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
701
+ if self.t_scheduler == 'cosine':
702
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
703
+ return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond)
704
+
705
+ def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond):
706
+ """
707
+ Fixed euler solver for ODEs.
708
+ Args:
709
+ x (torch.Tensor): random noise
710
+ t_span (torch.Tensor): n_timesteps interpolated
711
+ shape: (n_timesteps + 1,)
712
+ mu (torch.Tensor): output of encoder
713
+ shape: (batch_size, n_feats, mel_timesteps)
714
+ mask (torch.Tensor): output_mask
715
+ shape: (batch_size, 1, mel_timesteps)
716
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
717
+ shape: (batch_size, spk_emb_dim)
718
+ cond: Not used but kept for future purposes
719
+ """
720
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
721
+
722
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
723
+ # Or in future might add like a return_all_steps flag
724
+ sol = []
725
+
726
+ for step in range(1, len(t_span)):
727
+ dphi_dt = estimator(x, mask, mu, t, spks, cond)
728
+ # Classifier-Free Guidance inference introduced in VoiceBox
729
+ if self.inference_cfg_rate > 0:
730
+ cfg_dphi_dt = estimator(
731
+ x, mask,
732
+ torch.zeros_like(mu), t,
733
+ torch.zeros_like(spks) if spks is not None else None,
734
+ cond=cond
735
+ )
736
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
737
+ self.inference_cfg_rate * cfg_dphi_dt)
738
+ x = x + dt * dphi_dt
739
+ t = t + dt
740
+ sol.append(x)
741
+ if step < len(t_span) - 1:
742
+ dt = t_span[step + 1] - t
743
+
744
+ return sol[-1]
745
+
746
+ def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None):
747
+ """Computes diffusion loss
748
+
749
+ Args:
750
+ x1 (torch.Tensor): Target
751
+ shape: (batch_size, n_feats, mel_timesteps)
752
+ mask (torch.Tensor): target mask
753
+ shape: (batch_size, 1, mel_timesteps)
754
+ mu (torch.Tensor): output of encoder
755
+ shape: (batch_size, n_feats, mel_timesteps)
756
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
757
+ shape: (batch_size, spk_emb_dim)
758
+
759
+ Returns:
760
+ loss: conditional flow matching loss
761
+ y: conditional flow
762
+ shape: (batch_size, n_feats, mel_timesteps)
763
+ """
764
+ org_dtype = x1.dtype
765
+
766
+ b, _, t = mu.shape
767
+ # random timestep
768
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
769
+ if self.t_scheduler == 'cosine':
770
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
771
+ # sample noise p(x_0)
772
+ z = torch.randn_like(x1)
773
+
774
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
775
+ u = x1 - (1 - self.sigma_min) * z
776
+
777
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
778
+ if self.training_cfg_rate > 0:
779
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
780
+ mu = mu * cfg_mask.view(-1, 1, 1)
781
+ if spks is not None:
782
+ spks = spks * cfg_mask.view(-1, 1)
783
+ if cond is not None:
784
+ cond = cond * cfg_mask.view(-1, 1, 1)
785
+
786
+ pred = estimator(y, mask, mu, t.squeeze(), spks, cond)
787
+ pred = pred.float()
788
+ u = u.float()
789
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
790
+ loss = loss.to(org_dtype)
791
+ return loss, y
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "eos_token_id": 151643,
4
+ "max_new_tokens": 2048,
5
+ "transformers_version": "4.45.0.dev0"
6
+ }
generation_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from queue import Queue
3
+
4
+ import torch
5
+
6
+
7
+ def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
8
+ def _parse_messages(messages, split_role="user"):
9
+ system, rounds = "", []
10
+ round = []
11
+ for i, message in enumerate(messages):
12
+ if message["role"] == "system":
13
+ assert i == 0
14
+ system = message["content"]
15
+ continue
16
+ if message["role"] == split_role and round:
17
+ rounds.append(round)
18
+ round = []
19
+ round.append(message)
20
+ if round:
21
+ rounds.append(round)
22
+ return system, rounds
23
+
24
+ max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
25
+ max_input_tokens = model.config.model_max_length - max_new_tokens
26
+ system, rounds = _parse_messages(messages, split_role="user")
27
+ system_tokens = tokenizer.encode(system)
28
+ max_history_tokens = max_input_tokens - len(system_tokens)
29
+
30
+ history_tokens = []
31
+ for round in rounds[::-1]:
32
+ round_tokens = []
33
+ for message in round:
34
+ if message["role"] == "user":
35
+ round_tokens.append(model.generation_config.user_token_id)
36
+ else:
37
+ round_tokens.append(model.generation_config.assistant_token_id)
38
+ round_tokens.extend(tokenizer.encode(message["content"]))
39
+ if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
40
+ history_tokens = round_tokens + history_tokens # concat left
41
+ if len(history_tokens) < max_history_tokens:
42
+ continue
43
+ break
44
+
45
+ input_tokens = system_tokens + history_tokens
46
+ if messages[-1]["role"] != "assistant":
47
+ input_tokens.append(model.generation_config.assistant_token_id)
48
+ input_tokens = input_tokens[-max_input_tokens:] # truncate left
49
+ return torch.LongTensor([input_tokens]).to(model.device)
50
+
51
+
52
+ class TextIterStreamer:
53
+ def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
54
+ self.tokenizer = tokenizer
55
+ self.skip_prompt = skip_prompt
56
+ self.skip_special_tokens = skip_special_tokens
57
+ self.tokens = []
58
+ self.text_queue = Queue()
59
+ self.next_tokens_are_prompt = True
60
+
61
+ def put(self, value):
62
+ if self.skip_prompt and self.next_tokens_are_prompt:
63
+ self.next_tokens_are_prompt = False
64
+ else:
65
+ if len(value.shape) > 1:
66
+ value = value[0]
67
+ self.tokens.extend(value.tolist())
68
+ self.text_queue.put(
69
+ self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
70
+
71
+ def end(self):
72
+ self.text_queue.put(None)
73
+
74
+ def __iter__(self):
75
+ return self
76
+
77
+ def __next__(self):
78
+ value = self.text_queue.get()
79
+ if value is None:
80
+ raise StopIteration()
81
+ else:
82
+ return value
83
+
matcha_components.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
2
+ """
3
+ MIT License
4
+
5
+ Copyright (c) 2023 Shivam Mehta
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+
26
+ import math
27
+ from typing import Optional
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+
33
+ from diffusers.models.activations import get_activation
34
+
35
+
36
+ class SinusoidalPosEmb(torch.nn.Module):
37
+ def __init__(self, dim):
38
+ super().__init__()
39
+ self.dim = dim
40
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
41
+
42
+ def forward(self, x, scale=1000):
43
+ if x.ndim < 1:
44
+ x = x.unsqueeze(0)
45
+ device = x.device
46
+ half_dim = self.dim // 2
47
+ emb = math.log(10000) / (half_dim - 1)
48
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
49
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
50
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
51
+ return emb
52
+
53
+
54
+ class Block1D(torch.nn.Module):
55
+ def __init__(self, dim, dim_out, groups=8):
56
+ super().__init__()
57
+ self.block = torch.nn.Sequential(
58
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
59
+ torch.nn.GroupNorm(groups, dim_out),
60
+ nn.Mish(),
61
+ )
62
+
63
+ def forward(self, x, mask):
64
+ output = self.block(x * mask)
65
+ return output * mask
66
+
67
+
68
+ class ResnetBlock1D(torch.nn.Module):
69
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
70
+ super().__init__()
71
+ self.mlp = torch.nn.Sequential(
72
+ nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
73
+ )
74
+
75
+ self.block1 = Block1D(dim, dim_out, groups=groups)
76
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
77
+
78
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
79
+
80
+ def forward(self, x, mask, time_emb):
81
+ h = self.block1(x, mask)
82
+ h += self.mlp(time_emb).unsqueeze(-1)
83
+ h = self.block2(h, mask)
84
+ output = h + self.res_conv(x * mask)
85
+ return output
86
+
87
+
88
+ class Downsample1D(nn.Module):
89
+ def __init__(self, dim):
90
+ super().__init__()
91
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
92
+
93
+ def forward(self, x):
94
+ return self.conv(x)
95
+
96
+
97
+ class TimestepEmbedding(nn.Module):
98
+ def __init__(
99
+ self,
100
+ in_channels: int,
101
+ time_embed_dim: int,
102
+ act_fn: str = "silu",
103
+ out_dim: int = None,
104
+ post_act_fn: Optional[str] = None,
105
+ cond_proj_dim=None,
106
+ ):
107
+ super().__init__()
108
+
109
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
110
+
111
+ if cond_proj_dim is not None:
112
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
113
+ else:
114
+ self.cond_proj = None
115
+
116
+ self.act = get_activation(act_fn)
117
+
118
+ if out_dim is not None:
119
+ time_embed_dim_out = out_dim
120
+ else:
121
+ time_embed_dim_out = time_embed_dim
122
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
123
+
124
+ if post_act_fn is None:
125
+ self.post_act = None
126
+ else:
127
+ self.post_act = get_activation(post_act_fn)
128
+
129
+ def forward(self, sample, condition=None):
130
+ if condition is not None:
131
+ sample = sample + self.cond_proj(condition)
132
+ sample = self.linear_1(sample)
133
+
134
+ if self.act is not None:
135
+ sample = self.act(sample)
136
+
137
+ sample = self.linear_2(sample)
138
+
139
+ if self.post_act is not None:
140
+ sample = self.post_act(sample)
141
+ return sample
142
+
143
+
144
+ class Upsample1D(nn.Module):
145
+ """A 1D upsampling layer with an optional convolution.
146
+
147
+ Parameters:
148
+ channels (`int`):
149
+ number of channels in the inputs and outputs.
150
+ use_conv (`bool`, default `False`):
151
+ option to use a convolution.
152
+ use_conv_transpose (`bool`, default `False`):
153
+ option to use a convolution transpose.
154
+ out_channels (`int`, optional):
155
+ number of output channels. Defaults to `channels`.
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ channels,
161
+ use_conv=False,
162
+ use_conv_transpose=True,
163
+ out_channels=None,
164
+ name="conv",
165
+ ):
166
+ super().__init__()
167
+ self.channels = channels
168
+ self.out_channels = out_channels or channels
169
+ self.use_conv = use_conv
170
+ self.use_conv_transpose = use_conv_transpose
171
+ self.name = name
172
+
173
+ self.conv = None
174
+ if use_conv_transpose:
175
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
176
+ elif use_conv:
177
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
178
+
179
+ def forward(self, inputs):
180
+ assert inputs.shape[1] == self.channels
181
+ if self.use_conv_transpose:
182
+ return self.conv(inputs)
183
+
184
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
185
+
186
+ if self.use_conv:
187
+ outputs = self.conv(outputs)
188
+
189
+ return outputs
matcha_feat.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
2
+ """
3
+ MIT License
4
+
5
+ Copyright (c) 2023 Shivam Mehta
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.utils.data
29
+ from librosa.filters import mel as librosa_mel_fn
30
+ from scipy.io.wavfile import read
31
+
32
+ MAX_WAV_VALUE = 32768.0
33
+
34
+
35
+ def load_wav(full_path):
36
+ sampling_rate, data = read(full_path)
37
+ return data, sampling_rate
38
+
39
+
40
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
41
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
42
+
43
+
44
+ def dynamic_range_decompression(x, C=1):
45
+ return np.exp(x) / C
46
+
47
+
48
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
49
+ return torch.log(torch.clamp(x, min=clip_val) * C)
50
+
51
+
52
+ def dynamic_range_decompression_torch(x, C=1):
53
+ return torch.exp(x) / C
54
+
55
+
56
+ def spectral_normalize_torch(magnitudes):
57
+ output = dynamic_range_compression_torch(magnitudes)
58
+ return output
59
+
60
+
61
+ def spectral_de_normalize_torch(magnitudes):
62
+ output = dynamic_range_decompression_torch(magnitudes)
63
+ return output
64
+
65
+
66
+ mel_basis = {}
67
+ hann_window = {}
68
+
69
+
70
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
71
+ if torch.min(y) < -1.0:
72
+ print("min value is ", torch.min(y))
73
+ if torch.max(y) > 1.0:
74
+ print("max value is ", torch.max(y))
75
+
76
+ global mel_basis, hann_window # pylint: disable=global-statement
77
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
78
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
79
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
80
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
81
+
82
+ y = torch.nn.functional.pad(
83
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
84
+ )
85
+ y = y.squeeze(1)
86
+
87
+ spec = torch.view_as_real(
88
+ torch.stft(
89
+ y,
90
+ n_fft,
91
+ hop_length=hop_size,
92
+ win_length=win_size,
93
+ window=hann_window[str(y.device)],
94
+ center=center,
95
+ pad_mode="reflect",
96
+ normalized=False,
97
+ onesided=True,
98
+ return_complex=True,
99
+ )
100
+ )
101
+
102
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
103
+
104
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
105
+ spec = spectral_normalize_torch(spec)
106
+
107
+ return spec
matcha_transformer.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
2
+ """
3
+ MIT License
4
+
5
+ Copyright (c) 2023 Shivam Mehta
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+
26
+ from typing import Any, Dict, Optional
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ from diffusers.models.attention import (
31
+ GEGLU,
32
+ GELU,
33
+ AdaLayerNorm,
34
+ AdaLayerNormZero,
35
+ ApproximateGELU,
36
+ )
37
+ from diffusers.models.attention_processor import Attention
38
+ from diffusers.models.lora import LoRACompatibleLinear
39
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
40
+
41
+ import torch.nn.functional as F
42
+ from flash_attn import flash_attn_varlen_func
43
+
44
+
45
+ def get_sequence_mask(inputs, inputs_length):
46
+ if inputs.dim() == 3:
47
+ bsz, tgt_len, _ = inputs.size()
48
+ else:
49
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
50
+ sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
51
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(
52
+ bsz, tgt_len, 1
53
+ )
54
+ unpacking_index = (
55
+ torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
56
+ ) # 转成下标
57
+ return sequence_mask, unpacking_index
58
+
59
+
60
+ class OmniWhisperAttention(nn.Module):
61
+ def __init__(self, embed_dim, num_heads, causal=False):
62
+ super().__init__()
63
+ self.embed_dim = embed_dim
64
+ self.num_heads = num_heads
65
+ self.head_dim = embed_dim // num_heads
66
+
67
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
68
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
69
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
70
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
71
+
72
+ self.causal = causal
73
+
74
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
75
+ bsz, _ = hidden_states.size()
76
+
77
+ query_states = self.q_proj(hidden_states).view(
78
+ bsz, self.num_heads, self.head_dim
79
+ )
80
+ key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
81
+ value_states = self.v_proj(hidden_states).view(
82
+ bsz, self.num_heads, self.head_dim
83
+ )
84
+
85
+ cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(
86
+ torch.int32
87
+ )
88
+ max_seqlen = torch.max(seq_len).to(torch.int32).detach()
89
+ attn_output = flash_attn_varlen_func(
90
+ query_states,
91
+ key_states,
92
+ value_states,
93
+ cu_len,
94
+ cu_len,
95
+ max_seqlen,
96
+ max_seqlen,
97
+ causal=self.causal,
98
+ ) # (bsz * qlen, nheads, headdim)
99
+ attn_output = attn_output.reshape(bsz, self.embed_dim)
100
+ attn_output = self.out_proj(attn_output)
101
+ return attn_output
102
+
103
+
104
+ class SnakeBeta(nn.Module):
105
+ """
106
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
107
+ Shape:
108
+ - Input: (B, C, T)
109
+ - Output: (B, C, T), same shape as the input
110
+ Parameters:
111
+ - alpha - trainable parameter that controls frequency
112
+ - beta - trainable parameter that controls magnitude
113
+ References:
114
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
115
+ https://arxiv.org/abs/2006.08195
116
+ Examples:
117
+ >>> a1 = snakebeta(256)
118
+ >>> x = torch.randn(256)
119
+ >>> x = a1(x)
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ in_features,
125
+ out_features,
126
+ alpha=1.0,
127
+ alpha_trainable=True,
128
+ alpha_logscale=True,
129
+ ):
130
+ """
131
+ Initialization.
132
+ INPUT:
133
+ - in_features: shape of the input
134
+ - alpha - trainable parameter that controls frequency
135
+ - beta - trainable parameter that controls magnitude
136
+ alpha is initialized to 1 by default, higher values = higher-frequency.
137
+ beta is initialized to 1 by default, higher values = higher-magnitude.
138
+ alpha will be trained along with the rest of your model.
139
+ """
140
+ super().__init__()
141
+ self.in_features = (
142
+ out_features if isinstance(out_features, list) else [out_features]
143
+ )
144
+ self.proj = LoRACompatibleLinear(in_features, out_features)
145
+
146
+ # initialize alpha
147
+ self.alpha_logscale = alpha_logscale
148
+ if self.alpha_logscale: # log scale alphas initialized to zeros
149
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
150
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
151
+ else: # linear scale alphas initialized to ones
152
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
153
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
154
+
155
+ self.alpha.requires_grad = alpha_trainable
156
+ self.beta.requires_grad = alpha_trainable
157
+
158
+ self.no_div_by_zero = 0.000000001
159
+
160
+ def forward(self, x):
161
+ """
162
+ Forward pass of the function.
163
+ Applies the function to the input elementwise.
164
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
165
+ """
166
+ x = self.proj(x)
167
+ if self.alpha_logscale:
168
+ alpha = torch.exp(self.alpha)
169
+ beta = torch.exp(self.beta)
170
+ else:
171
+ alpha = self.alpha
172
+ beta = self.beta
173
+
174
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
175
+ torch.sin(x * alpha), 2
176
+ )
177
+
178
+ return x
179
+
180
+
181
+ class FeedForward(nn.Module):
182
+ r"""
183
+ A feed-forward layer.
184
+
185
+ Parameters:
186
+ dim (`int`): The number of channels in the input.
187
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
188
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
189
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
190
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
191
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ dim: int,
197
+ dim_out: Optional[int] = None,
198
+ mult: int = 4,
199
+ dropout: float = 0.0,
200
+ activation_fn: str = "geglu",
201
+ final_dropout: bool = False,
202
+ ):
203
+ super().__init__()
204
+ inner_dim = int(dim * mult)
205
+ dim_out = dim_out if dim_out is not None else dim
206
+
207
+ if activation_fn == "gelu":
208
+ act_fn = GELU(dim, inner_dim)
209
+ if activation_fn == "gelu-approximate":
210
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
211
+ elif activation_fn == "geglu":
212
+ act_fn = GEGLU(dim, inner_dim)
213
+ elif activation_fn == "geglu-approximate":
214
+ act_fn = ApproximateGELU(dim, inner_dim)
215
+ elif activation_fn == "snakebeta":
216
+ act_fn = SnakeBeta(dim, inner_dim)
217
+
218
+ self.net = nn.ModuleList([])
219
+ # project in
220
+ self.net.append(act_fn)
221
+ # project dropout
222
+ self.net.append(nn.Dropout(dropout))
223
+ # project out
224
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
225
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
226
+ if final_dropout:
227
+ self.net.append(nn.Dropout(dropout))
228
+
229
+ def forward(self, hidden_states):
230
+ for module in self.net:
231
+ hidden_states = module(hidden_states)
232
+ return hidden_states
233
+
234
+
235
+ @maybe_allow_in_graph
236
+ class BasicTransformerBlock(nn.Module):
237
+ r"""
238
+ A basic Transformer block.
239
+
240
+ Parameters:
241
+ dim (`int`): The number of channels in the input and output.
242
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
243
+ attention_head_dim (`int`): The number of channels in each head.
244
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
245
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
246
+ only_cross_attention (`bool`, *optional*):
247
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
248
+ double_self_attention (`bool`, *optional*):
249
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
250
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
251
+ num_embeds_ada_norm (:
252
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
253
+ attention_bias (:
254
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ dim: int,
260
+ num_attention_heads: int,
261
+ attention_head_dim: int,
262
+ dropout=0.0,
263
+ cross_attention_dim: Optional[int] = None,
264
+ activation_fn: str = "geglu",
265
+ num_embeds_ada_norm: Optional[int] = None,
266
+ attention_bias: bool = False,
267
+ only_cross_attention: bool = False,
268
+ double_self_attention: bool = False,
269
+ upcast_attention: bool = False,
270
+ norm_elementwise_affine: bool = True,
271
+ norm_type: str = "layer_norm",
272
+ final_dropout: bool = False,
273
+ use_omni_attn: bool = False,
274
+ ):
275
+ super().__init__()
276
+
277
+ self.use_omni_attn = use_omni_attn
278
+ self.dim = dim
279
+
280
+ self.only_cross_attention = only_cross_attention
281
+
282
+ self.use_ada_layer_norm_zero = (
283
+ num_embeds_ada_norm is not None
284
+ ) and norm_type == "ada_norm_zero"
285
+ self.use_ada_layer_norm = (
286
+ num_embeds_ada_norm is not None
287
+ ) and norm_type == "ada_norm"
288
+
289
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
290
+ raise ValueError(
291
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
292
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
293
+ )
294
+
295
+ # Define 3 blocks. Each block has its own normalization layer.
296
+ # 1. Self-Attn
297
+ if self.use_ada_layer_norm:
298
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
299
+ elif self.use_ada_layer_norm_zero:
300
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
301
+ else:
302
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
303
+
304
+ if self.use_omni_attn:
305
+ if only_cross_attention:
306
+ raise NotImplementedError
307
+ print(
308
+ "Use OmniWhisperAttention with flash attention. Dropout is ignored."
309
+ )
310
+ self.attn1 = OmniWhisperAttention(
311
+ embed_dim=dim, num_heads=num_attention_heads, causal=False
312
+ )
313
+ else:
314
+ self.attn1 = Attention(
315
+ query_dim=dim,
316
+ heads=num_attention_heads,
317
+ dim_head=attention_head_dim,
318
+ dropout=dropout,
319
+ bias=attention_bias,
320
+ cross_attention_dim=(
321
+ cross_attention_dim if only_cross_attention else None
322
+ ),
323
+ upcast_attention=upcast_attention,
324
+ )
325
+
326
+ # 2. Cross-Attn
327
+ if cross_attention_dim is not None or double_self_attention:
328
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
329
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
330
+ # the second cross attention block.
331
+ self.norm2 = (
332
+ AdaLayerNorm(dim, num_embeds_ada_norm)
333
+ if self.use_ada_layer_norm
334
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
335
+ )
336
+ self.attn2 = Attention(
337
+ query_dim=dim,
338
+ cross_attention_dim=(
339
+ cross_attention_dim if not double_self_attention else None
340
+ ),
341
+ heads=num_attention_heads,
342
+ dim_head=attention_head_dim,
343
+ dropout=dropout,
344
+ bias=attention_bias,
345
+ upcast_attention=upcast_attention,
346
+ # scale_qk=False, # uncomment this to not to use flash attention
347
+ ) # is self-attn if encoder_hidden_states is none
348
+ else:
349
+ self.norm2 = None
350
+ self.attn2 = None
351
+
352
+ # 3. Feed-forward
353
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
354
+ self.ff = FeedForward(
355
+ dim,
356
+ dropout=dropout,
357
+ activation_fn=activation_fn,
358
+ final_dropout=final_dropout,
359
+ )
360
+
361
+ # let chunk size default to None
362
+ self._chunk_size = None
363
+ self._chunk_dim = 0
364
+
365
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
366
+ # Sets chunk feed-forward
367
+ self._chunk_size = chunk_size
368
+ self._chunk_dim = dim
369
+
370
+ def forward(
371
+ self,
372
+ hidden_states: torch.FloatTensor,
373
+ attention_mask: Optional[torch.FloatTensor] = None,
374
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
375
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
376
+ timestep: Optional[torch.LongTensor] = None,
377
+ cross_attention_kwargs: Dict[str, Any] = None,
378
+ class_labels: Optional[torch.LongTensor] = None,
379
+ ):
380
+
381
+ bsz, tgt_len, d_model = hidden_states.shape
382
+
383
+ # Notice that normalization is always applied before the real computation in the following blocks.
384
+ # 1. Self-Attention
385
+ if self.use_ada_layer_norm:
386
+ norm_hidden_states = self.norm1(hidden_states, timestep)
387
+ elif self.use_ada_layer_norm_zero:
388
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
389
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
390
+ )
391
+ else:
392
+ norm_hidden_states = self.norm1(hidden_states)
393
+
394
+ cross_attention_kwargs = (
395
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
396
+ )
397
+
398
+ if self.use_omni_attn:
399
+ seq_len = attention_mask[:, 0, :].float().long().sum(dim=1)
400
+ var_len_attention_mask, unpacking_index = get_sequence_mask(
401
+ norm_hidden_states, seq_len
402
+ )
403
+ norm_hidden_states = torch.masked_select(
404
+ norm_hidden_states, var_len_attention_mask
405
+ )
406
+ norm_hidden_states = norm_hidden_states.view(torch.sum(seq_len), self.dim)
407
+ attn_output = self.attn1(norm_hidden_states, seq_len)
408
+ # unpacking
409
+ attn_output = torch.index_select(attn_output, 0, unpacking_index).view(
410
+ bsz, tgt_len, d_model
411
+ )
412
+ attn_output = torch.where(var_len_attention_mask, attn_output, 0)
413
+ else:
414
+ attn_output = self.attn1(
415
+ norm_hidden_states,
416
+ encoder_hidden_states=(
417
+ encoder_hidden_states if self.only_cross_attention else None
418
+ ),
419
+ attention_mask=(
420
+ encoder_attention_mask
421
+ if self.only_cross_attention
422
+ else attention_mask
423
+ ),
424
+ **cross_attention_kwargs,
425
+ )
426
+
427
+ if self.use_ada_layer_norm_zero:
428
+ attn_output = gate_msa.unsqueeze(1) * attn_output
429
+ hidden_states = attn_output + hidden_states
430
+
431
+ # 2. Cross-Attention
432
+ if self.attn2 is not None:
433
+ norm_hidden_states = (
434
+ self.norm2(hidden_states, timestep)
435
+ if self.use_ada_layer_norm
436
+ else self.norm2(hidden_states)
437
+ )
438
+
439
+ attn_output = self.attn2(
440
+ norm_hidden_states,
441
+ encoder_hidden_states=encoder_hidden_states,
442
+ attention_mask=encoder_attention_mask,
443
+ **cross_attention_kwargs,
444
+ )
445
+ hidden_states = attn_output + hidden_states
446
+
447
+ # 3. Feed-forward
448
+ norm_hidden_states = self.norm3(hidden_states)
449
+
450
+ if self.use_ada_layer_norm_zero:
451
+ norm_hidden_states = (
452
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
453
+ )
454
+
455
+ if self._chunk_size is not None:
456
+ # "feed_forward_chunk_size" can be used to save memory
457
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
458
+ raise ValueError(
459
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
460
+ )
461
+
462
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
463
+ ff_output = torch.cat(
464
+ [
465
+ self.ff(hid_slice)
466
+ for hid_slice in norm_hidden_states.chunk(
467
+ num_chunks, dim=self._chunk_dim
468
+ )
469
+ ],
470
+ dim=self._chunk_dim,
471
+ )
472
+ else:
473
+ ff_output = self.ff(norm_hidden_states)
474
+
475
+ if self.use_ada_layer_norm_zero:
476
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
477
+
478
+ hidden_states = ff_output + hidden_states
479
+
480
+ return hidden_states
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db38ae1d3c918180550c64e1105decf6cd9dfe71e92c800d153788d689393035
3
+ size 4966661936
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f324edf3f066581664bceeb992dccb9e1b8134ad16d79e9df8383cd865d9a680
3
+ size 4991490856
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3fa8dbe0b7d76565c529f6f15ac0965ab1bd9c07653b3c77f868ac160d40e7d
3
+ size 4932746512
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c4212ed32a9bb3ad957a4aaea33e8f8a840ea762c99876984297499e42505e9
3
+ size 4987094096
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fc6ca933cdfb81e1e8ee1cc5dfc1f3885243041faf3a11e3146bdae04ef93e1
3
+ size 2565519592
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_omni.py ADDED
@@ -0,0 +1,1011 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Baichuan Inc. All Rights Reserved.
2
+ #
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """ PyTorch omni model."""
22
+ import os
23
+ import time
24
+ import json
25
+ import math
26
+ import numpy as np
27
+ from typing import List, Optional, Tuple, Union, Any
28
+ from threading import Thread
29
+ from easydict import EasyDict
30
+
31
+ import torch
32
+ import torch.distributed
33
+ import torch.utils.checkpoint
34
+ from torch import nn
35
+ from torch.nn import CrossEntropyLoss
36
+ from torch.nn import functional as F
37
+ import torch.distributed as dist
38
+ from transformers import PreTrainedModel
39
+ from transformers.activations import ACT2FN
40
+ from dataclasses import dataclass
41
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
42
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
43
+ from transformers.generation.utils import GenerationConfig
44
+ from transformers.utils import logging
45
+ # import for dynamic import not used in this file
46
+ from .vector_quantize import VectorQuantize, EuclideanCodebook
47
+ from .matcha_components import (
48
+ SinusoidalPosEmb,
49
+ Block1D,
50
+ ResnetBlock1D,
51
+ Downsample1D,
52
+ TimestepEmbedding,
53
+ Upsample1D,
54
+ )
55
+ from .matcha_transformer import BasicTransformerBlock
56
+ from .flow_matching import ConditionalDecoder, ConditionalCFM
57
+
58
+ from .configuration_omni import OmniConfig
59
+ from .audio_modeling_omni import (RMSNorm,
60
+ OmniAudioEncoder,
61
+ OmniAudioDecoder,
62
+ OmniAudioVQBridgeTokenizer,
63
+ OmniAudioFlowMatchingDecoder)
64
+ from .visual_modeling_omni import OmniVisualEncoder, OmniVisualBridge
65
+ from .processor_omni import OmniMMProcessor
66
+
67
+ # support model path contain point(.)
68
+ try:
69
+ # step1: copy relative imports to transformers_modules
70
+ from .generation_utils import build_chat_input, TextIterStreamer
71
+ from .sequence_parallel_utils import (
72
+ create_attention_layer,
73
+ get_sequence_parallel_size,
74
+ get_sequence_parallel_chunk,
75
+ )
76
+ except ModuleNotFoundError:
77
+ # step2: direct import from transformers_modules
78
+ try: # bypass check_imports failure
79
+ import sys
80
+ sys.path.append(os.path.dirname(__file__))
81
+ from generation_utils import build_chat_input, TextIterStreamer
82
+ from sequence_parallel_utils import (
83
+ create_attention_layer,
84
+ get_sequence_parallel_size,
85
+ get_sequence_parallel_chunk,
86
+ )
87
+ except Exception:
88
+ raise
89
+
90
+ logger = logging.get_logger(__name__)
91
+
92
+ def get_slopes(n):
93
+ def get_slopes_power_of_2(n):
94
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
95
+ ratio = start
96
+ return [start * ratio ** i for i in range(n)]
97
+
98
+ if math.log2(n).is_integer():
99
+ return get_slopes_power_of_2(
100
+ n) # In the paper, we only train models that have 2^a heads for some a. This function has
101
+ else: # some good properties that only occur when the input is a power of 2. To maintain that even
102
+ closest_power_of_2 = 2 ** math.floor(
103
+ math.log2(n)) # when the number of heads is not a power of 2, we use this workaround.
104
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
105
+ :n - closest_power_of_2]
106
+
107
+
108
+ class RotaryEmbedding(torch.nn.Module):
109
+ def __init__(self, dim, max_position_embeddings=2048, base=5e6, device=None):
110
+ super().__init__()
111
+ # 修复RePE初始化精度问题 https://zhuanlan.zhihu.com/p/678963442
112
+ # DeepSpeed 会 Hack torch.arange 强制在 GPU 上运行,这里使用原生的 torch.arange
113
+ try:
114
+ import deepspeed
115
+ self.arange = deepspeed.runtime.zero.partition_parameters._orig_torch_arange
116
+ except:
117
+ self.arange = torch.arange
118
+
119
+ self.inv_freq = 1.0 / (base ** (self.arange(0, dim, 2).float().to(device) / dim))
120
+ self.max_seq_len_cached = max_position_embeddings
121
+ t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
122
+ freqs = torch.outer(t, self.inv_freq)
123
+ emb = torch.cat((freqs, freqs), dim=-1)
124
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
125
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
126
+
127
+ def forward(self, x, seq_len=None):
128
+ # x: [bs, num_attention_heads, seq_len, head_size]
129
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
130
+ if seq_len > self.max_seq_len_cached:
131
+ self.max_seq_len_cached = seq_len
132
+ t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
133
+ freqs = torch.outer(t, self.inv_freq)
134
+ emb = torch.cat((freqs, freqs), dim=-1)
135
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
136
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
137
+ return (
138
+ self.cos_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
139
+ self.sin_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
140
+ )
141
+
142
+
143
+ def rotate_half(x):
144
+ """Rotates half the hidden dims of the input."""
145
+ x1 = x[..., : x.shape[-1] // 2]
146
+ x2 = x[..., x.shape[-1] // 2:]
147
+ return torch.cat((-x2, x1), dim=-1)
148
+
149
+
150
+ def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
151
+ cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
152
+ sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
153
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
154
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
155
+ q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
156
+ k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
157
+ return q_embed.to(q.dtype), k_embed.to(k.dtype)
158
+
159
+
160
+ class MLP(nn.Module):
161
+ def __init__(
162
+ self,
163
+ hidden_size: int,
164
+ intermediate_size: int,
165
+ hidden_act: str,
166
+ ):
167
+ super().__init__()
168
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
169
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
170
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
171
+ self.act_fn = ACT2FN[hidden_act]
172
+
173
+ def forward(self, x):
174
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
175
+
176
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
177
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
178
+ """
179
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
180
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
181
+ """
182
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
183
+ if n_rep == 1:
184
+ return hidden_states
185
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
186
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
187
+
188
+
189
+ class Attention(nn.Module):
190
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
191
+ def __init__(self, config: OmniConfig, is_sparse=False):
192
+ super().__init__()
193
+ self.config = config
194
+ self.position_embedding_type = config.position_embedding_type.lower()
195
+ self.num_kv_heads = config.num_key_value_heads
196
+ self.head_dim = config.head_dim
197
+ self.hidden_size = config.num_attention_heads * self.head_dim
198
+ self.hidden_kv_size = self.num_kv_heads * self.head_dim
199
+
200
+ if is_sparse:
201
+ self.num_heads = config.sparse_attention_heads
202
+ assert self.num_kv_heads == config.num_attention_heads
203
+ self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=config.attention_qkv_bias)
204
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
205
+ else:
206
+ self.num_heads = config.num_attention_heads
207
+ if self.config.attention_qkv_pack:
208
+ self.W_pack = nn.Linear(config.hidden_size, self.hidden_size + self.hidden_kv_size * 2, bias=config.attention_qkv_bias)
209
+ else:
210
+ self.q_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=config.attention_qkv_bias)
211
+ self.k_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
212
+ self.v_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
213
+
214
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
215
+
216
+ if self.position_embedding_type == 'rope':
217
+ self.rotary_emb = RotaryEmbedding(
218
+ dim=self.head_dim,
219
+ max_position_embeddings=config.max_position_embeddings,
220
+ base=config.get_rotary_base()
221
+ )
222
+ elif self.position_embedding_type == 'alibi':
223
+ self.alibi_slopes = get_slopes(self.num_heads)
224
+ self.attention = create_attention_layer(self.hidden_size, self.num_heads, self.head_dim)
225
+
226
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
227
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
228
+
229
+ def _repeat_kv(self, hidden_states: torch.Tensor, num_heads: int) -> torch.Tensor:
230
+ assert hidden_states.size(1) <= num_heads and num_heads % hidden_states.size(1) == 0
231
+ return repeat_kv(hidden_states, num_heads // hidden_states.size(1))
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states: torch.Tensor,
236
+ attention_mask: Optional[torch.Tensor] = None,
237
+ position_ids: Optional[torch.LongTensor] = None,
238
+ seqlens: Optional[torch.IntTensor] = None,
239
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
240
+ output_attentions: bool = False,
241
+ use_cache: bool = False,
242
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
243
+ bsz, q_len = hidden_states.shape[:2]
244
+
245
+ if self.config.attention_qkv_pack:
246
+ proj = self.W_pack(hidden_states)
247
+ query_states, key_states, value_states = proj.split([self.hidden_size, self.hidden_kv_size, self.hidden_kv_size], dim=-1)
248
+ else:
249
+ query_states = self.q_proj(hidden_states)
250
+ key_states = self.k_proj(hidden_states)
251
+ value_states = self.v_proj(hidden_states)
252
+
253
+ # (B, S, hidden_size) -> (B, num_heads, S, head_size)
254
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
255
+ # (B, S, hidden_size) -> (B, num_kv_heads, S, head_size)
256
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
257
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
258
+
259
+ kv_seq_len = key_states.shape[-2]
260
+ if past_key_value is not None:
261
+ kv_seq_len += past_key_value[0].shape[-2]
262
+ if self.position_embedding_type == 'rope':
263
+ max_position = position_ids.max().item()+1 if position_ids is not None else kv_seq_len * get_sequence_parallel_size()
264
+ cos, sin = self.rotary_emb(value_states, seq_len=max_position)
265
+ query_states, key_states = apply_rotary_pos_emb(
266
+ query_states, key_states, cos, sin,
267
+ get_sequence_parallel_chunk(position_ids)
268
+ )
269
+
270
+ if past_key_value is not None:
271
+ # reuse k, v, self_attention
272
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
273
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
274
+ past_key_value = (key_states, value_states) if use_cache else None
275
+
276
+ # repeat k/v heads if n_kv_heads < n_heads
277
+ key_states = self._repeat_kv(key_states, query_states.size(1))
278
+ value_states = self._repeat_kv(value_states, query_states.size(1))
279
+
280
+ if seqlens is not None:
281
+ seqlens = seqlens.to(dtype=torch.int32)
282
+ max_seqlen = (seqlens[1:] - seqlens[:-1]).max().item()
283
+ if self.position_embedding_type == 'alibi':
284
+ alibi_slopes = torch.tensor(self.alibi_slopes, dtype=torch.float32).to(query_states.device)
285
+ else:
286
+ alibi_slopes = None
287
+ attn_output = self.attention(
288
+ query_states, key_states, value_states, seqlens, seqlens,
289
+ max_seqlen, max_seqlen, causal=True, alibi_slopes=alibi_slopes, use_flash=True)
290
+ else:
291
+ attn_output = self.attention(
292
+ query_states, key_states, value_states, attn_mask=attention_mask, use_flash=False)
293
+
294
+ attn_output = attn_output.reshape(bsz, q_len, -1)
295
+ attn_output = self.o_proj(attn_output)
296
+
297
+ return attn_output, None, past_key_value
298
+
299
+
300
+ class DecoderLayer(nn.Module):
301
+ def __init__(self, config: OmniConfig, is_sparse=False):
302
+ super().__init__()
303
+ self.hidden_size = config.hidden_size
304
+ self.self_attn = Attention(config=config, is_sparse=is_sparse)
305
+ self.mlp = MLP(
306
+ hidden_size=self.hidden_size,
307
+ intermediate_size=config.intermediate_size,
308
+ hidden_act=config.hidden_act,
309
+ )
310
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
311
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
312
+
313
+ def forward(
314
+ self,
315
+ hidden_states: torch.Tensor,
316
+ attention_mask: Optional[torch.Tensor] = None,
317
+ position_ids: Optional[torch.LongTensor] = None,
318
+ seqlens: Optional[torch.IntTensor] = None,
319
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
320
+ output_attentions: Optional[bool] = False,
321
+ use_cache: Optional[bool] = False,
322
+ group_index=None,
323
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
324
+
325
+ residual = hidden_states
326
+
327
+ hidden_states = self.input_layernorm(hidden_states)
328
+
329
+ # Self Attention
330
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
331
+ hidden_states=hidden_states,
332
+ attention_mask=attention_mask,
333
+ position_ids=position_ids,
334
+ seqlens=seqlens,
335
+ past_key_value=past_key_value,
336
+ output_attentions=output_attentions,
337
+ use_cache=use_cache,
338
+ )
339
+ hidden_states = residual + hidden_states
340
+
341
+ # Fully Connected
342
+ residual = hidden_states
343
+ hidden_states = self.post_attention_layernorm(hidden_states)
344
+ hidden_states = self.mlp(hidden_states)
345
+ hidden_states = residual + hidden_states
346
+
347
+ outputs = (hidden_states,)
348
+
349
+ if output_attentions:
350
+ outputs += (self_attn_weights,)
351
+
352
+ if use_cache:
353
+ outputs += (present_key_value,)
354
+
355
+ return outputs
356
+
357
+
358
+ class OmniPreTrainedModel(PreTrainedModel):
359
+ config_class = OmniConfig
360
+ base_model_prefix = "model"
361
+ supports_gradient_checkpointing = True
362
+ _no_split_modules = ["DecoderLayer"]
363
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
364
+
365
+ def _init_weights(self, module):
366
+ std = self.config.initializer_range
367
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
368
+ module.weight.data.normal_(mean=0.0, std=std)
369
+ if module.bias is not None:
370
+ module.bias.data.zero_()
371
+ elif isinstance(module, nn.Embedding):
372
+ module.weight.data.normal_(mean=0.0, std=std)
373
+ if module.padding_idx is not None:
374
+ module.weight.data[module.padding_idx].zero_()
375
+ elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.GroupNorm):
376
+ module.weight.data.fill_(1.0)
377
+ module.bias.data.zero_()
378
+ elif isinstance(module, RMSNorm):
379
+ module.weight.data.fill_(1.0)
380
+
381
+ def _set_gradient_checkpointing(self, module, value=False):
382
+ if isinstance(module, OmniModel):
383
+ module.gradient_checkpointing = value
384
+
385
+ @dataclass
386
+ class OmniModelOutputWithPast(BaseModelOutputWithPast):
387
+ audio_encoder_ret: Optional[Any] = None
388
+ audio_decoder_ret: Optional[Any] = None
389
+
390
+ class OmniModel(OmniPreTrainedModel):
391
+ def __init__(self, config: OmniConfig):
392
+ super().__init__(config)
393
+ self.padding_idx = config.pad_token_id
394
+ self.vocab_size = config.vocab_size
395
+
396
+ if config.visual_config.enable:
397
+ self.visual_model = OmniVisualEncoder(config.visual_config)
398
+ self.visual_bridge_model = OmniVisualBridge(config.visual_config)
399
+ if config.video_config.enable and not config.visual_config.enable: # in case 没有visual_config而只有video_config
400
+ self.visual_model = OmniVisualEncoder(config.video_config)
401
+ self.visual_bridge_model = OmniVisualBridge(config.video_config)
402
+
403
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
404
+ self.layers = nn.ModuleList([
405
+ DecoderLayer(config, is_sparse=layer_idx in config.sparse_attention_layers)
406
+ for layer_idx in range(config.num_hidden_layers)
407
+ ])
408
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
409
+
410
+ self.audio_embed_layers = nn.ModuleList([
411
+ nn.Embedding(codedim + 1, config.hidden_size)
412
+ for i, codedim in enumerate(config.audio_config.vq_config.codebook_sizes)
413
+ ])
414
+
415
+ self.gradient_checkpointing = True
416
+ # Initialize weights and apply final processing
417
+ self.post_init()
418
+
419
+ def get_input_embeddings(self):
420
+ return self.embed_tokens
421
+
422
+ def set_input_embeddings(self, value):
423
+ self.embed_tokens = value
424
+
425
+ @torch.no_grad()
426
+ def get_multimodal_mask(self, input_ids, pad_token_id, special_token_list):
427
+ '''
428
+ 获取任意模态的特殊mask,包含以下
429
+ 1. pad mask 表示文本中图像/语音/视频模态提前留出的token位置
430
+ 2. special token mask 特殊token 例如对理解模型<start> <end> 不需要next token prediction
431
+ 3. embedding mask / lm_head mask 标记出特殊token在embedding中的mask
432
+ '''
433
+ pad_mask = torch.eq(input_ids, pad_token_id)
434
+ sp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
435
+ lm_head_mask = torch.zeros([self.config.vocab_size, 1], dtype=torch.bool)
436
+ for sp_id in special_token_list:
437
+ sp_mask = torch.logical_or(sp_mask, torch.eq(input_ids, sp_id))
438
+ lm_head_mask[sp_id, 0] = True
439
+ return pad_mask, sp_mask, lm_head_mask
440
+
441
+ def get_multimodal_embed(
442
+ self,
443
+ input_ids,
444
+ text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
445
+ multimodal_embed,
446
+ pad_token_id,
447
+ fake_input,
448
+ group_index=None, # 某种模态的编号
449
+ ):
450
+ pad_mask, sp_mask, _ = self.get_multimodal_mask(input_ids, pad_token_id, self.config.multimodal_special_token_list)
451
+ if not self.training: # 推理支持auto map 把多模态模块输出和input_ids 统一到一个device
452
+ multimodal_embed = multimodal_embed.to(input_ids.device)
453
+ if not fake_input: # 检查多模态token 和 pad mask数量一致 (不正确的截断会导致该问题)
454
+ assert pad_mask.sum() == multimodal_embed.shape[0]
455
+ else:
456
+ assert pad_mask.sum() <= 0
457
+
458
+ # 合并 当前模态embeddings 和text embeddings
459
+ input_ids = torch.where(pad_mask, torch.cumsum(pad_mask.view(-1).to(input_ids), dim=0).view(input_ids.shape)-1, input_ids)
460
+ text_embedding = (1 - pad_mask.to(text_embedding)).unsqueeze(-1) * text_embedding # pad token位置填0
461
+ multimodal_embedding = torch.embedding(multimodal_embed, input_ids * pad_mask) # 非 pad token 位置填idx=0位置结果
462
+ multimodal_embedding = pad_mask.to(multimodal_embedding).unsqueeze(-1) * multimodal_embedding # 非pad token 位置填0
463
+ final_embedding = multimodal_embedding.to(text_embedding) + text_embedding
464
+
465
+ if group_index is None:
466
+ group_index = pad_mask.to(torch.int32)
467
+ else:
468
+ current_index = torch.max(group_index) + 1
469
+ group_index += pad_mask.to(torch.int32) * current_index # 假设模态无重叠
470
+
471
+ return final_embedding, group_index
472
+
473
+ def get_visual_embed(
474
+ self,
475
+ input_ids,
476
+ text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
477
+ images = None,
478
+ patch_nums = None,
479
+ images_grid = None,
480
+ videos = None,
481
+ videos_patch_nums = None,
482
+ videos_grid = None,
483
+ group_index = None, # 某种模态的编号
484
+ ):
485
+ if images is None or len(images) <= 0:
486
+ images, images_grid, patch_nums = self.visual_model.fake_input(input_ids.device)
487
+ image_fake_input = True
488
+ else:
489
+ image_fake_input = False
490
+
491
+ if videos is None or len(videos) <= 0 :
492
+ videos, videos_grid, videos_patch_nums = self.visual_model.fake_input(input_ids.device)
493
+ video_fake_input = True
494
+ else:
495
+ video_fake_input = False
496
+
497
+ visual_input = images + videos
498
+ visual_grid = images_grid + videos_grid
499
+
500
+ visual_input = torch.cat(visual_input, dim=0)
501
+ visual_grid = torch.tensor(np.array(visual_grid))
502
+
503
+ visual_embed = self.visual_model(visual_input, grid_thw=visual_grid)
504
+ visual_embed = self.visual_bridge_model(visual_embed)
505
+
506
+ assert sum(patch_nums) + sum(videos_patch_nums) == visual_embed.shape[0]
507
+ images_embed = visual_embed[:sum(patch_nums)]
508
+ videos_embed = visual_embed[sum(patch_nums):]
509
+
510
+ final_embedding, group_index = self.get_multimodal_embed(input_ids, text_embedding, images_embed, self.config.visual_config.image_pad_token_id, image_fake_input, group_index=group_index)
511
+ final_embedding, group_index = self.get_multimodal_embed(input_ids, final_embedding, videos_embed, self.config.video_config.video_place_token_id, video_fake_input, group_index=group_index)
512
+ return final_embedding, group_index
513
+
514
+
515
+ @torch.no_grad()
516
+ def audio_fake_input(self, device):
517
+ return torch.zeros(5, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.int32, device=device)
518
+
519
+ def forward(
520
+ self,
521
+ input_ids: torch.LongTensor = None,
522
+ attention_mask: Optional[torch.Tensor] = None,
523
+ position_ids: Optional[torch.LongTensor] = None,
524
+ seqlens: Optional[torch.IntTensor] = None,
525
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
526
+ inputs_embeds: Optional[torch.FloatTensor] = None,
527
+ audios_tokens: Optional[List|torch.Tensor] = None, # 音频token bs*seqlen*vq_num
528
+ images: Optional[List|torch.Tensor] = None,
529
+ patch_nums: Optional[torch.Tensor] = None,
530
+ images_grid: Optional[List|torch.Tensor] = None,
531
+ videos: Optional[List|torch.Tensor] = None,
532
+ videos_patch_nums: Optional[torch.Tensor] = None,
533
+ videos_grid: Optional[List|torch.Tensor] = None,
534
+ use_cache: Optional[bool] = None,
535
+ output_attentions: Optional[bool] = None,
536
+ output_hidden_states: Optional[bool] = None,
537
+ return_dict: Optional[bool] = None,
538
+ ) -> Union[Tuple, OmniModelOutputWithPast]:
539
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
540
+ output_hidden_states = (
541
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
542
+ )
543
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
544
+ return_dict = True if (return_dict is not None or self.training) else self.config.use_return_dict
545
+
546
+ # retrieve input_ids and inputs_embeds
547
+ if input_ids is not None and inputs_embeds is not None:
548
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
549
+ elif input_ids is not None:
550
+ batch_size, seq_length = input_ids.shape
551
+ elif inputs_embeds is not None:
552
+ batch_size, seq_length, _ = inputs_embeds.shape
553
+ else:
554
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
555
+
556
+ seq_length_with_past = seq_length
557
+ past_key_values_length = 0
558
+
559
+ if past_key_values is not None:
560
+ past_key_values_length = past_key_values[0][0].shape[2]
561
+ seq_length_with_past = seq_length_with_past + past_key_values_length
562
+
563
+ if position_ids is None:
564
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
565
+ position_ids = torch.arange(
566
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
567
+ )
568
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
569
+ else:
570
+ position_ids = position_ids.view(-1, seq_length).long()
571
+
572
+ group_index, audio_decoder_ret = None, None
573
+ if inputs_embeds is None:
574
+ sp_input_ids = get_sequence_parallel_chunk(input_ids)
575
+ inputs_embeds = self.embed_tokens(sp_input_ids)
576
+ if audios_tokens is None or len(audios_tokens) <= 0 :
577
+ audios_tokens = torch.zeros(5, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.int32, device=input_ids.device) # a fake input
578
+ fake_input = True
579
+ else:
580
+ fake_input = False
581
+ for i, audio_emb_layer in enumerate(self.audio_embed_layers):
582
+ if i==0:
583
+ audio_embs = audio_emb_layer(audios_tokens[..., i])
584
+ else:
585
+ audio_embs += audio_emb_layer(audios_tokens[..., i])
586
+ inputs_embeds, group_index = self.get_multimodal_embed(sp_input_ids, inputs_embeds, audio_embs, self.config.audio_config.audio_pad_token_id, fake_input, group_index=group_index)
587
+
588
+ if self.config.visual_config.enable or self.config.video_config.enable:
589
+ inputs_embeds, group_index = self.get_visual_embed(sp_input_ids, inputs_embeds, images, patch_nums, images_grid, videos, videos_patch_nums, videos_grid, group_index=group_index) # 注意更新group index
590
+
591
+ if seqlens is not None and seqlens.ndim == 2:
592
+ cu_seqlens = []
593
+ offset, seqlen = 0, seqlens.size(1)
594
+ for lens in seqlens:
595
+ cu_seqlens.append(offset)
596
+ cu_seqlens.extend((lens[(lens > 0) & (lens < seqlen)] + offset).tolist())
597
+ offset += seqlen
598
+ cu_seqlens.append(offset)
599
+ seqlens = torch.tensor(cu_seqlens, dtype=seqlens.dtype, device=seqlens.device)
600
+ elif seqlens is None and self.training:
601
+ seqlens = torch.arange(
602
+ end=input_ids.size(0) + 1,
603
+ dtype=torch.int32,
604
+ device=input_ids.device
605
+ ) * input_ids.size(1)
606
+ if seqlens is not None:
607
+ attention_mask = None # unset attention_mask to save memory
608
+
609
+ if seqlens is None and attention_mask is None:
610
+ attention_mask = torch.ones(
611
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
612
+ )
613
+ if attention_mask is not None:
614
+ attention_mask = _prepare_4d_causal_attention_mask(
615
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
616
+ )
617
+
618
+ # embed positions
619
+ hidden_states = inputs_embeds
620
+
621
+ if self.gradient_checkpointing and self.training:
622
+ if use_cache:
623
+ logger.warning_once(
624
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
625
+ )
626
+ use_cache = False
627
+
628
+ # decoder layers
629
+ all_hidden_states = () if output_hidden_states else None
630
+ all_self_attns = () if output_attentions else None
631
+ next_decoder_cache = () if use_cache else None
632
+
633
+ for idx, decoder_layer in enumerate(self.layers):
634
+ if output_hidden_states:
635
+ all_hidden_states += (hidden_states,)
636
+
637
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
638
+
639
+ if self.gradient_checkpointing and self.training:
640
+
641
+ def create_custom_forward(module):
642
+ def custom_forward(*inputs):
643
+ # None for past_key_value
644
+ return module(*inputs, output_attentions, False, group_index)
645
+
646
+ return custom_forward
647
+
648
+ layer_outputs = torch.utils.checkpoint.checkpoint(
649
+ create_custom_forward(decoder_layer),
650
+ hidden_states,
651
+ attention_mask,
652
+ position_ids,
653
+ seqlens,
654
+ None,
655
+ )
656
+ else:
657
+ layer_outputs = decoder_layer(
658
+ hidden_states,
659
+ attention_mask=attention_mask,
660
+ position_ids=position_ids,
661
+ seqlens=seqlens,
662
+ past_key_value=past_key_value,
663
+ output_attentions=output_attentions,
664
+ use_cache=use_cache,
665
+ group_index=group_index,
666
+ )
667
+
668
+ hidden_states = layer_outputs[0]
669
+
670
+ if use_cache:
671
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
672
+
673
+ if output_attentions:
674
+ all_self_attns += (layer_outputs[1],)
675
+
676
+ hidden_states = self.norm(hidden_states)
677
+
678
+ # add hidden states from the last decoder layer
679
+ if output_hidden_states:
680
+ all_hidden_states += (hidden_states,)
681
+
682
+ next_cache = next_decoder_cache if use_cache else None
683
+ if not return_dict:
684
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
685
+ return BaseModelOutputWithPast(
686
+ last_hidden_state=hidden_states,
687
+ past_key_values=next_cache,
688
+ hidden_states=all_hidden_states,
689
+ attentions=all_self_attns,
690
+ )
691
+
692
+
693
+ class NormHead(nn.Module):
694
+ def __init__(self, hidden_size, vocab_size, bias=False):
695
+ super().__init__()
696
+ self.hidden_size = hidden_size
697
+ self.vocab_size = vocab_size
698
+ self.weight = nn.Parameter(torch.empty((self.vocab_size, self.hidden_size)))
699
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
700
+
701
+ def forward(self, hidden_states, mask=None):
702
+ norm_weight = nn.functional.normalize(self.weight)
703
+ if mask is not None:
704
+ mask = mask.to(norm_weight)
705
+ norm_weight = norm_weight * mask + (1 - mask) * norm_weight.detach()
706
+ return nn.functional.linear(hidden_states, norm_weight)
707
+
708
+
709
+ def extra_repr(self) -> str:
710
+ return f'in_features={self.hidden_size}, out_features={self.vocab_size}'
711
+
712
+ @dataclass
713
+ class OmniMMCausalLMOutputWithPast(ModelOutput):
714
+ loss: Optional[torch.FloatTensor] = None
715
+ logits: Optional[torch.FloatTensor] = None
716
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
717
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
718
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
719
+ audios_emb_for_infer: Optional[torch.FloatTensor] = None # 用于audio head 推理的 embeddings
720
+
721
+
722
+ class CasualDepthTransformerLayer(nn.Module):
723
+ def __init__(self, config, depth):
724
+ super().__init__()
725
+ self.config = config
726
+ embed_size = config.hidden_size
727
+ assert embed_size % 128 == 0
728
+ num_heads = embed_size // 128
729
+ self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads,batch_first=True)
730
+ self.layernorm1 = RMSNorm(embed_size)
731
+ self.layernorm2 = RMSNorm(embed_size)
732
+ self.linear1 = nn.Linear(embed_size * depth, 2 * embed_size)
733
+ self.linear2 = nn.Linear(2 * embed_size * depth, embed_size)
734
+
735
+ def forward(self, x):
736
+ seq_len = x.size(1)
737
+ res = x
738
+ x = self.layernorm1(x)
739
+ src_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
740
+ _x, _ = self.self_attention(x, x, x, is_causal=True, attn_mask=src_mask)
741
+ res = _x + res # (bs, sl, d)
742
+ res = self.layernorm2(res)
743
+ x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (2 * self.config.hidden_size, -1, self.config.hidden_size)))
744
+ x = torch.nn.functional.gelu(x)
745
+ x = torch.einsum('blt,dlt->bld', x, torch.reshape(self.linear2.weight, (self.config.hidden_size, -1, 2 * self.config.hidden_size)))
746
+ return res + x
747
+
748
+ class OmniAudioHead(nn.Module):
749
+ def __init__(self, config):
750
+ super().__init__()
751
+ self.config = config
752
+ hidden_size = config.hidden_size
753
+ self.transformer_layers = nn.ModuleList([
754
+ CasualDepthTransformerLayer(config, len(config.audio_config.vq_config.codebook_sizes))
755
+ for _ in range(config.audio_config.audio_head_transformer_layers)
756
+ ])
757
+ self.headnorm = RMSNorm(hidden_size)
758
+ self.heads = nn.ModuleList([
759
+ nn.Linear(hidden_size, vq_size+1)
760
+ for vq_size in config.audio_config.vq_config.codebook_sizes
761
+ ])
762
+ self.gradient_checkpointing = True
763
+
764
+ def forward(self, x, audios_tokens, audio_emb_layers):
765
+ cumsum_audio_embed = torch.stack([
766
+ audio_emb_layers[i](audios_tokens[..., i])
767
+ for i, vq_size in enumerate(self.config.audio_config.vq_config.codebook_sizes[:-1])
768
+ ], dim=1)
769
+ cumsum_audio_embed = torch.cumsum(cumsum_audio_embed, dim=1) # (bs, depth-1, d)
770
+ hidden_states = torch.concat([x.reshape(-1, 1, self.config.hidden_size), cumsum_audio_embed], dim=1) # (bs, depth, d)
771
+ assert hidden_states.size(1) == len(self.config.audio_config.vq_config.codebook_sizes)
772
+ for i, tlayer in enumerate(self.transformer_layers):
773
+ hidden_states = tlayer(hidden_states,)
774
+ hidden_states = self.headnorm(hidden_states)
775
+ logits = [head(hidden_states[:,i]) for i, head in enumerate(self.heads)]
776
+ return logits
777
+
778
+
779
+ class OmniForCausalLM(OmniPreTrainedModel):
780
+ def __init__(self, config):
781
+ super().__init__(config)
782
+ self.config = config
783
+ self.model = OmniModel(config)
784
+ self.audio_tokenizer = OmniAudioTokenizer(config)
785
+ self.audio_head = OmniAudioHead(config)
786
+ if config.use_norm_head:
787
+ self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
788
+ else:
789
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
790
+ # Initialize weights and apply final processing
791
+ self.post_init()
792
+
793
+ @property
794
+ def main_device(self):
795
+ return self.lm_head.weight.device
796
+
797
+ def bind_processor(self, tokenizer, **kwargs):
798
+ self.processor = OmniMMProcessor(
799
+ tokenizer=tokenizer,
800
+ config=self.config,
801
+ **kwargs,
802
+ )
803
+ return self.processor
804
+
805
+ def get_input_embeddings(self):
806
+ return self.model.embed_tokens
807
+
808
+ def set_input_embeddings(self, value):
809
+ self.model.embed_tokens = value
810
+
811
+ def get_output_embeddings(self):
812
+ return self.lm_head
813
+
814
+ def set_output_embeddings(self, new_embeddings):
815
+ self.lm_head = new_embeddings
816
+
817
+ def set_decoder(self, decoder):
818
+ self.model = decoder
819
+
820
+ def get_decoder(self):
821
+ return self.model
822
+
823
+ def forward(
824
+ self,
825
+ input_ids: torch.LongTensor = None,
826
+ attention_mask: Optional[torch.Tensor] = None,
827
+ position_ids: Optional[torch.LongTensor] = None,
828
+ seqlens: Optional[torch.IntTensor] = None,
829
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
830
+ inputs_embeds: Optional[torch.FloatTensor] = None,
831
+ labels: Optional[torch.LongTensor] = None,
832
+ audios: Optional[List|torch.Tensor] = None,
833
+ audios_tokens: Optional[List|torch.Tensor] = None,
834
+ encoder_length: Optional[torch.Tensor] = None,
835
+ bridge_length: Optional[torch.Tensor] = None,
836
+ images: Optional[torch.Tensor] = None,
837
+ patch_nums: Optional[torch.Tensor] = None,
838
+ images_grid: Optional[torch.Tensor] = None,
839
+ videos: Optional[torch.Tensor] = None,
840
+ videos_patch_nums: Optional[torch.Tensor] = None,
841
+ videos_grid: Optional[torch.Tensor] = None,
842
+ use_cache: Optional[bool] = None,
843
+ output_attentions: Optional[bool] = None,
844
+ output_hidden_states: Optional[bool] = None,
845
+ return_dict: Optional[bool] = None,
846
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
847
+
848
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
849
+ output_hidden_states = (
850
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
851
+ )
852
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
853
+
854
+ if audios_tokens is not None:
855
+ assert isinstance(audios_tokens, torch.Tensor)
856
+ else:
857
+ if audios is None or len(audios) == 0:
858
+ audios_tokens = None
859
+ else:
860
+ audios_tokens = self.audio_tokenizer(audios,encoder_length,bridge_length)
861
+
862
+ outputs = self.model(
863
+ input_ids=input_ids,
864
+ attention_mask=attention_mask,
865
+ position_ids=position_ids,
866
+ seqlens=seqlens,
867
+ past_key_values=past_key_values,
868
+ inputs_embeds=inputs_embeds,
869
+ audios_tokens=audios_tokens,
870
+ images=images,
871
+ patch_nums=patch_nums,
872
+ images_grid=images_grid,
873
+ videos=videos,
874
+ videos_patch_nums=videos_patch_nums,
875
+ videos_grid=videos_grid,
876
+ use_cache=use_cache,
877
+ output_attentions=output_attentions,
878
+ output_hidden_states=output_hidden_states,
879
+ return_dict=return_dict,
880
+ )
881
+ hidden_states = outputs.last_hidden_state
882
+ audios_emb_for_infer = hidden_states[:,-1,:]
883
+ logits = self.lm_head(hidden_states)
884
+
885
+ return OmniMMCausalLMOutputWithPast(
886
+ logits=logits,
887
+ past_key_values=outputs.past_key_values,
888
+ hidden_states=outputs.hidden_states,
889
+ attentions=outputs.attentions,
890
+ audios_emb_for_infer=audios_emb_for_infer
891
+ )
892
+
893
+ def prepare_inputs_for_generation(
894
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
895
+ ):
896
+ if past_key_values:
897
+ input_ids = input_ids[:, past_key_values[0][0].shape[-2]:]
898
+
899
+ position_ids = kwargs.get("position_ids", None)
900
+ if attention_mask is not None and position_ids is None:
901
+ # create position_ids on the fly for batch generation
902
+ position_ids = attention_mask.long().cumsum(-1)
903
+ # position_ids.masked_fill_(attention_mask == 0, 1)
904
+ if past_key_values:
905
+ position_ids = position_ids[:, past_key_values[0][0].shape[-2]:]
906
+
907
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
908
+ if inputs_embeds is not None and past_key_values is None:
909
+ model_inputs = {"inputs_embeds": inputs_embeds}
910
+ elif past_key_values is not None:
911
+ model_inputs = {"input_ids": input_ids}
912
+ else:
913
+ model_inputs = {"input_ids": input_ids,
914
+ "audios": kwargs.get("audios", None), "encoder_length": kwargs.get("encoder_length", None), "bridge_length": kwargs.get("bridge_length", None),
915
+ "audios_tokens": kwargs.get("audios_tokens", None),
916
+ "images": kwargs.get("images", None),
917
+ "videos": kwargs.get("videos", None)
918
+ }
919
+
920
+ model_inputs.update(
921
+ {
922
+ "position_ids": position_ids,
923
+ "past_key_values": past_key_values,
924
+ "use_cache": kwargs.get("use_cache"),
925
+ "attention_mask": attention_mask,
926
+ "images_grid": kwargs.get("images_grid"),
927
+ "videos_grid": kwargs.get("videos_grid"),
928
+ "patch_nums": kwargs.get("patch_nums"),
929
+ "videos_patch_nums": kwargs.get("videos_patch_nums"),
930
+ }
931
+ )
932
+ return model_inputs
933
+
934
+ @staticmethod
935
+ def _reorder_cache(past_key_values, beam_idx):
936
+ reordered_past = ()
937
+ for layer_past in past_key_values:
938
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
939
+ return reordered_past
940
+
941
+ def chat(self, tokenizer, messages: List[dict], stream=False,
942
+ generation_config: Optional[GenerationConfig]=None):
943
+ generation_config = generation_config or self.generation_config
944
+ input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
945
+ if stream:
946
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
947
+ Thread(target=self.generate, kwargs=dict(
948
+ inputs=input_ids, streamer=streamer,
949
+ generation_config=generation_config,
950
+ )).start()
951
+ return streamer
952
+ else:
953
+ outputs = self.generate(input_ids, generation_config=generation_config)
954
+ response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
955
+ return response
956
+
957
+
958
+ class OmniAudioTokenizer(OmniPreTrainedModel):
959
+ """
960
+ Construct an audio tokenizer and decoder.
961
+ """
962
+ def __init__(self, config: OmniConfig):
963
+ super().__init__(config)
964
+ self.padding_idx = None
965
+ self.vocab_size = config.vocab_size
966
+ self.training = False
967
+ self.eval()
968
+ self.audio_model = OmniAudioEncoder(config.audio_config)
969
+ self.audio_bridge_model = OmniAudioVQBridgeTokenizer(config)
970
+ if config.vocoder_config.enable:
971
+ self.audio_decoder = OmniAudioDecoder(config)
972
+ if config.flow_matching_config.enable:
973
+ self.audio_flow_matching_decoder = OmniAudioFlowMatchingDecoder(config)
974
+
975
+ def encode(self, x, encoder_length: Optional[torch.Tensor] = None,
976
+ bridge_length: Optional[torch.Tensor] = None):
977
+ audio_emb = self.audio_model(x, encoder_length)
978
+ audios_tokens = self.audio_bridge_model(audio_emb, bridge_length)
979
+ return audios_tokens
980
+
981
+ def decode(self, audio_code_ids, bridge_length: Optional[torch.Tensor] = None):
982
+ assert self.config.vocoder_config.enable, "Vocoder is not enabled in config."
983
+ audio_emb = self.audio_bridge_model.decode(audio_code_ids)
984
+ audio_dec = self.audio_decoder(
985
+ audio_emb.to(next(self.audio_decoder.parameters())), bridge_length
986
+ )
987
+ if self.config.flow_matching_config.enable:
988
+ if self.config.flow_matching_config.use_hidden_states_before_dconv2:
989
+ hidden_states, hidden_states_length = (
990
+ self.audio_flow_matching_decoder.unpack_hidden_states(
991
+ audio_dec.hidden_states_before_dconv2,
992
+ audio_dec.output_length_before_dconv2,
993
+ )
994
+ )
995
+ audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
996
+ hidden_states, hidden_states_length
997
+ )
998
+
999
+ else:
1000
+ audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
1001
+ audio_dec.refined_mel, audio_dec.mel_length
1002
+ )
1003
+ return audio_flow_matching_decoder_ret
1004
+ else:
1005
+ return audio_dec
1006
+
1007
+ @torch.no_grad()
1008
+ def forward(self, audios, encoder_length: Optional[torch.Tensor] = None, bridge_length: Optional[torch.Tensor] = None):
1009
+ self.eval()
1010
+ audios_tokens = self.encode(audios, encoder_length, bridge_length)
1011
+ return audios_tokens
processor_omni.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import re, ujson, os, sys, fire, glob, random, time, json
3
+ import numpy as np
4
+ import io
5
+ import torch
6
+ from torch.utils.data import default_collate
7
+ import torchaudio
8
+ from typing import *
9
+ from dataclasses import dataclass, field
10
+ import transformers
11
+ from transformers.modeling_outputs import ModelOutput
12
+ from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
13
+ from functools import lru_cache
14
+ from io import BytesIO
15
+ from PIL import Image
16
+ import concurrent.futures as cf
17
+ from transformers.image_transforms import resize, center_crop, get_resize_output_image_size
18
+ from transformers.image_utils import PILImageResampling
19
+ from PIL import Image, ImageOps
20
+ from PIL import ImageFile
21
+ torch.set_num_threads(1) # 限制torch的线程数 否则可能会卡住
22
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
23
+ import base64
24
+ from decord import VideoReader, cpu
25
+ import cv2
26
+ import av
27
+ import imagesize
28
+ import tempfile
29
+ import math
30
+ from multiprocessing import Pool
31
+ from cairosvg import svg2png
32
+ import hashlib
33
+
34
+ IMAGE_FACTOR = 28
35
+ MIN_PIXELS = 4 * 28 * 28
36
+ MAX_PIXELS = 16384 * 28 * 28
37
+ MAX_RATIO = 200
38
+
39
+ VIDEO_MIN_PIXELS = 128 * 28 * 28
40
+ VIDEO_MAX_PIXELS = 768 * 28 * 28
41
+ VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
42
+ FRAME_FACTOR = 2
43
+ FPS = 2.0
44
+ FPS_MIN_FRAMES = 4
45
+ FPS_MAX_FRAMES = 768
46
+
47
+ def round_by_factor(number: int, factor: int) -> int:
48
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
49
+ return round(number / factor) * factor
50
+
51
+
52
+ def ceil_by_factor(number: int, factor: int) -> int:
53
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
54
+ return math.ceil(number / factor) * factor
55
+
56
+
57
+ def floor_by_factor(number: int, factor: int) -> int:
58
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
59
+ return math.floor(number / factor) * factor
60
+
61
+
62
+ def smart_resize(
63
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
64
+ ) -> tuple[int, int]:
65
+ """
66
+ Rescales the image so that the following conditions are met:
67
+
68
+ 1. Both dimensions (height and width) are divisible by 'factor'.
69
+
70
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
71
+
72
+ 3. The aspect ratio of the image is maintained as closely as possible.
73
+ """
74
+ if max(height, width) / min(height, width) > MAX_RATIO:
75
+ raise ValueError(
76
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
77
+ )
78
+ h_bar = max(factor, round_by_factor(height, factor))
79
+ w_bar = max(factor, round_by_factor(width, factor))
80
+ if h_bar * w_bar > max_pixels:
81
+ beta = math.sqrt((height * width) / max_pixels)
82
+ h_bar = floor_by_factor(height / beta, factor)
83
+ w_bar = floor_by_factor(width / beta, factor)
84
+ elif h_bar * w_bar < min_pixels:
85
+ beta = math.sqrt(min_pixels / (height * width))
86
+ h_bar = ceil_by_factor(height * beta, factor)
87
+ w_bar = ceil_by_factor(width * beta, factor)
88
+ return h_bar, w_bar
89
+
90
+
91
+ def split_text(text, match_regex):
92
+ matches = list(re.finditer(match_regex, text))
93
+ # 初始化结果列表
94
+ result = []
95
+ match_flag_list = []
96
+ # 上一个匹配的结束位置
97
+ last_end = 0
98
+ # 遍历所有匹配项
99
+ for match in matches:
100
+ # 添加匹配项之前的部分
101
+ if text[last_end:match.start()]:
102
+ result.append(text[last_end:match.start()])
103
+ match_flag_list.append(False)
104
+ # 添加匹配项
105
+ result.append(match.group(0))
106
+ match_flag_list.append(True)
107
+ # 更新上一个匹配的结束位置
108
+ last_end = match.end()
109
+ # 添加最后一个匹配项之后的部分
110
+ if text[last_end:]:
111
+ result.append(text[last_end:])
112
+ match_flag_list.append(False)
113
+ return result, match_flag_list
114
+
115
+
116
+ def read_video(image_path, max_frame_number, decode_way):
117
+ if decode_way=='1fps':
118
+ try:
119
+ # print(image_path)
120
+ vr = VideoReader(image_path, ctx=cpu(0))
121
+ total_frame_num = len(vr)
122
+ fps = round(vr.get_avg_fps())
123
+ frame_idx = [i for i in range(0, len(vr), fps)]
124
+ frames = vr.get_batch(frame_idx).asnumpy()
125
+ cnt = len(frames)
126
+ frame_times = range(cnt)
127
+ except Exception as e:
128
+ print(image_path)
129
+ print('error is', e)
130
+ return None
131
+ elif decode_way=='key':
132
+ try:
133
+ with av.open(image_path) as container:
134
+ stream = container.streams.video[0]
135
+ stream.codec_context.skip_frame = 'NONKEY'
136
+ frames = []
137
+ frame_times = []
138
+ fps = int(stream.average_rate)
139
+ cnt = 0
140
+ for frame in container.decode(stream): # 关键帧存成image patch
141
+ image = np.array(frame.to_image())
142
+ frames.append(image)
143
+ frame_time = int(frame.time)
144
+ frame_times.append(frame_time)
145
+ cnt += 1
146
+ except Exception as e:
147
+ print('error is', e)
148
+ return None
149
+ if frames is None or len(frames)==0:
150
+ return None
151
+ if len(frames)>max_frame_number and max_frame_number>0:
152
+ # 生成14个均匀间隔的索引
153
+ indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int)
154
+ # 根据索引获取对应元素
155
+ frames = frames[indices]
156
+ frame_times = frame_times[indices]
157
+ return frames, frame_times
158
+
159
+
160
+ class OmniImageProcessor:
161
+ def __init__(self, config, **kwargs):
162
+ self.config = config # visual_config
163
+ self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56
164
+ self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280
165
+ self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14
166
+ self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2
167
+ self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
168
+ self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2
169
+
170
+ def image_transform(self, strseq, return_mm_data = True):
171
+ image = None
172
+ if isinstance(strseq, str):
173
+ if return_mm_data:
174
+ image = Image.open(strseq).convert("RGB")
175
+ else:
176
+ try:
177
+ image = Image.open(BytesIO(strseq)).convert("RGB")
178
+ except:
179
+ image = Image.open(BytesIO(svg2png(bytestring=strseq))).convert("RGB") # interleaved有的是矢量图,需要转换
180
+
181
+ image = np.array(image.convert("RGB")) # 这一步首先将图像转换为 RGB 格式,确保图像有三个通道(R、G、B)。然后使用 np.array() 将其转换为 NumPy 数组,方便后续处理。
182
+ image_org_size = image.shape[:2] # 这里保存了图像的原始大小(高度和宽度),image.shape 返回图像的形状 (高度, 宽度, 通道数),而 image.shape[:2] 提取了前两个值,即原始的高度和宽度。这个信息可以用于后续的对比或其他处理。
183
+
184
+ # resize, crop, scale, normalize
185
+ # 输出一个新的尺寸,这个尺寸通常是 (宽度, 高度) 格式,用于后续的图像调整操作,如缩放或裁剪。
186
+ resized_height, resized_width = smart_resize(
187
+ image_org_size[0], image_org_size[1],
188
+ factor=self.patch_size * self.spatial_merge_size,
189
+ min_pixels=self.min_pixels,
190
+ max_pixels=self.max_pixels,
191
+ )
192
+ output_size = (resized_height, resized_width)
193
+
194
+ # 使用 resize 函数将图像调整到 output_size 大小。PILImageResampling.BICUBIC 指定使用双三次插值法来进行图像缩放,这种方法通常能够提供较好的图像质量。
195
+ # image: 输入的图像数据,可以是 NumPy 数组或 PIL 图像对象;output_size: 目标大小,通常是一个二元组 (宽度, 高度)。这个尺寸可以是图像的绝对大小,也可以是相对于原始图像的比例;
196
+ # resample: 可选的重采样方法,通常用于确定如何插值像素。例如,PILImageResampling.BICUBIC 表示使用双三次插值法,这是一种平滑的插值方法,常用于图像缩放。
197
+ image = resize(image, output_size, PILImageResampling.BICUBIC)
198
+ img = image.transpose(2, 0, 1)
199
+ # 对图像进行归一化和标准化处理
200
+ image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis]
201
+ # 处理成patch
202
+ patches = image[np.newaxis, :]
203
+ if patches.shape[0] == 1:
204
+ patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
205
+ channel = patches.shape[1]
206
+ grid_t = patches.shape[0] // self.temporal_patch_size
207
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
208
+ patches = patches.reshape(
209
+ grid_t,
210
+ self.temporal_patch_size,
211
+ channel,
212
+ grid_h // self.spatial_merge_size,
213
+ self.spatial_merge_size,
214
+ self.patch_size,
215
+ grid_w // self.spatial_merge_size,
216
+ self.spatial_merge_size,
217
+ self.patch_size,
218
+ )
219
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
220
+ flatten_patches = patches.reshape(
221
+ grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
222
+ )
223
+
224
+ return flatten_patches, image_org_size, (grid_t, grid_h, grid_w)
225
+
226
+
227
+ class OmniAudioProcessor:
228
+ # 包含基本的音频特征抽取模块 + 输入数据解析模块
229
+ def __init__(
230
+ self,
231
+ config, # audio processor config
232
+ **kwargs
233
+ ):
234
+ # make sure you have install 'conda install -c conda-forge 'ffmpeg<7'' for torchaudio
235
+ assert(len(torchaudio.list_audio_backends()) > 0)
236
+ self.config = config
237
+ self.mel_filters = mel_filter_bank(
238
+ num_frequency_bins=1 + self.config.n_fft // 2,
239
+ num_mel_filters=self.config.num_mel_bins,
240
+ min_frequency=0.0,
241
+ max_frequency=self.config.sampling_rate / 2.0,
242
+ sampling_rate=self.config.sampling_rate,
243
+ norm="slaney",
244
+ mel_scale="slaney",
245
+ )
246
+ self.window = torch.hann_window(self.config.n_fft)
247
+
248
+ @staticmethod
249
+ def dynamic_range_compression(x, C=1, clip_val=1e-6):
250
+ return torch.log(torch.clamp(x, min=clip_val) * C)
251
+
252
+ @staticmethod
253
+ def zero_mean_unit_var_norm(x):
254
+ return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
255
+
256
+ def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False):
257
+ metadata = torchaudio.info(uri) # sample_rate, num_frames, num_channels, bits_per_sample, encoding=PCM_S
258
+ assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
259
+ waveform_tensor, _ = torchaudio.load(uri, normalize=True)
260
+ if self.config.sampling_rate != metadata.sample_rate:
261
+ waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate, lowpass_filter_width=128)
262
+
263
+ # downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
264
+ if metadata.num_channels > 1:
265
+ waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
266
+
267
+ # normalized to zero mean
268
+ if do_normalize:
269
+ waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
270
+
271
+ if return_tensors: # (channels, samples)
272
+ return waveform_tensor
273
+ else:
274
+ return waveform_tensor.numpy()
275
+
276
+ def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
277
+ channels, wave_samples = waveform.shape
278
+ max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate
279
+ if wave_samples <= max_audio_samples or self.config.split_overlap < 0:
280
+ return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
281
+
282
+ split_waveform, start = [], 0
283
+ while start < wave_samples: # 统一按秒数对齐overlap
284
+ if start > int(self.config.sampling_rate * self.config.split_overlap):
285
+ start -= int(self.config.sampling_rate * self.config.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
286
+ end = min(start + max_audio_samples, wave_samples)
287
+ if end - start>= self.config.n_fft: # 保证至少有一帧数据
288
+ split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
289
+ start = end
290
+ return split_waveform
291
+
292
+ @classmethod
293
+ def inference_output_length(cls, config, input_length):
294
+ # for whisper + bridge
295
+ kernel_size = config.kernel_size
296
+ stride_size = config.stride_size
297
+ avg_pooler = config.avg_pooler
298
+ encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
299
+ encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
300
+ if avg_pooler > 1:
301
+ bridge_length = encoder_length // avg_pooler
302
+ return encoder_length, bridge_length
303
+
304
+ def extract_fbank_features(self, waveform):
305
+ # ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
306
+ channels, wave_samples = waveform.shape
307
+ assert(wave_samples >= self.config.n_fft)
308
+ valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1)
309
+ if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate:
310
+ waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0)
311
+ else:
312
+ waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate]
313
+
314
+ # window = torch.hann_window(self.config.n_fft)
315
+ stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=self.window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
316
+ magnitudes = stft[..., :-1].abs() ** 2
317
+
318
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
319
+ mel_spec = mel_filters.T @ magnitudes
320
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
321
+ if waveform.dim() == 2:
322
+ max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
323
+ log_spec = torch.maximum(log_spec, max_val - 8.0)
324
+ else:
325
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
326
+ log_spec = (log_spec + 4.0) / 4.0
327
+
328
+ log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
329
+ log_spec[:, valid_frame_nums:] = 0.0 # pad0
330
+
331
+ return log_spec, valid_frame_nums
332
+
333
+ def data_augment(self, feature: np.array, input_length, training=True):
334
+ # reference https://arxiv.org/pdf/1904.08779
335
+ def mask_start_indices(input_length, mask_length, min_masks, mask_prob):
336
+ num_masked_span = int(mask_prob * input_length / mask_length + random.random())
337
+ num_masked_span = max(num_masked_span, min_masks)
338
+ start_indices = list(range(input_length - mask_length))
339
+ random.shuffle(start_indices)
340
+ start_indices = start_indices[:num_masked_span]
341
+ return start_indices
342
+
343
+ if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0):
344
+ return feature
345
+ if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1:
346
+ return feature
347
+ if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1:
348
+ return feature
349
+
350
+ if self.config.mask_time_prob > 0:
351
+ start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob)
352
+ for start_idx in start_indices:
353
+ feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0
354
+ if self.config.mask_feature_prob > 0:
355
+ start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob)
356
+ for start_idx in start_indices:
357
+ feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0
358
+
359
+ return feature
360
+
361
+ @dataclass
362
+ class OmniProcessorOutput(ModelOutput):
363
+ input_ids: Optional["List|torch.Tensor"] = None
364
+ labels: Optional["List|torch.Tensor"] = None
365
+ attention_mask: Optional["List|torch.Tensor"] = None
366
+ position_ids: Optional["List|torch.Tensor"] = None
367
+ seqlens: Optional["List|torch.Tensor"] = None # 需要配合Omni Modeling使用
368
+ # audio fields
369
+ audios: Optional["List|torch.Tensor"] = None
370
+ encoder_length: Optional["List|torch.Tensor"] = None
371
+ bridge_length: Optional["List|torch.Tensor"] = None
372
+ # image fields
373
+ images: Optional["List|torch.Tensor"] = None
374
+ patch_nums: Optional["List|torch.Tensor"] = None
375
+ images_size: Optional["List|torch.Tensor"] = None
376
+ crop_size: Optional["List|torch.Tensor"] = None
377
+ images_grid: Optional["List|torch.Tensor"] = None
378
+ # video fields
379
+ videos: Optional["List|torch.Tensor"] = None
380
+ videos_patch_nums: Optional["List|torch.Tensor"] = None
381
+ videos_size: Optional["List|torch.Tensor"] = None
382
+ videos_crop_size: Optional["List|torch.Tensor"] = None
383
+ videos_grid: Optional["List|torch.Tensor"] = None
384
+ # processor fields
385
+ raw_text: Optional[str] = None
386
+ index: Optional[int] = None
387
+
388
+ def concatenate(self, other): # 仅限list使用
389
+ def concat_one(a, b):
390
+ if a is None and b is None:
391
+ return None
392
+ elif a is None and b is not None:
393
+ return b
394
+ elif a is not None and b is None:
395
+ return a
396
+ else:
397
+ return a + b
398
+ return OmniProcessorOutput(
399
+ input_ids=concat_one(self.input_ids, other.input_ids),
400
+ labels=concat_one(self.labels, other.labels),
401
+ audios=concat_one(self.audios, other.audios),
402
+ encoder_length=concat_one(self.encoder_length, other.encoder_length),
403
+ bridge_length=concat_one(self.bridge_length, other.bridge_length),
404
+ images=concat_one(self.images, other.images),
405
+ images_grid=concat_one(self.images_grid, other.images_grid),
406
+ patch_nums=concat_one(self.patch_nums, other.patch_nums),
407
+
408
+ videos=concat_one(self.videos, other.videos),
409
+ videos_grid=concat_one(self.videos_grid, other.videos_grid),
410
+ videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums),
411
+
412
+ position_ids=concat_one(self.position_ids, other.position_ids),
413
+ seqlens=concat_one(self.seqlens, other.seqlens),
414
+ images_size=concat_one(self.images_size, other.images_size),
415
+ videos_size=concat_one(self.videos_size, other.videos_size),
416
+ index = self.index # concat保持index不变
417
+ )
418
+
419
+ class OmniMMProcessor(object):
420
+ def __init__(self,
421
+ tokenizer: transformers.PreTrainedTokenizer,
422
+ config,
423
+ training,
424
+ relative_path=None,
425
+ parallel=None,
426
+ **kwargs,
427
+ ):
428
+ self.tokenizer = tokenizer
429
+ self.config = config
430
+ self.audio_processor = OmniAudioProcessor(config.audio_config)
431
+ self.visual_processor = None
432
+ if hasattr(config, "visual_config"):
433
+ self.visual_processor = OmniImageProcessor(config.visual_config)
434
+ self.video_processor = None
435
+ if hasattr(config, "video_config"):
436
+ self.video_processor = OmniImageProcessor(config.video_config)
437
+ self.training = training
438
+ self.relative_path = relative_path
439
+ self.parallel = parallel
440
+ # audio tag
441
+ self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id)
442
+ self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id)
443
+ self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id)
444
+ self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id)
445
+ self.audiogen_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_start_token_id)
446
+ self.audiogen_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_end_token_id)
447
+ # image tag
448
+ self.image_start_tag = None
449
+ self.image_end_tag = None
450
+ self.image_pad_tag = None
451
+ self.video_start_tag = None
452
+ self.video_end_tag = None
453
+ # videoframe tag只是为了兼容图片帧作为输入的情况,没有token id,在抽取视频帧的时候,会将这个替换成image tag的start、end
454
+ self.videoframe_start_tag = '<videoframe_start_omni>'
455
+ self.videoframe_end_tag = '<videoframe_end_omni>'
456
+ if hasattr(self.config, "visual_config"):
457
+ # special token for start_tag
458
+ self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id)
459
+ # special token for end_tag
460
+ self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id)
461
+ # special token for pad_tag
462
+ self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id)
463
+ self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id)
464
+ self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id)
465
+ if hasattr(self.config, "video_config"):
466
+ self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id)
467
+ self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id)
468
+ self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id)
469
+ self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id)
470
+ self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id)
471
+ self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id)
472
+
473
+ self.frame_pattern = getattr(self.config.video_config, 'frame_pattern', '<frame>')
474
+
475
+
476
+ # @lru_cache(maxsize=1024)
477
+ def _get_audio(self, audio_info):
478
+ try:
479
+ audio_info = ujson.loads(audio_info)
480
+ if 'path' in audio_info.keys():
481
+ audio_uri = None
482
+ if os.path.exists(audio_info['path']):
483
+ audio_uri = audio_info['path']
484
+ elif self.relative_path is not None:
485
+ audio_uri = os.path.join(self.relative_path, audio_info['path'].lstrip('/'))
486
+ if not os.path.exists(audio_uri):
487
+ audio_uri = None
488
+ if audio_uri is not None:
489
+ waveform = self.audio_processor.load_audio_waveform(audio_uri, True)
490
+ waveforms = self.audio_processor.split_with_overlap(waveform)
491
+
492
+ ret = OmniProcessorOutput() # 默认初始化 audios字段为None
493
+ for i, waveform in enumerate(waveforms): #(zip(waveforms,vocoder_waveforms)):
494
+ audio, input_length = self.audio_processor.extract_fbank_features(waveform)
495
+ audio = self.audio_processor.data_augment(audio, input_length, self.training)
496
+ encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length)
497
+ if bridge_length <= 0:
498
+ continue
499
+ current_ret = OmniProcessorOutput(
500
+ audios=[audio[:,:input_length]],
501
+ encoder_length=[encoder_length],
502
+ bridge_length=[bridge_length],
503
+ )
504
+ if ret.audios is None:
505
+ ret = current_ret
506
+ else:
507
+ ret = ret.concatenate(current_ret) # 拼接多个切片
508
+ return ret
509
+ else:
510
+ raise ValueError("can not find path in audio_info")
511
+ except Exception as e:
512
+ print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info)))
513
+ return OmniProcessorOutput()
514
+
515
+ # @lru_cache(maxsize=1024)
516
+ def _get_image(self, image_info):
517
+ try:
518
+ try:
519
+ image_info = ujson.loads(image_info)
520
+ except:
521
+ image_info = re.sub(r"(?<!\\)'", '"', image_info)
522
+ image_info = ujson.loads(image_info)
523
+ if 'base64' in image_info.keys():
524
+ image_data = base64.b64decode(image_info['base64'])
525
+ image_feat, org_size, image_list = self.visual_processor.image_transform(image_data)
526
+ elif 'local' in image_info.keys():
527
+ image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'])
528
+ elif 'path' in image_info.keys() and os.path.exists(image_info['path']):
529
+ image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['path'])
530
+ elif 'url' in image_info.keys():
531
+ image_bytes = self._get_vision_obj_byte('url', image_info['url'])
532
+ image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes)
533
+ else:
534
+ raise ValueError("can not find any path in image_info")
535
+
536
+ merge_length = self.visual_processor.merge_size**2
537
+ patch_nums = np.array(image_list).prod() // merge_length
538
+
539
+ if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
540
+ return OmniProcessorOutput(
541
+ images=[image_feat],
542
+ patch_nums=[patch_nums],
543
+ crop_size=[image_list],
544
+ images_size= [org_size],
545
+ images_grid=[image_list]
546
+ )
547
+ else:
548
+ print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info)))
549
+ return OmniProcessorOutput()
550
+
551
+ except Exception as e:
552
+ print("**** get image error: {}, info: {} *****".format(str(e), str(image_info)))
553
+ return OmniProcessorOutput()
554
+
555
+ # @lru_cache(maxsize=1024)
556
+ def _get_video_frame(self, video_frame_infos):
557
+ try:
558
+ pattern = r'\{.*?\}'
559
+ matches = re.findall(pattern, video_frame_infos)
560
+ ret = OmniProcessorOutput()
561
+ # 逐个解析
562
+ for match in matches:
563
+ video_frame_info = ujson.loads(match)
564
+ # video_frame_info = ujson.loads(video_frame_info)
565
+ if 'local' in video_frame_info.keys():
566
+ image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'])
567
+ elif 'path' in video_frame_info.keys() and os.path.exists(video_frame_info['path']):
568
+ image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['path'])
569
+ else:
570
+ raise ValueError("can not find any path in video_info")
571
+
572
+ merge_length = self.video_processor.merge_size**2
573
+ patch_nums = np.array(image_list).prod() // merge_length
574
+
575
+ if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
576
+ ret = ret.concatenate(
577
+ OmniProcessorOutput(
578
+ videos=[image_feat],
579
+ videos_patch_nums=[patch_nums],
580
+ videos_crop_size=[image_list],
581
+ videos_size= [org_size],
582
+ videos_grid=[image_list]
583
+ )
584
+ )
585
+ else:
586
+ print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info)))
587
+ return ret
588
+
589
+ except Exception as e:
590
+ print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info)))
591
+ return OmniProcessorOutput()
592
+
593
+ # 读取视频
594
+ def _get_vision_obj_byte(self, source, path):
595
+ vision_obj_byte = None
596
+ if source == "local":
597
+ if os.path.exists(path):
598
+ vision_obj_byte = open(path, "rb").read()
599
+ else:
600
+ vision_obj_byte = None
601
+ if source == "base64":
602
+ vision_obj_byte = base64.b64decode(path)
603
+ if source == "url":
604
+ vision_obj_byte = requests.get(url=path).content
605
+ return vision_obj_byte
606
+
607
+ # 将视频切分为帧,保存至子目录中
608
+ def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"):
609
+ if decode_way=='1fps':
610
+ frame_suffix = f'_frames'
611
+ elif decode_way=='key':
612
+ frame_suffix = f'_keyframes'
613
+ else:
614
+ raise ValueError('unvalid decode way!!!')
615
+
616
+ server = "local"
617
+ if 'local' in video_info.keys():
618
+ # 本地路径
619
+ video_path = video_info['local']
620
+ # 帧保存本地路径
621
+ frame_path = video_path.split('.')[0] + frame_suffix
622
+ mm_obj_byte = self._get_vision_obj_byte('local', video_path)
623
+ elif 'base64' in video_info.keys():
624
+ md5 = hashlib.md5(video_info['base64'].encode('utf-8')).hexdigest()
625
+ if self.relative_path is not None:
626
+ video_path = os.path.join(self.relative_path, md5)
627
+ else:
628
+ video_path = os.path.join(os.getcwd(), md5)
629
+ frame_path = md5 + frame_suffix
630
+ mm_obj_byte = self._get_vision_obj_byte('base64', video_info['base64'])
631
+ elif 'url' in video_info.keys():
632
+ md5 = hashlib.md5(video_info['url'].encode('utf-8')).hexdigest()
633
+ if self.relative_path is not None:
634
+ video_path = os.path.join(self.relative_path, md5)
635
+ else:
636
+ video_path = os.path.join(os.getcwd(), md5)
637
+ frame_path = md5 + frame_suffix
638
+ mm_obj_byte = self._get_vision_obj_byte('url', video_info['url'])
639
+ else:
640
+ raise ValueError('unvalid video server !!!')
641
+ return ""
642
+
643
+ if mm_obj_byte is None: # 未读取到视频文件
644
+ return ""
645
+ if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0:
646
+ # 保存帧
647
+ os.makedirs(frame_path, exist_ok=True)
648
+ frames, frame_times = read_video(io.BytesIO(mm_obj_byte), max_frame_number=-1, decode_way=decode_way) #读取全部帧
649
+ for frame_idx, frame in enumerate(frames):
650
+ output_filename = os.path.join(frame_path, f"{frame_times[frame_idx]}.jpg")
651
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
652
+ cv2.imwrite(output_filename, frame)
653
+ frame_paths = os.listdir(frame_path)
654
+
655
+ # 选取帧
656
+ frame_times = [int(filename.split('/')[-1].replace('.jpg', '')) for filename in frame_paths if filename.endswith('.jpg')] # 文件名对应秒数
657
+ frame_times.sort() #从小到大排序
658
+ frame_number = len(frame_times)
659
+ if frame_number > max_frame_number:
660
+ indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int)
661
+ else:
662
+ indices = np.linspace(0, frame_number - 1, frame_number, dtype=int)
663
+ # 拼接模式
664
+ replace_str = ""
665
+ for frame_idx, idx in enumerate(indices):
666
+ frame_time = frame_times[idx] # frame_time表示帧对应的时间 单位为s 同时也是存储的文件名
667
+ frame_dict = {"local": os.path.join(frame_path, f'{frame_time}.jpg')}
668
+ frame_str = self.frame_pattern.format(frame_idx) if '{}' in self.frame_pattern else self.frame_pattern # {}对应的是第几张图片
669
+ frame_str = frame_str.replace('<TIMEIDX>', str(frame_time)) # TIMEIDX对应的是第几秒
670
+ frame_str = frame_str.replace('<TIMESTAMP>', time.strftime("%H:%M:%S", time.gmtime(frame_time))) # TIMESTAMP对应的是时间戳
671
+ frame_str = frame_str.replace('<frame>', f'{self.image_start_tag}{json.dumps(frame_dict)}{self.image_end_tag}')
672
+ replace_str += frame_str
673
+
674
+ return replace_str
675
+
676
+ def sample_frame(self,frames_str,max_frame = 32):
677
+ def uniform_sample(lst, num_samples):
678
+ if num_samples > len(lst):
679
+ return lst
680
+ interval = len(lst) / num_samples
681
+ samples = [lst[int(i * interval)] for i in range(num_samples)]
682
+ return samples
683
+ p = rf'({self.image_start_tag}.*?{self.image_end_tag})'
684
+ frames_str_split = re.split(p,frames_str)
685
+ frame_idxs = [idx for idx in range(len(frames_str_split)) if self.image_start_tag in frames_str_split[idx]]
686
+ sample_frame_idxs = set(uniform_sample(frame_idxs, max_frame))
687
+ return ''.join([item for idx,item in enumerate(frames_str_split) if idx in sample_frame_idxs or self.image_start_tag not in frames_str_split[idx]])
688
+
689
+ def _get_video_frame_str(self, video_info):
690
+ try:
691
+ if self.videoframe_start_tag in video_info:#如果是以视频帧的形式表示一个视频,则替换成image tag
692
+ frames_str = video_info
693
+ frames_str = frames_str.replace(self.videoframe_start_tag,self.image_start_tag).replace(self.videoframe_end_tag,self.image_end_tag)
694
+ return self.sample_frame(frames_str, max_frame = self.config.video_config.max_frame_num)
695
+ video_info = ujson.loads(video_info)
696
+ # 获取包含多帧图像路径的字符���,最大帧数量max_frame_number
697
+ frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way)
698
+ return frames_str
699
+ except Exception as e:
700
+ print("**** get video error: {}, info: {} *****".format(str(e), str(video_info)))
701
+ return ""
702
+
703
+ def _replace_image(self, image_text):
704
+ image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
705
+ ret = self._get_image(image_info) # 重复取结果 cached result
706
+ if ret.patch_nums is None:
707
+ return ''
708
+ return ret, self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag
709
+
710
+ def _replace_video_frame(self, video_frame_text):
711
+ video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text)
712
+ ret = self._get_video_frame(video_frame_info) # 重复取结果 cached result
713
+ if ret.videos_patch_nums is None:
714
+ return ''
715
+ video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))]
716
+ return ret, ''.join(video_frame_str)
717
+
718
+
719
+ def split_multimodal_chunk(self, text_list, mm_label_list, trainable_list, mtype='audio'):
720
+ # 抽取text中的json格式音频/图像信息,读取并转化为特征,同时估计encoder token数,填入对应数量的pad token
721
+ if (self.audio_start_tag != None) and (mtype == 'audio'):
722
+ match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag,re.S)
723
+ drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag,re.S)
724
+ elif (self.image_start_tag != None) and (mtype == 'image'):
725
+ match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag,re.S)
726
+ drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag,re.S)
727
+ elif (self.audiogen_start_tag != None) and (mtype == 'audiogen'):
728
+ match_regex = re.compile(self.audiogen_start_tag + '.*?' + self.audiogen_end_tag,re.S)
729
+ drop_regex = re.compile(self.audiogen_start_tag + "|" + self.audiogen_end_tag,re.S)
730
+ elif (self.video_start_tag != None) and (mtype == 'video'):
731
+ match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag,re.S)
732
+ drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag,re.S)
733
+ else:
734
+ raise ValueError("mtype not supportted!")
735
+ new_text_list = []
736
+ new_mm_label_list = []
737
+ new_trainable_flag_list = []
738
+ for text,mm_label,trainable in zip(text_list,mm_label_list,trainable_list):
739
+ for t,m in zip(*split_text(text, match_regex)):
740
+ new_trainable_flag_list.append(trainable)
741
+ if m:
742
+ new_text_list.append(re.sub(drop_regex, '', t))
743
+ new_mm_label_list.append(mtype)
744
+ else:
745
+ new_text_list.append(t)
746
+ new_mm_label_list.append(mm_label)
747
+ return new_text_list, new_mm_label_list, new_trainable_flag_list
748
+
749
+ def process_multimodal_chunk(self, text, mm_label, trainable):
750
+ ret = OmniProcessorOutput()
751
+ if mm_label == 'audio':
752
+ ret = self._get_audio(text)
753
+ if ret.bridge_length is not None:
754
+ ret.input_ids = self.tokenizer.encode(self.audio_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audio_end_tag,add_special_tokens=False)
755
+ else:
756
+ raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
757
+ elif mm_label == 'audiogen':
758
+ ret = self._get_audio(text)
759
+ if ret.bridge_length is not None:
760
+ ret.input_ids = self.tokenizer.encode(self.audiogen_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audiogen_end_tag,add_special_tokens=False)
761
+ else:
762
+ raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
763
+ elif mm_label == 'image':
764
+ ret, input_str = self._replace_image(text)
765
+ if input_str:
766
+ ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
767
+ else:
768
+ raise ValueError("Get image data Failed at Process image chunk")
769
+ elif mm_label == 'video':
770
+ frame_str = self.video_start_tag+self._get_video_frame_str(text)+self.video_end_tag
771
+ ret, input_str = self._replace_video_frame(frame_str)
772
+ if input_str:
773
+ ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
774
+ else:
775
+ raise ValueError("Get video data Failed at Process video chunk")
776
+ elif mm_label == 'text':
777
+ ret.input_ids = self.tokenizer.encode(text, add_special_tokens=False)
778
+ if len(ret.input_ids) > self.tokenizer.model_max_length-1: # 过滤长文本
779
+ raise ValueError(f"Text too long, please check text length! 【{text[:5]+'...'*6+text[-5:]}】")
780
+ else:
781
+ raise ValueError(f"mm_label not supportted! must in ['audio', 'image', 'text'] but get {mm_label}")
782
+ return ret
783
+
784
+ def process_one(self, text, index=0, raw_only=False):
785
+ ret = OmniProcessorOutput(index=index)
786
+ all_text_list = []
787
+ all_mm_label_list = []
788
+ all_trainable_flag_list = []
789
+ text_list, match_flag = split_text(text, re.compile("<trainable_start>.*?<trainable_end>",re.S))
790
+ if len(text_list) == 1:
791
+ text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text_list[0])
792
+ all_text_list.append(text)
793
+ all_mm_label_list.append('text')
794
+ all_trainable_flag_list.append(True)
795
+ else:
796
+ for text, match in zip(text_list, match_flag):
797
+ text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text)
798
+ if text.strip() == '':
799
+ continue # 把多余的空格干掉
800
+ all_text_list.append(text)
801
+ all_mm_label_list.append('text')
802
+ all_trainable_flag_list.append(match)
803
+ # 处理多模态信息
804
+ for mtype in self.config.multimodal: # 循环获取音频 图像结果
805
+ all_text_list, all_mm_label_list, all_trainable_flag_list = self.split_multimodal_chunk(all_text_list, all_mm_label_list, all_trainable_flag_list, mtype)
806
+ if len(all_text_list) == 0:
807
+ print(f"Process {text} chunk error: No valid Text data!!!!!")
808
+ return OmniProcessorOutput(index=index)
809
+
810
+ for text, mm_label, trainable in zip(all_text_list, all_mm_label_list, all_trainable_flag_list):
811
+ try:
812
+ mret = self.process_multimodal_chunk(text, mm_label, trainable)
813
+ ret = ret.concatenate(mret)
814
+ except ValueError as e:
815
+ tt = text[:24].replace('\n','<LF>')
816
+ print(f"Process {tt if mm_label == 'text' else text} {mm_label} chunk error: {str(e)}")
817
+ return OmniProcessorOutput(index=index)
818
+
819
+ if raw_only:
820
+ ret.raw_text = self.tokenizer.decode(ret.input_ids, skip_special_tokens=False)
821
+ return ret
822
+ return ret
823
+
824
+ @torch.no_grad()
825
+ def __call__(self, example, parallel=128):
826
+ if isinstance(example, Dict):
827
+ pass
828
+ elif isinstance(example, str):
829
+ return self.process_one(example)
830
+ elif isinstance(example, List): # batch推理 异步多线程处理
831
+ with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor:
832
+ future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)]
833
+ batch_data = [key.result() for key in cf.as_completed(future_list)]
834
+ valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data])
835
+ assert(valid_num == len(batch_data)) # 推理数据严格要求数量对齐
836
+ batch_data = sorted(batch_data, key=lambda x: x.index) # 保证顺序不变
837
+
838
+ ret = OmniProcessorOutput()
839
+ for i in range(len(batch_data)):
840
+ ret = ret.concatenate(batch_data[i])
841
+ self.tokenizer.padding_side = "left"
842
+ max_len = min(max([len(x.input_ids) for x in batch_data]),self.tokenizer.model_max_length)
843
+ padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt')
844
+ ret.input_ids = padding_result["input_ids"]
845
+ ret.attention_mask = padding_result["attention_mask"] # batch推理不pack 不需要seqlens
846
+
847
+ if ret.audios is not None:
848
+ max_audios_len = max([x.shape[-1] for x in ret.audios])
849
+ ret.audios = default_collate([np.pad(x, ((0,0),(0,max_audios_len - x.shape[-1])), 'constant', constant_values=0) for x in ret.audios])
850
+
851
+ ret.encoder_length = default_collate(ret.encoder_length)
852
+ ret.bridge_length = default_collate(ret.bridge_length)
853
+
854
+ if ret.images is not None:
855
+ ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]
856
+ ret.patch_nums = default_collate(ret.patch_nums)
857
+
858
+ if ret.videos is not None:
859
+ ret.videos = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos]
860
+ ret.videos_patch_nums = default_collate(ret.videos_patch_nums)
861
+
862
+ return ret
863
+
864
+ else:
865
+ raise ValueError("example format supported yet")
sequence_parallel_utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+ from flash_attn import flash_attn_varlen_func
8
+ try:
9
+ import deepspeed.comm as dist
10
+ except:
11
+ dist = None
12
+
13
+
14
+ try:
15
+ from utils import (
16
+ get_sequence_parallel_group,
17
+ get_sequence_parallel_size,
18
+ get_sequence_parallel_rank
19
+ )
20
+ except (ModuleNotFoundError, ImportError):
21
+ # 从 utils 获取seq parallel设置,import不成功默认为不开启
22
+ get_sequence_parallel_group = lambda : None
23
+ get_sequence_parallel_size = lambda : 1
24
+ get_sequence_parallel_rank = lambda : 0
25
+
26
+
27
+ def single_all_to_all(input, scatter_idx, gather_idx, group):
28
+ seq_world_size = dist.get_world_size(group)
29
+ inp_shape = list(input.shape)
30
+ inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
31
+ if scatter_idx < 2:
32
+ input_t = input.reshape(
33
+ [seq_world_size, inp_shape[scatter_idx]] + \
34
+ inp_shape[scatter_idx + 1:]
35
+ ).contiguous()
36
+ else:
37
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
38
+ input_t = input.reshape(
39
+ [-1, seq_world_size, inp_shape[scatter_idx]] + \
40
+ inp_shape[scatter_idx + 1:]
41
+ ).transpose(0, 1).contiguous()
42
+
43
+ output = torch.empty_like(input_t)
44
+ dist.all_to_all_single(output, input_t, group=group)
45
+
46
+ # if scattering the seq-dim, transpose the heads back to the original dimension
47
+ # [sp_size, seq_len//sp_size, batch_size, head_num // sp_size, head_dim] -->
48
+ # [seq_len//sp_size,batch_size, sp_size, head_num // sp_size, head_dim]
49
+ if scatter_idx < 2:
50
+ output = output.transpose(0, 1).transpose(1, 2).contiguous()
51
+
52
+ return output.reshape(
53
+ inp_shape[: gather_idx] + \
54
+ [inp_shape[gather_idx] * seq_world_size,] + \
55
+ inp_shape[gather_idx + 1:]).contiguous()
56
+
57
+
58
+ class _SeqAllToAll(torch.autograd.Function):
59
+
60
+ @staticmethod
61
+ def forward(ctx: Any, group: 'dist.ProcessGroup', input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
62
+ ctx.group = group
63
+ ctx.scatter_idx = scatter_idx
64
+ ctx.gather_idx = gather_idx
65
+
66
+ return single_all_to_all(input, scatter_idx, gather_idx, group)
67
+
68
+ @staticmethod
69
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
70
+ return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
71
+
72
+
73
+ # import from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
74
+ # but fix some bugs for 符合训练的维度设置
75
+ class DistributedAttention(nn.Module):
76
+ """Initialization.
77
+
78
+ Arguments:
79
+ local_attention (Module): local attention with q,k,v
80
+ sequence_process_group (ProcessGroup): sequence parallel process group
81
+ scatter_idx (int): scatter_idx for all2all comm
82
+ gather_idx (int): gather_idx for all2all comm
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ local_attention: nn.Module,
88
+ sequence_process_group: 'dist.ProcessGroup',
89
+ scatter_idx: int = 2,
90
+ gather_idx: int = 0,
91
+ ) -> None:
92
+
93
+ super(DistributedAttention, self).__init__()
94
+ self.local_attn = local_attention
95
+ self.spg = sequence_process_group
96
+ self.scatter_idx = scatter_idx
97
+ self.gather_idx = gather_idx
98
+
99
+ def pad_attention_head(self, query: Tensor, key: Tensor, value: Tensor):
100
+ # 将输入的head 维度pad到sp_size的倍数
101
+ sp_size = torch.distributed.get_world_size(self.spg)
102
+ pad_size = (sp_size - query.size(1) % sp_size) % sp_size
103
+ if pad_size > 0:
104
+ # [bs, num_head, seq_len, head_dim] -> [bs, num_head+pad_size, seq_len, head_dim]
105
+ query = torch.nn.functional.pad(query, (0,0,0,0,0,pad_size), value = 0.01)
106
+ key = torch.nn.functional.pad(key, (0,0,0,0,0,pad_size), value = 0.01)
107
+ value = torch.nn.functional.pad(value, (0,0,0,0,0,pad_size),value=0.0)
108
+ return query, key, value
109
+
110
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
111
+ """ forward
112
+
113
+ Arguments:
114
+ query (Tensor): query input to the layer [batch_size, num_head, seq_len, head_dim]
115
+ key (Tensor): key input to the layer
116
+ value (Tensor): value input to the layer
117
+ args: other args
118
+
119
+ Returns:
120
+ * output (Tensor): context output
121
+ """
122
+ # TODO Merge three alltoall calls into one
123
+ # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
124
+ # [batch_size,num_head,seq_len, head_dim ]trans to [seq_len,batch_size,num_head,head_dim]
125
+ origin_num_head = query.size(1)
126
+ query, key, value = self.pad_attention_head(query,key,value)
127
+
128
+ query = query.transpose(1,2).transpose(0,1)
129
+ key = key.transpose(1,2).transpose(0,1)
130
+ value = value.transpose(1,2).transpose(0,1)
131
+ #in shape : e.g., [s/p,bs,h,head_dim]
132
+ query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
133
+ key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
134
+ value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
135
+
136
+ context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
137
+ context_layer = context_layer.transpose(0,1).contiguous()
138
+ # [seq_len, batch_size, num_head, head_dim]
139
+ output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
140
+ return output.transpose(0,1)[:,:,:origin_num_head,:]
141
+
142
+
143
+ class LocalAttention(nn.Module):
144
+ def __init__(self, hidden_size, num_heads, head_dim):
145
+ super().__init__()
146
+ self.hidden_size = hidden_size
147
+ self.num_heads = num_heads
148
+ self.head_dim = head_dim
149
+
150
+ def forward(self, q, k, v, *args, use_flash=True, **kwargs):
151
+ # input q,k,v [batch_size, num_head, seq_len, head_dim]
152
+ # output [batch_size, seq_len, num_head, head_dim]
153
+ if use_flash:
154
+ q_len, num_heads = q.shape[2], q.shape[1]
155
+ q = q.transpose(1,2).reshape(-1, num_heads, self.head_dim)
156
+ k = k.transpose(1,2).reshape(-1, num_heads, self.head_dim)
157
+ v = v.transpose(1,2).reshape(-1, num_heads, self.head_dim)
158
+ return flash_attn_varlen_func(q,k,v,*args, **kwargs).reshape(-1,q_len, num_heads, self.head_dim)
159
+ else:
160
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
161
+ attn_output = F.scaled_dot_product_attention(
162
+ q,k,v, *args, **kwargs)
163
+ attn_output = attn_output.transpose(1, 2)
164
+ return attn_output
165
+
166
+
167
+ def create_attention_layer(hidden_size, num_heads, head_dim):
168
+ if get_sequence_parallel_group() is None:
169
+ return LocalAttention(hidden_size, num_heads, head_dim)
170
+ else:
171
+ return DistributedAttention(
172
+ local_attention=LocalAttention(hidden_size, num_heads, head_dim),
173
+ sequence_process_group=get_sequence_parallel_group()
174
+ )
175
+
176
+
177
+ def get_sequence_parallel_chunk(tensor, dim=1, shift=0):
178
+ assert tensor.size(dim) % get_sequence_parallel_size() == 0
179
+ original_size = tensor.size(dim)
180
+ if shift:
181
+ tensor = tensor.split([shift, tensor.size(dim) - shift], dim=dim)[1]
182
+ if get_sequence_parallel_group() is None:
183
+ return tensor
184
+ else:
185
+ chunk_size = original_size // get_sequence_parallel_size()
186
+ return tensor.split(chunk_size, dim=dim)[get_sequence_parallel_rank()]
special_tokens_map.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>",
16
+ "<B_SYS>",
17
+ "<B_USYS>",
18
+ "<C_Q>",
19
+ "<C_A>",
20
+ "<B_FUNC>",
21
+ "<B_CODE>",
22
+ "<B_APE>",
23
+ "<function_calling>",
24
+ "<calc_start>",
25
+ "<calc_end>",
26
+ "<inner_think>",
27
+ "<audio_start_baichuan>",
28
+ "<audio_end_baichuan>",
29
+ "<audio_pad_baichuan>",
30
+ "<img_start_baichuan>",
31
+ "<img_end_baichuan>",
32
+ "<img_pad_baichuan>",
33
+ "<img_newline_baichuan>",
34
+ "<box_start_baichuan>",
35
+ "<box_end_baichuan>",
36
+ "<box_delim_baichuan>",
37
+ "<ref_start_baichuan>",
38
+ "<ref_end_baichuan>",
39
+ "<img_delim_baichuan>",
40
+ "<polygon_start_baichuan>",
41
+ "<polygon_end_baichuan>",
42
+ "<baichuan_pad_token>",
43
+ "<reserved_113>",
44
+ "<audio_delim_baichuan>",
45
+ "<video_start_baichuan>",
46
+ "<video_end_baichuan>",
47
+ "<video_palce_baichuan>",
48
+ "<audiotext_start_baichuan>",
49
+ "<audiotext_end_baichuan>",
50
+ "<audiotext_pad_baichuan>",
51
+ "<audiogen_start_baichuan>",
52
+ "<audiogen_end_baichuan>"
53
+ ],
54
+ "eos_token": {
55
+ "content": "<|endoftext|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false
60
+ },
61
+ "pad_token": {
62
+ "content": "<|endoftext|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false
67
+ }
68
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<B_SYS>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ },
189
+ "151666": {
190
+ "content": "<B_USYS>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": true
196
+ },
197
+ "151667": {
198
+ "content": "<C_Q>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": true
204
+ },
205
+ "151668": {
206
+ "content": "<C_A>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": true
212
+ },
213
+ "151669": {
214
+ "content": "<B_FUNC>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ },
221
+ "151670": {
222
+ "content": "<B_CODE>",
223
+ "lstrip": false,
224
+ "normalized": false,
225
+ "rstrip": false,
226
+ "single_word": false,
227
+ "special": true
228
+ },
229
+ "151671": {
230
+ "content": "<B_APE>",
231
+ "lstrip": false,
232
+ "normalized": false,
233
+ "rstrip": false,
234
+ "single_word": true,
235
+ "special": true
236
+ },
237
+ "151672": {
238
+ "content": "<function_calling>",
239
+ "lstrip": false,
240
+ "normalized": false,
241
+ "rstrip": false,
242
+ "single_word": true,
243
+ "special": true
244
+ },
245
+ "151673": {
246
+ "content": "<calc_start>",
247
+ "lstrip": false,
248
+ "normalized": false,
249
+ "rstrip": false,
250
+ "single_word": true,
251
+ "special": true
252
+ },
253
+ "151674": {
254
+ "content": "<calc_end>",
255
+ "lstrip": false,
256
+ "normalized": false,
257
+ "rstrip": false,
258
+ "single_word": true,
259
+ "special": true
260
+ },
261
+ "151675": {
262
+ "content": "<inner_think>",
263
+ "lstrip": false,
264
+ "normalized": false,
265
+ "rstrip": false,
266
+ "single_word": true,
267
+ "special": true
268
+ },
269
+ "151676": {
270
+ "content": "<audio_start_baichuan>",
271
+ "lstrip": false,
272
+ "normalized": false,
273
+ "rstrip": false,
274
+ "single_word": false,
275
+ "special": true
276
+ },
277
+ "151677": {
278
+ "content": "<audio_end_baichuan>",
279
+ "lstrip": false,
280
+ "normalized": false,
281
+ "rstrip": false,
282
+ "single_word": false,
283
+ "special": true
284
+ },
285
+ "151678": {
286
+ "content": "<audio_pad_baichuan>",
287
+ "lstrip": false,
288
+ "normalized": false,
289
+ "rstrip": false,
290
+ "single_word": false,
291
+ "special": true
292
+ },
293
+ "151679": {
294
+ "content": "<img_start_baichuan>",
295
+ "lstrip": false,
296
+ "normalized": false,
297
+ "rstrip": false,
298
+ "single_word": false,
299
+ "special": true
300
+ },
301
+ "151680": {
302
+ "content": "<img_end_baichuan>",
303
+ "lstrip": false,
304
+ "normalized": false,
305
+ "rstrip": false,
306
+ "single_word": false,
307
+ "special": true
308
+ },
309
+ "151681": {
310
+ "content": "<img_pad_baichuan>",
311
+ "lstrip": false,
312
+ "normalized": false,
313
+ "rstrip": false,
314
+ "single_word": false,
315
+ "special": true
316
+ },
317
+ "151682": {
318
+ "content": "<img_newline_baichuan>",
319
+ "lstrip": false,
320
+ "normalized": false,
321
+ "rstrip": false,
322
+ "single_word": false,
323
+ "special": true
324
+ },
325
+ "151683": {
326
+ "content": "<box_start_baichuan>",
327
+ "lstrip": false,
328
+ "normalized": false,
329
+ "rstrip": false,
330
+ "single_word": false,
331
+ "special": true
332
+ },
333
+ "151684": {
334
+ "content": "<box_end_baichuan>",
335
+ "lstrip": false,
336
+ "normalized": false,
337
+ "rstrip": false,
338
+ "single_word": false,
339
+ "special": true
340
+ },
341
+ "151685": {
342
+ "content": "<box_delim_baichuan>",
343
+ "lstrip": false,
344
+ "normalized": false,
345
+ "rstrip": false,
346
+ "single_word": false,
347
+ "special": true
348
+ },
349
+ "151686": {
350
+ "content": "<ref_start_baichuan>",
351
+ "lstrip": false,
352
+ "normalized": false,
353
+ "rstrip": false,
354
+ "single_word": false,
355
+ "special": true
356
+ },
357
+ "151687": {
358
+ "content": "<ref_end_baichuan>",
359
+ "lstrip": false,
360
+ "normalized": false,
361
+ "rstrip": false,
362
+ "single_word": false,
363
+ "special": true
364
+ },
365
+ "151688": {
366
+ "content": "<img_delim_baichuan>",
367
+ "lstrip": false,
368
+ "normalized": false,
369
+ "rstrip": false,
370
+ "single_word": false,
371
+ "special": true
372
+ },
373
+ "151689": {
374
+ "content": "<polygon_start_baichuan>",
375
+ "lstrip": false,
376
+ "normalized": false,
377
+ "rstrip": false,
378
+ "single_word": false,
379
+ "special": true
380
+ },
381
+ "151690": {
382
+ "content": "<polygon_end_baichuan>",
383
+ "lstrip": false,
384
+ "normalized": false,
385
+ "rstrip": false,
386
+ "single_word": false,
387
+ "special": true
388
+ },
389
+ "151691": {
390
+ "content": "<baichuan_pad_token>",
391
+ "lstrip": false,
392
+ "normalized": false,
393
+ "rstrip": false,
394
+ "single_word": false,
395
+ "special": true
396
+ },
397
+ "151692": {
398
+ "content": "<reserved_113>",
399
+ "lstrip": false,
400
+ "normalized": false,
401
+ "rstrip": false,
402
+ "single_word": false,
403
+ "special": true
404
+ },
405
+ "151693": {
406
+ "content": "<audio_delim_baichuan>",
407
+ "lstrip": false,
408
+ "normalized": false,
409
+ "rstrip": false,
410
+ "single_word": false,
411
+ "special": true
412
+ },
413
+ "151694": {
414
+ "content": "<video_palce_baichuan>",
415
+ "lstrip": false,
416
+ "normalized": false,
417
+ "rstrip": false,
418
+ "single_word": false,
419
+ "special": true
420
+ },
421
+ "151695": {
422
+ "content": "<video_start_baichuan>",
423
+ "lstrip": false,
424
+ "normalized": false,
425
+ "rstrip": false,
426
+ "single_word": false,
427
+ "special": true
428
+ },
429
+ "151696": {
430
+ "content": "<video_end_baichuan>",
431
+ "lstrip": false,
432
+ "normalized": false,
433
+ "rstrip": false,
434
+ "single_word": false,
435
+ "special": true
436
+ },
437
+ "151697": {
438
+ "content": "<audiotext_start_baichuan>",
439
+ "lstrip": false,
440
+ "normalized": false,
441
+ "rstrip": false,
442
+ "single_word": false,
443
+ "special": true
444
+ },
445
+ "151698": {
446
+ "content": "<audiotext_end_baichuan>",
447
+ "lstrip": false,
448
+ "normalized": false,
449
+ "rstrip": false,
450
+ "single_word": false,
451
+ "special": true
452
+ },
453
+ "151699": {
454
+ "content": "<audiotext_pad_baichuan>",
455
+ "lstrip": false,
456
+ "normalized": false,
457
+ "rstrip": false,
458
+ "single_word": false,
459
+ "special": true
460
+ },
461
+ "151700": {
462
+ "content": "<audiogen_start_baichuan>",
463
+ "lstrip": false,
464
+ "normalized": false,
465
+ "rstrip": false,
466
+ "single_word": false,
467
+ "special": true
468
+ },
469
+ "151701": {
470
+ "content": "<audiogen_end_baichuan>",
471
+ "lstrip": false,
472
+ "normalized": false,
473
+ "rstrip": false,
474
+ "single_word": false,
475
+ "special": true
476
+ }
477
+ },
478
+ "additional_special_tokens": [
479
+ "<|im_start|>",
480
+ "<|im_end|>",
481
+ "<|object_ref_start|>",
482
+ "<|object_ref_end|>",
483
+ "<|box_start|>",
484
+ "<|box_end|>",
485
+ "<|quad_start|>",
486
+ "<|quad_end|>",
487
+ "<|vision_start|>",
488
+ "<|vision_end|>",
489
+ "<|vision_pad|>",
490
+ "<|image_pad|>",
491
+ "<|video_pad|>",
492
+ "<B_SYS>",
493
+ "<B_USYS>",
494
+ "<C_Q>",
495
+ "<C_A>",
496
+ "<B_FUNC>",
497
+ "<B_CODE>",
498
+ "<B_APE>",
499
+ "<function_calling>",
500
+ "<calc_start>",
501
+ "<calc_end>",
502
+ "<inner_think>",
503
+ "<audio_start_baichuan>",
504
+ "<audio_end_baichuan>",
505
+ "<audio_pad_baichuan>",
506
+ "<img_start_baichuan>",
507
+ "<img_end_baichuan>",
508
+ "<img_pad_baichuan>",
509
+ "<img_newline_baichuan>",
510
+ "<box_start_baichuan>",
511
+ "<box_end_baichuan>",
512
+ "<box_delim_baichuan>",
513
+ "<ref_start_baichuan>",
514
+ "<ref_end_baichuan>",
515
+ "<img_delim_baichuan>",
516
+ "<polygon_start_baichuan>",
517
+ "<polygon_end_baichuan>",
518
+ "<baichuan_pad_token>",
519
+ "<reserved_113>",
520
+ "<audio_delim_baichuan>",
521
+ "<video_start_baichuan>",
522
+ "<video_end_baichuan>",
523
+ "<video_palce_baichuan>",
524
+ "<audiotext_start_baichuan>",
525
+ "<audiotext_end_baichuan>",
526
+ "<audiotext_pad_baichuan>",
527
+ "<audiogen_start_baichuan>",
528
+ "<audiogen_end_baichuan>"
529
+ ],
530
+ "bos_token": null,
531
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
532
+ "clean_up_tokenization_spaces": false,
533
+ "eos_token": "<|endoftext|>",
534
+ "errors": "replace",
535
+ "model_max_length": 8192,
536
+ "pad_token": "<|endoftext|>",
537
+ "split_special_tokens": false,
538
+ "tokenizer_class": "Qwen2Tokenizer",
539
+ "unk_token": null
540
+ }
vector_quantize.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, random
2
+ from torch.nn import functional as F
3
+ from torch import nn
4
+ import numpy as np
5
+ from torch.cuda.amp import autocast
6
+
7
+ def uniform_init(*shape):
8
+ t = torch.zeros(shape)
9
+ nn.init.kaiming_uniform_(t)
10
+ return t
11
+
12
+ def cdist(x, y):
13
+ x2 = torch.sum(x ** 2, dim=-1, keepdims=True) # (b, 1)
14
+ y2 = torch.sum(y ** 2, dim=-1).reshape(1, -1) # (1, c)
15
+ xy = torch.einsum('bd,cd->bc', x, y) * -2
16
+ return (x2 + y2 + xy).clamp(min=0).sqrt() # (b, c)
17
+
18
+ def get_sequence_mask(inputs, inputs_length):
19
+ if inputs.dim() == 3:
20
+ bsz, tgt_len, _ = inputs.size()
21
+ else:
22
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
23
+ sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
24
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
25
+ unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
26
+ return sequence_mask, unpacking_index
27
+
28
+
29
+ class EuclideanCodebook(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim,
33
+ codebook_size,
34
+ init_std=0.02,
35
+ ):
36
+ super().__init__()
37
+ self.init_std = init_std
38
+ self.dim = dim
39
+ self.codebook_size = codebook_size
40
+
41
+ embed = uniform_init(codebook_size, dim).to(torch.float32)
42
+ self.cluster_size = nn.Parameter(torch.ones(codebook_size))
43
+ self.embed_avg = nn.Parameter(embed.clone())
44
+ self.embed = nn.Parameter(embed)
45
+ del embed
46
+
47
+ @autocast(enabled=True, dtype=torch.float32)
48
+ @torch.no_grad()
49
+ def forward(self, x):
50
+ assert(len(x.shape) == 2)
51
+ assert(x.dtype == torch.float32)
52
+ embed = self.embed.detach().to(x.device)
53
+ dist = -cdist(x, embed) # dist((bs*sl, d), (c, d)) --> (bs*sl, c)
54
+ embed_ind = dist.argmax(dim=-1)
55
+ quantize = embed[embed_ind] # (bs*sl, d)
56
+ return quantize, embed_ind, dist
57
+
58
+ class VectorQuantize(nn.Module):
59
+ def __init__(self, config, *args, **kwargs):
60
+ super().__init__(*args, **kwargs)
61
+ self.config = config
62
+ self.codebook = EuclideanCodebook(dim=config.dim, codebook_size=config.codebook_size)
63
+
64
+ def forward(self, x, input_length):
65
+ batch_size, seq_len, _ = x.shape
66
+ mask, unpacking_index = get_sequence_mask(x, input_length)
67
+ if x.dtype != torch.float32:
68
+ x = x.to(torch.float32)
69
+ x = torch.masked_select(x, mask).reshape(-1, self.config.dim) # (bs*sl?, d)
70
+ quantize, embed_ind, _ = self.codebook(x)
71
+ quantize = torch.index_select(quantize, 0, unpacking_index).view(batch_size, seq_len, self.config.dim)
72
+ quantize = torch.where(mask, quantize, 0)
73
+ embed_ind = torch.index_select(embed_ind.reshape(-1, 1), 0, unpacking_index).view(batch_size, seq_len, 1)
74
+ embed_ind = torch.where(mask, embed_ind, -1).squeeze()
75
+ return quantize, embed_ind
76
+
77
+ def get_output_from_indices(self, indices):
78
+ return self.codebook.embed[indices]
visual_modeling_omni.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Optional, Tuple, Union
3
+ import torch, math
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ import transformers
7
+ from flash_attn import flash_attn_varlen_func
8
+ from transformers.activations import ACT2FN
9
+ from PIL import Image
10
+ import io, fire
11
+ from torch.nn import functional as F
12
+
13
+ class OmniVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel):
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.config_attn_implementation = 'flash_attention_2'
17
+ self.gradient_checkpointing = True # 强制开启
18
+ self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
19
+ self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
20
+ del self.merger
21
+
22
+ def forward(
23
+ self,
24
+ pixel_values: torch.Tensor,
25
+ grid_thw: torch.Tensor,
26
+ ):
27
+ hidden_states = pixel_values.to(self.get_dtype())
28
+ grid_thw = grid_thw.to(pixel_values.device)
29
+
30
+ hidden_states = self.patch_embed(hidden_states)
31
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
32
+
33
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
34
+ dim=0, dtype=torch.int32
35
+ )
36
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
37
+
38
+ for blk in self.blocks:
39
+ if self.gradient_checkpointing and self.training:
40
+ hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb)
41
+ else:
42
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
43
+
44
+ return hidden_states
45
+
46
+ @torch.no_grad()
47
+ def fake_input(self, device):
48
+ merge_size = max(self.merge_size, self.config.spatial_merge_size)
49
+ fake_image = torch.zeros([
50
+ 1,
51
+ self.config.temporal_patch_size,
52
+ 3,
53
+ merge_size // self.config.spatial_merge_size,
54
+ self.config.spatial_merge_size,
55
+ self.config.patch_size,
56
+ merge_size // self.config.spatial_merge_size,
57
+ self.config.spatial_merge_size,
58
+ self.config.patch_size,
59
+ ], dtype=torch.float32, device=device)
60
+ patches = fake_image.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
61
+ flatten_patches = patches.reshape(
62
+ merge_size * merge_size, 3 * self.config.temporal_patch_size * self.config.patch_size * self.config.patch_size
63
+ )
64
+ return [flatten_patches], [(1, merge_size, merge_size)], [1]
65
+
66
+
67
+ class OmniVisualBridge(nn.Module):
68
+ def __init__(self, config):
69
+ super().__init__()
70
+ self.config = config
71
+ self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
72
+ self.hidden_size = config.embed_dim * (self.merge_size**2)
73
+ self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6)
74
+ self.mlp = nn.Sequential(
75
+ nn.Linear(self.hidden_size, self.hidden_size),
76
+ nn.GELU(),
77
+ nn.Linear(self.hidden_size, config.hidden_size),
78
+ )
79
+
80
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
82
+ return x
83
+
84
+
85
+ if __name__ == '__main__':
86
+ fire.Fire()
87
+
vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
zero_to_fp32.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example: python zero_to_fp32.py . pytorch_model.bin
14
+
15
+ import argparse
16
+ import torch
17
+ import glob
18
+ import math
19
+ import os
20
+ import re
21
+ from collections import OrderedDict
22
+ from dataclasses import dataclass
23
+
24
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
25
+ # DeepSpeed data structures it has to be available in the current python environment.
26
+ from deepspeed.utils import logger
27
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
28
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
29
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
30
+
31
+
32
+ @dataclass
33
+ class zero_model_state:
34
+ buffers: dict()
35
+ param_shapes: dict()
36
+ shared_params: list
37
+ ds_version: int
38
+ frozen_param_shapes: dict()
39
+ frozen_param_fragments: dict()
40
+
41
+
42
+ debug = 0
43
+
44
+ # load to cpu
45
+ device = torch.device('cpu')
46
+
47
+
48
+ def atoi(text):
49
+ return int(text) if text.isdigit() else text
50
+
51
+
52
+ def natural_keys(text):
53
+ '''
54
+ alist.sort(key=natural_keys) sorts in human order
55
+ http://nedbatchelder.com/blog/200712/human_sorting.html
56
+ (See Toothy's implementation in the comments)
57
+ '''
58
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
59
+
60
+
61
+ def get_model_state_file(checkpoint_dir, zero_stage):
62
+ if not os.path.isdir(checkpoint_dir):
63
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
64
+
65
+ # there should be only one file
66
+ if zero_stage <= 2:
67
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
68
+ elif zero_stage == 3:
69
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
70
+
71
+ if not os.path.exists(file):
72
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
73
+
74
+ return file
75
+
76
+
77
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
78
+ # XXX: need to test that this simple glob rule works for multi-node setup too
79
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
80
+
81
+ if len(ckpt_files) == 0:
82
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
83
+
84
+ return ckpt_files
85
+
86
+
87
+ def get_optim_files(checkpoint_dir):
88
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
89
+
90
+
91
+ def get_model_state_files(checkpoint_dir):
92
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
93
+
94
+
95
+ def parse_model_states(files):
96
+ zero_model_states = []
97
+ for file in files:
98
+ state_dict = torch.load(file, map_location=device)
99
+
100
+ if BUFFER_NAMES not in state_dict:
101
+ raise ValueError(f"{file} is not a model state checkpoint")
102
+ buffer_names = state_dict[BUFFER_NAMES]
103
+ if debug:
104
+ print("Found buffers:", buffer_names)
105
+
106
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
107
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
108
+ param_shapes = state_dict[PARAM_SHAPES]
109
+
110
+ # collect parameters that are included in param_shapes
111
+ param_names = []
112
+ for s in param_shapes:
113
+ for name in s.keys():
114
+ param_names.append(name)
115
+
116
+ # update with frozen parameters
117
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
118
+ if frozen_param_shapes is not None:
119
+ if debug:
120
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
121
+ param_names += list(frozen_param_shapes.keys())
122
+
123
+ # handle shared params
124
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
125
+
126
+ ds_version = state_dict.get(DS_VERSION, None)
127
+
128
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
129
+
130
+ z_model_state = zero_model_state(buffers=buffers,
131
+ param_shapes=param_shapes,
132
+ shared_params=shared_params,
133
+ ds_version=ds_version,
134
+ frozen_param_shapes=frozen_param_shapes,
135
+ frozen_param_fragments=frozen_param_fragments)
136
+ zero_model_states.append(z_model_state)
137
+
138
+ return zero_model_states
139
+
140
+
141
+ def parse_optim_states(files, ds_checkpoint_dir):
142
+
143
+ total_files = len(files)
144
+ state_dicts = []
145
+ for f in files:
146
+ state_dict = torch.load(f, map_location=device)
147
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
148
+ # and also handle the case where it was already removed by another helper script
149
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
150
+ state_dicts.append(state_dict)
151
+
152
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
153
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
154
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
155
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
156
+
157
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
158
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
159
+ # use the max of the partition_count to get the dp world_size.
160
+
161
+ if type(world_size) is list:
162
+ world_size = max(world_size)
163
+
164
+ if world_size != total_files:
165
+ raise ValueError(
166
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
167
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
168
+ )
169
+
170
+ # the groups are named differently in each stage
171
+ if zero_stage <= 2:
172
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
173
+ elif zero_stage == 3:
174
+ fp32_groups_key = FP32_FLAT_GROUPS
175
+ else:
176
+ raise ValueError(f"unknown zero stage {zero_stage}")
177
+
178
+ if zero_stage <= 2:
179
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
180
+ elif zero_stage == 3:
181
+ # if there is more than one param group, there will be multiple flattened tensors - one
182
+ # flattened tensor per group - for simplicity merge them into a single tensor
183
+ #
184
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
185
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
186
+
187
+ fp32_flat_groups = [
188
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
189
+ ]
190
+
191
+ return zero_stage, world_size, fp32_flat_groups
192
+
193
+
194
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
195
+ """
196
+ Returns fp32 state_dict reconstructed from ds checkpoint
197
+
198
+ Args:
199
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
200
+
201
+ """
202
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
203
+
204
+ optim_files = get_optim_files(ds_checkpoint_dir)
205
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
206
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
207
+
208
+ model_files = get_model_state_files(ds_checkpoint_dir)
209
+
210
+ zero_model_states = parse_model_states(model_files)
211
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
212
+
213
+ if zero_stage <= 2:
214
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
215
+ exclude_frozen_parameters)
216
+ elif zero_stage == 3:
217
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
218
+ exclude_frozen_parameters)
219
+
220
+
221
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
222
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
223
+ return
224
+
225
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
226
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
227
+
228
+ if debug:
229
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
230
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
231
+
232
+ wanted_params = len(frozen_param_shapes)
233
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
234
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
235
+ print(f'Frozen params: Have {avail_numel} numels to process.')
236
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
237
+
238
+ total_params = 0
239
+ total_numel = 0
240
+ for name, shape in frozen_param_shapes.items():
241
+ total_params += 1
242
+ unpartitioned_numel = shape.numel()
243
+ total_numel += unpartitioned_numel
244
+
245
+ state_dict[name] = frozen_param_fragments[name]
246
+
247
+ if debug:
248
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
249
+
250
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
251
+
252
+
253
+ def _has_callable(obj, fn):
254
+ attr = getattr(obj, fn, None)
255
+ return callable(attr)
256
+
257
+
258
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
259
+ param_shapes = zero_model_states[0].param_shapes
260
+
261
+ # Reconstruction protocol:
262
+ #
263
+ # XXX: document this
264
+
265
+ if debug:
266
+ for i in range(world_size):
267
+ for j in range(len(fp32_flat_groups[0])):
268
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
269
+
270
+ # XXX: memory usage doubles here (zero2)
271
+ num_param_groups = len(fp32_flat_groups[0])
272
+ merged_single_partition_of_fp32_groups = []
273
+ for i in range(num_param_groups):
274
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
275
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
276
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
277
+ avail_numel = sum(
278
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
279
+
280
+ if debug:
281
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
282
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
283
+ # not asserting if there is a mismatch due to possible padding
284
+ print(f"Have {avail_numel} numels to process.")
285
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
286
+
287
+ # params
288
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
289
+ # out-of-core computing solution
290
+ total_numel = 0
291
+ total_params = 0
292
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
293
+ offset = 0
294
+ avail_numel = full_single_fp32_vector.numel()
295
+ for name, shape in shapes.items():
296
+
297
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
298
+ total_numel += unpartitioned_numel
299
+ total_params += 1
300
+
301
+ if debug:
302
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
303
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
304
+ offset += unpartitioned_numel
305
+
306
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
307
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
308
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
309
+ # live optimizer object, so we are checking that the numbers are within the right range
310
+ align_to = 2 * world_size
311
+
312
+ def zero2_align(x):
313
+ return align_to * math.ceil(x / align_to)
314
+
315
+ if debug:
316
+ print(f"original offset={offset}, avail_numel={avail_numel}")
317
+
318
+ offset = zero2_align(offset)
319
+ avail_numel = zero2_align(avail_numel)
320
+
321
+ if debug:
322
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
323
+
324
+ # Sanity check
325
+ if offset != avail_numel:
326
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
327
+
328
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
329
+
330
+
331
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
332
+ exclude_frozen_parameters):
333
+ state_dict = OrderedDict()
334
+
335
+ # buffers
336
+ buffers = zero_model_states[0].buffers
337
+ state_dict.update(buffers)
338
+ if debug:
339
+ print(f"added {len(buffers)} buffers")
340
+
341
+ if not exclude_frozen_parameters:
342
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
343
+
344
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
345
+
346
+ # recover shared parameters
347
+ for pair in zero_model_states[0].shared_params:
348
+ if pair[1] in state_dict:
349
+ state_dict[pair[0]] = state_dict[pair[1]]
350
+
351
+ return state_dict
352
+
353
+
354
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
355
+ remainder = unpartitioned_numel % world_size
356
+ padding_numel = (world_size - remainder) if remainder else 0
357
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
358
+ return partitioned_numel, padding_numel
359
+
360
+
361
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
362
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
363
+ return
364
+
365
+ if debug:
366
+ for i in range(world_size):
367
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
368
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
369
+
370
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
371
+ wanted_params = len(frozen_param_shapes)
372
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
373
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
374
+ print(f'Frozen params: Have {avail_numel} numels to process.')
375
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
376
+
377
+ total_params = 0
378
+ total_numel = 0
379
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
380
+ total_params += 1
381
+ unpartitioned_numel = shape.numel()
382
+ total_numel += unpartitioned_numel
383
+
384
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
385
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
386
+
387
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
388
+
389
+ if debug:
390
+ print(
391
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
392
+ )
393
+
394
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
395
+
396
+
397
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
398
+ param_shapes = zero_model_states[0].param_shapes
399
+ avail_numel = fp32_flat_groups[0].numel() * world_size
400
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
401
+ # param, re-consolidating each param, while dealing with padding if any
402
+
403
+ # merge list of dicts, preserving order
404
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
405
+
406
+ if debug:
407
+ for i in range(world_size):
408
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
409
+
410
+ wanted_params = len(param_shapes)
411
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
412
+ # not asserting if there is a mismatch due to possible padding
413
+ avail_numel = fp32_flat_groups[0].numel() * world_size
414
+ print(f"Trainable params: Have {avail_numel} numels to process.")
415
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
416
+
417
+ # params
418
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
419
+ # out-of-core computing solution
420
+ offset = 0
421
+ total_numel = 0
422
+ total_params = 0
423
+ for name, shape in param_shapes.items():
424
+
425
+ unpartitioned_numel = shape.numel()
426
+ total_numel += unpartitioned_numel
427
+ total_params += 1
428
+
429
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
430
+
431
+ if debug:
432
+ print(
433
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
434
+ )
435
+
436
+ # XXX: memory usage doubles here
437
+ state_dict[name] = torch.cat(
438
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
439
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
440
+ offset += partitioned_numel
441
+
442
+ offset *= world_size
443
+
444
+ # Sanity check
445
+ if offset != avail_numel:
446
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
447
+
448
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
449
+
450
+
451
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
452
+ exclude_frozen_parameters):
453
+ state_dict = OrderedDict()
454
+
455
+ # buffers
456
+ buffers = zero_model_states[0].buffers
457
+ state_dict.update(buffers)
458
+ if debug:
459
+ print(f"added {len(buffers)} buffers")
460
+
461
+ if not exclude_frozen_parameters:
462
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
463
+
464
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
465
+
466
+ # recover shared parameters
467
+ for pair in zero_model_states[0].shared_params:
468
+ if pair[1] in state_dict:
469
+ state_dict[pair[0]] = state_dict[pair[1]]
470
+
471
+ return state_dict
472
+
473
+
474
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
475
+ """
476
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
477
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
478
+ via a model hub.
479
+
480
+ Args:
481
+ - ``checkpoint_dir``: path to the desired checkpoint folder
482
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
483
+ - ``exclude_frozen_parameters``: exclude frozen parameters
484
+
485
+ Returns:
486
+ - pytorch ``state_dict``
487
+
488
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
489
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
490
+ the checkpoint.
491
+
492
+ A typical usage might be ::
493
+
494
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
495
+ # do the training and checkpoint saving
496
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
497
+ model = model.cpu() # move to cpu
498
+ model.load_state_dict(state_dict)
499
+ # submit to model hub or save the model to share with others
500
+
501
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
502
+ application. i.e. you will need to re-initialize the deepspeed engine, since
503
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
504
+
505
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
506
+
507
+ """
508
+ if tag is None:
509
+ latest_path = os.path.join(checkpoint_dir, 'latest')
510
+ if os.path.isfile(latest_path):
511
+ with open(latest_path, 'r') as fd:
512
+ tag = fd.read().strip()
513
+ else:
514
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
515
+
516
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
517
+
518
+ if not os.path.isdir(ds_checkpoint_dir):
519
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
520
+
521
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
522
+
523
+
524
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
525
+ """
526
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
527
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
528
+
529
+ Args:
530
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
531
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
532
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
533
+ - ``exclude_frozen_parameters``: exclude frozen parameters
534
+ """
535
+
536
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
537
+ print(f"Saving fp32 state dict to {output_file}")
538
+ torch.save(state_dict, output_file)
539
+
540
+
541
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
542
+ """
543
+ 1. Put the provided model to cpu
544
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
545
+ 3. Load it into the provided model
546
+
547
+ Args:
548
+ - ``model``: the model object to update
549
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
550
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
551
+
552
+ Returns:
553
+ - ``model`: modified model
554
+
555
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
556
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
557
+ conveniently placed for you in the checkpoint folder.
558
+
559
+ A typical usage might be ::
560
+
561
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
562
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
563
+ # submit to model hub or save the model to share with others
564
+
565
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
566
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
567
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
568
+
569
+ """
570
+ logger.info(f"Extracting fp32 weights")
571
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
572
+
573
+ logger.info(f"Overwriting model with fp32 weights")
574
+ model = model.cpu()
575
+ model.load_state_dict(state_dict, strict=False)
576
+
577
+ return model
578
+
579
+
580
+ if __name__ == "__main__":
581
+
582
+ parser = argparse.ArgumentParser()
583
+ parser.add_argument("checkpoint_dir",
584
+ type=str,
585
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
586
+ parser.add_argument(
587
+ "output_file",
588
+ type=str,
589
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
590
+ parser.add_argument("-t",
591
+ "--tag",
592
+ type=str,
593
+ default=None,
594
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
595
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
596
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
597
+ args = parser.parse_args()
598
+
599
+ debug = args.debug
600
+
601
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
602
+ args.output_file,
603
+ tag=args.tag,
604
+ exclude_frozen_parameters=args.exclude_frozen_parameters)