from torch.nn.functional import pad def pad_cut_batch_audio(wavs, new_shape): wav_length = wavs.shape[-1] new_length = new_shape[-1] if wav_length > new_length: wavs = wavs[:, :, :new_length] elif wav_length < new_length: wavs = pad(wavs, (0, new_length - wav_length)) return wavs