VibeVoice-Colab / modular /streamer.py
NeuralFalcon's picture
Upload 34 files
909e36b verified
from __future__ import annotations
import torch
import asyncio
from queue import Queue
from typing import TYPE_CHECKING, Optional
from transformers.generation import BaseStreamer
class AudioStreamer(BaseStreamer):
"""
Audio streamer that stores audio chunks in queues for each sample in the batch.
This allows streaming audio generation for multiple samples simultaneously.
Parameters:
batch_size (`int`):
The batch size for generation
stop_signal (`any`, *optional*):
The signal to put in the queue when generation ends. Defaults to None.
timeout (`float`, *optional*):
The timeout for the audio queue. If `None`, the queue will block indefinitely.
"""
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
self.batch_size = batch_size
self.stop_signal = stop_signal
self.timeout = timeout
# Create a queue for each sample in the batch
self.audio_queues = [Queue() for _ in range(batch_size)]
self.finished_flags = [False for _ in range(batch_size)]
self.sample_indices_map = {} # Maps from sample index to queue index
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
"""
Receives audio chunks and puts them in the appropriate queues.
Args:
audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
sample_indices: Tensor indicating which samples these chunks belong to
"""
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
# Convert to numpy or keep as tensor based on preference
audio_chunk = audio_chunks[i].detach().cpu()
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
def end(self, sample_indices: Optional[torch.Tensor] = None):
"""
Signals the end of generation for specified samples or all samples.
Args:
sample_indices: Optional tensor of sample indices to end. If None, ends all.
"""
if sample_indices is None:
# End all samples
for idx in range(self.batch_size):
if not self.finished_flags[idx]:
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
else:
# End specific samples
for sample_idx in sample_indices:
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
if idx < self.batch_size and not self.finished_flags[idx]:
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
def __iter__(self):
"""Returns an iterator over the batch of audio streams."""
return AudioBatchIterator(self)
def get_stream(self, sample_idx: int):
"""Get the audio stream for a specific sample."""
if sample_idx >= self.batch_size:
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
return AudioSampleIterator(self, sample_idx)
class AudioSampleIterator:
"""Iterator for a single audio stream from the batch."""
def __init__(self, streamer: AudioStreamer, sample_idx: int):
self.streamer = streamer
self.sample_idx = sample_idx
def __iter__(self):
return self
def __next__(self):
value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
if value == self.streamer.stop_signal:
raise StopIteration()
return value
class AudioBatchIterator:
"""Iterator that yields audio chunks for all samples in the batch."""
def __init__(self, streamer: AudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __iter__(self):
return self
def __next__(self):
if not self.active_samples:
raise StopIteration()
batch_chunks = {}
samples_to_remove = set()
# Try to get chunks from all active samples
for idx in self.active_samples:
try:
value = self.streamer.audio_queues[idx].get(block=False)
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except:
# Queue is empty for this sample, skip it this iteration
pass
# Remove finished samples
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
# If no chunks were ready but we still have active samples,
# wait a bit and try again
import time
time.sleep(0.01)
return self.__next__()
else:
raise StopIteration()
class AsyncAudioStreamer(AudioStreamer):
"""
Async version of AudioStreamer for use in async contexts.
"""
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
super().__init__(batch_size, stop_signal, timeout)
# Replace regular queues with async queues
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
self.loop = asyncio.get_running_loop()
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
"""Put audio chunks in the appropriate async queues."""
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
audio_chunk = audio_chunks[i].detach().cpu()
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, audio_chunk
)
def end(self, sample_indices: Optional[torch.Tensor] = None):
"""Signal the end of generation for specified samples."""
if sample_indices is None:
indices_to_end = range(self.batch_size)
else:
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
for idx in indices_to_end:
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, self.stop_signal
)
self.finished_flags[idx] = True
async def get_stream(self, sample_idx: int):
"""Get async iterator for a specific sample's audio stream."""
if sample_idx >= self.batch_size:
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
while True:
value = await self.audio_queues[sample_idx].get()
if value == self.stop_signal:
break
yield value
def __aiter__(self):
"""Returns an async iterator over all audio streams."""
return AsyncAudioBatchIterator(self)
class AsyncAudioBatchIterator:
"""Async iterator for batch audio streaming."""
def __init__(self, streamer: AsyncAudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __aiter__(self):
return self
async def __anext__(self):
if not self.active_samples:
raise StopAsyncIteration()
batch_chunks = {}
samples_to_remove = set()
# Create tasks for all active samples
tasks = {
idx: asyncio.create_task(self._get_chunk(idx))
for idx in self.active_samples
}
# Wait for at least one chunk to be ready
done, pending = await asyncio.wait(
tasks.values(),
return_when=asyncio.FIRST_COMPLETED,
timeout=self.streamer.timeout
)
# Cancel pending tasks
for task in pending:
task.cancel()
# Process completed tasks
for idx, task in tasks.items():
if task in done:
try:
value = await task
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except asyncio.CancelledError:
pass
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
# Try again if we still have active samples
return await self.__anext__()
else:
raise StopAsyncIteration()
async def _get_chunk(self, idx):
"""Helper to get a chunk from a specific queue."""
return await self.streamer.audio_queues[idx].get()