from inspect import isfunction from torch import nn from torch.nn import init def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def cycle(dl): while True: for data in dl: yield data def num_to_groups(num, divisor): groups = num // divisor remainder = num % divisor arr = [divisor] * groups if remainder > 0: arr.append(remainder) return arr def initialize_weights(net_l, scale=0.1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers, seq=False): layers = [] for _ in range(n_layers): layers.append(block()) if seq: return nn.Sequential(*layers) else: return nn.ModuleList(layers)