Spaces:
Running
on
Zero
Running
on
Zero
| import torch.nn.functional as F | |
| from torch import nn | |
| class PreactResBlock(nn.Sequential): | |
| def __init__(self, dim): | |
| super().__init__( | |
| nn.GroupNorm(dim // 16, dim), | |
| nn.GELU(), | |
| nn.Conv2d(dim, dim, 3, padding=1), | |
| nn.GroupNorm(dim // 16, dim), | |
| nn.GELU(), | |
| nn.Conv2d(dim, dim, 3, padding=1), | |
| ) | |
| def forward(self, x): | |
| return x + super().forward(x) | |
| class UNetBlock(nn.Module): | |
| def __init__(self, input_dim, output_dim=None, scale_factor=1.0): | |
| super().__init__() | |
| if output_dim is None: | |
| output_dim = input_dim | |
| self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1) | |
| self.res_block1 = PreactResBlock(output_dim) | |
| self.res_block2 = PreactResBlock(output_dim) | |
| self.downsample = self.upsample = nn.Identity() | |
| if scale_factor > 1: | |
| self.upsample = nn.Upsample(scale_factor=scale_factor) | |
| elif scale_factor < 1: | |
| self.downsample = nn.Upsample(scale_factor=scale_factor) | |
| def forward(self, x, h=None): | |
| """ | |
| Args: | |
| x: (b c h w), last output | |
| h: (b c h w), skip output | |
| Returns: | |
| o: (b c h w), output | |
| s: (b c h w), skip output | |
| """ | |
| x = self.upsample(x) | |
| if h is not None: | |
| assert x.shape == h.shape, f"{x.shape} != {h.shape}" | |
| x = x + h | |
| x = self.pre_conv(x) | |
| x = self.res_block1(x) | |
| x = self.res_block2(x) | |
| return self.downsample(x), x | |
| class UNet(nn.Module): | |
| def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) | |
| self.encoder_blocks = nn.ModuleList( | |
| [ | |
| UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5) | |
| for i in range(num_blocks) | |
| ] | |
| ) | |
| self.middle_blocks = nn.ModuleList( | |
| [UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)] | |
| ) | |
| self.decoder_blocks = nn.ModuleList( | |
| [ | |
| UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2) | |
| for i in reversed(range(num_blocks)) | |
| ] | |
| ) | |
| self.head = nn.Sequential( | |
| nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), | |
| nn.GELU(), | |
| nn.Conv2d(hidden_dim, output_dim, 1), | |
| ) | |
| def scale_factor(self): | |
| return 2 ** len(self.encoder_blocks) | |
| def pad_to_fit(self, x): | |
| """ | |
| Args: | |
| x: (b c h w), input | |
| Returns: | |
| x: (b c h' w'), padded input | |
| """ | |
| hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor | |
| wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor | |
| return F.pad(x, (0, wpad, 0, hpad)) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: (b c h w), input | |
| Returns: | |
| o: (b c h w), output | |
| """ | |
| shape = x.shape | |
| x = self.pad_to_fit(x) | |
| x = self.input_proj(x) | |
| s_list = [] | |
| for block in self.encoder_blocks: | |
| x, s = block(x) | |
| s_list.append(s) | |
| for block in self.middle_blocks: | |
| x, _ = block(x) | |
| for block, s in zip(self.decoder_blocks, reversed(s_list)): | |
| x, _ = block(x, s) | |
| x = self.head(x) | |
| x = x[..., : shape[2], : shape[3]] | |
| return x | |
| def test(self, shape=(3, 512, 256)): | |
| import ptflops | |
| macs, params = ptflops.get_model_complexity_info( | |
| self, | |
| shape, | |
| as_strings=True, | |
| print_per_layer_stat=True, | |
| verbose=True, | |
| ) | |
| print(f"macs: {macs}") | |
| print(f"params: {params}") | |
| def main(): | |
| model = UNet(3, 3) | |
| model.test() | |
| if __name__ == "__main__": | |
| main() | |