from __future__ import annotations import torch import asyncio from queue import Queue from typing import TYPE_CHECKING, Optional from transformers.generation import BaseStreamer class AudioStreamer(BaseStreamer): """ Audio streamer that stores audio chunks in queues for each sample in the batch. This allows streaming audio generation for multiple samples simultaneously. Parameters: batch_size (`int`): The batch size for generation stop_signal (`any`, *optional*): The signal to put in the queue when generation ends. Defaults to None. timeout (`float`, *optional*): The timeout for the audio queue. If `None`, the queue will block indefinitely. """ def __init__( self, batch_size: int, stop_signal: Optional[any] = None, timeout: Optional[float] = None, ): self.batch_size = batch_size self.stop_signal = stop_signal self.timeout = timeout # Create a queue for each sample in the batch self.audio_queues = [Queue() for _ in range(batch_size)] self.finished_flags = [False for _ in range(batch_size)] self.sample_indices_map = {} # Maps from sample index to queue index def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): """ Receives audio chunks and puts them in the appropriate queues. Args: audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks sample_indices: Tensor indicating which samples these chunks belong to """ for i, sample_idx in enumerate(sample_indices): idx = sample_idx.item() if idx < self.batch_size and not self.finished_flags[idx]: # Convert to numpy or keep as tensor based on preference audio_chunk = audio_chunks[i].detach().cpu() self.audio_queues[idx].put(audio_chunk, timeout=self.timeout) def end(self, sample_indices: Optional[torch.Tensor] = None): """ Signals the end of generation for specified samples or all samples. Args: sample_indices: Optional tensor of sample indices to end. If None, ends all. """ if sample_indices is None: # End all samples for idx in range(self.batch_size): if not self.finished_flags[idx]: self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) self.finished_flags[idx] = True else: # End specific samples for sample_idx in sample_indices: idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx if idx < self.batch_size and not self.finished_flags[idx]: self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) self.finished_flags[idx] = True def __iter__(self): """Returns an iterator over the batch of audio streams.""" return AudioBatchIterator(self) def get_stream(self, sample_idx: int): """Get the audio stream for a specific sample.""" if sample_idx >= self.batch_size: raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") return AudioSampleIterator(self, sample_idx) class AudioSampleIterator: """Iterator for a single audio stream from the batch.""" def __init__(self, streamer: AudioStreamer, sample_idx: int): self.streamer = streamer self.sample_idx = sample_idx def __iter__(self): return self def __next__(self): value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout) if value == self.streamer.stop_signal: raise StopIteration() return value class AudioBatchIterator: """Iterator that yields audio chunks for all samples in the batch.""" def __init__(self, streamer: AudioStreamer): self.streamer = streamer self.active_samples = set(range(streamer.batch_size)) def __iter__(self): return self def __next__(self): if not self.active_samples: raise StopIteration() batch_chunks = {} samples_to_remove = set() # Try to get chunks from all active samples for idx in self.active_samples: try: value = self.streamer.audio_queues[idx].get(block=False) if value == self.streamer.stop_signal: samples_to_remove.add(idx) else: batch_chunks[idx] = value except: # Queue is empty for this sample, skip it this iteration pass # Remove finished samples self.active_samples -= samples_to_remove if batch_chunks: return batch_chunks elif self.active_samples: # If no chunks were ready but we still have active samples, # wait a bit and try again import time time.sleep(0.01) return self.__next__() else: raise StopIteration() class AsyncAudioStreamer(AudioStreamer): """ Async version of AudioStreamer for use in async contexts. """ def __init__( self, batch_size: int, stop_signal: Optional[any] = None, timeout: Optional[float] = None, ): super().__init__(batch_size, stop_signal, timeout) # Replace regular queues with async queues self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] self.loop = asyncio.get_running_loop() def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): """Put audio chunks in the appropriate async queues.""" for i, sample_idx in enumerate(sample_indices): idx = sample_idx.item() if idx < self.batch_size and not self.finished_flags[idx]: audio_chunk = audio_chunks[i].detach().cpu() self.loop.call_soon_threadsafe( self.audio_queues[idx].put_nowait, audio_chunk ) def end(self, sample_indices: Optional[torch.Tensor] = None): """Signal the end of generation for specified samples.""" if sample_indices is None: indices_to_end = range(self.batch_size) else: indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices] for idx in indices_to_end: if idx < self.batch_size and not self.finished_flags[idx]: self.loop.call_soon_threadsafe( self.audio_queues[idx].put_nowait, self.stop_signal ) self.finished_flags[idx] = True async def get_stream(self, sample_idx: int): """Get async iterator for a specific sample's audio stream.""" if sample_idx >= self.batch_size: raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") while True: value = await self.audio_queues[sample_idx].get() if value == self.stop_signal: break yield value def __aiter__(self): """Returns an async iterator over all audio streams.""" return AsyncAudioBatchIterator(self) class AsyncAudioBatchIterator: """Async iterator for batch audio streaming.""" def __init__(self, streamer: AsyncAudioStreamer): self.streamer = streamer self.active_samples = set(range(streamer.batch_size)) def __aiter__(self): return self async def __anext__(self): if not self.active_samples: raise StopAsyncIteration() batch_chunks = {} samples_to_remove = set() # Create tasks for all active samples tasks = { idx: asyncio.create_task(self._get_chunk(idx)) for idx in self.active_samples } # Wait for at least one chunk to be ready done, pending = await asyncio.wait( tasks.values(), return_when=asyncio.FIRST_COMPLETED, timeout=self.streamer.timeout ) # Cancel pending tasks for task in pending: task.cancel() # Process completed tasks for idx, task in tasks.items(): if task in done: try: value = await task if value == self.streamer.stop_signal: samples_to_remove.add(idx) else: batch_chunks[idx] = value except asyncio.CancelledError: pass self.active_samples -= samples_to_remove if batch_chunks: return batch_chunks elif self.active_samples: # Try again if we still have active samples return await self.__anext__() else: raise StopAsyncIteration() async def _get_chunk(self, idx): """Helper to get a chunk from a specific queue.""" return await self.streamer.audio_queues[idx].get()