Spaces:
Runtime error
Runtime error
# borrowed from | |
import torch | |
import torch | |
import torch.nn as nn | |
def disable_running_stats(model): | |
def _disable(module): | |
if isinstance(module, nn.BatchNorm2d): | |
module.backup_momentum = module.momentum | |
module.momentum = 0 | |
model.apply(_disable) | |
def enable_running_stats(model): | |
def _enable(module): | |
if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"): | |
module.momentum = module.backup_momentum | |
model.apply(_enable) | |
class SAM(torch.optim.Optimizer): | |
def __init__(self, params, base_optimizer, rho=0.05, **kwargs): | |
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" | |
defaults = dict(rho=rho, **kwargs) | |
super(SAM, self).__init__(params, defaults) | |
self.base_optimizer = base_optimizer(self.param_groups, **kwargs) | |
self.param_groups = self.base_optimizer.param_groups | |
def first_step(self, zero_grad=False): | |
grad_norm = self._grad_norm() | |
for group in self.param_groups: | |
scale = group["rho"] / (grad_norm + 1e-12) | |
for p in group["params"]: | |
if p.grad is None: continue | |
e_w = p.grad * scale.to(p) | |
p.add_(e_w) # climb to the local maximum "w + e(w)" | |
self.state[p]["e_w"] = e_w | |
if zero_grad: self.zero_grad() | |
def second_step(self, zero_grad=False): | |
for group in self.param_groups: | |
for p in group["params"]: | |
if p.grad is None: continue | |
p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)" | |
self.base_optimizer.step() # do the actual "sharpness-aware" update | |
if zero_grad: self.zero_grad() | |
def step(self, closure=None): | |
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" | |
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass | |
self.first_step(zero_grad=True) | |
closure() | |
self.second_step() | |
def _grad_norm(self): | |
shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism | |
norm = torch.norm( | |
torch.stack([ | |
p.grad.norm(p=2).to(shared_device) | |
for group in self.param_groups for p in group["params"] | |
if p.grad is not None | |
]), | |
p=2 | |
) | |
return norm |