Spaces:
Sleeping
Sleeping
| import atexit | |
| import os | |
| import random | |
| import time | |
| import traceback | |
| import pickle | |
| from mpire.pool import WorkerPool | |
| from ditk import logging | |
| import tempfile | |
| import socket | |
| from os import path | |
| from typing import Callable, Dict, List, Optional, Tuple, Union, Set | |
| from threading import Thread | |
| from ding.framework.event_loop import EventLoop | |
| from ding.utils.design_helper import SingletonMetaclass | |
| from ding.framework.message_queue import * | |
| from ding.utils.registry_factory import MQ_REGISTRY | |
| # Avoid ipc address conflict, random should always use random seed | |
| random = random.Random() | |
| class Parallel(metaclass=SingletonMetaclass): | |
| def __init__(self) -> None: | |
| # Init will only be called once in a process | |
| self._listener = None | |
| self.is_active = False | |
| self.node_id = None | |
| self.local_id = None | |
| self.labels = set() | |
| self._event_loop = EventLoop("parallel_{}".format(id(self))) | |
| self._retries = 0 # Retries in auto recovery | |
| def _run( | |
| self, | |
| node_id: int, | |
| local_id: int, | |
| n_parallel_workers: int, | |
| labels: Optional[Set[str]] = None, | |
| auto_recover: bool = False, | |
| max_retries: int = float("inf"), | |
| mq_type: str = "nng", | |
| startup_interval: int = 1, | |
| **kwargs | |
| ) -> None: | |
| self.node_id = node_id | |
| self.local_id = local_id | |
| self.startup_interval = startup_interval | |
| self.n_parallel_workers = n_parallel_workers | |
| self.labels = labels or set() | |
| self.auto_recover = auto_recover | |
| self.max_retries = max_retries | |
| self._mq = MQ_REGISTRY.get(mq_type)(**kwargs) | |
| time.sleep(self.local_id * self.startup_interval) | |
| self._listener = Thread(target=self.listen, name="mq_listener", daemon=True) | |
| self._listener.start() | |
| self.mq_type = mq_type | |
| self.barrier_runtime = Parallel.get_barrier_runtime()(self.node_id) | |
| def runner( | |
| cls, | |
| n_parallel_workers: int, | |
| mq_type: str = "nng", | |
| attach_to: Optional[List[str]] = None, | |
| protocol: str = "ipc", | |
| address: Optional[str] = None, | |
| ports: Optional[Union[List[int], int]] = None, | |
| topology: str = "mesh", | |
| labels: Optional[Set[str]] = None, | |
| node_ids: Optional[Union[List[int], int]] = None, | |
| auto_recover: bool = False, | |
| max_retries: int = float("inf"), | |
| redis_host: Optional[str] = None, | |
| redis_port: Optional[int] = None, | |
| startup_interval: int = 1 | |
| ) -> Callable: | |
| """ | |
| Overview: | |
| This method allows you to configure parallel parameters, and now you are still in the parent process. | |
| Arguments: | |
| - n_parallel_workers (:obj:`int`): Workers to spawn. | |
| - mq_type (:obj:`str`): Embedded message queue type, i.e. nng, redis. | |
| - attach_to (:obj:`Optional[List[str]]`): The node's addresses you want to attach to. | |
| - protocol (:obj:`str`): Network protocol. | |
| - address (:obj:`Optional[str]`): Bind address, ip or file path. | |
| - ports (:obj:`Optional[List[int]]`): Candidate ports. | |
| - topology (:obj:`str`): Network topology, includes: | |
| `mesh` (default): fully connected between each other; | |
| `star`: only connect to the first node; | |
| `alone`: do not connect to any node, except the node attached to; | |
| - labels (:obj:`Optional[Set[str]]`): Labels. | |
| - node_ids (:obj:`Optional[List[int]]`): Candidate node ids. | |
| - auto_recover (:obj:`bool`): Auto recover from uncaught exceptions from main. | |
| - max_retries (:obj:`int`): Max retries for auto recover. | |
| - redis_host (:obj:`str`): Redis server host. | |
| - redis_port (:obj:`int`): Redis server port. | |
| - startup_interval (:obj:`int`): Start up interval between each task. | |
| Returns: | |
| - _runner (:obj:`Callable`): The wrapper function for main. | |
| """ | |
| all_args = locals() | |
| del all_args["cls"] | |
| args_parsers = {"nng": cls._nng_args_parser, "redis": cls._redis_args_parser} | |
| assert n_parallel_workers > 0, "Parallel worker number should bigger than 0" | |
| def _runner(main_process: Callable, *args, **kwargs) -> None: | |
| """ | |
| Overview: | |
| Prepare to run in subprocess. | |
| Arguments: | |
| - main_process (:obj:`Callable`): The main function, your program start from here. | |
| """ | |
| runner_params = args_parsers[mq_type](**all_args) | |
| params_group = [] | |
| for i, runner_kwargs in enumerate(runner_params): | |
| runner_kwargs["local_id"] = i | |
| params_group.append([runner_kwargs, (main_process, args, kwargs)]) | |
| if n_parallel_workers == 1: | |
| cls._subprocess_runner(*params_group[0]) | |
| else: | |
| with WorkerPool(n_jobs=n_parallel_workers, start_method="spawn", daemon=False) as pool: | |
| # Cleanup the pool just in case the program crashes. | |
| atexit.register(pool.__exit__) | |
| pool.map(cls._subprocess_runner, params_group) | |
| return _runner | |
| def _nng_args_parser( | |
| cls, | |
| n_parallel_workers: int, | |
| attach_to: Optional[List[str]] = None, | |
| protocol: str = "ipc", | |
| address: Optional[str] = None, | |
| ports: Optional[Union[List[int], int]] = None, | |
| topology: str = "mesh", | |
| node_ids: Optional[Union[List[int], int]] = None, | |
| **kwargs | |
| ) -> Dict[str, dict]: | |
| attach_to = attach_to or [] | |
| nodes = cls.get_node_addrs(n_parallel_workers, protocol=protocol, address=address, ports=ports) | |
| def cleanup_nodes(): | |
| for node in nodes: | |
| protocol, file_path = node.split("://") | |
| if protocol == "ipc" and path.exists(file_path): | |
| os.remove(file_path) | |
| atexit.register(cleanup_nodes) | |
| def topology_network(i: int) -> List[str]: | |
| if topology == "mesh": | |
| return nodes[:i] + attach_to | |
| elif topology == "star": | |
| return nodes[:min(1, i)] + attach_to | |
| elif topology == "alone": | |
| return attach_to | |
| else: | |
| raise ValueError("Unknown topology: {}".format(topology)) | |
| runner_params = [] | |
| candidate_node_ids = cls.padding_param(node_ids, n_parallel_workers, 0) | |
| for i in range(n_parallel_workers): | |
| runner_kwargs = { | |
| **kwargs, | |
| "node_id": candidate_node_ids[i], | |
| "listen_to": nodes[i], | |
| "attach_to": topology_network(i), | |
| "n_parallel_workers": n_parallel_workers, | |
| } | |
| runner_params.append(runner_kwargs) | |
| return runner_params | |
| def _redis_args_parser(cls, n_parallel_workers: int, node_ids: Optional[Union[List[int], int]] = None, **kwargs): | |
| runner_params = [] | |
| candidate_node_ids = cls.padding_param(node_ids, n_parallel_workers, 0) | |
| for i in range(n_parallel_workers): | |
| runner_kwargs = {**kwargs, "n_parallel_workers": n_parallel_workers, "node_id": candidate_node_ids[i]} | |
| runner_params.append(runner_kwargs) | |
| return runner_params | |
| def _subprocess_runner(cls, runner_kwargs: dict, main_params: Tuple[Union[List, Dict]]) -> None: | |
| """ | |
| Overview: | |
| Really run in subprocess. | |
| Arguments: | |
| - runner_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for runner. | |
| - main_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for main function. | |
| """ | |
| logging.getLogger().setLevel(logging.INFO) | |
| main_process, args, kwargs = main_params | |
| with Parallel() as router: | |
| router.is_active = True | |
| router._run(**runner_kwargs) | |
| time.sleep(0.3) # Waiting for network pairing | |
| router._supervised_runner(main_process, *args, **kwargs) | |
| def _supervised_runner(self, main: Callable, *args, **kwargs) -> None: | |
| """ | |
| Overview: | |
| Run in supervised mode. | |
| Arguments: | |
| - main (:obj:`Callable`): Main function. | |
| """ | |
| if self.auto_recover: | |
| while True: | |
| try: | |
| main(*args, **kwargs) | |
| break | |
| except Exception as e: | |
| if self._retries < self.max_retries: | |
| logging.warning( | |
| "Auto recover from exception: {}, node: {}, retries: {}".format( | |
| e, self.node_id, self._retries | |
| ) | |
| ) | |
| logging.warning(traceback.format_exc()) | |
| self._retries += 1 | |
| else: | |
| logging.warning( | |
| "Exceed the max retries, node: {}, retries: {}, max_retries: {}".format( | |
| self.node_id, self._retries, self.max_retries | |
| ) | |
| ) | |
| raise e | |
| else: | |
| main(*args, **kwargs) | |
| def get_node_addrs( | |
| cls, | |
| n_workers: int, | |
| protocol: str = "ipc", | |
| address: Optional[str] = None, | |
| ports: Optional[Union[List[int], int]] = None | |
| ) -> None: | |
| if protocol == "ipc": | |
| node_name = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=4)) | |
| tmp_dir = tempfile.gettempdir() | |
| nodes = ["ipc://{}/ditask_{}_{}.ipc".format(tmp_dir, node_name, i) for i in range(n_workers)] | |
| elif protocol == "tcp": | |
| address = address or cls.get_ip() | |
| ports = cls.padding_param(ports, n_workers, 50515) | |
| assert len(ports) == n_workers, "The number of ports must be the same as the number of workers, \ | |
| now there are {} ports and {} workers".format(len(ports), n_workers) | |
| nodes = ["tcp://{}:{}".format(address, port) for port in ports] | |
| else: | |
| raise Exception("Unknown protocol {}".format(protocol)) | |
| return nodes | |
| def padding_param(cls, int_or_list: Optional[Union[List[int], int]], n_max: int, start_value: int) -> List[int]: | |
| """ | |
| Overview: | |
| Padding int or list param to the length of n_max. | |
| Arguments: | |
| - int_or_list (:obj:`Optional[Union[List[int], int]]`): Int or list typed value. | |
| - n_max (:obj:`int`): Max length. | |
| - start_value (:obj:`int`): Start from value. | |
| """ | |
| param = int_or_list | |
| if isinstance(param, List) and len(param) == 1: | |
| param = param[0] # List with only 1 element is equal to int | |
| if isinstance(param, int): | |
| param = range(param, param + n_max) | |
| else: | |
| param = param or range(start_value, start_value + n_max) | |
| return param | |
| def listen(self): | |
| self._mq.listen() | |
| while True: | |
| if not self._mq: | |
| break | |
| msg = self._mq.recv() | |
| # msg is none means that the message queue is no longer being listened to, | |
| # especially if the message queue is already closed | |
| if not msg: | |
| break | |
| topic, msg = msg | |
| self._handle_message(topic, msg) | |
| def on(self, event: str, fn: Callable) -> None: | |
| """ | |
| Overview: | |
| Register an remote event on parallel instance, this function will be executed \ | |
| when a remote process emit this event via network. | |
| Arguments: | |
| - event (:obj:`str`): Event name. | |
| - fn (:obj:`Callable`): Function body. | |
| """ | |
| if self.is_active: | |
| self._mq.subscribe(event) | |
| self._event_loop.on(event, fn) | |
| def once(self, event: str, fn: Callable) -> None: | |
| """ | |
| Overview: | |
| Register an remote event which will only call once on parallel instance, | |
| this function will be executed when a remote process emit this event via network. | |
| Arguments: | |
| - event (:obj:`str`): Event name. | |
| - fn (:obj:`Callable`): Function body. | |
| """ | |
| if self.is_active: | |
| self._mq.subscribe(event) | |
| self._event_loop.once(event, fn) | |
| def off(self, event: str) -> None: | |
| """ | |
| Overview: | |
| Unregister an event. | |
| Arguments: | |
| - event (:obj:`str`): Event name. | |
| """ | |
| if self.is_active: | |
| self._mq.unsubscribe(event) | |
| self._event_loop.off(event) | |
| def emit(self, event: str, *args, **kwargs) -> None: | |
| """ | |
| Overview: | |
| Send an remote event via network to subscribed processes. | |
| Arguments: | |
| - event (:obj:`str`): Event name. | |
| """ | |
| if self.is_active: | |
| payload = {"a": args, "k": kwargs} | |
| try: | |
| data = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) | |
| except AttributeError as e: | |
| logging.error("Arguments are not pickable! Event: {}, Args: {}".format(event, args)) | |
| raise e | |
| self._mq.publish(event, data) | |
| def _handle_message(self, topic: str, msg: bytes) -> None: | |
| """ | |
| Overview: | |
| Recv and parse payload from other processes, and call local functions. | |
| Arguments: | |
| - topic (:obj:`str`): Recevied topic. | |
| - msg (:obj:`bytes`): Recevied message. | |
| """ | |
| event = topic | |
| if not self._event_loop.listened(event): | |
| logging.debug("Event {} was not listened in parallel {}".format(event, self.node_id)) | |
| return | |
| try: | |
| payload = pickle.loads(msg) | |
| except Exception as e: | |
| logging.error("Error when unpacking message on node {}, msg: {}".format(self.node_id, e)) | |
| return | |
| self._event_loop.emit(event, *payload["a"], **payload["k"]) | |
| def get_ip(cls): | |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
| try: | |
| # doesn't even have to be reachable | |
| s.connect(('10.255.255.255', 1)) | |
| ip = s.getsockname()[0] | |
| except Exception: | |
| ip = '127.0.0.1' | |
| finally: | |
| s.close() | |
| return ip | |
| def get_attch_to_len(self) -> int: | |
| """ | |
| Overview: | |
| Get the length of the 'attach_to' list of message queue. | |
| Returns: | |
| int: the length of the self._mq.attach_to. Returns 0 if self._mq is not initialized | |
| """ | |
| if self._mq: | |
| if hasattr(self._mq, 'attach_to'): | |
| return len(self._mq.attach_to) | |
| return 0 | |
| def __enter__(self) -> "Parallel": | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.stop() | |
| def stop(self): | |
| logging.info("Stopping parallel worker on node: {}".format(self.node_id)) | |
| self.is_active = False | |
| time.sleep(0.03) | |
| if self._mq: | |
| self._mq.stop() | |
| self._mq = None | |
| if self._listener: | |
| self._listener.join(timeout=1) | |
| self._listener = None | |
| self._event_loop.stop() | |
| def get_barrier_runtime(cls): | |
| # We get the BarrierRuntime object in the closure to avoid circular import. | |
| from ding.framework.middleware.barrier import BarrierRuntime | |
| return BarrierRuntime | |