Spaces:
Sleeping
Sleeping
Phil Sobrepena
commited on
Commit
·
e87de0e
1
Parent(s):
d07a8ac
match embeddings and network to rc repo
Browse files- mmaudio/model/embeddings.py +30 -16
- mmaudio/model/networks.py +6 -4
mmaudio/model/embeddings.py
CHANGED
@@ -4,31 +4,45 @@ import math
|
|
4 |
|
5 |
# https://github.com/facebookresearch/DiT
|
6 |
|
7 |
-
|
8 |
class TimestepEmbedder(nn.Module):
|
9 |
"""
|
10 |
Embeds scalar timesteps into vector representations.
|
11 |
"""
|
12 |
|
13 |
-
def __init__(self,
|
14 |
super().__init__()
|
15 |
-
|
16 |
self.mlp = nn.Sequential(
|
17 |
-
nn.Linear(frequency_embedding_size,
|
18 |
nn.SiLU(),
|
19 |
-
nn.Linear(
|
20 |
)
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
def forward(self, t):
|
31 |
-
t_freq =
|
32 |
-
|
33 |
-
|
34 |
-
return t_embed
|
|
|
4 |
|
5 |
# https://github.com/facebookresearch/DiT
|
6 |
|
|
|
7 |
class TimestepEmbedder(nn.Module):
|
8 |
"""
|
9 |
Embeds scalar timesteps into vector representations.
|
10 |
"""
|
11 |
|
12 |
+
def __init__(self, dim, frequency_embedding_size, max_period):
|
13 |
super().__init__()
|
|
|
14 |
self.mlp = nn.Sequential(
|
15 |
+
nn.Linear(frequency_embedding_size, dim),
|
16 |
nn.SiLU(),
|
17 |
+
nn.Linear(dim, dim),
|
18 |
)
|
19 |
+
self.dim = dim
|
20 |
+
self.max_period = max_period
|
21 |
+
assert dim % 2 == 0, 'dim must be even.'
|
22 |
|
23 |
+
with torch.autocast('cuda', enabled=False):
|
24 |
+
self.freqs = (
|
25 |
+
1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
|
26 |
+
frequency_embedding_size)))
|
27 |
+
freq_scale = 10000 / max_period
|
28 |
+
self.freqs = nn.Parameter(freq_scale * self.freqs)
|
29 |
+
|
30 |
+
def timestep_embedding(self, t):
|
31 |
+
"""
|
32 |
+
Create sinusoidal timestep embeddings.
|
33 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
34 |
+
These may be fractional.
|
35 |
+
:param dim: the dimension of the output.
|
36 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
37 |
+
:return: an (N, D) Tensor of positional embeddings.
|
38 |
+
"""
|
39 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
40 |
+
|
41 |
+
args = t[:, None].float() * self.freqs[None]
|
42 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
43 |
+
return embedding
|
44 |
|
45 |
def forward(self, t):
|
46 |
+
t_freq = self.timestep_embedding(t).to(t.dtype)
|
47 |
+
t_emb = self.mlp(t_freq)
|
48 |
+
return t_emb
|
|
mmaudio/model/networks.py
CHANGED
@@ -166,8 +166,10 @@ class MMAudio(nn.Module):
|
|
166 |
self._clip_seq_len,
|
167 |
device=self.device)
|
168 |
|
169 |
-
self.latent_rot =
|
170 |
-
self.clip_rot =
|
|
|
|
|
171 |
|
172 |
def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
|
173 |
self._latent_seq_len = latent_seq_len
|
@@ -346,7 +348,7 @@ class MMAudio(nn.Module):
|
|
346 |
if 'clip_rot' in src_dict:
|
347 |
del src_dict['clip_rot']
|
348 |
|
349 |
-
self.load_state_dict(src_dict, strict=
|
350 |
|
351 |
@property
|
352 |
def device(self) -> torch.device:
|
@@ -466,4 +468,4 @@ if __name__ == '__main__':
|
|
466 |
|
467 |
# print the number of parameters in terms of millions
|
468 |
num_params = sum(p.numel() for p in network.parameters()) / 1e6
|
469 |
-
print(f'Number of parameters: {num_params:.2f}M')
|
|
|
166 |
self._clip_seq_len,
|
167 |
device=self.device)
|
168 |
|
169 |
+
# self.latent_rot = latent_rot.to(self.device)
|
170 |
+
# self.clip_rot = clip_rot.to(self.device)
|
171 |
+
self.register_buffer('latent_rot', latent_rot)
|
172 |
+
self.register_buffer('clip_rot', clip_rot)
|
173 |
|
174 |
def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
|
175 |
self._latent_seq_len = latent_seq_len
|
|
|
348 |
if 'clip_rot' in src_dict:
|
349 |
del src_dict['clip_rot']
|
350 |
|
351 |
+
self.load_state_dict(src_dict, strict=False)
|
352 |
|
353 |
@property
|
354 |
def device(self) -> torch.device:
|
|
|
468 |
|
469 |
# print the number of parameters in terms of millions
|
470 |
num_params = sum(p.numel() for p in network.parameters()) / 1e6
|
471 |
+
print(f'Number of parameters: {num_params:.2f}M')
|