# Original Source: # Original Source: # https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_linear.py # https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_quadratic.py # Modifications made to jacobian computation by Yurong You and Kevin Shih # Original License Text: ######################################################################### # The MIT License (MIT) # Copyright (c) 2020, nicolas deutschmann # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import torch import torch.nn.functional as F third_dimension_softmax = torch.nn.Softmax(dim=2) def piecewise_linear_transform( x, q_tilde, compute_jacobian=True, outlier_passthru=True ): """Apply an element-wise piecewise-linear transformation to some variables Parameters ---------- x : torch.Tensor a tensor with shape (N,k) where N is the batch dimension while k is the dimension of the variable space. This variable span the k-dimensional unit hypercube q_tilde: torch.Tensor is a tensor with shape (N,k,b) where b is the number of bins. This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k, i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet. Normalization is imposed in this function using softmax. compute_jacobian : bool, optional determines whether the jacobian should be compute or None is returned Returns ------- tuple of torch.Tensor pair `(y,h)`. - `y` is a tensor with shape (N,k) living in the k-dimensional unit hypercube - `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None. """ logj = None # TODO bottom-up assesment of handling the differentiability of variables # Compute the bin width w N, k, b = q_tilde.shape Nx, kx = x.shape assert N == Nx and k == kx, "Shape mismatch" w = 1.0 / b # Compute normalized bin heights with softmax function on bin dimension q = 1.0 / w * third_dimension_softmax(q_tilde) # x is in the mx-th bin: x \in [0,1], # mx \in [[0,b-1]], so we clamp away the case x == 1 mx = torch.clamp(torch.floor(b * x), 0, b - 1).to(torch.long) # Need special error handling because trying to index with mx # if it contains nans will lock the GPU. (device-side assert triggered) if torch.any(torch.isnan(mx)).item() or torch.any(mx < 0) or torch.any(mx >= b): raise Exception("NaN detected in PWLinear bin indexing") # We compute the output variable in-place out = x - mx * w # alpha (element of [0.,w], the position of x in its bin # Multiply by the slope # q has shape (N,k,b), mxu = mx.unsqueeze(-1) has shape (N,k) with entries that are a b-index # gather defines slope[i, j, k] = q[i, j, mxu[i, j, k]] with k taking only 0 as a value # i.e. we say slope[i, j] = q[i, j, mx [i, j]] slopes = torch.gather(q, 2, mx.unsqueeze(-1)).squeeze(-1) out = out * slopes # The jacobian is the product of the slopes in all dimensions # Compute the integral over the left-bins. # 1. Compute all integrals: cumulative sum of bin height * bin weight. # We want that index i contains the cumsum *strictly to the left* so we shift by 1 # leaving the first entry null, which is achieved with a roll and assignment q_left_integrals = torch.roll(torch.cumsum(q, 2) * w, 1, 2) q_left_integrals[:, :, 0] = 0 # 2. Access the correct index to get the left integral of each point and add it to our transformation out = out + torch.gather(q_left_integrals, 2, mx.unsqueeze(-1)).squeeze(-1) # Regularization: points must be strictly within the unit hypercube # Use the dtype information from pytorch eps = torch.finfo(out.dtype).eps out = out.clamp(min=eps, max=1.0 - eps) oob_mask = torch.logical_or(x < 0.0, x > 1.0).detach().float() if outlier_passthru: out = out * (1 - oob_mask) + x * oob_mask slopes = slopes * (1 - oob_mask) + oob_mask if compute_jacobian: # logj = torch.log(torch.prod(slopes.float(), 1)) logj = torch.sum(torch.log(slopes), 1) del slopes return out, logj def piecewise_linear_inverse_transform( y, q_tilde, compute_jacobian=True, outlier_passthru=True ): """ Apply inverse of an element-wise piecewise-linear transformation to some variables Parameters ---------- y : torch.Tensor a tensor with shape (N,k) where N is the batch dimension while k is the dimension of the variable space. This variable span the k-dimensional unit hypercube q_tilde: torch.Tensor is a tensor with shape (N,k,b) where b is the number of bins. This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k, i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet. Normalization is imposed in this function using softmax. compute_jacobian : bool, optional determines whether the jacobian should be compute or None is returned Returns ------- tuple of torch.Tensor pair `(x,h)`. - `x` is a tensor with shape (N,k) living in the k-dimensional unit hypercube - `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None. """ # TODO bottom-up assesment of handling the differentiability of variables # Compute the bin width w N, k, b = q_tilde.shape Ny, ky = y.shape assert N == Ny and k == ky, "Shape mismatch" w = 1.0 / b # Compute normalized bin heights with softmax function on the bin dimension q = 1.0 / w * third_dimension_softmax(q_tilde) # Compute the integral over the left-bins in the forward transform. # 1. Compute all integrals: cumulative sum of bin height * bin weight. # We want that index i contains the cumsum *strictly to the left*, # so we shift by 1 leaving the first entry null, # which is achieved with a roll and assignment q_left_integrals = torch.roll(torch.cumsum(q.float(), 2) * w, 1, 2) q_left_integrals[:, :, 0] = 0 # Find which bin each y belongs to by finding the smallest bin such that # y - q_left_integral is positive edges = (y.unsqueeze(-1) - q_left_integrals).detach() # y and q_left_integrals are between 0 and 1, # so that their difference is at most 1. # By setting the negative values to 2., we know that the # smallest value left is the smallest positive edges[edges < 0] = 2.0 edges = torch.clamp(torch.argmin(edges, dim=2), 0, b - 1).to(torch.long) # Need special error handling because trying to index with mx # if it contains nans will lock the GPU. (device-side assert triggered) if ( torch.any(torch.isnan(edges)).item() or torch.any(edges < 0) or torch.any(edges >= b) ): raise Exception("NaN detected in PWLinear bin indexing") # Gather the left integrals at each edge. See comment about gathering in q_left_integrals # for the unsqueeze q_left_integrals = q_left_integrals.gather(2, edges.unsqueeze(-1)).squeeze(-1) # Gather the slope at each edge. q = q.gather(2, edges.unsqueeze(-1)).squeeze(-1) # Build the output x = (y - q_left_integrals) / q + edges * w # Regularization: points must be strictly within the unit hypercube # Use the dtype information from pytorch eps = torch.finfo(x.dtype).eps x = x.clamp(min=eps, max=1.0 - eps) oob_mask = torch.logical_or(y < 0.0, y > 1.0).detach().float() if outlier_passthru: x = x * (1 - oob_mask) + y * oob_mask q = q * (1 - oob_mask) + oob_mask # Prepare the jacobian logj = None if compute_jacobian: # logj = - torch.log(torch.prod(q, 1)) logj = -torch.sum(torch.log(q.float()), 1) return x.detach(), logj def unbounded_piecewise_quadratic_transform( x, w_tilde, v_tilde, upper=1, lower=0, inverse=False ): assert upper > lower _range = upper - lower inside_interval_mask = (x >= lower) & (x < upper) outside_interval_mask = ~inside_interval_mask outputs = torch.zeros_like(x) log_j = torch.zeros_like(x) outputs[outside_interval_mask] = x[outside_interval_mask] log_j[outside_interval_mask] = 0 output, _log_j = piecewise_quadratic_transform( (x[inside_interval_mask] - lower) / _range, w_tilde[inside_interval_mask, :], v_tilde[inside_interval_mask, :], inverse=inverse, ) outputs[inside_interval_mask] = output * _range + lower if not inverse: # the before and after transformation cancel out, so the log_j would be just as it is. log_j[inside_interval_mask] = _log_j else: log_j = None return outputs, log_j def weighted_softmax(v, w): # to avoid NaN... v = v - torch.max(v, dim=-1, keepdim=True)[0] v = torch.exp(v) + 1e-8 # to avoid NaN... v_sum = torch.sum((v[..., :-1] + v[..., 1:]) / 2 * w, dim=-1, keepdim=True) return v / v_sum def piecewise_quadratic_transform(x, w_tilde, v_tilde, inverse=False): """Element-wise piecewise-quadratic transformation Parameters ---------- x : torch.Tensor *, The variable spans the D-dim unit hypercube ([0,1)) w_tilde : torch.Tensor * x K defined in the paper v_tilde : torch.Tensor * x (K+1) defined in the paper inverse : bool forward or inverse Returns ------- c : torch.Tensor *, transformed value log_j : torch.Tensor *, log determinant of the Jacobian matrix """ w = torch.softmax(w_tilde, dim=-1) v = weighted_softmax(v_tilde, w) w_cumsum = torch.cumsum(w, dim=-1) # force sum = 1 w_cumsum[..., -1] = 1.0 w_cumsum_shift = F.pad(w_cumsum, (1, 0), "constant", 0) cdf = torch.cumsum((v[..., 1:] + v[..., :-1]) / 2 * w, dim=-1) # force sum = 1 cdf[..., -1] = 1.0 cdf_shift = F.pad(cdf, (1, 0), "constant", 0) if not inverse: # * x D x 1, (w_cumsum[idx-1] < x <= w_cumsum[idx]) bin_index = torch.searchsorted(w_cumsum, x.unsqueeze(-1)) else: # * x D x 1, (cdf[idx-1] < x <= cdf[idx]) bin_index = torch.searchsorted(cdf, x.unsqueeze(-1)) w_b = torch.gather(w, -1, bin_index).squeeze(-1) w_bn1 = torch.gather(w_cumsum_shift, -1, bin_index).squeeze(-1) v_b = torch.gather(v, -1, bin_index).squeeze(-1) v_bp1 = torch.gather(v, -1, bin_index + 1).squeeze(-1) cdf_bn1 = torch.gather(cdf_shift, -1, bin_index).squeeze(-1) if not inverse: alpha = (x - w_bn1) / w_b.clamp(min=torch.finfo(w_b.dtype).eps) c = (alpha**2) / 2 * (v_bp1 - v_b) * w_b + alpha * v_b * w_b + cdf_bn1 # just sum of log pdfs log_j = torch.lerp(v_b, v_bp1, alpha).clamp(min=torch.finfo(c.dtype).eps).log() # make sure it falls into [0,1) c = c.clamp(min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(c.dtype).eps) return c, log_j else: # quadratic equation for alpha # alpha should fall into (0, 1]. Since a, b > 0, the symmetry axis -b/2a < 0 and we should pick the larger root # skip calculating the log_j in inverse since we don't need it a = (v_bp1 - v_b) * w_b / 2 b = v_b * w_b c = cdf_bn1 - x alpha = (-b + torch.sqrt((b**2) - 4 * a * c)) / (2 * a) inv = alpha * w_b + w_bn1 # make sure it falls into [0,1) inv = inv.clamp( min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(inv.dtype).eps ) return inv, None