Upload blender_model.py
Browse files- blender_model.py +207 -0
blender_model.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import (
|
2 |
+
AutoConfig,
|
3 |
+
BlenderbotSmallForConditionalGeneration,
|
4 |
+
logging
|
5 |
+
)
|
6 |
+
from transformers.modeling_outputs import (
|
7 |
+
Seq2SeqLMOutput,
|
8 |
+
BaseModelOutput,
|
9 |
+
)
|
10 |
+
from huggingface_hub import hf_hub_url, cached_download
|
11 |
+
from onnxruntime import (GraphOptimizationLevel,
|
12 |
+
InferenceSession,
|
13 |
+
SessionOptions)
|
14 |
+
|
15 |
+
from torch import from_numpy
|
16 |
+
from torch.nn import Module
|
17 |
+
from functools import reduce
|
18 |
+
from operator import iconcat
|
19 |
+
|
20 |
+
#supress huggingface warnings
|
21 |
+
logging.set_verbosity_error()
|
22 |
+
|
23 |
+
model_vocab_size=30000
|
24 |
+
model_card="remzicam/xs_blenderbot_onnx"
|
25 |
+
model_file_names=["blenderbot_small-90M-encoder-quantized.onnx",
|
26 |
+
"blenderbot_small-90M-decoder-quantized.onnx",
|
27 |
+
"blenderbot_small-90M-init-decoder-quantized.onnx"]
|
28 |
+
|
29 |
+
class BlenderEncoder(Module):
|
30 |
+
def __init__(self, encoder_sess):
|
31 |
+
super().__init__()
|
32 |
+
self.encoder = encoder_sess
|
33 |
+
|
34 |
+
def forward(
|
35 |
+
self,
|
36 |
+
input_ids,
|
37 |
+
attention_mask,
|
38 |
+
inputs_embeds=None,
|
39 |
+
head_mask=None,
|
40 |
+
output_attentions=None,
|
41 |
+
output_hidden_states=None,
|
42 |
+
return_dict=None,
|
43 |
+
):
|
44 |
+
|
45 |
+
encoder_hidden_state = from_numpy(
|
46 |
+
self.encoder.run(
|
47 |
+
None,
|
48 |
+
{
|
49 |
+
"input_ids": input_ids.cpu().numpy(),
|
50 |
+
"attention_mask": attention_mask.cpu().numpy(),
|
51 |
+
},
|
52 |
+
)[0]
|
53 |
+
)
|
54 |
+
|
55 |
+
return BaseModelOutput(encoder_hidden_state)
|
56 |
+
|
57 |
+
|
58 |
+
class BlenderDecoderInit(Module):
|
59 |
+
def __init__(self, decoder_sess):
|
60 |
+
super().__init__()
|
61 |
+
self.decoder = decoder_sess
|
62 |
+
|
63 |
+
def forward(self, input_ids, encoder_attention_mask, encoder_hidden_states):
|
64 |
+
|
65 |
+
decoder_outputs = self.decoder.run(
|
66 |
+
None,
|
67 |
+
{
|
68 |
+
"input_ids": input_ids.cpu().numpy(),
|
69 |
+
"encoder_attention_mask": encoder_attention_mask.cpu().numpy(),
|
70 |
+
"encoder_hidden_states": encoder_hidden_states.cpu().numpy(),
|
71 |
+
},
|
72 |
+
)
|
73 |
+
|
74 |
+
list_pkv = tuple(from_numpy(x) for x in decoder_outputs[1:])
|
75 |
+
|
76 |
+
out_past_key_values = tuple(
|
77 |
+
list_pkv[i : i + 4] for i in range(0, len(list_pkv), 4)
|
78 |
+
)
|
79 |
+
|
80 |
+
return from_numpy(decoder_outputs[0]), out_past_key_values
|
81 |
+
|
82 |
+
|
83 |
+
class BlenderDecoder(Module):
|
84 |
+
def __init__(self, decoder_sess):
|
85 |
+
super().__init__()
|
86 |
+
self.decoder = decoder_sess
|
87 |
+
|
88 |
+
def forward(self, input_ids, attention_mask, encoder_output, past_key_values):
|
89 |
+
|
90 |
+
decoder_inputs = {
|
91 |
+
"input_ids": input_ids.cpu().numpy(),
|
92 |
+
"encoder_attention_mask": attention_mask.cpu().numpy(),
|
93 |
+
}
|
94 |
+
|
95 |
+
flat_past_key_values = reduce(iconcat, past_key_values, [])
|
96 |
+
|
97 |
+
past_key_values = {
|
98 |
+
f"pkv_{i}": pkv.cpu().numpy() for i, pkv in enumerate(flat_past_key_values)
|
99 |
+
}
|
100 |
+
|
101 |
+
decoder_outputs = self.decoder.run(None, {**decoder_inputs, **past_key_values})
|
102 |
+
# converts each value of the list to tensor from numpy
|
103 |
+
list_pkv = tuple(from_numpy(x) for x in decoder_outputs[1:])
|
104 |
+
|
105 |
+
# creates a tuple of tuples of shape 6x4 from the above tuple
|
106 |
+
out_past_key_values = tuple(
|
107 |
+
list_pkv[i : i + 4] for i in range(0, len(list_pkv), 4)
|
108 |
+
)
|
109 |
+
|
110 |
+
return from_numpy(decoder_outputs[0]), out_past_key_values
|
111 |
+
|
112 |
+
|
113 |
+
class OnnxBlender(BlenderbotSmallForConditionalGeneration):
|
114 |
+
"""creates a Blender model using onnx sessions (encode, decoder & init_decoder)"""
|
115 |
+
|
116 |
+
def __init__(self, onnx_model_sessions):
|
117 |
+
config = AutoConfig.from_pretrained("facebook/blenderbot_small-90M")
|
118 |
+
config.vocab_size=model_vocab_size
|
119 |
+
super().__init__(config)
|
120 |
+
|
121 |
+
assert len(onnx_model_sessions) == 3, "all three models should be given"
|
122 |
+
|
123 |
+
encoder_sess, decoder_sess, decoder_sess_init = onnx_model_sessions
|
124 |
+
|
125 |
+
self.encoder = BlenderEncoder(encoder_sess)
|
126 |
+
self.decoder = BlenderDecoder(decoder_sess)
|
127 |
+
self.decoder_init = BlenderDecoderInit(decoder_sess_init)
|
128 |
+
|
129 |
+
def get_encoder(self):
|
130 |
+
return self.encoder
|
131 |
+
|
132 |
+
def get_decoder(self):
|
133 |
+
return self.decoder
|
134 |
+
|
135 |
+
def forward(
|
136 |
+
self,
|
137 |
+
input_ids=None,
|
138 |
+
attention_mask=None,
|
139 |
+
decoder_input_ids=None,
|
140 |
+
decoder_attention_mask=None,
|
141 |
+
head_mask=None,
|
142 |
+
decoder_head_mask=None,
|
143 |
+
cross_attn_head_mask=None,
|
144 |
+
encoder_outputs=None,
|
145 |
+
past_key_values=None,
|
146 |
+
inputs_embeds=None,
|
147 |
+
decoder_inputs_embeds=None,
|
148 |
+
labels=None,
|
149 |
+
use_cache=None,
|
150 |
+
output_attentions=None,
|
151 |
+
output_hidden_states=None,
|
152 |
+
return_dict=None,
|
153 |
+
):
|
154 |
+
|
155 |
+
encoder_hidden_states = encoder_outputs[0]
|
156 |
+
|
157 |
+
if past_key_values is not None:
|
158 |
+
if decoder_input_ids is not None:
|
159 |
+
decoder_input_ids = decoder_input_ids[:, -1:]
|
160 |
+
if decoder_inputs_embeds is not None:
|
161 |
+
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
|
162 |
+
|
163 |
+
if past_key_values is None:
|
164 |
+
|
165 |
+
# runs only for the first time:
|
166 |
+
init_onnx_outputs = self.decoder_init(
|
167 |
+
decoder_input_ids, attention_mask, encoder_hidden_states
|
168 |
+
)
|
169 |
+
|
170 |
+
logits, past_key_values = init_onnx_outputs
|
171 |
+
|
172 |
+
else:
|
173 |
+
|
174 |
+
onnx_outputs = self.decoder(
|
175 |
+
decoder_input_ids,
|
176 |
+
attention_mask,
|
177 |
+
encoder_hidden_states,
|
178 |
+
past_key_values,
|
179 |
+
)
|
180 |
+
|
181 |
+
logits, past_key_values = onnx_outputs
|
182 |
+
|
183 |
+
return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
|
184 |
+
|
185 |
+
class ModelLoad:
|
186 |
+
def __init__(self, model_card,file_names):
|
187 |
+
self.model_card=model_card
|
188 |
+
self.file_names=file_names
|
189 |
+
|
190 |
+
def model_file_downloader(self,model_card,filename):
|
191 |
+
config_file_url = hf_hub_url(model_card, filename)
|
192 |
+
model_file = cached_download(config_file_url)
|
193 |
+
return model_file
|
194 |
+
|
195 |
+
def inference_session(self,file_name):
|
196 |
+
model_file=self.model_file_downloader(self.model_card,file_name)
|
197 |
+
options = SessionOptions()
|
198 |
+
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
199 |
+
return InferenceSession(model_file,options=options)
|
200 |
+
|
201 |
+
def __call__(self,model_config):
|
202 |
+
model=model_config([*map(self.inference_session,
|
203 |
+
self.file_names)])
|
204 |
+
return model
|
205 |
+
|
206 |
+
model_loader=ModelLoad(model_card,model_file_names)
|
207 |
+
blender_onnx_model=model_loader(OnnxBlender)
|