|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
def count_macs(model, spec_size): |
|
list_conv2d = [] |
|
|
|
def conv2d_hook(self, input, output): |
|
batch_size, input_channels, input_height, input_width = input[0].size() |
|
assert batch_size == 1 |
|
output_channels, output_height, output_width = output[0].size() |
|
|
|
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) |
|
bias_ops = 1 if self.bias is not None else 0 |
|
|
|
params = output_channels * (kernel_ops + bias_ops) |
|
|
|
|
|
macs = batch_size * params * output_height * output_width |
|
|
|
list_conv2d.append(macs) |
|
|
|
list_linear = [] |
|
|
|
def linear_hook(self, input, output): |
|
batch_size = input[0].size(0) if input[0].dim() == 2 else 1 |
|
assert batch_size == 1 |
|
weight_ops = self.weight.nelement() |
|
bias_ops = self.bias.nelement() |
|
|
|
|
|
macs = batch_size * (weight_ops + bias_ops) |
|
list_linear.append(macs) |
|
|
|
def foo(net): |
|
if net.__class__.__name__ == 'Conv2dStaticSamePadding': |
|
net.register_forward_hook(conv2d_hook) |
|
childrens = list(net.children()) |
|
if not childrens: |
|
if isinstance(net, nn.Conv2d): |
|
net.register_forward_hook(conv2d_hook) |
|
elif isinstance(net, nn.Linear): |
|
net.register_forward_hook(linear_hook) |
|
else: |
|
print('Warning: flop of module {} is not counted!'.format(net)) |
|
return |
|
for c in childrens: |
|
foo(c) |
|
|
|
|
|
foo(model) |
|
|
|
device = next(model.parameters()).device |
|
input = torch.rand(spec_size).to(device) |
|
with torch.no_grad(): |
|
model(input) |
|
|
|
total_macs = sum(list_conv2d) + sum(list_linear) |
|
|
|
print("*************Computational Complexity (multiply-adds) **************") |
|
print("Number of Convolutional Layers: ", len(list_conv2d)) |
|
print("Number of Linear Layers: ", len(list_linear)) |
|
print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs))) |
|
print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs)) |
|
print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9)) |
|
print("********************************************************************") |
|
return total_macs |
|
|
|
|
|
def count_macs_transformer(model, spec_size): |
|
"""Count macs. Code modified from others' implementation. |
|
""" |
|
list_conv2d = [] |
|
|
|
def conv2d_hook(self, input, output): |
|
batch_size, input_channels, input_height, input_width = input[0].size() |
|
assert batch_size == 1 |
|
output_channels, output_height, output_width = output[0].size() |
|
|
|
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) |
|
bias_ops = 1 if self.bias is not None else 0 |
|
|
|
params = output_channels * (kernel_ops + bias_ops) |
|
|
|
|
|
macs = batch_size * params * output_height * output_width |
|
|
|
list_conv2d.append(macs) |
|
|
|
list_linear = [] |
|
|
|
def linear_hook(self, input, output): |
|
batch_size = input[0].size(0) if input[0].dim() >= 2 else 1 |
|
assert batch_size == 1 |
|
if input[0].dim() == 3: |
|
|
|
batch_size, seq_len, embed_size = input[0].size() |
|
|
|
weight_ops = self.weight.nelement() |
|
bias_ops = self.bias.nelement() if self.bias is not None else 0 |
|
|
|
macs = batch_size * (weight_ops + bias_ops) * seq_len |
|
else: |
|
|
|
|
|
batch_size, embed_size = input[0].size() |
|
weight_ops = self.weight.nelement() |
|
bias_ops = self.bias.nelement() if self.bias is not None else 0 |
|
|
|
macs = batch_size * (weight_ops + bias_ops) |
|
list_linear.append(macs) |
|
|
|
list_att = [] |
|
|
|
def attention_hook(self, input, output): |
|
|
|
batch_size, seq_len, embed_size = input[0].size() |
|
|
|
|
|
|
|
|
|
macs = batch_size * embed_size * seq_len * seq_len * 2 |
|
list_att.append(macs) |
|
|
|
def foo(net): |
|
childrens = list(net.children()) |
|
if net.__class__.__name__ == "MultiHeadAttention": |
|
net.register_forward_hook(attention_hook) |
|
if not childrens: |
|
if isinstance(net, nn.Conv2d): |
|
net.register_forward_hook(conv2d_hook) |
|
elif isinstance(net, nn.Linear): |
|
net.register_forward_hook(linear_hook) |
|
else: |
|
print('Warning: flop of module {} is not counted!'.format(net)) |
|
return |
|
for c in childrens: |
|
foo(c) |
|
|
|
|
|
foo(model) |
|
|
|
device = next(model.parameters()).device |
|
input = torch.rand(spec_size).to(device) |
|
|
|
with torch.no_grad(): |
|
model(input) |
|
|
|
total_macs = sum(list_conv2d) + sum(list_linear) + sum(list_att) |
|
|
|
print("*************Computational Complexity (multiply-adds) **************") |
|
print("Number of Convolutional Layers: ", len(list_conv2d)) |
|
print("Number of Linear Layers: ", len(list_linear)) |
|
print("Number of Attention Layers: ", len(list_att)) |
|
print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs))) |
|
print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs)) |
|
print("Relative Share of Attention Layers: {:.2f}".format(sum(list_att) / total_macs)) |
|
print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9)) |
|
print("********************************************************************") |
|
return total_macs |
|
|