File size: 3,593 Bytes
a00b67a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import torch
import torch.nn as nn
from asteroid_filterbanks import make_enc_dec

from asteroid.masknn import TDConvNet

import utils
from .base_models import (
    BaseEncoderMaskerDecoderWithConfigs,
    BaseEncoderMaskerDecoderWithConfigsMaskOnOutput,
    BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid,
)


def load_model_with_args(args):
    if args.model_loss_params.architecture == "conv_tasnet_mask_on_output":
        encoder, decoder = make_enc_dec(
            "free",
            n_filters=args.conv_tasnet_params.n_filters,
            kernel_size=args.conv_tasnet_params.kernel_size,
            stride=args.conv_tasnet_params.stride,
            sample_rate=args.sample_rate,
        )
        masker = TDConvNet(
            in_chan=encoder.n_feats_out * args.data_params.nb_channels,  # stereo
            n_src=1,  # for de-limit task.
            out_chan=encoder.n_feats_out,
            n_blocks=args.conv_tasnet_params.n_blocks,
            n_repeats=args.conv_tasnet_params.n_repeats,
            bn_chan=args.conv_tasnet_params.bn_chan,
            hid_chan=args.conv_tasnet_params.hid_chan,
            skip_chan=args.conv_tasnet_params.skip_chan,
            # conv_kernel_size=args.conv_tasnet_params.conv_kernel_size,
            norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN',
            mask_act=args.conv_tasnet_params.mask_act,
            # causal=args.conv_tasnet_params.causal,
        )

        model = BaseEncoderMaskerDecoderWithConfigsMaskOnOutput(
            encoder,
            masker,
            decoder,
            encoder_activation=args.conv_tasnet_params.encoder_activation,
            use_encoder=True,
            apply_mask=True,
            use_decoder=True,
            decoder_activation=args.conv_tasnet_params.decoder_activation,
        )
        model.use_encoder_to_target = False

    elif args.model_loss_params.architecture == "conv_tasnet":
        encoder, decoder = make_enc_dec(
            "free",
            n_filters=args.conv_tasnet_params.n_filters,
            kernel_size=args.conv_tasnet_params.kernel_size,
            stride=args.conv_tasnet_params.stride,
            sample_rate=args.sample_rate,
        )
        masker = TDConvNet(
            in_chan=encoder.n_feats_out * args.data_params.nb_channels,  # stereo
            n_src=args.conv_tasnet_params.n_src,  # for de-limit task with the standard conv-tasnet setting.
            out_chan=encoder.n_feats_out,
            n_blocks=args.conv_tasnet_params.n_blocks,
            n_repeats=args.conv_tasnet_params.n_repeats,
            bn_chan=args.conv_tasnet_params.bn_chan,
            hid_chan=args.conv_tasnet_params.hid_chan,
            skip_chan=args.conv_tasnet_params.skip_chan,
            # conv_kernel_size=args.conv_tasnet_params.conv_kernel_size,
            norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN',
            mask_act=args.conv_tasnet_params.mask_act,
            # causal=args.conv_tasnet_params.causal,
        )

        model = BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid(
            encoder,
            masker,
            decoder,
            encoder_activation=args.conv_tasnet_params.encoder_activation,
            use_encoder=True,
            apply_mask=False if args.conv_tasnet_params.synthesis else True,
            use_decoder=True,
            decoder_activation=args.conv_tasnet_params.decoder_activation,
        )
        model.use_encoder_to_target = False

    return model