xco2
init
ebf6d7b
"""
att_uncontrol9_adam以及之前的都是用这个
"""
import numpy as np
import torch
import torch.nn as nn
import math
class SubPixelConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, scale_factor=2):
super(SubPixelConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels * scale_factor ** 2, kernel_size, stride,
padding=kernel_size // 2)
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
return x
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
# swish
return x * torch.sigmoid(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(self, channels, num_heads=-1, use_checkpoint=False):
super().__init__()
self.channels = channels
self.num_heads = num_heads if num_heads != -1 else min(channels // 32, 8)
self.use_checkpoint = use_checkpoint
self.norm = nn.GroupNorm(16, channels, eps=1e-6)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.attention = QKVAttention()
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
def forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
h = self.attention(qkv)
h = h.reshape(b, -1, h.shape[-1])
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention.
"""
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
:return: an [N x C x T] tensor after attention.
"""
ch = qkv.shape[1] // 3
q, k, v = torch.split(qkv, ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
return torch.einsum("bts,bcs->bct", weight, v)
@staticmethod
def count_flops(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
model.total_ops += torch.DoubleTensor([matmul_ops])
# ====================================================================
class TEncoder(nn.Module):
def __init__(self, out_c=256, scale=30.):
super(TEncoder, self).__init__()
# 随机映射
self.out_c = out_c
self.W = nn.Parameter(torch.randn(out_c // 2) * scale, requires_grad=False)
self.linear = nn.Sequential(nn.Linear(out_c, out_c),
Swish(),
nn.Linear(out_c, out_c),
)
def timestep_embedding(self, timesteps, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = self.out_c // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if self.out_c % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
# t_proj = t * self.W[None, :] * 2 * np.pi
# t_proj = torch.cat((torch.sin(t_proj), torch.cos(t_proj)), dim=-1)
t_proj = self.timestep_embedding(t)[:, 0, :]
encoded_t = self.linear(t_proj)
return encoded_t
class EncoderBlock(nn.Module):
def __init__(self, in_c, out_c, kernel_size, stride, t_in_c, att_num_head=-1, block_deep=4):
super(EncoderBlock, self).__init__()
self.in_c = in_c
self.out_c = out_c
self.stride = stride
self.model_list_len = block_deep # 一个block有多少次卷积
padding = kernel_size // 2
self.model_list = nn.ModuleList()
self.model_list.append(nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding),
nn.GroupNorm(16, out_c, eps=1e-6),
Swish()))
if att_num_head != 0: # stride == 1
self.att_block = AttentionBlock(out_c, num_heads=att_num_head)
else:
self.att_block = nn.Identity()
for _ in range(self.model_list_len - 2): # -2是减一头一尾
self.model_list.append(
nn.Sequential(
nn.Conv2d(out_c, out_c, kernel_size=kernel_size, stride=1,
padding=padding),
nn.GroupNorm(16, out_c, eps=1e-6),
Swish(),
))
self.model_list.append(
nn.Sequential(
nn.Conv2d(out_c, out_c, kernel_size=kernel_size, stride=1,
padding=padding),
nn.GroupNorm(16, out_c, eps=1e-6),
))
# 编码时间t
self.encode_t = nn.ModuleList(
[nn.Linear(t_in_c, out_c) for _ in range(len(self.model_list) - 1)])
if self.in_c != self.out_c or self.stride != 1:
self.conv_skip = nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, padding=0)
else:
self.conv_skip = nn.Identity()
self.act_skip = Swish()
def forward(self, x, t):
skip = self.conv_skip(x)
for i, layer in enumerate(self.model_list):
x = layer(x)
if i == 0:
x = self.att_block(x)
if i < self.model_list_len - 1:
t_ = self.encode_t[i](t)
# t_ = torch.tile(t[:, :, None, None], dims=[1, 1, x.shape[2], x.shape[3]])
t_ = t_[:, :, None, None]
x = x + t_
return self.act_skip(x + skip)
class DecoderBlock(nn.Module):
def __init__(self, in_c, out_c, kernel_size, upsample="none", t_in_c=256, att_num_head=-1, block_deep=4):
super(DecoderBlock, self).__init__()
self.in_c = in_c
self.out_c = out_c
self.model_list_len = block_deep # 一个block有多少次卷积
self.model_list = nn.ModuleList()
if upsample == "subpix":
self.model_list.append(nn.Sequential(
SubPixelConv(in_c, out_c, kernel_size=3),
nn.GroupNorm(16, out_c, eps=1e-6),
Swish()
))
self.upsample = SubPixelConv(in_c, in_c, kernel_size=3)
elif upsample == "convt":
self.model_list.append(nn.Sequential(
nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1),
nn.GroupNorm(16, out_c, eps=1e-6),
Swish()
))
self.upsample = nn.ConvTranspose2d(in_c, in_c, kernel_size=4, stride=2, padding=1)
else:
self.model_list.append(nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=kernel_size, stride=1,
padding=kernel_size // 2),
nn.GroupNorm(16, out_c, eps=1e-6),
Swish()
))
self.upsample = nn.Identity()
if att_num_head != 0: # upsample != "none"
self.att_block = AttentionBlock(out_c, num_heads=att_num_head)
else:
self.att_block = nn.Identity()
for _ in range(self.model_list_len - 2):
self.model_list.append(nn.Sequential(nn.Conv2d(out_c, out_c, kernel_size=kernel_size, stride=1,
padding=kernel_size // 2),
nn.GroupNorm(16, out_c, eps=1e-6),
Swish()))
self.model_list.append(nn.Sequential(nn.Conv2d(out_c, out_c, kernel_size=kernel_size, stride=1,
padding=kernel_size // 2),
nn.GroupNorm(16, out_c, eps=1e-6)))
# 编码时间t
self.encode_t = nn.ModuleList([nn.Linear(t_in_c, out_c) for _ in range(len(self.model_list) - 1)])
self.conv_skip = nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, padding=0)
self.act_skip = Swish()
def forward(self, x, t):
skip = self.upsample(x)
skip = self.conv_skip(skip)
for i, layer in enumerate(self.model_list):
x = layer(x)
if i == 0:
x = self.att_block(x)
if i < self.model_list_len - 1:
t_ = self.encode_t[i](t)
# t_ = torch.tile(t[:, :, None, None], dims=[1, 1, x.shape[2], x.shape[3]])
t_ = t_[:, :, None, None]
x = x + t_
return self.act_skip(x + skip)
class Encoder(nn.Module):
def __init__(self,
model_in_c=8,
out_cs=(64, 64, 128, 128, 256, 256, 512, 512),
down_sample=(0, 0, 1, 0, 1, 0, 1, 0),
skip_out=(0, 1, 0, 1, 0, 1, 0, 1),
att_num_heads=(-1, -1, -1, -1, -1, -1, -1, -1),
t_in_c=256,
block_deep=4):
"""
:param out_cs: 每一个块输出的尺寸
:param down_sample: 是否下采样
:param skip_out: unet的条连
"""
super(Encoder, self).__init__()
self.skip_out = skip_out
self.model_list = nn.ModuleList()
for i, (out_c, down, att_num_head) in enumerate(zip(out_cs, down_sample, att_num_heads)):
in_c = model_in_c if i == 0 else out_cs[i - 1]
self.model_list.append(
EncoderBlock(in_c, out_cs[i], kernel_size=3, stride=down + 1, t_in_c=t_in_c,
att_num_head=att_num_head, block_deep=block_deep))
def forward(self, x, t):
res_x = []
for i, layer in enumerate(self.model_list):
x = layer(x, t)
if self.skip_out[i] == 1:
res_x.append(x)
return res_x
class Decoder(nn.Module):
def __init__(self,
in_c,
model_out_c=8,
out_cs=(512, 256, 256, 128, 128, 64, 64, 32),
up_sample=("none", "convt", "none", "subpix", "none", "subpix", "none", "none"),
skip_out=(1, 0, 1, 0, 1, 0, 1, 0),
att_num_heads=(-1, -1, -1, -1, -1, -1, -1, -1),
t_in_c=256,
block_deep=4):
"""
:param out_cs: 每一个块输出的尺寸
:param up_sample: 上采样方法,none是不进行上采样
:param skip_out: unet的跳连
"""
super(Decoder, self).__init__()
self.skip_out = skip_out
self.model_list = nn.ModuleList()
for i, (out_c, up, att_num_head) in enumerate(zip(out_cs, up_sample, att_num_heads)):
if self.skip_out[i] == 1 and i > 0:
in_c *= 2
self.model_list.append(
DecoderBlock(in_c, out_cs[i], kernel_size=3, upsample=up, t_in_c=t_in_c,
att_num_head=att_num_head, block_deep=block_deep))
in_c = out_cs[i]
self.Conv1 = nn.Conv2d(out_cs[-1], model_out_c, kernel_size=1, stride=1, padding=0)
def forward(self, x, t):
x_list = x
# print([xx.shape for xx in x_list])
x = None
for i, layer in enumerate(self.model_list):
if self.skip_out[i] == 1:
# print("skip_x:", x_list[-1].shape)
if i == 0:
x = x_list.pop()
else:
x = torch.cat([x, x_list.pop()], dim=1)
# print("x:", x.shape)
x = layer(x, t)
x = self.Conv1(x)
return x
class UNet(nn.Module):
def __init__(self,
en_out_c,
en_down,
en_skip,
en_att_heads,
de_out_c,
de_up,
de_skip,
de_att_heads,
t_out_c,
vae_c=8,
block_deep=4):
"""
:param en_out_c: encoder参数
:param en_down:
:param en_skip:
:param de_out_c: decoder参数
:param de_up:
:param de_skip:
"""
super(UNet, self).__init__()
self.encoder = Encoder(model_in_c=vae_c,
out_cs=en_out_c,
down_sample=en_down,
skip_out=en_skip,
att_num_heads=en_att_heads,
t_in_c=t_out_c,
block_deep=block_deep)
self.decoder = Decoder(in_c=en_out_c[-1],
model_out_c=vae_c,
out_cs=de_out_c,
up_sample=de_up,
skip_out=de_skip,
att_num_heads=de_att_heads,
t_in_c=t_out_c,
block_deep=block_deep)
self.t_encoder = TEncoder(t_out_c)
def forward(self, x, t):
t = self.t_encoder(t)
# print("encoded_t:", torch.mean(t), torch.std(t))
# print("t:", t.shape)
encoder_out = self.encoder(x, t)
# print("encode:")
# for e in encoder_out:
# print(e.shape)
decoder_out = self.decoder(encoder_out, t)
# print("decoder:")
# print(decoder_out.shape)
return decoder_out