File size: 2,362 Bytes
c0a944c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torchaudio
from torchaudio.functional import resample

from transformers import AutoProcessor, EncodecModel


ALL_BANDWIDTHS = [1.1]

class VoilaTokenizer:
    def __init__(
        self,
        model_path="maitrix-org/Voila-Tokenizer",
        bandwidth_id=0,
        device="cpu",
    ):
        self.device = torch.device(device)
        self.bandwidth = ALL_BANDWIDTHS[bandwidth_id]
        self.bandwidth_id = torch.tensor([bandwidth_id], device=device)

        self.processor = AutoProcessor.from_pretrained(model_path)
        self.model = EncodecModel.from_pretrained(model_path).to(device)

        self.sampling_rate = self.processor.sampling_rate
        self.model_version = self.model.config.model_version


    @torch.no_grad()
    def encode(self, wav, sr):
        wav = torch.tensor(wav, dtype=torch.float32, device=self.device)
        if sr != self.processor.sampling_rate:
            wav = resample(wav, sr, self.processor.sampling_rate)
            sr = self.processor.sampling_rate
        if len(wav.shape) == 1:
            wav = wav[None, None, :]
        elif len(wav.shape) == 2:
            assert wav.shape[0] == 1
            wav = wav[None, :]
        elif len(wav.shape) == 3:
            assert wav.shape[0] == 1 and wav.shape[1] == 1

        # inputs = self.processor(raw_audio=wav, sampling_rate=sr, return_tensors="pt")
        encoder_outputs = self.model.encode(wav, bandwidth=self.bandwidth)
        return encoder_outputs.audio_codes[0, 0]

    @torch.no_grad()
    def decode(self, audio_codes):
        assert len(audio_codes.shape) == 2
        audio_values = self.model.decode(audio_codes[None, None, :, :], [None])[0]
        return audio_values[0, 0]

if __name__ == '__main__':
    import argparse
    import soundfile as sf

    parser = argparse.ArgumentParser()
    parser.add_argument("--wav", type=str)
    args = parser.parse_args()

    wav, sr = torchaudio.load(args.wav)
    if len(wav.shape) > 1:
        wav = wav[0]

    model = VoilaTokenizer(device="cuda")

    audio_codes = model.encode(wav, sr)
    audio_values = model.decode(audio_codes).cpu().numpy()

    tps = audio_codes.shape[-1] / (audio_values.shape[-1] / model.processor.sampling_rate)
    print(audio_codes.shape, audio_values.shape, tps)
    sf.write("audio_mt.wav", audio_values, model.processor.sampling_rate)