streaming generate

#2
by weege007 - opened
from transformers.generation.streamers import BaseStreamer

class TokenStreamer(BaseStreamer):
    def __init__(self, skip_prompt: bool = False, timeout=None):
        self.skip_prompt = skip_prompt

        # variables used in the streaming process
        self.token_queue = Queue()
        self.stop_signal = None
        self.next_tokens_are_prompt = True
        self.timeout = timeout

    def put(self, value):
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("TextStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        for token in value.tolist():
            self.token_queue.put(token)

    def end(self):
        self.token_queue.put(self.stop_signal)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.token_queue.get(timeout=self.timeout)
        if value == self.stop_signal:
            raise StopIteration()
        else:
            return value

#TTS start!
with torch.no_grad():
 
    formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"

    # Tokenize the text
    chat = [
        {"role": "user", "content": "Convert the text to speech:" + formatted_text},
        {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
    ]

    input_ids = tokenizer.apply_chat_template(
        chat, 
        tokenize=True, 
        return_tensors='pt', 
        continue_final_message=True
    )
    input_ids = input_ids.to('cuda')
    speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
    streamer = TokenStreamer(skip_prompt=True)
    generation_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_length=2048,  # We trained our model with a max length of 2048
        eos_token_id= speech_end_id ,
        do_sample=True,    
        top_p=1,           #  Adjusts the diversity of generated content
        temperature=0.8,   #  Controls randomness in output
    )
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    i = 0
    batch_size = 60
    generated_ids=[]
    j=0
    for token_id in streamer:
        print(token_id, end=',', flush=True)
        generated_ids.append(token_id)
        if i>0 and i % batch_size == 0:
            #print(generated_ids)
            speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
            # Convert  token <|s_23456|> to int 23456 
            speech_tokens = extract_speech_ids(speech_tokens)
            speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
            # Decode the speech tokens to speech waveform
            gen_wav = Codec_model.decode_code(speech_tokens)
            sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
            generated_ids=[]
            j+=1
        i += 1
        #yield token_id
    if len(generated_ids)>0:
        speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
        # Convert  token <|s_23456|> to int 23456 
        speech_tokens = extract_speech_ids(speech_tokens)
        speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
        # Decode the speech tokens to speech waveform
        gen_wav = Codec_model.decode_code(speech_tokens)
        sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)

colab 笔记:https://github.com/weedge/doraemon-nb/blob/main/LLaSA.ipynb

HKUST Audio org

Thank you for sharing on Colab! It’s very well-written and helpful!

Hi! Thank you for this code! I've looked over your collab, but there is something i cant figure out to use this with prompt wav. What should I do with the generated_ids? When a sync call to generateg gives you an output, you would normally do:

# from repo code example
# Extract the speech tokens
generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)   

How does this integrate in your code since the batch size messes with the ids?

Also i've managed to expose the TTS engine through a Flask endpoint using synchronous generation (from the code example of the repo) with both text and voice prompts, but how would one use this streaming to start streaming chunks of audio x-wav with Flask? I've experimented a bit with the code and input prompt, but I cant seem to write the files into io.BytesIO() buffers and return them in Response via Flask streaming_with_context while the batches are being generated... Do you have an idea on how could this be achieved?

Thanks a lot

Sign up or log in to comment