# This file will contain helper functions related to the pruning process, including any specialized pruning functions and the SparseGPT functionality. # DISCLAIMER: The SparseGPT class is a modified version of the original SparseGPT class. The original SparseGPT class can be found in [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot]. import math import time import torch import torch.nn as nn import transformers from quant import * # turned this flag to be True DEBUG = True torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): if type(module) in layers: return {name: module} res = {} for name1, child in module.named_children(): res.update(find_layers( child, layers=layers, name=name + '.' + name1 if name != '' else name1 )) return res class SparseGPT_OPT: def __init__(self, layer): self.layer = layer self.dev = self.layer.weight.device W = layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): W = W.t() self.rows = W.shape[0] self.columns = W.shape[1] self.H = torch.zeros((self.columns, self.columns), device=self.dev) self.nsamples = 0 self.batch_inp = [] self.batch_out = [] def add_batch(self, inp, out, name, blocksize=1024): if DEBUG: self.inp1 = inp self.out1 = out if len(inp.shape) == 2: inp = inp.unsqueeze(0) ###### added code if name == 'fc1' or name == 'fc2': self.batch_inp.append(inp[0].clone().detach()) if len(out.shape) == 3: out = out.squeeze(0) self.batch_out.append(out.clone().detach()) ###### tmp = inp.shape[0] if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() self.H *= self.nsamples / (self.nsamples + tmp) self.nsamples += tmp inp = math.sqrt(2 / self.nsamples) * inp.float() self.H += inp.matmul(inp.t()) def fasterprune( self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01 ): W = self.layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): W = W.t() W = W.float() if hasattr(self, 'quantizer'): if not self.quantizer.ready(): self.quantizer.find_params(W, weight=True) tick = time.time() H = self.H # del self.H dead = torch.diag(H) == 0 H[dead, dead] = 1 W[:, dead] = 0 Losses = torch.zeros(self.rows, device=self.dev) damp = percdamp * torch.mean(torch.diag(H)) diag = torch.arange(self.columns, device=self.dev) H[diag, diag] += damp H = torch.linalg.cholesky(H) H = torch.cholesky_inverse(H) H = torch.linalg.cholesky(H, upper=True) Hinv = H mask = None for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) count = i2 - i1 W1 = W[:, i1:i2].clone() Q1 = torch.zeros_like(W1) Err1 = torch.zeros_like(W1) Losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] if prunen == 0: if mask is not None: mask1 = mask[:, i1:i2] else: tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] mask1 = tmp <= thresh else: mask1 = torch.zeros_like(W1) == 1 for i in range(count): w = W1[:, i] d = Hinv1[i, i] if prunen != 0 and i % prunem == 0: tmp = W1[:, i:(i + prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2 mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True) q = w.clone() q[mask1[:, i]] = 0 if hasattr(self, 'quantizer'): q = quantize( q.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq ).flatten() Q1[:, i] = q Losses1[:, i] = (w - q) ** 2 / d ** 2 err1 = (w - q) / d W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) Err1[:, i] = err1 W[:, i1:i2] = Q1 Losses += torch.sum(Losses1, 1) / 2 W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) # if DEBUG: # self.layer.weight.data[:, :i2] = W[:, :i2] # self.layer.weight.data[:, i2:] = W[:, i2:] # print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) # print(torch.sum(Losses)) torch.cuda.synchronize() print('time %.2f' % (time.time() - tick)) print('error', torch.sum(Losses).item()) if isinstance(self.layer, transformers.Conv1D): W = W.t() self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) # if DEBUG: # print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) def free(self): if DEBUG: self.inp1 = None self.out1 = None self.H = None torch.cuda.empty_cache() class SparseGPT_LlaMA: def __init__(self, layer): self.layer = layer self.dev = self.layer.weight.device W = layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): W = W.t() self.rows = W.shape[0] self.columns = W.shape[1] self.H = torch.zeros((self.columns, self.columns), device=self.dev) self.nsamples = 0 self.batch_inp = [] self.batch_out = [] def add_batch(self, inp, out, name, blocksize=1024): if DEBUG: self.inp1 = inp self.out1 = out if len(inp.shape) == 2: inp = inp.unsqueeze(0) ###### added code if name == 'mlp.up_proj' or name == 'mlp.down_proj': self.batch_inp.append(inp[0].clone().detach()) if len(out.shape) == 3: out = out.squeeze(0) self.batch_out.append(out.clone().detach()) if name == 'mlp.gate_proj': # for this layer, we only store the outputs. for inputs, they are shared with 'mlp.up_proj' if len(out.shape) == 3: out = out.squeeze(0) self.batch_out.append(out.clone().detach()) ###### tmp = inp.shape[0] if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() self.H *= self.nsamples / (self.nsamples + tmp) self.nsamples += tmp inp = math.sqrt(2 / self.nsamples) * inp.float() self.H += inp.matmul(inp.t()) def fasterprune( self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01 ): W = self.layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): W = W.t() W = W.float() if hasattr(self, 'quantizer'): if not self.quantizer.ready(): self.quantizer.find_params(W, weight=True) tick = time.time() H = self.H # del self.H dead = torch.diag(H) == 0 H[dead, dead] = 1 W[:, dead] = 0 Losses = torch.zeros(self.rows, device=self.dev) damp = percdamp * torch.mean(torch.diag(H)) diag = torch.arange(self.columns, device=self.dev) H[diag, diag] += damp H = torch.linalg.cholesky(H) H = torch.cholesky_inverse(H) H = torch.linalg.cholesky(H, upper=True) Hinv = H mask = None for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) count = i2 - i1 W1 = W[:, i1:i2].clone() Q1 = torch.zeros_like(W1) Err1 = torch.zeros_like(W1) Losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] if prunen == 0: if mask is not None: mask1 = mask[:, i1:i2] else: tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] mask1 = tmp <= thresh else: mask1 = torch.zeros_like(W1) == 1 for i in range(count): w = W1[:, i] d = Hinv1[i, i] if prunen != 0 and i % prunem == 0: tmp = W1[:, i:(i + prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2 mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True) q = w.clone() q[mask1[:, i]] = 0 if hasattr(self, 'quantizer'): q = quantize( q.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq ).flatten() Q1[:, i] = q Losses1[:, i] = (w - q) ** 2 / d ** 2 err1 = (w - q) / d W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) Err1[:, i] = err1 W[:, i1:i2] = Q1 Losses += torch.sum(Losses1, 1) / 2 W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) # if DEBUG: # self.layer.weight.data[:, :i2] = W[:, :i2] # self.layer.weight.data[:, i2:] = W[:, i2:] # print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) # print(torch.sum(Losses)) torch.cuda.synchronize() print('time %.2f' % (time.time() - tick)) print('error', torch.sum(Losses).item()) if isinstance(self.layer, transformers.Conv1D): W = W.t() self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) # if DEBUG: # print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) def free(self): if DEBUG: self.inp1 = None self.out1 = None self.H = None torch.cuda.empty_cache()