Spaces:
Sleeping
Sleeping
File size: 12,404 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
import numpy as np
from gym.spaces import Dict, Discrete
from rlkit.data_management.replay_buffer import ReplayBuffer
class ObsDictRelabelingBuffer(ReplayBuffer):
"""
Replay buffer for environments whose observations are dictionaries, such as
- OpenAI Gym GoalEnv environments. https://blog.openai.com/ingredients-for-robotics-research/
- multiworld MultitaskEnv. https://github.com/vitchyr/multiworld/
Implementation details:
- Only add_path is implemented.
- Image observations are presumed to start with the 'image_' prefix
- Every sample from [0, self._size] will be valid.
- Observation and next observation are saved separately. It's a memory
inefficient to save the observations twice, but it makes the code
*much* easier since you no longer have to worry about termination
conditions.
"""
def __init__(
self,
max_size,
env,
fraction_goals_rollout_goals=1.0,
fraction_goals_env_goals=0.0,
internal_keys=None,
goal_keys=None,
observation_key='observation',
desired_goal_key='desired_goal',
achieved_goal_key='achieved_goal',
):
if internal_keys is None:
internal_keys = []
self.internal_keys = internal_keys
if goal_keys is None:
goal_keys = []
if desired_goal_key not in goal_keys:
goal_keys.append(desired_goal_key)
self.goal_keys = goal_keys
assert isinstance(env.observation_space, Dict)
assert 0 <= fraction_goals_rollout_goals
assert 0 <= fraction_goals_env_goals
assert 0 <= fraction_goals_rollout_goals + fraction_goals_env_goals
assert fraction_goals_rollout_goals + fraction_goals_env_goals <= 1
self.max_size = max_size
self.env = env
self.fraction_goals_rollout_goals = fraction_goals_rollout_goals
self.fraction_goals_env_goals = fraction_goals_env_goals
self.ob_keys_to_save = [
observation_key,
desired_goal_key,
achieved_goal_key,
]
self.observation_key = observation_key
self.desired_goal_key = desired_goal_key
self.achieved_goal_key = achieved_goal_key
if isinstance(self.env.action_space, Discrete):
self._action_dim = env.action_space.n
else:
self._action_dim = env.action_space.low.size
self._actions = np.zeros((max_size, self._action_dim))
# self._terminals[i] = a terminal was received at time i
self._terminals = np.zeros((max_size, 1), dtype='uint8')
# self._obs[key][i] is the value of observation[key] at time i
self._obs = {}
self._next_obs = {}
self.ob_spaces = self.env.observation_space.spaces
for key in self.ob_keys_to_save + internal_keys:
assert key in self.ob_spaces, \
"Key not found in the observation space: %s" % key
type = np.float64
if key.startswith('image'):
type = np.uint8
self._obs[key] = np.zeros(
(max_size, self.ob_spaces[key].low.size), dtype=type)
self._next_obs[key] = np.zeros(
(max_size, self.ob_spaces[key].low.size), dtype=type)
self._top = 0
self._size = 0
# Let j be any index in self._idx_to_future_obs_idx[i]
# Then self._next_obs[j] is a valid next observation for observation i
self._idx_to_future_obs_idx = [None] * max_size
def add_sample(self, observation, action, reward, terminal,
next_observation, **kwargs):
raise NotImplementedError("Only use add_path")
def terminate_episode(self):
pass
def num_steps_can_sample(self):
return self._size
def add_path(self, path):
obs = path["observations"]
actions = path["actions"]
rewards = path["rewards"]
next_obs = path["next_observations"]
terminals = path["terminals"]
path_len = len(rewards)
actions = flatten_n(actions)
if isinstance(self.env.action_space, Discrete):
actions = np.eye(self._action_dim)[actions].reshape((-1, self._action_dim))
obs = flatten_dict(obs, self.ob_keys_to_save + self.internal_keys)
next_obs = flatten_dict(next_obs, self.ob_keys_to_save + self.internal_keys)
obs = preprocess_obs_dict(obs)
next_obs = preprocess_obs_dict(next_obs)
if self._top + path_len >= self.max_size:
"""
All of this logic is to handle wrapping the pointer when the
replay buffer gets full.
"""
num_pre_wrap_steps = self.max_size - self._top
# numpy slice
pre_wrap_buffer_slice = np.s_[
self._top:self._top + num_pre_wrap_steps, :
]
pre_wrap_path_slice = np.s_[0:num_pre_wrap_steps, :]
num_post_wrap_steps = path_len - num_pre_wrap_steps
post_wrap_buffer_slice = slice(0, num_post_wrap_steps)
post_wrap_path_slice = slice(num_pre_wrap_steps, path_len)
for buffer_slice, path_slice in [
(pre_wrap_buffer_slice, pre_wrap_path_slice),
(post_wrap_buffer_slice, post_wrap_path_slice),
]:
self._actions[buffer_slice] = actions[path_slice]
self._terminals[buffer_slice] = terminals[path_slice]
for key in self.ob_keys_to_save + self.internal_keys:
self._obs[key][buffer_slice] = obs[key][path_slice]
self._next_obs[key][buffer_slice] = next_obs[key][path_slice]
# Pointers from before the wrap
for i in range(self._top, self.max_size):
self._idx_to_future_obs_idx[i] = np.hstack((
# Pre-wrap indices
np.arange(i, self.max_size),
# Post-wrap indices
np.arange(0, num_post_wrap_steps)
))
# Pointers after the wrap
for i in range(0, num_post_wrap_steps):
self._idx_to_future_obs_idx[i] = np.arange(
i,
num_post_wrap_steps,
)
else:
slc = np.s_[self._top:self._top + path_len, :]
self._actions[slc] = actions
self._terminals[slc] = terminals
for key in self.ob_keys_to_save + self.internal_keys:
self._obs[key][slc] = obs[key]
self._next_obs[key][slc] = next_obs[key]
for i in range(self._top, self._top + path_len):
self._idx_to_future_obs_idx[i] = np.arange(
i, self._top + path_len
)
self._top = (self._top + path_len) % self.max_size
self._size = min(self._size + path_len, self.max_size)
def _sample_indices(self, batch_size):
return np.random.randint(0, self._size, batch_size)
def random_batch(self, batch_size):
indices = self._sample_indices(batch_size)
resampled_goals = self._next_obs[self.desired_goal_key][indices]
num_env_goals = int(batch_size * self.fraction_goals_env_goals)
num_rollout_goals = int(batch_size * self.fraction_goals_rollout_goals)
num_future_goals = batch_size - (num_env_goals + num_rollout_goals)
new_obs_dict = self._batch_obs_dict(indices)
new_next_obs_dict = self._batch_next_obs_dict(indices)
if num_env_goals > 0:
env_goals = self.env.sample_goals(num_env_goals)
env_goals = preprocess_obs_dict(env_goals)
last_env_goal_idx = num_rollout_goals + num_env_goals
resampled_goals[num_rollout_goals:last_env_goal_idx] = (
env_goals[self.desired_goal_key]
)
for goal_key in self.goal_keys:
new_obs_dict[goal_key][num_rollout_goals:last_env_goal_idx] = \
env_goals[goal_key]
new_next_obs_dict[goal_key][
num_rollout_goals:last_env_goal_idx] = \
env_goals[goal_key]
if num_future_goals > 0:
future_obs_idxs = []
for i in indices[-num_future_goals:]:
possible_future_obs_idxs = self._idx_to_future_obs_idx[i]
# This is generally faster than random.choice. Makes you wonder what
# random.choice is doing
num_options = len(possible_future_obs_idxs)
next_obs_i = int(np.random.randint(0, num_options))
future_obs_idxs.append(possible_future_obs_idxs[next_obs_i])
future_obs_idxs = np.array(future_obs_idxs)
resampled_goals[-num_future_goals:] = self._next_obs[
self.achieved_goal_key
][future_obs_idxs]
for goal_key in self.goal_keys:
new_obs_dict[goal_key][-num_future_goals:] = \
self._next_obs[goal_key][future_obs_idxs]
new_next_obs_dict[goal_key][-num_future_goals:] = \
self._next_obs[goal_key][future_obs_idxs]
new_obs_dict[self.desired_goal_key] = resampled_goals
new_next_obs_dict[self.desired_goal_key] = resampled_goals
new_obs_dict = postprocess_obs_dict(new_obs_dict)
new_next_obs_dict = postprocess_obs_dict(new_next_obs_dict)
# resampled_goals must be postprocessed as well
resampled_goals = new_next_obs_dict[self.desired_goal_key]
new_actions = self._actions[indices]
"""
For example, the environments in this repo have batch-wise
implementations of computing rewards:
https://github.com/vitchyr/multiworld
"""
if hasattr(self.env, 'compute_rewards'):
new_rewards = self.env.compute_rewards(
new_actions,
new_next_obs_dict,
)
else: # Assuming it's a (possibly wrapped) gym GoalEnv
new_rewards = np.ones((batch_size, 1))
for i in range(batch_size):
new_rewards[i] = self.env.compute_reward(
new_next_obs_dict[self.achieved_goal_key][i],
new_next_obs_dict[self.desired_goal_key][i],
None
)
new_rewards = new_rewards.reshape(-1, 1)
new_obs = new_obs_dict[self.observation_key]
new_next_obs = new_next_obs_dict[self.observation_key]
batch = {
'observations': new_obs,
'actions': new_actions,
'rewards': new_rewards,
'terminals': self._terminals[indices],
'next_observations': new_next_obs,
'resampled_goals': resampled_goals,
'indices': np.array(indices).reshape(-1, 1),
}
return batch
def _batch_obs_dict(self, indices):
return {
key: self._obs[key][indices]
for key in self.ob_keys_to_save
}
def _batch_next_obs_dict(self, indices):
return {
key: self._next_obs[key][indices]
for key in self.ob_keys_to_save
}
def flatten_n(xs):
xs = np.asarray(xs)
return xs.reshape((xs.shape[0], -1))
def flatten_dict(dicts, keys):
"""
Turns list of dicts into dict of np arrays
"""
return {
key: flatten_n([d[key] for d in dicts])
for key in keys
}
def preprocess_obs_dict(obs_dict):
"""
Apply internal replay buffer representation changes: save images as bytes
"""
for obs_key, obs in obs_dict.items():
if 'image' in obs_key and obs is not None:
obs_dict[obs_key] = unnormalize_image(obs)
return obs_dict
def postprocess_obs_dict(obs_dict):
"""
Undo internal replay buffer representation changes: save images as bytes
"""
for obs_key, obs in obs_dict.items():
if 'image' in obs_key and obs is not None:
obs_dict[obs_key] = normalize_image(obs)
return obs_dict
def normalize_image(image):
assert image.dtype == np.uint8
return np.float64(image) / 255.0
def unnormalize_image(image):
assert image.dtype != np.uint8
return np.uint8(image * 255.0)
|