Spaces:
Running
on
Zero
Running
on
Zero
| from typing import TypedDict | |
| import torch | |
| import torchaudio | |
| class AudioDict(TypedDict): | |
| """Comfy's representation of AUDIO data.""" | |
| sample_rate: int | |
| waveform: torch.Tensor | |
| AudioData = AudioDict | list[AudioDict] | |
| class MtbAudio: | |
| """Base class for audio processing.""" | |
| def is_stereo( | |
| cls, | |
| audios: AudioData, | |
| ) -> bool: | |
| if isinstance(audios, list): | |
| return any(cls.is_stereo(audio) for audio in audios) | |
| else: | |
| return audios["waveform"].shape[1] == 2 | |
| def resample(audio: AudioDict, common_sample_rate: int) -> AudioDict: | |
| if audio["sample_rate"] != common_sample_rate: | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=audio["sample_rate"], new_freq=common_sample_rate | |
| ) | |
| return { | |
| "sample_rate": common_sample_rate, | |
| "waveform": resampler(audio["waveform"]), | |
| } | |
| else: | |
| return audio | |
| def to_stereo(audio: AudioDict) -> AudioDict: | |
| if audio["waveform"].shape[1] == 1: | |
| return { | |
| "sample_rate": audio["sample_rate"], | |
| "waveform": torch.cat( | |
| [audio["waveform"], audio["waveform"]], dim=1 | |
| ), | |
| } | |
| else: | |
| return audio | |
| def preprocess_audios( | |
| cls, audios: list[AudioDict] | |
| ) -> tuple[list[AudioDict], bool, int]: | |
| max_sample_rate = max([audio["sample_rate"] for audio in audios]) | |
| resampled_audios = [ | |
| cls.resample(audio, max_sample_rate) for audio in audios | |
| ] | |
| is_stereo = cls.is_stereo(audios) | |
| if is_stereo: | |
| audios = [cls.to_stereo(audio) for audio in resampled_audios] | |
| return (audios, is_stereo, max_sample_rate) | |
| class MTB_AudioCut(MtbAudio): | |
| """Basic audio cutter, values are in ms.""" | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "audio": ("AUDIO",), | |
| "length": ( | |
| ("FLOAT"), | |
| { | |
| "default": 1000.0, | |
| "min": 0.0, | |
| "max": 999999.0, | |
| "step": 1, | |
| }, | |
| ), | |
| "offset": ( | |
| ("FLOAT"), | |
| {"default": 0.0, "min": 0.0, "max": 999999.0, "step": 1}, | |
| ), | |
| }, | |
| } | |
| RETURN_TYPES = ("AUDIO",) | |
| RETURN_NAMES = ("cut_audio",) | |
| CATEGORY = "mtb/audio" | |
| FUNCTION = "cut" | |
| def cut(self, audio: AudioDict, length: float, offset: float): | |
| sample_rate = audio["sample_rate"] | |
| start_idx = int(offset * sample_rate / 1000) | |
| end_idx = min( | |
| start_idx + int(length * sample_rate / 1000), | |
| audio["waveform"].shape[-1], | |
| ) | |
| cut_waveform = audio["waveform"][:, :, start_idx:end_idx] | |
| return ( | |
| { | |
| "sample_rate": sample_rate, | |
| "waveform": cut_waveform, | |
| }, | |
| ) | |
| class MTB_AudioStack(MtbAudio): | |
| """Stack/Overlay audio inputs (dynamic inputs). | |
| - pad audios to the longest inputs. | |
| - resample audios to the highest sample rate in the inputs. | |
| - convert them all to stereo if one of the inputs is. | |
| """ | |
| def INPUT_TYPES(cls): | |
| return {"required": {}} | |
| RETURN_TYPES = ("AUDIO",) | |
| RETURN_NAMES = ("stacked_audio",) | |
| CATEGORY = "mtb/audio" | |
| FUNCTION = "stack" | |
| def stack(self, **kwargs: AudioDict) -> tuple[AudioDict]: | |
| audios, is_stereo, max_rate = self.preprocess_audios( | |
| list(kwargs.values()) | |
| ) | |
| max_length = max([audio["waveform"].shape[-1] for audio in audios]) | |
| padded_audios: list[torch.Tensor] = [] | |
| for audio in audios: | |
| padding = torch.zeros( | |
| ( | |
| 1, | |
| 2 if is_stereo else 1, | |
| max_length - audio["waveform"].shape[-1], | |
| ) | |
| ) | |
| padded_audio = torch.cat([audio["waveform"], padding], dim=-1) | |
| padded_audios.append(padded_audio) | |
| stacked_waveform = torch.stack(padded_audios, dim=0).sum(dim=0) | |
| return ( | |
| { | |
| "sample_rate": max_rate, | |
| "waveform": stacked_waveform, | |
| }, | |
| ) | |
| class MTB_AudioSequence(MtbAudio): | |
| """Sequence audio inputs (dynamic inputs). | |
| - adding silence_duration between each segment | |
| can now also be negative to overlap the clips, safely bound | |
| to the the input length. | |
| - resample audios to the highest sample rate in the inputs. | |
| - convert them all to stereo if one of the inputs is. | |
| """ | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "silence_duration": ( | |
| ("FLOAT"), | |
| {"default": 0.0, "min": -999.0, "max": 999, "step": 0.01}, | |
| ) | |
| }, | |
| } | |
| RETURN_TYPES = ("AUDIO",) | |
| RETURN_NAMES = ("sequenced_audio",) | |
| CATEGORY = "mtb/audio" | |
| FUNCTION = "sequence" | |
| def sequence(self, silence_duration: float, **kwargs: AudioDict): | |
| audios, is_stereo, max_rate = self.preprocess_audios( | |
| list(kwargs.values()) | |
| ) | |
| sequence: list[torch.Tensor] = [] | |
| for i, audio in enumerate(audios): | |
| if i > 0: | |
| if silence_duration > 0: | |
| silence = torch.zeros( | |
| ( | |
| 1, | |
| 2 if is_stereo else 1, | |
| int(silence_duration * max_rate), | |
| ) | |
| ) | |
| sequence.append(silence) | |
| elif silence_duration < 0: | |
| overlap = int(abs(silence_duration) * max_rate) | |
| previous_audio = sequence[-1] | |
| overlap = min( | |
| overlap, | |
| previous_audio.shape[-1], | |
| audio["waveform"].shape[-1], | |
| ) | |
| if overlap > 0: | |
| overlap_part = ( | |
| previous_audio[:, :, -overlap:] | |
| + audio["waveform"][:, :, :overlap] | |
| ) | |
| sequence[-1] = previous_audio[:, :, :-overlap] | |
| sequence.append(overlap_part) | |
| audio["waveform"] = audio["waveform"][:, :, overlap:] | |
| sequence.append(audio["waveform"]) | |
| sequenced_waveform = torch.cat(sequence, dim=-1) | |
| return ( | |
| { | |
| "sample_rate": max_rate, | |
| "waveform": sequenced_waveform, | |
| }, | |
| ) | |
| __nodes__ = [MTB_AudioSequence, MTB_AudioStack, MTB_AudioCut] | |