lmzjms commited on
Commit
65ae2e3
·
1 Parent(s): e004ff0

Upload audio_foundation_models.py

Browse files
Files changed (1) hide show
  1. audio_foundation_models.py +939 -0
audio_foundation_models.py ADDED
@@ -0,0 +1,939 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
5
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
6
+ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NeuralSeq'))
7
+ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'text_to_audio/Make_An_Audio'))
8
+ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'audio_detection'))
9
+ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mono2binaural'))
10
+ import matplotlib
11
+ import librosa
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
13
+ import torch
14
+ from diffusers import StableDiffusionPipeline
15
+ from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
16
+ import re
17
+ import uuid
18
+ import soundfile
19
+ from diffusers import StableDiffusionInpaintPipeline
20
+ from PIL import Image
21
+ import numpy as np
22
+ from omegaconf import OmegaConf
23
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
24
+ import cv2
25
+ import einops
26
+ from einops import repeat
27
+ from pytorch_lightning import seed_everything
28
+ import random
29
+ from ldm.util import instantiate_from_config
30
+ from ldm.data.extract_mel_spectrogram import TRANSFORMS_16000
31
+ from pathlib import Path
32
+ from vocoder.hifigan.modules import VocoderHifigan
33
+ from vocoder.bigvgan.models import VocoderBigVGAN
34
+ from ldm.models.diffusion.ddim import DDIMSampler
35
+ from wav_evaluation.models.CLAPWrapper import CLAPWrapper
36
+ from inference.svs.ds_e2e import DiffSingerE2EInfer
37
+ from audio_to_text.inference_waveform import AudioCapModel
38
+ import whisper
39
+ from text_to_speech.TTS_binding import TTSInference
40
+ from inference.svs.ds_e2e import DiffSingerE2EInfer
41
+ from inference.tts.GenerSpeech import GenerSpeechInfer
42
+ from utils.hparams import set_hparams
43
+ from utils.hparams import hparams as hp
44
+ from utils.os_utils import move_file
45
+ import scipy.io.wavfile as wavfile
46
+ from audio_infer.utils import config as detection_config
47
+ from audio_infer.pytorch.models import PVT
48
+ from src.models import BinauralNetwork
49
+ from sound_extraction.model.LASSNet import LASSNet
50
+ from sound_extraction.utils.stft import STFT
51
+ from sound_extraction.utils.wav_io import load_wav, save_wav
52
+ from target_sound_detection.src import models as tsd_models
53
+ from target_sound_detection.src.models import event_labels
54
+ from target_sound_detection.src.utils import median_filter, decode_with_timestamps
55
+ import clip
56
+
57
+
58
+ def prompts(name, description):
59
+ def decorator(func):
60
+ func.name = name
61
+ func.description = description
62
+ return func
63
+
64
+ return decorator
65
+
66
+
67
+ def initialize_model(config, ckpt, device):
68
+ config = OmegaConf.load(config)
69
+ model = instantiate_from_config(config.model)
70
+ model.load_state_dict(torch.load(ckpt, map_location='cpu')["state_dict"], strict=False)
71
+
72
+ model = model.to(device)
73
+ model.cond_stage_model.to(model.device)
74
+ model.cond_stage_model.device = model.device
75
+ sampler = DDIMSampler(model)
76
+ return sampler
77
+
78
+
79
+ def initialize_model_inpaint(config, ckpt):
80
+ config = OmegaConf.load(config)
81
+ model = instantiate_from_config(config.model)
82
+ model.load_state_dict(torch.load(ckpt, map_location='cpu')["state_dict"], strict=False)
83
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
84
+ model = model.to(device)
85
+ print(model.device, device, model.cond_stage_model.device)
86
+ sampler = DDIMSampler(model)
87
+ return sampler
88
+
89
+
90
+ def select_best_audio(prompt, wav_list):
91
+ clap_model = CLAPWrapper('text_to_audio/Make_An_Audio/useful_ckpts/CLAP/CLAP_weights_2022.pth',
92
+ 'text_to_audio/Make_An_Audio/useful_ckpts/CLAP/config.yml',
93
+ use_cuda=torch.cuda.is_available())
94
+ text_embeddings = clap_model.get_text_embeddings([prompt])
95
+ score_list = []
96
+ for data in wav_list:
97
+ sr, wav = data
98
+ audio_embeddings = clap_model.get_audio_embeddings([(torch.FloatTensor(wav), sr)], resample=True)
99
+ score = clap_model.compute_similarity(audio_embeddings, text_embeddings,
100
+ use_logit_scale=False).squeeze().cpu().numpy()
101
+ score_list.append(score)
102
+ max_index = np.array(score_list).argmax()
103
+ print(score_list, max_index)
104
+ return wav_list[max_index]
105
+
106
+
107
+ def merge_audio(audio_path_1, audio_path_2):
108
+ merged_signal = []
109
+ sr_1, signal_1 = wavfile.read(audio_path_1)
110
+ sr_2, signal_2 = wavfile.read(audio_path_2)
111
+ merged_signal.append(signal_1)
112
+ merged_signal.append(signal_2)
113
+ merged_signal = np.hstack(merged_signal)
114
+ merged_signal = np.asarray(merged_signal, dtype=np.int16)
115
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
116
+ wavfile.write(audio_filename, sr_1, merged_signal)
117
+ return audio_filename
118
+
119
+
120
+ class T2I:
121
+ def __init__(self, device):
122
+ print("Initializing T2I to %s" % device)
123
+ self.device = device
124
+ self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
125
+ self.text_refine_tokenizer = AutoTokenizer.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
126
+ self.text_refine_model = AutoModelForCausalLM.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
127
+ self.text_refine_gpt2_pipe = pipeline("text-generation", model=self.text_refine_model,
128
+ tokenizer=self.text_refine_tokenizer, device=self.device)
129
+ self.pipe.to(device)
130
+
131
+ @prompts(name="Generate Image From User Input Text",
132
+ description="useful when you want to generate an image from a user input text and save it to a file. "
133
+ "like: generate an image of an object or something, or generate an image that includes some objects. "
134
+ "The input to this tool should be a string, representing the text used to generate image. ")
135
+ def inference(self, text):
136
+ image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
137
+ refined_text = self.text_refine_gpt2_pipe(text)[0]["generated_text"]
138
+ print(f'{text} refined to {refined_text}')
139
+ image = self.pipe(refined_text).images[0]
140
+ image.save(image_filename)
141
+ print(f"Processed T2I.run, text: {text}, image_filename: {image_filename}")
142
+ return image_filename
143
+
144
+
145
+ class ImageCaptioning:
146
+ def __init__(self, device):
147
+ print("Initializing ImageCaptioning to %s" % device)
148
+ self.device = device
149
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
150
+ self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(
151
+ self.device)
152
+
153
+ @prompts(name="Remove Something From The Photo",
154
+ description="useful when you want to remove and object or something from the photo "
155
+ "from its description or location. "
156
+ "The input to this tool should be a comma separated string of two, "
157
+ "representing the image_path and the object need to be removed. ")
158
+ def inference(self, image_path):
159
+ inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device)
160
+ out = self.model.generate(**inputs)
161
+ captions = self.processor.decode(out[0], skip_special_tokens=True)
162
+ return captions
163
+
164
+
165
+ class T2A:
166
+ def __init__(self, device):
167
+ print("Initializing Make-An-Audio to %s" % device)
168
+ self.device = device
169
+ self.sampler = initialize_model('text_to_audio/Make_An_Audio/configs/text-to-audio/txt2audio_args.yaml',
170
+ 'text_to_audio/Make_An_Audio/useful_ckpts/ta40multi_epoch=000085.ckpt',
171
+ device=device)
172
+ self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio/vocoder/logs/bigv16k53w', device=device)
173
+
174
+ def txt2audio(self, text, seed=55, scale=1.5, ddim_steps=100, n_samples=3, W=624, H=80):
175
+ SAMPLE_RATE = 16000
176
+ prng = np.random.RandomState(seed)
177
+ start_code = prng.randn(n_samples, self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8)
178
+ start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
179
+ uc = self.sampler.model.get_learned_conditioning(n_samples * [""])
180
+ c = self.sampler.model.get_learned_conditioning(n_samples * [text])
181
+ shape = [self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8] # (z_dim, 80//2^x, 848//2^x)
182
+ samples_ddim, _ = self.sampler.sample(S=ddim_steps,
183
+ conditioning=c,
184
+ batch_size=n_samples,
185
+ shape=shape,
186
+ verbose=False,
187
+ unconditional_guidance_scale=scale,
188
+ unconditional_conditioning=uc,
189
+ x_T=start_code)
190
+
191
+ x_samples_ddim = self.sampler.model.decode_first_stage(samples_ddim)
192
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) # [0, 1]
193
+
194
+ wav_list = []
195
+ for idx, spec in enumerate(x_samples_ddim):
196
+ wav = self.vocoder.vocode(spec)
197
+ wav_list.append((SAMPLE_RATE, wav))
198
+ best_wav = select_best_audio(text, wav_list)
199
+ return best_wav
200
+
201
+ @prompts(name="Generate Audio From User Input Text",
202
+ description="useful for when you want to generate an audio "
203
+ "from a user input text and it saved it to a file."
204
+ "The input to this tool should be a string, "
205
+ "representing the text used to generate audio.")
206
+ def inference(self, text, seed=55, scale=1.5, ddim_steps=100, n_samples=3, W=624, H=80):
207
+ melbins, mel_len = 80, 624
208
+ with torch.no_grad():
209
+ result = self.txt2audio(
210
+ text=text,
211
+ H=melbins,
212
+ W=mel_len
213
+ )
214
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
215
+ soundfile.write(audio_filename, result[1], samplerate=16000)
216
+ print(f"Processed T2I.run, text: {text}, audio_filename: {audio_filename}")
217
+ return audio_filename
218
+
219
+
220
+ class I2A:
221
+ def __init__(self, device):
222
+ print("Initializing Make-An-Audio-Image to %s" % device)
223
+ self.device = device
224
+ self.sampler = initialize_model('text_to_audio/Make_An_Audio/configs/img_to_audio/img2audio_args.yaml',
225
+ 'text_to_audio/Make_An_Audio/useful_ckpts/ta54_epoch=000216.ckpt',
226
+ device=device)
227
+ self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio/vocoder/logs/bigv16k53w', device=device)
228
+
229
+ def img2audio(self, image, seed=55, scale=3, ddim_steps=100, W=624, H=80):
230
+ SAMPLE_RATE = 16000
231
+ n_samples = 1 # only support 1 sample
232
+ prng = np.random.RandomState(seed)
233
+ start_code = prng.randn(n_samples, self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8)
234
+ start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
235
+ uc = self.sampler.model.get_learned_conditioning(n_samples * [""])
236
+ # image = Image.fromarray(image)
237
+ image = Image.open(image)
238
+ image = self.sampler.model.cond_stage_model.preprocess(image).unsqueeze(0)
239
+ image_embedding = self.sampler.model.cond_stage_model.forward_img(image)
240
+ c = image_embedding.repeat(n_samples, 1, 1)
241
+ shape = [self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8] # (z_dim, 80//2^x, 848//2^x)
242
+ samples_ddim, _ = self.sampler.sample(S=ddim_steps,
243
+ conditioning=c,
244
+ batch_size=n_samples,
245
+ shape=shape,
246
+ verbose=False,
247
+ unconditional_guidance_scale=scale,
248
+ unconditional_conditioning=uc,
249
+ x_T=start_code)
250
+
251
+ x_samples_ddim = self.sampler.model.decode_first_stage(samples_ddim)
252
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) # [0, 1]
253
+ wav_list = []
254
+ for idx, spec in enumerate(x_samples_ddim):
255
+ wav = self.vocoder.vocode(spec)
256
+ wav_list.append((SAMPLE_RATE, wav))
257
+ best_wav = wav_list[0]
258
+ return best_wav
259
+
260
+ @prompts(name="Generate Audio From The Image",
261
+ description="useful for when you want to generate an audio "
262
+ "based on an image. "
263
+ "The input to this tool should be a string, "
264
+ "representing the image_path. ")
265
+ def inference(self, image, seed=55, scale=3, ddim_steps=100, W=624, H=80):
266
+ melbins, mel_len = 80, 624
267
+ with torch.no_grad():
268
+ result = self.img2audio(
269
+ image=image,
270
+ H=melbins,
271
+ W=mel_len
272
+ )
273
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
274
+ soundfile.write(audio_filename, result[1], samplerate=16000)
275
+ print(f"Processed I2a.run, image_filename: {image}, audio_filename: {audio_filename}")
276
+ return audio_filename
277
+
278
+
279
+ class TTS:
280
+ def __init__(self, device=None):
281
+ self.model = TTSInference(device)
282
+
283
+ @prompts(name="Synthesize Speech Given the User Input Text",
284
+ description="useful for when you want to convert a user input text into speech audio it saved it to a file."
285
+ "The input to this tool should be a string, "
286
+ "representing the text used to be converted to speech.")
287
+ def inference(self, text):
288
+ inp = {"text": text}
289
+ out = self.model.infer_once(inp)
290
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
291
+ soundfile.write(audio_filename, out, samplerate=22050)
292
+ return audio_filename
293
+
294
+
295
+ class T2S:
296
+ def __init__(self, device=None):
297
+ if device is None:
298
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
299
+ print("Initializing DiffSinger to %s" % device)
300
+ self.device = device
301
+ self.exp_name = 'checkpoints/0831_opencpop_ds1000'
302
+ self.config = 'NeuralSeq/egs/egs_bases/svs/midi/e2e/opencpop/ds1000.yaml'
303
+ self.set_model_hparams()
304
+ self.pipe = DiffSingerE2EInfer(self.hp, device)
305
+ self.default_inp = {
306
+ 'text': '你 说 你 不 SP 懂 为 何 在 这 时 牵 手 AP',
307
+ 'notes': 'D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | rest | D#4/Eb4 | D4 | D4 | D4 | D#4/Eb4 | F4 | D#4/Eb4 | D4 | rest',
308
+ 'notes_duration': '0.113740 | 0.329060 | 0.287950 | 0.133480 | 0.150900 | 0.484730 | 0.242010 | 0.180820 | 0.343570 | 0.152050 | 0.266720 | 0.280310 | 0.633300 | 0.444590'
309
+ }
310
+
311
+ def set_model_hparams(self):
312
+ set_hparams(config=self.config, exp_name=self.exp_name, print_hparams=False)
313
+ self.hp = hp
314
+
315
+ @prompts(name="Generate Singing Voice From User Input Text, Note and Duration Sequence",
316
+ description="useful for when you want to generate a piece of singing voice (Optional: from User Input Text, Note and Duration Sequence) "
317
+ "and save it to a file."
318
+ "If Like: Generate a piece of singing voice, the input to this tool should be \"\" since there is no User Input Text, Note and Duration Sequence. "
319
+ "If Like: Generate a piece of singing voice. Text: xxx, Note: xxx, Duration: xxx. "
320
+ "Or Like: Generate a piece of singing voice. Text is xxx, note is xxx, duration is xxx."
321
+ "The input to this tool should be a comma seperated string of three, "
322
+ "representing text, note and duration sequence since User Input Text, Note and Duration Sequence are all provided. ")
323
+ def inference(self, inputs):
324
+ self.set_model_hparams()
325
+ val = inputs.split(",")
326
+ key = ['text', 'notes', 'notes_duration']
327
+ try:
328
+ inp = {k: v for k, v in zip(key, val)}
329
+ wav = self.pipe.infer_once(inp)
330
+ except:
331
+ print('Error occurs. Generate default audio sample.\n')
332
+ inp = self.default_inp
333
+ wav = self.pipe.infer_once(inp)
334
+ wav *= 32767
335
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
336
+ wavfile.write(audio_filename, self.hp['audio_sample_rate'], wav.astype(np.int16))
337
+ print(f"Processed T2S.run, audio_filename: {audio_filename}")
338
+ return audio_filename
339
+
340
+
341
+ class TTS_OOD:
342
+ def __init__(self, device):
343
+ if device is None:
344
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
345
+ print("Initializing GenerSpeech to %s" % device)
346
+ self.device = device
347
+ self.exp_name = 'checkpoints/GenerSpeech'
348
+ self.config = 'NeuralSeq/modules/GenerSpeech/config/generspeech.yaml'
349
+ self.set_model_hparams()
350
+ self.pipe = GenerSpeechInfer(self.hp, device)
351
+
352
+ def set_model_hparams(self):
353
+ set_hparams(config=self.config, exp_name=self.exp_name, print_hparams=False)
354
+ f0_stats_fn = f'{hp["binary_data_dir"]}/train_f0s_mean_std.npy'
355
+ if os.path.exists(f0_stats_fn):
356
+ hp['f0_mean'], hp['f0_std'] = np.load(f0_stats_fn)
357
+ hp['f0_mean'] = float(hp['f0_mean'])
358
+ hp['f0_std'] = float(hp['f0_std'])
359
+ hp['emotion_encoder_path'] = 'checkpoints/Emotion_encoder.pt'
360
+ self.hp = hp
361
+
362
+ @prompts(name="Style Transfer",
363
+ description="useful for when you want to generate speech samples with styles "
364
+ "(e.g., timbre, emotion, and prosody) derived from a reference custom voice. "
365
+ "Like: Generate a speech with style transferred from this voice. The text is xxx., or speak using the voice of this audio. The text is xxx."
366
+ "The input to this tool should be a comma seperated string of two, "
367
+ "representing reference audio path and input text. ")
368
+ def inference(self, inputs):
369
+ self.set_model_hparams()
370
+ key = ['ref_audio', 'text']
371
+ val = inputs.split(",")
372
+ inp = {k: v for k, v in zip(key, val)}
373
+ wav = self.pipe.infer_once(inp)
374
+ wav *= 32767
375
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
376
+ wavfile.write(audio_filename, self.hp['audio_sample_rate'], wav.astype(np.int16))
377
+ print(
378
+ f"Processed GenerSpeech.run. Input text:{val[1]}. Input reference audio: {val[0]}. Output Audio_filename: {audio_filename}")
379
+ return audio_filename
380
+
381
+
382
+ class Inpaint:
383
+ def __init__(self, device):
384
+ print("Initializing Make-An-Audio-inpaint to %s" % device)
385
+ self.device = device
386
+ self.sampler = initialize_model_inpaint('text_to_audio/Make_An_Audio/configs/inpaint/txt2audio_args.yaml',
387
+ 'text_to_audio/Make_An_Audio/useful_ckpts/inpaint7_epoch00047.ckpt')
388
+ self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio/vocoder/logs/bigv16k53w', device=device)
389
+ self.cmap_transform = matplotlib.cm.viridis
390
+
391
+ def make_batch_sd(self, mel, mask, num_samples=1):
392
+
393
+ mel = torch.from_numpy(mel)[None, None, ...].to(dtype=torch.float32)
394
+ mask = torch.from_numpy(mask)[None, None, ...].to(dtype=torch.float32)
395
+ masked_mel = (1 - mask) * mel
396
+
397
+ mel = mel * 2 - 1
398
+ mask = mask * 2 - 1
399
+ masked_mel = masked_mel * 2 - 1
400
+
401
+ batch = {
402
+ "mel": repeat(mel.to(device=self.device), "1 ... -> n ...", n=num_samples),
403
+ "mask": repeat(mask.to(device=self.device), "1 ... -> n ...", n=num_samples),
404
+ "masked_mel": repeat(masked_mel.to(device=self.device), "1 ... -> n ...", n=num_samples),
405
+ }
406
+ return batch
407
+
408
+ def gen_mel(self, input_audio_path):
409
+ SAMPLE_RATE = 16000
410
+ sr, ori_wav = wavfile.read(input_audio_path)
411
+ print("gen_mel")
412
+ print(sr, ori_wav.shape, ori_wav)
413
+ ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0
414
+ if len(ori_wav.shape) == 2: # stereo
415
+ ori_wav = librosa.to_mono(
416
+ ori_wav.T) # gradio load wav shape could be (wav_len,2) but librosa expects (2,wav_len)
417
+ print(sr, ori_wav.shape, ori_wav)
418
+ ori_wav = librosa.resample(ori_wav, orig_sr=sr, target_sr=SAMPLE_RATE)
419
+
420
+ mel_len, hop_size = 848, 256
421
+ input_len = mel_len * hop_size
422
+ if len(ori_wav) < input_len:
423
+ input_wav = np.pad(ori_wav, (0, mel_len * hop_size), constant_values=0)
424
+ else:
425
+ input_wav = ori_wav[:input_len]
426
+
427
+ mel = TRANSFORMS_16000(input_wav)
428
+ return mel
429
+
430
+ def gen_mel_audio(self, input_audio):
431
+ SAMPLE_RATE = 16000
432
+ sr, ori_wav = input_audio
433
+ print("gen_mel_audio")
434
+ print(sr, ori_wav.shape, ori_wav)
435
+
436
+ ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0
437
+ if len(ori_wav.shape) == 2: # stereo
438
+ ori_wav = librosa.to_mono(
439
+ ori_wav.T) # gradio load wav shape could be (wav_len,2) but librosa expects (2,wav_len)
440
+ print(sr, ori_wav.shape, ori_wav)
441
+ ori_wav = librosa.resample(ori_wav, orig_sr=sr, target_sr=SAMPLE_RATE)
442
+
443
+ mel_len, hop_size = 848, 256
444
+ input_len = mel_len * hop_size
445
+ if len(ori_wav) < input_len:
446
+ input_wav = np.pad(ori_wav, (0, mel_len * hop_size), constant_values=0)
447
+ else:
448
+ input_wav = ori_wav[:input_len]
449
+ mel = TRANSFORMS_16000(input_wav)
450
+ return mel
451
+
452
+ def inpaint(self, batch, seed, ddim_steps, num_samples=1, W=512, H=512):
453
+ model = self.sampler.model
454
+
455
+ prng = np.random.RandomState(seed)
456
+ start_code = prng.randn(num_samples, model.first_stage_model.embed_dim, H // 8, W // 8)
457
+ start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
458
+
459
+ c = model.get_first_stage_encoding(model.encode_first_stage(batch["masked_mel"]))
460
+ cc = torch.nn.functional.interpolate(batch["mask"],
461
+ size=c.shape[-2:])
462
+ c = torch.cat((c, cc), dim=1) # (b,c+1,h,w) 1 is mask
463
+
464
+ shape = (c.shape[1] - 1,) + c.shape[2:]
465
+ samples_ddim, _ = self.sampler.sample(S=ddim_steps,
466
+ conditioning=c,
467
+ batch_size=c.shape[0],
468
+ shape=shape,
469
+ verbose=False)
470
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
471
+
472
+ mask = batch["mask"] # [-1,1]
473
+ mel = torch.clamp((batch["mel"] + 1.0) / 2.0, min=0.0, max=1.0)
474
+ mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0)
475
+ predicted_mel = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
476
+ inpainted = (1 - mask) * mel + mask * predicted_mel
477
+ inpainted = inpainted.cpu().numpy().squeeze()
478
+ inapint_wav = self.vocoder.vocode(inpainted)
479
+
480
+ return inpainted, inapint_wav
481
+
482
+ def predict(self, input_audio, mel_and_mask, seed=55, ddim_steps=100):
483
+ SAMPLE_RATE = 16000
484
+ torch.set_grad_enabled(False)
485
+ mel_img = Image.open(mel_and_mask['image'])
486
+ mask_img = Image.open(mel_and_mask["mask"])
487
+ show_mel = np.array(mel_img.convert("L")) / 255
488
+ mask = np.array(mask_img.convert("L")) / 255
489
+ mel_bins, mel_len = 80, 848
490
+ input_mel = self.gen_mel_audio(input_audio)[:, :mel_len]
491
+ mask = np.pad(mask, ((0, 0), (0, mel_len - mask.shape[1])), mode='constant', constant_values=0)
492
+ print(mask.shape, input_mel.shape)
493
+ with torch.no_grad():
494
+ batch = self.make_batch_sd(input_mel, mask, num_samples=1)
495
+ inpainted, gen_wav = self.inpaint(
496
+ batch=batch,
497
+ seed=seed,
498
+ ddim_steps=ddim_steps,
499
+ num_samples=1,
500
+ H=mel_bins, W=mel_len
501
+ )
502
+ inpainted = inpainted[:, :show_mel.shape[1]]
503
+ color_mel = self.cmap_transform(inpainted)
504
+ input_len = int(input_audio[1].shape[0] * SAMPLE_RATE / input_audio[0])
505
+ gen_wav = (gen_wav * 32768).astype(np.int16)[:input_len]
506
+ image = Image.fromarray((color_mel * 255).astype(np.uint8))
507
+ image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
508
+ image.save(image_filename)
509
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
510
+ soundfile.write(audio_filename, gen_wav, samplerate=16000)
511
+ return image_filename, audio_filename
512
+
513
+ @prompts(name="Audio Inpainting",
514
+ description="useful for when you want to inpaint a mel spectrum of an audio and predict this audio, "
515
+ "this tool will generate a mel spectrum and you can inpaint it, receives audio_path as input. "
516
+ "The input to this tool should be a string, "
517
+ "representing the audio_path. ")
518
+ def inference(self, input_audio_path):
519
+ crop_len = 500
520
+ crop_mel = self.gen_mel(input_audio_path)[:, :crop_len]
521
+ color_mel = self.cmap_transform(crop_mel)
522
+ image = Image.fromarray((color_mel * 255).astype(np.uint8))
523
+ image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
524
+ image.save(image_filename)
525
+ return image_filename
526
+
527
+
528
+ class ASR:
529
+ def __init__(self, device):
530
+ print("Initializing Whisper to %s" % device)
531
+ self.device = device
532
+ self.model = whisper.load_model("base", device=device)
533
+
534
+ @prompts(name="Transcribe speech",
535
+ description="useful for when you want to know the text corresponding to a human speech, "
536
+ "receives audio_path as input. "
537
+ "The input to this tool should be a string, "
538
+ "representing the audio_path. ")
539
+ def inference(self, audio_path):
540
+ audio = whisper.load_audio(audio_path)
541
+ audio = whisper.pad_or_trim(audio)
542
+ mel = whisper.log_mel_spectrogram(audio).to(self.device)
543
+ _, probs = self.model.detect_language(mel)
544
+ options = whisper.DecodingOptions()
545
+ result = whisper.decode(self.model, mel, options)
546
+ return result.text
547
+
548
+ def translate_english(self, audio_path):
549
+ audio = self.model.transcribe(audio_path, language='English')
550
+ return audio['text']
551
+
552
+
553
+ class A2T:
554
+ def __init__(self, device):
555
+ print("Initializing Audio-To-Text Model to %s" % device)
556
+ self.device = device
557
+ self.model = AudioCapModel("audio_to_text/audiocaps_cntrstv_cnn14rnn_trm")
558
+
559
+ @prompts(name="Generate Text From The Audio",
560
+ description="useful for when you want to describe an audio in text, "
561
+ "receives audio_path as input. "
562
+ "The input to this tool should be a string, "
563
+ "representing the audio_path. ")
564
+ def inference(self, audio_path):
565
+ audio = whisper.load_audio(audio_path)
566
+ caption_text = self.model(audio)
567
+ return caption_text[0]
568
+
569
+
570
+ class SoundDetection:
571
+ def __init__(self, device):
572
+ self.device = device
573
+ self.sample_rate = 32000
574
+ self.window_size = 1024
575
+ self.hop_size = 320
576
+ self.mel_bins = 64
577
+ self.fmin = 50
578
+ self.fmax = 14000
579
+ self.model_type = 'PVT'
580
+ self.checkpoint_path = 'audio_detection/audio_infer/useful_ckpts/audio_detection.pth'
581
+ self.classes_num = detection_config.classes_num
582
+ self.labels = detection_config.labels
583
+ self.frames_per_second = self.sample_rate // self.hop_size
584
+ # Model = eval(self.model_type)
585
+ self.model = PVT(sample_rate=self.sample_rate, window_size=self.window_size,
586
+ hop_size=self.hop_size, mel_bins=self.mel_bins, fmin=self.fmin, fmax=self.fmax,
587
+ classes_num=self.classes_num)
588
+ checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
589
+ self.model.load_state_dict(checkpoint['model'])
590
+ self.model.to(device)
591
+
592
+ @prompts(name="Detect The Sound Event From The Audio",
593
+ description="useful for when you want to know what event in the audio and the sound event start or end time, it will return an image "
594
+ "receives audio_path as input. "
595
+ "The input to this tool should be a string, "
596
+ "representing the audio_path. ")
597
+ def inference(self, audio_path):
598
+ # Forward
599
+ (waveform, _) = librosa.core.load(audio_path, sr=self.sample_rate, mono=True)
600
+ waveform = waveform[None, :] # (1, audio_length)
601
+ waveform = torch.from_numpy(waveform)
602
+ waveform = waveform.to(self.device)
603
+ # Forward
604
+ with torch.no_grad():
605
+ self.model.eval()
606
+ batch_output_dict = self.model(waveform, None)
607
+ framewise_output = batch_output_dict['framewise_output'].data.cpu().numpy()[0]
608
+ """(time_steps, classes_num)"""
609
+ # print('Sound event detection result (time_steps x classes_num): {}'.format(
610
+ # framewise_output.shape))
611
+ import numpy as np
612
+ import matplotlib.pyplot as plt
613
+ sorted_indexes = np.argsort(np.max(framewise_output, axis=0))[::-1]
614
+ top_k = 10 # Show top results
615
+ top_result_mat = framewise_output[:, sorted_indexes[0: top_k]]
616
+ """(time_steps, top_k)"""
617
+ # Plot result
618
+ stft = librosa.core.stft(y=waveform[0].data.cpu().numpy(), n_fft=self.window_size,
619
+ hop_length=self.hop_size, window='hann', center=True)
620
+ frames_num = stft.shape[-1]
621
+ fig, axs = plt.subplots(2, 1, sharex=True, figsize=(10, 4))
622
+ axs[0].matshow(np.log(np.abs(stft)), origin='lower', aspect='auto', cmap='jet')
623
+ axs[0].set_ylabel('Frequency bins')
624
+ axs[0].set_title('Log spectrogram')
625
+ axs[1].matshow(top_result_mat.T, origin='upper', aspect='auto', cmap='jet', vmin=0, vmax=1)
626
+ axs[1].xaxis.set_ticks(np.arange(0, frames_num, self.frames_per_second))
627
+ axs[1].xaxis.set_ticklabels(np.arange(0, frames_num / self.frames_per_second))
628
+ axs[1].yaxis.set_ticks(np.arange(0, top_k))
629
+ axs[1].yaxis.set_ticklabels(np.array(self.labels)[sorted_indexes[0: top_k]])
630
+ axs[1].yaxis.grid(color='k', linestyle='solid', linewidth=0.3, alpha=0.3)
631
+ axs[1].set_xlabel('Seconds')
632
+ axs[1].xaxis.set_ticks_position('bottom')
633
+ plt.tight_layout()
634
+ image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
635
+ plt.savefig(image_filename)
636
+ return image_filename
637
+
638
+
639
+ class SoundExtraction:
640
+ def __init__(self, device):
641
+ self.device = device
642
+ self.model_file = 'sound_extraction/useful_ckpts/LASSNet.pt'
643
+ self.stft = STFT()
644
+ import torch.nn as nn
645
+ self.model = nn.DataParallel(LASSNet(device)).to(device)
646
+ checkpoint = torch.load(self.model_file)
647
+ self.model.load_state_dict(checkpoint['model'])
648
+ self.model.eval()
649
+
650
+ @prompts(name="Extract Sound Event From Mixture Audio Based On Language Description",
651
+ description="useful for when you extract target sound from a mixture audio, you can describe the target sound by text, "
652
+ "receives audio_path and text as input. "
653
+ "The input to this tool should be a comma seperated string of two, "
654
+ "representing mixture audio path and input text.")
655
+ def inference(self, inputs):
656
+ # key = ['ref_audio', 'text']
657
+ val = inputs.split(",")
658
+ audio_path = val[0] # audio_path, text
659
+ text = val[1]
660
+ waveform = load_wav(audio_path)
661
+ waveform = torch.tensor(waveform).transpose(1, 0)
662
+ mixed_mag, mixed_phase = self.stft.transform(waveform)
663
+ text_query = ['[CLS] ' + text]
664
+ mixed_mag = mixed_mag.transpose(2, 1).unsqueeze(0).to(self.device)
665
+ est_mask = self.model(mixed_mag, text_query)
666
+ est_mag = est_mask * mixed_mag
667
+ est_mag = est_mag.squeeze(1)
668
+ est_mag = est_mag.permute(0, 2, 1)
669
+ est_wav = self.stft.inverse(est_mag.cpu().detach(), mixed_phase)
670
+ est_wav = est_wav.squeeze(0).squeeze(0).numpy()
671
+ # est_path = f'output/est{i}.wav'
672
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
673
+ print('audio_filename ', audio_filename)
674
+ save_wav(est_wav, audio_filename)
675
+ return audio_filename
676
+
677
+
678
+ class Binaural:
679
+ def __init__(self, device):
680
+ self.device = device
681
+ self.model_file = 'mono2binaural/useful_ckpts/m2b/binaural_network.net'
682
+ self.position_file = ['mono2binaural/useful_ckpts/m2b/tx_positions.txt',
683
+ 'mono2binaural/useful_ckpts/m2b/tx_positions2.txt',
684
+ 'mono2binaural/useful_ckpts/m2b/tx_positions3.txt',
685
+ 'mono2binaural/useful_ckpts/m2b/tx_positions4.txt',
686
+ 'mono2binaural/useful_ckpts/m2b/tx_positions5.txt']
687
+ self.net = BinauralNetwork(view_dim=7,
688
+ warpnet_layers=4,
689
+ warpnet_channels=64,
690
+ )
691
+ self.net.load_from_file(self.model_file)
692
+ self.sr = 48000
693
+
694
+ @prompts(name="Sythesize Binaural Audio From A Mono Audio Input",
695
+ description="useful for when you want to transfer your mono audio into binaural audio, "
696
+ "receives audio_path as input. "
697
+ "The input to this tool should be a string, "
698
+ "representing the audio_path. ")
699
+ def inference(self, audio_path):
700
+ mono, sr = librosa.load(path=audio_path, sr=self.sr, mono=True)
701
+ mono = torch.from_numpy(mono)
702
+ mono = mono.unsqueeze(0)
703
+ import numpy as np
704
+ import random
705
+ rand_int = random.randint(0, 4)
706
+ view = np.loadtxt(self.position_file[rand_int]).transpose().astype(np.float32)
707
+ view = torch.from_numpy(view)
708
+ if not view.shape[-1] * 400 == mono.shape[-1]:
709
+ mono = mono[:, :(mono.shape[-1] // 400) * 400] #
710
+ if view.shape[1] * 400 > mono.shape[1]:
711
+ m_a = view.shape[1] - mono.shape[-1] // 400
712
+ rand_st = random.randint(0, m_a)
713
+ view = view[:, m_a:m_a + (mono.shape[-1] // 400)] #
714
+ # binauralize and save output
715
+ self.net.eval().to(self.device)
716
+ mono, view = mono.to(self.device), view.to(self.device)
717
+ chunk_size = 48000 # forward in chunks of 1s
718
+ rec_field = 1000 # add 1000 samples as "safe bet" since warping has undefined rec. field
719
+ rec_field -= rec_field % 400 # make sure rec_field is a multiple of 400 to match audio and view frequencies
720
+ chunks = [
721
+ {
722
+ "mono": mono[:, max(0, i - rec_field):i + chunk_size],
723
+ "view": view[:, max(0, i - rec_field) // 400:(i + chunk_size) // 400]
724
+ }
725
+ for i in range(0, mono.shape[-1], chunk_size)
726
+ ]
727
+ for i, chunk in enumerate(chunks):
728
+ with torch.no_grad():
729
+ mono = chunk["mono"].unsqueeze(0)
730
+ view = chunk["view"].unsqueeze(0)
731
+ binaural = self.net(mono, view).squeeze(0)
732
+ if i > 0:
733
+ binaural = binaural[:, -(mono.shape[-1] - rec_field):]
734
+ chunk["binaural"] = binaural
735
+ binaural = torch.cat([chunk["binaural"] for chunk in chunks], dim=-1)
736
+ binaural = torch.clamp(binaural, min=-1, max=1).cpu()
737
+ # binaural = chunked_forwarding(net, mono, view)
738
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
739
+ import torchaudio
740
+ torchaudio.save(audio_filename, binaural, sr)
741
+ # soundfile.write(audio_filename, binaural, samplerate = 48000)
742
+ print(f"Processed Binaural.run, audio_filename: {audio_filename}")
743
+ return audio_filename
744
+
745
+
746
+ class TargetSoundDetection:
747
+ def __init__(self, device):
748
+ self.device = device
749
+ self.MEL_ARGS = {
750
+ 'n_mels': 64,
751
+ 'n_fft': 2048,
752
+ 'hop_length': int(22050 * 20 / 1000),
753
+ 'win_length': int(22050 * 40 / 1000)
754
+ }
755
+ self.EPS = np.spacing(1)
756
+ self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
757
+ self.event_labels = event_labels
758
+ self.id_to_event = {i: label for i, label in enumerate(self.event_labels)}
759
+ config = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/run_config.pth',
760
+ map_location='cpu')
761
+ config_parameters = dict(config)
762
+ config_parameters['tao'] = 0.6
763
+ if 'thres' not in config_parameters.keys():
764
+ config_parameters['thres'] = 0.5
765
+ if 'time_resolution' not in config_parameters.keys():
766
+ config_parameters['time_resolution'] = 125
767
+ model_parameters = torch.load(
768
+ 'audio_detection/target_sound_detection/useful_ckpts/tsd/run_model_7_loss=-0.0724.pt'
769
+ , map_location=lambda storage, loc: storage) # load parameter
770
+ self.model = getattr(tsd_models, config_parameters['model'])(config_parameters,
771
+ inputdim=64, outputdim=2,
772
+ time_resolution=config_parameters[
773
+ 'time_resolution'],
774
+ **config_parameters['model_args'])
775
+ self.model.load_state_dict(model_parameters)
776
+ self.model = self.model.to(self.device).eval()
777
+ self.re_embeds = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/text_emb.pth')
778
+ self.ref_mel = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/ref_mel.pth')
779
+
780
+ def extract_feature(self, fname):
781
+ import soundfile as sf
782
+ y, sr = sf.read(fname, dtype='float32')
783
+ print('y ', y.shape)
784
+ ti = y.shape[0] / sr
785
+ if y.ndim > 1:
786
+ y = y.mean(1)
787
+ y = librosa.resample(y, sr, 22050)
788
+ lms_feature = np.log(librosa.feature.melspectrogram(y, **self.MEL_ARGS) + self.EPS).T
789
+ return lms_feature, ti
790
+
791
+ def build_clip(self, text):
792
+ text = clip.tokenize(text).to(self.device) # ["a diagram with dog", "a dog", "a cat"]
793
+ text_features = self.clip_model.encode_text(text)
794
+ return text_features
795
+
796
+ def cal_similarity(self, target, retrievals):
797
+ ans = []
798
+ for name in retrievals.keys():
799
+ tmp = retrievals[name]
800
+ s = torch.cosine_similarity(target.squeeze(), tmp.squeeze(), dim=0)
801
+ ans.append(s.item())
802
+ return ans.index(max(ans))
803
+
804
+ @prompts(name="Target Sound Detection",
805
+ description="useful for when you want to know when the target sound event in the audio happens. You can use language descriptions to instruct the model, "
806
+ "receives text description and audio_path as input. "
807
+ "The input to this tool should be a comma seperated string of two, "
808
+ "representing audio path and the text description. ")
809
+ def inference(self, inputs):
810
+ audio_path, text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
811
+ target_emb = self.build_clip(text) # torch type
812
+ idx = self.cal_similarity(target_emb, self.re_embeds)
813
+ target_event = self.id_to_event[idx]
814
+ embedding = self.ref_mel[target_event]
815
+ embedding = torch.from_numpy(embedding)
816
+ embedding = embedding.unsqueeze(0).to(self.device).float()
817
+ inputs, ti = self.extract_feature(audio_path)
818
+ inputs = torch.from_numpy(inputs)
819
+ inputs = inputs.unsqueeze(0).to(self.device).float()
820
+ decision, decision_up, logit = self.model(inputs, embedding)
821
+ pred = decision_up.detach().cpu().numpy()
822
+ pred = pred[:, :, 0]
823
+ frame_num = decision_up.shape[1]
824
+ time_ratio = ti / frame_num
825
+ filtered_pred = median_filter(pred, window_size=1, threshold=0.5)
826
+ time_predictions = []
827
+ for index_k in range(filtered_pred.shape[0]):
828
+ decoded_pred = []
829
+ decoded_pred_ = decode_with_timestamps(target_event, filtered_pred[index_k, :])
830
+ if len(decoded_pred_) == 0: # neg deal
831
+ decoded_pred_.append((target_event, 0, 0))
832
+ decoded_pred.append(decoded_pred_)
833
+ for num_batch in range(len(decoded_pred)): # when we test our model,the batch_size is 1
834
+ cur_pred = pred[num_batch]
835
+ # Save each frame output, for later visualization
836
+ label_prediction = decoded_pred[num_batch] # frame predict
837
+ for event_label, onset, offset in label_prediction:
838
+ time_predictions.append({
839
+ 'onset': onset * time_ratio,
840
+ 'offset': offset * time_ratio, })
841
+ ans = ''
842
+ for i, item in enumerate(time_predictions):
843
+ ans = ans + 'segment' + str(i + 1) + ' start_time: ' + str(item['onset']) + ' end_time: ' + str(
844
+ item['offset']) + '\t'
845
+ return ans
846
+
847
+
848
+ class Speech_Enh_SC:
849
+ """Speech Enhancement or Separation in single-channel
850
+ Example usage:
851
+ enh_model = Speech_Enh_SS("cuda")
852
+ enh_wav = enh_model.inference("./test_chime4_audio_M05_440C0213_PED_REAL.wav")
853
+ """
854
+
855
+ def __init__(self, device="cuda", model_name="espnet/Wangyou_Zhang_chime4_enh_train_enh_conv_tasnet_raw"):
856
+ self.model_name = model_name
857
+ self.device = device
858
+ print("Initializing ESPnet Enh to %s" % device)
859
+ self._initialize_model()
860
+
861
+ def _initialize_model(self):
862
+ from espnet_model_zoo.downloader import ModelDownloader
863
+ from espnet2.bin.enh_inference import SeparateSpeech
864
+
865
+ d = ModelDownloader()
866
+
867
+ cfg = d.download_and_unpack(self.model_name)
868
+ self.separate_speech = SeparateSpeech(
869
+ train_config=cfg["train_config"],
870
+ model_file=cfg["model_file"],
871
+ # for segment-wise process on long speech
872
+ segment_size=2.4,
873
+ hop_size=0.8,
874
+ normalize_segment_scale=False,
875
+ show_progressbar=True,
876
+ ref_channel=None,
877
+ normalize_output_wav=True,
878
+ device=self.device,
879
+ )
880
+
881
+ @prompts(name="Speech Enhancement In Single-Channel",
882
+ description="useful for when you want to enhance the quality of the speech signal by reducing background noise (single-channel), "
883
+ "receives audio_path as input."
884
+ "The input to this tool should be a string, "
885
+ "representing the audio_path. ")
886
+ def inference(self, speech_path, ref_channel=0):
887
+ speech, sr = soundfile.read(speech_path)
888
+ speech = speech[:, ref_channel]
889
+ enh_speech = self.separate_speech(speech[None, ...], fs=sr)
890
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
891
+ soundfile.write(audio_filename, enh_speech[0].squeeze(), samplerate=sr)
892
+ return audio_filename
893
+
894
+
895
+ class Speech_SS:
896
+ def __init__(self, device="cuda", model_name="lichenda/wsj0_2mix_skim_noncausal"):
897
+ self.model_name = model_name
898
+ self.device = device
899
+ print("Initializing ESPnet SS to %s" % device)
900
+ self._initialize_model()
901
+
902
+ def _initialize_model(self):
903
+ from espnet_model_zoo.downloader import ModelDownloader
904
+ from espnet2.bin.enh_inference import SeparateSpeech
905
+
906
+ d = ModelDownloader()
907
+
908
+ cfg = d.download_and_unpack(self.model_name)
909
+ self.separate_speech = SeparateSpeech(
910
+ train_config=cfg["train_config"],
911
+ model_file=cfg["model_file"],
912
+ # for segment-wise process on long speech
913
+ segment_size=2.4,
914
+ hop_size=0.8,
915
+ normalize_segment_scale=False,
916
+ show_progressbar=True,
917
+ ref_channel=None,
918
+ normalize_output_wav=True,
919
+ device=self.device,
920
+ )
921
+
922
+ @prompts(name="Speech Separation",
923
+ description="useful for when you want to separate each speech from the speech mixture, "
924
+ "receives audio_path as input."
925
+ "The input to this tool should be a string, "
926
+ "representing the audio_path. ")
927
+ def inference(self, speech_path):
928
+ speech, sr = soundfile.read(speech_path)
929
+ enh_speech = self.separate_speech(speech[None, ...], fs=sr)
930
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
931
+ if len(enh_speech) == 1:
932
+ soundfile.write(audio_filename, enh_speech[0].squeeze(), samplerate=sr)
933
+ else:
934
+ audio_filename_1 = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
935
+ soundfile.write(audio_filename_1, enh_speech[0].squeeze(), samplerate=sr)
936
+ audio_filename_2 = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
937
+ soundfile.write(audio_filename_2, enh_speech[1].squeeze(), samplerate=sr)
938
+ audio_filename = merge_audio(audio_filename_1, audio_filename_2)
939
+ return audio_filename