import torch |
import torch.nn as nn |
from pruning_utils import * |
from quant import * |
import math |
from transformers import OPTForCausalLM, LlamaForCausalLM |
def get_opt(args): |
def skip(*args, **kwargs): |
pass |
torch.nn.init.kaiming_uniform_ = skip |
torch.nn.init.uniform_ = skip |
torch.nn.init.normal_ = skip |
model = OPTForCausalLM.from_pretrained(args.model, torch_dtype='auto') |
model.seqlen = model.config.max_position_embeddings |
return model |
def get_llama(args): |
def skip(*args, **kwargs): |
pass |
torch.nn.init.kaiming_uniform_ = skip |
torch.nn.init.uniform_ = skip |
torch.nn.init.normal_ = skip |
model = LlamaForCausalLM.from_pretrained(args.model, torch_dtype='auto') |
model.seqlen = 2048 |
return model |
@torch.no_grad() |
def opt_sparsellm(model, dataloader, dev, args): |
print('Starting ...') |
use_cache = model.config.use_cache |
model.config.use_cache = False |
layers = model.model.decoder.layers |
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) |
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) |
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
model.model.decoder.project_out = model.model.decoder.project_out.to(dev) |
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
model.model.decoder.project_in = model.model.decoder.project_in.to(dev) |
layers[0] = layers[0].to(dev) |
dtype = next(iter(model.parameters())).dtype |
inps = torch.zeros( |
(args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
) |
cache = {'i': 0, 'attention_mask': None} |
class Catcher(nn.Module): |
def __init__(self, module): |
super().__init__() |
self.module = module |
def forward(self, inp, **kwargs): |
inps[cache['i']] = inp |
cache['i'] += 1 |
cache['attention_mask'] = kwargs['attention_mask'] |
raise ValueError |
layers[0] = Catcher(layers[0]) |
for batch in dataloader: |
try: |
model(batch[0].to(dev)) |
except ValueError: |
pass |
layers[0] = layers[0].module |
layers[0] = layers[0].cpu() |
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() |
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() |
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
model.model.decoder.project_out = model.model.decoder.project_out.cpu() |
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
model.model.decoder.project_in = model.model.decoder.project_in.cpu() |
torch.cuda.empty_cache() |
outs = torch.zeros_like(inps) |
attention_mask = cache['attention_mask'] |
print('Ready.') |
for i in range(len(layers)): |
layer = layers[i].to(dev) |
subset = find_layers(layer) |
gpts = {} |
for name in subset: |
if (not (args.minlayer <= i < args.maxlayer and args.prune_only in name)) == (not args.invert): |
continue |
gpts[name] = SparseGPT_OPT(subset[name]) |
if args.wbits < 16: |
gpts[name].quantizer = Quantizer() |
gpts[name].quantizer.configure( |
args.wbits, perchannel=True, sym=False, mse=False |
) |
def add_batch(name): |
def tmp(_, inp, out): |
gpts[name].add_batch(inp[0].data, out.data, name) |
return tmp |
handles = [] |
for name in gpts: |
handles.append(subset[name].register_forward_hook(add_batch(name))) |
for j in range(args.nsamples): |
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
for h in handles: |
h.remove() |
target_layer_names = ['fc1', 'fc2'] |
for name in gpts: |
if name not in target_layer_names: |
print(i, name) |
print('Pruning ...') |
sparsity = args.sparsity |
gpts[name].fasterprune( |
sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp, blocksize=args.blocksize |
) |
gpts[name].free() |
alpha = 5.0 |
beta = 5.0 |
gamma = 5.0 |
opt_epochs = 10 |
X_list = gpts['fc1'].batch_inp |
Y_list = gpts['fc2'].batch_out |
X = torch.stack(X_list, dim=0) |
Y = torch.stack(Y_list, dim=0) |
X, Y = X.reshape((-1, X.size(-1))).T, Y.reshape((-1, Y.size(-1))).T |
X_list, Y_list = None, None |
gpts['fc1'].batch_inp.clear() |
gpts['fc2'].batch_out.clear() |
hidden_z_list = gpts['fc1'].batch_out |
z = torch.stack(hidden_z_list, dim=0) |
hidden_z_list = None |
gpts['fc1'].batch_out.clear() |
hidden_p_list = gpts['fc2'].batch_inp |
p = torch.stack(hidden_p_list, dim=0) |
hidden_p_list = None |
gpts['fc2'].batch_inp.clear() |
z = z.reshape((-1, z.size(-1))).T.to(dev) |
p = p.reshape((-1, p.size(-1))).T.to(dev) |
torch.cuda.empty_cache() |
Xinv = torch.pinverse(X.to(dtype=torch.float32)).half() |
for opt_step in range(opt_epochs): |
if opt_step > 0: |
bias = subset['fc1'].bias.unsqueeze(1).expand(-1, z.size(-1)) |
weight_matrix_1 = torch.matmul(z - bias, Xinv) |
gpts['fc1'].layer.weight.copy_(weight_matrix_1) |
del bias, weight_matrix_1 |
pinv = torch.pinverse(p.to(dtype=torch.float32)).half() |
bias = subset['fc2'].bias.unsqueeze(1).expand(-1, Y.size(-1)) |
weight_matrix_2 = torch.matmul(Y - bias, pinv) |
gpts['fc2'].layer.weight.copy_(weight_matrix_2) |
del bias, weight_matrix_2, pinv |
torch.cuda.empty_cache() |
if opt_step > 0: |
tmp_H = torch.zeros_like(gpts['fc2'].H) |
tmp_p = p.T.reshape((args.nsamples, -1, p.size(0))) |
tmp_nsamples = 0 |
for j in range(args.nsamples): |
tmp_inp = tmp_p[j].unsqueeze(0) |
tmp = tmp_inp.shape[0] |
if isinstance(gpts['fc2'].layer, nn.Linear) or isinstance(gpts['fc2'].layer, transformers.Conv1D): |
if len(tmp_inp.shape) == 3: |
tmp_inp = tmp_inp.reshape((-1, tmp_inp.shape[-1])) |
tmp_inp = tmp_inp.t() |
tmp_H *= tmp_nsamples / (tmp_nsamples + tmp) |
tmp_nsamples += tmp |
tmp_inp = math.sqrt(2 / tmp_nsamples) * tmp_inp.float() |
tmp_H += tmp_inp.matmul(tmp_inp.t()) |
gpts['fc2'].H.copy_(tmp_H) |
del tmp_H, tmp_p |
torch.cuda.empty_cache() |
for name in target_layer_names: |
print(i, name) |
print('Pruning ...') |
sparsity = args.sparsity |
gpts[name].fasterprune( |
sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp, blocksize=args.blocksize |
) |
next_weight = subset['fc2'].weight |
m1 = beta * torch.matmul(next_weight.T, next_weight) |
m2 = gamma * torch.eye(m1.shape[0], device=m1.device) |
av = torch.inverse(m1 + m2).to(dtype=torch.float16) |
del m1, m2 |
torch.cuda.empty_cache() |
layer_nl_output = nn.functional.relu(z) |
bias = subset['fc2'].bias.unsqueeze(1).expand(-1, Y.size(-1)) |
m3 = beta * torch.matmul(next_weight.T, Y - bias) |
m4 = gamma * layer_nl_output |
af = m3 + m4 |
p = torch.matmul(av, af) |
del layer_nl_output, next_weight, av, m3, m4, af, bias |
torch.cuda.empty_cache() |
w = subset['fc1'].weight |
bias = subset['fc1'].bias.unsqueeze(1).expand(-1, z.size(-1)) |
m = torch.matmul(w, X) + bias |
sol1 = (gamma * p + alpha * m) / (gamma + alpha) |
sol2 = m |
del w, bias |
torch.cuda.empty_cache() |
z1 = torch.zeros_like(p) |
z2 = torch.zeros_like(p) |
chunk_size = 500 |
for k in range(0, sol1.size(0), chunk_size): |
chunk = slice(k, k + chunk_size) |
z1_chunk = z1[chunk] |
sol1_chunk = sol1[chunk] |
z1_chunk[sol1_chunk >= 0.] = sol1_chunk[sol1_chunk >= 0.] |
z1[chunk] = z1_chunk |
z2_chunk = z2[chunk] |
sol2_chunk = sol2[chunk] |
z2_chunk[sol2_chunk <= 0.] = sol2_chunk[sol2_chunk <= 0.] |
z2[chunk] = z2_chunk |
del z1_chunk, z2_chunk, sol1_chunk, sol2_chunk, sol1, sol2 |
torch.cuda.empty_cache() |
for k in range(0, z1.size(0), chunk_size): |
chunk = slice(k, k + chunk_size) |
fz_1_chunk = gamma * torch.square(p[chunk] - nn.functional.relu(z1[chunk])) + alpha * torch.square(z1[chunk] - m[chunk]) |
fz_2_chunk = gamma * torch.square(p[chunk] - nn.functional.relu(z2[chunk])) + alpha * torch.square(z2[chunk] - m[chunk]) |
index_z1_chunk = fz_1_chunk <= fz_2_chunk |
index_z2_chunk = fz_2_chunk < fz_1_chunk |
z[chunk][index_z1_chunk] = z1[chunk][index_z1_chunk] |
z[chunk][index_z2_chunk] = z2[chunk][index_z2_chunk] |
del fz_1_chunk, fz_2_chunk, index_z1_chunk, index_z2_chunk, z1, z2, m, chunk |
torch.cuda.empty_cache() |
for name in target_layer_names: |
gpts[name].free() |
for j in range(args.nsamples): |
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
layers[i] = layer.cpu() |
del layer |
torch.cuda.empty_cache() |
inps, outs = outs, inps |
model.config.use_cache = use_cache |
@torch.no_grad() |
def llama_sparsellm(model, dataloader, dev, args): |
print("Starting...") |
use_cache = model.config.use_cache |
model.config.use_cache = False |
layers = model.model.layers |
model.model.embed_tokens = model.model.embed_tokens.to(dev) |
model.model.norm = model.model.norm.to(dev) |
layers[0] = layers[0].to(dev) |
dtype = next(iter(model.parameters())).dtype |
inps = torch.zeros( |
(args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
) |
cache = {"i": 0, "attention_mask": None} |
class Catcher(nn.Module): |
def __init__(self, module): |
super().__init__() |
self.module = module |
def forward(self, inp, **kwargs): |
inps[cache["i"]] = inp |
cache["i"] += 1 |
cache["attention_mask"] = kwargs["attention_mask"] |
raise ValueError |
layers[0] = Catcher(layers[0]) |
for batch in dataloader: |
try: |
model(batch[0].to(dev)) |
except ValueError: |
pass |
layers[0] = layers[0].module |
layers[0] = layers[0].cpu() |
model.model.embed_tokens = model.model.embed_tokens.cpu() |
model.model.norm = model.model.norm.cpu() |
torch.cuda.empty_cache() |
outs = torch.zeros_like(inps) |
attention_mask = cache["attention_mask"] |
print("Ready.") |
for i in range(len(layers)): |
layer = layers[i].to(dev) |
full = find_layers(layer) |
if args.true_sequential: |
sequential = [ |
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], |
["self_attn.o_proj"], |
["mlp.up_proj", "mlp.gate_proj"], |
["mlp.down_proj"], |
] |
else: |
sequential = [list(full.keys())] |
for names in sequential: |
subset = {n: full[n] for n in names} |
gpts = {} |
for name in subset: |
if ( |
not (args.minlayer <= i < args.maxlayer and args.prune_only in name) |
) == (not args.invert): |
continue |
gpts[name] = SparseGPT_LlaMA(subset[name]) |
if args.wbits < 16: |
gpts[name].quantizer = Quantizer() |
gpts[name].quantizer.configure( |
args.wbits, perchannel=True, sym=False, mse=False |
) |
def add_batch(name): |
def tmp(_, inp, out): |
gpts[name].add_batch(inp[0].data, out.data, name) |
return tmp |
handles = [] |
for name in subset: |
handles.append(subset[name].register_forward_hook(add_batch(name))) |
for j in range(args.nsamples): |
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
for h in handles: |
h.remove() |
target_layer_names = ["mlp.up_proj", "mlp.gate_proj", "mlp.down_proj"] |
for name in subset: |
if name not in target_layer_names: |
print(i, name) |
print("Pruning ...") |
sparsity = args.sparsity |
gpts[name].fasterprune( |
sparsity, |
prunen=args.prunen, |
prunem=args.prunem, |
percdamp=args.percdamp, |
blocksize=args.blocksize, |
) |
gpts[name].free() |
alpha = 5.0 |
beta = 5.0 |
gamma = 5.0 |
opt_epochs = 8 |
X_list = gpts['mlp.up_proj'].batch_inp |
Y_list = gpts['mlp.down_proj'].batch_out |
X = torch.stack(X_list, dim=0) |
Y = torch.stack(Y_list, dim=0) |
X, Y = X.reshape((-1, X.size(-1))).T, Y.reshape((-1, Y.size(-1))).T |
X_list, Y_list = None, None |
gpts['mlp.up_proj'].batch_inp.clear() |
gpts['mlp.down_proj'].batch_out.clear() |
hidden_z_list = gpts['mlp.up_proj'].batch_out |
z = torch.stack(hidden_z_list, dim=0) |
hidden_z_list = None |
gpts['mlp.up_proj'].batch_out.clear() |
hidden_p_list = gpts['mlp.down_proj'].batch_inp |
p = torch.stack(hidden_p_list, dim=0) |
hidden_p_list = None |
gpts['mlp.down_proj'].batch_inp.clear() |
hidden_s_list = gpts['mlp.gate_proj'].batch_out |
s = torch.stack(hidden_s_list, dim=0) |
hidden_s_list = None |
gpts['mlp.gate_proj'].batch_out.clear() |
z = z.reshape((-1, z.size(-1))).T.to(dev) |
p = p.reshape((-1, p.size(-1))).T.to(dev) |
s = s.reshape((-1, s.size(-1))).T.to(dev) |
torch.cuda.empty_cache() |
Xinv = torch.pinverse(X.to(dtype=torch.float32)).half() |
training_loss = {'Y_p_loss': [], 'p_z_loss': [], 'z_X_loss': [], 'train_loss': []} |
for opt_step in range(opt_epochs): |
if opt_step > 0: |
weight_matrix_1 = torch.matmul(z, Xinv) |
gpts['mlp.up_proj'].layer.weight.copy_(weight_matrix_1) |
del weight_matrix_1 |
pinv = torch.pinverse(p.to(dtype=torch.float32)).half() |
weight_matrix_2 = torch.matmul(Y, pinv) |
gpts['mlp.down_proj'].layer.weight.copy_(weight_matrix_2) |
del weight_matrix_2, pinv |
weight_matrix_3 = torch.matmul(s, Xinv) |
gpts['mlp.gate_proj'].layer.weight.copy_(weight_matrix_3) |
del weight_matrix_3 |
torch.cuda.empty_cache() |
if opt_step > 0: |
tmp_H = torch.zeros_like(gpts['mlp.down_proj'].H) |
tmp_p = p.T.reshape((args.nsamples, -1, p.size(0))) |
tmp_nsamples = 0 |
for j in range(args.nsamples): |
tmp_inp = tmp_p[j].unsqueeze(0) |
tmp = tmp_inp.shape[0] |
if isinstance(gpts['mlp.down_proj'].layer, nn.Linear) or isinstance(gpts['mlp.down_proj'].layer, transformers.Conv1D): |
if len(tmp_inp.shape) == 3: |
tmp_inp = tmp_inp.reshape((-1, tmp_inp.shape[-1])) |
tmp_inp = tmp_inp.t() |
tmp_H *= tmp_nsamples / (tmp_nsamples + tmp) |
tmp_nsamples += tmp |
tmp_inp = math.sqrt(2 / tmp_nsamples) * tmp_inp.float() |
tmp_H += tmp_inp.matmul(tmp_inp.t()) |
gpts['mlp.down_proj'].H.copy_(tmp_H) |
del tmp_H, tmp_p |
torch.cuda.empty_cache() |
for name in target_layer_names: |
print(i, name) |
print('Pruning ...') |
sparsity = args.sparsity |
gpts[name].fasterprune( |
sparsity, |
prunen=args.prunen, |
prunem=args.prunem, |
percdamp=args.percdamp, |
blocksize=args.blocksize, |
) |
next_weight = subset['mlp.down_proj'].weight |
m1 = beta * torch.matmul(next_weight.T, next_weight) |
m2 = gamma * torch.eye(m1.shape[0], device=m1.device) |
av = torch.inverse(m1 + m2).to(dtype=torch.float16) |
del m1, m2 |
torch.cuda.empty_cache() |
layer_nl_output = nn.functional.silu(s) * z |
m3 = beta * torch.matmul(next_weight.T, Y) |
m4 = gamma * layer_nl_output |
af = m3 + m4 |
p = torch.matmul(av, af) |
del layer_nl_output, next_weight, av, m3, m4, af |
torch.cuda.empty_cache() |
w = subset['mlp.up_proj'].weight |
m = torch.matmul(w, X) |
swish = nn.functional.silu(s) |
z = (m + swish * p) / (swish ** 2 + 1) |
del w, m, swish |
torch.cuda.empty_cache() |
w = subset['mlp.gate_proj'].weight |
w = w.to(dtype=torch.float32).requires_grad_(True) |
s_update_epochs = 2 |
s_learning_rate = 0.01 |
for _ in range(s_update_epochs): |
batch_size = 1000 |
for k in range(0, s.size(-1), batch_size): |
chunk = slice(k, k + batch_size) |
X_batch = X[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
z_batch = z[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
p_batch = p[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
s_batch = s[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
with torch.enable_grad(): |
loss_s = alpha * torch.norm(s_batch - torch.matmul(w, X_batch))**2 |
loss_s += gamma * torch.norm(p_batch - nn.functional.silu(s_batch) * z_batch)**2 |
loss_s.backward() |
s_batch -= s_learning_rate * s_batch.grad |
s_batch.grad.zero_() |
s[:,chunk] = s_batch.detach().to(dtype=torch.float16) |
s_batch, X_batch, z_batch, p_batch, w = s_batch.detach(), X_batch.detach(), z_batch.detach(), p_batch.detach(), w.detach() |
del w, loss_s, s_batch, X_batch, z_batch, p_batch |
torch.cuda.empty_cache() |
tmp_training_loss = nn.functional.mse_loss(torch.matmul(subset['mlp.down_proj'].weight, |
nn.functional.silu(torch.matmul(subset['mlp.gate_proj'].weight, X)) |
* torch.matmul(subset['mlp.up_proj'].weight, X)), Y) |
training_loss['train_loss'].append(tmp_training_loss.item()) |
for j in range(args.nsamples): |
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
layers[i] = layer.cpu() |
del layer |
del gpts |
torch.cuda.empty_cache() |
inps, outs = outs, inps |
model.config.use_cache = use_cache |
@torch.no_grad() |
def opt_eval(model, testenc, dev, args, dataset: str): |
print('Evaluating ...') |
testenc = testenc.input_ids |
nsamples = testenc.numel() // model.seqlen |
use_cache = model.config.use_cache |
model.config.use_cache = False |
layers = model.model.decoder.layers |
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) |
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) |
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
model.model.decoder.project_out = model.model.decoder.project_out.to(dev) |
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
model.model.decoder.project_in = model.model.decoder.project_in.to(dev) |
layers[0] = layers[0].to(dev) |
dtype = next(iter(model.parameters())).dtype |
inps = torch.zeros( |
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
) |
cache = {'i': 0, 'attention_mask': None} |
class Catcher(nn.Module): |
def __init__(self, module): |
super().__init__() |
self.module = module |
def forward(self, inp, **kwargs): |
inps[cache['i']] = inp |
cache['i'] += 1 |
cache['attention_mask'] = kwargs['attention_mask'] |
raise ValueError |
layers[0] = Catcher(layers[0]) |
for i in range(nsamples): |
batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) |
try: |
model(batch) |
except ValueError: |
pass |
layers[0] = layers[0].module |
layers[0] = layers[0].cpu() |
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() |
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() |
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
model.model.decoder.project_out = model.model.decoder.project_out.cpu() |
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
model.model.decoder.project_in = model.model.decoder.project_in.cpu() |
torch.cuda.empty_cache() |
outs = torch.zeros_like(inps) |
attention_mask = cache['attention_mask'] |
for i in range(len(layers)): |
print(i) |
layer = layers[i].to(dev) |
if args.gmp: |
subset = find_layers(layer) |
for name in subset: |
W = subset[name].weight.data |
thresh = torch.sort(torch.abs(W.flatten()))[0][int(W.numel() * args.sparsity)] |
W.data[torch.abs(W.data) <= thresh] = 0 |
for j in range(nsamples): |
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
layers[i] = layer.cpu() |
del layer |
torch.cuda.empty_cache() |
inps, outs = outs, inps |
if model.model.decoder.final_layer_norm is not None: |
model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) |
if model.model.decoder.project_out is not None: |
model.model.decoder.project_out = model.model.decoder.project_out.to(dev) |
model.lm_head = model.lm_head.to(dev) |
testenc = testenc.to(dev) |
nlls = [] |
for i in range(nsamples): |
hidden_states = inps[i].unsqueeze(0) |
if model.model.decoder.final_layer_norm is not None: |
hidden_states = model.model.decoder.final_layer_norm(hidden_states) |
if model.model.decoder.project_out is not None: |
hidden_states = model.model.decoder.project_out(hidden_states) |
lm_logits = model.lm_head(hidden_states) |
shift_logits = lm_logits[:, :-1, :].contiguous() |
shift_labels = testenc[ |
:, (i * model.seqlen):((i + 1) * model.seqlen) |
][:, 1:] |
loss_fct = nn.CrossEntropyLoss() |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
neg_log_likelihood = loss.float() * model.seqlen |
nlls.append(neg_log_likelihood) |
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) |
print(f"Perplexity: {ppl.item():3f}") |
model.config.use_cache = use_cache |
@torch.no_grad() |
def llama_eval(model, testenc, dev, args, dataset: str): |
print("Evaluating ...") |
testenc = testenc.input_ids |
nsamples = testenc.numel() // model.seqlen |
use_cache = model.config.use_cache |
model.config.use_cache = False |
layers = model.model.layers |
model.model.embed_tokens = model.model.embed_tokens.to(dev) |
layers[0] = layers[0].to(dev) |
dtype = next(iter(model.parameters())).dtype |
inps = torch.zeros( |
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
) |
cache = {"i": 0, "attention_mask": None} |
class Catcher(nn.Module): |
def __init__(self, module): |
super().__init__() |
self.module = module |
def forward(self, inp, **kwargs): |
inps[cache["i"]] = inp |
cache["i"] += 1 |
cache["attention_mask"] = kwargs["attention_mask"] |
raise ValueError |
layers[0] = Catcher(layers[0]) |
for i in range(nsamples): |
batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev) |
try: |
model(batch) |
except ValueError: |
pass |
layers[0] = layers[0].module |
layers[0] = layers[0].cpu() |
model.model.embed_tokens = model.model.embed_tokens.cpu() |
torch.cuda.empty_cache() |
outs = torch.zeros_like(inps) |
attention_mask = cache["attention_mask"] |
for i in range(len(layers)): |
print(i) |
layer = layers[i].to(dev) |
if args.gmp: |
subset = find_layers(layer) |
for name in subset: |
W = subset[name].weight.data |
thresh = torch.sort(torch.abs(W.flatten()))[0][ |
int(W.numel() * args.sparsity) |
] |
W.data[torch.abs(W.data) <= thresh] = 0 |
for j in range(nsamples): |
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
layers[i] = layer.cpu() |
del layer |
torch.cuda.empty_cache() |
inps, outs = outs, inps |
if model.model.norm is not None: |
model.model.norm = model.model.norm.to(dev) |
model.lm_head = model.lm_head.to(dev) |
testenc = testenc.to(dev) |
nlls = [] |
for i in range(nsamples): |
hidden_states = inps[i].unsqueeze(0) |
if model.model.norm is not None: |
hidden_states = model.model.norm(hidden_states) |
lm_logits = model.lm_head(hidden_states) |
shift_logits = lm_logits[:, :-1, :].contiguous() |
shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:] |
loss_fct = nn.CrossEntropyLoss() |
loss = loss_fct( |
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
) |
neg_log_likelihood = loss.float() * model.seqlen |
nlls.append(neg_log_likelihood) |
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) |
print(f"Perplexity: {ppl.item():3f}") |
model.config.use_cache = use_cache |