import torch import torch.nn as nn import math # https://github.com/facebookresearch/DiT class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_dim: int, frequency_embedding_size: int = 256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_dim, bias=True), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim, bias=True), ) self.frequency_embedding_size = frequency_embedding_size half_dim = self.frequency_embedding_size // 2 freqs = torch.exp( -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32) / half_dim ) self.register_buffer('freqs', freqs) def forward(self, t): t_freq = t.unsqueeze(-1) * self.freqs.unsqueeze(0) t_embed = torch.cat([t_freq.sin(), t_freq.cos()], dim=-1) t_embed = self.mlp(t_embed.to(t.dtype)) return t_embed