File size: 1,040 Bytes
73ed896
 
31bd90e
73ed896
 
 
 
 
 
 
 
 
31bd90e
73ed896
31bd90e
73ed896
31bd90e
73ed896
31bd90e
73ed896
 
31bd90e
 
 
 
 
 
 
73ed896
 
31bd90e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
import math

# https://github.com/facebookresearch/DiT


class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """

    def __init__(self, hidden_dim: int, frequency_embedding_size: int = 256):
        super().__init__()

        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_dim, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim, bias=True),
        )

        self.frequency_embedding_size = frequency_embedding_size
        half_dim = self.frequency_embedding_size // 2
        freqs = torch.exp(
            -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32) /
            half_dim
        )
        self.register_buffer('freqs', freqs)

    def forward(self, t):
        t_freq = t.unsqueeze(-1) * self.freqs.unsqueeze(0)
        t_embed = torch.cat([t_freq.sin(), t_freq.cos()], dim=-1)
        t_embed = self.mlp(t_embed.to(t.dtype))
        return t_embed