Wendy-Fly commited on
Commit
af80e72
·
verified ·
1 Parent(s): 7d96ba0

Upload mgie_llava.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mgie_llava.py +447 -0
mgie_llava.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/model/llava.py
6
+
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.nn import CrossEntropyLoss
13
+
14
+ from transformers import AutoConfig, AutoModelForCausalLM, \
15
+ LlamaConfig, LlamaModel, LlamaForCausalLM, \
16
+ CLIPVisionModel, CLIPImageProcessor
17
+
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+
20
+ import os, diffusers
21
+
22
+ DEFAULT_IMAGE_TOKEN = "<image>"
23
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
24
+ DEFAULT_IM_START_TOKEN = "<im_start>"
25
+ DEFAULT_IM_END_TOKEN = "<im_end>"
26
+
27
+
28
+ class LlavaConfig(LlamaConfig):
29
+ model_type = "llavaa"
30
+
31
+
32
+ class LlavaLlamaModel_(LlamaModel):
33
+ config_class = LlavaConfig
34
+
35
+ def __init__(self, config: LlamaConfig):
36
+ super(LlavaLlamaModel_, self).__init__(config)
37
+
38
+ if hasattr(config, "mm_vision_tower"):
39
+ # HACK: for FSDP
40
+ self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
41
+ # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
42
+
43
+ if hasattr(config, "use_mm_proj"):
44
+ self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
45
+
46
+ def get_vision_tower(self):
47
+ vision_tower = getattr(self, 'vision_tower', None)
48
+ if type(vision_tower) is list:
49
+ vision_tower = vision_tower[0]
50
+ return vision_tower
51
+
52
+ def initialize_vision_modules(self, vision_tower, mm_vision_select_layer,
53
+ pretrain_mm_mlp_adapter=None, fsdp=None):
54
+ self.config.mm_vision_tower = vision_tower
55
+
56
+ image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
57
+
58
+ if not hasattr(self, 'vision_tower'):
59
+ vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
60
+ else:
61
+ vision_tower = self.vision_tower[0]
62
+ vision_tower.requires_grad_(False)
63
+
64
+ if fsdp is not None and len(fsdp) > 0:
65
+ self.vision_tower = [vision_tower]
66
+ else:
67
+ self.vision_tower = vision_tower
68
+
69
+ vision_config = vision_tower.config
70
+ num_patches = (vision_config.image_size // vision_config.patch_size) ** 2
71
+
72
+ self.config.use_mm_proj = True
73
+ self.config.mm_hidden_size = vision_config.hidden_size
74
+ self.config.mm_vision_select_layer = mm_vision_select_layer
75
+
76
+ if not hasattr(self, 'mm_projector'):
77
+ self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size)
78
+
79
+ if pretrain_mm_mlp_adapter is not None:
80
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
81
+ self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
82
+
83
+ return dict(
84
+ image_processor=image_processor,
85
+ image_token_len=num_patches,
86
+ vision_config=vision_config
87
+ )
88
+
89
+ def forward(
90
+ self,
91
+ input_ids: torch.LongTensor = None,
92
+ attention_mask: Optional[torch.Tensor] = None,
93
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
94
+ inputs_embeds: Optional[torch.FloatTensor] = None,
95
+ use_cache: Optional[bool] = None,
96
+ output_attentions: Optional[bool] = None,
97
+ output_hidden_states: Optional[bool] = None,
98
+ images: Optional[torch.FloatTensor] = None,
99
+ return_dict: Optional[bool] = None,
100
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
101
+
102
+ # HACK: replace back original embeddings for LLaVA pretraining
103
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
104
+ # if orig_embeds_params is not None:
105
+ # orig_embeds_params = orig_embeds_params[0]
106
+ # with torch.no_grad():
107
+ # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data
108
+
109
+ if inputs_embeds is None:
110
+ # 此处的inputs_ids代表了将语言和图片tokenizer之后的词
111
+ inputs_embeds = self.embed_tokens(input_ids)
112
+
113
+ vision_tower = self.get_vision_tower()# 一个CLIPVisionModel
114
+ if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
115
+ # TODO: this is a modified multimodal LLM -- Haotian Liu
116
+ with torch.no_grad():
117
+ if type(images) is list:
118
+ # variable length images
119
+ image_features = []
120
+ for image in images:
121
+ # 将图片经过CLIP转化为视觉特征
122
+ image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True)
123
+ # 去除config中 mm_vision_select_layer这一层的输出如果没有返回最后��层!
124
+ # 返回的select_hidden_state_layer 返回的实际上是 某个类!
125
+ select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
126
+ select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
127
+ # 对于特征的选取 --- 从第一列开始 ????
128
+ image_feature = select_hidden_state[:, 1:]
129
+ image_features.append(image_feature)
130
+ else:
131
+ image_forward_outs = vision_tower(images.to(vision_tower.dtype), output_hidden_states=True)
132
+ select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
133
+ select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
134
+ image_features = select_hidden_state[:, 1:].to(images.dtype)
135
+
136
+ # 做一个mlp --- real_img_features
137
+ if type(images) is list:
138
+ image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features]
139
+ else:
140
+ image_features = self.mm_projector(image_features)
141
+
142
+ # 256, 1024 零向量 ---- None
143
+ dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
144
+ dummy_image_features = self.mm_projector(dummy_image_features)
145
+
146
+ new_input_embeds = []
147
+ cur_image_idx = 0
148
+ # 进行多模态的合并! 现在文本的tokens 以及embedding
149
+ for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
150
+ # check : 在文本的tokens是否有图像的tokens
151
+ if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
152
+ # multimodal LLM, 但是现在的样本的单模态
153
+ cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
154
+ new_input_embeds.append(cur_input_embeds)
155
+ cur_image_idx += 1
156
+ continue
157
+ if vision_tower.config.use_im_start_end:
158
+ # 取出图像特征
159
+ cur_image_features = image_features[cur_image_idx]
160
+ # 对应图像被达成patch 之后的数量!
161
+ num_patches = cur_image_features.shape[0]
162
+ # 保持处理图像输入和输出的tokens数量一致 !
163
+ if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
164
+ raise ValueError("The number of image start tokens and image end tokens should be the same.")
165
+ # find 第一个 图像特征所在位置!
166
+ image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0]
167
+ for image_start_token_pos in image_start_tokens:
168
+ cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device)
169
+ num_patches = cur_image_features.shape[0]
170
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
171
+ raise ValueError("The image end token should follow the image start token.")
172
+ if orig_embeds_params is not None:
173
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(),# 获取从序列开始到图像开始标记之前的文本嵌入
174
+ cur_input_embeds[image_start_token_pos:image_start_token_pos+1],# 获取图像开始标记对应的嵌入。
175
+ cur_image_features, # 获取图像特征嵌入
176
+ cur_input_embeds[image_start_token_pos + num_patches + 1:
177
+ image_start_token_pos + num_patches + 2],# 获取图像结束标记对应的嵌入
178
+ cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()),# 获取从图像结束标记之后到序列结束的文本嵌入
179
+ dim=0)
180
+ else:
181
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1],#获取从序列开始到图像开始标记(包括开始标记)的文本嵌入
182
+ cur_image_features,# 获取图像特征嵌入
183
+ cur_input_embeds[image_start_token_pos + num_patches + 1:])# 获取结束到尾部文本的embed
184
+ ,dim=0)
185
+ cur_image_idx += 1
186
+ new_input_embeds.append(cur_new_input_embeds)
187
+ else:
188
+ cur_image_features = image_features[cur_image_idx]
189
+ num_patches = cur_image_features.shape[0]
190
+ if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
191
+ raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
192
+ masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0]
193
+ mask_index_start = masked_indices[0]
194
+ if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
195
+ raise ValueError("The image patch tokens should be consecutive.")
196
+ if orig_embeds_params is not None:
197
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
198
+ else:
199
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
200
+ new_input_embeds.append(cur_new_input_embeds)
201
+ cur_image_idx += 1
202
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
203
+
204
+ return super(LlavaLlamaModel_, self).forward(
205
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
206
+ inputs_embeds=inputs_embeds, use_cache=use_cache,
207
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
208
+ return_dict=return_dict
209
+ )
210
+
211
+ class EditMapper(nn.Module):
212
+ '''
213
+ self.query ???
214
+
215
+ '''
216
+ def __init__(self):
217
+ super().__init__()
218
+
219
+ self.llm2hid = nn.Linear(4096, 512)
220
+ self.query = nn.Parameter(torch.randn(1, 77, 512))
221
+ self.mapper = nn.Transformer(batch_first=True, norm_first=True,
222
+ d_model=512, nhead=4, num_encoder_layers=4, num_decoder_layers=4,
223
+ dim_feedforward=2048, dropout=0.0)
224
+ self.hid2feat = nn.Linear(512, 768)
225
+
226
+ def forward(self, llm, emb):
227
+ hid = self.llm2hid(llm+emb)
228
+ hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
229
+ feat = self.hid2feat(hid)
230
+
231
+ return feat
232
+
233
+ class LlavaLlamaForCausalLM_(LlamaForCausalLM):
234
+ config_class = LlavaConfig
235
+
236
+ def __init__(self, config):
237
+ super(LlamaForCausalLM, self).__init__(config)
238
+ self.model = LlavaLlamaModel_(config)
239
+
240
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
241
+
242
+ self.edit_head = EditMapper()
243
+
244
+ self.scheduler, self.vae, self.unet = [diffusers.DDPMScheduler.from_pretrained('/root/autodl-tmp/_ckpt/stable_diffusion', subfolder='scheduler'),
245
+ diffusers.AutoencoderKL.from_pretrained('/root/autodl-tmp/_ckpt/stable_diffusion', subfolder='vae'),
246
+ diffusers.UNet2DConditionModel.from_pretrained('/root/autodl-tmp/_ckpt/stable_diffusion', subfolder='unet')]
247
+ self.vae.requires_grad_(False)
248
+ self.unet.register_to_config(in_channels=8)
249
+ with torch.no_grad():
250
+ conv = torch.nn.Conv2d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
251
+ conv.weight.zero_()
252
+ conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
253
+ self.unet.conv_in = conv
254
+
255
+ # Initialize weights and apply final processing
256
+ self.post_init()
257
+
258
+ def get_model(self):
259
+ return self.model
260
+
261
+ def get_vision_tower(self):
262
+ return self.get_model().get_vision_tower()
263
+
264
+ def get_vision_tower(self):
265
+ model = self.get_model()
266
+ vision_tower = model.vision_tower
267
+ if type(vision_tower) is list:
268
+ vision_tower = vision_tower[0]
269
+ return vision_tower
270
+
271
+ def forward(
272
+ self,
273
+ input_ids: torch.LongTensor = None,
274
+ attention_mask: Optional[torch.Tensor] = None,
275
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
276
+ inputs_embeds: Optional[torch.FloatTensor] = None,
277
+ labels: Optional[torch.LongTensor] = None,
278
+ use_cache: Optional[bool] = None,
279
+ output_attentions: Optional[bool] = None,
280
+ output_hidden_states: Optional[bool] = None,
281
+ images: Optional[torch.FloatTensor] = None,
282
+ return_dict: Optional[bool] = None,
283
+ p2p_inp=None, p2p_ans=None
284
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
285
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
286
+ output_hidden_states = (
287
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
288
+ )
289
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
290
+
291
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
292
+ outputs = self.model(
293
+ input_ids=input_ids,
294
+ attention_mask=
295
+ attention_mask,
296
+ past_key_values=
297
+ past_key_values,
298
+ inputs_embeds=
299
+ inputs_embeds,
300
+ use_cache=use_cache,
301
+ output_attentions=
302
+ output_attentions,
303
+ output_hidden_states=
304
+ output_hidden_states,
305
+ return_dict=return_dict,
306
+ images=images
307
+ )
308
+
309
+ hidden_states = outputs[0]
310
+ logits = self.lm_head(hidden_states)
311
+
312
+ loss = None
313
+ if labels is not None:
314
+ # Shift so that tokens < n predict n
315
+ # 由于模型自回归的训练依次训练会使得模型预测出的长度多一,一般我们不会将最后一个时间步纳入计算
316
+ shift_logits = logits[..., :-1, :].contiguous()
317
+ # 对于label而言,我们在分词的时候要加入初始符 BOS等等所以也不给予考虑
318
+ shift_labels = labels[..., 1:].contiguous()
319
+ # Flatten the tokens
320
+ # 定义损失函数并且将二者的向量都铺平成一维向量
321
+ loss_fct = CrossEntropyLoss()
322
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
323
+ shift_labels = shift_labels.view(-1)
324
+ # Enable model/pipeline parallelism
325
+ shift_labels = shift_labels.to(shift_logits.device)
326
+ loss = loss_fct(shift_logits, shift_labels)
327
+
328
+ if labels is not None:
329
+ llm = []
330
+ # 总共有多少个tokens
331
+ for i in range(labels.shape[0]):
332
+ try: p = labels[i].data.cpu().tolist().index(32003)-1
333
+ except: p = len(labels[i])-9
334
+ p = min(len(hidden_states[i])-9, p)
335
+ llm.append(hidden_states[i][p:p+8].unsqueeze(0))
336
+ llm = torch.cat(llm, dim=0)
337
+ hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
338
+
339
+ B, DROP = labels.shape[0], 0.05
340
+
341
+ hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device),
342
+ self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
343
+
344
+ with torch.no_grad():
345
+ lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, \
346
+ self.vae.encode(p2p_inp).latent_dist.mode()
347
+ lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
348
+ torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
349
+
350
+ noise = torch.randn_like(lat_ans)
351
+ ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long()
352
+ lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
353
+
354
+ ## 实现两个mask — hid_edit
355
+ prob = torch.rand(B, device=lat_ans.device)
356
+ mask = (prob<(DROP*2)).reshape(B, 1, 1)
357
+ hid_edit = torch.where(mask, hid_null, hid_edit)
358
+ ## mask -- img_input
359
+ mask = (1.0-((prob>=DROP).to(lat_inp.dtype)*(prob<(DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
360
+ lat_inp *= mask
361
+
362
+ out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
363
+
364
+ loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
365
+ if int(os.environ['LOCAL_RANK'])==0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
366
+ loss = loss_ce+loss_edit*0.5
367
+
368
+ if not return_dict:
369
+ output = (logits,) + outputs[1:]
370
+ return (loss,) + output if loss is not None else output
371
+
372
+ return CausalLMOutputWithPast(
373
+ loss=loss,
374
+ logits=logits,
375
+ past_key_values=outputs.past_key_values,
376
+ hidden_states=outputs.hidden_states,
377
+ attentions=outputs.attentions,
378
+ )
379
+
380
+ def prepare_inputs_for_generation(
381
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
382
+ ):
383
+ if past_key_values:
384
+ input_ids = input_ids[:, -1:]
385
+
386
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
387
+ if inputs_embeds is not None and past_key_values is None:
388
+ model_inputs = {"inputs_embeds": inputs_embeds}
389
+ else:
390
+ model_inputs = {"input_ids": input_ids}
391
+
392
+ model_inputs.update(
393
+ {
394
+ "past_key_values": past_key_values,
395
+ "use_cache": kwargs.get("use_cache"),
396
+ "attention_mask": attention_mask,
397
+ "images": kwargs.get("images", None),
398
+ }
399
+ )
400
+ return model_inputs
401
+
402
+ def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
403
+ tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
404
+ vision_config = self.get_vision_tower().config
405
+ vision_config.use_im_start_end = mm_use_im_start_end
406
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
407
+ self.resize_token_embeddings(len(tokenizer))
408
+
409
+ if mm_use_im_start_end:
410
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
411
+ self.resize_token_embeddings(len(tokenizer))
412
+ vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
413
+
414
+ if num_new_tokens > 0:
415
+ input_embeddings = self.get_input_embeddings().weight.data
416
+ output_embeddings = self.get_output_embeddings().weight.data
417
+
418
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
419
+ dim=0, keepdim=True)
420
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
421
+ dim=0, keepdim=True)
422
+
423
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
424
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
425
+
426
+ if tune_mm_mlp_adapter:
427
+ self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
428
+ for p in self.get_input_embeddings().parameters():
429
+ p.requires_grad = True
430
+ for p in self.get_output_embeddings().parameters():
431
+ p.requires_grad = False
432
+
433
+ if pretrain_mm_mlp_adapter:
434
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
435
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
436
+ assert num_new_tokens == 2
437
+ if input_embeddings.shape == embed_tokens_weight.shape:
438
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
439
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
440
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
441
+ else:
442
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
443
+
444
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
445
+
446
+ AutoConfig.register("llavaa", LlavaConfig)
447
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM_)