|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from pruning_utils import * |
|
from quant import * |
|
import math |
|
from transformers import OPTForCausalLM, LlamaForCausalLM |
|
|
|
def get_opt(args): |
|
def skip(*args, **kwargs): |
|
pass |
|
torch.nn.init.kaiming_uniform_ = skip |
|
torch.nn.init.uniform_ = skip |
|
torch.nn.init.normal_ = skip |
|
model = OPTForCausalLM.from_pretrained(args.model, torch_dtype='auto') |
|
model.seqlen = model.config.max_position_embeddings |
|
return model |
|
|
|
def get_llama(args): |
|
def skip(*args, **kwargs): |
|
pass |
|
torch.nn.init.kaiming_uniform_ = skip |
|
torch.nn.init.uniform_ = skip |
|
torch.nn.init.normal_ = skip |
|
model = LlamaForCausalLM.from_pretrained(args.model, torch_dtype='auto') |
|
model.seqlen = 2048 |
|
return model |
|
|
|
@torch.no_grad() |
|
def opt_sparsellm(model, dataloader, dev, args): |
|
print('Starting ...') |
|
|
|
use_cache = model.config.use_cache |
|
model.config.use_cache = False |
|
layers = model.model.decoder.layers |
|
|
|
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) |
|
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) |
|
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
|
model.model.decoder.project_out = model.model.decoder.project_out.to(dev) |
|
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
|
model.model.decoder.project_in = model.model.decoder.project_in.to(dev) |
|
layers[0] = layers[0].to(dev) |
|
|
|
dtype = next(iter(model.parameters())).dtype |
|
inps = torch.zeros( |
|
(args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
|
) |
|
cache = {'i': 0, 'attention_mask': None} |
|
|
|
class Catcher(nn.Module): |
|
def __init__(self, module): |
|
super().__init__() |
|
self.module = module |
|
def forward(self, inp, **kwargs): |
|
inps[cache['i']] = inp |
|
cache['i'] += 1 |
|
cache['attention_mask'] = kwargs['attention_mask'] |
|
raise ValueError |
|
layers[0] = Catcher(layers[0]) |
|
for batch in dataloader: |
|
try: |
|
model(batch[0].to(dev)) |
|
except ValueError: |
|
pass |
|
layers[0] = layers[0].module |
|
|
|
layers[0] = layers[0].cpu() |
|
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() |
|
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() |
|
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
|
model.model.decoder.project_out = model.model.decoder.project_out.cpu() |
|
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
|
model.model.decoder.project_in = model.model.decoder.project_in.cpu() |
|
torch.cuda.empty_cache() |
|
|
|
outs = torch.zeros_like(inps) |
|
attention_mask = cache['attention_mask'] |
|
|
|
print('Ready.') |
|
|
|
for i in range(len(layers)): |
|
layer = layers[i].to(dev) |
|
|
|
subset = find_layers(layer) |
|
|
|
gpts = {} |
|
for name in subset: |
|
if (not (args.minlayer <= i < args.maxlayer and args.prune_only in name)) == (not args.invert): |
|
continue |
|
gpts[name] = SparseGPT_OPT(subset[name]) |
|
if args.wbits < 16: |
|
gpts[name].quantizer = Quantizer() |
|
gpts[name].quantizer.configure( |
|
args.wbits, perchannel=True, sym=False, mse=False |
|
) |
|
|
|
def add_batch(name): |
|
def tmp(_, inp, out): |
|
gpts[name].add_batch(inp[0].data, out.data, name) |
|
return tmp |
|
handles = [] |
|
for name in gpts: |
|
handles.append(subset[name].register_forward_hook(add_batch(name))) |
|
for j in range(args.nsamples): |
|
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
|
for h in handles: |
|
h.remove() |
|
|
|
target_layer_names = ['fc1', 'fc2'] |
|
|
|
for name in gpts: |
|
if name not in target_layer_names: |
|
print(i, name) |
|
print('Pruning ...') |
|
|
|
sparsity = args.sparsity |
|
gpts[name].fasterprune( |
|
sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp, blocksize=args.blocksize |
|
) |
|
gpts[name].free() |
|
|
|
|
|
alpha = 5.0 |
|
beta = 5.0 |
|
gamma = 5.0 |
|
|
|
|
|
opt_epochs = 10 |
|
|
|
|
|
X_list = gpts['fc1'].batch_inp |
|
Y_list = gpts['fc2'].batch_out |
|
X = torch.stack(X_list, dim=0) |
|
Y = torch.stack(Y_list, dim=0) |
|
|
|
X, Y = X.reshape((-1, X.size(-1))).T, Y.reshape((-1, Y.size(-1))).T |
|
|
|
|
|
X_list, Y_list = None, None |
|
gpts['fc1'].batch_inp.clear() |
|
gpts['fc2'].batch_out.clear() |
|
|
|
hidden_z_list = gpts['fc1'].batch_out |
|
z = torch.stack(hidden_z_list, dim=0) |
|
hidden_z_list = None |
|
gpts['fc1'].batch_out.clear() |
|
hidden_p_list = gpts['fc2'].batch_inp |
|
p = torch.stack(hidden_p_list, dim=0) |
|
hidden_p_list = None |
|
gpts['fc2'].batch_inp.clear() |
|
|
|
|
|
z = z.reshape((-1, z.size(-1))).T.to(dev) |
|
p = p.reshape((-1, p.size(-1))).T.to(dev) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
Xinv = torch.pinverse(X.to(dtype=torch.float32)).half() |
|
|
|
for opt_step in range(opt_epochs): |
|
|
|
|
|
|
|
|
|
|
|
if opt_step > 0: |
|
|
|
|
|
bias = subset['fc1'].bias.unsqueeze(1).expand(-1, z.size(-1)) |
|
|
|
weight_matrix_1 = torch.matmul(z - bias, Xinv) |
|
|
|
gpts['fc1'].layer.weight.copy_(weight_matrix_1) |
|
del bias, weight_matrix_1 |
|
|
|
|
|
pinv = torch.pinverse(p.to(dtype=torch.float32)).half() |
|
bias = subset['fc2'].bias.unsqueeze(1).expand(-1, Y.size(-1)) |
|
|
|
weight_matrix_2 = torch.matmul(Y - bias, pinv) |
|
|
|
gpts['fc2'].layer.weight.copy_(weight_matrix_2) |
|
|
|
del bias, weight_matrix_2, pinv |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
if opt_step > 0: |
|
|
|
tmp_H = torch.zeros_like(gpts['fc2'].H) |
|
tmp_p = p.T.reshape((args.nsamples, -1, p.size(0))) |
|
tmp_nsamples = 0 |
|
for j in range(args.nsamples): |
|
tmp_inp = tmp_p[j].unsqueeze(0) |
|
tmp = tmp_inp.shape[0] |
|
if isinstance(gpts['fc2'].layer, nn.Linear) or isinstance(gpts['fc2'].layer, transformers.Conv1D): |
|
if len(tmp_inp.shape) == 3: |
|
tmp_inp = tmp_inp.reshape((-1, tmp_inp.shape[-1])) |
|
tmp_inp = tmp_inp.t() |
|
tmp_H *= tmp_nsamples / (tmp_nsamples + tmp) |
|
tmp_nsamples += tmp |
|
tmp_inp = math.sqrt(2 / tmp_nsamples) * tmp_inp.float() |
|
tmp_H += tmp_inp.matmul(tmp_inp.t()) |
|
gpts['fc2'].H.copy_(tmp_H) |
|
del tmp_H, tmp_p |
|
torch.cuda.empty_cache() |
|
|
|
for name in target_layer_names: |
|
print(i, name) |
|
print('Pruning ...') |
|
sparsity = args.sparsity |
|
gpts[name].fasterprune( |
|
sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp, blocksize=args.blocksize |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
next_weight = subset['fc2'].weight |
|
m1 = beta * torch.matmul(next_weight.T, next_weight) |
|
m2 = gamma * torch.eye(m1.shape[0], device=m1.device) |
|
av = torch.inverse(m1 + m2).to(dtype=torch.float16) |
|
|
|
del m1, m2 |
|
torch.cuda.empty_cache() |
|
|
|
|
|
layer_nl_output = nn.functional.relu(z) |
|
|
|
|
|
bias = subset['fc2'].bias.unsqueeze(1).expand(-1, Y.size(-1)) |
|
m3 = beta * torch.matmul(next_weight.T, Y - bias) |
|
m4 = gamma * layer_nl_output |
|
af = m3 + m4 |
|
|
|
p = torch.matmul(av, af) |
|
|
|
del layer_nl_output, next_weight, av, m3, m4, af, bias |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
w = subset['fc1'].weight |
|
bias = subset['fc1'].bias.unsqueeze(1).expand(-1, z.size(-1)) |
|
m = torch.matmul(w, X) + bias |
|
sol1 = (gamma * p + alpha * m) / (gamma + alpha) |
|
sol2 = m |
|
del w, bias |
|
torch.cuda.empty_cache() |
|
|
|
z1 = torch.zeros_like(p) |
|
z2 = torch.zeros_like(p) |
|
|
|
chunk_size = 500 |
|
|
|
for k in range(0, sol1.size(0), chunk_size): |
|
chunk = slice(k, k + chunk_size) |
|
|
|
|
|
z1_chunk = z1[chunk] |
|
sol1_chunk = sol1[chunk] |
|
z1_chunk[sol1_chunk >= 0.] = sol1_chunk[sol1_chunk >= 0.] |
|
z1[chunk] = z1_chunk |
|
|
|
z2_chunk = z2[chunk] |
|
sol2_chunk = sol2[chunk] |
|
z2_chunk[sol2_chunk <= 0.] = sol2_chunk[sol2_chunk <= 0.] |
|
z2[chunk] = z2_chunk |
|
|
|
del z1_chunk, z2_chunk, sol1_chunk, sol2_chunk, sol1, sol2 |
|
torch.cuda.empty_cache() |
|
|
|
for k in range(0, z1.size(0), chunk_size): |
|
chunk = slice(k, k + chunk_size) |
|
|
|
|
|
fz_1_chunk = gamma * torch.square(p[chunk] - nn.functional.relu(z1[chunk])) + alpha * torch.square(z1[chunk] - m[chunk]) |
|
fz_2_chunk = gamma * torch.square(p[chunk] - nn.functional.relu(z2[chunk])) + alpha * torch.square(z2[chunk] - m[chunk]) |
|
|
|
|
|
index_z1_chunk = fz_1_chunk <= fz_2_chunk |
|
index_z2_chunk = fz_2_chunk < fz_1_chunk |
|
|
|
|
|
z[chunk][index_z1_chunk] = z1[chunk][index_z1_chunk] |
|
z[chunk][index_z2_chunk] = z2[chunk][index_z2_chunk] |
|
|
|
|
|
del fz_1_chunk, fz_2_chunk, index_z1_chunk, index_z2_chunk, z1, z2, m, chunk |
|
torch.cuda.empty_cache() |
|
|
|
for name in target_layer_names: |
|
gpts[name].free() |
|
|
|
for j in range(args.nsamples): |
|
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
|
|
|
layers[i] = layer.cpu() |
|
del layer |
|
torch.cuda.empty_cache() |
|
|
|
inps, outs = outs, inps |
|
|
|
model.config.use_cache = use_cache |
|
|
|
|
|
@torch.no_grad() |
|
def llama_sparsellm(model, dataloader, dev, args): |
|
print("Starting...") |
|
|
|
use_cache = model.config.use_cache |
|
model.config.use_cache = False |
|
layers = model.model.layers |
|
|
|
model.model.embed_tokens = model.model.embed_tokens.to(dev) |
|
model.model.norm = model.model.norm.to(dev) |
|
layers[0] = layers[0].to(dev) |
|
|
|
dtype = next(iter(model.parameters())).dtype |
|
inps = torch.zeros( |
|
(args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
|
) |
|
cache = {"i": 0, "attention_mask": None} |
|
|
|
class Catcher(nn.Module): |
|
def __init__(self, module): |
|
super().__init__() |
|
self.module = module |
|
|
|
def forward(self, inp, **kwargs): |
|
inps[cache["i"]] = inp |
|
cache["i"] += 1 |
|
cache["attention_mask"] = kwargs["attention_mask"] |
|
raise ValueError |
|
|
|
layers[0] = Catcher(layers[0]) |
|
for batch in dataloader: |
|
try: |
|
model(batch[0].to(dev)) |
|
except ValueError: |
|
pass |
|
layers[0] = layers[0].module |
|
|
|
layers[0] = layers[0].cpu() |
|
model.model.embed_tokens = model.model.embed_tokens.cpu() |
|
model.model.norm = model.model.norm.cpu() |
|
torch.cuda.empty_cache() |
|
|
|
outs = torch.zeros_like(inps) |
|
attention_mask = cache["attention_mask"] |
|
|
|
print("Ready.") |
|
|
|
for i in range(len(layers)): |
|
layer = layers[i].to(dev) |
|
full = find_layers(layer) |
|
|
|
if args.true_sequential: |
|
sequential = [ |
|
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], |
|
["self_attn.o_proj"], |
|
["mlp.up_proj", "mlp.gate_proj"], |
|
["mlp.down_proj"], |
|
] |
|
else: |
|
sequential = [list(full.keys())] |
|
|
|
for names in sequential: |
|
subset = {n: full[n] for n in names} |
|
|
|
gpts = {} |
|
for name in subset: |
|
if ( |
|
not (args.minlayer <= i < args.maxlayer and args.prune_only in name) |
|
) == (not args.invert): |
|
continue |
|
gpts[name] = SparseGPT_LlaMA(subset[name]) |
|
if args.wbits < 16: |
|
gpts[name].quantizer = Quantizer() |
|
gpts[name].quantizer.configure( |
|
args.wbits, perchannel=True, sym=False, mse=False |
|
) |
|
|
|
def add_batch(name): |
|
def tmp(_, inp, out): |
|
gpts[name].add_batch(inp[0].data, out.data, name) |
|
|
|
return tmp |
|
|
|
handles = [] |
|
for name in subset: |
|
handles.append(subset[name].register_forward_hook(add_batch(name))) |
|
for j in range(args.nsamples): |
|
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
|
for h in handles: |
|
h.remove() |
|
|
|
target_layer_names = ["mlp.up_proj", "mlp.gate_proj", "mlp.down_proj"] |
|
|
|
for name in subset: |
|
if name not in target_layer_names: |
|
print(i, name) |
|
print("Pruning ...") |
|
sparsity = args.sparsity |
|
gpts[name].fasterprune( |
|
sparsity, |
|
prunen=args.prunen, |
|
prunem=args.prunem, |
|
percdamp=args.percdamp, |
|
blocksize=args.blocksize, |
|
) |
|
gpts[name].free() |
|
|
|
|
|
alpha = 5.0 |
|
beta = 5.0 |
|
gamma = 5.0 |
|
|
|
|
|
opt_epochs = 8 |
|
|
|
|
|
X_list = gpts['mlp.up_proj'].batch_inp |
|
Y_list = gpts['mlp.down_proj'].batch_out |
|
X = torch.stack(X_list, dim=0) |
|
Y = torch.stack(Y_list, dim=0) |
|
|
|
X, Y = X.reshape((-1, X.size(-1))).T, Y.reshape((-1, Y.size(-1))).T |
|
|
|
|
|
X_list, Y_list = None, None |
|
gpts['mlp.up_proj'].batch_inp.clear() |
|
gpts['mlp.down_proj'].batch_out.clear() |
|
|
|
|
|
|
|
hidden_z_list = gpts['mlp.up_proj'].batch_out |
|
z = torch.stack(hidden_z_list, dim=0) |
|
hidden_z_list = None |
|
gpts['mlp.up_proj'].batch_out.clear() |
|
|
|
hidden_p_list = gpts['mlp.down_proj'].batch_inp |
|
p = torch.stack(hidden_p_list, dim=0) |
|
hidden_p_list = None |
|
gpts['mlp.down_proj'].batch_inp.clear() |
|
|
|
hidden_s_list = gpts['mlp.gate_proj'].batch_out |
|
s = torch.stack(hidden_s_list, dim=0) |
|
hidden_s_list = None |
|
gpts['mlp.gate_proj'].batch_out.clear() |
|
|
|
|
|
z = z.reshape((-1, z.size(-1))).T.to(dev) |
|
p = p.reshape((-1, p.size(-1))).T.to(dev) |
|
s = s.reshape((-1, s.size(-1))).T.to(dev) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
Xinv = torch.pinverse(X.to(dtype=torch.float32)).half() |
|
|
|
|
|
training_loss = {'Y_p_loss': [], 'p_z_loss': [], 'z_X_loss': [], 'train_loss': []} |
|
|
|
for opt_step in range(opt_epochs): |
|
|
|
|
|
|
|
|
|
|
|
if opt_step > 0: |
|
|
|
|
|
|
|
weight_matrix_1 = torch.matmul(z, Xinv) |
|
|
|
gpts['mlp.up_proj'].layer.weight.copy_(weight_matrix_1) |
|
del weight_matrix_1 |
|
|
|
|
|
pinv = torch.pinverse(p.to(dtype=torch.float32)).half() |
|
|
|
weight_matrix_2 = torch.matmul(Y, pinv) |
|
|
|
gpts['mlp.down_proj'].layer.weight.copy_(weight_matrix_2) |
|
del weight_matrix_2, pinv |
|
|
|
|
|
|
|
weight_matrix_3 = torch.matmul(s, Xinv) |
|
|
|
gpts['mlp.gate_proj'].layer.weight.copy_(weight_matrix_3) |
|
del weight_matrix_3 |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
if opt_step > 0: |
|
|
|
tmp_H = torch.zeros_like(gpts['mlp.down_proj'].H) |
|
tmp_p = p.T.reshape((args.nsamples, -1, p.size(0))) |
|
tmp_nsamples = 0 |
|
for j in range(args.nsamples): |
|
tmp_inp = tmp_p[j].unsqueeze(0) |
|
tmp = tmp_inp.shape[0] |
|
if isinstance(gpts['mlp.down_proj'].layer, nn.Linear) or isinstance(gpts['mlp.down_proj'].layer, transformers.Conv1D): |
|
if len(tmp_inp.shape) == 3: |
|
tmp_inp = tmp_inp.reshape((-1, tmp_inp.shape[-1])) |
|
tmp_inp = tmp_inp.t() |
|
tmp_H *= tmp_nsamples / (tmp_nsamples + tmp) |
|
tmp_nsamples += tmp |
|
tmp_inp = math.sqrt(2 / tmp_nsamples) * tmp_inp.float() |
|
tmp_H += tmp_inp.matmul(tmp_inp.t()) |
|
gpts['mlp.down_proj'].H.copy_(tmp_H) |
|
del tmp_H, tmp_p |
|
torch.cuda.empty_cache() |
|
|
|
for name in target_layer_names: |
|
print(i, name) |
|
print('Pruning ...') |
|
sparsity = args.sparsity |
|
gpts[name].fasterprune( |
|
sparsity, |
|
prunen=args.prunen, |
|
prunem=args.prunem, |
|
percdamp=args.percdamp, |
|
blocksize=args.blocksize, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
next_weight = subset['mlp.down_proj'].weight |
|
m1 = beta * torch.matmul(next_weight.T, next_weight) |
|
m2 = gamma * torch.eye(m1.shape[0], device=m1.device) |
|
av = torch.inverse(m1 + m2).to(dtype=torch.float16) |
|
|
|
del m1, m2 |
|
torch.cuda.empty_cache() |
|
|
|
|
|
layer_nl_output = nn.functional.silu(s) * z |
|
|
|
|
|
m3 = beta * torch.matmul(next_weight.T, Y) |
|
m4 = gamma * layer_nl_output |
|
af = m3 + m4 |
|
|
|
p = torch.matmul(av, af) |
|
|
|
del layer_nl_output, next_weight, av, m3, m4, af |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
w = subset['mlp.up_proj'].weight |
|
m = torch.matmul(w, X) |
|
swish = nn.functional.silu(s) |
|
z = (m + swish * p) / (swish ** 2 + 1) |
|
|
|
del w, m, swish |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
w = subset['mlp.gate_proj'].weight |
|
|
|
w = w.to(dtype=torch.float32).requires_grad_(True) |
|
|
|
s_update_epochs = 2 |
|
s_learning_rate = 0.01 |
|
for _ in range(s_update_epochs): |
|
|
|
batch_size = 1000 |
|
|
|
for k in range(0, s.size(-1), batch_size): |
|
chunk = slice(k, k + batch_size) |
|
|
|
|
|
X_batch = X[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
|
z_batch = z[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
|
p_batch = p[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
|
s_batch = s[:,chunk].to(dtype=torch.float32).requires_grad_(True) |
|
|
|
with torch.enable_grad(): |
|
|
|
loss_s = alpha * torch.norm(s_batch - torch.matmul(w, X_batch))**2 |
|
loss_s += gamma * torch.norm(p_batch - nn.functional.silu(s_batch) * z_batch)**2 |
|
|
|
loss_s.backward() |
|
s_batch -= s_learning_rate * s_batch.grad |
|
s_batch.grad.zero_() |
|
s[:,chunk] = s_batch.detach().to(dtype=torch.float16) |
|
|
|
s_batch, X_batch, z_batch, p_batch, w = s_batch.detach(), X_batch.detach(), z_batch.detach(), p_batch.detach(), w.detach() |
|
del w, loss_s, s_batch, X_batch, z_batch, p_batch |
|
torch.cuda.empty_cache() |
|
|
|
|
|
tmp_training_loss = nn.functional.mse_loss(torch.matmul(subset['mlp.down_proj'].weight, |
|
nn.functional.silu(torch.matmul(subset['mlp.gate_proj'].weight, X)) |
|
* torch.matmul(subset['mlp.up_proj'].weight, X)), Y) |
|
training_loss['train_loss'].append(tmp_training_loss.item()) |
|
|
|
for j in range(args.nsamples): |
|
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
|
|
|
layers[i] = layer.cpu() |
|
del layer |
|
del gpts |
|
torch.cuda.empty_cache() |
|
|
|
inps, outs = outs, inps |
|
|
|
model.config.use_cache = use_cache |
|
|
|
|
|
@torch.no_grad() |
|
def opt_eval(model, testenc, dev, args, dataset: str): |
|
print('Evaluating ...') |
|
|
|
testenc = testenc.input_ids |
|
nsamples = testenc.numel() // model.seqlen |
|
|
|
use_cache = model.config.use_cache |
|
model.config.use_cache = False |
|
layers = model.model.decoder.layers |
|
|
|
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) |
|
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) |
|
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
|
model.model.decoder.project_out = model.model.decoder.project_out.to(dev) |
|
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
|
model.model.decoder.project_in = model.model.decoder.project_in.to(dev) |
|
layers[0] = layers[0].to(dev) |
|
|
|
dtype = next(iter(model.parameters())).dtype |
|
inps = torch.zeros( |
|
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
|
) |
|
cache = {'i': 0, 'attention_mask': None} |
|
|
|
class Catcher(nn.Module): |
|
def __init__(self, module): |
|
super().__init__() |
|
self.module = module |
|
def forward(self, inp, **kwargs): |
|
inps[cache['i']] = inp |
|
cache['i'] += 1 |
|
cache['attention_mask'] = kwargs['attention_mask'] |
|
raise ValueError |
|
layers[0] = Catcher(layers[0]) |
|
for i in range(nsamples): |
|
batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) |
|
try: |
|
model(batch) |
|
except ValueError: |
|
pass |
|
layers[0] = layers[0].module |
|
|
|
layers[0] = layers[0].cpu() |
|
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() |
|
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() |
|
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out: |
|
model.model.decoder.project_out = model.model.decoder.project_out.cpu() |
|
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in: |
|
model.model.decoder.project_in = model.model.decoder.project_in.cpu() |
|
torch.cuda.empty_cache() |
|
|
|
outs = torch.zeros_like(inps) |
|
attention_mask = cache['attention_mask'] |
|
|
|
for i in range(len(layers)): |
|
print(i) |
|
layer = layers[i].to(dev) |
|
|
|
if args.gmp: |
|
subset = find_layers(layer) |
|
for name in subset: |
|
W = subset[name].weight.data |
|
thresh = torch.sort(torch.abs(W.flatten()))[0][int(W.numel() * args.sparsity)] |
|
W.data[torch.abs(W.data) <= thresh] = 0 |
|
|
|
for j in range(nsamples): |
|
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
|
layers[i] = layer.cpu() |
|
del layer |
|
torch.cuda.empty_cache() |
|
inps, outs = outs, inps |
|
|
|
if model.model.decoder.final_layer_norm is not None: |
|
model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev) |
|
if model.model.decoder.project_out is not None: |
|
model.model.decoder.project_out = model.model.decoder.project_out.to(dev) |
|
model.lm_head = model.lm_head.to(dev) |
|
|
|
testenc = testenc.to(dev) |
|
nlls = [] |
|
for i in range(nsamples): |
|
hidden_states = inps[i].unsqueeze(0) |
|
if model.model.decoder.final_layer_norm is not None: |
|
hidden_states = model.model.decoder.final_layer_norm(hidden_states) |
|
if model.model.decoder.project_out is not None: |
|
hidden_states = model.model.decoder.project_out(hidden_states) |
|
lm_logits = model.lm_head(hidden_states) |
|
shift_logits = lm_logits[:, :-1, :].contiguous() |
|
shift_labels = testenc[ |
|
:, (i * model.seqlen):((i + 1) * model.seqlen) |
|
][:, 1:] |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
neg_log_likelihood = loss.float() * model.seqlen |
|
nlls.append(neg_log_likelihood) |
|
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) |
|
print(f"Perplexity: {ppl.item():3f}") |
|
|
|
model.config.use_cache = use_cache |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def llama_eval(model, testenc, dev, args, dataset: str): |
|
print("Evaluating ...") |
|
|
|
testenc = testenc.input_ids |
|
nsamples = testenc.numel() // model.seqlen |
|
|
|
use_cache = model.config.use_cache |
|
model.config.use_cache = False |
|
layers = model.model.layers |
|
|
|
model.model.embed_tokens = model.model.embed_tokens.to(dev) |
|
layers[0] = layers[0].to(dev) |
|
|
|
dtype = next(iter(model.parameters())).dtype |
|
inps = torch.zeros( |
|
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev |
|
) |
|
cache = {"i": 0, "attention_mask": None} |
|
|
|
class Catcher(nn.Module): |
|
def __init__(self, module): |
|
super().__init__() |
|
self.module = module |
|
|
|
def forward(self, inp, **kwargs): |
|
inps[cache["i"]] = inp |
|
cache["i"] += 1 |
|
cache["attention_mask"] = kwargs["attention_mask"] |
|
raise ValueError |
|
|
|
layers[0] = Catcher(layers[0]) |
|
for i in range(nsamples): |
|
batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev) |
|
try: |
|
model(batch) |
|
except ValueError: |
|
pass |
|
layers[0] = layers[0].module |
|
|
|
layers[0] = layers[0].cpu() |
|
model.model.embed_tokens = model.model.embed_tokens.cpu() |
|
torch.cuda.empty_cache() |
|
|
|
outs = torch.zeros_like(inps) |
|
attention_mask = cache["attention_mask"] |
|
|
|
for i in range(len(layers)): |
|
print(i) |
|
layer = layers[i].to(dev) |
|
|
|
if args.gmp: |
|
subset = find_layers(layer) |
|
for name in subset: |
|
W = subset[name].weight.data |
|
thresh = torch.sort(torch.abs(W.flatten()))[0][ |
|
int(W.numel() * args.sparsity) |
|
] |
|
W.data[torch.abs(W.data) <= thresh] = 0 |
|
|
|
for j in range(nsamples): |
|
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] |
|
layers[i] = layer.cpu() |
|
del layer |
|
torch.cuda.empty_cache() |
|
inps, outs = outs, inps |
|
|
|
if model.model.norm is not None: |
|
model.model.norm = model.model.norm.to(dev) |
|
model.lm_head = model.lm_head.to(dev) |
|
|
|
testenc = testenc.to(dev) |
|
nlls = [] |
|
for i in range(nsamples): |
|
hidden_states = inps[i].unsqueeze(0) |
|
if model.model.norm is not None: |
|
hidden_states = model.model.norm(hidden_states) |
|
lm_logits = model.lm_head(hidden_states) |
|
shift_logits = lm_logits[:, :-1, :].contiguous() |
|
shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:] |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct( |
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
|
) |
|
neg_log_likelihood = loss.float() * model.seqlen |
|
nlls.append(neg_log_likelihood) |
|
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) |
|
print(f"Perplexity: {ppl.item():3f}") |
|
|
|
model.config.use_cache = use_cache |