DIPO / agent /helpers.py
Wyatt-Huang's picture
Upload 10 files
f761808 verified
import math
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def init_weights(m):
def truncated_normal_init(t, mean=0.0, std=0.01):
torch.nn.init.normal_(t, mean=mean, std=std)
while True:
cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std)
if not torch.sum(cond):
break
t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t)
return t
if type(m) == nn.Linear:
input_dim = m.in_features
truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(input_dim)))
m.bias.data.fill_(0.0)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
#-----------------------------------------------------------------------------#
#---------------------------------- sampling ---------------------------------#
#-----------------------------------------------------------------------------#
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
return torch.tensor(betas_clipped, dtype=dtype)
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=2e-2, dtype=torch.float32):
betas = np.linspace(
beta_start, beta_end, timesteps
)
return torch.tensor(betas, dtype=dtype)
def vp_beta_schedule(timesteps, dtype=torch.float32):
t = np.arange(1, timesteps + 1)
T = timesteps
b_max = 10.
b_min = 0.1
alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2)
betas = 1 - alpha
return torch.tensor(betas, dtype=dtype)
#-----------------------------------------------------------------------------#
#---------------------------------- losses -----------------------------------#
#-----------------------------------------------------------------------------#
class WeightedLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, targ, weights=1.0):
'''
pred, targ : tensor [ batch_size x action_dim ]
'''
loss = self._loss(pred, targ)
weighted_loss = (loss * weights).mean()
return weighted_loss
class WeightedL1(WeightedLoss):
def _loss(self, pred, targ):
return torch.abs(pred - targ)
class WeightedL2(WeightedLoss):
def _loss(self, pred, targ):
return F.mse_loss(pred, targ, reduction='none')
Losses = {
'l1': WeightedL1,
'l2': WeightedL2,
}
class EMA():
'''
empirical moving average
'''
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new