De-limiter / models /load_models.py
jeonchangbin49's picture
first commit
a00b67a
import numpy as np
import torch
import torch.nn as nn
from asteroid_filterbanks import make_enc_dec
from asteroid.masknn import TDConvNet
import utils
from .base_models import (
BaseEncoderMaskerDecoderWithConfigs,
BaseEncoderMaskerDecoderWithConfigsMaskOnOutput,
BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid,
)
def load_model_with_args(args):
if args.model_loss_params.architecture == "conv_tasnet_mask_on_output":
encoder, decoder = make_enc_dec(
"free",
n_filters=args.conv_tasnet_params.n_filters,
kernel_size=args.conv_tasnet_params.kernel_size,
stride=args.conv_tasnet_params.stride,
sample_rate=args.sample_rate,
)
masker = TDConvNet(
in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo
n_src=1, # for de-limit task.
out_chan=encoder.n_feats_out,
n_blocks=args.conv_tasnet_params.n_blocks,
n_repeats=args.conv_tasnet_params.n_repeats,
bn_chan=args.conv_tasnet_params.bn_chan,
hid_chan=args.conv_tasnet_params.hid_chan,
skip_chan=args.conv_tasnet_params.skip_chan,
# conv_kernel_size=args.conv_tasnet_params.conv_kernel_size,
norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN',
mask_act=args.conv_tasnet_params.mask_act,
# causal=args.conv_tasnet_params.causal,
)
model = BaseEncoderMaskerDecoderWithConfigsMaskOnOutput(
encoder,
masker,
decoder,
encoder_activation=args.conv_tasnet_params.encoder_activation,
use_encoder=True,
apply_mask=True,
use_decoder=True,
decoder_activation=args.conv_tasnet_params.decoder_activation,
)
model.use_encoder_to_target = False
elif args.model_loss_params.architecture == "conv_tasnet":
encoder, decoder = make_enc_dec(
"free",
n_filters=args.conv_tasnet_params.n_filters,
kernel_size=args.conv_tasnet_params.kernel_size,
stride=args.conv_tasnet_params.stride,
sample_rate=args.sample_rate,
)
masker = TDConvNet(
in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo
n_src=args.conv_tasnet_params.n_src, # for de-limit task with the standard conv-tasnet setting.
out_chan=encoder.n_feats_out,
n_blocks=args.conv_tasnet_params.n_blocks,
n_repeats=args.conv_tasnet_params.n_repeats,
bn_chan=args.conv_tasnet_params.bn_chan,
hid_chan=args.conv_tasnet_params.hid_chan,
skip_chan=args.conv_tasnet_params.skip_chan,
# conv_kernel_size=args.conv_tasnet_params.conv_kernel_size,
norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN',
mask_act=args.conv_tasnet_params.mask_act,
# causal=args.conv_tasnet_params.causal,
)
model = BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid(
encoder,
masker,
decoder,
encoder_activation=args.conv_tasnet_params.encoder_activation,
use_encoder=True,
apply_mask=False if args.conv_tasnet_params.synthesis else True,
use_decoder=True,
decoder_activation=args.conv_tasnet_params.decoder_activation,
)
model.use_encoder_to_target = False
return model