diff --git "a/PD_pLMProbXDiff/ModelPack.py" "b/PD_pLMProbXDiff/ModelPack.py" new file mode 100644--- /dev/null +++ "b/PD_pLMProbXDiff/ModelPack.py" @@ -0,0 +1,5578 @@ +######################################################## +## Attention-Diffusion model +######################################################## + +#based on: https://github.com/lucidrains/imagen-pytorch + +import math +import copy +from random import random +from typing import List, Union +from tqdm.auto import tqdm +from functools import partial, wraps +from contextlib import contextmanager, nullcontext +from collections import namedtuple +from pathlib import Path + +import torch +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel +from torch import nn, einsum +from torch.cuda.amp import autocast +from torch.special import expm1 +import torchvision.transforms as T + +import kornia.augmentation as K + +from einops import rearrange, repeat, reduce +from einops.layers.torch import Rearrange, Reduce + +from einops_exts import rearrange_many, repeat_many, check_shape +from einops_exts.torch import EinopsToAndFrom + +# +from tensorflow.keras.preprocessing import text, sequence +from tensorflow.keras.preprocessing.text import Tokenizer +# # -- +# from PD_SpLMxDiff.UtilityPack import prepare_UNet_keys, modify_keys, params +# ++ +from PD_pLMProbXDiff.UtilityPack import ( + prepare_UNet_keys, modify_keys, params +) + +# +from torchinfo import summary +import json +# quick way to identify the device: +device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" +) +print('identify the device independently', device) + +# ================================================== +# helper functions +# ================================================== + +def exists(val): + return val is not None + +def identity(t, *args, **kwargs): + return t + +def first(arr, d = None): + if len(arr) == 0: + return d + return arr[0] + +def maybe(fn): + @wraps(fn) + def inner(x): + if not exists(x): + return x + return fn(x) + return inner + +def once(fn): + called = False + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + return inner + +print_once = once(print) + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def cast_tuple(val, length = None): + if isinstance(val, list): + val = tuple(val) + + output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) + + if exists(length): + assert len(output) == length + + return output + +def is_float_dtype(dtype): + return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)]) + +def cast_uint8_images_to_float(images): + if not images.dtype == torch.uint8: + return images + return images / 255 + +def module_device(module): + return next(module.parameters()).device + +def zero_init_(m): + nn.init.zeros_(m.weight) + if exists(m.bias): + nn.init.zeros_(m.bias) + +def eval_decorator(fn): + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + return inner + +def pad_tuple_to_length(t, length, fillvalue = None): + remain_length = length - len(t) + if remain_length <= 0: + return t + return (*t, *((fillvalue,) * remain_length)) + +# ================================================== +# helper classes +# ================================================== + +class Identity(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + +# ================================================== +# tensor helpers +# ================================================== + +def log(t, eps: float = 1e-12): + return torch.log(t.clamp(min = eps)) + +def l2norm(t): + return F.normalize(t, dim = -1) + +def right_pad_dims_to(x, t): + padding_dims = x.ndim - t.ndim + if padding_dims <= 0: + return t + return t.view(*t.shape, *((1,) * padding_dims)) + +def masked_mean(t, *, dim, mask = None): + if not exists(mask): + return t.mean(dim = dim) + + denom = mask.sum(dim = dim, keepdim = True) + mask = rearrange(mask, 'b n -> b n 1') + masked_t = t.masked_fill(~mask, 0.) + + return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) + +def resize_image_to( + image, + target_image_size, + clamp_range = None +): + orig_image_size = image.shape[-1] + + if orig_image_size == target_image_size: + return image + + out = F.interpolate(image.float(), target_image_size, mode = 'linear', align_corners = True) + + return out + +# image normalization functions +# ddpms expect images to be in the range of -1 to 1 +# +def normalize_neg_one_to_one(img): + return img * 2 - 1 + +def unnormalize_zero_to_one(normed_img): + return (normed_img + 1) * 0.5 + +# classifier free guidance functions +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device = device, dtype = torch.bool) + elif prob == 0: + return torch.zeros(shape, device = device, dtype = torch.bool) + else: + return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob + +# ================================================== +# gaussian diffusion with continuous time helper functions and classes +# large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py +# ================================================== + +@torch.jit.script +def beta_linear_log_snr(t): + return -torch.log(expm1(1e-4 + 10 * (t ** 2))) + +@torch.jit.script +def alpha_cosine_log_snr(t, s: float = 0.008): + return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version + +def log_snr_to_alpha_sigma(log_snr): + return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr)) + +class GaussianDiffusionContinuousTimes(nn.Module): + def __init__(self, *, noise_schedule, timesteps = 1000): + super().__init__() + + if noise_schedule == "linear": + self.log_snr = beta_linear_log_snr + elif noise_schedule == "cosine": + self.log_snr = alpha_cosine_log_snr + else: + raise ValueError(f'invalid noise schedule {noise_schedule}') + + self.num_timesteps = timesteps + + def get_times(self, batch_size, noise_level, *, device): + return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32) + + def sample_random_times(self, batch_size, max_thres = 0.999, *, device): + return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres) + + def get_condition(self, times): + return maybe(self.log_snr)(times) + + def get_sampling_timesteps(self, batch, *, device): + times = torch.linspace(1., 0., self.num_timesteps + 1, device = device) + times = repeat(times, 't -> b t', b = batch) + times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0) + times = times.unbind(dim = -1) + return times + + def q_posterior(self, x_start, x_t, t, *, t_next = None): + t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.)) + + """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """ + log_snr = self.log_snr(t) + log_snr_next = self.log_snr(t_next) + log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next)) + + alpha, sigma = log_snr_to_alpha_sigma(log_snr) + alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next) + + # c - as defined near eq 33 + c = -expm1(log_snr - log_snr_next) + posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start) + + # following (eq. 33) + posterior_variance = (sigma_next ** 2) * c + posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def q_sample(self, x_start, t, noise = None): + dtype = x_start.dtype + + if isinstance(t, float): + batch = x_start.shape[0] + t = torch.full((batch,), t, device = x_start.device, dtype = dtype) + + noise = default(noise, lambda: torch.randn_like(x_start)) + log_snr = self.log_snr(t).type(dtype) + log_snr_padded_dim = right_pad_dims_to(x_start, log_snr) + alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) + + return alpha * x_start + sigma * noise, log_snr + + def q_sample_from_to(self, x_from, from_t, to_t, noise = None): + shape, device, dtype = x_from.shape, x_from.device, x_from.dtype + batch = shape[0] + + if isinstance(from_t, float): + from_t = torch.full((batch,), from_t, device = device, dtype = dtype) + + if isinstance(to_t, float): + to_t = torch.full((batch,), to_t, device = device, dtype = dtype) + + noise = default(noise, lambda: torch.randn_like(x_from)) + + log_snr = self.log_snr(from_t) + log_snr_padded_dim = right_pad_dims_to(x_from, log_snr) + alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) + + log_snr_to = self.log_snr(to_t) + log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to) + alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to) + + return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha + + def predict_start_from_noise(self, x_t, t, noise): + log_snr = self.log_snr(t) + log_snr = right_pad_dims_to(x_t, log_snr) + alpha, sigma = log_snr_to_alpha_sigma(log_snr) + return (x_t - sigma * noise) / alpha.clamp(min = 1e-8) + +# ================================================== +# norms and residuals +# ================================================== + +class LayerNorm(nn.Module): + def __init__(self, feats, stable = False, dim = -1): + super().__init__() + self.stable = stable + self.dim = dim + + self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1)))) + + def forward(self, x): + dtype, dim = x.dtype, self.dim + + if self.stable: + x = x / x.amax(dim = dim, keepdim = True).detach() + + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim = dim, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = dim, keepdim = True) + + return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype) + +ChanLayerNorm = partial(LayerNorm, dim = -2) + +class Always(): + def __init__(self, val): + self.val = val + + def __call__(self, *args, **kwargs): + return self.val + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + +class Parallel(nn.Module): + def __init__(self, *fns): + super().__init__() + self.fns = nn.ModuleList(fns) + + def forward(self, x): + outputs = [fn(x) for fn in self.fns] + return sum(outputs) + +# ================================================== +# attention pooling +# ================================================== +# modified +class PerceiverAttention(nn.Module): + def __init__( + self, + *, + dim, + dim_head = 64, + heads = 8, + cosine_sim_attn = False + ): + super().__init__() + self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1 + self.cosine_sim_attn = cosine_sim_attn + self.cosine_sim_scale = 16 if cosine_sim_attn else 1 + + self.heads = heads + inner_dim = dim_head * heads + + self.norm = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim, bias = False), + nn.LayerNorm(dim) + ) + + def forward(self, x, latents, mask = None): + x = self.norm(x) + latents = self.norm_latents(latents) + + b, h = x.shape[0], self.heads + + q = self.to_q(latents) + + # the paper differs from Perceiver in which they also concat the key / values + # derived from the latents to be attended to + kv_input = torch.cat((x, latents), dim = -2) + k, v = self.to_kv(kv_input).chunk(2, dim = -1) + + q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h) + + q = q * self.scale + + # cosine sim attention + + if self.cosine_sim_attn: + q, k = map(l2norm, (q, k)) + + # similarities and masking + + sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale + + if exists(mask): + max_neg_value = -torch.finfo(sim.dtype).max + mask = F.pad(mask, (0, latents.shape[-2]), value = True) + + mask = rearrange(mask, 'b j -> b 1 1 j') + sim = sim.masked_fill(~mask, max_neg_value) + + # attention + + attn = sim.softmax(dim = -1, dtype = torch.float32) + attn = attn.to(sim.dtype) + + out = einsum('... i j, ... j d -> ... i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)', h = h) + return self.to_out(out) + +class PerceiverResampler(nn.Module): + def __init__( + self, + *, + dim, + depth, + dim_head = 64, + heads = 8, + num_latents = 64, + num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence + max_seq_len = 512, + ff_mult = 4, + cosine_sim_attn = False + ): + super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, dim) + + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + + self.to_latents_from_mean_pooled_seq = None + + if num_latents_mean_pooled > 0: + self.to_latents_from_mean_pooled_seq = nn.Sequential( + LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled) + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn), + FeedForward(dim = dim, mult = ff_mult) + ])) + + def forward(self, x, mask = None): + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device = device)) + + x_with_pos = x + pos_emb + + latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0]) + + if exists(self.to_latents_from_mean_pooled_seq): + meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim = -2) + + for attn, ff in self.layers: + latents = attn(x_with_pos, latents, mask = mask) + latents + latents = ff(latents) + latents + + return latents + +# ==================================================== +# attention +# ==================================================== +class Attention(nn.Module): + def __init__( + self, + dim, + *, + dim_head = 64, + heads = 8, + context_dim = None, + cosine_sim_attn = False + ): + super().__init__() + self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1. + self.cosine_sim_attn = cosine_sim_attn + self.cosine_sim_scale = 16 if cosine_sim_attn else 1 + + self.heads = heads + inner_dim = dim_head * heads + + self.norm = LayerNorm(dim) + + self.null_kv = nn.Parameter(torch.randn(2, dim_head)) + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) + + self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim, bias = False), + LayerNorm(dim) + ) + + def forward(self, x, context = None, mask = None, attn_bias = None): + b, n, device = *x.shape[:2], x.device + + x = self.norm(x) + + q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) + + q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) + q = q * self.scale + + # add null key / value for classifier free guidance in prior net + + nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b) + k = torch.cat((nk, k), dim = -2) + v = torch.cat((nv, v), dim = -2) + + # add text conditioning, if present + + if exists(context): + assert exists(self.to_context) + ck, cv = self.to_context(context).chunk(2, dim = -1) + k = torch.cat((ck, k), dim = -2) + v = torch.cat((cv, v), dim = -2) + + # cosine sim attention + + if self.cosine_sim_attn: + q, k = map(l2norm, (q, k)) + + # calculate query / key similarities + + sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale + + # relative positional encoding (T5 style) + + if exists(attn_bias): + sim = sim + attn_bias + + # masking + + max_neg_value = -torch.finfo(sim.dtype).max + + if exists(mask): + mask = F.pad(mask, (1, 0), value = True) + + mask = rearrange(mask, 'b j -> b 1 j') + sim = sim.masked_fill(~mask, max_neg_value) + + # attention + + attn = sim.softmax(dim = -1, dtype = torch.float32) + attn = attn.to(sim.dtype) + + # aggregate values + + out = einsum('b h i j, b j d -> b h i d', attn, v) + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +# ============================================================ +# decoder +# ============================================================ +def Upsample(dim, dim_out = None): + dim_out = default(dim_out, dim) + + return nn.Sequential( + nn.Upsample(scale_factor = 2, mode = 'nearest'), + nn.Conv1d(dim, dim_out, 3, padding = 1) + ) + +class PixelShuffleUpsample(nn.Module): + """ + code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts + https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf + """ + def __init__(self, dim, dim_out = None): + super().__init__() + dim_out = default(dim_out, dim) + conv = nn.Conv1d(dim, dim_out * 4, 1) + + self.net = nn.Sequential( + conv, + nn.SiLU(), + nn.PixelShuffle(2) + ) + + self.init_conv_(conv) + + def init_conv_(self, conv): + + o, i, h = conv.weight.shape + conv_weight = torch.empty(o // 4, i, h ) + nn.init.kaiming_uniform_(conv_weight) + conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') + + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def forward(self, x): + return self.net(x) + +def Downsample(dim, dim_out = None): + # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample + # named SP-conv in the paper, but basically a pixel unshuffle + dim_out = default(dim_out, dim) + + return nn.Sequential( + + Rearrange('b c (h s1) -> b (c s1) h', s1 = 2), + nn.Conv1d(dim * 2, dim_out, 1) + ) + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) + emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') + return torch.cat((emb.sin(), emb.cos()), dim = -1) + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, 'b -> b 1') + freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) + fouriered = torch.cat((x, fouriered), dim = -1) + return fouriered + +class Block(nn.Module): + def __init__( + self, + dim, + dim_out, + groups = 8, + norm = True + ): + super().__init__() + self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() + self.activation = nn.SiLU() + self.project = nn.Conv1d(dim, dim_out, 3, padding = 1) + + def forward(self, x, scale_shift = None): + x = self.groupnorm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.activation(x) + return self.project(x) + +class ResnetBlock(nn.Module): + def __init__( + self, + dim, + dim_out, + *, + cond_dim = None, + time_cond_dim = None, + groups = 8, + linear_attn = False, + use_gca = False, + squeeze_excite = False, + **attn_kwargs + ): + super().__init__() + + self.time_mlp = None + + if exists(time_cond_dim): + self.time_mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_cond_dim, dim_out * 2) + ) + + self.cross_attn = None + + if exists(cond_dim): + attn_klass = CrossAttention if not linear_attn else LinearCrossAttention + + self.cross_attn = EinopsToAndFrom( + + 'b c h ', + 'b h c', + attn_klass( + dim = dim_out, + context_dim = cond_dim, + **attn_kwargs + ) + ) + + self.block1 = Block(dim, dim_out, groups = groups) + self.block2 = Block(dim_out, dim_out, groups = groups) + + self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1) + + self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else Identity() + + + def forward(self, x, time_emb = None, cond = None): + + scale_shift = None + if exists(self.time_mlp) and exists(time_emb): + time_emb = self.time_mlp(time_emb) + + time_emb = rearrange(time_emb, 'b c -> b c 1') + scale_shift = time_emb.chunk(2, dim = 1) + + h = self.block1(x) + + if exists(self.cross_attn): + assert exists(cond) + h = self.cross_attn(h, context = cond) + h + + h = self.block2(h, scale_shift = scale_shift) + + h = h * self.gca(h) + + return h + self.res_conv(x) + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + *, + context_dim = None, + dim_head = 64, + heads = 8, + norm_context = False, + cosine_sim_attn = False + ): + super().__init__() + self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1. + self.cosine_sim_attn = cosine_sim_attn + self.cosine_sim_scale = 16 if cosine_sim_attn else 1 + + self.heads = heads + inner_dim = dim_head * heads + + context_dim = default(context_dim, dim) + + self.norm = LayerNorm(dim) + self.norm_context = LayerNorm(context_dim) if norm_context else Identity() + + self.null_kv = nn.Parameter(torch.randn(2, dim_head)) + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim, bias = False), + LayerNorm(dim) + ) + + def forward(self, x, context, mask = None): + b, n, device = *x.shape[:2], x.device + + x = self.norm(x) + context = self.norm_context(context) + + q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) + + q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads) + + # add null key / value for classifier free guidance in prior net + + nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b) + + k = torch.cat((nk, k), dim = -2) + v = torch.cat((nv, v), dim = -2) + + q = q * self.scale + + # cosine sim attention + + if self.cosine_sim_attn: + q, k = map(l2norm, (q, k)) + + # similarities + + sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale + + # masking + + max_neg_value = -torch.finfo(sim.dtype).max + + if exists(mask): + mask = F.pad(mask, (1, 0), value = True) + + mask = rearrange(mask, 'b j -> b 1 j') + sim = sim.masked_fill(~mask, max_neg_value) + + attn = sim.softmax(dim = -1, dtype = torch.float32) + attn = attn.to(sim.dtype) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class LinearCrossAttention(CrossAttention): + def forward(self, x, context, mask = None): + b, n, device = *x.shape[:2], x.device + + x = self.norm(x) + context = self.norm_context(context) + + q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) + + q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads) + + # add null key / value for classifier free guidance in prior net + + nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads, b = b) + + k = torch.cat((nk, k), dim = -2) + v = torch.cat((nv, v), dim = -2) + + # masking + + max_neg_value = -torch.finfo(x.dtype).max + + if exists(mask): + mask = F.pad(mask, (1, 0), value = True) + mask = rearrange(mask, 'b n -> b n 1') + k = k.masked_fill(~mask, max_neg_value) + v = v.masked_fill(~mask, 0.) + + # linear attention + + q = q.softmax(dim = -1) + k = k.softmax(dim = -2) + + q = q * self.scale + + context = einsum('b n d, b n e -> b d e', k, v) + out = einsum('b n d, b d e -> b n e', q, context) + out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads) + return self.to_out(out) + +class LinearAttention(nn.Module): + def __init__( + self, + dim, + dim_head = 32, + heads = 8, + dropout = 0.05, + context_dim = None, + **kwargs + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + inner_dim = dim_head * heads + self.norm = ChanLayerNorm(dim) + + self.nonlin = nn.SiLU() + + self.to_q = nn.Sequential( + nn.Dropout(dropout), + nn.Conv1d(dim, inner_dim, 1, bias = False), + nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) + ) + + self.to_k = nn.Sequential( + nn.Dropout(dropout), + nn.Conv1d(dim, inner_dim, 1, bias = False), + nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) + ) + + self.to_v = nn.Sequential( + nn.Dropout(dropout), + nn.Conv1d(dim, inner_dim, 1, bias = False), + nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) + ) + + self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None + + self.to_out = nn.Sequential( + nn.Conv1d(inner_dim, dim, 1, bias = False), + ChanLayerNorm(dim) + ) + + def forward(self, fmap, context = None): + h, x, y = self.heads, *fmap.shape[-2:] + + fmap = self.norm(fmap) + q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v)) + q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h) + + if exists(context): + assert exists(self.to_context) + ck, cv = self.to_context(context).chunk(2, dim = -1) + ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h) + k = torch.cat((k, ck), dim = -2) + v = torch.cat((v, cv), dim = -2) + + q = q.softmax(dim = -1) + k = k.softmax(dim = -2) + + q = q * self.scale + + context = einsum('b n d, b n e -> b d e', k, v) + out = einsum('b n d, b d e -> b n e', q, context) + out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) + + out = self.nonlin(out) + return self.to_out(out) + +class GlobalContext(nn.Module): + """ basically a superior form of squeeze-excitation that is attention-esque """ + + def __init__( + self, + *, + dim_in, + dim_out + ): + super().__init__() + self.to_k = nn.Conv1d(dim_in, 1, 1) + hidden_dim = max(3, dim_out // 2) + + self.net = nn.Sequential( + nn.Conv1d(dim_in, hidden_dim, 1), + nn.SiLU(), + nn.Conv1d(hidden_dim, dim_out, 1), + nn.Sigmoid() + ) + + def forward(self, x): + context = self.to_k(x) + x, context = rearrange_many((x, context), 'b n ... -> b n (...)') + out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x) + + return self.net(out) + +def FeedForward(dim, mult = 2): + hidden_dim = int(dim * mult) + return nn.Sequential( + LayerNorm(dim), + nn.Linear(dim, hidden_dim, bias = False), + nn.GELU(), + LayerNorm(hidden_dim), + nn.Linear(hidden_dim, dim, bias = False) + ) + +def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width + hidden_dim = int(dim * mult) + return nn.Sequential( + ChanLayerNorm(dim), + nn.Conv1d(dim, hidden_dim, 1, bias = False), + nn.GELU(), + ChanLayerNorm(hidden_dim), + nn.Conv1d(hidden_dim, dim, 1, bias = False) + ) + +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + *, + depth = 1, + heads = 8, + dim_head = 32, + ff_mult = 2, + context_dim = None, + cosine_sim_attn = False + ): + super().__init__() + self.layers = nn.ModuleList([]) + + for _ in range(depth): + self.layers.append(nn.ModuleList([ + EinopsToAndFrom('b c h', 'b h c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)), + ChanFeedForward(dim = dim, mult = ff_mult) + ])) + + def forward(self, x, context = None): + for attn, ff in self.layers: + x = attn(x, context = context) + x + x = ff(x) + x + return x + +class LinearAttentionTransformerBlock(nn.Module): + def __init__( + self, + dim, + *, + depth = 1, + heads = 8, + dim_head = 32, + ff_mult = 2, + context_dim = None, + **kwargs + ): + super().__init__() + self.layers = nn.ModuleList([]) + + for _ in range(depth): + self.layers.append(nn.ModuleList([ + LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), + ChanFeedForward(dim = dim, mult = ff_mult) + ])) + + def forward(self, x, context = None): + for attn, ff in self.layers: + x = attn(x, context = context) + x + x = ff(x) + x + return x + +class CrossEmbedLayer(nn.Module): + def __init__( + self, + dim_in, + kernel_sizes, + dim_out = None, + stride = 2 + ): + super().__init__() + assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) + dim_out = default(dim_out, dim_in) + + kernel_sizes = sorted(kernel_sizes) + num_scales = len(kernel_sizes) + + # calculate the dimension at each scale + dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] + dim_scales = [*dim_scales, dim_out - sum(dim_scales)] + + self.convs = nn.ModuleList([]) + for kernel, dim_scale in zip(kernel_sizes, dim_scales): + self.convs.append(nn.Conv1d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2)) + + def forward(self, x): + fmaps = tuple(map(lambda conv: conv(x), self.convs)) + return torch.cat(fmaps, dim = 1) + +class UpsampleCombiner(nn.Module): + def __init__( + self, + dim, + *, + enabled = False, + dim_ins = tuple(), + dim_outs = tuple() + ): + super().__init__() + dim_outs = cast_tuple(dim_outs, len(dim_ins)) + assert len(dim_ins) == len(dim_outs) + + self.enabled = enabled + + if not self.enabled: + self.dim_out = dim + return + + self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)]) + self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0) + + def forward(self, x, fmaps = None): + target_size = x.shape[-1] + + fmaps = default(fmaps, tuple()) + + if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0: + return x + + fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps] + outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)] + return torch.cat((x, *outs), dim = 1) + +######################################################## +## 1D Unet updated +######################################################## +class OneD_Unet(nn.Module): + def __init__( + self, + *, + # +++++++++++++++++++++++++++ + CKeys=None, + PKeys=None, + ): + super().__init__() + + # save locals to take care of some hyperparameters for cascading DDPM + + self._locals = locals() + self._locals.pop('self', None) + self._locals.pop('__class__', None) + + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # unload the parameters + # print('I am here') + if CKeys['Debug_ModelPack']==1: + print(json.dumps(PKeys, indent=4)) + + dim = PKeys['dim'] # , + # image_embed_dim = 1024, # NOT used so far + text_embed_dim = default(PKeys['text_embed_dim'], 768) # 768, # get_encoded_dim(DEFAULT_T5_NAME) + num_resnet_blocks = default(PKeys['num_resnet_blocks'], 1)# 1, + cond_dim = default(PKeys['cond_dim'], None) # None, + num_image_tokens = default(PKeys['num_image_tokens'], 4) # 4, + num_time_tokens = default(PKeys['num_time_tokens'], 2) # 2, + learned_sinu_pos_emb_dim = default(PKeys['learned_sinu_pos_emb_dim'], 16) # 16, + out_dim = default(PKeys['out_dim'], None) # None, + dim_mults = default(PKeys['dim_mults'], (1, 2, 4, 8)) # (1, 2, 4, 8), + + cond_images_channels = default(PKeys['cond_images_channels'], 0) # 0, + channels = default(PKeys['channels'], 3) # 3, + channels_out = default(PKeys['channels_out'], None) # None, + + attn_dim_head = default(PKeys['attn_dim_head'], 64) # 64, + attn_heads = default(PKeys['attn_heads'], 8) # 8, + ff_mult = default(PKeys['ff_mult'], 2.) # 2., + lowres_cond = default(PKeys['lowres_cond'], False) # False, # for cascading diffusion - https://cascaded-diffusion.github.io/ + layer_attns = default(PKeys['layer_attns'], True) # True, + layer_attns_depth = default(PKeys['layer_attns_depth'], 1) # 1, + layer_attns_add_text_cond = default(PKeys['layer_attns_add_text_cond'], True) # True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 + attend_at_middle = default(PKeys['attend_at_middle'], True) # True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) + layer_cross_attns = default(PKeys['layer_cross_attns'], True) # True, + use_linear_attn = default(PKeys['use_linear_attn'], False) # False, + use_linear_cross_attn = default(PKeys['use_linear_cross_attn'], False) # False, + + cond_on_text = default(PKeys['cond_on_text'], True) # True, + max_text_len = default(PKeys['max_text_len'], 256) # 256, + init_dim = default(PKeys['init_dim'], None) # None, + resnet_groups = default(PKeys['resnet_groups'], 8) # 8, + init_conv_kernel_size = default(PKeys['init_conv_kernel_size'], 7) # 7, # kernel size of initial conv, if not using cross embed + init_cross_embed = default(PKeys['init_cross_embed'], False) # False, + init_cross_embed_kernel_sizes = default(PKeys['init_cross_embed_kernel_sizes'], (3, 7, 15)) # (3, 7, 15), + cross_embed_downsample = default(PKeys['cross_embed_downsample'], False) # False, + cross_embed_downsample_kernel_sizes = default(PKeys['cross_embed_downsample_kernel_sizes'], (2,4)) # (2, 4), + + + attn_pool_text = default(PKeys['attn_pool_text'], True) # True, + attn_pool_num_latents = default(PKeys['attn_pool_num_latents'], 32) # 32, + dropout = default(PKeys['dropout'], 0.) # 0., + memory_efficient = default(PKeys['memory_efficient'], False) # False, + init_conv_to_final_conv_residual = default(PKeys['init_conv_to_final_conv_residual'], False) # False, + + + use_global_context_attn = default(PKeys['use_global_context_attn'], True) # True, + scale_skip_connection = default(PKeys['scale_skip_connection'], True) # True, + final_resnet_block = default(PKeys['final_resnet_block'], True) # True, + final_conv_kernel_size = default(PKeys['final_conv_kernel_size'], 3) # 3, + + + cosine_sim_attn = default(PKeys['cosine_sim_attn'], False) # False, + self_cond = default(PKeys['self_cond'], False) # False, + combine_upsample_fmaps = default(PKeys['combine_upsample_fmaps'], False) # False, # combine feature maps from all upsample blocks, used in unet squared successfully + pixel_shuffle_upsample = default(PKeys['pixel_shuffle_upsample'], False) # False , # may address checkboard artifacts + beginning_and_final_conv_present = default(PKeys['beginning_and_final_conv_present'], True) # True , #TODO add cross-attn, doesnt work yet...whether or not to have final conv layer + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + + self.CKeys=CKeys + + # + for debug + if CKeys['Debug_ModelPack']==1: + print("Check the inputs:") + print(json.dumps(PKeys, indent=4)) + + + + assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8' + + if dim < 128: + print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/') + + + + + # determine dimensions + + self.channels = channels + self.channels_out = default(channels_out, channels) + + # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis + # (2) in self conditioning, one appends the predict x0 (x_start) + init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) + init_dim = default(init_dim, dim) + + self.self_cond = self_cond + + # optional image conditioning + + self.has_cond_image = cond_images_channels > 0 + self.cond_images_channels = cond_images_channels + + init_channels += cond_images_channels + + + self.beginning_and_final_conv_present=beginning_and_final_conv_present + + + if self.beginning_and_final_conv_present: + self.init_conv = CrossEmbedLayer( + init_channels, dim_out = init_dim, + kernel_sizes = init_cross_embed_kernel_sizes, + stride = 1 + ) if init_cross_embed else nn.Conv1d( + init_channels, init_dim, + init_conv_kernel_size, + padding = init_conv_kernel_size // 2) + + if self.CKeys['Debug_ModelPack']==1 and self.beginning_and_final_conv_present: + print("On self.init_conv:") + print(f"init_channels: {str(init_channels)}\n init_dim: {str(init_dim)}") + print("On self.init_conv, batch#=1, init_ch=2*#seq_channel, seq_len=128") + summary(self.init_conv, (1, init_channels, 128), verbose=1) + # # + for debug + # if self.CKeys['Debug_ModelPack']==1: + # print("On self.init_conv") + # summary(self.init_conv, (1, init_channels, init_dim)) + + + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + # time conditioning + + cond_dim = default(cond_dim, dim) + time_cond_dim = dim * 4 * (2 if lowres_cond else 1) + + # embedding time for log(snr) noise from continuous version + + sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) + sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 + + + self.to_time_hiddens = nn.Sequential( + sinu_pos_emb, + nn.Linear(sinu_pos_emb_input_dim, time_cond_dim), + nn.SiLU() + ) + # # + for debug + # if self.CKeys['Debug_ModelPack']==1: + # print("On self.to_time_hiddens") + # print(" in dim: ", learned_sinu_pos_emb_dim) + # print(" ou dim: ", time_cond_dim) + # # summary(self.to_time_hiddens, (1,learned_sinu_pos_emb_dim), verbose=1) + # # summary(sinu_pos_emb, (1,learned_sinu_pos_emb_dim), verbose=1) + + + self.to_time_cond = nn.Sequential( + nn.Linear(time_cond_dim, time_cond_dim) + ) + + # project to time tokens as well as time hiddens + + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print(" to_time_tokens") + print(" in dim: ", time_cond_dim) # =4xdim cond_dim=dim + print(" ou dim: ", num_time_tokens) + self.to_time_tokens = nn.Sequential( + nn.Linear(time_cond_dim, cond_dim * num_time_tokens), + Rearrange('b (r d) -> b r d', r = num_time_tokens) + ) + + # low res aug noise conditioning + + self.lowres_cond = lowres_cond + + if lowres_cond: + self.to_lowres_time_hiddens = nn.Sequential( + LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim), + nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim), + nn.SiLU() + ) + + self.to_lowres_time_cond = nn.Sequential( + nn.Linear(time_cond_dim, time_cond_dim) + ) + + self.to_lowres_time_tokens = nn.Sequential( + nn.Linear(time_cond_dim, cond_dim * num_time_tokens), + Rearrange('b (r d) -> b r d', r = num_time_tokens) + ) + + # normalizations + + self.norm_cond = nn.LayerNorm(cond_dim) + + # text encoding conditioning (optional) + + self.text_to_cond = None + + if cond_on_text: #only add linear lear if cond dim is not text emnd dim + assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True' + if text_embed_dim != cond_dim: + self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) + self.text_cond_linear=True + + else: + print ("Text conditioning is equatl to cond_dim - no linear layer used") + self.text_cond_linear=False + + # finer control over whether to condition on text encodings + + self.cond_on_text = cond_on_text + + # attention pooling + + self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, + dim_head = attn_dim_head, heads = attn_heads, + num_latents = attn_pool_num_latents, + cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None + + # for classifier free guidance + + self.max_text_len = max_text_len + + self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) + self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim)) + + # for non-attention based text conditioning at all points in the network where time is also conditioned + + self.to_text_non_attn_cond = None + + if cond_on_text: + self.to_text_non_attn_cond = nn.Sequential( + nn.LayerNorm(cond_dim), + nn.Linear(cond_dim, time_cond_dim), + nn.SiLU(), + nn.Linear(time_cond_dim, time_cond_dim) + ) + + # attention related params + + attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn) + + num_layers = len(in_out) + + # resnet block klass + + num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) + resnet_groups = cast_tuple(resnet_groups, num_layers) + + resnet_klass = partial(ResnetBlock, **attn_kwargs) + + layer_attns = cast_tuple(layer_attns, num_layers) + layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) + layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) + + use_linear_attn = cast_tuple(use_linear_attn, num_layers) + use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers) + + assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) + + # downsample klass + + downsample_klass = Downsample + + if cross_embed_downsample: + downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) + + # initial resnet block (for memory efficient unet) + + self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None + + # scale for resnet skip connections + + self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn] + reversed_layer_params = list(map(reversed, layer_params)) + + # downsampling layers + + skip_connect_dims = [] # keep track of skip connection dimensions + + for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)): + is_last = ind >= (num_resolutions - 1) + + layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None + + if layer_attn: + transformer_block_klass = TransformerBlock + elif layer_use_linear_attn: + transformer_block_klass = LinearAttentionTransformerBlock + else: + transformer_block_klass = Identity + + current_dim = dim_in + + # whether to pre-downsample, from memory efficient unet + + pre_downsample = None + + if memory_efficient: + pre_downsample = downsample_klass(dim_in, dim_out) + current_dim = dim_out + + skip_connect_dims.append(current_dim) + + # whether to do post-downsample, for non-memory efficient unet + + post_downsample = None + if not memory_efficient: + post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv1d(dim_in, dim_out, 3, padding = 1), nn.Conv1d(dim_in, dim_out, 1)) + + self.downs.append( + nn.ModuleList([ + pre_downsample, + resnet_klass( + current_dim, current_dim, + cond_dim = layer_cond_dim, + linear_attn = layer_use_linear_cross_attn, + time_cond_dim = time_cond_dim, groups = groups + ), + nn.ModuleList([ + ResnetBlock( + current_dim, current_dim, + time_cond_dim = time_cond_dim, + groups = groups, + use_gca = use_global_context_attn + ) for _ in range(layer_num_resnet_blocks)]), + transformer_block_klass( + dim = current_dim, + depth = layer_attn_depth, + ff_mult = ff_mult, + context_dim = cond_dim, + **attn_kwargs), + post_downsample + ]) + ) + + # middle layers + + mid_dim = dims[-1] + + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) + self.mid_attn = EinopsToAndFrom('b c h', 'b h c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) + + # upsample klass + + upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample + + # upsampling layers + + upsample_fmap_dims = [] + + for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)): + is_last = ind == (len(in_out) - 1) + + layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None + + if layer_attn: + transformer_block_klass = TransformerBlock + elif layer_use_linear_attn: + transformer_block_klass = LinearAttentionTransformerBlock + else: + transformer_block_klass = Identity + + skip_connect_dim = skip_connect_dims.pop() + + upsample_fmap_dims.append(dim_out) + + self.ups.append(nn.ModuleList([ + resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), + nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), + transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), + upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() + ])) + + # whether to combine feature maps from all upsample blocks before final resnet block out + + self.upsample_combiner = UpsampleCombiner( + dim = dim, + enabled = combine_upsample_fmaps, + dim_ins = upsample_fmap_dims, + dim_outs = dim + ) + + # whether to do a final residual from initial conv to the final resnet block out + + self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual + final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0) + + # final optional resnet block and convolution out + + self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None + + final_conv_dim_in = dim if final_resnet_block else final_conv_dim + final_conv_dim_in += (channels if lowres_cond else 0) + + if self.beginning_and_final_conv_present: + print (final_conv_dim_in, self.channels_out) + self.final_conv = nn.Conv1d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2) + + if self.beginning_and_final_conv_present: + zero_init_(self.final_conv) + + # if the current settings for the unet are not correct + # for cascading DDPM, then reinit the unet with the right settings + def cast_model_parameters( + self, + *, + lowres_cond, + text_embed_dim, + channels, + channels_out, + cond_on_text + ): + # # ---------------------------------------------- + # if lowres_cond == self.lowres_cond and \ + # channels == self.channels and \ + # cond_on_text == self.cond_on_text and \ + # text_embed_dim == self._locals['text_embed_dim'] and \ + # channels_out == self.channels_out: + # return self + # +++++++++++++++++++++++++++++++++++++++++++++++ + if lowres_cond == self.lowres_cond and \ + channels == self.channels and \ + cond_on_text == self.cond_on_text and \ + text_embed_dim == self._locals['PKeys']['text_embed_dim'] and \ + channels_out == self.channels_out: + return self + + # # -------------------------------------- + # updated_kwargs = dict( + # lowres_cond = lowres_cond, + # text_embed_dim = text_embed_dim, + # channels = channels, + # channels_out = channels_out, + # cond_on_text = cond_on_text + # ) + # ++++++++++++++++++++++++++++++++++++++ + write_key=dict( + lowres_cond = lowres_cond, + text_embed_dim = text_embed_dim, + channels = channels, + channels_out = channels_out, + cond_on_text = cond_on_text + ) + # this_PKeys=prepare_UNet_keys(write_key) + old_PKeys=self._locals['PKeys'] + this_PKeys=modify_keys(old_PKeys, write_key) + # write the new arguments + updated_kwargs = dict( + CKeys=self.CKeys, + PKeys=this_PKeys, + ) + + return self.__class__(**{**self._locals, **updated_kwargs}) + + # methods for returning the full unet config as well as its parameter state + + def to_config_and_state_dict(self): + return self._locals, self.state_dict() + + # class method for rehydrating the unet from its config and state dict + + @classmethod + def from_config_and_state_dict(klass, config, state_dict): + unet = klass(**config) + unet.load_state_dict(state_dict) + return unet + + # methods for persisting unet to disk + + def persist_to_file(self, path): + path = Path(path) + path.parents[0].mkdir(exist_ok = True, parents = True) + + config, state_dict = self.to_config_and_state_dict() + pkg = dict(config = config, state_dict = state_dict) + torch.save(pkg, str(path)) + + # class method for rehydrating the unet from file saved with `persist_to_file` + + @classmethod + def hydrate_from_file(klass, path): + path = Path(path) + assert path.exists() + pkg = torch.load(str(path)) + + assert 'config' in pkg and 'state_dict' in pkg + config, state_dict = pkg['config'], pkg['state_dict'] + + return Unet.from_config_and_state_dict(config, state_dict) + + # forward with classifier free guidance + + def forward_with_cond_scale( + self, + *args, + cond_scale = 1., + **kwargs + ): + logits = self.forward(*args, **kwargs) + + if cond_scale == 1: + return logits + + null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + # forward fun returns the loss + # Here, for Model B: + # x : (batch, 1, seq_len) values: from noised image/Y + # time : (batch) values: + # text_embeds : None + # text_mask: None + # cond_images: input sequence/X + # lowres_noise_times: None + # lowres_cond_img: None + # cond_drop_prob: 0.1 + # self_cond: from noised image/Y + + def forward( + self, + x, + time, + *, + lowres_cond_img = None, + lowres_noise_times = None, + text_embeds = None, + text_mask = None, + cond_images = None, + self_cond = None, + cond_drop_prob = 0. + ): + + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("========================================") + print("Here are OneD_Unet:forward") + ii=0 + + batch_size, device = x.shape[0], x.device + # condition on self + + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("Check inputs: ") + print(" x-0 .dim: ", x.shape, f"[batch, {str(self.channels)}, seq_len]") + print(" time .dim: ", time.shape, "batch") + if lowres_cond_img==None: + print(" lowres_cond_img: None") + else: + print(" lowres_cond_img.dim: ", lowres_cond_img.shape) + if lowres_noise_times==None: + print(" lowres_noise_times: None") + else: + print(" lowres_noise_times.dim: ", lowres_noise_times.shape) + if cond_images==None: + print(" cond_images dim: None") + else: + # Model B is used + print(" cond_images dim: ", cond_images.shape, f"[batch, {str(self.cond_images_channels)}, seq_len]") + if text_embeds==None: + print(" text_embeds.dim: None") + else: + # Model A is used + print(" text_embeds.dim: ", text_embeds.shape) + if self_cond==None: + print(" self_cond: None") + else: + print(" self_cond.dim: ", self_cond.shape) + print("\n\n") + + if self.self_cond: + self_cond = default(self_cond, lambda: torch.zeros_like(x)) + x = torch.cat((x, self_cond), dim = 1) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(ii) + print("self_cond dim: ", self_cond.shape) + print("After cat(x, self_cond)-> x. dim: ", x.shape) + + # add low resolution conditioning, if present + + assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' + assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present' + + if exists(lowres_cond_img): + x = torch.cat((x, lowres_cond_img), dim = 1) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(ii) + print("lowres_cond_img dim: ", lowres_cond_img.shape) + print("After cat(x, lowres_cond_img) dim: ", x.shape) + + # condition on input image + + assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' + + if exists(cond_images): + assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' + cond_images = resize_image_to(cond_images, x.shape[-1]) + + # ++++++++++++++++++++++++++++++ + # add cond_images (from X) into generation process... + x = torch.cat((cond_images.to(device), x.to(device)), dim = 1) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(ii) + print("cond_images dim: ", cond_images.shape, f"[batch, {str(self.cond_images_channels)}, max_seq_len]") + print("After cat(cond_images, x), x dim: ", x.shape, "[batch, cond_images_channels+images_channels, max_seq_len]") + + # initial convolution + + if self.beginning_and_final_conv_present: + x = self.init_conv(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(ii) + print("After self.init_conv(x)-> x dim: ", x.shape, "[batch, UNet:dim, max_seq_len]") + + # init conv residual + + if self.init_conv_to_final_conv_residual: + init_conv_residual = x.clone() + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(ii) + print("x.clone()->init_conv_resi, dim: ", init_conv_residual.shape) + + # time conditioning + + time_hiddens = self.to_time_hiddens(time) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(ii) + print("time dim: ", time.shape, "[batch]") + print("self.to_time_hiddens(time)-> time_hiddens .dim: ", time_hiddens.shape, "[batch, 4xUNet:dim]") + + # derive time tokens + + time_tokens = self.to_time_tokens(time_hiddens) + t = self.to_time_cond(time_hiddens) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(ii) + print("self.to_time_tokens(time_hiddens)-> time_tokens dim: ", time_tokens.shape, "[batch, num_time_tokens,4xdim/num_time_tokens]") + print("self.to_time_cond(time_hiddens)-> t dim: ", t.shape, "[batch, 4xUNet:dim]") + + # add lowres time conditioning to time hiddens + # and add lowres time tokens along sequence dimension for attention + + if self.lowres_cond: + lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times) + lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens) + lowres_t = self.to_lowres_time_cond(lowres_time_hiddens) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(ii) + print("self.to_lowres_time_hiddens(lowres_noise_times)-> lowres_time_hiddens .dim: ", lowres_time_hiddens.shape) + print("self.to_lowres_time_tokens(lowres_time_hiddens)-> lowres_time_tokens .dim: ", lowres_time_tokens.shape) + print("self.to_lowres_time_cond(lowres_time_hiddens)-> lowres_t.dim: ", lowres_t.shape) + + t = t + lowres_t + + time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii +=1 + print(ii) + print("After cat(time_tokens, lowres_time_tokens)-> time_tokens dim: ", time_tokens.shape) + + # text conditioning + + text_tokens = None + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("UNet.cond_on_text: ", self.cond_on_text) + + if exists(text_embeds) and self.cond_on_text: + + # conditional dropout + + text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) + + text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1') + text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1') + + # calculate text embeds + + if self.text_cond_linear: + text_tokens = self.text_to_cond(text_embeds) + else: + text_tokens=text_embeds + if self.CKeys['Debug_ModelPack']==1: + print("On text conditioning part...") + ii +=1 + print(ii) + print("text_embeds->text_tokens.dim: ", text_tokens.shape) + + text_tokens = text_tokens[:, :self.max_text_len] + if self.CKeys['Debug_ModelPack']==1: + ii +=1 + print(ii) + print("text_tokens[:,:max_text_len]-> text_tokens.dim: ", text_tokens.shape) + print("do the same for text_mask") + + if exists(text_mask): + text_mask = text_mask[:, :self.max_text_len] + + text_tokens_len = text_tokens.shape[1] + remainder = self.max_text_len - text_tokens_len + + if remainder > 0: + + text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) + + if exists(text_mask): + if remainder > 0: + text_mask = F.pad(text_mask, (0, remainder), value = False) + + + text_mask = rearrange(text_mask, 'b n -> b n 1') + text_keep_mask_embed = text_mask & text_keep_mask_embed + + null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working + text_tokens = torch.where( + text_keep_mask_embed, + text_tokens, + null_text_embed + ) + + if exists(self.attn_pool): + text_tokens = self.attn_pool(text_tokens) + if self.CKeys['Debug_ModelPack']==1: + ii+=1 + print(ii) + print("self.attn_pool(text_tokens)->text_tokens.dim: ", text_tokens.shape) + + # extra non-attention conditioning by projecting and then summing text embeddings to time + # termed as text hiddens + + mean_pooled_text_tokens = text_tokens.mean(dim = -2) + if self.CKeys['Debug_ModelPack']==1: + ii +=1 + print(ii) + print("text_tokens.mean(dim=-2)->mean_pooled_text_tokens.dim: ", mean_pooled_text_tokens.shape) + + text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens) + if self.CKeys['Debug_ModelPack']==1: + ii +=1 + print(ii) + print("self.to_text_non_attn_cond(mean_pooled_text_tokens)->text_hiddens.dim: ",text_hiddens.shape) + + null_text_hidden = self.null_text_hidden.to(t.dtype) + + text_hiddens = torch.where( + text_keep_mask_hidden, + text_hiddens, + null_text_hidden + ) + + t = t + text_hiddens + if self.CKeys['Debug_ModelPack']==1: + ii +=1 + print(ii) + print("t+text_hiddens.dim: ", t.shape) + + + # main conditioning tokens (c) + + c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(ii) + print("cat(time_tokens, text_tokens)-> c dim: ", c.shape) + + # normalize conditioning tokens + + c = self.norm_cond(c) + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(ii) + print("self.norm_cond(c)->c dim:", c.shape) + + # initial resnet block (for memory efficient unet) + + if exists(self.init_resnet_block): + x = self.init_resnet_block(x, t) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(ii) + print("self.init_resnet_block(x,t)-> x dim: ", x.shape) + + # go through the layers of the unet, down and up + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("Before unet, down and up, ") + print(" x dim: ", x.shape) + print(" t dim: ", t.shape) + print(" c dim: ", c.shape) + # ii=0 + + hiddens = [] + + for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs: + if exists(pre_downsample): + x = pre_downsample(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, pre_downsample(x)=>x dim: ", x.shape) + + x = init_block(x, t, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, init_block(x,t,c)=>x dim: ", x.shape) + + for resnet_block in resnet_blocks: + x = resnet_block(x, t) + hiddens.append(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, resnet_block(x,t)=> x dim: ", x.shape) + + + x = attn_block(x, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, attn_block(x,c)=> x dim: ", x.shape) + + hiddens.append(x) + + if exists(post_downsample): + + x = post_downsample(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, post_downsample(x)=> x dim: ", x.shape) + + x = self.mid_block1(x, t, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, mid_block_1(x,t,c)=> x dim: ", x.shape) + + if exists(self.mid_attn): + x = self.mid_attn(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, mid_attn(x)=> x dim: ", x.shape) + + x = self.mid_block2(x, t, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, mid_block_2(x,t,c)=> x dim: ", x.shape) + + add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) + + up_hiddens = [] + + for init_block, resnet_blocks, attn_block, upsample in self.ups: + x = add_skip_connection(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, add_skip_connection(x)=> x dim: ", x.shape) + # + x = init_block(x, t, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, init_block(x,t,c)=> x dim: ", x.shape) + + for resnet_block in resnet_blocks: + x = add_skip_connection(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, add_skip_connection(x)=> x dim: ", x.shape) + x = resnet_block(x, t) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, resnet_block(x,t)=> x dim: ", x.shape) + + x = attn_block(x, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, attn_block(x,c)=> x dim: ", x.shape) + + up_hiddens.append(x.contiguous()) + x = upsample(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, upsample(x)=> x dim: ", x.shape) + + # whether to combine all feature maps from upsample blocks + + x = self.upsample_combiner(x, up_hiddens) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, upsample_combiner(x,up_hiddens)=> x dim: ", x.shape) + + # final top-most residual if needed + + if self.init_conv_to_final_conv_residual: + x = torch.cat((x, init_conv_residual), dim = 1) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, cat(x,init_conv_residual)=> x dim: ", x.shape) + + if exists(self.final_res_block): + x = self.final_res_block(x, t) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, final_res_block(x,t)=> x dim: ", x.shape) + + if exists(lowres_cond_img): + x = torch.cat((x, lowres_cond_img), dim = 1) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, cat(x,lowres_cond_img)=> x dim: ", x.shape) + + if self.beginning_and_final_conv_present: + x=self.final_conv(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, final_conv(x)=> x dim: ", x.shape) + + return x + +######################################################## +## 1D Unet +######################################################## +class OneD_Unet_Old(nn.Module): + def __init__( + self, + *, + dim, + # image_embed_dim = 1024, # NOT used so far + text_embed_dim = 768, # get_encoded_dim(DEFAULT_T5_NAME) + num_resnet_blocks = 1, + cond_dim = None, + num_image_tokens = 4, + num_time_tokens = 2, + learned_sinu_pos_emb_dim = 16, + out_dim = None, + dim_mults=(1, 2, 4, 8), + cond_images_channels = 0, + channels = 3, + channels_out = None, + attn_dim_head = 64, + attn_heads = 8, + ff_mult = 2., + lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ + layer_attns = True, + layer_attns_depth = 1, + layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 + attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) + layer_cross_attns = True, + use_linear_attn = False, + use_linear_cross_attn = False, + cond_on_text = True, + max_text_len = 256, + init_dim = None, + resnet_groups = 8, + init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed + init_cross_embed = False, + init_cross_embed_kernel_sizes = (3, 7, 15), + cross_embed_downsample = False, + cross_embed_downsample_kernel_sizes = (2, 4), + attn_pool_text = True, + attn_pool_num_latents = 32, + dropout = 0., + memory_efficient = False, + init_conv_to_final_conv_residual = False, + use_global_context_attn = True, + scale_skip_connection = True, + final_resnet_block = True, + final_conv_kernel_size = 3, + cosine_sim_attn = False, + self_cond = False, + combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully + pixel_shuffle_upsample = False , # may address checkboard artifacts + beginning_and_final_conv_present = True , #TODO add cross-attn, doesnt work yet...whether or not to have final conv layer + # +++++++++++++++++++++++++++ + CKeys=None, + + ): + super().__init__() + + # +++++++++++++++++++++++++++++++ + self.CKeys=CKeys + + assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8' + + if dim < 128: + print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/') + + + # save locals to take care of some hyperparameters for cascading DDPM + + self._locals = locals() + self._locals.pop('self', None) + self._locals.pop('__class__', None) + # + for debug + if CKeys['Debug_ModelPack']==1: + print("Showing the input:") + print(json.dumps(self._locals, indent=4)) + + # determine dimensions + + self.channels = channels + self.channels_out = default(channels_out, channels) + + # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis + # (2) in self conditioning, one appends the predict x0 (x_start) + init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) + init_dim = default(init_dim, dim) + + self.self_cond = self_cond + + # optional image conditioning + + self.has_cond_image = cond_images_channels > 0 + self.cond_images_channels = cond_images_channels + + init_channels += cond_images_channels + + + self.beginning_and_final_conv_present=beginning_and_final_conv_present + + + if self.beginning_and_final_conv_present: + self.init_conv = CrossEmbedLayer( + init_channels, dim_out = init_dim, + kernel_sizes = init_cross_embed_kernel_sizes, + stride = 1 + ) if init_cross_embed else nn.Conv1d( + init_channels, init_dim, + init_conv_kernel_size, + padding = init_conv_kernel_size // 2) + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + # time conditioning + + cond_dim = default(cond_dim, dim) + time_cond_dim = dim * 4 * (2 if lowres_cond else 1) + + # embedding time for log(snr) noise from continuous version + + sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) + sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 + + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print(" to_time_hiddens") + print(" ou dim: ", time_cond_dim) + self.to_time_hiddens = nn.Sequential( + sinu_pos_emb, + nn.Linear(sinu_pos_emb_input_dim, time_cond_dim), + nn.SiLU() + ) + + self.to_time_cond = nn.Sequential( + nn.Linear(time_cond_dim, time_cond_dim) + ) + + # project to time tokens as well as time hiddens + + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print(" to_time_tokens") + print(" in dim: ", time_cond_dim) # =4xdim cond_dim=dim + print(" ou dim: ", num_time_tokens) + self.to_time_tokens = nn.Sequential( + nn.Linear(time_cond_dim, cond_dim * num_time_tokens), + Rearrange('b (r d) -> b r d', r = num_time_tokens) + ) + + # low res aug noise conditioning + + self.lowres_cond = lowres_cond + + if lowres_cond: + self.to_lowres_time_hiddens = nn.Sequential( + LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim), + nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim), + nn.SiLU() + ) + + self.to_lowres_time_cond = nn.Sequential( + nn.Linear(time_cond_dim, time_cond_dim) + ) + + self.to_lowres_time_tokens = nn.Sequential( + nn.Linear(time_cond_dim, cond_dim * num_time_tokens), + Rearrange('b (r d) -> b r d', r = num_time_tokens) + ) + + # normalizations + + self.norm_cond = nn.LayerNorm(cond_dim) + + # text encoding conditioning (optional) + + self.text_to_cond = None + + if cond_on_text: #only add linear lear if cond dim is not text emnd dim + assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True' + if text_embed_dim != cond_dim: + self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) + self.text_cond_linear=True + + else: + print ("Text conditioning is equatl to cond_dim - no linear layer used") + self.text_cond_linear=False + + # finer control over whether to condition on text encodings + + self.cond_on_text = cond_on_text + + # attention pooling + + self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, + dim_head = attn_dim_head, heads = attn_heads, + num_latents = attn_pool_num_latents, + cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None + + # for classifier free guidance + + self.max_text_len = max_text_len + + self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) + self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim)) + + # for non-attention based text conditioning at all points in the network where time is also conditioned + + self.to_text_non_attn_cond = None + + if cond_on_text: + self.to_text_non_attn_cond = nn.Sequential( + nn.LayerNorm(cond_dim), + nn.Linear(cond_dim, time_cond_dim), + nn.SiLU(), + nn.Linear(time_cond_dim, time_cond_dim) + ) + + # attention related params + + attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn) + + num_layers = len(in_out) + + # resnet block klass + + num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) + resnet_groups = cast_tuple(resnet_groups, num_layers) + + resnet_klass = partial(ResnetBlock, **attn_kwargs) + + layer_attns = cast_tuple(layer_attns, num_layers) + layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) + layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) + + use_linear_attn = cast_tuple(use_linear_attn, num_layers) + use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers) + + assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) + + # downsample klass + + downsample_klass = Downsample + + if cross_embed_downsample: + downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) + + # initial resnet block (for memory efficient unet) + + self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None + + # scale for resnet skip connections + + self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn] + reversed_layer_params = list(map(reversed, layer_params)) + + # downsampling layers + + skip_connect_dims = [] # keep track of skip connection dimensions + + for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)): + is_last = ind >= (num_resolutions - 1) + + layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None + + if layer_attn: + transformer_block_klass = TransformerBlock + elif layer_use_linear_attn: + transformer_block_klass = LinearAttentionTransformerBlock + else: + transformer_block_klass = Identity + + current_dim = dim_in + + # whether to pre-downsample, from memory efficient unet + + pre_downsample = None + + if memory_efficient: + pre_downsample = downsample_klass(dim_in, dim_out) + current_dim = dim_out + + skip_connect_dims.append(current_dim) + + # whether to do post-downsample, for non-memory efficient unet + + post_downsample = None + if not memory_efficient: + post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv1d(dim_in, dim_out, 3, padding = 1), nn.Conv1d(dim_in, dim_out, 1)) + + self.downs.append( + nn.ModuleList([ + pre_downsample, + resnet_klass( + current_dim, current_dim, + cond_dim = layer_cond_dim, + linear_attn = layer_use_linear_cross_attn, + time_cond_dim = time_cond_dim, groups = groups + ), + nn.ModuleList([ + ResnetBlock( + current_dim, current_dim, + time_cond_dim = time_cond_dim, + groups = groups, + use_gca = use_global_context_attn + ) for _ in range(layer_num_resnet_blocks)]), + transformer_block_klass( + dim = current_dim, + depth = layer_attn_depth, + ff_mult = ff_mult, + context_dim = cond_dim, + **attn_kwargs), + post_downsample + ]) + ) + + # middle layers + + mid_dim = dims[-1] + + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) + self.mid_attn = EinopsToAndFrom('b c h', 'b h c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) + + # upsample klass + + upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample + + # upsampling layers + + upsample_fmap_dims = [] + + for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)): + is_last = ind == (len(in_out) - 1) + + layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None + + if layer_attn: + transformer_block_klass = TransformerBlock + elif layer_use_linear_attn: + transformer_block_klass = LinearAttentionTransformerBlock + else: + transformer_block_klass = Identity + + skip_connect_dim = skip_connect_dims.pop() + + upsample_fmap_dims.append(dim_out) + + self.ups.append(nn.ModuleList([ + resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), + nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), + transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), + upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() + ])) + + # whether to combine feature maps from all upsample blocks before final resnet block out + + self.upsample_combiner = UpsampleCombiner( + dim = dim, + enabled = combine_upsample_fmaps, + dim_ins = upsample_fmap_dims, + dim_outs = dim + ) + + # whether to do a final residual from initial conv to the final resnet block out + + self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual + final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0) + + # final optional resnet block and convolution out + + self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None + + final_conv_dim_in = dim if final_resnet_block else final_conv_dim + final_conv_dim_in += (channels if lowres_cond else 0) + + if self.beginning_and_final_conv_present: + print (final_conv_dim_in, self.channels_out) + self.final_conv = nn.Conv1d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2) + + if self.beginning_and_final_conv_present: + zero_init_(self.final_conv) + + # if the current settings for the unet are not correct + # for cascading DDPM, then reinit the unet with the right settings + def cast_model_parameters( + self, + *, + lowres_cond, + text_embed_dim, + channels, + channels_out, + cond_on_text + ): + if lowres_cond == self.lowres_cond and \ + channels == self.channels and \ + cond_on_text == self.cond_on_text and \ + text_embed_dim == self._locals['text_embed_dim'] and \ + channels_out == self.channels_out: + return self + + updated_kwargs = dict( + lowres_cond = lowres_cond, + text_embed_dim = text_embed_dim, + channels = channels, + channels_out = channels_out, + cond_on_text = cond_on_text + ) + + return self.__class__(**{**self._locals, **updated_kwargs}) + + # methods for returning the full unet config as well as its parameter state + + def to_config_and_state_dict(self): + return self._locals, self.state_dict() + + # class method for rehydrating the unet from its config and state dict + + @classmethod + def from_config_and_state_dict(klass, config, state_dict): + unet = klass(**config) + unet.load_state_dict(state_dict) + return unet + + # methods for persisting unet to disk + + def persist_to_file(self, path): + path = Path(path) + path.parents[0].mkdir(exist_ok = True, parents = True) + + config, state_dict = self.to_config_and_state_dict() + pkg = dict(config = config, state_dict = state_dict) + torch.save(pkg, str(path)) + + # class method for rehydrating the unet from file saved with `persist_to_file` + + @classmethod + def hydrate_from_file(klass, path): + path = Path(path) + assert path.exists() + pkg = torch.load(str(path)) + + assert 'config' in pkg and 'state_dict' in pkg + config, state_dict = pkg['config'], pkg['state_dict'] + + return Unet.from_config_and_state_dict(config, state_dict) + + # forward with classifier free guidance + + def forward_with_cond_scale( + self, + *args, + cond_scale = 1., + **kwargs + ): + logits = self.forward(*args, **kwargs) + + if cond_scale == 1: + return logits + + null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + # forward fun returns the loss + # Here, for Model B: + # x : (batch, 1, seq_len) values: from noised image/Y + # time : (batch) values: + # text_embeds : None + # text_mask: None + # cond_images: input sequence/X + # lowres_noise_times: None + # lowres_cond_img: None + # cond_drop_prob: 0.1 + # self_cond: from noised image/Y + + def forward( + self, + x, + time, + *, + lowres_cond_img = None, + lowres_noise_times = None, + text_embeds = None, + text_mask = None, + cond_images = None, + self_cond = None, + cond_drop_prob = 0. + ): + + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("========================================") + print("Here are OneD_Unet:forward") + + batch_size, device = x.shape[0], x.device + # condition on self + + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("Check inputs: ") + print(" x-0 dim: ", x.shape, "[batch, 1, seq_len]") + print(" time dim: ", time.shape, "") + print(" cond_images dim: ", cond_images.shape, "") + + + if self.self_cond: + self_cond = default(self_cond, lambda: torch.zeros_like(x)) + x = torch.cat((x, self_cond), dim = 1) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("After self_cond x dim: ", x.shape) + + # add low resolution conditioning, if present + + assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' + assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present' + + if exists(lowres_cond_img): + x = torch.cat((x, lowres_cond_img), dim = 1) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("After lowres_cond_img x dim: ", x.shape) + + # condition on input image + + assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' + + if exists(cond_images): + assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' + cond_images = resize_image_to(cond_images, x.shape[-1]) + + # ++++++++++++++++++++++++++++++ + # add cond_images (from X) into generation process... + x = torch.cat((cond_images.to(device), x.to(device)), dim = 1) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("cond_images dim: ", cond_images.shape, "[batch, 1, max_seq_len]") + print("After cond_images, x dim: ", x.shape, "[batch, 2, max_seq_len]") + + # initial convolution + + if self.beginning_and_final_conv_present: + x = self.init_conv(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("After init_conv, x dim: ", x.shape, "[batch, UNet:dim, max_seq_len]") + + # init conv residual + + if self.init_conv_to_final_conv_residual: + init_conv_residual = x.clone() + + # time conditioning + + time_hiddens = self.to_time_hiddens(time) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("time dim: ", time.shape, "[batch]") + print("after, time_hiddens dim: ", time_hiddens.shape, "[batch, 4xUNet:dim]") + + # derive time tokens + + time_tokens = self.to_time_tokens(time_hiddens) + t = self.to_time_cond(time_hiddens) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("time_tokens dim: ", time_tokens.shape, "[batch, num_time_tokens,4xdim/num_time_tokens]") + print("after to_time_cond t dim: ", t.shape, "[batch, 4xUNet:dim]") + + # add lowres time conditioning to time hiddens + # and add lowres time tokens along sequence dimension for attention + + if self.lowres_cond: + lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times) + lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens) + lowres_t = self.to_lowres_time_cond(lowres_time_hiddens) + + t = t + lowres_t + + time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("After lowres_cond, time_tokens dim: ", time_tokens.shape) + + # text conditioning + + text_tokens = None + + if exists(text_embeds) and self.cond_on_text: + + # conditional dropout + + text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) + + text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1') + text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1') + + # calculate text embeds + + if self.text_cond_linear: + text_tokens = self.text_to_cond(text_embeds) + else: + text_tokens=text_embeds + + text_tokens = text_tokens[:, :self.max_text_len] + + if exists(text_mask): + text_mask = text_mask[:, :self.max_text_len] + + text_tokens_len = text_tokens.shape[1] + remainder = self.max_text_len - text_tokens_len + + if remainder > 0: + + text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) + + if exists(text_mask): + if remainder > 0: + text_mask = F.pad(text_mask, (0, remainder), value = False) + + + text_mask = rearrange(text_mask, 'b n -> b n 1') + text_keep_mask_embed = text_mask & text_keep_mask_embed + + null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working + text_tokens = torch.where( + text_keep_mask_embed, + text_tokens, + null_text_embed + ) + + if exists(self.attn_pool): + text_tokens = self.attn_pool(text_tokens) + + # extra non-attention conditioning by projecting and then summing text embeddings to time + # termed as text hiddens + + mean_pooled_text_tokens = text_tokens.mean(dim = -2) + + text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens) + + null_text_hidden = self.null_text_hidden.to(t.dtype) + + text_hiddens = torch.where( + text_keep_mask_hidden, + text_hiddens, + null_text_hidden + ) + + t = t + text_hiddens + + + # main conditioning tokens (c) + + c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("Merge time and text tokens, c dim: ", c.shape) + + # normalize conditioning tokens + + c = self.norm_cond(c) + + # initial resnet block (for memory efficient unet) + + if exists(self.init_resnet_block): + x = self.init_resnet_block(x, t) + + # go through the layers of the unet, down and up + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("Before unet, down and up, ") + print("x dim: ", x.shape) + print("t dim: ", t.shape) + print("c dim: ", c.shape) + ii=0 + + hiddens = [] + + for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs: + if exists(pre_downsample): + x = pre_downsample(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after pre_downsample x dim: ", x.shape) + + x = init_block(x, t, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after init_block(x,t,c) x dim: ", x.shape) + + for resnet_block in resnet_blocks: + x = resnet_block(x, t) + hiddens.append(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after resnet_block x dim: ", x.shape) + + + x = attn_block(x, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after attn_block x dim: ", x.shape) + + hiddens.append(x) + + if exists(post_downsample): + + x = post_downsample(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after post_downsample x dim: ", x.shape) + + x = self.mid_block1(x, t, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after mid_block_1 x dim: ", x.shape) + + if exists(self.mid_attn): + x = self.mid_attn(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after mid_attn x dim: ", x.shape) + + x = self.mid_block2(x, t, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after mid_block_2 x dim: ", x.shape) + + add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) + + up_hiddens = [] + + for init_block, resnet_blocks, attn_block, upsample in self.ups: + x = add_skip_connection(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after add_skip_connection x dim: ", x.shape) + # + x = init_block(x, t, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after init_block(x,t,c) x dim: ", x.shape) + + for resnet_block in resnet_blocks: + x = add_skip_connection(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after add_skip_connection x dim: ", x.shape) + x = resnet_block(x, t) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after resnet_block(x,t) x dim: ", x.shape) + + x = attn_block(x, c) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after attn_block(x,c) x dim: ", x.shape) + up_hiddens.append(x.contiguous()) + x = upsample(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after upsample(x) x dim: ", x.shape) + + # whether to combine all feature maps from upsample blocks + + x = self.upsample_combiner(x, up_hiddens) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after upsample_combiner(x,..) x dim: ", x.shape) + + # final top-most residual if needed + + if self.init_conv_to_final_conv_residual: + x = torch.cat((x, init_conv_residual), dim = 1) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after cat_init_conv_resi x dim: ", x.shape) + + if exists(self.final_res_block): + x = self.final_res_block(x, t) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after final_res_block(x,t) x dim: ", x.shape) + + if exists(lowres_cond_img): + x = torch.cat((x, lowres_cond_img), dim = 1) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after cat_x_lowres_cond_img x dim: ", x.shape) + + if self.beginning_and_final_conv_present: + x=self.final_conv(x) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + ii += 1 + print(F" {str(ii)}, after final_conv(x) x dim: ", x.shape) + + return x + +######################################################## +## null unets +######################################################## + +class NullUnet(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.lowres_cond = False + self.dummy_paramcast_model_parameterseter = nn.Parameter(torch.tensor([0.])) + + def cast_model_parameters(self, *args, **kwargs): + return self + + def forward(self, x, *args, **kwargs): + return x + +class Unet(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.lowres_cond = False + self.dummy_parameter = nn.Parameter(torch.tensor([0.])) + + def cast_model_parameters(self, *args, **kwargs): + return self + + def forward(self, x, *args, **kwargs): + return x + + +######################################################## +## Elucidated denoising model +## After: Tero Karras and Miika Aittala and Timo Aila and Samuli Laine, +## Elucidating the Design Space of Diffusion-Based Generative Models +## https://arxiv.org/abs/2206.00364, 2022 +######################################################## + +from math import sqrt + +Hparams_fields = [ + 'num_sample_steps', + 'sigma_min', + 'sigma_max', + 'sigma_data', + 'rho', + 'P_mean', + 'P_std', + 'S_churn', + 'S_tmin', + 'S_tmax', + 'S_noise' +] + +Hparams = namedtuple('Hparams', Hparams_fields) + +# helper functions + +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +# main class + +class ElucidatedImagen(nn.Module): + def __init__( + self, + unets, + *, + image_sizes, # for cascading ddpm, image size at each stage + text_encoder_name = '', # can be "DEFAULT_T5_NAME" + text_embed_dim = None, + channels = 3, + channels_out=3, + cond_drop_prob = 0.1, + random_crop_sizes = None, + lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level + per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find + condition_on_text = True, + auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader + dynamic_thresholding = True, + dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper + only_train_unet_number = None, + lowres_noise_schedule = 'linear', + num_sample_steps = 32, # number of sampling steps + sigma_min = 0.002, # min noise level + sigma_max = 80, # max noise level + sigma_data = 0.5, # standard deviation of data distribution + rho = 7, # controls the sampling schedule + P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training + P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training + S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper + S_tmin = 0.05, + S_tmax = 50, + S_noise = 1.003, + + loss_type=0, #0=MSE, + categorical_loss_ignore=None, + # +++++++++++++++++ + # device=None, + CKeys=None, + PKeys=None, + ): + super().__init__() + + # +++++++++++++++++++++++++++ + # self.device=device + self.CKeys=CKeys + self.PKeys=PKeys + + self.only_train_unet_number = only_train_unet_number + + self.condition_on_text = condition_on_text + self.unconditional = not condition_on_text + self.loss_type=loss_type + if self.loss_type>0: + self.categorical_loss=True + self.m = nn.LogSoftmax(dim=1) #used for some loss functins + else: + self.categorical_loss=False + + print("Loss type: ", self.loss_type) + self.categorical_loss_ignore=categorical_loss_ignore + + # channels + + self.channels = channels + self.channels_out = channels_out + + unets = cast_tuple(unets) + num_unets = len(unets) + + # randomly cropping for upsampler training + + self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets) + assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example' + + # lowres augmentation noise schedule + + self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule) + + # get text encoder + + self.text_embed_dim =text_embed_dim + + # construct unets + + self.unets = nn.ModuleList([]) + self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment + + print (f"Channels in={self.channels}, channels out={self.channels_out}") + for ind, one_unet in enumerate(unets): + + # # --------------------------- + # assert isinstance(one_unet, ( OneD_Unet, NullUnet)) + # ++++++++++++++++++++++++++++ + + is_first = ind == 0 + + # if the current settings for the unet are not correct + # for cascading DDPM, then reinit the unet with the right settings + # + # + for debug + # print('I am here') + print("Test on cast_model_parameters...") + print(not is_first) + print(self.condition_on_text) + print(self.text_embed_dim if self.condition_on_text else None) + print(self.channels) + print(self.channels_out) + one_unet = one_unet.cast_model_parameters( + lowres_cond = not is_first, + cond_on_text = self.condition_on_text, + text_embed_dim = self.text_embed_dim if self.condition_on_text else None, + channels = self.channels, + #channels_out = self.channels + channels_out = self.channels_out + ) + + self.unets.append(one_unet) + + # determine whether we are training on images or video + + is_video = False # only consider image case + self.is_video = is_video + + self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1' if not is_video else 'b -> b 1 1 1')) + self.resize_to = resize_video_to if is_video else resize_image_to + + + self.image_sizes = image_sizes + assert num_unets == len(self.image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}' + + self.sample_channels = cast_tuple(self.channels, num_unets) + + lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) + assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' + + self.lowres_sample_noise_level = lowres_sample_noise_level + self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level + + # classifier free guidance + + self.cond_drop_prob = cond_drop_prob + self.can_classifier_guidance = cond_drop_prob > 0. + + # normalize and unnormalize image functions + + self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity + self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity + self.input_image_range = (0. if auto_normalize_img else -1., 1.) + + # dynamic thresholding + + self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets) + self.dynamic_thresholding_percentile = dynamic_thresholding_percentile + +# # temporal interpolations + +# temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets) +# self.temporal_downsample_factor = temporal_downsample_factor + +# self.resize_cond_video_frames = resize_cond_video_frames +# self.temporal_downsample_divisor = temporal_downsample_factor[0] + +# assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1' +# assert tuple(sorted(temporal_downsample_factor, reverse = True)) == temporal_downsample_factor, 'temporal downsample factor must be in order of descending' + + + # elucidating parameters + + hparams = [ + num_sample_steps, + sigma_min, + sigma_max, + sigma_data, + rho, + P_mean, + P_std, + S_churn, + S_tmin, + S_tmax, + S_noise, + ] + + hparams = [cast_tuple(hp, num_unets) for hp in hparams] + self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)] + + # one temp parameter for keeping track of device + + # self.register_buffer('_temp', torch.tensor([0.]).to(device), persistent = False) + self.register_buffer('_temp', torch.tensor([0.]), persistent = False) + + # default to device of unets passed in + + self.to(next(self.unets.parameters()).device) + # for debug + print(next(self.unets.parameters()).device) + + # print ("Device used in ImagenEluc: ", self.device) + + def force_unconditional_(self): + self.condition_on_text = False + self.unconditional = True + + for unet in self.unets: + unet.cond_on_text = False + + @property + def device(self): + return self._temp.device + #return (device) + + def get_unet(self, unet_number): + assert 0 < unet_number <= len(self.unets) + index = unet_number - 1 + + if isinstance(self.unets, nn.ModuleList): + unets_list = [unet for unet in self.unets] + delattr(self, 'unets') + self.unets = unets_list + + if index != self.unet_being_trained_index: + for unet_index, unet in enumerate(self.unets): + unet.to(self.device if unet_index == index else 'cpu') + + self.unet_being_trained_index = index + return self.unets[index] + + def reset_unets_all_one_device(self, device = None): + device = default(device, self.device) + + self.unets = nn.ModuleList([*self.unets]) + self.unets.to(device) + + self.unet_being_trained_index = -1 + + @contextmanager + def one_unet_in_gpu(self, unet_number = None, unet = None): + assert exists(unet_number) ^ exists(unet) + + if exists(unet_number): + unet = self.unets[unet_number - 1] + + devices = [module_device(unet) for unet in self.unets] + self.unets.cpu() + unet.to(self.device) + + yield + + for unet, device in zip(self.unets, devices): + unet.to(device) + + # overriding state dict functions + + def state_dict(self, *args, **kwargs): + self.reset_unets_all_one_device() + return super().state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + self.reset_unets_all_one_device() + return super().load_state_dict(*args, **kwargs) + + # dynamic thresholding + + def threshold_x_start(self, x_start, dynamic_threshold = True): + if not dynamic_threshold: + return x_start.clamp(-1., 1.) + + s = torch.quantile( + rearrange(x_start, 'b ... -> b (...)').abs(), + self.dynamic_thresholding_percentile, + dim = -1 + ) + + s.clamp_(min = 1.) + s = right_pad_dims_to(x_start, s) + return x_start.clamp(-s, s) / s + + # derived preconditioning params - Table 1 + + def c_skip(self, sigma_data, sigma): + return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2) + + def c_out(self, sigma_data, sigma): + return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5 + + def c_in(self, sigma_data, sigma): + return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5 + + def c_noise(self, sigma): + return log(sigma) * 0.25 + + # preconditioned network output + # equation (7) in the paper + + def preconditioned_network_forward( + self, + unet_forward, + noised_images, + sigma, + *, + sigma_data, + clamp = False, + dynamic_threshold = True, + **kwargs + ): + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("========================================") + print("ElucidatedImagen: preconditioned_network_forward") + + batch, device = noised_images.shape[0], noised_images.device + + if isinstance(sigma, float): + sigma = torch.full((batch,), sigma, device = device) + + padded_sigma = self.right_pad_dims_to_datatype(sigma) + + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("unet_forward: ") + print("arg_0: ", noised_images.shape) + print("arg_1: ", sigma.shape) + print("other arg: ", kwargs) + + net_out = unet_forward( + self.c_in(sigma_data, padded_sigma) * noised_images, + self.c_noise(sigma), + # "{text_embeds,text_mask,cond_images,...} are in kwargs" + **kwargs + ) + + out = self.c_skip(sigma_data, padded_sigma) * noised_images + self.c_out(sigma_data, padded_sigma) * net_out + + if not clamp: + return out + + return self.threshold_x_start(out, dynamic_threshold) + + # sampling + + # sample schedule + # equation (5) in the paper + + def sample_schedule( + self, + num_sample_steps, + rho, + sigma_min, + sigma_max + ): + N = num_sample_steps + inv_rho = 1 / rho + + steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32) + sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho + + sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0. + return sigmas + + @torch.no_grad() + def one_unet_sample( + self, + unet, + shape, + *, + unet_number, + clamp = True, + dynamic_threshold = True, + cond_scale = 1., + use_tqdm = True, + inpaint_images = None, + inpaint_masks = None, + inpaint_resample_times = 5, + init_images = None, + skip_steps = None, + sigma_min = None, + sigma_max = None, + **kwargs + ): + # get specific sampling hyperparameters for unet + + hp = self.hparams[unet_number - 1] + + sigma_min = default(sigma_min, hp.sigma_min) + sigma_max = default(sigma_max, hp.sigma_max) + + # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma + + sigmas = self.sample_schedule(hp.num_sample_steps, hp.rho, sigma_min, sigma_max) + + gammas = torch.where( + (sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax), + min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1), + 0. + ) + + sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) + + # images is noise at the beginning + + init_sigma = sigmas[0] + + images = init_sigma * torch.randn(shape, device = self.device) + + # initializing with an image + + if exists(init_images): + images += init_images + + # keeping track of x0, for self conditioning if needed + + x_start = None + + # prepare inpainting images and mask + + has_inpainting = exists(inpaint_images) and exists(inpaint_masks) + resample_times = inpaint_resample_times if has_inpainting else 1 + + if has_inpainting: + inpaint_images = self.normalize_img(inpaint_images) + inpaint_images = self.resize_to(inpaint_images, shape[-1]) + inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool() + + # unet kwargs + + unet_kwargs = dict( + sigma_data = hp.sigma_data, + clamp = clamp, + dynamic_threshold = dynamic_threshold, + cond_scale = cond_scale, + **kwargs + ) + + # gradually denoise + + initial_step = default(skip_steps, 0) + sigmas_and_gammas = sigmas_and_gammas[initial_step:] + + total_steps = len(sigmas_and_gammas) + + for ind, (sigma, sigma_next, gamma) in tqdm(enumerate(sigmas_and_gammas), total = total_steps, desc = 'sampling time step', disable = not use_tqdm): + is_last_timestep = ind == (total_steps - 1) + + sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) + + for r in reversed(range(resample_times)): + is_last_resample_step = r == 0 + + eps = hp.S_noise * torch.randn(shape, device = self.device) # stochastic sampling + + sigma_hat = sigma + gamma * sigma + added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps + + images_hat = images + added_noise + + self_cond = x_start if unet.self_cond else None + + if has_inpainting: + images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks + + model_output = self.preconditioned_network_forward( + unet.forward_with_cond_scale, + images_hat, + sigma_hat, + self_cond = self_cond, + **unet_kwargs + ) + + denoised_over_sigma = (images_hat - model_output) / sigma_hat + + images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma + + # second order correction, if not the last timestep + + if sigma_next != 0: + self_cond = model_output if unet.self_cond else None + + model_output_next = self.preconditioned_network_forward( + unet.forward_with_cond_scale, + images_next, + sigma_next, + self_cond = self_cond, + **unet_kwargs + ) + + denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next + images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) + + images = images_next + + if has_inpainting and not (is_last_resample_step or is_last_timestep): + # renoise in repaint and then resample + repaint_noise = torch.randn(shape, device = self.device) + images = images + (sigma - sigma_next) * repaint_noise + + x_start = model_output # save model output for self conditioning + + + if has_inpainting: + images = images * ~inpaint_masks + inpaint_images * inpaint_masks + + return images + + @torch.no_grad() + @eval_decorator + def sample( + self, + texts: List[str] = None, + text_masks = None, + text_embeds = None, + cond_images = None, + inpaint_images = None, + inpaint_masks = None, + inpaint_resample_times = 5, + init_images = None, + skip_steps = None, + sigma_min = None, + sigma_max = None, + video_frames = None, + batch_size = 1, + cond_scale = 1., + lowres_sample_noise_level = None, + start_at_unet_number = 1, + start_image_or_video = None, + stop_at_unet_number = None, + return_all_unet_outputs = False, + return_pil_images = False, + use_tqdm = True, + device = None, + + ): + + # + + device = default(device, self.device) + self.reset_unets_all_one_device(device = device) + + cond_images = maybe(cast_uint8_images_to_float)(cond_images) + + if exists(texts) and not exists(text_embeds) and not self.unconditional: + assert all([*map(len, texts)]), 'text cannot be empty' + + with autocast(enabled = False): + text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) + + text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks)) + + if not self.unconditional: + assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training' + + text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) + batch_size = text_embeds.shape[0] + + if exists(inpaint_images): + if self.unconditional: + if batch_size == 1: # assume researcher wants to broadcast along inpainted images + batch_size = inpaint_images.shape[0] + + assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=)``' + assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on' + + assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified' + assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented' + assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' + + assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting' + + outputs = [] + + is_cuda = next(self.parameters()).is_cuda + device = next(self.parameters()).device + + lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level) + + num_unets = len(self.unets) + cond_scale = cast_tuple(cond_scale, num_unets) + + # handle video and frame dimension + + assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video' + + frame_dims = (video_frames,) if self.is_video else tuple() + + # initializing with an image or video + + init_images = cast_tuple(init_images, num_unets) + init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images] + + skip_steps = cast_tuple(skip_steps, num_unets) + + sigma_min = cast_tuple(sigma_min, num_unets) + sigma_max = cast_tuple(sigma_max, num_unets) + + # handle starting at a unet greater than 1, for training only-upscaler training + + if start_at_unet_number > 1: + assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets' + assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number + assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling' + + prev_image_size = self.image_sizes[start_at_unet_number - 2] + img = self.resize_to(start_image_or_video, prev_image_size) + + + for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm): + if unet_number < start_at_unet_number: + continue + + assert not isinstance(unet, NullUnet), 'cannot sample from null unet' + + context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext() + + with context: + lowres_cond_img = lowres_noise_times = None + + shape = (batch_size, channel, *frame_dims, image_size ) + + # low resolution conditioning + + if unet.lowres_cond: + lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device) + + lowres_cond_img = self.resize_to(img, image_size) + lowres_cond_img = self.normalize_img(lowres_cond_img.float()) + + lowres_cond_img, _ = self.lowres_noise_schedule.q_sample( + x_start = lowres_cond_img.float(), + t = lowres_noise_times, + noise = torch.randn_like(lowres_cond_img.float()) + ) + + if exists(unet_init_images): + unet_init_images = self.resize_to(unet_init_images, image_size) + + + shape = (batch_size, self.channels, *frame_dims, image_size) + + img = self.one_unet_sample( + unet, + shape, + unet_number = unet_number, + text_embeds = text_embeds, + text_mask =text_masks, + cond_images = cond_images, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = unet_init_images, + skip_steps = unet_skip_steps, + sigma_min = unet_sigma_min, + sigma_max = unet_sigma_max, + cond_scale = unet_cond_scale, + lowres_cond_img = lowres_cond_img, + lowres_noise_times = lowres_noise_times, + dynamic_threshold = dynamic_threshold, + use_tqdm = use_tqdm + ) + + if self.categorical_loss: + img=self.m(img) + + outputs.append(img) + + if exists(stop_at_unet_number) and stop_at_unet_number == unet_number: + break + + output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs + + + if not return_all_unet_outputs: + outputs = outputs[-1:] + + assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet' + + if self.categorical_loss: + return torch.argmax(outputs[output_index], dim=1).unsqueeze (1) + else: + return outputs[output_index] + + # training + + def loss_weight(self, sigma_data, sigma): + return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2 + + def noise_distribution(self, P_mean, P_std, batch_size): + # torch.randn: normal distribution N(0, 1) + return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp() + + def forward( + self, + images, + unet: Union[ NullUnet, DistributedDataParallel] = None, + texts: List[str] = None, + text_embeds = None, + text_masks = None, + unet_number = None, + cond_images = None, + + ): + assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' + unet_number = default(unet_number, 1) + assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}' + + # + for debug + if self.CKeys['Debug_ModelPack']==1: + if cond_images!=None: + print("cond_images type: ", cond_images.dtype) + else: + print("cond_images type: None") + + cond_images = maybe(cast_uint8_images_to_float)(cond_images) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + if cond_images!=None: + print("cond_images type: ", cond_images.dtype) + else: + print("cond_images type: None") + + + if self.categorical_loss==False: + assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead' + + unet_index = unet_number - 1 + + unet = default(unet, lambda: self.get_unet(unet_number)) + + assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained' + + target_image_size = self.image_sizes[unet_index] + random_crop_size = self.random_crop_sizes[unet_index] + prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None + hp = self.hparams[unet_index] + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("target_image_size: ", target_image_size) + print("prev_image_size: ", prev_image_size) + print("random_crop_size: ", random_crop_size) + + + batch_size, c, *_, h, device, is_video = *images.shape, images.device, (images.ndim == 4) + + frames = images.shape[2] if is_video else None + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("frames: ", frames) + + check_shape(images, 'b c ...', c = self.channels) + + + assert h >= target_image_size + + if exists(texts) and not exists(text_embeds) and not self.unconditional: + assert all([*map(len, texts)]), 'text cannot be empty' + assert len(texts) == len(images), 'number of text captions does not match up with the number of images given' + + with autocast(enabled = False): + text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) + + text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks)) + + if not self.unconditional: + text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) + + assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified' + assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented' + + assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' + + lowres_cond_img = lowres_aug_times = None + if exists(prev_image_size): + lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range) + lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range) + + if self.per_sample_random_aug_noise_level: + lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device) + else: + # return a random time: + lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device) + # extend it to batch_size + lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size) + + if exists(random_crop_size): + aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.) + + if is_video: + images, lowres_cond_img = rearrange_many((images, lowres_cond_img), 'b c f h -> (b f) c h') + + images = aug(images) + lowres_cond_img = aug(lowres_cond_img, params = aug._params) + + if is_video: + images, lowres_cond_img = rearrange_many((images, lowres_cond_img), '(b f) c h -> b c f h', f = frames) + + + lowres_cond_img_noisy = None + if exists(lowres_cond_img): + lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample( + x_start = lowres_cond_img, + t = lowres_aug_times, + noise = torch.randn_like(lowres_cond_img.float()) + ) + + # get the sigmas + + sigmas = self.noise_distribution( + hp.P_mean, + hp.P_std, + batch_size + ).to(device) + padded_sigmas = self.right_pad_dims_to_datatype(sigmas).to(device) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print('sigmas dim: ', sigmas.shape, 'should = batch_size') + print('sigmas[0..3]: ', sigmas[:3]) + print('padded_sigmas dim: ', padded_sigmas.shape) + + # noise + + noise = torch.randn_like(images.float()).to(device) + + + noised_images = images + padded_sigmas * noise # alphas are 1. in the paper + + # unet kwargs + + unet_kwargs = dict( + sigma_data = hp.sigma_data, + text_embeds = text_embeds, + text_mask =text_masks, + cond_images = cond_images, + lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times), + lowres_cond_img = lowres_cond_img_noisy, + cond_drop_prob = self.cond_drop_prob, + ) + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("lowres_noise_times: ", unet_kwargs['lowres_noise_times']) + print("lowres_cond_img_noisy: ", unet_kwargs['lowres_cond_img']) + print("unet_kwargs: \n", unet_kwargs) + + # self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower + + # Because 'unet' can be an instance of DistributedDataParallel coming from the + # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to + # access the member 'module' of the wrapped unet instance. + + self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet + # # + for debug + # if self.CKeys['Debug_ModelPack']==1: + # # this will be the unet + # print("self_cond: ", self_cond) + + if self_cond and random() < 0.5: + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("=========added self_cond.============") + with torch.no_grad(): + pred_x0 = self.preconditioned_network_forward( + unet.forward, + noised_images, + sigmas, + **unet_kwargs + ).detach() + + unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0} + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("after self-condition, random < 0.5 .......") + print("unet_kwargs: \n", unet_kwargs) + + # get prediction + + denoised_images = self.preconditioned_network_forward( + unet.forward, + noised_images, + sigmas, + **unet_kwargs + ) + + # losses + + if self.loss_type==0: + + losses = F.mse_loss(denoised_images, images, reduction = 'none') + losses = reduce(losses, 'b ... -> b', 'mean') + + # loss weighting + + losses = losses * self.loss_weight(hp.sigma_data, sigmas) + losses=losses.mean() + + return losses + +# =========================================================== +# final models +# explict define unets +class ProteinDesigner_B(nn.Module): + def __init__(self, + unet, + # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + CKeys=None, + PKeys=None, + ): + super(ProteinDesigner_B, self).__init__() + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # unload the parameters + timesteps =default(PKeys['timesteps'], 10) # 10 , + dim =default(PKeys['dim'], 32) # 32, + pred_dim =default(PKeys['pred_dim'], 25) # 25, + loss_type =default(PKeys['loss_type'], 0) # 0, # MSE + elucidated =default(PKeys['elucidated'], True) # True, + padding_idx =default(PKeys['padding_idx'], 0) # 0, + cond_dim =default(PKeys['cond_dim'], 512) # 512, + text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512, + input_tokens =default(PKeys['input_tokens'], 25) #for non-BERT + sequence_embed =default(PKeys['sequence_embed'], False) + embed_dim_position =default(PKeys['embed_dim_position'], 32) + max_text_len =default(PKeys['max_text_len'], 16) + cond_images_channels=default(PKeys['cond_images_channels'], 0) + # ++++++++++++++++++++++ + max_length =default(PKeys['max_length'], 64) + device =default(PKeys['device'], None) + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + + + print ("Model B: Generative protein diffusion model, residue-based") + print ("Using condition as the initial sequence") + self.pred_dim=pred_dim + self.loss_type=loss_type + # +++++++++++++++++ for debug + self.CKeys=CKeys + self.PKeys=PKeys + self.max_length = max_length + + assert loss_type == 0, "Losses other than MSE not implemented" + + self.fc_embed1 = nn.Linear( 8, max_length) # NOT used # INPUT DIM (last), OUTPUT DIM, last + self.fc_embed2 = nn.Linear( 1, text_embed_dim) # USED # INPUT DIM (last), OUTPUT DIM, last + self.max_text_len=max_text_len + + self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) + text_embed_dim=text_embed_dim+embed_dim_position + + self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) + for i in range (max_text_len): + self.pos_matrix_i [i]=i +1 + + condition_on_text=True + self.cond_images_channels=cond_images_channels + + if self.cond_images_channels>0: + condition_on_text = False + + if self.cond_images_channels>0: + print ("Use conditioning image during training....") + + assert elucidated , "Only elucidated model implemented...." + self.is_elucidated=elucidated + if elucidated: + self.imagen = ElucidatedImagen( + unets = (unet), + channels=self.pred_dim, + channels_out=self.pred_dim , + loss_type=loss_type, + condition_on_text = condition_on_text, + text_embed_dim = text_embed_dim, + image_sizes = ( [max_length ]), + cond_drop_prob = 0.1, + auto_normalize_img = False, + num_sample_steps = timesteps, # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) + sigma_min = 0.002, # min noise level + sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler + sigma_data = 0.5, # standard deviation of data distribution + rho = 7, # controls the sampling schedule + P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training + P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training + S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper + S_tmin = 0.05, + S_tmax = 50, + S_noise = 1.003, + # ++++++++++++++++ + # device=device, + CKeys=self.CKeys, + PKeys=self.PKeys, + + ).to(device) + else: + print ("Not implemented.") + + def forward(self, + output, + x=None, + cond_images = None, + unet_number=1, + ): + # ddddddddddddddddddddddddddddddddddddddddddddd + if self.CKeys['Debug_ModelPack']==1: + print("on Model B:forward") + print("inputs:") + print(' output: ', output.shape) + print(' x: ', x) + print(' cond_img: ', cond_images.shape) + print(' unet_num: ', unet_number) + + + if x != None: + x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device) + x_in[:,:x.shape[1]]=x + x=x_in + + x=x.unsqueeze (2) + + x= self.fc_embed2(x) + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + pos_emb_x = torch.squeeze(pos_emb_x, 1) + + x= torch.cat( (x, pos_emb_x ), 2) + + if cond_images!=None: + this_cond_images=cond_images.to(device) + else: + this_cond_images=cond_images + + if self.CKeys['Debug_ModelPack']==1: + print('x with pos_emb_x: ', x) + print("into self.imagen...") + loss = self.imagen( + output, + text_embeds = x, + # cond_images=cond_images.to(device), + cond_images=this_cond_images, + unet_number = unet_number, + ) + + return loss + + def sample ( + self, + x=None, + stop_at_unet_number=1 , + cond_scale=7.5, + x_data=None, + skip_steps=None, + inpaint_images = None, + inpaint_masks = None, + inpaint_resample_times = 5, + init_images = None, + x_data_tokenized=None, + device=None, + # +++++++++++++++++++++++++ + tokenizer_X=None, + Xnormfac=1., + max_length=1., + ): + + batch_size=1 + + if x_data != None: + print ("Conditioning target sequence provided via ori x_data ...", x_data) + # this is for Model B with SecStr as input + # + # # -- ver 0.0: channel = 1, all padding at the end: content+000 + # x_data = tokenizer_X.texts_to_sequences(x_data) + # x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post') + # + # x_data= torch.from_numpy(x_data).float().to(device) + # x_data = x_data/Xnormfac + # x_data=x_data.unsqueeze (2) + # x_data=torch.permute(x_data, (0,2,1) ) + # + + # ++ ver 1.0: channel = self.pred_dim>1, esm padding, 0+content+000 + x_data = tokenizer_X.texts_to_sequences(x_data) + # padding: 0+content+00 + x_data_0 = [] + for this_x_data in x_data: + x_data_0.append([0]+this_x_data) # to be checked whether this works + x_data = sequence.pad_sequences( + x_data_0, maxlen=max_length, + padding='post', truncating='post', + ) + # normalization: dummy, for future + x_data = torch.from_numpy(x_data).float().to(device) + x_data = x_data/Xnormfac + # adjust the channel:NOTE, here we use "copy" to fill-in the channels + # May need to adjust this part for future models + x_data=x_data.unsqueeze(1).repeat(1,self.pred_dim,1) + + + print ("After channel expansion, x_data from target sequence=", x_data, x_data.shape) + batch_size=x_data.shape[0] + + if x_data_tokenized != None: + # this is for Model B with ForcePath vector as input + # everything is padded as what the dataloader has + # + # if x_data_tokenized.any() != None: + print ("Conditioning target sequence provided via x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape) + # # ++ ver 1.0 + # if self.pred_dim==1: + # # old case: channel=1 + # x_data=x_data_tokenized.unsqueeze (2) + # x_data=torch.permute(x_data, (0,2,1) ).to(device) + # else: + # # new case: include channels in x_data_tokenized + # x_data=x_data_tokenized.to(device) + # + # # -- ver 0.0: original + # x_data=x_data_tokenized.unsqueeze (2) + # x_data=torch.permute(x_data, (0,2,1) ).to(device) + # + # ++ ver 2.0: input x_data_tokenized.dim= (batch, seq_len) + x_data=x_data_tokenized.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device) + # + print ("x_data.dim provided from x_data_tokenized: ", x_data.shape) + batch_size=x_data.shape[0] + + if init_images != None: + print ("Init sequence provided...", init_images) + init_images = tokenizer_y.texts_to_sequences(init_images) + init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') + init_images= torch.from_numpy(init_images).float().to(device)/ynormfac + print ("init_images=", init_images) + + if inpaint_images != None: + print ("Inpaint sequence provided...", inpaint_images) + print ("Mask: ", inpaint_masks) + inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) + inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') + inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac + print ("in_paint images=", inpaint_images) + + if x !=None: + x_in=torch.zeros( (x.shape[0],max_length) ).to(device) + x_in[:,:x.shape[1]]=x + x=x_in + x=x.unsqueeze (2) + + x= self.fc_embed2(x) + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + pos_emb_x = torch.squeeze(pos_emb_x, 1) + + x= torch.cat( (x, pos_emb_x ), 2) + + batch_size=x.shape[0] + + output = self.imagen.sample( + text_embeds= x, + cond_scale = cond_scale, + stop_at_unet_number=stop_at_unet_number, + cond_images=x_data, + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images, + batch_size=batch_size, + device=device, + ) + + return output + +# =========================================================== +# final models +# explict define unets +class ProteinPredictor_B(nn.Module): + def __init__(self, + unet, + # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + CKeys=None, + PKeys=None, + ): + super(ProteinPredictor_B, self).__init__() + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # unload the parameters + timesteps =default(PKeys['timesteps'], 10) # 10 , + dim =default(PKeys['dim'], 32) # 32, + pred_dim =default(PKeys['pred_dim'], 25) # 25, + loss_type =default(PKeys['loss_type'], 0) # 0, # MSE + elucidated =default(PKeys['elucidated'], True) # True, + padding_idx =default(PKeys['padding_idx'], 0) # 0, + cond_dim =default(PKeys['cond_dim'], 512) # 512, + text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512, + input_tokens =default(PKeys['input_tokens'], 25) #for non-BERT + sequence_embed =default(PKeys['sequence_embed'], False) + embed_dim_position =default(PKeys['embed_dim_position'], 32) + max_text_len =default(PKeys['max_text_len'], 16) + cond_images_channels=default(PKeys['cond_images_channels'], 0) + # ++++++++++++++++++++++ + max_length =default(PKeys['max_length'], 64) + device =default(PKeys['device'], None) + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + + + print ("Model B: Predictive protein diffusion model, residue-based") + print ("Using condition as the initial sequence") + self.pred_dim=pred_dim + self.loss_type=loss_type + # +++++++++++++++++ for debug + self.CKeys=CKeys + self.PKeys=PKeys + self.max_length = max_length + + assert loss_type == 0, "Losses other than MSE not implemented" + + self.fc_embed1 = nn.Linear( 8, max_length) # NOT used # INPUT DIM (last), OUTPUT DIM, last + self.fc_embed2 = nn.Linear( 1, text_embed_dim) # USED # INPUT DIM (last), OUTPUT DIM, last + self.max_text_len=max_text_len + + self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) + text_embed_dim=text_embed_dim+embed_dim_position + + self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) + for i in range (max_text_len): + self.pos_matrix_i [i]=i +1 + + condition_on_text=True + self.cond_images_channels=cond_images_channels + + if self.cond_images_channels>0: + condition_on_text = False + + if self.cond_images_channels>0: + print ("Use conditioning image during training....") + + assert elucidated , "Only elucidated model implemented...." + self.is_elucidated=elucidated + if elucidated: + self.imagen = ElucidatedImagen( + unets = (unet), + channels=self.pred_dim, + channels_out=self.pred_dim , + loss_type=loss_type, + condition_on_text = condition_on_text, + text_embed_dim = text_embed_dim, + image_sizes = ( [max_length ]), + cond_drop_prob = 0.1, + auto_normalize_img = False, + num_sample_steps = timesteps, # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) + sigma_min = 0.002, # min noise level + sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler + sigma_data = 0.5, # standard deviation of data distribution + rho = 7, # controls the sampling schedule + P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training + P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training + S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper + S_tmin = 0.05, + S_tmax = 50, + S_noise = 1.003, + # ++++++++++++++++ + # device=device, + CKeys=self.CKeys, + PKeys=self.PKeys, + + ).to(device) + else: + print ("Not implemented.") + + def forward(self, + output, + x=None, + cond_images = None, + unet_number=1, + ): + # ddddddddddddddddddddddddddddddddddddddddddddd + if self.CKeys['Debug_ModelPack']==1: + print("on Model B:forward") + print("inputs:") + print(' output: ', output.shape) + print(' x: ', x) + print(' cond_img: ', cond_images.shape) + print(' unet_num: ', unet_number) + + + if x != None: + x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device) + x_in[:,:x.shape[1]]=x + x=x_in + + x=x.unsqueeze (2) + + x= self.fc_embed2(x) + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + pos_emb_x = torch.squeeze(pos_emb_x, 1) + + x= torch.cat( (x, pos_emb_x ), 2) + + if cond_images!=None: + this_cond_images=cond_images.to(device) + else: + this_cond_images=cond_images + + if self.CKeys['Debug_ModelPack']==1: + print('x with pos_emb_x: ', x) + print("into self.imagen...") + loss = self.imagen( + output, + text_embeds = x, + # cond_images=cond_images.to(device), + cond_images=this_cond_images, + unet_number = unet_number, + ) + + return loss + + def sample ( + self, + x=None, + stop_at_unet_number=1 , + cond_scale=7.5, + x_data=None, + skip_steps=None, + inpaint_images = None, + inpaint_masks = None, + inpaint_resample_times = 5, + init_images = None, + x_data_tokenized=None, + device=None, + # +++++++++++++++++++++++++ + tokenizer_X=None, + Xnormfac=1., + max_length=1., + # ++ + pLM_Model_Name=None, + pLM_Model=None, + pLM_alphabet=None, + esm_layer=None, + ): + + batch_size=1 + + if x_data != None: + print ("Conditioning target sequence provided via ori x_data ...", x_data) + print(f"use pLM model {pLM_Model_Name}") + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # prepare x_data into the cond_img channel (batch, channel, seq_len) + # for ProteinPredictor: AA seq -> FOrcPath + # need to distinguish trivial and ESM series + # mimic the block in DataSetPack + if pLM_Model_Name=='trivial': + # no ESM but direct tokenizer is used here + # NEED to pad the 0th position + x_data = tokenizer_X.texts_to_sequences(x_data) + # two-step padding + x_data = sequence.pad_sequences( + x_data, maxlen=max_length-1, + padding='post', truncating='post', + value=0.0, + ) + # add one 0 at the begining + x_data = sequence.pad_sequences( + x_data, maxlen=max_length, + padding='pre', truncating='pre', + value=0.0, + ) # (batch, seq_len) + x_data=x_data.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device) + # (batch, 1, seq_len) + x_data = x_data/Xnormfac + + else: + # ++ for esm + print("pLM Model: ", pLM_Model_Name) + # 1. from AA string to token + # need to save 2 positions for and + esm_batch_converter = pLM_alphabet.get_batch_converter( + truncation_seq_length=max_length-2 + ) + # prepare seqs for the "esm_batch_converter..." + # add dummy labels + seqs_ext=[] + for i in range(len(x_data)): + seqs_ext.append( + (" ", x_data[i]) + ) + # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext) + _, x_strs, x_data = esm_batch_converter(seqs_ext) + x_strs_lens = (x_data != pLM_alphabet.padding_idx).sum(1) + # + # NEED to check the size of y_data + # need to dealwith if y_data are only shorter sequences + # need to add padding with a value, int (1) + current_seq_len = x_data.shape[1] + print("current seq batch len: ", current_seq_len) + missing_num_pad = max_length-current_seq_len + if missing_num_pad>0: + print("extra padding is added to match the target seq input length...") + # padding is needed + x_data = F.pad( + x_data, + (0, missing_num_pad), + "constant", pLM_alphabet.padding_idx + ) + else: + print("No extra padding is needed") + x_data = x_data.to(device) + # + # 2. from token to embedding + with torch.no_grad(): + results = pLM_Model( + x_data, + repr_layers=[esm_layer], + return_contacts=False, + ) + x_data=results["representations"][esm_layer] # (batch, seq_len, channels) + x_data=rearrange( + x_data, + 'b l c -> b c l' + ) + + + # print(batch_tokens.shape) + print ("x_data.dim: ", x_data.shape) + print ("x_data.type: ", x_data.type) + batch_size=x_data.shape[0] + + + +# # --------------------------------------------------------- +# # this is for Model B with SecStr as input +# # +# # # -- ver 0.0: channel = 1, all padding at the end: content+000 +# # x_data = tokenizer_X.texts_to_sequences(x_data) +# # x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post') +# # +# # x_data= torch.from_numpy(x_data).float().to(device) +# # x_data = x_data/Xnormfac +# # x_data=x_data.unsqueeze (2) +# # x_data=torch.permute(x_data, (0,2,1) ) +# # + +# # ++ ver 1.0: channel = self.pred_dim>1, esm padding, 0+content+000 +# x_data = tokenizer_X.texts_to_sequences(x_data) +# # padding: 0+content+00 +# x_data_0 = [] +# for this_x_data in x_data: +# x_data_0.append([0]+this_x_data) # to be checked whether this works +# x_data = sequence.pad_sequences( +# x_data_0, maxlen=max_length, +# padding='post', truncating='post', +# ) +# # normalization: dummy, for future +# x_data = torch.from_numpy(x_data).float().to(device) +# x_data = x_data/Xnormfac +# # adjust the channel:NOTE, here we use "copy" to fill-in the channels +# # May need to adjust this part for future models +# x_data=x_data.unsqueeze(1).repeat(1,self.pred_dim,1) + + +# print ("After channel expansion, x_data from target sequence=", x_data, x_data.shape) +# batch_size=x_data.shape[0] + + if x_data_tokenized != None: + # ++ + # this is for Model B with AA tokens as input + # task: x_data_tokenized (batch, seq_len) -> x_data (batch, channel, seq_len) + print ( + "Conditioning target output via provided AA tokens sequence...", + x_data_tokenized, + x_data_tokenized.shape, + ) + # transfer tokens into embedding and expand the channels + if pLM_Model_Name=='trivial': + # self.pred_dim should be 1 + x_data=x_data_tokenized.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device) + else: # esm models + with torch.no_grad(): + results = pLM_Model( + x_data_tokenized, + repr_layers=[esm_layer], + return_contacts=False, + ) + x_data=results["representations"][esm_layer] # (batch, seq_len, channels) + x_data=rearrange( + x_data, + 'b l c -> b c l' + ) + x_data = x_data.to(device) + + batch_size=x_data.shape[0] + + + print ("x_data.dim provided from x_data_tokenized: ", x_data.shape) + + # # -- + # # this is for Model B with ForcePath vector as input + # # everything is padded as what the dataloader has + # # + # # if x_data_tokenized.any() != None: + # print ("Conditioning target sequence provided via x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape) + # # # ++ ver 1.0 + # # if self.pred_dim==1: + # # # old case: channel=1 + # # x_data=x_data_tokenized.unsqueeze (2) + # # x_data=torch.permute(x_data, (0,2,1) ).to(device) + # # else: + # # # new case: include channels in x_data_tokenized + # # x_data=x_data_tokenized.to(device) + # # + # # # -- ver 0.0: original + # # x_data=x_data_tokenized.unsqueeze (2) + # # x_data=torch.permute(x_data, (0,2,1) ).to(device) + # # + # # ++ ver 2.0: input x_data_tokenized.dim= (batch, seq_len) + # x_data=x_data_tokenized.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device) + # # + # print ("x_data.dim provided from x_data_tokenized: ", x_data.shape) + # batch_size=x_data.shape[0] + + if init_images != None: + print ("Init sequence provided...", init_images) + init_images = tokenizer_y.texts_to_sequences(init_images) + init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') + init_images= torch.from_numpy(init_images).float().to(device)/ynormfac + print ("init_images=", init_images) + + if inpaint_images != None: + print ("Inpaint sequence provided...", inpaint_images) + print ("Mask: ", inpaint_masks) + inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) + inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') + inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac + print ("in_paint images=", inpaint_images) + + if x !=None: + x_in=torch.zeros( (x.shape[0],max_length) ).to(device) + x_in[:,:x.shape[1]]=x + x=x_in + x=x.unsqueeze (2) + + x= self.fc_embed2(x) + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + pos_emb_x = torch.squeeze(pos_emb_x, 1) + + x= torch.cat( (x, pos_emb_x ), 2) + + batch_size=x.shape[0] + + output = self.imagen.sample( + text_embeds= x, + cond_scale = cond_scale, + stop_at_unet_number=stop_at_unet_number, + cond_images=x_data, + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images, + batch_size=batch_size, + device=device, + ) + + return output + + +# =========================================================== +# final models +# explict define unets +class ProteinDesigner_B_Old(nn.Module): + def __init__(self, + unet, + timesteps=10 , + dim=32, + pred_dim=25, + loss_type=0, # MSE + elucidated=True, + padding_idx=0, + cond_dim = 512, + text_embed_dim = 512, + input_tokens=25,#for non-BERT + sequence_embed=False, + embed_dim_position=32, + max_text_len=16, + cond_images_channels=0, + # ++++++++++++++++++++++ + max_length=1, + device=None, + CKeys=None, + PKeys=None, + + ): + super(ProteinDesigner_B_Old, self).__init__() + + print ("Model B: Generative protein diffusion model, residue-based") + print ("Using condition as the initial sequence") + self.pred_dim=pred_dim + self.loss_type=loss_type + # +++++++++++++++++ for debug + self.CKeys=CKeys + self.PKeys=PKeys + self.max_length = max_length + + assert loss_type == 0, "Losses other than MSE not implemented" + + self.fc_embed1 = nn.Linear( 8, max_length) # NOT used # INPUT DIM (last), OUTPUT DIM, last + self.fc_embed2 = nn.Linear( 1, text_embed_dim) # USED # INPUT DIM (last), OUTPUT DIM, last + self.max_text_len=max_text_len + + self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) + text_embed_dim=text_embed_dim+embed_dim_position + + self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) + for i in range (max_text_len): + self.pos_matrix_i [i]=i +1 + + condition_on_text=True + self.cond_images_channels=cond_images_channels + + if self.cond_images_channels>0: + condition_on_text = False + + if self.cond_images_channels>0: + print ("Use conditioning image during training....") + + assert elucidated , "Only elucidated model implemented...." + self.is_elucidated=elucidated + if elucidated: + self.imagen = ElucidatedImagen( + unets = (unet), + channels=self.pred_dim, + channels_out=self.pred_dim , + loss_type=loss_type, + condition_on_text = condition_on_text, + text_embed_dim = text_embed_dim, + image_sizes = ( [max_length ]), + cond_drop_prob = 0.1, + auto_normalize_img = False, + num_sample_steps = timesteps, # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) + sigma_min = 0.002, # min noise level + sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler + sigma_data = 0.5, # standard deviation of data distribution + rho = 7, # controls the sampling schedule + P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training + P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training + S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper + S_tmin = 0.05, + S_tmax = 50, + S_noise = 1.003, + # ++++++++++++++++ + # device=device, + CKeys=self.CKeys, + PKeys=self.PKeys, + + ).to(device) + else: + print ("Not implemented.") + + def forward(self, + output, + x=None, + cond_images = None, + unet_number=1, + ): + # ddddddddddddddddddddddddddddddddddddddddddddd + if self.CKeys['Debug_ModelPack']==1: + print('output: ', output.shape) + print('x: ', x) + print('cond_img: ', cond_images) + print('unet_num: ', unet_number) + + + if x != None: + x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device) + x_in[:,:x.shape[1]]=x + x=x_in + + x=x.unsqueeze (2) + + x= self.fc_embed2(x) + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + pos_emb_x = torch.squeeze(pos_emb_x, 1) + + x= torch.cat( (x, pos_emb_x ), 2) + + if cond_images!=None: + this_cond_images=cond_images.to(device) + else: + this_cond_images=cond_images + + if self.CKeys['Debug_ModelPack']==1: + print('x with pos_emb_x: ', x) + loss = self.imagen( + output, + text_embeds = x, + # cond_images=cond_images.to(device), + cond_images=this_cond_images, + unet_number = unet_number, + ) + + return loss + + def sample ( + self, + x=None, + stop_at_unet_number=1 , + cond_scale=7.5, + x_data=None, + skip_steps=None, + inpaint_images = None, + inpaint_masks = None, + inpaint_resample_times = 5, + init_images = None, + x_data_tokenized=None, + device=None, + # +++++++++++++++++++++++++ + tokenizer_X=None, + Xnormfac=1., + ynormfac=1., + max_length=1., + ): + + batch_size=1 + + if x_data != None: + print ("Conditioning target sequence provided via x_data ...", x_data) + x_data = tokenizer_X.texts_to_sequences(x_data) + x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post') + + x_data= torch.from_numpy(x_data).float().to(device) + x_data = x_data/Xnormfac + x_data=x_data.unsqueeze (2) + x_data=torch.permute(x_data, (0,2,1) ) + + print ("x_data from target sequence=", x_data, x_data.shape) + batch_size=x_data.shape[0] + + if x_data_tokenized != None: + print ("Conditioning target sequence provided via x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape) + + x_data=x_data_tokenized.unsqueeze (2) + x_data=torch.permute(x_data, (0,2,1) ).to(device) + print ("Data provided from x_data_tokenized: ", x_data.shape) + batch_size=x_data.shape[0] + + if init_images != None: + print ("Init sequence provided...", init_images) + init_images = tokenizer_y.texts_to_sequences(init_images) + init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') + init_images= torch.from_numpy(init_images).float().to(device)/ynormfac + print ("init_images=", init_images) + + if inpaint_images != None: + print ("Inpaint sequence provided...", inpaint_images) + print ("Mask: ", inpaint_masks) + inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) + inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') + inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac + print ("in_paint images=", inpaint_images) + + if x !=None: + x_in=torch.zeros( (x.shape[0],max_length) ).to(device) + x_in[:,:x.shape[1]]=x + x=x_in + x=x.unsqueeze (2) + + x= self.fc_embed2(x) + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + pos_emb_x = torch.squeeze(pos_emb_x, 1) + + x= torch.cat( (x, pos_emb_x ), 2) + + batch_size=x.shape[0] + + output = self.imagen.sample( + text_embeds= x, + cond_scale = cond_scale, + stop_at_unet_number=stop_at_unet_number, + cond_images=x_data, + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images, + batch_size=batch_size, + device=device, + ) + + return output + +# =========================================================== +# final models: old model +# Changes: 1. put the UNet part out +# +class ProteinDesigner_A_II(nn.Module): + def __init__( + self, + unet1, + # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + CKeys=None, + PKeys=None, +): + + super(ProteinDesigner_A_II, self).__init__() + + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # unload the arguments + timesteps =default(PKeys['timesteps'], 10) # 10 + dim =default(PKeys['dim'], 32) # 32, + pred_dim =default(PKeys['pred_dim'], 25) # 25, + loss_type =default(PKeys['loss_type'], 0) # 0, + elucidated =default(PKeys['elucidated'], True) # False, + padding_idx =default(PKeys['padding_idx'], 0) # 0, + cond_dim =default(PKeys['cond_dim'], 512) # 512, + text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512, + input_tokens =default(PKeys['input_tokens'], 25) # 25,#for non-BERT + sequence_embed =default(PKeys['sequence_embed'], False) # False, + embed_dim_position =default(PKeys['embed_dim_position'], 32) # 32, + max_text_len =default(PKeys['max_text_len'], 16) # 16, + cond_images_channels=default(PKeys['cond_images_channels'], 0) + # ++ + max_length =default(PKeys['max_length'], 64) # 64, + device =default(PKeys['device'], 'cuda:0') # 'cuda:0', + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + + self.pred_dim=pred_dim + self.loss_type=loss_type + # ++ + self.CKeys=CKeys + self.PKeys=PKeys + + self.device=device + + assert (loss_type==0), "Loss other than MSE not implemented" + + + self.fc_embed1 = nn.Linear( 8, max_length) # NOT USED + self.fc_embed2 = nn.Linear( 1, text_embed_dim) # project the text into embedding dimensions + self.max_text_len=max_text_len + + self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) + text_embed_dim=text_embed_dim+embed_dim_position + + self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) + for i in range (max_text_len): + self.pos_matrix_i [i]=i +1 + if self.CKeys['Debug_ModelPack']==1: + print("ModelA.pos_matrix_i: ", self.pos_matrix_i) + + + + # # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # # prepare the Unet Key + # write_PK_UNet=dict() + # write_PK_UNet['dim']=dim + # write_PK_UNet['text_embed_dim']=text_embed_dim + # write_PK_UNet['cond_dim']=cond_dim + # write_PK_UNet['dim_mults']=(1, 2, 4, 8) + # write_PK_UNet['num_resnet_blocks']=1 + # write_PK_UNet['layer_attns']=(False, True, True, False) + # write_PK_UNet['layer_cross_attns']=(False, True, True, False) + # write_PK_UNet['channels']=pred_dim + # write_PK_UNet['channels_out']=pred_dim + # write_PK_UNet['attn_dim_head']=64 + # write_PK_UNet['attn_heads']=8 + # write_PK_UNet['ff_mult']=2. + # write_PK_UNet['lowres_cond']=False # for cascading diffusion - https://cascaded-diffusion.github.io/ + # write_PK_UNet['layer_attns_depth']=1 + # write_PK_UNet['layer_attns_add_text_cond']=True # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 + # write_PK_UNet['attend_at_middle']=True # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) + # write_PK_UNet['use_linear_attn']=False + # write_PK_UNet['use_linear_cross_attn']=False + # write_PK_UNet['cond_on_text'] = True + # write_PK_UNet['max_text_len'] = max_length + # write_PK_UNet['init_dim'] = None + # write_PK_UNet['resnet_groups'] = 8 + # write_PK_UNet['init_conv_kernel_size'] =7 # kernel size of initial conv, if not using cross embed + # write_PK_UNet['init_cross_embed'] = False #TODO - fix ouput size calcs for conv1d + # write_PK_UNet['init_cross_embed_kernel_sizes'] = (3, 7, 15) + # write_PK_UNet['cross_embed_downsample'] = False + # write_PK_UNet['cross_embed_downsample_kernel_sizes'] = (2, 4) + # write_PK_UNet['attn_pool_text'] = True + # write_PK_UNet['attn_pool_num_latents'] = 32 #32, #perceiver model latents + # write_PK_UNet['dropout'] = 0. + # write_PK_UNet['memory_efficient'] = False + # write_PK_UNet['init_conv_to_final_conv_residual'] = False + # write_PK_UNet['use_global_context_attn'] = True + # write_PK_UNet['scale_skip_connection'] = True + # write_PK_UNet['final_resnet_block'] = True + # write_PK_UNet['final_conv_kernel_size'] = 3 + # write_PK_UNet['cosine_sim_attn'] = True + # write_PK_UNet['self_cond'] = False + # write_PK_UNet['combine_upsample_fmaps'] = True # combine feature maps from all upsample blocks, used in unet squared successfully + # write_PK_UNet['pixel_shuffle_upsample'] = False # may address checkboard artifacts + # # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + # # + # Unet_PKeys=prepare_UNet_keys(write_PK_UNet) + # unet1 = OneD_Unet( + # CKeys=CKeys, + # PKeys=Unet_PKeys, + # ).to (self.device) +# # ----------------------------------------------------------------------------- +# unet1 = OneD_Unet_Old( +# dim = dim, +# text_embed_dim = text_embed_dim, +# cond_dim = cond_dim, +# dim_mults = (1, 2, 4, 8), + +# num_resnet_blocks = 1,#1, +# layer_attns = (False, True, True, False), +# layer_cross_attns = (False, True, True, False), +# channels=self.pred_dim, +# channels_out=self.pred_dim , +# # +# attn_dim_head = 64, +# attn_heads = 8, +# ff_mult = 2., +# lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ + +# layer_attns_depth =1, +# layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 +# attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) +# use_linear_attn = False, +# use_linear_cross_attn = False, +# cond_on_text = True, +# max_text_len = max_length, +# init_dim = None, +# resnet_groups = 8, +# init_conv_kernel_size =7, # kernel size of initial conv, if not using cross embed +# init_cross_embed = False, #TODO - fix ouput size calcs for conv1d +# init_cross_embed_kernel_sizes = (3, 7, 15), +# cross_embed_downsample = False, +# cross_embed_downsample_kernel_sizes = (2, 4), +# attn_pool_text = True, +# attn_pool_num_latents = 32,#32, #perceiver model latents +# dropout = 0., +# memory_efficient = False, +# init_conv_to_final_conv_residual = False, +# use_global_context_attn = True, +# scale_skip_connection = True, +# final_resnet_block = True, +# final_conv_kernel_size = 3, +# cosine_sim_attn = True, +# self_cond = False, +# combine_upsample_fmaps = True, # combine feature maps from all upsample blocks, used in unet squared successfully +# pixel_shuffle_upsample = False , # may address checkboard artifacts +# # ++ +# CKeys=CKeys, + +# ).to (self.device) + + assert elucidated , "Only elucidated model implemented...." + self.is_elucidated=elucidated + if elucidated: + self.imagen = ElucidatedImagen( + unets = (unet1), + channels=self.pred_dim, + channels_out=self.pred_dim , + loss_type=loss_type, + text_embed_dim = text_embed_dim, + image_sizes = [max_length], + cond_drop_prob = 0.2, + auto_normalize_img = False, + num_sample_steps = timesteps,#(64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) + sigma_min = 0.002, # min noise level + sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler + sigma_data = 0.5, # standard deviation of data distribution + rho = 7, # controls the sampling schedule + P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training + P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training + S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper + S_tmin = 0.05, + S_tmax = 50, + S_noise = 1.003, + # ++ + CKeys=self.CKeys, + PKeys=self.PKeys, + + ).to (self.device) + # + for debug: + if CKeys['Debug_ModelPack']==1: + print("Check on EImagen:") + print("channels: ", self.pred_dim) + print("loss_type: ", loss_type) + print("text_embed_dim: ",text_embed_dim) + print("image_sizes: ", max_length) + print("num_sample_steps: ", timesteps) + print("Measure imagen:") + params( self.imagen) + print("Measure fc_embed2") + params( self.fc_embed2) + print("Measure pos_emb_x") + params( self.pos_emb_x) + + else: + print ("Not implemented.") + + # need to merge this with ModelB:forward + def forward( + self, + # # ------------------- + # x, + # output, + # +++++++++++++++++++ + output, + x=None, + # + cond_images=None, + unet_number=1, + ): #sequences=conditioning, output=prediction + # dddddddddddddddddddddddddddddddddddd + if self.CKeys['Debug_ModelPack']==1: + print("on Model A:forward") + print("inputs:") + print(' output.dim : ', output.shape) + print(' x: ', x) + print(' x.dim: ', x.shape) + if cond_images==None: + print(' cond_img: None ') + else: + print(' cond_img: ', cond_images.shape) + print(' unet_num: ', unet_number) + + x=x.unsqueeze (2) + if self.CKeys['Debug_ModelPack']==1: + print("After x.unsqueeze(2), x.dim: ", x.shape) + + x= self.fc_embed2(x) + if self.CKeys['Debug_ModelPack']==1: + print("After fc_embed2(x), x.dim: ", x.shape) + print() + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) + if self.CKeys['Debug_ModelPack']==1: + print("pos_matrix_i_.dim: ", pos_matrix_i_.shape) + print("pos_matrix_i_: ", pos_matrix_i_) + print() + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + if self.CKeys['Debug_ModelPack']==1: + print("After pos_emb_x(pos_matrix_i_), pos_emb_x.dim: ", pos_emb_x.shape) + print("pos_emb_x: ", pos_emb_x) + print() + pos_emb_x = torch.squeeze(pos_emb_x, 1) + pos_emb_x[:,x.shape[1]:,:]=0 #set all to zero that are not provided via x + pos_emb_x=pos_emb_x[:,:x.shape[1],:] + if self.CKeys['Debug_ModelPack']==1: + print("after operations, pos_emb_x.dim: ", pos_emb_x.shape) + print("pos_emb_x: ", pos_emb_x) + print() + x= torch.cat( (x, pos_emb_x ), 2) + if self.CKeys['Debug_ModelPack']==1: + print("after cat((x,pos_emb_x),2)=>x dim: ", x.shape, "Batch x max_text_len x (text_embed_dim+embed_dim_position)") + print() + print("Now, get into self.imagen part...") + + # ref: the full argument for imagen:forward() + # _________________________________________________ + # self, + # images, + # unet: Union[ NullUnet, DistributedDataParallel] = None, + # texts: List[str] = None, + # text_embeds = None, + # text_masks = None, + # unet_number = None, + # cond_images = None, + # _________________________________________________ + loss = self.imagen( + output, + text_embeds = x, + unet_number = unet_number, + ) + + return loss + + def sample ( + self, + x=None, + stop_at_unet_number=1, + cond_scale=7.5, + # ++ + x_data=None, # image_condi data + skip_steps=None, + inpaint_images = None, + inpaint_masks = None, + inpaint_resample_times = 5, + init_images = None, + x_data_tokenized=None, + tokenizer_X=None, + Xnormfac=None, + # -+ + device=None, + max_length=None, # for XandY data, in image/sequence format; NOT for text condition + max_text_len=None, # for X data, in text format + ): + # + add for the uniform shape for model A and B + if x_data != None: # condition via images + print ("Conditioning target sequence provided via sequence/image in x_data ...", x_data) + # need tokeni + x_data = tokenizer_X.texts_to_sequences(x_data) + x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post') + + x_data= torch.from_numpy(x_data).float().to(self.device) + x_data = x_data/Xnormfac + x_data=x_data.unsqueeze (2) + x_data=torch.permute(x_data, (0,2,1) ) + + print ("x_data from target sequence=", x_data, x_data.shape) + batch_size=x_data.shape[0] + + # + add for tokenized sequence/image data + if x_data_tokenized != None: + print ("Conditioning target sequence provided via processed sequence/image in x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape) + + x_data=x_data_tokenized.unsqueeze (2) + x_data=torch.permute(x_data, (0,2,1) ).to(self.device) + print ("Data provided from x_data_tokenized: ", x_data.shape) + batch_size=x_data.shape[0] + + # keep but not used + if init_images != None: # initial sequence provided via init_images/sequences + # need tokenizer_y, ynormfac, max_length + print ("Init sequence provided...", init_images) + init_images = tokenizer_y.texts_to_sequences(init_images) + init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') + init_images= torch.from_numpy(init_images).float().to(self.device)/ynormfac + print ("init_images=", init_images) + # + if inpaint_images != None: # inpainting model + # need inpaint_images, inpaint_masks, tokenizer_y, ynormfac, max_length + print ("Inpaint sequence provided...", inpaint_images) + print ("Mask: ", inpaint_masks) + inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) + inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') + inpaint_images= torch.from_numpy(inpaint_images).float().to(self.device)/ynormfac + print ("in_paint images=", inpaint_images) + + # now, for Model A type: text conditioning + if x !=None: + print ("Conditioning target sequence via tokenized text in x ...", x) + # for now, we only consider tokenized text in x + # need now: max_text_len ; need for future: + # + for debug + if self.CKeys['Debug_ModelPack']==1: + print("x.dim: ", x.shape) + print("max_text_len: ", max_text_len) + print("self.device: ",self.device) + + x_in=torch.zeros( (x.shape[0],max_text_len) ).to(self.device) + x_in[:,:x.shape[1]]=x + x=x_in + x=x.unsqueeze (2) + + x= self.fc_embed2(x) + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(self.device) + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + pos_emb_x = torch.squeeze(pos_emb_x, 1) + pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x + pos_emb_x=pos_emb_x[:,:x.shape[1],:] + x= torch.cat( (x, pos_emb_x ), 2) + + batch_size=x.shape[0] + + output = self.imagen.sample( + text_embeds= x, + cond_scale= cond_scale, + stop_at_unet_number=stop_at_unet_number, + # ++ + cond_images=x_data, + skip_steps=skip_steps, + inpaint_images = inpaint_images, + inpaint_masks = inpaint_masks, + inpaint_resample_times = inpaint_resample_times, + init_images = init_images, + batch_size=batch_size, + device=self.device, + ) + + +# x=x.unsqueeze (2) + +# x= self.fc_embed2(x) + +# pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) +# pos_emb_x = self.pos_emb_x( pos_matrix_i_) +# pos_emb_x = torch.squeeze(pos_emb_x, 1) +# pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x +# pos_emb_x=pos_emb_x[:,:x.shape[1],:] +# x= torch.cat( (x, pos_emb_x ), 2) + + # output = self.imagen.sample( + # text_embeds= x, + # cond_scale=cond_scale, + # stop_at_unet_number=stop_at_unet_number + # ) + + return output + +# =========================================================== +# final models: old model +# Changes: 1. replace the OneD_UNet part +# +class ProteinDesigner_A_I(nn.Module): + def __init__( + self, + # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + CKeys=None, + PKeys=None, +): + + super(ProteinDesigner_A_I, self).__init__() + + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # unload the arguments + timesteps =default(PKeys['timesteps'], 10) # 10 + dim =default(PKeys['dim'], 32) # 32, + pred_dim =default(PKeys['pred_dim'], 25) # 25, + loss_type =default(PKeys['loss_type'], 0) # 0, + elucidated =default(PKeys['elucidated'], True) # False, + padding_idx =default(PKeys['padding_idx'], 0) # 0, + cond_dim =default(PKeys['cond_dim'], 512) # 512, + text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512, + input_tokens =default(PKeys['input_tokens'], 25) # 25,#for non-BERT + sequence_embed =default(PKeys['sequence_embed'], False) # False, + embed_dim_position =default(PKeys['embed_dim_position'], 32) # 32, + max_text_len =default(PKeys['max_text_len'], 16) # 16, + cond_images_channels=default(PKeys['cond_images_channels'], 0) + # ++ + max_length =default(PKeys['max_length'], 64) # 64, + device =default(PKeys['device'], 'cuda:0') # 'cuda:0', + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + + # ++ + self.CKeys=CKeys + self.PKeys=PKeys + + self.device=device + self.pred_dim=pred_dim + self.loss_type=loss_type + + self.fc_embed1 = nn.Linear( 8, max_length) # NOT USED + self.fc_embed2 = nn.Linear( 1, text_embed_dim) # + self.max_text_len=max_text_len + + self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) + text_embed_dim=text_embed_dim+embed_dim_position + self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) + for i in range (max_text_len): + self.pos_matrix_i [i]=i +1 + + assert (loss_type==0), "Loss other than MSE not implemented" + + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # prepare the Unet Key + write_PK_UNet=dict() + write_PK_UNet['dim']=dim + write_PK_UNet['text_embed_dim']=text_embed_dim + write_PK_UNet['cond_dim']=cond_dim + write_PK_UNet['dim_mults']=(1, 2, 4, 8) + write_PK_UNet['num_resnet_blocks']=1 + write_PK_UNet['layer_attns']=(False, True, True, False) + write_PK_UNet['layer_cross_attns']=(False, True, True, False) + write_PK_UNet['channels']=pred_dim + write_PK_UNet['channels_out']=pred_dim + write_PK_UNet['attn_dim_head']=64 + write_PK_UNet['attn_heads']=8 + write_PK_UNet['ff_mult']=2. + write_PK_UNet['lowres_cond']=False # for cascading diffusion - https://cascaded-diffusion.github.io/ + write_PK_UNet['layer_attns_depth']=1 + write_PK_UNet['layer_attns_add_text_cond']=True # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 + write_PK_UNet['attend_at_middle']=True # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) + write_PK_UNet['use_linear_attn']=False + write_PK_UNet['use_linear_cross_attn']=False + write_PK_UNet['cond_on_text'] = True + write_PK_UNet['max_text_len'] = max_length + write_PK_UNet['init_dim'] = None + write_PK_UNet['resnet_groups'] = 8 + write_PK_UNet['init_conv_kernel_size'] =7 # kernel size of initial conv, if not using cross embed + write_PK_UNet['init_cross_embed'] = False #TODO - fix ouput size calcs for conv1d + write_PK_UNet['init_cross_embed_kernel_sizes'] = (3, 7, 15) + write_PK_UNet['cross_embed_downsample'] = False + write_PK_UNet['cross_embed_downsample_kernel_sizes'] = (2, 4) + write_PK_UNet['attn_pool_text'] = True + write_PK_UNet['attn_pool_num_latents'] = 32 #32, #perceiver model latents + write_PK_UNet['dropout'] = 0. + write_PK_UNet['memory_efficient'] = False + write_PK_UNet['init_conv_to_final_conv_residual'] = False + write_PK_UNet['use_global_context_attn'] = True + write_PK_UNet['scale_skip_connection'] = True + write_PK_UNet['final_resnet_block'] = True + write_PK_UNet['final_conv_kernel_size'] = 3 + write_PK_UNet['cosine_sim_attn'] = True + write_PK_UNet['self_cond'] = False + write_PK_UNet['combine_upsample_fmaps'] = True # combine feature maps from all upsample blocks, used in unet squared successfully + write_PK_UNet['pixel_shuffle_upsample'] = False # may address checkboard artifacts + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + # + Unet_PKeys=prepare_UNet_keys(write_PK_UNet) + unet1 = OneD_Unet( + CKeys=CKeys, + PKeys=Unet_PKeys, + ).to (self.device) + if CKeys['Debug_ModelPack']==1: + print('Check unet generated...') + params(unet1) +# # ----------------------------------------------------------------------------- +# unet1 = OneD_Unet_Old( +# dim = dim, +# text_embed_dim = text_embed_dim, +# cond_dim = cond_dim, +# dim_mults = (1, 2, 4, 8), + +# num_resnet_blocks = 1,#1, +# layer_attns = (False, True, True, False), +# layer_cross_attns = (False, True, True, False), +# channels=self.pred_dim, +# channels_out=self.pred_dim , +# # +# attn_dim_head = 64, +# attn_heads = 8, +# ff_mult = 2., +# lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ + +# layer_attns_depth =1, +# layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 +# attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) +# use_linear_attn = False, +# use_linear_cross_attn = False, +# cond_on_text = True, +# max_text_len = max_length, +# init_dim = None, +# resnet_groups = 8, +# init_conv_kernel_size =7, # kernel size of initial conv, if not using cross embed +# init_cross_embed = False, #TODO - fix ouput size calcs for conv1d +# init_cross_embed_kernel_sizes = (3, 7, 15), +# cross_embed_downsample = False, +# cross_embed_downsample_kernel_sizes = (2, 4), +# attn_pool_text = True, +# attn_pool_num_latents = 32,#32, #perceiver model latents +# dropout = 0., +# memory_efficient = False, +# init_conv_to_final_conv_residual = False, +# use_global_context_attn = True, +# scale_skip_connection = True, +# final_resnet_block = True, +# final_conv_kernel_size = 3, +# cosine_sim_attn = True, +# self_cond = False, +# combine_upsample_fmaps = True, # combine feature maps from all upsample blocks, used in unet squared successfully +# pixel_shuffle_upsample = False , # may address checkboard artifacts +# # ++ +# CKeys=CKeys, + +# ).to (self.device) + + assert elucidated , "Only elucidated model implemented...." + self.is_elucidated=elucidated + if elucidated: + self.imagen = ElucidatedImagen( + unets = (unet1), + channels=self.pred_dim, + channels_out=self.pred_dim , + loss_type=loss_type, + text_embed_dim = text_embed_dim, + image_sizes = [max_length], + cond_drop_prob = 0.2, + auto_normalize_img = False, + num_sample_steps = timesteps,#(64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) + sigma_min = 0.002, # min noise level + sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler + sigma_data = 0.5, # standard deviation of data distribution + rho = 7, # controls the sampling schedule + P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training + P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training + S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper + S_tmin = 0.05, + S_tmax = 50, + S_noise = 1.003, + # ++ + CKeys=self.CKeys, + PKeys=self.PKeys, + + ).to (self.device) + else: + print ("Not implemented.") + + def forward(self, x, output, unet_number=1): #sequences=conditioning, output=prediction + + x=x.unsqueeze (2) + + x= self.fc_embed2(x) + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + pos_emb_x = torch.squeeze(pos_emb_x, 1) + pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x + pos_emb_x=pos_emb_x[:,:x.shape[1],:] + x= torch.cat( (x, pos_emb_x ), 2) + + loss = self.imagen( + output, + text_embeds = x, + unet_number = unet_number, + ) + + return loss + + def sample (self, x, stop_at_unet_number=1 ,cond_scale=7.5,): + + x=x.unsqueeze (2) + + x= self.fc_embed2(x) + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + pos_emb_x = torch.squeeze(pos_emb_x, 1) + pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x + pos_emb_x=pos_emb_x[:,:x.shape[1],:] + x= torch.cat( (x, pos_emb_x ), 2) + + output = self.imagen.sample(text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number) + + return output + +# =========================================================== +# final models: old model +# +class ProteinDesigner_A_Old(nn.Module): + def __init__( + self, + timesteps=10 , + dim=32, + pred_dim=25, + loss_type=0, + elucidated=False, + padding_idx=0, + cond_dim = 512, + text_embed_dim = 512, + input_tokens=25,#for non-BERT + sequence_embed=False, + embed_dim_position=32, + max_text_len=16, + device='cuda:0', + # ++ + max_length=64, + CKeys=None, + PKeys=None, +): + + super(ProteinDesigner_A_Old, self).__init__() + + # ++ + self.CKeys=CKeys + self.PKeys=PKeys + + self.device=device + self.pred_dim=pred_dim + self.loss_type=loss_type + + self.fc_embed1 = nn.Linear( 8, max_length) # NOT USED + self.fc_embed2 = nn.Linear( 1, text_embed_dim) # + self.max_text_len=max_text_len + + self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) + text_embed_dim=text_embed_dim+embed_dim_position + self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) + for i in range (max_text_len): + self.pos_matrix_i [i]=i +1 + + assert (loss_type==0), "Loss other than MSE not implemented" + + unet1 = OneD_Unet_Old( + dim = dim, + text_embed_dim = text_embed_dim, + cond_dim = cond_dim, + dim_mults = (1, 2, 4, 8), + + num_resnet_blocks = 1,#1, + layer_attns = (False, True, True, False), + layer_cross_attns = (False, True, True, False), + channels=self.pred_dim, + channels_out=self.pred_dim , + # + attn_dim_head = 64, + attn_heads = 8, + ff_mult = 2., + lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ + + layer_attns_depth =1, + layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 + attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) + use_linear_attn = False, + use_linear_cross_attn = False, + cond_on_text = True, + max_text_len = max_length, # need to be checked + init_dim = None, + resnet_groups = 8, + init_conv_kernel_size =7, # kernel size of initial conv, if not using cross embed + init_cross_embed = False, #TODO - fix ouput size calcs for conv1d + init_cross_embed_kernel_sizes = (3, 7, 15), + cross_embed_downsample = False, + cross_embed_downsample_kernel_sizes = (2, 4), + attn_pool_text = True, + attn_pool_num_latents = 32,#32, #perceiver model latents + dropout = 0., + memory_efficient = False, + init_conv_to_final_conv_residual = False, + use_global_context_attn = True, + scale_skip_connection = True, + final_resnet_block = True, + final_conv_kernel_size = 3, + cosine_sim_attn = True, + self_cond = False, + combine_upsample_fmaps = True, # combine feature maps from all upsample blocks, used in unet squared successfully + pixel_shuffle_upsample = False , # may address checkboard artifacts + # ++ + CKeys=CKeys, + + ).to (self.device) + # print the size + if CKeys['Debug_ModelPack']==1: + print("Check NUnet...") + params( unet1) + + assert elucidated , "Only elucidated model implemented...." + self.is_elucidated=elucidated + if elucidated: + self.imagen = ElucidatedImagen( + unets = (unet1), + channels=self.pred_dim, + channels_out=self.pred_dim , + loss_type=loss_type, + text_embed_dim = text_embed_dim, + image_sizes = [max_length], + cond_drop_prob = 0.2, + auto_normalize_img = False, + num_sample_steps = timesteps,#(64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) + sigma_min = 0.002, # min noise level + sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler + sigma_data = 0.5, # standard deviation of data distribution + rho = 7, # controls the sampling schedule + P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training + P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training + S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper + S_tmin = 0.05, + S_tmax = 50, + S_noise = 1.003, + # ++ + CKeys=self.CKeys, + PKeys=self.PKeys, + + ).to (self.device) + if CKeys['Debug_ModelPack']==1: + print("Check on EImagen:") + print("channels: ", self.pred_dim) + print("loss_type: ", loss_type) + print("text_embed_dim: ",text_embed_dim) + print("image_sizes: ", max_length) + print("num_sample_steps: ", timesteps) + print("Measure imagen:") + params( self.imagen) + print("Measure fc_embed2") + params( self.fc_embed2) + print("Measure pos_emb_x") + params( self.pos_emb_x) + else: + print ("Not implemented.") + + def forward(self, x, output, unet_number=1): #sequences=conditioning, output=prediction + + x=x.unsqueeze (2) + + x= self.fc_embed2(x) + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + pos_emb_x = torch.squeeze(pos_emb_x, 1) + pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x + pos_emb_x=pos_emb_x[:,:x.shape[1],:] + x= torch.cat( (x, pos_emb_x ), 2) + + loss = self.imagen( + output, + text_embeds = x, + unet_number = unet_number, + ) + + return loss + + def sample (self, x, stop_at_unet_number=1 ,cond_scale=7.5,): + + x=x.unsqueeze (2) + + x= self.fc_embed2(x) + + pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) + pos_emb_x = self.pos_emb_x( pos_matrix_i_) + pos_emb_x = torch.squeeze(pos_emb_x, 1) + pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x + pos_emb_x=pos_emb_x[:,:x.shape[1],:] + x= torch.cat( (x, pos_emb_x ), 2) + + output = self.imagen.sample(text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number) + + return output \ No newline at end of file