Phil Sobrepena commited on
Commit
e87de0e
·
1 Parent(s): d07a8ac

match embeddings and network to rc repo

Browse files
mmaudio/model/embeddings.py CHANGED
@@ -4,31 +4,45 @@ import math
4
 
5
  # https://github.com/facebookresearch/DiT
6
 
7
-
8
  class TimestepEmbedder(nn.Module):
9
  """
10
  Embeds scalar timesteps into vector representations.
11
  """
12
 
13
- def __init__(self, hidden_dim: int, frequency_embedding_size: int = 256):
14
  super().__init__()
15
-
16
  self.mlp = nn.Sequential(
17
- nn.Linear(frequency_embedding_size, hidden_dim, bias=True),
18
  nn.SiLU(),
19
- nn.Linear(hidden_dim, hidden_dim, bias=True),
20
  )
 
 
 
21
 
22
- self.frequency_embedding_size = frequency_embedding_size
23
- half_dim = self.frequency_embedding_size // 2
24
- freqs = torch.exp(
25
- -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32) /
26
- half_dim
27
- )
28
- self.register_buffer('freqs', freqs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def forward(self, t):
31
- t_freq = t.unsqueeze(-1) * self.freqs.unsqueeze(0)
32
- t_embed = torch.cat([t_freq.sin(), t_freq.cos()], dim=-1)
33
- t_embed = self.mlp(t_embed.to(t.dtype))
34
- return t_embed
 
4
 
5
  # https://github.com/facebookresearch/DiT
6
 
 
7
  class TimestepEmbedder(nn.Module):
8
  """
9
  Embeds scalar timesteps into vector representations.
10
  """
11
 
12
+ def __init__(self, dim, frequency_embedding_size, max_period):
13
  super().__init__()
 
14
  self.mlp = nn.Sequential(
15
+ nn.Linear(frequency_embedding_size, dim),
16
  nn.SiLU(),
17
+ nn.Linear(dim, dim),
18
  )
19
+ self.dim = dim
20
+ self.max_period = max_period
21
+ assert dim % 2 == 0, 'dim must be even.'
22
 
23
+ with torch.autocast('cuda', enabled=False):
24
+ self.freqs = (
25
+ 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
26
+ frequency_embedding_size)))
27
+ freq_scale = 10000 / max_period
28
+ self.freqs = nn.Parameter(freq_scale * self.freqs)
29
+
30
+ def timestep_embedding(self, t):
31
+ """
32
+ Create sinusoidal timestep embeddings.
33
+ :param t: a 1-D Tensor of N indices, one per batch element.
34
+ These may be fractional.
35
+ :param dim: the dimension of the output.
36
+ :param max_period: controls the minimum frequency of the embeddings.
37
+ :return: an (N, D) Tensor of positional embeddings.
38
+ """
39
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
40
+
41
+ args = t[:, None].float() * self.freqs[None]
42
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
43
+ return embedding
44
 
45
  def forward(self, t):
46
+ t_freq = self.timestep_embedding(t).to(t.dtype)
47
+ t_emb = self.mlp(t_freq)
48
+ return t_emb
 
mmaudio/model/networks.py CHANGED
@@ -166,8 +166,10 @@ class MMAudio(nn.Module):
166
  self._clip_seq_len,
167
  device=self.device)
168
 
169
- self.latent_rot = nn.Buffer(latent_rot, persistent=False)
170
- self.clip_rot = nn.Buffer(clip_rot, persistent=False)
 
 
171
 
172
  def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
173
  self._latent_seq_len = latent_seq_len
@@ -346,7 +348,7 @@ class MMAudio(nn.Module):
346
  if 'clip_rot' in src_dict:
347
  del src_dict['clip_rot']
348
 
349
- self.load_state_dict(src_dict, strict=True)
350
 
351
  @property
352
  def device(self) -> torch.device:
@@ -466,4 +468,4 @@ if __name__ == '__main__':
466
 
467
  # print the number of parameters in terms of millions
468
  num_params = sum(p.numel() for p in network.parameters()) / 1e6
469
- print(f'Number of parameters: {num_params:.2f}M')
 
166
  self._clip_seq_len,
167
  device=self.device)
168
 
169
+ # self.latent_rot = latent_rot.to(self.device)
170
+ # self.clip_rot = clip_rot.to(self.device)
171
+ self.register_buffer('latent_rot', latent_rot)
172
+ self.register_buffer('clip_rot', clip_rot)
173
 
174
  def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
175
  self._latent_seq_len = latent_seq_len
 
348
  if 'clip_rot' in src_dict:
349
  del src_dict['clip_rot']
350
 
351
+ self.load_state_dict(src_dict, strict=False)
352
 
353
  @property
354
  def device(self) -> torch.device:
 
468
 
469
  # print the number of parameters in terms of millions
470
  num_params = sum(p.numel() for p in network.parameters()) / 1e6
471
+ print(f'Number of parameters: {num_params:.2f}M')