Audio-to-Audio
LLaSE-G1 / inference.py
BeauKang01's picture
add checkpoint
d972bc8
import os
import sys
import time
import librosa
import yaml
import joblib
import argparse
import soundfile as sf
import numpy as np
from pathlib import Path
from collections import defaultdict
from typing import Optional
from tqdm import tqdm
sys.path.append(os.path.dirname(__file__))
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
# WavLM
from nnet.WavLM import WavLM, WavLMConfig
# Xcodec2
from vq.codec_encoder import CodecEncoder_Transformer
from vq.codec_decoder_vocos import CodecDecoderVocos
from vq.module import SemanticEncoder
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
from collections import OrderedDict
# Dataloader
from loader.datareader import DataReader
from loader.datareader_aec import DataReaderAEC
from loader.datareader_tse import DataReaderTSE
# LLaSE
from nnet.llase import LLM_AR as model
class Encodec():
'''
Load Xcodec2
'''
def __init__(self,device="cpu") -> None:
self.device=device
ckpt = "./ckpt/codec_ckpt/epoch=4-step=1400000.ckpt",
# ckpt = '/home/bykang/codec_ckpt/epoch=4-step=1400000.ckpt'
ckpt = torch.load(ckpt, map_location='cpu')
state_dict = ckpt['state_dict']
filtered_state_dict_codec = OrderedDict()
filtered_state_dict_semantic_encoder = OrderedDict()
filtered_state_dict_gen = OrderedDict()
filtered_state_dict_fc_post_a = OrderedDict()
filtered_state_dict_fc_prior = OrderedDict()
for key, value in state_dict.items():
if key.startswith('CodecEnc.'):
new_key = key[len('CodecEnc.'):]
filtered_state_dict_codec[new_key] = value
elif key.startswith('generator.'):
new_key = key[len('generator.'):]
filtered_state_dict_gen[new_key] = value
elif key.startswith('fc_post_a.'):
new_key = key[len('fc_post_a.'):]
filtered_state_dict_fc_post_a[new_key] = value
elif key.startswith('SemanticEncoder_module.'):
new_key = key[len('SemanticEncoder_module.'):]
filtered_state_dict_semantic_encoder[new_key] = value
elif key.startswith('fc_prior.'):
new_key = key[len('fc_prior.'):]
filtered_state_dict_fc_prior[new_key] = value
self.semantic_model = Wav2Vec2BertModel.from_pretrained(
"./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0",
# "/home/bykang/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b",
output_hidden_states=True)
self.semantic_model=self.semantic_model.eval().to(self.device)
self.SemanticEncoder_module = SemanticEncoder(1024,1024,1024)
self.SemanticEncoder_module.load_state_dict(filtered_state_dict_semantic_encoder)
self.SemanticEncoder_module = self.SemanticEncoder_module.eval().to(self.device)
self.encoder = CodecEncoder_Transformer()
self.encoder.load_state_dict(filtered_state_dict_codec)
self.encoder = self.encoder.eval().to(self.device)
self.decoder = CodecDecoderVocos()
self.decoder.load_state_dict(filtered_state_dict_gen)
self.decoder = self.decoder.eval().to(self.device)
self.fc_post_a = nn.Linear( 2048, 1024 )
self.fc_post_a.load_state_dict(filtered_state_dict_fc_post_a)
self.fc_post_a = self.fc_post_a.eval().to(self.device)
self.fc_prior = nn.Linear( 2048, 2048 )
self.fc_prior.load_state_dict(filtered_state_dict_fc_prior)
self.fc_prior = self.fc_prior.eval().to(self.device)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
"./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0")
# "/home/bykang/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b")
def get_feat(self, wav_batch, pad=None):
if len(wav_batch.shape) != 2:
return self.feature_extractor(F.pad(wav_batch, pad), sampling_rate=16000, return_tensors="pt") .data['input_features']
padded_wavs = torch.stack([F.pad(wav, pad) for wav in wav_batch])
batch_feats = []
for wav in padded_wavs:
feat = self.feature_extractor(
wav,
sampling_rate=16000,
return_tensors="pt"
).data['input_features']
batch_feats.append(feat)
feat_batch = torch.concat(batch_feats, dim=0).to(self.device)
return feat_batch
def get_embedding(self, wav_cpu):
wav_cpu = wav_cpu.cpu()
feat = self.get_feat(wav_cpu,pad=(160,160))
feat = feat.to(self.device)
if(len(wav_cpu.shape)==1):
wav = wav_cpu.unsqueeze(0).to(self.device)
else:
wav = wav_cpu.to(self.device)
wav = torch.nn.functional.pad(wav, (0, (200 - (wav.shape[1] % 200))))
with torch.no_grad():
vq_emb = self.encoder(wav.unsqueeze(1))
vq_emb = vq_emb.transpose(1, 2)
if vq_emb.shape[2]!=feat.shape[1]:
feat = self.get_feat(wav_cpu)
feat = feat.to(self.device)
semantic_target = self.semantic_model(feat[:, :,:])
semantic_target = semantic_target.hidden_states[16]
semantic_target = semantic_target.transpose(1, 2)
semantic_target = self.SemanticEncoder_module(semantic_target)
vq_emb = torch.cat([semantic_target, vq_emb], dim=1)
return vq_emb
def emb2token(self, emb):
emb.to(self.device)
emb = self.fc_prior(emb.transpose(1, 2)).transpose(1, 2)
_, vq_code, _ = self.decoder(emb, vq=True)
return vq_code
def token2wav(self, vq_code):
vq_code.to(self.device)
vq_post_emb = self.decoder.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
vq_post_emb = vq_post_emb.transpose(1, 2)
vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1,2)).transpose(1,2)
recon = self.decoder(vq_post_emb.transpose(1, 2), vq=False)[0].squeeze()
# if write the wav, add .squeeze().detach().cpu().numpy()
# if need gradient use the config right now
return recon
class WavLM_feat(object):
'''
Load WavLM
'''
def __init__(self, device):
self.wavlm = self._reload_wavLM_large(device=device)
def __call__(self, wav):
T = wav.shape[-1]
wav = wav.reshape(-1, T)
with torch.no_grad():
feat = self.wavlm.extract_features(wav, output_layer=6, ret_layer_results=False)[0]
B, T, D = feat.shape
feat = torch.reshape(feat, (-1, D))
return feat
def _reload_wavLM_large(self, path="/home/bykang/WavLM-Large.pt", device: Optional[torch.device] = None):
cpt = torch.load(path, map_location="cpu")
cfg = WavLMConfig(cpt['cfg'])
wavLM = WavLM(cfg)
wavLM.load_state_dict(cpt['model'])
wavLM.eval()
if device != None:
wavLM = wavLM.to(device)
for p in wavLM.parameters():
p.requires_grad = False
print('successful to reload wavLM', path)
return wavLM
def get_firstchannel_read(path, fs=16000):
'''
Get first channel of the wav
'''
wave_data, sr = sf.read(path)
if sr != fs:
if len(wave_data.shape) != 1:
wave_data = wave_data.transpose((1, 0))
wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs)
if len(wave_data.shape) != 1:
wave_data = wave_data.transpose((1, 0))
if len(wave_data.shape) > 1:
wave_data = wave_data[:, 0]
return wave_data
def load_obj(obj, device):
'''
Offload tensor object in obj to cuda device
'''
def cuda(obj):
return obj.to(device) if isinstance(obj, torch.Tensor) else obj
if isinstance(obj, dict):
return {key: load_obj(obj[key], device) for key in obj}
elif isinstance(obj, list):
return [load_obj(val, device) for val in obj]
else:
return cuda(obj)
def run(args):
LOCAL_RANK = int(os.environ['LOCAL_RANK'])
WORLD_SIZE = int(os.environ['WORLD_SIZE'])
WORLD_RANK = int(os.environ['RANK'])
dist.init_process_group(args.backend, rank=WORLD_RANK, world_size=WORLD_SIZE)
torch.cuda.set_device(LOCAL_RANK)
device = torch.device('cuda', LOCAL_RANK)
print(f"[{os.getpid()}] using device: {device}", torch.cuda.current_device(), "local rank", LOCAL_RANK)
with open(args.conf, "r") as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# Dataloader
if conf["task"]=="AEC":
data_reader = DataReaderAEC(**conf["datareader"])
elif conf["task"]=="TSE":
data_reader = DataReaderTSE(**conf["datareader"])
else:
data_reader = DataReader(**conf["datareader"])
# Load WavLM and XCodec2
codec = Encodec(device)
wavlm_feat = WavLM_feat(device)
# Load LLaSE
nnet = model(**conf["nnet_conf"])
cpt_fname = Path(conf["test"]["checkpoint"])
cpt = torch.load(cpt_fname, map_location="cpu")
nnet = nnet.to(device)
nnet = DistributedDataParallel(nnet, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, find_unused_parameters=True)
nnet.load_state_dict(cpt["model_state_dict"])
nnet.eval()
# Make sure the dir exists
if conf["task"]=="AEC":
if not os.path.exists(conf["save"]["feat_dir"]+"/mic"):
os.makedirs(conf["save"]["feat_dir"]+"/mic")
if not os.path.exists(conf["save"]["feat_dir"]+"/ref"):
os.makedirs(conf["save"]["feat_dir"]+"/ref")
elif conf["task"]=="TSE":
if not os.path.exists(conf["save"]["feat_dir"]+"/mic"):
os.makedirs(conf["save"]["feat_dir"]+"/mic")
if not os.path.exists(conf["save"]["feat_dir"]+"/ref"):
os.makedirs(conf["save"]["feat_dir"]+"/ref")
else:
if not os.path.exists(conf["save"]["feat_dir"]):
os.makedirs(conf["save"]["feat_dir"])
if not os.path.exists(conf["save"]["wav_dir"]):
os.makedirs(conf["save"]["wav_dir"])
# Main of inference
if_feat_too = conf["test"]["infer_feat_too"]
origin_feat_dir = conf["save"]["feat_dir"]
origin_wav_dir = conf["save"]["wav_dir"]
last_feat_dir = origin_feat_dir
last_wav_dir = origin_wav_dir
for inference_time in range(conf["test"]["inference_time"]):
# For multi-inference
if inference_time > 0:
feat_dir = origin_feat_dir + "inference" + str(inference_time)
wav_dir = origin_wav_dir + "inference" + str(inference_time)
else:
feat_dir = origin_feat_dir
wav_dir = origin_wav_dir
if not os.path.exists(feat_dir):
os.makedirs(feat_dir)
if not os.path.exists(wav_dir):
os.makedirs(wav_dir)
with torch.no_grad():
# Extract WavLM features
if if_feat_too ==True or inference_time>0:
for egs in tqdm(data_reader):
egs = load_obj(egs, device)
if conf["task"]=="AEC" or conf["task"]=="TSE":
if inference_time > 0:
mic_path = last_wav_dir + '/' + egs["mic_name"] + ".wav"
egs["mic"] = torch.from_numpy(get_firstchannel_read(mic_path).astype(np.float32)).unsqueeze(0).to(device)
else:
egs["mic"]=egs["mic"].contiguous()
egs["ref"]=egs["ref"].contiguous()
feat_mic = wavlm_feat(egs["mic"])
out_mic = feat_mic.detach().squeeze(0).cpu().numpy()
if not os.path.exists(os.path.join(feat_dir, "mic")):
os.makedirs(os.path.join(feat_dir, "mic"))
np.save(os.path.join(feat_dir, "mic", egs["mic_name"]), out_mic)
# For AEC and TSE, reference audio only need to extract feats at first time
if inference_time == 0:
feat_ref = wavlm_feat(egs["ref"])
out_ref = feat_ref.detach().squeeze(0).cpu().numpy()
np.save(os.path.join(origin_feat_dir, "ref", egs["ref_name"]), out_ref)
torch.cuda.empty_cache()
else:
if inference_time > 0:
mix_path = last_wav_dir + '/' + egs["name"] + ".wav"
egs["mix"] = torch.from_numpy(get_firstchannel_read(mix_path).astype(np.float32)).unsqueeze(0).to(device)
else:
egs["mix"]=egs["mix"].contiguous()
feat = wavlm_feat(egs["mix"])
out = feat.detach().squeeze(0).cpu().numpy()
np.save(os.path.join(feat_dir, egs["name"]), out)
# Predict the clean tokens and token2wav
for egs in tqdm(data_reader):
egs = load_obj(egs, device)
sr = 16000
if conf["task"] == "AEC":
# Get feat
feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy"
feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy"
feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0)
feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0)
# For multi-inference
if inference_time > 0:
est = nnet(feat_mic)
else:
est = nnet(feat_mic, feat_ref)
# Get tokens and token2wav
max, max_indices_1 = torch.max(est[1], dim=1)
recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy()
# Save the wav
target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav")
print(target_path)
sf.write(target_path , recon_1, sr)
elif conf["task"] == "TSE" :
# Get feat
feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy"
feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy"
feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0)
feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0)
# Choose if keep the enroallment audio while multi-inference
if_keep_ref = True
if inference_time>0 and if_keep_ref== False:
est = nnet(feat_mic)
else:
est = nnet(feat_mic, feat_ref)
# Get tokens and token2wav
max, max_indices_1 = torch.max(est[0], dim=1)
recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy()
# Save the wav
target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav")
print(target_path)
sf.write(target_path , recon_1, sr)
elif conf["task"] == "PLC":
# Get feat
feat_path = os.path.join(feat_dir, egs["name"]) + ".npy"
feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0)
# Get tokens and token2wav
est = nnet(feat)
max, max_indices_1 = torch.max(est[1], dim=1)
recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy()
# Save the wav
target_path = os.path.join(wav_dir, egs["name"] + ".wav")
print(target_path)
sf.write(target_path , recon_1, sr)
elif conf["task"] == "SS":
# Get feat
feat_path = os.path.join(feat_dir, egs["name"]) + ".npy"
feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0)
# Separate the first speaker
est = nnet(feat)
max, max_indices_1 = torch.max(est[1], dim=1)
recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy()
target_path_1 = os.path.join(wav_dir, egs["name"] + ".wav")
sf.write(target_path_1 , recon_1, sr)
# Separate the second speaker, SS need at least 2 inference time in config
if inference_time > 0:
origin_feat_path = os.path.join(origin_feat_dir, egs["name"]) + ".npy"
origin_feat = torch.from_numpy(np.load(origin_feat_path)).unsqueeze(0)
est2 = nnet(origin_feat, feat)
max, max_indices_2 = torch.max(est2[1], dim=1)
recon_2 = codec.token2wav(max_indices_2.unsqueeze(0)).squeeze().detach().cpu().numpy()
if not os.path.exists(last_wav_dir + "s2"):
os.makedirs(last_wav_dir + "s2")
target_path_2 = os.path.join(last_wav_dir + "s2", egs["name"] + ".wav")
sf.write(target_path_2 , recon_2, sr)
else:
# Get feat
feat_path = os.path.join(feat_dir, egs["name"]) + ".npy"
feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0)
# Get tokens and token2wav
est = nnet(feat)
max, max_indices_1 = torch.max(est[1], dim=1)
recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy()
# Save the wav
target_path = os.path.join(wav_dir, egs["name"] + ".wav")
print(target_path)
sf.write(target_path , recon_1, sr)
# For next inference
last_feat_dir = feat_dir
last_wav_dir = wav_dir
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description = "Command to test separation model in Pytorch",
formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-conf",
type=str,
required=True,
help="Yaml configuration file for training")
parser.add_argument("--backend",
type=str,
default="nccl",
choices=["nccl", "gloo"])
args = parser.parse_args()
# for nccl debug
os.environ["NCCL_DEBUG"] = "INFO"
run(args)