Spaces:
Runtime error
Runtime error
| import os | |
| import io | |
| import glob | |
| import math | |
| import tarfile | |
| import torch | |
| import torchaudio | |
| import safetensors | |
| from .configuration_whisper import WhisperVQConfig | |
| from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration | |
| from transformers import WhisperFeatureExtractor, WhisperTokenizerFast | |
| def load_quantize_encoder(model_path): | |
| config = WhisperVQConfig.from_pretrained(model_path) | |
| config.quantize_encoder_only = True | |
| model = WhisperVQEncoder(config) | |
| state_dict = {} | |
| for path in glob.glob(os.path.join(model_path, "model*.safetensors")): | |
| with safetensors.safe_open(path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| if key.startswith("model.encoder."): | |
| new_key = key[len("model.encoder."):] | |
| if new_key.startswith("layer_norm"): | |
| continue | |
| if new_key.startswith("layers"): | |
| layer_id = int(new_key.split(".")[1]) | |
| if layer_id >= config.quantize_position: | |
| continue | |
| state_dict[new_key] = f.get_tensor(key) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| model.cuda() | |
| return model | |
| _resample_buffer: dict[int, torchaudio.transforms.Resample] = {} | |
| def extract_speech_token(model: WhisperVQEncoder, feature_extractor: WhisperFeatureExtractor, utts): | |
| with torch.no_grad(): | |
| audios, indices = [], [] | |
| for idx, utt in enumerate(utts): | |
| if isinstance(utt, tuple): | |
| audio, sample_rate = utt | |
| else: | |
| audio, sample_rate = torchaudio.load(utt) | |
| audio = audio.cuda() | |
| if sample_rate != 16000: | |
| if sample_rate not in _resample_buffer: | |
| _resample_buffer[sample_rate] = torchaudio.transforms.Resample( | |
| orig_freq=sample_rate, | |
| new_freq=16000 | |
| ).to('cuda') | |
| audio = _resample_buffer[sample_rate](audio) | |
| # if audio.shape[0] > 1: | |
| # audio = audio[:1] | |
| audio = audio[0] | |
| audio = audio.cpu().numpy() | |
| time_step = 0 | |
| while time_step * 16000 < audio.shape[0]: | |
| audio_segment = audio[time_step * 16000: (time_step + 30) * 16000] | |
| audios.append(audio_segment) | |
| indices.append(idx) | |
| time_step += 30 | |
| pooling_kernel_size = model.config.pooling_kernel_size or 1 | |
| stride = model.conv1.stride[0] * model.conv2.stride[0] * pooling_kernel_size * feature_extractor.hop_length | |
| all_speech_tokens = [[] for _ in range(len(utts))] | |
| batch_size = 128 | |
| for start in range(0, len(audios), batch_size): | |
| features = feature_extractor(audios[start: start + batch_size], sampling_rate=16000, | |
| return_attention_mask=True, return_tensors="pt", device='cuda', | |
| padding="longest", pad_to_multiple_of=stride) | |
| features = features.to(device="cuda") | |
| outputs = model(**features) | |
| speech_tokens = outputs.quantized_token_ids | |
| attention_mask = features.attention_mask[:, ::model.conv1.stride[0] * model.conv2.stride[0]] | |
| attention_mask = attention_mask[:, ::model.config.pooling_kernel_size] | |
| assert attention_mask.shape == speech_tokens.shape | |
| for i in range(len(speech_tokens)): | |
| idx = indices[start + i] | |
| speech_token = speech_tokens[i][attention_mask[i].bool()].tolist() | |
| all_speech_tokens[idx].extend(speech_token) | |
| return all_speech_tokens | |