Spaces:
Sleeping
Sleeping
File size: 1,235 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import numpy as np
class PathBuilder(dict):
"""
Usage:
```
path_builder = PathBuilder()
path.add_sample(
observations=1,
actions=2,
next_observations=3,
...
)
path.add_sample(
observations=4,
actions=5,
next_observations=6,
...
)
path = path_builder.get_all_stacked()
path['observations']
# output: [1, 4]
path['actions']
# output: [2, 5]
```
Note that the key should be "actions" and not "action" since the
resulting dictionary will have those keys.
"""
def __init__(self):
super().__init__()
self._path_length = 0
def add_all(self, **key_to_value):
for k, v in key_to_value.items():
if k not in self:
self[k] = [v]
else:
self[k].append(v)
self._path_length += 1
def get_all_stacked(self):
output_dict = dict()
for k, v in self.items():
output_dict[k] = stack_list(v)
return output_dict
def __len__(self):
return self._path_length
def stack_list(lst):
if isinstance(lst[0], dict):
return lst
else:
return np.array(lst)
|