File size: 31,972 Bytes
6815477 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 |
# This file will contain functions related to the model such as loading the model, SparseLLM pruning, and evaluation.
import torch
import torch.nn as nn
from pruning_utils import *
from quant import *
import math
from transformers import OPTForCausalLM, LlamaForCausalLM
def get_opt(args):
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = OPTForCausalLM.from_pretrained(args.model, torch_dtype='auto')
model.seqlen = model.config.max_position_embeddings
return model
def get_llama(args):
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = LlamaForCausalLM.from_pretrained(args.model, torch_dtype='auto')
model.seqlen = 2048
return model
@torch.no_grad()
def opt_sparsellm(model, dataloader, dev, args):
print('Starting ...')
use_cache = model.config.use_cache
model.config.use_cache = False
layers = model.model.decoder.layers
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
(args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {'i': 0, 'attention_mask': None}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache['i']] = inp
cache['i'] += 1
cache['attention_mask'] = kwargs['attention_mask']
raise ValueError
layers[0] = Catcher(layers[0])
for batch in dataloader:
try:
model(batch[0].to(dev))
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
model.model.decoder.project_out = model.model.decoder.project_out.cpu()
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
model.model.decoder.project_in = model.model.decoder.project_in.cpu()
torch.cuda.empty_cache()
outs = torch.zeros_like(inps)
attention_mask = cache['attention_mask']
print('Ready.')
for i in range(len(layers)):
layer = layers[i].to(dev)
subset = find_layers(layer)
gpts = {}
for name in subset:
if (not (args.minlayer <= i < args.maxlayer and args.prune_only in name)) == (not args.invert):
continue
gpts[name] = SparseGPT_OPT(subset[name])
if args.wbits < 16:
gpts[name].quantizer = Quantizer()
gpts[name].quantizer.configure(
args.wbits, perchannel=True, sym=False, mse=False
)
def add_batch(name):
def tmp(_, inp, out):
gpts[name].add_batch(inp[0].data, out.data, name)
return tmp
handles = []
for name in gpts:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(args.nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
for h in handles:
h.remove()
target_layer_names = ['fc1', 'fc2']
for name in gpts:
if name not in target_layer_names:
print(i, name)
print('Pruning ...')
# Prune the layer
sparsity = args.sparsity
gpts[name].fasterprune(
sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp, blocksize=args.blocksize
)
gpts[name].free()
# Adjust hyperparameters as needed
alpha = 5.0
beta = 5.0
gamma = 5.0
# Define the number of optimization steps
opt_epochs = 10
# Get the inputs and outputs which are constants here
X_list = gpts['fc1'].batch_inp
Y_list = gpts['fc2'].batch_out
X = torch.stack(X_list, dim=0)
Y = torch.stack(Y_list, dim=0)
# Reshape to 2D
X, Y = X.reshape((-1, X.size(-1))).T, Y.reshape((-1, Y.size(-1))).T
# free memory
X_list, Y_list = None, None
gpts['fc1'].batch_inp.clear()
gpts['fc2'].batch_out.clear()
hidden_z_list = gpts['fc1'].batch_out
z = torch.stack(hidden_z_list, dim=0)
hidden_z_list = None
gpts['fc1'].batch_out.clear()
hidden_p_list = gpts['fc2'].batch_inp
p = torch.stack(hidden_p_list, dim=0)
hidden_p_list = None
gpts['fc2'].batch_inp.clear()
# Initialize auxiliary variables z and p
z = z.reshape((-1, z.size(-1))).T.to(dev)
p = p.reshape((-1, p.size(-1))).T.to(dev)
torch.cuda.empty_cache()
# Pre-compute the pinverse of X and cache it to save computational cost
Xinv = torch.pinverse(X.to(dtype=torch.float32)).half()
for opt_step in range(opt_epochs):
##############
# optimize W
##############
if opt_step > 0: # for the first step, no need for updating W
# Update the weight matrix of fc1
bias = subset['fc1'].bias.unsqueeze(1).expand(-1, z.size(-1))
# Calculate the weight matrix
weight_matrix_1 = torch.matmul(z - bias, Xinv)
# assign the new parameters to gpts class
gpts['fc1'].layer.weight.copy_(weight_matrix_1)
del bias, weight_matrix_1
# Update the weight matrix of fc2
pinv = torch.pinverse(p.to(dtype=torch.float32)).half()
bias = subset['fc2'].bias.unsqueeze(1).expand(-1, Y.size(-1))
# Calculate the weight matrix
weight_matrix_2 = torch.matmul(Y - bias, pinv)
# assign the new parameters to gpts class
gpts['fc2'].layer.weight.copy_(weight_matrix_2)
del bias, weight_matrix_2, pinv
torch.cuda.empty_cache()
##############
# prune W
##############
# modify gpts[name].H to be our auxiliary variable
if opt_step > 0: # for the first step, no need for updating H
tmp_H = torch.zeros_like(gpts['fc2'].H)
tmp_p = p.T.reshape((args.nsamples, -1, p.size(0)))
tmp_nsamples = 0
for j in range(args.nsamples):
tmp_inp = tmp_p[j].unsqueeze(0)
tmp = tmp_inp.shape[0]
if isinstance(gpts['fc2'].layer, nn.Linear) or isinstance(gpts['fc2'].layer, transformers.Conv1D):
if len(tmp_inp.shape) == 3:
tmp_inp = tmp_inp.reshape((-1, tmp_inp.shape[-1]))
tmp_inp = tmp_inp.t()
tmp_H *= tmp_nsamples / (tmp_nsamples + tmp)
tmp_nsamples += tmp
tmp_inp = math.sqrt(2 / tmp_nsamples) * tmp_inp.float()
tmp_H += tmp_inp.matmul(tmp_inp.t())
gpts['fc2'].H.copy_(tmp_H)
del tmp_H, tmp_p
torch.cuda.empty_cache()
for name in target_layer_names:
print(i, name)
print('Pruning ...')
sparsity = args.sparsity
gpts[name].fasterprune(
sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp, blocksize=args.blocksize
)
##############
# optimize p
##############
# Activation inverse
next_weight = subset['fc2'].weight
m1 = beta * torch.matmul(next_weight.T, next_weight)
m2 = gamma * torch.eye(m1.shape[0], device=m1.device)
av = torch.inverse(m1 + m2).to(dtype=torch.float16)
del m1, m2
torch.cuda.empty_cache()
# Calculate ReLU
layer_nl_output = nn.functional.relu(z)
# Activation formulate
bias = subset['fc2'].bias.unsqueeze(1).expand(-1, Y.size(-1))
m3 = beta * torch.matmul(next_weight.T, Y - bias)
m4 = gamma * layer_nl_output
af = m3 + m4
p = torch.matmul(av, af)
del layer_nl_output, next_weight, av, m3, m4, af, bias
torch.cuda.empty_cache()
##############
# optimize z
##############
w = subset['fc1'].weight
bias = subset['fc1'].bias.unsqueeze(1).expand(-1, z.size(-1))
m = torch.matmul(w, X) + bias
sol1 = (gamma * p + alpha * m) / (gamma + alpha)
sol2 = m
del w, bias
torch.cuda.empty_cache()
z1 = torch.zeros_like(p)
z2 = torch.zeros_like(p)
chunk_size = 500 # Choose an appropriate size based on your memory constraints
# Assuming the first dimension is the one to be chunked
for k in range(0, sol1.size(0), chunk_size):
chunk = slice(k, k + chunk_size)
# Apply the condition and assignment for the chunk
z1_chunk = z1[chunk]
sol1_chunk = sol1[chunk]
z1_chunk[sol1_chunk >= 0.] = sol1_chunk[sol1_chunk >= 0.]
z1[chunk] = z1_chunk
z2_chunk = z2[chunk]
sol2_chunk = sol2[chunk]
z2_chunk[sol2_chunk <= 0.] = sol2_chunk[sol2_chunk <= 0.]
z2[chunk] = z2_chunk
del z1_chunk, z2_chunk, sol1_chunk, sol2_chunk, sol1, sol2
torch.cuda.empty_cache()
for k in range(0, z1.size(0), chunk_size):
chunk = slice(k, k + chunk_size)
# Compute fz_1 and fz_2 for the current chunk
fz_1_chunk = gamma * torch.square(p[chunk] - nn.functional.relu(z1[chunk])) + alpha * torch.square(z1[chunk] - m[chunk])
fz_2_chunk = gamma * torch.square(p[chunk] - nn.functional.relu(z2[chunk])) + alpha * torch.square(z2[chunk] - m[chunk])
# Determine indices for z1 and z2 for the current chunk
index_z1_chunk = fz_1_chunk <= fz_2_chunk
index_z2_chunk = fz_2_chunk < fz_1_chunk
# Update z for the current chunk
z[chunk][index_z1_chunk] = z1[chunk][index_z1_chunk]
z[chunk][index_z2_chunk] = z2[chunk][index_z2_chunk]
# Clear memory if necessary
del fz_1_chunk, fz_2_chunk, index_z1_chunk, index_z2_chunk, z1, z2, m, chunk
torch.cuda.empty_cache()
for name in target_layer_names:
gpts[name].free()
for j in range(args.nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
layers[i] = layer.cpu()
del layer
torch.cuda.empty_cache()
inps, outs = outs, inps
model.config.use_cache = use_cache
@torch.no_grad()
def llama_sparsellm(model, dataloader, dev, args):
print("Starting...")
use_cache = model.config.use_cache
model.config.use_cache = False
layers = model.model.layers
model.model.embed_tokens = model.model.embed_tokens.to(dev)
model.model.norm = model.model.norm.to(dev)
layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
(args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {"i": 0, "attention_mask": None}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache["i"]] = inp
cache["i"] += 1
cache["attention_mask"] = kwargs["attention_mask"]
raise ValueError
layers[0] = Catcher(layers[0])
for batch in dataloader:
try:
model(batch[0].to(dev))
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
model.model.embed_tokens = model.model.embed_tokens.cpu()
model.model.norm = model.model.norm.cpu()
torch.cuda.empty_cache()
outs = torch.zeros_like(inps)
attention_mask = cache["attention_mask"]
print("Ready.")
for i in range(len(layers)):
layer = layers[i].to(dev)
full = find_layers(layer)
if args.true_sequential:
sequential = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
else:
sequential = [list(full.keys())]
for names in sequential:
subset = {n: full[n] for n in names}
gpts = {}
for name in subset:
if (
not (args.minlayer <= i < args.maxlayer and args.prune_only in name)
) == (not args.invert):
continue
gpts[name] = SparseGPT_LlaMA(subset[name])
if args.wbits < 16:
gpts[name].quantizer = Quantizer()
gpts[name].quantizer.configure(
args.wbits, perchannel=True, sym=False, mse=False
)
def add_batch(name):
def tmp(_, inp, out):
gpts[name].add_batch(inp[0].data, out.data, name)
return tmp
handles = []
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(args.nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
for h in handles:
h.remove()
target_layer_names = ["mlp.up_proj", "mlp.gate_proj", "mlp.down_proj"]
for name in subset:
if name not in target_layer_names:
print(i, name)
print("Pruning ...")
sparsity = args.sparsity
gpts[name].fasterprune(
sparsity,
prunen=args.prunen,
prunem=args.prunem,
percdamp=args.percdamp,
blocksize=args.blocksize,
)
gpts[name].free()
# Adjust hyperparameters as needed
alpha = 5.0
beta = 5.0
gamma = 5.0
# Define the number of global pruning epochs
opt_epochs = 8 # This might need to be adjusted
# Get the inputs and outputs which are constants here
X_list = gpts['mlp.up_proj'].batch_inp
Y_list = gpts['mlp.down_proj'].batch_out
X = torch.stack(X_list, dim=0)
Y = torch.stack(Y_list, dim=0)
# Reshape to 2D
X, Y = X.reshape((-1, X.size(-1))).T, Y.reshape((-1, Y.size(-1))).T
# free memory
X_list, Y_list = None, None
gpts['mlp.up_proj'].batch_inp.clear()
gpts['mlp.down_proj'].batch_out.clear()
# Get the hidden variables and their initialization
# z: output of 'mlp.up_proj'
hidden_z_list = gpts['mlp.up_proj'].batch_out
z = torch.stack(hidden_z_list, dim=0)
hidden_z_list = None
gpts['mlp.up_proj'].batch_out.clear()
# p: input of 'mlp.down_proj'
hidden_p_list = gpts['mlp.down_proj'].batch_inp
p = torch.stack(hidden_p_list, dim=0)
hidden_p_list = None
gpts['mlp.down_proj'].batch_inp.clear()
# s: output of 'mlp.gate_proj'
hidden_s_list = gpts['mlp.gate_proj'].batch_out
s = torch.stack(hidden_s_list, dim=0)
hidden_s_list = None
gpts['mlp.gate_proj'].batch_out.clear()
# Reshape auxiliary variables
z = z.reshape((-1, z.size(-1))).T.to(dev)
p = p.reshape((-1, p.size(-1))).T.to(dev)
s = s.reshape((-1, s.size(-1))).T.to(dev)
torch.cuda.empty_cache()
# Pre-compute the pinverse of X and cache it to save computational cost
Xinv = torch.pinverse(X.to(dtype=torch.float32)).half()
# list to store training losses
training_loss = {'Y_p_loss': [], 'p_z_loss': [], 'z_X_loss': [], 'train_loss': []}
for opt_step in range(opt_epochs):
##############
# optimize W
##############
if opt_step > 0: # for the first step, no need for updating W
# Update the weight matrix of mlp.up_project
# Calculate the weight matrix
weight_matrix_1 = torch.matmul(z, Xinv)
# assign the new parameters to gpts class
gpts['mlp.up_proj'].layer.weight.copy_(weight_matrix_1)
del weight_matrix_1
# Update the weight matrix of mlp.down_proj
pinv = torch.pinverse(p.to(dtype=torch.float32)).half()
# Calculate the weight matrix
weight_matrix_2 = torch.matmul(Y, pinv)
# assign the new parameters to gpts class
gpts['mlp.down_proj'].layer.weight.copy_(weight_matrix_2)
del weight_matrix_2, pinv
# Update the weight matrix of mlp.gate_project
# Calculate the weight matrix
weight_matrix_3 = torch.matmul(s, Xinv)
# assign the new parameters to gpts class
gpts['mlp.gate_proj'].layer.weight.copy_(weight_matrix_3)
del weight_matrix_3
torch.cuda.empty_cache()
##############
# prune W
##############
# modify gpts[name].H to be our auxiliary variable
if opt_step > 0: # for the first step, no need for updating H
tmp_H = torch.zeros_like(gpts['mlp.down_proj'].H)
tmp_p = p.T.reshape((args.nsamples, -1, p.size(0)))
tmp_nsamples = 0
for j in range(args.nsamples):
tmp_inp = tmp_p[j].unsqueeze(0)
tmp = tmp_inp.shape[0]
if isinstance(gpts['mlp.down_proj'].layer, nn.Linear) or isinstance(gpts['mlp.down_proj'].layer, transformers.Conv1D):
if len(tmp_inp.shape) == 3:
tmp_inp = tmp_inp.reshape((-1, tmp_inp.shape[-1]))
tmp_inp = tmp_inp.t()
tmp_H *= tmp_nsamples / (tmp_nsamples + tmp)
tmp_nsamples += tmp
tmp_inp = math.sqrt(2 / tmp_nsamples) * tmp_inp.float()
tmp_H += tmp_inp.matmul(tmp_inp.t())
gpts['mlp.down_proj'].H.copy_(tmp_H)
del tmp_H, tmp_p
torch.cuda.empty_cache()
for name in target_layer_names:
print(i, name)
print('Pruning ...')
sparsity = args.sparsity
gpts[name].fasterprune(
sparsity,
prunen=args.prunen,
prunem=args.prunem,
percdamp=args.percdamp,
blocksize=args.blocksize,
)
##############
# optimize p
##############
# Activation inverse
next_weight = subset['mlp.down_proj'].weight
m1 = beta * torch.matmul(next_weight.T, next_weight)
m2 = gamma * torch.eye(m1.shape[0], device=m1.device)
av = torch.inverse(m1 + m2).to(dtype=torch.float16)
del m1, m2
torch.cuda.empty_cache()
# Calculate SwiGLU output
layer_nl_output = nn.functional.silu(s) * z
# Activation formulate
m3 = beta * torch.matmul(next_weight.T, Y)
m4 = gamma * layer_nl_output
af = m3 + m4
p = torch.matmul(av, af)
del layer_nl_output, next_weight, av, m3, m4, af
torch.cuda.empty_cache()
##############
# optimize z
##############
w = subset['mlp.up_proj'].weight
m = torch.matmul(w, X)
swish = nn.functional.silu(s)
z = (m + swish * p) / (swish ** 2 + 1)
del w, m, swish
torch.cuda.empty_cache()
##############
# optimize s
##############
w = subset['mlp.gate_proj'].weight
# convert the layer's weight tensor to float32 and enable grad
w = w.to(dtype=torch.float32).requires_grad_(True)
s_update_epochs = 2
s_learning_rate = 0.01
for _ in range(s_update_epochs):
batch_size = 1000 # Choose an appropriate batch size based on your memory constraints
# s: [hidden_d, n_samples]
for k in range(0, s.size(-1), batch_size):
chunk = slice(k, k + batch_size)
# get the "mini-batch" for each tensor and turn on autograd
X_batch = X[:,chunk].to(dtype=torch.float32).requires_grad_(True)
z_batch = z[:,chunk].to(dtype=torch.float32).requires_grad_(True)
p_batch = p[:,chunk].to(dtype=torch.float32).requires_grad_(True)
s_batch = s[:,chunk].to(dtype=torch.float32).requires_grad_(True)
with torch.enable_grad(): # temporarily turn on the Pytorch computational graph functionality
loss_s = alpha * torch.norm(s_batch - torch.matmul(w, X_batch))**2
loss_s += gamma * torch.norm(p_batch - nn.functional.silu(s_batch) * z_batch)**2
loss_s.backward()
s_batch -= s_learning_rate * s_batch.grad
s_batch.grad.zero_()
s[:,chunk] = s_batch.detach().to(dtype=torch.float16)
s_batch, X_batch, z_batch, p_batch, w = s_batch.detach(), X_batch.detach(), z_batch.detach(), p_batch.detach(), w.detach()
del w, loss_s, s_batch, X_batch, z_batch, p_batch
torch.cuda.empty_cache()
# compute and save the training loss after each epoch
tmp_training_loss = nn.functional.mse_loss(torch.matmul(subset['mlp.down_proj'].weight,
nn.functional.silu(torch.matmul(subset['mlp.gate_proj'].weight, X))
* torch.matmul(subset['mlp.up_proj'].weight, X)), Y)
training_loss['train_loss'].append(tmp_training_loss.item())
for j in range(args.nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
layers[i] = layer.cpu()
del layer
del gpts
torch.cuda.empty_cache()
inps, outs = outs, inps
model.config.use_cache = use_cache
@torch.no_grad()
def opt_eval(model, testenc, dev, args, dataset: str):
print('Evaluating ...')
testenc = testenc.input_ids
nsamples = testenc.numel() // model.seqlen
use_cache = model.config.use_cache
model.config.use_cache = False
layers = model.model.decoder.layers
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {'i': 0, 'attention_mask': None}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache['i']] = inp
cache['i'] += 1
cache['attention_mask'] = kwargs['attention_mask']
raise ValueError
layers[0] = Catcher(layers[0])
for i in range(nsamples):
batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
try:
model(batch)
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
model.model.decoder.project_out = model.model.decoder.project_out.cpu()
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
model.model.decoder.project_in = model.model.decoder.project_in.cpu()
torch.cuda.empty_cache()
outs = torch.zeros_like(inps)
attention_mask = cache['attention_mask']
for i in range(len(layers)):
print(i)
layer = layers[i].to(dev)
if args.gmp:
subset = find_layers(layer)
for name in subset:
W = subset[name].weight.data
thresh = torch.sort(torch.abs(W.flatten()))[0][int(W.numel() * args.sparsity)]
W.data[torch.abs(W.data) <= thresh] = 0
for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
layers[i] = layer.cpu()
del layer
torch.cuda.empty_cache()
inps, outs = outs, inps
if model.model.decoder.final_layer_norm is not None:
model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
if model.model.decoder.project_out is not None:
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
model.lm_head = model.lm_head.to(dev)
testenc = testenc.to(dev)
nlls = []
for i in range(nsamples):
hidden_states = inps[i].unsqueeze(0)
if model.model.decoder.final_layer_norm is not None:
hidden_states = model.model.decoder.final_layer_norm(hidden_states)
if model.model.decoder.project_out is not None:
hidden_states = model.model.decoder.project_out(hidden_states)
lm_logits = model.lm_head(hidden_states)
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = testenc[
:, (i * model.seqlen):((i + 1) * model.seqlen)
][:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * model.seqlen
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
print(f"Perplexity: {ppl.item():3f}")
model.config.use_cache = use_cache
@torch.no_grad()
def llama_eval(model, testenc, dev, args, dataset: str):
print("Evaluating ...")
testenc = testenc.input_ids
nsamples = testenc.numel() // model.seqlen
use_cache = model.config.use_cache
model.config.use_cache = False
layers = model.model.layers
model.model.embed_tokens = model.model.embed_tokens.to(dev)
layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {"i": 0, "attention_mask": None}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache["i"]] = inp
cache["i"] += 1
cache["attention_mask"] = kwargs["attention_mask"]
raise ValueError
layers[0] = Catcher(layers[0])
for i in range(nsamples):
batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev)
try:
model(batch)
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
model.model.embed_tokens = model.model.embed_tokens.cpu()
torch.cuda.empty_cache()
outs = torch.zeros_like(inps)
attention_mask = cache["attention_mask"]
for i in range(len(layers)):
print(i)
layer = layers[i].to(dev)
if args.gmp:
subset = find_layers(layer)
for name in subset:
W = subset[name].weight.data
thresh = torch.sort(torch.abs(W.flatten()))[0][
int(W.numel() * args.sparsity)
]
W.data[torch.abs(W.data) <= thresh] = 0
for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
layers[i] = layer.cpu()
del layer
torch.cuda.empty_cache()
inps, outs = outs, inps
if model.model.norm is not None:
model.model.norm = model.model.norm.to(dev)
model.lm_head = model.lm_head.to(dev)
testenc = testenc.to(dev)
nlls = []
for i in range(nsamples):
hidden_states = inps[i].unsqueeze(0)
if model.model.norm is not None:
hidden_states = model.model.norm(hidden_states)
lm_logits = model.lm_head(hidden_states)
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
neg_log_likelihood = loss.float() * model.seqlen
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
print(f"Perplexity: {ppl.item():3f}")
model.config.use_cache = use_cache |