Spaces:
Running
on
Zero
Running
on
Zero
| # pylint: disable=C0301 | |
| ''' | |
| This module contains the AudioProcessor class and related functions for processing audio data. | |
| It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction, | |
| and audio separation. The class is initialized with configuration parameters and can process | |
| audio files using the provided models. | |
| ''' | |
| import os | |
| import subprocess | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from audio_separator.separator import Separator | |
| from transformers import WhisperModel, AutoFeatureExtractor | |
| import torch.nn.functional as F | |
| def linear_interpolation_fps(features, input_fps, output_fps, output_len=None): | |
| features = features.transpose(1, 2) # [1, C, T] | |
| seq_len = features.shape[2] / float(input_fps) | |
| if output_len is None: | |
| output_len = int(seq_len * output_fps) | |
| output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') | |
| return output_features.transpose(1, 2) | |
| def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int): | |
| p = subprocess.Popen([ | |
| "ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file | |
| ]) | |
| ret = p.wait() | |
| assert ret == 0, "Resample audio failed!" | |
| return output_audio_file | |
| class AudioProcessor: | |
| """ | |
| AudioProcessor is a class that handles the processing of audio files. | |
| It takes care of preprocessing the audio files, extracting features | |
| using wav2vec models, and separating audio signals if needed. | |
| :param sample_rate: Sampling rate of the audio file | |
| :param fps: Frames per second for the extracted features | |
| :param wav2vec_model_path: Path to the wav2vec model | |
| :param only_last_features: Whether to only use the last features | |
| :param audio_separator_model_path: Path to the audio separator model | |
| :param audio_separator_model_name: Name of the audio separator model | |
| :param cache_dir: Directory to cache the intermediate results | |
| :param device: Device to run the processing on | |
| """ | |
| def __init__( | |
| self, | |
| sample_rate, | |
| fps, | |
| wav2vec_model_path, | |
| wav2vec_feature_type, | |
| audio_separator_model_path:str=None, | |
| audio_separator_model_name:str=None, | |
| cache_dir:str='', | |
| device="cuda:0", | |
| ) -> None: | |
| self.sample_rate = sample_rate | |
| self.fps = fps | |
| self.device = device | |
| self.whisper = WhisperModel.from_pretrained(wav2vec_model_path).to(device).eval() | |
| self.whisper.requires_grad_(False) | |
| self.feature_extractor = AutoFeatureExtractor.from_pretrained(wav2vec_model_path) | |
| if audio_separator_model_name is not None: | |
| try: | |
| os.makedirs(cache_dir, exist_ok=True) | |
| except OSError as _: | |
| print("Fail to create the output cache dir.") | |
| self.audio_separator = Separator( | |
| output_dir=cache_dir, | |
| output_single_stem="vocals", | |
| model_file_dir=audio_separator_model_path, | |
| ) | |
| self.audio_separator.load_model(audio_separator_model_name) | |
| assert self.audio_separator.model_instance is not None, "Fail to load audio separate model." | |
| else: | |
| self.audio_separator=None | |
| print("Use audio directly without vocals seperator.") | |
| def get_audio_feature(self, audio_path): | |
| audio_input, sampling_rate = librosa.load(audio_path, sr=16000) | |
| assert sampling_rate == 16000 | |
| audio_features = [] | |
| window = 750*640 | |
| for i in range(0, len(audio_input), window): | |
| audio_feature = self.feature_extractor(audio_input[i:i+window], | |
| sampling_rate=sampling_rate, | |
| return_tensors="pt", | |
| ).input_features | |
| audio_features.append(audio_feature) | |
| audio_features = torch.cat(audio_features, dim=-1) | |
| return audio_features, len(audio_input) // 640 | |
| def preprocess(self, audio_path: str): | |
| audio_input, audio_len = self.get_audio_feature(audio_path) | |
| audio_feature = audio_input.to(self.whisper.device).float() | |
| window = 3000 | |
| audio_prompts = [] | |
| for i in range(0, audio_feature.shape[-1], window): | |
| audio_prompt = self.whisper.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states | |
| audio_prompt = torch.stack(audio_prompt, dim=2) | |
| audio_prompts.append(audio_prompt) | |
| audio_prompts = torch.cat(audio_prompts, dim=1) | |
| audio_prompts = audio_prompts[:,:audio_len*2] | |
| audio_emb = self.audio_emb_enc(audio_prompts, wav_enc_type="whisper") | |
| return audio_emb, audio_emb.shape[0] | |
| def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"): | |
| if wav_enc_type == "wav2vec": | |
| feat_merge = audio_emb | |
| elif wav_enc_type == "whisper": | |
| # [1, T, 33, 1280] | |
| feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25) | |
| feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25) | |
| feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25) | |
| feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25) | |
| feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25) | |
| feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280] | |
| else: | |
| raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}") | |
| return feat_merge | |
| def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2): | |
| zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) | |
| zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device | |
| iter_ = 1 + (frame_num - 1) // 4 | |
| audio_emb_wind = [] | |
| for lt_i in range(iter_): | |
| if lt_i == 0: # latent_i | |
| # 提取第一帧VAElatent,audio左侧补0,标识出 | |
| st = frame0_idx + lt_i - 2 | |
| ed = frame0_idx + lt_i + 3 | |
| wind_feat = torch.stack([ | |
| audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed | |
| for i in range(st, ed) | |
| ], dim=0) # [5, 13, 768] | |
| wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) # [8, 13, 768] | |
| else: | |
| st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift | |
| ed = frame0_idx + 1 + 4 * lt_i + audio_shift | |
| wind_feat = torch.stack([ | |
| audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed | |
| for i in range(st, ed) | |
| ], dim=0) # [8, 13, 768] | |
| audio_emb_wind.append(wind_feat) | |
| audio_emb_wind = torch.stack(audio_emb_wind, dim=0) # [iter_, 8, 13, 768] | |
| return audio_emb_wind, ed - audio_shift | |
| def close(self): | |
| """ | |
| TODO: to be implemented | |
| """ | |
| return self | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, _exc_type, _exc_val, _exc_tb): | |
| self.close() | |