baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
1.95 kB
import torch
import random
from src.utils.datastruct import RingQueue
class ReplayMem:
def __init__(self, capacity=int(2e5), device='cpu'):
self.queue = RingQueue(capacity)
self.device = device
def add(self, o, a, r, op):
transition = (
torch.tensor(o, device=self.device, dtype=torch.float),
torch.tensor(a, device=self.device, dtype=torch.float),
r,
torch.tensor(op, device=self.device, dtype=torch.float)
)
self.queue.push(transition)
def add_transitions(self, transitions):
for transition in transitions:
self.add(*transition)
def add_batched(self, obs, actions, rewards, next_obs):
for o, a, r, op in zip(obs, actions, rewards, next_obs):
self.add(o, a, r, op)
def sample(self, n):
tuples = random.sample(self.queue.main, n)
obs, acts, rews, ops = [], [], [], []
for o, a, r, op in tuples:
obs.append(o)
acts.append(a)
rews.append(r)
ops.append(op)
obs = torch.stack(obs)
acts = torch.stack(acts)
rews = torch.tensor(rews, device=self.device, dtype=torch.float)
ops = torch.stack(ops)
return obs, acts, rews, ops
def __len__(self):
return len(self.queue)
def clear(self):
self.queue.clear()
if __name__ == '__main__':
# mem = ReplayMem()
# mem.add_batched(
# [np.zeros([5]), np.zeros([5]), np.zeros([5]), np.zeros([5])],
# [np.zeros([5]), np.zeros([5]), np.zeros([5]), np.zeros([5])],
# [np.zeros([5]), np.zeros([5]), np.zeros([5]), np.zeros([5])],
# [np.zeros([5]), np.zeros([5]), np.zeros([5]), np.zeros([5])]
# )
# mem.sample(2, 'cuda:0')
# mem.sample(2, 'cuda:0')
# mem.sample(2, 'cuda:0')
# mem.sample(2, 'cuda:0')
# mem.sample(2, 'cuda:0')
# mem.sample(2, 'cuda:0')
pass