L0_Regularizer / L0_Regularizer.py
chaaz992's picture
Uploading repository
cd0ec4d
import torch, math, copy
class L0_Regularizer(torch.nn.Module):
def __init__(self, original_module: torch.nn.Module, lam: float, weight_decay: float = 0,
temperature: float = 2 / 3, droprate_init=0.2, limit_a=-.1, limit_b=1.1, epsilon=1e-6
):
super(L0_Regularizer, self).__init__()
self.module = copy.deepcopy(original_module)
self.pre_parameters = torch.nn.ParameterDict(
{name.replace(".", "#") + "_p": param for name, param in self.module.named_parameters()}
)
self.param_names = [name.replace(".", "#") for name, param in self.module.named_parameters()]
self.mask_parameters = torch.nn.ParameterDict()
self.lam = lam
self.weight_decay = weight_decay
self.temperature = temperature
self.droprate_init = droprate_init
self.limit_a = limit_a
self.limit_b = limit_b
self.epsilon = epsilon
for name, param in self.module.named_parameters():
mask = torch.nn.Parameter(torch.Tensor(param.size()))
self.mask_parameters.update({name.replace(".", "#") + "_m": mask})
# below code guts the module of its previous parameters,
# allowing them to be replaced by non-leaf tensors
self.reset_parameters()
self.to("cpu")
for name in self.param_names:
L0_Regularizer.recursive_del(self.module, name)
L0_Regularizer.recursive_set(self.module, name, self.sample_weights(name, 1))
'''
Below code direct copy with adaptations from codebase for:
Louizos, C., Welling, M., & Kingma, D. P. (2017).
Learning sparse neural networks through L_0 regularization.
arXiv preprint arXiv:1712.01312.
'''
def reset_parameters(self):
for name, weight in self.pre_parameters.items():
if "bias" in name:
torch.nn.init.constant_(weight, 0.0)
else:
torch.nn.init.xavier_uniform_(weight)
for name, weight in self.mask_parameters.items():
torch.nn.init.normal_(weight, math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2)
def constrain_parameters(self):
for name, weight in self.mask_parameters.items():
weight.data.clamp_(min=math.log(1e-2), max=math.log(1e2))
def cdf_qz(self, x, param):
"""Implements the CDF of the 'stretched' concrete distribution"""
# references parameters
xn = (x - self.limit_a) / (self.limit_b - self.limit_a)
logits = math.log(xn) - math.log(1 - xn)
return torch.sigmoid(
logits * self.temperature - self.mask_parameters[param + "_m"]).clamp(min=self.epsilon,
max=1 - self.epsilon)
def quantile_concrete(self, x, param):
"""Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
# references parameters
y = torch.sigmoid(
(torch.log(x) - torch.log(1 - x) + self.mask_parameters[param + "_m"]) / self.temperature)
return y * (self.limit_b - self.limit_a) + self.limit_a
def _reg_w(self, param):
"""Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty"""
"""is_neural is old method, calculates wrt columns first multiplied by expected values of gates
second method calculates wrt all parameters
"""
# why is this negative? will investigate behavior at testing
# reversed negative value, value should increase with description length
logpw_l2 = - (.5 * self.weight_decay * self.pre_parameters[param + "_p"].pow(2)) - self.lam
logpw = torch.sum((1 - self.cdf_qz(0, param)) * logpw_l2)
return -logpw
def regularization(self):
r_total = torch.Tensor([])
for param in self.param_names:
device = self.mask_parameters[param + "_m"].device
r_total = torch.cat([r_total.to(device), self._reg_w(param).unsqueeze(dim=0)])
return r_total.sum()
def count_l0(self):
total = []
for param in self.param_names:
total.append(torch.sum(1 - self.cdf_qz(0, param)).unsqueeze(dim=0))
return torch.cat(total).sum()
def count_l2(self):
total = []
for param in self.param_names:
total.append(self._l2_helper(param).unsqueeze(dim=0))
return torch.cat(total).sum()
def _l2_helper(self, param):
return (self.sample_weights(param, False) ** 2).sum()
def get_eps(self, size):
"""Uniform random numbers for the concrete distribution"""
# Variable deprecated and removed
eps = torch.rand(size) * (1 - 2 * self.epsilon) + self.epsilon
return eps
def sample_z(self, param, sample=True):
"""Sample the hard-concrete gates for training and use a deterministic value for testing"""
size = self.mask_parameters[param + "_m"].size()
if sample:
device = self.mask_parameters[param + "_m"].device
eps = self.get_eps(size).to(device)
z = self.quantile_concrete(eps, param)
return torch.nn.functional.hardtanh(z, min_val=0, max_val=1)
else: # mode
pi = torch.sigmoid(self.mask_parameters[param + "_m"])
return torch.nn.functional.hardtanh(pi * (self.limit_b - self.limit_a) + self.limit_a, min_val=0, max_val=1)
def sample_weights(self, param, sample=True):
mask = self.sample_z(param, sample)
return mask * self.pre_parameters[param + "_p"]
def forward(self, x):
"""rewrite parameters (tensors) of core module and feedforward"""
for param in self.param_names:
L0_Regularizer.recursive_set(self.module, param, self.sample_weights(param, sample=self.training))
return self.module(x)
@staticmethod
def recursive_get(obj, att_name):
if "#" in att_name:
first, last = att_name.split("#", 1)
L0_Regularizer.recursive_get(getattr(obj, first), last)
else:
return getattr(obj, att_name)
@staticmethod
def recursive_set(obj, att_name, val):
if "#" in att_name:
first, last = att_name.split("#", 1)
L0_Regularizer.recursive_set(getattr(obj, first), last, val)
else:
setattr(obj, att_name, val)
@staticmethod
def recursive_del(obj, att_name):
if "#" in att_name:
first, last = att_name.split("#", 1)
L0_Regularizer.recursive_del(getattr(obj, first), last)
else:
delattr(obj, att_name)