Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
from torch import nn | |
from torch.nn import Parameter | |
class Residual(nn.Module): | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x, *args, **kwargs): | |
return self.fn(x, *args, **kwargs) + x | |
class SinusoidalPosEmb(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, None] * emb[None, :] | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
class Mish(nn.Module): | |
def forward(self, x): | |
return x * torch.tanh(F.softplus(x)) | |
class Rezero(nn.Module): | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
self.g = nn.Parameter(torch.zeros(1)) | |
def forward(self, x): | |
return self.fn(x) * self.g | |
# building block modules | |
class Block(nn.Module): | |
def __init__(self, dim, dim_out, groups=8): | |
super().__init__() | |
if groups == 0: | |
self.block = nn.Sequential( | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(dim, dim_out, 3), | |
Mish() | |
) | |
else: | |
self.block = nn.Sequential( | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(dim, dim_out, 3), | |
nn.GroupNorm(groups, dim_out), | |
Mish() | |
) | |
def forward(self, x): | |
return self.block(x) | |
class ResnetBlock(nn.Module): | |
def __init__(self, dim, dim_out, *, time_emb_dim=0, groups=8): | |
super().__init__() | |
if time_emb_dim > 0: | |
self.mlp = nn.Sequential( | |
Mish(), | |
nn.Linear(time_emb_dim, dim_out) | |
) | |
self.block1 = Block(dim, dim_out, groups=groups) | |
self.block2 = Block(dim_out, dim_out, groups=groups) | |
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() | |
def forward(self, x, time_emb=None, cond=None): | |
h = self.block1(x) | |
if time_emb is not None: | |
h += self.mlp(time_emb)[:, :, None, None] | |
if cond is not None: | |
h += cond | |
h = self.block2(h) | |
return h + self.res_conv(x) | |
class Upsample(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.conv = nn.Sequential( | |
nn.ConvTranspose2d(dim, dim, 4, 2, 1), | |
) | |
def forward(self, x): | |
return self.conv(x) | |
class Downsample(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.conv = nn.Sequential( | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(dim, dim, 3, 2), | |
) | |
def forward(self, x): | |
return self.conv(x) | |
class LinearAttention(nn.Module): | |
def __init__(self, dim, heads=4, dim_head=32): | |
super().__init__() | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
self.to_out = nn.Conv2d(hidden_dim, dim, 1) | |
def forward(self, x): | |
b, c, h, w = x.shape | |
qkv = self.to_qkv(x) | |
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) | |
k = k.softmax(dim=-1) | |
context = torch.einsum('bhdn,bhen->bhde', k, v) | |
out = torch.einsum('bhde,bhdn->bhen', context, q) | |
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) | |
return self.to_out(out) | |
class MultiheadAttention(nn.Module): | |
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, | |
add_bias_kv=False, add_zero_attn=False): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.kdim = kdim if kdim is not None else embed_dim | |
self.vdim = vdim if vdim is not None else embed_dim | |
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.head_dim = embed_dim // num_heads | |
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | |
self.scaling = self.head_dim ** -0.5 | |
if self.qkv_same_dim: | |
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) | |
else: | |
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) | |
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) | |
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) | |
if bias: | |
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) | |
else: | |
self.register_parameter('in_proj_bias', None) | |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
if add_bias_kv: | |
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) | |
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) | |
else: | |
self.bias_k = self.bias_v = None | |
self.add_zero_attn = add_zero_attn | |
self.reset_parameters() | |
self.enable_torch_version = False | |
if hasattr(F, "multi_head_attention_forward"): | |
self.enable_torch_version = True | |
else: | |
self.enable_torch_version = False | |
self.last_attn_probs = None | |
def reset_parameters(self): | |
if self.qkv_same_dim: | |
nn.init.xavier_uniform_(self.in_proj_weight) | |
else: | |
nn.init.xavier_uniform_(self.k_proj_weight) | |
nn.init.xavier_uniform_(self.v_proj_weight) | |
nn.init.xavier_uniform_(self.q_proj_weight) | |
nn.init.xavier_uniform_(self.out_proj.weight) | |
if self.in_proj_bias is not None: | |
nn.init.constant_(self.in_proj_bias, 0.) | |
nn.init.constant_(self.out_proj.bias, 0.) | |
if self.bias_k is not None: | |
nn.init.xavier_normal_(self.bias_k) | |
if self.bias_v is not None: | |
nn.init.xavier_normal_(self.bias_v) | |
def forward( | |
self, | |
query, key, value, | |
key_padding_mask=None, | |
need_weights=True, | |
attn_mask=None, | |
before_softmax=False, | |
need_head_weights=False, | |
): | |
"""Input shape: [B, T, C] | |
Args: | |
key_padding_mask (ByteTensor, optional): mask to exclude | |
keys that are pads, of shape `(batch, src_len)`, where | |
padding elements are indicated by 1s. | |
need_weights (bool, optional): return the attention weights, | |
averaged over heads (default: False). | |
attn_mask (ByteTensor, optional): typically used to | |
implement causal attention, where the mask prevents the | |
attention from looking forward in time (default: None). | |
before_softmax (bool, optional): return the raw attention | |
weights and values before the attention softmax. | |
need_head_weights (bool, optional): return the attention | |
weights for each head. Implies *need_weights*. Default: | |
return the average attention weights over all heads. | |
""" | |
if need_head_weights: | |
need_weights = True | |
query = query.transpose(0, 1) | |
key = key.transpose(0, 1) | |
value = value.transpose(0, 1) | |
tgt_len, bsz, embed_dim = query.size() | |
assert embed_dim == self.embed_dim | |
assert list(query.size()) == [tgt_len, bsz, embed_dim] | |
attn_output, attn_output_weights = F.multi_head_attention_forward( | |
query, key, value, self.embed_dim, self.num_heads, | |
self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, | |
self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, | |
self.training, key_padding_mask, need_weights, attn_mask) | |
attn_output = attn_output.transpose(0, 1) | |
return attn_output, attn_output_weights | |
def in_proj_qkv(self, query): | |
return self._in_proj(query).chunk(3, dim=-1) | |
def in_proj_q(self, query): | |
if self.qkv_same_dim: | |
return self._in_proj(query, end=self.embed_dim) | |
else: | |
bias = self.in_proj_bias | |
if bias is not None: | |
bias = bias[:self.embed_dim] | |
return F.linear(query, self.q_proj_weight, bias) | |
def in_proj_k(self, key): | |
if self.qkv_same_dim: | |
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) | |
else: | |
weight = self.k_proj_weight | |
bias = self.in_proj_bias | |
if bias is not None: | |
bias = bias[self.embed_dim:2 * self.embed_dim] | |
return F.linear(key, weight, bias) | |
def in_proj_v(self, value): | |
if self.qkv_same_dim: | |
return self._in_proj(value, start=2 * self.embed_dim) | |
else: | |
weight = self.v_proj_weight | |
bias = self.in_proj_bias | |
if bias is not None: | |
bias = bias[2 * self.embed_dim:] | |
return F.linear(value, weight, bias) | |
def _in_proj(self, input, start=0, end=None): | |
weight = self.in_proj_weight | |
bias = self.in_proj_bias | |
weight = weight[start:end, :] | |
if bias is not None: | |
bias = bias[start:end] | |
return F.linear(input, weight, bias) | |
class ResidualDenseBlock_5C(nn.Module): | |
def __init__(self, nf=64, gc=32, bias=True): | |
super(ResidualDenseBlock_5C, self).__init__() | |
# gc: growth channel, i.e. intermediate channels | |
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) | |
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) | |
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) | |
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) | |
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) | |
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
# initialization | |
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) | |
def forward(self, x): | |
x1 = self.lrelu(self.conv1(x)) | |
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) | |
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) | |
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) | |
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
return x5 * 0.2 + x | |
class RRDB(nn.Module): | |
'''Residual in Residual Dense Block''' | |
def __init__(self, nf, gc=32): | |
super(RRDB, self).__init__() | |
self.RDB1 = ResidualDenseBlock_5C(nf, gc) | |
self.RDB2 = ResidualDenseBlock_5C(nf, gc) | |
self.RDB3 = ResidualDenseBlock_5C(nf, gc) | |
def forward(self, x): | |
out = self.RDB1(x) | |
out = self.RDB2(out) | |
out = self.RDB3(out) | |
return out * 0.2 + x | |