Bo-Ni's picture
Upload the lib
269fa8c
########################################################
## Attention-Diffusion model
########################################################
#based on: https://github.com/lucidrains/imagen-pytorch
import math
import copy
from random import random
from typing import List, Union
from tqdm.auto import tqdm
from functools import partial, wraps
from contextlib import contextmanager, nullcontext
from collections import namedtuple
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
from torch import nn, einsum
from torch.cuda.amp import autocast
from torch.special import expm1
import torchvision.transforms as T
import kornia.augmentation as K
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange, Reduce
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
#
from tensorflow.keras.preprocessing import text, sequence
from tensorflow.keras.preprocessing.text import Tokenizer
# # --
# from PD_SpLMxDiff.UtilityPack import prepare_UNet_keys, modify_keys, params
# ++
from PD_pLMProbXDiff.UtilityPack import (
prepare_UNet_keys, modify_keys, params
)
#
from torchinfo import summary
import json
# quick way to identify the device:
device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)
print('identify the device independently', device)
# ==================================================
# helper functions
# ==================================================
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def first(arr, d = None):
if len(arr) == 0:
return d
return arr[0]
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(output) == length
return output
def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
def cast_uint8_images_to_float(images):
if not images.dtype == torch.uint8:
return images
return images / 255
def module_device(module):
return next(module.parameters()).device
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
# ==================================================
# helper classes
# ==================================================
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
# ==================================================
# tensor helpers
# ==================================================
def log(t, eps: float = 1e-12):
return torch.log(t.clamp(min = eps))
def l2norm(t):
return F.normalize(t, dim = -1)
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
def masked_mean(t, *, dim, mask = None):
if not exists(mask):
return t.mean(dim = dim)
denom = mask.sum(dim = dim, keepdim = True)
mask = rearrange(mask, 'b n -> b n 1')
masked_t = t.masked_fill(~mask, 0.)
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
def resize_image_to(
image,
target_image_size,
clamp_range = None
):
orig_image_size = image.shape[-1]
if orig_image_size == target_image_size:
return image
out = F.interpolate(image.float(), target_image_size, mode = 'linear', align_corners = True)
return out
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
#
def normalize_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
elif prob == 0:
return torch.zeros(shape, device = device, dtype = torch.bool)
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
# ==================================================
# gaussian diffusion with continuous time helper functions and classes
# large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
# ==================================================
@torch.jit.script
def beta_linear_log_snr(t):
return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
@torch.jit.script
def alpha_cosine_log_snr(t, s: float = 0.008):
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version
def log_snr_to_alpha_sigma(log_snr):
return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
class GaussianDiffusionContinuousTimes(nn.Module):
def __init__(self, *, noise_schedule, timesteps = 1000):
super().__init__()
if noise_schedule == "linear":
self.log_snr = beta_linear_log_snr
elif noise_schedule == "cosine":
self.log_snr = alpha_cosine_log_snr
else:
raise ValueError(f'invalid noise schedule {noise_schedule}')
self.num_timesteps = timesteps
def get_times(self, batch_size, noise_level, *, device):
return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32)
def sample_random_times(self, batch_size, max_thres = 0.999, *, device):
return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres)
def get_condition(self, times):
return maybe(self.log_snr)(times)
def get_sampling_timesteps(self, batch, *, device):
times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
times = repeat(times, 't -> b t', b = batch)
times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
times = times.unbind(dim = -1)
return times
def q_posterior(self, x_start, x_t, t, *, t_next = None):
t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))
""" https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
log_snr = self.log_snr(t)
log_snr_next = self.log_snr(t_next)
log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
# c - as defined near eq 33
c = -expm1(log_snr - log_snr_next)
posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)
# following (eq. 33)
posterior_variance = (sigma_next ** 2) * c
posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise = None):
dtype = x_start.dtype
if isinstance(t, float):
batch = x_start.shape[0]
t = torch.full((batch,), t, device = x_start.device, dtype = dtype)
noise = default(noise, lambda: torch.randn_like(x_start))
log_snr = self.log_snr(t).type(dtype)
log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
return alpha * x_start + sigma * noise, log_snr
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape, device, dtype = x_from.shape, x_from.device, x_from.dtype
batch = shape[0]
if isinstance(from_t, float):
from_t = torch.full((batch,), from_t, device = device, dtype = dtype)
if isinstance(to_t, float):
to_t = torch.full((batch,), to_t, device = device, dtype = dtype)
noise = default(noise, lambda: torch.randn_like(x_from))
log_snr = self.log_snr(from_t)
log_snr_padded_dim = right_pad_dims_to(x_from, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
log_snr_to = self.log_snr(to_t)
log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to)
alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to)
return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha
def predict_start_from_noise(self, x_t, t, noise):
log_snr = self.log_snr(t)
log_snr = right_pad_dims_to(x_t, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)
# ==================================================
# norms and residuals
# ==================================================
class LayerNorm(nn.Module):
def __init__(self, feats, stable = False, dim = -1):
super().__init__()
self.stable = stable
self.dim = dim
self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1))))
def forward(self, x):
dtype, dim = x.dtype, self.dim
if self.stable:
x = x / x.amax(dim = dim, keepdim = True).detach()
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = dim, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = dim, keepdim = True)
return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype)
ChanLayerNorm = partial(LayerNorm, dim = -2)
class Always():
def __init__(self, val):
self.val = val
def __call__(self, *args, **kwargs):
return self.val
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class Parallel(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
def forward(self, x):
outputs = [fn(x) for fn in self.fns]
return sum(outputs)
# ==================================================
# attention pooling
# ==================================================
# modified
class PerceiverAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
cosine_sim_attn = False
):
super().__init__()
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1
self.cosine_sim_attn = cosine_sim_attn
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
nn.LayerNorm(dim)
)
def forward(self, x, latents, mask = None):
x = self.norm(x)
latents = self.norm_latents(latents)
b, h = x.shape[0], self.heads
q = self.to_q(latents)
# the paper differs from Perceiver in which they also concat the key / values
# derived from the latents to be attended to
kv_input = torch.cat((x, latents), dim = -2)
k, v = self.to_kv(kv_input).chunk(2, dim = -1)
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
q = q * self.scale
# cosine sim attention
if self.cosine_sim_attn:
q, k = map(l2norm, (q, k))
# similarities and masking
sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
mask = F.pad(mask, (0, latents.shape[-2]), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# attention
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.to(sim.dtype)
out = einsum('... i j, ... j d -> ... i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
return self.to_out(out)
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_latents = 64,
num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
max_seq_len = 512,
ff_mult = 4,
cosine_sim_attn = False
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.latents = nn.Parameter(torch.randn(num_latents, dim))
self.to_latents_from_mean_pooled_seq = None
if num_latents_mean_pooled > 0:
self.to_latents_from_mean_pooled_seq = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn),
FeedForward(dim = dim, mult = ff_mult)
]))
def forward(self, x, mask = None):
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device = device))
x_with_pos = x + pos_emb
latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
if exists(self.to_latents_from_mean_pooled_seq):
meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim = -2)
for attn, ff in self.layers:
latents = attn(x_with_pos, latents, mask = mask) + latents
latents = ff(latents) + latents
return latents
# ====================================================
# attention
# ====================================================
class Attention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
context_dim = None,
cosine_sim_attn = False
):
super().__init__()
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
self.cosine_sim_attn = cosine_sim_attn
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
self.heads = heads
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context = None, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
q = q * self.scale
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# add text conditioning, if present
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
k = torch.cat((ck, k), dim = -2)
v = torch.cat((cv, v), dim = -2)
# cosine sim attention
if self.cosine_sim_attn:
q, k = map(l2norm, (q, k))
# calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale
# relative positional encoding (T5 style)
if exists(attn_bias):
sim = sim + attn_bias
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# attention
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.to(sim.dtype)
# aggregate values
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# ============================================================
# decoder
# ============================================================
def Upsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv1d(dim, dim_out, 3, padding = 1)
)
class PixelShuffleUpsample(nn.Module):
"""
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
"""
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv1d(dim, dim_out * 4, 1)
self.net = nn.Sequential(
conv,
nn.SiLU(),
nn.PixelShuffle(2)
)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h )
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
def Downsample(dim, dim_out = None):
# https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
# named SP-conv in the paper, but basically a pixel unshuffle
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b c (h s1) -> b (c s1) h', s1 = 2),
nn.Conv1d(dim * 2, dim_out, 1)
)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
class LearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
groups = 8,
norm = True
):
super().__init__()
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
self.activation = nn.SiLU()
self.project = nn.Conv1d(dim, dim_out, 3, padding = 1)
def forward(self, x, scale_shift = None):
x = self.groupnorm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.activation(x)
return self.project(x)
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
groups = 8,
linear_attn = False,
use_gca = False,
squeeze_excite = False,
**attn_kwargs
):
super().__init__()
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out * 2)
)
self.cross_attn = None
if exists(cond_dim):
attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
self.cross_attn = EinopsToAndFrom(
'b c h ',
'b h c',
attn_klass(
dim = dim_out,
context_dim = cond_dim,
**attn_kwargs
)
)
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else Identity()
def forward(self, x, time_emb = None, cond = None):
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x)
if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context = cond) + h
h = self.block2(h, scale_shift = scale_shift)
h = h * self.gca(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
norm_context = False,
cosine_sim_attn = False
):
super().__init__()
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
self.cosine_sim_attn = cosine_sim_attn
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
# cosine sim attention
if self.cosine_sim_attn:
q, k = map(l2norm, (q, k))
# similarities
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.to(sim.dtype)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class LinearCrossAttention(CrossAttention):
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads, b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# masking
max_neg_value = -torch.finfo(x.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b n -> b n 1')
k = k.masked_fill(~mask, max_neg_value)
v = v.masked_fill(~mask, 0.)
# linear attention
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
return self.to_out(out)
class LinearAttention(nn.Module):
def __init__(
self,
dim,
dim_head = 32,
heads = 8,
dropout = 0.05,
context_dim = None,
**kwargs
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.nonlin = nn.SiLU()
self.to_q = nn.Sequential(
nn.Dropout(dropout),
nn.Conv1d(dim, inner_dim, 1, bias = False),
nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_k = nn.Sequential(
nn.Dropout(dropout),
nn.Conv1d(dim, inner_dim, 1, bias = False),
nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_v = nn.Sequential(
nn.Dropout(dropout),
nn.Conv1d(dim, inner_dim, 1, bias = False),
nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
self.to_out = nn.Sequential(
nn.Conv1d(inner_dim, dim, 1, bias = False),
ChanLayerNorm(dim)
)
def forward(self, fmap, context = None):
h, x, y = self.heads, *fmap.shape[-2:]
fmap = self.norm(fmap)
q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h)
k = torch.cat((k, ck), dim = -2)
v = torch.cat((v, cv), dim = -2)
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
out = self.nonlin(out)
return self.to_out(out)
class GlobalContext(nn.Module):
""" basically a superior form of squeeze-excitation that is attention-esque """
def __init__(
self,
*,
dim_in,
dim_out
):
super().__init__()
self.to_k = nn.Conv1d(dim_in, 1, 1)
hidden_dim = max(3, dim_out // 2)
self.net = nn.Sequential(
nn.Conv1d(dim_in, hidden_dim, 1),
nn.SiLU(),
nn.Conv1d(hidden_dim, dim_out, 1),
nn.Sigmoid()
)
def forward(self, x):
context = self.to_k(x)
x, context = rearrange_many((x, context), 'b n ... -> b n (...)')
out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
return self.net(out)
def FeedForward(dim, mult = 2):
hidden_dim = int(dim * mult)
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, hidden_dim, bias = False),
nn.GELU(),
LayerNorm(hidden_dim),
nn.Linear(hidden_dim, dim, bias = False)
)
def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width
hidden_dim = int(dim * mult)
return nn.Sequential(
ChanLayerNorm(dim),
nn.Conv1d(dim, hidden_dim, 1, bias = False),
nn.GELU(),
ChanLayerNorm(hidden_dim),
nn.Conv1d(hidden_dim, dim, 1, bias = False)
)
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
context_dim = None,
cosine_sim_attn = False
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
EinopsToAndFrom('b c h', 'b h c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)),
ChanFeedForward(dim = dim, mult = ff_mult)
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
class LinearAttentionTransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
context_dim = None,
**kwargs
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
ChanFeedForward(dim = dim, mult = ff_mult)
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
kernel_sizes,
dim_out = None,
stride = 2
):
super().__init__()
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# calculate the dimension at each scale
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(nn.Conv1d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
class UpsampleCombiner(nn.Module):
def __init__(
self,
dim,
*,
enabled = False,
dim_ins = tuple(),
dim_outs = tuple()
):
super().__init__()
dim_outs = cast_tuple(dim_outs, len(dim_ins))
assert len(dim_ins) == len(dim_outs)
self.enabled = enabled
if not self.enabled:
self.dim_out = dim
return
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
def forward(self, x, fmaps = None):
target_size = x.shape[-1]
fmaps = default(fmaps, tuple())
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
return x
fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
return torch.cat((x, *outs), dim = 1)
########################################################
## 1D Unet updated
########################################################
class OneD_Unet(nn.Module):
def __init__(
self,
*,
# +++++++++++++++++++++++++++
CKeys=None,
PKeys=None,
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
self._locals.pop('self', None)
self._locals.pop('__class__', None)
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# unload the parameters
# print('I am here')
if CKeys['Debug_ModelPack']==1:
print(json.dumps(PKeys, indent=4))
dim = PKeys['dim'] # ,
# image_embed_dim = 1024, # NOT used so far
text_embed_dim = default(PKeys['text_embed_dim'], 768) # 768, # get_encoded_dim(DEFAULT_T5_NAME)
num_resnet_blocks = default(PKeys['num_resnet_blocks'], 1)# 1,
cond_dim = default(PKeys['cond_dim'], None) # None,
num_image_tokens = default(PKeys['num_image_tokens'], 4) # 4,
num_time_tokens = default(PKeys['num_time_tokens'], 2) # 2,
learned_sinu_pos_emb_dim = default(PKeys['learned_sinu_pos_emb_dim'], 16) # 16,
out_dim = default(PKeys['out_dim'], None) # None,
dim_mults = default(PKeys['dim_mults'], (1, 2, 4, 8)) # (1, 2, 4, 8),
cond_images_channels = default(PKeys['cond_images_channels'], 0) # 0,
channels = default(PKeys['channels'], 3) # 3,
channels_out = default(PKeys['channels_out'], None) # None,
attn_dim_head = default(PKeys['attn_dim_head'], 64) # 64,
attn_heads = default(PKeys['attn_heads'], 8) # 8,
ff_mult = default(PKeys['ff_mult'], 2.) # 2.,
lowres_cond = default(PKeys['lowres_cond'], False) # False, # for cascading diffusion - https://cascaded-diffusion.github.io/
layer_attns = default(PKeys['layer_attns'], True) # True,
layer_attns_depth = default(PKeys['layer_attns_depth'], 1) # 1,
layer_attns_add_text_cond = default(PKeys['layer_attns_add_text_cond'], True) # True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
attend_at_middle = default(PKeys['attend_at_middle'], True) # True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
layer_cross_attns = default(PKeys['layer_cross_attns'], True) # True,
use_linear_attn = default(PKeys['use_linear_attn'], False) # False,
use_linear_cross_attn = default(PKeys['use_linear_cross_attn'], False) # False,
cond_on_text = default(PKeys['cond_on_text'], True) # True,
max_text_len = default(PKeys['max_text_len'], 256) # 256,
init_dim = default(PKeys['init_dim'], None) # None,
resnet_groups = default(PKeys['resnet_groups'], 8) # 8,
init_conv_kernel_size = default(PKeys['init_conv_kernel_size'], 7) # 7, # kernel size of initial conv, if not using cross embed
init_cross_embed = default(PKeys['init_cross_embed'], False) # False,
init_cross_embed_kernel_sizes = default(PKeys['init_cross_embed_kernel_sizes'], (3, 7, 15)) # (3, 7, 15),
cross_embed_downsample = default(PKeys['cross_embed_downsample'], False) # False,
cross_embed_downsample_kernel_sizes = default(PKeys['cross_embed_downsample_kernel_sizes'], (2,4)) # (2, 4),
attn_pool_text = default(PKeys['attn_pool_text'], True) # True,
attn_pool_num_latents = default(PKeys['attn_pool_num_latents'], 32) # 32,
dropout = default(PKeys['dropout'], 0.) # 0.,
memory_efficient = default(PKeys['memory_efficient'], False) # False,
init_conv_to_final_conv_residual = default(PKeys['init_conv_to_final_conv_residual'], False) # False,
use_global_context_attn = default(PKeys['use_global_context_attn'], True) # True,
scale_skip_connection = default(PKeys['scale_skip_connection'], True) # True,
final_resnet_block = default(PKeys['final_resnet_block'], True) # True,
final_conv_kernel_size = default(PKeys['final_conv_kernel_size'], 3) # 3,
cosine_sim_attn = default(PKeys['cosine_sim_attn'], False) # False,
self_cond = default(PKeys['self_cond'], False) # False,
combine_upsample_fmaps = default(PKeys['combine_upsample_fmaps'], False) # False, # combine feature maps from all upsample blocks, used in unet squared successfully
pixel_shuffle_upsample = default(PKeys['pixel_shuffle_upsample'], False) # False , # may address checkboard artifacts
beginning_and_final_conv_present = default(PKeys['beginning_and_final_conv_present'], True) # True , #TODO add cross-attn, doesnt work yet...whether or not to have final conv layer
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
self.CKeys=CKeys
# + for debug
if CKeys['Debug_ModelPack']==1:
print("Check the inputs:")
print(json.dumps(PKeys, indent=4))
assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
if dim < 128:
print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
# determine dimensions
self.channels = channels
self.channels_out = default(channels_out, channels)
# (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
# (2) in self conditioning, one appends the predict x0 (x_start)
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
init_dim = default(init_dim, dim)
self.self_cond = self_cond
# optional image conditioning
self.has_cond_image = cond_images_channels > 0
self.cond_images_channels = cond_images_channels
init_channels += cond_images_channels
self.beginning_and_final_conv_present=beginning_and_final_conv_present
if self.beginning_and_final_conv_present:
self.init_conv = CrossEmbedLayer(
init_channels, dim_out = init_dim,
kernel_sizes = init_cross_embed_kernel_sizes,
stride = 1
) if init_cross_embed else nn.Conv1d(
init_channels, init_dim,
init_conv_kernel_size,
padding = init_conv_kernel_size // 2)
if self.CKeys['Debug_ModelPack']==1 and self.beginning_and_final_conv_present:
print("On self.init_conv:")
print(f"init_channels: {str(init_channels)}\n init_dim: {str(init_dim)}")
print("On self.init_conv, batch#=1, init_ch=2*#seq_channel, seq_len=128")
summary(self.init_conv, (1, init_channels, 128), verbose=1)
# # + for debug
# if self.CKeys['Debug_ModelPack']==1:
# print("On self.init_conv")
# summary(self.init_conv, (1, init_channels, init_dim))
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time conditioning
cond_dim = default(cond_dim, dim)
time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
# embedding time for log(snr) noise from continuous version
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
self.to_time_hiddens = nn.Sequential(
sinu_pos_emb,
nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
nn.SiLU()
)
# # + for debug
# if self.CKeys['Debug_ModelPack']==1:
# print("On self.to_time_hiddens")
# print(" in dim: ", learned_sinu_pos_emb_dim)
# print(" ou dim: ", time_cond_dim)
# # summary(self.to_time_hiddens, (1,learned_sinu_pos_emb_dim), verbose=1)
# # summary(sinu_pos_emb, (1,learned_sinu_pos_emb_dim), verbose=1)
self.to_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
# project to time tokens as well as time hiddens
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print(" to_time_tokens")
print(" in dim: ", time_cond_dim) # =4xdim cond_dim=dim
print(" ou dim: ", num_time_tokens)
self.to_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
# low res aug noise conditioning
self.lowres_cond = lowres_cond
if lowres_cond:
self.to_lowres_time_hiddens = nn.Sequential(
LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
nn.SiLU()
)
self.to_lowres_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
self.to_lowres_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
# normalizations
self.norm_cond = nn.LayerNorm(cond_dim)
# text encoding conditioning (optional)
self.text_to_cond = None
if cond_on_text: #only add linear lear if cond dim is not text emnd dim
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
if text_embed_dim != cond_dim:
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
self.text_cond_linear=True
else:
print ("Text conditioning is equatl to cond_dim - no linear layer used")
self.text_cond_linear=False
# finer control over whether to condition on text encodings
self.cond_on_text = cond_on_text
# attention pooling
self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2,
dim_head = attn_dim_head, heads = attn_heads,
num_latents = attn_pool_num_latents,
cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None
# for classifier free guidance
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
# for non-attention based text conditioning at all points in the network where time is also conditioned
self.to_text_non_attn_cond = None
if cond_on_text:
self.to_text_non_attn_cond = nn.Sequential(
nn.LayerNorm(cond_dim),
nn.Linear(cond_dim, time_cond_dim),
nn.SiLU(),
nn.Linear(time_cond_dim, time_cond_dim)
)
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn)
num_layers = len(in_out)
# resnet block klass
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
resnet_groups = cast_tuple(resnet_groups, num_layers)
resnet_klass = partial(ResnetBlock, **attn_kwargs)
layer_attns = cast_tuple(layer_attns, num_layers)
layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
use_linear_attn = cast_tuple(use_linear_attn, num_layers)
use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers)
assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
# downsample klass
downsample_klass = Downsample
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# initial resnet block (for memory efficient unet)
self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
# scale for resnet skip connections
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn]
reversed_layer_params = list(map(reversed, layer_params))
# downsampling layers
skip_connect_dims = [] # keep track of skip connection dimensions
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)):
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
if layer_attn:
transformer_block_klass = TransformerBlock
elif layer_use_linear_attn:
transformer_block_klass = LinearAttentionTransformerBlock
else:
transformer_block_klass = Identity
current_dim = dim_in
# whether to pre-downsample, from memory efficient unet
pre_downsample = None
if memory_efficient:
pre_downsample = downsample_klass(dim_in, dim_out)
current_dim = dim_out
skip_connect_dims.append(current_dim)
# whether to do post-downsample, for non-memory efficient unet
post_downsample = None
if not memory_efficient:
post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv1d(dim_in, dim_out, 3, padding = 1), nn.Conv1d(dim_in, dim_out, 1))
self.downs.append(
nn.ModuleList([
pre_downsample,
resnet_klass(
current_dim, current_dim,
cond_dim = layer_cond_dim,
linear_attn = layer_use_linear_cross_attn,
time_cond_dim = time_cond_dim, groups = groups
),
nn.ModuleList([
ResnetBlock(
current_dim, current_dim,
time_cond_dim = time_cond_dim,
groups = groups,
use_gca = use_global_context_attn
) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(
dim = current_dim,
depth = layer_attn_depth,
ff_mult = ff_mult,
context_dim = cond_dim,
**attn_kwargs),
post_downsample
])
)
# middle layers
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c h', 'b h c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
# upsample klass
upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# upsampling layers
upsample_fmap_dims = []
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
is_last = ind == (len(in_out) - 1)
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
if layer_attn:
transformer_block_klass = TransformerBlock
elif layer_use_linear_attn:
transformer_block_klass = LinearAttentionTransformerBlock
else:
transformer_block_klass = Identity
skip_connect_dim = skip_connect_dims.pop()
upsample_fmap_dims.append(dim_out)
self.ups.append(nn.ModuleList([
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
]))
# whether to combine feature maps from all upsample blocks before final resnet block out
self.upsample_combiner = UpsampleCombiner(
dim = dim,
enabled = combine_upsample_fmaps,
dim_ins = upsample_fmap_dims,
dim_outs = dim
)
# whether to do a final residual from initial conv to the final resnet block out
self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
# final optional resnet block and convolution out
self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
final_conv_dim_in = dim if final_resnet_block else final_conv_dim
final_conv_dim_in += (channels if lowres_cond else 0)
if self.beginning_and_final_conv_present:
print (final_conv_dim_in, self.channels_out)
self.final_conv = nn.Conv1d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)
if self.beginning_and_final_conv_present:
zero_init_(self.final_conv)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
self,
*,
lowres_cond,
text_embed_dim,
channels,
channels_out,
cond_on_text
):
# # ----------------------------------------------
# if lowres_cond == self.lowres_cond and \
# channels == self.channels and \
# cond_on_text == self.cond_on_text and \
# text_embed_dim == self._locals['text_embed_dim'] and \
# channels_out == self.channels_out:
# return self
# +++++++++++++++++++++++++++++++++++++++++++++++
if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_text == self.cond_on_text and \
text_embed_dim == self._locals['PKeys']['text_embed_dim'] and \
channels_out == self.channels_out:
return self
# # --------------------------------------
# updated_kwargs = dict(
# lowres_cond = lowres_cond,
# text_embed_dim = text_embed_dim,
# channels = channels,
# channels_out = channels_out,
# cond_on_text = cond_on_text
# )
# ++++++++++++++++++++++++++++++++++++++
write_key=dict(
lowres_cond = lowres_cond,
text_embed_dim = text_embed_dim,
channels = channels,
channels_out = channels_out,
cond_on_text = cond_on_text
)
# this_PKeys=prepare_UNet_keys(write_key)
old_PKeys=self._locals['PKeys']
this_PKeys=modify_keys(old_PKeys, write_key)
# write the new arguments
updated_kwargs = dict(
CKeys=self.CKeys,
PKeys=this_PKeys,
)
return self.__class__(**{**self._locals, **updated_kwargs})
# methods for returning the full unet config as well as its parameter state
def to_config_and_state_dict(self):
return self._locals, self.state_dict()
# class method for rehydrating the unet from its config and state dict
@classmethod
def from_config_and_state_dict(klass, config, state_dict):
unet = klass(**config)
unet.load_state_dict(state_dict)
return unet
# methods for persisting unet to disk
def persist_to_file(self, path):
path = Path(path)
path.parents[0].mkdir(exist_ok = True, parents = True)
config, state_dict = self.to_config_and_state_dict()
pkg = dict(config = config, state_dict = state_dict)
torch.save(pkg, str(path))
# class method for rehydrating the unet from file saved with `persist_to_file`
@classmethod
def hydrate_from_file(klass, path):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path))
assert 'config' in pkg and 'state_dict' in pkg
config, state_dict = pkg['config'], pkg['state_dict']
return Unet.from_config_and_state_dict(config, state_dict)
# forward with classifier free guidance
def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
# forward fun returns the loss
# Here, for Model B:
# x : (batch, 1, seq_len) values: from noised image/Y
# time : (batch) values:
# text_embeds : None
# text_mask: None
# cond_images: input sequence/X
# lowres_noise_times: None
# lowres_cond_img: None
# cond_drop_prob: 0.1
# self_cond: from noised image/Y
def forward(
self,
x,
time,
*,
lowres_cond_img = None,
lowres_noise_times = None,
text_embeds = None,
text_mask = None,
cond_images = None,
self_cond = None,
cond_drop_prob = 0.
):
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("========================================")
print("Here are OneD_Unet:forward")
ii=0
batch_size, device = x.shape[0], x.device
# condition on self
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("Check inputs: ")
print(" x-0 .dim: ", x.shape, f"[batch, {str(self.channels)}, seq_len]")
print(" time .dim: ", time.shape, "batch")
if lowres_cond_img==None:
print(" lowres_cond_img: None")
else:
print(" lowres_cond_img.dim: ", lowres_cond_img.shape)
if lowres_noise_times==None:
print(" lowres_noise_times: None")
else:
print(" lowres_noise_times.dim: ", lowres_noise_times.shape)
if cond_images==None:
print(" cond_images dim: None")
else:
# Model B is used
print(" cond_images dim: ", cond_images.shape, f"[batch, {str(self.cond_images_channels)}, seq_len]")
if text_embeds==None:
print(" text_embeds.dim: None")
else:
# Model A is used
print(" text_embeds.dim: ", text_embeds.shape)
if self_cond==None:
print(" self_cond: None")
else:
print(" self_cond.dim: ", self_cond.shape)
print("\n\n")
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x, self_cond), dim = 1)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(ii)
print("self_cond dim: ", self_cond.shape)
print("After cat(x, self_cond)-> x. dim: ", x.shape)
# add low resolution conditioning, if present
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(ii)
print("lowres_cond_img dim: ", lowres_cond_img.shape)
print("After cat(x, lowres_cond_img) dim: ", x.shape)
# condition on input image
assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
if exists(cond_images):
assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
cond_images = resize_image_to(cond_images, x.shape[-1])
# ++++++++++++++++++++++++++++++
# add cond_images (from X) into generation process...
x = torch.cat((cond_images.to(device), x.to(device)), dim = 1)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(ii)
print("cond_images dim: ", cond_images.shape, f"[batch, {str(self.cond_images_channels)}, max_seq_len]")
print("After cat(cond_images, x), x dim: ", x.shape, "[batch, cond_images_channels+images_channels, max_seq_len]")
# initial convolution
if self.beginning_and_final_conv_present:
x = self.init_conv(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(ii)
print("After self.init_conv(x)-> x dim: ", x.shape, "[batch, UNet:dim, max_seq_len]")
# init conv residual
if self.init_conv_to_final_conv_residual:
init_conv_residual = x.clone()
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(ii)
print("x.clone()->init_conv_resi, dim: ", init_conv_residual.shape)
# time conditioning
time_hiddens = self.to_time_hiddens(time)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(ii)
print("time dim: ", time.shape, "[batch]")
print("self.to_time_hiddens(time)-> time_hiddens .dim: ", time_hiddens.shape, "[batch, 4xUNet:dim]")
# derive time tokens
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(ii)
print("self.to_time_tokens(time_hiddens)-> time_tokens dim: ", time_tokens.shape, "[batch, num_time_tokens,4xdim/num_time_tokens]")
print("self.to_time_cond(time_hiddens)-> t dim: ", t.shape, "[batch, 4xUNet:dim]")
# add lowres time conditioning to time hiddens
# and add lowres time tokens along sequence dimension for attention
if self.lowres_cond:
lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(ii)
print("self.to_lowres_time_hiddens(lowres_noise_times)-> lowres_time_hiddens .dim: ", lowres_time_hiddens.shape)
print("self.to_lowres_time_tokens(lowres_time_hiddens)-> lowres_time_tokens .dim: ", lowres_time_tokens.shape)
print("self.to_lowres_time_cond(lowres_time_hiddens)-> lowres_t.dim: ", lowres_t.shape)
t = t + lowres_t
time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii +=1
print(ii)
print("After cat(time_tokens, lowres_time_tokens)-> time_tokens dim: ", time_tokens.shape)
# text conditioning
text_tokens = None
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("UNet.cond_on_text: ", self.cond_on_text)
if exists(text_embeds) and self.cond_on_text:
# conditional dropout
text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
# calculate text embeds
if self.text_cond_linear:
text_tokens = self.text_to_cond(text_embeds)
else:
text_tokens=text_embeds
if self.CKeys['Debug_ModelPack']==1:
print("On text conditioning part...")
ii +=1
print(ii)
print("text_embeds->text_tokens.dim: ", text_tokens.shape)
text_tokens = text_tokens[:, :self.max_text_len]
if self.CKeys['Debug_ModelPack']==1:
ii +=1
print(ii)
print("text_tokens[:,:max_text_len]-> text_tokens.dim: ", text_tokens.shape)
print("do the same for text_mask")
if exists(text_mask):
text_mask = text_mask[:, :self.max_text_len]
text_tokens_len = text_tokens.shape[1]
remainder = self.max_text_len - text_tokens_len
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask_embed = text_mask & text_keep_mask_embed
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
text_tokens = torch.where(
text_keep_mask_embed,
text_tokens,
null_text_embed
)
if exists(self.attn_pool):
text_tokens = self.attn_pool(text_tokens)
if self.CKeys['Debug_ModelPack']==1:
ii+=1
print(ii)
print("self.attn_pool(text_tokens)->text_tokens.dim: ", text_tokens.shape)
# extra non-attention conditioning by projecting and then summing text embeddings to time
# termed as text hiddens
mean_pooled_text_tokens = text_tokens.mean(dim = -2)
if self.CKeys['Debug_ModelPack']==1:
ii +=1
print(ii)
print("text_tokens.mean(dim=-2)->mean_pooled_text_tokens.dim: ", mean_pooled_text_tokens.shape)
text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
if self.CKeys['Debug_ModelPack']==1:
ii +=1
print(ii)
print("self.to_text_non_attn_cond(mean_pooled_text_tokens)->text_hiddens.dim: ",text_hiddens.shape)
null_text_hidden = self.null_text_hidden.to(t.dtype)
text_hiddens = torch.where(
text_keep_mask_hidden,
text_hiddens,
null_text_hidden
)
t = t + text_hiddens
if self.CKeys['Debug_ModelPack']==1:
ii +=1
print(ii)
print("t+text_hiddens.dim: ", t.shape)
# main conditioning tokens (c)
c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(ii)
print("cat(time_tokens, text_tokens)-> c dim: ", c.shape)
# normalize conditioning tokens
c = self.norm_cond(c)
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(ii)
print("self.norm_cond(c)->c dim:", c.shape)
# initial resnet block (for memory efficient unet)
if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(ii)
print("self.init_resnet_block(x,t)-> x dim: ", x.shape)
# go through the layers of the unet, down and up
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("Before unet, down and up, ")
print(" x dim: ", x.shape)
print(" t dim: ", t.shape)
print(" c dim: ", c.shape)
# ii=0
hiddens = []
for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, pre_downsample(x)=>x dim: ", x.shape)
x = init_block(x, t, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, init_block(x,t,c)=>x dim: ", x.shape)
for resnet_block in resnet_blocks:
x = resnet_block(x, t)
hiddens.append(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, resnet_block(x,t)=> x dim: ", x.shape)
x = attn_block(x, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, attn_block(x,c)=> x dim: ", x.shape)
hiddens.append(x)
if exists(post_downsample):
x = post_downsample(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, post_downsample(x)=> x dim: ", x.shape)
x = self.mid_block1(x, t, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, mid_block_1(x,t,c)=> x dim: ", x.shape)
if exists(self.mid_attn):
x = self.mid_attn(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, mid_attn(x)=> x dim: ", x.shape)
x = self.mid_block2(x, t, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, mid_block_2(x,t,c)=> x dim: ", x.shape)
add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
up_hiddens = []
for init_block, resnet_blocks, attn_block, upsample in self.ups:
x = add_skip_connection(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, add_skip_connection(x)=> x dim: ", x.shape)
#
x = init_block(x, t, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, init_block(x,t,c)=> x dim: ", x.shape)
for resnet_block in resnet_blocks:
x = add_skip_connection(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, add_skip_connection(x)=> x dim: ", x.shape)
x = resnet_block(x, t)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, resnet_block(x,t)=> x dim: ", x.shape)
x = attn_block(x, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, attn_block(x,c)=> x dim: ", x.shape)
up_hiddens.append(x.contiguous())
x = upsample(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, upsample(x)=> x dim: ", x.shape)
# whether to combine all feature maps from upsample blocks
x = self.upsample_combiner(x, up_hiddens)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, upsample_combiner(x,up_hiddens)=> x dim: ", x.shape)
# final top-most residual if needed
if self.init_conv_to_final_conv_residual:
x = torch.cat((x, init_conv_residual), dim = 1)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, cat(x,init_conv_residual)=> x dim: ", x.shape)
if exists(self.final_res_block):
x = self.final_res_block(x, t)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, final_res_block(x,t)=> x dim: ", x.shape)
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, cat(x,lowres_cond_img)=> x dim: ", x.shape)
if self.beginning_and_final_conv_present:
x=self.final_conv(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, final_conv(x)=> x dim: ", x.shape)
return x
########################################################
## 1D Unet
########################################################
class OneD_Unet_Old(nn.Module):
def __init__(
self,
*,
dim,
# image_embed_dim = 1024, # NOT used so far
text_embed_dim = 768, # get_encoded_dim(DEFAULT_T5_NAME)
num_resnet_blocks = 1,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
learned_sinu_pos_emb_dim = 16,
out_dim = None,
dim_mults=(1, 2, 4, 8),
cond_images_channels = 0,
channels = 3,
channels_out = None,
attn_dim_head = 64,
attn_heads = 8,
ff_mult = 2.,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
layer_attns = True,
layer_attns_depth = 1,
layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
layer_cross_attns = True,
use_linear_attn = False,
use_linear_cross_attn = False,
cond_on_text = True,
max_text_len = 256,
init_dim = None,
resnet_groups = 8,
init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
init_cross_embed = False,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
attn_pool_text = True,
attn_pool_num_latents = 32,
dropout = 0.,
memory_efficient = False,
init_conv_to_final_conv_residual = False,
use_global_context_attn = True,
scale_skip_connection = True,
final_resnet_block = True,
final_conv_kernel_size = 3,
cosine_sim_attn = False,
self_cond = False,
combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
pixel_shuffle_upsample = False , # may address checkboard artifacts
beginning_and_final_conv_present = True , #TODO add cross-attn, doesnt work yet...whether or not to have final conv layer
# +++++++++++++++++++++++++++
CKeys=None,
):
super().__init__()
# +++++++++++++++++++++++++++++++
self.CKeys=CKeys
assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
if dim < 128:
print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
self._locals.pop('self', None)
self._locals.pop('__class__', None)
# + for debug
if CKeys['Debug_ModelPack']==1:
print("Showing the input:")
print(json.dumps(self._locals, indent=4))
# determine dimensions
self.channels = channels
self.channels_out = default(channels_out, channels)
# (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
# (2) in self conditioning, one appends the predict x0 (x_start)
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
init_dim = default(init_dim, dim)
self.self_cond = self_cond
# optional image conditioning
self.has_cond_image = cond_images_channels > 0
self.cond_images_channels = cond_images_channels
init_channels += cond_images_channels
self.beginning_and_final_conv_present=beginning_and_final_conv_present
if self.beginning_and_final_conv_present:
self.init_conv = CrossEmbedLayer(
init_channels, dim_out = init_dim,
kernel_sizes = init_cross_embed_kernel_sizes,
stride = 1
) if init_cross_embed else nn.Conv1d(
init_channels, init_dim,
init_conv_kernel_size,
padding = init_conv_kernel_size // 2)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time conditioning
cond_dim = default(cond_dim, dim)
time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
# embedding time for log(snr) noise from continuous version
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print(" to_time_hiddens")
print(" ou dim: ", time_cond_dim)
self.to_time_hiddens = nn.Sequential(
sinu_pos_emb,
nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
nn.SiLU()
)
self.to_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
# project to time tokens as well as time hiddens
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print(" to_time_tokens")
print(" in dim: ", time_cond_dim) # =4xdim cond_dim=dim
print(" ou dim: ", num_time_tokens)
self.to_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
# low res aug noise conditioning
self.lowres_cond = lowres_cond
if lowres_cond:
self.to_lowres_time_hiddens = nn.Sequential(
LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
nn.SiLU()
)
self.to_lowres_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
self.to_lowres_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
# normalizations
self.norm_cond = nn.LayerNorm(cond_dim)
# text encoding conditioning (optional)
self.text_to_cond = None
if cond_on_text: #only add linear lear if cond dim is not text emnd dim
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
if text_embed_dim != cond_dim:
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
self.text_cond_linear=True
else:
print ("Text conditioning is equatl to cond_dim - no linear layer used")
self.text_cond_linear=False
# finer control over whether to condition on text encodings
self.cond_on_text = cond_on_text
# attention pooling
self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2,
dim_head = attn_dim_head, heads = attn_heads,
num_latents = attn_pool_num_latents,
cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None
# for classifier free guidance
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
# for non-attention based text conditioning at all points in the network where time is also conditioned
self.to_text_non_attn_cond = None
if cond_on_text:
self.to_text_non_attn_cond = nn.Sequential(
nn.LayerNorm(cond_dim),
nn.Linear(cond_dim, time_cond_dim),
nn.SiLU(),
nn.Linear(time_cond_dim, time_cond_dim)
)
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn)
num_layers = len(in_out)
# resnet block klass
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
resnet_groups = cast_tuple(resnet_groups, num_layers)
resnet_klass = partial(ResnetBlock, **attn_kwargs)
layer_attns = cast_tuple(layer_attns, num_layers)
layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
use_linear_attn = cast_tuple(use_linear_attn, num_layers)
use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers)
assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
# downsample klass
downsample_klass = Downsample
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# initial resnet block (for memory efficient unet)
self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
# scale for resnet skip connections
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn]
reversed_layer_params = list(map(reversed, layer_params))
# downsampling layers
skip_connect_dims = [] # keep track of skip connection dimensions
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)):
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
if layer_attn:
transformer_block_klass = TransformerBlock
elif layer_use_linear_attn:
transformer_block_klass = LinearAttentionTransformerBlock
else:
transformer_block_klass = Identity
current_dim = dim_in
# whether to pre-downsample, from memory efficient unet
pre_downsample = None
if memory_efficient:
pre_downsample = downsample_klass(dim_in, dim_out)
current_dim = dim_out
skip_connect_dims.append(current_dim)
# whether to do post-downsample, for non-memory efficient unet
post_downsample = None
if not memory_efficient:
post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv1d(dim_in, dim_out, 3, padding = 1), nn.Conv1d(dim_in, dim_out, 1))
self.downs.append(
nn.ModuleList([
pre_downsample,
resnet_klass(
current_dim, current_dim,
cond_dim = layer_cond_dim,
linear_attn = layer_use_linear_cross_attn,
time_cond_dim = time_cond_dim, groups = groups
),
nn.ModuleList([
ResnetBlock(
current_dim, current_dim,
time_cond_dim = time_cond_dim,
groups = groups,
use_gca = use_global_context_attn
) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(
dim = current_dim,
depth = layer_attn_depth,
ff_mult = ff_mult,
context_dim = cond_dim,
**attn_kwargs),
post_downsample
])
)
# middle layers
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c h', 'b h c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
# upsample klass
upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# upsampling layers
upsample_fmap_dims = []
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
is_last = ind == (len(in_out) - 1)
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
if layer_attn:
transformer_block_klass = TransformerBlock
elif layer_use_linear_attn:
transformer_block_klass = LinearAttentionTransformerBlock
else:
transformer_block_klass = Identity
skip_connect_dim = skip_connect_dims.pop()
upsample_fmap_dims.append(dim_out)
self.ups.append(nn.ModuleList([
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
]))
# whether to combine feature maps from all upsample blocks before final resnet block out
self.upsample_combiner = UpsampleCombiner(
dim = dim,
enabled = combine_upsample_fmaps,
dim_ins = upsample_fmap_dims,
dim_outs = dim
)
# whether to do a final residual from initial conv to the final resnet block out
self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
# final optional resnet block and convolution out
self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
final_conv_dim_in = dim if final_resnet_block else final_conv_dim
final_conv_dim_in += (channels if lowres_cond else 0)
if self.beginning_and_final_conv_present:
print (final_conv_dim_in, self.channels_out)
self.final_conv = nn.Conv1d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)
if self.beginning_and_final_conv_present:
zero_init_(self.final_conv)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
self,
*,
lowres_cond,
text_embed_dim,
channels,
channels_out,
cond_on_text
):
if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_text == self.cond_on_text and \
text_embed_dim == self._locals['text_embed_dim'] and \
channels_out == self.channels_out:
return self
updated_kwargs = dict(
lowres_cond = lowres_cond,
text_embed_dim = text_embed_dim,
channels = channels,
channels_out = channels_out,
cond_on_text = cond_on_text
)
return self.__class__(**{**self._locals, **updated_kwargs})
# methods for returning the full unet config as well as its parameter state
def to_config_and_state_dict(self):
return self._locals, self.state_dict()
# class method for rehydrating the unet from its config and state dict
@classmethod
def from_config_and_state_dict(klass, config, state_dict):
unet = klass(**config)
unet.load_state_dict(state_dict)
return unet
# methods for persisting unet to disk
def persist_to_file(self, path):
path = Path(path)
path.parents[0].mkdir(exist_ok = True, parents = True)
config, state_dict = self.to_config_and_state_dict()
pkg = dict(config = config, state_dict = state_dict)
torch.save(pkg, str(path))
# class method for rehydrating the unet from file saved with `persist_to_file`
@classmethod
def hydrate_from_file(klass, path):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path))
assert 'config' in pkg and 'state_dict' in pkg
config, state_dict = pkg['config'], pkg['state_dict']
return Unet.from_config_and_state_dict(config, state_dict)
# forward with classifier free guidance
def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
# forward fun returns the loss
# Here, for Model B:
# x : (batch, 1, seq_len) values: from noised image/Y
# time : (batch) values:
# text_embeds : None
# text_mask: None
# cond_images: input sequence/X
# lowres_noise_times: None
# lowres_cond_img: None
# cond_drop_prob: 0.1
# self_cond: from noised image/Y
def forward(
self,
x,
time,
*,
lowres_cond_img = None,
lowres_noise_times = None,
text_embeds = None,
text_mask = None,
cond_images = None,
self_cond = None,
cond_drop_prob = 0.
):
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("========================================")
print("Here are OneD_Unet:forward")
batch_size, device = x.shape[0], x.device
# condition on self
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("Check inputs: ")
print(" x-0 dim: ", x.shape, "[batch, 1, seq_len]")
print(" time dim: ", time.shape, "")
print(" cond_images dim: ", cond_images.shape, "")
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x, self_cond), dim = 1)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("After self_cond x dim: ", x.shape)
# add low resolution conditioning, if present
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("After lowres_cond_img x dim: ", x.shape)
# condition on input image
assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
if exists(cond_images):
assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
cond_images = resize_image_to(cond_images, x.shape[-1])
# ++++++++++++++++++++++++++++++
# add cond_images (from X) into generation process...
x = torch.cat((cond_images.to(device), x.to(device)), dim = 1)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("cond_images dim: ", cond_images.shape, "[batch, 1, max_seq_len]")
print("After cond_images, x dim: ", x.shape, "[batch, 2, max_seq_len]")
# initial convolution
if self.beginning_and_final_conv_present:
x = self.init_conv(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("After init_conv, x dim: ", x.shape, "[batch, UNet:dim, max_seq_len]")
# init conv residual
if self.init_conv_to_final_conv_residual:
init_conv_residual = x.clone()
# time conditioning
time_hiddens = self.to_time_hiddens(time)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("time dim: ", time.shape, "[batch]")
print("after, time_hiddens dim: ", time_hiddens.shape, "[batch, 4xUNet:dim]")
# derive time tokens
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("time_tokens dim: ", time_tokens.shape, "[batch, num_time_tokens,4xdim/num_time_tokens]")
print("after to_time_cond t dim: ", t.shape, "[batch, 4xUNet:dim]")
# add lowres time conditioning to time hiddens
# and add lowres time tokens along sequence dimension for attention
if self.lowres_cond:
lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
t = t + lowres_t
time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("After lowres_cond, time_tokens dim: ", time_tokens.shape)
# text conditioning
text_tokens = None
if exists(text_embeds) and self.cond_on_text:
# conditional dropout
text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
# calculate text embeds
if self.text_cond_linear:
text_tokens = self.text_to_cond(text_embeds)
else:
text_tokens=text_embeds
text_tokens = text_tokens[:, :self.max_text_len]
if exists(text_mask):
text_mask = text_mask[:, :self.max_text_len]
text_tokens_len = text_tokens.shape[1]
remainder = self.max_text_len - text_tokens_len
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask_embed = text_mask & text_keep_mask_embed
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
text_tokens = torch.where(
text_keep_mask_embed,
text_tokens,
null_text_embed
)
if exists(self.attn_pool):
text_tokens = self.attn_pool(text_tokens)
# extra non-attention conditioning by projecting and then summing text embeddings to time
# termed as text hiddens
mean_pooled_text_tokens = text_tokens.mean(dim = -2)
text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
null_text_hidden = self.null_text_hidden.to(t.dtype)
text_hiddens = torch.where(
text_keep_mask_hidden,
text_hiddens,
null_text_hidden
)
t = t + text_hiddens
# main conditioning tokens (c)
c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("Merge time and text tokens, c dim: ", c.shape)
# normalize conditioning tokens
c = self.norm_cond(c)
# initial resnet block (for memory efficient unet)
if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
# go through the layers of the unet, down and up
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("Before unet, down and up, ")
print("x dim: ", x.shape)
print("t dim: ", t.shape)
print("c dim: ", c.shape)
ii=0
hiddens = []
for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after pre_downsample x dim: ", x.shape)
x = init_block(x, t, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after init_block(x,t,c) x dim: ", x.shape)
for resnet_block in resnet_blocks:
x = resnet_block(x, t)
hiddens.append(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after resnet_block x dim: ", x.shape)
x = attn_block(x, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after attn_block x dim: ", x.shape)
hiddens.append(x)
if exists(post_downsample):
x = post_downsample(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after post_downsample x dim: ", x.shape)
x = self.mid_block1(x, t, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after mid_block_1 x dim: ", x.shape)
if exists(self.mid_attn):
x = self.mid_attn(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after mid_attn x dim: ", x.shape)
x = self.mid_block2(x, t, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after mid_block_2 x dim: ", x.shape)
add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
up_hiddens = []
for init_block, resnet_blocks, attn_block, upsample in self.ups:
x = add_skip_connection(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after add_skip_connection x dim: ", x.shape)
#
x = init_block(x, t, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after init_block(x,t,c) x dim: ", x.shape)
for resnet_block in resnet_blocks:
x = add_skip_connection(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after add_skip_connection x dim: ", x.shape)
x = resnet_block(x, t)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after resnet_block(x,t) x dim: ", x.shape)
x = attn_block(x, c)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after attn_block(x,c) x dim: ", x.shape)
up_hiddens.append(x.contiguous())
x = upsample(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after upsample(x) x dim: ", x.shape)
# whether to combine all feature maps from upsample blocks
x = self.upsample_combiner(x, up_hiddens)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after upsample_combiner(x,..) x dim: ", x.shape)
# final top-most residual if needed
if self.init_conv_to_final_conv_residual:
x = torch.cat((x, init_conv_residual), dim = 1)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after cat_init_conv_resi x dim: ", x.shape)
if exists(self.final_res_block):
x = self.final_res_block(x, t)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after final_res_block(x,t) x dim: ", x.shape)
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after cat_x_lowres_cond_img x dim: ", x.shape)
if self.beginning_and_final_conv_present:
x=self.final_conv(x)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
ii += 1
print(F" {str(ii)}, after final_conv(x) x dim: ", x.shape)
return x
########################################################
## null unets
########################################################
class NullUnet(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.lowres_cond = False
self.dummy_paramcast_model_parameterseter = nn.Parameter(torch.tensor([0.]))
def cast_model_parameters(self, *args, **kwargs):
return self
def forward(self, x, *args, **kwargs):
return x
class Unet(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.lowres_cond = False
self.dummy_parameter = nn.Parameter(torch.tensor([0.]))
def cast_model_parameters(self, *args, **kwargs):
return self
def forward(self, x, *args, **kwargs):
return x
########################################################
## Elucidated denoising model
## After: Tero Karras and Miika Aittala and Timo Aila and Samuli Laine,
## Elucidating the Design Space of Diffusion-Based Generative Models
## https://arxiv.org/abs/2206.00364, 2022
########################################################
from math import sqrt
Hparams_fields = [
'num_sample_steps',
'sigma_min',
'sigma_max',
'sigma_data',
'rho',
'P_mean',
'P_std',
'S_churn',
'S_tmin',
'S_tmax',
'S_noise'
]
Hparams = namedtuple('Hparams', Hparams_fields)
# helper functions
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
# main class
class ElucidatedImagen(nn.Module):
def __init__(
self,
unets,
*,
image_sizes, # for cascading ddpm, image size at each stage
text_encoder_name = '', # can be "DEFAULT_T5_NAME"
text_embed_dim = None,
channels = 3,
channels_out=3,
cond_drop_prob = 0.1,
random_crop_sizes = None,
lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
condition_on_text = True,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
dynamic_thresholding = True,
dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
only_train_unet_number = None,
lowres_noise_schedule = 'linear',
num_sample_steps = 32, # number of sampling steps
sigma_min = 0.002, # min noise level
sigma_max = 80, # max noise level
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
loss_type=0, #0=MSE,
categorical_loss_ignore=None,
# +++++++++++++++++
# device=None,
CKeys=None,
PKeys=None,
):
super().__init__()
# +++++++++++++++++++++++++++
# self.device=device
self.CKeys=CKeys
self.PKeys=PKeys
self.only_train_unet_number = only_train_unet_number
self.condition_on_text = condition_on_text
self.unconditional = not condition_on_text
self.loss_type=loss_type
if self.loss_type>0:
self.categorical_loss=True
self.m = nn.LogSoftmax(dim=1) #used for some loss functins
else:
self.categorical_loss=False
print("Loss type: ", self.loss_type)
self.categorical_loss_ignore=categorical_loss_ignore
# channels
self.channels = channels
self.channels_out = channels_out
unets = cast_tuple(unets)
num_unets = len(unets)
# randomly cropping for upsampler training
self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
# lowres augmentation noise schedule
self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule)
# get text encoder
self.text_embed_dim =text_embed_dim
# construct unets
self.unets = nn.ModuleList([])
self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
print (f"Channels in={self.channels}, channels out={self.channels_out}")
for ind, one_unet in enumerate(unets):
# # ---------------------------
# assert isinstance(one_unet, ( OneD_Unet, NullUnet))
# ++++++++++++++++++++++++++++
is_first = ind == 0
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
#
# + for debug
# print('I am here')
print("Test on cast_model_parameters...")
print(not is_first)
print(self.condition_on_text)
print(self.text_embed_dim if self.condition_on_text else None)
print(self.channels)
print(self.channels_out)
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
cond_on_text = self.condition_on_text,
text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
channels = self.channels,
#channels_out = self.channels
channels_out = self.channels_out
)
self.unets.append(one_unet)
# determine whether we are training on images or video
is_video = False # only consider image case
self.is_video = is_video
self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1' if not is_video else 'b -> b 1 1 1'))
self.resize_to = resize_video_to if is_video else resize_image_to
self.image_sizes = image_sizes
assert num_unets == len(self.image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}'
self.sample_channels = cast_tuple(self.channels, num_unets)
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.lowres_sample_noise_level = lowres_sample_noise_level
self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
# classifier free guidance
self.cond_drop_prob = cond_drop_prob
self.can_classifier_guidance = cond_drop_prob > 0.
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
self.input_image_range = (0. if auto_normalize_img else -1., 1.)
# dynamic thresholding
self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
# # temporal interpolations
# temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets)
# self.temporal_downsample_factor = temporal_downsample_factor
# self.resize_cond_video_frames = resize_cond_video_frames
# self.temporal_downsample_divisor = temporal_downsample_factor[0]
# assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1'
# assert tuple(sorted(temporal_downsample_factor, reverse = True)) == temporal_downsample_factor, 'temporal downsample factor must be in order of descending'
# elucidating parameters
hparams = [
num_sample_steps,
sigma_min,
sigma_max,
sigma_data,
rho,
P_mean,
P_std,
S_churn,
S_tmin,
S_tmax,
S_noise,
]
hparams = [cast_tuple(hp, num_unets) for hp in hparams]
self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)]
# one temp parameter for keeping track of device
# self.register_buffer('_temp', torch.tensor([0.]).to(device), persistent = False)
self.register_buffer('_temp', torch.tensor([0.]), persistent = False)
# default to device of unets passed in
self.to(next(self.unets.parameters()).device)
# for debug
print(next(self.unets.parameters()).device)
# print ("Device used in ImagenEluc: ", self.device)
def force_unconditional_(self):
self.condition_on_text = False
self.unconditional = True
for unet in self.unets:
unet.cond_on_text = False
@property
def device(self):
return self._temp.device
#return (device)
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
if isinstance(self.unets, nn.ModuleList):
unets_list = [unet for unet in self.unets]
delattr(self, 'unets')
self.unets = unets_list
if index != self.unet_being_trained_index:
for unet_index, unet in enumerate(self.unets):
unet.to(self.device if unet_index == index else 'cpu')
self.unet_being_trained_index = index
return self.unets[index]
def reset_unets_all_one_device(self, device = None):
device = default(device, self.device)
self.unets = nn.ModuleList([*self.unets])
self.unets.to(device)
self.unet_being_trained_index = -1
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.unets[unet_number - 1]
devices = [module_device(unet) for unet in self.unets]
self.unets.cpu()
unet.to(self.device)
yield
for unet, device in zip(self.unets, devices):
unet.to(device)
# overriding state dict functions
def state_dict(self, *args, **kwargs):
self.reset_unets_all_one_device()
return super().state_dict(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
self.reset_unets_all_one_device()
return super().load_state_dict(*args, **kwargs)
# dynamic thresholding
def threshold_x_start(self, x_start, dynamic_threshold = True):
if not dynamic_threshold:
return x_start.clamp(-1., 1.)
s = torch.quantile(
rearrange(x_start, 'b ... -> b (...)').abs(),
self.dynamic_thresholding_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = right_pad_dims_to(x_start, s)
return x_start.clamp(-s, s) / s
# derived preconditioning params - Table 1
def c_skip(self, sigma_data, sigma):
return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2)
def c_out(self, sigma_data, sigma):
return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5
def c_in(self, sigma_data, sigma):
return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5
def c_noise(self, sigma):
return log(sigma) * 0.25
# preconditioned network output
# equation (7) in the paper
def preconditioned_network_forward(
self,
unet_forward,
noised_images,
sigma,
*,
sigma_data,
clamp = False,
dynamic_threshold = True,
**kwargs
):
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("========================================")
print("ElucidatedImagen: preconditioned_network_forward")
batch, device = noised_images.shape[0], noised_images.device
if isinstance(sigma, float):
sigma = torch.full((batch,), sigma, device = device)
padded_sigma = self.right_pad_dims_to_datatype(sigma)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("unet_forward: ")
print("arg_0: ", noised_images.shape)
print("arg_1: ", sigma.shape)
print("other arg: ", kwargs)
net_out = unet_forward(
self.c_in(sigma_data, padded_sigma) * noised_images,
self.c_noise(sigma),
# "{text_embeds,text_mask,cond_images,...} are in kwargs"
**kwargs
)
out = self.c_skip(sigma_data, padded_sigma) * noised_images + self.c_out(sigma_data, padded_sigma) * net_out
if not clamp:
return out
return self.threshold_x_start(out, dynamic_threshold)
# sampling
# sample schedule
# equation (5) in the paper
def sample_schedule(
self,
num_sample_steps,
rho,
sigma_min,
sigma_max
):
N = num_sample_steps
inv_rho = 1 / rho
steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho
sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
return sigmas
@torch.no_grad()
def one_unet_sample(
self,
unet,
shape,
*,
unet_number,
clamp = True,
dynamic_threshold = True,
cond_scale = 1.,
use_tqdm = True,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
init_images = None,
skip_steps = None,
sigma_min = None,
sigma_max = None,
**kwargs
):
# get specific sampling hyperparameters for unet
hp = self.hparams[unet_number - 1]
sigma_min = default(sigma_min, hp.sigma_min)
sigma_max = default(sigma_max, hp.sigma_max)
# get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
sigmas = self.sample_schedule(hp.num_sample_steps, hp.rho, sigma_min, sigma_max)
gammas = torch.where(
(sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax),
min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1),
0.
)
sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
# images is noise at the beginning
init_sigma = sigmas[0]
images = init_sigma * torch.randn(shape, device = self.device)
# initializing with an image
if exists(init_images):
images += init_images
# keeping track of x0, for self conditioning if needed
x_start = None
# prepare inpainting images and mask
has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
resample_times = inpaint_resample_times if has_inpainting else 1
if has_inpainting:
inpaint_images = self.normalize_img(inpaint_images)
inpaint_images = self.resize_to(inpaint_images, shape[-1])
inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool()
# unet kwargs
unet_kwargs = dict(
sigma_data = hp.sigma_data,
clamp = clamp,
dynamic_threshold = dynamic_threshold,
cond_scale = cond_scale,
**kwargs
)
# gradually denoise
initial_step = default(skip_steps, 0)
sigmas_and_gammas = sigmas_and_gammas[initial_step:]
total_steps = len(sigmas_and_gammas)
for ind, (sigma, sigma_next, gamma) in tqdm(enumerate(sigmas_and_gammas), total = total_steps, desc = 'sampling time step', disable = not use_tqdm):
is_last_timestep = ind == (total_steps - 1)
sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
for r in reversed(range(resample_times)):
is_last_resample_step = r == 0
eps = hp.S_noise * torch.randn(shape, device = self.device) # stochastic sampling
sigma_hat = sigma + gamma * sigma
added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps
images_hat = images + added_noise
self_cond = x_start if unet.self_cond else None
if has_inpainting:
images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks
model_output = self.preconditioned_network_forward(
unet.forward_with_cond_scale,
images_hat,
sigma_hat,
self_cond = self_cond,
**unet_kwargs
)
denoised_over_sigma = (images_hat - model_output) / sigma_hat
images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma
# second order correction, if not the last timestep
if sigma_next != 0:
self_cond = model_output if unet.self_cond else None
model_output_next = self.preconditioned_network_forward(
unet.forward_with_cond_scale,
images_next,
sigma_next,
self_cond = self_cond,
**unet_kwargs
)
denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
images = images_next
if has_inpainting and not (is_last_resample_step or is_last_timestep):
# renoise in repaint and then resample
repaint_noise = torch.randn(shape, device = self.device)
images = images + (sigma - sigma_next) * repaint_noise
x_start = model_output # save model output for self conditioning
if has_inpainting:
images = images * ~inpaint_masks + inpaint_images * inpaint_masks
return images
@torch.no_grad()
@eval_decorator
def sample(
self,
texts: List[str] = None,
text_masks = None,
text_embeds = None,
cond_images = None,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
init_images = None,
skip_steps = None,
sigma_min = None,
sigma_max = None,
video_frames = None,
batch_size = 1,
cond_scale = 1.,
lowres_sample_noise_level = None,
start_at_unet_number = 1,
start_image_or_video = None,
stop_at_unet_number = None,
return_all_unet_outputs = False,
return_pil_images = False,
use_tqdm = True,
device = None,
):
# +
device = default(device, self.device)
self.reset_unets_all_one_device(device = device)
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
if exists(texts) and not exists(text_embeds) and not self.unconditional:
assert all([*map(len, texts)]), 'text cannot be empty'
with autocast(enabled = False):
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
if not self.unconditional:
assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
batch_size = text_embeds.shape[0]
if exists(inpaint_images):
if self.unconditional:
if batch_size == 1: # assume researcher wants to broadcast along inpainted images
batch_size = inpaint_images.shape[0]
assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'
assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting'
outputs = []
is_cuda = next(self.parameters()).is_cuda
device = next(self.parameters()).device
lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
num_unets = len(self.unets)
cond_scale = cast_tuple(cond_scale, num_unets)
# handle video and frame dimension
assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
frame_dims = (video_frames,) if self.is_video else tuple()
# initializing with an image or video
init_images = cast_tuple(init_images, num_unets)
init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
skip_steps = cast_tuple(skip_steps, num_unets)
sigma_min = cast_tuple(sigma_min, num_unets)
sigma_max = cast_tuple(sigma_max, num_unets)
# handle starting at a unet greater than 1, for training only-upscaler training
if start_at_unet_number > 1:
assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
prev_image_size = self.image_sizes[start_at_unet_number - 2]
img = self.resize_to(start_image_or_video, prev_image_size)
for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm):
if unet_number < start_at_unet_number:
continue
assert not isinstance(unet, NullUnet), 'cannot sample from null unet'
context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext()
with context:
lowres_cond_img = lowres_noise_times = None
shape = (batch_size, channel, *frame_dims, image_size )
# low resolution conditioning
if unet.lowres_cond:
lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)
lowres_cond_img = self.resize_to(img, image_size)
lowres_cond_img = self.normalize_img(lowres_cond_img.float())
lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(
x_start = lowres_cond_img.float(),
t = lowres_noise_times,
noise = torch.randn_like(lowres_cond_img.float())
)
if exists(unet_init_images):
unet_init_images = self.resize_to(unet_init_images, image_size)
shape = (batch_size, self.channels, *frame_dims, image_size)
img = self.one_unet_sample(
unet,
shape,
unet_number = unet_number,
text_embeds = text_embeds,
text_mask =text_masks,
cond_images = cond_images,
inpaint_images = inpaint_images,
inpaint_masks = inpaint_masks,
inpaint_resample_times = inpaint_resample_times,
init_images = unet_init_images,
skip_steps = unet_skip_steps,
sigma_min = unet_sigma_min,
sigma_max = unet_sigma_max,
cond_scale = unet_cond_scale,
lowres_cond_img = lowres_cond_img,
lowres_noise_times = lowres_noise_times,
dynamic_threshold = dynamic_threshold,
use_tqdm = use_tqdm
)
if self.categorical_loss:
img=self.m(img)
outputs.append(img)
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
break
output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs
if not return_all_unet_outputs:
outputs = outputs[-1:]
assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet'
if self.categorical_loss:
return torch.argmax(outputs[output_index], dim=1).unsqueeze (1)
else:
return outputs[output_index]
# training
def loss_weight(self, sigma_data, sigma):
return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2
def noise_distribution(self, P_mean, P_std, batch_size):
# torch.randn: normal distribution N(0, 1)
return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp()
def forward(
self,
images,
unet: Union[ NullUnet, DistributedDataParallel] = None,
texts: List[str] = None,
text_embeds = None,
text_masks = None,
unet_number = None,
cond_images = None,
):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'
# + for debug
if self.CKeys['Debug_ModelPack']==1:
if cond_images!=None:
print("cond_images type: ", cond_images.dtype)
else:
print("cond_images type: None")
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
if cond_images!=None:
print("cond_images type: ", cond_images.dtype)
else:
print("cond_images type: None")
if self.categorical_loss==False:
assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
unet_index = unet_number - 1
unet = default(unet, lambda: self.get_unet(unet_number))
assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
target_image_size = self.image_sizes[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
hp = self.hparams[unet_index]
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("target_image_size: ", target_image_size)
print("prev_image_size: ", prev_image_size)
print("random_crop_size: ", random_crop_size)
batch_size, c, *_, h, device, is_video = *images.shape, images.device, (images.ndim == 4)
frames = images.shape[2] if is_video else None
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("frames: ", frames)
check_shape(images, 'b c ...', c = self.channels)
assert h >= target_image_size
if exists(texts) and not exists(text_embeds) and not self.unconditional:
assert all([*map(len, texts)]), 'text cannot be empty'
assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
with autocast(enabled = False):
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
if not self.unconditional:
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented'
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
lowres_cond_img = lowres_aug_times = None
if exists(prev_image_size):
lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range)
lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range)
if self.per_sample_random_aug_noise_level:
lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device)
else:
# return a random time:
lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
# extend it to batch_size
lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size)
if exists(random_crop_size):
aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
if is_video:
images, lowres_cond_img = rearrange_many((images, lowres_cond_img), 'b c f h -> (b f) c h')
images = aug(images)
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
if is_video:
images, lowres_cond_img = rearrange_many((images, lowres_cond_img), '(b f) c h -> b c f h', f = frames)
lowres_cond_img_noisy = None
if exists(lowres_cond_img):
lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(
x_start = lowres_cond_img,
t = lowres_aug_times,
noise = torch.randn_like(lowres_cond_img.float())
)
# get the sigmas
sigmas = self.noise_distribution(
hp.P_mean,
hp.P_std,
batch_size
).to(device)
padded_sigmas = self.right_pad_dims_to_datatype(sigmas).to(device)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print('sigmas dim: ', sigmas.shape, 'should = batch_size')
print('sigmas[0..3]: ', sigmas[:3])
print('padded_sigmas dim: ', padded_sigmas.shape)
# noise
noise = torch.randn_like(images.float()).to(device)
noised_images = images + padded_sigmas * noise # alphas are 1. in the paper
# unet kwargs
unet_kwargs = dict(
sigma_data = hp.sigma_data,
text_embeds = text_embeds,
text_mask =text_masks,
cond_images = cond_images,
lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
lowres_cond_img = lowres_cond_img_noisy,
cond_drop_prob = self.cond_drop_prob,
)
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("lowres_noise_times: ", unet_kwargs['lowres_noise_times'])
print("lowres_cond_img_noisy: ", unet_kwargs['lowres_cond_img'])
print("unet_kwargs: \n", unet_kwargs)
# self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower
# Because 'unet' can be an instance of DistributedDataParallel coming from the
# ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
# access the member 'module' of the wrapped unet instance.
self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet
# # + for debug
# if self.CKeys['Debug_ModelPack']==1:
# # this will be the unet
# print("self_cond: ", self_cond)
if self_cond and random() < 0.5:
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("=========added self_cond.============")
with torch.no_grad():
pred_x0 = self.preconditioned_network_forward(
unet.forward,
noised_images,
sigmas,
**unet_kwargs
).detach()
unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0}
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("after self-condition, random < 0.5 .......")
print("unet_kwargs: \n", unet_kwargs)
# get prediction
denoised_images = self.preconditioned_network_forward(
unet.forward,
noised_images,
sigmas,
**unet_kwargs
)
# losses
if self.loss_type==0:
losses = F.mse_loss(denoised_images, images, reduction = 'none')
losses = reduce(losses, 'b ... -> b', 'mean')
# loss weighting
losses = losses * self.loss_weight(hp.sigma_data, sigmas)
losses=losses.mean()
return losses
# ===========================================================
# final models
# explict define unets
class ProteinDesigner_B(nn.Module):
def __init__(self,
unet,
# xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
CKeys=None,
PKeys=None,
):
super(ProteinDesigner_B, self).__init__()
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# unload the parameters
timesteps =default(PKeys['timesteps'], 10) # 10 ,
dim =default(PKeys['dim'], 32) # 32,
pred_dim =default(PKeys['pred_dim'], 25) # 25,
loss_type =default(PKeys['loss_type'], 0) # 0, # MSE
elucidated =default(PKeys['elucidated'], True) # True,
padding_idx =default(PKeys['padding_idx'], 0) # 0,
cond_dim =default(PKeys['cond_dim'], 512) # 512,
text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512,
input_tokens =default(PKeys['input_tokens'], 25) #for non-BERT
sequence_embed =default(PKeys['sequence_embed'], False)
embed_dim_position =default(PKeys['embed_dim_position'], 32)
max_text_len =default(PKeys['max_text_len'], 16)
cond_images_channels=default(PKeys['cond_images_channels'], 0)
# ++++++++++++++++++++++
max_length =default(PKeys['max_length'], 64)
device =default(PKeys['device'], None)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
print ("Model B: Generative protein diffusion model, residue-based")
print ("Using condition as the initial sequence")
self.pred_dim=pred_dim
self.loss_type=loss_type
# +++++++++++++++++ for debug
self.CKeys=CKeys
self.PKeys=PKeys
self.max_length = max_length
assert loss_type == 0, "Losses other than MSE not implemented"
self.fc_embed1 = nn.Linear( 8, max_length) # NOT used # INPUT DIM (last), OUTPUT DIM, last
self.fc_embed2 = nn.Linear( 1, text_embed_dim) # USED # INPUT DIM (last), OUTPUT DIM, last
self.max_text_len=max_text_len
self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position)
text_embed_dim=text_embed_dim+embed_dim_position
self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long)
for i in range (max_text_len):
self.pos_matrix_i [i]=i +1
condition_on_text=True
self.cond_images_channels=cond_images_channels
if self.cond_images_channels>0:
condition_on_text = False
if self.cond_images_channels>0:
print ("Use conditioning image during training....")
assert elucidated , "Only elucidated model implemented...."
self.is_elucidated=elucidated
if elucidated:
self.imagen = ElucidatedImagen(
unets = (unet),
channels=self.pred_dim,
channels_out=self.pred_dim ,
loss_type=loss_type,
condition_on_text = condition_on_text,
text_embed_dim = text_embed_dim,
image_sizes = ( [max_length ]),
cond_drop_prob = 0.1,
auto_normalize_img = False,
num_sample_steps = timesteps, # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
sigma_min = 0.002, # min noise level
sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
# ++++++++++++++++
# device=device,
CKeys=self.CKeys,
PKeys=self.PKeys,
).to(device)
else:
print ("Not implemented.")
def forward(self,
output,
x=None,
cond_images = None,
unet_number=1,
):
# ddddddddddddddddddddddddddddddddddddddddddddd
if self.CKeys['Debug_ModelPack']==1:
print("on Model B:forward")
print("inputs:")
print(' output: ', output.shape)
print(' x: ', x)
print(' cond_img: ', cond_images.shape)
print(' unet_num: ', unet_number)
if x != None:
x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device)
x_in[:,:x.shape[1]]=x
x=x_in
x=x.unsqueeze (2)
x= self.fc_embed2(x)
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device)
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
pos_emb_x = torch.squeeze(pos_emb_x, 1)
x= torch.cat( (x, pos_emb_x ), 2)
if cond_images!=None:
this_cond_images=cond_images.to(device)
else:
this_cond_images=cond_images
if self.CKeys['Debug_ModelPack']==1:
print('x with pos_emb_x: ', x)
print("into self.imagen...")
loss = self.imagen(
output,
text_embeds = x,
# cond_images=cond_images.to(device),
cond_images=this_cond_images,
unet_number = unet_number,
)
return loss
def sample (
self,
x=None,
stop_at_unet_number=1 ,
cond_scale=7.5,
x_data=None,
skip_steps=None,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
init_images = None,
x_data_tokenized=None,
device=None,
# +++++++++++++++++++++++++
tokenizer_X=None,
Xnormfac=1.,
max_length=1.,
):
batch_size=1
if x_data != None:
print ("Conditioning target sequence provided via ori x_data ...", x_data)
# this is for Model B with SecStr as input
#
# # -- ver 0.0: channel = 1, all padding at the end: content+000
# x_data = tokenizer_X.texts_to_sequences(x_data)
# x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post')
#
# x_data= torch.from_numpy(x_data).float().to(device)
# x_data = x_data/Xnormfac
# x_data=x_data.unsqueeze (2)
# x_data=torch.permute(x_data, (0,2,1) )
#
# ++ ver 1.0: channel = self.pred_dim>1, esm padding, 0+content+000
x_data = tokenizer_X.texts_to_sequences(x_data)
# padding: 0+content+00
x_data_0 = []
for this_x_data in x_data:
x_data_0.append([0]+this_x_data) # to be checked whether this works
x_data = sequence.pad_sequences(
x_data_0, maxlen=max_length,
padding='post', truncating='post',
)
# normalization: dummy, for future
x_data = torch.from_numpy(x_data).float().to(device)
x_data = x_data/Xnormfac
# adjust the channel:NOTE, here we use "copy" to fill-in the channels
# May need to adjust this part for future models
x_data=x_data.unsqueeze(1).repeat(1,self.pred_dim,1)
print ("After channel expansion, x_data from target sequence=", x_data, x_data.shape)
batch_size=x_data.shape[0]
if x_data_tokenized != None:
# this is for Model B with ForcePath vector as input
# everything is padded as what the dataloader has
#
# if x_data_tokenized.any() != None:
print ("Conditioning target sequence provided via x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape)
# # ++ ver 1.0
# if self.pred_dim==1:
# # old case: channel=1
# x_data=x_data_tokenized.unsqueeze (2)
# x_data=torch.permute(x_data, (0,2,1) ).to(device)
# else:
# # new case: include channels in x_data_tokenized
# x_data=x_data_tokenized.to(device)
#
# # -- ver 0.0: original
# x_data=x_data_tokenized.unsqueeze (2)
# x_data=torch.permute(x_data, (0,2,1) ).to(device)
#
# ++ ver 2.0: input x_data_tokenized.dim= (batch, seq_len)
x_data=x_data_tokenized.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device)
#
print ("x_data.dim provided from x_data_tokenized: ", x_data.shape)
batch_size=x_data.shape[0]
if init_images != None:
print ("Init sequence provided...", init_images)
init_images = tokenizer_y.texts_to_sequences(init_images)
init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post')
init_images= torch.from_numpy(init_images).float().to(device)/ynormfac
print ("init_images=", init_images)
if inpaint_images != None:
print ("Inpaint sequence provided...", inpaint_images)
print ("Mask: ", inpaint_masks)
inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images)
inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post')
inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac
print ("in_paint images=", inpaint_images)
if x !=None:
x_in=torch.zeros( (x.shape[0],max_length) ).to(device)
x_in[:,:x.shape[1]]=x
x=x_in
x=x.unsqueeze (2)
x= self.fc_embed2(x)
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device)
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
pos_emb_x = torch.squeeze(pos_emb_x, 1)
x= torch.cat( (x, pos_emb_x ), 2)
batch_size=x.shape[0]
output = self.imagen.sample(
text_embeds= x,
cond_scale = cond_scale,
stop_at_unet_number=stop_at_unet_number,
cond_images=x_data,
skip_steps=skip_steps,
inpaint_images = inpaint_images,
inpaint_masks = inpaint_masks,
inpaint_resample_times = inpaint_resample_times,
init_images = init_images,
batch_size=batch_size,
device=device,
)
return output
# ===========================================================
# final models
# explict define unets
class ProteinPredictor_B(nn.Module):
def __init__(self,
unet,
# xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
CKeys=None,
PKeys=None,
):
super(ProteinPredictor_B, self).__init__()
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# unload the parameters
timesteps =default(PKeys['timesteps'], 10) # 10 ,
dim =default(PKeys['dim'], 32) # 32,
pred_dim =default(PKeys['pred_dim'], 25) # 25,
loss_type =default(PKeys['loss_type'], 0) # 0, # MSE
elucidated =default(PKeys['elucidated'], True) # True,
padding_idx =default(PKeys['padding_idx'], 0) # 0,
cond_dim =default(PKeys['cond_dim'], 512) # 512,
text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512,
input_tokens =default(PKeys['input_tokens'], 25) #for non-BERT
sequence_embed =default(PKeys['sequence_embed'], False)
embed_dim_position =default(PKeys['embed_dim_position'], 32)
max_text_len =default(PKeys['max_text_len'], 16)
cond_images_channels=default(PKeys['cond_images_channels'], 0)
# ++++++++++++++++++++++
max_length =default(PKeys['max_length'], 64)
device =default(PKeys['device'], None)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
print ("Model B: Predictive protein diffusion model, residue-based")
print ("Using condition as the initial sequence")
self.pred_dim=pred_dim
self.loss_type=loss_type
# +++++++++++++++++ for debug
self.CKeys=CKeys
self.PKeys=PKeys
self.max_length = max_length
assert loss_type == 0, "Losses other than MSE not implemented"
self.fc_embed1 = nn.Linear( 8, max_length) # NOT used # INPUT DIM (last), OUTPUT DIM, last
self.fc_embed2 = nn.Linear( 1, text_embed_dim) # USED # INPUT DIM (last), OUTPUT DIM, last
self.max_text_len=max_text_len
self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position)
text_embed_dim=text_embed_dim+embed_dim_position
self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long)
for i in range (max_text_len):
self.pos_matrix_i [i]=i +1
condition_on_text=True
self.cond_images_channels=cond_images_channels
if self.cond_images_channels>0:
condition_on_text = False
if self.cond_images_channels>0:
print ("Use conditioning image during training....")
assert elucidated , "Only elucidated model implemented...."
self.is_elucidated=elucidated
if elucidated:
self.imagen = ElucidatedImagen(
unets = (unet),
channels=self.pred_dim,
channels_out=self.pred_dim ,
loss_type=loss_type,
condition_on_text = condition_on_text,
text_embed_dim = text_embed_dim,
image_sizes = ( [max_length ]),
cond_drop_prob = 0.1,
auto_normalize_img = False,
num_sample_steps = timesteps, # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
sigma_min = 0.002, # min noise level
sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
# ++++++++++++++++
# device=device,
CKeys=self.CKeys,
PKeys=self.PKeys,
).to(device)
else:
print ("Not implemented.")
def forward(self,
output,
x=None,
cond_images = None,
unet_number=1,
):
# ddddddddddddddddddddddddddddddddddddddddddddd
if self.CKeys['Debug_ModelPack']==1:
print("on Model B:forward")
print("inputs:")
print(' output: ', output.shape)
print(' x: ', x)
print(' cond_img: ', cond_images.shape)
print(' unet_num: ', unet_number)
if x != None:
x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device)
x_in[:,:x.shape[1]]=x
x=x_in
x=x.unsqueeze (2)
x= self.fc_embed2(x)
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device)
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
pos_emb_x = torch.squeeze(pos_emb_x, 1)
x= torch.cat( (x, pos_emb_x ), 2)
if cond_images!=None:
this_cond_images=cond_images.to(device)
else:
this_cond_images=cond_images
if self.CKeys['Debug_ModelPack']==1:
print('x with pos_emb_x: ', x)
print("into self.imagen...")
loss = self.imagen(
output,
text_embeds = x,
# cond_images=cond_images.to(device),
cond_images=this_cond_images,
unet_number = unet_number,
)
return loss
def sample (
self,
x=None,
stop_at_unet_number=1 ,
cond_scale=7.5,
x_data=None,
skip_steps=None,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
init_images = None,
x_data_tokenized=None,
device=None,
# +++++++++++++++++++++++++
tokenizer_X=None,
Xnormfac=1.,
max_length=1.,
# ++
pLM_Model_Name=None,
pLM_Model=None,
pLM_alphabet=None,
esm_layer=None,
):
batch_size=1
if x_data != None:
print ("Conditioning target sequence provided via ori x_data ...", x_data)
print(f"use pLM model {pLM_Model_Name}")
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# prepare x_data into the cond_img channel (batch, channel, seq_len)
# for ProteinPredictor: AA seq -> FOrcPath
# need to distinguish trivial and ESM series
# mimic the block in DataSetPack
if pLM_Model_Name=='trivial':
# no ESM but direct tokenizer is used here
# NEED to pad the 0th position
x_data = tokenizer_X.texts_to_sequences(x_data)
# two-step padding
x_data = sequence.pad_sequences(
x_data, maxlen=max_length-1,
padding='post', truncating='post',
value=0.0,
)
# add one 0 at the begining
x_data = sequence.pad_sequences(
x_data, maxlen=max_length,
padding='pre', truncating='pre',
value=0.0,
) # (batch, seq_len)
x_data=x_data.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device)
# (batch, 1, seq_len)
x_data = x_data/Xnormfac
else:
# ++ for esm
print("pLM Model: ", pLM_Model_Name)
# 1. from AA string to token
# need to save 2 positions for <cls> and <eos>
esm_batch_converter = pLM_alphabet.get_batch_converter(
truncation_seq_length=max_length-2
)
# prepare seqs for the "esm_batch_converter..."
# add dummy labels
seqs_ext=[]
for i in range(len(x_data)):
seqs_ext.append(
(" ", x_data[i])
)
# batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext)
_, x_strs, x_data = esm_batch_converter(seqs_ext)
x_strs_lens = (x_data != pLM_alphabet.padding_idx).sum(1)
#
# NEED to check the size of y_data
# need to dealwith if y_data are only shorter sequences
# need to add padding with a value, int (1)
current_seq_len = x_data.shape[1]
print("current seq batch len: ", current_seq_len)
missing_num_pad = max_length-current_seq_len
if missing_num_pad>0:
print("extra padding is added to match the target seq input length...")
# padding is needed
x_data = F.pad(
x_data,
(0, missing_num_pad),
"constant", pLM_alphabet.padding_idx
)
else:
print("No extra padding is needed")
x_data = x_data.to(device)
#
# 2. from token to embedding
with torch.no_grad():
results = pLM_Model(
x_data,
repr_layers=[esm_layer],
return_contacts=False,
)
x_data=results["representations"][esm_layer] # (batch, seq_len, channels)
x_data=rearrange(
x_data,
'b l c -> b c l'
)
# print(batch_tokens.shape)
print ("x_data.dim: ", x_data.shape)
print ("x_data.type: ", x_data.type)
batch_size=x_data.shape[0]
# # ---------------------------------------------------------
# # this is for Model B with SecStr as input
# #
# # # -- ver 0.0: channel = 1, all padding at the end: content+000
# # x_data = tokenizer_X.texts_to_sequences(x_data)
# # x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post')
# #
# # x_data= torch.from_numpy(x_data).float().to(device)
# # x_data = x_data/Xnormfac
# # x_data=x_data.unsqueeze (2)
# # x_data=torch.permute(x_data, (0,2,1) )
# #
# # ++ ver 1.0: channel = self.pred_dim>1, esm padding, 0+content+000
# x_data = tokenizer_X.texts_to_sequences(x_data)
# # padding: 0+content+00
# x_data_0 = []
# for this_x_data in x_data:
# x_data_0.append([0]+this_x_data) # to be checked whether this works
# x_data = sequence.pad_sequences(
# x_data_0, maxlen=max_length,
# padding='post', truncating='post',
# )
# # normalization: dummy, for future
# x_data = torch.from_numpy(x_data).float().to(device)
# x_data = x_data/Xnormfac
# # adjust the channel:NOTE, here we use "copy" to fill-in the channels
# # May need to adjust this part for future models
# x_data=x_data.unsqueeze(1).repeat(1,self.pred_dim,1)
# print ("After channel expansion, x_data from target sequence=", x_data, x_data.shape)
# batch_size=x_data.shape[0]
if x_data_tokenized != None:
# ++
# this is for Model B with AA tokens as input
# task: x_data_tokenized (batch, seq_len) -> x_data (batch, channel, seq_len)
print (
"Conditioning target output via provided AA tokens sequence...",
x_data_tokenized,
x_data_tokenized.shape,
)
# transfer tokens into embedding and expand the channels
if pLM_Model_Name=='trivial':
# self.pred_dim should be 1
x_data=x_data_tokenized.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device)
else: # esm models
with torch.no_grad():
results = pLM_Model(
x_data_tokenized,
repr_layers=[esm_layer],
return_contacts=False,
)
x_data=results["representations"][esm_layer] # (batch, seq_len, channels)
x_data=rearrange(
x_data,
'b l c -> b c l'
)
x_data = x_data.to(device)
batch_size=x_data.shape[0]
print ("x_data.dim provided from x_data_tokenized: ", x_data.shape)
# # --
# # this is for Model B with ForcePath vector as input
# # everything is padded as what the dataloader has
# #
# # if x_data_tokenized.any() != None:
# print ("Conditioning target sequence provided via x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape)
# # # ++ ver 1.0
# # if self.pred_dim==1:
# # # old case: channel=1
# # x_data=x_data_tokenized.unsqueeze (2)
# # x_data=torch.permute(x_data, (0,2,1) ).to(device)
# # else:
# # # new case: include channels in x_data_tokenized
# # x_data=x_data_tokenized.to(device)
# #
# # # -- ver 0.0: original
# # x_data=x_data_tokenized.unsqueeze (2)
# # x_data=torch.permute(x_data, (0,2,1) ).to(device)
# #
# # ++ ver 2.0: input x_data_tokenized.dim= (batch, seq_len)
# x_data=x_data_tokenized.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device)
# #
# print ("x_data.dim provided from x_data_tokenized: ", x_data.shape)
# batch_size=x_data.shape[0]
if init_images != None:
print ("Init sequence provided...", init_images)
init_images = tokenizer_y.texts_to_sequences(init_images)
init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post')
init_images= torch.from_numpy(init_images).float().to(device)/ynormfac
print ("init_images=", init_images)
if inpaint_images != None:
print ("Inpaint sequence provided...", inpaint_images)
print ("Mask: ", inpaint_masks)
inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images)
inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post')
inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac
print ("in_paint images=", inpaint_images)
if x !=None:
x_in=torch.zeros( (x.shape[0],max_length) ).to(device)
x_in[:,:x.shape[1]]=x
x=x_in
x=x.unsqueeze (2)
x= self.fc_embed2(x)
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device)
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
pos_emb_x = torch.squeeze(pos_emb_x, 1)
x= torch.cat( (x, pos_emb_x ), 2)
batch_size=x.shape[0]
output = self.imagen.sample(
text_embeds= x,
cond_scale = cond_scale,
stop_at_unet_number=stop_at_unet_number,
cond_images=x_data,
skip_steps=skip_steps,
inpaint_images = inpaint_images,
inpaint_masks = inpaint_masks,
inpaint_resample_times = inpaint_resample_times,
init_images = init_images,
batch_size=batch_size,
device=device,
)
return output
# ===========================================================
# final models
# explict define unets
class ProteinDesigner_B_Old(nn.Module):
def __init__(self,
unet,
timesteps=10 ,
dim=32,
pred_dim=25,
loss_type=0, # MSE
elucidated=True,
padding_idx=0,
cond_dim = 512,
text_embed_dim = 512,
input_tokens=25,#for non-BERT
sequence_embed=False,
embed_dim_position=32,
max_text_len=16,
cond_images_channels=0,
# ++++++++++++++++++++++
max_length=1,
device=None,
CKeys=None,
PKeys=None,
):
super(ProteinDesigner_B_Old, self).__init__()
print ("Model B: Generative protein diffusion model, residue-based")
print ("Using condition as the initial sequence")
self.pred_dim=pred_dim
self.loss_type=loss_type
# +++++++++++++++++ for debug
self.CKeys=CKeys
self.PKeys=PKeys
self.max_length = max_length
assert loss_type == 0, "Losses other than MSE not implemented"
self.fc_embed1 = nn.Linear( 8, max_length) # NOT used # INPUT DIM (last), OUTPUT DIM, last
self.fc_embed2 = nn.Linear( 1, text_embed_dim) # USED # INPUT DIM (last), OUTPUT DIM, last
self.max_text_len=max_text_len
self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position)
text_embed_dim=text_embed_dim+embed_dim_position
self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long)
for i in range (max_text_len):
self.pos_matrix_i [i]=i +1
condition_on_text=True
self.cond_images_channels=cond_images_channels
if self.cond_images_channels>0:
condition_on_text = False
if self.cond_images_channels>0:
print ("Use conditioning image during training....")
assert elucidated , "Only elucidated model implemented...."
self.is_elucidated=elucidated
if elucidated:
self.imagen = ElucidatedImagen(
unets = (unet),
channels=self.pred_dim,
channels_out=self.pred_dim ,
loss_type=loss_type,
condition_on_text = condition_on_text,
text_embed_dim = text_embed_dim,
image_sizes = ( [max_length ]),
cond_drop_prob = 0.1,
auto_normalize_img = False,
num_sample_steps = timesteps, # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
sigma_min = 0.002, # min noise level
sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
# ++++++++++++++++
# device=device,
CKeys=self.CKeys,
PKeys=self.PKeys,
).to(device)
else:
print ("Not implemented.")
def forward(self,
output,
x=None,
cond_images = None,
unet_number=1,
):
# ddddddddddddddddddddddddddddddddddddddddddddd
if self.CKeys['Debug_ModelPack']==1:
print('output: ', output.shape)
print('x: ', x)
print('cond_img: ', cond_images)
print('unet_num: ', unet_number)
if x != None:
x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device)
x_in[:,:x.shape[1]]=x
x=x_in
x=x.unsqueeze (2)
x= self.fc_embed2(x)
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device)
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
pos_emb_x = torch.squeeze(pos_emb_x, 1)
x= torch.cat( (x, pos_emb_x ), 2)
if cond_images!=None:
this_cond_images=cond_images.to(device)
else:
this_cond_images=cond_images
if self.CKeys['Debug_ModelPack']==1:
print('x with pos_emb_x: ', x)
loss = self.imagen(
output,
text_embeds = x,
# cond_images=cond_images.to(device),
cond_images=this_cond_images,
unet_number = unet_number,
)
return loss
def sample (
self,
x=None,
stop_at_unet_number=1 ,
cond_scale=7.5,
x_data=None,
skip_steps=None,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
init_images = None,
x_data_tokenized=None,
device=None,
# +++++++++++++++++++++++++
tokenizer_X=None,
Xnormfac=1.,
ynormfac=1.,
max_length=1.,
):
batch_size=1
if x_data != None:
print ("Conditioning target sequence provided via x_data ...", x_data)
x_data = tokenizer_X.texts_to_sequences(x_data)
x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post')
x_data= torch.from_numpy(x_data).float().to(device)
x_data = x_data/Xnormfac
x_data=x_data.unsqueeze (2)
x_data=torch.permute(x_data, (0,2,1) )
print ("x_data from target sequence=", x_data, x_data.shape)
batch_size=x_data.shape[0]
if x_data_tokenized != None:
print ("Conditioning target sequence provided via x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape)
x_data=x_data_tokenized.unsqueeze (2)
x_data=torch.permute(x_data, (0,2,1) ).to(device)
print ("Data provided from x_data_tokenized: ", x_data.shape)
batch_size=x_data.shape[0]
if init_images != None:
print ("Init sequence provided...", init_images)
init_images = tokenizer_y.texts_to_sequences(init_images)
init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post')
init_images= torch.from_numpy(init_images).float().to(device)/ynormfac
print ("init_images=", init_images)
if inpaint_images != None:
print ("Inpaint sequence provided...", inpaint_images)
print ("Mask: ", inpaint_masks)
inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images)
inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post')
inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac
print ("in_paint images=", inpaint_images)
if x !=None:
x_in=torch.zeros( (x.shape[0],max_length) ).to(device)
x_in[:,:x.shape[1]]=x
x=x_in
x=x.unsqueeze (2)
x= self.fc_embed2(x)
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device)
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
pos_emb_x = torch.squeeze(pos_emb_x, 1)
x= torch.cat( (x, pos_emb_x ), 2)
batch_size=x.shape[0]
output = self.imagen.sample(
text_embeds= x,
cond_scale = cond_scale,
stop_at_unet_number=stop_at_unet_number,
cond_images=x_data,
skip_steps=skip_steps,
inpaint_images = inpaint_images,
inpaint_masks = inpaint_masks,
inpaint_resample_times = inpaint_resample_times,
init_images = init_images,
batch_size=batch_size,
device=device,
)
return output
# ===========================================================
# final models: old model
# Changes: 1. put the UNet part out
#
class ProteinDesigner_A_II(nn.Module):
def __init__(
self,
unet1,
# xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
CKeys=None,
PKeys=None,
):
super(ProteinDesigner_A_II, self).__init__()
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# unload the arguments
timesteps =default(PKeys['timesteps'], 10) # 10
dim =default(PKeys['dim'], 32) # 32,
pred_dim =default(PKeys['pred_dim'], 25) # 25,
loss_type =default(PKeys['loss_type'], 0) # 0,
elucidated =default(PKeys['elucidated'], True) # False,
padding_idx =default(PKeys['padding_idx'], 0) # 0,
cond_dim =default(PKeys['cond_dim'], 512) # 512,
text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512,
input_tokens =default(PKeys['input_tokens'], 25) # 25,#for non-BERT
sequence_embed =default(PKeys['sequence_embed'], False) # False,
embed_dim_position =default(PKeys['embed_dim_position'], 32) # 32,
max_text_len =default(PKeys['max_text_len'], 16) # 16,
cond_images_channels=default(PKeys['cond_images_channels'], 0)
# ++
max_length =default(PKeys['max_length'], 64) # 64,
device =default(PKeys['device'], 'cuda:0') # 'cuda:0',
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
self.pred_dim=pred_dim
self.loss_type=loss_type
# ++
self.CKeys=CKeys
self.PKeys=PKeys
self.device=device
assert (loss_type==0), "Loss other than MSE not implemented"
self.fc_embed1 = nn.Linear( 8, max_length) # NOT USED
self.fc_embed2 = nn.Linear( 1, text_embed_dim) # project the text into embedding dimensions
self.max_text_len=max_text_len
self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position)
text_embed_dim=text_embed_dim+embed_dim_position
self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long)
for i in range (max_text_len):
self.pos_matrix_i [i]=i +1
if self.CKeys['Debug_ModelPack']==1:
print("ModelA.pos_matrix_i: ", self.pos_matrix_i)
# # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# # prepare the Unet Key
# write_PK_UNet=dict()
# write_PK_UNet['dim']=dim
# write_PK_UNet['text_embed_dim']=text_embed_dim
# write_PK_UNet['cond_dim']=cond_dim
# write_PK_UNet['dim_mults']=(1, 2, 4, 8)
# write_PK_UNet['num_resnet_blocks']=1
# write_PK_UNet['layer_attns']=(False, True, True, False)
# write_PK_UNet['layer_cross_attns']=(False, True, True, False)
# write_PK_UNet['channels']=pred_dim
# write_PK_UNet['channels_out']=pred_dim
# write_PK_UNet['attn_dim_head']=64
# write_PK_UNet['attn_heads']=8
# write_PK_UNet['ff_mult']=2.
# write_PK_UNet['lowres_cond']=False # for cascading diffusion - https://cascaded-diffusion.github.io/
# write_PK_UNet['layer_attns_depth']=1
# write_PK_UNet['layer_attns_add_text_cond']=True # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
# write_PK_UNet['attend_at_middle']=True # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
# write_PK_UNet['use_linear_attn']=False
# write_PK_UNet['use_linear_cross_attn']=False
# write_PK_UNet['cond_on_text'] = True
# write_PK_UNet['max_text_len'] = max_length
# write_PK_UNet['init_dim'] = None
# write_PK_UNet['resnet_groups'] = 8
# write_PK_UNet['init_conv_kernel_size'] =7 # kernel size of initial conv, if not using cross embed
# write_PK_UNet['init_cross_embed'] = False #TODO - fix ouput size calcs for conv1d
# write_PK_UNet['init_cross_embed_kernel_sizes'] = (3, 7, 15)
# write_PK_UNet['cross_embed_downsample'] = False
# write_PK_UNet['cross_embed_downsample_kernel_sizes'] = (2, 4)
# write_PK_UNet['attn_pool_text'] = True
# write_PK_UNet['attn_pool_num_latents'] = 32 #32, #perceiver model latents
# write_PK_UNet['dropout'] = 0.
# write_PK_UNet['memory_efficient'] = False
# write_PK_UNet['init_conv_to_final_conv_residual'] = False
# write_PK_UNet['use_global_context_attn'] = True
# write_PK_UNet['scale_skip_connection'] = True
# write_PK_UNet['final_resnet_block'] = True
# write_PK_UNet['final_conv_kernel_size'] = 3
# write_PK_UNet['cosine_sim_attn'] = True
# write_PK_UNet['self_cond'] = False
# write_PK_UNet['combine_upsample_fmaps'] = True # combine feature maps from all upsample blocks, used in unet squared successfully
# write_PK_UNet['pixel_shuffle_upsample'] = False # may address checkboard artifacts
# # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
# #
# Unet_PKeys=prepare_UNet_keys(write_PK_UNet)
# unet1 = OneD_Unet(
# CKeys=CKeys,
# PKeys=Unet_PKeys,
# ).to (self.device)
# # -----------------------------------------------------------------------------
# unet1 = OneD_Unet_Old(
# dim = dim,
# text_embed_dim = text_embed_dim,
# cond_dim = cond_dim,
# dim_mults = (1, 2, 4, 8),
# num_resnet_blocks = 1,#1,
# layer_attns = (False, True, True, False),
# layer_cross_attns = (False, True, True, False),
# channels=self.pred_dim,
# channels_out=self.pred_dim ,
# #
# attn_dim_head = 64,
# attn_heads = 8,
# ff_mult = 2.,
# lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
# layer_attns_depth =1,
# layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
# attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
# use_linear_attn = False,
# use_linear_cross_attn = False,
# cond_on_text = True,
# max_text_len = max_length,
# init_dim = None,
# resnet_groups = 8,
# init_conv_kernel_size =7, # kernel size of initial conv, if not using cross embed
# init_cross_embed = False, #TODO - fix ouput size calcs for conv1d
# init_cross_embed_kernel_sizes = (3, 7, 15),
# cross_embed_downsample = False,
# cross_embed_downsample_kernel_sizes = (2, 4),
# attn_pool_text = True,
# attn_pool_num_latents = 32,#32, #perceiver model latents
# dropout = 0.,
# memory_efficient = False,
# init_conv_to_final_conv_residual = False,
# use_global_context_attn = True,
# scale_skip_connection = True,
# final_resnet_block = True,
# final_conv_kernel_size = 3,
# cosine_sim_attn = True,
# self_cond = False,
# combine_upsample_fmaps = True, # combine feature maps from all upsample blocks, used in unet squared successfully
# pixel_shuffle_upsample = False , # may address checkboard artifacts
# # ++
# CKeys=CKeys,
# ).to (self.device)
assert elucidated , "Only elucidated model implemented...."
self.is_elucidated=elucidated
if elucidated:
self.imagen = ElucidatedImagen(
unets = (unet1),
channels=self.pred_dim,
channels_out=self.pred_dim ,
loss_type=loss_type,
text_embed_dim = text_embed_dim,
image_sizes = [max_length],
cond_drop_prob = 0.2,
auto_normalize_img = False,
num_sample_steps = timesteps,#(64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
sigma_min = 0.002, # min noise level
sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
# ++
CKeys=self.CKeys,
PKeys=self.PKeys,
).to (self.device)
# + for debug:
if CKeys['Debug_ModelPack']==1:
print("Check on EImagen:")
print("channels: ", self.pred_dim)
print("loss_type: ", loss_type)
print("text_embed_dim: ",text_embed_dim)
print("image_sizes: ", max_length)
print("num_sample_steps: ", timesteps)
print("Measure imagen:")
params( self.imagen)
print("Measure fc_embed2")
params( self.fc_embed2)
print("Measure pos_emb_x")
params( self.pos_emb_x)
else:
print ("Not implemented.")
# need to merge this with ModelB:forward
def forward(
self,
# # -------------------
# x,
# output,
# +++++++++++++++++++
output,
x=None,
#
cond_images=None,
unet_number=1,
): #sequences=conditioning, output=prediction
# dddddddddddddddddddddddddddddddddddd
if self.CKeys['Debug_ModelPack']==1:
print("on Model A:forward")
print("inputs:")
print(' output.dim : ', output.shape)
print(' x: ', x)
print(' x.dim: ', x.shape)
if cond_images==None:
print(' cond_img: None ')
else:
print(' cond_img: ', cond_images.shape)
print(' unet_num: ', unet_number)
x=x.unsqueeze (2)
if self.CKeys['Debug_ModelPack']==1:
print("After x.unsqueeze(2), x.dim: ", x.shape)
x= self.fc_embed2(x)
if self.CKeys['Debug_ModelPack']==1:
print("After fc_embed2(x), x.dim: ", x.shape)
print()
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device)
if self.CKeys['Debug_ModelPack']==1:
print("pos_matrix_i_.dim: ", pos_matrix_i_.shape)
print("pos_matrix_i_: ", pos_matrix_i_)
print()
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
if self.CKeys['Debug_ModelPack']==1:
print("After pos_emb_x(pos_matrix_i_), pos_emb_x.dim: ", pos_emb_x.shape)
print("pos_emb_x: ", pos_emb_x)
print()
pos_emb_x = torch.squeeze(pos_emb_x, 1)
pos_emb_x[:,x.shape[1]:,:]=0 #set all to zero that are not provided via x
pos_emb_x=pos_emb_x[:,:x.shape[1],:]
if self.CKeys['Debug_ModelPack']==1:
print("after operations, pos_emb_x.dim: ", pos_emb_x.shape)
print("pos_emb_x: ", pos_emb_x)
print()
x= torch.cat( (x, pos_emb_x ), 2)
if self.CKeys['Debug_ModelPack']==1:
print("after cat((x,pos_emb_x),2)=>x dim: ", x.shape, "Batch x max_text_len x (text_embed_dim+embed_dim_position)")
print()
print("Now, get into self.imagen part...")
# ref: the full argument for imagen:forward()
# _________________________________________________
# self,
# images,
# unet: Union[ NullUnet, DistributedDataParallel] = None,
# texts: List[str] = None,
# text_embeds = None,
# text_masks = None,
# unet_number = None,
# cond_images = None,
# _________________________________________________
loss = self.imagen(
output,
text_embeds = x,
unet_number = unet_number,
)
return loss
def sample (
self,
x=None,
stop_at_unet_number=1,
cond_scale=7.5,
# ++
x_data=None, # image_condi data
skip_steps=None,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
init_images = None,
x_data_tokenized=None,
tokenizer_X=None,
Xnormfac=None,
# -+
device=None,
max_length=None, # for XandY data, in image/sequence format; NOT for text condition
max_text_len=None, # for X data, in text format
):
# + add for the uniform shape for model A and B
if x_data != None: # condition via images
print ("Conditioning target sequence provided via sequence/image in x_data ...", x_data)
# need tokeni
x_data = tokenizer_X.texts_to_sequences(x_data)
x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post')
x_data= torch.from_numpy(x_data).float().to(self.device)
x_data = x_data/Xnormfac
x_data=x_data.unsqueeze (2)
x_data=torch.permute(x_data, (0,2,1) )
print ("x_data from target sequence=", x_data, x_data.shape)
batch_size=x_data.shape[0]
# + add for tokenized sequence/image data
if x_data_tokenized != None:
print ("Conditioning target sequence provided via processed sequence/image in x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape)
x_data=x_data_tokenized.unsqueeze (2)
x_data=torch.permute(x_data, (0,2,1) ).to(self.device)
print ("Data provided from x_data_tokenized: ", x_data.shape)
batch_size=x_data.shape[0]
# keep but not used
if init_images != None: # initial sequence provided via init_images/sequences
# need tokenizer_y, ynormfac, max_length
print ("Init sequence provided...", init_images)
init_images = tokenizer_y.texts_to_sequences(init_images)
init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post')
init_images= torch.from_numpy(init_images).float().to(self.device)/ynormfac
print ("init_images=", init_images)
#
if inpaint_images != None: # inpainting model
# need inpaint_images, inpaint_masks, tokenizer_y, ynormfac, max_length
print ("Inpaint sequence provided...", inpaint_images)
print ("Mask: ", inpaint_masks)
inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images)
inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post')
inpaint_images= torch.from_numpy(inpaint_images).float().to(self.device)/ynormfac
print ("in_paint images=", inpaint_images)
# now, for Model A type: text conditioning
if x !=None:
print ("Conditioning target sequence via tokenized text in x ...", x)
# for now, we only consider tokenized text in x
# need now: max_text_len ; need for future:
# + for debug
if self.CKeys['Debug_ModelPack']==1:
print("x.dim: ", x.shape)
print("max_text_len: ", max_text_len)
print("self.device: ",self.device)
x_in=torch.zeros( (x.shape[0],max_text_len) ).to(self.device)
x_in[:,:x.shape[1]]=x
x=x_in
x=x.unsqueeze (2)
x= self.fc_embed2(x)
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(self.device)
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
pos_emb_x = torch.squeeze(pos_emb_x, 1)
pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x
pos_emb_x=pos_emb_x[:,:x.shape[1],:]
x= torch.cat( (x, pos_emb_x ), 2)
batch_size=x.shape[0]
output = self.imagen.sample(
text_embeds= x,
cond_scale= cond_scale,
stop_at_unet_number=stop_at_unet_number,
# ++
cond_images=x_data,
skip_steps=skip_steps,
inpaint_images = inpaint_images,
inpaint_masks = inpaint_masks,
inpaint_resample_times = inpaint_resample_times,
init_images = init_images,
batch_size=batch_size,
device=self.device,
)
# x=x.unsqueeze (2)
# x= self.fc_embed2(x)
# pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device)
# pos_emb_x = self.pos_emb_x( pos_matrix_i_)
# pos_emb_x = torch.squeeze(pos_emb_x, 1)
# pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x
# pos_emb_x=pos_emb_x[:,:x.shape[1],:]
# x= torch.cat( (x, pos_emb_x ), 2)
# output = self.imagen.sample(
# text_embeds= x,
# cond_scale=cond_scale,
# stop_at_unet_number=stop_at_unet_number
# )
return output
# ===========================================================
# final models: old model
# Changes: 1. replace the OneD_UNet part
#
class ProteinDesigner_A_I(nn.Module):
def __init__(
self,
# xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
CKeys=None,
PKeys=None,
):
super(ProteinDesigner_A_I, self).__init__()
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# unload the arguments
timesteps =default(PKeys['timesteps'], 10) # 10
dim =default(PKeys['dim'], 32) # 32,
pred_dim =default(PKeys['pred_dim'], 25) # 25,
loss_type =default(PKeys['loss_type'], 0) # 0,
elucidated =default(PKeys['elucidated'], True) # False,
padding_idx =default(PKeys['padding_idx'], 0) # 0,
cond_dim =default(PKeys['cond_dim'], 512) # 512,
text_embed_dim =default(PKeys['text_embed_dim'], 512) # 512,
input_tokens =default(PKeys['input_tokens'], 25) # 25,#for non-BERT
sequence_embed =default(PKeys['sequence_embed'], False) # False,
embed_dim_position =default(PKeys['embed_dim_position'], 32) # 32,
max_text_len =default(PKeys['max_text_len'], 16) # 16,
cond_images_channels=default(PKeys['cond_images_channels'], 0)
# ++
max_length =default(PKeys['max_length'], 64) # 64,
device =default(PKeys['device'], 'cuda:0') # 'cuda:0',
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
# ++
self.CKeys=CKeys
self.PKeys=PKeys
self.device=device
self.pred_dim=pred_dim
self.loss_type=loss_type
self.fc_embed1 = nn.Linear( 8, max_length) # NOT USED
self.fc_embed2 = nn.Linear( 1, text_embed_dim) #
self.max_text_len=max_text_len
self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position)
text_embed_dim=text_embed_dim+embed_dim_position
self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long)
for i in range (max_text_len):
self.pos_matrix_i [i]=i +1
assert (loss_type==0), "Loss other than MSE not implemented"
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# prepare the Unet Key
write_PK_UNet=dict()
write_PK_UNet['dim']=dim
write_PK_UNet['text_embed_dim']=text_embed_dim
write_PK_UNet['cond_dim']=cond_dim
write_PK_UNet['dim_mults']=(1, 2, 4, 8)
write_PK_UNet['num_resnet_blocks']=1
write_PK_UNet['layer_attns']=(False, True, True, False)
write_PK_UNet['layer_cross_attns']=(False, True, True, False)
write_PK_UNet['channels']=pred_dim
write_PK_UNet['channels_out']=pred_dim
write_PK_UNet['attn_dim_head']=64
write_PK_UNet['attn_heads']=8
write_PK_UNet['ff_mult']=2.
write_PK_UNet['lowres_cond']=False # for cascading diffusion - https://cascaded-diffusion.github.io/
write_PK_UNet['layer_attns_depth']=1
write_PK_UNet['layer_attns_add_text_cond']=True # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
write_PK_UNet['attend_at_middle']=True # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
write_PK_UNet['use_linear_attn']=False
write_PK_UNet['use_linear_cross_attn']=False
write_PK_UNet['cond_on_text'] = True
write_PK_UNet['max_text_len'] = max_length
write_PK_UNet['init_dim'] = None
write_PK_UNet['resnet_groups'] = 8
write_PK_UNet['init_conv_kernel_size'] =7 # kernel size of initial conv, if not using cross embed
write_PK_UNet['init_cross_embed'] = False #TODO - fix ouput size calcs for conv1d
write_PK_UNet['init_cross_embed_kernel_sizes'] = (3, 7, 15)
write_PK_UNet['cross_embed_downsample'] = False
write_PK_UNet['cross_embed_downsample_kernel_sizes'] = (2, 4)
write_PK_UNet['attn_pool_text'] = True
write_PK_UNet['attn_pool_num_latents'] = 32 #32, #perceiver model latents
write_PK_UNet['dropout'] = 0.
write_PK_UNet['memory_efficient'] = False
write_PK_UNet['init_conv_to_final_conv_residual'] = False
write_PK_UNet['use_global_context_attn'] = True
write_PK_UNet['scale_skip_connection'] = True
write_PK_UNet['final_resnet_block'] = True
write_PK_UNet['final_conv_kernel_size'] = 3
write_PK_UNet['cosine_sim_attn'] = True
write_PK_UNet['self_cond'] = False
write_PK_UNet['combine_upsample_fmaps'] = True # combine feature maps from all upsample blocks, used in unet squared successfully
write_PK_UNet['pixel_shuffle_upsample'] = False # may address checkboard artifacts
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
#
Unet_PKeys=prepare_UNet_keys(write_PK_UNet)
unet1 = OneD_Unet(
CKeys=CKeys,
PKeys=Unet_PKeys,
).to (self.device)
if CKeys['Debug_ModelPack']==1:
print('Check unet generated...')
params(unet1)
# # -----------------------------------------------------------------------------
# unet1 = OneD_Unet_Old(
# dim = dim,
# text_embed_dim = text_embed_dim,
# cond_dim = cond_dim,
# dim_mults = (1, 2, 4, 8),
# num_resnet_blocks = 1,#1,
# layer_attns = (False, True, True, False),
# layer_cross_attns = (False, True, True, False),
# channels=self.pred_dim,
# channels_out=self.pred_dim ,
# #
# attn_dim_head = 64,
# attn_heads = 8,
# ff_mult = 2.,
# lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
# layer_attns_depth =1,
# layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
# attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
# use_linear_attn = False,
# use_linear_cross_attn = False,
# cond_on_text = True,
# max_text_len = max_length,
# init_dim = None,
# resnet_groups = 8,
# init_conv_kernel_size =7, # kernel size of initial conv, if not using cross embed
# init_cross_embed = False, #TODO - fix ouput size calcs for conv1d
# init_cross_embed_kernel_sizes = (3, 7, 15),
# cross_embed_downsample = False,
# cross_embed_downsample_kernel_sizes = (2, 4),
# attn_pool_text = True,
# attn_pool_num_latents = 32,#32, #perceiver model latents
# dropout = 0.,
# memory_efficient = False,
# init_conv_to_final_conv_residual = False,
# use_global_context_attn = True,
# scale_skip_connection = True,
# final_resnet_block = True,
# final_conv_kernel_size = 3,
# cosine_sim_attn = True,
# self_cond = False,
# combine_upsample_fmaps = True, # combine feature maps from all upsample blocks, used in unet squared successfully
# pixel_shuffle_upsample = False , # may address checkboard artifacts
# # ++
# CKeys=CKeys,
# ).to (self.device)
assert elucidated , "Only elucidated model implemented...."
self.is_elucidated=elucidated
if elucidated:
self.imagen = ElucidatedImagen(
unets = (unet1),
channels=self.pred_dim,
channels_out=self.pred_dim ,
loss_type=loss_type,
text_embed_dim = text_embed_dim,
image_sizes = [max_length],
cond_drop_prob = 0.2,
auto_normalize_img = False,
num_sample_steps = timesteps,#(64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
sigma_min = 0.002, # min noise level
sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
# ++
CKeys=self.CKeys,
PKeys=self.PKeys,
).to (self.device)
else:
print ("Not implemented.")
def forward(self, x, output, unet_number=1): #sequences=conditioning, output=prediction
x=x.unsqueeze (2)
x= self.fc_embed2(x)
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device)
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
pos_emb_x = torch.squeeze(pos_emb_x, 1)
pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x
pos_emb_x=pos_emb_x[:,:x.shape[1],:]
x= torch.cat( (x, pos_emb_x ), 2)
loss = self.imagen(
output,
text_embeds = x,
unet_number = unet_number,
)
return loss
def sample (self, x, stop_at_unet_number=1 ,cond_scale=7.5,):
x=x.unsqueeze (2)
x= self.fc_embed2(x)
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device)
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
pos_emb_x = torch.squeeze(pos_emb_x, 1)
pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x
pos_emb_x=pos_emb_x[:,:x.shape[1],:]
x= torch.cat( (x, pos_emb_x ), 2)
output = self.imagen.sample(text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number)
return output
# ===========================================================
# final models: old model
#
class ProteinDesigner_A_Old(nn.Module):
def __init__(
self,
timesteps=10 ,
dim=32,
pred_dim=25,
loss_type=0,
elucidated=False,
padding_idx=0,
cond_dim = 512,
text_embed_dim = 512,
input_tokens=25,#for non-BERT
sequence_embed=False,
embed_dim_position=32,
max_text_len=16,
device='cuda:0',
# ++
max_length=64,
CKeys=None,
PKeys=None,
):
super(ProteinDesigner_A_Old, self).__init__()
# ++
self.CKeys=CKeys
self.PKeys=PKeys
self.device=device
self.pred_dim=pred_dim
self.loss_type=loss_type
self.fc_embed1 = nn.Linear( 8, max_length) # NOT USED
self.fc_embed2 = nn.Linear( 1, text_embed_dim) #
self.max_text_len=max_text_len
self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position)
text_embed_dim=text_embed_dim+embed_dim_position
self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long)
for i in range (max_text_len):
self.pos_matrix_i [i]=i +1
assert (loss_type==0), "Loss other than MSE not implemented"
unet1 = OneD_Unet_Old(
dim = dim,
text_embed_dim = text_embed_dim,
cond_dim = cond_dim,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 1,#1,
layer_attns = (False, True, True, False),
layer_cross_attns = (False, True, True, False),
channels=self.pred_dim,
channels_out=self.pred_dim ,
#
attn_dim_head = 64,
attn_heads = 8,
ff_mult = 2.,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
layer_attns_depth =1,
layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
use_linear_attn = False,
use_linear_cross_attn = False,
cond_on_text = True,
max_text_len = max_length, # need to be checked
init_dim = None,
resnet_groups = 8,
init_conv_kernel_size =7, # kernel size of initial conv, if not using cross embed
init_cross_embed = False, #TODO - fix ouput size calcs for conv1d
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
attn_pool_text = True,
attn_pool_num_latents = 32,#32, #perceiver model latents
dropout = 0.,
memory_efficient = False,
init_conv_to_final_conv_residual = False,
use_global_context_attn = True,
scale_skip_connection = True,
final_resnet_block = True,
final_conv_kernel_size = 3,
cosine_sim_attn = True,
self_cond = False,
combine_upsample_fmaps = True, # combine feature maps from all upsample blocks, used in unet squared successfully
pixel_shuffle_upsample = False , # may address checkboard artifacts
# ++
CKeys=CKeys,
).to (self.device)
# print the size
if CKeys['Debug_ModelPack']==1:
print("Check NUnet...")
params( unet1)
assert elucidated , "Only elucidated model implemented...."
self.is_elucidated=elucidated
if elucidated:
self.imagen = ElucidatedImagen(
unets = (unet1),
channels=self.pred_dim,
channels_out=self.pred_dim ,
loss_type=loss_type,
text_embed_dim = text_embed_dim,
image_sizes = [max_length],
cond_drop_prob = 0.2,
auto_normalize_img = False,
num_sample_steps = timesteps,#(64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
sigma_min = 0.002, # min noise level
sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
# ++
CKeys=self.CKeys,
PKeys=self.PKeys,
).to (self.device)
if CKeys['Debug_ModelPack']==1:
print("Check on EImagen:")
print("channels: ", self.pred_dim)
print("loss_type: ", loss_type)
print("text_embed_dim: ",text_embed_dim)
print("image_sizes: ", max_length)
print("num_sample_steps: ", timesteps)
print("Measure imagen:")
params( self.imagen)
print("Measure fc_embed2")
params( self.fc_embed2)
print("Measure pos_emb_x")
params( self.pos_emb_x)
else:
print ("Not implemented.")
def forward(self, x, output, unet_number=1): #sequences=conditioning, output=prediction
x=x.unsqueeze (2)
x= self.fc_embed2(x)
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device)
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
pos_emb_x = torch.squeeze(pos_emb_x, 1)
pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x
pos_emb_x=pos_emb_x[:,:x.shape[1],:]
x= torch.cat( (x, pos_emb_x ), 2)
loss = self.imagen(
output,
text_embeds = x,
unet_number = unet_number,
)
return loss
def sample (self, x, stop_at_unet_number=1 ,cond_scale=7.5,):
x=x.unsqueeze (2)
x= self.fc_embed2(x)
pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device)
pos_emb_x = self.pos_emb_x( pos_matrix_i_)
pos_emb_x = torch.squeeze(pos_emb_x, 1)
pos_emb_x[:,x.shape[1]:,:]=0#set all to zero that are not provided via x
pos_emb_x=pos_emb_x[:,:x.shape[1],:]
x= torch.cat( (x, pos_emb_x ), 2)
output = self.imagen.sample(text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number)
return output