Zackh commited on
Commit
1c1de51
·
1 Parent(s): 45e163c

switch to config arg

Browse files
Files changed (1) hide show
  1. models.py +16 -21
models.py CHANGED
@@ -102,27 +102,22 @@ class Model(
102
  repo_url="https://github.com/SesameAILabs/csm",
103
  pipeline_tag="text-to-speech",
104
  license="apache-2.0",
105
- coders={
106
- # Tells the class how to serialize and deserialize config.json
107
- ModelArgs : (
108
- lambda x: asdict(x), # Encoder: how to convert a `ModelArgs` to a valid jsonable value?
109
- lambda data: ModelArgs(**data), # Decoder: how to reconstruct a `ModelArgs` from a dictionary?
110
- )
111
- }
112
  ):
113
- def __init__(self, args: ModelArgs):
114
  super().__init__()
115
- self.args = args
116
 
117
- self.backbone, backbone_dim = _prepare_transformer(FLAVORS[args.backbone_flavor]())
118
- self.decoder, decoder_dim = _prepare_transformer(FLAVORS[args.decoder_flavor]())
119
 
120
- self.text_embeddings = nn.Embedding(args.text_vocab_size, backbone_dim)
121
- self.audio_embeddings = nn.Embedding(args.audio_vocab_size * args.audio_num_codebooks, backbone_dim)
122
 
123
  self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
124
- self.codebook0_head = nn.Linear(backbone_dim, args.audio_vocab_size, bias=False)
125
- self.audio_head = nn.Parameter(torch.empty(args.audio_num_codebooks - 1, decoder_dim, args.audio_vocab_size))
 
 
126
 
127
  def setup_caches(self, max_batch_size: int) -> None:
128
  """Setup KV caches and return a causal mask."""
@@ -131,10 +126,10 @@ class Model(
131
 
132
  with device:
133
  self.backbone.setup_caches(max_batch_size, dtype)
134
- self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.args.audio_num_codebooks)
135
 
136
  self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
137
- self.register_buffer("decoder_causal_mask", _create_causal_mask(self.args.audio_num_codebooks, device))
138
 
139
  def generate_frame(
140
  self,
@@ -175,7 +170,7 @@ class Model(
175
 
176
  # Decoder caches must be reset every frame.
177
  self.decoder.reset_caches()
178
- for i in range(1, self.args.audio_num_codebooks):
179
  curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
180
  decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
181
  dtype=dtype
@@ -195,16 +190,16 @@ class Model(
195
  self.decoder.reset_caches()
196
 
197
  def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
198
- return self.audio_embeddings(tokens + codebook * self.args.audio_vocab_size)
199
 
200
  def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
201
  text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
202
 
203
  audio_tokens = tokens[:, :, :-1] + (
204
- self.args.audio_vocab_size * torch.arange(self.args.audio_num_codebooks, device=tokens.device)
205
  )
206
  audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
207
- tokens.size(0), tokens.size(1), self.args.audio_num_codebooks, -1
208
  )
209
 
210
  return torch.cat([audio_embeds, text_embeds], dim=-2)
 
102
  repo_url="https://github.com/SesameAILabs/csm",
103
  pipeline_tag="text-to-speech",
104
  license="apache-2.0",
 
 
 
 
 
 
 
105
  ):
106
+ def __init__(self, config: ModelArgs):
107
  super().__init__()
108
+ self.config = config
109
 
110
+ self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
111
+ self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
112
 
113
+ self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
114
+ self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
115
 
116
  self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
117
+ self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
118
+ self.audio_head = nn.Parameter(
119
+ torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size)
120
+ )
121
 
122
  def setup_caches(self, max_batch_size: int) -> None:
123
  """Setup KV caches and return a causal mask."""
 
126
 
127
  with device:
128
  self.backbone.setup_caches(max_batch_size, dtype)
129
+ self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
130
 
131
  self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
132
+ self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
133
 
134
  def generate_frame(
135
  self,
 
170
 
171
  # Decoder caches must be reset every frame.
172
  self.decoder.reset_caches()
173
+ for i in range(1, self.config.audio_num_codebooks):
174
  curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
175
  decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
176
  dtype=dtype
 
190
  self.decoder.reset_caches()
191
 
192
  def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
193
+ return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
194
 
195
  def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
196
  text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
197
 
198
  audio_tokens = tokens[:, :, :-1] + (
199
+ self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
200
  )
201
  audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
202
+ tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
203
  )
204
 
205
  return torch.cat([audio_embeds, text_embeds], dim=-2)