pcuenq HF Staff reach-vb HF Staff commited on
Commit
139ff37
·
verified ·
1 Parent(s): d446e1a

Add a remote code file for transformers integration 🤗 (#2)

Browse files

- Add a remote code file for transformers integration 🤗 (e5c6dc1a07b977d538d27829470a7a62e98f1e39)
- Update config.json (fbe8feb9086c566d4d867103a646a966f8922d22)
- Update README.md (8e093cf658178a113a2dc8396a5213a552d59475)
- Update README.md (69d7e3fdd67b10df146686be17bcf661bbc500e8)


Co-authored-by: Vaibhav Srivastav <[email protected]>

Files changed (3) hide show
  1. README.md +58 -1
  2. config.json +4 -0
  3. llava_qwen.py +2195 -0
README.md CHANGED
@@ -3,6 +3,8 @@ license: apple-amlr
3
  license_name: apple-ascl
4
  license_link: https://github.com/apple/ml-fastvlm/blob/main/LICENSE_MODEL
5
  library_name: ml-fastvlm
 
 
6
  ---
7
  # FastVLM: Efficient Vision Encoding for Vision Language Models
8
 
@@ -51,6 +53,61 @@ python predict.py --model-path /path/to/checkpoint-dir \
51
  --prompt "Describe the image."
52
  ```
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  ## Citation
56
  If you found this model useful, please cite the following paper:
@@ -62,4 +119,4 @@ If you found this model useful, please cite the following paper:
62
  month = {June},
63
  year = {2025},
64
  }
65
- ```
 
3
  license_name: apple-ascl
4
  license_link: https://github.com/apple/ml-fastvlm/blob/main/LICENSE_MODEL
5
  library_name: ml-fastvlm
6
+ tags:
7
+ - transformers
8
  ---
9
  # FastVLM: Efficient Vision Encoding for Vision Language Models
10
 
 
53
  --prompt "Describe the image."
54
  ```
55
 
56
+ ### Run inference with Transformers (Remote Code)
57
+ To run inference with transformers we can leverage `trust_remote_code` along with the following snippet:
58
+
59
+ ```python
60
+ import torch
61
+ from PIL import Image
62
+ from transformers import AutoTokenizer, AutoModelForCausalLM
63
+
64
+ MID = "apple/FastVLM-0.5B"
65
+ IMAGE_TOKEN_INDEX = -200 # what the model code looks for
66
+
67
+ # Load
68
+ tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
69
+ model = AutoModelForCausalLM.from_pretrained(
70
+ MID,
71
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
72
+ device_map="auto",
73
+ trust_remote_code=True,
74
+ )
75
+
76
+ # Build chat -> render to string (not tokens) so we can place <image> exactly
77
+ messages = [
78
+ {"role": "user", "content": "<image>\nDescribe this image in detail."}
79
+ ]
80
+ rendered = tok.apply_chat_template(
81
+ messages, add_generation_prompt=True, tokenize=False
82
+ )
83
+
84
+ pre, post = rendered.split("<image>", 1)
85
+
86
+ # Tokenize the text *around* the image token (no extra specials!)
87
+ pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
88
+ post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
89
+
90
+ # Splice in the IMAGE token id (-200) at the placeholder position
91
+ img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
92
+ input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
93
+ attention_mask = torch.ones_like(input_ids, device=model.device)
94
+
95
+ # Preprocess image via the model's own processor
96
+ img = Image.open("test-2.jpg").convert("RGB")
97
+ px = model.get_vision_tower().image_processor(images=img, return_tensors="pt")["pixel_values"]
98
+ px = px.to(model.device, dtype=model.dtype)
99
+
100
+ # Generate
101
+ with torch.no_grad():
102
+ out = model.generate(
103
+ inputs=input_ids,
104
+ attention_mask=attention_mask,
105
+ images=px,
106
+ max_new_tokens=128,
107
+ )
108
+
109
+ print(tok.decode(out[0], skip_special_tokens=True))
110
+ ```
111
 
112
  ## Citation
113
  If you found this model useful, please cite the following paper:
 
119
  month = {June},
120
  year = {2025},
121
  }
122
+ ```
config.json CHANGED
@@ -3,6 +3,10 @@
3
  "architectures": [
4
  "LlavaQwen2ForCausalLM"
5
  ],
 
 
 
 
6
  "attention_dropout": 0.0,
7
  "bos_token_id": 151643,
8
  "eos_token_id": 151645,
 
3
  "architectures": [
4
  "LlavaQwen2ForCausalLM"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "llava_qwen.LlavaConfig",
8
+ "AutoModelForCausalLM": "llava_qwen.LlavaQwen2ForCausalLM"
9
+ },
10
  "attention_dropout": 0.0,
11
  "bos_token_id": 151643,
12
  "eos_token_id": 151645,
llava_qwen.py ADDED
@@ -0,0 +1,2195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import re
19
+ import copy
20
+ from timm.models import create_model
21
+ from abc import ABC, abstractmethod
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch import Tensor
26
+ import torch.nn.functional as F
27
+ from torch.nn.init import normal_
28
+
29
+ from transformers import CLIPImageProcessor
30
+ from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM
31
+
32
+ from transformers.modeling_outputs import CausalLMOutputWithPast
33
+ from transformers.generation.utils import GenerateOutput
34
+
35
+ from functools import partial
36
+ from typing import List, Tuple, Optional, Union, Dict, Any
37
+
38
+ from timm.models import register_model
39
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
40
+ from timm.layers import DropPath, SqueezeExcite
41
+
42
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
43
+ WORKER_HEART_BEAT_INTERVAL = 15
44
+ LOGDIR = "."
45
+ # Model Constants
46
+ IGNORE_INDEX = -100
47
+ IMAGE_TOKEN_INDEX = -200
48
+ DEFAULT_IMAGE_TOKEN = "<image>"
49
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
50
+ DEFAULT_IM_START_TOKEN = "<im_start>"
51
+ DEFAULT_IM_END_TOKEN = "<im_end>"
52
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
53
+
54
+ class LlavaConfig(Qwen2Config):
55
+ model_type = "llava_qwen2"
56
+
57
+ def _cfg(url="", **kwargs):
58
+ return {
59
+ "url": url,
60
+ "num_classes": 1000,
61
+ "input_size": (3, 256, 256),
62
+ "pool_size": None,
63
+ "crop_pct": 0.95,
64
+ "interpolation": "bicubic",
65
+ "mean": IMAGENET_DEFAULT_MEAN,
66
+ "std": IMAGENET_DEFAULT_STD,
67
+ "classifier": "head",
68
+ **kwargs,
69
+ }
70
+
71
+
72
+ default_cfgs = {
73
+ "fastvit_t": _cfg(crop_pct=0.9),
74
+ "fastvit_s": _cfg(crop_pct=0.9),
75
+ "fastvit_m": _cfg(crop_pct=0.95),
76
+ }
77
+
78
+
79
+ class SEBlock(nn.Module):
80
+ """Squeeze and Excite module.
81
+ Pytorch implementation of `Squeeze-and-Excitation Networks` -
82
+ https://arxiv.org/pdf/1709.01507.pdf
83
+ """
84
+
85
+ def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
86
+ """Construct a Squeeze and Excite Module.
87
+ Args:
88
+ in_channels: Number of input channels.
89
+ rd_ratio: Input channel reduction ratio.
90
+ """
91
+ super(SEBlock, self).__init__()
92
+ self.reduce = nn.Conv2d(
93
+ in_channels=in_channels,
94
+ out_channels=int(in_channels * rd_ratio),
95
+ kernel_size=1,
96
+ stride=1,
97
+ bias=True,
98
+ )
99
+ self.expand = nn.Conv2d(
100
+ in_channels=int(in_channels * rd_ratio),
101
+ out_channels=in_channels,
102
+ kernel_size=1,
103
+ stride=1,
104
+ bias=True,
105
+ )
106
+
107
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
108
+ """Apply forward pass."""
109
+ b, c, h, w = inputs.size()
110
+ # x = F.avg_pool2d(inputs, kernel_size=[h, w])
111
+ x = F.avg_pool2d(inputs, kernel_size=[16, 16])
112
+ x = self.reduce(x)
113
+ x = F.relu(x)
114
+ x = self.expand(x)
115
+ x = torch.sigmoid(x)
116
+ x = x.view(-1, c, 1, 1)
117
+ return inputs * x
118
+
119
+
120
+ class MobileOneBlock(nn.Module):
121
+ """MobileOne building block.
122
+ This block has a multi-branched architecture at train-time
123
+ and plain-CNN style architecture at inference time
124
+ For more details, please refer to our paper:
125
+ `An Improved One millisecond Mobile Backbone` -
126
+ https://arxiv.org/pdf/2206.04040.pdf
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ in_channels: int,
132
+ out_channels: int,
133
+ kernel_size: int,
134
+ stride: int = 1,
135
+ padding: int = 0,
136
+ dilation: int = 1,
137
+ groups: int = 1,
138
+ inference_mode: bool = False,
139
+ use_se: bool = False,
140
+ use_act: bool = True,
141
+ use_scale_branch: bool = True,
142
+ num_conv_branches: int = 1,
143
+ activation: nn.Module = nn.GELU(),
144
+ ) -> None:
145
+ """Construct a MobileOneBlock module.
146
+ Args:
147
+ in_channels: Number of channels in the input.
148
+ out_channels: Number of channels produced by the block.
149
+ kernel_size: Size of the convolution kernel.
150
+ stride: Stride size.
151
+ padding: Zero-padding size.
152
+ dilation: Kernel dilation factor.
153
+ groups: Group number.
154
+ inference_mode: If True, instantiates model in inference mode.
155
+ use_se: Whether to use SE-ReLU activations.
156
+ use_act: Whether to use activation. Default: ``True``
157
+ use_scale_branch: Whether to use scale branch. Default: ``True``
158
+ num_conv_branches: Number of linear conv branches.
159
+ """
160
+ super(MobileOneBlock, self).__init__()
161
+ self.inference_mode = inference_mode
162
+ self.groups = groups
163
+ self.stride = stride
164
+ self.padding = padding
165
+ self.dilation = dilation
166
+ self.kernel_size = kernel_size
167
+ self.in_channels = in_channels
168
+ self.out_channels = out_channels
169
+ self.num_conv_branches = num_conv_branches
170
+
171
+ # Check if SE-ReLU is requested
172
+ if use_se:
173
+ self.se = SEBlock(out_channels)
174
+ else:
175
+ self.se = nn.Identity()
176
+
177
+ if use_act:
178
+ self.activation = activation
179
+ else:
180
+ self.activation = nn.Identity()
181
+
182
+ if inference_mode:
183
+ self.reparam_conv = nn.Conv2d(
184
+ in_channels=in_channels,
185
+ out_channels=out_channels,
186
+ kernel_size=kernel_size,
187
+ stride=stride,
188
+ padding=padding,
189
+ dilation=dilation,
190
+ groups=groups,
191
+ bias=True,
192
+ )
193
+ else:
194
+ # Re-parameterizable skip connection
195
+ # Fallback, sometimes batchnorm tensors
196
+ # do not get instantiated correctly on some processes
197
+ # when using deepspeed + accelerate
198
+ norm_layer = nn.BatchNorm2d(num_features=in_channels)
199
+ if norm_layer.weight.shape[0] == 0:
200
+ norm_layer.weight = nn.Parameter(torch.zeros(in_channels))
201
+ if norm_layer.bias.shape[0] == 0:
202
+ norm_layer.bias = nn.Parameter(torch.zeros(in_channels))
203
+
204
+ self.rbr_skip = (
205
+ norm_layer
206
+ if out_channels == in_channels and stride == 1
207
+ else None
208
+ )
209
+
210
+ # Re-parameterizable conv branches
211
+ if num_conv_branches > 0:
212
+ rbr_conv = list()
213
+ for _ in range(self.num_conv_branches):
214
+ rbr_conv.append(
215
+ self._conv_bn(kernel_size=kernel_size, padding=padding)
216
+ )
217
+ self.rbr_conv = nn.ModuleList(rbr_conv)
218
+ else:
219
+ self.rbr_conv = None
220
+
221
+ # Re-parameterizable scale branch
222
+ self.rbr_scale = None
223
+ if not isinstance(kernel_size, int):
224
+ kernel_size = kernel_size[0]
225
+ if (kernel_size > 1) and use_scale_branch:
226
+ self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
227
+
228
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
229
+ """Apply forward pass."""
230
+ # Inference mode forward pass.
231
+ if self.inference_mode:
232
+ return self.activation(self.se(self.reparam_conv(x)))
233
+
234
+ # Multi-branched train-time forward pass.
235
+ # Skip branch output
236
+ identity_out = 0
237
+ if self.rbr_skip is not None:
238
+ identity_out = self.rbr_skip(x)
239
+
240
+ # Scale branch output
241
+ scale_out = 0
242
+ if self.rbr_scale is not None:
243
+ scale_out = self.rbr_scale(x)
244
+
245
+ # Other branches
246
+ out = scale_out + identity_out
247
+ if self.rbr_conv is not None:
248
+ for ix in range(self.num_conv_branches):
249
+ out += self.rbr_conv[ix](x)
250
+
251
+ return self.activation(self.se(out))
252
+
253
+ def reparameterize(self):
254
+ """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
255
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
256
+ architecture used at training time to obtain a plain CNN-like structure
257
+ for inference.
258
+ """
259
+ if self.inference_mode:
260
+ return
261
+ kernel, bias = self._get_kernel_bias()
262
+ self.reparam_conv = nn.Conv2d(
263
+ in_channels=self.in_channels,
264
+ out_channels=self.out_channels,
265
+ kernel_size=self.kernel_size,
266
+ stride=self.stride,
267
+ padding=self.padding,
268
+ dilation=self.dilation,
269
+ groups=self.groups,
270
+ bias=True,
271
+ )
272
+ self.reparam_conv.weight.data = kernel
273
+ self.reparam_conv.bias.data = bias
274
+
275
+ # Delete un-used branches
276
+ self.__delattr__("rbr_conv")
277
+ self.__delattr__("rbr_scale")
278
+ if hasattr(self, "rbr_skip"):
279
+ self.__delattr__("rbr_skip")
280
+
281
+ self.inference_mode = True
282
+
283
+ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
284
+ """Method to obtain re-parameterized kernel and bias.
285
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
286
+ Returns:
287
+ Tuple of (kernel, bias) after fusing branches.
288
+ """
289
+ # get weights and bias of scale branch
290
+ kernel_scale = 0
291
+ bias_scale = 0
292
+ if self.rbr_scale is not None:
293
+ kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
294
+ # Pad scale branch kernel to match conv branch kernel size.
295
+ pad = self.kernel_size // 2
296
+ kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
297
+
298
+ # get weights and bias of skip branch
299
+ kernel_identity = 0
300
+ bias_identity = 0
301
+ if self.rbr_skip is not None:
302
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
303
+
304
+ # get weights and bias of conv branches
305
+ kernel_conv = 0
306
+ bias_conv = 0
307
+ if self.rbr_conv is not None:
308
+ for ix in range(self.num_conv_branches):
309
+ _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
310
+ kernel_conv += _kernel
311
+ bias_conv += _bias
312
+
313
+ kernel_final = kernel_conv + kernel_scale + kernel_identity
314
+ bias_final = bias_conv + bias_scale + bias_identity
315
+ return kernel_final, bias_final
316
+
317
+ def _fuse_bn_tensor(
318
+ self, branch: Union[nn.Sequential, nn.BatchNorm2d]
319
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
320
+ """Method to fuse batchnorm layer with preceeding conv layer.
321
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
322
+ Args:
323
+ branch: Sequence of ops to be fused.
324
+ Returns:
325
+ Tuple of (kernel, bias) after fusing batchnorm.
326
+ """
327
+ if isinstance(branch, nn.Sequential):
328
+ kernel = branch.conv.weight
329
+ running_mean = branch.bn.running_mean
330
+ running_var = branch.bn.running_var
331
+ gamma = branch.bn.weight
332
+ beta = branch.bn.bias
333
+ eps = branch.bn.eps
334
+ else:
335
+ assert isinstance(branch, nn.BatchNorm2d)
336
+ if not hasattr(self, "id_tensor"):
337
+ input_dim = self.in_channels // self.groups
338
+
339
+ kernel_size = self.kernel_size
340
+ if isinstance(self.kernel_size, int):
341
+ kernel_size = (self.kernel_size, self.kernel_size)
342
+
343
+ kernel_value = torch.zeros(
344
+ (self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
345
+ dtype=branch.weight.dtype,
346
+ device=branch.weight.device,
347
+ )
348
+ for i in range(self.in_channels):
349
+ kernel_value[
350
+ i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
351
+ ] = 1
352
+ self.id_tensor = kernel_value
353
+ kernel = self.id_tensor
354
+ running_mean = branch.running_mean
355
+ running_var = branch.running_var
356
+ gamma = branch.weight
357
+ beta = branch.bias
358
+ eps = branch.eps
359
+ std = (running_var + eps).sqrt()
360
+ t = (gamma / std).reshape(-1, 1, 1, 1)
361
+ return kernel * t, beta - running_mean * gamma / std
362
+
363
+ def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
364
+ """Helper method to construct conv-batchnorm layers.
365
+ Args:
366
+ kernel_size: Size of the convolution kernel.
367
+ padding: Zero-padding size.
368
+ Returns:
369
+ Conv-BN module.
370
+ """
371
+ # Fallback, sometimes batchnorm tensors
372
+ # do not get instantiated correctly on some processes
373
+ # when using deepspeed + accelerate
374
+ norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
375
+ if norm_layer.weight.shape[0] == 0:
376
+ norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
377
+ if norm_layer.bias.shape[0] == 0:
378
+ norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
379
+
380
+ mod_list = nn.Sequential()
381
+ mod_list.add_module(
382
+ "conv",
383
+ nn.Conv2d(
384
+ in_channels=self.in_channels,
385
+ out_channels=self.out_channels,
386
+ kernel_size=kernel_size,
387
+ stride=self.stride,
388
+ padding=padding,
389
+ groups=self.groups,
390
+ bias=False,
391
+ ),
392
+ )
393
+ mod_list.add_module("bn", norm_layer)
394
+ return mod_list
395
+
396
+
397
+ class ReparamLargeKernelConv(nn.Module):
398
+ """Building Block of RepLKNet
399
+ This class defines overparameterized large kernel conv block
400
+ introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
401
+ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
402
+ """
403
+
404
+ def __init__(
405
+ self,
406
+ in_channels: int,
407
+ out_channels: int,
408
+ kernel_size: int,
409
+ stride: int,
410
+ groups: int,
411
+ small_kernel: int,
412
+ inference_mode: bool = False,
413
+ use_se: bool = False,
414
+ activation: nn.Module = nn.GELU(),
415
+ ) -> None:
416
+ """Construct a ReparamLargeKernelConv module.
417
+ Args:
418
+ in_channels: Number of input channels.
419
+ out_channels: Number of output channels.
420
+ kernel_size: Kernel size of the large kernel conv branch.
421
+ stride: Stride size. Default: 1
422
+ groups: Group number. Default: 1
423
+ small_kernel: Kernel size of small kernel conv branch.
424
+ inference_mode: If True, instantiates model in inference mode. Default: ``False``
425
+ activation: Activation module. Default: ``nn.GELU``
426
+ """
427
+ super(ReparamLargeKernelConv, self).__init__()
428
+
429
+ self.stride = stride
430
+ self.groups = groups
431
+ self.in_channels = in_channels
432
+ self.out_channels = out_channels
433
+ self.activation = activation
434
+
435
+ self.kernel_size = kernel_size
436
+ self.small_kernel = small_kernel
437
+ self.padding = kernel_size // 2
438
+
439
+ # Check if SE is requested
440
+ if use_se:
441
+ self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
442
+ else:
443
+ self.se = nn.Identity()
444
+
445
+ if inference_mode:
446
+ self.lkb_reparam = nn.Conv2d(
447
+ in_channels=in_channels,
448
+ out_channels=out_channels,
449
+ kernel_size=kernel_size,
450
+ stride=stride,
451
+ padding=self.padding,
452
+ dilation=1,
453
+ groups=groups,
454
+ bias=True,
455
+ )
456
+ else:
457
+ self.lkb_origin = self._conv_bn(
458
+ kernel_size=kernel_size, padding=self.padding
459
+ )
460
+ if small_kernel is not None:
461
+ assert (
462
+ small_kernel <= kernel_size
463
+ ), "The kernel size for re-param cannot be larger than the large kernel!"
464
+ self.small_conv = self._conv_bn(
465
+ kernel_size=small_kernel, padding=small_kernel // 2
466
+ )
467
+
468
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
469
+ """Apply forward pass."""
470
+ if hasattr(self, "lkb_reparam"):
471
+ out = self.lkb_reparam(x)
472
+ else:
473
+ out = self.lkb_origin(x)
474
+ if hasattr(self, "small_conv"):
475
+ out += self.small_conv(x)
476
+
477
+ return self.activation(self.se(out))
478
+
479
+ def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
480
+ """Method to obtain re-parameterized kernel and bias.
481
+ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
482
+ Returns:
483
+ Tuple of (kernel, bias) after fusing branches.
484
+ """
485
+ eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
486
+ if hasattr(self, "small_conv"):
487
+ small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
488
+ eq_b += small_b
489
+ eq_k += nn.functional.pad(
490
+ small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
491
+ )
492
+ return eq_k, eq_b
493
+
494
+ def reparameterize(self) -> None:
495
+ """
496
+ Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
497
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
498
+ architecture used at training time to obtain a plain CNN-like structure
499
+ for inference.
500
+ """
501
+ eq_k, eq_b = self.get_kernel_bias()
502
+ self.lkb_reparam = nn.Conv2d(
503
+ in_channels=self.in_channels,
504
+ out_channels=self.out_channels,
505
+ kernel_size=self.kernel_size,
506
+ stride=self.stride,
507
+ padding=self.padding,
508
+ dilation=self.lkb_origin.conv.dilation,
509
+ groups=self.groups,
510
+ bias=True,
511
+ )
512
+
513
+ self.lkb_reparam.weight.data = eq_k
514
+ self.lkb_reparam.bias.data = eq_b
515
+ self.__delattr__("lkb_origin")
516
+ if hasattr(self, "small_conv"):
517
+ self.__delattr__("small_conv")
518
+
519
+ @staticmethod
520
+ def _fuse_bn(
521
+ conv: torch.Tensor, bn: nn.BatchNorm2d
522
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
523
+ """Method to fuse batchnorm layer with conv layer.
524
+ Args:
525
+ conv: Convolutional kernel weights.
526
+ bn: Batchnorm 2d layer.
527
+ Returns:
528
+ Tuple of (kernel, bias) after fusing batchnorm.
529
+ """
530
+ kernel = conv.weight
531
+ running_mean = bn.running_mean
532
+ running_var = bn.running_var
533
+ gamma = bn.weight
534
+ beta = bn.bias
535
+ eps = bn.eps
536
+ std = (running_var + eps).sqrt()
537
+ t = (gamma / std).reshape(-1, 1, 1, 1)
538
+ return kernel * t, beta - running_mean * gamma / std
539
+
540
+ def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
541
+ """Helper method to construct conv-batchnorm layers.
542
+ Args:
543
+ kernel_size: Size of the convolution kernel.
544
+ padding: Zero-padding size.
545
+ Returns:
546
+ A nn.Sequential Conv-BN module.
547
+ """
548
+ # Fallback, sometimes batchnorm tensors
549
+ # do not get instantiated correctly on some processes
550
+ # when using deepspeed + accelerate
551
+ norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
552
+ if norm_layer.weight.shape[0] == 0:
553
+ norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
554
+ if norm_layer.bias.shape[0] == 0:
555
+ norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
556
+
557
+ mod_list = nn.Sequential()
558
+ mod_list.add_module(
559
+ "conv",
560
+ nn.Conv2d(
561
+ in_channels=self.in_channels,
562
+ out_channels=self.out_channels,
563
+ kernel_size=kernel_size,
564
+ stride=self.stride,
565
+ padding=padding,
566
+ groups=self.groups,
567
+ bias=False,
568
+ ),
569
+ )
570
+ mod_list.add_module("bn", norm_layer)
571
+ return mod_list
572
+
573
+
574
+ def convolutional_stem(
575
+ in_channels: int, out_channels: int, inference_mode: bool = False, use_scale_branch: bool = True,
576
+ ) -> nn.Sequential:
577
+ """Build convolutional stem with MobileOne blocks.
578
+ Args:
579
+ in_channels: Number of input channels.
580
+ out_channels: Number of output channels.
581
+ inference_mode: Flag to instantiate model in inference mode. Default: ``False``
582
+ Returns:
583
+ nn.Sequential object with stem elements.
584
+ """
585
+ return nn.Sequential(
586
+ MobileOneBlock(
587
+ in_channels=in_channels,
588
+ out_channels=out_channels,
589
+ kernel_size=3,
590
+ stride=2,
591
+ padding=1,
592
+ groups=1,
593
+ inference_mode=inference_mode,
594
+ use_se=False,
595
+ num_conv_branches=1,
596
+ use_scale_branch=use_scale_branch
597
+ ),
598
+ MobileOneBlock(
599
+ in_channels=out_channels,
600
+ out_channels=out_channels,
601
+ kernel_size=3,
602
+ stride=2,
603
+ padding=1,
604
+ groups=out_channels,
605
+ inference_mode=inference_mode,
606
+ use_se=False,
607
+ num_conv_branches=1,
608
+ use_scale_branch=use_scale_branch
609
+ ),
610
+ MobileOneBlock(
611
+ in_channels=out_channels,
612
+ out_channels=out_channels,
613
+ kernel_size=1,
614
+ stride=1,
615
+ padding=0,
616
+ groups=1,
617
+ inference_mode=inference_mode,
618
+ use_se=False,
619
+ num_conv_branches=1,
620
+ use_scale_branch=use_scale_branch
621
+ ),
622
+ )
623
+
624
+
625
+ class LayerNormChannel(nn.Module):
626
+ """
627
+ LayerNorm only for Channel Dimension.
628
+ Input: tensor in shape [B, C, H, W]
629
+ """
630
+ def __init__(self, num_features, eps=1e-05) -> None:
631
+ super().__init__()
632
+ self.weight = nn.Parameter(torch.ones(num_features))
633
+ self.bias = nn.Parameter(torch.zeros(num_features))
634
+ self.eps = eps
635
+
636
+ def forward(self, x) -> torch.Tensor:
637
+ u = x.mean(1, keepdim=True)
638
+ s = (x - u).pow(2).mean(1, keepdim=True)
639
+ x = (x - u) / torch.sqrt(s + self.eps)
640
+ x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
641
+ + self.bias.unsqueeze(-1).unsqueeze(-1)
642
+ return x
643
+
644
+
645
+ class MHSA(nn.Module):
646
+ """Multi-headed Self Attention module.
647
+ Source modified from:
648
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
649
+ """
650
+
651
+ def __init__(
652
+ self,
653
+ dim: int,
654
+ head_dim: int = 32,
655
+ qkv_bias: bool = False,
656
+ attn_drop: float = 0.0,
657
+ proj_drop: float = 0.0,
658
+ ) -> None:
659
+ """Build MHSA module that can handle 3D or 4D input tensors.
660
+ Args:
661
+ dim: Number of embedding dimensions.
662
+ head_dim: Number of hidden dimensions per head. Default: ``32``
663
+ qkv_bias: Use bias or not. Default: ``False``
664
+ attn_drop: Dropout rate for attention tensor.
665
+ proj_drop: Dropout rate for projection tensor.
666
+ """
667
+ super().__init__()
668
+ assert dim % head_dim == 0, "dim should be divisible by head_dim"
669
+ self.head_dim = head_dim
670
+ self.num_heads = dim // head_dim
671
+ self.scale = head_dim**-0.5
672
+
673
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
674
+ self.attn_drop = nn.Dropout(attn_drop)
675
+ self.proj = nn.Linear(dim, dim)
676
+ self.proj_drop = nn.Dropout(proj_drop)
677
+
678
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
679
+ shape = x.shape
680
+ B, C, H, W = shape
681
+ N = H * W
682
+ if len(shape) == 4:
683
+ x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
684
+ qkv = (
685
+ self.qkv(x)
686
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
687
+ .permute(2, 0, 3, 1, 4)
688
+ )
689
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
690
+
691
+ # trick here to make [email protected] more stable
692
+ attn = (q * self.scale) @ k.transpose(-2, -1)
693
+ attn = attn.softmax(dim=-1)
694
+ attn = self.attn_drop(attn)
695
+
696
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
697
+ x = self.proj(x)
698
+ x = self.proj_drop(x)
699
+ if len(shape) == 4:
700
+ x = x.transpose(-2, -1).reshape(B, C, H, W)
701
+
702
+ return x
703
+
704
+
705
+ class PatchEmbed(nn.Module):
706
+ """Convolutional patch embedding layer."""
707
+
708
+ def __init__(
709
+ self,
710
+ patch_size: int,
711
+ stride: int,
712
+ in_channels: int,
713
+ embed_dim: int,
714
+ inference_mode: bool = False,
715
+ use_se: bool = False,
716
+ ) -> None:
717
+ """Build patch embedding layer.
718
+ Args:
719
+ patch_size: Patch size for embedding computation.
720
+ stride: Stride for convolutional embedding layer.
721
+ in_channels: Number of channels of input tensor.
722
+ embed_dim: Number of embedding dimensions.
723
+ inference_mode: Flag to instantiate model in inference mode. Default: ``False``
724
+ use_se: If ``True`` SE block will be used.
725
+ """
726
+ super().__init__()
727
+ block = list()
728
+ block.append(
729
+ ReparamLargeKernelConv(
730
+ in_channels=in_channels,
731
+ out_channels=embed_dim,
732
+ kernel_size=patch_size,
733
+ stride=stride,
734
+ groups=in_channels,
735
+ small_kernel=3,
736
+ inference_mode=inference_mode,
737
+ use_se=use_se,
738
+ )
739
+ )
740
+ block.append(
741
+ MobileOneBlock(
742
+ in_channels=embed_dim,
743
+ out_channels=embed_dim,
744
+ kernel_size=1,
745
+ stride=1,
746
+ padding=0,
747
+ groups=1,
748
+ inference_mode=inference_mode,
749
+ use_se=False,
750
+ num_conv_branches=1,
751
+ )
752
+ )
753
+ self.proj = nn.Sequential(*block)
754
+
755
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
756
+ x = self.proj(x)
757
+ return x
758
+
759
+
760
+ class RepMixer(nn.Module):
761
+ """Reparameterizable token mixer.
762
+ For more details, please refer to our paper:
763
+ `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
764
+ """
765
+
766
+ def __init__(
767
+ self,
768
+ dim,
769
+ kernel_size=3,
770
+ use_layer_scale=True,
771
+ layer_scale_init_value=1e-5,
772
+ inference_mode: bool = False,
773
+ ):
774
+ """Build RepMixer Module.
775
+ Args:
776
+ dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
777
+ kernel_size: Kernel size for spatial mixing. Default: 3
778
+ use_layer_scale: If True, learnable layer scale is used. Default: ``True``
779
+ layer_scale_init_value: Initial value for layer scale. Default: 1e-5
780
+ inference_mode: If True, instantiates model in inference mode. Default: ``False``
781
+ """
782
+ super().__init__()
783
+ self.dim = dim
784
+ self.kernel_size = kernel_size
785
+ self.inference_mode = inference_mode
786
+
787
+ if inference_mode:
788
+ self.reparam_conv = nn.Conv2d(
789
+ in_channels=self.dim,
790
+ out_channels=self.dim,
791
+ kernel_size=self.kernel_size,
792
+ stride=1,
793
+ padding=self.kernel_size // 2,
794
+ groups=self.dim,
795
+ bias=True,
796
+ )
797
+ else:
798
+ self.norm = MobileOneBlock(
799
+ dim,
800
+ dim,
801
+ kernel_size,
802
+ padding=kernel_size // 2,
803
+ groups=dim,
804
+ use_act=False,
805
+ use_scale_branch=False,
806
+ num_conv_branches=0,
807
+ )
808
+ self.mixer = MobileOneBlock(
809
+ dim,
810
+ dim,
811
+ kernel_size,
812
+ padding=kernel_size // 2,
813
+ groups=dim,
814
+ use_act=False,
815
+ )
816
+ self.use_layer_scale = use_layer_scale
817
+ if use_layer_scale:
818
+ self.layer_scale = nn.Parameter(
819
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
820
+ )
821
+
822
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
823
+ if hasattr(self, "reparam_conv"):
824
+ x = self.reparam_conv(x)
825
+ return x
826
+ else:
827
+ if self.use_layer_scale:
828
+ x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
829
+ else:
830
+ x = x + self.mixer(x) - self.norm(x)
831
+ return x
832
+
833
+ def reparameterize(self) -> None:
834
+ """Reparameterize mixer and norm into a single
835
+ convolutional layer for efficient inference.
836
+ """
837
+ if self.inference_mode:
838
+ return
839
+
840
+ self.mixer.reparameterize()
841
+ self.norm.reparameterize()
842
+
843
+ if self.use_layer_scale:
844
+ w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
845
+ self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
846
+ )
847
+ b = torch.squeeze(self.layer_scale) * (
848
+ self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
849
+ )
850
+ else:
851
+ w = (
852
+ self.mixer.id_tensor
853
+ + self.mixer.reparam_conv.weight
854
+ - self.norm.reparam_conv.weight
855
+ )
856
+ b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
857
+
858
+ self.reparam_conv = nn.Conv2d(
859
+ in_channels=self.dim,
860
+ out_channels=self.dim,
861
+ kernel_size=self.kernel_size,
862
+ stride=1,
863
+ padding=self.kernel_size // 2,
864
+ groups=self.dim,
865
+ bias=True,
866
+ )
867
+ self.reparam_conv.weight.data = w
868
+ self.reparam_conv.bias.data = b
869
+
870
+ self.__delattr__("mixer")
871
+ self.__delattr__("norm")
872
+ if self.use_layer_scale:
873
+ self.__delattr__("layer_scale")
874
+
875
+
876
+ class ConvFFN(nn.Module):
877
+ """Convolutional FFN Module."""
878
+
879
+ def __init__(
880
+ self,
881
+ in_channels: int,
882
+ hidden_channels: Optional[int] = None,
883
+ out_channels: Optional[int] = None,
884
+ act_layer: nn.Module = nn.GELU,
885
+ drop: float = 0.0,
886
+ ) -> None:
887
+ """Build convolutional FFN module.
888
+ Args:
889
+ in_channels: Number of input channels.
890
+ hidden_channels: Number of channels after expansion. Default: None
891
+ out_channels: Number of output channels. Default: None
892
+ act_layer: Activation layer. Default: ``GELU``
893
+ drop: Dropout rate. Default: ``0.0``.
894
+ """
895
+ super().__init__()
896
+ out_channels = out_channels or in_channels
897
+ hidden_channels = hidden_channels or in_channels
898
+ self.conv = nn.Sequential()
899
+ self.conv.add_module(
900
+ "conv",
901
+ nn.Conv2d(
902
+ in_channels=in_channels,
903
+ out_channels=out_channels,
904
+ kernel_size=7,
905
+ padding=3,
906
+ groups=in_channels,
907
+ bias=False,
908
+ ),
909
+ )
910
+
911
+ # Fallback, sometimes batchnorm tensors
912
+ # do not get instantiated correctly on some processes
913
+ # when using deepspeed + accelerate
914
+ norm_layer = nn.BatchNorm2d(num_features=out_channels)
915
+ if norm_layer.weight.shape[0] == 0:
916
+ norm_layer.weight = nn.Parameter(torch.zeros(out_channels))
917
+ if norm_layer.bias.shape[0] == 0:
918
+ norm_layer.bias = nn.Parameter(torch.zeros(out_channels))
919
+
920
+ self.conv.add_module("bn", norm_layer)
921
+ self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
922
+ self.act = act_layer()
923
+ self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
924
+ self.drop = nn.Dropout(drop)
925
+ self.apply(self._init_weights)
926
+
927
+ def _init_weights(self, m: nn.Module) -> None:
928
+ if isinstance(m, nn.Conv2d):
929
+ normal_(m.weight, std=0.02)
930
+ if m.bias is not None:
931
+ nn.init.constant_(m.bias, 0)
932
+
933
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
934
+ x = self.conv(x)
935
+ x = self.fc1(x)
936
+ x = self.act(x)
937
+ x = self.drop(x)
938
+ x = self.fc2(x)
939
+ x = self.drop(x)
940
+ return x
941
+
942
+
943
+ class RepCPE(nn.Module):
944
+ """Implementation of conditional positional encoding.
945
+ For more details refer to paper:
946
+ `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
947
+ In our implementation, we can reparameterize this module to eliminate a skip connection.
948
+ """
949
+
950
+ def __init__(
951
+ self,
952
+ in_channels: int,
953
+ embed_dim: int = 768,
954
+ spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
955
+ inference_mode=False,
956
+ ) -> None:
957
+ """Build reparameterizable conditional positional encoding
958
+ Args:
959
+ in_channels: Number of input channels.
960
+ embed_dim: Number of embedding dimensions. Default: 768
961
+ spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
962
+ inference_mode: Flag to instantiate block in inference mode. Default: ``False``
963
+ """
964
+ super(RepCPE, self).__init__()
965
+ if isinstance(spatial_shape, int):
966
+ spatial_shape = tuple([spatial_shape] * 2)
967
+ assert isinstance(spatial_shape, Tuple), (
968
+ f'"spatial_shape" must by a sequence or int, '
969
+ f"get {type(spatial_shape)} instead."
970
+ )
971
+ assert len(spatial_shape) == 2, (
972
+ f'Length of "spatial_shape" should be 2, '
973
+ f"got {len(spatial_shape)} instead."
974
+ )
975
+
976
+ self.spatial_shape = spatial_shape
977
+ self.embed_dim = embed_dim
978
+ self.in_channels = in_channels
979
+ self.groups = embed_dim
980
+
981
+ if inference_mode:
982
+ self.reparam_conv = nn.Conv2d(
983
+ in_channels=self.in_channels,
984
+ out_channels=self.embed_dim,
985
+ kernel_size=self.spatial_shape,
986
+ stride=1,
987
+ padding=int(self.spatial_shape[0] // 2),
988
+ groups=self.embed_dim,
989
+ bias=True,
990
+ )
991
+ else:
992
+ self.pe = nn.Conv2d(
993
+ in_channels,
994
+ embed_dim,
995
+ spatial_shape,
996
+ 1,
997
+ int(spatial_shape[0] // 2),
998
+ bias=True,
999
+ groups=embed_dim,
1000
+ )
1001
+
1002
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1003
+ if hasattr(self, "reparam_conv"):
1004
+ x = self.reparam_conv(x)
1005
+ return x
1006
+ else:
1007
+ x = self.pe(x) + x
1008
+ return x
1009
+
1010
+ def reparameterize(self) -> None:
1011
+ # Build equivalent Id tensor
1012
+ input_dim = self.in_channels // self.groups
1013
+ kernel_value = torch.zeros(
1014
+ (
1015
+ self.in_channels,
1016
+ input_dim,
1017
+ self.spatial_shape[0],
1018
+ self.spatial_shape[1],
1019
+ ),
1020
+ dtype=self.pe.weight.dtype,
1021
+ device=self.pe.weight.device,
1022
+ )
1023
+ for i in range(self.in_channels):
1024
+ kernel_value[
1025
+ i,
1026
+ i % input_dim,
1027
+ self.spatial_shape[0] // 2,
1028
+ self.spatial_shape[1] // 2,
1029
+ ] = 1
1030
+ id_tensor = kernel_value
1031
+
1032
+ # Reparameterize Id tensor and conv
1033
+ w_final = id_tensor + self.pe.weight
1034
+ b_final = self.pe.bias
1035
+
1036
+ # Introduce reparam conv
1037
+ self.reparam_conv = nn.Conv2d(
1038
+ in_channels=self.in_channels,
1039
+ out_channels=self.embed_dim,
1040
+ kernel_size=self.spatial_shape,
1041
+ stride=1,
1042
+ padding=int(self.spatial_shape[0] // 2),
1043
+ groups=self.embed_dim,
1044
+ bias=True,
1045
+ )
1046
+ self.reparam_conv.weight.data = w_final
1047
+ self.reparam_conv.bias.data = b_final
1048
+
1049
+ self.__delattr__("pe")
1050
+
1051
+
1052
+ class RepMixerBlock(nn.Module):
1053
+ """Implementation of Metaformer block with RepMixer as token mixer.
1054
+ For more details on Metaformer structure, please refer to:
1055
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
1056
+ """
1057
+
1058
+ def __init__(
1059
+ self,
1060
+ dim: int,
1061
+ kernel_size: int = 3,
1062
+ mlp_ratio: float = 4.0,
1063
+ act_layer: nn.Module = nn.GELU,
1064
+ drop: float = 0.0,
1065
+ drop_path: float = 0.0,
1066
+ use_layer_scale: bool = True,
1067
+ layer_scale_init_value: float = 1e-5,
1068
+ inference_mode: bool = False,
1069
+ ):
1070
+ """Build RepMixer Block.
1071
+ Args:
1072
+ dim: Number of embedding dimensions.
1073
+ kernel_size: Kernel size for repmixer. Default: 3
1074
+ mlp_ratio: MLP expansion ratio. Default: 4.0
1075
+ act_layer: Activation layer. Default: ``nn.GELU``
1076
+ drop: Dropout rate. Default: 0.0
1077
+ drop_path: Drop path rate. Default: 0.0
1078
+ use_layer_scale: Flag to turn on layer scale. Default: ``True``
1079
+ layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
1080
+ inference_mode: Flag to instantiate block in inference mode. Default: ``False``
1081
+ """
1082
+
1083
+ super().__init__()
1084
+
1085
+ self.token_mixer = RepMixer(
1086
+ dim,
1087
+ kernel_size=kernel_size,
1088
+ use_layer_scale=use_layer_scale,
1089
+ layer_scale_init_value=layer_scale_init_value,
1090
+ inference_mode=inference_mode,
1091
+ )
1092
+
1093
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
1094
+ mlp_ratio
1095
+ )
1096
+ mlp_hidden_dim = int(dim * mlp_ratio)
1097
+ self.convffn = ConvFFN(
1098
+ in_channels=dim,
1099
+ hidden_channels=mlp_hidden_dim,
1100
+ act_layer=act_layer,
1101
+ drop=drop,
1102
+ )
1103
+
1104
+ # Drop Path
1105
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1106
+
1107
+ # Layer Scale
1108
+ self.use_layer_scale = use_layer_scale
1109
+ if use_layer_scale:
1110
+ self.layer_scale = nn.Parameter(
1111
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1112
+ )
1113
+
1114
+ def forward(self, x):
1115
+ if self.use_layer_scale:
1116
+ x = self.token_mixer(x)
1117
+ x = x + self.drop_path(self.layer_scale * self.convffn(x))
1118
+ else:
1119
+ x = self.token_mixer(x)
1120
+ x = x + self.drop_path(self.convffn(x))
1121
+ return x
1122
+
1123
+
1124
+ class AttentionBlock(nn.Module):
1125
+ """Implementation of metaformer block with MHSA as token mixer.
1126
+ For more details on Metaformer structure, please refer to:
1127
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
1128
+ """
1129
+
1130
+ def __init__(
1131
+ self,
1132
+ dim: int,
1133
+ mlp_ratio: float = 4.0,
1134
+ act_layer: nn.Module = nn.GELU,
1135
+ norm_layer: nn.Module = nn.BatchNorm2d,
1136
+ drop: float = 0.0,
1137
+ drop_path: float = 0.0,
1138
+ use_layer_scale: bool = True,
1139
+ layer_scale_init_value: float = 1e-5,
1140
+ ):
1141
+ """Build Attention Block.
1142
+ Args:
1143
+ dim: Number of embedding dimensions.
1144
+ mlp_ratio: MLP expansion ratio. Default: 4.0
1145
+ act_layer: Activation layer. Default: ``nn.GELU``
1146
+ norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
1147
+ drop: Dropout rate. Default: 0.0
1148
+ drop_path: Drop path rate. Default: 0.0
1149
+ use_layer_scale: Flag to turn on layer scale. Default: ``True``
1150
+ layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
1151
+ """
1152
+
1153
+ super().__init__()
1154
+
1155
+ # Fallback, sometimes batchnorm tensors
1156
+ # do not get instantiated correctly on some processes
1157
+ # when using deepspeed + accelerate
1158
+ norm_layer_ = norm_layer(num_features=dim)
1159
+ if norm_layer_.weight.shape[0] == 0:
1160
+ norm_layer_.weight = nn.Parameter(torch.zeros(dim))
1161
+ if norm_layer_.bias.shape[0] == 0:
1162
+ norm_layer_.bias = nn.Parameter(torch.zeros(dim))
1163
+
1164
+ self.norm = norm_layer_
1165
+ self.token_mixer = MHSA(dim=dim)
1166
+
1167
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
1168
+ mlp_ratio
1169
+ )
1170
+ mlp_hidden_dim = int(dim * mlp_ratio)
1171
+ self.convffn = ConvFFN(
1172
+ in_channels=dim,
1173
+ hidden_channels=mlp_hidden_dim,
1174
+ act_layer=act_layer,
1175
+ drop=drop,
1176
+ )
1177
+
1178
+ # Drop path
1179
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1180
+
1181
+ # Layer Scale
1182
+ self.use_layer_scale = use_layer_scale
1183
+ if use_layer_scale:
1184
+ self.layer_scale_1 = nn.Parameter(
1185
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1186
+ )
1187
+ self.layer_scale_2 = nn.Parameter(
1188
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1189
+ )
1190
+
1191
+ def forward(self, x):
1192
+ if self.use_layer_scale:
1193
+ x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
1194
+ x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
1195
+ else:
1196
+ x = x + self.drop_path(self.token_mixer(self.norm(x)))
1197
+ x = x + self.drop_path(self.convffn(x))
1198
+ return x
1199
+
1200
+
1201
+ def basic_blocks(
1202
+ dim: int,
1203
+ block_index: int,
1204
+ num_blocks: List[int],
1205
+ token_mixer_type: str,
1206
+ kernel_size: int = 3,
1207
+ mlp_ratio: float = 4.0,
1208
+ act_layer: nn.Module = nn.GELU,
1209
+ norm_layer: nn.Module = nn.BatchNorm2d,
1210
+ drop_rate: float = 0.0,
1211
+ drop_path_rate: float = 0.0,
1212
+ use_layer_scale: bool = True,
1213
+ layer_scale_init_value: float = 1e-5,
1214
+ inference_mode=False,
1215
+ ) -> nn.Sequential:
1216
+ """Build FastViT blocks within a stage.
1217
+ Args:
1218
+ dim: Number of embedding dimensions.
1219
+ block_index: block index.
1220
+ num_blocks: List containing number of blocks per stage.
1221
+ token_mixer_type: Token mixer type.
1222
+ kernel_size: Kernel size for repmixer.
1223
+ mlp_ratio: MLP expansion ratio.
1224
+ act_layer: Activation layer.
1225
+ norm_layer: Normalization layer.
1226
+ drop_rate: Dropout rate.
1227
+ drop_path_rate: Drop path rate.
1228
+ use_layer_scale: Flag to turn on layer scale regularization.
1229
+ layer_scale_init_value: Layer scale value at initialization.
1230
+ inference_mode: Flag to instantiate block in inference mode.
1231
+ Returns:
1232
+ nn.Sequential object of all the blocks within the stage.
1233
+ """
1234
+ blocks = []
1235
+ for block_idx in range(num_blocks[block_index]):
1236
+ block_dpr = (
1237
+ drop_path_rate
1238
+ * (block_idx + sum(num_blocks[:block_index]))
1239
+ / (sum(num_blocks) - 1)
1240
+ )
1241
+ if token_mixer_type == "repmixer":
1242
+ blocks.append(
1243
+ RepMixerBlock(
1244
+ dim,
1245
+ kernel_size=kernel_size,
1246
+ mlp_ratio=mlp_ratio,
1247
+ act_layer=act_layer,
1248
+ drop=drop_rate,
1249
+ drop_path=block_dpr,
1250
+ use_layer_scale=use_layer_scale,
1251
+ layer_scale_init_value=layer_scale_init_value,
1252
+ inference_mode=inference_mode,
1253
+ )
1254
+ )
1255
+ elif token_mixer_type == "attention":
1256
+ blocks.append(
1257
+ AttentionBlock(
1258
+ dim,
1259
+ mlp_ratio=mlp_ratio,
1260
+ act_layer=act_layer,
1261
+ norm_layer=norm_layer,
1262
+ drop=drop_rate,
1263
+ drop_path=block_dpr,
1264
+ use_layer_scale=use_layer_scale,
1265
+ layer_scale_init_value=layer_scale_init_value,
1266
+ )
1267
+ )
1268
+ else:
1269
+ raise ValueError(
1270
+ "Token mixer type: {} not supported".format(token_mixer_type)
1271
+ )
1272
+ blocks = nn.Sequential(*blocks)
1273
+ return blocks
1274
+
1275
+
1276
+ class GlobalPool2D(nn.Module):
1277
+ """This class implements global pooling with linear projection."""
1278
+
1279
+ def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
1280
+ super().__init__()
1281
+ scale = in_dim**-0.5
1282
+ self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
1283
+ self.in_dim = in_dim
1284
+ self.out_dim = out_dim
1285
+
1286
+ def pool(self, x) -> Tensor:
1287
+ if x.dim() == 4:
1288
+ dims = [-2, -1]
1289
+ elif x.dim() == 5:
1290
+ dims = [-3, -2, -1]
1291
+ x = torch.mean(x, dim=dims, keepdim=False)
1292
+ return x
1293
+
1294
+ def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
1295
+ # x is of shape [batch, in_dim]
1296
+ assert (
1297
+ x.dim() == 4
1298
+ ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
1299
+ x.shape
1300
+ )
1301
+
1302
+ # [batch, in_dim, in_height, in_width] --> [batch, in_dim]
1303
+ x = self.pool(x)
1304
+ # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
1305
+ x = x @ self.proj
1306
+ return x
1307
+
1308
+
1309
+ class FastViT(nn.Module):
1310
+ """
1311
+ This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
1312
+ """
1313
+
1314
+ def __init__(
1315
+ self,
1316
+ layers,
1317
+ token_mixers: Tuple[str, ...],
1318
+ embed_dims=None,
1319
+ mlp_ratios=None,
1320
+ downsamples=None,
1321
+ se_downsamples=None,
1322
+ repmixer_kernel_size=3,
1323
+ norm_layer: nn.Module = nn.BatchNorm2d,
1324
+ act_layer: nn.Module = nn.GELU,
1325
+ num_classes=1000,
1326
+ pos_embs=None,
1327
+ down_patch_size=7,
1328
+ down_stride=2,
1329
+ drop_rate=0.0,
1330
+ drop_path_rate=0.0,
1331
+ use_layer_scale=True,
1332
+ layer_scale_init_value=1e-5,
1333
+ init_cfg=None,
1334
+ pretrained=None,
1335
+ cls_ratio=2.0,
1336
+ inference_mode=False,
1337
+ stem_scale_branch=True,
1338
+ **kwargs,
1339
+ ) -> None:
1340
+
1341
+ super().__init__()
1342
+
1343
+ self.num_classes = num_classes
1344
+ if len(layers) == 4:
1345
+ self.out_indices = [0, 2, 4, 7]
1346
+ elif len(layers) == 5:
1347
+ self.out_indices = [0, 2, 4, 7, 10]
1348
+ else:
1349
+ raise NotImplementedError("FPN is not implemented for more than 5 stages.")
1350
+
1351
+ if pos_embs is None:
1352
+ pos_embs = [None] * len(layers)
1353
+
1354
+ if se_downsamples is None:
1355
+ se_downsamples = [False] * len(layers)
1356
+
1357
+ # Convolutional stem
1358
+ self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode,
1359
+ use_scale_branch=stem_scale_branch)
1360
+
1361
+ # Build the main stages of the network architecture
1362
+ network = []
1363
+ for i in range(len(layers)):
1364
+ # Add position embeddings if requested
1365
+ if pos_embs[i] is not None:
1366
+ network.append(
1367
+ pos_embs[i](
1368
+ embed_dims[i], embed_dims[i], inference_mode=inference_mode
1369
+ )
1370
+ )
1371
+ stage = basic_blocks(
1372
+ embed_dims[i],
1373
+ i,
1374
+ layers,
1375
+ token_mixer_type=token_mixers[i],
1376
+ kernel_size=repmixer_kernel_size,
1377
+ mlp_ratio=mlp_ratios[i],
1378
+ act_layer=act_layer,
1379
+ norm_layer=norm_layer,
1380
+ drop_rate=drop_rate,
1381
+ drop_path_rate=drop_path_rate,
1382
+ use_layer_scale=use_layer_scale,
1383
+ layer_scale_init_value=layer_scale_init_value,
1384
+ inference_mode=inference_mode,
1385
+ )
1386
+ network.append(stage)
1387
+ if i >= len(layers) - 1:
1388
+ break
1389
+
1390
+ # Patch merging/downsampling between stages.
1391
+ if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
1392
+ network.append(
1393
+ PatchEmbed(
1394
+ patch_size=down_patch_size,
1395
+ stride=down_stride,
1396
+ in_channels=embed_dims[i],
1397
+ embed_dim=embed_dims[i + 1],
1398
+ inference_mode=inference_mode,
1399
+ use_se=se_downsamples[i + 1],
1400
+ )
1401
+ )
1402
+ self.network = nn.ModuleList(network)
1403
+
1404
+ # Classifier head
1405
+ self.conv_exp = MobileOneBlock(
1406
+ in_channels=embed_dims[-1],
1407
+ out_channels=int(embed_dims[-1] * cls_ratio),
1408
+ kernel_size=3,
1409
+ stride=1,
1410
+ padding=1,
1411
+ groups=embed_dims[-1],
1412
+ inference_mode=inference_mode,
1413
+ use_se=True,
1414
+ num_conv_branches=1,
1415
+ )
1416
+ self.head = (
1417
+ nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
1418
+ if num_classes > 0
1419
+ else nn.Identity()
1420
+ )
1421
+ self.apply(self.cls_init_weights)
1422
+ self.init_cfg = copy.deepcopy(init_cfg)
1423
+
1424
+ def cls_init_weights(self, m: nn.Module) -> None:
1425
+ """Init. for classification"""
1426
+ if isinstance(m, nn.Linear):
1427
+ normal_(m.weight, std=0.02)
1428
+ if isinstance(m, nn.Linear) and m.bias is not None:
1429
+ nn.init.constant_(m.bias, 0)
1430
+
1431
+ def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
1432
+ x = self.patch_embed(x)
1433
+ return x
1434
+
1435
+ def forward_tokens(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1436
+ for idx, block in enumerate(self.network):
1437
+ x = block(x)
1438
+ return x
1439
+
1440
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
1441
+ # input embedding
1442
+ x = self.forward_embeddings(x)
1443
+ # through backbone
1444
+ x = self.forward_tokens(x)
1445
+ # for image classification/embedding
1446
+ x = self.conv_exp(x)
1447
+ cls_out = self.head(x)
1448
+
1449
+ out_dict = dict()
1450
+ if kwargs.get("return_image_embeddings", False):
1451
+ out_dict.update({"logits": cls_out})
1452
+ out_dict.update({"image_embeddings": x})
1453
+ return out_dict
1454
+ else:
1455
+ return cls_out
1456
+
1457
+
1458
+ @register_model
1459
+ def fastvithd(pretrained=False, **kwargs):
1460
+ """Instantiate FastViTHD model variant."""
1461
+ layers = [2, 12, 24, 4, 2]
1462
+ embed_dims = [96, 192, 384, 768, 1536]
1463
+ mlp_ratios = [4, 4, 4, 4, 4]
1464
+ downsamples = [True, True, True, True, True]
1465
+ pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7)), partial(RepCPE, spatial_shape=(7, 7))]
1466
+ token_mixers = ("repmixer", "repmixer", "repmixer", "attention", "attention")
1467
+ model = FastViT(
1468
+ layers,
1469
+ token_mixers=token_mixers,
1470
+ embed_dims=embed_dims,
1471
+ pos_embs=pos_embs,
1472
+ mlp_ratios=mlp_ratios,
1473
+ downsamples=downsamples,
1474
+ norm_layer=LayerNormChannel,
1475
+ stem_scale_branch=False,
1476
+ inference_mode=True,
1477
+ **kwargs,
1478
+ )
1479
+ model.default_cfg = default_cfgs["fastvit_m"]
1480
+ if pretrained:
1481
+ raise ValueError("Functionality not implemented.")
1482
+ return model
1483
+
1484
+ def load_model_config(
1485
+ model_name: str,
1486
+ ) -> Any:
1487
+ model_cfg = {
1488
+ "embed_dim": 768,
1489
+ "image_cfg": {
1490
+ "image_size": 1024,
1491
+ "model_name": "fastvithd",
1492
+ "embed_dim": 3072,
1493
+ "patch_size": 64
1494
+ },
1495
+ "text_cfg": {
1496
+ "context_length": 77,
1497
+ "vocab_size": 49408,
1498
+ "dim": 768,
1499
+ "ffn_multiplier_per_layer": 4.0,
1500
+ "n_heads_per_layer": 12,
1501
+ "n_transformer_layers": 12,
1502
+ "norm_layer": "layer_norm_fp32",
1503
+ "causal_masking": False,
1504
+ "model_name": "base"
1505
+ }
1506
+ }
1507
+ return model_cfg
1508
+
1509
+
1510
+ class MCi(nn.Module):
1511
+ """
1512
+ This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_
1513
+ """
1514
+
1515
+ def __init__(self, model_name: str, *args, **kwargs) -> None:
1516
+ super().__init__()
1517
+ self.projection_dim = None
1518
+ if "projection_dim" in kwargs:
1519
+ self.projection_dim = kwargs.get("projection_dim")
1520
+
1521
+ # Create model
1522
+ self.model = create_model(model_name, projection_dim=self.projection_dim)
1523
+
1524
+ # Build out projection head.
1525
+ if self.projection_dim is not None:
1526
+ if hasattr(self.model, "head"):
1527
+ self.model.head = MCi._update_image_classifier(
1528
+ image_classifier=self.model.head, projection_dim=self.projection_dim
1529
+ )
1530
+
1531
+ def forward(self, x: Any, *args, **kwargs) -> Any:
1532
+ """A forward function of the model."""
1533
+ x = self.model(x, *args, **kwargs)
1534
+ return x
1535
+
1536
+ @staticmethod
1537
+ def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
1538
+ """Return the input feature dimension to the image classification head."""
1539
+ in_features = None
1540
+ if isinstance(image_classifier, nn.Sequential):
1541
+ # Classifier that uses nn.Sequential usually has global pooling and
1542
+ # multiple linear layers. Find the first linear layer and get its
1543
+ # in_features
1544
+ for layer in image_classifier:
1545
+ if isinstance(layer, nn.Linear):
1546
+ in_features = layer.in_features
1547
+ break
1548
+ elif isinstance(image_classifier, nn.Linear):
1549
+ in_features = image_classifier.in_features
1550
+
1551
+ if in_features is None:
1552
+ raise NotImplementedError(
1553
+ f"Cannot get input feature dimension of {image_classifier}."
1554
+ )
1555
+ return in_features
1556
+
1557
+ @staticmethod
1558
+ def _update_image_classifier(
1559
+ image_classifier: nn.Module, projection_dim: int, *args, **kwargs
1560
+ ) -> nn.Module:
1561
+ in_features = MCi._get_in_feature_dimension(image_classifier)
1562
+ new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
1563
+ return new_img_classifier
1564
+
1565
+
1566
+ class MobileCLIPVisionTower(nn.Module):
1567
+ def __init__(self, vision_tower, args, delay_load=False):
1568
+ super().__init__()
1569
+
1570
+ self.is_loaded = False
1571
+ self.vision_tower_name = vision_tower
1572
+ self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
1573
+ self.input_image_size = int(vision_tower.split("_")[-1])
1574
+
1575
+ # Delay load is disabled for now
1576
+ if not delay_load:
1577
+ self.load_model()
1578
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
1579
+ self.load_model()
1580
+ else:
1581
+ model_cfg = load_model_config(self.vision_tower_name)
1582
+ self.cfg_only = model_cfg
1583
+
1584
+ def load_model(self, device_map=None):
1585
+ if self.is_loaded:
1586
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
1587
+ return
1588
+
1589
+ # Load model config
1590
+ model_cfg = load_model_config(self.vision_tower_name)
1591
+
1592
+ # Override default image resolution
1593
+ model_cfg["image_cfg"]["image_size"] = self.input_image_size
1594
+
1595
+ self.cfg_only = model_cfg
1596
+
1597
+ # Build HF CLIPImageProcessor with MobileCLIP parameters
1598
+ self.image_processor = CLIPImageProcessor(crop_size={"height": model_cfg["image_cfg"]["image_size"],
1599
+ "width": model_cfg["image_cfg"]["image_size"]},
1600
+ image_mean=[0.0, 0.0, 0.0],
1601
+ image_std=[1.0, 1.0, 1.0],
1602
+ size={"shortest_edge": model_cfg["image_cfg"]["image_size"]})
1603
+
1604
+ # Instantiate the image encoder
1605
+ self.vision_tower = MCi(model_name=model_cfg["image_cfg"]["model_name"],
1606
+ projection_dim=model_cfg["embed_dim"])
1607
+
1608
+ if not self.tune_vision_tower:
1609
+ self.vision_tower.requires_grad_(False)
1610
+
1611
+ self.is_loaded = True
1612
+
1613
+ def feature_select(self, image_forward_outs):
1614
+ # Features from penultimate layer
1615
+ image_features = image_forward_outs["image_embeddings"]
1616
+
1617
+ # Reshape 4D tensor to 3D
1618
+ B, C, H, W = image_features.shape
1619
+ image_features = image_features.reshape(B, C, H*W)
1620
+ image_features = image_features.transpose(1, 2)
1621
+ return image_features
1622
+
1623
+ def forward(self, images):
1624
+ if self.tune_vision_tower:
1625
+ return self.forward_images(images)
1626
+ else:
1627
+ with torch.no_grad():
1628
+ return self.forward_images(images)
1629
+
1630
+ def forward_images(self, images):
1631
+ if type(images) is list:
1632
+ image_features = []
1633
+ for image in images:
1634
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True)
1635
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
1636
+ image_features.append(image_feature)
1637
+ else:
1638
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True)
1639
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
1640
+
1641
+ return image_features
1642
+
1643
+ @property
1644
+ def dummy_feature(self):
1645
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
1646
+
1647
+ @property
1648
+ def dtype(self):
1649
+ return next(self.vision_tower.parameters()).dtype
1650
+
1651
+ @property
1652
+ def device(self):
1653
+ return next(self.vision_tower.parameters()).device
1654
+
1655
+ @property
1656
+ def config(self):
1657
+ return self.cfg_only
1658
+
1659
+ @property
1660
+ def hidden_size(self):
1661
+ return self.config["image_cfg"]["embed_dim"]
1662
+
1663
+ @property
1664
+ def num_patches_per_side(self):
1665
+ return self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]
1666
+
1667
+ @property
1668
+ def num_patches(self):
1669
+ return (self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]) ** 2
1670
+
1671
+ class IdentityMap(nn.Module):
1672
+ def __init__(self):
1673
+ super().__init__()
1674
+
1675
+ def forward(self, x, *args, **kwargs):
1676
+ return x
1677
+
1678
+ @property
1679
+ def config(self):
1680
+ return {"mm_projector_type": 'identity'}
1681
+
1682
+ def build_vision_projector(config, delay_load=False, **kwargs):
1683
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
1684
+
1685
+ if projector_type == 'linear':
1686
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
1687
+
1688
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
1689
+ if mlp_gelu_match:
1690
+ mlp_depth = int(mlp_gelu_match.group(1))
1691
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
1692
+ for _ in range(1, mlp_depth):
1693
+ modules.append(nn.GELU())
1694
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
1695
+ return nn.Sequential(*modules)
1696
+
1697
+ if projector_type == 'identity':
1698
+ return IdentityMap()
1699
+
1700
+ raise ValueError(f'Unknown projector type: {projector_type}')
1701
+
1702
+ def build_vision_tower(vision_tower_cfg, **kwargs):
1703
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
1704
+ return MobileCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
1705
+
1706
+ class LlavaMetaModel:
1707
+
1708
+ def __init__(self, config):
1709
+ super(LlavaMetaModel, self).__init__(config)
1710
+
1711
+ if hasattr(config, "mm_vision_tower"):
1712
+ self.vision_tower = build_vision_tower(config, delay_load=True)
1713
+ self.mm_projector = build_vision_projector(config)
1714
+
1715
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
1716
+ self.image_newline = nn.Parameter(
1717
+ torch.empty(config.hidden_size, dtype=self.dtype)
1718
+ )
1719
+
1720
+ def get_vision_tower(self):
1721
+ vision_tower = getattr(self, 'vision_tower', None)
1722
+ if type(vision_tower) is list:
1723
+ vision_tower = vision_tower[0]
1724
+ return vision_tower
1725
+
1726
+ def initialize_vision_modules(self, model_args, fsdp=None):
1727
+ vision_tower = model_args.vision_tower
1728
+ mm_vision_select_layer = model_args.mm_vision_select_layer
1729
+ mm_vision_select_feature = model_args.mm_vision_select_feature
1730
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
1731
+ mm_patch_merge_type = model_args.mm_patch_merge_type
1732
+
1733
+ self.config.mm_vision_tower = vision_tower
1734
+
1735
+ if self.get_vision_tower() is None:
1736
+ vision_tower = build_vision_tower(model_args)
1737
+
1738
+ if fsdp is not None and len(fsdp) > 0:
1739
+ self.vision_tower = [vision_tower]
1740
+ else:
1741
+ self.vision_tower = vision_tower
1742
+ else:
1743
+ if fsdp is not None and len(fsdp) > 0:
1744
+ vision_tower = self.vision_tower[0]
1745
+ else:
1746
+ vision_tower = self.vision_tower
1747
+ vision_tower.load_model()
1748
+
1749
+ self.config.use_mm_proj = True
1750
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
1751
+ self.config.mm_hidden_size = vision_tower.hidden_size
1752
+ self.config.mm_vision_select_layer = mm_vision_select_layer
1753
+ self.config.mm_vision_select_feature = mm_vision_select_feature
1754
+ self.config.mm_patch_merge_type = mm_patch_merge_type
1755
+
1756
+ if getattr(self, 'mm_projector', None) is None:
1757
+ self.mm_projector = build_vision_projector(self.config)
1758
+
1759
+ if 'unpad' in mm_patch_merge_type:
1760
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
1761
+ self.image_newline = nn.Parameter(
1762
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
1763
+ )
1764
+ else:
1765
+ # In case it is frozen by LoRA
1766
+ for p in self.mm_projector.parameters():
1767
+ p.requires_grad = True
1768
+
1769
+ if pretrain_mm_mlp_adapter is not None:
1770
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
1771
+
1772
+ def get_w(weights, keyword):
1773
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
1774
+
1775
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
1776
+
1777
+ def select_best_resolution(original_size, possible_resolutions):
1778
+ """
1779
+ Selects the best resolution from a list of possible resolutions based on the original size.
1780
+ Args:
1781
+ original_size (tuple): The original size of the image in the format (width, height).
1782
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
1783
+ Returns:
1784
+ tuple: The best fit resolution in the format (width, height).
1785
+ """
1786
+ original_width, original_height = original_size
1787
+ best_fit = None
1788
+ max_effective_resolution = 0
1789
+ min_wasted_resolution = float('inf')
1790
+
1791
+ for width, height in possible_resolutions:
1792
+ scale = min(width / original_width, height / original_height)
1793
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
1794
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
1795
+ wasted_resolution = (width * height) - effective_resolution
1796
+
1797
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
1798
+ max_effective_resolution = effective_resolution
1799
+ min_wasted_resolution = wasted_resolution
1800
+ best_fit = (width, height)
1801
+
1802
+ return best_fit
1803
+
1804
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
1805
+ """
1806
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
1807
+ Args:
1808
+ image_size (tuple): The size of the input image in the format (width, height).
1809
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
1810
+ patch_size (int): The size of each image patch.
1811
+ Returns:
1812
+ tuple: The shape of the image patch grid in the format (width, height).
1813
+ """
1814
+ import ast
1815
+ if type(grid_pinpoints) is list:
1816
+ possible_resolutions = grid_pinpoints
1817
+ else:
1818
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
1819
+ width, height = select_best_resolution(image_size, possible_resolutions)
1820
+ return width // patch_size, height // patch_size
1821
+
1822
+ class LlavaMetaForCausalLM(ABC):
1823
+
1824
+ @abstractmethod
1825
+ def get_model(self):
1826
+ pass
1827
+
1828
+ def get_vision_tower(self):
1829
+ return self.get_model().get_vision_tower()
1830
+
1831
+ def encode_images(self, images):
1832
+ image_features = self.get_model().get_vision_tower()(images)
1833
+ image_features = self.get_model().mm_projector(image_features)
1834
+ return image_features
1835
+
1836
+ def prepare_inputs_labels_for_multimodal(
1837
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
1838
+ images, image_sizes=None
1839
+ ):
1840
+ vision_tower = self.get_vision_tower()
1841
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
1842
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
1843
+
1844
+ if type(images) is list or images.ndim == 5:
1845
+ if type(images) is list:
1846
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
1847
+ concat_images = torch.cat([image for image in images], dim=0)
1848
+ image_features = self.encode_images(concat_images)
1849
+ split_sizes = [image.shape[0] for image in images]
1850
+ image_features = torch.split(image_features, split_sizes, dim=0)
1851
+ mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
1852
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
1853
+ if mm_patch_merge_type == 'flat':
1854
+ image_features = [x.flatten(0, 1) for x in image_features]
1855
+ elif mm_patch_merge_type.startswith('spatial'):
1856
+ new_image_features = []
1857
+ for image_idx, image_feature in enumerate(image_features):
1858
+ if image_feature.shape[0] > 1:
1859
+ base_image_feature = image_feature[0]
1860
+ image_feature = image_feature[1:]
1861
+ height = width = self.get_vision_tower().num_patches_per_side
1862
+ assert height * width == base_image_feature.shape[0]
1863
+ if image_aspect_ratio == 'anyres':
1864
+ if hasattr(self.get_vision_tower(), 's2_image_size'):
1865
+ img_size = self.get_vision_tower().s2_image_size
1866
+ elif isinstance(self.get_vision_tower().config, dict):
1867
+ img_size = self.get_vision_tower().config["image_cfg"]["image_size"]
1868
+ else:
1869
+ img_size = self.get_vision_tower().config.image_size
1870
+
1871
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, img_size)
1872
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
1873
+ else:
1874
+ raise NotImplementedError
1875
+ if 'unpad' in mm_patch_merge_type:
1876
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
1877
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
1878
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
1879
+ image_feature = torch.cat((
1880
+ image_feature,
1881
+ self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
1882
+ ), dim=-1)
1883
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
1884
+ else:
1885
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
1886
+ image_feature = image_feature.flatten(0, 3)
1887
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
1888
+ else:
1889
+ image_feature = image_feature[0]
1890
+ if 'unpad' in mm_patch_merge_type:
1891
+ image_feature = torch.cat((
1892
+ image_feature,
1893
+ self.model.image_newline[None].to(image_feature.device)
1894
+ ), dim=0)
1895
+ new_image_features.append(image_feature)
1896
+ image_features = new_image_features
1897
+ else:
1898
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
1899
+ else:
1900
+ image_features = self.encode_images(images)
1901
+
1902
+ # TODO: image start / end is not implemented here to support pretraining.
1903
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
1904
+ raise NotImplementedError
1905
+
1906
+ # Let's just add dummy tensors if they do not exist,
1907
+ # it is a headache to deal with None all the time.
1908
+ # But it is not ideal, and if you have a better idea,
1909
+ # please open an issue / submit a PR, thanks.
1910
+ _labels = labels
1911
+ _position_ids = position_ids
1912
+ _attention_mask = attention_mask
1913
+ if attention_mask is None:
1914
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1915
+ else:
1916
+ attention_mask = attention_mask.bool()
1917
+ if position_ids is None:
1918
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
1919
+ if labels is None:
1920
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
1921
+
1922
+ # remove the padding using attention_mask -- FIXME
1923
+ _input_ids = input_ids
1924
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
1925
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
1926
+
1927
+ new_input_embeds = []
1928
+ new_labels = []
1929
+ cur_image_idx = 0
1930
+ for batch_idx, cur_input_ids in enumerate(input_ids):
1931
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
1932
+ if num_images == 0:
1933
+ cur_image_features = image_features[cur_image_idx]
1934
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
1935
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
1936
+ new_input_embeds.append(cur_input_embeds)
1937
+ new_labels.append(labels[batch_idx])
1938
+ cur_image_idx += 1
1939
+ continue
1940
+
1941
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
1942
+ cur_input_ids_noim = []
1943
+ cur_labels = labels[batch_idx]
1944
+ cur_labels_noim = []
1945
+ for i in range(len(image_token_indices) - 1):
1946
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
1947
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
1948
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
1949
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
1950
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
1951
+ cur_new_input_embeds = []
1952
+ cur_new_labels = []
1953
+
1954
+ for i in range(num_images + 1):
1955
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
1956
+ cur_new_labels.append(cur_labels_noim[i])
1957
+ if i < num_images:
1958
+ cur_image_features = image_features[cur_image_idx]
1959
+ cur_image_idx += 1
1960
+ cur_new_input_embeds.append(cur_image_features)
1961
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
1962
+
1963
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
1964
+
1965
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
1966
+ cur_new_labels = torch.cat(cur_new_labels)
1967
+
1968
+ new_input_embeds.append(cur_new_input_embeds)
1969
+ new_labels.append(cur_new_labels)
1970
+
1971
+ # Truncate sequences to max length as image embeddings can make the sequence longer
1972
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
1973
+ if tokenizer_model_max_length is not None:
1974
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
1975
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
1976
+
1977
+ # Combine them
1978
+ max_len = max(x.shape[0] for x in new_input_embeds)
1979
+ batch_size = len(new_input_embeds)
1980
+
1981
+ new_input_embeds_padded = []
1982
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
1983
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
1984
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
1985
+
1986
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
1987
+ cur_len = cur_new_embed.shape[0]
1988
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
1989
+ new_input_embeds_padded.append(torch.cat((
1990
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
1991
+ cur_new_embed
1992
+ ), dim=0))
1993
+ if cur_len > 0:
1994
+ new_labels_padded[i, -cur_len:] = cur_new_labels
1995
+ attention_mask[i, -cur_len:] = True
1996
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
1997
+ else:
1998
+ new_input_embeds_padded.append(torch.cat((
1999
+ cur_new_embed,
2000
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
2001
+ ), dim=0))
2002
+ if cur_len > 0:
2003
+ new_labels_padded[i, :cur_len] = cur_new_labels
2004
+ attention_mask[i, :cur_len] = True
2005
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
2006
+
2007
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
2008
+
2009
+ if _labels is None:
2010
+ new_labels = None
2011
+ else:
2012
+ new_labels = new_labels_padded
2013
+
2014
+ if _attention_mask is None:
2015
+ attention_mask = None
2016
+ else:
2017
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
2018
+
2019
+ if _position_ids is None:
2020
+ position_ids = None
2021
+
2022
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
2023
+
2024
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
2025
+ if model_args.mm_use_im_patch_token:
2026
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
2027
+ self.resize_token_embeddings(len(tokenizer))
2028
+
2029
+ if model_args.mm_use_im_start_end:
2030
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
2031
+ self.resize_token_embeddings(len(tokenizer))
2032
+
2033
+ if num_new_tokens > 0:
2034
+ input_embeddings = self.get_input_embeddings().weight.data
2035
+ output_embeddings = self.get_output_embeddings().weight.data
2036
+
2037
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
2038
+ dim=0, keepdim=True)
2039
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
2040
+ dim=0, keepdim=True)
2041
+
2042
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
2043
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
2044
+
2045
+ if model_args.tune_mm_mlp_adapter:
2046
+ for p in self.get_input_embeddings().parameters():
2047
+ p.requires_grad = True
2048
+ for p in self.get_output_embeddings().parameters():
2049
+ p.requires_grad = False
2050
+
2051
+ if model_args.pretrain_mm_mlp_adapter:
2052
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
2053
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
2054
+ assert num_new_tokens == 2
2055
+ if input_embeddings.shape == embed_tokens_weight.shape:
2056
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
2057
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
2058
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
2059
+ else:
2060
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
2061
+ elif model_args.mm_use_im_patch_token:
2062
+ if model_args.tune_mm_mlp_adapter:
2063
+ for p in self.get_input_embeddings().parameters():
2064
+ p.requires_grad = False
2065
+ for p in self.get_output_embeddings().parameters():
2066
+ p.requires_grad = False
2067
+
2068
+
2069
+ class LlavaQwen2Model(LlavaMetaModel, Qwen2Model):
2070
+ config_class = LlavaConfig
2071
+
2072
+ def __init__(self, config: Qwen2Config):
2073
+ super(LlavaQwen2Model, self).__init__(config)
2074
+
2075
+
2076
+ class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
2077
+ config_class = LlavaConfig
2078
+
2079
+ def __init__(self, config):
2080
+ super(Qwen2ForCausalLM, self).__init__(config)
2081
+ self.model = LlavaQwen2Model(config)
2082
+ # self.pretraining_tp = config.pretraining_tp
2083
+ self.vocab_size = config.vocab_size
2084
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
2085
+
2086
+ # Initialize weights and apply final processing
2087
+ self.post_init()
2088
+
2089
+ def get_model(self):
2090
+ return self.model
2091
+
2092
+ def forward(
2093
+ self,
2094
+ input_ids: torch.LongTensor = None,
2095
+ attention_mask: Optional[torch.Tensor] = None,
2096
+ position_ids: Optional[torch.LongTensor] = None,
2097
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
2098
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2099
+ labels: Optional[torch.LongTensor] = None,
2100
+ use_cache: Optional[bool] = None,
2101
+ output_attentions: Optional[bool] = None,
2102
+ output_hidden_states: Optional[bool] = None,
2103
+ images: Optional[torch.FloatTensor] = None,
2104
+ image_sizes: Optional[List[List[int]]] = None,
2105
+ return_dict: Optional[bool] = None,
2106
+ cache_position=None,
2107
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
2108
+
2109
+ if inputs_embeds is None:
2110
+ (
2111
+ input_ids,
2112
+ position_ids,
2113
+ attention_mask,
2114
+ past_key_values,
2115
+ inputs_embeds,
2116
+ labels
2117
+ ) = self.prepare_inputs_labels_for_multimodal(
2118
+ input_ids,
2119
+ position_ids,
2120
+ attention_mask,
2121
+ past_key_values,
2122
+ labels,
2123
+ images,
2124
+ image_sizes
2125
+ )
2126
+
2127
+ return super().forward(
2128
+ input_ids=input_ids,
2129
+ attention_mask=attention_mask,
2130
+ position_ids=position_ids,
2131
+ past_key_values=past_key_values,
2132
+ inputs_embeds=inputs_embeds,
2133
+ labels=labels,
2134
+ use_cache=use_cache,
2135
+ output_attentions=output_attentions,
2136
+ output_hidden_states=output_hidden_states,
2137
+ return_dict=return_dict
2138
+ )
2139
+
2140
+ @torch.no_grad()
2141
+ def generate(
2142
+ self,
2143
+ inputs: Optional[torch.Tensor] = None,
2144
+ images: Optional[torch.Tensor] = None,
2145
+ image_sizes: Optional[torch.Tensor] = None,
2146
+ **kwargs,
2147
+ ) -> Union[GenerateOutput, torch.LongTensor]:
2148
+ position_ids = kwargs.pop("position_ids", None)
2149
+ attention_mask = kwargs.pop("attention_mask", None)
2150
+ if "inputs_embeds" in kwargs:
2151
+ raise NotImplementedError("`inputs_embeds` is not supported")
2152
+
2153
+ if images is not None:
2154
+ (
2155
+ inputs,
2156
+ position_ids,
2157
+ attention_mask,
2158
+ _,
2159
+ inputs_embeds,
2160
+ _
2161
+ ) = self.prepare_inputs_labels_for_multimodal(
2162
+ inputs,
2163
+ position_ids,
2164
+ attention_mask,
2165
+ None,
2166
+ None,
2167
+ images,
2168
+ image_sizes=image_sizes
2169
+ )
2170
+ else:
2171
+ inputs_embeds = self.get_model().embed_tokens(inputs)
2172
+
2173
+ return super().generate(
2174
+ position_ids=position_ids,
2175
+ attention_mask=attention_mask,
2176
+ inputs_embeds=inputs_embeds,
2177
+ **kwargs
2178
+ )
2179
+
2180
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
2181
+ inputs_embeds=None, **kwargs):
2182
+ images = kwargs.pop("images", None)
2183
+ image_sizes = kwargs.pop("image_sizes", None)
2184
+ inputs = super().prepare_inputs_for_generation(
2185
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
2186
+ )
2187
+ if images is not None:
2188
+ inputs['images'] = images
2189
+ if image_sizes is not None:
2190
+ inputs['image_sizes'] = image_sizes
2191
+ return inputs
2192
+
2193
+
2194
+ AutoConfig.register("llava_qwen2", LlavaConfig)
2195
+ AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM)