Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from typing import Any, Tuple, Union | |
| from utils import ( | |
| ImageType, | |
| crop_image_part, | |
| ) | |
| from layers import ( | |
| SpectralConv2d, | |
| InitLayer, | |
| SLEBlock, | |
| UpsampleBlockT1, | |
| UpsampleBlockT2, | |
| DownsampleBlockT1, | |
| DownsampleBlockT2, | |
| Decoder, | |
| ) | |
| from huggan.pytorch.huggan_mixin import HugGANModelHubMixin | |
| class Generator(nn.Module, HugGANModelHubMixin): | |
| def __init__(self, in_channels: int, | |
| out_channels: int): | |
| super().__init__() | |
| self._channels = { | |
| 4: 1024, | |
| 8: 512, | |
| 16: 256, | |
| 32: 128, | |
| 64: 128, | |
| 128: 64, | |
| 256: 32, | |
| 512: 16, | |
| 1024: 8, | |
| } | |
| self._init = InitLayer( | |
| in_channels=in_channels, | |
| out_channels=self._channels[4], | |
| ) | |
| self._upsample_8 = UpsampleBlockT2(in_channels=self._channels[4], out_channels=self._channels[8] ) | |
| self._upsample_16 = UpsampleBlockT1(in_channels=self._channels[8], out_channels=self._channels[16] ) | |
| self._upsample_32 = UpsampleBlockT2(in_channels=self._channels[16], out_channels=self._channels[32] ) | |
| self._upsample_64 = UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64] ) | |
| self._upsample_128 = UpsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[128] ) | |
| self._upsample_256 = UpsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[256] ) | |
| self._upsample_512 = UpsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[512] ) | |
| self._upsample_1024 = UpsampleBlockT1(in_channels=self._channels[512], out_channels=self._channels[1024]) | |
| self._sle_64 = SLEBlock(in_channels=self._channels[4], out_channels=self._channels[64] ) | |
| self._sle_128 = SLEBlock(in_channels=self._channels[8], out_channels=self._channels[128]) | |
| self._sle_256 = SLEBlock(in_channels=self._channels[16], out_channels=self._channels[256]) | |
| self._sle_512 = SLEBlock(in_channels=self._channels[32], out_channels=self._channels[512]) | |
| self._out_128 = nn.Sequential( | |
| SpectralConv2d( | |
| in_channels=self._channels[128], | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding='same', | |
| bias=False, | |
| ), | |
| nn.Tanh(), | |
| ) | |
| self._out_1024 = nn.Sequential( | |
| SpectralConv2d( | |
| in_channels=self._channels[1024], | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding='same', | |
| bias=False, | |
| ), | |
| nn.Tanh(), | |
| ) | |
| def forward(self, input: torch.Tensor) -> \ | |
| Tuple[torch.Tensor, torch.Tensor]: | |
| size_4 = self._init(input) | |
| size_8 = self._upsample_8(size_4) | |
| size_16 = self._upsample_16(size_8) | |
| size_32 = self._upsample_32(size_16) | |
| size_64 = self._sle_64 (size_4, self._upsample_64 (size_32) ) | |
| size_128 = self._sle_128(size_8, self._upsample_128(size_64) ) | |
| size_256 = self._sle_256(size_16, self._upsample_256(size_128)) | |
| size_512 = self._sle_512(size_32, self._upsample_512(size_256)) | |
| size_1024 = self._upsample_1024(size_512) | |
| out_128 = self._out_128 (size_128) | |
| out_1024 = self._out_1024(size_1024) | |
| return out_1024, out_128 | |
| class Discriminrator(nn.Module, HugGANModelHubMixin): | |
| def __init__(self, in_channels: int): | |
| super().__init__() | |
| self._channels = { | |
| 4: 1024, | |
| 8: 512, | |
| 16: 256, | |
| 32: 128, | |
| 64: 128, | |
| 128: 64, | |
| 256: 32, | |
| 512: 16, | |
| 1024: 8, | |
| } | |
| self._init = nn.Sequential( | |
| SpectralConv2d( | |
| in_channels=in_channels, | |
| out_channels=self._channels[1024], | |
| kernel_size=4, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.LeakyReLU(negative_slope=0.2), | |
| SpectralConv2d( | |
| in_channels=self._channels[1024], | |
| out_channels=self._channels[512], | |
| kernel_size=4, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(num_features=self._channels[512]), | |
| nn.LeakyReLU(negative_slope=0.2), | |
| ) | |
| self._downsample_256 = DownsampleBlockT2(in_channels=self._channels[512], out_channels=self._channels[256]) | |
| self._downsample_128 = DownsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[128]) | |
| self._downsample_64 = DownsampleBlockT2(in_channels=self._channels[128], out_channels=self._channels[64] ) | |
| self._downsample_32 = DownsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[32] ) | |
| self._downsample_16 = DownsampleBlockT2(in_channels=self._channels[32], out_channels=self._channels[16] ) | |
| self._sle_64 = SLEBlock(in_channels=self._channels[512], out_channels=self._channels[64]) | |
| self._sle_32 = SLEBlock(in_channels=self._channels[256], out_channels=self._channels[32]) | |
| self._sle_16 = SLEBlock(in_channels=self._channels[128], out_channels=self._channels[16]) | |
| self._small_track = nn.Sequential( | |
| SpectralConv2d( | |
| in_channels=in_channels, | |
| out_channels=self._channels[256], | |
| kernel_size=4, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.LeakyReLU(negative_slope=0.2), | |
| DownsampleBlockT1(in_channels=self._channels[256], out_channels=self._channels[128]), | |
| DownsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[64] ), | |
| DownsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[32] ), | |
| ) | |
| self._features_large = nn.Sequential( | |
| SpectralConv2d( | |
| in_channels=self._channels[16] , | |
| out_channels=self._channels[8], | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(num_features=self._channels[8]), | |
| nn.LeakyReLU(negative_slope=0.2), | |
| SpectralConv2d( | |
| in_channels=self._channels[8], | |
| out_channels=1, | |
| kernel_size=4, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ) | |
| ) | |
| self._features_small = nn.Sequential( | |
| SpectralConv2d( | |
| in_channels=self._channels[32], | |
| out_channels=1, | |
| kernel_size=4, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ), | |
| ) | |
| self._decoder_large = Decoder(in_channels=self._channels[16], out_channels=3) | |
| self._decoder_small = Decoder(in_channels=self._channels[32], out_channels=3) | |
| self._decoder_piece = Decoder(in_channels=self._channels[32], out_channels=3) | |
| def forward(self, images_1024: torch.Tensor, | |
| images_128: torch.Tensor, | |
| image_type: ImageType) -> \ | |
| Union[ | |
| torch.Tensor, | |
| Tuple[torch.Tensor, Tuple[Any, Any, Any]] | |
| ]: | |
| # large track | |
| down_512 = self._init(images_1024) | |
| down_256 = self._downsample_256(down_512) | |
| down_128 = self._downsample_128(down_256) | |
| down_64 = self._downsample_64(down_128) | |
| down_64 = self._sle_64(down_512, down_64) | |
| down_32 = self._downsample_32(down_64) | |
| down_32 = self._sle_32(down_256, down_32) | |
| down_16 = self._downsample_16(down_32) | |
| down_16 = self._sle_16(down_128, down_16) | |
| # small track | |
| down_small = self._small_track(images_128) | |
| # features | |
| features_large = self._features_large(down_16).view(-1) | |
| features_small = self._features_small(down_small).view(-1) | |
| features = torch.cat([features_large, features_small], dim=0) | |
| # decoder | |
| if image_type != ImageType.FAKE: | |
| dec_large = self._decoder_large(down_16) | |
| dec_small = self._decoder_small(down_small) | |
| dec_piece = self._decoder_piece(crop_image_part(down_32, image_type)) | |
| return features, (dec_large, dec_small, dec_piece) | |
| return features | |