File size: 6,736 Bytes
cd0ec4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch, math, copy
class L0_Regularizer(torch.nn.Module):
def __init__(self, original_module: torch.nn.Module, lam: float, weight_decay: float = 0,
temperature: float = 2 / 3, droprate_init=0.2, limit_a=-.1, limit_b=1.1, epsilon=1e-6
):
super(L0_Regularizer, self).__init__()
self.module = copy.deepcopy(original_module)
self.pre_parameters = torch.nn.ParameterDict(
{name.replace(".", "#") + "_p": param for name, param in self.module.named_parameters()}
)
self.param_names = [name.replace(".", "#") for name, param in self.module.named_parameters()]
self.mask_parameters = torch.nn.ParameterDict()
self.lam = lam
self.weight_decay = weight_decay
self.temperature = temperature
self.droprate_init = droprate_init
self.limit_a = limit_a
self.limit_b = limit_b
self.epsilon = epsilon
for name, param in self.module.named_parameters():
mask = torch.nn.Parameter(torch.Tensor(param.size()))
self.mask_parameters.update({name.replace(".", "#") + "_m": mask})
# below code guts the module of its previous parameters,
# allowing them to be replaced by non-leaf tensors
self.reset_parameters()
self.to("cpu")
for name in self.param_names:
L0_Regularizer.recursive_del(self.module, name)
L0_Regularizer.recursive_set(self.module, name, self.sample_weights(name, 1))
'''
Below code direct copy with adaptations from codebase for:
Louizos, C., Welling, M., & Kingma, D. P. (2017).
Learning sparse neural networks through L_0 regularization.
arXiv preprint arXiv:1712.01312.
'''
def reset_parameters(self):
for name, weight in self.pre_parameters.items():
if "bias" in name:
torch.nn.init.constant_(weight, 0.0)
else:
torch.nn.init.xavier_uniform_(weight)
for name, weight in self.mask_parameters.items():
torch.nn.init.normal_(weight, math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2)
def constrain_parameters(self):
for name, weight in self.mask_parameters.items():
weight.data.clamp_(min=math.log(1e-2), max=math.log(1e2))
def cdf_qz(self, x, param):
"""Implements the CDF of the 'stretched' concrete distribution"""
# references parameters
xn = (x - self.limit_a) / (self.limit_b - self.limit_a)
logits = math.log(xn) - math.log(1 - xn)
return torch.sigmoid(
logits * self.temperature - self.mask_parameters[param + "_m"]).clamp(min=self.epsilon,
max=1 - self.epsilon)
def quantile_concrete(self, x, param):
"""Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
# references parameters
y = torch.sigmoid(
(torch.log(x) - torch.log(1 - x) + self.mask_parameters[param + "_m"]) / self.temperature)
return y * (self.limit_b - self.limit_a) + self.limit_a
def _reg_w(self, param):
"""Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty"""
"""is_neural is old method, calculates wrt columns first multiplied by expected values of gates
second method calculates wrt all parameters
"""
# why is this negative? will investigate behavior at testing
# reversed negative value, value should increase with description length
logpw_l2 = - (.5 * self.weight_decay * self.pre_parameters[param + "_p"].pow(2)) - self.lam
logpw = torch.sum((1 - self.cdf_qz(0, param)) * logpw_l2)
return -logpw
def regularization(self):
r_total = torch.Tensor([])
for param in self.param_names:
device = self.mask_parameters[param + "_m"].device
r_total = torch.cat([r_total.to(device), self._reg_w(param).unsqueeze(dim=0)])
return r_total.sum()
def count_l0(self):
total = []
for param in self.param_names:
total.append(torch.sum(1 - self.cdf_qz(0, param)).unsqueeze(dim=0))
return torch.cat(total).sum()
def count_l2(self):
total = []
for param in self.param_names:
total.append(self._l2_helper(param).unsqueeze(dim=0))
return torch.cat(total).sum()
def _l2_helper(self, param):
return (self.sample_weights(param, False) ** 2).sum()
def get_eps(self, size):
"""Uniform random numbers for the concrete distribution"""
# Variable deprecated and removed
eps = torch.rand(size) * (1 - 2 * self.epsilon) + self.epsilon
return eps
def sample_z(self, param, sample=True):
"""Sample the hard-concrete gates for training and use a deterministic value for testing"""
size = self.mask_parameters[param + "_m"].size()
if sample:
device = self.mask_parameters[param + "_m"].device
eps = self.get_eps(size).to(device)
z = self.quantile_concrete(eps, param)
return torch.nn.functional.hardtanh(z, min_val=0, max_val=1)
else: # mode
pi = torch.sigmoid(self.mask_parameters[param + "_m"])
return torch.nn.functional.hardtanh(pi * (self.limit_b - self.limit_a) + self.limit_a, min_val=0, max_val=1)
def sample_weights(self, param, sample=True):
mask = self.sample_z(param, sample)
return mask * self.pre_parameters[param + "_p"]
def forward(self, x):
"""rewrite parameters (tensors) of core module and feedforward"""
for param in self.param_names:
L0_Regularizer.recursive_set(self.module, param, self.sample_weights(param, sample=self.training))
return self.module(x)
@staticmethod
def recursive_get(obj, att_name):
if "#" in att_name:
first, last = att_name.split("#", 1)
L0_Regularizer.recursive_get(getattr(obj, first), last)
else:
return getattr(obj, att_name)
@staticmethod
def recursive_set(obj, att_name, val):
if "#" in att_name:
first, last = att_name.split("#", 1)
L0_Regularizer.recursive_set(getattr(obj, first), last, val)
else:
setattr(obj, att_name, val)
@staticmethod
def recursive_del(obj, att_name):
if "#" in att_name:
first, last = att_name.split("#", 1)
L0_Regularizer.recursive_del(getattr(obj, first), last)
else:
delattr(obj, att_name)
|