Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| def tile_features(features, num_pieces): | |
| _, _, h, w = features.size() | |
| num_pieces_per_line = int(math.sqrt(num_pieces)) | |
| h_per_patch = h // num_pieces_per_line | |
| w_per_patch = w // num_pieces_per_line | |
| """ | |
| +-----+-----+ | |
| | 1 | 2 | | |
| +-----+-----+ | |
| | 3 | 4 | | |
| +-----+-----+ | |
| +-----+-----+-----+-----+ | |
| | 1 | 2 | 3 | 4 | | |
| +-----+-----+-----+-----+ | |
| """ | |
| patches = [] | |
| for splitted_features in torch.split(features, h_per_patch, dim=2): | |
| for patch in torch.split(splitted_features, w_per_patch, dim=3): | |
| patches.append(patch) | |
| return torch.cat(patches, dim=0) | |
| def merge_features(features, num_pieces, batch_size): | |
| """ | |
| +-----+-----+-----+-----+ | |
| | 1 | 2 | 3 | 4 | | |
| +-----+-----+-----+-----+ | |
| +-----+-----+ | |
| | 1 | 2 | | |
| +-----+-----+ | |
| | 3 | 4 | | |
| +-----+-----+ | |
| """ | |
| features_list = list(torch.split(features, batch_size)) | |
| num_pieces_per_line = int(math.sqrt(num_pieces)) | |
| index = 0 | |
| ext_h_list = [] | |
| for _ in range(num_pieces_per_line): | |
| ext_w_list = [] | |
| for _ in range(num_pieces_per_line): | |
| ext_w_list.append(features_list[index]) | |
| index += 1 | |
| ext_h_list.append(torch.cat(ext_w_list, dim=3)) | |
| features = torch.cat(ext_h_list, dim=2) | |
| return features | |
| def puzzle_module(x, func_list, num_pieces): | |
| tiled_x = tile_features(x, num_pieces) | |
| for func in func_list: | |
| tiled_x = func(tiled_x) | |
| merged_x = merge_features(tiled_x, num_pieces, x.size()[0]) | |
| return merged_x | |