Spaces:
Sleeping
Sleeping
import numpy as np | |
from gym.spaces import Dict, Discrete | |
from rlkit.data_management.replay_buffer import ReplayBuffer | |
class ObsDictRelabelingBuffer(ReplayBuffer): | |
""" | |
Replay buffer for environments whose observations are dictionaries, such as | |
- OpenAI Gym GoalEnv environments. https://blog.openai.com/ingredients-for-robotics-research/ | |
- multiworld MultitaskEnv. https://github.com/vitchyr/multiworld/ | |
Implementation details: | |
- Only add_path is implemented. | |
- Image observations are presumed to start with the 'image_' prefix | |
- Every sample from [0, self._size] will be valid. | |
- Observation and next observation are saved separately. It's a memory | |
inefficient to save the observations twice, but it makes the code | |
*much* easier since you no longer have to worry about termination | |
conditions. | |
""" | |
def __init__( | |
self, | |
max_size, | |
env, | |
fraction_goals_rollout_goals=1.0, | |
fraction_goals_env_goals=0.0, | |
internal_keys=None, | |
goal_keys=None, | |
observation_key='observation', | |
desired_goal_key='desired_goal', | |
achieved_goal_key='achieved_goal', | |
): | |
if internal_keys is None: | |
internal_keys = [] | |
self.internal_keys = internal_keys | |
if goal_keys is None: | |
goal_keys = [] | |
if desired_goal_key not in goal_keys: | |
goal_keys.append(desired_goal_key) | |
self.goal_keys = goal_keys | |
assert isinstance(env.observation_space, Dict) | |
assert 0 <= fraction_goals_rollout_goals | |
assert 0 <= fraction_goals_env_goals | |
assert 0 <= fraction_goals_rollout_goals + fraction_goals_env_goals | |
assert fraction_goals_rollout_goals + fraction_goals_env_goals <= 1 | |
self.max_size = max_size | |
self.env = env | |
self.fraction_goals_rollout_goals = fraction_goals_rollout_goals | |
self.fraction_goals_env_goals = fraction_goals_env_goals | |
self.ob_keys_to_save = [ | |
observation_key, | |
desired_goal_key, | |
achieved_goal_key, | |
] | |
self.observation_key = observation_key | |
self.desired_goal_key = desired_goal_key | |
self.achieved_goal_key = achieved_goal_key | |
if isinstance(self.env.action_space, Discrete): | |
self._action_dim = env.action_space.n | |
else: | |
self._action_dim = env.action_space.low.size | |
self._actions = np.zeros((max_size, self._action_dim)) | |
# self._terminals[i] = a terminal was received at time i | |
self._terminals = np.zeros((max_size, 1), dtype='uint8') | |
# self._obs[key][i] is the value of observation[key] at time i | |
self._obs = {} | |
self._next_obs = {} | |
self.ob_spaces = self.env.observation_space.spaces | |
for key in self.ob_keys_to_save + internal_keys: | |
assert key in self.ob_spaces, \ | |
"Key not found in the observation space: %s" % key | |
type = np.float64 | |
if key.startswith('image'): | |
type = np.uint8 | |
self._obs[key] = np.zeros( | |
(max_size, self.ob_spaces[key].low.size), dtype=type) | |
self._next_obs[key] = np.zeros( | |
(max_size, self.ob_spaces[key].low.size), dtype=type) | |
self._top = 0 | |
self._size = 0 | |
# Let j be any index in self._idx_to_future_obs_idx[i] | |
# Then self._next_obs[j] is a valid next observation for observation i | |
self._idx_to_future_obs_idx = [None] * max_size | |
def add_sample(self, observation, action, reward, terminal, | |
next_observation, **kwargs): | |
raise NotImplementedError("Only use add_path") | |
def terminate_episode(self): | |
pass | |
def num_steps_can_sample(self): | |
return self._size | |
def add_path(self, path): | |
obs = path["observations"] | |
actions = path["actions"] | |
rewards = path["rewards"] | |
next_obs = path["next_observations"] | |
terminals = path["terminals"] | |
path_len = len(rewards) | |
actions = flatten_n(actions) | |
if isinstance(self.env.action_space, Discrete): | |
actions = np.eye(self._action_dim)[actions].reshape((-1, self._action_dim)) | |
obs = flatten_dict(obs, self.ob_keys_to_save + self.internal_keys) | |
next_obs = flatten_dict(next_obs, self.ob_keys_to_save + self.internal_keys) | |
obs = preprocess_obs_dict(obs) | |
next_obs = preprocess_obs_dict(next_obs) | |
if self._top + path_len >= self.max_size: | |
""" | |
All of this logic is to handle wrapping the pointer when the | |
replay buffer gets full. | |
""" | |
num_pre_wrap_steps = self.max_size - self._top | |
# numpy slice | |
pre_wrap_buffer_slice = np.s_[ | |
self._top:self._top + num_pre_wrap_steps, : | |
] | |
pre_wrap_path_slice = np.s_[0:num_pre_wrap_steps, :] | |
num_post_wrap_steps = path_len - num_pre_wrap_steps | |
post_wrap_buffer_slice = slice(0, num_post_wrap_steps) | |
post_wrap_path_slice = slice(num_pre_wrap_steps, path_len) | |
for buffer_slice, path_slice in [ | |
(pre_wrap_buffer_slice, pre_wrap_path_slice), | |
(post_wrap_buffer_slice, post_wrap_path_slice), | |
]: | |
self._actions[buffer_slice] = actions[path_slice] | |
self._terminals[buffer_slice] = terminals[path_slice] | |
for key in self.ob_keys_to_save + self.internal_keys: | |
self._obs[key][buffer_slice] = obs[key][path_slice] | |
self._next_obs[key][buffer_slice] = next_obs[key][path_slice] | |
# Pointers from before the wrap | |
for i in range(self._top, self.max_size): | |
self._idx_to_future_obs_idx[i] = np.hstack(( | |
# Pre-wrap indices | |
np.arange(i, self.max_size), | |
# Post-wrap indices | |
np.arange(0, num_post_wrap_steps) | |
)) | |
# Pointers after the wrap | |
for i in range(0, num_post_wrap_steps): | |
self._idx_to_future_obs_idx[i] = np.arange( | |
i, | |
num_post_wrap_steps, | |
) | |
else: | |
slc = np.s_[self._top:self._top + path_len, :] | |
self._actions[slc] = actions | |
self._terminals[slc] = terminals | |
for key in self.ob_keys_to_save + self.internal_keys: | |
self._obs[key][slc] = obs[key] | |
self._next_obs[key][slc] = next_obs[key] | |
for i in range(self._top, self._top + path_len): | |
self._idx_to_future_obs_idx[i] = np.arange( | |
i, self._top + path_len | |
) | |
self._top = (self._top + path_len) % self.max_size | |
self._size = min(self._size + path_len, self.max_size) | |
def _sample_indices(self, batch_size): | |
return np.random.randint(0, self._size, batch_size) | |
def random_batch(self, batch_size): | |
indices = self._sample_indices(batch_size) | |
resampled_goals = self._next_obs[self.desired_goal_key][indices] | |
num_env_goals = int(batch_size * self.fraction_goals_env_goals) | |
num_rollout_goals = int(batch_size * self.fraction_goals_rollout_goals) | |
num_future_goals = batch_size - (num_env_goals + num_rollout_goals) | |
new_obs_dict = self._batch_obs_dict(indices) | |
new_next_obs_dict = self._batch_next_obs_dict(indices) | |
if num_env_goals > 0: | |
env_goals = self.env.sample_goals(num_env_goals) | |
env_goals = preprocess_obs_dict(env_goals) | |
last_env_goal_idx = num_rollout_goals + num_env_goals | |
resampled_goals[num_rollout_goals:last_env_goal_idx] = ( | |
env_goals[self.desired_goal_key] | |
) | |
for goal_key in self.goal_keys: | |
new_obs_dict[goal_key][num_rollout_goals:last_env_goal_idx] = \ | |
env_goals[goal_key] | |
new_next_obs_dict[goal_key][ | |
num_rollout_goals:last_env_goal_idx] = \ | |
env_goals[goal_key] | |
if num_future_goals > 0: | |
future_obs_idxs = [] | |
for i in indices[-num_future_goals:]: | |
possible_future_obs_idxs = self._idx_to_future_obs_idx[i] | |
# This is generally faster than random.choice. Makes you wonder what | |
# random.choice is doing | |
num_options = len(possible_future_obs_idxs) | |
next_obs_i = int(np.random.randint(0, num_options)) | |
future_obs_idxs.append(possible_future_obs_idxs[next_obs_i]) | |
future_obs_idxs = np.array(future_obs_idxs) | |
resampled_goals[-num_future_goals:] = self._next_obs[ | |
self.achieved_goal_key | |
][future_obs_idxs] | |
for goal_key in self.goal_keys: | |
new_obs_dict[goal_key][-num_future_goals:] = \ | |
self._next_obs[goal_key][future_obs_idxs] | |
new_next_obs_dict[goal_key][-num_future_goals:] = \ | |
self._next_obs[goal_key][future_obs_idxs] | |
new_obs_dict[self.desired_goal_key] = resampled_goals | |
new_next_obs_dict[self.desired_goal_key] = resampled_goals | |
new_obs_dict = postprocess_obs_dict(new_obs_dict) | |
new_next_obs_dict = postprocess_obs_dict(new_next_obs_dict) | |
# resampled_goals must be postprocessed as well | |
resampled_goals = new_next_obs_dict[self.desired_goal_key] | |
new_actions = self._actions[indices] | |
""" | |
For example, the environments in this repo have batch-wise | |
implementations of computing rewards: | |
https://github.com/vitchyr/multiworld | |
""" | |
if hasattr(self.env, 'compute_rewards'): | |
new_rewards = self.env.compute_rewards( | |
new_actions, | |
new_next_obs_dict, | |
) | |
else: # Assuming it's a (possibly wrapped) gym GoalEnv | |
new_rewards = np.ones((batch_size, 1)) | |
for i in range(batch_size): | |
new_rewards[i] = self.env.compute_reward( | |
new_next_obs_dict[self.achieved_goal_key][i], | |
new_next_obs_dict[self.desired_goal_key][i], | |
None | |
) | |
new_rewards = new_rewards.reshape(-1, 1) | |
new_obs = new_obs_dict[self.observation_key] | |
new_next_obs = new_next_obs_dict[self.observation_key] | |
batch = { | |
'observations': new_obs, | |
'actions': new_actions, | |
'rewards': new_rewards, | |
'terminals': self._terminals[indices], | |
'next_observations': new_next_obs, | |
'resampled_goals': resampled_goals, | |
'indices': np.array(indices).reshape(-1, 1), | |
} | |
return batch | |
def _batch_obs_dict(self, indices): | |
return { | |
key: self._obs[key][indices] | |
for key in self.ob_keys_to_save | |
} | |
def _batch_next_obs_dict(self, indices): | |
return { | |
key: self._next_obs[key][indices] | |
for key in self.ob_keys_to_save | |
} | |
def flatten_n(xs): | |
xs = np.asarray(xs) | |
return xs.reshape((xs.shape[0], -1)) | |
def flatten_dict(dicts, keys): | |
""" | |
Turns list of dicts into dict of np arrays | |
""" | |
return { | |
key: flatten_n([d[key] for d in dicts]) | |
for key in keys | |
} | |
def preprocess_obs_dict(obs_dict): | |
""" | |
Apply internal replay buffer representation changes: save images as bytes | |
""" | |
for obs_key, obs in obs_dict.items(): | |
if 'image' in obs_key and obs is not None: | |
obs_dict[obs_key] = unnormalize_image(obs) | |
return obs_dict | |
def postprocess_obs_dict(obs_dict): | |
""" | |
Undo internal replay buffer representation changes: save images as bytes | |
""" | |
for obs_key, obs in obs_dict.items(): | |
if 'image' in obs_key and obs is not None: | |
obs_dict[obs_key] = normalize_image(obs) | |
return obs_dict | |
def normalize_image(image): | |
assert image.dtype == np.uint8 | |
return np.float64(image) / 255.0 | |
def unnormalize_image(image): | |
assert image.dtype != np.uint8 | |
return np.uint8(image * 255.0) | |