Spaces:
Build error
Build error
import torch | |
from torch import nn, Tensor, einsum | |
from typing import Optional, Tuple | |
import math | |
from functools import partial | |
from collections import OrderedDict | |
import torch.nn.functional as F | |
from einops import rearrange | |
def exists(val): | |
return val is not None | |
def default(val, d): | |
return val if exists(val) else d | |
def eval_decorator(fn): | |
def inner(model, *args, **kwargs): | |
was_training = model.training | |
model.eval() | |
out = fn(model, *args, **kwargs) | |
model.train(was_training) | |
return out | |
return inner | |
class ResBlock(nn.Module): | |
def __init__(self, chan_in, hidden_size, chan_out): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Conv2d(chan_in, hidden_size, 3, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(hidden_size, hidden_size, 3, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(hidden_size, chan_out, 1), | |
) | |
def forward(self, x): | |
return self.net(x) + x | |
class BasicVAE(nn.Module): | |
def get_codebook_indices(self, images): | |
raise NotImplementedError() | |
def decode(self, img_seq): | |
raise NotImplementedError() | |
def get_codebook_probs(self, img_seq): | |
raise NotImplementedError() | |
def get_image_tokens_size(self): | |
pass | |
def get_image_size(self): | |
pass | |
class DiscreteVAE(BasicVAE): | |
def __init__( | |
self, | |
image_size: Tuple[int, int] = [256, 256], # input image size | |
codebook_tokens: int = 512, # codebook vocab size | |
codebook_dim: int = 512, # codebook embedding dimension | |
num_layers: int = 3, # layers of resnet blocks in encoder/decoder | |
hidden_dim: int = 64, # dimension in resnet blocks | |
channels: int = 3, # input channels | |
smooth_l1_loss: bool = False, # prevents exploding gradients | |
temperature: float = 0.9, # tau in gumbel softmax | |
straight_through: bool = False, # if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd | |
kl_div_loss_weight: float = 0.0, | |
): | |
super().__init__() | |
assert num_layers >= 1, "number of layers must be greater than or equal to 1" | |
self.image_size = image_size | |
self.codebook_tokens = codebook_tokens | |
self.num_layers = num_layers | |
self.temperature = temperature | |
self.straight_through = straight_through | |
self.codebook = nn.Embedding(codebook_tokens, codebook_dim) | |
encoder_layers = list() | |
decoder_layers = list() | |
encoder_in = channels | |
decoder_in = codebook_dim | |
for _ in range(num_layers): | |
encoder_layers.append( | |
nn.Sequential( | |
nn.Conv2d(encoder_in, hidden_dim, 4, stride=2, padding=1), nn.ReLU() | |
) | |
) | |
encoder_layers.append( | |
ResBlock( | |
chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim | |
) | |
) | |
encoder_in = hidden_dim | |
decoder_layers.append( | |
nn.Sequential( | |
nn.ConvTranspose2d(decoder_in, hidden_dim, 4, stride=2, padding=1), | |
nn.ReLU(), | |
) | |
) | |
decoder_layers.append( | |
ResBlock( | |
chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim | |
) | |
) | |
decoder_in = hidden_dim | |
encoder_layers.append(nn.Conv2d(hidden_dim, codebook_tokens, 1)) | |
decoder_layers.append(nn.Conv2d(hidden_dim, channels, 1)) | |
self.encoder = nn.Sequential(*encoder_layers) | |
self.decoder = nn.Sequential(*decoder_layers) | |
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss | |
self.kl_div_loss_weight = kl_div_loss_weight | |
def get_image_size(self): | |
return self.image_size | |
def get_image_tokens_size(self) -> int: | |
ds_ratio = math.pow(2, self.num_layers) | |
return int((self.image_size[0] // ds_ratio) * (self.image_size[1] // ds_ratio)) | |
def get_codebook_indices(self, images: Tensor): | |
logits = self.forward(images, return_logits=True) | |
codebook_indices = logits.argmax(dim=1) | |
return codebook_indices | |
def get_codebook_probs(self, images: Tensor): | |
logits = self.forward(images, return_logits=True) | |
return nn.Softmax(dim=1)(logits) | |
def decode(self, img_seq: Tensor): | |
image_embeds = self.codebook(img_seq) | |
image_embeds = image_embeds.permute((0, 3, 1, 2)).contiguous() | |
# image_embeds = rearrange(image_embeds, "b h w d -> b d h w", h=h, w=w) | |
images = self.decoder(image_embeds) | |
return images | |
def forward( | |
self, | |
img: Tensor, | |
return_loss: bool = False, | |
return_recons: bool = False, | |
return_logits: bool = False, | |
temp=None, | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
assert ( | |
img.shape[-1] == self.image_size[0] and img.shape[-2] == self.image_size[1] | |
), f"input must have the correct image size {self.image_size}" | |
logits = self.encoder(img) | |
if return_logits: | |
return logits # return logits for getting hard image indices for DALL-E training | |
temp = default(temp, self.temperature) | |
soft_one_hot = F.gumbel_softmax( | |
logits, tau=temp, dim=1, hard=self.straight_through | |
) | |
sampled = einsum( | |
"b n h w, n d -> b d h w", soft_one_hot, self.codebook.weight | |
).contiguous() | |
out = self.decoder(sampled) | |
if not return_loss: | |
return out | |
# reconstruction loss | |
recon_loss = self.loss_fn(img, out) | |
# kl divergence | |
logits = rearrange(logits, "b n h w -> b (h w) n").contiguous() | |
qy = F.softmax(logits, dim=-1) | |
log_qy = torch.log(qy + 1e-10) | |
log_uniform = torch.log( | |
torch.tensor([1.0 / self.codebook_tokens], device=img.device) | |
) | |
kl_div = F.kl_div(log_uniform, log_qy, None, None, "batchmean", log_target=True) | |
loss = recon_loss + (kl_div * self.kl_div_loss_weight) | |
if not return_recons: | |
return loss | |
return loss, out | |
if __name__ == "__main__": | |
input = torch.rand(1, 3, 256, 256) | |
model = DiscreteVAE() | |
loss, output = model(input, return_loss=True, return_recons=True) | |
print(model) | |
print(model.get_image_tokens_size()) | |
print(model.get_codebook_indices(input).shape) | |
print(loss, output.shape, output.max(), output.min()) | |