Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Optional | |
| import copy | |
| from easydict import EasyDict | |
| import numpy as np | |
| import hickle | |
| from ding.data.buffer import DequeBuffer | |
| from ding.data.buffer.middleware import use_time_check, PriorityExperienceReplay | |
| from ding.utils import BUFFER_REGISTRY | |
| class DequeBufferWrapper(object): | |
| def default_config(cls: type) -> EasyDict: | |
| cfg = EasyDict(copy.deepcopy(cls.config)) | |
| cfg.cfg_type = cls.__name__ + 'Dict' | |
| return cfg | |
| config = dict( | |
| replay_buffer_size=10000, | |
| max_use=float("inf"), | |
| train_iter_per_log=100, | |
| priority=False, | |
| priority_IS_weight=False, | |
| priority_power_factor=0.6, | |
| IS_weight_power_factor=0.4, | |
| IS_weight_anneal_train_iter=int(1e5), | |
| priority_max_limit=1000, | |
| ) | |
| def __init__( | |
| self, | |
| cfg: EasyDict, | |
| tb_logger: Optional[object] = None, | |
| exp_name: str = 'default_experiement', | |
| instance_name: str = 'buffer' | |
| ) -> None: | |
| self.cfg = cfg | |
| self.priority_max_limit = cfg.priority_max_limit | |
| self.name = '{}_iter'.format(instance_name) | |
| self.tb_logger = tb_logger | |
| self.buffer = DequeBuffer(size=cfg.replay_buffer_size) | |
| self.last_log_train_iter = -1 | |
| # use_count middleware | |
| if self.cfg.max_use != float("inf"): | |
| self.buffer.use(use_time_check(self.buffer, max_use=self.cfg.max_use)) | |
| # priority middleware | |
| if self.cfg.priority: | |
| self.buffer.use( | |
| PriorityExperienceReplay( | |
| self.buffer, | |
| IS_weight=self.cfg.priority_IS_weight, | |
| priority_power_factor=self.cfg.priority_power_factor, | |
| IS_weight_power_factor=self.cfg.IS_weight_power_factor, | |
| IS_weight_anneal_train_iter=self.cfg.IS_weight_anneal_train_iter | |
| ) | |
| ) | |
| self.last_sample_index = None | |
| self.last_sample_meta = None | |
| def sample(self, size: int, train_iter: int = 0): | |
| output = self.buffer.sample(size=size, ignore_insufficient=True) | |
| if len(output) > 0: | |
| if self.last_log_train_iter == -1 or train_iter - self.last_log_train_iter >= self.cfg.train_iter_per_log: | |
| meta = [o.meta for o in output] | |
| if self.cfg.max_use != float("inf"): | |
| use_count_avg = np.mean([m['use_count'] for m in meta]) | |
| self.tb_logger.add_scalar('{}/use_count_avg'.format(self.name), use_count_avg, train_iter) | |
| if self.cfg.priority: | |
| self.last_sample_index = [o.index for o in output] | |
| self.last_sample_meta = meta | |
| priority_list = [m['priority'] for m in meta] | |
| priority_avg = np.mean(priority_list) | |
| priority_max = np.max(priority_list) | |
| self.tb_logger.add_scalar('{}/priority_avg'.format(self.name), priority_avg, train_iter) | |
| self.tb_logger.add_scalar('{}/priority_max'.format(self.name), priority_max, train_iter) | |
| self.tb_logger.add_scalar('{}/buffer_data_count'.format(self.name), self.buffer.count(), train_iter) | |
| self.last_log_train_iter = train_iter | |
| data = [o.data for o in output] | |
| if self.cfg.priority_IS_weight: | |
| IS = [o.meta['priority_IS'] for o in output] | |
| for i in range(len(data)): | |
| data[i]['IS'] = IS[i] | |
| return data | |
| else: | |
| return None | |
| def push(self, data, cur_collector_envstep: int = -1) -> None: | |
| for d in data: | |
| meta = {} | |
| if self.cfg.priority and 'priority' in d: | |
| init_priority = d.pop('priority') | |
| meta['priority'] = init_priority | |
| self.buffer.push(d, meta=meta) | |
| def update(self, meta: dict) -> None: | |
| if not self.cfg.priority: | |
| return | |
| if self.last_sample_index is None: | |
| return | |
| new_meta = self.last_sample_meta | |
| for m, p in zip(new_meta, meta['priority']): | |
| m['priority'] = min(self.priority_max_limit, p) | |
| for idx, m in zip(self.last_sample_index, new_meta): | |
| self.buffer.update(idx, data=None, meta=m) | |
| self.last_sample_index = None | |
| self.last_sample_meta = None | |
| def count(self) -> int: | |
| return self.buffer.count() | |
| def save_data(self, file_name): | |
| self.buffer.save_data(file_name) | |
| def load_data(self, file_name: str): | |
| self.buffer.load_data(file_name) | |