# Copyright 2025 Bytedance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import os, sys, importlib from typing import Optional, Dict, List import torch from functools import partial from diffusers import DiffusionPipeline from diffusers.utils import logging from accelerate import ( init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch, ) from huggingface_hub import snapshot_download from tqdm import tqdm from copy import deepcopy import random import cv2 import numpy as np from torchvision import transforms from torchvision.transforms import functional as F from torchvision.transforms import InterpolationMode from dataclasses import dataclass from types import SimpleNamespace from einops import rearrange from torch import Tensor, nn from safetensors.torch import load_file as load_sft import copy from typing import List, Tuple, Optional import torch.nn.functional as F from torch import nn from torch.nn.attention.flex_attention import create_block_mask from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel from dataclasses import asdict, fields from diffusers.models.modeling_utils import ModelMixin from diffusers.configuration_utils import ConfigMixin import math from transformers.activations import ACT2FN from torch import nn from torch.nn.attention import SDPBackend, sdpa_kernel from torch.nn.attention.flex_attention import flex_attention from torch.nn.functional import scaled_dot_product_attention from transformers.utils import ModelOutput from flash_attn import flash_attn_varlen_func torch._dynamo.config.cache_size_limit = 512 torch._dynamo.config.accumulated_cache_size_limit = 4096 # flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune' flex_attention = torch.compile(flex_attention) from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging from typing import List, Optional, Tuple, Union import torch.utils.checkpoint from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from typing import Optional, Tuple from transformers.tokenization_utils import AddedToken from transformers.tokenization_utils_fast import PreTrainedTokenizerFast import json import unicodedata from functools import lru_cache import regex as re from transformers.tokenization_utils import PreTrainedTokenizer from typing import Union from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging import string import warnings from shutil import copyfile from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import sentencepiece as spm from transformers.convert_slow_tokenizer import import_protobuf from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils_base import AddedToken if TYPE_CHECKING: from transformers.tokenization_utils_base import TextInput from transformers.utils import logging, requires_backends VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} SPIECE_UNDERLINE = "▁" from typing import Dict, List, Optional, Union from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from transformers.image_transforms import ( convert_to_rgb, resize, to_channel_dimension_format, ) from transformers.image_utils import ( IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ChannelDimension, ImageInput, PILImageResampling, infer_channel_dimension_format, is_scaled_image, make_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, ) from transformers.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging logger = logging.get_logger(__name__) import warnings from dataclasses import dataclass from typing import Any, Optional, Tuple, Union from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn.init import _calculate_fan_in_and_fan_out from transformers.activations import ACT2FN from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPooling, ImageClassifierOutput from transformers.utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, torch_int, ) from typing import List, Optional, Union from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy from transformers.utils import TensorType from PIL import Image from torch.nn.attention.flex_attention import or_masks, and_masks def create_sparse_mask(document_lens, split_lens, attn_modes, device): def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx def full_and_noise_mask(b, h, q_idx, kv_idx): return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0) def remove_noise_mask(b, h, q_idx, kv_idx): return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx]))) def sample_mask(b, h, q_idx, kv_idx): return document_id[q_idx] == document_id[kv_idx] full_and_noise_tmp = [] noise_tmp = [] for i, (length, model) in enumerate(zip(split_lens, attn_modes)): value = i if model in ['full', 'noise'] else -1 full_and_noise_tmp.extend([value] * length) value_noise = i if model == 'noise' else -1 noise_tmp.extend([value_noise] * length) full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device) noise_seq_id = torch.Tensor(noise_tmp).to(device) document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device) return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask) def patchify(image, patch_size): p = patch_size c, h, w = image.shape assert h % p == 0 and w % p == 0 image = image.reshape(c, h // p, p, w // p, p) image = torch.einsum("chpwq->hwpqc", image) image = image.reshape(-1, p**2 * c) return image def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size coords_h = torch.arange(0, num_patches_h) coords_w = torch.arange(0, num_patches_w) pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() return pos_ids def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side): num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side) fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten() return pos_ids def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"): """ nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within a sample, where each sample contains multiple splits with different attn modes. nested_attn_modes: whether to use full attn in each split. """ sample_len = sum(split_lens) attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device) csum = 0 for s, attn_mode in zip(split_lens, attn_modes): assert attn_mode in ['causal', 'full', 'noise'] if attn_mode == "causal": attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril() attention_mask[csum:csum + s, :csum] = 1 else: attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s)) attention_mask[csum:csum + s, :csum] = 1 csum += s csum = 0 for s, attn_mode in zip(split_lens, attn_modes): if attn_mode == "noise": attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s)) attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s)) csum += s attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_( ~attention_mask, float("-inf") ) return attention_mask def split_integer_exp_decay(S, ng_sample_decay=1.0): if ng_sample_decay == 1.0: N = random.randint(1, S) else: base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S)) p = [base * math.pow(ng_sample_decay, i) for i in range(S)] N = random.choices(list(range(1, S + 1)), p, k=1)[0] cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S] result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)] return result, cumsum def pil_img2rgb(image): if image.mode == "RGBA" or image.info.get("transparency", None) is not None: image = image.convert("RGBA") white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255)) white.paste(image, mask=image.split()[3]) image = white else: image = image.convert("RGB") return image def add_special_tokens(tokenizer): all_special_tokens = [] for k, v in tokenizer.special_tokens_map.items(): if isinstance(v, str): all_special_tokens.append(v) elif isinstance(v, list): all_special_tokens += v new_tokens = [] if '<|im_start|>' not in all_special_tokens: new_tokens.append('<|im_start|>') if '<|im_end|>' not in all_special_tokens: new_tokens.append('<|im_end|>') if '<|vision_start|>' not in all_special_tokens: new_tokens.append('<|vision_start|>') if '<|vision_end|>' not in all_special_tokens: new_tokens.append('<|vision_end|>') num_new_tokens = tokenizer.add_tokens(new_tokens) bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>') eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>') start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>') end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>') new_token_ids = dict( bos_token_id=bos_token_id, eos_token_id=eos_token_id, start_of_image=start_of_image, end_of_image=end_of_image, ) return tokenizer, new_token_ids, num_new_tokens def len2weight(x, loss_reduction='square'): if x == 0: return x if loss_reduction == 'token': return 1 if loss_reduction == 'sample': return 1 / x if loss_reduction == 'square': return 1 / (x ** 0.5) raise NotImplementedError(loss_reduction) class NaiveCache: def __init__(self, num_layers): self.key_cache = {k: None for k in range(num_layers)} self.value_cache = {k: None for k in range(num_layers)} @property def num_layers(self): return len(self.key_cache) @property def seq_lens(self): if self.key_cache[0] is not None: return self.key_cache[0].shape[0] else: return 0 class _Qwen2Config(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 151936): Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Qwen2Model`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 22016): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer encoder. num_key_value_heads (`int`, *optional*, defaults to 32): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 32768): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be tied. rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. Expected contents: `rope_type` (`str`): The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. `factor` (`float`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. `original_max_position_embeddings` (`int`, *optional*): Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during pretraining. `attention_factor` (`float`, *optional*): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value. `beta_fast` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32. `beta_slow` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. `short_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to short contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `long_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to long contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `low_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE use_sliding_window (`bool`, *optional*, defaults to `False`): Whether to use sliding window attention. sliding_window (`int`, *optional*, defaults to 4096): Sliding window attention (SWA) window size. If not specified, will default to `4096`. max_window_layers (`int`, *optional*, defaults to 28): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. ```python >>> from transformers import Qwen2Model, _Qwen2Config >>> # Initializing a Qwen2 style configuration >>> configuration = _Qwen2Config() >>> # Initializing a model from the Qwen2-7B style configuration >>> model = Qwen2Model(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "qwen2" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, is_causal=True, _attn_implementation="flash_attention_2", **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = use_sliding_window self.sliding_window = sliding_window if use_sliding_window else None self.max_window_layers = max_window_layers # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_dropout = attention_dropout self.is_causal = is_causal self._attn_implementation = _attn_implementation # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) if is_flash_attn_2_available(): from transformers.modeling_flash_attention_utils import _flash_attention_forward _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B" _CONFIG_FOR_DOC = "_Qwen2Config" class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Qwen2RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 class Qwen2RotaryEmbedding(nn.Module): def __init__( self, dim=None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, rope_type="default", config: Optional[_Qwen2Config] = None, ): super().__init__() # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} if config is None: logger.warning_once( "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the " "`config` argument. All other arguments will be removed in v4.46" ) self.rope_kwargs = { "rope_type": rope_type, "factor": scaling_factor, "dim": dim, "base": base, "max_position_embeddings": max_position_embeddings, } self.rope_type = rope_type self.max_seq_len_cached = max_position_embeddings self.original_max_seq_len = max_position_embeddings else: # BC: "rope_type" was originally "type" if config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len, **self.rope_kwargs ) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 class Qwen2MLP(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_state): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class Qwen2Attention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer and "Generating Long Sequences with Sparse Transformers". """ def __init__(self, config: _Qwen2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = config.is_causal self.attention_dropout = config.attention_dropout if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class Qwen2FlashAttention2(Qwen2Attention): """ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom config.max_window_layers layers. """ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) if ( self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers ): sliding_window = self.config.sliding_window else: sliding_window = None attn_output = _flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, position_ids=position_ids, dropout=dropout_rate, sliding_window=sliding_window, is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value QWEN2_ATTENTION_CLASSES = { "eager": Qwen2Attention, "flash_attention_2": Qwen2FlashAttention2, } QWEN2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`_Qwen2Config`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", QWEN2_START_DOCSTRING, ) class Qwen2PreTrainedModel(PreTrainedModel): config_class = _Qwen2Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() QWEN2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ VOCAB_FILES_NAMES = { "vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json", } MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768} PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" @lru_cache() # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode def bytes_to_unicode(): """ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control characters the bpe code barfs on. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. """ bs = ( list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) ) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs def get_pairs(word): """ Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs class Qwen2Tokenizer(PreTrainedTokenizer): """ Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding. Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will be encoded differently whether it is at the beginning of the sentence (without space) or not: ```python >>> from transformers import Qwen2Tokenizer >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer") >>> tokenizer("Hello world")["input_ids"] [9707, 1879] >>> tokenizer(" Hello world")["input_ids"] [21927, 1879] ``` This is expected. You should not use GPT2Tokenizer instead, because of the different pretokenization rules. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Args: vocab_file (`str`): Path to the vocabulary file. merges_file (`str`): Path to the merges file. errors (`str`, *optional*, defaults to `"replace"`): Paradigm to follow when decoding bytes to UTF-8. See [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. bos_token (`str`, *optional*): The beginning of sequence token. Not applicable for this tokenizer. eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): The end of sequence token. pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): The token used for padding, for example when batching sequences of different lengths. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): Whether or not the model should cleanup the spaces that were added when splitting the input text during the tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces. split_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not the special tokens should be split during the tokenization process. The default behavior is to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") = ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<', '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment. """ vocab_files_names = VOCAB_FILES_NAMES model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file, merges_file, errors="replace", unk_token="<|endoftext|>", bos_token=None, eos_token="<|endoftext|>", pad_token="<|endoftext|>", clean_up_tokenization_spaces=False, split_special_tokens=False, **kwargs, ): # Qwen vocab does not contain control tokens; added tokens need to be special bos_token = ( AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) if isinstance(bos_token, str) else bos_token ) eos_token = ( AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) if isinstance(eos_token, str) else eos_token ) unk_token = ( AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) if isinstance(unk_token, str) else unk_token ) pad_token = ( AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) if isinstance(pad_token, str) else pad_token ) with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) self.decoder = {v: k for k, v in self.encoder.items()} self.errors = errors # how to handle errors in decoding self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} bpe_merges = [] with open(merges_file, encoding="utf-8") as merges_handle: for i, line in enumerate(merges_handle): line = line.strip() if (i == 0 and line.startswith("#version:")) or not line: continue bpe_merges.append(tuple(line.split())) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) # NOTE: the cache can grow without bound and will get really large for long running processes # (esp. for texts of language that do not use space between word, e.g. Chinese); technically # not a memory leak but appears as one. # GPT2Tokenizer has the same problem, so let's be consistent. self.cache = {} self.pat = re.compile(PRETOKENIZE_REGEX) if kwargs.get("add_prefix_space", False): logger.warning_once( f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect." ) super().__init__( errors=errors, bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, unk_token=unk_token, clean_up_tokenization_spaces=clean_up_tokenization_spaces, split_special_tokens=split_special_tokens, **kwargs, ) @property def vocab_size(self) -> int: return len(self.encoder) # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab def get_vocab(self): return dict(self.encoder, **self.added_tokens_encoder) # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token) pairs = get_pairs(word) if not pairs: return token while True: bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) except ValueError: new_word.extend(word[i:]) break else: new_word.extend(word[i:j]) i = j if word[i] == first and i < len(word) - 1 and word[i + 1] == second: new_word.append(first + second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = " ".join(word) self.cache[token] = word return word # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize def _tokenize(self, text): """Tokenize a string.""" bpe_tokens = [] for token in re.findall(self.pat, text): token = "".join( self.byte_encoder[b] for b in token.encode("utf-8") ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) return bpe_tokens # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.encoder.get(token, self.encoder.get(self.unk_token)) # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" return self.decoder.get(index) # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" text = "".join(tokens) text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) return text def decode( self, token_ids, skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = False, spaces_between_special_tokens: bool = False, **kwargs, ) -> str: # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer return super().decode( token_ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, spaces_between_special_tokens=spaces_between_special_tokens, **kwargs, ) # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) merge_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] ) with open(vocab_file, "w", encoding="utf-8") as f: f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") index = 0 with open(merge_file, "w", encoding="utf-8") as writer: writer.write("#version: 0.2\n") for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): if index != token_index: logger.warning( f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." " Please check that the tokenizer is not corrupted!" ) index = token_index writer.write(" ".join(bpe_tokens) + "\n") index += 1 return vocab_file, merge_file def prepare_for_tokenization(self, text, **kwargs): text = unicodedata.normalize("NFC", text) return (text, kwargs) class SiglipTextConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`SiglipModel`]. hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. intermediate_size (`int`, *optional*, defaults to 3072): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. max_position_embeddings (`int`, *optional*, defaults to 64): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. pad_token_id (`int`, *optional*, defaults to 1): The id of the padding token in the vocabulary. bos_token_id (`int`, *optional*, defaults to 49406): The id of the beginning-of-sequence token in the vocabulary. eos_token_id (`int`, *optional*, defaults to 49407): The id of the end-of-sequence token in the vocabulary. Example: ```python >>> from transformers import SiglipTextConfig, SiglipTextModel >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration >>> configuration = SiglipTextConfig() >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration >>> model = SiglipTextModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "siglip_text_model" def __init__( self, vocab_size=32000, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, max_position_embeddings=64, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, # This differs from `CLIPTokenizer`'s default and from openai/siglip # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 pad_token_id=1, bos_token_id=49406, eos_token_id=49407, **kwargs, ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.max_position_embeddings = max_position_embeddings self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.attention_dropout = attention_dropout @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # get the text config dict if we are loading from SiglipConfig if config_dict.get("model_type") == "siglip": config_dict = config_dict["text_config"] if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) class _SiglipVisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. intermediate_size (`int`, *optional*, defaults to 3072): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. num_channels (`int`, *optional*, defaults to 3): Number of channels in the input images. image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. patch_size (`int`, *optional*, defaults to 16): The size (resolution) of each patch. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. Example: ```python >>> from transformers import _SiglipVisionConfig, SiglipVisionModel >>> # Initializing a _SiglipVisionConfig with google/siglip-base-patch16-224 style configuration >>> configuration = _SiglipVisionConfig() >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration >>> model = SiglipVisionModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "siglip_vision_model" def __init__( self, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, num_channels=3, image_size=224, patch_size=16, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_channels = num_channels self.patch_size = patch_size self.image_size = image_size self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # get the vision config dict if we are loading from SiglipConfig if config_dict.get("model_type") == "siglip": config_dict = config_dict["vision_config"] if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) class SiglipConfig(PretrainedConfig): r""" [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: text_config (`dict`, *optional*): Dictionary of configuration options used to initialize [`SiglipTextConfig`]. vision_config (`dict`, *optional*): Dictionary of configuration options used to initialize [`_SiglipVisionConfig`]. kwargs (*optional*): Dictionary of keyword arguments. Example: ```python >>> from transformers import SiglipConfig, SiglipModel >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration >>> configuration = SiglipConfig() >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration >>> model = SiglipModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a _SiglipVisionConfig >>> from transformers import SiglipTextConfig, _SiglipVisionConfig >>> # Initializing a SiglipText and SiglipVision configuration >>> config_text = SiglipTextConfig() >>> config_vision = _SiglipVisionConfig() >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) ```""" model_type = "siglip" def __init__(self, text_config=None, vision_config=None, **kwargs): super().__init__(**kwargs) if text_config is None: text_config = {} logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") if vision_config is None: vision_config = {} logger.info("`vision_config` is `None`. initializing the `_SiglipVisionConfig` with default values.") self.text_config = SiglipTextConfig(**text_config) self.vision_config = _SiglipVisionConfig(**vision_config) self.initializer_factor = 1.0 @classmethod def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: _SiglipVisionConfig, **kwargs): r""" Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision model configuration. Returns: [`SiglipConfig`]: An instance of a configuration object """ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) if is_vision_available(): import PIL class SiglipImageProcessor(BaseImageProcessor): r""" Constructs a SigLIP image processor. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by `do_resize` in the `preprocess` method. size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `True`): Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in the `preprocess` method. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` method. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image by the specified mean and standard deviation. Can be overridden by `do_normalize` in the `preprocess` method. image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. """ model_input_names = ["pixel_values"] def __init__( self, do_resize: bool = True, size: Dict[str, int] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = None, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"height": 224, "width": 224} image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.do_resize = do_resize self.size = size self.resample = resample self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize self.image_mean = image_mean self.image_std = image_std self.do_convert_rgb = do_convert_rgb @filter_out_non_signature_kwargs() def preprocess( self, images: ImageInput, do_resize: bool = None, size: Dict[str, int] = None, resample: PILImageResampling = None, do_rescale: bool = None, rescale_factor: float = None, do_normalize: bool = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, do_convert_rgb: bool = None, ) -> PIL.Image.Image: """ Preprocess an image or batch of images. Args: images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): Size of the image after resizing. resample (`int`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only has an effect if `do_resize` is set to `True`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to rescale the image by if `do_rescale` is set to `True`. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - Unset: Use the channel dimension format of the input image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size size = get_size_dict(size, param_name="size", default_to_square=False) resample = resample if resample is not None else self.resample do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor do_normalize = do_normalize if do_normalize is not None else self.do_normalize image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb images = make_list_of_images(images) if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, do_resize=do_resize, size=size, resample=resample, ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] if do_convert_rgb: images = [convert_to_rgb(image) for image in images] if is_scaled_image(images[0]) and do_rescale: logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." ) if input_data_format is None: # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) if do_resize: height, width = size["height"], size["width"] images = [ resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format) for image in images ] if do_rescale: images = [ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) for image in images ] if do_normalize: images = [ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) for image in images ] images = [ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images ] data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) if is_flash_attn_2_available(): from transformers.modeling_flash_attention_utils import _flash_attention_forward # General docstring _CONFIG_FOR_DOC = "SiglipConfig" _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) def trunc_normal_tf_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 ) -> torch.Tensor: """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \\leq \text{mean} \\leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsequently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": with torch.no_grad(): tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) with torch.no_grad(): tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") @dataclass # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip class SiglipVisionModelOutput(ModelOutput): """ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. Args: image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The image embeddings obtained by applying the projection layer to the pooler_output. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ image_embeds: Optional[torch.FloatTensor] = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None @dataclass # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip class SiglipTextModelOutput(ModelOutput): """ Base class for text model's outputs that also contains a pooling of the last hidden states. Args: text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The text embeddings obtained by applying the projection layer to the pooler_output. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ text_embeds: Optional[torch.FloatTensor] = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None @dataclass # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip class SiglipOutput(ModelOutput): """ Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): Contrastive loss for image-text similarity. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text similarity scores. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image similarity scores. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. text_model_output (`BaseModelOutputWithPooling`): The output of the [`SiglipTextModel`]. vision_model_output (`BaseModelOutputWithPooling`): The output of the [`SiglipVisionModel`]. """ loss: Optional[torch.FloatTensor] = None logits_per_image: torch.FloatTensor = None logits_per_text: torch.FloatTensor = None text_embeds: torch.FloatTensor = None image_embeds: torch.FloatTensor = None text_model_output: BaseModelOutputWithPooling = None vision_model_output: BaseModelOutputWithPooling = None def to_tuple(self) -> Tuple[Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() ) # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip class SiglipTextEmbeddings(nn.Module): def __init__(self, config: SiglipTextConfig): super().__init__() embed_dim = config.hidden_size self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) def forward( self, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings class SiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) k_v_seq_len = key_states.shape[-2] attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): raise ValueError( f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): raise ValueError( f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights class SiglipSdpaAttention(SiglipAttention): """ Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ is_causal = False # Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if self.is_causal and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None class SiglipPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = SiglipConfig base_model_prefix = "siglip" supports_gradient_checkpointing = True _no_split_modules = [ "SiglipTextEmbeddings", "SiglipEncoderLayer", "SiglipVisionEmbeddings", "SiglipEncoderLayer", "SiglipMultiheadAttentionPoolingHead", ] _supports_flash_attn_2 = True _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SiglipVisionEmbeddings): width = ( self.config.vision_config.hidden_size if isinstance(self.config, SiglipConfig) else self.config.hidden_size ) nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, SiglipAttention): nn.init.xavier_uniform_(module.q_proj.weight) nn.init.xavier_uniform_(module.k_proj.weight) nn.init.xavier_uniform_(module.v_proj.weight) nn.init.xavier_uniform_(module.out_proj.weight) nn.init.zeros_(module.q_proj.bias) nn.init.zeros_(module.k_proj.bias) nn.init.zeros_(module.v_proj.bias) nn.init.zeros_(module.out_proj.bias) elif isinstance(module, SiglipMLP): nn.init.xavier_uniform_(module.fc1.weight) nn.init.xavier_uniform_(module.fc2.weight) nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, SiglipMultiheadAttentionPoolingHead): nn.init.xavier_uniform_(module.probe.data) nn.init.xavier_uniform_(module.attention.in_proj_weight.data) nn.init.zeros_(module.attention.in_proj_bias.data) elif isinstance(module, SiglipModel): logit_scale_init = torch.log(torch.tensor(1.0)) module.logit_scale.data.fill_(logit_scale_init) module.logit_bias.data.zero_() elif isinstance(module, SiglipForImageClassification): nn.init.normal_( module.classifier.weight, std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, ) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) SIGLIP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ SIGLIP_TEXT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ SIGLIP_VISION_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ SIGLIP_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. return_loss (`bool`, *optional*): Whether or not to return the contrastive loss. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ class SiglipTextTransformer(nn.Module): def __init__(self, config: SiglipTextConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = SiglipTextEmbeddings(config) self.encoder = SiglipEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.head = nn.Linear(embed_dim, embed_dim) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is None: raise ValueError("You have to specify input_ids") input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. # expand attention_mask if attention_mask is not None and not self._use_flash_attention_2: # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) # Assuming "sticky" EOS tokenization, last token is always EOS. pooled_output = last_hidden_state[:, -1, :] pooled_output = self.head(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @add_start_docstrings( """The text model from SigLIP without any head or projection on top.""", SIGLIP_START_DOCSTRING, ) class SiglipTextModel(SiglipPreTrainedModel): config_class = SiglipTextConfig def __init__(self, config: SiglipTextConfig): super().__init__(config) self.text_model = SiglipTextTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.text_model.embeddings.token_embedding def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: Examples: ```python >>> from transformers import AutoTokenizer, SiglipTextModel >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") >>> # important: make sure to set padding="max_length" as that's how the model was trained >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class SiglipMultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" def __init__(self, config: _SiglipVisionConfig): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) def forward(self, hidden_state): batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) hidden_state = self.attention(probe, hidden_state, hidden_state)[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0] @add_start_docstrings(SIGLIP_START_DOCSTRING) class SiglipModel(SiglipPreTrainedModel): config_class = SiglipConfig def __init__(self, config: SiglipConfig): super().__init__(config) if not isinstance(config.text_config, SiglipTextConfig): raise TypeError( "config.text_config is expected to be of type SiglipTextConfig but is of type" f" {type(config.text_config)}." ) if not isinstance(config.vision_config, _SiglipVisionConfig): raise TypeError( "config.vision_config is expected to be of type _SiglipVisionConfig but is of type" f" {type(config.vision_config)}." ) text_config = config.text_config vision_config = config.vision_config # First, initialize the text and vision models with proper attention implementation text_model = SiglipTextModel._from_config(text_config) vision_model = SiglipVisionModel._from_config(vision_config) # Second, get the text and vision submodules (for backward compatibility) self.text_model = text_model.text_model self.vision_model = vision_model.vision_model self.logit_scale = nn.Parameter(torch.randn(1)) self.logit_bias = nn.Parameter(torch.randn(1)) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) def get_text_features( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" Returns: text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. Examples: ```python >>> from transformers import AutoTokenizer, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") >>> # important: make sure to set padding="max_length" as that's how the model was trained >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") >>> with torch.no_grad(): ... text_features = model.get_text_features(**inputs) ```""" # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = text_outputs[1] return pooled_output @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: r""" Returns: image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> with torch.no_grad(): ... image_features = model.get_image_features(**inputs) ```""" # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) pooled_output = vision_outputs[1] return pooled_output @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> Union[Tuple, SiglipOutput]: r""" Returns: Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] >>> # important: we pass `padding=max_length` since the model was trained with this >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> logits_per_image = outputs.logits_per_image >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") 31.9% that image 0 is 'a photo of 2 cats' ```""" # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[1] text_embeds = text_outputs[1] # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits logits_per_text = ( torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp() + self.logit_bias ) logits_per_image = logits_per_text.t() loss = None if return_loss: # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) nll = -torch.sum(loglik, dim=-1) loss = nll.mean() if not return_dict: output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) return ((loss,) + output) if loss is not None else output return SiglipOutput( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, ) @add_start_docstrings( """ SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of the patch tokens) e.g. for ImageNet. """, SIGLIP_START_DOCSTRING, ) class SiglipForImageClassification(SiglipPreTrainedModel): main_input_name = "pixel_values" def __init__(self, config: SiglipConfig) -> None: super().__init__(config) self.num_labels = config.num_labels # Create the vision model with proper attention # and take only vision_model submodule (for backward compatibility) vision_model = SiglipVisionModel._from_config(config.vision_config) self.vision_model = vision_model.vision_model # Classifier head self.classifier = ( nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() ) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> Union[tuple, ImageClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Returns: Examples: ```python >>> from transformers import AutoImageProcessor, SiglipForImageClassification >>> import torch >>> from PIL import Image >>> import requests >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> # note: we are loading a `SiglipModel` from the hub here, >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224") >>> inputs = image_processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> logits = outputs.logits >>> # model predicts one of the two classes >>> predicted_class_idx = logits.argmax(-1).item() >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) Predicted class: LABEL_1 ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.vision_model( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = outputs[0] # average pool the patch tokens sequence_output = torch.mean(sequence_output, dim=1) # apply classifier logits = self.classifier(sequence_output) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return ImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class SiglipProcessor(ProcessorMixin): r""" Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor. [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information. Args: image_processor ([`SiglipImageProcessor`]): The image processor is a required input. tokenizer ([`SiglipTokenizer`]): The tokenizer is a required input. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "SiglipImageProcessor" tokenizer_class = "SiglipTokenizer" def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length: int = None, return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode the text. To prepare the image(s), this method forwards the `images` argument to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring of the above two methods for more information. Args: text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum acceptable input length for the model if that argument is not provided. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). max_length (`int`, *optional*): Maximum length of the returned list and optionally padding length (see above). truncation (`bool`, *optional*): Activates truncation to cut input sequences longer than `max_length` to `max_length`. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") if text is not None: encoding = self.tokenizer( text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length ) if images is not None: image_features = self.image_processor(images, return_tensors=return_tensors) if text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values return encoding elif text is not None: return encoding else: return BatchFeature(data=dict(**image_features), tensor_type=return_tensors) def decode(self, *args, **kwargs): """ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) @property # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) class SiglipTokenizer(PreTrainedTokenizer): """ Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Args: vocab_file (`str`): [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that contains the vocabulary necessary to instantiate a tokenizer. eos_token (`str`, *optional*, defaults to `""`): The end of sequence token. unk_token (`str`, *optional*, defaults to `""`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. pad_token (`str`, *optional*, defaults to `""`): The token used for padding, for example when batching sequences of different lengths. additional_special_tokens (`List[str]`, *optional*): Additional special tokens used by the tokenizer. sp_model_kwargs (`dict`, *optional*): Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, to set: - `enable_sampling`: Enable subword regularization. - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. - `nbest_size = {0,1}`: No sampling is performed. - `nbest_size > 1`: samples from the nbest_size results. - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) using forward-filtering-and-backward-sampling algorithm. - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for BPE-dropout. model_max_length (`int`, *optional*, defaults to 64): The maximum length (in number of tokens) for model inputs. do_lower_case (`bool`, *optional*, defaults to `True`): Whether or not to lowercase the input when tokenizing. """ vocab_files_names = VOCAB_FILES_NAMES model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file, eos_token="", unk_token="", pad_token="", additional_special_tokens=None, sp_model_kwargs: Optional[Dict[str, Any]] = None, model_max_length=64, do_lower_case=True, **kwargs, ) -> None: requires_backends(self, "protobuf") pad_token = ( AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True) if isinstance(pad_token, str) else pad_token ) unk_token = ( AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True) if isinstance(unk_token, str) else unk_token ) eos_token = ( AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True) if isinstance(eos_token, str) else eos_token ) self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs self.do_lower_case = do_lower_case self.vocab_file = vocab_file self.sp_model = self.get_spm_processor() self.vocab_file = vocab_file super().__init__( eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, additional_special_tokens=additional_special_tokens, sp_model_kwargs=self.sp_model_kwargs, model_max_length=model_max_length, do_lower_case=do_lower_case, **kwargs, ) def get_spm_processor(self): tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) with open(self.vocab_file, "rb") as f: sp_model = f.read() model_pb2 = import_protobuf() model = model_pb2.ModelProto.FromString(sp_model) normalizer_spec = model_pb2.NormalizerSpec() normalizer_spec.add_dummy_prefix = False model.normalizer_spec.MergeFrom(normalizer_spec) sp_model = model.SerializeToString() tokenizer.LoadFromSerializedProto(sp_model) return tokenizer @property # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size def vocab_size(self): return self.sp_model.get_piece_size() # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab def get_vocab(self): vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. already_has_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not the token list is already formatted with special tokens for the model. Returns: `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True ) # normal case: some special tokens if token_ids_1 is None: return ([0] * len(token_ids_0)) + [1] return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: """Do not add eos again if user already added it.""" if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: warnings.warn( f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" " eos tokens being added." ) return token_ids else: return token_ids + [self.eos_token_id] # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make use of token type ids, therefore a list of zeros is returned. Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of zeros. """ eos = [self.eos_token_id] if token_ids_1 is None: return len(token_ids_0 + eos) * [0] return len(token_ids_0 + eos + token_ids_1 + eos) * [0] # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A sequence has the following format: - single sequence: `X ` - pair of sequences: `A B ` Args: token_ids_0 (`List[int]`): List of IDs to which the special tokens will be added. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ token_ids_0 = self._add_eos_if_not_present(token_ids_0) if token_ids_1 is None: return token_ids_0 else: token_ids_1 = self._add_eos_if_not_present(token_ids_1) return token_ids_0 + token_ids_1 # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__ def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None return state # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__ def __setstate__(self, d): self.__dict__ = d # for backward compatibility if not hasattr(self, "sp_model_kwargs"): self.sp_model_kwargs = {} self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.Load(self.vocab_file) def remove_punctuation(self, text: str) -> str: return text.translate(str.maketrans("", "", string.punctuation)) # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 def canonicalize_text(self, text, *, keep_punctuation_exact_string=None): """Returns canonicalized `text` (puncuation removed). Args: text (`str`): String to be canonicalized. keep_punctuation_exact_string (`str`, *optional*): If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}' (but will still remove '{' and '}' that appear separately). """ if keep_punctuation_exact_string: text = keep_punctuation_exact_string.join( self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string) ) else: text = self.remove_punctuation(text) text = re.sub(r"\s+", " ", text) text = text.strip() return text def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]: """ Converts a string to a list of tokens. """ tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs) if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: tokens = tokens[1:] return tokens @property # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length def unk_token_length(self): return len(self.sp_model.encode(str(self.unk_token))) def _tokenize(self, text, **kwargs): """ Returns a tokenized string. We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. """ text = self.canonicalize_text(text, keep_punctuation_exact_string=None) tokens = self.sp_model.encode(text, out_type=str) # 1. Encode string + prefix ex: " Hey" tokens = self.sp_model.encode(self.unk_token + text, out_type=str) # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.sp_model.piece_to_id(token) # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" token = self.sp_model.IdToPiece(index) return token def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" current_sub_tokens = [] out_string = "" prev_is_special = False for token in tokens: # make sure that special tokens are not decoded using sentencepiece model if token in self.all_special_tokens: if not prev_is_special: out_string += " " out_string += self.sp_model.decode(current_sub_tokens) + token prev_is_special = True current_sub_tokens = [] else: current_sub_tokens.append(token) prev_is_special = False out_string += self.sp_model.decode(current_sub_tokens) return out_string.strip() # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return out_vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: content_spiece_model = self.sp_model.serialized_model_proto() fi.write(content_spiece_model) return (out_vocab_file,) class BagelConfig(PretrainedConfig): def __init__( self, visual_gen=True, visual_und=True, llm_config=None, vit_config=None, vae_config=None, latent_patch_size=2, max_latent_size=32, vit_max_num_patch_per_side=70, connector_act="gelu_pytorch_tanh", interpolate_pos=False, timestep_shift=1.0, **kwargs ): super().__init__(**kwargs) self.visual_gen = visual_gen self.visual_und = visual_und self.llm_config = llm_config self.vit_config = vit_config self.vae_config = vae_config self.latent_patch_size = latent_patch_size self.max_latent_size = max_latent_size self.vit_max_num_patch_per_side = vit_max_num_patch_per_side self.connector_act = connector_act self.interpolate_pos = interpolate_pos self.timestep_shift = timestep_shift class Bagel(PreTrainedModel): config_class = BagelConfig base_model_prefix = 'bagel' def __init__( self, config: BagelConfig, # ← first! language_model: Optional[Qwen2ForCausalLM] = None, vit_model: Optional[SiglipVisionModel] = None, ): if isinstance(config.llm_config, dict): config.llm_config = Qwen2Config(**config.llm_config) if isinstance(config.vit_config, dict): config.vit_config = SiglipVisionConfig(**config.vit_config) if isinstance(config.vae_config, dict): # ← NEW config.vae_config = SimpleNamespace(**config.vae_config) if language_model is None or vit_model is None: with init_empty_weights(): # ‘meta’ device → 0 RAM language_model = Qwen2ForCausalLM(config.llm_config) vit_model = SiglipVisionModel(config.vit_config) super().__init__(config) self.language_model = language_model self.hidden_size = config.llm_config.hidden_size self.use_moe = "Mo" in config.llm_config.layer_module self.num_heads = config.llm_config.num_attention_heads if config.visual_gen: self.latent_patch_size = config.latent_patch_size self.timestep_shift = config.timestep_shift self.latent_downsample = config.vae_config.downsample * config.latent_patch_size self.max_latent_size = config.max_latent_size self.latent_channel = config.vae_config.z_channels self.patch_latent_dim = self.latent_patch_size ** 2 * self.latent_channel self.time_embedder = TimestepEmbedder(self.hidden_size) self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size) self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim) self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size) if config.visual_und: self.vit_model = vit_model self.vit_patch_size = config.vit_config.patch_size self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side self.vit_hidden_size = config.vit_config.hidden_size self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act) self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size) self.vit_model.vision_model.embeddings.convert_conv2d_to_linear(config.vit_config, meta=True) if config.interpolate_pos: self.get_flattened_position_ids = get_flattened_position_ids_interpolate else: self.get_flattened_position_ids = get_flattened_position_ids_extrapolate self.config = config self._init_weights() def _init_weights(self): if self.config.visual_gen: nn.init.constant_(self.llm2vae.weight, 0) nn.init.constant_(self.llm2vae.bias, 0) def forward( self, sequence_length: int, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, sample_lens: List[int], packed_position_ids: torch.LongTensor, nested_attention_masks: List[torch.Tensor] = None, split_lens: List[int] = None, attn_modes: List[str] = None, # for visual understanding ce_loss_indexes: Optional[torch.BoolTensor] = None, packed_label_ids: Optional[torch.LongTensor] = None, packed_vit_tokens: Optional[torch.Tensor] = None, packed_vit_token_indexes: Optional[torch.LongTensor] = None, packed_vit_position_ids: Optional[torch.LongTensor] = None, vit_token_seqlens: Optional[torch.IntTensor] = None, # for visual generation padded_latent: Optional[torch.Tensor] = None, patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None, packed_latent_position_ids: Optional[torch.LongTensor] = None, packed_vae_token_indexes: Optional[torch.LongTensor] = None, packed_timesteps: Optional[torch.LongTensor] = None, mse_loss_indexes: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: """ Args: sequence_length: length of sequence. packed_text_ids: 1-D int tensor, packed text token ids. packed_text_indexes: 1-D int tensor, packed text token indexes in sequence. sample_lens: A list of N ints, length of each sample in packed_sequence. nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and -inf means ignore. packed_position_ids: packed 1-D positions, an image has only one global position shared by all latent tokens. packed_vit_tokens: packed patchified image tokens for vit model. packed_vit_position_ids: 1-D int tensor, the position of each token for vit model. packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence. vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model. packed_label_ids: 1-D int tensor, packed label token ids. ce_loss_indexes: 1-D bool tensor, where to compute ce loss. padded_latent: padded latent from VAE encoder. patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image. packed_latent_position_ids: 1-D int tensor, the position of each token for latent. packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence. packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image. mse_loss_indexes: 1-D bool tensor, where to compute mse loss. """ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) packed_sequence = packed_text_embedding.new_zeros(size=(sequence_length, self.hidden_size)) packed_sequence[packed_text_indexes] = packed_text_embedding if nested_attention_masks is None: sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, packed_text_embedding.device) seqlen = sum(sample_lens) block_mask = create_block_mask( sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen, device=packed_text_embedding.device, BLOCK_SIZE=128, _compile=True ) attention_mask = block_mask else: attention_mask = nested_attention_masks if self.config.visual_und: cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0)) cu_seqlens = cu_seqlens.to(torch.int32) max_seqlen = torch.max(vit_token_seqlens).item() packed_vit_token_embed = self.vit_model( packed_pixel_values=packed_vit_tokens, packed_flattened_position_ids=packed_vit_position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) packed_vit_token_embed = self.connector(packed_vit_token_embed) vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids) packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed if self.config.visual_gen: p = self.latent_patch_size packed_latent = [] for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes): latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p) latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel) packed_latent.append(latent) packed_latent_clean = torch.cat(packed_latent, dim=0) noise = torch.randn_like(packed_latent_clean) packed_timesteps = torch.sigmoid(packed_timesteps) packed_timesteps = self.timestep_shift * packed_timesteps / (1 + (self.timestep_shift - 1) * packed_timesteps) packed_latent = (1 - packed_timesteps[:, None]) * packed_latent_clean + packed_timesteps[:, None] * noise packed_timestep_embeds = self.time_embedder(packed_timesteps) latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids) packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + latent_token_pos_emb packed_sequence[packed_vae_token_indexes] = packed_latent extra_inputs = {} if self.use_moe: packed_und_token_indexes = packed_text_indexes if packed_vit_token_indexes is not None: packed_und_token_indexes=torch.cat([packed_text_indexes, packed_vit_token_indexes], dim=0) extra_inputs.update( packed_und_token_indexes=packed_und_token_indexes, packed_gen_token_indexes=packed_vae_token_indexes, ) last_hidden_state = self.language_model( packed_sequence=packed_sequence, sample_lens=sample_lens, attention_mask=attention_mask, packed_position_ids=packed_position_ids, **extra_inputs, ) mse = None if self.config.visual_gen: packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes]) target = noise - packed_latent_clean # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise has_mse = packed_timesteps > 0 mse = (packed_mse_preds - target[has_mse]) ** 2 ce = None if ce_loss_indexes is not None: packed_ce_preds = self.language_model.lm_head(last_hidden_state[ce_loss_indexes]) ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none") return dict(mse=mse, ce=ce) def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids): packed_text_ids = list() packed_text_position_ids = list() text_token_lens = list() packed_text_indexes = list() packed_key_value_indexes = list() curr = 0 newlens, new_rope = list(), list() for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen text_ids = tokenizer.encode(prompt) text_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']] text_token_lens.append(len(text_ids)) packed_text_ids.extend(text_ids) packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids))) packed_text_indexes.extend(range(curr, curr + len(text_ids))) newlens.append(curr_kvlen + len(text_ids)) new_rope.append(curr_position_id + len(text_ids)) curr += len(text_ids) generation_input = { "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int), "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long), "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), } return generation_input, newlens, new_rope @torch.no_grad def forward_cache_update_text( self, past_key_values: NaiveCache, packed_text_ids: torch.IntTensor, packed_text_position_ids: torch.LongTensor, text_token_lens: torch.LongTensor, packed_text_indexes: torch.LongTensor, packed_key_value_indexes: torch.LongTensor, key_values_lens: torch.IntTensor, ): packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) extra_inputs = {} if self.use_moe: extra_inputs = {"mode": "und"} output = self.language_model.forward_inference( packed_query_sequence=packed_text_embedding, query_lens=text_token_lens, packed_query_position_ids=packed_text_position_ids, packed_query_indexes=packed_text_indexes, past_key_values=past_key_values, packed_key_value_indexes=packed_key_value_indexes, key_values_lens=key_values_lens, update_past_key_values=True, is_causal=True, **extra_inputs, ) past_key_values = output.past_key_values return past_key_values def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids): packed_vit_token_indexes = list() vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list() packed_text_ids, packed_text_indexes = list(), list() packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() packed_key_value_indexes = list() _curr = curr = 0 newlens, new_rope = list(), list() for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen packed_text_ids.append(new_token_ids['start_of_image']) packed_text_indexes.append(_curr) packed_indexes.append(curr) curr += 1 _curr += 1 image_tensor = transforms(image) vit_position_ids = self.get_flattened_position_ids( image_tensor.size(1), image_tensor.size(2), self.vit_patch_size, max_num_patches_per_side=self.vit_max_num_patch_per_side ) vit_tokens = patchify(image_tensor, self.vit_patch_size) packed_vit_tokens.append(vit_tokens) num_img_tokens = vit_tokens.shape[0] packed_vit_position_ids.append(vit_position_ids) vit_token_seqlens.append(num_img_tokens) packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens)) packed_indexes.extend(range(curr, curr + num_img_tokens)) curr += num_img_tokens _curr += num_img_tokens packed_text_ids.append(new_token_ids['end_of_image']) packed_text_indexes.append(_curr) packed_indexes.append(curr) curr += 1 _curr += 1 packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) packed_seqlens.append(num_img_tokens + 2) newlens.append(curr_kvlen + num_img_tokens + 2) new_rope.append(curr_position_id + 1) generation_input = { "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int), "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0), "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0), "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long), "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), } return generation_input, newlens, new_rope @torch.no_grad def forward_cache_update_vit( self, past_key_values: NaiveCache, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, packed_vit_tokens: torch.Tensor, packed_vit_token_indexes: torch.LongTensor, packed_vit_position_ids: torch.LongTensor, vit_token_seqlens: torch.IntTensor, packed_position_ids: torch.LongTensor, packed_seqlens: torch.IntTensor, packed_indexes: torch.LongTensor, packed_key_value_indexes: torch.LongTensor, key_values_lens: torch.IntTensor, ): packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) packed_sequence[packed_text_indexes] = packed_text_embedding cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0)) cu_seqlens = cu_seqlens.to(torch.int32) max_seqlen = torch.max(vit_token_seqlens).item() packed_vit_token_embed = self.vit_model( packed_pixel_values=packed_vit_tokens, packed_flattened_position_ids=packed_vit_position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) packed_vit_token_embed = self.connector(packed_vit_token_embed) pos_emb = self.vit_pos_embed(packed_vit_position_ids) packed_vit_token_embed = packed_vit_token_embed + pos_emb if packed_vit_token_embed.dtype != packed_sequence.dtype: packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype) packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed extra_inputs = {} if self.use_moe: extra_inputs = {"mode": "und"} output = self.language_model.forward_inference( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=packed_position_ids, packed_query_indexes=packed_indexes, past_key_values=past_key_values, packed_key_value_indexes=packed_key_value_indexes, key_values_lens=key_values_lens, update_past_key_values=True, is_causal=False, **extra_inputs, ) past_key_values = output.past_key_values return past_key_values def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0): patchified_vae_latent_shapes, packed_vae_position_ids = list(), list() packed_vae_token_indexes = list() packed_text_ids, packed_text_indexes = list(), list() packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() packed_key_value_indexes = list() _curr = curr = 0 vae_image_tensors = list() newlens, new_rope = list(), list() for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen packed_text_ids.append(new_token_ids['start_of_image']) packed_text_indexes.append(_curr) packed_indexes.append(curr) curr += 1 _curr += 1 image_tensor = transforms(image) vae_image_tensors.append(image_tensor) vae_posiiton_ids = self.get_flattened_position_ids( image_tensor.size(1), image_tensor.size(2), self.latent_downsample, max_num_patches_per_side=self.max_latent_size ) packed_vae_position_ids.append(vae_posiiton_ids) H, W = image_tensor.shape[1:] h = H // self.latent_downsample w = W // self.latent_downsample patchified_vae_latent_shapes.append((h, w)) num_img_tokens = w * h packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens)) packed_indexes.extend(range(curr, curr + num_img_tokens)) curr += num_img_tokens _curr += num_img_tokens packed_text_ids.append(new_token_ids['end_of_image']) packed_text_indexes.append(_curr) packed_indexes.append(curr) curr += 1 _curr += 1 packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) packed_seqlens.append(num_img_tokens + 2) newlens.append(curr_kvlen + num_img_tokens + 2) new_rope.append(curr_position_id + 1) image_sizes = [item.shape for item in vae_image_tensors] max_image_size = [max(item) for item in list(zip(*image_sizes))] padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size)) for i, image_tensor in enumerate(vae_image_tensors): padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor generation_input = { "padded_images": padded_images, "patchified_vae_latent_shapes": patchified_vae_latent_shapes, "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), "packed_timesteps": torch.tensor([timestep]), "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), } return generation_input, newlens, new_rope @torch.no_grad def forward_cache_update_vae( self, vae_model, past_key_values: NaiveCache, padded_images: torch.Tensor, patchified_vae_latent_shapes: List, packed_vae_position_ids: torch.LongTensor, packed_timesteps: torch.Tensor, packed_vae_token_indexes: torch.LongTensor, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, packed_position_ids: torch.LongTensor, packed_seqlens: torch.IntTensor, packed_indexes: torch.LongTensor, key_values_lens: torch.IntTensor, packed_key_value_indexes: torch.Tensor, ): packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) packed_sequence[packed_text_indexes] = packed_text_embedding padded_latent = vae_model.encode(padded_images) p = self.latent_patch_size packed_latent = list() for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes): latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p) latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel) packed_latent.append(latent) packed_latent = torch.cat(packed_latent, dim=0) packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) packed_timestep_embeds = self.time_embedder(packed_timesteps) packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed if packed_latent.dtype != packed_sequence.dtype: packed_latent = packed_latent.to(packed_sequence.dtype) packed_sequence[packed_vae_token_indexes] = packed_latent extra_inputs = {} if self.use_moe: extra_inputs = { "mode": "gen", "packed_vae_token_indexes": packed_vae_token_indexes, "packed_text_indexes": packed_text_indexes } output = self.language_model.forward_inference( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=packed_position_ids, packed_query_indexes=packed_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=True, is_causal=False, **extra_inputs, ) past_key_values = output.past_key_values return past_key_values def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids): packed_text_ids, packed_text_indexes = list(), list() packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list() packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list() packed_key_value_indexes = list() query_curr = curr = 0 for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen packed_text_ids.append(new_token_ids['start_of_image']) packed_text_indexes.append(query_curr) packed_indexes.append(curr) curr += 1 query_curr += 1 vae_posiiton_ids = self.get_flattened_position_ids( H, W, self.latent_downsample, max_num_patches_per_side=self.max_latent_size ) packed_vae_position_ids.append(vae_posiiton_ids) h, w = H // self.latent_downsample, W // self.latent_downsample num_image_tokens = h * w packed_init_noises.append( torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size ** 2) ) packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens)) packed_indexes.extend(range(curr, curr + num_image_tokens)) curr += num_image_tokens query_curr += num_image_tokens packed_text_ids.append(new_token_ids['end_of_image']) packed_text_indexes.append(query_curr) packed_indexes.append(curr) curr += 1 query_curr += 1 packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) packed_seqlens.append(num_image_tokens + 2) generation_input = { "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), "packed_init_noises": torch.cat(packed_init_noises, dim=0), "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), } return generation_input def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes): packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list() query_curr = curr = 0 for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen packed_indexes.append(curr) curr += 1 query_curr += 1 h, w = H // self.latent_downsample, W // self.latent_downsample num_image_tokens = h * w packed_indexes.extend(range(curr, curr + num_image_tokens)) curr += num_image_tokens query_curr += num_image_tokens packed_indexes.append(curr) curr += 1 query_curr += 1 packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) generation_input = { "cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), "cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), "cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long), "cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), } return generation_input @torch.no_grad def generate_image( self, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, packed_init_noises: torch.Tensor, packed_vae_position_ids: torch.LongTensor, packed_vae_token_indexes: torch.LongTensor, packed_seqlens: torch.IntTensor, packed_position_ids: torch.LongTensor, packed_indexes: torch.LongTensor, past_key_values: NaiveCache, key_values_lens: torch.IntTensor, packed_key_value_indexes: torch.LongTensor, num_timesteps: int = 24, timestep_shift: float = 1.0, cfg_renorm_min: float = 0.0, cfg_renorm_type: str = "global", cfg_interval: Optional[Tuple[float, float]] = [0, 1], # cfg_text cfg_text_scale: float = 1.0, cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None, cfg_text_packed_position_ids: Optional[torch.LongTensor] = None, cfg_text_past_key_values: Optional[NaiveCache] = None, cfg_text_key_values_lens: Optional[torch.IntTensor] = None, cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None, # cfg_img cfg_img_scale: float = 1.0, cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None, cfg_img_packed_position_ids: Optional[torch.LongTensor] = None, cfg_img_past_key_values: Optional[NaiveCache] = None, cfg_img_key_values_lens: Optional[torch.IntTensor] = None, cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None, cfg_type: str = "parallel", ): x_t = packed_init_noises timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device) timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps) dts = timesteps[:-1] - timesteps[1:] timesteps = timesteps[:-1] for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) if t > cfg_interval[0] and t <= cfg_interval[1]: cfg_text_scale_ = cfg_text_scale cfg_img_scale_ = cfg_img_scale else: cfg_text_scale_ = 1.0 cfg_img_scale_ = 1.0 v_t = self._forward_flow( x_t=x_t, timestep=timestep, packed_vae_token_indexes=packed_vae_token_indexes, packed_vae_position_ids=packed_vae_position_ids, packed_text_ids=packed_text_ids, packed_text_indexes=packed_text_indexes, packed_position_ids=packed_position_ids, packed_indexes=packed_indexes, packed_seqlens=packed_seqlens, key_values_lens=key_values_lens, past_key_values=past_key_values, packed_key_value_indexes=packed_key_value_indexes, cfg_renorm_min=cfg_renorm_min, cfg_renorm_type=cfg_renorm_type, # cfg_text cfg_text_scale=cfg_text_scale_, cfg_text_packed_position_ids=cfg_text_packed_position_ids, cfg_text_packed_query_indexes=cfg_text_packed_query_indexes, cfg_text_key_values_lens=cfg_text_key_values_lens, cfg_text_past_key_values=cfg_text_past_key_values, cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes, # cfg_img cfg_img_scale=cfg_img_scale_, cfg_img_packed_position_ids=cfg_img_packed_position_ids, cfg_img_packed_query_indexes=cfg_img_packed_query_indexes, cfg_img_key_values_lens=cfg_img_key_values_lens, cfg_img_past_key_values=cfg_img_past_key_values, cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes, cfg_type=cfg_type, ) x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) return unpacked_latent @torch.no_grad def _forward_flow( self, x_t: torch.Tensor, timestep: torch.LongTensor, packed_vae_token_indexes: torch.LongTensor, packed_vae_position_ids: torch.LongTensor, packed_text_ids: torch.LongTensor, packed_text_indexes: torch.LongTensor, packed_indexes: torch.LongTensor, packed_position_ids: torch.LongTensor, packed_seqlens: torch.IntTensor, key_values_lens: torch.IntTensor, past_key_values: NaiveCache, packed_key_value_indexes: torch.LongTensor, cfg_renorm_min: float = 0.0, cfg_renorm_type: str = "global", # cfg_text cfg_text_scale: float = 1.0, cfg_text_packed_position_ids: Optional[torch.LongTensor] = None, cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None, cfg_text_key_values_lens: Optional[torch.Tensor] = None, cfg_text_past_key_values: Optional[NaiveCache] = None, cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None, # cfg_img cfg_img_scale: float = 1.0, cfg_img_packed_position_ids: Optional[torch.LongTensor] = None, cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None, cfg_img_key_values_lens: Optional[torch.Tensor] = None, cfg_img_past_key_values: Optional[NaiveCache] = None, cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None, cfg_type: str = "parallel", ): packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) packed_sequence[packed_text_indexes] = packed_text_embedding assert timestep.unique().shape[0] == 1 packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) packed_timestep_embeds = self.time_embedder(timestep) x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed if x_t.dtype != packed_sequence.dtype: x_t = x_t.to(packed_sequence.dtype) packed_sequence[packed_vae_token_indexes] = x_t extra_inputs = {} if self.use_moe: extra_inputs = { "mode": "gen", "packed_vae_token_indexes": packed_vae_token_indexes, "packed_text_indexes": packed_text_indexes } output = self.language_model.forward_inference( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=packed_position_ids, packed_query_indexes=packed_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=False, is_causal=False, **extra_inputs, ) v_t = self.llm2vae(output.packed_query_sequence) v_t = v_t[packed_vae_token_indexes] if cfg_text_scale > 1.0: cfg_text_output = self.language_model.forward_inference( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=cfg_text_packed_position_ids, packed_query_indexes=cfg_text_packed_query_indexes, past_key_values=cfg_text_past_key_values, key_values_lens=cfg_text_key_values_lens, packed_key_value_indexes=cfg_text_packed_key_value_indexes, update_past_key_values=False, is_causal=False, **extra_inputs, ) cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence) cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes] if cfg_img_scale > 1.0: cfg_img_output = self.language_model.forward_inference( packed_query_sequence=packed_sequence, query_lens=packed_seqlens, packed_query_position_ids=cfg_img_packed_position_ids, packed_query_indexes=cfg_img_packed_query_indexes, past_key_values=cfg_img_past_key_values, key_values_lens=cfg_img_key_values_lens, packed_key_value_indexes=cfg_img_packed_key_value_indexes, update_past_key_values=False, is_causal=False, **extra_inputs, ) cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence) cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes] if cfg_text_scale > 1.0: if cfg_renorm_type == "text_channel": v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True) scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) v_t_text = v_t_text_ * scale if cfg_img_scale > 1.0: v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t) else: v_t = v_t_text else: v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) if cfg_img_scale > 1.0: v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t) else: v_t_ = v_t_text_ # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit if cfg_renorm_type == "global": norm_v_t = torch.norm(v_t) norm_v_t_ = torch.norm(v_t_) elif cfg_renorm_type == "channel": norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True) else: raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted") scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) v_t = v_t_ * scale else: # No CFG pass return v_t def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids): packed_start_tokens, packed_key_value_indexes = list(), list() packed_query_position_ids = list() curr = 0 for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) packed_start_tokens.append(new_token_ids['bos_token_id']) packed_query_position_ids.append(curr_position_id) curr += curr_kvlen generation_input = { "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long), "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long), "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), } return generation_input @torch.no_grad def generate_text( self, past_key_values: NaiveCache, packed_key_value_indexes: torch.LongTensor, key_values_lens: torch.IntTensor, packed_start_tokens: torch.LongTensor, packed_query_position_ids: torch.LongTensor, max_length: int, do_sample: bool = False, temperature: float = 1.0, end_token_id: int = None, ): step = 0 generated_sequence = [] curr_tokens = packed_start_tokens while step < max_length: generated_sequence.append(curr_tokens) packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens) query_lens = torch.ones_like(curr_tokens) packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange( 0, len(key_values_lens), device=key_values_lens.device, dtype=key_values_lens.dtype ) uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) for i in range(len(uppacked)): uppacked[i] += i packed_key_value_indexes = torch.cat(uppacked, dim=0) extra_inputs = {} if self.use_moe: extra_inputs = {"mode": "und"} output = self.language_model.forward_inference( packed_query_sequence=packed_text_embedding, query_lens=query_lens, packed_query_position_ids=packed_query_position_ids, packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=True, is_causal=True, **extra_inputs, ) past_key_values = output.past_key_values packed_query_sequence = output.packed_query_sequence pred_logits = self.language_model.lm_head(packed_query_sequence) if do_sample: probs = nn.functional.softmax(pred_logits / temperature, dim=-1) curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: curr_tokens = torch.argmax(pred_logits, dim=-1) uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) for i in range(len(uppacked)): uppacked[i] = torch.cat( [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0 ) packed_key_value_indexes = torch.cat(uppacked, dim=0) key_values_lens = key_values_lens + 1 packed_query_position_ids = packed_query_position_ids + 1 step += 1 if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1 break output_device = generated_sequence[0].device return torch.stack([i.to(output_device) for i in generated_sequence], dim=0) # for evaluation @torch.no_grad() def chat( self, tokenizer, new_token_ids, image_transform, images, prompt, max_length: int, do_sample: bool = False, temperature: float = 1.0, ): device = next(self.parameters()).device if isinstance(new_token_ids, dict): for k, v in new_token_ids.items(): if torch.is_tensor(v): new_token_ids[k] = v.to(device) elif torch.is_tensor(new_token_ids): new_token_ids = new_token_ids.to(device) # prefill past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers) newlens = [0] new_rope = [0] # add images for image in images: generation_input, newlens, new_rope = self.prepare_vit_images( curr_kvlens=newlens, curr_rope=new_rope, images=[image], transforms=image_transform, new_token_ids=new_token_ids, ) for k, v in generation_input.items(): if torch.is_tensor(v): generation_input[k] = v.to(device) with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): past_key_values = self.forward_cache_update_vit(past_key_values, **generation_input) # add text generation_input, newlens, new_rope = self.prepare_prompts( curr_kvlens=newlens, curr_rope=new_rope, prompts=[prompt], tokenizer=tokenizer, new_token_ids=new_token_ids, ) for k, v in generation_input.items(): if torch.is_tensor(v): generation_input[k] = v.to(device) with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): past_key_values = self.forward_cache_update_text(past_key_values, **generation_input) # decode generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids) for k, v in generation_input.items(): if torch.is_tensor(v): generation_input[k] = v.to(device) with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): unpacked_latent = self.generate_text( past_key_values=past_key_values, max_length=max_length, do_sample=do_sample, temperature=temperature, end_token_id=new_token_ids['eos_token_id'], **generation_input, ) output = tokenizer.decode(unpacked_latent[:,0]) output = output.split('<|im_end|>')[0].split('<|im_start|>')[1] return output def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb class MLPconnector(nn.Module): def __init__(self, in_dim: int, out_dim: int, hidden_act: str): super().__init__() self.activation_fn = ACT2FN[hidden_act] self.fc1 = nn.Linear(in_dim, out_dim) self.fc2 = nn.Linear(out_dim, out_dim) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class PositionEmbedding(nn.Module): def __init__(self, max_num_patch_per_side, hidden_size): super().__init__() self.max_num_patch_per_side = max_num_patch_per_side self.hidden_size = hidden_size self.pos_embed = nn.Parameter( torch.zeros(max_num_patch_per_side ** 2, hidden_size), requires_grad=False ) self._init_weights() def _init_weights(self): # Initialize (and freeze) pos_embed by sin-cos embedding: pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float()) def forward(self, position_ids): return self.pos_embed[position_ids] class Qwen2Config(_Qwen2Config): r""" This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 151936): Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Qwen2Model`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 22016): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer encoder. num_key_value_heads (`int`, *optional*, defaults to 32): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 32768): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be tied. rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. Expected contents: `rope_type` (`str`): The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. `factor` (`float`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. `original_max_position_embeddings` (`int`, *optional*): Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during pretraining. `attention_factor` (`float`, *optional*): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value. `beta_fast` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32. `beta_slow` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. `short_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to short contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `long_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to long contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `low_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE use_sliding_window (`bool`, *optional*, defaults to `False`): Whether to use sliding window attention. sliding_window (`int`, *optional*, defaults to 4096): Sliding window attention (SWA) window size. If not specified, will default to `4096`. max_window_layers (`int`, *optional*, defaults to 28): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. ```python >>> from transformers import Qwen2Model, Qwen2Config >>> # Initializing a Qwen2 style configuration >>> configuration = Qwen2Config() >>> # Initializing a model from the Qwen2-7B style configuration >>> model = Qwen2Model(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "qwen2" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, is_causal=True, _attn_implementation="flash_attention_2", qk_norm=True, layer_module="Qwen2DecoderLayer", freeze_und=False, **kwargs, ): super().__init__( vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, initializer_range=initializer_range, rms_norm_eps=rms_norm_eps, use_cache=use_cache, tie_word_embeddings=tie_word_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, use_sliding_window=use_sliding_window, sliding_window=sliding_window, max_window_layers=max_window_layers, attention_dropout=attention_dropout, is_causal=is_causal, _attn_implementation=_attn_implementation, **kwargs, ) self.qk_norm = qk_norm self.layer_module = layer_module self.freeze_und = freeze_und @dataclass class BaseNavitOutputWithPast(ModelOutput): packed_query_sequence: torch.FloatTensor = None past_key_values: Optional[NaiveCache] = None def pad_sequence(tensor, pad_size): H, L, D = tensor.shape pad_tensor = tensor.new_zeros((H, pad_size, D)) return torch.cat([tensor, pad_tensor], dim=1) class PackedAttention(Qwen2Attention): def __init__(self, config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) if self.config.qk_norm: self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) else: self.q_norm = nn.Identity() self.k_norm = nn.Identity() def forward(self, *args, **kwargs): if self.training: return self.forward_train(*args, **kwargs) else: return self.forward_inference(*args, **kwargs) def forward_train( self, packed_sequence: torch.Tensor, sample_lens: List[int], attention_mask: List[torch.Tensor], packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], ): packed_query_states = self.q_proj(packed_sequence).view(-1, self.num_heads, self.head_dim) packed_key_states = self.k_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim) packed_value_states = self.v_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim) packed_query_states = self.q_norm(packed_query_states) packed_key_states = self.k_norm(packed_key_states) packed_cos, packed_sin = packed_position_embeddings packed_query_states, packed_key_states = apply_rotary_pos_emb( packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 ) if isinstance(attention_mask, List): packed_key_states = packed_key_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) packed_key_states = packed_key_states.reshape(-1, self.num_heads, self.head_dim) packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim) unpacked_query_states = packed_query_states.transpose(0, 1).split(sample_lens, dim=1) unpacked_key_states = packed_key_states.transpose(0, 1).split(sample_lens, dim=1) unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1) upacked_attn_output = [] for query_states, key_states, value_states, attention_mask_per_sample in zip( unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask ): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): attn_output = scaled_dot_product_attention( query_states.to(torch.bfloat16).unsqueeze(0), key_states.to(torch.bfloat16).unsqueeze(0), value_states.to(torch.bfloat16).unsqueeze(0), attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0), ) upacked_attn_output.append(attn_output.squeeze(0)) packed_attn_output = torch.cat(upacked_attn_output, dim=1) else: pad_size = sum(sample_lens) - packed_query_states.shape[0] packed_query_states = pad_sequence(packed_query_states.permute(1, 0, 2), pad_size) packed_key_states = pad_sequence(packed_key_states.permute(1, 0, 2), pad_size) packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size) packed_attn_output = flex_attention( packed_query_states.unsqueeze(0), packed_key_states.unsqueeze(0), packed_value_states.unsqueeze(0), enable_gqa=True, block_mask=attention_mask, ) end_index = packed_attn_output.shape[2] - pad_size packed_attn_output = packed_attn_output[0, :, :end_index, :] packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.hidden_size) packed_attn_output = self.o_proj(packed_attn_output) return packed_attn_output def forward_inference( self, packed_query_sequence: torch.Tensor, query_lens: torch.Tensor, packed_query_position_embeddings: torch.Tensor, packed_query_indexes: torch.Tensor, past_key_values: Optional[NaiveCache] = None, key_values_lens: Optional[torch.Tensor] = None, packed_key_value_indexes: Optional[torch.Tensor] = None, update_past_key_values=True, is_causal=True, ): packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) packed_query_states = self.q_norm(packed_query_states) packed_key_states = self.k_norm(packed_key_states) packed_cos, packed_sin = packed_query_position_embeddings packed_query_states, packed_key_states = apply_rotary_pos_emb( packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 ) packed_query_states = packed_query_states.to(torch.bfloat16) packed_key_states = packed_key_states.to(torch.bfloat16) packed_value_states = packed_value_states.to(torch.bfloat16) if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: past_key_states = past_key_values.key_cache[self.layer_idx] past_value_states = past_key_values.value_cache[self.layer_idx] seqlens = sum(query_lens) + sum(key_values_lens) merged_key_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim)) merged_value_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim)) merged_key_states[packed_query_indexes] = packed_key_states merged_key_states[packed_key_value_indexes] = past_key_states merged_value_states[packed_query_indexes] = packed_value_states merged_value_states[packed_key_value_indexes] = past_value_states key_values_lens = key_values_lens + query_lens else: merged_key_states = packed_key_states merged_value_states = packed_value_states key_values_lens = query_lens cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) packed_attn_output = flash_attn_varlen_func( q=packed_query_states, k=merged_key_states, v=merged_value_states, cu_seqlens_q=cu_seqlens_q.to(torch.int32), cu_seqlens_k=cu_seqlens_k.to(torch.int32), max_seqlen_q=max(query_lens).item(), max_seqlen_k=max(key_values_lens).item(), causal=is_causal, ) packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) packed_attn_output = self.o_proj(packed_attn_output) if update_past_key_values: past_key_values.key_cache[self.layer_idx] = merged_key_states past_key_values.value_cache[self.layer_idx] = merged_value_states return packed_attn_output, past_key_values class PackedAttentionMoT(Qwen2Attention): def __init__(self, config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) if self.config.qk_norm: self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) else: self.q_norm = nn.Identity() self.k_norm = nn.Identity() self.q_norm_moe_gen = nn.Identity() self.k_norm_moe_gen = nn.Identity() self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def forward(self, *args, **kwargs): if self.training: return self.forward_train(*args, **kwargs) else: return self.forward_inference(*args, **kwargs) def forward_train( self, packed_sequence: torch.Tensor, sample_lens: List[int], attention_mask, packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], packed_und_token_indexes: torch.LongTensor, packed_gen_token_indexes: torch.LongTensor, ): packed_query_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_heads * self.head_dim)) packed_key_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)) packed_value_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)) packed_sequence_und = packed_sequence[packed_und_token_indexes] packed_sequence_gen = packed_sequence[packed_gen_token_indexes] packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und) packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(packed_sequence_gen) packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und) packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(packed_sequence_gen) packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und) packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(packed_sequence_gen) packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) if self.config.freeze_und: packed_value_states[packed_und_token_indexes] = packed_value_states[packed_und_token_indexes].detach() packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape) packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape) packed_query_states_[packed_und_token_indexes] = self.q_norm(packed_query_states[packed_und_token_indexes]) if self.config.freeze_und: packed_query_states_[packed_und_token_indexes] = packed_query_states_[packed_und_token_indexes].detach() packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_gen_token_indexes]) packed_key_states_[packed_und_token_indexes] = self.k_norm(packed_key_states[packed_und_token_indexes]) if self.config.freeze_und: packed_key_states_[packed_und_token_indexes] = packed_key_states_[packed_und_token_indexes].detach() packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_gen_token_indexes]) packed_cos, packed_sin = packed_position_embeddings packed_query_states_, packed_key_states_ = apply_rotary_pos_emb( packed_query_states_, packed_key_states_, packed_cos, packed_sin, unsqueeze_dim=1 ) if isinstance(attention_mask, List): packed_key_states_ = packed_key_states_[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) packed_key_states_ = packed_key_states_.reshape(-1, self.num_heads, self.head_dim) packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim) unpacked_query_states = packed_query_states_.transpose(0, 1).split(sample_lens, dim=1) unpacked_key_states = packed_key_states_.transpose(0, 1).split(sample_lens, dim=1) unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1) upacked_attn_output = [] for query_states, key_states, value_states, attention_mask_per_sample in zip( unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask ): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): attn_output = scaled_dot_product_attention( query_states.to(torch.bfloat16).unsqueeze(0), key_states.to(torch.bfloat16).unsqueeze(0), value_states.to(torch.bfloat16).unsqueeze(0), attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0), ) upacked_attn_output.append(attn_output.squeeze(0)) packed_attn_output = torch.cat(upacked_attn_output, dim=1) else: pad_size = sum(sample_lens) - packed_query_states.shape[0] packed_query_states_ = pad_sequence(packed_query_states_.permute(1, 0, 2), pad_size) packed_key_states_ = pad_sequence(packed_key_states_.permute(1, 0, 2), pad_size) packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size) packed_attn_output = flex_attention( packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim packed_key_states_.unsqueeze(0), packed_value_states.unsqueeze(0), enable_gqa=True, block_mask=attention_mask, ) end_index = packed_attn_output.shape[2] - pad_size packed_attn_output = packed_attn_output[0, :, :end_index, :] packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.num_heads * self.head_dim) packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape) packed_attn_output_[packed_und_token_indexes] = self.o_proj(packed_attn_output[packed_und_token_indexes]) packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_gen_token_indexes]) return packed_attn_output_ def forward_inference( self, packed_query_sequence: torch.Tensor, query_lens: torch.Tensor, packed_query_position_embeddings: torch.Tensor, packed_query_indexes: torch.Tensor, past_key_values: Optional[NaiveCache] = None, key_values_lens: Optional[torch.Tensor] = None, packed_key_value_indexes: Optional[torch.Tensor] = None, update_past_key_values=True, is_causal=True, mode="und", packed_vae_token_indexes=None, packed_text_indexes=None, ): if mode == 'und': packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) packed_query_states = self.q_norm(packed_query_states) packed_key_states = self.k_norm(packed_key_states) elif mode == 'gen': packed_query_sequence = packed_query_sequence.to(torch.bfloat16) packed_query_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_heads * self.head_dim)) packed_key_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim)) packed_value_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim)) packed_text_query_sequence = packed_query_sequence[packed_text_indexes] packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence) packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence) packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence) packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence) packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence) packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence) packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) packed_query_states = packed_query_states.to(torch.float32) packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes]) packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_vae_token_indexes]) packed_key_states = packed_key_states.to(torch.float32) packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes]) packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_vae_token_indexes]) packed_cos, packed_sin = packed_query_position_embeddings packed_query_states, packed_key_states = apply_rotary_pos_emb( packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 ) packed_query_states = packed_query_states.to(torch.bfloat16) packed_key_states = packed_key_states.to(torch.bfloat16) packed_value_states = packed_value_states.to(torch.bfloat16) if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: past_key_states = past_key_values.key_cache[self.layer_idx] past_value_states = past_key_values.value_cache[self.layer_idx] seqlens = sum(query_lens) + sum(key_values_lens) merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) merged_key_states[packed_query_indexes] = packed_key_states merged_key_states[packed_key_value_indexes] = past_key_states merged_value_states[packed_query_indexes] = packed_value_states merged_value_states[packed_key_value_indexes] = past_value_states key_values_lens = key_values_lens + query_lens else: merged_key_states = packed_key_states merged_value_states = packed_value_states key_values_lens = query_lens cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) packed_attn_output = flash_attn_varlen_func( q=packed_query_states, k=merged_key_states, v=merged_value_states, cu_seqlens_q=cu_seqlens_q.to(torch.int32), cu_seqlens_k=cu_seqlens_k.to(torch.int32), max_seqlen_q=max(query_lens).item(), max_seqlen_k=max(key_values_lens).item(), causal=is_causal, ) packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) if mode == 'und': packed_attn_output = self.o_proj(packed_attn_output) elif mode == 'gen': packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes]) packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes]) if update_past_key_values: past_key_values.key_cache[self.layer_idx] = merged_key_states past_key_values.value_cache[self.layer_idx] = merged_value_states return packed_attn_output, past_key_values class Qwen2DecoderLayer(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size self.self_attn = PackedAttention(config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, *args, **kwargs): if self.training: return self.forward_train(*args, **kwargs) else: return self.forward_inference(*args, **kwargs) def forward_train( self, packed_sequence: torch.Tensor, sample_lens: List[int], attention_mask, packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], ) -> torch.Tensor: residual = packed_sequence packed_sequence = self.input_layernorm(packed_sequence) # Self Attention packed_sequence = self.self_attn( packed_sequence=packed_sequence, sample_lens=sample_lens, attention_mask=attention_mask, packed_position_embeddings=packed_position_embeddings, ) packed_sequence = residual + packed_sequence # Fully Connected residual = packed_sequence packed_sequence = self.post_attention_layernorm(packed_sequence) packed_sequence = self.mlp(packed_sequence) packed_sequence = residual + packed_sequence return packed_sequence def forward_inference( self, packed_query_sequence: torch.Tensor, query_lens: torch.Tensor, packed_query_position_embeddings: torch.Tensor, packed_query_indexes: torch.Tensor, past_key_values: Optional[NaiveCache] = None, key_values_lens: Optional[torch.Tensor] = None, packed_key_value_indexes: Optional[torch.Tensor] = None, update_past_key_values=True, is_causal=True, ) -> BaseNavitOutputWithPast: residual = packed_query_sequence packed_query_sequence = self.input_layernorm(packed_query_sequence) # Self Attention packed_query_sequence, past_key_values = self.self_attn( packed_query_sequence=packed_query_sequence, query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=update_past_key_values, is_causal=is_causal, ) packed_query_sequence = residual + packed_query_sequence # Fully Connected residual = packed_query_sequence packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) packed_query_sequence = self.mlp(packed_query_sequence) packed_query_sequence = residual + packed_query_sequence return packed_query_sequence, past_key_values class Qwen2MoTDecoderLayer(nn.Module): def __init__( self, config, layer_idx: Optional[int] = None, attn_module: Optional[Qwen2Attention] = PackedAttentionMoT, ): super().__init__() self.hidden_size = config.hidden_size self.freeze_und = config.freeze_und self.self_attn = attn_module(config, layer_idx) self.mlp = Qwen2MLP(config) self.mlp_moe_gen = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, *args, **kwargs): if self.training: return self.forward_train(*args, **kwargs) else: return self.forward_inference(*args, **kwargs) def forward_train( self, packed_sequence: torch.Tensor, sample_lens: List[int], attention_mask, packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], packed_und_token_indexes: torch.LongTensor, packed_gen_token_indexes: torch.LongTensor, ) -> torch.Tensor: residual = packed_sequence packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape) packed_sequence_[packed_und_token_indexes] = self.input_layernorm(packed_sequence[packed_und_token_indexes]) packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes]) # Self Attention packed_sequence_ = self.self_attn( packed_sequence=packed_sequence_, sample_lens=sample_lens, attention_mask=attention_mask, packed_position_embeddings=packed_position_embeddings, packed_und_token_indexes=packed_und_token_indexes, packed_gen_token_indexes=packed_gen_token_indexes, ) if self.freeze_und: packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() packed_sequence = residual + packed_sequence_ # Fully Connected residual = packed_sequence packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape) packed_sequence_[packed_und_token_indexes] = self.mlp( self.post_attention_layernorm(packed_sequence[packed_und_token_indexes]) ) if self.freeze_und: packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() packed_sequence_[packed_gen_token_indexes] = self.mlp_moe_gen( self.post_attention_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes]) ) packed_sequence = residual + packed_sequence_ return packed_sequence def forward_inference( self, packed_query_sequence: torch.Tensor, query_lens: torch.Tensor, packed_query_position_embeddings: torch.Tensor, packed_query_indexes: torch.Tensor, past_key_values: Optional[NaiveCache] = None, key_values_lens: Optional[torch.Tensor] = None, packed_key_value_indexes: Optional[torch.Tensor] = None, update_past_key_values=True, is_causal=True, mode="und", packed_vae_token_indexes=None, packed_text_indexes=None, ) -> BaseNavitOutputWithPast: residual = packed_query_sequence if mode == "und": packed_query_sequence = self.input_layernorm(packed_query_sequence) elif mode == "gen": packed_query_sequence_ = torch.zeros_like(packed_query_sequence) packed_query_sequence_[packed_text_indexes] = self.input_layernorm(packed_query_sequence[packed_text_indexes]) packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen(packed_query_sequence[packed_vae_token_indexes]) packed_query_sequence = packed_query_sequence_ # Self Attention packed_query_sequence, past_key_values = self.self_attn( packed_query_sequence=packed_query_sequence, query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=update_past_key_values, is_causal=is_causal, mode=mode, packed_vae_token_indexes=packed_vae_token_indexes, packed_text_indexes=packed_text_indexes, ) packed_query_sequence = residual + packed_query_sequence # Fully Connected residual = packed_query_sequence if mode == "und": packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) packed_query_sequence = self.mlp(packed_query_sequence) elif mode == "gen": packed_text_query_sequence = packed_query_sequence[packed_text_indexes] packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16) packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to(torch.bfloat16) packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence) packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence) packed_query_sequence = packed_query_sequence_ packed_query_sequence = residual + packed_query_sequence return packed_query_sequence, past_key_values class Qwen2MoEDecoderLayer(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size self.self_attn = PackedAttention(config, layer_idx) self.mlp = Qwen2MLP(config) self.mlp_moe_gen = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, *args, **kwargs): if self.training: return self.forward_train(*args, **kwargs) else: return self.forward_inference(*args, **kwargs) def forward_train( self, packed_sequence: torch.Tensor, sample_lens: List[int], attention_mask, packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], packed_und_token_indexes: torch.LongTensor, packed_gen_token_indexes: torch.LongTensor, ) -> torch.Tensor: residual = packed_sequence packed_sequence = self.input_layernorm(packed_sequence) # Self Attention packed_sequence = self.self_attn( packed_sequence=packed_sequence, sample_lens=sample_lens, attention_mask=attention_mask, packed_position_embeddings=packed_position_embeddings, ) packed_sequence = residual + packed_sequence # Fully Connected residual = packed_sequence packed_sequence = self.post_attention_layernorm(packed_sequence) packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape) packed_sequence_und = self.mlp(packed_sequence[packed_und_token_indexes]) packed_sequence_gen = self.mlp_moe_gen(packed_sequence[packed_gen_token_indexes]) packed_sequence_new[packed_und_token_indexes] = packed_sequence_und packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen packed_sequence = residual + packed_sequence_new return packed_sequence def forward_inference( self, packed_query_sequence: torch.Tensor, query_lens: torch.Tensor, packed_query_position_embeddings: torch.Tensor, packed_query_indexes: torch.Tensor, past_key_values: Optional[NaiveCache] = None, key_values_lens: Optional[torch.Tensor] = None, packed_key_value_indexes: Optional[torch.Tensor] = None, update_past_key_values=True, is_causal=True, mode="und", packed_vae_token_indexes=None, packed_text_indexes=None, ) -> BaseNavitOutputWithPast: residual = packed_query_sequence packed_query_sequence = self.input_layernorm(packed_query_sequence) # Self Attention packed_query_sequence, past_key_values = self.self_attn( packed_query_sequence=packed_query_sequence, query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=update_past_key_values, is_causal=is_causal, ) packed_query_sequence = residual + packed_query_sequence # Fully Connected residual = packed_query_sequence packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) if mode == "und": packed_query_sequence = self.mlp(packed_query_sequence) elif mode == "gen": packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) packed_query_sequence_[packed_text_indexes] = self.mlp(packed_query_sequence[packed_text_indexes]) packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_query_sequence[packed_vae_token_indexes]) packed_query_sequence = packed_query_sequence_ packed_query_sequence = residual + packed_query_sequence return packed_query_sequence, past_key_values Decoder_layer_dict = { "Qwen2DecoderLayer": Qwen2DecoderLayer, "Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer, "Qwen2MoTDecoderLayer": partial(Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT), } class Qwen2Model(Qwen2PreTrainedModel): def __init__(self, config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.use_moe = 'Mo' in config.layer_module self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) layer_module = Decoder_layer_dict[config.layer_module] self.layers = nn.ModuleList( [layer_module(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.use_moe: self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() def forward(self, *args, **kwargs): if self.training: return self.forward_train(*args, **kwargs) else: return self.forward_inference(*args, **kwargs) def forward_train( self, packed_sequence: torch.Tensor, sample_lens: List[int], attention_mask, packed_position_ids: torch.Tensor, packed_und_token_indexes: Optional[torch.LongTensor] = None, packed_gen_token_indexes: Optional[torch.LongTensor] = None, ) -> torch.Tensor: if self.config.freeze_und: packed_sequence[packed_und_token_indexes] = packed_sequence[packed_und_token_indexes].detach() # create position embeddings to be shared across the decoder layers cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0)) cos = cos.squeeze(0) sin = sin.squeeze(0) packed_position_embeddings = (cos, sin) extra_inputs = {} if self.use_moe: assert packed_und_token_indexes is not None if packed_gen_token_indexes is None: packed_gen_token_indexes = packed_und_token_indexes.new_ones(size=[0]) extra_inputs.update( packed_und_token_indexes=packed_und_token_indexes, packed_gen_token_indexes=packed_gen_token_indexes, ) for decoder_layer in self.layers: packed_sequence = decoder_layer( packed_sequence=packed_sequence, sample_lens=sample_lens, attention_mask=attention_mask, packed_position_embeddings=packed_position_embeddings, **extra_inputs ) if self.use_moe: packed_sequence_ = torch.zeros_like(packed_sequence) packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes]) if self.config.freeze_und: packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(packed_sequence[packed_gen_token_indexes]) return packed_sequence_ else: return self.norm(packed_sequence) def forward_inference( self, packed_query_sequence: torch.Tensor, query_lens: torch.Tensor, packed_query_position_ids: torch.Tensor, packed_query_indexes: torch.Tensor, past_key_values: Optional[NaiveCache] = None, key_values_lens: Optional[torch.Tensor] = None, packed_key_value_indexes: Optional[torch.Tensor] = None, update_past_key_values=True, is_causal=True, mode="und", packed_vae_token_indexes=None, packed_text_indexes=None, ) -> BaseNavitOutputWithPast: # create position embeddings to be shared across the decoder layers cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0)) cos = cos.squeeze(0) sin = sin.squeeze(0) packed_query_position_embeddings = (cos, sin) extra_inputs = {} if self.use_moe: extra_inputs.update(mode=mode) if mode == 'gen': assert packed_vae_token_indexes is not None assert packed_text_indexes is not None extra_inputs.update( packed_vae_token_indexes=packed_vae_token_indexes, packed_text_indexes=packed_text_indexes, ) for decoder_layer in self.layers: packed_query_sequence, past_key_values = decoder_layer( packed_query_sequence=packed_query_sequence, query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=update_past_key_values, is_causal=is_causal, **extra_inputs, ) if self.use_moe: if mode == "und": packed_query_sequence = self.norm(packed_query_sequence) elif mode == "gen": packed_query_sequence_ = torch.zeros_like(packed_query_sequence) packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes]) packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(packed_query_sequence[packed_vae_token_indexes]) packed_query_sequence = packed_query_sequence_ else: packed_query_sequence = self.norm(packed_query_sequence) return BaseNavitOutputWithPast( packed_query_sequence=packed_query_sequence, past_key_values=past_key_values, ) class Qwen2ForCausalLM(Qwen2PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = Qwen2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def init_moe(self): for name, param in self.named_parameters(): if "moe_gen" in name: original_name = name.replace("_moe_gen", "") param.data.copy_(self.state_dict()[original_name].data) def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward(self, *args, **kwargs): if self.training: return self.forward_train(*args, **kwargs) else: return self.forward_inference(*args, **kwargs) def forward_train( self, packed_sequence: torch.Tensor, sample_lens: List[int], attention_mask, packed_position_ids: torch.Tensor, packed_und_token_indexes: Optional[torch.LongTensor] = None, packed_gen_token_indexes: Optional[torch.LongTensor] = None, ) -> torch.Tensor: outputs = self.model( packed_sequence=packed_sequence, sample_lens=sample_lens, packed_position_ids=packed_position_ids, attention_mask=attention_mask, packed_und_token_indexes=packed_und_token_indexes, packed_gen_token_indexes=packed_gen_token_indexes, ) return outputs def forward_inference( self, packed_query_sequence: torch.Tensor, query_lens: torch.Tensor, packed_query_position_ids: torch.Tensor, packed_query_indexes: torch.Tensor, past_key_values: Optional[NaiveCache] = None, key_values_lens: Optional[torch.Tensor] = None, packed_key_value_indexes: Optional[torch.Tensor] = None, update_past_key_values=True, is_causal=True, mode="und", packed_vae_token_indexes=None, packed_text_indexes=None, ) -> BaseNavitOutputWithPast: outputs = self.model( packed_query_sequence=packed_query_sequence, query_lens=query_lens, packed_query_position_ids=packed_query_position_ids, packed_query_indexes=packed_query_indexes, past_key_values=past_key_values, key_values_lens=key_values_lens, packed_key_value_indexes=packed_key_value_indexes, update_past_key_values=update_past_key_values, is_causal=is_causal, mode=mode, packed_vae_token_indexes=packed_vae_token_indexes, packed_text_indexes=packed_text_indexes, ) return outputs class SiglipVisionConfig(_SiglipVisionConfig): r""" This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. intermediate_size (`int`, *optional*, defaults to 3072): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. num_channels (`int`, *optional*, defaults to 3): Number of channels in the input images. image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. patch_size (`int`, *optional*, defaults to 16): The size (resolution) of each patch. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. Example: ```python >>> from transformers import SiglipVisionConfig, SiglipVisionModel >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration >>> configuration = SiglipVisionConfig() >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration >>> model = SiglipVisionModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "siglip_vision_model" def __init__( self, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, num_channels=3, image_size=224, patch_size=16, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, rope=True, **kwargs, ): super().__init__( hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_channels=num_channels, image_size=image_size, patch_size=patch_size, hidden_act=hidden_act, layer_norm_eps=layer_norm_eps, attention_dropout=attention_dropout, **kwargs) self.rope = rope class RotaryEmbedding2D(torch.nn.Module): def __init__(self, dim, max_h, max_w, base=10000): super().__init__() freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim inv_freq = 1.0 / (base ** freq) grid_h = torch.arange(0, max_h) grid_h = grid_h.to(inv_freq.dtype) grid_h = grid_h[:, None].repeat(1, max_w) grid_w = torch.arange(0, max_w) grid_w = grid_w.to(inv_freq.dtype) grid_w = grid_w[None, :].repeat(max_h, 1) cos_h, sin_h = self._forward_one_side(grid_h, inv_freq) cos_w, sin_w = self._forward_one_side(grid_w, inv_freq) self.register_buffer("cos_h", cos_h) self.register_buffer("sin_h", sin_h) self.register_buffer("cos_w", cos_w) self.register_buffer("sin_w", sin_w) def _forward_one_side(self, grid, inv_freq): freqs = grid[..., None] * inv_freq[None, None, :] emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1) return emb.cos(), emb.sin() def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class SiglipVisionEmbeddings(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.num_patches_per_side = self.image_size // self.patch_size self.num_patches = self.num_patches_per_side**2 self.num_positions = self.num_patches if not config.rope: self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) def convert_conv2d_to_linear(self, config, meta=False): if meta: linear_patch_embedding = nn.Linear( config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True, device='meta' ) else: linear_patch_embedding = nn.Linear( config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True ) W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape( self.embed_dim, config.num_channels * self.patch_size ** 2 ) linear_patch_embedding.weight.data = W linear_patch_embedding.bias.data = self.patch_embedding.bias.data del self.patch_embedding self.patch_embedding = linear_patch_embedding def forward( self, packed_pixel_values: torch.FloatTensor, packed_flattened_position_ids: torch.LongTensor ) -> torch.Tensor: patch_embeds = self.patch_embedding(packed_pixel_values) if not self.config.rope: embeddings = patch_embeds + self.position_embedding(packed_flattened_position_ids) else: embeddings = patch_embeds return embeddings class SiglipFlashAttention2(SiglipAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.IntTensor, max_seqlen: int, cos_h: torch.Tensor = None, sin_h: torch.Tensor = None, cos_w: torch.Tensor = None, sin_w: torch.Tensor = None, **kwargs, ) -> torch.Tensor: total_q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(total_q_len, self.num_heads, self.head_dim) key_states = key_states.view(total_q_len, self.num_heads, self.head_dim) value_states = value_states.view(total_q_len, self.num_heads, self.head_dim) if self.config.rope: qh, qw = query_states[:, :, :self.head_dim // 2], query_states[:, :, self.head_dim // 2:] kh, kw = key_states[:, :, :self.head_dim // 2], key_states[:, :, self.head_dim // 2:] qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h) qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w) query_states = torch.cat([qh, qw], dim=-1) key_states = torch.cat([kh, kw], dim=-1) attn_output = flash_attn_varlen_func( query_states.to(torch.bfloat16), key_states.to(torch.bfloat16), value_states.to(torch.bfloat16), cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, causal=False, ) attn_output = self.out_proj(attn_output.reshape(total_q_len, -1)) return attn_output class SiglipMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class SiglipEncoderLayer(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = SiglipFlashAttention2(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.IntTensor, max_seqlen: int, cos_h: torch.Tensor = None, sin_h: torch.Tensor = None, cos_w: torch.Tensor = None, sin_w: torch.Tensor = None ) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class SiglipEncoder(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList( [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)] ) def forward( self, inputs_embeds: torch.Tensor, cu_seqlens: torch.IntTensor, max_seqlen: int, cos_h: torch.Tensor = None, sin_h: torch.Tensor = None, cos_w: torch.Tensor = None, sin_w: torch.Tensor = None, ) -> torch.Tensor: hidden_states = inputs_embeds for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen, cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w) return hidden_states class SiglipVisionTransformer(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings(config) if config.rope: max_size = config.image_size // config.patch_size dim_head = config.hidden_size // config.num_attention_heads self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size) self.encoder = SiglipEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, packed_pixel_values: torch.Tensor, packed_flattened_position_ids: torch.LongTensor, cu_seqlens: torch.IntTensor, max_seqlen: int, ) -> torch.Tensor: hidden_states = self.embeddings( packed_pixel_values=packed_pixel_values, packed_flattened_position_ids=packed_flattened_position_ids ) extra_inputs = {} if self.config.rope: extra_inputs.update( cos_h = self.rope.cos_h[packed_flattened_position_ids], sin_h = self.rope.sin_h[packed_flattened_position_ids], cos_w = self.rope.cos_w[packed_flattened_position_ids], sin_w = self.rope.sin_w[packed_flattened_position_ids] ) last_hidden_state = self.encoder( inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **extra_inputs ) last_hidden_state = self.post_layernorm(last_hidden_state) return last_hidden_state class SiglipVisionModel(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "packed_pixel_values" def __init__(self, config: SiglipVisionConfig): super().__init__(config) self.vision_model = SiglipVisionTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding def forward( self, packed_pixel_values: torch.Tensor, packed_flattened_position_ids: torch.LongTensor, cu_seqlens: torch.IntTensor, max_seqlen: int, ) -> torch.Tensor: return self.vision_model( packed_pixel_values=packed_pixel_values, packed_flattened_position_ids=packed_flattened_position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) class MaxLongEdgeMinShortEdgeResize(torch.nn.Module): """Resize the input image so that its longest side and shortest side are within a specified range, ensuring that both sides are divisible by a specified stride. Args: max_size (int): Maximum size for the longest edge of the image. min_size (int): Minimum size for the shortest edge of the image. stride (int): Value by which the height and width of the image must be divisible. max_pixels (int): Maximum pixels for the full image. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported. The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted. antialias (bool, optional): Whether to apply antialiasing (default is True). """ def __init__( self, max_size: int, min_size: int, stride: int, max_pixels: int, interpolation=InterpolationMode.BICUBIC, antialias=True ): super().__init__() self.max_size = max_size self.min_size = min_size self.stride = stride self.max_pixels = max_pixels self.interpolation = interpolation self.antialias = antialias def _make_divisible(self, value, stride): """Ensure the value is divisible by the stride.""" return max(stride, int(round(value / stride) * stride)) def _apply_scale(self, width, height, scale): new_width = round(width * scale) new_height = round(height * scale) new_width = self._make_divisible(new_width, self.stride) new_height = self._make_divisible(new_height, self.stride) return new_width, new_height def forward(self, img, img_num=1): """ Args: img (PIL Image): Image to be resized. img_num (int): Number of images, used to change max_tokens. Returns: PIL Image or Tensor: Rescaled image with divisible dimensions. """ if isinstance(img, torch.Tensor): height, width = img.shape[-2:] else: width, height = img.size scale = min(self.max_size / max(width, height), 1.0) scale = max(scale, self.min_size / min(width, height)) new_width, new_height = self._apply_scale(width, height, scale) # Ensure the number of pixels does not exceed max_pixels if new_width * new_height > self.max_pixels / img_num: scale = self.max_pixels / img_num / (new_width * new_height) new_width, new_height = self._apply_scale(new_width, new_height, scale) # Ensure longest edge does not exceed max_size if max(new_width, new_height) > self.max_size: scale = self.max_size / max(new_width, new_height) new_width, new_height = self._apply_scale(new_width, new_height, scale) return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias) class ImageTransform: def __init__( self, max_image_size, min_image_size, image_stride, max_pixels=14*14*9*1024, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5] ): self.stride = image_stride self.resize_transform = MaxLongEdgeMinShortEdgeResize( max_size=max_image_size, min_size=min_image_size, stride=image_stride, max_pixels=max_pixels, ) self.to_tensor_transform = transforms.ToTensor() self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True) def __call__(self, img, img_num=1): img = self.resize_transform(img, img_num=img_num) img = self.to_tensor_transform(img) img = self.normalize_transform(img) return img def decolorization(image): gray_image = image.convert('L') return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image def downscale(image, scale_factor): new_width = int(round(image.width * scale_factor)) new_height = int(round(image.height * scale_factor)) new_width = max(1, new_width) new_height = max(1, new_height) return image.resize((new_width, new_height), resample=Image.BICUBIC) def crop(image, crop_factors): target_h, target_w = crop_factors img_w, img_h = image.size if target_h > img_h or target_w > img_w: raise ValueError("Crop size exceeds image dimensions") x = random.randint(0, img_w - target_w) y = random.randint(0, img_h - target_h) return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]] def motion_blur_opencv(image, kernel_size=15, angle=0): # 线性核 kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32) kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32) # 旋转核 center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5) M = cv2.getRotationMatrix2D(center, angle, 1) rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size)) # 归一化核 rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1 img = np.array(image) if img.ndim == 2: blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT) else: # 对于彩色图像,各通道独立卷积 blurred = np.zeros_like(img) for c in range(img.shape[2]): blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT) return Image.fromarray(blurred.astype(np.uint8)) def shuffle_patch(image, num_splits, gap_size=2): """将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙""" h_splits, w_splits = num_splits img_w, img_h = image.size base_patch_h = img_h // h_splits patch_heights = [base_patch_h] * (h_splits - 1) patch_heights.append(img_h - sum(patch_heights)) base_patch_w = img_w // w_splits patch_widths = [base_patch_w] * (w_splits - 1) patch_widths.append(img_w - sum(patch_widths)) patches = [] current_y = 0 for i in range(h_splits): current_x = 0 patch_h = patch_heights[i] for j in range(w_splits): patch_w = patch_widths[j] patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h)) patches.append(patch) current_x += patch_w current_y += patch_h random.shuffle(patches) total_width = sum(patch_widths) + (w_splits - 1) * gap_size total_height = sum(patch_heights) + (h_splits - 1) * gap_size new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255)) current_y = 0 # 当前行的起始 Y 坐标 patch_idx = 0 # 当前处理的块索引 for i in range(h_splits): current_x = 0 # 当前列的起始 X 坐标 patch_h = patch_heights[i] # 当前行块的高度 for j in range(w_splits): # 取出打乱后的块 patch = patches[patch_idx] patch_w = patch_widths[j] # 当前列块的宽度 # 粘贴块(左上角坐标为 (current_x, current_y)) new_image.paste(patch, (current_x, current_y)) # 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙) current_x += patch_w + gap_size patch_idx += 1 # 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙) current_y += patch_h + gap_size return new_image def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)): """ 图像分割后随机空白部分patch,用于inpainting任务 参数: image: PIL.Image 输入图像(RGB模式) h_splits: int 行分割数(垂直方向分割块数) w_splits: int 列分割数(水平方向分割块数) blank_ratio: float 空白patch的比例(0~1) blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255)) 返回: PIL.Image 处理后拼接的图像 """ h_splits, w_splits = num_splits img_w, img_h = image.size base_patch_h = img_h // h_splits patch_heights = [base_patch_h] * (h_splits - 1) patch_heights.append(img_h - sum(patch_heights)) base_patch_w = img_w // w_splits patch_widths = [base_patch_w] * (w_splits - 1) patch_widths.append(img_w - sum(patch_widths)) patches = [] current_y = 0 for i in range(h_splits): current_x = 0 patch_h = patch_heights[i] for j in range(w_splits): patch_w = patch_widths[j] patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h)) patches.append(patch) current_x += patch_w current_y += patch_h total_patches = h_splits * w_splits num_blank = int(total_patches * blank_ratio) num_blank = max(0, min(num_blank, total_patches)) blank_indices = random.sample(range(total_patches), num_blank) processed_patches = [] for idx, patch in enumerate(patches): if idx in blank_indices: blank_patch = Image.new("RGB", patch.size, color=blank_color) processed_patches.append(blank_patch) else: processed_patches.append(patch) # 创建结果图像(尺寸与原图一致) result_image = Image.new("RGB", (img_w, img_h)) current_y = 0 patch_idx = 0 for i in range(h_splits): current_x = 0 patch_h = patch_heights[i] for j in range(w_splits): # 取出处理后的patch patch = processed_patches[patch_idx] patch_w = patch_widths[j] # 粘贴到原位置 result_image.paste(patch, (current_x, current_y)) current_x += patch_w patch_idx += 1 current_y += patch_h return result_image @dataclass class AutoEncoderParams: resolution: int in_channels: int downsample: int ch: int out_ch: int ch_mult: list[int] num_res_blocks: int z_channels: int scale_factor: float shift_factor: float def swish(x: Tensor) -> Tensor: return x * torch.sigmoid(x) class AttnBlock(nn.Module): def __init__(self, in_channels: int): super().__init__() self.in_channels = in_channels self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) def attention(self, h_: Tensor) -> Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() h_ = nn.functional.scaled_dot_product_attention(q, k, v) return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) def forward(self, x: Tensor) -> Tensor: return x + self.proj_out(self.attention(x)) class ResnetBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h = x h = self.norm1(h) h = swish(h) h = self.conv1(h) h = self.norm2(h) h = swish(h) h = self.conv2(h) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) return x + h class Downsample(nn.Module): def __init__(self, in_channels: int): super().__init__() # no asymmetric padding in torch conv, must do it ourselves self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x: Tensor): pad = (0, 1, 0, 1) x = nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) return x class Upsample(nn.Module): def __init__(self, in_channels: int): super().__init__() self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: Tensor): x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") x = self.conv(x) return x class Encoder(nn.Module): def __init__( self, resolution: int, in_channels: int, ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() block_in = self.ch for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: Tensor) -> Tensor: # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1]) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, ch: int, out_ch: int, ch_mult: list[int], num_res_blocks: int, in_channels: int, resolution: int, z_channels: int, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.ffactor = 2 ** (self.num_resolutions - 1) # compute in_ch_mult, block_in and curr_res at lowest res block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) # z to block_in self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z: Tensor) -> Tensor: # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class DiagonalGaussian(nn.Module): def __init__(self, sample: bool = True, chunk_dim: int = 1): super().__init__() self.sample = sample self.chunk_dim = chunk_dim def forward(self, z: Tensor) -> Tensor: mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) if self.sample: std = torch.exp(0.5 * logvar) return mean + std * torch.randn_like(mean) else: return mean class AutoEncoder(ModelMixin, ConfigMixin): def __init__(self, params: AutoEncoderParams | None = None, **kwargs): if params is None: params = AutoEncoderParams(**kwargs) super().__init__() self.register_to_config(**asdict(params)) self.encoder = Encoder( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, ) self.decoder = Decoder( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, out_ch=params.out_ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, ) self.reg = DiagonalGaussian() self.scale_factor = params.scale_factor self.shift_factor = params.shift_factor def encode(self, x: Tensor) -> Tensor: z = self.reg(self.encoder(x)) z = self.scale_factor * (z - self.shift_factor) return z def decode(self, z: Tensor) -> Tensor: z = z / self.scale_factor + self.shift_factor return self.decoder(z) def forward(self, x: Tensor) -> Tensor: return self.decode(self.encode(x)) @classmethod def from_config(cls, config, **unused): """ Diffusers passes us `config` as a *dict* here. Rebuild the AutoEncoderParams dataclass from that dict and delegate to the normal constructor. """ # keep only keys that exist in AutoEncoderParams allowed = {f.name for f in fields(AutoEncoderParams)} params_dict = {k: v for k, v in config.items() if k in allowed} params = AutoEncoderParams(**params_dict) return cls(params) def print_load_warning(missing: list[str], unexpected: list[str]) -> None: if len(missing) > 0 and len(unexpected) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) print("\n" + "-" * 79 + "\n") print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) elif len(missing) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) elif len(unexpected) > 0: print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) def load_ae(local_path: str) -> AutoEncoder: ae_params = AutoEncoderParams( resolution=256, in_channels=3, downsample=8, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ) # Loading the autoencoder ae = AutoEncoder(ae_params) if local_path is not None: sd = load_sft(local_path) missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return ae, ae_params VLM_THINK_SYSTEM_PROMPT = '''You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here''' GEN_THINK_SYSTEM_PROMPT = '''You should first think about the planning process in the mind and then generate the image. The planning process is enclosed within tags, i.e. planning process here image here''' class InterleaveInferencer: def __init__(self, model, vae_model, tokenizer, vae_transform, vit_transform, new_token_ids): self.model = model self.vae_model = vae_model self.tokenizer = tokenizer self.vae_transform = vae_transform self.vit_transform = vit_transform self.new_token_ids = new_token_ids def _to_device(self, d, device): """Recursively move every tensor in *d* to *device*.""" for k, v in d.items(): if torch.is_tensor(v): d[k] = v.to(device) return d def to(self, device): self.model = self.model.to(device) self.vae_model = self.vae_model.to(device) return self def init_gen_context(self): gen_context = { 'kv_lens': [0], 'ropes': [0], 'past_key_values': NaiveCache(self.model.config.llm_config.num_hidden_layers), } return gen_context @torch.no_grad() def update_context_text(self, text, gen_context): # used for interleave data, currently only support 1 data inference, past_key_values = gen_context['past_key_values'] kv_lens = gen_context['kv_lens'] ropes = gen_context['ropes'] generation_input, kv_lens, ropes = self.model.prepare_prompts( curr_kvlens=kv_lens, curr_rope=ropes, prompts=[text], tokenizer=self.tokenizer, new_token_ids=self.new_token_ids, ) generation_input = self._to_device(generation_input, next(self.model.parameters()).device) past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input) gen_context['kv_lens'] = kv_lens gen_context['ropes'] = ropes gen_context['past_key_values'] = past_key_values return gen_context @torch.no_grad() def update_context_image(self, image, gen_context, vae=True, vit=True): # used for interleave data, currently only support 1 data inference, assert vae or vit past_key_values = gen_context['past_key_values'] kv_lens = gen_context['kv_lens'] ropes = gen_context['ropes'] device = next(self.model.parameters()).device if vae: ## update vae generation_input, kv_lens, ropes = self.model.prepare_vae_images( curr_kvlens=kv_lens, curr_rope=ropes, images=[image], transforms=self.vae_transform, new_token_ids=self.new_token_ids, ) generation_input = self._to_device(generation_input, device) past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input) if vit: ## update vit generation_input, kv_lens, ropes = self.model.prepare_vit_images( curr_kvlens=kv_lens, curr_rope=ropes, images=[image], transforms=self.vit_transform, new_token_ids=self.new_token_ids, ) generation_input = self._to_device(generation_input, device) past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input) gen_context['kv_lens'] = kv_lens gen_context['ropes'] = ropes gen_context['past_key_values'] = past_key_values return gen_context @torch.no_grad() def gen_image( self, image_shape, gen_context, cfg_text_scale=4.0, cfg_img_scale=1.5, cfg_text_precontext=None, cfg_img_precontext=None, cfg_interval=(0.4, 1.0), cfg_renorm_min=0.0, cfg_renorm_type="global", num_timesteps=50, timestep_shift=3.0 ): # print(cfg_renorm_type) device = next(self.model.parameters()).device past_key_values = gen_context['past_key_values'] kv_lens = gen_context['kv_lens'] ropes = gen_context['ropes'] generation_input = self.model.prepare_vae_latent( curr_kvlens=kv_lens, curr_rope=ropes, image_sizes=[image_shape], new_token_ids=self.new_token_ids, ) generation_input = self._to_device(generation_input, device) # text cfg cfg_text_past_key_values = cfg_text_precontext['past_key_values'] kv_lens_cfg = cfg_text_precontext['kv_lens'] ropes_cfg = cfg_text_precontext['ropes'] generation_input_cfg_text = self.model.prepare_vae_latent_cfg( curr_kvlens=kv_lens_cfg, curr_rope=ropes_cfg, image_sizes=[image_shape], ) generation_input_cfg_text = self._to_device(generation_input_cfg_text, device) # img cfg cfg_img_past_key_values = cfg_img_precontext['past_key_values'] kv_lens_cfg = cfg_img_precontext['kv_lens'] ropes_cfg = cfg_img_precontext['ropes'] generation_input_cfg_img = self.model.prepare_vae_latent_cfg( curr_kvlens=kv_lens_cfg, curr_rope=ropes_cfg, image_sizes=[image_shape], ) generation_input_cfg_img = self._to_device(generation_input_cfg_img, device) unpacked_latent = self.model.generate_image( past_key_values=past_key_values, cfg_text_past_key_values=cfg_text_past_key_values, cfg_img_past_key_values=cfg_img_past_key_values, num_timesteps=num_timesteps, cfg_text_scale=cfg_text_scale, cfg_img_scale=cfg_img_scale, cfg_interval=cfg_interval, cfg_renorm_min=cfg_renorm_min, cfg_renorm_type=cfg_renorm_type, timestep_shift=timestep_shift, **generation_input, cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'], cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'], cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'], cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'], cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'], cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'], cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'], cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'], ) image = self.decode_image(unpacked_latent[0], image_shape) return image def decode_image(self, latent, image_shape): H, W = image_shape h, w = H // self.model.latent_downsample, W // self.model.latent_downsample latent = latent.reshape(1, h, w, self.model.latent_patch_size, self.model.latent_patch_size, self.model.latent_channel) latent = torch.einsum("nhwpqc->nchpwq", latent) latent = latent.reshape(1, self.model.latent_channel, h * self.model.latent_patch_size, w * self.model.latent_patch_size) image = self.vae_model.decode(latent) image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255 image = Image.fromarray((image).to(torch.uint8).cpu().numpy()) return image @torch.no_grad() def gen_text(self, gen_context, max_length: int = 500, do_sample: bool = True, temperature: float = 1.0): gen_context = deepcopy(gen_context) past_key_values = gen_context['past_key_values'] kv_lens = gen_context['kv_lens'] ropes = gen_context['ropes'] generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids) unpacked_latent = self.model.generate_text( past_key_values=past_key_values, max_length=max_length, do_sample=do_sample, temperature=temperature, end_token_id=self.new_token_ids['eos_token_id'], **generation_input, ) output = self.tokenizer.decode(unpacked_latent[:,0]) output = output.split('<|im_end|>')[0].split('<|im_start|>')[1] return output @torch.no_grad() def interleave_inference( self, input_lists: List[Union[str, Image.Image]], think=False, understanding_output=False, max_think_token_n=1000, do_sample=False, text_temperature=0.3, cfg_text_scale=3.0, cfg_img_scale=1.5, cfg_interval=[0.4, 1.0], timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=0.0, cfg_renorm_type="global", image_shapes=(1024, 1024), ) -> List[Union[str, Image.Image]]: output_list = [] gen_context = self.init_gen_context() cfg_text_context = deepcopy(gen_context) cfg_img_context = deepcopy(gen_context) with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): if think: if understanding_output: system_prompt = VLM_THINK_SYSTEM_PROMPT else: system_prompt = GEN_THINK_SYSTEM_PROMPT gen_context = self.update_context_text(system_prompt, gen_context) cfg_img_context = self.update_context_text(system_prompt, cfg_img_context) for input_term in input_lists: if isinstance(input_term, str): cfg_text_context = deepcopy(gen_context) gen_context = self.update_context_text(input_term, gen_context) cfg_img_context = self.update_context_text(input_term, cfg_img_context) elif isinstance(input_term, Image.Image): input_term = self.vae_transform.resize_transform(pil_img2rgb(input_term)) gen_context = self.update_context_image(input_term, gen_context, vae=not understanding_output) image_shapes = input_term.size[::-1] cfg_text_context = deepcopy(gen_context) else: raise ValueError(f"Unsupported input type: {type(input_term)}") if understanding_output: gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n) output_list.append(gen_text) else: if think: gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n) gen_context = self.update_context_text(gen_text, gen_context) output_list.append(gen_text) img = self.gen_image( image_shapes, gen_context, cfg_text_precontext=cfg_text_context, cfg_img_precontext=cfg_img_context, cfg_text_scale=cfg_text_scale, cfg_img_scale=cfg_img_scale, cfg_interval=cfg_interval, timestep_shift=timestep_shift, num_timesteps=num_timesteps, cfg_renorm_min=cfg_renorm_min, cfg_renorm_type=cfg_renorm_type, ) output_list.append(img) return output_list def __call__( self, image: Optional[Image.Image] = None, text: Optional[str] = None, **kargs ) -> Dict[str, Any]: output_dict = {'image': None, 'text': None} if image is None and text is None: print('Please provide at least one input: either an image or text.') return output_dict input_list = [] if image is not None: input_list.append(image) if text is not None: input_list.append(text) output_list = self.interleave_inference(input_list, **kargs) for i in output_list: if isinstance(i, Image.Image): output_dict['image'] = i elif isinstance(i, str): output_dict['text'] = i return output_dict # class BagelPipeline(DiffusionPipeline): # """ # A “naive” Bagel wrapper that replicates your notebook exactly. # """ # model_cpu_offload_seq = "bagel_model" # def __init__( # self, # torch_dtype: torch.dtype = torch.bfloat16, # ): # super().__init__() # self._dtype = torch_dtype # self._built = False # self._inferencer = None # self.new_token_ids: List[int] = [] # # Hard‐code default weights path; overridden by from_pretrained # self.weights_root: Optional[str] = None # self.register_to_config(weights_root=self.weights_root, torch_dtype=torch_dtype) # repo_id = "ByteDance-Seed/BAGEL-7B-MoT" # model_path = snapshot_download(repo_id=repo_id) # print("loaded from ", model_path) # # LLM config preparing # llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json")) # llm_config.qk_norm = True # llm_config.tie_word_embeddings = False # llm_config.layer_module = "Qwen2MoTDecoderLayer" # # ViT config preparing # vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json")) # vit_config.rope = False # vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1 # # VAE loading # vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors")) # # Bagel config preparing # config = BagelConfig( # visual_gen=True, # visual_und=True, # llm_config=llm_config, # vit_config=vit_config, # vae_config=vae_config, # vit_max_num_patch_per_side=70, # connector_act='gelu_pytorch_tanh', # latent_patch_size=2, # max_latent_size=64, # ) # with init_empty_weights(): # language_model = Qwen2ForCausalLM(llm_config) # vit_model = SiglipVisionModel(vit_config) # model = Bagel(language_model, vit_model, config) # model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True) # # Tokenizer Preparing # tokenizer = Qwen2Tokenizer.from_pretrained(model_path) # tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) # # Image Transform Preparing # vae_transform = ImageTransform(1024, 512, 16) # vit_transform = ImageTransform(980, 224, 14) # # set cuda device to 4 # max_mem_per_gpu = "40GiB" # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU. # device_map = infer_auto_device_map( # model, # max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())}, # no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"], # ) # print(device_map) # same_device_modules = [ # 'language_model.model.embed_tokens', # 'time_embedder', # 'latent_pos_embed', # 'vae2llm', # 'llm2vae', # 'connector', # 'vit_pos_embed' # ] # if torch.cuda.device_count() == 1: # first_device = device_map.get(same_device_modules[0], "cuda:0") # for k in same_device_modules: # if k in device_map: # device_map[k] = first_device # else: # device_map[k] = "cuda:0" # else: # first_device = device_map.get(same_device_modules[0]) # for k in same_device_modules: # if k in device_map: # device_map[k] = first_device # # Thanks @onion-liu: https://github.com/ByteDance-Seed/Bagel/pull/8 # model = load_checkpoint_and_dispatch( # model, # checkpoint=os.path.join(model_path, "ema.safetensors"), # device_map=device_map, # offload_buffers=True, # dtype=torch.bfloat16, # force_hooks=True, # offload_folder="/tmp/offload" # ) # model = model.eval() # print('Model loaded') # self._inferencer = InterleaveInferencer( # model=model, # vae_model=vae_model, # tokenizer=tokenizer, # vae_transform=vae_transform, # vit_transform=vit_transform, # new_token_ids=new_token_ids # ) # seed = 42 # random.seed(seed) # np.random.seed(seed) # torch.manual_seed(seed) # if torch.cuda.is_available(): # torch.cuda.manual_seed(seed) # torch.cuda.manual_seed_all(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False # @torch.no_grad() # def __call__( # self, # prompt: str, # think=False, # cfg_text_scale: float = 4.0, # cfg_img_scale: float = 1.0, # cfg_interval=(0.4, 1.0), # timestep_shift: float = 3.0, # num_timesteps: int = 50, # cfg_renorm_min: float = 0.0, # cfg_renorm_type: str = "global", # seed: Optional[int] = None, # output_type: str = "pil", # return_dict: bool = True, # **unused, # ): # if seed is not None: # torch.manual_seed(seed) # if torch.cuda.is_available(): # torch.cuda.manual_seed_all(seed) # inference_kwargs = dict( # text=prompt, # think=think, # cfg_text_scale=cfg_text_scale, # cfg_img_scale=cfg_img_scale, # cfg_interval=list(cfg_interval), # timestep_shift=timestep_shift, # num_timesteps=num_timesteps, # cfg_renorm_min=cfg_renorm_min, # cfg_renorm_type=cfg_renorm_type, # ) # result = self._inferencer(**inference_kwargs) # image = result["image"] if isinstance(result, dict) else result # if return_dict: # return {"images": [image]} # return [image] class BagelPipeline(DiffusionPipeline): model_cpu_offload_seq = "bagel_model" def __init__(self, bagel_model, vae, tokenizer): super().__init__() self.register_modules( bagel_model = bagel_model, vae = vae, tokenizer = tokenizer, ) tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) self._inferencer = InterleaveInferencer( model = bagel_model, vae_model = vae, tokenizer = tokenizer, vae_transform= ImageTransform(1024, 512, 16), vit_transform= ImageTransform(980, 224, 14), new_token_ids= new_token_ids, ) def __call__(self, prompt: str, **infer_kwargs): result = self._inferencer(text=prompt, **infer_kwargs) img = result["image"] if isinstance(result, dict) else result return {"images": [img]} def to(self, device): super().to(device) # moves registered modules if hasattr(self, "_inferencer"): self._inferencer.to(device) return self