Upload folder using huggingface_hub
Browse files- added_tokens.json +61 -0
- audio_modeling_omni.py +658 -0
- config.json +255 -0
- configuration_omni.py +120 -0
- flow_matching.py +791 -0
- generation_config.json +6 -0
- generation_utils.py +83 -0
- matcha_components.py +189 -0
- matcha_feat.py +107 -0
- matcha_transformer.py +480 -0
- model-00001-of-00005.safetensors +3 -0
- model-00002-of-00005.safetensors +3 -0
- model-00003-of-00005.safetensors +3 -0
- model-00004-of-00005.safetensors +3 -0
- model-00005-of-00005.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_omni.py +1011 -0
- processor_omni.py +865 -0
- sequence_parallel_utils.py +186 -0
- special_tokens_map.json +68 -0
- tokenizer.json +0 -0
- tokenizer_config.json +540 -0
- vector_quantize.py +78 -0
- visual_modeling_omni.py +87 -0
- vocab.json +0 -0
- zero_to_fp32.py +604 -0
added_tokens.json
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"</tool_call>": 151658,
|
3 |
+
"<B_APE>": 151671,
|
4 |
+
"<B_CODE>": 151670,
|
5 |
+
"<B_FUNC>": 151669,
|
6 |
+
"<B_SYS>": 151665,
|
7 |
+
"<B_USYS>": 151666,
|
8 |
+
"<C_A>": 151668,
|
9 |
+
"<C_Q>": 151667,
|
10 |
+
"<audio_delim_baichuan>": 151693,
|
11 |
+
"<audio_end_baichuan>": 151677,
|
12 |
+
"<audio_pad_baichuan>": 151678,
|
13 |
+
"<audio_start_baichuan>": 151676,
|
14 |
+
"<audiogen_end_baichuan>": 151701,
|
15 |
+
"<audiogen_start_baichuan>": 151700,
|
16 |
+
"<audiotext_end_baichuan>": 151698,
|
17 |
+
"<audiotext_pad_baichuan>": 151699,
|
18 |
+
"<audiotext_start_baichuan>": 151697,
|
19 |
+
"<baichuan_pad_token>": 151691,
|
20 |
+
"<box_delim_baichuan>": 151685,
|
21 |
+
"<box_end_baichuan>": 151684,
|
22 |
+
"<box_start_baichuan>": 151683,
|
23 |
+
"<calc_end>": 151674,
|
24 |
+
"<calc_start>": 151673,
|
25 |
+
"<function_calling>": 151672,
|
26 |
+
"<img_delim_baichuan>": 151688,
|
27 |
+
"<img_end_baichuan>": 151680,
|
28 |
+
"<img_newline_baichuan>": 151682,
|
29 |
+
"<img_pad_baichuan>": 151681,
|
30 |
+
"<img_start_baichuan>": 151679,
|
31 |
+
"<inner_think>": 151675,
|
32 |
+
"<polygon_end_baichuan>": 151690,
|
33 |
+
"<polygon_start_baichuan>": 151689,
|
34 |
+
"<ref_end_baichuan>": 151687,
|
35 |
+
"<ref_start_baichuan>": 151686,
|
36 |
+
"<reserved_113>": 151692,
|
37 |
+
"<tool_call>": 151657,
|
38 |
+
"<video_end_baichuan>": 151696,
|
39 |
+
"<video_palce_baichuan>": 151694,
|
40 |
+
"<video_start_baichuan>": 151695,
|
41 |
+
"<|box_end|>": 151649,
|
42 |
+
"<|box_start|>": 151648,
|
43 |
+
"<|endoftext|>": 151643,
|
44 |
+
"<|file_sep|>": 151664,
|
45 |
+
"<|fim_middle|>": 151660,
|
46 |
+
"<|fim_pad|>": 151662,
|
47 |
+
"<|fim_prefix|>": 151659,
|
48 |
+
"<|fim_suffix|>": 151661,
|
49 |
+
"<|im_end|>": 151645,
|
50 |
+
"<|im_start|>": 151644,
|
51 |
+
"<|image_pad|>": 151655,
|
52 |
+
"<|object_ref_end|>": 151647,
|
53 |
+
"<|object_ref_start|>": 151646,
|
54 |
+
"<|quad_end|>": 151651,
|
55 |
+
"<|quad_start|>": 151650,
|
56 |
+
"<|repo_name|>": 151663,
|
57 |
+
"<|video_pad|>": 151656,
|
58 |
+
"<|vision_end|>": 151653,
|
59 |
+
"<|vision_pad|>": 151654,
|
60 |
+
"<|vision_start|>": 151652
|
61 |
+
}
|
audio_modeling_omni.py
ADDED
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, fire
|
2 |
+
from typing import Optional
|
3 |
+
import torch.distributed
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from flash_attn import flash_attn_varlen_func
|
6 |
+
from torch import nn
|
7 |
+
import numpy as np
|
8 |
+
import deepspeed
|
9 |
+
from transformers.activations import ACT2FN
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from transformers.modeling_outputs import ModelOutput
|
12 |
+
try:
|
13 |
+
from .vector_quantize import VectorQuantize
|
14 |
+
except:
|
15 |
+
from vector_quantize import VectorQuantize
|
16 |
+
|
17 |
+
from .flow_matching import (
|
18 |
+
ConditionalDecoder,
|
19 |
+
ConditionalCFM,
|
20 |
+
)
|
21 |
+
|
22 |
+
import math
|
23 |
+
import copy
|
24 |
+
|
25 |
+
def sinusoids(length, channels, max_timescale=10000):
|
26 |
+
"""Returns sinusoids for positional embedding"""
|
27 |
+
assert channels % 2 == 0
|
28 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
29 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
30 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
31 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
32 |
+
|
33 |
+
def get_sequence_mask(inputs, inputs_length):
|
34 |
+
if inputs.dim() == 3:
|
35 |
+
bsz, tgt_len, _ = inputs.size()
|
36 |
+
else:
|
37 |
+
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
|
38 |
+
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
|
39 |
+
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
|
40 |
+
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
|
41 |
+
return sequence_mask, unpacking_index
|
42 |
+
|
43 |
+
def unpack_hidden_states(hidden_states, lengths):
|
44 |
+
bsz = lengths.shape[0]
|
45 |
+
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
|
46 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
47 |
+
bsz, torch.max(lengths), hidden_states.shape[-1]
|
48 |
+
)
|
49 |
+
hidden_states = torch.where(
|
50 |
+
sequence_mask, hidden_states, 0
|
51 |
+
) # 3d (bsz, max_input_len, d)
|
52 |
+
return hidden_states
|
53 |
+
|
54 |
+
|
55 |
+
class RMSNorm(nn.Module):
|
56 |
+
def __init__(self, hidden_size, eps=1e-6):
|
57 |
+
"""
|
58 |
+
RMSNorm is equivalent to T5LayerNorm
|
59 |
+
"""
|
60 |
+
super().__init__()
|
61 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
62 |
+
self.variance_epsilon = eps
|
63 |
+
|
64 |
+
def forward(self, hidden_states):
|
65 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
66 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
67 |
+
|
68 |
+
# convert into half-precision if necessary
|
69 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
70 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
71 |
+
|
72 |
+
return self.weight * hidden_states
|
73 |
+
|
74 |
+
|
75 |
+
class OmniWhisperAttention(nn.Module):
|
76 |
+
def __init__(self, embed_dim, num_heads, causal=False):
|
77 |
+
super().__init__()
|
78 |
+
self.embed_dim = embed_dim
|
79 |
+
self.num_heads = num_heads
|
80 |
+
self.head_dim = embed_dim // num_heads
|
81 |
+
|
82 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
83 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
84 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
85 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
86 |
+
|
87 |
+
self.causal = causal
|
88 |
+
|
89 |
+
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
|
90 |
+
bsz, _ = hidden_states.size()
|
91 |
+
|
92 |
+
query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
93 |
+
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
94 |
+
value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
95 |
+
|
96 |
+
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
|
97 |
+
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
|
98 |
+
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen,
|
99 |
+
max_seqlen, causal=self.causal) # (bsz * qlen, nheads, headdim)
|
100 |
+
attn_output = attn_output.reshape(bsz, self.embed_dim)
|
101 |
+
attn_output = self.out_proj(attn_output)
|
102 |
+
return attn_output
|
103 |
+
|
104 |
+
|
105 |
+
class OmniWhisperTransformerLayer(nn.Module):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
act,
|
109 |
+
d_model,
|
110 |
+
encoder_attention_heads,
|
111 |
+
encoder_ffn_dim,
|
112 |
+
causal,
|
113 |
+
ln_type="LayerNorm",
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
self.embed_dim = d_model
|
117 |
+
self.self_attn = OmniWhisperAttention(
|
118 |
+
self.embed_dim, encoder_attention_heads, causal
|
119 |
+
)
|
120 |
+
|
121 |
+
if ln_type == "LayerNorm":
|
122 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
123 |
+
elif ln_type == "RMSNorm":
|
124 |
+
self.self_attn_layer_norm = RMSNorm(self.embed_dim)
|
125 |
+
else:
|
126 |
+
raise ValueError(f"Unknown ln_type: {ln_type}")
|
127 |
+
|
128 |
+
self.activation_fn = act
|
129 |
+
self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
|
130 |
+
self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
|
131 |
+
|
132 |
+
if ln_type == "LayerNorm":
|
133 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
134 |
+
elif ln_type == "RMSNorm":
|
135 |
+
self.final_layer_norm = RMSNorm(self.embed_dim)
|
136 |
+
else:
|
137 |
+
raise ValueError(f"Unknown ln_type: {ln_type}")
|
138 |
+
|
139 |
+
def forward(
|
140 |
+
self, hidden_states: torch.Tensor, seq_len: torch.Tensor
|
141 |
+
) -> torch.Tensor:
|
142 |
+
residual = hidden_states
|
143 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
144 |
+
hidden_states = self.self_attn(hidden_states, seq_len)
|
145 |
+
hidden_states = residual + hidden_states
|
146 |
+
residual = hidden_states
|
147 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
148 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
149 |
+
hidden_states = self.fc2(hidden_states)
|
150 |
+
hidden_states = residual + hidden_states
|
151 |
+
|
152 |
+
if (
|
153 |
+
hidden_states.dtype == torch.float16
|
154 |
+
or hidden_states.dtype == torch.bfloat16
|
155 |
+
) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
|
156 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
157 |
+
hidden_states = torch.clamp(
|
158 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
159 |
+
)
|
160 |
+
return hidden_states
|
161 |
+
|
162 |
+
|
163 |
+
class OmniAudioEncoder(nn.Module):
|
164 |
+
def __init__(self, config):
|
165 |
+
super().__init__()
|
166 |
+
config._attn_implementation = 'flash_attention_2' #
|
167 |
+
self.config = config
|
168 |
+
self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
|
169 |
+
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
170 |
+
|
171 |
+
self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
|
172 |
+
self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size,
|
173 |
+
stride=config.stride_size, padding=1)
|
174 |
+
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
|
175 |
+
|
176 |
+
self.layers = nn.ModuleList([OmniWhisperTransformerLayer(
|
177 |
+
ACT2FN[config.activation_function],
|
178 |
+
config.d_model,
|
179 |
+
config.encoder_attention_heads,
|
180 |
+
config.encoder_ffn_dim,
|
181 |
+
False) for _ in range(config.encoder_layers)])
|
182 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
183 |
+
|
184 |
+
@torch.no_grad()
|
185 |
+
def fake_input(self, device):
|
186 |
+
input_features = torch.rand([2, self.config.num_mel_bins, 10], dtype=torch.float32, device=device)
|
187 |
+
encoder_length = torch.ones([2], dtype=torch.int32, device=device) * 3
|
188 |
+
bridge_length = torch.ones([2], dtype=torch.int32, device=device)
|
189 |
+
return input_features, encoder_length, bridge_length
|
190 |
+
|
191 |
+
def forward(
|
192 |
+
self,
|
193 |
+
input_features,
|
194 |
+
output_length,
|
195 |
+
):
|
196 |
+
input_features = input_features.to(self.conv1.weight.dtype)
|
197 |
+
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
|
198 |
+
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
|
199 |
+
inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
|
200 |
+
bsz, tgt_len, _ = inputs_embeds.size()
|
201 |
+
if tgt_len < self.positional_embedding.shape[0]:
|
202 |
+
current_positional_embedding = self.positional_embedding[:tgt_len]
|
203 |
+
else:
|
204 |
+
current_positional_embedding = self.positional_embedding
|
205 |
+
hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
|
206 |
+
|
207 |
+
# packing hidden states
|
208 |
+
attention_mask, unpacking_index = get_sequence_mask(hidden_states, output_length)
|
209 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length),
|
210 |
+
self.config.d_model)
|
211 |
+
|
212 |
+
for idx, encoder_layer in enumerate(self.layers):
|
213 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
214 |
+
hidden_states = self.layer_norm(hidden_states)
|
215 |
+
# unpacking
|
216 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
|
217 |
+
hidden_states = torch.where(attention_mask, hidden_states, 0)
|
218 |
+
return hidden_states
|
219 |
+
|
220 |
+
|
221 |
+
class CasualConvTranspose1d(nn.Module): # 反卷积
|
222 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride):
|
223 |
+
super().__init__()
|
224 |
+
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
|
225 |
+
self.norm = nn.GroupNorm(1, out_channels)
|
226 |
+
self.in_channels = in_channels
|
227 |
+
self.out_channels = out_channels
|
228 |
+
|
229 |
+
def forward(self, hidden_states, input_length, output_dim=None):
|
230 |
+
kernel_size = self.conv.kernel_size[0]
|
231 |
+
stride = self.conv.stride[0]
|
232 |
+
bsz = input_length.shape[0]
|
233 |
+
|
234 |
+
if output_dim is None:
|
235 |
+
output_dim = hidden_states.dim()
|
236 |
+
if hidden_states.dim() <= 2: # unpack sequence to 3d
|
237 |
+
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, input_length)
|
238 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, torch.max(input_length),
|
239 |
+
self.in_channels)
|
240 |
+
hidden_states = torch.where(sequence_mask, hidden_states, 0) # 3d (bsz, max_input_len, d)
|
241 |
+
|
242 |
+
hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
|
243 |
+
hidden_states = self.conv(hidden_states)
|
244 |
+
hidden_states = self.norm(hidden_states)
|
245 |
+
hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
|
246 |
+
|
247 |
+
casual_padding_right = max(0, kernel_size - stride)
|
248 |
+
hidden_states = hidden_states[:, :hidden_states.shape[1] - casual_padding_right,
|
249 |
+
:]
|
250 |
+
output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
|
251 |
+
sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
|
252 |
+
if output_dim <= 2:
|
253 |
+
hidden_states = torch.masked_select(hidden_states, sequence_mask).view(-1, self.out_channels)
|
254 |
+
else:
|
255 |
+
hidden_states = torch.where(sequence_mask, hidden_states, 0)
|
256 |
+
hidden_states = hidden_states[:, :torch.max(output_length), :] # 截断到最大有效长度
|
257 |
+
return hidden_states, output_length
|
258 |
+
|
259 |
+
|
260 |
+
class MelSpecRefineNet(nn.Module):
|
261 |
+
"""
|
262 |
+
# post net, coarse to refined mel-spectrogram frames
|
263 |
+
# ref1: Autoregressive Speech Synthesis without Vector Quantization
|
264 |
+
# ref2: CosyVoice length_regulator.py
|
265 |
+
# ref3: Neural Speech Synthesis with Transformer Network https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
|
266 |
+
"""
|
267 |
+
|
268 |
+
def __init__(self, encoder_config, vocoder_config):
|
269 |
+
super().__init__()
|
270 |
+
self.encoder_config = encoder_config
|
271 |
+
self.vocoder_config = vocoder_config
|
272 |
+
|
273 |
+
layers = nn.ModuleList([])
|
274 |
+
in_channels = self.vocoder_config.num_mel_bins
|
275 |
+
for i, out_channels in enumerate(self.vocoder_config.channels[:-1]):
|
276 |
+
module = nn.Conv1d(in_channels, out_channels, 5, 1, 2) # cosyvoice kernel=3, stride=1, pad=1
|
277 |
+
in_channels = out_channels
|
278 |
+
norm = nn.GroupNorm(1, out_channels)
|
279 |
+
act = nn.Mish()
|
280 |
+
layers.extend([module, norm, act])
|
281 |
+
layers.append(nn.Conv1d(in_channels, self.vocoder_config.num_mel_bins, 1, 1)) # projector
|
282 |
+
self.layers = nn.Sequential(*layers)
|
283 |
+
|
284 |
+
def compute_output_length(self, input_length):
|
285 |
+
output_length = input_length.to(
|
286 |
+
torch.float32) * self.encoder_config.hop_length / self.encoder_config.sampling_rate
|
287 |
+
output_length = output_length * self.vocoder_config.sampling_rate / self.vocoder_config.hop_length
|
288 |
+
return output_length.to(torch.int64)
|
289 |
+
|
290 |
+
def forward(self, coarse_mel, input_length, output_length=None):
|
291 |
+
bsz, _, d = coarse_mel.shape
|
292 |
+
assert (d == self.vocoder_config.num_mel_bins)
|
293 |
+
if output_length is None or not self.training:
|
294 |
+
output_length = self.compute_output_length(input_length)
|
295 |
+
coarse_mel, default_dtype = coarse_mel[:, :torch.max(input_length), :], coarse_mel.dtype
|
296 |
+
coarse_mel = F.interpolate(coarse_mel.to(torch.float32).transpose(1, 2).contiguous(), size=output_length.max(),
|
297 |
+
mode='nearest').to(default_dtype)
|
298 |
+
refined_mel = self.layers(coarse_mel).transpose(1, 2).contiguous() # (bs, t, d)
|
299 |
+
coarse_mel = coarse_mel.transpose(1, 2) # (bs, max(output_length), d)
|
300 |
+
refined_mel += coarse_mel # residual conntection
|
301 |
+
sequence_mask, _ = get_sequence_mask(refined_mel, output_length)
|
302 |
+
coarse_mel = torch.where(sequence_mask, coarse_mel, 0)
|
303 |
+
refined_mel = torch.where(sequence_mask, refined_mel, 0)
|
304 |
+
return refined_mel, coarse_mel, output_length
|
305 |
+
|
306 |
+
|
307 |
+
@dataclass
|
308 |
+
class OmniAudioDecoderOutput(ModelOutput):
|
309 |
+
refined_mel: Optional[torch.FloatTensor] = None
|
310 |
+
coarse_mel: Optional[torch.FloatTensor] = None
|
311 |
+
mel_length: Optional[torch.Tensor] = None
|
312 |
+
hidden_states_before_dconv2: Optional[torch.FloatTensor] = None
|
313 |
+
output_length_before_dconv2: Optional[torch.Tensor] = None
|
314 |
+
|
315 |
+
|
316 |
+
class OmniAudioDecoder(nn.Module):
|
317 |
+
def __init__(self, config):
|
318 |
+
super().__init__()
|
319 |
+
self.config = config.audio_config
|
320 |
+
self.vocoder_config = config.vocoder_config
|
321 |
+
self.max_source_positions = self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length
|
322 |
+
|
323 |
+
self.dconv1 = CasualConvTranspose1d(
|
324 |
+
self.config.d_model,
|
325 |
+
self.config.d_model,
|
326 |
+
self.config.decoder_kernel_size,
|
327 |
+
self.config.avg_pooler,
|
328 |
+
)
|
329 |
+
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, self.config.d_model))
|
330 |
+
# causal transformer layers
|
331 |
+
self.layers = nn.ModuleList(
|
332 |
+
[OmniWhisperTransformerLayer(
|
333 |
+
ACT2FN[self.config.activation_function],
|
334 |
+
self.config.d_model,
|
335 |
+
self.config.decoder_attention_heads,
|
336 |
+
self.config.decoder_ffn_dim,
|
337 |
+
True # causal
|
338 |
+
) for _ in range(self.config.decoder_layers)
|
339 |
+
])
|
340 |
+
self.layer_norm = nn.LayerNorm(self.config.d_model)
|
341 |
+
self.dconv2 = CasualConvTranspose1d(
|
342 |
+
self.config.d_model,
|
343 |
+
self.vocoder_config.num_mel_bins,
|
344 |
+
self.config.decoder_kernel_size,
|
345 |
+
self.config.decoder_stride_size
|
346 |
+
)
|
347 |
+
self.post_net = MelSpecRefineNet(config.audio_config, config.vocoder_config)
|
348 |
+
self.gradient_checkpointing = True
|
349 |
+
|
350 |
+
@torch.no_grad()
|
351 |
+
def fake_input(self, device):
|
352 |
+
audio_embed = torch.rand([1, 10, self.config.d_model], dtype=torch.float32, device=device)
|
353 |
+
input_length = torch.ones([1], dtype=torch.int32, device=device) * 10
|
354 |
+
mel_labels_length = self.post_net.compute_output_length(input_length)
|
355 |
+
return audio_embed, input_length, None, mel_labels_length
|
356 |
+
|
357 |
+
def forward(self,
|
358 |
+
audio_embed,
|
359 |
+
input_length,
|
360 |
+
mel_labels=None,
|
361 |
+
mel_labels_length=None,
|
362 |
+
fake_input=False,
|
363 |
+
):
|
364 |
+
if fake_input:
|
365 |
+
audio_embed, input_length, mel_labels, mel_labels_length = self.fake_input(self.layer_norm.weight.device)
|
366 |
+
|
367 |
+
assert (audio_embed.shape[-1] == self.config.d_model)
|
368 |
+
audio_embed = audio_embed.to(self.layer_norm.weight) # device and type
|
369 |
+
audio_embed, output_length = self.dconv1(audio_embed, input_length, output_dim=3) # (b, l*2, d_model)
|
370 |
+
_, tgt_len, _ = audio_embed.size()
|
371 |
+
if tgt_len < self.positional_embedding.shape[0]:
|
372 |
+
current_positional_embedding = self.positional_embedding[:tgt_len]
|
373 |
+
else:
|
374 |
+
current_positional_embedding = self.positional_embedding
|
375 |
+
hidden_states = (audio_embed.to(torch.float32) + current_positional_embedding).to(audio_embed.dtype)
|
376 |
+
|
377 |
+
# packing hidden states
|
378 |
+
attention_mask, _ = get_sequence_mask(hidden_states, output_length)
|
379 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
|
380 |
+
|
381 |
+
for idx, encoder_layer in enumerate(self.layers):
|
382 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
383 |
+
|
384 |
+
hidden_states = self.layer_norm(hidden_states)
|
385 |
+
hidden_states_before_dconv2 = hidden_states
|
386 |
+
output_length_before_dconv2 = output_length
|
387 |
+
|
388 |
+
coarse_mel, output_length = self.dconv2(hidden_states, output_length, output_dim=3)
|
389 |
+
refined_mel, coarse_mel, mel_labels_length = self.post_net(coarse_mel, output_length, mel_labels_length)
|
390 |
+
|
391 |
+
return OmniAudioDecoderOutput(
|
392 |
+
refined_mel=refined_mel,
|
393 |
+
coarse_mel=coarse_mel,
|
394 |
+
mel_length=mel_labels_length,
|
395 |
+
hidden_states_before_dconv2=hidden_states_before_dconv2,
|
396 |
+
output_length_before_dconv2=output_length_before_dconv2,
|
397 |
+
)
|
398 |
+
|
399 |
+
|
400 |
+
class OmniAudioVQBridgeTokenizer(nn.Module):
|
401 |
+
def __init__(self, config):
|
402 |
+
super().__init__()
|
403 |
+
self.config = config.audio_config
|
404 |
+
self.gradient_checkpointing = False
|
405 |
+
self.intermediate_dim = self.config.d_model * self.config.avg_pooler
|
406 |
+
self.gate_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
|
407 |
+
self.up_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
|
408 |
+
|
409 |
+
self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
|
410 |
+
self.act_fn = ACT2FN['silu']
|
411 |
+
self.layer_norm = nn.LayerNorm(self.intermediate_dim)
|
412 |
+
self.proj_decoder = nn.Linear(self.intermediate_dim, self.config.d_model)
|
413 |
+
|
414 |
+
self.vq_list = nn.ModuleList([])
|
415 |
+
for idx, codebook_size in enumerate(self.config.vq_config.codebook_sizes):
|
416 |
+
vq_config = copy.deepcopy(self.config.vq_config)
|
417 |
+
vq_config.dim = self.intermediate_dim
|
418 |
+
vq_config.codebook_size = codebook_size
|
419 |
+
self.vq_list.append(VectorQuantize(vq_config))
|
420 |
+
for vq_layer in self.vq_list:
|
421 |
+
deepspeed.zero.register_external_parameter(self, vq_layer.codebook.embed)
|
422 |
+
|
423 |
+
def rvq_op(self, inputs, output_length):
|
424 |
+
def rvq_layer_op(vq_layer, residual_encoding, output_length):
|
425 |
+
q_v_i, code_ids_i = vq_layer(residual_encoding, output_length)
|
426 |
+
residual_encoding = residual_encoding.float() - q_v_i.float()
|
427 |
+
residual_encoding = residual_encoding.to(inputs.dtype)
|
428 |
+
return residual_encoding, code_ids_i
|
429 |
+
|
430 |
+
cmt_loss, residual_encoding = 0, inputs
|
431 |
+
code_ids_list = []
|
432 |
+
for i, vq_layer in enumerate(self.vq_list):
|
433 |
+
residual_encoding, code_ids_i = rvq_layer_op(vq_layer, residual_encoding, output_length)
|
434 |
+
code_ids_list.append(code_ids_i)
|
435 |
+
return torch.stack(code_ids_list, -1)
|
436 |
+
|
437 |
+
def forward(self, x, output_length):
|
438 |
+
batch_size, _, _ = x.shape
|
439 |
+
output_length = output_length.to(x.device)
|
440 |
+
|
441 |
+
if x.shape[1] % self.config.avg_pooler != 0:
|
442 |
+
x = F.pad(x, (0, 0, 0, self.config.avg_pooler - x.shape[1] % self.config.avg_pooler), "constant", 0)
|
443 |
+
xt = x.permute(0, 2, 1)
|
444 |
+
g = self.gate_proj(xt).permute(0, 2, 1) # (bs, sl//poolersizre+1, d*2)
|
445 |
+
u = self.up_proj(xt).permute(0, 2, 1)
|
446 |
+
x = x.reshape(batch_size, -1, self.intermediate_dim) # (bs, sl//poolersizre+1, d*2)
|
447 |
+
|
448 |
+
c = self.down_proj(self.act_fn(g) * u)
|
449 |
+
res = self.layer_norm(c + x)
|
450 |
+
valid_mask, _ = get_sequence_mask(res, output_length)
|
451 |
+
code_ids = self.rvq_op(res, output_length)
|
452 |
+
code_ids = torch.masked_select(code_ids, valid_mask).reshape(-1, len(self.vq_list)) # (sum(valid_sequence_length), vq_num)
|
453 |
+
return code_ids
|
454 |
+
|
455 |
+
@torch.no_grad()
|
456 |
+
def decode(self, code_ids):
|
457 |
+
vq_num = code_ids.shape[-1]
|
458 |
+
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
|
459 |
+
decoder_emb = self.proj_decoder(res.to(self.proj_decoder.weight))
|
460 |
+
return decoder_emb
|
461 |
+
|
462 |
+
@torch.no_grad()
|
463 |
+
def recover(self, code_ids):
|
464 |
+
vq_num = code_ids.shape[-1]
|
465 |
+
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
|
466 |
+
return res
|
467 |
+
|
468 |
+
|
469 |
+
class FlowmatchingPrenet(nn.Module):
|
470 |
+
def __init__(
|
471 |
+
self,
|
472 |
+
input_feat_dim,
|
473 |
+
out_feat_dim,
|
474 |
+
d_model,
|
475 |
+
attention_heads,
|
476 |
+
ffn_dim,
|
477 |
+
nlayers,
|
478 |
+
activation_function,
|
479 |
+
max_source_positions,
|
480 |
+
target_mel_length_scale_ratio,
|
481 |
+
):
|
482 |
+
super().__init__()
|
483 |
+
|
484 |
+
self.d_model = d_model
|
485 |
+
self.target_mel_length_scale_ratio = target_mel_length_scale_ratio
|
486 |
+
self.gradient_checkpointing = False
|
487 |
+
|
488 |
+
self.register_buffer(
|
489 |
+
"positional_embedding", sinusoids(max_source_positions, d_model)
|
490 |
+
)
|
491 |
+
|
492 |
+
self.in_mlp = nn.Sequential(
|
493 |
+
nn.Linear(input_feat_dim, d_model * 4),
|
494 |
+
nn.SiLU(),
|
495 |
+
nn.Linear(d_model * 4, d_model),
|
496 |
+
)
|
497 |
+
|
498 |
+
self.transformer_layers = nn.ModuleList(
|
499 |
+
[
|
500 |
+
OmniWhisperTransformerLayer(
|
501 |
+
act=ACT2FN[activation_function],
|
502 |
+
d_model=d_model,
|
503 |
+
encoder_attention_heads=attention_heads,
|
504 |
+
encoder_ffn_dim=ffn_dim,
|
505 |
+
causal=True, # causal
|
506 |
+
ln_type="RMSNorm",
|
507 |
+
)
|
508 |
+
for _ in range(nlayers)
|
509 |
+
]
|
510 |
+
)
|
511 |
+
|
512 |
+
self.final_norm = RMSNorm(self.d_model)
|
513 |
+
self.out_proj = nn.Linear(d_model, out_feat_dim, bias=False)
|
514 |
+
|
515 |
+
def compute_output_length(self, input_length):
|
516 |
+
output_length = input_length.float() * self.target_mel_length_scale_ratio
|
517 |
+
return output_length.to(torch.int64)
|
518 |
+
|
519 |
+
def forward(self, input_feat, input_length, output_length=None):
|
520 |
+
"""
|
521 |
+
Args:
|
522 |
+
input_feat: [B, T, input_feat_dim]
|
523 |
+
input_length: [B]
|
524 |
+
output_length: [B]
|
525 |
+
|
526 |
+
"""
|
527 |
+
if output_length is None or not self.training:
|
528 |
+
output_length = self.compute_output_length(input_length)
|
529 |
+
|
530 |
+
input_feat = input_feat[:, : input_length.max(), :] # [B, T, D]
|
531 |
+
orig_dtype = input_feat.dtype
|
532 |
+
|
533 |
+
input_feat = F.interpolate(
|
534 |
+
input=input_feat.to(torch.float32).transpose(1, 2).contiguous(),
|
535 |
+
size=output_length.max(),
|
536 |
+
mode="nearest",
|
537 |
+
).to(orig_dtype)
|
538 |
+
input_feat = input_feat.transpose(1, 2).contiguous() # [B, T, D]
|
539 |
+
hidden_states = self.in_mlp(input_feat)
|
540 |
+
|
541 |
+
# packing hidden states
|
542 |
+
bsz, tgt_len, d_model = hidden_states.shape
|
543 |
+
attention_mask, unpacking_index = get_sequence_mask(
|
544 |
+
hidden_states, output_length
|
545 |
+
)
|
546 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
|
547 |
+
torch.sum(output_length), self.d_model
|
548 |
+
)
|
549 |
+
|
550 |
+
for idx, encoder_layer in enumerate(self.transformer_layers):
|
551 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
552 |
+
|
553 |
+
# unpacking
|
554 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
555 |
+
bsz, tgt_len, d_model
|
556 |
+
)
|
557 |
+
hidden_states = torch.where(attention_mask, hidden_states, 0)
|
558 |
+
|
559 |
+
hidden_states = self.final_norm(hidden_states)
|
560 |
+
output = self.out_proj(hidden_states)
|
561 |
+
return output, output_length
|
562 |
+
|
563 |
+
|
564 |
+
@dataclass
|
565 |
+
class OmniAudioFlowMatchingDecoderOutput(ModelOutput):
|
566 |
+
flow_matching_mel: Optional[torch.FloatTensor] = None
|
567 |
+
flow_matching_mel_lengths: Optional[torch.FloatTensor] = None
|
568 |
+
|
569 |
+
|
570 |
+
class OmniAudioFlowMatchingDecoder(nn.Module):
|
571 |
+
def __init__(self, config):
|
572 |
+
super().__init__()
|
573 |
+
self.config = config.flow_matching_config
|
574 |
+
self.in_channels = self.config.in_channels
|
575 |
+
self.spk_emb_dim = self.config.spk_emb_dim
|
576 |
+
self.diffusion_steps = self.config.diffusion_steps
|
577 |
+
self.cal_mel_mae = self.config.cal_mel_mae
|
578 |
+
self.forward_step = -1
|
579 |
+
|
580 |
+
self.prenet = FlowmatchingPrenet(
|
581 |
+
input_feat_dim=self.config.prenet_in_dim,
|
582 |
+
out_feat_dim=self.config.prenet_out_dim,
|
583 |
+
d_model=self.config.prenet_d_model,
|
584 |
+
attention_heads=self.config.prenet_attention_heads,
|
585 |
+
ffn_dim=self.config.prenet_ffn_dim,
|
586 |
+
nlayers=self.config.prenet_nlayers,
|
587 |
+
activation_function=self.config.prenet_activation_function,
|
588 |
+
max_source_positions=self.config.prenet_max_source_positions,
|
589 |
+
target_mel_length_scale_ratio=self.config.prenet_target_mel_length_scale_ratio,
|
590 |
+
)
|
591 |
+
|
592 |
+
self.conditional_decoder = ConditionalDecoder(
|
593 |
+
in_channels=self.in_channels * 2 + self.spk_emb_dim,
|
594 |
+
out_channels=self.in_channels,
|
595 |
+
causal=True,
|
596 |
+
channels=self.config.channels,
|
597 |
+
dropout=self.config.dropout,
|
598 |
+
attention_head_dim=self.config.attention_head_dim,
|
599 |
+
n_blocks=self.config.n_blocks,
|
600 |
+
num_mid_blocks=self.config.num_mid_blocks,
|
601 |
+
num_heads=self.config.num_heads,
|
602 |
+
act_fn=self.config.act_fn,
|
603 |
+
)
|
604 |
+
|
605 |
+
self.cfm = ConditionalCFM(
|
606 |
+
in_channels=self.in_channels,
|
607 |
+
cfm_params=self.config.cfm_params,
|
608 |
+
n_spks=0,
|
609 |
+
spk_emb_dim=self.spk_emb_dim,
|
610 |
+
)
|
611 |
+
|
612 |
+
|
613 |
+
def unpack_hidden_states(self, hidden_states, output_length):
|
614 |
+
unpacked = unpack_hidden_states(hidden_states, output_length)
|
615 |
+
return unpacked, output_length
|
616 |
+
|
617 |
+
def forward(
|
618 |
+
self, refined_mel, input_length, mel_labels=None, mel_labels_length=None
|
619 |
+
):
|
620 |
+
"""
|
621 |
+
:param refined_mel: [bs, max_input_len, mel_bin]
|
622 |
+
:param input_length: [batch_size]
|
623 |
+
:param refined_mel: [bs, mel_bin, max_input_len]
|
624 |
+
:return:
|
625 |
+
"""
|
626 |
+
self.forward_step += 1
|
627 |
+
|
628 |
+
orig_dtype = refined_mel.dtype
|
629 |
+
prenet_mae_metric = torch.tensor(0.0).to(refined_mel.device)
|
630 |
+
prenet_regression_loss = torch.tensor(0.0).to(refined_mel.device)
|
631 |
+
|
632 |
+
if self.prenet is not None:
|
633 |
+
refined_mel = refined_mel[:, : torch.max(input_length), :]
|
634 |
+
if mel_labels_length is None:
|
635 |
+
mel_labels_length = self.prenet.compute_output_length(input_length)
|
636 |
+
refined_mel, input_length = self.prenet(
|
637 |
+
refined_mel, input_length, mel_labels_length
|
638 |
+
)
|
639 |
+
|
640 |
+
float_dtype = refined_mel.dtype
|
641 |
+
refined_mel = refined_mel.float()
|
642 |
+
input_length = input_length.long()
|
643 |
+
|
644 |
+
refined_mel = refined_mel[:, : torch.max(input_length), :]
|
645 |
+
sequence_mask, unpacking_index = get_sequence_mask(refined_mel, input_length)
|
646 |
+
refined_mel = refined_mel.transpose(1, 2) # (bs, mel_bin, max_input_len)
|
647 |
+
sequence_mask = sequence_mask.transpose(2, 1) # (bs, 1, sl)
|
648 |
+
|
649 |
+
fm_mel = self.cfm.forward(
|
650 |
+
estimator=self.conditional_decoder,
|
651 |
+
mu=refined_mel.to(float_dtype),
|
652 |
+
mask=sequence_mask.float(),
|
653 |
+
n_timesteps=self.diffusion_steps,
|
654 |
+
)
|
655 |
+
return OmniAudioFlowMatchingDecoderOutput(
|
656 |
+
flow_matching_mel=fm_mel.transpose(1, 2),
|
657 |
+
flow_matching_mel_lengths=mel_labels_length,
|
658 |
+
)
|
config.json
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "_",
|
3 |
+
"architectures": [
|
4 |
+
"OmniForCausalLM"
|
5 |
+
],
|
6 |
+
"attention_qkv_bias": true,
|
7 |
+
"attention_qkv_pack": true,
|
8 |
+
"audio_config": {
|
9 |
+
"audio_head_transformer_layers": 3,
|
10 |
+
"audio_delim_token_id": 151693,
|
11 |
+
"audio_end_token_id": 151677,
|
12 |
+
"audio_pad_token_id": 151678,
|
13 |
+
"audio_start_token_id": 151676,
|
14 |
+
"audiogen_end_token_id": 151701,
|
15 |
+
"audiogen_start_token_id": 151700,
|
16 |
+
"audiotext_end_token_id": 151698,
|
17 |
+
"audiotext_pad_token_id": 151699,
|
18 |
+
"audiotext_start_token_id": 151697,
|
19 |
+
"avg_pooler": 4,
|
20 |
+
"d_model": 1280,
|
21 |
+
"decoder_attention_heads": 20,
|
22 |
+
"decoder_ffn_dim": 5120,
|
23 |
+
"decoder_kernel_size": 3,
|
24 |
+
"decoder_layers": 8,
|
25 |
+
"decoder_stride_size": 2,
|
26 |
+
"enable": true,
|
27 |
+
"encoder_attention_heads": 20,
|
28 |
+
"encoder_ffn_dim": 5120,
|
29 |
+
"encoder_layers": 32,
|
30 |
+
"hop_length": 160,
|
31 |
+
"kernel_size": 3,
|
32 |
+
"max_audio_seconds": 30,
|
33 |
+
"n_fft": 400,
|
34 |
+
"num_mel_bins": 128,
|
35 |
+
"sampling_rate": 16000,
|
36 |
+
"stride_size": 2,
|
37 |
+
"split_overlap": 0.0,
|
38 |
+
"vq_config":{
|
39 |
+
"enable": true,
|
40 |
+
"codebook_sizes": [8192, 4096, 2048, 1024, 1024, 1024, 1024, 1024]
|
41 |
+
}
|
42 |
+
},
|
43 |
+
"auto_map": {
|
44 |
+
"AutoConfig": "configuration_omni.OmniConfig",
|
45 |
+
"AutoModelForCausalLM": "modeling_omni.OmniForCausalLM"
|
46 |
+
},
|
47 |
+
"omni_tokenizer_type": "auto",
|
48 |
+
"bos_token_id": 1,
|
49 |
+
"eos_token_id": 2,
|
50 |
+
"flow_matching_config": {
|
51 |
+
"enable": true,
|
52 |
+
"use_hires_mel": true,
|
53 |
+
"sampling_rate": 24000,
|
54 |
+
"hop_length": 480,
|
55 |
+
"max_audio_seconds": 30,
|
56 |
+
"split_overlap": 0.1,
|
57 |
+
"use_hidden_states_before_dconv2": true,
|
58 |
+
"prenet_in_dim": 1280,
|
59 |
+
"prenet_out_dim": 80,
|
60 |
+
"prenet_d_model": 512,
|
61 |
+
"prenet_attention_heads": 8,
|
62 |
+
"prenet_ffn_dim": 2048,
|
63 |
+
"prenet_nlayers": 12,
|
64 |
+
"prenet_activation_function": "gelu",
|
65 |
+
"prenet_max_source_positions": 5000,
|
66 |
+
"prenet_target_mel_length_scale_ratio": 1.0,
|
67 |
+
"prenet_loss_weight": 1.0,
|
68 |
+
"unet_use_omni_attn": false,
|
69 |
+
"loss_weight": 1.0,
|
70 |
+
"in_channels": 80,
|
71 |
+
"spk_emb_dim": 0,
|
72 |
+
"diffusion_steps": 10,
|
73 |
+
"channels": [256],
|
74 |
+
"dropout": 0.0,
|
75 |
+
"attention_head_dim": 64,
|
76 |
+
"n_blocks": 4,
|
77 |
+
"num_mid_blocks": 12,
|
78 |
+
"num_heads": 8,
|
79 |
+
"act_fn": "gelu",
|
80 |
+
"cal_mel_mae": true,
|
81 |
+
"cfm_params": {
|
82 |
+
"sigma_min": 1e-6,
|
83 |
+
"solver": "euler",
|
84 |
+
"t_scheduler": "cosine",
|
85 |
+
"training_cfg_rate": 0.2,
|
86 |
+
"inference_cfg_rate": 0.7,
|
87 |
+
"reg_loss_type": "l1"
|
88 |
+
}
|
89 |
+
},
|
90 |
+
"head_dim": 128,
|
91 |
+
"hidden_act": "silu",
|
92 |
+
"hidden_size": 3584,
|
93 |
+
"initializer_range": 0.02,
|
94 |
+
"intermediate_size": 18944,
|
95 |
+
"max_position_embeddings": 8192,
|
96 |
+
"max_window_layers": 28,
|
97 |
+
"model_type": "omni",
|
98 |
+
"multimodal": [
|
99 |
+
"audio",
|
100 |
+
"image",
|
101 |
+
"video",
|
102 |
+
"audiogen"
|
103 |
+
],
|
104 |
+
"multimodal_special_token_list": [
|
105 |
+
151676,
|
106 |
+
151677,
|
107 |
+
151678,
|
108 |
+
151679,
|
109 |
+
151680,
|
110 |
+
151681,
|
111 |
+
151682,
|
112 |
+
151683,
|
113 |
+
151684,
|
114 |
+
151685,
|
115 |
+
151686,
|
116 |
+
151687,
|
117 |
+
151688,
|
118 |
+
151693,
|
119 |
+
151694,
|
120 |
+
151695,
|
121 |
+
151696,
|
122 |
+
151697,
|
123 |
+
151698,
|
124 |
+
151699,
|
125 |
+
151700,
|
126 |
+
151701
|
127 |
+
],
|
128 |
+
"num_attention_heads": 28,
|
129 |
+
"num_hidden_layers": 28,
|
130 |
+
"num_key_value_heads": 4,
|
131 |
+
"pad_token_id": 0,
|
132 |
+
"position_embedding_type": "rope",
|
133 |
+
"rms_norm_eps": 1e-06,
|
134 |
+
"rope_theta": 1000000.0,
|
135 |
+
"sliding_window": 131072,
|
136 |
+
"sparse_attention_heads": null,
|
137 |
+
"sparse_attention_layers": [],
|
138 |
+
"tie_word_embeddings": false,
|
139 |
+
"torch_dtype": "bfloat16",
|
140 |
+
"train_multimodal_special_tokens_only": false,
|
141 |
+
"transformers_version": "4.45.0.dev0",
|
142 |
+
"use_cache": false,
|
143 |
+
"use_norm_head": false,
|
144 |
+
"use_sliding_window": false,
|
145 |
+
"video_config": {
|
146 |
+
"_name_or_path": "",
|
147 |
+
"_attn_implementation": "flash_attention_2",
|
148 |
+
"decode_way": "1fps",
|
149 |
+
"depth": 32,
|
150 |
+
"embed_dim": 1280,
|
151 |
+
"enable": true,
|
152 |
+
"hidden_act": "quick_gelu",
|
153 |
+
"hidden_size": 3584,
|
154 |
+
"image_delimiter_token_id": 151688,
|
155 |
+
"image_end_token_id": 151680,
|
156 |
+
"image_line_token_id": 151682,
|
157 |
+
"image_mean": [
|
158 |
+
0.48145466,
|
159 |
+
0.4578275,
|
160 |
+
0.40821073
|
161 |
+
],
|
162 |
+
"image_pad_token_id": 151681,
|
163 |
+
"image_size": 224,
|
164 |
+
"image_start_token_id": 151679,
|
165 |
+
"image_std": [
|
166 |
+
0.26862954,
|
167 |
+
0.26130258,
|
168 |
+
0.27577711
|
169 |
+
],
|
170 |
+
"in_channels": 3,
|
171 |
+
"in_chans": 3,
|
172 |
+
"intermediate_size": 3072,
|
173 |
+
"layer_norm_eps": 1e-05,
|
174 |
+
"max_frame_num": 32,
|
175 |
+
"max_length": 20,
|
176 |
+
"max_pixels": 602112,
|
177 |
+
"merge_size": 2,
|
178 |
+
"min_length": 0,
|
179 |
+
"min_pixels": 3136,
|
180 |
+
"mlp_ratio": 4,
|
181 |
+
"model_type": "clip_vision_model",
|
182 |
+
"num_attention_heads": 12,
|
183 |
+
"num_channels": 3,
|
184 |
+
"num_heads": 16,
|
185 |
+
"num_hidden_layers": 12,
|
186 |
+
"patch_size": 14,
|
187 |
+
"spatial_merge_size": 2,
|
188 |
+
"spatial_patch_size": 14,
|
189 |
+
"temporal_patch_size": 2,
|
190 |
+
"video_end_token_id": 151696,
|
191 |
+
"video_place_token_id": 151694,
|
192 |
+
"video_start_token_id": 151695
|
193 |
+
},
|
194 |
+
"visual_config": {
|
195 |
+
"_name_or_path": "",
|
196 |
+
"_attn_implementation": "flash_attention_2",
|
197 |
+
"depth": 32,
|
198 |
+
"diversity_penalty": 0.0,
|
199 |
+
"do_sample": false,
|
200 |
+
"early_stopping": false,
|
201 |
+
"embed_dim": 1280,
|
202 |
+
"enable": true,
|
203 |
+
"hidden_act": "quick_gelu",
|
204 |
+
"hidden_size": 3584,
|
205 |
+
"image_delimiter_token_id": 151688,
|
206 |
+
"image_end_token_id": 151680,
|
207 |
+
"image_line_token_id": 151682,
|
208 |
+
"image_mean": [
|
209 |
+
0.48145466,
|
210 |
+
0.4578275,
|
211 |
+
0.40821073
|
212 |
+
],
|
213 |
+
"image_pad_token_id": 151681,
|
214 |
+
"image_size": 224,
|
215 |
+
"image_start_token_id": 151679,
|
216 |
+
"image_std": [
|
217 |
+
0.26862954,
|
218 |
+
0.26130258,
|
219 |
+
0.27577711
|
220 |
+
],
|
221 |
+
"in_channels": 3,
|
222 |
+
"in_chans": 3,
|
223 |
+
"intermediate_size": 3072,
|
224 |
+
"layer_norm_eps": 1e-05,
|
225 |
+
"length_penalty": 1.0,
|
226 |
+
"max_length": 20,
|
227 |
+
"max_pixels": 3211264,
|
228 |
+
"merge_size": 2,
|
229 |
+
"min_length": 0,
|
230 |
+
"min_pixels": 3136,
|
231 |
+
"mlp_ratio": 4,
|
232 |
+
"model_type": "clip_vision_model",
|
233 |
+
"num_attention_heads": 12,
|
234 |
+
"num_channels": 3,
|
235 |
+
"num_heads": 16,
|
236 |
+
"num_hidden_layers": 12,
|
237 |
+
"patch_size": 14,
|
238 |
+
"projection_dim": 512,
|
239 |
+
"spatial_merge_size": 2,
|
240 |
+
"spatial_patch_size": 14,
|
241 |
+
"temporal_patch_size": 2
|
242 |
+
},
|
243 |
+
"vocab_size": 152064,
|
244 |
+
"vocoder_config":{
|
245 |
+
"enable": true,
|
246 |
+
"enable_multi_scale": true,
|
247 |
+
"max_audio_seconds": 30,
|
248 |
+
"sampling_rate": 16000,
|
249 |
+
"hop_length": 256,
|
250 |
+
"split_overlap": 0.0,
|
251 |
+
"n_fft": 1024,
|
252 |
+
"num_mel_bins": 80,
|
253 |
+
"channels": [256, 256, 256, 256, 256]
|
254 |
+
}
|
255 |
+
}
|
configuration_omni.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Baichuan Inc. All Rights Reserved.
|
2 |
+
|
3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
6 |
+
# and OPT implementations in this library. It has been modified from its
|
7 |
+
# original forms to accommodate minor architectural differences compared
|
8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
9 |
+
#
|
10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11 |
+
# you may not use this file except in compliance with the License.
|
12 |
+
# You may obtain a copy of the License at
|
13 |
+
#
|
14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15 |
+
#
|
16 |
+
# Unless required by applicable law or agreed to in writing, software
|
17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19 |
+
# See the License for the specific language governing permissions and
|
20 |
+
# limitations under the License.
|
21 |
+
|
22 |
+
from transformers.configuration_utils import PretrainedConfig
|
23 |
+
from transformers.utils import logging
|
24 |
+
from transformers import WhisperConfig
|
25 |
+
from transformers import CLIPVisionConfig
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
class OmniConfig(PretrainedConfig):
|
31 |
+
model_type = "omni"
|
32 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
vocab_size=125696,
|
37 |
+
hidden_size=4096,
|
38 |
+
intermediate_size=11008,
|
39 |
+
num_hidden_layers=32,
|
40 |
+
num_attention_heads=32,
|
41 |
+
num_key_value_heads=None,
|
42 |
+
sparse_attention_heads=None,
|
43 |
+
sparse_attention_layers=[],
|
44 |
+
head_dim=None,
|
45 |
+
attention_qkv_pack=True,
|
46 |
+
attention_qkv_bias=False,
|
47 |
+
use_norm_head=True,
|
48 |
+
hidden_act="silu",
|
49 |
+
max_position_embeddings=4096,
|
50 |
+
position_embedding_type="rope",
|
51 |
+
initializer_range=0.02,
|
52 |
+
rms_norm_eps=1e-6,
|
53 |
+
use_cache=True,
|
54 |
+
pad_token_id=0,
|
55 |
+
bos_token_id=1,
|
56 |
+
eos_token_id=2,
|
57 |
+
tie_word_embeddings=False,
|
58 |
+
audio_config=None,
|
59 |
+
visual_config=None,
|
60 |
+
video_config=None,
|
61 |
+
vocoder_config=None,
|
62 |
+
flow_matching_config=None,
|
63 |
+
**kwargs,
|
64 |
+
):
|
65 |
+
self.vocab_size = vocab_size
|
66 |
+
self.max_position_embeddings = max_position_embeddings
|
67 |
+
self.hidden_size = hidden_size
|
68 |
+
self.intermediate_size = intermediate_size
|
69 |
+
self.num_hidden_layers = num_hidden_layers
|
70 |
+
self.num_attention_heads = num_attention_heads
|
71 |
+
self.num_key_value_heads = num_key_value_heads or self.num_attention_heads
|
72 |
+
self.sparse_attention_heads = sparse_attention_heads
|
73 |
+
self.sparse_attention_layers = sparse_attention_layers
|
74 |
+
self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
|
75 |
+
self.attention_qkv_pack = attention_qkv_pack
|
76 |
+
self.attention_qkv_bias = attention_qkv_bias
|
77 |
+
self.use_norm_head = use_norm_head
|
78 |
+
self.hidden_act = hidden_act
|
79 |
+
self.position_embedding_type = position_embedding_type
|
80 |
+
self.initializer_range = initializer_range
|
81 |
+
self.rms_norm_eps = rms_norm_eps
|
82 |
+
self.use_cache = use_cache
|
83 |
+
assert self.position_embedding_type.lower() in ("rope", "alibi")
|
84 |
+
super().__init__(
|
85 |
+
pad_token_id=pad_token_id,
|
86 |
+
bos_token_id=bos_token_id,
|
87 |
+
eos_token_id=eos_token_id,
|
88 |
+
tie_word_embeddings=tie_word_embeddings,
|
89 |
+
**kwargs,
|
90 |
+
)
|
91 |
+
if audio_config is not None:
|
92 |
+
self.audio_config = WhisperConfig(**audio_config)
|
93 |
+
if self.audio_config.vq_config is not None:
|
94 |
+
self.audio_config.vq_config = PretrainedConfig(**self.audio_config.vq_config)
|
95 |
+
if vocoder_config is not None:
|
96 |
+
self.vocoder_config = WhisperConfig(**vocoder_config)
|
97 |
+
if flow_matching_config is not None:
|
98 |
+
self.flow_matching_config = PretrainedConfig(**flow_matching_config)
|
99 |
+
self.flow_matching_config.cfm_params = PretrainedConfig(**self.flow_matching_config.cfm_params)
|
100 |
+
if visual_config is not None:
|
101 |
+
self.visual_config = CLIPVisionConfig(**visual_config)
|
102 |
+
if video_config is not None:
|
103 |
+
self.video_config = CLIPVisionConfig(**video_config)
|
104 |
+
|
105 |
+
|
106 |
+
def to_diff_dict(self):
|
107 |
+
data = super().to_diff_dict()
|
108 |
+
data["model_type"] = self.model_type
|
109 |
+
return data
|
110 |
+
|
111 |
+
def get_rotary_base(self):
|
112 |
+
if hasattr(self, "rotary_emb_base"):
|
113 |
+
return self.rotary_emb_base
|
114 |
+
else:
|
115 |
+
return self.rope_theta
|
116 |
+
|
117 |
+
if __name__ == '__main__':
|
118 |
+
from transformers import AutoConfig
|
119 |
+
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
|
120 |
+
print(config)
|
flow_matching.py
ADDED
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice/tree/main
|
2 |
+
"""
|
3 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
|
18 |
+
from abc import ABC
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from typing import Dict, Optional
|
22 |
+
|
23 |
+
import torch.nn as nn
|
24 |
+
from einops import pack, rearrange, repeat
|
25 |
+
from .matcha_components import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
26 |
+
from .matcha_transformer import BasicTransformerBlock
|
27 |
+
from omegaconf import DictConfig
|
28 |
+
|
29 |
+
|
30 |
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
31 |
+
assert mask.dtype == torch.bool
|
32 |
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
33 |
+
mask = mask.to(dtype)
|
34 |
+
# attention mask bias
|
35 |
+
# NOTE(Mddct): torch.finfo jit issues
|
36 |
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
37 |
+
mask = (1.0 - mask) * torch.finfo(dtype).min
|
38 |
+
return mask
|
39 |
+
|
40 |
+
|
41 |
+
def subsequent_chunk_mask(
|
42 |
+
size: int,
|
43 |
+
chunk_size: int,
|
44 |
+
num_left_chunks: int = -1,
|
45 |
+
device: torch.device = torch.device("cpu"),
|
46 |
+
) -> torch.Tensor:
|
47 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
48 |
+
this is for streaming encoder
|
49 |
+
|
50 |
+
Args:
|
51 |
+
size (int): size of mask
|
52 |
+
chunk_size (int): size of chunk
|
53 |
+
num_left_chunks (int): number of left chunks
|
54 |
+
<0: use full chunk
|
55 |
+
>=0: use num_left_chunks
|
56 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
torch.Tensor: mask
|
60 |
+
|
61 |
+
Examples:
|
62 |
+
>>> subsequent_chunk_mask(4, 2)
|
63 |
+
[[1, 1, 0, 0],
|
64 |
+
[1, 1, 0, 0],
|
65 |
+
[1, 1, 1, 1],
|
66 |
+
[1, 1, 1, 1]]
|
67 |
+
"""
|
68 |
+
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
69 |
+
# actually this is not needed after we have inference cache implemented, will remove it later
|
70 |
+
pos_idx = torch.arange(size, device=device)
|
71 |
+
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
72 |
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
73 |
+
return ret
|
74 |
+
|
75 |
+
def subsequent_mask(
|
76 |
+
size: int,
|
77 |
+
device: torch.device = torch.device("cpu"),
|
78 |
+
) -> torch.Tensor:
|
79 |
+
"""Create mask for subsequent steps (size, size).
|
80 |
+
|
81 |
+
This mask is used only in decoder which works in an auto-regressive mode.
|
82 |
+
This means the current step could only do attention with its left steps.
|
83 |
+
|
84 |
+
In encoder, fully attention is used when streaming is not necessary and
|
85 |
+
the sequence is not long. In this case, no attention mask is needed.
|
86 |
+
|
87 |
+
When streaming is need, chunk-based attention is used in encoder. See
|
88 |
+
subsequent_chunk_mask for the chunk-based attention mask.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
size (int): size of mask
|
92 |
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
93 |
+
dtype (torch.device): result dtype
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
torch.Tensor: mask
|
97 |
+
|
98 |
+
Examples:
|
99 |
+
>>> subsequent_mask(3)
|
100 |
+
[[1, 0, 0],
|
101 |
+
[1, 1, 0],
|
102 |
+
[1, 1, 1]]
|
103 |
+
"""
|
104 |
+
arange = torch.arange(size, device=device)
|
105 |
+
mask = arange.expand(size, size)
|
106 |
+
arange = arange.unsqueeze(-1)
|
107 |
+
mask = mask <= arange
|
108 |
+
return mask
|
109 |
+
|
110 |
+
|
111 |
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
112 |
+
masks: torch.Tensor,
|
113 |
+
use_dynamic_chunk: bool,
|
114 |
+
use_dynamic_left_chunk: bool,
|
115 |
+
decoding_chunk_size: int,
|
116 |
+
static_chunk_size: int,
|
117 |
+
num_decoding_left_chunks: int,
|
118 |
+
enable_full_context: bool = True):
|
119 |
+
""" Apply optional mask for encoder.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
123 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
124 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
125 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
126 |
+
training.
|
127 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
128 |
+
0: default for training, use random dynamic chunk.
|
129 |
+
<0: for decoding, use full chunk.
|
130 |
+
>0: for decoding, use fixed chunk size as set.
|
131 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
132 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
133 |
+
this parameter will be ignored
|
134 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
135 |
+
the chunk size is decoding_chunk_size.
|
136 |
+
>=0: use num_decoding_left_chunks
|
137 |
+
<0: use all left chunks
|
138 |
+
enable_full_context (bool):
|
139 |
+
True: chunk size is either [1, 25] or full context(max_len)
|
140 |
+
False: chunk size ~ U[1, 25]
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
torch.Tensor: chunk mask of the input xs.
|
144 |
+
"""
|
145 |
+
# Whether to use chunk mask or not
|
146 |
+
if use_dynamic_chunk:
|
147 |
+
max_len = xs.size(1)
|
148 |
+
if decoding_chunk_size < 0:
|
149 |
+
chunk_size = max_len
|
150 |
+
num_left_chunks = -1
|
151 |
+
elif decoding_chunk_size > 0:
|
152 |
+
chunk_size = decoding_chunk_size
|
153 |
+
num_left_chunks = num_decoding_left_chunks
|
154 |
+
else:
|
155 |
+
# chunk size is either [1, 25] or full context(max_len).
|
156 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
157 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
158 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
159 |
+
num_left_chunks = -1
|
160 |
+
if chunk_size > max_len // 2 and enable_full_context:
|
161 |
+
chunk_size = max_len
|
162 |
+
else:
|
163 |
+
chunk_size = chunk_size % 25 + 1
|
164 |
+
if use_dynamic_left_chunk:
|
165 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
166 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
167 |
+
(1, )).item()
|
168 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
169 |
+
num_left_chunks,
|
170 |
+
xs.device) # (L, L)
|
171 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
172 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
173 |
+
elif static_chunk_size > 0:
|
174 |
+
num_left_chunks = num_decoding_left_chunks
|
175 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
176 |
+
num_left_chunks,
|
177 |
+
xs.device) # (L, L)
|
178 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
179 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
180 |
+
else:
|
181 |
+
chunk_masks = masks
|
182 |
+
return chunk_masks
|
183 |
+
|
184 |
+
|
185 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
186 |
+
"""Make mask tensor containing indices of padded part.
|
187 |
+
|
188 |
+
See description of make_non_pad_mask.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
192 |
+
Returns:
|
193 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
194 |
+
|
195 |
+
Examples:
|
196 |
+
>>> lengths = [5, 3, 2]
|
197 |
+
>>> make_pad_mask(lengths)
|
198 |
+
masks = [[0, 0, 0, 0 ,0],
|
199 |
+
[0, 0, 0, 1, 1],
|
200 |
+
[0, 0, 1, 1, 1]]
|
201 |
+
"""
|
202 |
+
batch_size = lengths.size(0)
|
203 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
204 |
+
seq_range = torch.arange(0,
|
205 |
+
max_len,
|
206 |
+
dtype=torch.int64,
|
207 |
+
device=lengths.device)
|
208 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
209 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
210 |
+
mask = seq_range_expand >= seq_length_expand
|
211 |
+
return mask
|
212 |
+
|
213 |
+
# Causal
|
214 |
+
class Transpose(torch.nn.Module):
|
215 |
+
def __init__(self, dim0: int, dim1: int):
|
216 |
+
super().__init__()
|
217 |
+
self.dim0 = dim0
|
218 |
+
self.dim1 = dim1
|
219 |
+
|
220 |
+
def forward(self, x: torch.Tensor):
|
221 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
222 |
+
return x
|
223 |
+
|
224 |
+
class CausalBlock1D(Block1D):
|
225 |
+
def __init__(self, dim: int, dim_out: int):
|
226 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
227 |
+
self.block = torch.nn.Sequential(
|
228 |
+
CausalConv1d(dim, dim_out, 3),
|
229 |
+
Transpose(1, 2),
|
230 |
+
nn.LayerNorm(dim_out),
|
231 |
+
Transpose(1, 2),
|
232 |
+
nn.Mish(),
|
233 |
+
)
|
234 |
+
|
235 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
236 |
+
output = self.block(x * mask)
|
237 |
+
return output * mask
|
238 |
+
|
239 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
240 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
241 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
242 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
243 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
244 |
+
|
245 |
+
class CausalConv1d(torch.nn.Conv1d):
|
246 |
+
def __init__(
|
247 |
+
self,
|
248 |
+
in_channels: int,
|
249 |
+
out_channels: int,
|
250 |
+
kernel_size: int,
|
251 |
+
stride: int = 1,
|
252 |
+
dilation: int = 1,
|
253 |
+
groups: int = 1,
|
254 |
+
bias: bool = True,
|
255 |
+
padding_mode: str = 'zeros',
|
256 |
+
device=None,
|
257 |
+
dtype=None
|
258 |
+
) -> None:
|
259 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
260 |
+
kernel_size, stride,
|
261 |
+
padding=0, dilation=dilation,
|
262 |
+
groups=groups, bias=bias,
|
263 |
+
padding_mode=padding_mode,
|
264 |
+
device=device, dtype=dtype)
|
265 |
+
assert stride == 1
|
266 |
+
self.causal_padding = (kernel_size - 1, 0)
|
267 |
+
|
268 |
+
def forward(self, x: torch.Tensor):
|
269 |
+
x = F.pad(x, self.causal_padding)
|
270 |
+
x = super(CausalConv1d, self).forward(x)
|
271 |
+
return x
|
272 |
+
|
273 |
+
|
274 |
+
class BASECFM(torch.nn.Module, ABC):
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
n_feats,
|
278 |
+
cfm_params,
|
279 |
+
n_spks=1,
|
280 |
+
spk_emb_dim=128,
|
281 |
+
):
|
282 |
+
super().__init__()
|
283 |
+
self.n_feats = n_feats
|
284 |
+
self.n_spks = n_spks
|
285 |
+
self.spk_emb_dim = spk_emb_dim
|
286 |
+
self.solver = cfm_params.solver
|
287 |
+
if hasattr(cfm_params, "sigma_min"):
|
288 |
+
self.sigma_min = cfm_params.sigma_min
|
289 |
+
else:
|
290 |
+
self.sigma_min = 1e-4
|
291 |
+
|
292 |
+
self.estimator = None
|
293 |
+
|
294 |
+
@torch.inference_mode()
|
295 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
296 |
+
"""Forward diffusion
|
297 |
+
|
298 |
+
Args:
|
299 |
+
mu (torch.Tensor): output of encoder
|
300 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
301 |
+
mask (torch.Tensor): output_mask
|
302 |
+
shape: (batch_size, 1, mel_timesteps)
|
303 |
+
n_timesteps (int): number of diffusion steps
|
304 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
305 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
306 |
+
shape: (batch_size, spk_emb_dim)
|
307 |
+
cond: Not used but kept for future purposes
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
sample: generated mel-spectrogram
|
311 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
312 |
+
"""
|
313 |
+
z = torch.randn_like(mu) * temperature
|
314 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
315 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
316 |
+
|
317 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
318 |
+
"""
|
319 |
+
Fixed euler solver for ODEs.
|
320 |
+
Args:
|
321 |
+
x (torch.Tensor): random noise
|
322 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
323 |
+
shape: (n_timesteps + 1,)
|
324 |
+
mu (torch.Tensor): output of encoder
|
325 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
326 |
+
mask (torch.Tensor): output_mask
|
327 |
+
shape: (batch_size, 1, mel_timesteps)
|
328 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
329 |
+
shape: (batch_size, spk_emb_dim)
|
330 |
+
cond: Not used but kept for future purposes
|
331 |
+
"""
|
332 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
333 |
+
|
334 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
335 |
+
# Or in future might add like a return_all_steps flag
|
336 |
+
sol = []
|
337 |
+
|
338 |
+
for step in range(1, len(t_span)):
|
339 |
+
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
340 |
+
|
341 |
+
x = x + dt * dphi_dt
|
342 |
+
t = t + dt
|
343 |
+
sol.append(x)
|
344 |
+
if step < len(t_span) - 1:
|
345 |
+
dt = t_span[step + 1] - t
|
346 |
+
|
347 |
+
return sol[-1]
|
348 |
+
|
349 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
350 |
+
"""Computes diffusion loss
|
351 |
+
|
352 |
+
Args:
|
353 |
+
x1 (torch.Tensor): Target
|
354 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
355 |
+
mask (torch.Tensor): target mask
|
356 |
+
shape: (batch_size, 1, mel_timesteps)
|
357 |
+
mu (torch.Tensor): output of encoder
|
358 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
359 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
360 |
+
shape: (batch_size, spk_emb_dim)
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
loss: conditional flow matching loss
|
364 |
+
y: conditional flow
|
365 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
366 |
+
"""
|
367 |
+
b, _, t = mu.shape
|
368 |
+
|
369 |
+
# random timestep
|
370 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
371 |
+
# sample noise p(x_0)
|
372 |
+
z = torch.randn_like(x1)
|
373 |
+
|
374 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
375 |
+
u = x1 - (1 - self.sigma_min) * z
|
376 |
+
|
377 |
+
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
378 |
+
torch.sum(mask) * u.shape[1]
|
379 |
+
)
|
380 |
+
return loss, y
|
381 |
+
|
382 |
+
|
383 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
384 |
+
"""Make mask tensor containing indices of padded part.
|
385 |
+
|
386 |
+
See description of make_non_pad_mask.
|
387 |
+
|
388 |
+
Args:
|
389 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
390 |
+
Returns:
|
391 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
392 |
+
|
393 |
+
Examples:
|
394 |
+
>>> lengths = [5, 3, 2]
|
395 |
+
>>> make_pad_mask(lengths)
|
396 |
+
masks = [[0, 0, 0, 0 ,0],
|
397 |
+
[0, 0, 0, 1, 1],
|
398 |
+
[0, 0, 1, 1, 1]]
|
399 |
+
"""
|
400 |
+
batch_size = lengths.size(0)
|
401 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
402 |
+
seq_range = torch.arange(0,
|
403 |
+
max_len,
|
404 |
+
dtype=torch.int64,
|
405 |
+
device=lengths.device)
|
406 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
407 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
408 |
+
mask = seq_range_expand >= seq_length_expand
|
409 |
+
return mask
|
410 |
+
|
411 |
+
|
412 |
+
class ConditionalDecoder(nn.Module):
|
413 |
+
def __init__(
|
414 |
+
self,
|
415 |
+
in_channels,
|
416 |
+
out_channels,
|
417 |
+
causal=False,
|
418 |
+
channels=(256, 256),
|
419 |
+
dropout=0.05,
|
420 |
+
attention_head_dim=64,
|
421 |
+
n_blocks=1,
|
422 |
+
num_mid_blocks=2,
|
423 |
+
num_heads=4,
|
424 |
+
act_fn="snake",
|
425 |
+
gradient_checkpointing=True,
|
426 |
+
):
|
427 |
+
"""
|
428 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
429 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
430 |
+
"""
|
431 |
+
super().__init__()
|
432 |
+
channels = tuple(channels)
|
433 |
+
self.in_channels = in_channels
|
434 |
+
self.out_channels = out_channels
|
435 |
+
self.causal = causal
|
436 |
+
self.static_chunk_size = 2 * 25 * 2 # 2*input_frame_rate*token_mel_ratio
|
437 |
+
self.gradient_checkpointing = gradient_checkpointing
|
438 |
+
|
439 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
440 |
+
time_embed_dim = channels[0] * 4
|
441 |
+
self.time_mlp = TimestepEmbedding(
|
442 |
+
in_channels=in_channels,
|
443 |
+
time_embed_dim=time_embed_dim,
|
444 |
+
act_fn="silu",
|
445 |
+
)
|
446 |
+
self.down_blocks = nn.ModuleList([])
|
447 |
+
self.mid_blocks = nn.ModuleList([])
|
448 |
+
self.up_blocks = nn.ModuleList([])
|
449 |
+
|
450 |
+
output_channel = in_channels
|
451 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
452 |
+
input_channel = output_channel
|
453 |
+
output_channel = channels[i]
|
454 |
+
is_last = i == len(channels) - 1
|
455 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
456 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
457 |
+
transformer_blocks = nn.ModuleList(
|
458 |
+
[
|
459 |
+
BasicTransformerBlock(
|
460 |
+
dim=output_channel,
|
461 |
+
num_attention_heads=num_heads,
|
462 |
+
attention_head_dim=attention_head_dim,
|
463 |
+
dropout=dropout,
|
464 |
+
activation_fn=act_fn,
|
465 |
+
)
|
466 |
+
for _ in range(n_blocks)
|
467 |
+
]
|
468 |
+
)
|
469 |
+
downsample = (
|
470 |
+
Downsample1D(output_channel) if not is_last else
|
471 |
+
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
472 |
+
)
|
473 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
474 |
+
|
475 |
+
for _ in range(num_mid_blocks):
|
476 |
+
input_channel = channels[-1]
|
477 |
+
out_channels = channels[-1]
|
478 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
479 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
480 |
+
|
481 |
+
transformer_blocks = nn.ModuleList(
|
482 |
+
[
|
483 |
+
BasicTransformerBlock(
|
484 |
+
dim=output_channel,
|
485 |
+
num_attention_heads=num_heads,
|
486 |
+
attention_head_dim=attention_head_dim,
|
487 |
+
dropout=dropout,
|
488 |
+
activation_fn=act_fn,
|
489 |
+
)
|
490 |
+
for _ in range(n_blocks)
|
491 |
+
]
|
492 |
+
)
|
493 |
+
|
494 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
495 |
+
|
496 |
+
channels = channels[::-1] + (channels[0],)
|
497 |
+
for i in range(len(channels) - 1):
|
498 |
+
input_channel = channels[i] * 2
|
499 |
+
output_channel = channels[i + 1]
|
500 |
+
is_last = i == len(channels) - 2
|
501 |
+
resnet = CausalResnetBlock1D(
|
502 |
+
dim=input_channel,
|
503 |
+
dim_out=output_channel,
|
504 |
+
time_emb_dim=time_embed_dim,
|
505 |
+
) if self.causal else ResnetBlock1D(
|
506 |
+
dim=input_channel,
|
507 |
+
dim_out=output_channel,
|
508 |
+
time_emb_dim=time_embed_dim,
|
509 |
+
)
|
510 |
+
transformer_blocks = nn.ModuleList(
|
511 |
+
[
|
512 |
+
BasicTransformerBlock(
|
513 |
+
dim=output_channel,
|
514 |
+
num_attention_heads=num_heads,
|
515 |
+
attention_head_dim=attention_head_dim,
|
516 |
+
dropout=dropout,
|
517 |
+
activation_fn=act_fn,
|
518 |
+
)
|
519 |
+
for _ in range(n_blocks)
|
520 |
+
]
|
521 |
+
)
|
522 |
+
upsample = (
|
523 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
524 |
+
if not is_last
|
525 |
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
526 |
+
)
|
527 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
528 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
529 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
530 |
+
self.initialize_weights()
|
531 |
+
|
532 |
+
def initialize_weights(self):
|
533 |
+
for m in self.modules():
|
534 |
+
if isinstance(m, nn.Conv1d):
|
535 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
536 |
+
if m.bias is not None:
|
537 |
+
nn.init.constant_(m.bias, 0)
|
538 |
+
elif isinstance(m, nn.GroupNorm):
|
539 |
+
nn.init.constant_(m.weight, 1)
|
540 |
+
nn.init.constant_(m.bias, 0)
|
541 |
+
elif isinstance(m, nn.Linear):
|
542 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
543 |
+
if m.bias is not None:
|
544 |
+
nn.init.constant_(m.bias, 0)
|
545 |
+
|
546 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
547 |
+
"""Forward pass of the UNet1DConditional model.
|
548 |
+
|
549 |
+
Args:
|
550 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
551 |
+
mask (_type_): shape (batch_size, 1, time)
|
552 |
+
t (_type_): shape (batch_size)
|
553 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
554 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
555 |
+
|
556 |
+
Raises:
|
557 |
+
ValueError: _description_
|
558 |
+
ValueError: _description_
|
559 |
+
|
560 |
+
Returns:
|
561 |
+
_type_: _description_
|
562 |
+
"""
|
563 |
+
t = self.time_embeddings(t)
|
564 |
+
t = t.to(x.dtype)
|
565 |
+
t = self.time_mlp(t)
|
566 |
+
x = pack([x, mu], "b * t")[0]
|
567 |
+
mask = mask.to(x.dtype)
|
568 |
+
if spks is not None:
|
569 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
570 |
+
x = pack([x, spks], "b * t")[0]
|
571 |
+
if cond is not None:
|
572 |
+
x = pack([x, cond], "b * t")[0]
|
573 |
+
|
574 |
+
hiddens = []
|
575 |
+
masks = [mask]
|
576 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
577 |
+
mask_down = masks[-1]
|
578 |
+
x = resnet(x, mask_down, t)
|
579 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
580 |
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
581 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
582 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
583 |
+
for transformer_block in transformer_blocks:
|
584 |
+
if self.gradient_checkpointing and self.training:
|
585 |
+
def create_custom_forward(module):
|
586 |
+
def custom_forward(*inputs):
|
587 |
+
return module(*inputs)
|
588 |
+
return custom_forward
|
589 |
+
x = torch.utils.checkpoint.checkpoint(
|
590 |
+
create_custom_forward(transformer_block),
|
591 |
+
x,
|
592 |
+
attn_mask,
|
593 |
+
t,
|
594 |
+
)
|
595 |
+
else:
|
596 |
+
x = transformer_block(
|
597 |
+
hidden_states=x,
|
598 |
+
attention_mask=attn_mask,
|
599 |
+
timestep=t,
|
600 |
+
)
|
601 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
602 |
+
hiddens.append(x) # Save hidden states for skip connections
|
603 |
+
x = downsample(x * mask_down)
|
604 |
+
masks.append(mask_down[:, :, ::2])
|
605 |
+
masks = masks[:-1]
|
606 |
+
mask_mid = masks[-1]
|
607 |
+
|
608 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
609 |
+
x = resnet(x, mask_mid, t)
|
610 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
611 |
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
612 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
613 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
614 |
+
for transformer_block in transformer_blocks:
|
615 |
+
if self.gradient_checkpointing and self.training:
|
616 |
+
def create_custom_forward(module):
|
617 |
+
def custom_forward(*inputs):
|
618 |
+
return module(*inputs)
|
619 |
+
return custom_forward
|
620 |
+
x = torch.utils.checkpoint.checkpoint(
|
621 |
+
create_custom_forward(transformer_block),
|
622 |
+
x,
|
623 |
+
attn_mask,
|
624 |
+
t,
|
625 |
+
)
|
626 |
+
else:
|
627 |
+
x = transformer_block(
|
628 |
+
hidden_states=x,
|
629 |
+
attention_mask=attn_mask,
|
630 |
+
timestep=t,
|
631 |
+
)
|
632 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
633 |
+
|
634 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
635 |
+
mask_up = masks.pop()
|
636 |
+
skip = hiddens.pop()
|
637 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
638 |
+
x = resnet(x, mask_up, t)
|
639 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
640 |
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
641 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
642 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
643 |
+
for transformer_block in transformer_blocks:
|
644 |
+
if self.gradient_checkpointing and self.training:
|
645 |
+
def create_custom_forward(module):
|
646 |
+
def custom_forward(*inputs):
|
647 |
+
return module(*inputs)
|
648 |
+
return custom_forward
|
649 |
+
x = torch.utils.checkpoint.checkpoint(
|
650 |
+
create_custom_forward(transformer_block),
|
651 |
+
x,
|
652 |
+
attn_mask,
|
653 |
+
t,
|
654 |
+
)
|
655 |
+
else:
|
656 |
+
x = transformer_block(
|
657 |
+
hidden_states=x,
|
658 |
+
attention_mask=attn_mask,
|
659 |
+
timestep=t,
|
660 |
+
)
|
661 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
662 |
+
x = upsample(x * mask_up)
|
663 |
+
x = self.final_block(x, mask_up)
|
664 |
+
output = self.final_proj(x * mask_up)
|
665 |
+
return output * mask
|
666 |
+
|
667 |
+
|
668 |
+
class ConditionalCFM(BASECFM):
|
669 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64):
|
670 |
+
super().__init__(
|
671 |
+
n_feats=in_channels,
|
672 |
+
cfm_params=cfm_params,
|
673 |
+
n_spks=n_spks,
|
674 |
+
spk_emb_dim=spk_emb_dim,
|
675 |
+
)
|
676 |
+
self.t_scheduler = cfm_params.t_scheduler
|
677 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
678 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
679 |
+
|
680 |
+
@torch.inference_mode()
|
681 |
+
def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
682 |
+
"""Forward diffusion
|
683 |
+
|
684 |
+
Args:
|
685 |
+
mu (torch.Tensor): output of encoder
|
686 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
687 |
+
mask (torch.Tensor): output_mask
|
688 |
+
shape: (batch_size, 1, mel_timesteps)
|
689 |
+
n_timesteps (int): number of diffusion steps
|
690 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
691 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
692 |
+
shape: (batch_size, spk_emb_dim)
|
693 |
+
cond: Not used but kept for future purposes
|
694 |
+
|
695 |
+
Returns:
|
696 |
+
sample: generated mel-spectrogram
|
697 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
698 |
+
"""
|
699 |
+
z = torch.randn_like(mu) * temperature
|
700 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
701 |
+
if self.t_scheduler == 'cosine':
|
702 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
703 |
+
return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond)
|
704 |
+
|
705 |
+
def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond):
|
706 |
+
"""
|
707 |
+
Fixed euler solver for ODEs.
|
708 |
+
Args:
|
709 |
+
x (torch.Tensor): random noise
|
710 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
711 |
+
shape: (n_timesteps + 1,)
|
712 |
+
mu (torch.Tensor): output of encoder
|
713 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
714 |
+
mask (torch.Tensor): output_mask
|
715 |
+
shape: (batch_size, 1, mel_timesteps)
|
716 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
717 |
+
shape: (batch_size, spk_emb_dim)
|
718 |
+
cond: Not used but kept for future purposes
|
719 |
+
"""
|
720 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
721 |
+
|
722 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
723 |
+
# Or in future might add like a return_all_steps flag
|
724 |
+
sol = []
|
725 |
+
|
726 |
+
for step in range(1, len(t_span)):
|
727 |
+
dphi_dt = estimator(x, mask, mu, t, spks, cond)
|
728 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
729 |
+
if self.inference_cfg_rate > 0:
|
730 |
+
cfg_dphi_dt = estimator(
|
731 |
+
x, mask,
|
732 |
+
torch.zeros_like(mu), t,
|
733 |
+
torch.zeros_like(spks) if spks is not None else None,
|
734 |
+
cond=cond
|
735 |
+
)
|
736 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
|
737 |
+
self.inference_cfg_rate * cfg_dphi_dt)
|
738 |
+
x = x + dt * dphi_dt
|
739 |
+
t = t + dt
|
740 |
+
sol.append(x)
|
741 |
+
if step < len(t_span) - 1:
|
742 |
+
dt = t_span[step + 1] - t
|
743 |
+
|
744 |
+
return sol[-1]
|
745 |
+
|
746 |
+
def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None):
|
747 |
+
"""Computes diffusion loss
|
748 |
+
|
749 |
+
Args:
|
750 |
+
x1 (torch.Tensor): Target
|
751 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
752 |
+
mask (torch.Tensor): target mask
|
753 |
+
shape: (batch_size, 1, mel_timesteps)
|
754 |
+
mu (torch.Tensor): output of encoder
|
755 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
756 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
757 |
+
shape: (batch_size, spk_emb_dim)
|
758 |
+
|
759 |
+
Returns:
|
760 |
+
loss: conditional flow matching loss
|
761 |
+
y: conditional flow
|
762 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
763 |
+
"""
|
764 |
+
org_dtype = x1.dtype
|
765 |
+
|
766 |
+
b, _, t = mu.shape
|
767 |
+
# random timestep
|
768 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
769 |
+
if self.t_scheduler == 'cosine':
|
770 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
771 |
+
# sample noise p(x_0)
|
772 |
+
z = torch.randn_like(x1)
|
773 |
+
|
774 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
775 |
+
u = x1 - (1 - self.sigma_min) * z
|
776 |
+
|
777 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
778 |
+
if self.training_cfg_rate > 0:
|
779 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
780 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
781 |
+
if spks is not None:
|
782 |
+
spks = spks * cfg_mask.view(-1, 1)
|
783 |
+
if cond is not None:
|
784 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
785 |
+
|
786 |
+
pred = estimator(y, mask, mu, t.squeeze(), spks, cond)
|
787 |
+
pred = pred.float()
|
788 |
+
u = u.float()
|
789 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
790 |
+
loss = loss.to(org_dtype)
|
791 |
+
return loss, y
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token_id": 151643,
|
3 |
+
"eos_token_id": 151643,
|
4 |
+
"max_new_tokens": 2048,
|
5 |
+
"transformers_version": "4.45.0.dev0"
|
6 |
+
}
|
generation_utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from queue import Queue
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
|
8 |
+
def _parse_messages(messages, split_role="user"):
|
9 |
+
system, rounds = "", []
|
10 |
+
round = []
|
11 |
+
for i, message in enumerate(messages):
|
12 |
+
if message["role"] == "system":
|
13 |
+
assert i == 0
|
14 |
+
system = message["content"]
|
15 |
+
continue
|
16 |
+
if message["role"] == split_role and round:
|
17 |
+
rounds.append(round)
|
18 |
+
round = []
|
19 |
+
round.append(message)
|
20 |
+
if round:
|
21 |
+
rounds.append(round)
|
22 |
+
return system, rounds
|
23 |
+
|
24 |
+
max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
|
25 |
+
max_input_tokens = model.config.model_max_length - max_new_tokens
|
26 |
+
system, rounds = _parse_messages(messages, split_role="user")
|
27 |
+
system_tokens = tokenizer.encode(system)
|
28 |
+
max_history_tokens = max_input_tokens - len(system_tokens)
|
29 |
+
|
30 |
+
history_tokens = []
|
31 |
+
for round in rounds[::-1]:
|
32 |
+
round_tokens = []
|
33 |
+
for message in round:
|
34 |
+
if message["role"] == "user":
|
35 |
+
round_tokens.append(model.generation_config.user_token_id)
|
36 |
+
else:
|
37 |
+
round_tokens.append(model.generation_config.assistant_token_id)
|
38 |
+
round_tokens.extend(tokenizer.encode(message["content"]))
|
39 |
+
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
|
40 |
+
history_tokens = round_tokens + history_tokens # concat left
|
41 |
+
if len(history_tokens) < max_history_tokens:
|
42 |
+
continue
|
43 |
+
break
|
44 |
+
|
45 |
+
input_tokens = system_tokens + history_tokens
|
46 |
+
if messages[-1]["role"] != "assistant":
|
47 |
+
input_tokens.append(model.generation_config.assistant_token_id)
|
48 |
+
input_tokens = input_tokens[-max_input_tokens:] # truncate left
|
49 |
+
return torch.LongTensor([input_tokens]).to(model.device)
|
50 |
+
|
51 |
+
|
52 |
+
class TextIterStreamer:
|
53 |
+
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
|
54 |
+
self.tokenizer = tokenizer
|
55 |
+
self.skip_prompt = skip_prompt
|
56 |
+
self.skip_special_tokens = skip_special_tokens
|
57 |
+
self.tokens = []
|
58 |
+
self.text_queue = Queue()
|
59 |
+
self.next_tokens_are_prompt = True
|
60 |
+
|
61 |
+
def put(self, value):
|
62 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
63 |
+
self.next_tokens_are_prompt = False
|
64 |
+
else:
|
65 |
+
if len(value.shape) > 1:
|
66 |
+
value = value[0]
|
67 |
+
self.tokens.extend(value.tolist())
|
68 |
+
self.text_queue.put(
|
69 |
+
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
|
70 |
+
|
71 |
+
def end(self):
|
72 |
+
self.text_queue.put(None)
|
73 |
+
|
74 |
+
def __iter__(self):
|
75 |
+
return self
|
76 |
+
|
77 |
+
def __next__(self):
|
78 |
+
value = self.text_queue.get()
|
79 |
+
if value is None:
|
80 |
+
raise StopIteration()
|
81 |
+
else:
|
82 |
+
return value
|
83 |
+
|
matcha_components.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
|
2 |
+
"""
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2023 Shivam Mehta
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
"""
|
25 |
+
|
26 |
+
import math
|
27 |
+
from typing import Optional
|
28 |
+
|
29 |
+
import torch
|
30 |
+
import torch.nn as nn
|
31 |
+
import torch.nn.functional as F
|
32 |
+
|
33 |
+
from diffusers.models.activations import get_activation
|
34 |
+
|
35 |
+
|
36 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
37 |
+
def __init__(self, dim):
|
38 |
+
super().__init__()
|
39 |
+
self.dim = dim
|
40 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
41 |
+
|
42 |
+
def forward(self, x, scale=1000):
|
43 |
+
if x.ndim < 1:
|
44 |
+
x = x.unsqueeze(0)
|
45 |
+
device = x.device
|
46 |
+
half_dim = self.dim // 2
|
47 |
+
emb = math.log(10000) / (half_dim - 1)
|
48 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
49 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
50 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
51 |
+
return emb
|
52 |
+
|
53 |
+
|
54 |
+
class Block1D(torch.nn.Module):
|
55 |
+
def __init__(self, dim, dim_out, groups=8):
|
56 |
+
super().__init__()
|
57 |
+
self.block = torch.nn.Sequential(
|
58 |
+
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
59 |
+
torch.nn.GroupNorm(groups, dim_out),
|
60 |
+
nn.Mish(),
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, x, mask):
|
64 |
+
output = self.block(x * mask)
|
65 |
+
return output * mask
|
66 |
+
|
67 |
+
|
68 |
+
class ResnetBlock1D(torch.nn.Module):
|
69 |
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
70 |
+
super().__init__()
|
71 |
+
self.mlp = torch.nn.Sequential(
|
72 |
+
nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
|
73 |
+
)
|
74 |
+
|
75 |
+
self.block1 = Block1D(dim, dim_out, groups=groups)
|
76 |
+
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
77 |
+
|
78 |
+
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
79 |
+
|
80 |
+
def forward(self, x, mask, time_emb):
|
81 |
+
h = self.block1(x, mask)
|
82 |
+
h += self.mlp(time_emb).unsqueeze(-1)
|
83 |
+
h = self.block2(h, mask)
|
84 |
+
output = h + self.res_conv(x * mask)
|
85 |
+
return output
|
86 |
+
|
87 |
+
|
88 |
+
class Downsample1D(nn.Module):
|
89 |
+
def __init__(self, dim):
|
90 |
+
super().__init__()
|
91 |
+
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
return self.conv(x)
|
95 |
+
|
96 |
+
|
97 |
+
class TimestepEmbedding(nn.Module):
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
in_channels: int,
|
101 |
+
time_embed_dim: int,
|
102 |
+
act_fn: str = "silu",
|
103 |
+
out_dim: int = None,
|
104 |
+
post_act_fn: Optional[str] = None,
|
105 |
+
cond_proj_dim=None,
|
106 |
+
):
|
107 |
+
super().__init__()
|
108 |
+
|
109 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
110 |
+
|
111 |
+
if cond_proj_dim is not None:
|
112 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
113 |
+
else:
|
114 |
+
self.cond_proj = None
|
115 |
+
|
116 |
+
self.act = get_activation(act_fn)
|
117 |
+
|
118 |
+
if out_dim is not None:
|
119 |
+
time_embed_dim_out = out_dim
|
120 |
+
else:
|
121 |
+
time_embed_dim_out = time_embed_dim
|
122 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
123 |
+
|
124 |
+
if post_act_fn is None:
|
125 |
+
self.post_act = None
|
126 |
+
else:
|
127 |
+
self.post_act = get_activation(post_act_fn)
|
128 |
+
|
129 |
+
def forward(self, sample, condition=None):
|
130 |
+
if condition is not None:
|
131 |
+
sample = sample + self.cond_proj(condition)
|
132 |
+
sample = self.linear_1(sample)
|
133 |
+
|
134 |
+
if self.act is not None:
|
135 |
+
sample = self.act(sample)
|
136 |
+
|
137 |
+
sample = self.linear_2(sample)
|
138 |
+
|
139 |
+
if self.post_act is not None:
|
140 |
+
sample = self.post_act(sample)
|
141 |
+
return sample
|
142 |
+
|
143 |
+
|
144 |
+
class Upsample1D(nn.Module):
|
145 |
+
"""A 1D upsampling layer with an optional convolution.
|
146 |
+
|
147 |
+
Parameters:
|
148 |
+
channels (`int`):
|
149 |
+
number of channels in the inputs and outputs.
|
150 |
+
use_conv (`bool`, default `False`):
|
151 |
+
option to use a convolution.
|
152 |
+
use_conv_transpose (`bool`, default `False`):
|
153 |
+
option to use a convolution transpose.
|
154 |
+
out_channels (`int`, optional):
|
155 |
+
number of output channels. Defaults to `channels`.
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
channels,
|
161 |
+
use_conv=False,
|
162 |
+
use_conv_transpose=True,
|
163 |
+
out_channels=None,
|
164 |
+
name="conv",
|
165 |
+
):
|
166 |
+
super().__init__()
|
167 |
+
self.channels = channels
|
168 |
+
self.out_channels = out_channels or channels
|
169 |
+
self.use_conv = use_conv
|
170 |
+
self.use_conv_transpose = use_conv_transpose
|
171 |
+
self.name = name
|
172 |
+
|
173 |
+
self.conv = None
|
174 |
+
if use_conv_transpose:
|
175 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
176 |
+
elif use_conv:
|
177 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
178 |
+
|
179 |
+
def forward(self, inputs):
|
180 |
+
assert inputs.shape[1] == self.channels
|
181 |
+
if self.use_conv_transpose:
|
182 |
+
return self.conv(inputs)
|
183 |
+
|
184 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
185 |
+
|
186 |
+
if self.use_conv:
|
187 |
+
outputs = self.conv(outputs)
|
188 |
+
|
189 |
+
return outputs
|
matcha_feat.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
|
2 |
+
"""
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2023 Shivam Mehta
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
"""
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
import torch.utils.data
|
29 |
+
from librosa.filters import mel as librosa_mel_fn
|
30 |
+
from scipy.io.wavfile import read
|
31 |
+
|
32 |
+
MAX_WAV_VALUE = 32768.0
|
33 |
+
|
34 |
+
|
35 |
+
def load_wav(full_path):
|
36 |
+
sampling_rate, data = read(full_path)
|
37 |
+
return data, sampling_rate
|
38 |
+
|
39 |
+
|
40 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
41 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
42 |
+
|
43 |
+
|
44 |
+
def dynamic_range_decompression(x, C=1):
|
45 |
+
return np.exp(x) / C
|
46 |
+
|
47 |
+
|
48 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
49 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
50 |
+
|
51 |
+
|
52 |
+
def dynamic_range_decompression_torch(x, C=1):
|
53 |
+
return torch.exp(x) / C
|
54 |
+
|
55 |
+
|
56 |
+
def spectral_normalize_torch(magnitudes):
|
57 |
+
output = dynamic_range_compression_torch(magnitudes)
|
58 |
+
return output
|
59 |
+
|
60 |
+
|
61 |
+
def spectral_de_normalize_torch(magnitudes):
|
62 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
mel_basis = {}
|
67 |
+
hann_window = {}
|
68 |
+
|
69 |
+
|
70 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
71 |
+
if torch.min(y) < -1.0:
|
72 |
+
print("min value is ", torch.min(y))
|
73 |
+
if torch.max(y) > 1.0:
|
74 |
+
print("max value is ", torch.max(y))
|
75 |
+
|
76 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
77 |
+
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
78 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
79 |
+
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
80 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
81 |
+
|
82 |
+
y = torch.nn.functional.pad(
|
83 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
84 |
+
)
|
85 |
+
y = y.squeeze(1)
|
86 |
+
|
87 |
+
spec = torch.view_as_real(
|
88 |
+
torch.stft(
|
89 |
+
y,
|
90 |
+
n_fft,
|
91 |
+
hop_length=hop_size,
|
92 |
+
win_length=win_size,
|
93 |
+
window=hann_window[str(y.device)],
|
94 |
+
center=center,
|
95 |
+
pad_mode="reflect",
|
96 |
+
normalized=False,
|
97 |
+
onesided=True,
|
98 |
+
return_complex=True,
|
99 |
+
)
|
100 |
+
)
|
101 |
+
|
102 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
103 |
+
|
104 |
+
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
105 |
+
spec = spectral_normalize_torch(spec)
|
106 |
+
|
107 |
+
return spec
|
matcha_transformer.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
|
2 |
+
"""
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2023 Shivam Mehta
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
"""
|
25 |
+
|
26 |
+
from typing import Any, Dict, Optional
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.nn as nn
|
30 |
+
from diffusers.models.attention import (
|
31 |
+
GEGLU,
|
32 |
+
GELU,
|
33 |
+
AdaLayerNorm,
|
34 |
+
AdaLayerNormZero,
|
35 |
+
ApproximateGELU,
|
36 |
+
)
|
37 |
+
from diffusers.models.attention_processor import Attention
|
38 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
39 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
40 |
+
|
41 |
+
import torch.nn.functional as F
|
42 |
+
from flash_attn import flash_attn_varlen_func
|
43 |
+
|
44 |
+
|
45 |
+
def get_sequence_mask(inputs, inputs_length):
|
46 |
+
if inputs.dim() == 3:
|
47 |
+
bsz, tgt_len, _ = inputs.size()
|
48 |
+
else:
|
49 |
+
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
|
50 |
+
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
|
51 |
+
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(
|
52 |
+
bsz, tgt_len, 1
|
53 |
+
)
|
54 |
+
unpacking_index = (
|
55 |
+
torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
|
56 |
+
) # 转成下标
|
57 |
+
return sequence_mask, unpacking_index
|
58 |
+
|
59 |
+
|
60 |
+
class OmniWhisperAttention(nn.Module):
|
61 |
+
def __init__(self, embed_dim, num_heads, causal=False):
|
62 |
+
super().__init__()
|
63 |
+
self.embed_dim = embed_dim
|
64 |
+
self.num_heads = num_heads
|
65 |
+
self.head_dim = embed_dim // num_heads
|
66 |
+
|
67 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
68 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
69 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
70 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
71 |
+
|
72 |
+
self.causal = causal
|
73 |
+
|
74 |
+
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
|
75 |
+
bsz, _ = hidden_states.size()
|
76 |
+
|
77 |
+
query_states = self.q_proj(hidden_states).view(
|
78 |
+
bsz, self.num_heads, self.head_dim
|
79 |
+
)
|
80 |
+
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
81 |
+
value_states = self.v_proj(hidden_states).view(
|
82 |
+
bsz, self.num_heads, self.head_dim
|
83 |
+
)
|
84 |
+
|
85 |
+
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(
|
86 |
+
torch.int32
|
87 |
+
)
|
88 |
+
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
|
89 |
+
attn_output = flash_attn_varlen_func(
|
90 |
+
query_states,
|
91 |
+
key_states,
|
92 |
+
value_states,
|
93 |
+
cu_len,
|
94 |
+
cu_len,
|
95 |
+
max_seqlen,
|
96 |
+
max_seqlen,
|
97 |
+
causal=self.causal,
|
98 |
+
) # (bsz * qlen, nheads, headdim)
|
99 |
+
attn_output = attn_output.reshape(bsz, self.embed_dim)
|
100 |
+
attn_output = self.out_proj(attn_output)
|
101 |
+
return attn_output
|
102 |
+
|
103 |
+
|
104 |
+
class SnakeBeta(nn.Module):
|
105 |
+
"""
|
106 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
107 |
+
Shape:
|
108 |
+
- Input: (B, C, T)
|
109 |
+
- Output: (B, C, T), same shape as the input
|
110 |
+
Parameters:
|
111 |
+
- alpha - trainable parameter that controls frequency
|
112 |
+
- beta - trainable parameter that controls magnitude
|
113 |
+
References:
|
114 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
115 |
+
https://arxiv.org/abs/2006.08195
|
116 |
+
Examples:
|
117 |
+
>>> a1 = snakebeta(256)
|
118 |
+
>>> x = torch.randn(256)
|
119 |
+
>>> x = a1(x)
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
in_features,
|
125 |
+
out_features,
|
126 |
+
alpha=1.0,
|
127 |
+
alpha_trainable=True,
|
128 |
+
alpha_logscale=True,
|
129 |
+
):
|
130 |
+
"""
|
131 |
+
Initialization.
|
132 |
+
INPUT:
|
133 |
+
- in_features: shape of the input
|
134 |
+
- alpha - trainable parameter that controls frequency
|
135 |
+
- beta - trainable parameter that controls magnitude
|
136 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
137 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
138 |
+
alpha will be trained along with the rest of your model.
|
139 |
+
"""
|
140 |
+
super().__init__()
|
141 |
+
self.in_features = (
|
142 |
+
out_features if isinstance(out_features, list) else [out_features]
|
143 |
+
)
|
144 |
+
self.proj = LoRACompatibleLinear(in_features, out_features)
|
145 |
+
|
146 |
+
# initialize alpha
|
147 |
+
self.alpha_logscale = alpha_logscale
|
148 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
149 |
+
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
150 |
+
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
151 |
+
else: # linear scale alphas initialized to ones
|
152 |
+
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
|
153 |
+
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
|
154 |
+
|
155 |
+
self.alpha.requires_grad = alpha_trainable
|
156 |
+
self.beta.requires_grad = alpha_trainable
|
157 |
+
|
158 |
+
self.no_div_by_zero = 0.000000001
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
"""
|
162 |
+
Forward pass of the function.
|
163 |
+
Applies the function to the input elementwise.
|
164 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
165 |
+
"""
|
166 |
+
x = self.proj(x)
|
167 |
+
if self.alpha_logscale:
|
168 |
+
alpha = torch.exp(self.alpha)
|
169 |
+
beta = torch.exp(self.beta)
|
170 |
+
else:
|
171 |
+
alpha = self.alpha
|
172 |
+
beta = self.beta
|
173 |
+
|
174 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
|
175 |
+
torch.sin(x * alpha), 2
|
176 |
+
)
|
177 |
+
|
178 |
+
return x
|
179 |
+
|
180 |
+
|
181 |
+
class FeedForward(nn.Module):
|
182 |
+
r"""
|
183 |
+
A feed-forward layer.
|
184 |
+
|
185 |
+
Parameters:
|
186 |
+
dim (`int`): The number of channels in the input.
|
187 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
188 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
189 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
190 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
191 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
192 |
+
"""
|
193 |
+
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
dim: int,
|
197 |
+
dim_out: Optional[int] = None,
|
198 |
+
mult: int = 4,
|
199 |
+
dropout: float = 0.0,
|
200 |
+
activation_fn: str = "geglu",
|
201 |
+
final_dropout: bool = False,
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
inner_dim = int(dim * mult)
|
205 |
+
dim_out = dim_out if dim_out is not None else dim
|
206 |
+
|
207 |
+
if activation_fn == "gelu":
|
208 |
+
act_fn = GELU(dim, inner_dim)
|
209 |
+
if activation_fn == "gelu-approximate":
|
210 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
211 |
+
elif activation_fn == "geglu":
|
212 |
+
act_fn = GEGLU(dim, inner_dim)
|
213 |
+
elif activation_fn == "geglu-approximate":
|
214 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
215 |
+
elif activation_fn == "snakebeta":
|
216 |
+
act_fn = SnakeBeta(dim, inner_dim)
|
217 |
+
|
218 |
+
self.net = nn.ModuleList([])
|
219 |
+
# project in
|
220 |
+
self.net.append(act_fn)
|
221 |
+
# project dropout
|
222 |
+
self.net.append(nn.Dropout(dropout))
|
223 |
+
# project out
|
224 |
+
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
225 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
226 |
+
if final_dropout:
|
227 |
+
self.net.append(nn.Dropout(dropout))
|
228 |
+
|
229 |
+
def forward(self, hidden_states):
|
230 |
+
for module in self.net:
|
231 |
+
hidden_states = module(hidden_states)
|
232 |
+
return hidden_states
|
233 |
+
|
234 |
+
|
235 |
+
@maybe_allow_in_graph
|
236 |
+
class BasicTransformerBlock(nn.Module):
|
237 |
+
r"""
|
238 |
+
A basic Transformer block.
|
239 |
+
|
240 |
+
Parameters:
|
241 |
+
dim (`int`): The number of channels in the input and output.
|
242 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
243 |
+
attention_head_dim (`int`): The number of channels in each head.
|
244 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
245 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
246 |
+
only_cross_attention (`bool`, *optional*):
|
247 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
248 |
+
double_self_attention (`bool`, *optional*):
|
249 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
250 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
251 |
+
num_embeds_ada_norm (:
|
252 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
253 |
+
attention_bias (:
|
254 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
255 |
+
"""
|
256 |
+
|
257 |
+
def __init__(
|
258 |
+
self,
|
259 |
+
dim: int,
|
260 |
+
num_attention_heads: int,
|
261 |
+
attention_head_dim: int,
|
262 |
+
dropout=0.0,
|
263 |
+
cross_attention_dim: Optional[int] = None,
|
264 |
+
activation_fn: str = "geglu",
|
265 |
+
num_embeds_ada_norm: Optional[int] = None,
|
266 |
+
attention_bias: bool = False,
|
267 |
+
only_cross_attention: bool = False,
|
268 |
+
double_self_attention: bool = False,
|
269 |
+
upcast_attention: bool = False,
|
270 |
+
norm_elementwise_affine: bool = True,
|
271 |
+
norm_type: str = "layer_norm",
|
272 |
+
final_dropout: bool = False,
|
273 |
+
use_omni_attn: bool = False,
|
274 |
+
):
|
275 |
+
super().__init__()
|
276 |
+
|
277 |
+
self.use_omni_attn = use_omni_attn
|
278 |
+
self.dim = dim
|
279 |
+
|
280 |
+
self.only_cross_attention = only_cross_attention
|
281 |
+
|
282 |
+
self.use_ada_layer_norm_zero = (
|
283 |
+
num_embeds_ada_norm is not None
|
284 |
+
) and norm_type == "ada_norm_zero"
|
285 |
+
self.use_ada_layer_norm = (
|
286 |
+
num_embeds_ada_norm is not None
|
287 |
+
) and norm_type == "ada_norm"
|
288 |
+
|
289 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
290 |
+
raise ValueError(
|
291 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
292 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
293 |
+
)
|
294 |
+
|
295 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
296 |
+
# 1. Self-Attn
|
297 |
+
if self.use_ada_layer_norm:
|
298 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
299 |
+
elif self.use_ada_layer_norm_zero:
|
300 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
301 |
+
else:
|
302 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
303 |
+
|
304 |
+
if self.use_omni_attn:
|
305 |
+
if only_cross_attention:
|
306 |
+
raise NotImplementedError
|
307 |
+
print(
|
308 |
+
"Use OmniWhisperAttention with flash attention. Dropout is ignored."
|
309 |
+
)
|
310 |
+
self.attn1 = OmniWhisperAttention(
|
311 |
+
embed_dim=dim, num_heads=num_attention_heads, causal=False
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
self.attn1 = Attention(
|
315 |
+
query_dim=dim,
|
316 |
+
heads=num_attention_heads,
|
317 |
+
dim_head=attention_head_dim,
|
318 |
+
dropout=dropout,
|
319 |
+
bias=attention_bias,
|
320 |
+
cross_attention_dim=(
|
321 |
+
cross_attention_dim if only_cross_attention else None
|
322 |
+
),
|
323 |
+
upcast_attention=upcast_attention,
|
324 |
+
)
|
325 |
+
|
326 |
+
# 2. Cross-Attn
|
327 |
+
if cross_attention_dim is not None or double_self_attention:
|
328 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
329 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
330 |
+
# the second cross attention block.
|
331 |
+
self.norm2 = (
|
332 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
333 |
+
if self.use_ada_layer_norm
|
334 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
335 |
+
)
|
336 |
+
self.attn2 = Attention(
|
337 |
+
query_dim=dim,
|
338 |
+
cross_attention_dim=(
|
339 |
+
cross_attention_dim if not double_self_attention else None
|
340 |
+
),
|
341 |
+
heads=num_attention_heads,
|
342 |
+
dim_head=attention_head_dim,
|
343 |
+
dropout=dropout,
|
344 |
+
bias=attention_bias,
|
345 |
+
upcast_attention=upcast_attention,
|
346 |
+
# scale_qk=False, # uncomment this to not to use flash attention
|
347 |
+
) # is self-attn if encoder_hidden_states is none
|
348 |
+
else:
|
349 |
+
self.norm2 = None
|
350 |
+
self.attn2 = None
|
351 |
+
|
352 |
+
# 3. Feed-forward
|
353 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
354 |
+
self.ff = FeedForward(
|
355 |
+
dim,
|
356 |
+
dropout=dropout,
|
357 |
+
activation_fn=activation_fn,
|
358 |
+
final_dropout=final_dropout,
|
359 |
+
)
|
360 |
+
|
361 |
+
# let chunk size default to None
|
362 |
+
self._chunk_size = None
|
363 |
+
self._chunk_dim = 0
|
364 |
+
|
365 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
366 |
+
# Sets chunk feed-forward
|
367 |
+
self._chunk_size = chunk_size
|
368 |
+
self._chunk_dim = dim
|
369 |
+
|
370 |
+
def forward(
|
371 |
+
self,
|
372 |
+
hidden_states: torch.FloatTensor,
|
373 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
374 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
375 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
376 |
+
timestep: Optional[torch.LongTensor] = None,
|
377 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
378 |
+
class_labels: Optional[torch.LongTensor] = None,
|
379 |
+
):
|
380 |
+
|
381 |
+
bsz, tgt_len, d_model = hidden_states.shape
|
382 |
+
|
383 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
384 |
+
# 1. Self-Attention
|
385 |
+
if self.use_ada_layer_norm:
|
386 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
387 |
+
elif self.use_ada_layer_norm_zero:
|
388 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
389 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
norm_hidden_states = self.norm1(hidden_states)
|
393 |
+
|
394 |
+
cross_attention_kwargs = (
|
395 |
+
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
396 |
+
)
|
397 |
+
|
398 |
+
if self.use_omni_attn:
|
399 |
+
seq_len = attention_mask[:, 0, :].float().long().sum(dim=1)
|
400 |
+
var_len_attention_mask, unpacking_index = get_sequence_mask(
|
401 |
+
norm_hidden_states, seq_len
|
402 |
+
)
|
403 |
+
norm_hidden_states = torch.masked_select(
|
404 |
+
norm_hidden_states, var_len_attention_mask
|
405 |
+
)
|
406 |
+
norm_hidden_states = norm_hidden_states.view(torch.sum(seq_len), self.dim)
|
407 |
+
attn_output = self.attn1(norm_hidden_states, seq_len)
|
408 |
+
# unpacking
|
409 |
+
attn_output = torch.index_select(attn_output, 0, unpacking_index).view(
|
410 |
+
bsz, tgt_len, d_model
|
411 |
+
)
|
412 |
+
attn_output = torch.where(var_len_attention_mask, attn_output, 0)
|
413 |
+
else:
|
414 |
+
attn_output = self.attn1(
|
415 |
+
norm_hidden_states,
|
416 |
+
encoder_hidden_states=(
|
417 |
+
encoder_hidden_states if self.only_cross_attention else None
|
418 |
+
),
|
419 |
+
attention_mask=(
|
420 |
+
encoder_attention_mask
|
421 |
+
if self.only_cross_attention
|
422 |
+
else attention_mask
|
423 |
+
),
|
424 |
+
**cross_attention_kwargs,
|
425 |
+
)
|
426 |
+
|
427 |
+
if self.use_ada_layer_norm_zero:
|
428 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
429 |
+
hidden_states = attn_output + hidden_states
|
430 |
+
|
431 |
+
# 2. Cross-Attention
|
432 |
+
if self.attn2 is not None:
|
433 |
+
norm_hidden_states = (
|
434 |
+
self.norm2(hidden_states, timestep)
|
435 |
+
if self.use_ada_layer_norm
|
436 |
+
else self.norm2(hidden_states)
|
437 |
+
)
|
438 |
+
|
439 |
+
attn_output = self.attn2(
|
440 |
+
norm_hidden_states,
|
441 |
+
encoder_hidden_states=encoder_hidden_states,
|
442 |
+
attention_mask=encoder_attention_mask,
|
443 |
+
**cross_attention_kwargs,
|
444 |
+
)
|
445 |
+
hidden_states = attn_output + hidden_states
|
446 |
+
|
447 |
+
# 3. Feed-forward
|
448 |
+
norm_hidden_states = self.norm3(hidden_states)
|
449 |
+
|
450 |
+
if self.use_ada_layer_norm_zero:
|
451 |
+
norm_hidden_states = (
|
452 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
453 |
+
)
|
454 |
+
|
455 |
+
if self._chunk_size is not None:
|
456 |
+
# "feed_forward_chunk_size" can be used to save memory
|
457 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
458 |
+
raise ValueError(
|
459 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
460 |
+
)
|
461 |
+
|
462 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
463 |
+
ff_output = torch.cat(
|
464 |
+
[
|
465 |
+
self.ff(hid_slice)
|
466 |
+
for hid_slice in norm_hidden_states.chunk(
|
467 |
+
num_chunks, dim=self._chunk_dim
|
468 |
+
)
|
469 |
+
],
|
470 |
+
dim=self._chunk_dim,
|
471 |
+
)
|
472 |
+
else:
|
473 |
+
ff_output = self.ff(norm_hidden_states)
|
474 |
+
|
475 |
+
if self.use_ada_layer_norm_zero:
|
476 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
477 |
+
|
478 |
+
hidden_states = ff_output + hidden_states
|
479 |
+
|
480 |
+
return hidden_states
|
model-00001-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db38ae1d3c918180550c64e1105decf6cd9dfe71e92c800d153788d689393035
|
3 |
+
size 4966661936
|
model-00002-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f324edf3f066581664bceeb992dccb9e1b8134ad16d79e9df8383cd865d9a680
|
3 |
+
size 4991490856
|
model-00003-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f3fa8dbe0b7d76565c529f6f15ac0965ab1bd9c07653b3c77f868ac160d40e7d
|
3 |
+
size 4932746512
|
model-00004-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c4212ed32a9bb3ad957a4aaea33e8f8a840ea762c99876984297499e42505e9
|
3 |
+
size 4987094096
|
model-00005-of-00005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fc6ca933cdfb81e1e8ee1cc5dfc1f3885243041faf3a11e3146bdae04ef93e1
|
3 |
+
size 2565519592
|
model.safetensors.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modeling_omni.py
ADDED
@@ -0,0 +1,1011 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Baichuan Inc. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
6 |
+
# and OPT implementations in this library. It has been modified from its
|
7 |
+
# original forms to accommodate minor architectural differences compared
|
8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
9 |
+
#
|
10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11 |
+
# you may not use this file except in compliance with the License.
|
12 |
+
# You may obtain a copy of the License at
|
13 |
+
#
|
14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15 |
+
#
|
16 |
+
# Unless required by applicable law or agreed to in writing, software
|
17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19 |
+
# See the License for the specific language governing permissions and
|
20 |
+
# limitations under the License.
|
21 |
+
""" PyTorch omni model."""
|
22 |
+
import os
|
23 |
+
import time
|
24 |
+
import json
|
25 |
+
import math
|
26 |
+
import numpy as np
|
27 |
+
from typing import List, Optional, Tuple, Union, Any
|
28 |
+
from threading import Thread
|
29 |
+
from easydict import EasyDict
|
30 |
+
|
31 |
+
import torch
|
32 |
+
import torch.distributed
|
33 |
+
import torch.utils.checkpoint
|
34 |
+
from torch import nn
|
35 |
+
from torch.nn import CrossEntropyLoss
|
36 |
+
from torch.nn import functional as F
|
37 |
+
import torch.distributed as dist
|
38 |
+
from transformers import PreTrainedModel
|
39 |
+
from transformers.activations import ACT2FN
|
40 |
+
from dataclasses import dataclass
|
41 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
42 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
43 |
+
from transformers.generation.utils import GenerationConfig
|
44 |
+
from transformers.utils import logging
|
45 |
+
# import for dynamic import not used in this file
|
46 |
+
from .vector_quantize import VectorQuantize, EuclideanCodebook
|
47 |
+
from .matcha_components import (
|
48 |
+
SinusoidalPosEmb,
|
49 |
+
Block1D,
|
50 |
+
ResnetBlock1D,
|
51 |
+
Downsample1D,
|
52 |
+
TimestepEmbedding,
|
53 |
+
Upsample1D,
|
54 |
+
)
|
55 |
+
from .matcha_transformer import BasicTransformerBlock
|
56 |
+
from .flow_matching import ConditionalDecoder, ConditionalCFM
|
57 |
+
|
58 |
+
from .configuration_omni import OmniConfig
|
59 |
+
from .audio_modeling_omni import (RMSNorm,
|
60 |
+
OmniAudioEncoder,
|
61 |
+
OmniAudioDecoder,
|
62 |
+
OmniAudioVQBridgeTokenizer,
|
63 |
+
OmniAudioFlowMatchingDecoder)
|
64 |
+
from .visual_modeling_omni import OmniVisualEncoder, OmniVisualBridge
|
65 |
+
from .processor_omni import OmniMMProcessor
|
66 |
+
|
67 |
+
# support model path contain point(.)
|
68 |
+
try:
|
69 |
+
# step1: copy relative imports to transformers_modules
|
70 |
+
from .generation_utils import build_chat_input, TextIterStreamer
|
71 |
+
from .sequence_parallel_utils import (
|
72 |
+
create_attention_layer,
|
73 |
+
get_sequence_parallel_size,
|
74 |
+
get_sequence_parallel_chunk,
|
75 |
+
)
|
76 |
+
except ModuleNotFoundError:
|
77 |
+
# step2: direct import from transformers_modules
|
78 |
+
try: # bypass check_imports failure
|
79 |
+
import sys
|
80 |
+
sys.path.append(os.path.dirname(__file__))
|
81 |
+
from generation_utils import build_chat_input, TextIterStreamer
|
82 |
+
from sequence_parallel_utils import (
|
83 |
+
create_attention_layer,
|
84 |
+
get_sequence_parallel_size,
|
85 |
+
get_sequence_parallel_chunk,
|
86 |
+
)
|
87 |
+
except Exception:
|
88 |
+
raise
|
89 |
+
|
90 |
+
logger = logging.get_logger(__name__)
|
91 |
+
|
92 |
+
def get_slopes(n):
|
93 |
+
def get_slopes_power_of_2(n):
|
94 |
+
start = (2 ** (-2 ** -(math.log2(n) - 3)))
|
95 |
+
ratio = start
|
96 |
+
return [start * ratio ** i for i in range(n)]
|
97 |
+
|
98 |
+
if math.log2(n).is_integer():
|
99 |
+
return get_slopes_power_of_2(
|
100 |
+
n) # In the paper, we only train models that have 2^a heads for some a. This function has
|
101 |
+
else: # some good properties that only occur when the input is a power of 2. To maintain that even
|
102 |
+
closest_power_of_2 = 2 ** math.floor(
|
103 |
+
math.log2(n)) # when the number of heads is not a power of 2, we use this workaround.
|
104 |
+
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
|
105 |
+
:n - closest_power_of_2]
|
106 |
+
|
107 |
+
|
108 |
+
class RotaryEmbedding(torch.nn.Module):
|
109 |
+
def __init__(self, dim, max_position_embeddings=2048, base=5e6, device=None):
|
110 |
+
super().__init__()
|
111 |
+
# 修复RePE初始化精度问题 https://zhuanlan.zhihu.com/p/678963442
|
112 |
+
# DeepSpeed 会 Hack torch.arange 强制在 GPU 上运行,这里使用原生的 torch.arange
|
113 |
+
try:
|
114 |
+
import deepspeed
|
115 |
+
self.arange = deepspeed.runtime.zero.partition_parameters._orig_torch_arange
|
116 |
+
except:
|
117 |
+
self.arange = torch.arange
|
118 |
+
|
119 |
+
self.inv_freq = 1.0 / (base ** (self.arange(0, dim, 2).float().to(device) / dim))
|
120 |
+
self.max_seq_len_cached = max_position_embeddings
|
121 |
+
t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
122 |
+
freqs = torch.outer(t, self.inv_freq)
|
123 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
124 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
|
125 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
|
126 |
+
|
127 |
+
def forward(self, x, seq_len=None):
|
128 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
129 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
130 |
+
if seq_len > self.max_seq_len_cached:
|
131 |
+
self.max_seq_len_cached = seq_len
|
132 |
+
t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
133 |
+
freqs = torch.outer(t, self.inv_freq)
|
134 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
135 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
|
136 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
|
137 |
+
return (
|
138 |
+
self.cos_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
|
139 |
+
self.sin_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
def rotate_half(x):
|
144 |
+
"""Rotates half the hidden dims of the input."""
|
145 |
+
x1 = x[..., : x.shape[-1] // 2]
|
146 |
+
x2 = x[..., x.shape[-1] // 2:]
|
147 |
+
return torch.cat((-x2, x1), dim=-1)
|
148 |
+
|
149 |
+
|
150 |
+
def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
|
151 |
+
cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
|
152 |
+
sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
|
153 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
154 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
155 |
+
q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
|
156 |
+
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
|
157 |
+
return q_embed.to(q.dtype), k_embed.to(k.dtype)
|
158 |
+
|
159 |
+
|
160 |
+
class MLP(nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
hidden_size: int,
|
164 |
+
intermediate_size: int,
|
165 |
+
hidden_act: str,
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
169 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
170 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
171 |
+
self.act_fn = ACT2FN[hidden_act]
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
175 |
+
|
176 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
177 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
178 |
+
"""
|
179 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
180 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
181 |
+
"""
|
182 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
183 |
+
if n_rep == 1:
|
184 |
+
return hidden_states
|
185 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
186 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
187 |
+
|
188 |
+
|
189 |
+
class Attention(nn.Module):
|
190 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
191 |
+
def __init__(self, config: OmniConfig, is_sparse=False):
|
192 |
+
super().__init__()
|
193 |
+
self.config = config
|
194 |
+
self.position_embedding_type = config.position_embedding_type.lower()
|
195 |
+
self.num_kv_heads = config.num_key_value_heads
|
196 |
+
self.head_dim = config.head_dim
|
197 |
+
self.hidden_size = config.num_attention_heads * self.head_dim
|
198 |
+
self.hidden_kv_size = self.num_kv_heads * self.head_dim
|
199 |
+
|
200 |
+
if is_sparse:
|
201 |
+
self.num_heads = config.sparse_attention_heads
|
202 |
+
assert self.num_kv_heads == config.num_attention_heads
|
203 |
+
self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=config.attention_qkv_bias)
|
204 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
205 |
+
else:
|
206 |
+
self.num_heads = config.num_attention_heads
|
207 |
+
if self.config.attention_qkv_pack:
|
208 |
+
self.W_pack = nn.Linear(config.hidden_size, self.hidden_size + self.hidden_kv_size * 2, bias=config.attention_qkv_bias)
|
209 |
+
else:
|
210 |
+
self.q_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=config.attention_qkv_bias)
|
211 |
+
self.k_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
|
212 |
+
self.v_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
|
213 |
+
|
214 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
215 |
+
|
216 |
+
if self.position_embedding_type == 'rope':
|
217 |
+
self.rotary_emb = RotaryEmbedding(
|
218 |
+
dim=self.head_dim,
|
219 |
+
max_position_embeddings=config.max_position_embeddings,
|
220 |
+
base=config.get_rotary_base()
|
221 |
+
)
|
222 |
+
elif self.position_embedding_type == 'alibi':
|
223 |
+
self.alibi_slopes = get_slopes(self.num_heads)
|
224 |
+
self.attention = create_attention_layer(self.hidden_size, self.num_heads, self.head_dim)
|
225 |
+
|
226 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
227 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
228 |
+
|
229 |
+
def _repeat_kv(self, hidden_states: torch.Tensor, num_heads: int) -> torch.Tensor:
|
230 |
+
assert hidden_states.size(1) <= num_heads and num_heads % hidden_states.size(1) == 0
|
231 |
+
return repeat_kv(hidden_states, num_heads // hidden_states.size(1))
|
232 |
+
|
233 |
+
def forward(
|
234 |
+
self,
|
235 |
+
hidden_states: torch.Tensor,
|
236 |
+
attention_mask: Optional[torch.Tensor] = None,
|
237 |
+
position_ids: Optional[torch.LongTensor] = None,
|
238 |
+
seqlens: Optional[torch.IntTensor] = None,
|
239 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
240 |
+
output_attentions: bool = False,
|
241 |
+
use_cache: bool = False,
|
242 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
243 |
+
bsz, q_len = hidden_states.shape[:2]
|
244 |
+
|
245 |
+
if self.config.attention_qkv_pack:
|
246 |
+
proj = self.W_pack(hidden_states)
|
247 |
+
query_states, key_states, value_states = proj.split([self.hidden_size, self.hidden_kv_size, self.hidden_kv_size], dim=-1)
|
248 |
+
else:
|
249 |
+
query_states = self.q_proj(hidden_states)
|
250 |
+
key_states = self.k_proj(hidden_states)
|
251 |
+
value_states = self.v_proj(hidden_states)
|
252 |
+
|
253 |
+
# (B, S, hidden_size) -> (B, num_heads, S, head_size)
|
254 |
+
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
255 |
+
# (B, S, hidden_size) -> (B, num_kv_heads, S, head_size)
|
256 |
+
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
257 |
+
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
258 |
+
|
259 |
+
kv_seq_len = key_states.shape[-2]
|
260 |
+
if past_key_value is not None:
|
261 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
262 |
+
if self.position_embedding_type == 'rope':
|
263 |
+
max_position = position_ids.max().item()+1 if position_ids is not None else kv_seq_len * get_sequence_parallel_size()
|
264 |
+
cos, sin = self.rotary_emb(value_states, seq_len=max_position)
|
265 |
+
query_states, key_states = apply_rotary_pos_emb(
|
266 |
+
query_states, key_states, cos, sin,
|
267 |
+
get_sequence_parallel_chunk(position_ids)
|
268 |
+
)
|
269 |
+
|
270 |
+
if past_key_value is not None:
|
271 |
+
# reuse k, v, self_attention
|
272 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
273 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
274 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
275 |
+
|
276 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
277 |
+
key_states = self._repeat_kv(key_states, query_states.size(1))
|
278 |
+
value_states = self._repeat_kv(value_states, query_states.size(1))
|
279 |
+
|
280 |
+
if seqlens is not None:
|
281 |
+
seqlens = seqlens.to(dtype=torch.int32)
|
282 |
+
max_seqlen = (seqlens[1:] - seqlens[:-1]).max().item()
|
283 |
+
if self.position_embedding_type == 'alibi':
|
284 |
+
alibi_slopes = torch.tensor(self.alibi_slopes, dtype=torch.float32).to(query_states.device)
|
285 |
+
else:
|
286 |
+
alibi_slopes = None
|
287 |
+
attn_output = self.attention(
|
288 |
+
query_states, key_states, value_states, seqlens, seqlens,
|
289 |
+
max_seqlen, max_seqlen, causal=True, alibi_slopes=alibi_slopes, use_flash=True)
|
290 |
+
else:
|
291 |
+
attn_output = self.attention(
|
292 |
+
query_states, key_states, value_states, attn_mask=attention_mask, use_flash=False)
|
293 |
+
|
294 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
295 |
+
attn_output = self.o_proj(attn_output)
|
296 |
+
|
297 |
+
return attn_output, None, past_key_value
|
298 |
+
|
299 |
+
|
300 |
+
class DecoderLayer(nn.Module):
|
301 |
+
def __init__(self, config: OmniConfig, is_sparse=False):
|
302 |
+
super().__init__()
|
303 |
+
self.hidden_size = config.hidden_size
|
304 |
+
self.self_attn = Attention(config=config, is_sparse=is_sparse)
|
305 |
+
self.mlp = MLP(
|
306 |
+
hidden_size=self.hidden_size,
|
307 |
+
intermediate_size=config.intermediate_size,
|
308 |
+
hidden_act=config.hidden_act,
|
309 |
+
)
|
310 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
311 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
312 |
+
|
313 |
+
def forward(
|
314 |
+
self,
|
315 |
+
hidden_states: torch.Tensor,
|
316 |
+
attention_mask: Optional[torch.Tensor] = None,
|
317 |
+
position_ids: Optional[torch.LongTensor] = None,
|
318 |
+
seqlens: Optional[torch.IntTensor] = None,
|
319 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
320 |
+
output_attentions: Optional[bool] = False,
|
321 |
+
use_cache: Optional[bool] = False,
|
322 |
+
group_index=None,
|
323 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
324 |
+
|
325 |
+
residual = hidden_states
|
326 |
+
|
327 |
+
hidden_states = self.input_layernorm(hidden_states)
|
328 |
+
|
329 |
+
# Self Attention
|
330 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
331 |
+
hidden_states=hidden_states,
|
332 |
+
attention_mask=attention_mask,
|
333 |
+
position_ids=position_ids,
|
334 |
+
seqlens=seqlens,
|
335 |
+
past_key_value=past_key_value,
|
336 |
+
output_attentions=output_attentions,
|
337 |
+
use_cache=use_cache,
|
338 |
+
)
|
339 |
+
hidden_states = residual + hidden_states
|
340 |
+
|
341 |
+
# Fully Connected
|
342 |
+
residual = hidden_states
|
343 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
344 |
+
hidden_states = self.mlp(hidden_states)
|
345 |
+
hidden_states = residual + hidden_states
|
346 |
+
|
347 |
+
outputs = (hidden_states,)
|
348 |
+
|
349 |
+
if output_attentions:
|
350 |
+
outputs += (self_attn_weights,)
|
351 |
+
|
352 |
+
if use_cache:
|
353 |
+
outputs += (present_key_value,)
|
354 |
+
|
355 |
+
return outputs
|
356 |
+
|
357 |
+
|
358 |
+
class OmniPreTrainedModel(PreTrainedModel):
|
359 |
+
config_class = OmniConfig
|
360 |
+
base_model_prefix = "model"
|
361 |
+
supports_gradient_checkpointing = True
|
362 |
+
_no_split_modules = ["DecoderLayer"]
|
363 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
364 |
+
|
365 |
+
def _init_weights(self, module):
|
366 |
+
std = self.config.initializer_range
|
367 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
|
368 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
369 |
+
if module.bias is not None:
|
370 |
+
module.bias.data.zero_()
|
371 |
+
elif isinstance(module, nn.Embedding):
|
372 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
373 |
+
if module.padding_idx is not None:
|
374 |
+
module.weight.data[module.padding_idx].zero_()
|
375 |
+
elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.GroupNorm):
|
376 |
+
module.weight.data.fill_(1.0)
|
377 |
+
module.bias.data.zero_()
|
378 |
+
elif isinstance(module, RMSNorm):
|
379 |
+
module.weight.data.fill_(1.0)
|
380 |
+
|
381 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
382 |
+
if isinstance(module, OmniModel):
|
383 |
+
module.gradient_checkpointing = value
|
384 |
+
|
385 |
+
@dataclass
|
386 |
+
class OmniModelOutputWithPast(BaseModelOutputWithPast):
|
387 |
+
audio_encoder_ret: Optional[Any] = None
|
388 |
+
audio_decoder_ret: Optional[Any] = None
|
389 |
+
|
390 |
+
class OmniModel(OmniPreTrainedModel):
|
391 |
+
def __init__(self, config: OmniConfig):
|
392 |
+
super().__init__(config)
|
393 |
+
self.padding_idx = config.pad_token_id
|
394 |
+
self.vocab_size = config.vocab_size
|
395 |
+
|
396 |
+
if config.visual_config.enable:
|
397 |
+
self.visual_model = OmniVisualEncoder(config.visual_config)
|
398 |
+
self.visual_bridge_model = OmniVisualBridge(config.visual_config)
|
399 |
+
if config.video_config.enable and not config.visual_config.enable: # in case 没有visual_config而只有video_config
|
400 |
+
self.visual_model = OmniVisualEncoder(config.video_config)
|
401 |
+
self.visual_bridge_model = OmniVisualBridge(config.video_config)
|
402 |
+
|
403 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
404 |
+
self.layers = nn.ModuleList([
|
405 |
+
DecoderLayer(config, is_sparse=layer_idx in config.sparse_attention_layers)
|
406 |
+
for layer_idx in range(config.num_hidden_layers)
|
407 |
+
])
|
408 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
409 |
+
|
410 |
+
self.audio_embed_layers = nn.ModuleList([
|
411 |
+
nn.Embedding(codedim + 1, config.hidden_size)
|
412 |
+
for i, codedim in enumerate(config.audio_config.vq_config.codebook_sizes)
|
413 |
+
])
|
414 |
+
|
415 |
+
self.gradient_checkpointing = True
|
416 |
+
# Initialize weights and apply final processing
|
417 |
+
self.post_init()
|
418 |
+
|
419 |
+
def get_input_embeddings(self):
|
420 |
+
return self.embed_tokens
|
421 |
+
|
422 |
+
def set_input_embeddings(self, value):
|
423 |
+
self.embed_tokens = value
|
424 |
+
|
425 |
+
@torch.no_grad()
|
426 |
+
def get_multimodal_mask(self, input_ids, pad_token_id, special_token_list):
|
427 |
+
'''
|
428 |
+
获取任意模态的特殊mask,包含以下
|
429 |
+
1. pad mask 表示文本中图像/语音/视频模态提前留出的token位置
|
430 |
+
2. special token mask 特殊token 例如对理解模型<start> <end> 不需要next token prediction
|
431 |
+
3. embedding mask / lm_head mask 标记出特殊token在embedding中的mask
|
432 |
+
'''
|
433 |
+
pad_mask = torch.eq(input_ids, pad_token_id)
|
434 |
+
sp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
435 |
+
lm_head_mask = torch.zeros([self.config.vocab_size, 1], dtype=torch.bool)
|
436 |
+
for sp_id in special_token_list:
|
437 |
+
sp_mask = torch.logical_or(sp_mask, torch.eq(input_ids, sp_id))
|
438 |
+
lm_head_mask[sp_id, 0] = True
|
439 |
+
return pad_mask, sp_mask, lm_head_mask
|
440 |
+
|
441 |
+
def get_multimodal_embed(
|
442 |
+
self,
|
443 |
+
input_ids,
|
444 |
+
text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
|
445 |
+
multimodal_embed,
|
446 |
+
pad_token_id,
|
447 |
+
fake_input,
|
448 |
+
group_index=None, # 某种模态的编号
|
449 |
+
):
|
450 |
+
pad_mask, sp_mask, _ = self.get_multimodal_mask(input_ids, pad_token_id, self.config.multimodal_special_token_list)
|
451 |
+
if not self.training: # 推理支持auto map 把多模态模块输出和input_ids 统一到一个device
|
452 |
+
multimodal_embed = multimodal_embed.to(input_ids.device)
|
453 |
+
if not fake_input: # 检查多模态token 和 pad mask数量一致 (不正确的截断会导致该问题)
|
454 |
+
assert pad_mask.sum() == multimodal_embed.shape[0]
|
455 |
+
else:
|
456 |
+
assert pad_mask.sum() <= 0
|
457 |
+
|
458 |
+
# 合并 当前模态embeddings 和text embeddings
|
459 |
+
input_ids = torch.where(pad_mask, torch.cumsum(pad_mask.view(-1).to(input_ids), dim=0).view(input_ids.shape)-1, input_ids)
|
460 |
+
text_embedding = (1 - pad_mask.to(text_embedding)).unsqueeze(-1) * text_embedding # pad token位置填0
|
461 |
+
multimodal_embedding = torch.embedding(multimodal_embed, input_ids * pad_mask) # 非 pad token 位置填idx=0位置结果
|
462 |
+
multimodal_embedding = pad_mask.to(multimodal_embedding).unsqueeze(-1) * multimodal_embedding # 非pad token 位置填0
|
463 |
+
final_embedding = multimodal_embedding.to(text_embedding) + text_embedding
|
464 |
+
|
465 |
+
if group_index is None:
|
466 |
+
group_index = pad_mask.to(torch.int32)
|
467 |
+
else:
|
468 |
+
current_index = torch.max(group_index) + 1
|
469 |
+
group_index += pad_mask.to(torch.int32) * current_index # 假设模态无重叠
|
470 |
+
|
471 |
+
return final_embedding, group_index
|
472 |
+
|
473 |
+
def get_visual_embed(
|
474 |
+
self,
|
475 |
+
input_ids,
|
476 |
+
text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
|
477 |
+
images = None,
|
478 |
+
patch_nums = None,
|
479 |
+
images_grid = None,
|
480 |
+
videos = None,
|
481 |
+
videos_patch_nums = None,
|
482 |
+
videos_grid = None,
|
483 |
+
group_index = None, # 某种模态的编号
|
484 |
+
):
|
485 |
+
if images is None or len(images) <= 0:
|
486 |
+
images, images_grid, patch_nums = self.visual_model.fake_input(input_ids.device)
|
487 |
+
image_fake_input = True
|
488 |
+
else:
|
489 |
+
image_fake_input = False
|
490 |
+
|
491 |
+
if videos is None or len(videos) <= 0 :
|
492 |
+
videos, videos_grid, videos_patch_nums = self.visual_model.fake_input(input_ids.device)
|
493 |
+
video_fake_input = True
|
494 |
+
else:
|
495 |
+
video_fake_input = False
|
496 |
+
|
497 |
+
visual_input = images + videos
|
498 |
+
visual_grid = images_grid + videos_grid
|
499 |
+
|
500 |
+
visual_input = torch.cat(visual_input, dim=0)
|
501 |
+
visual_grid = torch.tensor(np.array(visual_grid))
|
502 |
+
|
503 |
+
visual_embed = self.visual_model(visual_input, grid_thw=visual_grid)
|
504 |
+
visual_embed = self.visual_bridge_model(visual_embed)
|
505 |
+
|
506 |
+
assert sum(patch_nums) + sum(videos_patch_nums) == visual_embed.shape[0]
|
507 |
+
images_embed = visual_embed[:sum(patch_nums)]
|
508 |
+
videos_embed = visual_embed[sum(patch_nums):]
|
509 |
+
|
510 |
+
final_embedding, group_index = self.get_multimodal_embed(input_ids, text_embedding, images_embed, self.config.visual_config.image_pad_token_id, image_fake_input, group_index=group_index)
|
511 |
+
final_embedding, group_index = self.get_multimodal_embed(input_ids, final_embedding, videos_embed, self.config.video_config.video_place_token_id, video_fake_input, group_index=group_index)
|
512 |
+
return final_embedding, group_index
|
513 |
+
|
514 |
+
|
515 |
+
@torch.no_grad()
|
516 |
+
def audio_fake_input(self, device):
|
517 |
+
return torch.zeros(5, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.int32, device=device)
|
518 |
+
|
519 |
+
def forward(
|
520 |
+
self,
|
521 |
+
input_ids: torch.LongTensor = None,
|
522 |
+
attention_mask: Optional[torch.Tensor] = None,
|
523 |
+
position_ids: Optional[torch.LongTensor] = None,
|
524 |
+
seqlens: Optional[torch.IntTensor] = None,
|
525 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
526 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
527 |
+
audios_tokens: Optional[List|torch.Tensor] = None, # 音频token bs*seqlen*vq_num
|
528 |
+
images: Optional[List|torch.Tensor] = None,
|
529 |
+
patch_nums: Optional[torch.Tensor] = None,
|
530 |
+
images_grid: Optional[List|torch.Tensor] = None,
|
531 |
+
videos: Optional[List|torch.Tensor] = None,
|
532 |
+
videos_patch_nums: Optional[torch.Tensor] = None,
|
533 |
+
videos_grid: Optional[List|torch.Tensor] = None,
|
534 |
+
use_cache: Optional[bool] = None,
|
535 |
+
output_attentions: Optional[bool] = None,
|
536 |
+
output_hidden_states: Optional[bool] = None,
|
537 |
+
return_dict: Optional[bool] = None,
|
538 |
+
) -> Union[Tuple, OmniModelOutputWithPast]:
|
539 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
540 |
+
output_hidden_states = (
|
541 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
542 |
+
)
|
543 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
544 |
+
return_dict = True if (return_dict is not None or self.training) else self.config.use_return_dict
|
545 |
+
|
546 |
+
# retrieve input_ids and inputs_embeds
|
547 |
+
if input_ids is not None and inputs_embeds is not None:
|
548 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
549 |
+
elif input_ids is not None:
|
550 |
+
batch_size, seq_length = input_ids.shape
|
551 |
+
elif inputs_embeds is not None:
|
552 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
553 |
+
else:
|
554 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
555 |
+
|
556 |
+
seq_length_with_past = seq_length
|
557 |
+
past_key_values_length = 0
|
558 |
+
|
559 |
+
if past_key_values is not None:
|
560 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
561 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
562 |
+
|
563 |
+
if position_ids is None:
|
564 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
565 |
+
position_ids = torch.arange(
|
566 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
567 |
+
)
|
568 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
569 |
+
else:
|
570 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
571 |
+
|
572 |
+
group_index, audio_decoder_ret = None, None
|
573 |
+
if inputs_embeds is None:
|
574 |
+
sp_input_ids = get_sequence_parallel_chunk(input_ids)
|
575 |
+
inputs_embeds = self.embed_tokens(sp_input_ids)
|
576 |
+
if audios_tokens is None or len(audios_tokens) <= 0 :
|
577 |
+
audios_tokens = torch.zeros(5, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.int32, device=input_ids.device) # a fake input
|
578 |
+
fake_input = True
|
579 |
+
else:
|
580 |
+
fake_input = False
|
581 |
+
for i, audio_emb_layer in enumerate(self.audio_embed_layers):
|
582 |
+
if i==0:
|
583 |
+
audio_embs = audio_emb_layer(audios_tokens[..., i])
|
584 |
+
else:
|
585 |
+
audio_embs += audio_emb_layer(audios_tokens[..., i])
|
586 |
+
inputs_embeds, group_index = self.get_multimodal_embed(sp_input_ids, inputs_embeds, audio_embs, self.config.audio_config.audio_pad_token_id, fake_input, group_index=group_index)
|
587 |
+
|
588 |
+
if self.config.visual_config.enable or self.config.video_config.enable:
|
589 |
+
inputs_embeds, group_index = self.get_visual_embed(sp_input_ids, inputs_embeds, images, patch_nums, images_grid, videos, videos_patch_nums, videos_grid, group_index=group_index) # 注意更新group index
|
590 |
+
|
591 |
+
if seqlens is not None and seqlens.ndim == 2:
|
592 |
+
cu_seqlens = []
|
593 |
+
offset, seqlen = 0, seqlens.size(1)
|
594 |
+
for lens in seqlens:
|
595 |
+
cu_seqlens.append(offset)
|
596 |
+
cu_seqlens.extend((lens[(lens > 0) & (lens < seqlen)] + offset).tolist())
|
597 |
+
offset += seqlen
|
598 |
+
cu_seqlens.append(offset)
|
599 |
+
seqlens = torch.tensor(cu_seqlens, dtype=seqlens.dtype, device=seqlens.device)
|
600 |
+
elif seqlens is None and self.training:
|
601 |
+
seqlens = torch.arange(
|
602 |
+
end=input_ids.size(0) + 1,
|
603 |
+
dtype=torch.int32,
|
604 |
+
device=input_ids.device
|
605 |
+
) * input_ids.size(1)
|
606 |
+
if seqlens is not None:
|
607 |
+
attention_mask = None # unset attention_mask to save memory
|
608 |
+
|
609 |
+
if seqlens is None and attention_mask is None:
|
610 |
+
attention_mask = torch.ones(
|
611 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
612 |
+
)
|
613 |
+
if attention_mask is not None:
|
614 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
615 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
616 |
+
)
|
617 |
+
|
618 |
+
# embed positions
|
619 |
+
hidden_states = inputs_embeds
|
620 |
+
|
621 |
+
if self.gradient_checkpointing and self.training:
|
622 |
+
if use_cache:
|
623 |
+
logger.warning_once(
|
624 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
625 |
+
)
|
626 |
+
use_cache = False
|
627 |
+
|
628 |
+
# decoder layers
|
629 |
+
all_hidden_states = () if output_hidden_states else None
|
630 |
+
all_self_attns = () if output_attentions else None
|
631 |
+
next_decoder_cache = () if use_cache else None
|
632 |
+
|
633 |
+
for idx, decoder_layer in enumerate(self.layers):
|
634 |
+
if output_hidden_states:
|
635 |
+
all_hidden_states += (hidden_states,)
|
636 |
+
|
637 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
638 |
+
|
639 |
+
if self.gradient_checkpointing and self.training:
|
640 |
+
|
641 |
+
def create_custom_forward(module):
|
642 |
+
def custom_forward(*inputs):
|
643 |
+
# None for past_key_value
|
644 |
+
return module(*inputs, output_attentions, False, group_index)
|
645 |
+
|
646 |
+
return custom_forward
|
647 |
+
|
648 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
649 |
+
create_custom_forward(decoder_layer),
|
650 |
+
hidden_states,
|
651 |
+
attention_mask,
|
652 |
+
position_ids,
|
653 |
+
seqlens,
|
654 |
+
None,
|
655 |
+
)
|
656 |
+
else:
|
657 |
+
layer_outputs = decoder_layer(
|
658 |
+
hidden_states,
|
659 |
+
attention_mask=attention_mask,
|
660 |
+
position_ids=position_ids,
|
661 |
+
seqlens=seqlens,
|
662 |
+
past_key_value=past_key_value,
|
663 |
+
output_attentions=output_attentions,
|
664 |
+
use_cache=use_cache,
|
665 |
+
group_index=group_index,
|
666 |
+
)
|
667 |
+
|
668 |
+
hidden_states = layer_outputs[0]
|
669 |
+
|
670 |
+
if use_cache:
|
671 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
672 |
+
|
673 |
+
if output_attentions:
|
674 |
+
all_self_attns += (layer_outputs[1],)
|
675 |
+
|
676 |
+
hidden_states = self.norm(hidden_states)
|
677 |
+
|
678 |
+
# add hidden states from the last decoder layer
|
679 |
+
if output_hidden_states:
|
680 |
+
all_hidden_states += (hidden_states,)
|
681 |
+
|
682 |
+
next_cache = next_decoder_cache if use_cache else None
|
683 |
+
if not return_dict:
|
684 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
685 |
+
return BaseModelOutputWithPast(
|
686 |
+
last_hidden_state=hidden_states,
|
687 |
+
past_key_values=next_cache,
|
688 |
+
hidden_states=all_hidden_states,
|
689 |
+
attentions=all_self_attns,
|
690 |
+
)
|
691 |
+
|
692 |
+
|
693 |
+
class NormHead(nn.Module):
|
694 |
+
def __init__(self, hidden_size, vocab_size, bias=False):
|
695 |
+
super().__init__()
|
696 |
+
self.hidden_size = hidden_size
|
697 |
+
self.vocab_size = vocab_size
|
698 |
+
self.weight = nn.Parameter(torch.empty((self.vocab_size, self.hidden_size)))
|
699 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
700 |
+
|
701 |
+
def forward(self, hidden_states, mask=None):
|
702 |
+
norm_weight = nn.functional.normalize(self.weight)
|
703 |
+
if mask is not None:
|
704 |
+
mask = mask.to(norm_weight)
|
705 |
+
norm_weight = norm_weight * mask + (1 - mask) * norm_weight.detach()
|
706 |
+
return nn.functional.linear(hidden_states, norm_weight)
|
707 |
+
|
708 |
+
|
709 |
+
def extra_repr(self) -> str:
|
710 |
+
return f'in_features={self.hidden_size}, out_features={self.vocab_size}'
|
711 |
+
|
712 |
+
@dataclass
|
713 |
+
class OmniMMCausalLMOutputWithPast(ModelOutput):
|
714 |
+
loss: Optional[torch.FloatTensor] = None
|
715 |
+
logits: Optional[torch.FloatTensor] = None
|
716 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
717 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
718 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
719 |
+
audios_emb_for_infer: Optional[torch.FloatTensor] = None # 用于audio head 推理的 embeddings
|
720 |
+
|
721 |
+
|
722 |
+
class CasualDepthTransformerLayer(nn.Module):
|
723 |
+
def __init__(self, config, depth):
|
724 |
+
super().__init__()
|
725 |
+
self.config = config
|
726 |
+
embed_size = config.hidden_size
|
727 |
+
assert embed_size % 128 == 0
|
728 |
+
num_heads = embed_size // 128
|
729 |
+
self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads,batch_first=True)
|
730 |
+
self.layernorm1 = RMSNorm(embed_size)
|
731 |
+
self.layernorm2 = RMSNorm(embed_size)
|
732 |
+
self.linear1 = nn.Linear(embed_size * depth, 2 * embed_size)
|
733 |
+
self.linear2 = nn.Linear(2 * embed_size * depth, embed_size)
|
734 |
+
|
735 |
+
def forward(self, x):
|
736 |
+
seq_len = x.size(1)
|
737 |
+
res = x
|
738 |
+
x = self.layernorm1(x)
|
739 |
+
src_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
|
740 |
+
_x, _ = self.self_attention(x, x, x, is_causal=True, attn_mask=src_mask)
|
741 |
+
res = _x + res # (bs, sl, d)
|
742 |
+
res = self.layernorm2(res)
|
743 |
+
x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (2 * self.config.hidden_size, -1, self.config.hidden_size)))
|
744 |
+
x = torch.nn.functional.gelu(x)
|
745 |
+
x = torch.einsum('blt,dlt->bld', x, torch.reshape(self.linear2.weight, (self.config.hidden_size, -1, 2 * self.config.hidden_size)))
|
746 |
+
return res + x
|
747 |
+
|
748 |
+
class OmniAudioHead(nn.Module):
|
749 |
+
def __init__(self, config):
|
750 |
+
super().__init__()
|
751 |
+
self.config = config
|
752 |
+
hidden_size = config.hidden_size
|
753 |
+
self.transformer_layers = nn.ModuleList([
|
754 |
+
CasualDepthTransformerLayer(config, len(config.audio_config.vq_config.codebook_sizes))
|
755 |
+
for _ in range(config.audio_config.audio_head_transformer_layers)
|
756 |
+
])
|
757 |
+
self.headnorm = RMSNorm(hidden_size)
|
758 |
+
self.heads = nn.ModuleList([
|
759 |
+
nn.Linear(hidden_size, vq_size+1)
|
760 |
+
for vq_size in config.audio_config.vq_config.codebook_sizes
|
761 |
+
])
|
762 |
+
self.gradient_checkpointing = True
|
763 |
+
|
764 |
+
def forward(self, x, audios_tokens, audio_emb_layers):
|
765 |
+
cumsum_audio_embed = torch.stack([
|
766 |
+
audio_emb_layers[i](audios_tokens[..., i])
|
767 |
+
for i, vq_size in enumerate(self.config.audio_config.vq_config.codebook_sizes[:-1])
|
768 |
+
], dim=1)
|
769 |
+
cumsum_audio_embed = torch.cumsum(cumsum_audio_embed, dim=1) # (bs, depth-1, d)
|
770 |
+
hidden_states = torch.concat([x.reshape(-1, 1, self.config.hidden_size), cumsum_audio_embed], dim=1) # (bs, depth, d)
|
771 |
+
assert hidden_states.size(1) == len(self.config.audio_config.vq_config.codebook_sizes)
|
772 |
+
for i, tlayer in enumerate(self.transformer_layers):
|
773 |
+
hidden_states = tlayer(hidden_states,)
|
774 |
+
hidden_states = self.headnorm(hidden_states)
|
775 |
+
logits = [head(hidden_states[:,i]) for i, head in enumerate(self.heads)]
|
776 |
+
return logits
|
777 |
+
|
778 |
+
|
779 |
+
class OmniForCausalLM(OmniPreTrainedModel):
|
780 |
+
def __init__(self, config):
|
781 |
+
super().__init__(config)
|
782 |
+
self.config = config
|
783 |
+
self.model = OmniModel(config)
|
784 |
+
self.audio_tokenizer = OmniAudioTokenizer(config)
|
785 |
+
self.audio_head = OmniAudioHead(config)
|
786 |
+
if config.use_norm_head:
|
787 |
+
self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
|
788 |
+
else:
|
789 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
790 |
+
# Initialize weights and apply final processing
|
791 |
+
self.post_init()
|
792 |
+
|
793 |
+
@property
|
794 |
+
def main_device(self):
|
795 |
+
return self.lm_head.weight.device
|
796 |
+
|
797 |
+
def bind_processor(self, tokenizer, **kwargs):
|
798 |
+
self.processor = OmniMMProcessor(
|
799 |
+
tokenizer=tokenizer,
|
800 |
+
config=self.config,
|
801 |
+
**kwargs,
|
802 |
+
)
|
803 |
+
return self.processor
|
804 |
+
|
805 |
+
def get_input_embeddings(self):
|
806 |
+
return self.model.embed_tokens
|
807 |
+
|
808 |
+
def set_input_embeddings(self, value):
|
809 |
+
self.model.embed_tokens = value
|
810 |
+
|
811 |
+
def get_output_embeddings(self):
|
812 |
+
return self.lm_head
|
813 |
+
|
814 |
+
def set_output_embeddings(self, new_embeddings):
|
815 |
+
self.lm_head = new_embeddings
|
816 |
+
|
817 |
+
def set_decoder(self, decoder):
|
818 |
+
self.model = decoder
|
819 |
+
|
820 |
+
def get_decoder(self):
|
821 |
+
return self.model
|
822 |
+
|
823 |
+
def forward(
|
824 |
+
self,
|
825 |
+
input_ids: torch.LongTensor = None,
|
826 |
+
attention_mask: Optional[torch.Tensor] = None,
|
827 |
+
position_ids: Optional[torch.LongTensor] = None,
|
828 |
+
seqlens: Optional[torch.IntTensor] = None,
|
829 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
830 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
831 |
+
labels: Optional[torch.LongTensor] = None,
|
832 |
+
audios: Optional[List|torch.Tensor] = None,
|
833 |
+
audios_tokens: Optional[List|torch.Tensor] = None,
|
834 |
+
encoder_length: Optional[torch.Tensor] = None,
|
835 |
+
bridge_length: Optional[torch.Tensor] = None,
|
836 |
+
images: Optional[torch.Tensor] = None,
|
837 |
+
patch_nums: Optional[torch.Tensor] = None,
|
838 |
+
images_grid: Optional[torch.Tensor] = None,
|
839 |
+
videos: Optional[torch.Tensor] = None,
|
840 |
+
videos_patch_nums: Optional[torch.Tensor] = None,
|
841 |
+
videos_grid: Optional[torch.Tensor] = None,
|
842 |
+
use_cache: Optional[bool] = None,
|
843 |
+
output_attentions: Optional[bool] = None,
|
844 |
+
output_hidden_states: Optional[bool] = None,
|
845 |
+
return_dict: Optional[bool] = None,
|
846 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
847 |
+
|
848 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
849 |
+
output_hidden_states = (
|
850 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
851 |
+
)
|
852 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
853 |
+
|
854 |
+
if audios_tokens is not None:
|
855 |
+
assert isinstance(audios_tokens, torch.Tensor)
|
856 |
+
else:
|
857 |
+
if audios is None or len(audios) == 0:
|
858 |
+
audios_tokens = None
|
859 |
+
else:
|
860 |
+
audios_tokens = self.audio_tokenizer(audios,encoder_length,bridge_length)
|
861 |
+
|
862 |
+
outputs = self.model(
|
863 |
+
input_ids=input_ids,
|
864 |
+
attention_mask=attention_mask,
|
865 |
+
position_ids=position_ids,
|
866 |
+
seqlens=seqlens,
|
867 |
+
past_key_values=past_key_values,
|
868 |
+
inputs_embeds=inputs_embeds,
|
869 |
+
audios_tokens=audios_tokens,
|
870 |
+
images=images,
|
871 |
+
patch_nums=patch_nums,
|
872 |
+
images_grid=images_grid,
|
873 |
+
videos=videos,
|
874 |
+
videos_patch_nums=videos_patch_nums,
|
875 |
+
videos_grid=videos_grid,
|
876 |
+
use_cache=use_cache,
|
877 |
+
output_attentions=output_attentions,
|
878 |
+
output_hidden_states=output_hidden_states,
|
879 |
+
return_dict=return_dict,
|
880 |
+
)
|
881 |
+
hidden_states = outputs.last_hidden_state
|
882 |
+
audios_emb_for_infer = hidden_states[:,-1,:]
|
883 |
+
logits = self.lm_head(hidden_states)
|
884 |
+
|
885 |
+
return OmniMMCausalLMOutputWithPast(
|
886 |
+
logits=logits,
|
887 |
+
past_key_values=outputs.past_key_values,
|
888 |
+
hidden_states=outputs.hidden_states,
|
889 |
+
attentions=outputs.attentions,
|
890 |
+
audios_emb_for_infer=audios_emb_for_infer
|
891 |
+
)
|
892 |
+
|
893 |
+
def prepare_inputs_for_generation(
|
894 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
895 |
+
):
|
896 |
+
if past_key_values:
|
897 |
+
input_ids = input_ids[:, past_key_values[0][0].shape[-2]:]
|
898 |
+
|
899 |
+
position_ids = kwargs.get("position_ids", None)
|
900 |
+
if attention_mask is not None and position_ids is None:
|
901 |
+
# create position_ids on the fly for batch generation
|
902 |
+
position_ids = attention_mask.long().cumsum(-1)
|
903 |
+
# position_ids.masked_fill_(attention_mask == 0, 1)
|
904 |
+
if past_key_values:
|
905 |
+
position_ids = position_ids[:, past_key_values[0][0].shape[-2]:]
|
906 |
+
|
907 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
908 |
+
if inputs_embeds is not None and past_key_values is None:
|
909 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
910 |
+
elif past_key_values is not None:
|
911 |
+
model_inputs = {"input_ids": input_ids}
|
912 |
+
else:
|
913 |
+
model_inputs = {"input_ids": input_ids,
|
914 |
+
"audios": kwargs.get("audios", None), "encoder_length": kwargs.get("encoder_length", None), "bridge_length": kwargs.get("bridge_length", None),
|
915 |
+
"audios_tokens": kwargs.get("audios_tokens", None),
|
916 |
+
"images": kwargs.get("images", None),
|
917 |
+
"videos": kwargs.get("videos", None)
|
918 |
+
}
|
919 |
+
|
920 |
+
model_inputs.update(
|
921 |
+
{
|
922 |
+
"position_ids": position_ids,
|
923 |
+
"past_key_values": past_key_values,
|
924 |
+
"use_cache": kwargs.get("use_cache"),
|
925 |
+
"attention_mask": attention_mask,
|
926 |
+
"images_grid": kwargs.get("images_grid"),
|
927 |
+
"videos_grid": kwargs.get("videos_grid"),
|
928 |
+
"patch_nums": kwargs.get("patch_nums"),
|
929 |
+
"videos_patch_nums": kwargs.get("videos_patch_nums"),
|
930 |
+
}
|
931 |
+
)
|
932 |
+
return model_inputs
|
933 |
+
|
934 |
+
@staticmethod
|
935 |
+
def _reorder_cache(past_key_values, beam_idx):
|
936 |
+
reordered_past = ()
|
937 |
+
for layer_past in past_key_values:
|
938 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
939 |
+
return reordered_past
|
940 |
+
|
941 |
+
def chat(self, tokenizer, messages: List[dict], stream=False,
|
942 |
+
generation_config: Optional[GenerationConfig]=None):
|
943 |
+
generation_config = generation_config or self.generation_config
|
944 |
+
input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
|
945 |
+
if stream:
|
946 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
947 |
+
Thread(target=self.generate, kwargs=dict(
|
948 |
+
inputs=input_ids, streamer=streamer,
|
949 |
+
generation_config=generation_config,
|
950 |
+
)).start()
|
951 |
+
return streamer
|
952 |
+
else:
|
953 |
+
outputs = self.generate(input_ids, generation_config=generation_config)
|
954 |
+
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
955 |
+
return response
|
956 |
+
|
957 |
+
|
958 |
+
class OmniAudioTokenizer(OmniPreTrainedModel):
|
959 |
+
"""
|
960 |
+
Construct an audio tokenizer and decoder.
|
961 |
+
"""
|
962 |
+
def __init__(self, config: OmniConfig):
|
963 |
+
super().__init__(config)
|
964 |
+
self.padding_idx = None
|
965 |
+
self.vocab_size = config.vocab_size
|
966 |
+
self.training = False
|
967 |
+
self.eval()
|
968 |
+
self.audio_model = OmniAudioEncoder(config.audio_config)
|
969 |
+
self.audio_bridge_model = OmniAudioVQBridgeTokenizer(config)
|
970 |
+
if config.vocoder_config.enable:
|
971 |
+
self.audio_decoder = OmniAudioDecoder(config)
|
972 |
+
if config.flow_matching_config.enable:
|
973 |
+
self.audio_flow_matching_decoder = OmniAudioFlowMatchingDecoder(config)
|
974 |
+
|
975 |
+
def encode(self, x, encoder_length: Optional[torch.Tensor] = None,
|
976 |
+
bridge_length: Optional[torch.Tensor] = None):
|
977 |
+
audio_emb = self.audio_model(x, encoder_length)
|
978 |
+
audios_tokens = self.audio_bridge_model(audio_emb, bridge_length)
|
979 |
+
return audios_tokens
|
980 |
+
|
981 |
+
def decode(self, audio_code_ids, bridge_length: Optional[torch.Tensor] = None):
|
982 |
+
assert self.config.vocoder_config.enable, "Vocoder is not enabled in config."
|
983 |
+
audio_emb = self.audio_bridge_model.decode(audio_code_ids)
|
984 |
+
audio_dec = self.audio_decoder(
|
985 |
+
audio_emb.to(next(self.audio_decoder.parameters())), bridge_length
|
986 |
+
)
|
987 |
+
if self.config.flow_matching_config.enable:
|
988 |
+
if self.config.flow_matching_config.use_hidden_states_before_dconv2:
|
989 |
+
hidden_states, hidden_states_length = (
|
990 |
+
self.audio_flow_matching_decoder.unpack_hidden_states(
|
991 |
+
audio_dec.hidden_states_before_dconv2,
|
992 |
+
audio_dec.output_length_before_dconv2,
|
993 |
+
)
|
994 |
+
)
|
995 |
+
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
|
996 |
+
hidden_states, hidden_states_length
|
997 |
+
)
|
998 |
+
|
999 |
+
else:
|
1000 |
+
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
|
1001 |
+
audio_dec.refined_mel, audio_dec.mel_length
|
1002 |
+
)
|
1003 |
+
return audio_flow_matching_decoder_ret
|
1004 |
+
else:
|
1005 |
+
return audio_dec
|
1006 |
+
|
1007 |
+
@torch.no_grad()
|
1008 |
+
def forward(self, audios, encoder_length: Optional[torch.Tensor] = None, bridge_length: Optional[torch.Tensor] = None):
|
1009 |
+
self.eval()
|
1010 |
+
audios_tokens = self.encode(audios, encoder_length, bridge_length)
|
1011 |
+
return audios_tokens
|
processor_omni.py
ADDED
@@ -0,0 +1,865 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import re, ujson, os, sys, fire, glob, random, time, json
|
3 |
+
import numpy as np
|
4 |
+
import io
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import default_collate
|
7 |
+
import torchaudio
|
8 |
+
from typing import *
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
import transformers
|
11 |
+
from transformers.modeling_outputs import ModelOutput
|
12 |
+
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
|
13 |
+
from functools import lru_cache
|
14 |
+
from io import BytesIO
|
15 |
+
from PIL import Image
|
16 |
+
import concurrent.futures as cf
|
17 |
+
from transformers.image_transforms import resize, center_crop, get_resize_output_image_size
|
18 |
+
from transformers.image_utils import PILImageResampling
|
19 |
+
from PIL import Image, ImageOps
|
20 |
+
from PIL import ImageFile
|
21 |
+
torch.set_num_threads(1) # 限制torch的线程数 否则可能会卡住
|
22 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
23 |
+
import base64
|
24 |
+
from decord import VideoReader, cpu
|
25 |
+
import cv2
|
26 |
+
import av
|
27 |
+
import imagesize
|
28 |
+
import tempfile
|
29 |
+
import math
|
30 |
+
from multiprocessing import Pool
|
31 |
+
from cairosvg import svg2png
|
32 |
+
import hashlib
|
33 |
+
|
34 |
+
IMAGE_FACTOR = 28
|
35 |
+
MIN_PIXELS = 4 * 28 * 28
|
36 |
+
MAX_PIXELS = 16384 * 28 * 28
|
37 |
+
MAX_RATIO = 200
|
38 |
+
|
39 |
+
VIDEO_MIN_PIXELS = 128 * 28 * 28
|
40 |
+
VIDEO_MAX_PIXELS = 768 * 28 * 28
|
41 |
+
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
|
42 |
+
FRAME_FACTOR = 2
|
43 |
+
FPS = 2.0
|
44 |
+
FPS_MIN_FRAMES = 4
|
45 |
+
FPS_MAX_FRAMES = 768
|
46 |
+
|
47 |
+
def round_by_factor(number: int, factor: int) -> int:
|
48 |
+
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
49 |
+
return round(number / factor) * factor
|
50 |
+
|
51 |
+
|
52 |
+
def ceil_by_factor(number: int, factor: int) -> int:
|
53 |
+
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
54 |
+
return math.ceil(number / factor) * factor
|
55 |
+
|
56 |
+
|
57 |
+
def floor_by_factor(number: int, factor: int) -> int:
|
58 |
+
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
59 |
+
return math.floor(number / factor) * factor
|
60 |
+
|
61 |
+
|
62 |
+
def smart_resize(
|
63 |
+
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
64 |
+
) -> tuple[int, int]:
|
65 |
+
"""
|
66 |
+
Rescales the image so that the following conditions are met:
|
67 |
+
|
68 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
69 |
+
|
70 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
71 |
+
|
72 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
73 |
+
"""
|
74 |
+
if max(height, width) / min(height, width) > MAX_RATIO:
|
75 |
+
raise ValueError(
|
76 |
+
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
77 |
+
)
|
78 |
+
h_bar = max(factor, round_by_factor(height, factor))
|
79 |
+
w_bar = max(factor, round_by_factor(width, factor))
|
80 |
+
if h_bar * w_bar > max_pixels:
|
81 |
+
beta = math.sqrt((height * width) / max_pixels)
|
82 |
+
h_bar = floor_by_factor(height / beta, factor)
|
83 |
+
w_bar = floor_by_factor(width / beta, factor)
|
84 |
+
elif h_bar * w_bar < min_pixels:
|
85 |
+
beta = math.sqrt(min_pixels / (height * width))
|
86 |
+
h_bar = ceil_by_factor(height * beta, factor)
|
87 |
+
w_bar = ceil_by_factor(width * beta, factor)
|
88 |
+
return h_bar, w_bar
|
89 |
+
|
90 |
+
|
91 |
+
def split_text(text, match_regex):
|
92 |
+
matches = list(re.finditer(match_regex, text))
|
93 |
+
# 初始化结果列表
|
94 |
+
result = []
|
95 |
+
match_flag_list = []
|
96 |
+
# 上一个匹配的结束位置
|
97 |
+
last_end = 0
|
98 |
+
# 遍历所有匹配项
|
99 |
+
for match in matches:
|
100 |
+
# 添加匹配项之前的部分
|
101 |
+
if text[last_end:match.start()]:
|
102 |
+
result.append(text[last_end:match.start()])
|
103 |
+
match_flag_list.append(False)
|
104 |
+
# 添加匹配项
|
105 |
+
result.append(match.group(0))
|
106 |
+
match_flag_list.append(True)
|
107 |
+
# 更新上一个匹配的结束位置
|
108 |
+
last_end = match.end()
|
109 |
+
# 添加最后一个匹配项之后的部分
|
110 |
+
if text[last_end:]:
|
111 |
+
result.append(text[last_end:])
|
112 |
+
match_flag_list.append(False)
|
113 |
+
return result, match_flag_list
|
114 |
+
|
115 |
+
|
116 |
+
def read_video(image_path, max_frame_number, decode_way):
|
117 |
+
if decode_way=='1fps':
|
118 |
+
try:
|
119 |
+
# print(image_path)
|
120 |
+
vr = VideoReader(image_path, ctx=cpu(0))
|
121 |
+
total_frame_num = len(vr)
|
122 |
+
fps = round(vr.get_avg_fps())
|
123 |
+
frame_idx = [i for i in range(0, len(vr), fps)]
|
124 |
+
frames = vr.get_batch(frame_idx).asnumpy()
|
125 |
+
cnt = len(frames)
|
126 |
+
frame_times = range(cnt)
|
127 |
+
except Exception as e:
|
128 |
+
print(image_path)
|
129 |
+
print('error is', e)
|
130 |
+
return None
|
131 |
+
elif decode_way=='key':
|
132 |
+
try:
|
133 |
+
with av.open(image_path) as container:
|
134 |
+
stream = container.streams.video[0]
|
135 |
+
stream.codec_context.skip_frame = 'NONKEY'
|
136 |
+
frames = []
|
137 |
+
frame_times = []
|
138 |
+
fps = int(stream.average_rate)
|
139 |
+
cnt = 0
|
140 |
+
for frame in container.decode(stream): # 关键帧存成image patch
|
141 |
+
image = np.array(frame.to_image())
|
142 |
+
frames.append(image)
|
143 |
+
frame_time = int(frame.time)
|
144 |
+
frame_times.append(frame_time)
|
145 |
+
cnt += 1
|
146 |
+
except Exception as e:
|
147 |
+
print('error is', e)
|
148 |
+
return None
|
149 |
+
if frames is None or len(frames)==0:
|
150 |
+
return None
|
151 |
+
if len(frames)>max_frame_number and max_frame_number>0:
|
152 |
+
# 生成14个均匀间隔的索引
|
153 |
+
indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int)
|
154 |
+
# 根据索引获取对应元素
|
155 |
+
frames = frames[indices]
|
156 |
+
frame_times = frame_times[indices]
|
157 |
+
return frames, frame_times
|
158 |
+
|
159 |
+
|
160 |
+
class OmniImageProcessor:
|
161 |
+
def __init__(self, config, **kwargs):
|
162 |
+
self.config = config # visual_config
|
163 |
+
self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56
|
164 |
+
self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280
|
165 |
+
self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14
|
166 |
+
self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2
|
167 |
+
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
|
168 |
+
self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2
|
169 |
+
|
170 |
+
def image_transform(self, strseq, return_mm_data = True):
|
171 |
+
image = None
|
172 |
+
if isinstance(strseq, str):
|
173 |
+
if return_mm_data:
|
174 |
+
image = Image.open(strseq).convert("RGB")
|
175 |
+
else:
|
176 |
+
try:
|
177 |
+
image = Image.open(BytesIO(strseq)).convert("RGB")
|
178 |
+
except:
|
179 |
+
image = Image.open(BytesIO(svg2png(bytestring=strseq))).convert("RGB") # interleaved有的是矢量图,需要转换
|
180 |
+
|
181 |
+
image = np.array(image.convert("RGB")) # 这一步首先将图像转换为 RGB 格式,确保图像有三个通道(R、G、B)。然后使用 np.array() 将其转换为 NumPy 数组,方便后续处理。
|
182 |
+
image_org_size = image.shape[:2] # 这里保存了图像的原始大小(高度和宽度),image.shape 返回图像的形状 (高度, 宽度, 通道数),而 image.shape[:2] 提取了前两个值,即原始的高度和宽度。这个信息可以用于后续的对比或其他处理。
|
183 |
+
|
184 |
+
# resize, crop, scale, normalize
|
185 |
+
# 输出一个新的尺寸,这个尺寸通常是 (宽度, 高度) 格式,用于后续的图像调整操作,如缩放或裁剪。
|
186 |
+
resized_height, resized_width = smart_resize(
|
187 |
+
image_org_size[0], image_org_size[1],
|
188 |
+
factor=self.patch_size * self.spatial_merge_size,
|
189 |
+
min_pixels=self.min_pixels,
|
190 |
+
max_pixels=self.max_pixels,
|
191 |
+
)
|
192 |
+
output_size = (resized_height, resized_width)
|
193 |
+
|
194 |
+
# 使用 resize 函数将图像调整到 output_size 大小。PILImageResampling.BICUBIC 指定使用双三次插值法来进行图像缩放,这种方法通常能够提供较好的图像质量。
|
195 |
+
# image: 输入的图像数据,可以是 NumPy 数组或 PIL 图像对象;output_size: 目标大小,通常是一个二元组 (宽度, 高度)。这个尺寸可以是图像的绝对大小,也可以是相对于原始图像的比例;
|
196 |
+
# resample: 可选的重采样方法,通常用于确定如何插值像素。例如,PILImageResampling.BICUBIC 表示使用双三次插值法,这是一种平滑的插值方法,常用于图像缩放。
|
197 |
+
image = resize(image, output_size, PILImageResampling.BICUBIC)
|
198 |
+
img = image.transpose(2, 0, 1)
|
199 |
+
# 对图像进行归一化和标准化处理
|
200 |
+
image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis]
|
201 |
+
# 处理成patch
|
202 |
+
patches = image[np.newaxis, :]
|
203 |
+
if patches.shape[0] == 1:
|
204 |
+
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
|
205 |
+
channel = patches.shape[1]
|
206 |
+
grid_t = patches.shape[0] // self.temporal_patch_size
|
207 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
208 |
+
patches = patches.reshape(
|
209 |
+
grid_t,
|
210 |
+
self.temporal_patch_size,
|
211 |
+
channel,
|
212 |
+
grid_h // self.spatial_merge_size,
|
213 |
+
self.spatial_merge_size,
|
214 |
+
self.patch_size,
|
215 |
+
grid_w // self.spatial_merge_size,
|
216 |
+
self.spatial_merge_size,
|
217 |
+
self.patch_size,
|
218 |
+
)
|
219 |
+
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
220 |
+
flatten_patches = patches.reshape(
|
221 |
+
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
222 |
+
)
|
223 |
+
|
224 |
+
return flatten_patches, image_org_size, (grid_t, grid_h, grid_w)
|
225 |
+
|
226 |
+
|
227 |
+
class OmniAudioProcessor:
|
228 |
+
# 包含基本的音频特征抽取模块 + 输入数据解析模块
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
config, # audio processor config
|
232 |
+
**kwargs
|
233 |
+
):
|
234 |
+
# make sure you have install 'conda install -c conda-forge 'ffmpeg<7'' for torchaudio
|
235 |
+
assert(len(torchaudio.list_audio_backends()) > 0)
|
236 |
+
self.config = config
|
237 |
+
self.mel_filters = mel_filter_bank(
|
238 |
+
num_frequency_bins=1 + self.config.n_fft // 2,
|
239 |
+
num_mel_filters=self.config.num_mel_bins,
|
240 |
+
min_frequency=0.0,
|
241 |
+
max_frequency=self.config.sampling_rate / 2.0,
|
242 |
+
sampling_rate=self.config.sampling_rate,
|
243 |
+
norm="slaney",
|
244 |
+
mel_scale="slaney",
|
245 |
+
)
|
246 |
+
self.window = torch.hann_window(self.config.n_fft)
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-6):
|
250 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
251 |
+
|
252 |
+
@staticmethod
|
253 |
+
def zero_mean_unit_var_norm(x):
|
254 |
+
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
|
255 |
+
|
256 |
+
def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False):
|
257 |
+
metadata = torchaudio.info(uri) # sample_rate, num_frames, num_channels, bits_per_sample, encoding=PCM_S
|
258 |
+
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
|
259 |
+
waveform_tensor, _ = torchaudio.load(uri, normalize=True)
|
260 |
+
if self.config.sampling_rate != metadata.sample_rate:
|
261 |
+
waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate, lowpass_filter_width=128)
|
262 |
+
|
263 |
+
# downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
|
264 |
+
if metadata.num_channels > 1:
|
265 |
+
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
|
266 |
+
|
267 |
+
# normalized to zero mean
|
268 |
+
if do_normalize:
|
269 |
+
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
|
270 |
+
|
271 |
+
if return_tensors: # (channels, samples)
|
272 |
+
return waveform_tensor
|
273 |
+
else:
|
274 |
+
return waveform_tensor.numpy()
|
275 |
+
|
276 |
+
def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
|
277 |
+
channels, wave_samples = waveform.shape
|
278 |
+
max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate
|
279 |
+
if wave_samples <= max_audio_samples or self.config.split_overlap < 0:
|
280 |
+
return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
|
281 |
+
|
282 |
+
split_waveform, start = [], 0
|
283 |
+
while start < wave_samples: # 统一按秒数对齐overlap
|
284 |
+
if start > int(self.config.sampling_rate * self.config.split_overlap):
|
285 |
+
start -= int(self.config.sampling_rate * self.config.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
|
286 |
+
end = min(start + max_audio_samples, wave_samples)
|
287 |
+
if end - start>= self.config.n_fft: # 保证至少有一帧数据
|
288 |
+
split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
|
289 |
+
start = end
|
290 |
+
return split_waveform
|
291 |
+
|
292 |
+
@classmethod
|
293 |
+
def inference_output_length(cls, config, input_length):
|
294 |
+
# for whisper + bridge
|
295 |
+
kernel_size = config.kernel_size
|
296 |
+
stride_size = config.stride_size
|
297 |
+
avg_pooler = config.avg_pooler
|
298 |
+
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
|
299 |
+
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
|
300 |
+
if avg_pooler > 1:
|
301 |
+
bridge_length = encoder_length // avg_pooler
|
302 |
+
return encoder_length, bridge_length
|
303 |
+
|
304 |
+
def extract_fbank_features(self, waveform):
|
305 |
+
# ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
|
306 |
+
channels, wave_samples = waveform.shape
|
307 |
+
assert(wave_samples >= self.config.n_fft)
|
308 |
+
valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1)
|
309 |
+
if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate:
|
310 |
+
waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0)
|
311 |
+
else:
|
312 |
+
waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate]
|
313 |
+
|
314 |
+
# window = torch.hann_window(self.config.n_fft)
|
315 |
+
stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=self.window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
|
316 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
317 |
+
|
318 |
+
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
319 |
+
mel_spec = mel_filters.T @ magnitudes
|
320 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
321 |
+
if waveform.dim() == 2:
|
322 |
+
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
|
323 |
+
log_spec = torch.maximum(log_spec, max_val - 8.0)
|
324 |
+
else:
|
325 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
326 |
+
log_spec = (log_spec + 4.0) / 4.0
|
327 |
+
|
328 |
+
log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
|
329 |
+
log_spec[:, valid_frame_nums:] = 0.0 # pad0
|
330 |
+
|
331 |
+
return log_spec, valid_frame_nums
|
332 |
+
|
333 |
+
def data_augment(self, feature: np.array, input_length, training=True):
|
334 |
+
# reference https://arxiv.org/pdf/1904.08779
|
335 |
+
def mask_start_indices(input_length, mask_length, min_masks, mask_prob):
|
336 |
+
num_masked_span = int(mask_prob * input_length / mask_length + random.random())
|
337 |
+
num_masked_span = max(num_masked_span, min_masks)
|
338 |
+
start_indices = list(range(input_length - mask_length))
|
339 |
+
random.shuffle(start_indices)
|
340 |
+
start_indices = start_indices[:num_masked_span]
|
341 |
+
return start_indices
|
342 |
+
|
343 |
+
if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0):
|
344 |
+
return feature
|
345 |
+
if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1:
|
346 |
+
return feature
|
347 |
+
if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1:
|
348 |
+
return feature
|
349 |
+
|
350 |
+
if self.config.mask_time_prob > 0:
|
351 |
+
start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob)
|
352 |
+
for start_idx in start_indices:
|
353 |
+
feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0
|
354 |
+
if self.config.mask_feature_prob > 0:
|
355 |
+
start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob)
|
356 |
+
for start_idx in start_indices:
|
357 |
+
feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0
|
358 |
+
|
359 |
+
return feature
|
360 |
+
|
361 |
+
@dataclass
|
362 |
+
class OmniProcessorOutput(ModelOutput):
|
363 |
+
input_ids: Optional["List|torch.Tensor"] = None
|
364 |
+
labels: Optional["List|torch.Tensor"] = None
|
365 |
+
attention_mask: Optional["List|torch.Tensor"] = None
|
366 |
+
position_ids: Optional["List|torch.Tensor"] = None
|
367 |
+
seqlens: Optional["List|torch.Tensor"] = None # 需要配合Omni Modeling使用
|
368 |
+
# audio fields
|
369 |
+
audios: Optional["List|torch.Tensor"] = None
|
370 |
+
encoder_length: Optional["List|torch.Tensor"] = None
|
371 |
+
bridge_length: Optional["List|torch.Tensor"] = None
|
372 |
+
# image fields
|
373 |
+
images: Optional["List|torch.Tensor"] = None
|
374 |
+
patch_nums: Optional["List|torch.Tensor"] = None
|
375 |
+
images_size: Optional["List|torch.Tensor"] = None
|
376 |
+
crop_size: Optional["List|torch.Tensor"] = None
|
377 |
+
images_grid: Optional["List|torch.Tensor"] = None
|
378 |
+
# video fields
|
379 |
+
videos: Optional["List|torch.Tensor"] = None
|
380 |
+
videos_patch_nums: Optional["List|torch.Tensor"] = None
|
381 |
+
videos_size: Optional["List|torch.Tensor"] = None
|
382 |
+
videos_crop_size: Optional["List|torch.Tensor"] = None
|
383 |
+
videos_grid: Optional["List|torch.Tensor"] = None
|
384 |
+
# processor fields
|
385 |
+
raw_text: Optional[str] = None
|
386 |
+
index: Optional[int] = None
|
387 |
+
|
388 |
+
def concatenate(self, other): # 仅限list使用
|
389 |
+
def concat_one(a, b):
|
390 |
+
if a is None and b is None:
|
391 |
+
return None
|
392 |
+
elif a is None and b is not None:
|
393 |
+
return b
|
394 |
+
elif a is not None and b is None:
|
395 |
+
return a
|
396 |
+
else:
|
397 |
+
return a + b
|
398 |
+
return OmniProcessorOutput(
|
399 |
+
input_ids=concat_one(self.input_ids, other.input_ids),
|
400 |
+
labels=concat_one(self.labels, other.labels),
|
401 |
+
audios=concat_one(self.audios, other.audios),
|
402 |
+
encoder_length=concat_one(self.encoder_length, other.encoder_length),
|
403 |
+
bridge_length=concat_one(self.bridge_length, other.bridge_length),
|
404 |
+
images=concat_one(self.images, other.images),
|
405 |
+
images_grid=concat_one(self.images_grid, other.images_grid),
|
406 |
+
patch_nums=concat_one(self.patch_nums, other.patch_nums),
|
407 |
+
|
408 |
+
videos=concat_one(self.videos, other.videos),
|
409 |
+
videos_grid=concat_one(self.videos_grid, other.videos_grid),
|
410 |
+
videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums),
|
411 |
+
|
412 |
+
position_ids=concat_one(self.position_ids, other.position_ids),
|
413 |
+
seqlens=concat_one(self.seqlens, other.seqlens),
|
414 |
+
images_size=concat_one(self.images_size, other.images_size),
|
415 |
+
videos_size=concat_one(self.videos_size, other.videos_size),
|
416 |
+
index = self.index # concat保持index不变
|
417 |
+
)
|
418 |
+
|
419 |
+
class OmniMMProcessor(object):
|
420 |
+
def __init__(self,
|
421 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
422 |
+
config,
|
423 |
+
training,
|
424 |
+
relative_path=None,
|
425 |
+
parallel=None,
|
426 |
+
**kwargs,
|
427 |
+
):
|
428 |
+
self.tokenizer = tokenizer
|
429 |
+
self.config = config
|
430 |
+
self.audio_processor = OmniAudioProcessor(config.audio_config)
|
431 |
+
self.visual_processor = None
|
432 |
+
if hasattr(config, "visual_config"):
|
433 |
+
self.visual_processor = OmniImageProcessor(config.visual_config)
|
434 |
+
self.video_processor = None
|
435 |
+
if hasattr(config, "video_config"):
|
436 |
+
self.video_processor = OmniImageProcessor(config.video_config)
|
437 |
+
self.training = training
|
438 |
+
self.relative_path = relative_path
|
439 |
+
self.parallel = parallel
|
440 |
+
# audio tag
|
441 |
+
self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id)
|
442 |
+
self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id)
|
443 |
+
self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id)
|
444 |
+
self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id)
|
445 |
+
self.audiogen_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_start_token_id)
|
446 |
+
self.audiogen_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_end_token_id)
|
447 |
+
# image tag
|
448 |
+
self.image_start_tag = None
|
449 |
+
self.image_end_tag = None
|
450 |
+
self.image_pad_tag = None
|
451 |
+
self.video_start_tag = None
|
452 |
+
self.video_end_tag = None
|
453 |
+
# videoframe tag只是为了兼容图片帧作为输入的情况,没有token id,在抽取视频帧的时候,会将这个替换成image tag的start、end
|
454 |
+
self.videoframe_start_tag = '<videoframe_start_omni>'
|
455 |
+
self.videoframe_end_tag = '<videoframe_end_omni>'
|
456 |
+
if hasattr(self.config, "visual_config"):
|
457 |
+
# special token for start_tag
|
458 |
+
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id)
|
459 |
+
# special token for end_tag
|
460 |
+
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id)
|
461 |
+
# special token for pad_tag
|
462 |
+
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id)
|
463 |
+
self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id)
|
464 |
+
self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id)
|
465 |
+
if hasattr(self.config, "video_config"):
|
466 |
+
self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id)
|
467 |
+
self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id)
|
468 |
+
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id)
|
469 |
+
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id)
|
470 |
+
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id)
|
471 |
+
self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id)
|
472 |
+
|
473 |
+
self.frame_pattern = getattr(self.config.video_config, 'frame_pattern', '<frame>')
|
474 |
+
|
475 |
+
|
476 |
+
# @lru_cache(maxsize=1024)
|
477 |
+
def _get_audio(self, audio_info):
|
478 |
+
try:
|
479 |
+
audio_info = ujson.loads(audio_info)
|
480 |
+
if 'path' in audio_info.keys():
|
481 |
+
audio_uri = None
|
482 |
+
if os.path.exists(audio_info['path']):
|
483 |
+
audio_uri = audio_info['path']
|
484 |
+
elif self.relative_path is not None:
|
485 |
+
audio_uri = os.path.join(self.relative_path, audio_info['path'].lstrip('/'))
|
486 |
+
if not os.path.exists(audio_uri):
|
487 |
+
audio_uri = None
|
488 |
+
if audio_uri is not None:
|
489 |
+
waveform = self.audio_processor.load_audio_waveform(audio_uri, True)
|
490 |
+
waveforms = self.audio_processor.split_with_overlap(waveform)
|
491 |
+
|
492 |
+
ret = OmniProcessorOutput() # 默认初始化 audios字段为None
|
493 |
+
for i, waveform in enumerate(waveforms): #(zip(waveforms,vocoder_waveforms)):
|
494 |
+
audio, input_length = self.audio_processor.extract_fbank_features(waveform)
|
495 |
+
audio = self.audio_processor.data_augment(audio, input_length, self.training)
|
496 |
+
encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length)
|
497 |
+
if bridge_length <= 0:
|
498 |
+
continue
|
499 |
+
current_ret = OmniProcessorOutput(
|
500 |
+
audios=[audio[:,:input_length]],
|
501 |
+
encoder_length=[encoder_length],
|
502 |
+
bridge_length=[bridge_length],
|
503 |
+
)
|
504 |
+
if ret.audios is None:
|
505 |
+
ret = current_ret
|
506 |
+
else:
|
507 |
+
ret = ret.concatenate(current_ret) # 拼接多个切片
|
508 |
+
return ret
|
509 |
+
else:
|
510 |
+
raise ValueError("can not find path in audio_info")
|
511 |
+
except Exception as e:
|
512 |
+
print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info)))
|
513 |
+
return OmniProcessorOutput()
|
514 |
+
|
515 |
+
# @lru_cache(maxsize=1024)
|
516 |
+
def _get_image(self, image_info):
|
517 |
+
try:
|
518 |
+
try:
|
519 |
+
image_info = ujson.loads(image_info)
|
520 |
+
except:
|
521 |
+
image_info = re.sub(r"(?<!\\)'", '"', image_info)
|
522 |
+
image_info = ujson.loads(image_info)
|
523 |
+
if 'base64' in image_info.keys():
|
524 |
+
image_data = base64.b64decode(image_info['base64'])
|
525 |
+
image_feat, org_size, image_list = self.visual_processor.image_transform(image_data)
|
526 |
+
elif 'local' in image_info.keys():
|
527 |
+
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'])
|
528 |
+
elif 'path' in image_info.keys() and os.path.exists(image_info['path']):
|
529 |
+
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['path'])
|
530 |
+
elif 'url' in image_info.keys():
|
531 |
+
image_bytes = self._get_vision_obj_byte('url', image_info['url'])
|
532 |
+
image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes)
|
533 |
+
else:
|
534 |
+
raise ValueError("can not find any path in image_info")
|
535 |
+
|
536 |
+
merge_length = self.visual_processor.merge_size**2
|
537 |
+
patch_nums = np.array(image_list).prod() // merge_length
|
538 |
+
|
539 |
+
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
|
540 |
+
return OmniProcessorOutput(
|
541 |
+
images=[image_feat],
|
542 |
+
patch_nums=[patch_nums],
|
543 |
+
crop_size=[image_list],
|
544 |
+
images_size= [org_size],
|
545 |
+
images_grid=[image_list]
|
546 |
+
)
|
547 |
+
else:
|
548 |
+
print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info)))
|
549 |
+
return OmniProcessorOutput()
|
550 |
+
|
551 |
+
except Exception as e:
|
552 |
+
print("**** get image error: {}, info: {} *****".format(str(e), str(image_info)))
|
553 |
+
return OmniProcessorOutput()
|
554 |
+
|
555 |
+
# @lru_cache(maxsize=1024)
|
556 |
+
def _get_video_frame(self, video_frame_infos):
|
557 |
+
try:
|
558 |
+
pattern = r'\{.*?\}'
|
559 |
+
matches = re.findall(pattern, video_frame_infos)
|
560 |
+
ret = OmniProcessorOutput()
|
561 |
+
# 逐个解析
|
562 |
+
for match in matches:
|
563 |
+
video_frame_info = ujson.loads(match)
|
564 |
+
# video_frame_info = ujson.loads(video_frame_info)
|
565 |
+
if 'local' in video_frame_info.keys():
|
566 |
+
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'])
|
567 |
+
elif 'path' in video_frame_info.keys() and os.path.exists(video_frame_info['path']):
|
568 |
+
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['path'])
|
569 |
+
else:
|
570 |
+
raise ValueError("can not find any path in video_info")
|
571 |
+
|
572 |
+
merge_length = self.video_processor.merge_size**2
|
573 |
+
patch_nums = np.array(image_list).prod() // merge_length
|
574 |
+
|
575 |
+
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
|
576 |
+
ret = ret.concatenate(
|
577 |
+
OmniProcessorOutput(
|
578 |
+
videos=[image_feat],
|
579 |
+
videos_patch_nums=[patch_nums],
|
580 |
+
videos_crop_size=[image_list],
|
581 |
+
videos_size= [org_size],
|
582 |
+
videos_grid=[image_list]
|
583 |
+
)
|
584 |
+
)
|
585 |
+
else:
|
586 |
+
print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info)))
|
587 |
+
return ret
|
588 |
+
|
589 |
+
except Exception as e:
|
590 |
+
print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info)))
|
591 |
+
return OmniProcessorOutput()
|
592 |
+
|
593 |
+
# 读取视频
|
594 |
+
def _get_vision_obj_byte(self, source, path):
|
595 |
+
vision_obj_byte = None
|
596 |
+
if source == "local":
|
597 |
+
if os.path.exists(path):
|
598 |
+
vision_obj_byte = open(path, "rb").read()
|
599 |
+
else:
|
600 |
+
vision_obj_byte = None
|
601 |
+
if source == "base64":
|
602 |
+
vision_obj_byte = base64.b64decode(path)
|
603 |
+
if source == "url":
|
604 |
+
vision_obj_byte = requests.get(url=path).content
|
605 |
+
return vision_obj_byte
|
606 |
+
|
607 |
+
# 将视频切分为帧,保存至子目录中
|
608 |
+
def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"):
|
609 |
+
if decode_way=='1fps':
|
610 |
+
frame_suffix = f'_frames'
|
611 |
+
elif decode_way=='key':
|
612 |
+
frame_suffix = f'_keyframes'
|
613 |
+
else:
|
614 |
+
raise ValueError('unvalid decode way!!!')
|
615 |
+
|
616 |
+
server = "local"
|
617 |
+
if 'local' in video_info.keys():
|
618 |
+
# 本地路径
|
619 |
+
video_path = video_info['local']
|
620 |
+
# 帧保存本地路径
|
621 |
+
frame_path = video_path.split('.')[0] + frame_suffix
|
622 |
+
mm_obj_byte = self._get_vision_obj_byte('local', video_path)
|
623 |
+
elif 'base64' in video_info.keys():
|
624 |
+
md5 = hashlib.md5(video_info['base64'].encode('utf-8')).hexdigest()
|
625 |
+
if self.relative_path is not None:
|
626 |
+
video_path = os.path.join(self.relative_path, md5)
|
627 |
+
else:
|
628 |
+
video_path = os.path.join(os.getcwd(), md5)
|
629 |
+
frame_path = md5 + frame_suffix
|
630 |
+
mm_obj_byte = self._get_vision_obj_byte('base64', video_info['base64'])
|
631 |
+
elif 'url' in video_info.keys():
|
632 |
+
md5 = hashlib.md5(video_info['url'].encode('utf-8')).hexdigest()
|
633 |
+
if self.relative_path is not None:
|
634 |
+
video_path = os.path.join(self.relative_path, md5)
|
635 |
+
else:
|
636 |
+
video_path = os.path.join(os.getcwd(), md5)
|
637 |
+
frame_path = md5 + frame_suffix
|
638 |
+
mm_obj_byte = self._get_vision_obj_byte('url', video_info['url'])
|
639 |
+
else:
|
640 |
+
raise ValueError('unvalid video server !!!')
|
641 |
+
return ""
|
642 |
+
|
643 |
+
if mm_obj_byte is None: # 未读取到视频文件
|
644 |
+
return ""
|
645 |
+
if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0:
|
646 |
+
# 保存帧
|
647 |
+
os.makedirs(frame_path, exist_ok=True)
|
648 |
+
frames, frame_times = read_video(io.BytesIO(mm_obj_byte), max_frame_number=-1, decode_way=decode_way) #读取全部帧
|
649 |
+
for frame_idx, frame in enumerate(frames):
|
650 |
+
output_filename = os.path.join(frame_path, f"{frame_times[frame_idx]}.jpg")
|
651 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
652 |
+
cv2.imwrite(output_filename, frame)
|
653 |
+
frame_paths = os.listdir(frame_path)
|
654 |
+
|
655 |
+
# 选取帧
|
656 |
+
frame_times = [int(filename.split('/')[-1].replace('.jpg', '')) for filename in frame_paths if filename.endswith('.jpg')] # 文件名对应秒数
|
657 |
+
frame_times.sort() #从小到大排序
|
658 |
+
frame_number = len(frame_times)
|
659 |
+
if frame_number > max_frame_number:
|
660 |
+
indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int)
|
661 |
+
else:
|
662 |
+
indices = np.linspace(0, frame_number - 1, frame_number, dtype=int)
|
663 |
+
# 拼接模式
|
664 |
+
replace_str = ""
|
665 |
+
for frame_idx, idx in enumerate(indices):
|
666 |
+
frame_time = frame_times[idx] # frame_time表示帧对应的时间 单位为s 同时也是存储的文件名
|
667 |
+
frame_dict = {"local": os.path.join(frame_path, f'{frame_time}.jpg')}
|
668 |
+
frame_str = self.frame_pattern.format(frame_idx) if '{}' in self.frame_pattern else self.frame_pattern # {}对应的是第几张图片
|
669 |
+
frame_str = frame_str.replace('<TIMEIDX>', str(frame_time)) # TIMEIDX对应的是第几秒
|
670 |
+
frame_str = frame_str.replace('<TIMESTAMP>', time.strftime("%H:%M:%S", time.gmtime(frame_time))) # TIMESTAMP对应的是时间戳
|
671 |
+
frame_str = frame_str.replace('<frame>', f'{self.image_start_tag}{json.dumps(frame_dict)}{self.image_end_tag}')
|
672 |
+
replace_str += frame_str
|
673 |
+
|
674 |
+
return replace_str
|
675 |
+
|
676 |
+
def sample_frame(self,frames_str,max_frame = 32):
|
677 |
+
def uniform_sample(lst, num_samples):
|
678 |
+
if num_samples > len(lst):
|
679 |
+
return lst
|
680 |
+
interval = len(lst) / num_samples
|
681 |
+
samples = [lst[int(i * interval)] for i in range(num_samples)]
|
682 |
+
return samples
|
683 |
+
p = rf'({self.image_start_tag}.*?{self.image_end_tag})'
|
684 |
+
frames_str_split = re.split(p,frames_str)
|
685 |
+
frame_idxs = [idx for idx in range(len(frames_str_split)) if self.image_start_tag in frames_str_split[idx]]
|
686 |
+
sample_frame_idxs = set(uniform_sample(frame_idxs, max_frame))
|
687 |
+
return ''.join([item for idx,item in enumerate(frames_str_split) if idx in sample_frame_idxs or self.image_start_tag not in frames_str_split[idx]])
|
688 |
+
|
689 |
+
def _get_video_frame_str(self, video_info):
|
690 |
+
try:
|
691 |
+
if self.videoframe_start_tag in video_info:#如果是以视频帧的形式表示一个视频,则替换成image tag
|
692 |
+
frames_str = video_info
|
693 |
+
frames_str = frames_str.replace(self.videoframe_start_tag,self.image_start_tag).replace(self.videoframe_end_tag,self.image_end_tag)
|
694 |
+
return self.sample_frame(frames_str, max_frame = self.config.video_config.max_frame_num)
|
695 |
+
video_info = ujson.loads(video_info)
|
696 |
+
# 获取包含多帧图像路径的字符���,最大帧数量max_frame_number
|
697 |
+
frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way)
|
698 |
+
return frames_str
|
699 |
+
except Exception as e:
|
700 |
+
print("**** get video error: {}, info: {} *****".format(str(e), str(video_info)))
|
701 |
+
return ""
|
702 |
+
|
703 |
+
def _replace_image(self, image_text):
|
704 |
+
image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
|
705 |
+
ret = self._get_image(image_info) # 重复取结果 cached result
|
706 |
+
if ret.patch_nums is None:
|
707 |
+
return ''
|
708 |
+
return ret, self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag
|
709 |
+
|
710 |
+
def _replace_video_frame(self, video_frame_text):
|
711 |
+
video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text)
|
712 |
+
ret = self._get_video_frame(video_frame_info) # 重复取结果 cached result
|
713 |
+
if ret.videos_patch_nums is None:
|
714 |
+
return ''
|
715 |
+
video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))]
|
716 |
+
return ret, ''.join(video_frame_str)
|
717 |
+
|
718 |
+
|
719 |
+
def split_multimodal_chunk(self, text_list, mm_label_list, trainable_list, mtype='audio'):
|
720 |
+
# 抽取text中的json格式音频/图像信息,读取并转化为特征,同时估计encoder token数,填入对应数量的pad token
|
721 |
+
if (self.audio_start_tag != None) and (mtype == 'audio'):
|
722 |
+
match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag,re.S)
|
723 |
+
drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag,re.S)
|
724 |
+
elif (self.image_start_tag != None) and (mtype == 'image'):
|
725 |
+
match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag,re.S)
|
726 |
+
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag,re.S)
|
727 |
+
elif (self.audiogen_start_tag != None) and (mtype == 'audiogen'):
|
728 |
+
match_regex = re.compile(self.audiogen_start_tag + '.*?' + self.audiogen_end_tag,re.S)
|
729 |
+
drop_regex = re.compile(self.audiogen_start_tag + "|" + self.audiogen_end_tag,re.S)
|
730 |
+
elif (self.video_start_tag != None) and (mtype == 'video'):
|
731 |
+
match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag,re.S)
|
732 |
+
drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag,re.S)
|
733 |
+
else:
|
734 |
+
raise ValueError("mtype not supportted!")
|
735 |
+
new_text_list = []
|
736 |
+
new_mm_label_list = []
|
737 |
+
new_trainable_flag_list = []
|
738 |
+
for text,mm_label,trainable in zip(text_list,mm_label_list,trainable_list):
|
739 |
+
for t,m in zip(*split_text(text, match_regex)):
|
740 |
+
new_trainable_flag_list.append(trainable)
|
741 |
+
if m:
|
742 |
+
new_text_list.append(re.sub(drop_regex, '', t))
|
743 |
+
new_mm_label_list.append(mtype)
|
744 |
+
else:
|
745 |
+
new_text_list.append(t)
|
746 |
+
new_mm_label_list.append(mm_label)
|
747 |
+
return new_text_list, new_mm_label_list, new_trainable_flag_list
|
748 |
+
|
749 |
+
def process_multimodal_chunk(self, text, mm_label, trainable):
|
750 |
+
ret = OmniProcessorOutput()
|
751 |
+
if mm_label == 'audio':
|
752 |
+
ret = self._get_audio(text)
|
753 |
+
if ret.bridge_length is not None:
|
754 |
+
ret.input_ids = self.tokenizer.encode(self.audio_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audio_end_tag,add_special_tokens=False)
|
755 |
+
else:
|
756 |
+
raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
|
757 |
+
elif mm_label == 'audiogen':
|
758 |
+
ret = self._get_audio(text)
|
759 |
+
if ret.bridge_length is not None:
|
760 |
+
ret.input_ids = self.tokenizer.encode(self.audiogen_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audiogen_end_tag,add_special_tokens=False)
|
761 |
+
else:
|
762 |
+
raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
|
763 |
+
elif mm_label == 'image':
|
764 |
+
ret, input_str = self._replace_image(text)
|
765 |
+
if input_str:
|
766 |
+
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
|
767 |
+
else:
|
768 |
+
raise ValueError("Get image data Failed at Process image chunk")
|
769 |
+
elif mm_label == 'video':
|
770 |
+
frame_str = self.video_start_tag+self._get_video_frame_str(text)+self.video_end_tag
|
771 |
+
ret, input_str = self._replace_video_frame(frame_str)
|
772 |
+
if input_str:
|
773 |
+
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
|
774 |
+
else:
|
775 |
+
raise ValueError("Get video data Failed at Process video chunk")
|
776 |
+
elif mm_label == 'text':
|
777 |
+
ret.input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
778 |
+
if len(ret.input_ids) > self.tokenizer.model_max_length-1: # 过滤长文本
|
779 |
+
raise ValueError(f"Text too long, please check text length! 【{text[:5]+'...'*6+text[-5:]}】")
|
780 |
+
else:
|
781 |
+
raise ValueError(f"mm_label not supportted! must in ['audio', 'image', 'text'] but get {mm_label}")
|
782 |
+
return ret
|
783 |
+
|
784 |
+
def process_one(self, text, index=0, raw_only=False):
|
785 |
+
ret = OmniProcessorOutput(index=index)
|
786 |
+
all_text_list = []
|
787 |
+
all_mm_label_list = []
|
788 |
+
all_trainable_flag_list = []
|
789 |
+
text_list, match_flag = split_text(text, re.compile("<trainable_start>.*?<trainable_end>",re.S))
|
790 |
+
if len(text_list) == 1:
|
791 |
+
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text_list[0])
|
792 |
+
all_text_list.append(text)
|
793 |
+
all_mm_label_list.append('text')
|
794 |
+
all_trainable_flag_list.append(True)
|
795 |
+
else:
|
796 |
+
for text, match in zip(text_list, match_flag):
|
797 |
+
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text)
|
798 |
+
if text.strip() == '':
|
799 |
+
continue # 把多余的空格干掉
|
800 |
+
all_text_list.append(text)
|
801 |
+
all_mm_label_list.append('text')
|
802 |
+
all_trainable_flag_list.append(match)
|
803 |
+
# 处理多模态信息
|
804 |
+
for mtype in self.config.multimodal: # 循环获取音频 图像结果
|
805 |
+
all_text_list, all_mm_label_list, all_trainable_flag_list = self.split_multimodal_chunk(all_text_list, all_mm_label_list, all_trainable_flag_list, mtype)
|
806 |
+
if len(all_text_list) == 0:
|
807 |
+
print(f"Process {text} chunk error: No valid Text data!!!!!")
|
808 |
+
return OmniProcessorOutput(index=index)
|
809 |
+
|
810 |
+
for text, mm_label, trainable in zip(all_text_list, all_mm_label_list, all_trainable_flag_list):
|
811 |
+
try:
|
812 |
+
mret = self.process_multimodal_chunk(text, mm_label, trainable)
|
813 |
+
ret = ret.concatenate(mret)
|
814 |
+
except ValueError as e:
|
815 |
+
tt = text[:24].replace('\n','<LF>')
|
816 |
+
print(f"Process {tt if mm_label == 'text' else text} {mm_label} chunk error: {str(e)}")
|
817 |
+
return OmniProcessorOutput(index=index)
|
818 |
+
|
819 |
+
if raw_only:
|
820 |
+
ret.raw_text = self.tokenizer.decode(ret.input_ids, skip_special_tokens=False)
|
821 |
+
return ret
|
822 |
+
return ret
|
823 |
+
|
824 |
+
@torch.no_grad()
|
825 |
+
def __call__(self, example, parallel=128):
|
826 |
+
if isinstance(example, Dict):
|
827 |
+
pass
|
828 |
+
elif isinstance(example, str):
|
829 |
+
return self.process_one(example)
|
830 |
+
elif isinstance(example, List): # batch推理 异步多线程处理
|
831 |
+
with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor:
|
832 |
+
future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)]
|
833 |
+
batch_data = [key.result() for key in cf.as_completed(future_list)]
|
834 |
+
valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data])
|
835 |
+
assert(valid_num == len(batch_data)) # 推理数据严格要求数量对齐
|
836 |
+
batch_data = sorted(batch_data, key=lambda x: x.index) # 保证顺序不变
|
837 |
+
|
838 |
+
ret = OmniProcessorOutput()
|
839 |
+
for i in range(len(batch_data)):
|
840 |
+
ret = ret.concatenate(batch_data[i])
|
841 |
+
self.tokenizer.padding_side = "left"
|
842 |
+
max_len = min(max([len(x.input_ids) for x in batch_data]),self.tokenizer.model_max_length)
|
843 |
+
padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt')
|
844 |
+
ret.input_ids = padding_result["input_ids"]
|
845 |
+
ret.attention_mask = padding_result["attention_mask"] # batch推理不pack 不需要seqlens
|
846 |
+
|
847 |
+
if ret.audios is not None:
|
848 |
+
max_audios_len = max([x.shape[-1] for x in ret.audios])
|
849 |
+
ret.audios = default_collate([np.pad(x, ((0,0),(0,max_audios_len - x.shape[-1])), 'constant', constant_values=0) for x in ret.audios])
|
850 |
+
|
851 |
+
ret.encoder_length = default_collate(ret.encoder_length)
|
852 |
+
ret.bridge_length = default_collate(ret.bridge_length)
|
853 |
+
|
854 |
+
if ret.images is not None:
|
855 |
+
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]
|
856 |
+
ret.patch_nums = default_collate(ret.patch_nums)
|
857 |
+
|
858 |
+
if ret.videos is not None:
|
859 |
+
ret.videos = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos]
|
860 |
+
ret.videos_patch_nums = default_collate(ret.videos_patch_nums)
|
861 |
+
|
862 |
+
return ret
|
863 |
+
|
864 |
+
else:
|
865 |
+
raise ValueError("example format supported yet")
|
sequence_parallel_utils.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor
|
7 |
+
from flash_attn import flash_attn_varlen_func
|
8 |
+
try:
|
9 |
+
import deepspeed.comm as dist
|
10 |
+
except:
|
11 |
+
dist = None
|
12 |
+
|
13 |
+
|
14 |
+
try:
|
15 |
+
from utils import (
|
16 |
+
get_sequence_parallel_group,
|
17 |
+
get_sequence_parallel_size,
|
18 |
+
get_sequence_parallel_rank
|
19 |
+
)
|
20 |
+
except (ModuleNotFoundError, ImportError):
|
21 |
+
# 从 utils 获取seq parallel设置,import不成功默认为不开启
|
22 |
+
get_sequence_parallel_group = lambda : None
|
23 |
+
get_sequence_parallel_size = lambda : 1
|
24 |
+
get_sequence_parallel_rank = lambda : 0
|
25 |
+
|
26 |
+
|
27 |
+
def single_all_to_all(input, scatter_idx, gather_idx, group):
|
28 |
+
seq_world_size = dist.get_world_size(group)
|
29 |
+
inp_shape = list(input.shape)
|
30 |
+
inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
|
31 |
+
if scatter_idx < 2:
|
32 |
+
input_t = input.reshape(
|
33 |
+
[seq_world_size, inp_shape[scatter_idx]] + \
|
34 |
+
inp_shape[scatter_idx + 1:]
|
35 |
+
).contiguous()
|
36 |
+
else:
|
37 |
+
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
|
38 |
+
input_t = input.reshape(
|
39 |
+
[-1, seq_world_size, inp_shape[scatter_idx]] + \
|
40 |
+
inp_shape[scatter_idx + 1:]
|
41 |
+
).transpose(0, 1).contiguous()
|
42 |
+
|
43 |
+
output = torch.empty_like(input_t)
|
44 |
+
dist.all_to_all_single(output, input_t, group=group)
|
45 |
+
|
46 |
+
# if scattering the seq-dim, transpose the heads back to the original dimension
|
47 |
+
# [sp_size, seq_len//sp_size, batch_size, head_num // sp_size, head_dim] -->
|
48 |
+
# [seq_len//sp_size,batch_size, sp_size, head_num // sp_size, head_dim]
|
49 |
+
if scatter_idx < 2:
|
50 |
+
output = output.transpose(0, 1).transpose(1, 2).contiguous()
|
51 |
+
|
52 |
+
return output.reshape(
|
53 |
+
inp_shape[: gather_idx] + \
|
54 |
+
[inp_shape[gather_idx] * seq_world_size,] + \
|
55 |
+
inp_shape[gather_idx + 1:]).contiguous()
|
56 |
+
|
57 |
+
|
58 |
+
class _SeqAllToAll(torch.autograd.Function):
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def forward(ctx: Any, group: 'dist.ProcessGroup', input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
|
62 |
+
ctx.group = group
|
63 |
+
ctx.scatter_idx = scatter_idx
|
64 |
+
ctx.gather_idx = gather_idx
|
65 |
+
|
66 |
+
return single_all_to_all(input, scatter_idx, gather_idx, group)
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
70 |
+
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
|
71 |
+
|
72 |
+
|
73 |
+
# import from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
|
74 |
+
# but fix some bugs for 符合训练的维度设置
|
75 |
+
class DistributedAttention(nn.Module):
|
76 |
+
"""Initialization.
|
77 |
+
|
78 |
+
Arguments:
|
79 |
+
local_attention (Module): local attention with q,k,v
|
80 |
+
sequence_process_group (ProcessGroup): sequence parallel process group
|
81 |
+
scatter_idx (int): scatter_idx for all2all comm
|
82 |
+
gather_idx (int): gather_idx for all2all comm
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
local_attention: nn.Module,
|
88 |
+
sequence_process_group: 'dist.ProcessGroup',
|
89 |
+
scatter_idx: int = 2,
|
90 |
+
gather_idx: int = 0,
|
91 |
+
) -> None:
|
92 |
+
|
93 |
+
super(DistributedAttention, self).__init__()
|
94 |
+
self.local_attn = local_attention
|
95 |
+
self.spg = sequence_process_group
|
96 |
+
self.scatter_idx = scatter_idx
|
97 |
+
self.gather_idx = gather_idx
|
98 |
+
|
99 |
+
def pad_attention_head(self, query: Tensor, key: Tensor, value: Tensor):
|
100 |
+
# 将输入的head 维度pad到sp_size的倍数
|
101 |
+
sp_size = torch.distributed.get_world_size(self.spg)
|
102 |
+
pad_size = (sp_size - query.size(1) % sp_size) % sp_size
|
103 |
+
if pad_size > 0:
|
104 |
+
# [bs, num_head, seq_len, head_dim] -> [bs, num_head+pad_size, seq_len, head_dim]
|
105 |
+
query = torch.nn.functional.pad(query, (0,0,0,0,0,pad_size), value = 0.01)
|
106 |
+
key = torch.nn.functional.pad(key, (0,0,0,0,0,pad_size), value = 0.01)
|
107 |
+
value = torch.nn.functional.pad(value, (0,0,0,0,0,pad_size),value=0.0)
|
108 |
+
return query, key, value
|
109 |
+
|
110 |
+
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
|
111 |
+
""" forward
|
112 |
+
|
113 |
+
Arguments:
|
114 |
+
query (Tensor): query input to the layer [batch_size, num_head, seq_len, head_dim]
|
115 |
+
key (Tensor): key input to the layer
|
116 |
+
value (Tensor): value input to the layer
|
117 |
+
args: other args
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
* output (Tensor): context output
|
121 |
+
"""
|
122 |
+
# TODO Merge three alltoall calls into one
|
123 |
+
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
|
124 |
+
# [batch_size,num_head,seq_len, head_dim ]trans to [seq_len,batch_size,num_head,head_dim]
|
125 |
+
origin_num_head = query.size(1)
|
126 |
+
query, key, value = self.pad_attention_head(query,key,value)
|
127 |
+
|
128 |
+
query = query.transpose(1,2).transpose(0,1)
|
129 |
+
key = key.transpose(1,2).transpose(0,1)
|
130 |
+
value = value.transpose(1,2).transpose(0,1)
|
131 |
+
#in shape : e.g., [s/p,bs,h,head_dim]
|
132 |
+
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
|
133 |
+
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
|
134 |
+
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
|
135 |
+
|
136 |
+
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
|
137 |
+
context_layer = context_layer.transpose(0,1).contiguous()
|
138 |
+
# [seq_len, batch_size, num_head, head_dim]
|
139 |
+
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
|
140 |
+
return output.transpose(0,1)[:,:,:origin_num_head,:]
|
141 |
+
|
142 |
+
|
143 |
+
class LocalAttention(nn.Module):
|
144 |
+
def __init__(self, hidden_size, num_heads, head_dim):
|
145 |
+
super().__init__()
|
146 |
+
self.hidden_size = hidden_size
|
147 |
+
self.num_heads = num_heads
|
148 |
+
self.head_dim = head_dim
|
149 |
+
|
150 |
+
def forward(self, q, k, v, *args, use_flash=True, **kwargs):
|
151 |
+
# input q,k,v [batch_size, num_head, seq_len, head_dim]
|
152 |
+
# output [batch_size, seq_len, num_head, head_dim]
|
153 |
+
if use_flash:
|
154 |
+
q_len, num_heads = q.shape[2], q.shape[1]
|
155 |
+
q = q.transpose(1,2).reshape(-1, num_heads, self.head_dim)
|
156 |
+
k = k.transpose(1,2).reshape(-1, num_heads, self.head_dim)
|
157 |
+
v = v.transpose(1,2).reshape(-1, num_heads, self.head_dim)
|
158 |
+
return flash_attn_varlen_func(q,k,v,*args, **kwargs).reshape(-1,q_len, num_heads, self.head_dim)
|
159 |
+
else:
|
160 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
161 |
+
attn_output = F.scaled_dot_product_attention(
|
162 |
+
q,k,v, *args, **kwargs)
|
163 |
+
attn_output = attn_output.transpose(1, 2)
|
164 |
+
return attn_output
|
165 |
+
|
166 |
+
|
167 |
+
def create_attention_layer(hidden_size, num_heads, head_dim):
|
168 |
+
if get_sequence_parallel_group() is None:
|
169 |
+
return LocalAttention(hidden_size, num_heads, head_dim)
|
170 |
+
else:
|
171 |
+
return DistributedAttention(
|
172 |
+
local_attention=LocalAttention(hidden_size, num_heads, head_dim),
|
173 |
+
sequence_process_group=get_sequence_parallel_group()
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
def get_sequence_parallel_chunk(tensor, dim=1, shift=0):
|
178 |
+
assert tensor.size(dim) % get_sequence_parallel_size() == 0
|
179 |
+
original_size = tensor.size(dim)
|
180 |
+
if shift:
|
181 |
+
tensor = tensor.split([shift, tensor.size(dim) - shift], dim=dim)[1]
|
182 |
+
if get_sequence_parallel_group() is None:
|
183 |
+
return tensor
|
184 |
+
else:
|
185 |
+
chunk_size = original_size // get_sequence_parallel_size()
|
186 |
+
return tensor.split(chunk_size, dim=dim)[get_sequence_parallel_rank()]
|
special_tokens_map.json
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|im_start|>",
|
4 |
+
"<|im_end|>",
|
5 |
+
"<|object_ref_start|>",
|
6 |
+
"<|object_ref_end|>",
|
7 |
+
"<|box_start|>",
|
8 |
+
"<|box_end|>",
|
9 |
+
"<|quad_start|>",
|
10 |
+
"<|quad_end|>",
|
11 |
+
"<|vision_start|>",
|
12 |
+
"<|vision_end|>",
|
13 |
+
"<|vision_pad|>",
|
14 |
+
"<|image_pad|>",
|
15 |
+
"<|video_pad|>",
|
16 |
+
"<B_SYS>",
|
17 |
+
"<B_USYS>",
|
18 |
+
"<C_Q>",
|
19 |
+
"<C_A>",
|
20 |
+
"<B_FUNC>",
|
21 |
+
"<B_CODE>",
|
22 |
+
"<B_APE>",
|
23 |
+
"<function_calling>",
|
24 |
+
"<calc_start>",
|
25 |
+
"<calc_end>",
|
26 |
+
"<inner_think>",
|
27 |
+
"<audio_start_baichuan>",
|
28 |
+
"<audio_end_baichuan>",
|
29 |
+
"<audio_pad_baichuan>",
|
30 |
+
"<img_start_baichuan>",
|
31 |
+
"<img_end_baichuan>",
|
32 |
+
"<img_pad_baichuan>",
|
33 |
+
"<img_newline_baichuan>",
|
34 |
+
"<box_start_baichuan>",
|
35 |
+
"<box_end_baichuan>",
|
36 |
+
"<box_delim_baichuan>",
|
37 |
+
"<ref_start_baichuan>",
|
38 |
+
"<ref_end_baichuan>",
|
39 |
+
"<img_delim_baichuan>",
|
40 |
+
"<polygon_start_baichuan>",
|
41 |
+
"<polygon_end_baichuan>",
|
42 |
+
"<baichuan_pad_token>",
|
43 |
+
"<reserved_113>",
|
44 |
+
"<audio_delim_baichuan>",
|
45 |
+
"<video_start_baichuan>",
|
46 |
+
"<video_end_baichuan>",
|
47 |
+
"<video_palce_baichuan>",
|
48 |
+
"<audiotext_start_baichuan>",
|
49 |
+
"<audiotext_end_baichuan>",
|
50 |
+
"<audiotext_pad_baichuan>",
|
51 |
+
"<audiogen_start_baichuan>",
|
52 |
+
"<audiogen_end_baichuan>"
|
53 |
+
],
|
54 |
+
"eos_token": {
|
55 |
+
"content": "<|endoftext|>",
|
56 |
+
"lstrip": false,
|
57 |
+
"normalized": false,
|
58 |
+
"rstrip": false,
|
59 |
+
"single_word": false
|
60 |
+
},
|
61 |
+
"pad_token": {
|
62 |
+
"content": "<|endoftext|>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": false,
|
66 |
+
"single_word": false
|
67 |
+
}
|
68 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_prefix_space": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"151643": {
|
6 |
+
"content": "<|endoftext|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"151644": {
|
14 |
+
"content": "<|im_start|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"151645": {
|
22 |
+
"content": "<|im_end|>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"151646": {
|
30 |
+
"content": "<|object_ref_start|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
},
|
37 |
+
"151647": {
|
38 |
+
"content": "<|object_ref_end|>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false,
|
43 |
+
"special": true
|
44 |
+
},
|
45 |
+
"151648": {
|
46 |
+
"content": "<|box_start|>",
|
47 |
+
"lstrip": false,
|
48 |
+
"normalized": false,
|
49 |
+
"rstrip": false,
|
50 |
+
"single_word": false,
|
51 |
+
"special": true
|
52 |
+
},
|
53 |
+
"151649": {
|
54 |
+
"content": "<|box_end|>",
|
55 |
+
"lstrip": false,
|
56 |
+
"normalized": false,
|
57 |
+
"rstrip": false,
|
58 |
+
"single_word": false,
|
59 |
+
"special": true
|
60 |
+
},
|
61 |
+
"151650": {
|
62 |
+
"content": "<|quad_start|>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": false,
|
66 |
+
"single_word": false,
|
67 |
+
"special": true
|
68 |
+
},
|
69 |
+
"151651": {
|
70 |
+
"content": "<|quad_end|>",
|
71 |
+
"lstrip": false,
|
72 |
+
"normalized": false,
|
73 |
+
"rstrip": false,
|
74 |
+
"single_word": false,
|
75 |
+
"special": true
|
76 |
+
},
|
77 |
+
"151652": {
|
78 |
+
"content": "<|vision_start|>",
|
79 |
+
"lstrip": false,
|
80 |
+
"normalized": false,
|
81 |
+
"rstrip": false,
|
82 |
+
"single_word": false,
|
83 |
+
"special": true
|
84 |
+
},
|
85 |
+
"151653": {
|
86 |
+
"content": "<|vision_end|>",
|
87 |
+
"lstrip": false,
|
88 |
+
"normalized": false,
|
89 |
+
"rstrip": false,
|
90 |
+
"single_word": false,
|
91 |
+
"special": true
|
92 |
+
},
|
93 |
+
"151654": {
|
94 |
+
"content": "<|vision_pad|>",
|
95 |
+
"lstrip": false,
|
96 |
+
"normalized": false,
|
97 |
+
"rstrip": false,
|
98 |
+
"single_word": false,
|
99 |
+
"special": true
|
100 |
+
},
|
101 |
+
"151655": {
|
102 |
+
"content": "<|image_pad|>",
|
103 |
+
"lstrip": false,
|
104 |
+
"normalized": false,
|
105 |
+
"rstrip": false,
|
106 |
+
"single_word": false,
|
107 |
+
"special": true
|
108 |
+
},
|
109 |
+
"151656": {
|
110 |
+
"content": "<|video_pad|>",
|
111 |
+
"lstrip": false,
|
112 |
+
"normalized": false,
|
113 |
+
"rstrip": false,
|
114 |
+
"single_word": false,
|
115 |
+
"special": true
|
116 |
+
},
|
117 |
+
"151657": {
|
118 |
+
"content": "<tool_call>",
|
119 |
+
"lstrip": false,
|
120 |
+
"normalized": false,
|
121 |
+
"rstrip": false,
|
122 |
+
"single_word": false,
|
123 |
+
"special": false
|
124 |
+
},
|
125 |
+
"151658": {
|
126 |
+
"content": "</tool_call>",
|
127 |
+
"lstrip": false,
|
128 |
+
"normalized": false,
|
129 |
+
"rstrip": false,
|
130 |
+
"single_word": false,
|
131 |
+
"special": false
|
132 |
+
},
|
133 |
+
"151659": {
|
134 |
+
"content": "<|fim_prefix|>",
|
135 |
+
"lstrip": false,
|
136 |
+
"normalized": false,
|
137 |
+
"rstrip": false,
|
138 |
+
"single_word": false,
|
139 |
+
"special": false
|
140 |
+
},
|
141 |
+
"151660": {
|
142 |
+
"content": "<|fim_middle|>",
|
143 |
+
"lstrip": false,
|
144 |
+
"normalized": false,
|
145 |
+
"rstrip": false,
|
146 |
+
"single_word": false,
|
147 |
+
"special": false
|
148 |
+
},
|
149 |
+
"151661": {
|
150 |
+
"content": "<|fim_suffix|>",
|
151 |
+
"lstrip": false,
|
152 |
+
"normalized": false,
|
153 |
+
"rstrip": false,
|
154 |
+
"single_word": false,
|
155 |
+
"special": false
|
156 |
+
},
|
157 |
+
"151662": {
|
158 |
+
"content": "<|fim_pad|>",
|
159 |
+
"lstrip": false,
|
160 |
+
"normalized": false,
|
161 |
+
"rstrip": false,
|
162 |
+
"single_word": false,
|
163 |
+
"special": false
|
164 |
+
},
|
165 |
+
"151663": {
|
166 |
+
"content": "<|repo_name|>",
|
167 |
+
"lstrip": false,
|
168 |
+
"normalized": false,
|
169 |
+
"rstrip": false,
|
170 |
+
"single_word": false,
|
171 |
+
"special": false
|
172 |
+
},
|
173 |
+
"151664": {
|
174 |
+
"content": "<|file_sep|>",
|
175 |
+
"lstrip": false,
|
176 |
+
"normalized": false,
|
177 |
+
"rstrip": false,
|
178 |
+
"single_word": false,
|
179 |
+
"special": false
|
180 |
+
},
|
181 |
+
"151665": {
|
182 |
+
"content": "<B_SYS>",
|
183 |
+
"lstrip": false,
|
184 |
+
"normalized": false,
|
185 |
+
"rstrip": false,
|
186 |
+
"single_word": false,
|
187 |
+
"special": true
|
188 |
+
},
|
189 |
+
"151666": {
|
190 |
+
"content": "<B_USYS>",
|
191 |
+
"lstrip": false,
|
192 |
+
"normalized": false,
|
193 |
+
"rstrip": false,
|
194 |
+
"single_word": false,
|
195 |
+
"special": true
|
196 |
+
},
|
197 |
+
"151667": {
|
198 |
+
"content": "<C_Q>",
|
199 |
+
"lstrip": false,
|
200 |
+
"normalized": false,
|
201 |
+
"rstrip": false,
|
202 |
+
"single_word": false,
|
203 |
+
"special": true
|
204 |
+
},
|
205 |
+
"151668": {
|
206 |
+
"content": "<C_A>",
|
207 |
+
"lstrip": false,
|
208 |
+
"normalized": false,
|
209 |
+
"rstrip": false,
|
210 |
+
"single_word": false,
|
211 |
+
"special": true
|
212 |
+
},
|
213 |
+
"151669": {
|
214 |
+
"content": "<B_FUNC>",
|
215 |
+
"lstrip": false,
|
216 |
+
"normalized": false,
|
217 |
+
"rstrip": false,
|
218 |
+
"single_word": false,
|
219 |
+
"special": true
|
220 |
+
},
|
221 |
+
"151670": {
|
222 |
+
"content": "<B_CODE>",
|
223 |
+
"lstrip": false,
|
224 |
+
"normalized": false,
|
225 |
+
"rstrip": false,
|
226 |
+
"single_word": false,
|
227 |
+
"special": true
|
228 |
+
},
|
229 |
+
"151671": {
|
230 |
+
"content": "<B_APE>",
|
231 |
+
"lstrip": false,
|
232 |
+
"normalized": false,
|
233 |
+
"rstrip": false,
|
234 |
+
"single_word": true,
|
235 |
+
"special": true
|
236 |
+
},
|
237 |
+
"151672": {
|
238 |
+
"content": "<function_calling>",
|
239 |
+
"lstrip": false,
|
240 |
+
"normalized": false,
|
241 |
+
"rstrip": false,
|
242 |
+
"single_word": true,
|
243 |
+
"special": true
|
244 |
+
},
|
245 |
+
"151673": {
|
246 |
+
"content": "<calc_start>",
|
247 |
+
"lstrip": false,
|
248 |
+
"normalized": false,
|
249 |
+
"rstrip": false,
|
250 |
+
"single_word": true,
|
251 |
+
"special": true
|
252 |
+
},
|
253 |
+
"151674": {
|
254 |
+
"content": "<calc_end>",
|
255 |
+
"lstrip": false,
|
256 |
+
"normalized": false,
|
257 |
+
"rstrip": false,
|
258 |
+
"single_word": true,
|
259 |
+
"special": true
|
260 |
+
},
|
261 |
+
"151675": {
|
262 |
+
"content": "<inner_think>",
|
263 |
+
"lstrip": false,
|
264 |
+
"normalized": false,
|
265 |
+
"rstrip": false,
|
266 |
+
"single_word": true,
|
267 |
+
"special": true
|
268 |
+
},
|
269 |
+
"151676": {
|
270 |
+
"content": "<audio_start_baichuan>",
|
271 |
+
"lstrip": false,
|
272 |
+
"normalized": false,
|
273 |
+
"rstrip": false,
|
274 |
+
"single_word": false,
|
275 |
+
"special": true
|
276 |
+
},
|
277 |
+
"151677": {
|
278 |
+
"content": "<audio_end_baichuan>",
|
279 |
+
"lstrip": false,
|
280 |
+
"normalized": false,
|
281 |
+
"rstrip": false,
|
282 |
+
"single_word": false,
|
283 |
+
"special": true
|
284 |
+
},
|
285 |
+
"151678": {
|
286 |
+
"content": "<audio_pad_baichuan>",
|
287 |
+
"lstrip": false,
|
288 |
+
"normalized": false,
|
289 |
+
"rstrip": false,
|
290 |
+
"single_word": false,
|
291 |
+
"special": true
|
292 |
+
},
|
293 |
+
"151679": {
|
294 |
+
"content": "<img_start_baichuan>",
|
295 |
+
"lstrip": false,
|
296 |
+
"normalized": false,
|
297 |
+
"rstrip": false,
|
298 |
+
"single_word": false,
|
299 |
+
"special": true
|
300 |
+
},
|
301 |
+
"151680": {
|
302 |
+
"content": "<img_end_baichuan>",
|
303 |
+
"lstrip": false,
|
304 |
+
"normalized": false,
|
305 |
+
"rstrip": false,
|
306 |
+
"single_word": false,
|
307 |
+
"special": true
|
308 |
+
},
|
309 |
+
"151681": {
|
310 |
+
"content": "<img_pad_baichuan>",
|
311 |
+
"lstrip": false,
|
312 |
+
"normalized": false,
|
313 |
+
"rstrip": false,
|
314 |
+
"single_word": false,
|
315 |
+
"special": true
|
316 |
+
},
|
317 |
+
"151682": {
|
318 |
+
"content": "<img_newline_baichuan>",
|
319 |
+
"lstrip": false,
|
320 |
+
"normalized": false,
|
321 |
+
"rstrip": false,
|
322 |
+
"single_word": false,
|
323 |
+
"special": true
|
324 |
+
},
|
325 |
+
"151683": {
|
326 |
+
"content": "<box_start_baichuan>",
|
327 |
+
"lstrip": false,
|
328 |
+
"normalized": false,
|
329 |
+
"rstrip": false,
|
330 |
+
"single_word": false,
|
331 |
+
"special": true
|
332 |
+
},
|
333 |
+
"151684": {
|
334 |
+
"content": "<box_end_baichuan>",
|
335 |
+
"lstrip": false,
|
336 |
+
"normalized": false,
|
337 |
+
"rstrip": false,
|
338 |
+
"single_word": false,
|
339 |
+
"special": true
|
340 |
+
},
|
341 |
+
"151685": {
|
342 |
+
"content": "<box_delim_baichuan>",
|
343 |
+
"lstrip": false,
|
344 |
+
"normalized": false,
|
345 |
+
"rstrip": false,
|
346 |
+
"single_word": false,
|
347 |
+
"special": true
|
348 |
+
},
|
349 |
+
"151686": {
|
350 |
+
"content": "<ref_start_baichuan>",
|
351 |
+
"lstrip": false,
|
352 |
+
"normalized": false,
|
353 |
+
"rstrip": false,
|
354 |
+
"single_word": false,
|
355 |
+
"special": true
|
356 |
+
},
|
357 |
+
"151687": {
|
358 |
+
"content": "<ref_end_baichuan>",
|
359 |
+
"lstrip": false,
|
360 |
+
"normalized": false,
|
361 |
+
"rstrip": false,
|
362 |
+
"single_word": false,
|
363 |
+
"special": true
|
364 |
+
},
|
365 |
+
"151688": {
|
366 |
+
"content": "<img_delim_baichuan>",
|
367 |
+
"lstrip": false,
|
368 |
+
"normalized": false,
|
369 |
+
"rstrip": false,
|
370 |
+
"single_word": false,
|
371 |
+
"special": true
|
372 |
+
},
|
373 |
+
"151689": {
|
374 |
+
"content": "<polygon_start_baichuan>",
|
375 |
+
"lstrip": false,
|
376 |
+
"normalized": false,
|
377 |
+
"rstrip": false,
|
378 |
+
"single_word": false,
|
379 |
+
"special": true
|
380 |
+
},
|
381 |
+
"151690": {
|
382 |
+
"content": "<polygon_end_baichuan>",
|
383 |
+
"lstrip": false,
|
384 |
+
"normalized": false,
|
385 |
+
"rstrip": false,
|
386 |
+
"single_word": false,
|
387 |
+
"special": true
|
388 |
+
},
|
389 |
+
"151691": {
|
390 |
+
"content": "<baichuan_pad_token>",
|
391 |
+
"lstrip": false,
|
392 |
+
"normalized": false,
|
393 |
+
"rstrip": false,
|
394 |
+
"single_word": false,
|
395 |
+
"special": true
|
396 |
+
},
|
397 |
+
"151692": {
|
398 |
+
"content": "<reserved_113>",
|
399 |
+
"lstrip": false,
|
400 |
+
"normalized": false,
|
401 |
+
"rstrip": false,
|
402 |
+
"single_word": false,
|
403 |
+
"special": true
|
404 |
+
},
|
405 |
+
"151693": {
|
406 |
+
"content": "<audio_delim_baichuan>",
|
407 |
+
"lstrip": false,
|
408 |
+
"normalized": false,
|
409 |
+
"rstrip": false,
|
410 |
+
"single_word": false,
|
411 |
+
"special": true
|
412 |
+
},
|
413 |
+
"151694": {
|
414 |
+
"content": "<video_palce_baichuan>",
|
415 |
+
"lstrip": false,
|
416 |
+
"normalized": false,
|
417 |
+
"rstrip": false,
|
418 |
+
"single_word": false,
|
419 |
+
"special": true
|
420 |
+
},
|
421 |
+
"151695": {
|
422 |
+
"content": "<video_start_baichuan>",
|
423 |
+
"lstrip": false,
|
424 |
+
"normalized": false,
|
425 |
+
"rstrip": false,
|
426 |
+
"single_word": false,
|
427 |
+
"special": true
|
428 |
+
},
|
429 |
+
"151696": {
|
430 |
+
"content": "<video_end_baichuan>",
|
431 |
+
"lstrip": false,
|
432 |
+
"normalized": false,
|
433 |
+
"rstrip": false,
|
434 |
+
"single_word": false,
|
435 |
+
"special": true
|
436 |
+
},
|
437 |
+
"151697": {
|
438 |
+
"content": "<audiotext_start_baichuan>",
|
439 |
+
"lstrip": false,
|
440 |
+
"normalized": false,
|
441 |
+
"rstrip": false,
|
442 |
+
"single_word": false,
|
443 |
+
"special": true
|
444 |
+
},
|
445 |
+
"151698": {
|
446 |
+
"content": "<audiotext_end_baichuan>",
|
447 |
+
"lstrip": false,
|
448 |
+
"normalized": false,
|
449 |
+
"rstrip": false,
|
450 |
+
"single_word": false,
|
451 |
+
"special": true
|
452 |
+
},
|
453 |
+
"151699": {
|
454 |
+
"content": "<audiotext_pad_baichuan>",
|
455 |
+
"lstrip": false,
|
456 |
+
"normalized": false,
|
457 |
+
"rstrip": false,
|
458 |
+
"single_word": false,
|
459 |
+
"special": true
|
460 |
+
},
|
461 |
+
"151700": {
|
462 |
+
"content": "<audiogen_start_baichuan>",
|
463 |
+
"lstrip": false,
|
464 |
+
"normalized": false,
|
465 |
+
"rstrip": false,
|
466 |
+
"single_word": false,
|
467 |
+
"special": true
|
468 |
+
},
|
469 |
+
"151701": {
|
470 |
+
"content": "<audiogen_end_baichuan>",
|
471 |
+
"lstrip": false,
|
472 |
+
"normalized": false,
|
473 |
+
"rstrip": false,
|
474 |
+
"single_word": false,
|
475 |
+
"special": true
|
476 |
+
}
|
477 |
+
},
|
478 |
+
"additional_special_tokens": [
|
479 |
+
"<|im_start|>",
|
480 |
+
"<|im_end|>",
|
481 |
+
"<|object_ref_start|>",
|
482 |
+
"<|object_ref_end|>",
|
483 |
+
"<|box_start|>",
|
484 |
+
"<|box_end|>",
|
485 |
+
"<|quad_start|>",
|
486 |
+
"<|quad_end|>",
|
487 |
+
"<|vision_start|>",
|
488 |
+
"<|vision_end|>",
|
489 |
+
"<|vision_pad|>",
|
490 |
+
"<|image_pad|>",
|
491 |
+
"<|video_pad|>",
|
492 |
+
"<B_SYS>",
|
493 |
+
"<B_USYS>",
|
494 |
+
"<C_Q>",
|
495 |
+
"<C_A>",
|
496 |
+
"<B_FUNC>",
|
497 |
+
"<B_CODE>",
|
498 |
+
"<B_APE>",
|
499 |
+
"<function_calling>",
|
500 |
+
"<calc_start>",
|
501 |
+
"<calc_end>",
|
502 |
+
"<inner_think>",
|
503 |
+
"<audio_start_baichuan>",
|
504 |
+
"<audio_end_baichuan>",
|
505 |
+
"<audio_pad_baichuan>",
|
506 |
+
"<img_start_baichuan>",
|
507 |
+
"<img_end_baichuan>",
|
508 |
+
"<img_pad_baichuan>",
|
509 |
+
"<img_newline_baichuan>",
|
510 |
+
"<box_start_baichuan>",
|
511 |
+
"<box_end_baichuan>",
|
512 |
+
"<box_delim_baichuan>",
|
513 |
+
"<ref_start_baichuan>",
|
514 |
+
"<ref_end_baichuan>",
|
515 |
+
"<img_delim_baichuan>",
|
516 |
+
"<polygon_start_baichuan>",
|
517 |
+
"<polygon_end_baichuan>",
|
518 |
+
"<baichuan_pad_token>",
|
519 |
+
"<reserved_113>",
|
520 |
+
"<audio_delim_baichuan>",
|
521 |
+
"<video_start_baichuan>",
|
522 |
+
"<video_end_baichuan>",
|
523 |
+
"<video_palce_baichuan>",
|
524 |
+
"<audiotext_start_baichuan>",
|
525 |
+
"<audiotext_end_baichuan>",
|
526 |
+
"<audiotext_pad_baichuan>",
|
527 |
+
"<audiogen_start_baichuan>",
|
528 |
+
"<audiogen_end_baichuan>"
|
529 |
+
],
|
530 |
+
"bos_token": null,
|
531 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
532 |
+
"clean_up_tokenization_spaces": false,
|
533 |
+
"eos_token": "<|endoftext|>",
|
534 |
+
"errors": "replace",
|
535 |
+
"model_max_length": 8192,
|
536 |
+
"pad_token": "<|endoftext|>",
|
537 |
+
"split_special_tokens": false,
|
538 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
539 |
+
"unk_token": null
|
540 |
+
}
|
vector_quantize.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, random
|
2 |
+
from torch.nn import functional as F
|
3 |
+
from torch import nn
|
4 |
+
import numpy as np
|
5 |
+
from torch.cuda.amp import autocast
|
6 |
+
|
7 |
+
def uniform_init(*shape):
|
8 |
+
t = torch.zeros(shape)
|
9 |
+
nn.init.kaiming_uniform_(t)
|
10 |
+
return t
|
11 |
+
|
12 |
+
def cdist(x, y):
|
13 |
+
x2 = torch.sum(x ** 2, dim=-1, keepdims=True) # (b, 1)
|
14 |
+
y2 = torch.sum(y ** 2, dim=-1).reshape(1, -1) # (1, c)
|
15 |
+
xy = torch.einsum('bd,cd->bc', x, y) * -2
|
16 |
+
return (x2 + y2 + xy).clamp(min=0).sqrt() # (b, c)
|
17 |
+
|
18 |
+
def get_sequence_mask(inputs, inputs_length):
|
19 |
+
if inputs.dim() == 3:
|
20 |
+
bsz, tgt_len, _ = inputs.size()
|
21 |
+
else:
|
22 |
+
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
|
23 |
+
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
|
24 |
+
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
|
25 |
+
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
|
26 |
+
return sequence_mask, unpacking_index
|
27 |
+
|
28 |
+
|
29 |
+
class EuclideanCodebook(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim,
|
33 |
+
codebook_size,
|
34 |
+
init_std=0.02,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
self.init_std = init_std
|
38 |
+
self.dim = dim
|
39 |
+
self.codebook_size = codebook_size
|
40 |
+
|
41 |
+
embed = uniform_init(codebook_size, dim).to(torch.float32)
|
42 |
+
self.cluster_size = nn.Parameter(torch.ones(codebook_size))
|
43 |
+
self.embed_avg = nn.Parameter(embed.clone())
|
44 |
+
self.embed = nn.Parameter(embed)
|
45 |
+
del embed
|
46 |
+
|
47 |
+
@autocast(enabled=True, dtype=torch.float32)
|
48 |
+
@torch.no_grad()
|
49 |
+
def forward(self, x):
|
50 |
+
assert(len(x.shape) == 2)
|
51 |
+
assert(x.dtype == torch.float32)
|
52 |
+
embed = self.embed.detach().to(x.device)
|
53 |
+
dist = -cdist(x, embed) # dist((bs*sl, d), (c, d)) --> (bs*sl, c)
|
54 |
+
embed_ind = dist.argmax(dim=-1)
|
55 |
+
quantize = embed[embed_ind] # (bs*sl, d)
|
56 |
+
return quantize, embed_ind, dist
|
57 |
+
|
58 |
+
class VectorQuantize(nn.Module):
|
59 |
+
def __init__(self, config, *args, **kwargs):
|
60 |
+
super().__init__(*args, **kwargs)
|
61 |
+
self.config = config
|
62 |
+
self.codebook = EuclideanCodebook(dim=config.dim, codebook_size=config.codebook_size)
|
63 |
+
|
64 |
+
def forward(self, x, input_length):
|
65 |
+
batch_size, seq_len, _ = x.shape
|
66 |
+
mask, unpacking_index = get_sequence_mask(x, input_length)
|
67 |
+
if x.dtype != torch.float32:
|
68 |
+
x = x.to(torch.float32)
|
69 |
+
x = torch.masked_select(x, mask).reshape(-1, self.config.dim) # (bs*sl?, d)
|
70 |
+
quantize, embed_ind, _ = self.codebook(x)
|
71 |
+
quantize = torch.index_select(quantize, 0, unpacking_index).view(batch_size, seq_len, self.config.dim)
|
72 |
+
quantize = torch.where(mask, quantize, 0)
|
73 |
+
embed_ind = torch.index_select(embed_ind.reshape(-1, 1), 0, unpacking_index).view(batch_size, seq_len, 1)
|
74 |
+
embed_ind = torch.where(mask, embed_ind, -1).squeeze()
|
75 |
+
return quantize, embed_ind
|
76 |
+
|
77 |
+
def get_output_from_indices(self, indices):
|
78 |
+
return self.codebook.embed[indices]
|
visual_modeling_omni.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
import torch, math
|
4 |
+
import torch.utils.checkpoint
|
5 |
+
from torch import nn
|
6 |
+
import transformers
|
7 |
+
from flash_attn import flash_attn_varlen_func
|
8 |
+
from transformers.activations import ACT2FN
|
9 |
+
from PIL import Image
|
10 |
+
import io, fire
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
class OmniVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel):
|
14 |
+
def __init__(self, config):
|
15 |
+
super().__init__(config)
|
16 |
+
self.config_attn_implementation = 'flash_attention_2'
|
17 |
+
self.gradient_checkpointing = True # 强制开启
|
18 |
+
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
|
19 |
+
self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
|
20 |
+
del self.merger
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
pixel_values: torch.Tensor,
|
25 |
+
grid_thw: torch.Tensor,
|
26 |
+
):
|
27 |
+
hidden_states = pixel_values.to(self.get_dtype())
|
28 |
+
grid_thw = grid_thw.to(pixel_values.device)
|
29 |
+
|
30 |
+
hidden_states = self.patch_embed(hidden_states)
|
31 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
32 |
+
|
33 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
34 |
+
dim=0, dtype=torch.int32
|
35 |
+
)
|
36 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
37 |
+
|
38 |
+
for blk in self.blocks:
|
39 |
+
if self.gradient_checkpointing and self.training:
|
40 |
+
hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb)
|
41 |
+
else:
|
42 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
43 |
+
|
44 |
+
return hidden_states
|
45 |
+
|
46 |
+
@torch.no_grad()
|
47 |
+
def fake_input(self, device):
|
48 |
+
merge_size = max(self.merge_size, self.config.spatial_merge_size)
|
49 |
+
fake_image = torch.zeros([
|
50 |
+
1,
|
51 |
+
self.config.temporal_patch_size,
|
52 |
+
3,
|
53 |
+
merge_size // self.config.spatial_merge_size,
|
54 |
+
self.config.spatial_merge_size,
|
55 |
+
self.config.patch_size,
|
56 |
+
merge_size // self.config.spatial_merge_size,
|
57 |
+
self.config.spatial_merge_size,
|
58 |
+
self.config.patch_size,
|
59 |
+
], dtype=torch.float32, device=device)
|
60 |
+
patches = fake_image.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
61 |
+
flatten_patches = patches.reshape(
|
62 |
+
merge_size * merge_size, 3 * self.config.temporal_patch_size * self.config.patch_size * self.config.patch_size
|
63 |
+
)
|
64 |
+
return [flatten_patches], [(1, merge_size, merge_size)], [1]
|
65 |
+
|
66 |
+
|
67 |
+
class OmniVisualBridge(nn.Module):
|
68 |
+
def __init__(self, config):
|
69 |
+
super().__init__()
|
70 |
+
self.config = config
|
71 |
+
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
|
72 |
+
self.hidden_size = config.embed_dim * (self.merge_size**2)
|
73 |
+
self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
74 |
+
self.mlp = nn.Sequential(
|
75 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
76 |
+
nn.GELU(),
|
77 |
+
nn.Linear(self.hidden_size, config.hidden_size),
|
78 |
+
)
|
79 |
+
|
80 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
81 |
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
fire.Fire()
|
87 |
+
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
zero_to_fp32.py
ADDED
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright (c) Microsoft Corporation.
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
# DeepSpeed Team
|
7 |
+
|
8 |
+
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
|
9 |
+
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
|
10 |
+
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
|
11 |
+
# application.
|
12 |
+
#
|
13 |
+
# example: python zero_to_fp32.py . pytorch_model.bin
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import torch
|
17 |
+
import glob
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
from collections import OrderedDict
|
22 |
+
from dataclasses import dataclass
|
23 |
+
|
24 |
+
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
25 |
+
# DeepSpeed data structures it has to be available in the current python environment.
|
26 |
+
from deepspeed.utils import logger
|
27 |
+
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
|
28 |
+
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
|
29 |
+
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class zero_model_state:
|
34 |
+
buffers: dict()
|
35 |
+
param_shapes: dict()
|
36 |
+
shared_params: list
|
37 |
+
ds_version: int
|
38 |
+
frozen_param_shapes: dict()
|
39 |
+
frozen_param_fragments: dict()
|
40 |
+
|
41 |
+
|
42 |
+
debug = 0
|
43 |
+
|
44 |
+
# load to cpu
|
45 |
+
device = torch.device('cpu')
|
46 |
+
|
47 |
+
|
48 |
+
def atoi(text):
|
49 |
+
return int(text) if text.isdigit() else text
|
50 |
+
|
51 |
+
|
52 |
+
def natural_keys(text):
|
53 |
+
'''
|
54 |
+
alist.sort(key=natural_keys) sorts in human order
|
55 |
+
http://nedbatchelder.com/blog/200712/human_sorting.html
|
56 |
+
(See Toothy's implementation in the comments)
|
57 |
+
'''
|
58 |
+
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
59 |
+
|
60 |
+
|
61 |
+
def get_model_state_file(checkpoint_dir, zero_stage):
|
62 |
+
if not os.path.isdir(checkpoint_dir):
|
63 |
+
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
|
64 |
+
|
65 |
+
# there should be only one file
|
66 |
+
if zero_stage <= 2:
|
67 |
+
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
|
68 |
+
elif zero_stage == 3:
|
69 |
+
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
|
70 |
+
|
71 |
+
if not os.path.exists(file):
|
72 |
+
raise FileNotFoundError(f"can't find model states file at '{file}'")
|
73 |
+
|
74 |
+
return file
|
75 |
+
|
76 |
+
|
77 |
+
def get_checkpoint_files(checkpoint_dir, glob_pattern):
|
78 |
+
# XXX: need to test that this simple glob rule works for multi-node setup too
|
79 |
+
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
|
80 |
+
|
81 |
+
if len(ckpt_files) == 0:
|
82 |
+
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
|
83 |
+
|
84 |
+
return ckpt_files
|
85 |
+
|
86 |
+
|
87 |
+
def get_optim_files(checkpoint_dir):
|
88 |
+
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
|
89 |
+
|
90 |
+
|
91 |
+
def get_model_state_files(checkpoint_dir):
|
92 |
+
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
|
93 |
+
|
94 |
+
|
95 |
+
def parse_model_states(files):
|
96 |
+
zero_model_states = []
|
97 |
+
for file in files:
|
98 |
+
state_dict = torch.load(file, map_location=device)
|
99 |
+
|
100 |
+
if BUFFER_NAMES not in state_dict:
|
101 |
+
raise ValueError(f"{file} is not a model state checkpoint")
|
102 |
+
buffer_names = state_dict[BUFFER_NAMES]
|
103 |
+
if debug:
|
104 |
+
print("Found buffers:", buffer_names)
|
105 |
+
|
106 |
+
# recover just the buffers while restoring them to fp32 if they were saved in fp16
|
107 |
+
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
|
108 |
+
param_shapes = state_dict[PARAM_SHAPES]
|
109 |
+
|
110 |
+
# collect parameters that are included in param_shapes
|
111 |
+
param_names = []
|
112 |
+
for s in param_shapes:
|
113 |
+
for name in s.keys():
|
114 |
+
param_names.append(name)
|
115 |
+
|
116 |
+
# update with frozen parameters
|
117 |
+
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
|
118 |
+
if frozen_param_shapes is not None:
|
119 |
+
if debug:
|
120 |
+
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
|
121 |
+
param_names += list(frozen_param_shapes.keys())
|
122 |
+
|
123 |
+
# handle shared params
|
124 |
+
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
|
125 |
+
|
126 |
+
ds_version = state_dict.get(DS_VERSION, None)
|
127 |
+
|
128 |
+
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
129 |
+
|
130 |
+
z_model_state = zero_model_state(buffers=buffers,
|
131 |
+
param_shapes=param_shapes,
|
132 |
+
shared_params=shared_params,
|
133 |
+
ds_version=ds_version,
|
134 |
+
frozen_param_shapes=frozen_param_shapes,
|
135 |
+
frozen_param_fragments=frozen_param_fragments)
|
136 |
+
zero_model_states.append(z_model_state)
|
137 |
+
|
138 |
+
return zero_model_states
|
139 |
+
|
140 |
+
|
141 |
+
def parse_optim_states(files, ds_checkpoint_dir):
|
142 |
+
|
143 |
+
total_files = len(files)
|
144 |
+
state_dicts = []
|
145 |
+
for f in files:
|
146 |
+
state_dict = torch.load(f, map_location=device)
|
147 |
+
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
|
148 |
+
# and also handle the case where it was already removed by another helper script
|
149 |
+
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
|
150 |
+
state_dicts.append(state_dict)
|
151 |
+
|
152 |
+
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
153 |
+
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
154 |
+
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
155 |
+
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
156 |
+
|
157 |
+
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
|
158 |
+
# parameters can be different from data parallelism for non-expert parameters. So we can just
|
159 |
+
# use the max of the partition_count to get the dp world_size.
|
160 |
+
|
161 |
+
if type(world_size) is list:
|
162 |
+
world_size = max(world_size)
|
163 |
+
|
164 |
+
if world_size != total_files:
|
165 |
+
raise ValueError(
|
166 |
+
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
|
167 |
+
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
|
168 |
+
)
|
169 |
+
|
170 |
+
# the groups are named differently in each stage
|
171 |
+
if zero_stage <= 2:
|
172 |
+
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
|
173 |
+
elif zero_stage == 3:
|
174 |
+
fp32_groups_key = FP32_FLAT_GROUPS
|
175 |
+
else:
|
176 |
+
raise ValueError(f"unknown zero stage {zero_stage}")
|
177 |
+
|
178 |
+
if zero_stage <= 2:
|
179 |
+
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
180 |
+
elif zero_stage == 3:
|
181 |
+
# if there is more than one param group, there will be multiple flattened tensors - one
|
182 |
+
# flattened tensor per group - for simplicity merge them into a single tensor
|
183 |
+
#
|
184 |
+
# XXX: could make the script more memory efficient for when there are multiple groups - it
|
185 |
+
# will require matching the sub-lists of param_shapes for each param group flattened tensor
|
186 |
+
|
187 |
+
fp32_flat_groups = [
|
188 |
+
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
|
189 |
+
]
|
190 |
+
|
191 |
+
return zero_stage, world_size, fp32_flat_groups
|
192 |
+
|
193 |
+
|
194 |
+
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
|
195 |
+
"""
|
196 |
+
Returns fp32 state_dict reconstructed from ds checkpoint
|
197 |
+
|
198 |
+
Args:
|
199 |
+
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
|
200 |
+
|
201 |
+
"""
|
202 |
+
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
|
203 |
+
|
204 |
+
optim_files = get_optim_files(ds_checkpoint_dir)
|
205 |
+
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
|
206 |
+
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
|
207 |
+
|
208 |
+
model_files = get_model_state_files(ds_checkpoint_dir)
|
209 |
+
|
210 |
+
zero_model_states = parse_model_states(model_files)
|
211 |
+
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
212 |
+
|
213 |
+
if zero_stage <= 2:
|
214 |
+
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
215 |
+
exclude_frozen_parameters)
|
216 |
+
elif zero_stage == 3:
|
217 |
+
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
218 |
+
exclude_frozen_parameters)
|
219 |
+
|
220 |
+
|
221 |
+
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
222 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
223 |
+
return
|
224 |
+
|
225 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
226 |
+
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
|
227 |
+
|
228 |
+
if debug:
|
229 |
+
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
|
230 |
+
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
231 |
+
|
232 |
+
wanted_params = len(frozen_param_shapes)
|
233 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
234 |
+
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
|
235 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
236 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
237 |
+
|
238 |
+
total_params = 0
|
239 |
+
total_numel = 0
|
240 |
+
for name, shape in frozen_param_shapes.items():
|
241 |
+
total_params += 1
|
242 |
+
unpartitioned_numel = shape.numel()
|
243 |
+
total_numel += unpartitioned_numel
|
244 |
+
|
245 |
+
state_dict[name] = frozen_param_fragments[name]
|
246 |
+
|
247 |
+
if debug:
|
248 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
249 |
+
|
250 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
251 |
+
|
252 |
+
|
253 |
+
def _has_callable(obj, fn):
|
254 |
+
attr = getattr(obj, fn, None)
|
255 |
+
return callable(attr)
|
256 |
+
|
257 |
+
|
258 |
+
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
259 |
+
param_shapes = zero_model_states[0].param_shapes
|
260 |
+
|
261 |
+
# Reconstruction protocol:
|
262 |
+
#
|
263 |
+
# XXX: document this
|
264 |
+
|
265 |
+
if debug:
|
266 |
+
for i in range(world_size):
|
267 |
+
for j in range(len(fp32_flat_groups[0])):
|
268 |
+
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
|
269 |
+
|
270 |
+
# XXX: memory usage doubles here (zero2)
|
271 |
+
num_param_groups = len(fp32_flat_groups[0])
|
272 |
+
merged_single_partition_of_fp32_groups = []
|
273 |
+
for i in range(num_param_groups):
|
274 |
+
merged_partitions = [sd[i] for sd in fp32_flat_groups]
|
275 |
+
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
276 |
+
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
277 |
+
avail_numel = sum(
|
278 |
+
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
|
279 |
+
|
280 |
+
if debug:
|
281 |
+
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
282 |
+
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
283 |
+
# not asserting if there is a mismatch due to possible padding
|
284 |
+
print(f"Have {avail_numel} numels to process.")
|
285 |
+
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
286 |
+
|
287 |
+
# params
|
288 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
289 |
+
# out-of-core computing solution
|
290 |
+
total_numel = 0
|
291 |
+
total_params = 0
|
292 |
+
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
293 |
+
offset = 0
|
294 |
+
avail_numel = full_single_fp32_vector.numel()
|
295 |
+
for name, shape in shapes.items():
|
296 |
+
|
297 |
+
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
|
298 |
+
total_numel += unpartitioned_numel
|
299 |
+
total_params += 1
|
300 |
+
|
301 |
+
if debug:
|
302 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
303 |
+
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
|
304 |
+
offset += unpartitioned_numel
|
305 |
+
|
306 |
+
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
307 |
+
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
|
308 |
+
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
|
309 |
+
# live optimizer object, so we are checking that the numbers are within the right range
|
310 |
+
align_to = 2 * world_size
|
311 |
+
|
312 |
+
def zero2_align(x):
|
313 |
+
return align_to * math.ceil(x / align_to)
|
314 |
+
|
315 |
+
if debug:
|
316 |
+
print(f"original offset={offset}, avail_numel={avail_numel}")
|
317 |
+
|
318 |
+
offset = zero2_align(offset)
|
319 |
+
avail_numel = zero2_align(avail_numel)
|
320 |
+
|
321 |
+
if debug:
|
322 |
+
print(f"aligned offset={offset}, avail_numel={avail_numel}")
|
323 |
+
|
324 |
+
# Sanity check
|
325 |
+
if offset != avail_numel:
|
326 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
327 |
+
|
328 |
+
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
329 |
+
|
330 |
+
|
331 |
+
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
332 |
+
exclude_frozen_parameters):
|
333 |
+
state_dict = OrderedDict()
|
334 |
+
|
335 |
+
# buffers
|
336 |
+
buffers = zero_model_states[0].buffers
|
337 |
+
state_dict.update(buffers)
|
338 |
+
if debug:
|
339 |
+
print(f"added {len(buffers)} buffers")
|
340 |
+
|
341 |
+
if not exclude_frozen_parameters:
|
342 |
+
_zero2_merge_frozen_params(state_dict, zero_model_states)
|
343 |
+
|
344 |
+
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
345 |
+
|
346 |
+
# recover shared parameters
|
347 |
+
for pair in zero_model_states[0].shared_params:
|
348 |
+
if pair[1] in state_dict:
|
349 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
350 |
+
|
351 |
+
return state_dict
|
352 |
+
|
353 |
+
|
354 |
+
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
355 |
+
remainder = unpartitioned_numel % world_size
|
356 |
+
padding_numel = (world_size - remainder) if remainder else 0
|
357 |
+
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
|
358 |
+
return partitioned_numel, padding_numel
|
359 |
+
|
360 |
+
|
361 |
+
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
362 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
363 |
+
return
|
364 |
+
|
365 |
+
if debug:
|
366 |
+
for i in range(world_size):
|
367 |
+
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
|
368 |
+
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
369 |
+
|
370 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
371 |
+
wanted_params = len(frozen_param_shapes)
|
372 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
373 |
+
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
|
374 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
375 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
376 |
+
|
377 |
+
total_params = 0
|
378 |
+
total_numel = 0
|
379 |
+
for name, shape in zero_model_states[0].frozen_param_shapes.items():
|
380 |
+
total_params += 1
|
381 |
+
unpartitioned_numel = shape.numel()
|
382 |
+
total_numel += unpartitioned_numel
|
383 |
+
|
384 |
+
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
|
385 |
+
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
386 |
+
|
387 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
388 |
+
|
389 |
+
if debug:
|
390 |
+
print(
|
391 |
+
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
392 |
+
)
|
393 |
+
|
394 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
395 |
+
|
396 |
+
|
397 |
+
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
398 |
+
param_shapes = zero_model_states[0].param_shapes
|
399 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
400 |
+
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
|
401 |
+
# param, re-consolidating each param, while dealing with padding if any
|
402 |
+
|
403 |
+
# merge list of dicts, preserving order
|
404 |
+
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
405 |
+
|
406 |
+
if debug:
|
407 |
+
for i in range(world_size):
|
408 |
+
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
|
409 |
+
|
410 |
+
wanted_params = len(param_shapes)
|
411 |
+
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
|
412 |
+
# not asserting if there is a mismatch due to possible padding
|
413 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
414 |
+
print(f"Trainable params: Have {avail_numel} numels to process.")
|
415 |
+
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
|
416 |
+
|
417 |
+
# params
|
418 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
419 |
+
# out-of-core computing solution
|
420 |
+
offset = 0
|
421 |
+
total_numel = 0
|
422 |
+
total_params = 0
|
423 |
+
for name, shape in param_shapes.items():
|
424 |
+
|
425 |
+
unpartitioned_numel = shape.numel()
|
426 |
+
total_numel += unpartitioned_numel
|
427 |
+
total_params += 1
|
428 |
+
|
429 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
430 |
+
|
431 |
+
if debug:
|
432 |
+
print(
|
433 |
+
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
434 |
+
)
|
435 |
+
|
436 |
+
# XXX: memory usage doubles here
|
437 |
+
state_dict[name] = torch.cat(
|
438 |
+
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
|
439 |
+
0).narrow(0, 0, unpartitioned_numel).view(shape)
|
440 |
+
offset += partitioned_numel
|
441 |
+
|
442 |
+
offset *= world_size
|
443 |
+
|
444 |
+
# Sanity check
|
445 |
+
if offset != avail_numel:
|
446 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
447 |
+
|
448 |
+
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
449 |
+
|
450 |
+
|
451 |
+
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
452 |
+
exclude_frozen_parameters):
|
453 |
+
state_dict = OrderedDict()
|
454 |
+
|
455 |
+
# buffers
|
456 |
+
buffers = zero_model_states[0].buffers
|
457 |
+
state_dict.update(buffers)
|
458 |
+
if debug:
|
459 |
+
print(f"added {len(buffers)} buffers")
|
460 |
+
|
461 |
+
if not exclude_frozen_parameters:
|
462 |
+
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
|
463 |
+
|
464 |
+
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
465 |
+
|
466 |
+
# recover shared parameters
|
467 |
+
for pair in zero_model_states[0].shared_params:
|
468 |
+
if pair[1] in state_dict:
|
469 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
470 |
+
|
471 |
+
return state_dict
|
472 |
+
|
473 |
+
|
474 |
+
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
|
475 |
+
"""
|
476 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
477 |
+
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
478 |
+
via a model hub.
|
479 |
+
|
480 |
+
Args:
|
481 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder
|
482 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
483 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
484 |
+
|
485 |
+
Returns:
|
486 |
+
- pytorch ``state_dict``
|
487 |
+
|
488 |
+
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
|
489 |
+
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
490 |
+
the checkpoint.
|
491 |
+
|
492 |
+
A typical usage might be ::
|
493 |
+
|
494 |
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
495 |
+
# do the training and checkpoint saving
|
496 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
|
497 |
+
model = model.cpu() # move to cpu
|
498 |
+
model.load_state_dict(state_dict)
|
499 |
+
# submit to model hub or save the model to share with others
|
500 |
+
|
501 |
+
In this example the ``model`` will no longer be usable in the deepspeed context of the same
|
502 |
+
application. i.e. you will need to re-initialize the deepspeed engine, since
|
503 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
504 |
+
|
505 |
+
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
|
506 |
+
|
507 |
+
"""
|
508 |
+
if tag is None:
|
509 |
+
latest_path = os.path.join(checkpoint_dir, 'latest')
|
510 |
+
if os.path.isfile(latest_path):
|
511 |
+
with open(latest_path, 'r') as fd:
|
512 |
+
tag = fd.read().strip()
|
513 |
+
else:
|
514 |
+
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
515 |
+
|
516 |
+
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
|
517 |
+
|
518 |
+
if not os.path.isdir(ds_checkpoint_dir):
|
519 |
+
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
520 |
+
|
521 |
+
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
|
522 |
+
|
523 |
+
|
524 |
+
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
|
525 |
+
"""
|
526 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
527 |
+
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
528 |
+
|
529 |
+
Args:
|
530 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
531 |
+
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
|
532 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
533 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
534 |
+
"""
|
535 |
+
|
536 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
|
537 |
+
print(f"Saving fp32 state dict to {output_file}")
|
538 |
+
torch.save(state_dict, output_file)
|
539 |
+
|
540 |
+
|
541 |
+
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
542 |
+
"""
|
543 |
+
1. Put the provided model to cpu
|
544 |
+
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
|
545 |
+
3. Load it into the provided model
|
546 |
+
|
547 |
+
Args:
|
548 |
+
- ``model``: the model object to update
|
549 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
550 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
551 |
+
|
552 |
+
Returns:
|
553 |
+
- ``model`: modified model
|
554 |
+
|
555 |
+
Make sure you have plenty of CPU memory available before you call this function. If you don't
|
556 |
+
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
|
557 |
+
conveniently placed for you in the checkpoint folder.
|
558 |
+
|
559 |
+
A typical usage might be ::
|
560 |
+
|
561 |
+
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
562 |
+
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
|
563 |
+
# submit to model hub or save the model to share with others
|
564 |
+
|
565 |
+
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
|
566 |
+
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
|
567 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
568 |
+
|
569 |
+
"""
|
570 |
+
logger.info(f"Extracting fp32 weights")
|
571 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
572 |
+
|
573 |
+
logger.info(f"Overwriting model with fp32 weights")
|
574 |
+
model = model.cpu()
|
575 |
+
model.load_state_dict(state_dict, strict=False)
|
576 |
+
|
577 |
+
return model
|
578 |
+
|
579 |
+
|
580 |
+
if __name__ == "__main__":
|
581 |
+
|
582 |
+
parser = argparse.ArgumentParser()
|
583 |
+
parser.add_argument("checkpoint_dir",
|
584 |
+
type=str,
|
585 |
+
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
586 |
+
parser.add_argument(
|
587 |
+
"output_file",
|
588 |
+
type=str,
|
589 |
+
help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
|
590 |
+
parser.add_argument("-t",
|
591 |
+
"--tag",
|
592 |
+
type=str,
|
593 |
+
default=None,
|
594 |
+
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
|
595 |
+
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
|
596 |
+
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
597 |
+
args = parser.parse_args()
|
598 |
+
|
599 |
+
debug = args.debug
|
600 |
+
|
601 |
+
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
|
602 |
+
args.output_file,
|
603 |
+
tag=args.tag,
|
604 |
+
exclude_frozen_parameters=args.exclude_frozen_parameters)
|