Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| import os | |
| import torch | |
| import numpy as np | |
| import uuid | |
| import treetensor.torch as ttorch | |
| from abc import ABC, abstractmethod | |
| from ditk import logging | |
| from time import sleep, time | |
| from threading import Lock, Thread | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| from ding.data import FileStorage, Storage | |
| from os import path | |
| from ding.data.shm_buffer import ShmBuffer | |
| from ding.framework.supervisor import RecvPayload, Supervisor, ChildType, SendPayload | |
| class ShmObject: | |
| id_: ShmBuffer | |
| buf: Any | |
| class StorageWorker: | |
| def load(self, storage: Storage) -> Any: | |
| return storage.load() | |
| class StorageLoader(Supervisor, ABC): | |
| def __init__(self, worker_num: int = 3) -> None: | |
| """ | |
| Overview: | |
| Save and send data synchronously and load them asynchronously. | |
| Arguments: | |
| - worker_num (:obj:`int`): Subprocess worker number. | |
| """ | |
| super().__init__(type_=ChildType.PROCESS) | |
| self._load_lock = Lock() # Load (first meet) should be called one by one. | |
| self._callback_map: Dict[str, Callable] = {} | |
| self._shm_obj_map: Dict[int, ShmObject] = {} | |
| self._worker_num = worker_num | |
| self._req_count = 0 | |
| def shutdown(self, timeout: Optional[float] = None) -> None: | |
| super().shutdown(timeout) | |
| self._recv_loop = None | |
| self._callback_map = {} | |
| self._shm_obj_map = {} | |
| self._req_count = 0 | |
| def start_link(self) -> None: | |
| if not self._running: | |
| super().start_link() | |
| self._recv_loop = Thread(target=self._loop_recv, daemon=True) | |
| self._recv_loop.start() | |
| def _next_proc_id(self): | |
| return self._req_count % self._worker_num | |
| def save(self, obj: Union[Dict, List]) -> Storage: | |
| """ | |
| Overview: | |
| Save data with a storage object synchronously. | |
| Arguments: | |
| - obj (:obj:`Union[Dict, List]`): The data (traj or episodes), can be numpy, tensor or treetensor. | |
| Returns: | |
| - storage (:obj:`Storage`): The storage object. | |
| """ | |
| raise NotImplementedError | |
| def load(self, storage: Storage, callback: Callable): | |
| """ | |
| Overview: | |
| Load data from a storage object asynchronously. \ | |
| This function will analysis the data structure when first meet a new data, \ | |
| then alloc a shared memory buffer for each subprocess, these shared memory buffer \ | |
| will be responsible for asynchronously loading data into memory. | |
| Arguments: | |
| - storage (:obj:`Storage`): The storage object. | |
| - callback (:obj:`Callable`): Callback function after data loaded. | |
| """ | |
| with self._load_lock: | |
| if not self._running: | |
| self._first_meet(storage, callback) | |
| return | |
| payload = SendPayload(proc_id=self._next_proc_id, method="load", args=[storage]) | |
| self._callback_map[payload.req_id] = callback | |
| self.send(payload) | |
| self._req_count += 1 | |
| def _first_meet(self, storage: Storage, callback: Callable): | |
| """ | |
| Overview: | |
| When first meet an object type, we'll load this object directly and analysis the structure, | |
| to allocate the shared memory object and create subprocess workers. | |
| Arguments: | |
| - storage (:obj:`Storage`): The storage object. | |
| - callback (:obj:`Callable`): Callback function after data loaded. | |
| """ | |
| obj = storage.load() | |
| # Create three workers for each usage type. | |
| for i in range(self._worker_num): | |
| shm_obj = self._create_shm_buffer(obj) | |
| self._shm_obj_map[i] = shm_obj | |
| self.register(StorageWorker, shm_buffer=shm_obj, shm_callback=self._shm_callback) | |
| self.start_link() | |
| callback(obj) | |
| def _loop_recv(self): | |
| while True: | |
| payload = self.recv(ignore_err=True) | |
| if payload.err: | |
| logging.warning("Got error when loading data: {}".format(payload.err)) | |
| if payload.req_id in self._callback_map: | |
| del self._callback_map[payload.req_id] | |
| else: | |
| self._shm_putback(payload, self._shm_obj_map[payload.proc_id]) | |
| if payload.req_id in self._callback_map: | |
| callback = self._callback_map.pop(payload.req_id) | |
| callback(payload.data) | |
| def _create_shm_buffer(self, obj: Union[Dict, List]) -> Optional[ShmObject]: | |
| """ | |
| Overview: | |
| Create shared object (buf and callback) by walk through the data structure. | |
| Arguments: | |
| - obj (:obj:`Union[Dict, List]`): The data (traj or episodes), can be numpy, tensor or treetensor. | |
| Returns: | |
| - shm_buf (:obj:`Optional[ShmObject]`): The shared memory buffer. | |
| """ | |
| max_level = 2 | |
| def to_shm(obj: Dict, level: int): | |
| if level > max_level: | |
| return | |
| shm_buf = None | |
| if isinstance(obj, Dict) or isinstance(obj, ttorch.Tensor): | |
| shm_buf = {} | |
| for key, val in obj.items(): | |
| # Only numpy array can fill into shm buffer | |
| if isinstance(val, np.ndarray): | |
| shm_buf[key] = ShmBuffer(val.dtype, val.shape, copy_on_get=False) | |
| elif isinstance(val, torch.Tensor): | |
| shm_buf[key] = ShmBuffer( | |
| val.numpy().dtype, val.numpy().shape, copy_on_get=False, ctype=torch.Tensor | |
| ) | |
| # Recursive parsing structure | |
| elif isinstance(val, Dict) or isinstance(val, ttorch.Tensor) or isinstance(val, List): | |
| buf = to_shm(val, level=level + 1) | |
| if buf: | |
| shm_buf[key] = buf | |
| elif isinstance(obj, List): | |
| # Double the size of buffer | |
| shm_buf = [to_shm(o, level=level) for o in obj] * 2 | |
| if all(s is None for s in shm_buf): | |
| shm_buf = [] | |
| return shm_buf | |
| shm_buf = to_shm(obj, level=0) | |
| if shm_buf is not None: | |
| random_id = self._random_id() | |
| shm_buf = ShmObject(id_=ShmBuffer(random_id.dtype, random_id.shape, copy_on_get=False), buf=shm_buf) | |
| return shm_buf | |
| def _random_id(self) -> np.ndarray: | |
| return np.random.randint(1, 9e6, size=(1)) | |
| def _shm_callback(self, payload: RecvPayload, shm_obj: ShmObject): | |
| """ | |
| Overview: | |
| Called in subprocess, put payload.data into buf. | |
| Arguments: | |
| - payload (:obj:`RecvPayload`): The recv payload with meta info of the data. | |
| - shm_obj (:obj:`ShmObject`): The shm buffer. | |
| """ | |
| assert isinstance(payload.data, type( | |
| shm_obj.buf | |
| )), "Data type ({}) and buf type ({}) are not match!".format(type(payload.data), type(shm_obj.buf)) | |
| # Sleep while shm object is not ready. | |
| while shm_obj.id_.get()[0] != 0: | |
| sleep(0.001) | |
| max_level = 2 | |
| def shm_callback(data: Union[Dict, List, ttorch.Tensor], buf: Union[Dict, List], level: int): | |
| if level > max_level: | |
| return | |
| if isinstance(buf, List): | |
| assert isinstance(data, List), "Data ({}) and buf ({}) type not match".format(type(data), type(buf)) | |
| elif isinstance(buf, Dict): | |
| assert isinstance(data, ttorch.Tensor) or isinstance( | |
| data, Dict | |
| ), "Data ({}) and buf ({}) type not match".format(type(data), type(buf)) | |
| if isinstance(data, Dict) or isinstance(data, ttorch.Tensor): | |
| for key, val in data.items(): | |
| if isinstance(val, torch.Tensor): | |
| val = val.numpy() | |
| buf_val = buf.get(key) | |
| if buf_val is None: | |
| continue | |
| if isinstance(buf_val, ShmBuffer) and isinstance(val, np.ndarray): | |
| buf_val.fill(val) | |
| data[key] = None | |
| else: | |
| shm_callback(val, buf_val, level=level + 1) | |
| elif isinstance(data, List): | |
| for i, data_ in enumerate(data): | |
| shm_callback(data_, buf[i], level=level) | |
| shm_callback(payload.data, buf=shm_obj.buf, level=0) | |
| id_ = self._random_id() | |
| shm_obj.id_.fill(id_) | |
| payload.extra = id_ | |
| def _shm_putback(self, payload: RecvPayload, shm_obj: ShmObject): | |
| """ | |
| Overview: | |
| Called in main process, put buf back into payload.data. | |
| Arguments: | |
| - payload (:obj:`RecvPayload`): The recv payload with meta info of the data. | |
| - shm_obj (:obj:`ShmObject`): The shm buffer. | |
| """ | |
| assert isinstance(payload.data, type( | |
| shm_obj.buf | |
| )), "Data type ({}) and buf type ({}) are not match!".format(type(payload.data), type(shm_obj.buf)) | |
| assert shm_obj.id_.get()[0] == payload.extra[0], "Shm object and payload do not match ({} - {}).".format( | |
| shm_obj.id_.get()[0], payload.extra[0] | |
| ) | |
| def shm_putback(data: Union[Dict, List], buf: Union[Dict, List]): | |
| if isinstance(data, Dict) or isinstance(data, ttorch.Tensor): | |
| for key, val in data.items(): | |
| buf_val = buf.get(key) | |
| if buf_val is None: | |
| continue | |
| if val is None and isinstance(buf_val, ShmBuffer): | |
| data[key] = buf[key].get() | |
| else: | |
| shm_putback(val, buf_val) | |
| elif isinstance(data, List): | |
| for i, data_ in enumerate(data): | |
| shm_putback(data_, buf[i]) | |
| shm_putback(payload.data, buf=shm_obj.buf) | |
| shm_obj.id_.fill(np.array([0])) | |
| class FileStorageLoader(StorageLoader): | |
| def __init__(self, dirname: str, ttl: int = 20, worker_num: int = 3) -> None: | |
| """ | |
| Overview: | |
| Dump and load object with file storage. | |
| Arguments: | |
| - dirname (:obj:`str`): The directory to save files. | |
| - ttl (:obj:`str`): Maximum time to keep a file, after which it will be deleted. | |
| - worker_num (:obj:`int`): Number of subprocess worker loaders. | |
| """ | |
| super().__init__(worker_num) | |
| self._dirname = dirname | |
| self._files = [] | |
| self._cleanup_thread = None | |
| self._ttl = ttl # # Delete files created 10 minutes ago. | |
| def save(self, obj: Union[Dict, List]) -> FileStorage: | |
| if not path.exists(self._dirname): | |
| os.mkdir(self._dirname) | |
| filename = "{}.pkl".format(uuid.uuid1()) | |
| full_path = path.join(self._dirname, filename) | |
| f = FileStorage(full_path) | |
| f.save(obj) | |
| self._files.append([time(), f.path]) | |
| self._start_cleanup() | |
| return f | |
| def _start_cleanup(self): | |
| """ | |
| Overview: | |
| Start a cleanup thread to clean up files that are taking up too much time on the disk. | |
| """ | |
| if self._cleanup_thread is None: | |
| self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True) | |
| self._cleanup_thread.start() | |
| def shutdown(self, timeout: Optional[float] = None) -> None: | |
| super().shutdown(timeout) | |
| self._cleanup_thread = None | |
| def _loop_cleanup(self): | |
| while True: | |
| if len(self._files) == 0 or time() - self._files[0][0] < self._ttl: | |
| sleep(1) | |
| continue | |
| _, file_path = self._files.pop(0) | |
| if path.exists(file_path): | |
| os.remove(file_path) | |