File size: 8,614 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) 2022 Microsoft
# Modified from
# https://github.com/microsoft/unilm/blob/master/beit2/norm_ema_quantizer.py
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from mmengine.dist import all_reduce


def ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor,
                decay: torch.Tensor) -> None:
    """Update moving average."""
    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))


def norm_ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor,
                     decay: torch.Tensor) -> None:
    """Update moving average with norm data."""
    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
    moving_avg.data.copy_(F.normalize(moving_avg.data, p=2, dim=-1))


def sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor:
    """Sample vectors according to the given number."""
    num_samples, device = samples.shape[0], samples.device

    if num_samples >= num:
        indices = torch.randperm(num_samples, device=device)[:num]
    else:
        indices = torch.randint(0, num_samples, (num, ), device=device)

    return samples[indices]


def kmeans(samples: torch.Tensor,
           num_clusters: int,
           num_iters: int = 10,
           use_cosine_sim: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
    """Run k-means algorithm."""
    dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device

    means = sample_vectors(samples, num_clusters)

    for _ in range(num_iters):
        if use_cosine_sim:
            dists = samples @ means.t()
        else:
            diffs = rearrange(samples, 'n d -> n () d') \
                    - rearrange(means, 'c d -> () c d')
            dists = -(diffs**2).sum(dim=-1)

        buckets = dists.max(dim=-1).indices
        bins = torch.bincount(buckets, minlength=num_clusters)
        zero_mask = bins == 0
        bins_min_clamped = bins.masked_fill(zero_mask, 1)

        new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
        new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
        new_means = new_means / bins_min_clamped[..., None]

        if use_cosine_sim:
            new_means = F.normalize(new_means, p=2, dim=-1)

        means = torch.where(zero_mask[..., None], means, new_means)

    return means, bins


class EmbeddingEMA(nn.Module):
    """The codebook of embedding vectors.

    Args:
        num_tokens (int): Number of embedding vectors in the codebook.
        codebook_dim (int) : The dimension of embedding vectors in the
            codebook.
        kmeans_init (bool): Whether to use k-means to initialize the
            VectorQuantizer. Defaults to True.
        codebook_init_path (str): The initialization checkpoint for codebook.
            Defaults to None.
    """

    def __init__(self,
                 num_tokens: int,
                 codebook_dim: int,
                 kmeans_init: bool = True,
                 codebook_init_path: Optional[str] = None):
        super().__init__()
        self.num_tokens = num_tokens
        self.codebook_dim = codebook_dim
        if codebook_init_path is None:
            if not kmeans_init:
                weight = torch.randn(num_tokens, codebook_dim)
                weight = F.normalize(weight, p=2, dim=-1)
            else:
                weight = torch.zeros(num_tokens, codebook_dim)
            self.register_buffer('initted', torch.Tensor([not kmeans_init]))
        else:
            print(f'load init codebook weight from {codebook_init_path}')
            codebook_ckpt_weight = torch.load(
                codebook_init_path, map_location='cpu')
            weight = codebook_ckpt_weight.clone()
            self.register_buffer('initted', torch.Tensor([True]))

        self.weight = nn.Parameter(weight, requires_grad=False)
        self.update = True

    @torch.jit.ignore
    def init_embed_(self, data: torch.Tensor) -> None:
        """Initialize embedding vectors of codebook."""
        if self.initted:
            return
        print('Performing K-means init for codebook')
        embed, _ = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
        self.weight.data.copy_(embed)
        self.initted.data.copy_(torch.Tensor([True]))

    def forward(self, embed_id: torch.Tensor) -> torch.Tensor:
        """Get embedding vectors."""
        return F.embedding(embed_id, self.weight)


class NormEMAVectorQuantizer(nn.Module):
    """Normed EMA vector quantizer module.

    Args:
        num_embed (int): Number of embedding vectors in the codebook. Defaults
            to 8192.
        embed_dims (int) : The dimension of embedding vectors in the codebook.
            Defaults to 32.
        beta (float): The mutiplier for VectorQuantizer embedding loss.
            Defaults to 1.
        decay (float): The decay parameter of EMA. Defaults to 0.99.
        statistic_code_usage (bool): Whether to use cluster_size to record
            statistic. Defaults to True.
        kmeans_init (bool): Whether to use k-means to initialize the
            VectorQuantizer. Defaults to True.
        codebook_init_path (str): The initialization checkpoint for codebook.
            Defaults to None.
    """

    def __init__(self,
                 num_embed: int,
                 embed_dims: int,
                 beta: float,
                 decay: float = 0.99,
                 statistic_code_usage: bool = True,
                 kmeans_init: bool = True,
                 codebook_init_path: Optional[str] = None) -> None:
        super().__init__()
        self.codebook_dim = embed_dims
        self.num_tokens = num_embed
        self.beta = beta
        self.decay = decay

        # learnable = True if orthogonal_reg_weight > 0 else False
        self.embedding = EmbeddingEMA(
            num_tokens=self.num_tokens,
            codebook_dim=self.codebook_dim,
            kmeans_init=kmeans_init,
            codebook_init_path=codebook_init_path)

        self.statistic_code_usage = statistic_code_usage
        if statistic_code_usage:
            self.register_buffer('cluster_size', torch.zeros(num_embed))

    def reset_cluster_size(self, device):

        if self.statistic_code_usage:
            self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
            self.cluster_size = self.cluster_size.to(device)

    def forward(self, z):
        """Forward function."""
        # reshape z -> (batch, height, width, channel)
        z = rearrange(z, 'b c h w -> b h w c')
        z = F.normalize(z, p=2, dim=-1)
        z_flattened = z.reshape(-1, self.codebook_dim)

        self.embedding.init_embed_(z_flattened)

        # 'n d -> d n'
        d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
            self.embedding.weight.pow(2).sum(dim=1) - 2 * \
            torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight)

        encoding_indices = torch.argmin(d, dim=1)

        z_q = self.embedding(encoding_indices).view(z.shape)

        encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)

        if not self.training:
            with torch.no_grad():
                cluster_size = encodings.sum(0)
                all_reduce(cluster_size)
                ema_inplace(self.cluster_size, cluster_size, self.decay)

        if self.training and self.embedding.update:
            # update cluster size with EMA
            bins = encodings.sum(0)
            all_reduce(bins)
            ema_inplace(self.cluster_size, bins, self.decay)

            zero_mask = (bins == 0)
            bins = bins.masked_fill(zero_mask, 1.)

            embed_sum = z_flattened.t() @ encodings
            all_reduce(embed_sum)

            embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
            embed_normalized = F.normalize(embed_normalized, p=2, dim=-1)
            embed_normalized = torch.where(zero_mask[..., None],
                                           self.embedding.weight,
                                           embed_normalized)

            # Update embedding vectors with EMA
            norm_ema_inplace(self.embedding.weight, embed_normalized,
                             self.decay)

        # compute loss for embedding
        loss = self.beta * F.mse_loss(z_q.detach(), z)

        # preserve gradients
        z_q = z + (z_q - z).detach()

        # reshape back to match original input shape
        z_q = rearrange(z_q, 'b h w c -> b c h w')
        return z_q, loss, encoding_indices