File size: 3,358 Bytes
e34aada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Lookup Free Quantization
Proposed in https://arxiv.org/abs/2310.05737

basically a 2-level FSQ (Finite Scalar Quantization) with entropy loss
https://arxiv.org/abs/2309.15505
"""

import torch
from einops import rearrange
from torch.nn import Module


# entropy

def binary_entropy(prob):
    return -prob * log(prob) - (1 - prob) * log(1 - prob)


# tensor helpers

def log(t, eps=1e-20):
    return t.clamp(min=eps).log()


# convert to bit representations and back

def decimal_to_bits(x: torch.LongTensor, bits: int) -> torch.FloatTensor:
    # [b, ...] {0, 1, ..., max - 1} -> [b, ..., d] {-1, 1}
    mask = 2 ** torch.arange(bits).to(x)  # [d]
    bits = ((x.unsqueeze(-1) & mask) != 0).float()  # [b, n, d] {0, 1}
    return bits * 2 - 1   # {0, 1} -> {-1, 1}


def bits_to_decimal(x: torch.FloatTensor) -> torch.LongTensor:
    # [b, ..., d] {-1, 1} -> [b, ...] {0, 1, ..., max - 1}
    x = (x > 0).long()   # {-1, 1} -> {0, 1}, [b, ..., d]
    mask = 2 ** torch.arange(x.size(-1)).to(x)  # [d]
    dec = (x * mask).sum(-1)  # [b, ...]
    return dec


# class

class LFQY(Module):
    def __init__(self, dim, entropy_loss_weight=0.1, diversity_gamma=1.0):
        super().__init__()
        self.dim = dim
        self.diversity_gamma = diversity_gamma
        self.entropy_loss_weight = entropy_loss_weight

    def indices_to_codes(self, indices):
        codes = decimal_to_bits(indices, self.dim)
        # codes = rearrange(codes, 'b ... d -> b d ...')
        return codes

    def forward(self, x, mask=None, inv_temperature=1.):
        """
        einstein notation
        b - batch
        n - sequence (or flattened spatial dimensions)
        d - feature dimension, which is also log2(codebook size)
        """
        # x = rearrange(x, 'b d ... -> b ... d')

        assert x.shape[-1] == self.dim
        z = torch.tanh(x / inv_temperature)  # (-1, 1)

        # quantize by eq 3.
        quantized = torch.sign(x)  # {-1, 1}
        z = z + (quantized - z).detach()

        # calculate indices
        indices = bits_to_decimal(z)

        # entropy aux loss
        if self.training:
            prob = torch.sigmoid(x / inv_temperature)  # [b, ..., d]

            bit_entropy = binary_entropy(prob).sum(-1).mean()
            # E[H(q)] = avg(sum(H(q_i)))

            avg_prob = prob.flatten(0, -2).mean(0)  # [b, ..., d] -> [n, d] -> [d]
            codebook_entropy = binary_entropy(avg_prob).sum()
            # H(E[q]) = sum(H(avg(q_i)))

            """
                1. entropy will be nudged to be low for each bit, 
                so each scalar commits to one latent binary bit or the other.
                2. codebook entropy will be nudged to be high,
                to encourage all codes to be uniformly used.
            """

            entropy_aux_loss = bit_entropy - self.diversity_gamma * codebook_entropy
        else:
            # if not training, just return dummy 0
            entropy_aux_loss = torch.zeros(1).to(z)

        entropy_aux_loss = entropy_aux_loss * self.entropy_loss_weight

        # reconstitute image or video dimensions

        # z = rearrange(z, 'b ... d -> b d ...')

        # bits to decimal for the codebook indices
        return z, entropy_aux_loss, indices

    def get_codebook_entry(self, encoding_indices):
        return self.indices_to_codes(encoding_indices)