Spaces:
Paused
Paused
Commit
·
75fa479
1
Parent(s):
b2a1f5e
Add tricksy
Browse files- app.py +47 -2
- configuration_tricksy.py +18 -0
- modeling_tricksy.py +618 -0
- util.py +83 -0
app.py
CHANGED
@@ -1,4 +1,49 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from threading import Thread
|
2 |
+
|
3 |
import streamlit as st
|
4 |
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer, TextIteratorStreamer, set_seed
|
7 |
+
from modeling_tricksy import TricksyOPTForCausalLM, OPTDiskWeights
|
8 |
+
from configuration_tricksy import TricksyConfig
|
9 |
+
|
10 |
+
def generate():
|
11 |
+
set_seed(42)
|
12 |
+
|
13 |
+
# 13.4 GB (16 bit)
|
14 |
+
model_name = 'facebook/opt-6.7b'
|
15 |
+
disk_weights = OPTDiskWeights(model_name)
|
16 |
+
tricksy_model = TricksyOPTForCausalLM(TricksyConfig(disk_weights.config, full_offload=(not use_tricksy)), disk_weights)
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
+
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
|
19 |
+
|
20 |
+
inputs = tokenizer(prompt, return_tensors='pt').input_ids.to('cuda')
|
21 |
+
|
22 |
+
print()
|
23 |
+
generation_kwargs = dict(inputs=inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p)
|
24 |
+
thread = Thread(target=tricksy_model.generate, kwargs=generation_kwargs)
|
25 |
+
thread.start()
|
26 |
+
generated_text = ''
|
27 |
+
with st.chat_message("user"):
|
28 |
+
t = st.empty()
|
29 |
+
for new_text in streamer:
|
30 |
+
generated_text += new_text.replace('\n', ' \n')
|
31 |
+
t.write(generated_text)
|
32 |
+
|
33 |
+
stats_text = f'Decoding tok/s: {1 / (sum(tricksy_model.tricksy_context.forward_times[1:]) / (len(tricksy_model.tricksy_context.forward_times) - 1))}'
|
34 |
+
stats_text += f' \nCurrent GPU mem usage: {torch.cuda.memory_allocated("cuda") / 1024 ** 3} GB'
|
35 |
+
stats_text += f' \nMax GPU mem usage: {torch.cuda.max_memory_allocated("cuda") / 1024 ** 3} GB'
|
36 |
+
st.write(stats_text)
|
37 |
+
|
38 |
+
prompt = st.text_area('Prompt', 'Making pesto from scratch can be done with these ingredients in 4 simple steps:\nStep 1')
|
39 |
+
|
40 |
+
col1, col2 = st.columns(2)
|
41 |
+
with col1:
|
42 |
+
submit = st.button('Submit', on_click=generate)
|
43 |
+
with col2:
|
44 |
+
use_tricksy = st.toggle('Use Tricksy', True, help='If true, only send sparse MLP weight diffs to GPU. If false, send all weights to GPU.')
|
45 |
+
|
46 |
+
with st.expander('Additional options'):
|
47 |
+
max_new_tokens = st.slider('Max new tokens', 1, 500, 100)
|
48 |
+
top_k = st.slider('Top-k sampling', 1, 500, 50)
|
49 |
+
top_p = st.slider('Top-p (nucleus sampling)', 0.0, 1.0, .9)
|
configuration_tricksy.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import torch
|
3 |
+
from transformers.models.opt.configuration_opt import OPTConfig
|
4 |
+
|
5 |
+
@dataclasses.dataclass(frozen=True)
|
6 |
+
class TricksyConfig:
|
7 |
+
opt_config: OPTConfig
|
8 |
+
|
9 |
+
# Percentage of weights to keep on each device
|
10 |
+
# e.g. 30% of each MLP layer on GPU
|
11 |
+
min_mlp_sparsity_gpu: float = .3
|
12 |
+
# e.g. 100% of each MLP layer on CPU
|
13 |
+
min_mlp_sparsity_cpu: float = 1
|
14 |
+
|
15 |
+
# If true, cleans up layer's weights after computing forward pass
|
16 |
+
full_offload: bool = False
|
17 |
+
|
18 |
+
dtype: torch.dtype = torch.float16
|
modeling_tricksy.py
ADDED
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional, Callable, List, Tuple
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from accelerate import init_empty_weights
|
11 |
+
from transformers.activations import ACT2FN
|
12 |
+
from transformers.generation import GenerationConfig
|
13 |
+
from transformers.models.opt.modeling_opt import (
|
14 |
+
OPTAttention,
|
15 |
+
OPTDecoder,
|
16 |
+
OPTDecoderLayer,
|
17 |
+
OPTForCausalLM,
|
18 |
+
OPTModel,
|
19 |
+
)
|
20 |
+
from transformers.models.opt.configuration_opt import OPTConfig
|
21 |
+
from huggingface_hub import snapshot_download
|
22 |
+
|
23 |
+
from configuration_tricksy import TricksyConfig
|
24 |
+
from util import batch_copy, compute_index_diffs, load_mlp_sparsity_predictor, mmap_to_tensor, topk_and_threshold
|
25 |
+
|
26 |
+
TRICKSY_WEIGHTS_PATH = 'tricksy-weights/'
|
27 |
+
|
28 |
+
class SparseMLPCache:
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
indexed_fc1_weight: Optional[torch.Tensor] = None,
|
32 |
+
indexed_fc1_bias: Optional[torch.Tensor] = None,
|
33 |
+
indexed_fc2_weight: Optional[torch.Tensor] = None,
|
34 |
+
gpu_cached_mlp_indices: Optional[torch.Tensor] = None,
|
35 |
+
):
|
36 |
+
# [ffn_embed_dim * min_mlp_sparsity, hidden_size]
|
37 |
+
self.indexed_fc1_weight = indexed_fc1_weight
|
38 |
+
# [ffn_embed_dim * min_mlp_sparsity]
|
39 |
+
self.indexed_fc1_bias = indexed_fc1_bias
|
40 |
+
# [ffn_embed_dim * min_mlp_sparsity, hidden_size] (stored in transpose for efficient indexing)
|
41 |
+
self.indexed_fc2_weight = indexed_fc2_weight
|
42 |
+
|
43 |
+
# Indices that are already on GPU (this tensor is stored on the CPU)
|
44 |
+
# [ffn_embed_dim * min_mlp_sparsity]
|
45 |
+
self.gpu_cached_mlp_indices = gpu_cached_mlp_indices
|
46 |
+
|
47 |
+
class SparseIndices:
|
48 |
+
def __init__(self, tricksy_config: TricksyConfig, opt_config: OPTConfig):
|
49 |
+
self.mlp_indices_buffer_gpu = torch.empty(
|
50 |
+
(int(opt_config.ffn_dim * tricksy_config.min_mlp_sparsity_gpu),),
|
51 |
+
dtype=torch.int32,
|
52 |
+
device='cuda'
|
53 |
+
)
|
54 |
+
self.mlp_indices_buffer_cpu = torch.empty(
|
55 |
+
(int(opt_config.ffn_dim * tricksy_config.min_mlp_sparsity_gpu),),
|
56 |
+
dtype=torch.int32,
|
57 |
+
device='cpu',
|
58 |
+
pin_memory=True,
|
59 |
+
)
|
60 |
+
|
61 |
+
# Default stream blocks until indices are copied to CPU
|
62 |
+
self.index_copy_stream = torch.cuda.default_stream()
|
63 |
+
|
64 |
+
def copy_mlp_indices_to_cpu(self):
|
65 |
+
self.mlp_indices_buffer_cpu = batch_copy([self.mlp_indices_buffer_gpu], self.index_copy_stream, device='cpu')[0]
|
66 |
+
|
67 |
+
class OPTDiskWeights:
|
68 |
+
def __init__(self, model_name: str):
|
69 |
+
self.model_name = model_name
|
70 |
+
self.model_suffix = model_name.split('/')[-1]
|
71 |
+
self.config = OPTConfig.from_pretrained(model_name)
|
72 |
+
|
73 |
+
try:
|
74 |
+
print(f'downloading from austinsilveria/tricksy-{self.model_suffix}')
|
75 |
+
self.weight_path = snapshot_download(repo_id=f'austinsilveria/tricksy-{self.model_suffix}') + '/'
|
76 |
+
except:
|
77 |
+
print(f'failed to download from austinsilveria/tricksy-{self.model_suffix}')
|
78 |
+
self.weight_path = f'{TRICKSY_WEIGHTS_PATH}{self.model_suffix}/'
|
79 |
+
|
80 |
+
with init_empty_weights():
|
81 |
+
model = OPTModel(self.config)
|
82 |
+
self.state_dict = model.state_dict()
|
83 |
+
|
84 |
+
if not os.path.exists(f'{self.weight_path}decoder.embed_tokens.weight'):
|
85 |
+
# Download original weights and write memmap files
|
86 |
+
print(f'downloading and preprocessing original weights')
|
87 |
+
self.cache_weights()
|
88 |
+
|
89 |
+
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
90 |
+
for i in range(self.config.num_hidden_layers):
|
91 |
+
layer_prefix = f'decoder.layers.{i}.'
|
92 |
+
self.delete_weights([
|
93 |
+
f'{layer_prefix}self_attn.q_proj.weight',
|
94 |
+
f'{layer_prefix}self_attn.k_proj.weight',
|
95 |
+
f'{layer_prefix}self_attn.v_proj.weight',
|
96 |
+
f'{layer_prefix}self_attn.out_proj.weight',
|
97 |
+
f'{layer_prefix}self_attn.q_proj.bias',
|
98 |
+
f'{layer_prefix}self_attn.k_proj.bias',
|
99 |
+
f'{layer_prefix}self_attn.v_proj.bias'
|
100 |
+
])
|
101 |
+
self.add_weights([
|
102 |
+
(f'{layer_prefix}fc2.weight', (self.config.ffn_dim, self.config.hidden_size)),
|
103 |
+
(f'{layer_prefix}self_attn.catted_head_weights', (self.config.num_attention_heads, head_dim * 4, self.config.hidden_size)),
|
104 |
+
(f'{layer_prefix}self_attn.catted_head_biases', (self.config.num_attention_heads, 3, head_dim)),
|
105 |
+
])
|
106 |
+
|
107 |
+
self.memmap_weights = { key: self.load_memmap_weight(key) for key in self.state_dict.keys() }
|
108 |
+
|
109 |
+
def load_memmap_weight(self, key: str):
|
110 |
+
return torch.from_numpy(np.memmap(f'{self.weight_path}{key}', dtype='float16', mode='r', shape=(self.state_dict[key].shape)))
|
111 |
+
|
112 |
+
def add_weights(self, weights: List[Tuple[str, torch.Size]]):
|
113 |
+
for key, shape in weights:
|
114 |
+
self.state_dict[key] = torch.empty(shape, dtype=torch.float16, device='meta')
|
115 |
+
|
116 |
+
def delete_weights(self, keys: List[str]):
|
117 |
+
for key in keys:
|
118 |
+
if key in self.state_dict:
|
119 |
+
del self.state_dict[key]
|
120 |
+
path = f'{self.weight_path}{key}'
|
121 |
+
if os.path.exists(path):
|
122 |
+
os.remove(path)
|
123 |
+
|
124 |
+
def cache_weights(self):
|
125 |
+
os.makedirs(self.weight_path, exist_ok=True)
|
126 |
+
weights_location = snapshot_download(repo_id=self.model_name, ignore_patterns=['flax*', 'tf*'])
|
127 |
+
shards = [file for file in os.listdir(weights_location) if file.startswith("pytorch_model") and file.endswith(".bin")]
|
128 |
+
for shard in shards:
|
129 |
+
print(f'caching {shard}')
|
130 |
+
shard_path = os.path.join(weights_location, shard)
|
131 |
+
shard_state_dict = torch.load(shard_path)
|
132 |
+
for key in shard_state_dict.keys():
|
133 |
+
path = f'{self.weight_path}{key.replace("model.", "")}'
|
134 |
+
memmap = np.memmap(path, dtype='float16', mode='w+', shape=(shard_state_dict[key].shape))
|
135 |
+
memmap[:] = shard_state_dict[key].cpu().numpy()
|
136 |
+
|
137 |
+
# Store weights in shape for efficient indexing
|
138 |
+
for i in range(self.config.num_hidden_layers):
|
139 |
+
layer_prefix = f'decoder.layers.{i}.'
|
140 |
+
# FC2 in transpose
|
141 |
+
fc2t = torch.from_numpy(np.array(self.load_memmap_weight(f'{layer_prefix}fc2.weight')[:])).t().contiguous().clone()
|
142 |
+
np.memmap(f'{self.weight_path}decoder.layers.{i}.fc2.weight', dtype='float16', mode='w+', shape=fc2t.shape)[:] = fc2t.numpy()
|
143 |
+
|
144 |
+
# Attention weights by head
|
145 |
+
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
146 |
+
qw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.q_proj.weight')[:])
|
147 |
+
kw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.k_proj.weight')[:])
|
148 |
+
vw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.v_proj.weight')[:])
|
149 |
+
ow = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.out_proj.weight')[:])
|
150 |
+
pre_cat_shape = (self.config.num_attention_heads, head_dim, self.config.hidden_size)
|
151 |
+
# [head, head_dim * 4, hidden_size]
|
152 |
+
catted_head_weights = torch.cat(
|
153 |
+
[qw.view(pre_cat_shape).clone(), kw.view(pre_cat_shape).clone(), vw.view(pre_cat_shape).clone(), ow.T.view(pre_cat_shape).clone(),],
|
154 |
+
dim=1,
|
155 |
+
).contiguous().clone()
|
156 |
+
np.memmap(f'{self.weight_path}{layer_prefix}self_attn.catted_head_weights', dtype='float16', mode='w+', shape=catted_head_weights.shape)[:] =\
|
157 |
+
catted_head_weights.numpy()
|
158 |
+
|
159 |
+
# Attention biases by head
|
160 |
+
qb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.q_proj.bias')[:])
|
161 |
+
kb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.k_proj.bias')[:])
|
162 |
+
vb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.v_proj.bias')[:])
|
163 |
+
pre_cat_shape = (self.config.num_attention_heads, 1, head_dim)
|
164 |
+
# [head, 3, head_dim]
|
165 |
+
catted_head_biases = torch.cat(
|
166 |
+
# Don't index out bias since we need all dims after projecting back up to hidden size
|
167 |
+
[qb.view(pre_cat_shape).clone(), kb.view(pre_cat_shape).clone(), vb.view(pre_cat_shape).clone()],
|
168 |
+
dim=1,
|
169 |
+
).contiguous().clone()
|
170 |
+
np.memmap(f'{self.weight_path}{layer_prefix}self_attn.catted_head_biases', dtype='float16', mode='w+', shape=catted_head_biases.shape)[:] =\
|
171 |
+
catted_head_biases.numpy()
|
172 |
+
|
173 |
+
self.delete_weights([
|
174 |
+
f'{layer_prefix}self_attn.q_proj.weight',
|
175 |
+
f'{layer_prefix}self_attn.k_proj.weight',
|
176 |
+
f'{layer_prefix}self_attn.v_proj.weight',
|
177 |
+
f'{layer_prefix}self_attn.out_proj.weight',
|
178 |
+
f'{layer_prefix}self_attn.q_proj.bias',
|
179 |
+
f'{layer_prefix}self_attn.k_proj.bias',
|
180 |
+
f'{layer_prefix}self_attn.v_proj.bias'
|
181 |
+
])
|
182 |
+
self.add_weights([
|
183 |
+
(f'{layer_prefix}self_attn.catted_head_weights', catted_head_weights.shape),
|
184 |
+
(f'{layer_prefix}self_attn.catted_head_biases', catted_head_biases.shape),
|
185 |
+
])
|
186 |
+
|
187 |
+
class TricksyContext:
|
188 |
+
def __init__(self, tricksy_config: TricksyConfig, opt_config: OPTConfig):
|
189 |
+
self.indices = SparseIndices(tricksy_config, opt_config)
|
190 |
+
self.load_weight_stream = torch.cuda.Stream()
|
191 |
+
self.layer = 0
|
192 |
+
self.is_prompt_phase = True
|
193 |
+
self.forward_times = []
|
194 |
+
|
195 |
+
class TricksyLayer:
|
196 |
+
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
197 |
+
return self.forward(*args, **kwds)
|
198 |
+
|
199 |
+
def load_weights(self, tricksy_context: TricksyContext):
|
200 |
+
pass
|
201 |
+
|
202 |
+
class TricksyLayerInputs:
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
disk_weights: OPTDiskWeights,
|
206 |
+
layer_key_prefix: str = None,
|
207 |
+
next_layer: TricksyLayer = None,
|
208 |
+
sparsity_predictors: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
209 |
+
) -> None:
|
210 |
+
self.disk_weights = disk_weights
|
211 |
+
# self.get_weight = lambda key: self.disk_weights.load_memmap_weight(f'{layer_key_prefix}{key}')
|
212 |
+
self.get_weight = lambda key: self.disk_weights.memmap_weights[(f'{layer_key_prefix}{key}')]
|
213 |
+
self.layer_key_prefix = layer_key_prefix
|
214 |
+
self.next_layer = next_layer
|
215 |
+
self.sparsity_predictors = sparsity_predictors
|
216 |
+
|
217 |
+
class TricksyOPTLearnedPositionalEmbedding(TricksyLayer):
|
218 |
+
"""
|
219 |
+
This module learns positional embeddings up to a fixed maximum size.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(self, tricksy_context):
|
223 |
+
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
|
224 |
+
# and adjust num_embeddings appropriately. Other models don't have this hack
|
225 |
+
self.offset = 2
|
226 |
+
self.tricksy_context = tricksy_context
|
227 |
+
self.weight = None
|
228 |
+
|
229 |
+
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
230 |
+
return self.forward(*args, **kwds)
|
231 |
+
|
232 |
+
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
|
233 |
+
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
234 |
+
attention_mask = attention_mask.long()
|
235 |
+
# create positions depending on attention_mask
|
236 |
+
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
|
237 |
+
# cut positions if `past_key_values_length` is > 0
|
238 |
+
positions = positions[:, past_key_values_length:]
|
239 |
+
|
240 |
+
out = F.embedding(positions + self.offset, self.weight)
|
241 |
+
return out
|
242 |
+
|
243 |
+
class TricksyOPTAttention(OPTAttention, TricksyLayer):
|
244 |
+
def __init__(self, tricksy_config: TricksyConfig, inputs: TricksyLayerInputs, tricksy_context: TricksyContext, is_decoder: bool = False, **kwargs):
|
245 |
+
nn.Module.__init__(self)
|
246 |
+
self.tricksy_config = tricksy_config
|
247 |
+
self.config = tricksy_config.opt_config
|
248 |
+
|
249 |
+
def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
|
250 |
+
"""
|
251 |
+
If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
|
252 |
+
warning and return that value, otherwise take the equivalent config.config_arg_name
|
253 |
+
"""
|
254 |
+
val = None
|
255 |
+
if fn_arg_name in kwargs:
|
256 |
+
print(
|
257 |
+
"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
|
258 |
+
" Please set it in the config instead"
|
259 |
+
)
|
260 |
+
val = kwargs.pop(fn_arg_name)
|
261 |
+
else:
|
262 |
+
val = getattr(config, config_arg_name)
|
263 |
+
return val
|
264 |
+
|
265 |
+
self.embed_dim = _handle_deprecated_argument("hidden_size", self.config, "embed_dim", kwargs)
|
266 |
+
self.num_heads = _handle_deprecated_argument("num_attention_heads", self.config, "num_heads", kwargs)
|
267 |
+
self.dropout = _handle_deprecated_argument("attention_dropout", self.config, "dropout", kwargs)
|
268 |
+
self.enable_bias = _handle_deprecated_argument("enable_bias", self.config, "bias", kwargs)
|
269 |
+
|
270 |
+
self.head_dim = self.embed_dim // self.num_heads
|
271 |
+
self.is_causal = True
|
272 |
+
|
273 |
+
if (self.head_dim * self.num_heads) != self.embed_dim:
|
274 |
+
raise ValueError(
|
275 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
276 |
+
f" and `num_heads`: {self.num_heads})."
|
277 |
+
)
|
278 |
+
self.scaling = self.head_dim**-0.5
|
279 |
+
self.is_decoder = is_decoder
|
280 |
+
|
281 |
+
# [Tricksy]
|
282 |
+
self.tricksy_context = tricksy_context
|
283 |
+
self.inputs = inputs
|
284 |
+
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
285 |
+
|
286 |
+
self.qw = self.kw = self.vw = self.ow = self.qb = self.kb = self.vb = self.out_proj_bias = self.layer_norm_weight = self.layer_norm_bias = None
|
287 |
+
self.q_proj = lambda x: F.linear(x, self.qw, self.qb)
|
288 |
+
self.k_proj = lambda x: F.linear(x, self.kw, self.kb)
|
289 |
+
self.v_proj = lambda x: F.linear(x, self.vw, self.vb)
|
290 |
+
self.out_proj = lambda x: F.linear(x, self.ow, self.out_proj_bias)
|
291 |
+
self.layer_norm = lambda x: F.layer_norm(x, (self.config.hidden_size,), self.layer_norm_weight, self.layer_norm_bias)
|
292 |
+
|
293 |
+
def load_weights(self, tricksy_context: TricksyContext):
|
294 |
+
if self.tricksy_context.is_prompt_phase:
|
295 |
+
# Full weights for prompt phase
|
296 |
+
self.catted_weights, self.catted_biases, self.out_proj_bias, self.layer_norm_weight, self.layer_norm_bias = batch_copy(
|
297 |
+
[
|
298 |
+
mmap_to_tensor(self.inputs.get_weight('self_attn.catted_head_weights')[:], pin_memory=True),
|
299 |
+
mmap_to_tensor(self.inputs.get_weight('self_attn.catted_head_biases')[:], pin_memory=True),
|
300 |
+
mmap_to_tensor(self.inputs.get_weight('self_attn.out_proj.bias')[:], pin_memory=True),
|
301 |
+
mmap_to_tensor(self.inputs.get_weight('self_attn_layer_norm.weight')[:], pin_memory=True),
|
302 |
+
mmap_to_tensor(self.inputs.get_weight('self_attn_layer_norm.bias')[:], pin_memory=True),
|
303 |
+
],
|
304 |
+
tricksy_context.load_weight_stream,
|
305 |
+
)
|
306 |
+
torch.cuda.synchronize()
|
307 |
+
# Weights stored in shape for efficient indexing to support offloading attention heads (not currently being done)
|
308 |
+
self.qw = self.catted_weights[:, :self.head_dim, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous()
|
309 |
+
self.kw = self.catted_weights[:, self.head_dim:self.head_dim * 2, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous()
|
310 |
+
self.vw = self.catted_weights[:, self.head_dim * 2:self.head_dim * 3, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous()
|
311 |
+
self.ow = self.catted_weights[:, self.head_dim * 3:, :].reshape(self.config.hidden_size, self.config.hidden_size).t().contiguous()
|
312 |
+
self.catted_weights = None
|
313 |
+
|
314 |
+
self.qb = self.catted_biases[:, 0, :].reshape(self.config.hidden_size).contiguous()
|
315 |
+
self.kb = self.catted_biases[:, 1, :].reshape(self.config.hidden_size).contiguous()
|
316 |
+
self.vb = self.catted_biases[:, 2, :].reshape(self.config.hidden_size).contiguous()
|
317 |
+
self.catted_biases = None
|
318 |
+
|
319 |
+
def forward(self, hidden_states, **kwargs):
|
320 |
+
# Wait for attention weights to get to GPU
|
321 |
+
torch.cuda.synchronize()
|
322 |
+
|
323 |
+
# Predict MLP sparsity based on attention input
|
324 |
+
self.tricksy_context.indices.mlp_indices_buffer_gpu = topk_and_threshold(
|
325 |
+
self.inputs.sparsity_predictors[0](hidden_states)[0, -1, :],
|
326 |
+
int(self.config.ffn_dim * self.tricksy_config.min_mlp_sparsity_gpu),
|
327 |
+
)
|
328 |
+
self.tricksy_context.indices.copy_mlp_indices_to_cpu()
|
329 |
+
torch.cuda.synchronize()
|
330 |
+
|
331 |
+
# Load MLP weights while computing attention
|
332 |
+
self.inputs.next_layer.load_weights(self.tricksy_context)
|
333 |
+
|
334 |
+
out = super().forward(self.layer_norm(hidden_states), **kwargs)
|
335 |
+
|
336 |
+
# Wait for MLP weights to get to GPU
|
337 |
+
torch.cuda.synchronize()
|
338 |
+
|
339 |
+
return out
|
340 |
+
|
341 |
+
class TricksyOPTDecoderLayer(OPTDecoderLayer):
|
342 |
+
def __init__(self, tricksy_config: TricksyConfig, inputs: TricksyLayerInputs, tricksy_context: TricksyContext):
|
343 |
+
nn.Module.__init__(self)
|
344 |
+
self.tricksy_config = tricksy_config
|
345 |
+
self.config = tricksy_config.opt_config
|
346 |
+
self.embed_dim = self.config.hidden_size
|
347 |
+
|
348 |
+
self.tricksy_context = tricksy_context
|
349 |
+
self.self_attn_layer_inputs = TricksyLayerInputs(
|
350 |
+
disk_weights=inputs.disk_weights,
|
351 |
+
layer_key_prefix=inputs.layer_key_prefix,
|
352 |
+
# While computing attention, load MLP
|
353 |
+
next_layer=self,
|
354 |
+
sparsity_predictors=inputs.sparsity_predictors,
|
355 |
+
)
|
356 |
+
self.self_attn = TricksyOPTAttention(tricksy_config, self.self_attn_layer_inputs, tricksy_context, is_decoder=True)
|
357 |
+
|
358 |
+
self.do_layer_norm_before = self.config.do_layer_norm_before
|
359 |
+
self.dropout = self.config.dropout
|
360 |
+
self.activation_fn = ACT2FN[self.config.activation_function]
|
361 |
+
|
362 |
+
self.inputs = inputs
|
363 |
+
random_mlp_indices_gpu =\
|
364 |
+
torch.randperm(self.config.ffn_dim, device='cpu', dtype=torch.int32)[:int(self.config.ffn_dim * self.tricksy_config.min_mlp_sparsity_gpu)]
|
365 |
+
self.index_cache = SparseMLPCache(gpu_cached_mlp_indices=random_mlp_indices_gpu)
|
366 |
+
|
367 |
+
# identity since we move this to attention layer
|
368 |
+
# extreme tricksy
|
369 |
+
self.self_attn_layer_norm = lambda x: x
|
370 |
+
|
371 |
+
self.fc1_weight = self.fc2_weight = self.final_layer_norm_weight = self.fc1_bias = self.fc2_bias = self.final_layer_norm_bias = None
|
372 |
+
self.ring_idx = 0
|
373 |
+
self.fc1_weight_diff = self.fc2_weight_diff = self.fc1_bias_diff = None
|
374 |
+
self.fc1 = lambda x: F.linear(x, torch.cat([self.fc1_weight, self.fc1_weight_diff]), torch.cat([self.fc1_bias, self.fc1_bias_diff]))
|
375 |
+
self.fc2 = lambda x: F.linear(x, torch.cat([self.fc2_weight, self.fc2_weight_diff]).T, self.fc2_bias)
|
376 |
+
self.final_layer_norm = lambda x: F.layer_norm(x, (self.embed_dim,), self.final_layer_norm_weight, self.final_layer_norm_bias)
|
377 |
+
|
378 |
+
def load_weights(self, tricksy_context: TricksyContext):
|
379 |
+
if self.tricksy_context.is_prompt_phase:
|
380 |
+
# Full weights for prompt phase
|
381 |
+
fc1w = mmap_to_tensor(self.inputs.get_weight('fc1.weight')[:], pin_memory=True)
|
382 |
+
fc1b = mmap_to_tensor(self.inputs.get_weight('fc1.bias')[:], pin_memory=True)
|
383 |
+
fc2w = mmap_to_tensor(self.inputs.get_weight('fc2.weight')[:], pin_memory=True)
|
384 |
+
fc2b = mmap_to_tensor(self.inputs.get_weight('fc2.bias')[:], pin_memory=True)
|
385 |
+
lnw = mmap_to_tensor(self.inputs.get_weight('final_layer_norm.weight')[:], pin_memory=True)
|
386 |
+
lnb = mmap_to_tensor(self.inputs.get_weight('final_layer_norm.bias')[:], pin_memory=True)
|
387 |
+
|
388 |
+
self.fc1_weight, self.fc1_bias, self.fc2_weight, self.fc2_bias, self.final_layer_norm_weight, self.final_layer_norm_bias =\
|
389 |
+
batch_copy([fc1w, fc1b, fc2w, fc2b, lnw, lnb], tricksy_context.load_weight_stream)
|
390 |
+
self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
|
391 |
+
self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
|
392 |
+
self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
|
393 |
+
|
394 |
+
index_diffs = compute_index_diffs(tricksy_context.indices.mlp_indices_buffer_cpu, [self.index_cache.gpu_cached_mlp_indices])
|
395 |
+
if len(index_diffs) > 0:
|
396 |
+
gpu_index_diff = index_diffs[0]
|
397 |
+
self.index_cache.gpu_cached_mlp_indices[gpu_index_diff.off_positions] = gpu_index_diff.off_elements
|
398 |
+
|
399 |
+
self.index_cache.indexed_fc1_weight = fc1w.contiguous().pin_memory()
|
400 |
+
self.index_cache.indexed_fc1_bias = fc1b.contiguous().pin_memory()
|
401 |
+
self.index_cache.indexed_fc2_weight = fc2w.contiguous().pin_memory()
|
402 |
+
return
|
403 |
+
elif self.fc1_weight is None:
|
404 |
+
# Full weights if full offload
|
405 |
+
self.fc1_weight, self.fc1_bias, self.fc2_weight = batch_copy(
|
406 |
+
[self.index_cache.indexed_fc1_weight, self.index_cache.indexed_fc1_bias, self.index_cache.indexed_fc2_weight],
|
407 |
+
tricksy_context.load_weight_stream
|
408 |
+
)
|
409 |
+
self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
|
410 |
+
self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
|
411 |
+
self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
|
412 |
+
|
413 |
+
off_elements = torch.tensor(
|
414 |
+
list(set(tricksy_context.indices.mlp_indices_buffer_cpu.tolist()).difference(set(self.index_cache.gpu_cached_mlp_indices.tolist()))),
|
415 |
+
device='cpu',
|
416 |
+
dtype=torch.int32,
|
417 |
+
pin_memory=True
|
418 |
+
)
|
419 |
+
if off_elements.size(0) == 0:
|
420 |
+
self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
|
421 |
+
self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
|
422 |
+
self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda')
|
423 |
+
return
|
424 |
+
|
425 |
+
new_ring_idx = (self.ring_idx + off_elements.size(0)) % self.index_cache.gpu_cached_mlp_indices.size(0)
|
426 |
+
if new_ring_idx > self.ring_idx:
|
427 |
+
# single contiguous update
|
428 |
+
self.index_cache.gpu_cached_mlp_indices[self.ring_idx:new_ring_idx] = off_elements
|
429 |
+
elif off_elements.size(0) > 0:
|
430 |
+
split = self.index_cache.gpu_cached_mlp_indices.size(0) - self.ring_idx
|
431 |
+
# end of ring
|
432 |
+
self.index_cache.gpu_cached_mlp_indices[self.ring_idx:] = off_elements[:split]
|
433 |
+
# beginning of ring
|
434 |
+
self.index_cache.gpu_cached_mlp_indices[:new_ring_idx] = off_elements[split:]
|
435 |
+
|
436 |
+
# Allocate
|
437 |
+
self.fc1_weight_diff = torch.empty((off_elements.size(0), self.config.hidden_size), dtype=self.tricksy_config.dtype, device='cuda')
|
438 |
+
self.fc1_bias_diff = torch.empty((off_elements.size(0)), dtype=self.tricksy_config.dtype, device='cuda')
|
439 |
+
self.fc2_weight_diff = torch.empty((off_elements.size(0), self.config.hidden_size), dtype=self.tricksy_config.dtype, device='cuda')
|
440 |
+
# Index
|
441 |
+
fc1wd = self.index_cache.indexed_fc1_weight[off_elements].pin_memory()
|
442 |
+
fc1bd = self.index_cache.indexed_fc1_bias[off_elements].pin_memory()
|
443 |
+
fc2wd = self.index_cache.indexed_fc2_weight[off_elements].pin_memory()
|
444 |
+
# Copy
|
445 |
+
self.fc1_weight_diff, self.fc1_bias_diff, self.fc2_weight_diff = batch_copy([fc1wd, fc1bd, fc2wd], tricksy_context.load_weight_stream)
|
446 |
+
|
447 |
+
def forward(self, *args, **kwargs):
|
448 |
+
# Wait for attention weights to get to GPU
|
449 |
+
torch.cuda.synchronize()
|
450 |
+
|
451 |
+
# Load next layer's attention weights
|
452 |
+
self.inputs.next_layer.load_weights(self.tricksy_context)
|
453 |
+
|
454 |
+
out = super().forward(*args, **kwargs)
|
455 |
+
|
456 |
+
if self.tricksy_config.full_offload:
|
457 |
+
self.fc1_weight = self.fc1_bias = self.fc2_weight = None
|
458 |
+
elif self.tricksy_context.is_prompt_phase:
|
459 |
+
# Only keep sparse MLP weights on GPU after prompt phase
|
460 |
+
self.fc1_weight = self.fc1_weight[self.index_cache.gpu_cached_mlp_indices.to('cuda')]
|
461 |
+
self.fc1_bias = self.fc1_bias[self.index_cache.gpu_cached_mlp_indices.to('cuda')]
|
462 |
+
self.fc2_weight = self.fc2_weight[self.index_cache.gpu_cached_mlp_indices.to('cuda')]
|
463 |
+
|
464 |
+
# Update ring buffers
|
465 |
+
if not self.tricksy_config.full_offload:
|
466 |
+
prev_ring_idx = self.ring_idx
|
467 |
+
self.ring_idx = (self.ring_idx + self.fc1_weight_diff.size(0)) % self.fc1_weight.size(0)
|
468 |
+
if self.ring_idx > prev_ring_idx:
|
469 |
+
# does not wrap around ring
|
470 |
+
self.fc1_weight[prev_ring_idx:self.ring_idx] = self.fc1_weight_diff
|
471 |
+
self.fc1_bias[prev_ring_idx:self.ring_idx] = self.fc1_bias_diff
|
472 |
+
self.fc2_weight[prev_ring_idx:self.ring_idx] = self.fc2_weight_diff
|
473 |
+
elif self.fc1_weight_diff.size(0) > 0:
|
474 |
+
# wraps around ring
|
475 |
+
split = self.fc1_weight_diff.size(0) - self.ring_idx
|
476 |
+
self.fc1_weight[prev_ring_idx:] = self.fc1_weight_diff[:split]
|
477 |
+
self.fc1_weight[:self.ring_idx] = self.fc1_weight_diff[split:]
|
478 |
+
self.fc1_bias[prev_ring_idx:] = self.fc1_bias_diff[:split]
|
479 |
+
self.fc1_bias[:self.ring_idx] = self.fc1_bias_diff[split:]
|
480 |
+
self.fc2_weight[prev_ring_idx:] = self.fc2_weight_diff[:split]
|
481 |
+
self.fc2_weight[:self.ring_idx] = self.fc2_weight_diff[split:]
|
482 |
+
self.fc1_weight_diff = self.fc2_weight_diff = self.fc1_bias_diff = None
|
483 |
+
|
484 |
+
self.tricksy_context.layer += 1
|
485 |
+
return out
|
486 |
+
|
487 |
+
class TricksyOPTDecoder(OPTDecoder, TricksyLayer):
|
488 |
+
def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights, tricksy_opt_for_causal_lm, tricksy_context: TricksyContext):
|
489 |
+
nn.Module.__init__(self)
|
490 |
+
self.config = tricksy_config.opt_config
|
491 |
+
self.dropout = self.config.dropout
|
492 |
+
self.layerdrop = self.config.layerdrop
|
493 |
+
self.padding_idx = self.config.pad_token_id
|
494 |
+
self.max_target_positions = self.config.max_position_embeddings
|
495 |
+
self.vocab_size = self.config.vocab_size
|
496 |
+
self._use_flash_attention_2 = False
|
497 |
+
self.gradient_checkpointing = False
|
498 |
+
self.project_out = None
|
499 |
+
self.project_in = None
|
500 |
+
|
501 |
+
self.embed_tokens_weight = None
|
502 |
+
self.embed_positions = TricksyOPTLearnedPositionalEmbedding(tricksy_context)
|
503 |
+
|
504 |
+
self.tricksy_context = tricksy_context
|
505 |
+
self.layers: List[TricksyOPTDecoderLayer] = []
|
506 |
+
for i in range(self.config.num_hidden_layers):
|
507 |
+
pretrained_layer_num = self.config.num_hidden_layers - i - 1
|
508 |
+
sparsity_predictors = [load_mlp_sparsity_predictor(disk_weights.weight_path, pretrained_layer_num, tricksy_config.dtype)]
|
509 |
+
if sparsity_predictors[0] is None:
|
510 |
+
sparsity_predictors[0] = lambda x: F.linear(x, torch.rand((self.config.ffn_dim, self.config.hidden_size), device='cuda', dtype=tricksy_config.dtype))
|
511 |
+
self.layers.append(TricksyOPTDecoderLayer(
|
512 |
+
tricksy_config,
|
513 |
+
TricksyLayerInputs(
|
514 |
+
disk_weights=disk_weights,
|
515 |
+
layer_key_prefix=f'decoder.layers.{pretrained_layer_num}.',
|
516 |
+
# While computing MLP, load next attention
|
517 |
+
# While computing last MLP, load output embeddings (stored in TricksyOPTForCausalLM)
|
518 |
+
next_layer=self.layers[i - 1].self_attn if i > 0 else tricksy_opt_for_causal_lm,
|
519 |
+
sparsity_predictors=sparsity_predictors,
|
520 |
+
),
|
521 |
+
tricksy_context,
|
522 |
+
))
|
523 |
+
self.layers.reverse()
|
524 |
+
|
525 |
+
self.final_layer_norm = lambda x: x
|
526 |
+
self.inputs = TricksyLayerInputs(disk_weights=disk_weights, layer_key_prefix='decoder.')
|
527 |
+
|
528 |
+
def embed_tokens(self, x):
|
529 |
+
return F.embedding(x, self.embed_tokens_weight, self.padding_idx)
|
530 |
+
|
531 |
+
def load_weights(self, tricksy_context: TricksyContext):
|
532 |
+
if self.embed_tokens_weight is None:
|
533 |
+
self.embed_tokens_weight, self.embed_positions.weight = batch_copy(
|
534 |
+
[
|
535 |
+
mmap_to_tensor(self.inputs.get_weight('embed_tokens.weight')[:], pin_memory=True),
|
536 |
+
mmap_to_tensor(self.inputs.get_weight('embed_positions.weight')[:], pin_memory=True),
|
537 |
+
],
|
538 |
+
tricksy_context.load_weight_stream,
|
539 |
+
)
|
540 |
+
|
541 |
+
def forward(self, *args, **kwargs):
|
542 |
+
# Wait for input embedding weights to get to GPU
|
543 |
+
torch.cuda.synchronize()
|
544 |
+
|
545 |
+
# While computing input embeddings, load first attention
|
546 |
+
self.layers[0].self_attn.load_weights(self.tricksy_context)
|
547 |
+
|
548 |
+
out = super().forward(*args, **kwargs)
|
549 |
+
|
550 |
+
# Wait for output embedding weights to get to GPU
|
551 |
+
torch.cuda.synchronize()
|
552 |
+
|
553 |
+
# No longer prompt phase after first full pass
|
554 |
+
self.tricksy_context.is_prompt_phase = False
|
555 |
+
# Load input embeddings while computing output
|
556 |
+
self.load_weights(self.tricksy_context)
|
557 |
+
|
558 |
+
return out
|
559 |
+
|
560 |
+
class TricksyOPTModel(OPTModel):
|
561 |
+
def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights, tricksy_opt_for_causal_lm, tricksy_context: TricksyContext):
|
562 |
+
nn.Module.__init__(self)
|
563 |
+
self.config = tricksy_config.opt_config
|
564 |
+
self.tricksy_context = tricksy_context
|
565 |
+
self.decoder = TricksyOPTDecoder(tricksy_config, disk_weights, tricksy_opt_for_causal_lm, tricksy_context)
|
566 |
+
|
567 |
+
def forward(self, *args, **kwargs):
|
568 |
+
out = super().forward(*args, **kwargs)
|
569 |
+
return out
|
570 |
+
|
571 |
+
# who's got the weights?
|
572 |
+
# [InputEmbedding, Attention.0, MLP.0, Attention.1, MLP.1, ..., OutputEmbedding]
|
573 |
+
# [TricksyOPTDecoder, TricksyOPTAttention.0, TricksyOPTDecoderLayer.0, TricksyOPTAttention.1, TricksyDecoderLayer.1, ..., TricksyOPTForCausalLM]
|
574 |
+
#
|
575 |
+
# 1. Prompt pass: Before computing layer, send full dense weights to GPU. After computing layer, only keep sparse weights on GPU.
|
576 |
+
# 2. Generation passes: Before computing layer, compute and send sparse weight diff to GPU.
|
577 |
+
class TricksyOPTForCausalLM(OPTForCausalLM, TricksyLayer):
|
578 |
+
def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights):
|
579 |
+
nn.Module.__init__(self)
|
580 |
+
self.config = disk_weights.config
|
581 |
+
self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None
|
582 |
+
|
583 |
+
self.tricksy_context = TricksyContext(tricksy_config, self.config)
|
584 |
+
self.model = TricksyOPTModel(tricksy_config, disk_weights, self, self.tricksy_context)
|
585 |
+
|
586 |
+
self.final_layer_norm_weight = self.lm_head_weight = self.final_layer_norm_bias = None
|
587 |
+
# double stacking tricksy!
|
588 |
+
self.final_layer_norm = lambda x: F.layer_norm(x, (self.config.hidden_size,), self.final_layer_norm_weight, self.final_layer_norm_bias)
|
589 |
+
self.lm_head = lambda x: F.linear(self.final_layer_norm(x), self.lm_head_weight)
|
590 |
+
|
591 |
+
self.inputs = TricksyLayerInputs(disk_weights=disk_weights, layer_key_prefix='decoder.', next_layer=self.model.decoder)
|
592 |
+
|
593 |
+
def load_weights(self, tricksy_context: TricksyContext):
|
594 |
+
if self.final_layer_norm_weight is None:
|
595 |
+
self.final_layer_norm_weight, self.lm_head_weight, self.final_layer_norm_bias = batch_copy(
|
596 |
+
[
|
597 |
+
mmap_to_tensor(self.inputs.get_weight('final_layer_norm.weight')[:], pin_memory=True),
|
598 |
+
mmap_to_tensor(self.inputs.get_weight('embed_tokens.weight')[:], pin_memory=True),
|
599 |
+
mmap_to_tensor(self.inputs.get_weight('final_layer_norm.bias')[:], pin_memory=True),
|
600 |
+
],
|
601 |
+
tricksy_context.load_weight_stream,
|
602 |
+
)
|
603 |
+
|
604 |
+
def forward(self, *args, **kwargs):
|
605 |
+
torch.cuda.synchronize()
|
606 |
+
start = time.time()
|
607 |
+
out = super().forward(*args, **kwargs)
|
608 |
+
torch.cuda.synchronize()
|
609 |
+
self.tricksy_context.forward_times.append(time.time() - start)
|
610 |
+
self.tricksy_context.layer = 0
|
611 |
+
return out
|
612 |
+
|
613 |
+
def generate(self, *args, **kwargs):
|
614 |
+
# Load input embeddings for first token
|
615 |
+
self.model.decoder.load_weights(self.tricksy_context)
|
616 |
+
torch.cuda.synchronize()
|
617 |
+
out = super().generate(*args, **kwargs)
|
618 |
+
return out
|
util.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Callable
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
np_dtype_to_torch_dtype = {
|
9 |
+
np.float16: torch.float16,
|
10 |
+
np.float32: torch.float32,
|
11 |
+
np.uint8: torch.uint8,
|
12 |
+
np.int8: torch.int8,
|
13 |
+
np.int32: torch.int32,
|
14 |
+
np.int64: torch.int64,
|
15 |
+
bool: torch.bool,
|
16 |
+
}
|
17 |
+
|
18 |
+
class IndexDiff:
|
19 |
+
def __init__(self, off_elements: torch.Tensor=None, off_positions: torch.Tensor=None, on_positions: torch.Tensor=None):
|
20 |
+
self.off_elements = off_elements
|
21 |
+
self.off_positions = off_positions
|
22 |
+
self.on_positions = on_positions
|
23 |
+
|
24 |
+
def batch_copy(sources: List[torch.Tensor], copy_stream, indices=None, device='cuda'):
|
25 |
+
with torch.cuda.stream(copy_stream):
|
26 |
+
out = ()
|
27 |
+
for src in sources:
|
28 |
+
indexed = src[indices] if indices is not None else src
|
29 |
+
dst = torch.empty(indexed.shape, device=device, dtype=src.dtype)
|
30 |
+
dst.copy_(indexed, non_blocking=True)
|
31 |
+
out += (dst,)
|
32 |
+
return out
|
33 |
+
|
34 |
+
def mmap_to_tensor(torch_wrapped_mmap, pin_memory=False) -> torch.Tensor:
|
35 |
+
out = torch.empty(torch_wrapped_mmap.shape, dtype=torch_wrapped_mmap.dtype, device='cpu', pin_memory=pin_memory)
|
36 |
+
out.copy_(torch_wrapped_mmap)
|
37 |
+
return out
|
38 |
+
|
39 |
+
# Assuming that each entry of cached_indices is a step down the memory hierarchy,
|
40 |
+
# compute the diff at each level of the hierarchy.
|
41 |
+
# e.g. the first loop computes the indices that the GPU does not have,
|
42 |
+
# and the second loop computes the indices *of that diff* that the CPU does not have.
|
43 |
+
def compute_index_diffs(new_indices: torch.Tensor, cached_indices_list: List[torch.Tensor], pin_memory=True):
|
44 |
+
diffs = []
|
45 |
+
current_diff = new_indices
|
46 |
+
for cached_indices in cached_indices_list:
|
47 |
+
if current_diff.size(0) == 0:
|
48 |
+
# No need to go further down the hierarchy
|
49 |
+
break
|
50 |
+
|
51 |
+
# Compute elements of new indices not contained current indices
|
52 |
+
off_elements = torch.tensor(
|
53 |
+
list(set(current_diff.tolist()).difference(set(cached_indices.tolist()))),
|
54 |
+
device='cpu',
|
55 |
+
dtype=torch.int32,
|
56 |
+
pin_memory=pin_memory
|
57 |
+
)
|
58 |
+
# Compute mask of current indices where new indices does not contain the element
|
59 |
+
on_position_mask = torch.isin(cached_indices, current_diff, assume_unique=True)
|
60 |
+
on_positions = torch.nonzero(on_position_mask).flatten()
|
61 |
+
off_positions = torch.nonzero(~on_position_mask).flatten()[:off_elements.size(0)]
|
62 |
+
|
63 |
+
diffs.append(IndexDiff(off_elements, off_positions, on_positions))
|
64 |
+
current_diff = off_elements
|
65 |
+
return diffs
|
66 |
+
|
67 |
+
def topk_and_threshold(x, k, threshold=1):
|
68 |
+
vals, indices = torch.topk(x, k, sorted=True)
|
69 |
+
return indices[vals > threshold].int()
|
70 |
+
|
71 |
+
def load_mlp_sparsity_predictor(weight_path_prefix: str, layer_num: int, dtype: torch.dtype, device: str = 'cuda') -> Callable:
|
72 |
+
path_prefix = f'{weight_path_prefix}decoder.layers.{layer_num}.attn.mlp-sparsity-predictor.'
|
73 |
+
return load_predictor(path_prefix, dtype, device=device)
|
74 |
+
|
75 |
+
def load_predictor(path_prefix: str, dtype: torch.dtype, device: str='cuda') -> Callable:
|
76 |
+
path = lambda i: os.path.expanduser(f'{path_prefix}{i}.weight')
|
77 |
+
if os.path.exists(path(1)):
|
78 |
+
l1 = torch.load(path(1)).to(device).to(dtype)
|
79 |
+
l2 = torch.load(path(2)).to(device).to(dtype)
|
80 |
+
return lambda x: F.linear(F.linear(x, l1), l2)
|
81 |
+
else:
|
82 |
+
print(f'could not find predictor at {path(1)}')
|
83 |
+
return None
|