import numpy as np import time import torch import torch.nn as nn def move_data_to_device(x, device): if 'float' in str(x.dtype): x = torch.Tensor(x) elif 'int' in str(x.dtype): x = torch.LongTensor(x) else: return x return x.to(device) def do_mixup(x, mixup_lambda): """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes (1, 3, 5, ...). Args: x: (batch_size * 2, ...) mixup_lambda: (batch_size * 2,) Returns: out: (batch_size, ...) """ out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \ x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1) return out def append_to_dict(dict, key, value): if key in dict.keys(): dict[key].append(value) else: dict[key] = [value] def interpolate(x, ratio): """Interpolate data in time domain. This is used to compensate the resolution reduction in downsampling of a CNN. Args: x: (batch_size, time_steps, classes_num) ratio: int, ratio to interpolate Returns: upsampled: (batch_size, time_steps * ratio, classes_num) """ (batch_size, time_steps, classes_num) = x.shape upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) return upsampled def pad_framewise_output(framewise_output, frames_num): """Pad framewise_output to the same length as input frames. The pad value is the same as the value of the last frame. Args: framewise_output: (batch_size, frames_num, classes_num) frames_num: int, number of frames to pad Outputs: output: (batch_size, frames_num, classes_num) """ pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1) """tensor for padding""" output = torch.cat((framewise_output, pad), dim=1) """(batch_size, frames_num, classes_num)""" return output def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def count_flops(model, audio_length): """Count flops. Code modified from others' implementation. """ multiply_adds = True list_conv2d=[] def conv2d_hook(self, input, output): batch_size, input_channels, input_height, input_width = input[0].size() output_channels, output_height, output_width = output[0].size() kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) bias_ops = 1 if self.bias is not None else 0 params = output_channels * (kernel_ops + bias_ops) flops = batch_size * params * output_height * output_width list_conv2d.append(flops) list_conv1d=[] def conv1d_hook(self, input, output): batch_size, input_channels, input_length = input[0].size() output_channels, output_length = output[0].size() kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) bias_ops = 1 if self.bias is not None else 0 params = output_channels * (kernel_ops + bias_ops) flops = batch_size * params * output_length list_conv1d.append(flops) list_linear=[] def linear_hook(self, input, output): batch_size = input[0].size(0) if input[0].dim() == 2 else 1 weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) bias_ops = self.bias.nelement() flops = batch_size * (weight_ops + bias_ops) list_linear.append(flops) list_bn=[] def bn_hook(self, input, output): list_bn.append(input[0].nelement() * 2) list_relu=[] def relu_hook(self, input, output): list_relu.append(input[0].nelement() * 2) list_pooling2d=[] def pooling2d_hook(self, input, output): batch_size, input_channels, input_height, input_width = input[0].size() output_channels, output_height, output_width = output[0].size() kernel_ops = self.kernel_size * self.kernel_size bias_ops = 0 params = output_channels * (kernel_ops + bias_ops) flops = batch_size * params * output_height * output_width list_pooling2d.append(flops) list_pooling1d=[] def pooling1d_hook(self, input, output): batch_size, input_channels, input_length = input[0].size() output_channels, output_length = output[0].size() kernel_ops = self.kernel_size[0] bias_ops = 0 params = output_channels * (kernel_ops + bias_ops) flops = batch_size * params * output_length list_pooling2d.append(flops) def foo(net): childrens = list(net.children()) if not childrens: if isinstance(net, nn.Conv2d): net.register_forward_hook(conv2d_hook) elif isinstance(net, nn.Conv1d): net.register_forward_hook(conv1d_hook) elif isinstance(net, nn.Linear): net.register_forward_hook(linear_hook) elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d): net.register_forward_hook(bn_hook) elif isinstance(net, nn.ReLU): net.register_forward_hook(relu_hook) elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d): net.register_forward_hook(pooling2d_hook) elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d): net.register_forward_hook(pooling1d_hook) else: print('Warning: flop of module {} is not counted!'.format(net)) return for c in childrens: foo(c) # Register hook foo(model) device = device = next(model.parameters()).device input = torch.rand(1, audio_length).to(device) out = model(input) total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \ sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d) return total_flops