De-limiter / eval_delimit /calc_flops.py
jeonchangbin49's picture
first commit
a00b67a
import os
import argparse
import random
import torch
from deepspeed.profiling.flops_profiler import get_model_profile
from utils import get_config
from models import load_model_with_args
# def main():
parser = argparse.ArgumentParser(description="FLOPs calculation")
parser.add_argument(
"-c", "--config", default="delimit_6_s", type=str, help="Name of the setting file."
)
config_args = parser.parse_args()
args = get_config(config_args.config)
print(args)
with torch.cuda.device(0):
model = load_model_with_args(args)
batch_size = 1
flops, macs, params = get_model_profile(
model=model, # model
input_shape=(batch_size, 2, 44100 * 60), # input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument.
args=[], # list of positional arguments to the model.
kwargs={}, # dictionary of keyword arguments to the model.
print_profile=True, # prints the model graph with the measured profile attached to each module
detailed=True, # print the detailed profile
module_depth=-1, # depth into the nested modules, with -1 being the inner most modules
top_modules=1, # the number of top modules to print aggregated profile
warm_up=1, # the number of warm-ups before measuring the time of each module
as_string=True, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
output_file=None, # path to the output file. If None, the profiler prints to stdout.
ignore_modules=None,
) # the list of modules to ignore in the profiling
print(args.dir_params.exp_name)
print('flops: ', flops)
print('macs: ', macs)
print('params: ', params)