Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from transformers import Idefics3Model, Idefics3ForConditionalGeneration | |
| from typing import Dict, Any, List, Optional, Union, Tuple | |
| from transformers.cache_utils import Cache, DynamicCache | |
| from transformers.utils import add_start_docstrings_to_model_forward, logging | |
| from transformers.models.idefics3.modeling_idefics3 import IDEFICS3_INPUTS_DOCSTRING, Idefics3BaseModelOutputWithPast | |
| logger = logging.get_logger(__name__) | |
| class SmolVLMModel(Idefics3Model): | |
| """ | |
| A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger | |
| in forward. Instead, we override inputs_merger here with custom logic. | |
| """ | |
| def inputs_merger( | |
| self, | |
| input_ids: torch.LongTensor, | |
| inputs_embeds: torch.Tensor, | |
| image_hidden_states: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Merge text embeddings with image embeddings out-of-place (no in-place indexing). | |
| The shapes are something like: | |
| - input_ids: (B, T) | |
| - inputs_embeds: (B, T, D) | |
| - image_hidden_states:(N, S, D) where N is total images across the batch, | |
| S is #patches (or #slots) per image, D is embedding dim. | |
| Logic: | |
| 1) For each sample in the batch, find <image> tokens in the text. | |
| 2) If zero <image> tokens => text-only. Concatenate a zero-length slice | |
| from image_hidden_states but do NOT advance the offset. This ensures | |
| the model's image encoder is still in the computation graph, but we | |
| skip "consuming" any image block for a text-only sample. | |
| 3) If there are <image> tokens, they appear in multiples of S for each image | |
| (because each image is S embeddings). We chunk those positions into groups | |
| of S. For each chunk => we consume one block from image_hidden_states[offset] | |
| (which is shape (S, D)), and place each row into the text in place of a token. | |
| Returns: | |
| A tensor of (B, T, D). | |
| """ | |
| ############################################## | |
| # 1) Basic shape checks | |
| ############################################## | |
| #old_merger_outputs = self.inputs_merger_old(input_ids, inputs_embeds, image_hidden_states) | |
| B, T, D_text = inputs_embeds.shape | |
| N, S, D_img = image_hidden_states.shape | |
| if D_text != D_img: | |
| raise ValueError( | |
| f"Text embedding dim {D_text} != image embedding dim {D_img}" | |
| ) | |
| ############################################## | |
| # 2) We'll track how many images we've used so far across the entire batch | |
| ############################################## | |
| image_offset = 0 | |
| # We'll store one merged tensor per batch sample | |
| merged_outputs: List[torch.Tensor] = [] | |
| ############################################## | |
| # 3) Iterate through each sample | |
| ############################################## | |
| for b_idx, (cur_ids, cur_embeds) in enumerate(zip(input_ids, inputs_embeds)): | |
| # Find positions of <image> tokens in the text | |
| image_positions = (cur_ids == self.image_token_id).nonzero(as_tuple=True)[0] | |
| num_image_tokens = len(image_positions) | |
| # If no <image> => text-only | |
| if num_image_tokens == 0: | |
| # We do not consume any row from image_hidden_states; | |
| # but we do a zero-length slice so the image encoder is in the graph. | |
| empty_slice = image_hidden_states[0][:0, :] # shape (0, D) | |
| # Concatenate text plus that empty slice. | |
| # NOTE: this is important for DeepSpeed. | |
| merged_text_only = torch.cat([cur_embeds, empty_slice], dim=0) | |
| merged_outputs.append(merged_text_only) | |
| continue | |
| # Otherwise, we have at least one <image> token. | |
| # Typically, if each image is S embeddings, we expect the total # of <image> tokens | |
| # in this sample to be multiple of S => each group of S tokens = 1 image | |
| if num_image_tokens % S != 0: | |
| raise ValueError( | |
| f"Sample {b_idx} has {num_image_tokens} <image> tokens, not a multiple of S={S}. " | |
| "Cannot map them to blocks of shape (S, D)." | |
| ) | |
| # We'll chunk image_positions into groups of size S | |
| positions_list = image_positions.tolist() | |
| # Example: if num_image_tokens=162 and S=81 => we have 2 images => 2 chunks each of length 81 | |
| chunks = [ | |
| positions_list[i : i + S] | |
| for i in range(0, num_image_tokens, S) | |
| ] | |
| # We'll build a list of segments: text, then image row(s), text, etc. | |
| segments = [] | |
| text_start = 0 | |
| # For each chunk (each chunk => 1 image) | |
| for chunk in chunks: | |
| # image_hidden_states[image_offset] => shape (S, D) | |
| cur_block = image_hidden_states[image_offset] | |
| image_offset += 1 | |
| # We'll iterate over the S positions in ascending order | |
| for i_s, pos in enumerate(chunk): | |
| # Add text from [text_start..pos) | |
| if pos > text_start: | |
| segments.append(cur_embeds[text_start:pos]) | |
| # Then add one row from cur_block => shape (1, D) | |
| row_of_block = cur_block[i_s : i_s + 1, :] | |
| segments.append(row_of_block) | |
| # skip the <image> token | |
| text_start = pos + 1 | |
| # leftover text after the final <image> token | |
| if text_start < T: | |
| segments.append(cur_embeds[text_start:]) | |
| # cat them into a single (T_b, D) tensor | |
| merged_sample = torch.cat(segments, dim=0) | |
| merged_outputs.append(merged_sample) | |
| merged_outputs = torch.stack(merged_outputs) | |
| #assert (old_merger_outputs==merged_outputs).all() | |
| return merged_outputs | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| pixel_attention_mask: Optional[torch.BoolTensor] = None, | |
| image_hidden_states: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, Idefics3BaseModelOutputWithPast]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if self.training and self.text_model.gradient_checkpointing and use_cache: | |
| logger.warning_once( | |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | |
| ) | |
| use_cache = False | |
| # retrieve input_ids and inputs_embeds | |
| if input_ids is not None: | |
| batch_size, seq_length = input_ids.shape | |
| elif inputs_embeds is not None: | |
| batch_size, seq_length, _ = inputs_embeds.shape | |
| else: | |
| raise ValueError("You have to specify either input_ids or inputs_embeds") | |
| past_seen_tokens = 0 | |
| if use_cache: | |
| if past_key_values is None: | |
| past_key_values = DynamicCache() | |
| past_seen_tokens = past_key_values.get_seq_length() | |
| if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: | |
| raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") | |
| if inputs_embeds is None: | |
| inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device) | |
| # START VISUAL INPUTS INTEGRATION | |
| if pixel_values is not None and image_hidden_states is not None: | |
| raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") | |
| elif pixel_values is not None: | |
| batch_size, num_images, num_channels, height, width = pixel_values.shape | |
| pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility | |
| pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) | |
| # Remove padding images - padding images are full 0. | |
| nb_values_per_image = pixel_values.shape[1:].numel() | |
| real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image | |
| if not any(real_images_inds): | |
| # no images, leave one empty image. | |
| real_images_inds[0] = True | |
| pixel_values = pixel_values[real_images_inds].contiguous() | |
| # Handle the vision attention mask | |
| if pixel_attention_mask is None: | |
| pixel_attention_mask = torch.ones( | |
| size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), | |
| dtype=torch.bool, | |
| device=pixel_values.device, | |
| ) | |
| else: | |
| # Remove padding images from the mask | |
| pixel_attention_mask = pixel_attention_mask.view( | |
| batch_size * num_images, *pixel_attention_mask.shape[2:] | |
| ) | |
| pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() | |
| patch_size = self.config.vision_config.patch_size | |
| patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) | |
| patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) | |
| patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() | |
| # Get sequence from the vision encoder | |
| image_hidden_states = self.vision_model( | |
| pixel_values=pixel_values, | |
| patch_attention_mask=patch_attention_mask, | |
| ).last_hidden_state | |
| # Modality projection & resampling | |
| image_hidden_states = self.connector(image_hidden_states) | |
| elif image_hidden_states is not None: | |
| image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) | |
| if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: | |
| # When we generate, we don't want to replace the potential image_token_id that we generated by images | |
| # that simply don't exist | |
| inputs_embeds = self.inputs_merger( | |
| input_ids=input_ids, | |
| inputs_embeds=inputs_embeds, | |
| image_hidden_states=image_hidden_states, | |
| ) | |
| outputs = self.text_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| if not return_dict: | |
| return tuple(v for v in [*outputs, image_hidden_states] if v is not None) | |
| return Idefics3BaseModelOutputWithPast( | |
| last_hidden_state=outputs.last_hidden_state, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| image_hidden_states=image_hidden_states, | |
| ) | |
| class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration): | |
| """ | |
| A subclass of Idefics3ForConditionalGeneration that uses MyIdefics3Model | |
| instead of the default Idefics3Model. | |
| """ | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # Instead of the original self.model = Idefics3Model(config), | |
| # we point to our custom class. | |
| self.model = SmolVLMModel(config) | |
| # We *keep* the same lm_head from the parent, or re-init if you prefer: | |
| self.lm_head = nn.Linear( | |
| config.text_config.hidden_size, config.text_config.vocab_size, bias=False | |
| ) | |
| # If parent sets up any post_init() logic: | |
| self.post_init() |