Spaces:
Runtime error
Runtime error
| import asyncio | |
| import traceback | |
| class Job: | |
| def __init__(self, data): | |
| self._id = None | |
| self.data = data | |
| class Node: | |
| # def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None): | |
| def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ): | |
| self.worker_id = worker_id | |
| self.input_queue = input_queue | |
| self.output_queue = output_queue | |
| self.buffer = {} | |
| self.job_sync = job_sync | |
| self.sequential_node = sequential_node | |
| self.next_i = 0 | |
| self._jobs_dequeued = 0 | |
| self._jobs_processed = 0 | |
| # throw an error if job_sync is not None and sequential_node is False | |
| if self.job_sync is not None and self.sequential_node == False: | |
| raise ValueError('job_sync is not None and sequential_node is False') | |
| async def run(self): | |
| try: | |
| while True: | |
| job: Job = await self.input_queue.get() | |
| self._jobs_dequeued += 1 | |
| if self.sequential_node == False: | |
| async for job in self.process_job(job): | |
| if self.output_queue is not None: | |
| await self.output_queue.put(job) | |
| if self.job_sync is not None: | |
| self.job_sync.append(job) | |
| self._jobs_processed += 1 | |
| else: | |
| # ensure that jobs are processed in order | |
| self.buffer[job.id] = job | |
| while self.next_i in self.buffer: | |
| job = self.buffer.pop(self.next_i) | |
| async for job in self.process_job(job): | |
| if self.output_queue is not None: | |
| await self.output_queue.put(job) | |
| if self.job_sync is not None: | |
| self.job_sync.append(job) | |
| self._jobs_processed += 1 | |
| self.next_i += 1 | |
| except Exception as e: | |
| print(f"An error occurred in node: {self.__class__.__name__} worker: {self.worker_id}: {e}") | |
| traceback.print_exc() | |
| raise # Re-raises the last exception. | |
| async def process_job(self, job: Job): | |
| raise NotImplementedError() | |
| class Pipeline: | |
| def __init__(self): | |
| self.input_queues = [] | |
| self.root_queue = None | |
| # self.output_queues = [] | |
| # self.job_sysncs = [] | |
| self.nodes= [] | |
| self.node_workers = {} | |
| self.tasks = [] | |
| self._job_id = 0 | |
| async def add_node(self, node: Node, num_workers=1, input_queue=None, output_queue=None, job_sync=None, sequential_node=False ): | |
| # input_queue must not be None | |
| if input_queue is None: | |
| raise ValueError('input_queue is None') | |
| # job_sync nodes must be sequential_nodes | |
| if job_sync is not None and sequential_node == False: | |
| raise ValueError('job_sync is not None and sequential_node is False') | |
| # sequential_nodes should one have 1 worker | |
| if sequential_node == True and num_workers != 1: | |
| raise ValueError('sequentaial nodes can only have one node (sequential_node is True and num_workers is not 1)') | |
| # output queue must not equal input_queue | |
| if output_queue == input_queue: | |
| raise ValueError('output_queue must not be the same as input_queue') | |
| node_name = node.__name__ | |
| if node_name not in self.nodes: | |
| self.nodes.append(node_name) | |
| # if input_queue is None then this is the root node | |
| if len(self.input_queues) == 0: | |
| self.root_queue = input_queue | |
| self.input_queues.append(input_queue) | |
| for i in range(num_workers): | |
| worker_id = i | |
| node_worker = node(worker_id, input_queue, output_queue, job_sync, sequential_node) | |
| if node_name not in self.node_workers: | |
| self.node_workers[node_name] = [] | |
| self.node_workers[node_name].append(node_worker) | |
| task = asyncio.create_task(node_worker.run()) | |
| self.tasks.append(task) | |
| async def enqueue_job(self, job: Job): | |
| job.id = self._job_id | |
| self._job_id += 1 | |
| await self.root_queue.put(job) | |
| async def close(self): | |
| for task in self.tasks: | |
| task.cancel() | |
| await asyncio.gather(*self.tasks, return_exceptions=True) | |