yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
import torch
from torch import nn, Tensor, einsum
from typing import Optional, Tuple
import math
from functools import partial
from collections import OrderedDict
import torch.nn.functional as F
from einops import rearrange
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
class ResBlock(nn.Module):
def __init__(self, chan_in, hidden_size, chan_out):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan_in, hidden_size, 3, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_size, hidden_size, 3, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_size, chan_out, 1),
)
def forward(self, x):
return self.net(x) + x
class BasicVAE(nn.Module):
def get_codebook_indices(self, images):
raise NotImplementedError()
def decode(self, img_seq):
raise NotImplementedError()
def get_codebook_probs(self, img_seq):
raise NotImplementedError()
def get_image_tokens_size(self):
pass
def get_image_size(self):
pass
class DiscreteVAE(BasicVAE):
def __init__(
self,
image_size: Tuple[int, int] = [256, 256], # input image size
codebook_tokens: int = 512, # codebook vocab size
codebook_dim: int = 512, # codebook embedding dimension
num_layers: int = 3, # layers of resnet blocks in encoder/decoder
hidden_dim: int = 64, # dimension in resnet blocks
channels: int = 3, # input channels
smooth_l1_loss: bool = False, # prevents exploding gradients
temperature: float = 0.9, # tau in gumbel softmax
straight_through: bool = False, # if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd
kl_div_loss_weight: float = 0.0,
):
super().__init__()
assert num_layers >= 1, "number of layers must be greater than or equal to 1"
self.image_size = image_size
self.codebook_tokens = codebook_tokens
self.num_layers = num_layers
self.temperature = temperature
self.straight_through = straight_through
self.codebook = nn.Embedding(codebook_tokens, codebook_dim)
encoder_layers = list()
decoder_layers = list()
encoder_in = channels
decoder_in = codebook_dim
for _ in range(num_layers):
encoder_layers.append(
nn.Sequential(
nn.Conv2d(encoder_in, hidden_dim, 4, stride=2, padding=1), nn.ReLU()
)
)
encoder_layers.append(
ResBlock(
chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim
)
)
encoder_in = hidden_dim
decoder_layers.append(
nn.Sequential(
nn.ConvTranspose2d(decoder_in, hidden_dim, 4, stride=2, padding=1),
nn.ReLU(),
)
)
decoder_layers.append(
ResBlock(
chan_in=hidden_dim, hidden_size=hidden_dim, chan_out=hidden_dim
)
)
decoder_in = hidden_dim
encoder_layers.append(nn.Conv2d(hidden_dim, codebook_tokens, 1))
decoder_layers.append(nn.Conv2d(hidden_dim, channels, 1))
self.encoder = nn.Sequential(*encoder_layers)
self.decoder = nn.Sequential(*decoder_layers)
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
self.kl_div_loss_weight = kl_div_loss_weight
def get_image_size(self):
return self.image_size
def get_image_tokens_size(self) -> int:
ds_ratio = math.pow(2, self.num_layers)
return int((self.image_size[0] // ds_ratio) * (self.image_size[1] // ds_ratio))
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images: Tensor):
logits = self.forward(images, return_logits=True)
codebook_indices = logits.argmax(dim=1)
return codebook_indices
@torch.no_grad()
@eval_decorator
def get_codebook_probs(self, images: Tensor):
logits = self.forward(images, return_logits=True)
return nn.Softmax(dim=1)(logits)
def decode(self, img_seq: Tensor):
image_embeds = self.codebook(img_seq)
image_embeds = image_embeds.permute((0, 3, 1, 2)).contiguous()
# image_embeds = rearrange(image_embeds, "b h w d -> b d h w", h=h, w=w)
images = self.decoder(image_embeds)
return images
def forward(
self,
img: Tensor,
return_loss: bool = False,
return_recons: bool = False,
return_logits: bool = False,
temp=None,
) -> Tuple[Tensor, Optional[Tensor]]:
assert (
img.shape[-1] == self.image_size[0] and img.shape[-2] == self.image_size[1]
), f"input must have the correct image size {self.image_size}"
logits = self.encoder(img)
if return_logits:
return logits # return logits for getting hard image indices for DALL-E training
temp = default(temp, self.temperature)
soft_one_hot = F.gumbel_softmax(
logits, tau=temp, dim=1, hard=self.straight_through
)
sampled = einsum(
"b n h w, n d -> b d h w", soft_one_hot, self.codebook.weight
).contiguous()
out = self.decoder(sampled)
if not return_loss:
return out
# reconstruction loss
recon_loss = self.loss_fn(img, out)
# kl divergence
logits = rearrange(logits, "b n h w -> b (h w) n").contiguous()
qy = F.softmax(logits, dim=-1)
log_qy = torch.log(qy + 1e-10)
log_uniform = torch.log(
torch.tensor([1.0 / self.codebook_tokens], device=img.device)
)
kl_div = F.kl_div(log_uniform, log_qy, None, None, "batchmean", log_target=True)
loss = recon_loss + (kl_div * self.kl_div_loss_weight)
if not return_recons:
return loss
return loss, out
if __name__ == "__main__":
input = torch.rand(1, 3, 256, 256)
model = DiscreteVAE()
loss, output = model(input, return_loss=True, return_recons=True)
print(model)
print(model.get_image_tokens_size())
print(model.get_codebook_indices(input).shape)
print(loss, output.shape, output.max(), output.min())