bglick13 commited on
Commit
c79138b
·
1 Parent(s): b4509d0

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +11 -2
pipeline.py CHANGED
@@ -14,8 +14,17 @@ class ValueGuidedDiffuserPipeline(DiffusionPipeline):
14
  self.scheduler = scheduler
15
  self.env = env
16
  self.data = env.get_dataset()
17
- self.means = dict((key, val.mean(axis=0)) for key, val in self.data.items())
18
- self.stds = dict((key, val.std(axis=0)) for key, val in self.data.items())
 
 
 
 
 
 
 
 
 
19
  self.device = self.unet.device
20
  self.state_dim = env.observation_space.shape[0]
21
  self.action_dim = env.action_space.shape[0]
 
14
  self.scheduler = scheduler
15
  self.env = env
16
  self.data = env.get_dataset()
17
+ for key in self.data.keys():
18
+ try:
19
+ self.means[key] = self.data[key].mean()
20
+ except AxisError:
21
+ pass
22
+ self.stds = dict()
23
+ for key in self.data.keys():
24
+ try:
25
+ self.stds[key] = self.data[key].std()
26
+ except AxisError:
27
+ pass
28
  self.device = self.unet.device
29
  self.state_dim = env.observation_space.shape[0]
30
  self.action_dim = env.action_space.shape[0]