Spaces:
Sleeping
Sleeping
| import os | |
| import argparse | |
| import json | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from torch.utils.data import DataLoader | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping | |
| from look2hear.utils.parser_utils import prepare_parser_from_dict, parse_args_as_dict | |
| import look2hear.models | |
| import yaml | |
| from ptflops import get_model_complexity_info | |
| from rich import print | |
| def check_parameters(net): | |
| """ | |
| Returns module parameters. Mb | |
| """ | |
| parameters = sum(param.numel() for param in net.parameters()) | |
| return parameters / 10 ** 6 | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--exp_dir", default="exp/tmp", help="Full path to save best validation model" | |
| ) | |
| with open("configs/tiger.yml") as f: | |
| def_conf = yaml.safe_load(f) | |
| parser = prepare_parser_from_dict(def_conf, parser=parser) | |
| arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) | |
| audiomodel = getattr(look2hear.models, arg_dic["audionet"]["audionet_name"])( | |
| sample_rate=arg_dic["datamodule"]["data_config"]["sample_rate"], | |
| **arg_dic["audionet"]["audionet_config"] | |
| ) | |
| # 配置GPU为mps | |
| device = torch.device("mps") | |
| a = torch.randn(1, 1, 16000).to(device) | |
| total_macs = 0 | |
| total_params = 0 | |
| model = audiomodel.to(device) | |
| with torch.no_grad(): | |
| macs, params = get_model_complexity_info( | |
| model, (16000,), as_strings=False, print_per_layer_stat=True, verbose=False | |
| ) | |
| print(model(a).shape) | |
| total_macs += macs | |
| total_params += params | |
| print("MACs: ", total_macs / 10.0 ** 9) | |
| print("Params: ", total_params / 10.0 ** 6) |