kyujinpy commited on
Commit
31b1ccb
·
verified ·
1 Parent(s): 8b45c88

Upload modeling_ovis.py

Browse files
Files changed (1) hide show
  1. modeling_ovis.py +620 -0
modeling_ovis.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from packaging import version
4
+ from importlib import import_module
5
+ from typing import List, Callable, Union, Optional, Dict
6
+
7
+ import PIL.Image
8
+ import torch
9
+ import transformers
10
+ from torch import Tensor
11
+ from torch.nn import init
12
+ from torch.nn.functional import softmax, gumbel_softmax, pad
13
+ from transformers import PreTrainedModel, AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoImageProcessor
14
+ from transformers import SiglipImageProcessor, SiglipVisionModel
15
+ from transformers.cache_utils import HybridCache
16
+ from transformers.generation.utils import GenerateOutput
17
+
18
+ from .configuration_ovis import BaseVisualTokenizerConfig, SiglipVisualTokenizerConfig
19
+ from .configuration_ovis import OvisConfig, ConversationFormatter
20
+ from .configuration_ovis import IGNORE_ID, IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS, IMAGE_TOKEN_ID
21
+
22
+
23
+ # ----------------------------------------------------------------------
24
+ # Visual Tokenizer
25
+ # ----------------------------------------------------------------------
26
+ class BaseVisualTokenizer(PreTrainedModel):
27
+ base_model_prefix = "backbone"
28
+ main_input_name = None
29
+ _image_processor_class = None
30
+ _image_processor_kwargs = {}
31
+ _backbone_class = None
32
+ _backbone_name_or_path = None
33
+
34
+ def __init__(self, config: BaseVisualTokenizerConfig, *inputs, **kwargs):
35
+ super().__init__(config, *inputs, **kwargs)
36
+ self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path'])
37
+ self.backbone = AutoModel.from_config(self.config.backbone_config)
38
+ head_dim = self.config.vocab_size - len(IMAGE_INDICATOR_IDS) # reserved tokens for IMAGE_INDICATORS
39
+ self.head = torch.nn.Sequential(
40
+ torch.nn.Linear(
41
+ self.backbone.config.hidden_size * self.config.hidden_stride * self.config.hidden_stride, head_dim,
42
+ bias=False
43
+ ),
44
+ torch.nn.LayerNorm(head_dim)
45
+ )
46
+
47
+ assert all((self.image_processor.do_resize,
48
+ not getattr(self.image_processor, 'do_center_crop', False),
49
+ self.image_processor.do_rescale,
50
+ self.image_processor.do_normalize
51
+ )), f"image_processor `{self.image_processor}` is not supported currently"
52
+
53
+ def get_backbone(self):
54
+ return self.backbone
55
+
56
+ def get_image_processor(self):
57
+ return self.image_processor
58
+
59
+ def mock_input(self):
60
+ height, width = self.get_image_size()
61
+ return torch.zeros(1, 3, height, width), self.construct_image_placeholders((1, 1))
62
+
63
+ def get_head(self):
64
+ return self.head
65
+
66
+ def get_image_size(self):
67
+ raise NotImplementedError
68
+
69
+ @staticmethod
70
+ def construct_image_placeholders(grid):
71
+ image_placeholders = [IMAGE_INDICATOR_IDS[0], IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS[1]]
72
+ if grid[0] * grid[1] > 1:
73
+ for r in range(grid[0]):
74
+ for c in range(grid[1]):
75
+ image_placeholders.append(IMAGE_ATOM_ID)
76
+ if c < grid[1] - 1:
77
+ image_placeholders.append(IMAGE_INDICATOR_IDS[2])
78
+ if r < grid[0] - 1:
79
+ image_placeholders.append(IMAGE_INDICATOR_IDS[3])
80
+ image_placeholders.append(IMAGE_INDICATOR_IDS[4])
81
+ return image_placeholders
82
+
83
+ def preprocess_image(self, image: PIL.Image.Image, max_partition=9, covering_threshold=0.9, convert_to_rgb=True):
84
+ def _preprocess(img: PIL.Image.Image, side):
85
+ # first resize and preprocess
86
+ w, h = img.size
87
+ if w == h:
88
+ new_width = new_height = side
89
+ elif w > h:
90
+ new_width = side
91
+ new_height = int(h / w * new_width)
92
+ else:
93
+ new_height = side
94
+ new_width = int(w / h * new_height)
95
+ new_size = dict(height=new_height, width=new_width)
96
+ pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors='pt')['pixel_values']
97
+
98
+ # then pad to square
99
+ square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device)
100
+ new_height, new_width = pixel_values.shape[2:]
101
+ if new_height == new_width:
102
+ square_values[:, :, :, :] = pixel_values
103
+ elif new_height > new_width:
104
+ from_index = (side - new_width) // 2
105
+ square_values[:, :, :, from_index:from_index + new_width] = pixel_values
106
+ else:
107
+ from_index = (side - new_height) // 2
108
+ square_values[:, :, from_index:from_index + new_height, :] = pixel_values
109
+
110
+ return square_values
111
+
112
+ def _partition(img, grid):
113
+ w, h = img.size
114
+ row_height = h // grid[0]
115
+ col_width = w // grid[1]
116
+
117
+ partition = []
118
+ for row in range(grid[0]):
119
+ for col in range(grid[1]):
120
+ left = col * col_width
121
+ upper = row * row_height
122
+ right = w if col == grid[1] - 1 else (col + 1) * col_width
123
+ lower = h if row == grid[0] - 1 else (row + 1) * row_height
124
+ partition.append((left, upper, right, lower))
125
+
126
+ return partition
127
+
128
+ def _covering_area(left, upper, right, lower, side):
129
+ w = right - left
130
+ h = lower - upper
131
+ w, h = max(w, h), min(w, h)
132
+ if w > side:
133
+ h = h / w * side
134
+ w = side
135
+ return w * h
136
+
137
+ def _get_best_grid(img, side):
138
+ img_area = img.size[0] * img.size[1]
139
+
140
+ candidate_grids = []
141
+ for i in range(1, max_partition + 1):
142
+ for j in range(1, max_partition + 1):
143
+ if i * j <= max_partition:
144
+ candidate_grids.append((i, j))
145
+
146
+ all_grids = []
147
+ good_grids = []
148
+ for grid in candidate_grids:
149
+ partition = _partition(img, grid)
150
+ covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area
151
+ assert covering_ratio <= 1.0
152
+ all_grids.append((grid, covering_ratio))
153
+ if covering_ratio > covering_threshold:
154
+ good_grids.append((grid, covering_ratio))
155
+
156
+ if len(good_grids) > 0:
157
+ # pick the good partition with minimum #sub_images and break the tie using covering_ratio
158
+ return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0]
159
+ else:
160
+ # pick the partition with maximum covering_ratio and break the tie using #sub_images
161
+ return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
162
+
163
+ if convert_to_rgb and image.mode != 'RGB':
164
+ image = image.convert('RGB')
165
+
166
+ sides = self.get_image_size()
167
+ if sides[0] != sides[1]:
168
+ raise ValueError('get_image_size() returns non-square size')
169
+ side = sides[0]
170
+ grid = _get_best_grid(image, side)
171
+ partition = _partition(image, grid)
172
+ crops = [image.crop(p) for p in partition]
173
+ if len(crops) > 1:
174
+ crops.insert(0, image)
175
+ pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
176
+ image_placeholders = self.construct_image_placeholders(grid)
177
+ return pixel_values, image_placeholders
178
+
179
+ def tokenize(self, logits):
180
+ def st_argmax(y_soft, dim): # straight-through softmax
181
+ index = y_soft.max(dim, keepdim=True)[1]
182
+ y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
183
+ ret = y_hard - y_soft.detach() + y_soft
184
+ return ret
185
+
186
+ if self.config.tokenize_function == 'softmax':
187
+ tokens = softmax(logits, dim=-1)
188
+ elif self.config.tokenize_function == 'gumbel_argmax':
189
+ tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
190
+ elif self.config.tokenize_function == 'st_argmax':
191
+ tokens = st_argmax(logits, dim=-1)
192
+ else:
193
+ raise ValueError(
194
+ f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}')
195
+ return tokens
196
+
197
+ def encode(self, pixel_values):
198
+ output = self.backbone(pixel_values, output_hidden_states=True, return_dict=True)
199
+ features = output.hidden_states[-1]
200
+ if self.config.drop_cls_token:
201
+ features = features[:, 1:, :]
202
+
203
+ # merge number of `hidden_stride * hidden_stride` hidden states together to reduce token sequence length
204
+ # e.g., for hidden_stride=3, this leads to a token length reduction: 729 -> 81 for siglip
205
+ if self.config.hidden_stride > 1:
206
+ n, l, d = features.shape # this `d` maybe different from the above `d
207
+ sqrt_l = int(l ** 0.5)
208
+ assert sqrt_l ** 2 == l, "The token sequence length should be a perfect square."
209
+ features = features.reshape(n, sqrt_l, sqrt_l, d)
210
+ pl = (self.config.hidden_stride - (sqrt_l % self.config.hidden_stride)) % self.config.hidden_stride
211
+ features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
212
+ sqrt_l += pl
213
+ features = features.reshape(n, sqrt_l // self.config.hidden_stride, self.config.hidden_stride,
214
+ sqrt_l // self.config.hidden_stride, self.config.hidden_stride, d)
215
+ features = features.permute(0, 1, 3, 2, 4, 5) # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
216
+ features = features.flatten(3) # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
217
+ features = features.reshape(
218
+ n, -1, self.config.hidden_stride * self.config.hidden_stride * d)
219
+
220
+ return features
221
+
222
+ def forward(self, pixel_values) -> torch.Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
223
+ features = self.encode(pixel_values)
224
+ logits = self.head(features)
225
+ tokens = self.tokenize(logits)
226
+ # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with [BatchSize, #Token, 5], after
227
+ # which, tokens' shape should become [BatchSize, #Token, VocabSize]
228
+ batch_size, token_len, _ = tokens.shape
229
+ padding_tensor = torch.zeros(size=(batch_size, token_len, len(IMAGE_INDICATOR_IDS)),
230
+ dtype=tokens.dtype,
231
+ device=tokens.device,
232
+ layout=tokens.layout,
233
+ requires_grad=False)
234
+ tokens = torch.cat((tokens, padding_tensor), dim=2)
235
+ return tokens
236
+
237
+
238
+ class SiglipVisualTokenizer(BaseVisualTokenizer):
239
+ config_class = SiglipVisualTokenizerConfig
240
+ supports_gradient_checkpointing = True
241
+ _no_split_modules = ["SiglipVisionTransformer"]
242
+ _image_processor_class = SiglipImageProcessor
243
+ _image_processor_kwargs = {}
244
+ _backbone_class = SiglipVisionModel
245
+ _backbone_name_or_path = "google/siglip-so400m-patch14-384"
246
+
247
+ def get_image_size(self):
248
+ height = self.image_processor.size["height"]
249
+ width = self.image_processor.size["width"]
250
+ return height, width
251
+
252
+
253
+ AutoModel.register(SiglipVisualTokenizerConfig, SiglipVisualTokenizer)
254
+
255
+
256
+ # ----------------------------------------------------------------------
257
+ # Ovis
258
+ # ----------------------------------------------------------------------
259
+ class VisualEmbedding(torch.nn.Embedding):
260
+ def forward(self, visual_tokens: Tensor) -> Tensor:
261
+ if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
262
+ return super().forward(visual_tokens)
263
+ return torch.matmul(visual_tokens, self.weight)
264
+
265
+ def reset_parameters(self, mean=0., std=1.) -> None:
266
+ init.normal_(self.weight, mean=mean, std=std)
267
+ self._fill_padding_idx_with_zero()
268
+
269
+
270
+ class OvisPreTrainedModel(PreTrainedModel):
271
+ config_class = OvisConfig
272
+ base_model_prefix = "ovis"
273
+
274
+
275
+ class Ovis(OvisPreTrainedModel):
276
+
277
+ def __init__(self, config: OvisConfig, *inputs, **kwargs):
278
+ super().__init__(config, *inputs, **kwargs)
279
+ attn_kwargs = dict()
280
+ if self.config.llm_attn_implementation:
281
+ attn_kwargs['attn_implementation'] = self.config.llm_attn_implementation
282
+ self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
283
+ assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
284
+ self.text_tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
285
+ self.visual_tokenizer = AutoModel.from_config(self.config.visual_tokenizer_config,
286
+ image_processor_name_or_path=self.config.name_or_path)
287
+ self.vte = VisualEmbedding(
288
+ self.config.visual_tokenizer_config.vocab_size,
289
+ self.config.hidden_size,
290
+ device=self.visual_tokenizer.device,
291
+ dtype=self.visual_tokenizer.dtype
292
+ )
293
+
294
+ def _merge_modules(modules_list: tuple):
295
+ merged_modules = []
296
+ for modules in modules_list:
297
+ merged_modules.extend(modules if modules else [])
298
+ return merged_modules
299
+
300
+ self._no_split_modules = _merge_modules((self.llm._no_split_modules, self.visual_tokenizer._no_split_modules))
301
+ self._skip_keys_device_placement = self.llm._skip_keys_device_placement
302
+ self._keep_in_fp32_modules = _merge_modules(
303
+ (self.llm._keep_in_fp32_modules, self.visual_tokenizer._keep_in_fp32_modules))
304
+ self.is_parallelizable = all((self.llm.is_parallelizable, self.visual_tokenizer.is_parallelizable))
305
+ self.supports_gradient_checkpointing = all(
306
+ (self.llm.supports_gradient_checkpointing, self.visual_tokenizer.supports_gradient_checkpointing))
307
+ self._supports_flash_attn_2 = all(
308
+ (self.llm._supports_flash_attn_2, self.visual_tokenizer._supports_flash_attn_2))
309
+ self._supports_sdpa = all((self.llm._supports_sdpa, self.visual_tokenizer._supports_sdpa))
310
+
311
+ def get_text_tokenizer(self):
312
+ return self.text_tokenizer
313
+
314
+ def get_visual_tokenizer(self):
315
+ return self.visual_tokenizer
316
+
317
+ def tie_weights(self):
318
+ if not self.config.disable_tie_weight:
319
+ self.get_llm().tie_weights()
320
+
321
+ def get_llm(self):
322
+ return self.llm
323
+
324
+ def get_vte(self):
325
+ return self.vte
326
+
327
+ def get_wte(self):
328
+ return self.llm.get_input_embeddings()
329
+
330
+ def get_conversation_formatter(self) -> ConversationFormatter:
331
+ if getattr(self, 'conversation_formatter', None) is None:
332
+ self.conversation_formatter = getattr(import_module(".configuration_ovis", __package__),
333
+ self.config.conversation_formatter_class)(self.text_tokenizer)
334
+ return self.conversation_formatter
335
+
336
+ def forward(
337
+ self,
338
+ input_ids: torch.Tensor,
339
+ attention_mask: torch.Tensor,
340
+ labels: Optional[torch.Tensor],
341
+ pixel_values: List[Optional[torch.Tensor]],
342
+ **kwargs
343
+ ):
344
+ assert self.training, "`forward` can only be used in training. For inference, use `generate`."
345
+ _, inputs_embeds, labels, attention_mask = self.merge_multimodal(
346
+ text_input_ids=input_ids,
347
+ text_attention_masks=attention_mask,
348
+ text_labels=labels,
349
+ pixel_values=pixel_values
350
+ )
351
+ return self.llm(inputs_embeds=inputs_embeds, labels=labels, attention_mask=attention_mask, **kwargs)
352
+
353
+ def merge_multimodal(
354
+ self,
355
+ text_input_ids: torch.Tensor,
356
+ text_attention_masks: torch.Tensor,
357
+ text_labels: Optional[torch.Tensor],
358
+ pixel_values: List[Optional[torch.Tensor]],
359
+ left_padding: bool = False
360
+ ):
361
+ input_device = text_input_ids.device
362
+ visual_vocab_szie = self.get_visual_tokenizer().config.vocab_size
363
+ visual_indicator_embeds = self.get_vte()(
364
+ torch.tensor(
365
+ list(range(visual_vocab_szie - 5, visual_vocab_szie)),
366
+ dtype=torch.long,
367
+ device=self.get_visual_tokenizer().device
368
+ )
369
+ ).to(device=input_device)
370
+
371
+ if self.training:
372
+ # When training, to be compatible with deepspeed zero, each sample has to include pixel_value tensor.
373
+ # For text-only sample, one can simply use a full zero tensor as pixel_value, which will be ignored
374
+ # (see below in this function); so, the gradient will not be affected.
375
+ num_images = [x.shape[0] for x in pixel_values]
376
+ visual_tokens = self.visual_tokenizer(torch.cat([x for x in pixel_values], dim=0))
377
+ visual_embeds = torch.split(self.get_vte()(visual_tokens).to(dtype=self.dtype, device=input_device),
378
+ split_size_or_sections=num_images, dim=0)
379
+ visual_input_ids = torch.split(torch.argmax(visual_tokens, dim=-1).to(device=input_device),
380
+ split_size_or_sections=num_images, dim=0)
381
+ visual_labels = [torch.full(x.shape, IGNORE_ID, dtype=torch.long, device=input_device) for x in
382
+ visual_input_ids]
383
+ else:
384
+ # When inference, sample can include only text with `None` pixel_value
385
+ num_images = [x.shape[0] if x is not None else 0 for x in pixel_values]
386
+ if sum(num_images) > 0:
387
+ visual_tokens = self.visual_tokenizer(torch.cat([x for x in pixel_values if x is not None], dim=0))
388
+ visual_embeds = torch.split(self.get_vte()(visual_tokens).to(dtype=self.dtype, device=input_device),
389
+ split_size_or_sections=num_images, dim=0)
390
+ visual_input_ids = torch.split(torch.argmax(visual_tokens, dim=-1).to(device=input_device),
391
+ split_size_or_sections=num_images, dim=0)
392
+ visual_labels = [torch.full(x.shape, IGNORE_ID, dtype=torch.long, device=input_device) for x in
393
+ visual_input_ids]
394
+ else:
395
+ # just placeholders
396
+ visual_embeds = [None] * len(num_images)
397
+ visual_input_ids = [None] * len(num_images)
398
+ visual_labels = [None] * len(num_images)
399
+ if text_labels is None:
400
+ text_labels = torch.full(text_input_ids.shape, IGNORE_ID, dtype=torch.long, device=input_device)
401
+
402
+ input_embeds = []
403
+ attention_masks = []
404
+ labels = []
405
+ for text_input_id, text_label, text_attention_mask, visual_embed, visual_input_id, visual_label in zip(
406
+ text_input_ids, text_labels, text_attention_masks, visual_embeds, visual_input_ids, visual_labels
407
+ ):
408
+ placeholder_token_mask = torch.lt(text_input_id, 0)
409
+ text_embed = self.get_wte()(torch.masked_fill(text_input_id, placeholder_token_mask, 0))
410
+ for i, indicator_id in enumerate(IMAGE_INDICATOR_IDS):
411
+ text_embed[text_input_id == indicator_id] = visual_indicator_embeds[i]
412
+ image_atom_positions = torch.where(torch.eq(text_input_id, IMAGE_ATOM_ID))[0].tolist()
413
+ if len(image_atom_positions) > 0:
414
+ input_embed_parts = []
415
+ attention_mask_parts = []
416
+ label_parts = []
417
+ prev_image_atom_position = -1
418
+ for index, image_atom_position in enumerate(image_atom_positions):
419
+ input_embed_parts.append(
420
+ text_embed[prev_image_atom_position + 1:image_atom_position, :])
421
+ label_parts.append(
422
+ text_label[prev_image_atom_position + 1:image_atom_position])
423
+ attention_mask_parts.append(
424
+ text_attention_mask[prev_image_atom_position + 1:image_atom_position])
425
+ input_embed_parts.append(visual_embed[index])
426
+ attention_mask_parts.append(
427
+ torch.ones_like(visual_label[index], dtype=torch.bool))
428
+ label_parts.append(visual_label[index])
429
+ prev_image_atom_position = image_atom_position
430
+ if prev_image_atom_position + 1 < text_input_id.shape[0]:
431
+ input_embed_parts.append(
432
+ text_embed[prev_image_atom_position + 1:, :])
433
+ attention_mask_parts.append(
434
+ text_attention_mask[prev_image_atom_position + 1:])
435
+ label_parts.append(
436
+ text_label[prev_image_atom_position + 1:])
437
+ input_embed = torch.cat(input_embed_parts, dim=0)
438
+ attention_mask = torch.cat(attention_mask_parts, dim=0)
439
+ label = torch.cat(label_parts, dim=0)
440
+ else:
441
+ input_embed = text_embed
442
+ attention_mask = text_attention_mask
443
+ label = text_label
444
+ if self.training:
445
+ # Make visual_embed & visual_indicator_embeds involved in the backward graph,
446
+ # to be compatible with deepspeed zero and ddp.
447
+ input_embed += torch.sum(visual_embed * 0.0) + torch.sum(visual_indicator_embeds * 0.0)
448
+ input_embeds.append(input_embed)
449
+ attention_masks.append(attention_mask)
450
+ labels.append(label)
451
+
452
+ if self.training: # padding to self.config.multimodal_max_length for increased training speed
453
+ padding_size = max(0, self.config.multimodal_max_length - len(input_embeds[0]))
454
+ input_embeds[0] = torch.nn.ConstantPad2d((0, 0, 0, padding_size), 0.0)(input_embeds[0])
455
+ attention_masks[0] = torch.nn.ConstantPad1d((0, padding_size), False)(attention_masks[0])
456
+ labels[0] = torch.nn.ConstantPad1d((0, padding_size), IGNORE_ID)(labels[0])
457
+ batch_input_embeds = self.pad_truncate_sequence(input_embeds, batch_first=True, padding_value=0.0, left_padding=left_padding)
458
+ batch_attention_mask = self.pad_truncate_sequence(attention_masks, batch_first=True, padding_value=False, left_padding=left_padding)
459
+ batch_labels = self.pad_truncate_sequence(labels, batch_first=True, padding_value=IGNORE_ID, left_padding=left_padding)
460
+
461
+ return visual_input_ids, batch_input_embeds, batch_labels, batch_attention_mask
462
+
463
+ def pad_truncate_sequence(self, sequences: List[torch.Tensor], batch_first: bool = True, padding_value: float = 0.0, left_padding: bool = False) -> torch.Tensor:
464
+ if left_padding == False:
465
+ pad_sequence = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=batch_first, padding_value=padding_value)
466
+ return pad_sequence[:,:self.config.multimodal_max_length]
467
+ else:
468
+ pad_sequence = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in sequences],batch_first=True, padding_value=padding_value).flip(dims=[1])
469
+ return pad_sequence[:,-self.config.multimodal_max_length:]
470
+
471
+ def preprocess_inputs(
472
+ self,
473
+ text_or_conversations: Union[List[Dict], str],
474
+ images: Optional[List[PIL.Image.Image]],
475
+ max_partition=9,
476
+ generation_preface='',
477
+ return_labels=False,
478
+ propagate_exception=True
479
+ ):
480
+ # convert text to conversations
481
+ if isinstance(text_or_conversations, str):
482
+ conversations = [{
483
+ "from": "human",
484
+ "value": text_or_conversations
485
+ }]
486
+ elif isinstance(text_or_conversations, list):
487
+ conversations = text_or_conversations
488
+ else:
489
+ raise ValueError(f'Invalid type of `text_or_conversations`, expected `List[Dict]` or `str`,'
490
+ f' but got {type(text_or_conversations)}')
491
+
492
+ # format conversations
493
+ prompt, raw_input_ids, raw_labels = self.get_conversation_formatter().format(
494
+ conversations, generation_preface=generation_preface)
495
+
496
+ # place image placeholders
497
+ input_ids = []
498
+ labels = []
499
+ pixel_values = []
500
+ invalidate_label = False
501
+ image_token_indices = [i for i, v in enumerate(raw_input_ids) if v == IMAGE_TOKEN_ID]
502
+ last_image_token_index = -1
503
+ for i in range(len(image_token_indices)):
504
+ head = 0 if i == 0 else image_token_indices[i - 1] + 1
505
+ tail = image_token_indices[i]
506
+ last_image_token_index = tail
507
+ input_ids.extend(raw_input_ids[head:tail])
508
+ labels.extend(raw_labels[head:tail])
509
+ try:
510
+ image = images[i]
511
+ raw_pixel_values, image_placeholders = self.visual_tokenizer.preprocess_image(
512
+ image, max_partition=max_partition)
513
+ except Exception as e:
514
+ if propagate_exception:
515
+ raise e
516
+ logging.exception(e)
517
+ invalidate_label = True
518
+ raw_pixel_values, image_placeholders = self.visual_tokenizer.mock_input()
519
+ input_ids.extend(image_placeholders)
520
+ labels.extend([IGNORE_ID] * len(image_placeholders))
521
+ pixel_values.append(raw_pixel_values)
522
+ input_ids.extend(raw_input_ids[last_image_token_index + 1:])
523
+ labels.extend(raw_labels[last_image_token_index + 1:])
524
+
525
+ # return tensors
526
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
527
+ labels = torch.tensor([IGNORE_ID] * len(labels) if invalidate_label else labels, dtype=torch.long)
528
+ pixel_values = torch.cat(pixel_values, dim=0) if len(pixel_values) > 0 else None
529
+
530
+ if return_labels:
531
+ return prompt, input_ids, pixel_values, labels
532
+ else:
533
+ return prompt, input_ids, pixel_values
534
+
535
+ def save_pretrained(
536
+ self,
537
+ save_directory: Union[str, os.PathLike],
538
+ is_main_process: bool = True,
539
+ state_dict: Optional[dict] = None,
540
+ save_function: Callable = torch.save,
541
+ push_to_hub: bool = False,
542
+ max_shard_size: Union[int, str] = "5GB",
543
+ safe_serialization: bool = True,
544
+ variant: Optional[str] = None,
545
+ token: Optional[Union[str, bool]] = None,
546
+ save_peft_format: bool = True,
547
+ **kwargs
548
+ ):
549
+ super().save_pretrained(save_directory,
550
+ is_main_process=is_main_process,
551
+ state_dict=state_dict,
552
+ save_function=save_function,
553
+ safe_serialization=safe_serialization)
554
+ self.get_text_tokenizer().save_pretrained(save_directory)
555
+ self.get_visual_tokenizer().get_image_processor().save_pretrained(save_directory)
556
+
557
+ def _get_hybrid_cache_for_llm(self, batch_size: int, max_cache_len: int):
558
+ cache_cls = HybridCache
559
+ llm = self.get_llm()
560
+
561
+ if version.parse(transformers.__version__) >= version.parse("4.46.0"):
562
+ need_new_cache = (
563
+ not hasattr(llm, "_cache")
564
+ or (not isinstance(llm._cache, cache_cls))
565
+ or llm._cache.batch_size != batch_size
566
+ or llm._cache.max_cache_len < max_cache_len
567
+ )
568
+ else:
569
+ need_new_cache = (
570
+ not hasattr(llm, "_cache")
571
+ or (not isinstance(llm._cache, cache_cls))
572
+ or llm._cache.max_batch_size != batch_size
573
+ or llm._cache.max_cache_len < max_cache_len
574
+ )
575
+
576
+ if need_new_cache:
577
+ if hasattr(llm.config, "_pre_quantization_dtype"):
578
+ cache_dtype = llm.config._pre_quantization_dtype
579
+ else:
580
+ cache_dtype = llm.dtype
581
+ if version.parse(transformers.__version__) >= version.parse("4.46.0"):
582
+ llm._cache = cache_cls(
583
+ config=llm.config,
584
+ batch_size=batch_size,
585
+ max_cache_len=max_cache_len,
586
+ device=llm.device,
587
+ dtype=cache_dtype,
588
+ )
589
+ else:
590
+ llm._cache = cache_cls(
591
+ config=llm.config,
592
+ max_batch_size=batch_size,
593
+ max_cache_len=max_cache_len,
594
+ device=llm.device,
595
+ dtype=cache_dtype,
596
+ )
597
+ else:
598
+ llm._cache.reset()
599
+ return llm._cache
600
+
601
+ # TODO: support batch generation
602
+ def generate(
603
+ self,
604
+ inputs: Optional[torch.Tensor] = None,
605
+ **kwargs
606
+ ) -> Union[GenerateOutput, torch.LongTensor]:
607
+ _, inputs_embeds, labels, attention_mask = self.merge_multimodal(
608
+ text_input_ids=inputs,
609
+ text_attention_masks=kwargs.pop('attention_mask'),
610
+ text_labels=None,
611
+ pixel_values=kwargs.pop('pixel_values'),
612
+ left_padding=True
613
+ )
614
+ if getattr(self.generation_config, 'cache_implementation') == 'hybrid': # mainly for Gemma2
615
+ kwargs['past_key_values'] = self._get_hybrid_cache_for_llm(
616
+ getattr(kwargs, "num_beams", inputs_embeds.shape[0]), kwargs['max_new_tokens'] + inputs_embeds.shape[-2])
617
+ self.get_llm()._supports_cache_class = True
618
+ kwargs['cache_implementation'] = None
619
+
620
+ return self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)