Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| import logging | |
| from os import path | |
| import os | |
| from threading import Thread | |
| from time import sleep, time | |
| from typing import Callable, Optional | |
| import uuid | |
| import torch.multiprocessing as mp | |
| import torch | |
| from ding.data.storage.file import FileModelStorage | |
| from ding.data.storage.storage import Storage | |
| from ding.framework import Supervisor | |
| from ding.framework.supervisor import ChildType, SendPayload | |
| class ModelWorker(): | |
| def __init__(self, model: torch.nn.Module) -> None: | |
| self._model = model | |
| def save(self, storage: Storage) -> Storage: | |
| storage.save(self._model.state_dict()) | |
| return storage | |
| class ModelLoader(Supervisor, ABC): | |
| def __init__(self, model: torch.nn.Module) -> None: | |
| """ | |
| Overview: | |
| Save and send models asynchronously and load them synchronously. | |
| Arguments: | |
| - model (:obj:`torch.nn.Module`): Torch module. | |
| """ | |
| if next(model.parameters()).is_cuda: | |
| super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn")) | |
| else: | |
| super().__init__(type_=ChildType.PROCESS) | |
| self._model = model | |
| self._send_callback_loop = None | |
| self._send_callbacks = {} | |
| self._model_worker = ModelWorker(self._model) | |
| def start(self): | |
| if not self._running: | |
| self._model.share_memory() | |
| self.register(self._model_worker) | |
| self.start_link() | |
| self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True) | |
| self._send_callback_loop.start() | |
| def shutdown(self, timeout: Optional[float] = None) -> None: | |
| super().shutdown(timeout) | |
| self._send_callback_loop = None | |
| self._send_callbacks = {} | |
| def _loop_send_callback(self): | |
| while True: | |
| payload = self.recv(ignore_err=True) | |
| if payload.err: | |
| logging.warning("Got error when loading data: {}".format(payload.err)) | |
| if payload.req_id in self._send_callbacks: | |
| del self._send_callbacks[payload.req_id] | |
| else: | |
| if payload.req_id in self._send_callbacks: | |
| callback = self._send_callbacks.pop(payload.req_id) | |
| callback(payload.data) | |
| def load(self, storage: Storage) -> object: | |
| """ | |
| Overview: | |
| Load model synchronously. | |
| Arguments: | |
| - storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage. | |
| Returns: | |
| - object (:obj:): The loaded model. | |
| """ | |
| return storage.load() | |
| def save(self, callback: Callable) -> Storage: | |
| """ | |
| Overview: | |
| Save model asynchronously. | |
| Arguments: | |
| - callback (:obj:`Callable`): The callback function after saving model. | |
| Returns: | |
| - storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned. | |
| """ | |
| raise NotImplementedError | |
| class FileModelLoader(ModelLoader): | |
| def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None: | |
| """ | |
| Overview: | |
| Model loader using files as storage media. | |
| Arguments: | |
| - model (:obj:`torch.nn.Module`): Torch module. | |
| - dirname (:obj:`str`): The directory for saving files. | |
| - ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \ | |
| files that do not time out when the process is stopped are not cleaned up \ | |
| (to avoid errors when other processes read the file), so you may need to \ | |
| clean up the remaining files manually | |
| """ | |
| super().__init__(model) | |
| self._dirname = dirname | |
| self._ttl = ttl | |
| self._files = [] | |
| self._cleanup_thread = None | |
| def _start_cleanup(self): | |
| """ | |
| Overview: | |
| Start a cleanup thread to clean up files that are taking up too much time on the disk. | |
| """ | |
| if self._cleanup_thread is None: | |
| self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True) | |
| self._cleanup_thread.start() | |
| def shutdown(self, timeout: Optional[float] = None) -> None: | |
| super().shutdown(timeout) | |
| self._cleanup_thread = None | |
| def _loop_cleanup(self): | |
| while True: | |
| if len(self._files) == 0 or time() - self._files[0][0] < self._ttl: | |
| sleep(1) | |
| continue | |
| _, file_path = self._files.pop(0) | |
| if path.exists(file_path): | |
| os.remove(file_path) | |
| def save(self, callback: Callable) -> FileModelStorage: | |
| if not self._running: | |
| logging.warning("Please start model loader before saving model.") | |
| return | |
| if not path.exists(self._dirname): | |
| os.mkdir(self._dirname) | |
| file_path = "model_{}.pth.tar".format(uuid.uuid1()) | |
| file_path = path.join(self._dirname, file_path) | |
| model_storage = FileModelStorage(file_path) | |
| payload = SendPayload(proc_id=0, method="save", args=[model_storage]) | |
| self.send(payload) | |
| def clean_callback(storage: Storage): | |
| self._files.append([time(), file_path]) | |
| callback(storage) | |
| self._send_callbacks[payload.req_id] = clean_callback | |
| self._start_cleanup() | |
| return model_storage | |