Spaces:
Running
Running
import numpy as np | |
import torch | |
import torch.nn as nn | |
from asteroid_filterbanks import make_enc_dec | |
from asteroid.masknn import TDConvNet | |
import utils | |
from .base_models import ( | |
BaseEncoderMaskerDecoderWithConfigs, | |
BaseEncoderMaskerDecoderWithConfigsMaskOnOutput, | |
BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid, | |
) | |
def load_model_with_args(args): | |
if args.model_loss_params.architecture == "conv_tasnet_mask_on_output": | |
encoder, decoder = make_enc_dec( | |
"free", | |
n_filters=args.conv_tasnet_params.n_filters, | |
kernel_size=args.conv_tasnet_params.kernel_size, | |
stride=args.conv_tasnet_params.stride, | |
sample_rate=args.sample_rate, | |
) | |
masker = TDConvNet( | |
in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo | |
n_src=1, # for de-limit task. | |
out_chan=encoder.n_feats_out, | |
n_blocks=args.conv_tasnet_params.n_blocks, | |
n_repeats=args.conv_tasnet_params.n_repeats, | |
bn_chan=args.conv_tasnet_params.bn_chan, | |
hid_chan=args.conv_tasnet_params.hid_chan, | |
skip_chan=args.conv_tasnet_params.skip_chan, | |
# conv_kernel_size=args.conv_tasnet_params.conv_kernel_size, | |
norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN', | |
mask_act=args.conv_tasnet_params.mask_act, | |
# causal=args.conv_tasnet_params.causal, | |
) | |
model = BaseEncoderMaskerDecoderWithConfigsMaskOnOutput( | |
encoder, | |
masker, | |
decoder, | |
encoder_activation=args.conv_tasnet_params.encoder_activation, | |
use_encoder=True, | |
apply_mask=True, | |
use_decoder=True, | |
decoder_activation=args.conv_tasnet_params.decoder_activation, | |
) | |
model.use_encoder_to_target = False | |
elif args.model_loss_params.architecture == "conv_tasnet": | |
encoder, decoder = make_enc_dec( | |
"free", | |
n_filters=args.conv_tasnet_params.n_filters, | |
kernel_size=args.conv_tasnet_params.kernel_size, | |
stride=args.conv_tasnet_params.stride, | |
sample_rate=args.sample_rate, | |
) | |
masker = TDConvNet( | |
in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo | |
n_src=args.conv_tasnet_params.n_src, # for de-limit task with the standard conv-tasnet setting. | |
out_chan=encoder.n_feats_out, | |
n_blocks=args.conv_tasnet_params.n_blocks, | |
n_repeats=args.conv_tasnet_params.n_repeats, | |
bn_chan=args.conv_tasnet_params.bn_chan, | |
hid_chan=args.conv_tasnet_params.hid_chan, | |
skip_chan=args.conv_tasnet_params.skip_chan, | |
# conv_kernel_size=args.conv_tasnet_params.conv_kernel_size, | |
norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN', | |
mask_act=args.conv_tasnet_params.mask_act, | |
# causal=args.conv_tasnet_params.causal, | |
) | |
model = BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid( | |
encoder, | |
masker, | |
decoder, | |
encoder_activation=args.conv_tasnet_params.encoder_activation, | |
use_encoder=True, | |
apply_mask=False if args.conv_tasnet_params.synthesis else True, | |
use_decoder=True, | |
decoder_activation=args.conv_tasnet_params.decoder_activation, | |
) | |
model.use_encoder_to_target = False | |
return model | |