######################################################## ## Attention-Diffusion model ######################################################## #based on: https://github.com/lucidrains/imagen-pytorch import math import copy from random import random from typing import List, Union from tqdm.auto import tqdm from functools import partial, wraps from contextlib import contextmanager, nullcontext from collections import namedtuple from pathlib import Path import torch import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel from torch import nn, einsum from torch.cuda.amp import autocast from torch.special import expm1 import torchvision.transforms as T import kornia.augmentation as K from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange, Reduce from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts.torch import EinopsToAndFrom # from tensorflow.keras.preprocessing import text, sequence from tensorflow.keras.preprocessing.text import Tokenizer # # -- # from PD_SpLMxDiff.UtilityPack import prepare_UNet_keys, modify_keys, params # ++ from PD_pLMProbXDiff.UtilityPack import ( prepare_UNet_keys, modify_keys, params ) # from torchinfo import summary import json # quick way to identify the device: device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu" ) print('identify the device independently', device) # ================================================== # helper functions # ================================================== def exists(val): return val is not None def identity(t, *args, **kwargs): return t def first(arr, d = None): if len(arr) == 0: return d return arr[0] def maybe(fn): @wraps(fn) def inner(x): if not exists(x): return x return fn(x) return inner def once(fn): called = False @wraps(fn) def inner(x): nonlocal called if called: return called = True return fn(x) return inner print_once = once(print) def default(val, d): if exists(val): return val return d() if callable(d) else d def cast_tuple(val, length = None): if isinstance(val, list): val = tuple(val) output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) if exists(length): assert len(output) == length return output def is_float_dtype(dtype): return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)]) def cast_uint8_images_to_float(images): if not images.dtype == torch.uint8: return images return images / 255 def module_device(module): return next(module.parameters()).device def zero_init_(m): nn.init.zeros_(m.weight) if exists(m.bias): nn.init.zeros_(m.bias) def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training model.eval() out = fn(model, *args, **kwargs) model.train(was_training) return out return inner def pad_tuple_to_length(t, length, fillvalue = None): remain_length = length - len(t) if remain_length <= 0: return t return (*t, *((fillvalue,) * remain_length)) # ================================================== # helper classes # ================================================== class Identity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, x, *args, **kwargs): return x # ================================================== # tensor helpers # ================================================== def log(t, eps: float = 1e-12): return torch.log(t.clamp(min = eps)) def l2norm(t): return F.normalize(t, dim = -1) def right_pad_dims_to(x, t): padding_dims = x.ndim - t.ndim if padding_dims <= 0: return t return t.view(*t.shape, *((1,) * padding_dims)) def masked_mean(t, *, dim, mask = None): if not exists(mask): return t.mean(dim = dim) denom = mask.sum(dim = dim, keepdim = True) mask = rearrange(mask, 'b n -> b n 1') masked_t = t.masked_fill(~mask, 0.) return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) def resize_image_to( image, target_image_size, clamp_range = None ): orig_image_size = image.shape[-1] if orig_image_size == target_image_size: return image out = F.interpolate(image.float(), target_image_size, mode = 'linear', align_corners = True) return out # image normalization functions # ddpms expect images to be in the range of -1 to 1 # def normalize_neg_one_to_one(img): return img * 2 - 1 def unnormalize_zero_to_one(normed_img): return (normed_img + 1) * 0.5 # classifier free guidance functions def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device = device, dtype = torch.bool) elif prob == 0: return torch.zeros(shape, device = device, dtype = torch.bool) else: return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob # ================================================== # gaussian diffusion with continuous time helper functions and classes # large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py # ================================================== @torch.jit.script def beta_linear_log_snr(t): return -torch.log(expm1(1e-4 + 10 * (t ** 2))) @torch.jit.script def alpha_cosine_log_snr(t, s: float = 0.008): return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version def log_snr_to_alpha_sigma(log_snr): return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr)) class GaussianDiffusionContinuousTimes(nn.Module): def __init__(self, *, noise_schedule, timesteps = 1000): super().__init__() if noise_schedule == "linear": self.log_snr = beta_linear_log_snr elif noise_schedule == "cosine": self.log_snr = alpha_cosine_log_snr else: raise ValueError(f'invalid noise schedule {noise_schedule}') self.num_timesteps = timesteps def get_times(self, batch_size, noise_level, *, device): return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32) def sample_random_times(self, batch_size, max_thres = 0.999, *, device): return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres) def get_condition(self, times): return maybe(self.log_snr)(times) def get_sampling_timesteps(self, batch, *, device): times = torch.linspace(1., 0., self.num_timesteps + 1, device = device) times = repeat(times, 't -> b t', b = batch) times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0) times = times.unbind(dim = -1) return times def q_posterior(self, x_start, x_t, t, *, t_next = None): t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.)) """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """ log_snr = self.log_snr(t) log_snr_next = self.log_snr(t_next) log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next)) alpha, sigma = log_snr_to_alpha_sigma(log_snr) alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next) # c - as defined near eq 33 c = -expm1(log_snr - log_snr_next) posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start) # following (eq. 33) posterior_variance = (sigma_next ** 2) * c posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20) return posterior_mean, posterior_variance, posterior_log_variance_clipped def q_sample(self, x_start, t, noise = None): dtype = x_start.dtype if isinstance(t, float): batch = x_start.shape[0] t = torch.full((batch,), t, device = x_start.device, dtype = dtype) noise = default(noise, lambda: torch.randn_like(x_start)) log_snr = self.log_snr(t).type(dtype) log_snr_padded_dim = right_pad_dims_to(x_start, log_snr) alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) return alpha * x_start + sigma * noise, log_snr def q_sample_from_to(self, x_from, from_t, to_t, noise = None): shape, device, dtype = x_from.shape, x_from.device, x_from.dtype batch = shape[0] if isinstance(from_t, float): from_t = torch.full((batch,), from_t, device = device, dtype = dtype) if isinstance(to_t, float): to_t = torch.full((batch,), to_t, device = device, dtype = dtype) noise = default(noise, lambda: torch.randn_like(x_from)) log_snr = self.log_snr(from_t) log_snr_padded_dim = right_pad_dims_to(x_from, log_snr) alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) log_snr_to = self.log_snr(to_t) log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to) alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to) return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha def predict_start_from_noise(self, x_t, t, noise): log_snr = self.log_snr(t) log_snr = right_pad_dims_to(x_t, log_snr) alpha, sigma = log_snr_to_alpha_sigma(log_snr) return (x_t - sigma * noise) / alpha.clamp(min = 1e-8) # ================================================== # norms and residuals # ================================================== class LayerNorm(nn.Module): def __init__(self, feats, stable = False, dim = -1): super().__init__() self.stable = stable self.dim = dim self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1)))) def forward(self, x): dtype, dim = x.dtype, self.dim if self.stable: x = x / x.amax(dim = dim, keepdim = True).detach() eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim = dim, unbiased = False, keepdim = True) mean = torch.mean(x, dim = dim, keepdim = True) return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype) ChanLayerNorm = partial(LayerNorm, dim = -2) class Always(): def __init__(self, val): self.val = val def __call__(self, *args, **kwargs): return self.val class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) + x class Parallel(nn.Module): def __init__(self, *fns): super().__init__() self.fns = nn.ModuleList(fns) def forward(self, x): outputs = [fn(x) for fn in self.fns] return sum(outputs) # ================================================== # attention pooling # ================================================== # modified class PerceiverAttention(nn.Module): def __init__( self, *, dim, dim_head = 64, heads = 8, cosine_sim_attn = False ): super().__init__() self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1 self.cosine_sim_attn = cosine_sim_attn self.cosine_sim_scale = 16 if cosine_sim_attn else 1 self.heads = heads inner_dim = dim_head * heads self.norm = nn.LayerNorm(dim) self.norm_latents = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), nn.LayerNorm(dim) ) def forward(self, x, latents, mask = None): x = self.norm(x) latents = self.norm_latents(latents) b, h = x.shape[0], self.heads q = self.to_q(latents) # the paper differs from Perceiver in which they also concat the key / values # derived from the latents to be attended to kv_input = torch.cat((x, latents), dim = -2) k, v = self.to_kv(kv_input).chunk(2, dim = -1) q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h) q = q * self.scale # cosine sim attention if self.cosine_sim_attn: q, k = map(l2norm, (q, k)) # similarities and masking sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale if exists(mask): max_neg_value = -torch.finfo(sim.dtype).max mask = F.pad(mask, (0, latents.shape[-2]), value = True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) # attention attn = sim.softmax(dim = -1, dtype = torch.float32) attn = attn.to(sim.dtype) out = einsum('... i j, ... j d -> ... i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)', h = h) return self.to_out(out) class PerceiverResampler(nn.Module): def __init__( self, *, dim, depth, dim_head = 64, heads = 8, num_latents = 64, num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence max_seq_len = 512, ff_mult = 4, cosine_sim_attn = False ): super().__init__() self.pos_emb = nn.Embedding(max_seq_len, dim) self.latents = nn.Parameter(torch.randn(num_latents, dim)) self.to_latents_from_mean_pooled_seq = None if num_latents_mean_pooled > 0: self.to_latents_from_mean_pooled_seq = nn.Sequential( LayerNorm(dim), nn.Linear(dim, dim * num_latents_mean_pooled), Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled) ) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn), FeedForward(dim = dim, mult = ff_mult) ])) def forward(self, x, mask = None): n, device = x.shape[1], x.device pos_emb = self.pos_emb(torch.arange(n, device = device)) x_with_pos = x + pos_emb latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0]) if exists(self.to_latents_from_mean_pooled_seq): meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)) meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) latents = torch.cat((meanpooled_latents, latents), dim = -2) for attn, ff in self.layers: latents = attn(x_with_pos, latents, mask = mask) + latents latents = ff(latents) + latents return latents # ==================================================== # attention # ==================================================== class Attention(nn.Module): def __init__( self, dim, *, dim_head = 64, heads = 8, context_dim = None, cosine_sim_attn = False ): super().__init__() self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1. self.cosine_sim_attn = cosine_sim_attn self.cosine_sim_scale = 16 if cosine_sim_attn else 1 self.heads = heads inner_dim = dim_head * heads self.norm = LayerNorm(dim) self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), LayerNorm(dim) ) def forward(self, x, context = None, mask = None, attn_bias = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) q = q * self.scale # add null key / value for classifier free guidance in prior net nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) # add text conditioning, if present if exists(context): assert exists(self.to_context) ck, cv = self.to_context(context).chunk(2, dim = -1) k = torch.cat((ck, k), dim = -2) v = torch.cat((cv, v), dim = -2) # cosine sim attention if self.cosine_sim_attn: q, k = map(l2norm, (q, k)) # calculate query / key similarities sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale # relative positional encoding (T5 style) if exists(attn_bias): sim = sim + attn_bias # masking max_neg_value = -torch.finfo(sim.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b j -> b 1 j') sim = sim.masked_fill(~mask, max_neg_value) # attention attn = sim.softmax(dim = -1, dtype = torch.float32) attn = attn.to(sim.dtype) # aggregate values out = einsum('b h i j, b j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) # ============================================================ # decoder # ============================================================ def Upsample(dim, dim_out = None): dim_out = default(dim_out, dim) return nn.Sequential( nn.Upsample(scale_factor = 2, mode = 'nearest'), nn.Conv1d(dim, dim_out, 3, padding = 1) ) class PixelShuffleUpsample(nn.Module): """ code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf """ def __init__(self, dim, dim_out = None): super().__init__() dim_out = default(dim_out, dim) conv = nn.Conv1d(dim, dim_out * 4, 1) self.net = nn.Sequential( conv, nn.SiLU(), nn.PixelShuffle(2) ) self.init_conv_(conv) def init_conv_(self, conv): o, i, h = conv.weight.shape conv_weight = torch.empty(o // 4, i, h ) nn.init.kaiming_uniform_(conv_weight) conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') conv.weight.data.copy_(conv_weight) nn.init.zeros_(conv.bias.data) def forward(self, x): return self.net(x) def Downsample(dim, dim_out = None): # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample # named SP-conv in the paper, but basically a pixel unshuffle dim_out = default(dim_out, dim) return nn.Sequential( Rearrange('b c (h s1) -> b (c s1) h', s1 = 2), nn.Conv1d(dim * 2, dim_out, 1) ) class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') return torch.cat((emb.sin(), emb.cos()), dim = -1) class LearnedSinusoidalPosEmb(nn.Module): """ following @crowsonkb 's lead with learned sinusoidal pos emb """ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ def __init__(self, dim): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, x): x = rearrange(x, 'b -> b 1') freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) fouriered = torch.cat((x, fouriered), dim = -1) return fouriered class Block(nn.Module): def __init__( self, dim, dim_out, groups = 8, norm = True ): super().__init__() self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() self.activation = nn.SiLU() self.project = nn.Conv1d(dim, dim_out, 3, padding = 1) def forward(self, x, scale_shift = None): x = self.groupnorm(x) if exists(scale_shift): scale, shift = scale_shift x = x * (scale + 1) + shift x = self.activation(x) return self.project(x) class ResnetBlock(nn.Module): def __init__( self, dim, dim_out, *, cond_dim = None, time_cond_dim = None, groups = 8, linear_attn = False, use_gca = False, squeeze_excite = False, **attn_kwargs ): super().__init__() self.time_mlp = None if exists(time_cond_dim): self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_cond_dim, dim_out * 2) ) self.cross_attn = None if exists(cond_dim): attn_klass = CrossAttention if not linear_attn else LinearCrossAttention self.cross_attn = EinopsToAndFrom( 'b c h ', 'b h c', attn_klass( dim = dim_out, context_dim = cond_dim, **attn_kwargs ) ) self.block1 = Block(dim, dim_out, groups = groups) self.block2 = Block(dim_out, dim_out, groups = groups) self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1) self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else Identity() def forward(self, x, time_emb = None, cond = None): scale_shift = None if exists(self.time_mlp) and exists(time_emb): time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1') scale_shift = time_emb.chunk(2, dim = 1) h = self.block1(x) if exists(self.cross_attn): assert exists(cond) h = self.cross_attn(h, context = cond) + h h = self.block2(h, scale_shift = scale_shift) h = h * self.gca(h) return h + self.res_conv(x) class CrossAttention(nn.Module): def __init__( self, dim, *, context_dim = None, dim_head = 64, heads = 8, norm_context = False, cosine_sim_attn = False ): super().__init__() self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1. self.cosine_sim_attn = cosine_sim_attn self.cosine_sim_scale = 16 if cosine_sim_attn else 1 self.heads = heads inner_dim = dim_head * heads context_dim = default(context_dim, dim) self.norm = LayerNorm(dim) self.norm_context = LayerNorm(context_dim) if norm_context else Identity() self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), LayerNorm(dim) ) def forward(self, x, context, mask = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads) # add null key / value for classifier free guidance in prior net nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) q = q * self.scale # cosine sim attention if self.cosine_sim_attn: q, k = map(l2norm, (q, k)) # similarities sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale # masking max_neg_value = -torch.finfo(sim.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b j -> b 1 j') sim = sim.masked_fill(~mask, max_neg_value) attn = sim.softmax(dim = -1, dtype = torch.float32) attn = attn.to(sim.dtype) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class LinearCrossAttention(CrossAttention): def forward(self, x, context, mask = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads) # add null key / value for classifier free guidance in prior net nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads, b = b) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) # masking max_neg_value = -torch.finfo(x.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b n -> b n 1') k = k.masked_fill(~mask, max_neg_value) v = v.masked_fill(~mask, 0.) # linear attention q = q.softmax(dim = -1) k = k.softmax(dim = -2) q = q * self.scale context = einsum('b n d, b n e -> b d e', k, v) out = einsum('b n d, b d e -> b n e', q, context) out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads) return self.to_out(out) class LinearAttention(nn.Module): def __init__( self, dim, dim_head = 32, heads = 8, dropout = 0.05, context_dim = None, **kwargs ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads inner_dim = dim_head * heads self.norm = ChanLayerNorm(dim) self.nonlin = nn.SiLU() self.to_q = nn.Sequential( nn.Dropout(dropout), nn.Conv1d(dim, inner_dim, 1, bias = False), nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) ) self.to_k = nn.Sequential( nn.Dropout(dropout), nn.Conv1d(dim, inner_dim, 1, bias = False), nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) ) self.to_v = nn.Sequential( nn.Dropout(dropout), nn.Conv1d(dim, inner_dim, 1, bias = False), nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) ) self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None self.to_out = nn.Sequential( nn.Conv1d(inner_dim, dim, 1, bias = False), ChanLayerNorm(dim) ) def forward(self, fmap, context = None): h, x, y = self.heads, *fmap.shape[-2:] fmap = self.norm(fmap) q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v)) q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h) if exists(context): assert exists(self.to_context) ck, cv = self.to_context(context).chunk(2, dim = -1) ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h) k = torch.cat((k, ck), dim = -2) v = torch.cat((v, cv), dim = -2) q = q.softmax(dim = -1) k = k.softmax(dim = -2) q = q * self.scale context = einsum('b n d, b n e -> b d e', k, v) out = einsum('b n d, b d e -> b n e', q, context) out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) out = self.nonlin(out) return self.to_out(out) class GlobalContext(nn.Module): """ basically a superior form of squeeze-excitation that is attention-esque """ def __init__( self, *, dim_in, dim_out ): super().__init__() self.to_k = nn.Conv1d(dim_in, 1, 1) hidden_dim = max(3, dim_out // 2) self.net = nn.Sequential( nn.Conv1d(dim_in, hidden_dim, 1), nn.SiLU(), nn.Conv1d(hidden_dim, dim_out, 1), nn.Sigmoid() ) def forward(self, x): context = self.to_k(x) x, context = rearrange_many((x, context), 'b n ... -> b n (...)') out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x) return self.net(out) def FeedForward(dim, mult = 2): hidden_dim = int(dim * mult) return nn.Sequential( LayerNorm(dim), nn.Linear(dim, hidden_dim, bias = False), nn.GELU(), LayerNorm(hidden_dim), nn.Linear(hidden_dim, dim, bias = False) ) def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width hidden_dim = int(dim * mult) return nn.Sequential( ChanLayerNorm(dim), nn.Conv1d(dim, hidden_dim, 1, bias = False), nn.GELU(), ChanLayerNorm(hidden_dim), nn.Conv1d(hidden_dim, dim, 1, bias = False) ) class TransformerBlock(nn.Module): def __init__( self, dim, *, depth = 1, heads = 8, dim_head = 32, ff_mult = 2, context_dim = None, cosine_sim_attn = False ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ EinopsToAndFrom('b c h', 'b h c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)), ChanFeedForward(dim = dim, mult = ff_mult) ])) def forward(self, x, context = None): for attn, ff in self.layers: x = attn(x, context = context) + x x = ff(x) + x return x class LinearAttentionTransformerBlock(nn.Module): def __init__( self, dim, *, depth = 1, heads = 8, dim_head = 32, ff_mult = 2, context_dim = None, **kwargs ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), ChanFeedForward(dim = dim, mult = ff_mult) ])) def forward(self, x, context = None): for attn, ff in self.layers: x = attn(x, context = context) + x x = ff(x) + x return x class CrossEmbedLayer(nn.Module): def __init__( self, dim_in, kernel_sizes, dim_out = None, stride = 2 ): super().__init__() assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) dim_out = default(dim_out, dim_in) kernel_sizes = sorted(kernel_sizes) num_scales = len(kernel_sizes) # calculate the dimension at each scale dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] dim_scales = [*dim_scales, dim_out - sum(dim_scales)] self.convs = nn.ModuleList([]) for kernel, dim_scale in zip(kernel_sizes, dim_scales): self.convs.append(nn.Conv1d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2)) def forward(self, x): fmaps = tuple(map(lambda conv: conv(x), self.convs)) return torch.cat(fmaps, dim = 1) class UpsampleCombiner(nn.Module): def __init__( self, dim, *, enabled = False, dim_ins = tuple(), dim_outs = tuple() ): super().__init__() dim_outs = cast_tuple(dim_outs, len(dim_ins)) assert len(dim_ins) == len(dim_outs) self.enabled = enabled if not self.enabled: self.dim_out = dim return self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)]) self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0) def forward(self, x, fmaps = None): target_size = x.shape[-1] fmaps = default(fmaps, tuple()) if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0: return x fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps] outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)] return torch.cat((x, *outs), dim = 1) ######################################################## ## 1D Unet updated ######################################################## class OneD_Unet(nn.Module): def __init__( self, *, # +++++++++++++++++++++++++++ CKeys=None, PKeys=None, ): super().__init__() # save locals to take care of some hyperparameters for cascading DDPM self._locals = locals() self._locals.pop('self', None) self._locals.pop('__class__', None) # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # unload the parameters # print('I am here') if CKeys['Debug_ModelPack']==1: print(json.dumps(PKeys, indent=4)) dim = PKeys['dim'] # , # image_embed_dim = 1024, # NOT used so far text_embed_dim = default(PKeys['text_embed_dim'], 768) # 768, # get_encoded_dim(DEFAULT_T5_NAME) num_resnet_blocks = default(PKeys['num_resnet_blocks'], 1)# 1, cond_dim = default(PKeys['cond_dim'], None) # None, num_image_tokens = default(PKeys['num_image_tokens'], 4) # 4, num_time_tokens = default(PKeys['num_time_tokens'], 2) # 2, learned_sinu_pos_emb_dim = default(PKeys['learned_sinu_pos_emb_dim'], 16) # 16, out_dim = default(PKeys['out_dim'], None) # None, dim_mults = default(PKeys['dim_mults'], (1, 2, 4, 8)) # (1, 2, 4, 8), cond_images_channels = default(PKeys['cond_images_channels'], 0) # 0, channels = default(PKeys['channels'], 3) # 3, channels_out = default(PKeys['channels_out'], None) # None, attn_dim_head = default(PKeys['attn_dim_head'], 64) # 64, attn_heads = default(PKeys['attn_heads'], 8) # 8, ff_mult = default(PKeys['ff_mult'], 2.) # 2., lowres_cond = default(PKeys['lowres_cond'], False) # False, # for cascading diffusion - https://cascaded-diffusion.github.io/ layer_attns = default(PKeys['layer_attns'], True) # True, layer_attns_depth = default(PKeys['layer_attns_depth'], 1) # 1, layer_attns_add_text_cond = default(PKeys['layer_attns_add_text_cond'], True) # True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 attend_at_middle = default(PKeys['attend_at_middle'], True) # True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) layer_cross_attns = default(PKeys['layer_cross_attns'], True) # True, use_linear_attn = default(PKeys['use_linear_attn'], False) # False, use_linear_cross_attn = default(PKeys['use_linear_cross_attn'], False) # False, cond_on_text = default(PKeys['cond_on_text'], True) # True, max_text_len = default(PKeys['max_text_len'], 256) # 256, init_dim = default(PKeys['init_dim'], None) # None, resnet_groups = default(PKeys['resnet_groups'], 8) # 8, init_conv_kernel_size = default(PKeys['init_conv_kernel_size'], 7) # 7, # kernel size of initial conv, if not using cross embed init_cross_embed = default(PKeys['init_cross_embed'], False) # False, init_cross_embed_kernel_sizes = default(PKeys['init_cross_embed_kernel_sizes'], (3, 7, 15)) # (3, 7, 15), cross_embed_downsample = default(PKeys['cross_embed_downsample'], False) # False, cross_embed_downsample_kernel_sizes = default(PKeys['cross_embed_downsample_kernel_sizes'], (2,4)) # (2, 4), attn_pool_text = default(PKeys['attn_pool_text'], True) # True, attn_pool_num_latents = default(PKeys['attn_pool_num_latents'], 32) # 32, dropout = default(PKeys['dropout'], 0.) # 0., memory_efficient = default(PKeys['memory_efficient'], False) # False, init_conv_to_final_conv_residual = default(PKeys['init_conv_to_final_conv_residual'], False) # False, use_global_context_attn = default(PKeys['use_global_context_attn'], True) # True, scale_skip_connection = default(PKeys['scale_skip_connection'], True) # True, final_resnet_block = default(PKeys['final_resnet_block'], True) # True, final_conv_kernel_size = default(PKeys['final_conv_kernel_size'], 3) # 3, cosine_sim_attn = default(PKeys['cosine_sim_attn'], False) # False, self_cond = default(PKeys['self_cond'], False) # False, combine_upsample_fmaps = default(PKeys['combine_upsample_fmaps'], False) # False, # combine feature maps from all upsample blocks, used in unet squared successfully pixel_shuffle_upsample = default(PKeys['pixel_shuffle_upsample'], False) # False , # may address checkboard artifacts beginning_and_final_conv_present = default(PKeys['beginning_and_final_conv_present'], True) # True , #TODO add cross-attn, doesnt work yet...whether or not to have final conv layer # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< self.CKeys=CKeys # + for debug if CKeys['Debug_ModelPack']==1: print("Check the inputs:") print(json.dumps(PKeys, indent=4)) assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8' if dim < 128: print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/') # determine dimensions self.channels = channels self.channels_out = default(channels_out, channels) # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis # (2) in self conditioning, one appends the predict x0 (x_start) init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) init_dim = default(init_dim, dim) self.self_cond = self_cond # optional image conditioning self.has_cond_image = cond_images_channels > 0 self.cond_images_channels = cond_images_channels init_channels += cond_images_channels self.beginning_and_final_conv_present=beginning_and_final_conv_present if self.beginning_and_final_conv_present: self.init_conv = CrossEmbedLayer( init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1 ) if init_cross_embed else nn.Conv1d( init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) if self.CKeys['Debug_ModelPack']==1 and self.beginning_and_final_conv_present: print("On self.init_conv:") print(f"init_channels: {str(init_channels)}\n init_dim: {str(init_dim)}") print("On self.init_conv, batch#=1, init_ch=2*#seq_channel, seq_len=128") summary(self.init_conv, (1, init_channels, 128), verbose=1) # # + for debug # if self.CKeys['Debug_ModelPack']==1: # print("On self.init_conv") # summary(self.init_conv, (1, init_channels, init_dim)) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # time conditioning cond_dim = default(cond_dim, dim) time_cond_dim = dim * 4 * (2 if lowres_cond else 1) # embedding time for log(snr) noise from continuous version sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 self.to_time_hiddens = nn.Sequential( sinu_pos_emb, nn.Linear(sinu_pos_emb_input_dim, time_cond_dim), nn.SiLU() ) # # + for debug # if self.CKeys['Debug_ModelPack']==1: # print("On self.to_time_hiddens") # print(" in dim: ", learned_sinu_pos_emb_dim) # print(" ou dim: ", time_cond_dim) # # summary(self.to_time_hiddens, (1,learned_sinu_pos_emb_dim), verbose=1) # # summary(sinu_pos_emb, (1,learned_sinu_pos_emb_dim), verbose=1) self.to_time_cond = nn.Sequential( nn.Linear(time_cond_dim, time_cond_dim) ) # project to time tokens as well as time hiddens # + for debug if self.CKeys['Debug_ModelPack']==1: print(" to_time_tokens") print(" in dim: ", time_cond_dim) # =4xdim cond_dim=dim print(" ou dim: ", num_time_tokens) self.to_time_tokens = nn.Sequential( nn.Linear(time_cond_dim, cond_dim * num_time_tokens), Rearrange('b (r d) -> b r d', r = num_time_tokens) ) # low res aug noise conditioning self.lowres_cond = lowres_cond if lowres_cond: self.to_lowres_time_hiddens = nn.Sequential( LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim), nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim), nn.SiLU() ) self.to_lowres_time_cond = nn.Sequential( nn.Linear(time_cond_dim, time_cond_dim) ) self.to_lowres_time_tokens = nn.Sequential( nn.Linear(time_cond_dim, cond_dim * num_time_tokens), Rearrange('b (r d) -> b r d', r = num_time_tokens) ) # normalizations self.norm_cond = nn.LayerNorm(cond_dim) # text encoding conditioning (optional) self.text_to_cond = None if cond_on_text: #only add linear lear if cond dim is not text emnd dim assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True' if text_embed_dim != cond_dim: self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) self.text_cond_linear=True else: print ("Text conditioning is equatl to cond_dim - no linear layer used") self.text_cond_linear=False # finer control over whether to condition on text encodings self.cond_on_text = cond_on_text # attention pooling self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents, cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None # for classifier free guidance self.max_text_len = max_text_len self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim)) # for non-attention based text conditioning at all points in the network where time is also conditioned self.to_text_non_attn_cond = None if cond_on_text: self.to_text_non_attn_cond = nn.Sequential( nn.LayerNorm(cond_dim), nn.Linear(cond_dim, time_cond_dim), nn.SiLU(), nn.Linear(time_cond_dim, time_cond_dim) ) # attention related params attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn) num_layers = len(in_out) # resnet block klass num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) resnet_groups = cast_tuple(resnet_groups, num_layers) resnet_klass = partial(ResnetBlock, **attn_kwargs) layer_attns = cast_tuple(layer_attns, num_layers) layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) use_linear_attn = cast_tuple(use_linear_attn, num_layers) use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers) assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) # downsample klass downsample_klass = Downsample if cross_embed_downsample: downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) # initial resnet block (for memory efficient unet) self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None # scale for resnet skip connections self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn] reversed_layer_params = list(map(reversed, layer_params)) # downsampling layers skip_connect_dims = [] # keep track of skip connection dimensions for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)): is_last = ind >= (num_resolutions - 1) layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None if layer_attn: transformer_block_klass = TransformerBlock elif layer_use_linear_attn: transformer_block_klass = LinearAttentionTransformerBlock else: transformer_block_klass = Identity current_dim = dim_in # whether to pre-downsample, from memory efficient unet pre_downsample = None if memory_efficient: pre_downsample = downsample_klass(dim_in, dim_out) current_dim = dim_out skip_connect_dims.append(current_dim) # whether to do post-downsample, for non-memory efficient unet post_downsample = None if not memory_efficient: post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv1d(dim_in, dim_out, 3, padding = 1), nn.Conv1d(dim_in, dim_out, 1)) self.downs.append( nn.ModuleList([ pre_downsample, resnet_klass( current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups ), nn.ModuleList([ ResnetBlock( current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn ) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass( dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), post_downsample ]) ) # middle layers mid_dim = dims[-1] self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_attn = EinopsToAndFrom('b c h', 'b h c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) # upsample klass upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample # upsampling layers upsample_fmap_dims = [] for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)): is_last = ind == (len(in_out) - 1) layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None if layer_attn: transformer_block_klass = TransformerBlock elif layer_use_linear_attn: transformer_block_klass = LinearAttentionTransformerBlock else: transformer_block_klass = Identity skip_connect_dim = skip_connect_dims.pop() upsample_fmap_dims.append(dim_out) self.ups.append(nn.ModuleList([ resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() ])) # whether to combine feature maps from all upsample blocks before final resnet block out self.upsample_combiner = UpsampleCombiner( dim = dim, enabled = combine_upsample_fmaps, dim_ins = upsample_fmap_dims, dim_outs = dim ) # whether to do a final residual from initial conv to the final resnet block out self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0) # final optional resnet block and convolution out self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None final_conv_dim_in = dim if final_resnet_block else final_conv_dim final_conv_dim_in += (channels if lowres_cond else 0) if self.beginning_and_final_conv_present: print (final_conv_dim_in, self.channels_out) self.final_conv = nn.Conv1d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2) if self.beginning_and_final_conv_present: zero_init_(self.final_conv) # if the current settings for the unet are not correct # for cascading DDPM, then reinit the unet with the right settings def cast_model_parameters( self, *, lowres_cond, text_embed_dim, channels, channels_out, cond_on_text ): # # ---------------------------------------------- # if lowres_cond == self.lowres_cond and \ # channels == self.channels and \ # cond_on_text == self.cond_on_text and \ # text_embed_dim == self._locals['text_embed_dim'] and \ # channels_out == self.channels_out: # return self # +++++++++++++++++++++++++++++++++++++++++++++++ if lowres_cond == self.lowres_cond and \ channels == self.channels and \ cond_on_text == self.cond_on_text and \ text_embed_dim == self._locals['PKeys']['text_embed_dim'] and \ channels_out == self.channels_out: return self # # -------------------------------------- # updated_kwargs = dict( # lowres_cond = lowres_cond, # text_embed_dim = text_embed_dim, # channels = channels, # channels_out = channels_out, # cond_on_text = cond_on_text # ) # ++++++++++++++++++++++++++++++++++++++ write_key=dict( lowres_cond = lowres_cond, text_embed_dim = text_embed_dim, channels = channels, channels_out = channels_out, cond_on_text = cond_on_text ) # this_PKeys=prepare_UNet_keys(write_key) old_PKeys=self._locals['PKeys'] this_PKeys=modify_keys(old_PKeys, write_key) # write the new arguments updated_kwargs = dict( CKeys=self.CKeys, PKeys=this_PKeys, ) return self.__class__(**{**self._locals, **updated_kwargs}) # methods for returning the full unet config as well as its parameter state def to_config_and_state_dict(self): return self._locals, self.state_dict() # class method for rehydrating the unet from its config and state dict @classmethod def from_config_and_state_dict(klass, config, state_dict): unet = klass(**config) unet.load_state_dict(state_dict) return unet # methods for persisting unet to disk def persist_to_file(self, path): path = Path(path) path.parents[0].mkdir(exist_ok = True, parents = True) config, state_dict = self.to_config_and_state_dict() pkg = dict(config = config, state_dict = state_dict) torch.save(pkg, str(path)) # class method for rehydrating the unet from file saved with `persist_to_file` @classmethod def hydrate_from_file(klass, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) assert 'config' in pkg and 'state_dict' in pkg config, state_dict = pkg['config'], pkg['state_dict'] return Unet.from_config_and_state_dict(config, state_dict) # forward with classifier free guidance def forward_with_cond_scale( self, *args, cond_scale = 1., **kwargs ): logits = self.forward(*args, **kwargs) if cond_scale == 1: return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale # forward fun returns the loss # Here, for Model B: # x : (batch, 1, seq_len) values: from noised image/Y # time : (batch) values: # text_embeds : None # text_mask: None # cond_images: input sequence/X # lowres_noise_times: None # lowres_cond_img: None # cond_drop_prob: 0.1 # self_cond: from noised image/Y def forward( self, x, time, *, lowres_cond_img = None, lowres_noise_times = None, text_embeds = None, text_mask = None, cond_images = None, self_cond = None, cond_drop_prob = 0. ): # + for debug if self.CKeys['Debug_ModelPack']==1: print("========================================") print("Here are OneD_Unet:forward") ii=0 batch_size, device = x.shape[0], x.device # condition on self # + for debug if self.CKeys['Debug_ModelPack']==1: print("Check inputs: ") print(" x-0 .dim: ", x.shape, f"[batch, {str(self.channels)}, seq_len]") print(" time .dim: ", time.shape, "batch") if lowres_cond_img==None: print(" lowres_cond_img: None") else: print(" lowres_cond_img.dim: ", lowres_cond_img.shape) if lowres_noise_times==None: print(" lowres_noise_times: None") else: print(" lowres_noise_times.dim: ", lowres_noise_times.shape) if cond_images==None: print(" cond_images dim: None") else: # Model B is used print(" cond_images dim: ", cond_images.shape, f"[batch, {str(self.cond_images_channels)}, seq_len]") if text_embeds==None: print(" text_embeds.dim: None") else: # Model A is used print(" text_embeds.dim: ", text_embeds.shape) if self_cond==None: print(" self_cond: None") else: print(" self_cond.dim: ", self_cond.shape) print("\n\n") if self.self_cond: self_cond = default(self_cond, lambda: torch.zeros_like(x)) x = torch.cat((x, self_cond), dim = 1) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(ii) print("self_cond dim: ", self_cond.shape) print("After cat(x, self_cond)-> x. dim: ", x.shape) # add low resolution conditioning, if present assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present' if exists(lowres_cond_img): x = torch.cat((x, lowres_cond_img), dim = 1) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(ii) print("lowres_cond_img dim: ", lowres_cond_img.shape) print("After cat(x, lowres_cond_img) dim: ", x.shape) # condition on input image assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' if exists(cond_images): assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' cond_images = resize_image_to(cond_images, x.shape[-1]) # ++++++++++++++++++++++++++++++ # add cond_images (from X) into generation process... x = torch.cat((cond_images.to(device), x.to(device)), dim = 1) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(ii) print("cond_images dim: ", cond_images.shape, f"[batch, {str(self.cond_images_channels)}, max_seq_len]") print("After cat(cond_images, x), x dim: ", x.shape, "[batch, cond_images_channels+images_channels, max_seq_len]") # initial convolution if self.beginning_and_final_conv_present: x = self.init_conv(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(ii) print("After self.init_conv(x)-> x dim: ", x.shape, "[batch, UNet:dim, max_seq_len]") # init conv residual if self.init_conv_to_final_conv_residual: init_conv_residual = x.clone() # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(ii) print("x.clone()->init_conv_resi, dim: ", init_conv_residual.shape) # time conditioning time_hiddens = self.to_time_hiddens(time) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(ii) print("time dim: ", time.shape, "[batch]") print("self.to_time_hiddens(time)-> time_hiddens .dim: ", time_hiddens.shape, "[batch, 4xUNet:dim]") # derive time tokens time_tokens = self.to_time_tokens(time_hiddens) t = self.to_time_cond(time_hiddens) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(ii) print("self.to_time_tokens(time_hiddens)-> time_tokens dim: ", time_tokens.shape, "[batch, num_time_tokens,4xdim/num_time_tokens]") print("self.to_time_cond(time_hiddens)-> t dim: ", t.shape, "[batch, 4xUNet:dim]") # add lowres time conditioning to time hiddens # and add lowres time tokens along sequence dimension for attention if self.lowres_cond: lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times) lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens) lowres_t = self.to_lowres_time_cond(lowres_time_hiddens) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(ii) print("self.to_lowres_time_hiddens(lowres_noise_times)-> lowres_time_hiddens .dim: ", lowres_time_hiddens.shape) print("self.to_lowres_time_tokens(lowres_time_hiddens)-> lowres_time_tokens .dim: ", lowres_time_tokens.shape) print("self.to_lowres_time_cond(lowres_time_hiddens)-> lowres_t.dim: ", lowres_t.shape) t = t + lowres_t time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2) # + for debug if self.CKeys['Debug_ModelPack']==1: ii +=1 print(ii) print("After cat(time_tokens, lowres_time_tokens)-> time_tokens dim: ", time_tokens.shape) # text conditioning text_tokens = None # + for debug if self.CKeys['Debug_ModelPack']==1: print("UNet.cond_on_text: ", self.cond_on_text) if exists(text_embeds) and self.cond_on_text: # conditional dropout text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1') text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1') # calculate text embeds if self.text_cond_linear: text_tokens = self.text_to_cond(text_embeds) else: text_tokens=text_embeds if self.CKeys['Debug_ModelPack']==1: print("On text conditioning part...") ii +=1 print(ii) print("text_embeds->text_tokens.dim: ", text_tokens.shape) text_tokens = text_tokens[:, :self.max_text_len] if self.CKeys['Debug_ModelPack']==1: ii +=1 print(ii) print("text_tokens[:,:max_text_len]-> text_tokens.dim: ", text_tokens.shape) print("do the same for text_mask") if exists(text_mask): text_mask = text_mask[:, :self.max_text_len] text_tokens_len = text_tokens.shape[1] remainder = self.max_text_len - text_tokens_len if remainder > 0: text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) if exists(text_mask): if remainder > 0: text_mask = F.pad(text_mask, (0, remainder), value = False) text_mask = rearrange(text_mask, 'b n -> b n 1') text_keep_mask_embed = text_mask & text_keep_mask_embed null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working text_tokens = torch.where( text_keep_mask_embed, text_tokens, null_text_embed ) if exists(self.attn_pool): text_tokens = self.attn_pool(text_tokens) if self.CKeys['Debug_ModelPack']==1: ii+=1 print(ii) print("self.attn_pool(text_tokens)->text_tokens.dim: ", text_tokens.shape) # extra non-attention conditioning by projecting and then summing text embeddings to time # termed as text hiddens mean_pooled_text_tokens = text_tokens.mean(dim = -2) if self.CKeys['Debug_ModelPack']==1: ii +=1 print(ii) print("text_tokens.mean(dim=-2)->mean_pooled_text_tokens.dim: ", mean_pooled_text_tokens.shape) text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens) if self.CKeys['Debug_ModelPack']==1: ii +=1 print(ii) print("self.to_text_non_attn_cond(mean_pooled_text_tokens)->text_hiddens.dim: ",text_hiddens.shape) null_text_hidden = self.null_text_hidden.to(t.dtype) text_hiddens = torch.where( text_keep_mask_hidden, text_hiddens, null_text_hidden ) t = t + text_hiddens if self.CKeys['Debug_ModelPack']==1: ii +=1 print(ii) print("t+text_hiddens.dim: ", t.shape) # main conditioning tokens (c) c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(ii) print("cat(time_tokens, text_tokens)-> c dim: ", c.shape) # normalize conditioning tokens c = self.norm_cond(c) if self.CKeys['Debug_ModelPack']==1: ii += 1 print(ii) print("self.norm_cond(c)->c dim:", c.shape) # initial resnet block (for memory efficient unet) if exists(self.init_resnet_block): x = self.init_resnet_block(x, t) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(ii) print("self.init_resnet_block(x,t)-> x dim: ", x.shape) # go through the layers of the unet, down and up # + for debug if self.CKeys['Debug_ModelPack']==1: print("Before unet, down and up, ") print(" x dim: ", x.shape) print(" t dim: ", t.shape) print(" c dim: ", c.shape) # ii=0 hiddens = [] for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs: if exists(pre_downsample): x = pre_downsample(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, pre_downsample(x)=>x dim: ", x.shape) x = init_block(x, t, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, init_block(x,t,c)=>x dim: ", x.shape) for resnet_block in resnet_blocks: x = resnet_block(x, t) hiddens.append(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, resnet_block(x,t)=> x dim: ", x.shape) x = attn_block(x, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, attn_block(x,c)=> x dim: ", x.shape) hiddens.append(x) if exists(post_downsample): x = post_downsample(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, post_downsample(x)=> x dim: ", x.shape) x = self.mid_block1(x, t, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, mid_block_1(x,t,c)=> x dim: ", x.shape) if exists(self.mid_attn): x = self.mid_attn(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, mid_attn(x)=> x dim: ", x.shape) x = self.mid_block2(x, t, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, mid_block_2(x,t,c)=> x dim: ", x.shape) add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) up_hiddens = [] for init_block, resnet_blocks, attn_block, upsample in self.ups: x = add_skip_connection(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, add_skip_connection(x)=> x dim: ", x.shape) # x = init_block(x, t, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, init_block(x,t,c)=> x dim: ", x.shape) for resnet_block in resnet_blocks: x = add_skip_connection(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, add_skip_connection(x)=> x dim: ", x.shape) x = resnet_block(x, t) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, resnet_block(x,t)=> x dim: ", x.shape) x = attn_block(x, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, attn_block(x,c)=> x dim: ", x.shape) up_hiddens.append(x.contiguous()) x = upsample(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, upsample(x)=> x dim: ", x.shape) # whether to combine all feature maps from upsample blocks x = self.upsample_combiner(x, up_hiddens) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, upsample_combiner(x,up_hiddens)=> x dim: ", x.shape) # final top-most residual if needed if self.init_conv_to_final_conv_residual: x = torch.cat((x, init_conv_residual), dim = 1) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, cat(x,init_conv_residual)=> x dim: ", x.shape) if exists(self.final_res_block): x = self.final_res_block(x, t) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, final_res_block(x,t)=> x dim: ", x.shape) if exists(lowres_cond_img): x = torch.cat((x, lowres_cond_img), dim = 1) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, cat(x,lowres_cond_img)=> x dim: ", x.shape) if self.beginning_and_final_conv_present: x=self.final_conv(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, final_conv(x)=> x dim: ", x.shape) return x ######################################################## ## 1D Unet ######################################################## class OneD_Unet_Old(nn.Module): def __init__( self, *, dim, # image_embed_dim = 1024, # NOT used so far text_embed_dim = 768, # get_encoded_dim(DEFAULT_T5_NAME) num_resnet_blocks = 1, cond_dim = None, num_image_tokens = 4, num_time_tokens = 2, learned_sinu_pos_emb_dim = 16, out_dim = None, dim_mults=(1, 2, 4, 8), cond_images_channels = 0, channels = 3, channels_out = None, attn_dim_head = 64, attn_heads = 8, ff_mult = 2., lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ layer_attns = True, layer_attns_depth = 1, layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) layer_cross_attns = True, use_linear_attn = False, use_linear_cross_attn = False, cond_on_text = True, max_text_len = 256, init_dim = None, resnet_groups = 8, init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed init_cross_embed = False, init_cross_embed_kernel_sizes = (3, 7, 15), cross_embed_downsample = False, cross_embed_downsample_kernel_sizes = (2, 4), attn_pool_text = True, attn_pool_num_latents = 32, dropout = 0., memory_efficient = False, init_conv_to_final_conv_residual = False, use_global_context_attn = True, scale_skip_connection = True, final_resnet_block = True, final_conv_kernel_size = 3, cosine_sim_attn = False, self_cond = False, combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully pixel_shuffle_upsample = False , # may address checkboard artifacts beginning_and_final_conv_present = True , #TODO add cross-attn, doesnt work yet...whether or not to have final conv layer # +++++++++++++++++++++++++++ CKeys=None, ): super().__init__() # +++++++++++++++++++++++++++++++ self.CKeys=CKeys assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8' if dim < 128: print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/') # save locals to take care of some hyperparameters for cascading DDPM self._locals = locals() self._locals.pop('self', None) self._locals.pop('__class__', None) # + for debug if CKeys['Debug_ModelPack']==1: print("Showing the input:") print(json.dumps(self._locals, indent=4)) # determine dimensions self.channels = channels self.channels_out = default(channels_out, channels) # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis # (2) in self conditioning, one appends the predict x0 (x_start) init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) init_dim = default(init_dim, dim) self.self_cond = self_cond # optional image conditioning self.has_cond_image = cond_images_channels > 0 self.cond_images_channels = cond_images_channels init_channels += cond_images_channels self.beginning_and_final_conv_present=beginning_and_final_conv_present if self.beginning_and_final_conv_present: self.init_conv = CrossEmbedLayer( init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1 ) if init_cross_embed else nn.Conv1d( init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # time conditioning cond_dim = default(cond_dim, dim) time_cond_dim = dim * 4 * (2 if lowres_cond else 1) # embedding time for log(snr) noise from continuous version sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 # + for debug if self.CKeys['Debug_ModelPack']==1: print(" to_time_hiddens") print(" ou dim: ", time_cond_dim) self.to_time_hiddens = nn.Sequential( sinu_pos_emb, nn.Linear(sinu_pos_emb_input_dim, time_cond_dim), nn.SiLU() ) self.to_time_cond = nn.Sequential( nn.Linear(time_cond_dim, time_cond_dim) ) # project to time tokens as well as time hiddens # + for debug if self.CKeys['Debug_ModelPack']==1: print(" to_time_tokens") print(" in dim: ", time_cond_dim) # =4xdim cond_dim=dim print(" ou dim: ", num_time_tokens) self.to_time_tokens = nn.Sequential( nn.Linear(time_cond_dim, cond_dim * num_time_tokens), Rearrange('b (r d) -> b r d', r = num_time_tokens) ) # low res aug noise conditioning self.lowres_cond = lowres_cond if lowres_cond: self.to_lowres_time_hiddens = nn.Sequential( LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim), nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim), nn.SiLU() ) self.to_lowres_time_cond = nn.Sequential( nn.Linear(time_cond_dim, time_cond_dim) ) self.to_lowres_time_tokens = nn.Sequential( nn.Linear(time_cond_dim, cond_dim * num_time_tokens), Rearrange('b (r d) -> b r d', r = num_time_tokens) ) # normalizations self.norm_cond = nn.LayerNorm(cond_dim) # text encoding conditioning (optional) self.text_to_cond = None if cond_on_text: #only add linear lear if cond dim is not text emnd dim assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True' if text_embed_dim != cond_dim: self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) self.text_cond_linear=True else: print ("Text conditioning is equatl to cond_dim - no linear layer used") self.text_cond_linear=False # finer control over whether to condition on text encodings self.cond_on_text = cond_on_text # attention pooling self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents, cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None # for classifier free guidance self.max_text_len = max_text_len self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim)) # for non-attention based text conditioning at all points in the network where time is also conditioned self.to_text_non_attn_cond = None if cond_on_text: self.to_text_non_attn_cond = nn.Sequential( nn.LayerNorm(cond_dim), nn.Linear(cond_dim, time_cond_dim), nn.SiLU(), nn.Linear(time_cond_dim, time_cond_dim) ) # attention related params attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn) num_layers = len(in_out) # resnet block klass num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) resnet_groups = cast_tuple(resnet_groups, num_layers) resnet_klass = partial(ResnetBlock, **attn_kwargs) layer_attns = cast_tuple(layer_attns, num_layers) layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) use_linear_attn = cast_tuple(use_linear_attn, num_layers) use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers) assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) # downsample klass downsample_klass = Downsample if cross_embed_downsample: downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) # initial resnet block (for memory efficient unet) self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None # scale for resnet skip connections self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn] reversed_layer_params = list(map(reversed, layer_params)) # downsampling layers skip_connect_dims = [] # keep track of skip connection dimensions for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)): is_last = ind >= (num_resolutions - 1) layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None if layer_attn: transformer_block_klass = TransformerBlock elif layer_use_linear_attn: transformer_block_klass = LinearAttentionTransformerBlock else: transformer_block_klass = Identity current_dim = dim_in # whether to pre-downsample, from memory efficient unet pre_downsample = None if memory_efficient: pre_downsample = downsample_klass(dim_in, dim_out) current_dim = dim_out skip_connect_dims.append(current_dim) # whether to do post-downsample, for non-memory efficient unet post_downsample = None if not memory_efficient: post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv1d(dim_in, dim_out, 3, padding = 1), nn.Conv1d(dim_in, dim_out, 1)) self.downs.append( nn.ModuleList([ pre_downsample, resnet_klass( current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups ), nn.ModuleList([ ResnetBlock( current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn ) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass( dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), post_downsample ]) ) # middle layers mid_dim = dims[-1] self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_attn = EinopsToAndFrom('b c h', 'b h c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) # upsample klass upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample # upsampling layers upsample_fmap_dims = [] for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)): is_last = ind == (len(in_out) - 1) layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None if layer_attn: transformer_block_klass = TransformerBlock elif layer_use_linear_attn: transformer_block_klass = LinearAttentionTransformerBlock else: transformer_block_klass = Identity skip_connect_dim = skip_connect_dims.pop() upsample_fmap_dims.append(dim_out) self.ups.append(nn.ModuleList([ resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() ])) # whether to combine feature maps from all upsample blocks before final resnet block out self.upsample_combiner = UpsampleCombiner( dim = dim, enabled = combine_upsample_fmaps, dim_ins = upsample_fmap_dims, dim_outs = dim ) # whether to do a final residual from initial conv to the final resnet block out self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0) # final optional resnet block and convolution out self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None final_conv_dim_in = dim if final_resnet_block else final_conv_dim final_conv_dim_in += (channels if lowres_cond else 0) if self.beginning_and_final_conv_present: print (final_conv_dim_in, self.channels_out) self.final_conv = nn.Conv1d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2) if self.beginning_and_final_conv_present: zero_init_(self.final_conv) # if the current settings for the unet are not correct # for cascading DDPM, then reinit the unet with the right settings def cast_model_parameters( self, *, lowres_cond, text_embed_dim, channels, channels_out, cond_on_text ): if lowres_cond == self.lowres_cond and \ channels == self.channels and \ cond_on_text == self.cond_on_text and \ text_embed_dim == self._locals['text_embed_dim'] and \ channels_out == self.channels_out: return self updated_kwargs = dict( lowres_cond = lowres_cond, text_embed_dim = text_embed_dim, channels = channels, channels_out = channels_out, cond_on_text = cond_on_text ) return self.__class__(**{**self._locals, **updated_kwargs}) # methods for returning the full unet config as well as its parameter state def to_config_and_state_dict(self): return self._locals, self.state_dict() # class method for rehydrating the unet from its config and state dict @classmethod def from_config_and_state_dict(klass, config, state_dict): unet = klass(**config) unet.load_state_dict(state_dict) return unet # methods for persisting unet to disk def persist_to_file(self, path): path = Path(path) path.parents[0].mkdir(exist_ok = True, parents = True) config, state_dict = self.to_config_and_state_dict() pkg = dict(config = config, state_dict = state_dict) torch.save(pkg, str(path)) # class method for rehydrating the unet from file saved with `persist_to_file` @classmethod def hydrate_from_file(klass, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) assert 'config' in pkg and 'state_dict' in pkg config, state_dict = pkg['config'], pkg['state_dict'] return Unet.from_config_and_state_dict(config, state_dict) # forward with classifier free guidance def forward_with_cond_scale( self, *args, cond_scale = 1., **kwargs ): logits = self.forward(*args, **kwargs) if cond_scale == 1: return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale # forward fun returns the loss # Here, for Model B: # x : (batch, 1, seq_len) values: from noised image/Y # time : (batch) values: # text_embeds : None # text_mask: None # cond_images: input sequence/X # lowres_noise_times: None # lowres_cond_img: None # cond_drop_prob: 0.1 # self_cond: from noised image/Y def forward( self, x, time, *, lowres_cond_img = None, lowres_noise_times = None, text_embeds = None, text_mask = None, cond_images = None, self_cond = None, cond_drop_prob = 0. ): # + for debug if self.CKeys['Debug_ModelPack']==1: print("========================================") print("Here are OneD_Unet:forward") batch_size, device = x.shape[0], x.device # condition on self # + for debug if self.CKeys['Debug_ModelPack']==1: print("Check inputs: ") print(" x-0 dim: ", x.shape, "[batch, 1, seq_len]") print(" time dim: ", time.shape, "") print(" cond_images dim: ", cond_images.shape, "") if self.self_cond: self_cond = default(self_cond, lambda: torch.zeros_like(x)) x = torch.cat((x, self_cond), dim = 1) # + for debug if self.CKeys['Debug_ModelPack']==1: print("After self_cond x dim: ", x.shape) # add low resolution conditioning, if present assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present' if exists(lowres_cond_img): x = torch.cat((x, lowres_cond_img), dim = 1) # + for debug if self.CKeys['Debug_ModelPack']==1: print("After lowres_cond_img x dim: ", x.shape) # condition on input image assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' if exists(cond_images): assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' cond_images = resize_image_to(cond_images, x.shape[-1]) # ++++++++++++++++++++++++++++++ # add cond_images (from X) into generation process... x = torch.cat((cond_images.to(device), x.to(device)), dim = 1) # + for debug if self.CKeys['Debug_ModelPack']==1: print("cond_images dim: ", cond_images.shape, "[batch, 1, max_seq_len]") print("After cond_images, x dim: ", x.shape, "[batch, 2, max_seq_len]") # initial convolution if self.beginning_and_final_conv_present: x = self.init_conv(x) # + for debug if self.CKeys['Debug_ModelPack']==1: print("After init_conv, x dim: ", x.shape, "[batch, UNet:dim, max_seq_len]") # init conv residual if self.init_conv_to_final_conv_residual: init_conv_residual = x.clone() # time conditioning time_hiddens = self.to_time_hiddens(time) # + for debug if self.CKeys['Debug_ModelPack']==1: print("time dim: ", time.shape, "[batch]") print("after, time_hiddens dim: ", time_hiddens.shape, "[batch, 4xUNet:dim]") # derive time tokens time_tokens = self.to_time_tokens(time_hiddens) t = self.to_time_cond(time_hiddens) # + for debug if self.CKeys['Debug_ModelPack']==1: print("time_tokens dim: ", time_tokens.shape, "[batch, num_time_tokens,4xdim/num_time_tokens]") print("after to_time_cond t dim: ", t.shape, "[batch, 4xUNet:dim]") # add lowres time conditioning to time hiddens # and add lowres time tokens along sequence dimension for attention if self.lowres_cond: lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times) lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens) lowres_t = self.to_lowres_time_cond(lowres_time_hiddens) t = t + lowres_t time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2) # + for debug if self.CKeys['Debug_ModelPack']==1: print("After lowres_cond, time_tokens dim: ", time_tokens.shape) # text conditioning text_tokens = None if exists(text_embeds) and self.cond_on_text: # conditional dropout text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1') text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1') # calculate text embeds if self.text_cond_linear: text_tokens = self.text_to_cond(text_embeds) else: text_tokens=text_embeds text_tokens = text_tokens[:, :self.max_text_len] if exists(text_mask): text_mask = text_mask[:, :self.max_text_len] text_tokens_len = text_tokens.shape[1] remainder = self.max_text_len - text_tokens_len if remainder > 0: text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) if exists(text_mask): if remainder > 0: text_mask = F.pad(text_mask, (0, remainder), value = False) text_mask = rearrange(text_mask, 'b n -> b n 1') text_keep_mask_embed = text_mask & text_keep_mask_embed null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working text_tokens = torch.where( text_keep_mask_embed, text_tokens, null_text_embed ) if exists(self.attn_pool): text_tokens = self.attn_pool(text_tokens) # extra non-attention conditioning by projecting and then summing text embeddings to time # termed as text hiddens mean_pooled_text_tokens = text_tokens.mean(dim = -2) text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens) null_text_hidden = self.null_text_hidden.to(t.dtype) text_hiddens = torch.where( text_keep_mask_hidden, text_hiddens, null_text_hidden ) t = t + text_hiddens # main conditioning tokens (c) c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2) # + for debug if self.CKeys['Debug_ModelPack']==1: print("Merge time and text tokens, c dim: ", c.shape) # normalize conditioning tokens c = self.norm_cond(c) # initial resnet block (for memory efficient unet) if exists(self.init_resnet_block): x = self.init_resnet_block(x, t) # go through the layers of the unet, down and up # + for debug if self.CKeys['Debug_ModelPack']==1: print("Before unet, down and up, ") print("x dim: ", x.shape) print("t dim: ", t.shape) print("c dim: ", c.shape) ii=0 hiddens = [] for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs: if exists(pre_downsample): x = pre_downsample(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after pre_downsample x dim: ", x.shape) x = init_block(x, t, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after init_block(x,t,c) x dim: ", x.shape) for resnet_block in resnet_blocks: x = resnet_block(x, t) hiddens.append(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after resnet_block x dim: ", x.shape) x = attn_block(x, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after attn_block x dim: ", x.shape) hiddens.append(x) if exists(post_downsample): x = post_downsample(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after post_downsample x dim: ", x.shape) x = self.mid_block1(x, t, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after mid_block_1 x dim: ", x.shape) if exists(self.mid_attn): x = self.mid_attn(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after mid_attn x dim: ", x.shape) x = self.mid_block2(x, t, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after mid_block_2 x dim: ", x.shape) add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) up_hiddens = [] for init_block, resnet_blocks, attn_block, upsample in self.ups: x = add_skip_connection(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after add_skip_connection x dim: ", x.shape) # x = init_block(x, t, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after init_block(x,t,c) x dim: ", x.shape) for resnet_block in resnet_blocks: x = add_skip_connection(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after add_skip_connection x dim: ", x.shape) x = resnet_block(x, t) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after resnet_block(x,t) x dim: ", x.shape) x = attn_block(x, c) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after attn_block(x,c) x dim: ", x.shape) up_hiddens.append(x.contiguous()) x = upsample(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after upsample(x) x dim: ", x.shape) # whether to combine all feature maps from upsample blocks x = self.upsample_combiner(x, up_hiddens) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after upsample_combiner(x,..) x dim: ", x.shape) # final top-most residual if needed if self.init_conv_to_final_conv_residual: x = torch.cat((x, init_conv_residual), dim = 1) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after cat_init_conv_resi x dim: ", x.shape) if exists(self.final_res_block): x = self.final_res_block(x, t) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after final_res_block(x,t) x dim: ", x.shape) if exists(lowres_cond_img): x = torch.cat((x, lowres_cond_img), dim = 1) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after cat_x_lowres_cond_img x dim: ", x.shape) if self.beginning_and_final_conv_present: x=self.final_conv(x) # + for debug if self.CKeys['Debug_ModelPack']==1: ii += 1 print(F" {str(ii)}, after final_conv(x) x dim: ", x.shape) return x ######################################################## ## null unets ######################################################## class NullUnet(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self.lowres_cond = False self.dummy_paramcast_model_parameterseter = nn.Parameter(torch.tensor([0.])) def cast_model_parameters(self, *args, **kwargs): return self def forward(self, x, *args, **kwargs): return x class Unet(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self.lowres_cond = False self.dummy_parameter = nn.Parameter(torch.tensor([0.])) def cast_model_parameters(self, *args, **kwargs): return self def forward(self, x, *args, **kwargs): return x ######################################################## ## Elucidated denoising model ## After: Tero Karras and Miika Aittala and Timo Aila and Samuli Laine, ## Elucidating the Design Space of Diffusion-Based Generative Models ## https://arxiv.org/abs/2206.00364, 2022 ######################################################## from math import sqrt Hparams_fields = [ 'num_sample_steps', 'sigma_min', 'sigma_max', 'sigma_data', 'rho', 'P_mean', 'P_std', 'S_churn', 'S_tmin', 'S_tmax', 'S_noise' ] Hparams = namedtuple('Hparams', Hparams_fields) # helper functions def log(t, eps = 1e-20): return torch.log(t.clamp(min = eps)) # main class class ElucidatedImagen(nn.Module): def __init__( self, unets, *, image_sizes, # for cascading ddpm, image size at each stage text_encoder_name = '', # can be "DEFAULT_T5_NAME" text_embed_dim = None, channels = 3, channels_out=3, cond_drop_prob = 0.1, random_crop_sizes = None, lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find condition_on_text = True, auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader dynamic_thresholding = True, dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper only_train_unet_number = None, lowres_noise_schedule = 'linear', num_sample_steps = 32, # number of sampling steps sigma_min = 0.002, # min noise level sigma_max = 80, # max noise level sigma_data = 0.5, # standard deviation of data distribution rho = 7, # controls the sampling schedule P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper S_tmin = 0.05, S_tmax = 50, S_noise = 1.003, loss_type=0, #0=MSE, categorical_loss_ignore=None, # +++++++++++++++++ # device=None, CKeys=None, PKeys=None, ): super().__init__() # +++++++++++++++++++++++++++ # self.device=device self.CKeys=CKeys self.PKeys=PKeys self.only_train_unet_number = only_train_unet_number self.condition_on_text = condition_on_text self.unconditional = not condition_on_text self.loss_type=loss_type if self.loss_type>0: self.categorical_loss=True self.m = nn.LogSoftmax(dim=1) #used for some loss functins else: self.categorical_loss=False print("Loss type: ", self.loss_type) self.categorical_loss_ignore=categorical_loss_ignore # channels self.channels = channels self.channels_out = channels_out unets = cast_tuple(unets) num_unets = len(unets) # randomly cropping for upsampler training self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets) assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example' # lowres augmentation noise schedule self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule) # get text encoder self.text_embed_dim =text_embed_dim # construct unets self.unets = nn.ModuleList([]) self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment print (f"Channels in={self.channels}, channels out={self.channels_out}") for ind, one_unet in enumerate(unets): # # --------------------------- # assert isinstance(one_unet, ( OneD_Unet, NullUnet)) # ++++++++++++++++++++++++++++ is_first = ind == 0 # if the current settings for the unet are not correct # for cascading DDPM, then reinit the unet with the right settings # # + for debug # print('I am here') print("Test on cast_model_parameters...") print(not is_first) print(self.condition_on_text) print(self.text_embed_dim if self.condition_on_text else None) print(self.channels) print(self.channels_out) one_unet = one_unet.cast_model_parameters( lowres_cond = not is_first, cond_on_text = self.condition_on_text, text_embed_dim = self.text_embed_dim if self.condition_on_text else None, channels = self.channels, #channels_out = self.channels channels_out = self.channels_out ) self.unets.append(one_unet) # determine whether we are training on images or video is_video = False # only consider image case self.is_video = is_video self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1' if not is_video else 'b -> b 1 1 1')) self.resize_to = resize_video_to if is_video else resize_image_to self.image_sizes = image_sizes assert num_unets == len(self.image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}' self.sample_channels = cast_tuple(self.channels, num_unets) lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' self.lowres_sample_noise_level = lowres_sample_noise_level self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level # classifier free guidance self.cond_drop_prob = cond_drop_prob self.can_classifier_guidance = cond_drop_prob > 0. # normalize and unnormalize image functions self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity self.input_image_range = (0. if auto_normalize_img else -1., 1.) # dynamic thresholding self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets) self.dynamic_thresholding_percentile = dynamic_thresholding_percentile # # temporal interpolations # temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets) # self.temporal_downsample_factor = temporal_downsample_factor # self.resize_cond_video_frames = resize_cond_video_frames # self.temporal_downsample_divisor = temporal_downsample_factor[0] # assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1' # assert tuple(sorted(temporal_downsample_factor, reverse = True)) == temporal_downsample_factor, 'temporal downsample factor must be in order of descending' # elucidating parameters hparams = [ num_sample_steps, sigma_min, sigma_max, sigma_data, rho, P_mean, P_std, S_churn, S_tmin, S_tmax, S_noise, ] hparams = [cast_tuple(hp, num_unets) for hp in hparams] self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)] # one temp parameter for keeping track of device # self.register_buffer('_temp', torch.tensor([0.]).to(device), persistent = False) self.register_buffer('_temp', torch.tensor([0.]), persistent = False) # default to device of unets passed in self.to(next(self.unets.parameters()).device) # for debug print(next(self.unets.parameters()).device) # print ("Device used in ImagenEluc: ", self.device) def force_unconditional_(self): self.condition_on_text = False self.unconditional = True for unet in self.unets: unet.cond_on_text = False @property def device(self): return self._temp.device #return (device) def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) index = unet_number - 1 if isinstance(self.unets, nn.ModuleList): unets_list = [unet for unet in self.unets] delattr(self, 'unets') self.unets = unets_list if index != self.unet_being_trained_index: for unet_index, unet in enumerate(self.unets): unet.to(self.device if unet_index == index else 'cpu') self.unet_being_trained_index = index return self.unets[index] def reset_unets_all_one_device(self, device = None): device = default(device, self.device) self.unets = nn.ModuleList([*self.unets]) self.unets.to(device) self.unet_being_trained_index = -1 @contextmanager def one_unet_in_gpu(self, unet_number = None, unet = None): assert exists(unet_number) ^ exists(unet) if exists(unet_number): unet = self.unets[unet_number - 1] devices = [module_device(unet) for unet in self.unets] self.unets.cpu() unet.to(self.device) yield for unet, device in zip(self.unets, devices): unet.to(device) # overriding state dict functions def state_dict(self, *args, **kwargs): self.reset_unets_all_one_device() return super().state_dict(*args, **kwargs) def load_state_dict(self, *args, **kwargs): self.reset_unets_all_one_device() return super().load_state_dict(*args, **kwargs) # dynamic thresholding def threshold_x_start(self, x_start, dynamic_threshold = True): if not dynamic_threshold: return x_start.clamp(-1., 1.) s = torch.quantile( rearrange(x_start, 'b ... -> b (...)').abs(), self.dynamic_thresholding_percentile, dim = -1 ) s.clamp_(min = 1.) s = right_pad_dims_to(x_start, s) return x_start.clamp(-s, s) / s # derived preconditioning params - Table 1 def c_skip(self, sigma_data, sigma): return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2) def c_out(self, sigma_data, sigma): return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5 def c_in(self, sigma_data, sigma): return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5 def c_noise(self, sigma): return log(sigma) * 0.25 # preconditioned network output # equation (7) in the paper def preconditioned_network_forward( self, unet_forward, noised_images, sigma, *, sigma_data, clamp = False, dynamic_threshold = True, **kwargs ): # + for debug if self.CKeys['Debug_ModelPack']==1: print("========================================") print("ElucidatedImagen: preconditioned_network_forward") batch, device = noised_images.shape[0], noised_images.device if isinstance(sigma, float): sigma = torch.full((batch,), sigma, device = device) padded_sigma = self.right_pad_dims_to_datatype(sigma) # + for debug if self.CKeys['Debug_ModelPack']==1: print("unet_forward: ") print("arg_0: ", noised_images.shape) print("arg_1: ", sigma.shape) print("other arg: ", kwargs) net_out = unet_forward( self.c_in(sigma_data, padded_sigma) * noised_images, self.c_noise(sigma), # "{text_embeds,text_mask,cond_images,...} are in kwargs" **kwargs ) out = self.c_skip(sigma_data, padded_sigma) * noised_images + self.c_out(sigma_data, padded_sigma) * net_out if not clamp: return out return self.threshold_x_start(out, dynamic_threshold) # sampling # sample schedule # equation (5) in the paper def sample_schedule( self, num_sample_steps, rho, sigma_min, sigma_max ): N = num_sample_steps inv_rho = 1 / rho steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32) sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0. return sigmas @torch.no_grad() def one_unet_sample( self, unet, shape, *, unet_number, clamp = True, dynamic_threshold = True, cond_scale = 1., use_tqdm = True, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = 5, init_images = None, skip_steps = None, sigma_min = None, sigma_max = None, **kwargs ): # get specific sampling hyperparameters for unet hp = self.hparams[unet_number - 1] sigma_min = default(sigma_min, hp.sigma_min) sigma_max = default(sigma_max, hp.sigma_max) # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma sigmas = self.sample_schedule(hp.num_sample_steps, hp.rho, sigma_min, sigma_max) gammas = torch.where( (sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax), min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1), 0. ) sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) # images is noise at the beginning init_sigma = sigmas[0] images = init_sigma * torch.randn(shape, device = self.device) # initializing with an image if exists(init_images): images += init_images # keeping track of x0, for self conditioning if needed x_start = None # prepare inpainting images and mask has_inpainting = exists(inpaint_images) and exists(inpaint_masks) resample_times = inpaint_resample_times if has_inpainting else 1 if has_inpainting: inpaint_images = self.normalize_img(inpaint_images) inpaint_images = self.resize_to(inpaint_images, shape[-1]) inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool() # unet kwargs unet_kwargs = dict( sigma_data = hp.sigma_data, clamp = clamp, dynamic_threshold = dynamic_threshold, cond_scale = cond_scale, **kwargs ) # gradually denoise initial_step = default(skip_steps, 0) sigmas_and_gammas = sigmas_and_gammas[initial_step:] total_steps = len(sigmas_and_gammas) for ind, (sigma, sigma_next, gamma) in tqdm(enumerate(sigmas_and_gammas), total = total_steps, desc = 'sampling time step', disable = not use_tqdm): is_last_timestep = ind == (total_steps - 1) sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) for r in reversed(range(resample_times)): is_last_resample_step = r == 0 eps = hp.S_noise * torch.randn(shape, device = self.device) # stochastic sampling sigma_hat = sigma + gamma * sigma added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps images_hat = images + added_noise self_cond = x_start if unet.self_cond else None if has_inpainting: images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks model_output = self.preconditioned_network_forward( unet.forward_with_cond_scale, images_hat, sigma_hat, self_cond = self_cond, **unet_kwargs ) denoised_over_sigma = (images_hat - model_output) / sigma_hat images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma # second order correction, if not the last timestep if sigma_next != 0: self_cond = model_output if unet.self_cond else None model_output_next = self.preconditioned_network_forward( unet.forward_with_cond_scale, images_next, sigma_next, self_cond = self_cond, **unet_kwargs ) denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) images = images_next if has_inpainting and not (is_last_resample_step or is_last_timestep): # renoise in repaint and then resample repaint_noise = torch.randn(shape, device = self.device) images = images + (sigma - sigma_next) * repaint_noise x_start = model_output # save model output for self conditioning if has_inpainting: images = images * ~inpaint_masks + inpaint_images * inpaint_masks return images @torch.no_grad() @eval_decorator def sample( self, texts: List[str] = None, text_masks = None, text_embeds = None, cond_images = None, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = 5, init_images = None, skip_steps = None, sigma_min = None, sigma_max = None, video_frames = None, batch_size = 1, cond_scale = 1., lowres_sample_noise_level = None, start_at_unet_number = 1, start_image_or_video = None, stop_at_unet_number = None, return_all_unet_outputs = False, return_pil_images = False, use_tqdm = True, device = None, ): # + device = default(device, self.device) self.reset_unets_all_one_device(device = device) cond_images = maybe(cast_uint8_images_to_float)(cond_images) if exists(texts) and not exists(text_embeds) and not self.unconditional: assert all([*map(len, texts)]), 'text cannot be empty' with autocast(enabled = False): text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks)) if not self.unconditional: assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training' text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) batch_size = text_embeds.shape[0] if exists(inpaint_images): if self.unconditional: if batch_size == 1: # assume researcher wants to broadcast along inpainted images batch_size = inpaint_images.shape[0] assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=)``' assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on' assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified' assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented' assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting' outputs = [] is_cuda = next(self.parameters()).is_cuda device = next(self.parameters()).device lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level) num_unets = len(self.unets) cond_scale = cast_tuple(cond_scale, num_unets) # handle video and frame dimension assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video' frame_dims = (video_frames,) if self.is_video else tuple() # initializing with an image or video init_images = cast_tuple(init_images, num_unets) init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images] skip_steps = cast_tuple(skip_steps, num_unets) sigma_min = cast_tuple(sigma_min, num_unets) sigma_max = cast_tuple(sigma_max, num_unets) # handle starting at a unet greater than 1, for training only-upscaler training if start_at_unet_number > 1: assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets' assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling' prev_image_size = self.image_sizes[start_at_unet_number - 2] img = self.resize_to(start_image_or_video, prev_image_size) for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm): if unet_number < start_at_unet_number: continue assert not isinstance(unet, NullUnet), 'cannot sample from null unet' context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext() with context: lowres_cond_img = lowres_noise_times = None shape = (batch_size, channel, *frame_dims, image_size ) # low resolution conditioning if unet.lowres_cond: lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device) lowres_cond_img = self.resize_to(img, image_size) lowres_cond_img = self.normalize_img(lowres_cond_img.float()) lowres_cond_img, _ = self.lowres_noise_schedule.q_sample( x_start = lowres_cond_img.float(), t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img.float()) ) if exists(unet_init_images): unet_init_images = self.resize_to(unet_init_images, image_size) shape = (batch_size, self.channels, *frame_dims, image_size) img = self.one_unet_sample( unet, shape, unet_number = unet_number, text_embeds = text_embeds, text_mask =text_masks, cond_images = cond_images, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = unet_init_images, skip_steps = unet_skip_steps, sigma_min = unet_sigma_min, sigma_max = unet_sigma_max, cond_scale = unet_cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_times = lowres_noise_times, dynamic_threshold = dynamic_threshold, use_tqdm = use_tqdm ) if self.categorical_loss: img=self.m(img) outputs.append(img) if exists(stop_at_unet_number) and stop_at_unet_number == unet_number: break output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs if not return_all_unet_outputs: outputs = outputs[-1:] assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet' if self.categorical_loss: return torch.argmax(outputs[output_index], dim=1).unsqueeze (1) else: return outputs[output_index] # training def loss_weight(self, sigma_data, sigma): return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2 def noise_distribution(self, P_mean, P_std, batch_size): # torch.randn: normal distribution N(0, 1) return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp() def forward( self, images, unet: Union[ NullUnet, DistributedDataParallel] = None, texts: List[str] = None, text_embeds = None, text_masks = None, unet_number = None, cond_images = None, ): assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}' # + for debug if self.CKeys['Debug_ModelPack']==1: if cond_images!=None: print("cond_images type: ", cond_images.dtype) else: print("cond_images type: None") cond_images = maybe(cast_uint8_images_to_float)(cond_images) # + for debug if self.CKeys['Debug_ModelPack']==1: if cond_images!=None: print("cond_images type: ", cond_images.dtype) else: print("cond_images type: None") if self.categorical_loss==False: assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead' unet_index = unet_number - 1 unet = default(unet, lambda: self.get_unet(unet_number)) assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained' target_image_size = self.image_sizes[unet_index] random_crop_size = self.random_crop_sizes[unet_index] prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None hp = self.hparams[unet_index] # + for debug if self.CKeys['Debug_ModelPack']==1: print("target_image_size: ", target_image_size) print("prev_image_size: ", prev_image_size) print("random_crop_size: ", random_crop_size) batch_size, c, *_, h, device, is_video = *images.shape, images.device, (images.ndim == 4) frames = images.shape[2] if is_video else None # + for debug if self.CKeys['Debug_ModelPack']==1: print("frames: ", frames) check_shape(images, 'b c ...', c = self.channels) assert h >= target_image_size if exists(texts) and not exists(text_embeds) and not self.unconditional: assert all([*map(len, texts)]), 'text cannot be empty' assert len(texts) == len(images), 'number of text captions does not match up with the number of images given' with autocast(enabled = False): text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks)) if not self.unconditional: text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' lowres_cond_img = lowres_aug_times = None if exists(prev_image_size): lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range) lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range) if self.per_sample_random_aug_noise_level: lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device) else: # return a random time: lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device) # extend it to batch_size lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size) if exists(random_crop_size): aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.) if is_video: images, lowres_cond_img = rearrange_many((images, lowres_cond_img), 'b c f h -> (b f) c h') images = aug(images) lowres_cond_img = aug(lowres_cond_img, params = aug._params) if is_video: images, lowres_cond_img = rearrange_many((images, lowres_cond_img), '(b f) c h -> b c f h', f = frames) lowres_cond_img_noisy = None if exists(lowres_cond_img): lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample( x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img.float()) ) # get the sigmas sigmas = self.noise_distribution( hp.P_mean, hp.P_std, batch_size ).to(device) padded_sigmas = self.right_pad_dims_to_datatype(sigmas).to(device) # + for debug if self.CKeys['Debug_ModelPack']==1: print('sigmas dim: ', sigmas.shape, 'should = batch_size') print('sigmas[0..3]: ', sigmas[:3]) print('padded_sigmas dim: ', padded_sigmas.shape) # noise noise = torch.randn_like(images.float()).to(device) noised_images = images + padded_sigmas * noise # alphas are 1. in the paper # unet kwargs unet_kwargs = dict( sigma_data = hp.sigma_data, text_embeds = text_embeds, text_mask =text_masks, cond_images = cond_images, lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times), lowres_cond_img = lowres_cond_img_noisy, cond_drop_prob = self.cond_drop_prob, ) # + for debug if self.CKeys['Debug_ModelPack']==1: print("lowres_noise_times: ", unet_kwargs['lowres_noise_times']) print("lowres_cond_img_noisy: ", unet_kwargs['lowres_cond_img']) print("unet_kwargs: \n", unet_kwargs) # self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower # Because 'unet' can be an instance of DistributedDataParallel coming from the # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to # access the member 'module' of the wrapped unet instance. self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet # # + for debug # if self.CKeys['Debug_ModelPack']==1: # # this will be the unet # print("self_cond: ", self_cond) if self_cond and random() < 0.5: # + for debug if self.CKeys['Debug_ModelPack']==1: print("=========added self_cond.============") with torch.no_grad(): pred_x0 = self.preconditioned_network_forward( unet.forward, noised_images, sigmas, **unet_kwargs ).detach() unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0} # + for debug if self.CKeys['Debug_ModelPack']==1: print("after self-condition, random < 0.5 .......") print("unet_kwargs: \n", unet_kwargs) # get prediction denoised_images = self.preconditioned_network_forward( unet.forward, noised_images, sigmas, **unet_kwargs ) # losses if self.loss_type==0: losses = F.mse_loss(denoised_images, images, reduction = 'none') losses = reduce(losses, 'b ... -> b', 'mean') # loss weighting losses = losses * self.loss_weight(hp.sigma_data, sigmas) losses=losses.mean() return losses # =========================================================== # final models # explict define unets class ProteinDesigner_B(nn.Module): def __init__(self, unet, # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx CKeys=None, PKeys=None, ): super(ProteinDesigner_B, self).__init__() # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # unload the parameters timesteps =default(PKeys['timesteps'], 10) # 10 , dim =default(PKeys['dim'], 32) # 32, pred_dim =default(PKeys['pred_dim'], 25) # 25, loss_type =default(PKeys['loss_type'], 0) # 0, # MSE elucidated =default(PKeys['elucidated'], True) # True, padding_idx =default(PKeys['padding_idx'], 0) # 0, cond_dim =default(PKeys['cond_dim'], 512) # 512, text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512, input_tokens =default(PKeys['input_tokens'], 25) #for non-BERT sequence_embed =default(PKeys['sequence_embed'], False) embed_dim_position =default(PKeys['embed_dim_position'], 32) max_text_len =default(PKeys['max_text_len'], 16) cond_images_channels=default(PKeys['cond_images_channels'], 0) # ++++++++++++++++++++++ max_length =default(PKeys['max_length'], 64) device =default(PKeys['device'], None) # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< print ("Model B: Generative protein diffusion model, residue-based") print ("Using condition as the initial sequence") self.pred_dim=pred_dim self.loss_type=loss_type # +++++++++++++++++ for debug self.CKeys=CKeys self.PKeys=PKeys self.max_length = max_length assert loss_type == 0, "Losses other than MSE not implemented" self.fc_embed1 = nn.Linear( 8, max_length) # NOT used # INPUT DIM (last), OUTPUT DIM, last self.fc_embed2 = nn.Linear( 1, text_embed_dim) # USED # INPUT DIM (last), OUTPUT DIM, last self.max_text_len=max_text_len self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) text_embed_dim=text_embed_dim+embed_dim_position self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) for i in range (max_text_len): self.pos_matrix_i [i]=i +1 condition_on_text=True self.cond_images_channels=cond_images_channels if self.cond_images_channels>0: condition_on_text = False if self.cond_images_channels>0: print ("Use conditioning image during training....") assert elucidated , "Only elucidated model implemented...." self.is_elucidated=elucidated if elucidated: self.imagen = ElucidatedImagen( unets = (unet), channels=self.pred_dim, channels_out=self.pred_dim , loss_type=loss_type, condition_on_text = condition_on_text, text_embed_dim = text_embed_dim, image_sizes = ( [max_length ]), cond_drop_prob = 0.1, auto_normalize_img = False, num_sample_steps = timesteps, # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) sigma_min = 0.002, # min noise level sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler sigma_data = 0.5, # standard deviation of data distribution rho = 7, # controls the sampling schedule P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper S_tmin = 0.05, S_tmax = 50, S_noise = 1.003, # ++++++++++++++++ # device=device, CKeys=self.CKeys, PKeys=self.PKeys, ).to(device) else: print ("Not implemented.") def forward(self, output, x=None, cond_images = None, unet_number=1, ): # ddddddddddddddddddddddddddddddddddddddddddddd if self.CKeys['Debug_ModelPack']==1: print("on Model B:forward") print("inputs:") print(' output: ', output.shape) print(' x: ', x) print(' cond_img: ', cond_images.shape) print(' unet_num: ', unet_number) if x != None: x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device) x_in[:,:x.shape[1]]=x x=x_in x=x.unsqueeze (2) x= self.fc_embed2(x) pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) pos_emb_x = self.pos_emb_x( pos_matrix_i_) pos_emb_x = torch.squeeze(pos_emb_x, 1) x= torch.cat( (x, pos_emb_x ), 2) if cond_images!=None: this_cond_images=cond_images.to(device) else: this_cond_images=cond_images if self.CKeys['Debug_ModelPack']==1: print('x with pos_emb_x: ', x) print("into self.imagen...") loss = self.imagen( output, text_embeds = x, # cond_images=cond_images.to(device), cond_images=this_cond_images, unet_number = unet_number, ) return loss def sample ( self, x=None, stop_at_unet_number=1 , cond_scale=7.5, x_data=None, skip_steps=None, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = 5, init_images = None, x_data_tokenized=None, device=None, # +++++++++++++++++++++++++ tokenizer_X=None, Xnormfac=1., max_length=1., ): batch_size=1 if x_data != None: print ("Conditioning target sequence provided via ori x_data ...", x_data) # this is for Model B with SecStr as input # # # -- ver 0.0: channel = 1, all padding at the end: content+000 # x_data = tokenizer_X.texts_to_sequences(x_data) # x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post') # # x_data= torch.from_numpy(x_data).float().to(device) # x_data = x_data/Xnormfac # x_data=x_data.unsqueeze (2) # x_data=torch.permute(x_data, (0,2,1) ) # # ++ ver 1.0: channel = self.pred_dim>1, esm padding, 0+content+000 x_data = tokenizer_X.texts_to_sequences(x_data) # padding: 0+content+00 x_data_0 = [] for this_x_data in x_data: x_data_0.append([0]+this_x_data) # to be checked whether this works x_data = sequence.pad_sequences( x_data_0, maxlen=max_length, padding='post', truncating='post', ) # normalization: dummy, for future x_data = torch.from_numpy(x_data).float().to(device) x_data = x_data/Xnormfac # adjust the channel:NOTE, here we use "copy" to fill-in the channels # May need to adjust this part for future models x_data=x_data.unsqueeze(1).repeat(1,self.pred_dim,1) print ("After channel expansion, x_data from target sequence=", x_data, x_data.shape) batch_size=x_data.shape[0] if x_data_tokenized != None: # this is for Model B with ForcePath vector as input # everything is padded as what the dataloader has # # if x_data_tokenized.any() != None: print ("Conditioning target sequence provided via x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape) # # ++ ver 1.0 # if self.pred_dim==1: # # old case: channel=1 # x_data=x_data_tokenized.unsqueeze (2) # x_data=torch.permute(x_data, (0,2,1) ).to(device) # else: # # new case: include channels in x_data_tokenized # x_data=x_data_tokenized.to(device) # # # -- ver 0.0: original # x_data=x_data_tokenized.unsqueeze (2) # x_data=torch.permute(x_data, (0,2,1) ).to(device) # # ++ ver 2.0: input x_data_tokenized.dim= (batch, seq_len) x_data=x_data_tokenized.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device) # print ("x_data.dim provided from x_data_tokenized: ", x_data.shape) batch_size=x_data.shape[0] if init_images != None: print ("Init sequence provided...", init_images) init_images = tokenizer_y.texts_to_sequences(init_images) init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') init_images= torch.from_numpy(init_images).float().to(device)/ynormfac print ("init_images=", init_images) if inpaint_images != None: print ("Inpaint sequence provided...", inpaint_images) print ("Mask: ", inpaint_masks) inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac print ("in_paint images=", inpaint_images) if x !=None: x_in=torch.zeros( (x.shape[0],max_length) ).to(device) x_in[:,:x.shape[1]]=x x=x_in x=x.unsqueeze (2) x= self.fc_embed2(x) pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) pos_emb_x = self.pos_emb_x( pos_matrix_i_) pos_emb_x = torch.squeeze(pos_emb_x, 1) x= torch.cat( (x, pos_emb_x ), 2) batch_size=x.shape[0] output = self.imagen.sample( text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number, cond_images=x_data, skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images, batch_size=batch_size, device=device, ) return output # =========================================================== # final models # explict define unets class ProteinPredictor_B(nn.Module): def __init__(self, unet, # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx CKeys=None, PKeys=None, ): super(ProteinPredictor_B, self).__init__() # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # unload the parameters timesteps =default(PKeys['timesteps'], 10) # 10 , dim =default(PKeys['dim'], 32) # 32, pred_dim =default(PKeys['pred_dim'], 25) # 25, loss_type =default(PKeys['loss_type'], 0) # 0, # MSE elucidated =default(PKeys['elucidated'], True) # True, padding_idx =default(PKeys['padding_idx'], 0) # 0, cond_dim =default(PKeys['cond_dim'], 512) # 512, text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512, input_tokens =default(PKeys['input_tokens'], 25) #for non-BERT sequence_embed =default(PKeys['sequence_embed'], False) embed_dim_position =default(PKeys['embed_dim_position'], 32) max_text_len =default(PKeys['max_text_len'], 16) cond_images_channels=default(PKeys['cond_images_channels'], 0) # ++++++++++++++++++++++ max_length =default(PKeys['max_length'], 64) device =default(PKeys['device'], None) # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< print ("Model B: Predictive protein diffusion model, residue-based") print ("Using condition as the initial sequence") self.pred_dim=pred_dim self.loss_type=loss_type # +++++++++++++++++ for debug self.CKeys=CKeys self.PKeys=PKeys self.max_length = max_length assert loss_type == 0, "Losses other than MSE not implemented" self.fc_embed1 = nn.Linear( 8, max_length) # NOT used # INPUT DIM (last), OUTPUT DIM, last self.fc_embed2 = nn.Linear( 1, text_embed_dim) # USED # INPUT DIM (last), OUTPUT DIM, last self.max_text_len=max_text_len self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) text_embed_dim=text_embed_dim+embed_dim_position self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) for i in range (max_text_len): self.pos_matrix_i [i]=i +1 condition_on_text=True self.cond_images_channels=cond_images_channels if self.cond_images_channels>0: condition_on_text = False if self.cond_images_channels>0: print ("Use conditioning image during training....") assert elucidated , "Only elucidated model implemented...." self.is_elucidated=elucidated if elucidated: self.imagen = ElucidatedImagen( unets = (unet), channels=self.pred_dim, channels_out=self.pred_dim , loss_type=loss_type, condition_on_text = condition_on_text, text_embed_dim = text_embed_dim, image_sizes = ( [max_length ]), cond_drop_prob = 0.1, auto_normalize_img = False, num_sample_steps = timesteps, # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) sigma_min = 0.002, # min noise level sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler sigma_data = 0.5, # standard deviation of data distribution rho = 7, # controls the sampling schedule P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper S_tmin = 0.05, S_tmax = 50, S_noise = 1.003, # ++++++++++++++++ # device=device, CKeys=self.CKeys, PKeys=self.PKeys, ).to(device) else: print ("Not implemented.") def forward(self, output, x=None, cond_images = None, unet_number=1, ): # ddddddddddddddddddddddddddddddddddddddddddddd if self.CKeys['Debug_ModelPack']==1: print("on Model B:forward") print("inputs:") print(' output: ', output.shape) print(' x: ', x) print(' cond_img: ', cond_images.shape) print(' unet_num: ', unet_number) if x != None: x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device) x_in[:,:x.shape[1]]=x x=x_in x=x.unsqueeze (2) x= self.fc_embed2(x) pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) pos_emb_x = self.pos_emb_x( pos_matrix_i_) pos_emb_x = torch.squeeze(pos_emb_x, 1) x= torch.cat( (x, pos_emb_x ), 2) if cond_images!=None: this_cond_images=cond_images.to(device) else: this_cond_images=cond_images if self.CKeys['Debug_ModelPack']==1: print('x with pos_emb_x: ', x) print("into self.imagen...") loss = self.imagen( output, text_embeds = x, # cond_images=cond_images.to(device), cond_images=this_cond_images, unet_number = unet_number, ) return loss def sample ( self, x=None, stop_at_unet_number=1 , cond_scale=7.5, x_data=None, skip_steps=None, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = 5, init_images = None, x_data_tokenized=None, device=None, # +++++++++++++++++++++++++ tokenizer_X=None, Xnormfac=1., max_length=1., # ++ pLM_Model_Name=None, pLM_Model=None, pLM_alphabet=None, esm_layer=None, ): batch_size=1 if x_data != None: print ("Conditioning target sequence provided via ori x_data ...", x_data) print(f"use pLM model {pLM_Model_Name}") # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # prepare x_data into the cond_img channel (batch, channel, seq_len) # for ProteinPredictor: AA seq -> FOrcPath # need to distinguish trivial and ESM series # mimic the block in DataSetPack if pLM_Model_Name=='trivial': # no ESM but direct tokenizer is used here # NEED to pad the 0th position x_data = tokenizer_X.texts_to_sequences(x_data) # two-step padding x_data = sequence.pad_sequences( x_data, maxlen=max_length-1, padding='post', truncating='post', value=0.0, ) # add one 0 at the begining x_data = sequence.pad_sequences( x_data, maxlen=max_length, padding='pre', truncating='pre', value=0.0, ) # (batch, seq_len) x_data=x_data.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device) # (batch, 1, seq_len) x_data = x_data/Xnormfac else: # ++ for esm print("pLM Model: ", pLM_Model_Name) # 1. from AA string to token # need to save 2 positions for and esm_batch_converter = pLM_alphabet.get_batch_converter( truncation_seq_length=max_length-2 ) # prepare seqs for the "esm_batch_converter..." # add dummy labels seqs_ext=[] for i in range(len(x_data)): seqs_ext.append( (" ", x_data[i]) ) # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext) _, x_strs, x_data = esm_batch_converter(seqs_ext) x_strs_lens = (x_data != pLM_alphabet.padding_idx).sum(1) # # NEED to check the size of y_data # need to dealwith if y_data are only shorter sequences # need to add padding with a value, int (1) current_seq_len = x_data.shape[1] print("current seq batch len: ", current_seq_len) missing_num_pad = max_length-current_seq_len if missing_num_pad>0: print("extra padding is added to match the target seq input length...") # padding is needed x_data = F.pad( x_data, (0, missing_num_pad), "constant", pLM_alphabet.padding_idx ) else: print("No extra padding is needed") x_data = x_data.to(device) # # 2. from token to embedding with torch.no_grad(): results = pLM_Model( x_data, repr_layers=[esm_layer], return_contacts=False, ) x_data=results["representations"][esm_layer] # (batch, seq_len, channels) x_data=rearrange( x_data, 'b l c -> b c l' ) # print(batch_tokens.shape) print ("x_data.dim: ", x_data.shape) print ("x_data.type: ", x_data.type) batch_size=x_data.shape[0] # # --------------------------------------------------------- # # this is for Model B with SecStr as input # # # # # -- ver 0.0: channel = 1, all padding at the end: content+000 # # x_data = tokenizer_X.texts_to_sequences(x_data) # # x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post') # # # # x_data= torch.from_numpy(x_data).float().to(device) # # x_data = x_data/Xnormfac # # x_data=x_data.unsqueeze (2) # # x_data=torch.permute(x_data, (0,2,1) ) # # # # ++ ver 1.0: channel = self.pred_dim>1, esm padding, 0+content+000 # x_data = tokenizer_X.texts_to_sequences(x_data) # # padding: 0+content+00 # x_data_0 = [] # for this_x_data in x_data: # x_data_0.append([0]+this_x_data) # to be checked whether this works # x_data = sequence.pad_sequences( # x_data_0, maxlen=max_length, # padding='post', truncating='post', # ) # # normalization: dummy, for future # x_data = torch.from_numpy(x_data).float().to(device) # x_data = x_data/Xnormfac # # adjust the channel:NOTE, here we use "copy" to fill-in the channels # # May need to adjust this part for future models # x_data=x_data.unsqueeze(1).repeat(1,self.pred_dim,1) # print ("After channel expansion, x_data from target sequence=", x_data, x_data.shape) # batch_size=x_data.shape[0] if x_data_tokenized != None: # ++ # this is for Model B with AA tokens as input # task: x_data_tokenized (batch, seq_len) -> x_data (batch, channel, seq_len) print ( "Conditioning target output via provided AA tokens sequence...", x_data_tokenized, x_data_tokenized.shape, ) # transfer tokens into embedding and expand the channels if pLM_Model_Name=='trivial': # self.pred_dim should be 1 x_data=x_data_tokenized.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device) else: # esm models with torch.no_grad(): results = pLM_Model( x_data_tokenized, repr_layers=[esm_layer], return_contacts=False, ) x_data=results["representations"][esm_layer] # (batch, seq_len, channels) x_data=rearrange( x_data, 'b l c -> b c l' ) x_data = x_data.to(device) batch_size=x_data.shape[0] print ("x_data.dim provided from x_data_tokenized: ", x_data.shape) # # -- # # this is for Model B with ForcePath vector as input # # everything is padded as what the dataloader has # # # # if x_data_tokenized.any() != None: # print ("Conditioning target sequence provided via x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape) # # # ++ ver 1.0 # # if self.pred_dim==1: # # # old case: channel=1 # # x_data=x_data_tokenized.unsqueeze (2) # # x_data=torch.permute(x_data, (0,2,1) ).to(device) # # else: # # # new case: include channels in x_data_tokenized # # x_data=x_data_tokenized.to(device) # # # # # -- ver 0.0: original # # x_data=x_data_tokenized.unsqueeze (2) # # x_data=torch.permute(x_data, (0,2,1) ).to(device) # # # # ++ ver 2.0: input x_data_tokenized.dim= (batch, seq_len) # x_data=x_data_tokenized.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device) # # # print ("x_data.dim provided from x_data_tokenized: ", x_data.shape) # batch_size=x_data.shape[0] if init_images != None: print ("Init sequence provided...", init_images) init_images = tokenizer_y.texts_to_sequences(init_images) init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') init_images= torch.from_numpy(init_images).float().to(device)/ynormfac print ("init_images=", init_images) if inpaint_images != None: print ("Inpaint sequence provided...", inpaint_images) print ("Mask: ", inpaint_masks) inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac print ("in_paint images=", inpaint_images) if x !=None: x_in=torch.zeros( (x.shape[0],max_length) ).to(device) x_in[:,:x.shape[1]]=x x=x_in x=x.unsqueeze (2) x= self.fc_embed2(x) pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) pos_emb_x = self.pos_emb_x( pos_matrix_i_) pos_emb_x = torch.squeeze(pos_emb_x, 1) x= torch.cat( (x, pos_emb_x ), 2) batch_size=x.shape[0] output = self.imagen.sample( text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number, cond_images=x_data, skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images, batch_size=batch_size, device=device, ) return output # =========================================================== # final models # explict define unets class ProteinDesigner_B_Old(nn.Module): def __init__(self, unet, timesteps=10 , dim=32, pred_dim=25, loss_type=0, # MSE elucidated=True, padding_idx=0, cond_dim = 512, text_embed_dim = 512, input_tokens=25,#for non-BERT sequence_embed=False, embed_dim_position=32, max_text_len=16, cond_images_channels=0, # ++++++++++++++++++++++ max_length=1, device=None, CKeys=None, PKeys=None, ): super(ProteinDesigner_B_Old, self).__init__() print ("Model B: Generative protein diffusion model, residue-based") print ("Using condition as the initial sequence") self.pred_dim=pred_dim self.loss_type=loss_type # +++++++++++++++++ for debug self.CKeys=CKeys self.PKeys=PKeys self.max_length = max_length assert loss_type == 0, "Losses other than MSE not implemented" self.fc_embed1 = nn.Linear( 8, max_length) # NOT used # INPUT DIM (last), OUTPUT DIM, last self.fc_embed2 = nn.Linear( 1, text_embed_dim) # USED # INPUT DIM (last), OUTPUT DIM, last self.max_text_len=max_text_len self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) text_embed_dim=text_embed_dim+embed_dim_position self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) for i in range (max_text_len): self.pos_matrix_i [i]=i +1 condition_on_text=True self.cond_images_channels=cond_images_channels if self.cond_images_channels>0: condition_on_text = False if self.cond_images_channels>0: print ("Use conditioning image during training....") assert elucidated , "Only elucidated model implemented...." self.is_elucidated=elucidated if elucidated: self.imagen = ElucidatedImagen( unets = (unet), channels=self.pred_dim, channels_out=self.pred_dim , loss_type=loss_type, condition_on_text = condition_on_text, text_embed_dim = text_embed_dim, image_sizes = ( [max_length ]), cond_drop_prob = 0.1, auto_normalize_img = False, num_sample_steps = timesteps, # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) sigma_min = 0.002, # min noise level sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler sigma_data = 0.5, # standard deviation of data distribution rho = 7, # controls the sampling schedule P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper S_tmin = 0.05, S_tmax = 50, S_noise = 1.003, # ++++++++++++++++ # device=device, CKeys=self.CKeys, PKeys=self.PKeys, ).to(device) else: print ("Not implemented.") def forward(self, output, x=None, cond_images = None, unet_number=1, ): # ddddddddddddddddddddddddddddddddddddddddddddd if self.CKeys['Debug_ModelPack']==1: print('output: ', output.shape) print('x: ', x) print('cond_img: ', cond_images) print('unet_num: ', unet_number) if x != None: x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device) x_in[:,:x.shape[1]]=x x=x_in x=x.unsqueeze (2) x= self.fc_embed2(x) pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) pos_emb_x = self.pos_emb_x( pos_matrix_i_) pos_emb_x = torch.squeeze(pos_emb_x, 1) x= torch.cat( (x, pos_emb_x ), 2) if cond_images!=None: this_cond_images=cond_images.to(device) else: this_cond_images=cond_images if self.CKeys['Debug_ModelPack']==1: print('x with pos_emb_x: ', x) loss = self.imagen( output, text_embeds = x, # cond_images=cond_images.to(device), cond_images=this_cond_images, unet_number = unet_number, ) return loss def sample ( self, x=None, stop_at_unet_number=1 , cond_scale=7.5, x_data=None, skip_steps=None, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = 5, init_images = None, x_data_tokenized=None, device=None, # +++++++++++++++++++++++++ tokenizer_X=None, Xnormfac=1., ynormfac=1., max_length=1., ): batch_size=1 if x_data != None: print ("Conditioning target sequence provided via x_data ...", x_data) x_data = tokenizer_X.texts_to_sequences(x_data) x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post') x_data= torch.from_numpy(x_data).float().to(device) x_data = x_data/Xnormfac x_data=x_data.unsqueeze (2) x_data=torch.permute(x_data, (0,2,1) ) print ("x_data from target sequence=", x_data, x_data.shape) batch_size=x_data.shape[0] if x_data_tokenized != None: print ("Conditioning target sequence provided via x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape) x_data=x_data_tokenized.unsqueeze (2) x_data=torch.permute(x_data, (0,2,1) ).to(device) print ("Data provided from x_data_tokenized: ", x_data.shape) batch_size=x_data.shape[0] if init_images != None: print ("Init sequence provided...", init_images) init_images = tokenizer_y.texts_to_sequences(init_images) init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') init_images= torch.from_numpy(init_images).float().to(device)/ynormfac print ("init_images=", init_images) if inpaint_images != None: print ("Inpaint sequence provided...", inpaint_images) print ("Mask: ", inpaint_masks) inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac print ("in_paint images=", inpaint_images) if x !=None: x_in=torch.zeros( (x.shape[0],max_length) ).to(device) x_in[:,:x.shape[1]]=x x=x_in x=x.unsqueeze (2) x= self.fc_embed2(x) pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) pos_emb_x = self.pos_emb_x( pos_matrix_i_) pos_emb_x = torch.squeeze(pos_emb_x, 1) x= torch.cat( (x, pos_emb_x ), 2) batch_size=x.shape[0] output = self.imagen.sample( text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number, cond_images=x_data, skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images, batch_size=batch_size, device=device, ) return output # =========================================================== # final models: old model # Changes: 1. put the UNet part out # class ProteinDesigner_A_II(nn.Module): def __init__( self, unet1, # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx CKeys=None, PKeys=None, ): super(ProteinDesigner_A_II, self).__init__() # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # unload the arguments timesteps =default(PKeys['timesteps'], 10) # 10 dim =default(PKeys['dim'], 32) # 32, pred_dim =default(PKeys['pred_dim'], 25) # 25, loss_type =default(PKeys['loss_type'], 0) # 0, elucidated =default(PKeys['elucidated'], True) # False, padding_idx =default(PKeys['padding_idx'], 0) # 0, cond_dim =default(PKeys['cond_dim'], 512) # 512, text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512, input_tokens =default(PKeys['input_tokens'], 25) # 25,#for non-BERT sequence_embed =default(PKeys['sequence_embed'], False) # False, embed_dim_position =default(PKeys['embed_dim_position'], 32) # 32, max_text_len =default(PKeys['max_text_len'], 16) # 16, cond_images_channels=default(PKeys['cond_images_channels'], 0) # ++ max_length =default(PKeys['max_length'], 64) # 64, device =default(PKeys['device'], 'cuda:0') # 'cuda:0', # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< self.pred_dim=pred_dim self.loss_type=loss_type # ++ self.CKeys=CKeys self.PKeys=PKeys self.device=device assert (loss_type==0), "Loss other than MSE not implemented" self.fc_embed1 = nn.Linear( 8, max_length) # NOT USED self.fc_embed2 = nn.Linear( 1, text_embed_dim) # project the text into embedding dimensions self.max_text_len=max_text_len self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) text_embed_dim=text_embed_dim+embed_dim_position self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) for i in range (max_text_len): self.pos_matrix_i [i]=i +1 if self.CKeys['Debug_ModelPack']==1: print("ModelA.pos_matrix_i: ", self.pos_matrix_i) # # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # # prepare the Unet Key # write_PK_UNet=dict() # write_PK_UNet['dim']=dim # write_PK_UNet['text_embed_dim']=text_embed_dim # write_PK_UNet['cond_dim']=cond_dim # write_PK_UNet['dim_mults']=(1, 2, 4, 8) # write_PK_UNet['num_resnet_blocks']=1 # write_PK_UNet['layer_attns']=(False, True, True, False) # write_PK_UNet['layer_cross_attns']=(False, True, True, False) # write_PK_UNet['channels']=pred_dim # write_PK_UNet['channels_out']=pred_dim # write_PK_UNet['attn_dim_head']=64 # write_PK_UNet['attn_heads']=8 # write_PK_UNet['ff_mult']=2. # write_PK_UNet['lowres_cond']=False # for cascading diffusion - https://cascaded-diffusion.github.io/ # write_PK_UNet['layer_attns_depth']=1 # write_PK_UNet['layer_attns_add_text_cond']=True # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 # write_PK_UNet['attend_at_middle']=True # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) # write_PK_UNet['use_linear_attn']=False # write_PK_UNet['use_linear_cross_attn']=False # write_PK_UNet['cond_on_text'] = True # write_PK_UNet['max_text_len'] = max_length # write_PK_UNet['init_dim'] = None # write_PK_UNet['resnet_groups'] = 8 # write_PK_UNet['init_conv_kernel_size'] =7 # kernel size of initial conv, if not using cross embed # write_PK_UNet['init_cross_embed'] = False #TODO - fix ouput size calcs for conv1d # write_PK_UNet['init_cross_embed_kernel_sizes'] = (3, 7, 15) # write_PK_UNet['cross_embed_downsample'] = False # write_PK_UNet['cross_embed_downsample_kernel_sizes'] = (2, 4) # write_PK_UNet['attn_pool_text'] = True # write_PK_UNet['attn_pool_num_latents'] = 32 #32, #perceiver model latents # write_PK_UNet['dropout'] = 0. # write_PK_UNet['memory_efficient'] = False # write_PK_UNet['init_conv_to_final_conv_residual'] = False # write_PK_UNet['use_global_context_attn'] = True # write_PK_UNet['scale_skip_connection'] = True # write_PK_UNet['final_resnet_block'] = True # write_PK_UNet['final_conv_kernel_size'] = 3 # write_PK_UNet['cosine_sim_attn'] = True # write_PK_UNet['self_cond'] = False # write_PK_UNet['combine_upsample_fmaps'] = True # combine feature maps from all upsample blocks, used in unet squared successfully # write_PK_UNet['pixel_shuffle_upsample'] = False # may address checkboard artifacts # # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< # # # Unet_PKeys=prepare_UNet_keys(write_PK_UNet) # unet1 = OneD_Unet( # CKeys=CKeys, # PKeys=Unet_PKeys, # ).to (self.device) # # ----------------------------------------------------------------------------- # unet1 = OneD_Unet_Old( # dim = dim, # text_embed_dim = text_embed_dim, # cond_dim = cond_dim, # dim_mults = (1, 2, 4, 8), # num_resnet_blocks = 1,#1, # layer_attns = (False, True, True, False), # layer_cross_attns = (False, True, True, False), # channels=self.pred_dim, # channels_out=self.pred_dim , # # # attn_dim_head = 64, # attn_heads = 8, # ff_mult = 2., # lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ # layer_attns_depth =1, # layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 # attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) # use_linear_attn = False, # use_linear_cross_attn = False, # cond_on_text = True, # max_text_len = max_length, # init_dim = None, # resnet_groups = 8, # init_conv_kernel_size =7, # kernel size of initial conv, if not using cross embed # init_cross_embed = False, #TODO - fix ouput size calcs for conv1d # init_cross_embed_kernel_sizes = (3, 7, 15), # cross_embed_downsample = False, # cross_embed_downsample_kernel_sizes = (2, 4), # attn_pool_text = True, # attn_pool_num_latents = 32,#32, #perceiver model latents # dropout = 0., # memory_efficient = False, # init_conv_to_final_conv_residual = False, # use_global_context_attn = True, # scale_skip_connection = True, # final_resnet_block = True, # final_conv_kernel_size = 3, # cosine_sim_attn = True, # self_cond = False, # combine_upsample_fmaps = True, # combine feature maps from all upsample blocks, used in unet squared successfully # pixel_shuffle_upsample = False , # may address checkboard artifacts # # ++ # CKeys=CKeys, # ).to (self.device) assert elucidated , "Only elucidated model implemented...." self.is_elucidated=elucidated if elucidated: self.imagen = ElucidatedImagen( unets = (unet1), channels=self.pred_dim, channels_out=self.pred_dim , loss_type=loss_type, text_embed_dim = text_embed_dim, image_sizes = [max_length], cond_drop_prob = 0.2, auto_normalize_img = False, num_sample_steps = timesteps,#(64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) sigma_min = 0.002, # min noise level sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler sigma_data = 0.5, # standard deviation of data distribution rho = 7, # controls the sampling schedule P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper S_tmin = 0.05, S_tmax = 50, S_noise = 1.003, # ++ CKeys=self.CKeys, PKeys=self.PKeys, ).to (self.device) # + for debug: if CKeys['Debug_ModelPack']==1: print("Check on EImagen:") print("channels: ", self.pred_dim) print("loss_type: ", loss_type) print("text_embed_dim: ",text_embed_dim) print("image_sizes: ", max_length) print("num_sample_steps: ", timesteps) print("Measure imagen:") params( self.imagen) print("Measure fc_embed2") params( self.fc_embed2) print("Measure pos_emb_x") params( self.pos_emb_x) else: print ("Not implemented.") # need to merge this with ModelB:forward def forward( self, # # ------------------- # x, # output, # +++++++++++++++++++ output, x=None, # cond_images=None, unet_number=1, ): #sequences=conditioning, output=prediction # dddddddddddddddddddddddddddddddddddd if self.CKeys['Debug_ModelPack']==1: print("on Model A:forward") print("inputs:") print(' output.dim : ', output.shape) print(' x: ', x) print(' x.dim: ', x.shape) if cond_images==None: print(' cond_img: None ') else: print(' cond_img: ', cond_images.shape) print(' unet_num: ', unet_number) x=x.unsqueeze (2) if self.CKeys['Debug_ModelPack']==1: print("After x.unsqueeze(2), x.dim: ", x.shape) x= self.fc_embed2(x) if self.CKeys['Debug_ModelPack']==1: print("After fc_embed2(x), x.dim: ", x.shape) print() pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) if self.CKeys['Debug_ModelPack']==1: print("pos_matrix_i_.dim: ", pos_matrix_i_.shape) print("pos_matrix_i_: ", pos_matrix_i_) print() pos_emb_x = self.pos_emb_x( pos_matrix_i_) if self.CKeys['Debug_ModelPack']==1: print("After pos_emb_x(pos_matrix_i_), pos_emb_x.dim: ", pos_emb_x.shape) print("pos_emb_x: ", pos_emb_x) print() pos_emb_x = torch.squeeze(pos_emb_x, 1) pos_emb_x[:,x.shape[1]:,:]=0 #set all to zero that are not provided via x pos_emb_x=pos_emb_x[:,:x.shape[1],:] if self.CKeys['Debug_ModelPack']==1: print("after operations, pos_emb_x.dim: ", pos_emb_x.shape) print("pos_emb_x: ", pos_emb_x) print() x= torch.cat( (x, pos_emb_x ), 2) if self.CKeys['Debug_ModelPack']==1: print("after cat((x,pos_emb_x),2)=>x dim: ", x.shape, "Batch x max_text_len x (text_embed_dim+embed_dim_position)") print() print("Now, get into self.imagen part...") # ref: the full argument for imagen:forward() # _________________________________________________ # self, # images, # unet: Union[ NullUnet, DistributedDataParallel] = None, # texts: List[str] = None, # text_embeds = None, # text_masks = None, # unet_number = None, # cond_images = None, # _________________________________________________ loss = self.imagen( output, text_embeds = x, unet_number = unet_number, ) return loss def sample ( self, x=None, stop_at_unet_number=1, cond_scale=7.5, # ++ x_data=None, # image_condi data skip_steps=None, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = 5, init_images = None, x_data_tokenized=None, tokenizer_X=None, Xnormfac=None, # -+ device=None, max_length=None, # for XandY data, in image/sequence format; NOT for text condition max_text_len=None, # for X data, in text format ): # + add for the uniform shape for model A and B if x_data != None: # condition via images print ("Conditioning target sequence provided via sequence/image in x_data ...", x_data) # need tokeni x_data = tokenizer_X.texts_to_sequences(x_data) x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post') x_data= torch.from_numpy(x_data).float().to(self.device) x_data = x_data/Xnormfac x_data=x_data.unsqueeze (2) x_data=torch.permute(x_data, (0,2,1) ) print ("x_data from target sequence=", x_data, x_data.shape) batch_size=x_data.shape[0] # + add for tokenized sequence/image data if x_data_tokenized != None: print ("Conditioning target sequence provided via processed sequence/image in x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape) x_data=x_data_tokenized.unsqueeze (2) x_data=torch.permute(x_data, (0,2,1) ).to(self.device) print ("Data provided from x_data_tokenized: ", x_data.shape) batch_size=x_data.shape[0] # keep but not used if init_images != None: # initial sequence provided via init_images/sequences # need tokenizer_y, ynormfac, max_length print ("Init sequence provided...", init_images) init_images = tokenizer_y.texts_to_sequences(init_images) init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') init_images= torch.from_numpy(init_images).float().to(self.device)/ynormfac print ("init_images=", init_images) # if inpaint_images != None: # inpainting model # need inpaint_images, inpaint_masks, tokenizer_y, ynormfac, max_length print ("Inpaint sequence provided...", inpaint_images) print ("Mask: ", inpaint_masks) inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') inpaint_images= torch.from_numpy(inpaint_images).float().to(self.device)/ynormfac print ("in_paint images=", inpaint_images) # now, for Model A type: text conditioning if x !=None: print ("Conditioning target sequence via tokenized text in x ...", x) # for now, we only consider tokenized text in x # need now: max_text_len ; need for future: # + for debug if self.CKeys['Debug_ModelPack']==1: print("x.dim: ", x.shape) print("max_text_len: ", max_text_len) print("self.device: ",self.device) x_in=torch.zeros( (x.shape[0],max_text_len) ).to(self.device) x_in[:,:x.shape[1]]=x x=x_in x=x.unsqueeze (2) x= self.fc_embed2(x) pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(self.device) pos_emb_x = self.pos_emb_x( pos_matrix_i_) pos_emb_x = torch.squeeze(pos_emb_x, 1) pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x pos_emb_x=pos_emb_x[:,:x.shape[1],:] x= torch.cat( (x, pos_emb_x ), 2) batch_size=x.shape[0] output = self.imagen.sample( text_embeds= x, cond_scale= cond_scale, stop_at_unet_number=stop_at_unet_number, # ++ cond_images=x_data, skip_steps=skip_steps, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = init_images, batch_size=batch_size, device=self.device, ) # x=x.unsqueeze (2) # x= self.fc_embed2(x) # pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) # pos_emb_x = self.pos_emb_x( pos_matrix_i_) # pos_emb_x = torch.squeeze(pos_emb_x, 1) # pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x # pos_emb_x=pos_emb_x[:,:x.shape[1],:] # x= torch.cat( (x, pos_emb_x ), 2) # output = self.imagen.sample( # text_embeds= x, # cond_scale=cond_scale, # stop_at_unet_number=stop_at_unet_number # ) return output # =========================================================== # final models: old model # Changes: 1. replace the OneD_UNet part # class ProteinDesigner_A_I(nn.Module): def __init__( self, # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx CKeys=None, PKeys=None, ): super(ProteinDesigner_A_I, self).__init__() # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # unload the arguments timesteps =default(PKeys['timesteps'], 10) # 10 dim =default(PKeys['dim'], 32) # 32, pred_dim =default(PKeys['pred_dim'], 25) # 25, loss_type =default(PKeys['loss_type'], 0) # 0, elucidated =default(PKeys['elucidated'], True) # False, padding_idx =default(PKeys['padding_idx'], 0) # 0, cond_dim =default(PKeys['cond_dim'], 512) # 512, text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512, input_tokens =default(PKeys['input_tokens'], 25) # 25,#for non-BERT sequence_embed =default(PKeys['sequence_embed'], False) # False, embed_dim_position =default(PKeys['embed_dim_position'], 32) # 32, max_text_len =default(PKeys['max_text_len'], 16) # 16, cond_images_channels=default(PKeys['cond_images_channels'], 0) # ++ max_length =default(PKeys['max_length'], 64) # 64, device =default(PKeys['device'], 'cuda:0') # 'cuda:0', # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< # ++ self.CKeys=CKeys self.PKeys=PKeys self.device=device self.pred_dim=pred_dim self.loss_type=loss_type self.fc_embed1 = nn.Linear( 8, max_length) # NOT USED self.fc_embed2 = nn.Linear( 1, text_embed_dim) # self.max_text_len=max_text_len self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) text_embed_dim=text_embed_dim+embed_dim_position self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) for i in range (max_text_len): self.pos_matrix_i [i]=i +1 assert (loss_type==0), "Loss other than MSE not implemented" # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # prepare the Unet Key write_PK_UNet=dict() write_PK_UNet['dim']=dim write_PK_UNet['text_embed_dim']=text_embed_dim write_PK_UNet['cond_dim']=cond_dim write_PK_UNet['dim_mults']=(1, 2, 4, 8) write_PK_UNet['num_resnet_blocks']=1 write_PK_UNet['layer_attns']=(False, True, True, False) write_PK_UNet['layer_cross_attns']=(False, True, True, False) write_PK_UNet['channels']=pred_dim write_PK_UNet['channels_out']=pred_dim write_PK_UNet['attn_dim_head']=64 write_PK_UNet['attn_heads']=8 write_PK_UNet['ff_mult']=2. write_PK_UNet['lowres_cond']=False # for cascading diffusion - https://cascaded-diffusion.github.io/ write_PK_UNet['layer_attns_depth']=1 write_PK_UNet['layer_attns_add_text_cond']=True # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 write_PK_UNet['attend_at_middle']=True # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) write_PK_UNet['use_linear_attn']=False write_PK_UNet['use_linear_cross_attn']=False write_PK_UNet['cond_on_text'] = True write_PK_UNet['max_text_len'] = max_length write_PK_UNet['init_dim'] = None write_PK_UNet['resnet_groups'] = 8 write_PK_UNet['init_conv_kernel_size'] =7 # kernel size of initial conv, if not using cross embed write_PK_UNet['init_cross_embed'] = False #TODO - fix ouput size calcs for conv1d write_PK_UNet['init_cross_embed_kernel_sizes'] = (3, 7, 15) write_PK_UNet['cross_embed_downsample'] = False write_PK_UNet['cross_embed_downsample_kernel_sizes'] = (2, 4) write_PK_UNet['attn_pool_text'] = True write_PK_UNet['attn_pool_num_latents'] = 32 #32, #perceiver model latents write_PK_UNet['dropout'] = 0. write_PK_UNet['memory_efficient'] = False write_PK_UNet['init_conv_to_final_conv_residual'] = False write_PK_UNet['use_global_context_attn'] = True write_PK_UNet['scale_skip_connection'] = True write_PK_UNet['final_resnet_block'] = True write_PK_UNet['final_conv_kernel_size'] = 3 write_PK_UNet['cosine_sim_attn'] = True write_PK_UNet['self_cond'] = False write_PK_UNet['combine_upsample_fmaps'] = True # combine feature maps from all upsample blocks, used in unet squared successfully write_PK_UNet['pixel_shuffle_upsample'] = False # may address checkboard artifacts # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< # Unet_PKeys=prepare_UNet_keys(write_PK_UNet) unet1 = OneD_Unet( CKeys=CKeys, PKeys=Unet_PKeys, ).to (self.device) if CKeys['Debug_ModelPack']==1: print('Check unet generated...') params(unet1) # # ----------------------------------------------------------------------------- # unet1 = OneD_Unet_Old( # dim = dim, # text_embed_dim = text_embed_dim, # cond_dim = cond_dim, # dim_mults = (1, 2, 4, 8), # num_resnet_blocks = 1,#1, # layer_attns = (False, True, True, False), # layer_cross_attns = (False, True, True, False), # channels=self.pred_dim, # channels_out=self.pred_dim , # # # attn_dim_head = 64, # attn_heads = 8, # ff_mult = 2., # lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ # layer_attns_depth =1, # layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 # attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) # use_linear_attn = False, # use_linear_cross_attn = False, # cond_on_text = True, # max_text_len = max_length, # init_dim = None, # resnet_groups = 8, # init_conv_kernel_size =7, # kernel size of initial conv, if not using cross embed # init_cross_embed = False, #TODO - fix ouput size calcs for conv1d # init_cross_embed_kernel_sizes = (3, 7, 15), # cross_embed_downsample = False, # cross_embed_downsample_kernel_sizes = (2, 4), # attn_pool_text = True, # attn_pool_num_latents = 32,#32, #perceiver model latents # dropout = 0., # memory_efficient = False, # init_conv_to_final_conv_residual = False, # use_global_context_attn = True, # scale_skip_connection = True, # final_resnet_block = True, # final_conv_kernel_size = 3, # cosine_sim_attn = True, # self_cond = False, # combine_upsample_fmaps = True, # combine feature maps from all upsample blocks, used in unet squared successfully # pixel_shuffle_upsample = False , # may address checkboard artifacts # # ++ # CKeys=CKeys, # ).to (self.device) assert elucidated , "Only elucidated model implemented...." self.is_elucidated=elucidated if elucidated: self.imagen = ElucidatedImagen( unets = (unet1), channels=self.pred_dim, channels_out=self.pred_dim , loss_type=loss_type, text_embed_dim = text_embed_dim, image_sizes = [max_length], cond_drop_prob = 0.2, auto_normalize_img = False, num_sample_steps = timesteps,#(64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) sigma_min = 0.002, # min noise level sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler sigma_data = 0.5, # standard deviation of data distribution rho = 7, # controls the sampling schedule P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper S_tmin = 0.05, S_tmax = 50, S_noise = 1.003, # ++ CKeys=self.CKeys, PKeys=self.PKeys, ).to (self.device) else: print ("Not implemented.") def forward(self, x, output, unet_number=1): #sequences=conditioning, output=prediction x=x.unsqueeze (2) x= self.fc_embed2(x) pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) pos_emb_x = self.pos_emb_x( pos_matrix_i_) pos_emb_x = torch.squeeze(pos_emb_x, 1) pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x pos_emb_x=pos_emb_x[:,:x.shape[1],:] x= torch.cat( (x, pos_emb_x ), 2) loss = self.imagen( output, text_embeds = x, unet_number = unet_number, ) return loss def sample (self, x, stop_at_unet_number=1 ,cond_scale=7.5,): x=x.unsqueeze (2) x= self.fc_embed2(x) pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) pos_emb_x = self.pos_emb_x( pos_matrix_i_) pos_emb_x = torch.squeeze(pos_emb_x, 1) pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x pos_emb_x=pos_emb_x[:,:x.shape[1],:] x= torch.cat( (x, pos_emb_x ), 2) output = self.imagen.sample(text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number) return output # =========================================================== # final models: old model # class ProteinDesigner_A_Old(nn.Module): def __init__( self, timesteps=10 , dim=32, pred_dim=25, loss_type=0, elucidated=False, padding_idx=0, cond_dim = 512, text_embed_dim = 512, input_tokens=25,#for non-BERT sequence_embed=False, embed_dim_position=32, max_text_len=16, device='cuda:0', # ++ max_length=64, CKeys=None, PKeys=None, ): super(ProteinDesigner_A_Old, self).__init__() # ++ self.CKeys=CKeys self.PKeys=PKeys self.device=device self.pred_dim=pred_dim self.loss_type=loss_type self.fc_embed1 = nn.Linear( 8, max_length) # NOT USED self.fc_embed2 = nn.Linear( 1, text_embed_dim) # self.max_text_len=max_text_len self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) text_embed_dim=text_embed_dim+embed_dim_position self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) for i in range (max_text_len): self.pos_matrix_i [i]=i +1 assert (loss_type==0), "Loss other than MSE not implemented" unet1 = OneD_Unet_Old( dim = dim, text_embed_dim = text_embed_dim, cond_dim = cond_dim, dim_mults = (1, 2, 4, 8), num_resnet_blocks = 1,#1, layer_attns = (False, True, True, False), layer_cross_attns = (False, True, True, False), channels=self.pred_dim, channels_out=self.pred_dim , # attn_dim_head = 64, attn_heads = 8, ff_mult = 2., lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ layer_attns_depth =1, layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) use_linear_attn = False, use_linear_cross_attn = False, cond_on_text = True, max_text_len = max_length, # need to be checked init_dim = None, resnet_groups = 8, init_conv_kernel_size =7, # kernel size of initial conv, if not using cross embed init_cross_embed = False, #TODO - fix ouput size calcs for conv1d init_cross_embed_kernel_sizes = (3, 7, 15), cross_embed_downsample = False, cross_embed_downsample_kernel_sizes = (2, 4), attn_pool_text = True, attn_pool_num_latents = 32,#32, #perceiver model latents dropout = 0., memory_efficient = False, init_conv_to_final_conv_residual = False, use_global_context_attn = True, scale_skip_connection = True, final_resnet_block = True, final_conv_kernel_size = 3, cosine_sim_attn = True, self_cond = False, combine_upsample_fmaps = True, # combine feature maps from all upsample blocks, used in unet squared successfully pixel_shuffle_upsample = False , # may address checkboard artifacts # ++ CKeys=CKeys, ).to (self.device) # print the size if CKeys['Debug_ModelPack']==1: print("Check NUnet...") params( unet1) assert elucidated , "Only elucidated model implemented...." self.is_elucidated=elucidated if elucidated: self.imagen = ElucidatedImagen( unets = (unet1), channels=self.pred_dim, channels_out=self.pred_dim , loss_type=loss_type, text_embed_dim = text_embed_dim, image_sizes = [max_length], cond_drop_prob = 0.2, auto_normalize_img = False, num_sample_steps = timesteps,#(64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) sigma_min = 0.002, # min noise level sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler sigma_data = 0.5, # standard deviation of data distribution rho = 7, # controls the sampling schedule P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper S_tmin = 0.05, S_tmax = 50, S_noise = 1.003, # ++ CKeys=self.CKeys, PKeys=self.PKeys, ).to (self.device) if CKeys['Debug_ModelPack']==1: print("Check on EImagen:") print("channels: ", self.pred_dim) print("loss_type: ", loss_type) print("text_embed_dim: ",text_embed_dim) print("image_sizes: ", max_length) print("num_sample_steps: ", timesteps) print("Measure imagen:") params( self.imagen) print("Measure fc_embed2") params( self.fc_embed2) print("Measure pos_emb_x") params( self.pos_emb_x) else: print ("Not implemented.") def forward(self, x, output, unet_number=1): #sequences=conditioning, output=prediction x=x.unsqueeze (2) x= self.fc_embed2(x) pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) pos_emb_x = self.pos_emb_x( pos_matrix_i_) pos_emb_x = torch.squeeze(pos_emb_x, 1) pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x pos_emb_x=pos_emb_x[:,:x.shape[1],:] x= torch.cat( (x, pos_emb_x ), 2) loss = self.imagen( output, text_embeds = x, unet_number = unet_number, ) return loss def sample (self, x, stop_at_unet_number=1 ,cond_scale=7.5,): x=x.unsqueeze (2) x= self.fc_embed2(x) pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) pos_emb_x = self.pos_emb_x( pos_matrix_i_) pos_emb_x = torch.squeeze(pos_emb_x, 1) pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x pos_emb_x=pos_emb_x[:,:x.shape[1],:] x= torch.cat( (x, pos_emb_x ), 2) output = self.imagen.sample(text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number) return output