|
import audiofile |
|
import numpy as np |
|
import typing as tp |
|
import torch |
|
|
|
from audiocraft.loaders import load_compression_model, load_lm_model |
|
from audiocraft.lm import LMModel |
|
from audiocraft.conditioners import ConditioningAttributes |
|
|
|
|
|
|
|
|
|
class AudioGen(): |
|
|
|
def __init__(self, |
|
compression_model=None, |
|
lm=None, |
|
duration=.04, |
|
top_k=249): |
|
|
|
self.compression_model = compression_model |
|
self.lm = lm |
|
self.top_k = top_k |
|
self.compression_model.eval() |
|
self.lm.eval() |
|
self.duration = duration |
|
self.device = next(iter(lm.parameters())).device |
|
|
|
@property |
|
def frame_rate(self) -> float: |
|
"""Roughly the number of AR steps per seconds.""" |
|
return self.compression_model.frame_rate |
|
|
|
@property |
|
def sample_rate(self) -> int: |
|
"""Sample rate of the generated audio.""" |
|
return self.compression_model.sample_rate |
|
|
|
def generate(self, descriptions): |
|
attributes = [ |
|
ConditioningAttributes(text={'description': d}) for d in descriptions] |
|
tokens = self._generate_tokens(attributes) |
|
print(f'\n{tokens.shape=}\n{tokens=} FINAL 5 AUD') |
|
return self.generate_audio(tokens) |
|
|
|
def _generate_tokens(self, attributes): |
|
total_gen_len = int(self.duration * self.frame_rate) |
|
gen_tokens = self.lm.generate(conditions=attributes, |
|
max_gen_len=total_gen_len) |
|
gen_tokens = gen_tokens.transpose(0, 1).reshape(4, -1)[None, :, :] |
|
return gen_tokens |
|
|
|
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor: |
|
"""Generate Audio from tokens.""" |
|
assert gen_tokens.dim() == 3 |
|
with torch.no_grad(): |
|
gen_audio = self.compression_model.decode(gen_tokens, None) |
|
return gen_audio |
|
|
|
device = 'cuda:0' |
|
|
|
|
|
|
|
sound_generator = AudioGen( |
|
compression_model=load_compression_model('facebook/audiogen-medium', device=device), |
|
lm=load_lm_model('facebook/audiogen-medium', device=device).to(torch.float), |
|
duration=.04, |
|
top_k=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('\n\n\n\n___________________') |
|
|
|
txt = 'dogs barging in the street' |
|
|
|
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :] |
|
x /= np.abs(x).max() + 1e-7 |
|
|
|
audiofile.write('del_seane.wav', x, 16000) |
|
|