SVECTOR-OFFICIAL commited on
Commit
a643437
·
verified ·
1 Parent(s): 7682f1f

Upload processing_spec_vision.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. processing_spec_vision.py +367 -0
processing_spec_vision.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Processor class for Spec-Vision.
18
+ """
19
+
20
+ import re
21
+ from typing import List, Optional, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torchvision
26
+ from PIL import Image
27
+ from transformers import AutoImageProcessor
28
+ from transformers.feature_extraction_utils import BatchFeature
29
+ from transformers.image_processing_utils import BaseImageProcessor
30
+ from transformers.image_transforms import convert_to_rgb
31
+ from transformers.image_utils import (OPENAI_CLIP_MEAN, OPENAI_CLIP_STD,
32
+ ImageInput, make_list_of_images,
33
+ valid_images)
34
+ from transformers.processing_utils import ProcessorMixin
35
+ from transformers.tokenization_utils_base import (PaddingStrategy, TextInput,
36
+ TruncationStrategy)
37
+ from transformers.utils import TensorType, is_vision_available, logging
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ def padding_336(image):
42
+ """Apply padding to make height a multiple of 336 while preserving aspect ratio."""
43
+ width, height = image.size
44
+ target_height = int(np.ceil(height / 336) * 336)
45
+ top_padding = int((target_height - height) / 2)
46
+ bottom_padding = target_height - height - top_padding
47
+ padded_image = torchvision.transforms.functional.pad(
48
+ image,
49
+ [0, top_padding, 0, bottom_padding],
50
+ fill=[255, 255, 255]
51
+ )
52
+ return padded_image
53
+
54
+ def calc_padded_size(width, height, padding_unit=336):
55
+ """Calculate the padded dimensions for an image."""
56
+ target_height = int(np.ceil(height / padding_unit) * padding_unit)
57
+ padded_width = width
58
+ padded_height = target_height
59
+ return padded_width, padded_height
60
+
61
+ def hd_transform(img, hd_num=16):
62
+ """Apply HD transformation with support for Spec-Vision's requirements."""
63
+ width, height = img.size
64
+ transposed = False
65
+
66
+ # Handle portrait images by transposing
67
+ if width < height:
68
+ img = img.transpose(Image.TRANSPOSE)
69
+ width, height = img.size
70
+ transposed = True
71
+
72
+ ratio = width / height
73
+ scale = 1
74
+ while scale * np.ceil(scale / ratio) <= hd_num:
75
+ scale += 1
76
+ scale -= 1
77
+
78
+ new_width = int(scale * 336)
79
+ new_height = int(new_width / ratio)
80
+
81
+ # Resize and pad
82
+ img = torchvision.transforms.functional.resize(img, [new_height, new_width])
83
+ img = padding_336(img)
84
+
85
+ # Restore original orientation if needed
86
+ if transposed:
87
+ img = img.transpose(Image.TRANSPOSE)
88
+
89
+ return img
90
+
91
+ def pad_to_max_crops(images, max_crops=5):
92
+ """Pad batch of images to have consistent number of crops."""
93
+ B, _, H, W = images.shape
94
+ if B < max_crops:
95
+ padding = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device)
96
+ images = torch.cat([images, padding], dim=0)
97
+ return images
98
+
99
+ class SpecVisionImageProcessor(BaseImageProcessor):
100
+ """
101
+ Image processor for Spec-Vision model.
102
+
103
+ This processor handles the preparation of images for the Spec-Vision model, including:
104
+ - HD transformation for high-resolution image processing
105
+ - Multi-crop processing with configurable number of crops
106
+ - Normalization and padding
107
+ """
108
+
109
+ model_input_names = ["pixel_values"]
110
+
111
+ def __init__(
112
+ self,
113
+ num_crops: int = 1,
114
+ image_mean: Optional[Union[float, List[float]]] = None,
115
+ image_std: Optional[Union[float, List[float]]] = None,
116
+ do_convert_rgb: bool = True,
117
+ hd_transform_order: str = "sub_glb",
118
+ **kwargs,
119
+ ) -> None:
120
+ super().__init__(**kwargs)
121
+ self.num_crops = num_crops
122
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
123
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
124
+ self.do_convert_rgb = do_convert_rgb
125
+ self.hd_transform_order = hd_transform_order
126
+
127
+ def calc_num_image_tokens(self, images: ImageInput) -> List[int]:
128
+ """Calculate number of image tokens needed for each image."""
129
+ images = make_list_of_images(images)
130
+ if not valid_images(images):
131
+ raise ValueError("Invalid image type provided")
132
+
133
+ images = [image.convert('RGB') for image in images]
134
+ transformed_images = [hd_transform(im, hd_num=self.num_crops) for im in images]
135
+ shapes = [[im.size[1], im.size[0]] for im in transformed_images]
136
+
137
+ # Calculate tokens based on Spec-Vision's architecture
138
+ num_img_tokens = [
139
+ int((h//336 * w//336 + 1) * 144 + 1 + (h//336 + 1) * 12)
140
+ for h, w in shapes
141
+ ]
142
+ return num_img_tokens
143
+
144
+ def preprocess(
145
+ self,
146
+ images: ImageInput,
147
+ image_mean: Optional[Union[float, List[float]]] = None,
148
+ image_std: Optional[Union[float, List[float]]] = None,
149
+ do_convert_rgb: bool = None,
150
+ return_tensors: Optional[Union[str, TensorType]] = None,
151
+ ) -> BatchFeature:
152
+ """
153
+ Preprocess images for the Spec-Vision model.
154
+
155
+ Handles HD transformation, normalization, and proper formatting of images
156
+ according to Spec-Vision's requirements.
157
+ """
158
+ image_mean = image_mean if image_mean is not None else self.image_mean
159
+ image_std = image_std if image_std is not None else self.image_std
160
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
161
+
162
+ # Validate and prepare images
163
+ images = make_list_of_images(images)
164
+ if not valid_images(images):
165
+ raise ValueError("Invalid image type provided")
166
+
167
+ if do_convert_rgb:
168
+ images = [convert_to_rgb(image) for image in images]
169
+
170
+ # Create image processor pipeline
171
+ img_processor = torchvision.transforms.Compose([
172
+ torchvision.transforms.ToTensor(),
173
+ torchvision.transforms.Normalize(image_mean, image_std)
174
+ ])
175
+
176
+ # Process images according to Spec-Vision's HD transform requirements
177
+ images = [image.convert('RGB') for image in images]
178
+ transformed_images = [hd_transform(im, hd_num=self.num_crops) for im in images]
179
+
180
+ # Convert to tensors and normalize
181
+ hd_images = [img_processor(im) for im in transformed_images]
182
+
183
+ # Create global views
184
+ global_images = [
185
+ torch.nn.functional.interpolate(
186
+ im.unsqueeze(0).float(),
187
+ size=(336, 336),
188
+ mode='bicubic'
189
+ ).to(im.dtype)
190
+ for im in hd_images
191
+ ]
192
+
193
+ # Process shapes and calculate tokens
194
+ shapes = [[im.size(1), im.size(2)] for im in hd_images]
195
+ num_img_tokens = [
196
+ int(((h//336) * (w//336) + 1) * 144 + 1 + (h//336 + 1) * 12)
197
+ for h, w in shapes
198
+ ]
199
+
200
+ # Reshape images according to Spec-Vision's requirements
201
+ hd_images_reshaped = [
202
+ im.reshape(1, 3, h//336, 336, w//336, 336)
203
+ .permute(0, 2, 4, 1, 3, 5)
204
+ .reshape(-1, 3, 336, 336)
205
+ .contiguous()
206
+ for im, (h, w) in zip(hd_images, shapes)
207
+ ]
208
+
209
+ # Combine global and local views based on transform order
210
+ if self.hd_transform_order == "sub_glb":
211
+ processed_images = [
212
+ torch.cat([_im, _global_image], dim=0)
213
+ for _global_image, _im in zip(global_images, hd_images_reshaped)
214
+ ]
215
+ else: # glb_sub
216
+ processed_images = [
217
+ torch.cat([_global_image, _im], dim=0)
218
+ for _global_image, _im in zip(global_images, hd_images_reshaped)
219
+ ]
220
+
221
+ # Pad to consistent number of crops
222
+ image_batch = [
223
+ pad_to_max_crops(im, self.num_crops + 1)
224
+ for im in processed_images
225
+ ]
226
+ image_batch = torch.stack(image_batch, dim=0)
227
+
228
+ return BatchFeature(
229
+ data={
230
+ "pixel_values": image_batch,
231
+ "image_sizes": shapes,
232
+ "num_img_tokens": num_img_tokens
233
+ },
234
+ tensor_type=return_tensors
235
+ )
236
+
237
+ class SpecVisionProcessor(ProcessorMixin):
238
+ """
239
+ Combined processor for Spec-Vision model, handling both image and text inputs.
240
+
241
+ Combines SpecVisionImageProcessor for images and a tokenizer for text,
242
+ coordinating their interaction for multi-modal inputs.
243
+ """
244
+
245
+ attributes = ["image_processor", "tokenizer"]
246
+ image_processor_class = "SpecVisionImageProcessor"
247
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
248
+ special_image_token = "<|image|>"
249
+
250
+ def __init__(self, image_processor, tokenizer):
251
+ self.image_processor = image_processor
252
+ self.tokenizer = tokenizer
253
+ self.num_img_tokens = image_processor.num_crops
254
+ self.img_tokens = [f"<|image_{i+1}|>" for i in range(1000000)]
255
+
256
+ def __call__(
257
+ self,
258
+ text: Union[TextInput, List[TextInput]],
259
+ images: ImageInput = None,
260
+ padding: Union[bool, str, PaddingStrategy] = False,
261
+ truncation: Union[bool, str, TruncationStrategy] = None,
262
+ max_length=None,
263
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
264
+ ) -> BatchFeature:
265
+ """Process both text and image inputs for the model."""
266
+ if images is not None:
267
+ image_features = self.image_processor(images, return_tensors=return_tensors)
268
+ else:
269
+ image_features = {}
270
+
271
+ # Process combined inputs
272
+ inputs = self._process_multimodal_inputs(
273
+ image_features,
274
+ text,
275
+ padding=padding,
276
+ truncation=truncation,
277
+ max_length=max_length,
278
+ return_tensors=return_tensors
279
+ )
280
+
281
+ return inputs
282
+
283
+ def _process_multimodal_inputs(self, images, texts, **kwargs):
284
+ """Process and combine image and text inputs."""
285
+ if not images:
286
+ return BatchFeature(data=self.tokenizer(
287
+ texts,
288
+ return_tensors=kwargs.get('return_tensors'),
289
+ padding=kwargs.get('padding'),
290
+ truncation=kwargs.get('truncation'),
291
+ max_length=kwargs.get('max_length')
292
+ ))
293
+
294
+ # Process text chunks and image tags
295
+ pattern = r"<\|image_\d+\|>"
296
+ text_chunks = [
297
+ self.tokenizer(chunk).input_ids
298
+ for chunk in re.split(pattern, texts)
299
+ ]
300
+
301
+ # Handle image tokens
302
+ num_img_tokens = (
303
+ images['num_img_tokens']
304
+ if 'num_img_tokens' in images
305
+ else [self.num_img_tokens] * len(images['pixel_values'])
306
+ )
307
+
308
+ image_tags = re.findall(pattern, texts)
309
+ image_ids = [int(tag.split("|")[1].split("_")[-1]) for tag in image_tags]
310
+
311
+ # Validate image IDs
312
+ unique_ids = sorted(set(image_ids))
313
+ if unique_ids != list(range(1, len(unique_ids) + 1)):
314
+ raise ValueError(
315
+ f"Image IDs must be consecutive integers starting from 1, got {unique_ids}"
316
+ )
317
+ if len(unique_ids) != len(images['pixel_values']):
318
+ raise ValueError(
319
+ f"Number of image tags ({len(unique_ids)}) doesn't match "
320
+ f"number of images ({len(images['pixel_values'])})"
321
+ )
322
+
323
+ # Create padded image IDs
324
+ image_ids_padded = [
325
+ [-iid] * num_img_tokens[iid-1]
326
+ for iid in image_ids
327
+ ]
328
+
329
+ # Combine text and image tokens
330
+ input_ids = []
331
+ for x in self._interleave_sequences(text_chunks, image_ids_padded):
332
+ input_ids.extend(x)
333
+
334
+ input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
335
+ attention_mask = (input_ids > -1000000).to(torch.long)
336
+
337
+ return BatchFeature(data={
338
+ "input_ids": input_ids,
339
+ "attention_mask": attention_mask,
340
+ "pixel_values": images['pixel_values'],
341
+ "image_sizes": images['image_sizes']
342
+ })
343
+
344
+ def _interleave_sequences(self, seq1, seq2):
345
+ """Interleave two sequences, padding second sequence if needed."""
346
+ if len(seq1) > len(seq2):
347
+ seq2.append([])
348
+ return [item for pair in zip(seq1, seq2) for item in pair]
349
+
350
+ def batch_decode(self, *args, **kwargs):
351
+ """Decode a batch of token IDs to text."""
352
+ return self.tokenizer.batch_decode(*args, **kwargs)
353
+
354
+ def decode(self, *args, **kwargs):
355
+ """Decode token IDs to text."""
356
+ return self.tokenizer.decode(*args, **kwargs)
357
+
358
+ @property
359
+ def model_input_names(self):
360
+ """Get combined input names from both processors."""
361
+ return list(dict.fromkeys(
362
+ self.tokenizer.model_input_names +
363
+ self.image_processor.model_input_names
364
+ ))
365
+
366
+ # Register the processor with AutoImageProcessor
367
+ AutoImageProcessor.register("SpecVisionImageProcessor", SpecVisionImageProcessor)