sonisphere / mmaudio /model /embeddings.py
Phil Sobrepena
deps
31bd90e
raw
history blame
1.04 kB
import torch
import torch.nn as nn
import math
# https://github.com/facebookresearch/DiT
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_dim: int, frequency_embedding_size: int = 256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_dim, bias=True),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
half_dim = self.frequency_embedding_size // 2
freqs = torch.exp(
-math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32) /
half_dim
)
self.register_buffer('freqs', freqs)
def forward(self, t):
t_freq = t.unsqueeze(-1) * self.freqs.unsqueeze(0)
t_embed = torch.cat([t_freq.sin(), t_freq.cos()], dim=-1)
t_embed = self.mlp(t_embed.to(t.dtype))
return t_embed