Spaces:
Running
on
Zero
Running
on
Zero
switch to config arg
Browse files
models.py
CHANGED
@@ -102,27 +102,22 @@ class Model(
|
|
102 |
repo_url="https://github.com/SesameAILabs/csm",
|
103 |
pipeline_tag="text-to-speech",
|
104 |
license="apache-2.0",
|
105 |
-
coders={
|
106 |
-
# Tells the class how to serialize and deserialize config.json
|
107 |
-
ModelArgs : (
|
108 |
-
lambda x: asdict(x), # Encoder: how to convert a `ModelArgs` to a valid jsonable value?
|
109 |
-
lambda data: ModelArgs(**data), # Decoder: how to reconstruct a `ModelArgs` from a dictionary?
|
110 |
-
)
|
111 |
-
}
|
112 |
):
|
113 |
-
def __init__(self,
|
114 |
super().__init__()
|
115 |
-
self.
|
116 |
|
117 |
-
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[
|
118 |
-
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[
|
119 |
|
120 |
-
self.text_embeddings = nn.Embedding(
|
121 |
-
self.audio_embeddings = nn.Embedding(
|
122 |
|
123 |
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
124 |
-
self.codebook0_head = nn.Linear(backbone_dim,
|
125 |
-
self.audio_head = nn.Parameter(
|
|
|
|
|
126 |
|
127 |
def setup_caches(self, max_batch_size: int) -> None:
|
128 |
"""Setup KV caches and return a causal mask."""
|
@@ -131,10 +126,10 @@ class Model(
|
|
131 |
|
132 |
with device:
|
133 |
self.backbone.setup_caches(max_batch_size, dtype)
|
134 |
-
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.
|
135 |
|
136 |
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
|
137 |
-
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.
|
138 |
|
139 |
def generate_frame(
|
140 |
self,
|
@@ -175,7 +170,7 @@ class Model(
|
|
175 |
|
176 |
# Decoder caches must be reset every frame.
|
177 |
self.decoder.reset_caches()
|
178 |
-
for i in range(1, self.
|
179 |
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
180 |
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
|
181 |
dtype=dtype
|
@@ -195,16 +190,16 @@ class Model(
|
|
195 |
self.decoder.reset_caches()
|
196 |
|
197 |
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
198 |
-
return self.audio_embeddings(tokens + codebook * self.
|
199 |
|
200 |
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
201 |
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
|
202 |
|
203 |
audio_tokens = tokens[:, :, :-1] + (
|
204 |
-
self.
|
205 |
)
|
206 |
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
207 |
-
tokens.size(0), tokens.size(1), self.
|
208 |
)
|
209 |
|
210 |
return torch.cat([audio_embeds, text_embeds], dim=-2)
|
|
|
102 |
repo_url="https://github.com/SesameAILabs/csm",
|
103 |
pipeline_tag="text-to-speech",
|
104 |
license="apache-2.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
):
|
106 |
+
def __init__(self, config: ModelArgs):
|
107 |
super().__init__()
|
108 |
+
self.config = config
|
109 |
|
110 |
+
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
|
111 |
+
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
|
112 |
|
113 |
+
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
|
114 |
+
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
|
115 |
|
116 |
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
117 |
+
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
|
118 |
+
self.audio_head = nn.Parameter(
|
119 |
+
torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size)
|
120 |
+
)
|
121 |
|
122 |
def setup_caches(self, max_batch_size: int) -> None:
|
123 |
"""Setup KV caches and return a causal mask."""
|
|
|
126 |
|
127 |
with device:
|
128 |
self.backbone.setup_caches(max_batch_size, dtype)
|
129 |
+
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
|
130 |
|
131 |
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
|
132 |
+
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
|
133 |
|
134 |
def generate_frame(
|
135 |
self,
|
|
|
170 |
|
171 |
# Decoder caches must be reset every frame.
|
172 |
self.decoder.reset_caches()
|
173 |
+
for i in range(1, self.config.audio_num_codebooks):
|
174 |
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
175 |
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
|
176 |
dtype=dtype
|
|
|
190 |
self.decoder.reset_caches()
|
191 |
|
192 |
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
193 |
+
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
|
194 |
|
195 |
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
196 |
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
|
197 |
|
198 |
audio_tokens = tokens[:, :, :-1] + (
|
199 |
+
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
|
200 |
)
|
201 |
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
202 |
+
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
|
203 |
)
|
204 |
|
205 |
return torch.cat([audio_embeds, text_embeds], dim=-2)
|