streaming generate
from transformers.generation.streamers import BaseStreamer
class TokenStreamer(BaseStreamer):
def __init__(self, skip_prompt: bool = False, timeout=None):
self.skip_prompt = skip_prompt
# variables used in the streaming process
self.token_queue = Queue()
self.stop_signal = None
self.next_tokens_are_prompt = True
self.timeout = timeout
def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
for token in value.tolist():
self.token_queue.put(token)
def end(self):
self.token_queue.put(self.stop_signal)
def __iter__(self):
return self
def __next__(self):
value = self.token_queue.get(timeout=self.timeout)
if value == self.stop_signal:
raise StopIteration()
else:
return value
#TTS start!
with torch.no_grad():
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
# Tokenize the text
chat = [
{"role": "user", "content": "Convert the text to speech:" + formatted_text},
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
]
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
)
input_ids = input_ids.to('cuda')
speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
streamer = TokenStreamer(skip_prompt=True)
generation_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_length=2048, # We trained our model with a max length of 2048
eos_token_id= speech_end_id ,
do_sample=True,
top_p=1, # Adjusts the diversity of generated content
temperature=0.8, # Controls randomness in output
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
i = 0
batch_size = 60
generated_ids=[]
j=0
for token_id in streamer:
print(token_id, end=',', flush=True)
generated_ids.append(token_id)
if i>0 and i % batch_size == 0:
#print(generated_ids)
speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
# Convert token <|s_23456|> to int 23456
speech_tokens = extract_speech_ids(speech_tokens)
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
# Decode the speech tokens to speech waveform
gen_wav = Codec_model.decode_code(speech_tokens)
sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
generated_ids=[]
j+=1
i += 1
#yield token_id
if len(generated_ids)>0:
speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
# Convert token <|s_23456|> to int 23456
speech_tokens = extract_speech_ids(speech_tokens)
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
# Decode the speech tokens to speech waveform
gen_wav = Codec_model.decode_code(speech_tokens)
sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
colab 笔记:https://github.com/weedge/doraemon-nb/blob/main/LLaSA.ipynb
Thank you for sharing on Colab! It’s very well-written and helpful!
Hi! Thank you for this code! I've looked over your collab, but there is something i cant figure out to use this with prompt wav. What should I do with the generated_ids? When a sync call to generateg gives you an output, you would normally do:
# from repo code example
# Extract the speech tokens
generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
How does this integrate in your code since the batch size messes with the ids?
Also i've managed to expose the TTS engine through a Flask endpoint using synchronous generation (from the code example of the repo) with both text and voice prompts, but how would one use this streaming to start streaming chunks of audio x-wav with Flask? I've experimented a bit with the code and input prompt, but I cant seem to write the files into io.BytesIO() buffers and return them in Response via Flask streaming_with_context while the batches are being generated... Do you have an idea on how could this be achieved?
Thanks a lot