Spaces:
Sleeping
Sleeping
| from typing import Optional, Callable, List, Any | |
| from ding.policy import PolicyFactory | |
| from ding.worker import IMetric, MetricSerialEvaluator | |
| class AccMetric(IMetric): | |
| def eval(self, inputs: Any, label: Any) -> dict: | |
| return {'Acc': (inputs['logit'].sum(dim=1) == label).sum().item() / label.shape[0]} | |
| def reduce_mean(self, inputs: List[Any]) -> Any: | |
| s = 0 | |
| for item in inputs: | |
| s += item['Acc'] | |
| return {'Acc': s / len(inputs)} | |
| def gt(self, metric1: Any, metric2: Any) -> bool: | |
| if metric2 is None: | |
| return True | |
| if isinstance(metric2, dict): | |
| m2 = metric2['Acc'] | |
| else: | |
| m2 = metric2 | |
| return metric1['Acc'] > m2 | |
| def mark_not_expert(ori_data: List[dict]) -> List[dict]: | |
| for i in range(len(ori_data)): | |
| # Set is_expert flag (expert 1, agent 0) | |
| ori_data[i]['is_expert'] = 0 | |
| return ori_data | |
| def mark_warm_up(ori_data: List[dict]) -> List[dict]: | |
| # for td3_vae | |
| for i in range(len(ori_data)): | |
| ori_data[i]['warm_up'] = True | |
| return ori_data | |
| def random_collect( | |
| policy_cfg: 'EasyDict', # noqa | |
| policy: 'Policy', # noqa | |
| collector: 'ISerialCollector', # noqa | |
| collector_env: 'BaseEnvManager', # noqa | |
| commander: 'BaseSerialCommander', # noqa | |
| replay_buffer: 'IBuffer', # noqa | |
| postprocess_data_fn: Optional[Callable] = None | |
| ) -> None: # noqa | |
| assert policy_cfg.random_collect_size > 0 | |
| if policy_cfg.get('transition_with_policy_data', False): | |
| collector.reset_policy(policy.collect_mode) | |
| else: | |
| action_space = collector_env.action_space | |
| random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) | |
| collector.reset_policy(random_policy) | |
| collect_kwargs = commander.step() | |
| if policy_cfg.collect.collector.type == 'episode': | |
| new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs) | |
| else: | |
| new_data = collector.collect( | |
| n_sample=policy_cfg.random_collect_size, | |
| random_collect=True, | |
| record_random_collect=False, | |
| policy_kwargs=collect_kwargs | |
| ) # 'record_random_collect=False' means random collect without output log | |
| if postprocess_data_fn is not None: | |
| new_data = postprocess_data_fn(new_data) | |
| replay_buffer.push(new_data, cur_collector_envstep=0) | |
| collector.reset_policy(policy.collect_mode) | |