Dionyssos's picture
instantiate audiogen in demo
d912185
raw
history blame
2.41 kB
import audiofile
import numpy as np
import typing as tp
import torch
from audiocraft.loaders import load_compression_model, load_lm_model
from audiocraft.lm import LMModel
from audiocraft.conditioners import ConditioningAttributes
class AudioGen():
def __init__(self,
compression_model=None,
lm=None,
duration=.04,
top_k=249):
self.compression_model = compression_model
self.lm = lm
self.top_k = top_k
self.compression_model.eval()
self.lm.eval()
self.duration = duration
self.device = next(iter(lm.parameters())).device
@property
def frame_rate(self) -> float:
"""Roughly the number of AR steps per seconds."""
return self.compression_model.frame_rate
@property
def sample_rate(self) -> int:
"""Sample rate of the generated audio."""
return self.compression_model.sample_rate
def generate(self, descriptions):
attributes = [
ConditioningAttributes(text={'description': d}) for d in descriptions]
tokens = self._generate_tokens(attributes)
print(f'\n{tokens.shape=}\n{tokens=} FINAL 5 AUD')
return self.generate_audio(tokens)
def _generate_tokens(self, attributes):
total_gen_len = int(self.duration * self.frame_rate)
gen_tokens = self.lm.generate(conditions=attributes,
max_gen_len=total_gen_len)
gen_tokens = gen_tokens.transpose(0, 1).reshape(4, -1)[None, :, :]
return gen_tokens
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
"""Generate Audio from tokens."""
assert gen_tokens.dim() == 3
with torch.no_grad():
gen_audio = self.compression_model.decode(gen_tokens, None)
return gen_audio
device = 'cuda:0'
# https://huggingface.co/facebook/audiogen-medium
sound_generator = AudioGen(
compression_model=load_compression_model('facebook/audiogen-medium', device=device),
lm=load_lm_model('facebook/audiogen-medium', device=device).to(torch.float),
duration=.04,
top_k=1)
print('\n\n\n\n___________________')
txt = 'dogs barging in the street'
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
x /= np.abs(x).max() + 1e-7
audiofile.write('del_seane.wav', x, 16000)