File size: 4,804 Bytes
fbe31d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5961cfc
fbe31d2
 
 
 
 
5961cfc
fbe31d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :Waveformer-main 
@File    :CLAPSep.py
@IDE     :PyCharm 
@Author  :Aisaka/Hao Ma @SDU
@Date    :2024/2/28 下午1:12 
'''

import torch
from torch import nn
import torchaudio
import laion_clap
from .CLAPSep_decoder import HTSAT_Decoder
import copy
import loralib as lora
from torchlibrosa import ISTFT, STFT
from torchlibrosa.stft import magphase
import librosa

def set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)
    setattr(cur_mod, tokens[-1], module)


def process_model(model, rank):
    for n, module in model.named_modules():
        if 'WindowAttention' in str(type(module)):
            for n_, layer in module.named_modules():
                if isinstance(layer, torch.nn.Linear):
                    lora_layer = lora.Linear(layer.in_features, layer.out_features, r=rank,
                                             bias=hasattr(layer, 'bias'), merge_weights=True)
                    lora_layer.weight = layer.weight
                    if hasattr(layer, 'bias'):
                        lora_layer.bias = layer.bias
                    set_module(model, n+'.'+n_, lora_layer)
    return model


class CLAPSep(nn.Module):
    def __init__(self, model_config, CLAP_path, use_lora=True, rank=16, nfft=1024):
        super().__init__()
        self.resampler = torchaudio.transforms.Resample(32000, 48000)
        self.clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu')
        self.clap_model.load_ckpt(CLAP_path)
        for p in self.clap_model.parameters():
            p.requires_grad = False
        self.audio_branch = copy.deepcopy(self.clap_model.model.audio_branch)
        if use_lora:
            process_model(self.audio_branch, rank)
        self.decoder_model = HTSAT_Decoder(**model_config)
        self.stft = STFT(n_fft=nfft, hop_length=320,
                         win_length=nfft, window='hann', center=True, pad_mode='reflect',
                         freeze_parameters=True)
        self.istft = ISTFT(n_fft=nfft, hop_length=320,
                           win_length=nfft, window='hann', center=True, pad_mode='reflect',
                           freeze_parameters=True)
        self.features = self.install_forward_hooks()

    def wav_reconstruct(self, mask, mag_x, cos_x, sin_x, length):
        mag_y = torch.nn.functional.relu_(mag_x * mask)
        cos_y = cos_x
        sin_y = sin_x
        pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length)
        return pred

    def inference_from_data(self, mixed, embed_pos, embed_neg):
        self.eval()
        real, imag = self.stft(mixed)
        mag, cos, sin = magphase(real, imag)
        self.features.append(mag)
        with torch.no_grad():
            embed = torch.nn.functional.normalize(torch.concat([embed_pos, embed_neg], dim=-1), dim=-1)
            self.audio_branch({"waveform": self.resampler(mixed)})
            mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed)
            pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1))
        del self.features[:]
        return pred

    def install_forward_hooks(self):
        features = []

        def get_features_list(_, __, output):
            features.append(output)

        def get_features_list_basic_layer(_, __, output):
            features.append(output[0])

        def spectrogram_padding(_, __, out):
            return torch.nn.functional.pad(out, (0, 0, 0, 1024 - out.size(2)))

        self.audio_branch.spectrogram_extractor.register_forward_hook(spectrogram_padding)
        self.audio_branch.patch_embed.register_forward_hook(get_features_list)
        for module in self.audio_branch.layers:
            module.register_forward_hook(get_features_list_basic_layer)
        return features

if __name__ == '__main__':
    model_config = {"lan_embed_dim": 1024,
    "depths": [1, 1, 1, 1],
    "embed_dim": 128,
    "encoder_embed_dim": 128,
    "phase": False,
    "spec_factor": 8,
    "d_attn": 640,
    "n_masker_layer": 3,
    "conv": False}
    CLAP_path = "./music_audioset_epoch_15_esc_90.14.pt"

    model = CLAPSep(model_config, CLAP_path)
    ckpt = torch.load('best_model.ckpt', map_location='cpu')
    model.load_state_dict(ckpt, strict=False)
    model.eval()
    audio, fs = librosa.load("./510_25.221254348754883_mixture.wav", sr=32000)
    pred = model.inference_from_data(torch.tensor(audio).unsqueeze(0), pos_prompt=[''], neg_prompt=['A vehicle engine revving then powering down.'])
    import soundfile as sf
    sf.write('./pred.wav', pred.squeeze().numpy(), 32000)