NCERL-Diverse-PCG / src /rlkit /data_management /obs_dict_replay_buffer.py
baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
12.4 kB
import numpy as np
from gym.spaces import Dict, Discrete
from rlkit.data_management.replay_buffer import ReplayBuffer
class ObsDictRelabelingBuffer(ReplayBuffer):
"""
Replay buffer for environments whose observations are dictionaries, such as
- OpenAI Gym GoalEnv environments. https://blog.openai.com/ingredients-for-robotics-research/
- multiworld MultitaskEnv. https://github.com/vitchyr/multiworld/
Implementation details:
- Only add_path is implemented.
- Image observations are presumed to start with the 'image_' prefix
- Every sample from [0, self._size] will be valid.
- Observation and next observation are saved separately. It's a memory
inefficient to save the observations twice, but it makes the code
*much* easier since you no longer have to worry about termination
conditions.
"""
def __init__(
self,
max_size,
env,
fraction_goals_rollout_goals=1.0,
fraction_goals_env_goals=0.0,
internal_keys=None,
goal_keys=None,
observation_key='observation',
desired_goal_key='desired_goal',
achieved_goal_key='achieved_goal',
):
if internal_keys is None:
internal_keys = []
self.internal_keys = internal_keys
if goal_keys is None:
goal_keys = []
if desired_goal_key not in goal_keys:
goal_keys.append(desired_goal_key)
self.goal_keys = goal_keys
assert isinstance(env.observation_space, Dict)
assert 0 <= fraction_goals_rollout_goals
assert 0 <= fraction_goals_env_goals
assert 0 <= fraction_goals_rollout_goals + fraction_goals_env_goals
assert fraction_goals_rollout_goals + fraction_goals_env_goals <= 1
self.max_size = max_size
self.env = env
self.fraction_goals_rollout_goals = fraction_goals_rollout_goals
self.fraction_goals_env_goals = fraction_goals_env_goals
self.ob_keys_to_save = [
observation_key,
desired_goal_key,
achieved_goal_key,
]
self.observation_key = observation_key
self.desired_goal_key = desired_goal_key
self.achieved_goal_key = achieved_goal_key
if isinstance(self.env.action_space, Discrete):
self._action_dim = env.action_space.n
else:
self._action_dim = env.action_space.low.size
self._actions = np.zeros((max_size, self._action_dim))
# self._terminals[i] = a terminal was received at time i
self._terminals = np.zeros((max_size, 1), dtype='uint8')
# self._obs[key][i] is the value of observation[key] at time i
self._obs = {}
self._next_obs = {}
self.ob_spaces = self.env.observation_space.spaces
for key in self.ob_keys_to_save + internal_keys:
assert key in self.ob_spaces, \
"Key not found in the observation space: %s" % key
type = np.float64
if key.startswith('image'):
type = np.uint8
self._obs[key] = np.zeros(
(max_size, self.ob_spaces[key].low.size), dtype=type)
self._next_obs[key] = np.zeros(
(max_size, self.ob_spaces[key].low.size), dtype=type)
self._top = 0
self._size = 0
# Let j be any index in self._idx_to_future_obs_idx[i]
# Then self._next_obs[j] is a valid next observation for observation i
self._idx_to_future_obs_idx = [None] * max_size
def add_sample(self, observation, action, reward, terminal,
next_observation, **kwargs):
raise NotImplementedError("Only use add_path")
def terminate_episode(self):
pass
def num_steps_can_sample(self):
return self._size
def add_path(self, path):
obs = path["observations"]
actions = path["actions"]
rewards = path["rewards"]
next_obs = path["next_observations"]
terminals = path["terminals"]
path_len = len(rewards)
actions = flatten_n(actions)
if isinstance(self.env.action_space, Discrete):
actions = np.eye(self._action_dim)[actions].reshape((-1, self._action_dim))
obs = flatten_dict(obs, self.ob_keys_to_save + self.internal_keys)
next_obs = flatten_dict(next_obs, self.ob_keys_to_save + self.internal_keys)
obs = preprocess_obs_dict(obs)
next_obs = preprocess_obs_dict(next_obs)
if self._top + path_len >= self.max_size:
"""
All of this logic is to handle wrapping the pointer when the
replay buffer gets full.
"""
num_pre_wrap_steps = self.max_size - self._top
# numpy slice
pre_wrap_buffer_slice = np.s_[
self._top:self._top + num_pre_wrap_steps, :
]
pre_wrap_path_slice = np.s_[0:num_pre_wrap_steps, :]
num_post_wrap_steps = path_len - num_pre_wrap_steps
post_wrap_buffer_slice = slice(0, num_post_wrap_steps)
post_wrap_path_slice = slice(num_pre_wrap_steps, path_len)
for buffer_slice, path_slice in [
(pre_wrap_buffer_slice, pre_wrap_path_slice),
(post_wrap_buffer_slice, post_wrap_path_slice),
]:
self._actions[buffer_slice] = actions[path_slice]
self._terminals[buffer_slice] = terminals[path_slice]
for key in self.ob_keys_to_save + self.internal_keys:
self._obs[key][buffer_slice] = obs[key][path_slice]
self._next_obs[key][buffer_slice] = next_obs[key][path_slice]
# Pointers from before the wrap
for i in range(self._top, self.max_size):
self._idx_to_future_obs_idx[i] = np.hstack((
# Pre-wrap indices
np.arange(i, self.max_size),
# Post-wrap indices
np.arange(0, num_post_wrap_steps)
))
# Pointers after the wrap
for i in range(0, num_post_wrap_steps):
self._idx_to_future_obs_idx[i] = np.arange(
i,
num_post_wrap_steps,
)
else:
slc = np.s_[self._top:self._top + path_len, :]
self._actions[slc] = actions
self._terminals[slc] = terminals
for key in self.ob_keys_to_save + self.internal_keys:
self._obs[key][slc] = obs[key]
self._next_obs[key][slc] = next_obs[key]
for i in range(self._top, self._top + path_len):
self._idx_to_future_obs_idx[i] = np.arange(
i, self._top + path_len
)
self._top = (self._top + path_len) % self.max_size
self._size = min(self._size + path_len, self.max_size)
def _sample_indices(self, batch_size):
return np.random.randint(0, self._size, batch_size)
def random_batch(self, batch_size):
indices = self._sample_indices(batch_size)
resampled_goals = self._next_obs[self.desired_goal_key][indices]
num_env_goals = int(batch_size * self.fraction_goals_env_goals)
num_rollout_goals = int(batch_size * self.fraction_goals_rollout_goals)
num_future_goals = batch_size - (num_env_goals + num_rollout_goals)
new_obs_dict = self._batch_obs_dict(indices)
new_next_obs_dict = self._batch_next_obs_dict(indices)
if num_env_goals > 0:
env_goals = self.env.sample_goals(num_env_goals)
env_goals = preprocess_obs_dict(env_goals)
last_env_goal_idx = num_rollout_goals + num_env_goals
resampled_goals[num_rollout_goals:last_env_goal_idx] = (
env_goals[self.desired_goal_key]
)
for goal_key in self.goal_keys:
new_obs_dict[goal_key][num_rollout_goals:last_env_goal_idx] = \
env_goals[goal_key]
new_next_obs_dict[goal_key][
num_rollout_goals:last_env_goal_idx] = \
env_goals[goal_key]
if num_future_goals > 0:
future_obs_idxs = []
for i in indices[-num_future_goals:]:
possible_future_obs_idxs = self._idx_to_future_obs_idx[i]
# This is generally faster than random.choice. Makes you wonder what
# random.choice is doing
num_options = len(possible_future_obs_idxs)
next_obs_i = int(np.random.randint(0, num_options))
future_obs_idxs.append(possible_future_obs_idxs[next_obs_i])
future_obs_idxs = np.array(future_obs_idxs)
resampled_goals[-num_future_goals:] = self._next_obs[
self.achieved_goal_key
][future_obs_idxs]
for goal_key in self.goal_keys:
new_obs_dict[goal_key][-num_future_goals:] = \
self._next_obs[goal_key][future_obs_idxs]
new_next_obs_dict[goal_key][-num_future_goals:] = \
self._next_obs[goal_key][future_obs_idxs]
new_obs_dict[self.desired_goal_key] = resampled_goals
new_next_obs_dict[self.desired_goal_key] = resampled_goals
new_obs_dict = postprocess_obs_dict(new_obs_dict)
new_next_obs_dict = postprocess_obs_dict(new_next_obs_dict)
# resampled_goals must be postprocessed as well
resampled_goals = new_next_obs_dict[self.desired_goal_key]
new_actions = self._actions[indices]
"""
For example, the environments in this repo have batch-wise
implementations of computing rewards:
https://github.com/vitchyr/multiworld
"""
if hasattr(self.env, 'compute_rewards'):
new_rewards = self.env.compute_rewards(
new_actions,
new_next_obs_dict,
)
else: # Assuming it's a (possibly wrapped) gym GoalEnv
new_rewards = np.ones((batch_size, 1))
for i in range(batch_size):
new_rewards[i] = self.env.compute_reward(
new_next_obs_dict[self.achieved_goal_key][i],
new_next_obs_dict[self.desired_goal_key][i],
None
)
new_rewards = new_rewards.reshape(-1, 1)
new_obs = new_obs_dict[self.observation_key]
new_next_obs = new_next_obs_dict[self.observation_key]
batch = {
'observations': new_obs,
'actions': new_actions,
'rewards': new_rewards,
'terminals': self._terminals[indices],
'next_observations': new_next_obs,
'resampled_goals': resampled_goals,
'indices': np.array(indices).reshape(-1, 1),
}
return batch
def _batch_obs_dict(self, indices):
return {
key: self._obs[key][indices]
for key in self.ob_keys_to_save
}
def _batch_next_obs_dict(self, indices):
return {
key: self._next_obs[key][indices]
for key in self.ob_keys_to_save
}
def flatten_n(xs):
xs = np.asarray(xs)
return xs.reshape((xs.shape[0], -1))
def flatten_dict(dicts, keys):
"""
Turns list of dicts into dict of np arrays
"""
return {
key: flatten_n([d[key] for d in dicts])
for key in keys
}
def preprocess_obs_dict(obs_dict):
"""
Apply internal replay buffer representation changes: save images as bytes
"""
for obs_key, obs in obs_dict.items():
if 'image' in obs_key and obs is not None:
obs_dict[obs_key] = unnormalize_image(obs)
return obs_dict
def postprocess_obs_dict(obs_dict):
"""
Undo internal replay buffer representation changes: save images as bytes
"""
for obs_key, obs in obs_dict.items():
if 'image' in obs_key and obs is not None:
obs_dict[obs_key] = normalize_image(obs)
return obs_dict
def normalize_image(image):
assert image.dtype == np.uint8
return np.float64(image) / 255.0
def unnormalize_image(image):
assert image.dtype != np.uint8
return np.uint8(image * 255.0)