import os |
import sys |
import time |
import librosa |
import yaml |
import joblib |
import argparse |
import soundfile as sf |
import numpy as np |
from pathlib import Path |
from collections import defaultdict |
from typing import Optional |
from tqdm import tqdm |
sys.path.append(os.path.dirname(__file__)) |
sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
import torch.distributed as dist |
from torch.nn.parallel import DistributedDataParallel |
from nnet.WavLM import WavLM, WavLMConfig |
from vq.codec_encoder import CodecEncoder_Transformer |
from vq.codec_decoder_vocos import CodecDecoderVocos |
from vq.module import SemanticEncoder |
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel |
from collections import OrderedDict |
from loader.datareader import DataReader |
from loader.datareader_aec import DataReaderAEC |
from loader.datareader_tse import DataReaderTSE |
from nnet.llase import LLM_AR as model |
class Encodec(): |
''' |
Load Xcodec2 |
''' |
def __init__(self,device="cpu") -> None: |
self.device=device |
ckpt = "./ckpt/codec_ckpt/epoch=4-step=1400000.ckpt", |
ckpt = torch.load(ckpt, map_location='cpu') |
state_dict = ckpt['state_dict'] |
filtered_state_dict_codec = OrderedDict() |
filtered_state_dict_semantic_encoder = OrderedDict() |
filtered_state_dict_gen = OrderedDict() |
filtered_state_dict_fc_post_a = OrderedDict() |
filtered_state_dict_fc_prior = OrderedDict() |
for key, value in state_dict.items(): |
if key.startswith('CodecEnc.'): |
new_key = key[len('CodecEnc.'):] |
filtered_state_dict_codec[new_key] = value |
elif key.startswith('generator.'): |
new_key = key[len('generator.'):] |
filtered_state_dict_gen[new_key] = value |
elif key.startswith('fc_post_a.'): |
new_key = key[len('fc_post_a.'):] |
filtered_state_dict_fc_post_a[new_key] = value |
elif key.startswith('SemanticEncoder_module.'): |
new_key = key[len('SemanticEncoder_module.'):] |
filtered_state_dict_semantic_encoder[new_key] = value |
elif key.startswith('fc_prior.'): |
new_key = key[len('fc_prior.'):] |
filtered_state_dict_fc_prior[new_key] = value |
self.semantic_model = Wav2Vec2BertModel.from_pretrained( |
"./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0", |
output_hidden_states=True) |
self.semantic_model=self.semantic_model.eval().to(self.device) |
self.SemanticEncoder_module = SemanticEncoder(1024,1024,1024) |
self.SemanticEncoder_module.load_state_dict(filtered_state_dict_semantic_encoder) |
self.SemanticEncoder_module = self.SemanticEncoder_module.eval().to(self.device) |
self.encoder = CodecEncoder_Transformer() |
self.encoder.load_state_dict(filtered_state_dict_codec) |
self.encoder = self.encoder.eval().to(self.device) |
self.decoder = CodecDecoderVocos() |
self.decoder.load_state_dict(filtered_state_dict_gen) |
self.decoder = self.decoder.eval().to(self.device) |
self.fc_post_a = nn.Linear( 2048, 1024 ) |
self.fc_post_a.load_state_dict(filtered_state_dict_fc_post_a) |
self.fc_post_a = self.fc_post_a.eval().to(self.device) |
self.fc_prior = nn.Linear( 2048, 2048 ) |
self.fc_prior.load_state_dict(filtered_state_dict_fc_prior) |
self.fc_prior = self.fc_prior.eval().to(self.device) |
self.feature_extractor = AutoFeatureExtractor.from_pretrained( |
"./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0") |
def get_feat(self, wav_batch, pad=None): |
if len(wav_batch.shape) != 2: |
return self.feature_extractor(F.pad(wav_batch, pad), sampling_rate=16000, return_tensors="pt") .data['input_features'] |
padded_wavs = torch.stack([F.pad(wav, pad) for wav in wav_batch]) |
batch_feats = [] |
for wav in padded_wavs: |
feat = self.feature_extractor( |
wav, |
sampling_rate=16000, |
return_tensors="pt" |
).data['input_features'] |
batch_feats.append(feat) |
feat_batch = torch.concat(batch_feats, dim=0).to(self.device) |
return feat_batch |
def get_embedding(self, wav_cpu): |
wav_cpu = wav_cpu.cpu() |
feat = self.get_feat(wav_cpu,pad=(160,160)) |
feat = feat.to(self.device) |
if(len(wav_cpu.shape)==1): |
wav = wav_cpu.unsqueeze(0).to(self.device) |
else: |
wav = wav_cpu.to(self.device) |
wav = torch.nn.functional.pad(wav, (0, (200 - (wav.shape[1] % 200)))) |
with torch.no_grad(): |
vq_emb = self.encoder(wav.unsqueeze(1)) |
vq_emb = vq_emb.transpose(1, 2) |
if vq_emb.shape[2]!=feat.shape[1]: |
feat = self.get_feat(wav_cpu) |
feat = feat.to(self.device) |
semantic_target = self.semantic_model(feat[:, :,:]) |
semantic_target = semantic_target.hidden_states[16] |
semantic_target = semantic_target.transpose(1, 2) |
semantic_target = self.SemanticEncoder_module(semantic_target) |
vq_emb = torch.cat([semantic_target, vq_emb], dim=1) |
return vq_emb |
def emb2token(self, emb): |
emb.to(self.device) |
emb = self.fc_prior(emb.transpose(1, 2)).transpose(1, 2) |
_, vq_code, _ = self.decoder(emb, vq=True) |
return vq_code |
def token2wav(self, vq_code): |
vq_code.to(self.device) |
vq_post_emb = self.decoder.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) |
vq_post_emb = vq_post_emb.transpose(1, 2) |
vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1,2)).transpose(1,2) |
recon = self.decoder(vq_post_emb.transpose(1, 2), vq=False)[0].squeeze() |
return recon |
class WavLM_feat(object): |
''' |
Load WavLM |
''' |
def __init__(self, device): |
self.wavlm = self._reload_wavLM_large(device=device) |
def __call__(self, wav): |
T = wav.shape[-1] |
wav = wav.reshape(-1, T) |
with torch.no_grad(): |
feat = self.wavlm.extract_features(wav, output_layer=6, ret_layer_results=False)[0] |
B, T, D = feat.shape |
feat = torch.reshape(feat, (-1, D)) |
return feat |
def _reload_wavLM_large(self, path="/home/bykang/WavLM-Large.pt", device: Optional[torch.device] = None): |
cpt = torch.load(path, map_location="cpu") |
cfg = WavLMConfig(cpt['cfg']) |
wavLM = WavLM(cfg) |
wavLM.load_state_dict(cpt['model']) |
wavLM.eval() |
if device != None: |
wavLM = wavLM.to(device) |
for p in wavLM.parameters(): |
p.requires_grad = False |
print('successful to reload wavLM', path) |
return wavLM |
def get_firstchannel_read(path, fs=16000): |
''' |
Get first channel of the wav |
''' |
wave_data, sr = sf.read(path) |
if sr != fs: |
if len(wave_data.shape) != 1: |
wave_data = wave_data.transpose((1, 0)) |
wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) |
if len(wave_data.shape) != 1: |
wave_data = wave_data.transpose((1, 0)) |
if len(wave_data.shape) > 1: |
wave_data = wave_data[:, 0] |
return wave_data |
def load_obj(obj, device): |
''' |
Offload tensor object in obj to cuda device |
''' |
def cuda(obj): |
return obj.to(device) if isinstance(obj, torch.Tensor) else obj |
if isinstance(obj, dict): |
return {key: load_obj(obj[key], device) for key in obj} |
elif isinstance(obj, list): |
return [load_obj(val, device) for val in obj] |
else: |
return cuda(obj) |
def run(args): |
LOCAL_RANK = int(os.environ['LOCAL_RANK']) |
WORLD_SIZE = int(os.environ['WORLD_SIZE']) |
WORLD_RANK = int(os.environ['RANK']) |
dist.init_process_group(args.backend, rank=WORLD_RANK, world_size=WORLD_SIZE) |
torch.cuda.set_device(LOCAL_RANK) |
device = torch.device('cuda', LOCAL_RANK) |
print(f"[{os.getpid()}] using device: {device}", torch.cuda.current_device(), "local rank", LOCAL_RANK) |
with open(args.conf, "r") as f: |
conf = yaml.load(f, Loader=yaml.FullLoader) |
if conf["task"]=="AEC": |
data_reader = DataReaderAEC(**conf["datareader"]) |
elif conf["task"]=="TSE": |
data_reader = DataReaderTSE(**conf["datareader"]) |
else: |
data_reader = DataReader(**conf["datareader"]) |
codec = Encodec(device) |
wavlm_feat = WavLM_feat(device) |
nnet = model(**conf["nnet_conf"]) |
cpt_fname = Path(conf["test"]["checkpoint"]) |
cpt = torch.load(cpt_fname, map_location="cpu") |
nnet = nnet.to(device) |
nnet = DistributedDataParallel(nnet, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, find_unused_parameters=True) |
nnet.load_state_dict(cpt["model_state_dict"]) |
nnet.eval() |
if conf["task"]=="AEC": |
if not os.path.exists(conf["save"]["feat_dir"]+"/mic"): |
os.makedirs(conf["save"]["feat_dir"]+"/mic") |
if not os.path.exists(conf["save"]["feat_dir"]+"/ref"): |
os.makedirs(conf["save"]["feat_dir"]+"/ref") |
elif conf["task"]=="TSE": |
if not os.path.exists(conf["save"]["feat_dir"]+"/mic"): |
os.makedirs(conf["save"]["feat_dir"]+"/mic") |
if not os.path.exists(conf["save"]["feat_dir"]+"/ref"): |
os.makedirs(conf["save"]["feat_dir"]+"/ref") |
else: |
if not os.path.exists(conf["save"]["feat_dir"]): |
os.makedirs(conf["save"]["feat_dir"]) |
if not os.path.exists(conf["save"]["wav_dir"]): |
os.makedirs(conf["save"]["wav_dir"]) |
if_feat_too = conf["test"]["infer_feat_too"] |
origin_feat_dir = conf["save"]["feat_dir"] |
origin_wav_dir = conf["save"]["wav_dir"] |
last_feat_dir = origin_feat_dir |
last_wav_dir = origin_wav_dir |
for inference_time in range(conf["test"]["inference_time"]): |
if inference_time > 0: |
feat_dir = origin_feat_dir + "inference" + str(inference_time) |
wav_dir = origin_wav_dir + "inference" + str(inference_time) |
else: |
feat_dir = origin_feat_dir |
wav_dir = origin_wav_dir |
if not os.path.exists(feat_dir): |
os.makedirs(feat_dir) |
if not os.path.exists(wav_dir): |
os.makedirs(wav_dir) |
with torch.no_grad(): |
if if_feat_too ==True or inference_time>0: |
for egs in tqdm(data_reader): |
egs = load_obj(egs, device) |
if conf["task"]=="AEC" or conf["task"]=="TSE": |
if inference_time > 0: |
mic_path = last_wav_dir + '/' + egs["mic_name"] + ".wav" |
egs["mic"] = torch.from_numpy(get_firstchannel_read(mic_path).astype(np.float32)).unsqueeze(0).to(device) |
else: |
egs["mic"]=egs["mic"].contiguous() |
egs["ref"]=egs["ref"].contiguous() |
feat_mic = wavlm_feat(egs["mic"]) |
out_mic = feat_mic.detach().squeeze(0).cpu().numpy() |
if not os.path.exists(os.path.join(feat_dir, "mic")): |
os.makedirs(os.path.join(feat_dir, "mic")) |
np.save(os.path.join(feat_dir, "mic", egs["mic_name"]), out_mic) |
if inference_time == 0: |
feat_ref = wavlm_feat(egs["ref"]) |
out_ref = feat_ref.detach().squeeze(0).cpu().numpy() |
np.save(os.path.join(origin_feat_dir, "ref", egs["ref_name"]), out_ref) |
torch.cuda.empty_cache() |
else: |
if inference_time > 0: |
mix_path = last_wav_dir + '/' + egs["name"] + ".wav" |
egs["mix"] = torch.from_numpy(get_firstchannel_read(mix_path).astype(np.float32)).unsqueeze(0).to(device) |
else: |
egs["mix"]=egs["mix"].contiguous() |
feat = wavlm_feat(egs["mix"]) |
out = feat.detach().squeeze(0).cpu().numpy() |
np.save(os.path.join(feat_dir, egs["name"]), out) |
for egs in tqdm(data_reader): |
egs = load_obj(egs, device) |
sr = 16000 |
if conf["task"] == "AEC": |
feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy" |
feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy" |
feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0) |
feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0) |
if inference_time > 0: |
est = nnet(feat_mic) |
else: |
est = nnet(feat_mic, feat_ref) |
max, max_indices_1 = torch.max(est[1], dim=1) |
recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() |
target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav") |
print(target_path) |
sf.write(target_path , recon_1, sr) |
elif conf["task"] == "TSE" : |
feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy" |
feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy" |
feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0) |
feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0) |
if_keep_ref = True |
if inference_time>0 and if_keep_ref== False: |
est = nnet(feat_mic) |
else: |
est = nnet(feat_mic, feat_ref) |
max, max_indices_1 = torch.max(est[0], dim=1) |
recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() |
target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav") |
print(target_path) |
sf.write(target_path , recon_1, sr) |
elif conf["task"] == "PLC": |
feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" |
feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) |
est = nnet(feat) |
max, max_indices_1 = torch.max(est[1], dim=1) |
recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() |
target_path = os.path.join(wav_dir, egs["name"] + ".wav") |
print(target_path) |
sf.write(target_path , recon_1, sr) |
elif conf["task"] == "SS": |
feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" |
feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) |
est = nnet(feat) |
max, max_indices_1 = torch.max(est[1], dim=1) |
recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() |
target_path_1 = os.path.join(wav_dir, egs["name"] + ".wav") |
sf.write(target_path_1 , recon_1, sr) |
if inference_time > 0: |
origin_feat_path = os.path.join(origin_feat_dir, egs["name"]) + ".npy" |
origin_feat = torch.from_numpy(np.load(origin_feat_path)).unsqueeze(0) |
est2 = nnet(origin_feat, feat) |
max, max_indices_2 = torch.max(est2[1], dim=1) |
recon_2 = codec.token2wav(max_indices_2.unsqueeze(0)).squeeze().detach().cpu().numpy() |
if not os.path.exists(last_wav_dir + "s2"): |
os.makedirs(last_wav_dir + "s2") |
target_path_2 = os.path.join(last_wav_dir + "s2", egs["name"] + ".wav") |
sf.write(target_path_2 , recon_2, sr) |
else: |
feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" |
feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) |
est = nnet(feat) |
max, max_indices_1 = torch.max(est[1], dim=1) |
recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() |
target_path = os.path.join(wav_dir, egs["name"] + ".wav") |
print(target_path) |
sf.write(target_path , recon_1, sr) |
last_feat_dir = feat_dir |
last_wav_dir = wav_dir |
if __name__ == "__main__": |
parser = argparse.ArgumentParser( |
description = "Command to test separation model in Pytorch", |
formatter_class = argparse.ArgumentDefaultsHelpFormatter) |
parser.add_argument("-conf", |
type=str, |
required=True, |
help="Yaml configuration file for training") |
parser.add_argument("--backend", |
type=str, |
default="nccl", |
choices=["nccl", "gloo"]) |
args = parser.parse_args() |
os.environ["NCCL_DEBUG"] = "INFO" |
run(args) |