Spaces:
Sleeping
Sleeping
File size: 798 Bytes
44db2f9 c10a05f 44db2f9 c412087 c10a05f c412087 c10a05f 44db2f9 c412087 44db2f9 c412087 44db2f9 c412087 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
"""
Shared optimizer, the parameters in the optimizer
will shared in the multiprocessors.
"""
import torch
class SharedAdam(torch.optim.Adam):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0):
super(SharedAdam, self).__init__(
params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay
)
# State initialization
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p.data)
state["exp_avg_sq"] = torch.zeros_like(p.data)
# share in memory
state["exp_avg"].share_memory_()
state["exp_avg_sq"].share_memory_()
|