reach-vb HF Staff commited on
Commit
7334d4e
·
verified ·
1 Parent(s): 7b529ec

Make the repos compatible with transformers `trust_remote_code` 🤗

Browse files

You can try it with the snippet here:

```python
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM

MID = "apple/FastVLM-0.5B"
IMAGE_TOKEN_INDEX = -200 # what the model code looks for

# 1) Load
tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True,
)

# 2) Build chat -> render to string (not tokens) so we can place <image> exactly
messages = [
{"role": "user", "content": "<image>\nDescribe this image in detail."}
]
rendered = tok.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
assert "<image>" in rendered, "The chat template output must contain <image> once."

pre, post = rendered.split("<image>", 1)

# 3) Tokenize the text *around* the image token (no extra specials!)
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids

# 4) Splice in the IMAGE token id (-200) at the placeholder position
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
attention_mask = torch.ones_like(input_ids, device=model.device)

# 5) Preprocess image via the model's own processor
img = Image.open("test-2.jpg").convert("RGB")
px = model.get_vision_tower().image_processor(images=img, return_tensors="pt")["pixel_values"]
px = px.to(model.device, dtype=model.dtype)

# 6) Generate
with torch.no_grad():
out = model.generate(
inputs=input_ids,
attention_mask=attention_mask,
images=px,
max_new_tokens=128,
)

print(tok.decode(out[0], skip_special_tokens=True))
```

Files changed (1) hide show
  1. llava_qwen.py +2234 -0
llava_qwen.py ADDED
@@ -0,0 +1,2234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import re
20
+ import copy
21
+ from timm.models import create_model
22
+ from abc import ABC, abstractmethod
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch import Tensor
27
+ import torch.nn.functional as F
28
+ from torch.nn.init import normal_
29
+
30
+ from transformers import CLIPImageProcessor
31
+ from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM
32
+
33
+ from transformers.modeling_outputs import CausalLMOutputWithPast
34
+ from transformers.generation.utils import GenerateOutput
35
+
36
+ from functools import partial
37
+ from typing import List, Tuple, Optional, Union, Dict, Any
38
+
39
+ from timm.models import register_model
40
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
41
+ from timm.layers import DropPath, SqueezeExcite
42
+
43
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
44
+ WORKER_HEART_BEAT_INTERVAL = 15
45
+ LOGDIR = "."
46
+ # Model Constants
47
+ IGNORE_INDEX = -100
48
+ IMAGE_TOKEN_INDEX = -200
49
+ DEFAULT_IMAGE_TOKEN = "<image>"
50
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
51
+ DEFAULT_IM_START_TOKEN = "<im_start>"
52
+ DEFAULT_IM_END_TOKEN = "<im_end>"
53
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
54
+
55
+ class LlavaConfig(Qwen2Config):
56
+ model_type = "llava_qwen2"
57
+
58
+ def _cfg(url="", **kwargs):
59
+ return {
60
+ "url": url,
61
+ "num_classes": 1000,
62
+ "input_size": (3, 256, 256),
63
+ "pool_size": None,
64
+ "crop_pct": 0.95,
65
+ "interpolation": "bicubic",
66
+ "mean": IMAGENET_DEFAULT_MEAN,
67
+ "std": IMAGENET_DEFAULT_STD,
68
+ "classifier": "head",
69
+ **kwargs,
70
+ }
71
+
72
+
73
+ default_cfgs = {
74
+ "fastvit_t": _cfg(crop_pct=0.9),
75
+ "fastvit_s": _cfg(crop_pct=0.9),
76
+ "fastvit_m": _cfg(crop_pct=0.95),
77
+ }
78
+
79
+
80
+ class SEBlock(nn.Module):
81
+ """Squeeze and Excite module.
82
+
83
+ Pytorch implementation of `Squeeze-and-Excitation Networks` -
84
+ https://arxiv.org/pdf/1709.01507.pdf
85
+ """
86
+
87
+ def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
88
+ """Construct a Squeeze and Excite Module.
89
+
90
+ Args:
91
+ in_channels: Number of input channels.
92
+ rd_ratio: Input channel reduction ratio.
93
+ """
94
+ super(SEBlock, self).__init__()
95
+ self.reduce = nn.Conv2d(
96
+ in_channels=in_channels,
97
+ out_channels=int(in_channels * rd_ratio),
98
+ kernel_size=1,
99
+ stride=1,
100
+ bias=True,
101
+ )
102
+ self.expand = nn.Conv2d(
103
+ in_channels=int(in_channels * rd_ratio),
104
+ out_channels=in_channels,
105
+ kernel_size=1,
106
+ stride=1,
107
+ bias=True,
108
+ )
109
+
110
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
111
+ """Apply forward pass."""
112
+ b, c, h, w = inputs.size()
113
+ # x = F.avg_pool2d(inputs, kernel_size=[h, w])
114
+ x = F.avg_pool2d(inputs, kernel_size=[16, 16])
115
+ x = self.reduce(x)
116
+ x = F.relu(x)
117
+ x = self.expand(x)
118
+ x = torch.sigmoid(x)
119
+ x = x.view(-1, c, 1, 1)
120
+ return inputs * x
121
+
122
+
123
+ class MobileOneBlock(nn.Module):
124
+ """MobileOne building block.
125
+
126
+ This block has a multi-branched architecture at train-time
127
+ and plain-CNN style architecture at inference time
128
+ For more details, please refer to our paper:
129
+ `An Improved One millisecond Mobile Backbone` -
130
+ https://arxiv.org/pdf/2206.04040.pdf
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ in_channels: int,
136
+ out_channels: int,
137
+ kernel_size: int,
138
+ stride: int = 1,
139
+ padding: int = 0,
140
+ dilation: int = 1,
141
+ groups: int = 1,
142
+ inference_mode: bool = False,
143
+ use_se: bool = False,
144
+ use_act: bool = True,
145
+ use_scale_branch: bool = True,
146
+ num_conv_branches: int = 1,
147
+ activation: nn.Module = nn.GELU(),
148
+ ) -> None:
149
+ """Construct a MobileOneBlock module.
150
+
151
+ Args:
152
+ in_channels: Number of channels in the input.
153
+ out_channels: Number of channels produced by the block.
154
+ kernel_size: Size of the convolution kernel.
155
+ stride: Stride size.
156
+ padding: Zero-padding size.
157
+ dilation: Kernel dilation factor.
158
+ groups: Group number.
159
+ inference_mode: If True, instantiates model in inference mode.
160
+ use_se: Whether to use SE-ReLU activations.
161
+ use_act: Whether to use activation. Default: ``True``
162
+ use_scale_branch: Whether to use scale branch. Default: ``True``
163
+ num_conv_branches: Number of linear conv branches.
164
+ """
165
+ super(MobileOneBlock, self).__init__()
166
+ self.inference_mode = inference_mode
167
+ self.groups = groups
168
+ self.stride = stride
169
+ self.padding = padding
170
+ self.dilation = dilation
171
+ self.kernel_size = kernel_size
172
+ self.in_channels = in_channels
173
+ self.out_channels = out_channels
174
+ self.num_conv_branches = num_conv_branches
175
+
176
+ # Check if SE-ReLU is requested
177
+ if use_se:
178
+ self.se = SEBlock(out_channels)
179
+ else:
180
+ self.se = nn.Identity()
181
+
182
+ if use_act:
183
+ self.activation = activation
184
+ else:
185
+ self.activation = nn.Identity()
186
+
187
+ if inference_mode:
188
+ self.reparam_conv = nn.Conv2d(
189
+ in_channels=in_channels,
190
+ out_channels=out_channels,
191
+ kernel_size=kernel_size,
192
+ stride=stride,
193
+ padding=padding,
194
+ dilation=dilation,
195
+ groups=groups,
196
+ bias=True,
197
+ )
198
+ else:
199
+ # Re-parameterizable skip connection
200
+ # Fallback, sometimes batchnorm tensors
201
+ # do not get instantiated correctly on some processes
202
+ # when using deepspeed + accelerate
203
+ norm_layer = nn.BatchNorm2d(num_features=in_channels)
204
+ if norm_layer.weight.shape[0] == 0:
205
+ norm_layer.weight = nn.Parameter(torch.zeros(in_channels))
206
+ if norm_layer.bias.shape[0] == 0:
207
+ norm_layer.bias = nn.Parameter(torch.zeros(in_channels))
208
+
209
+ self.rbr_skip = (
210
+ norm_layer
211
+ if out_channels == in_channels and stride == 1
212
+ else None
213
+ )
214
+
215
+ # Re-parameterizable conv branches
216
+ if num_conv_branches > 0:
217
+ rbr_conv = list()
218
+ for _ in range(self.num_conv_branches):
219
+ rbr_conv.append(
220
+ self._conv_bn(kernel_size=kernel_size, padding=padding)
221
+ )
222
+ self.rbr_conv = nn.ModuleList(rbr_conv)
223
+ else:
224
+ self.rbr_conv = None
225
+
226
+ # Re-parameterizable scale branch
227
+ self.rbr_scale = None
228
+ if not isinstance(kernel_size, int):
229
+ kernel_size = kernel_size[0]
230
+ if (kernel_size > 1) and use_scale_branch:
231
+ self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ """Apply forward pass."""
235
+ # Inference mode forward pass.
236
+ if self.inference_mode:
237
+ return self.activation(self.se(self.reparam_conv(x)))
238
+
239
+ # Multi-branched train-time forward pass.
240
+ # Skip branch output
241
+ identity_out = 0
242
+ if self.rbr_skip is not None:
243
+ identity_out = self.rbr_skip(x)
244
+
245
+ # Scale branch output
246
+ scale_out = 0
247
+ if self.rbr_scale is not None:
248
+ scale_out = self.rbr_scale(x)
249
+
250
+ # Other branches
251
+ out = scale_out + identity_out
252
+ if self.rbr_conv is not None:
253
+ for ix in range(self.num_conv_branches):
254
+ out += self.rbr_conv[ix](x)
255
+
256
+ return self.activation(self.se(out))
257
+
258
+ def reparameterize(self):
259
+ """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
260
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
261
+ architecture used at training time to obtain a plain CNN-like structure
262
+ for inference.
263
+ """
264
+ if self.inference_mode:
265
+ return
266
+ kernel, bias = self._get_kernel_bias()
267
+ self.reparam_conv = nn.Conv2d(
268
+ in_channels=self.in_channels,
269
+ out_channels=self.out_channels,
270
+ kernel_size=self.kernel_size,
271
+ stride=self.stride,
272
+ padding=self.padding,
273
+ dilation=self.dilation,
274
+ groups=self.groups,
275
+ bias=True,
276
+ )
277
+ self.reparam_conv.weight.data = kernel
278
+ self.reparam_conv.bias.data = bias
279
+
280
+ # Delete un-used branches
281
+ self.__delattr__("rbr_conv")
282
+ self.__delattr__("rbr_scale")
283
+ if hasattr(self, "rbr_skip"):
284
+ self.__delattr__("rbr_skip")
285
+
286
+ self.inference_mode = True
287
+
288
+ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
289
+ """Method to obtain re-parameterized kernel and bias.
290
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
291
+
292
+ Returns:
293
+ Tuple of (kernel, bias) after fusing branches.
294
+ """
295
+ # get weights and bias of scale branch
296
+ kernel_scale = 0
297
+ bias_scale = 0
298
+ if self.rbr_scale is not None:
299
+ kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
300
+ # Pad scale branch kernel to match conv branch kernel size.
301
+ pad = self.kernel_size // 2
302
+ kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
303
+
304
+ # get weights and bias of skip branch
305
+ kernel_identity = 0
306
+ bias_identity = 0
307
+ if self.rbr_skip is not None:
308
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
309
+
310
+ # get weights and bias of conv branches
311
+ kernel_conv = 0
312
+ bias_conv = 0
313
+ if self.rbr_conv is not None:
314
+ for ix in range(self.num_conv_branches):
315
+ _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
316
+ kernel_conv += _kernel
317
+ bias_conv += _bias
318
+
319
+ kernel_final = kernel_conv + kernel_scale + kernel_identity
320
+ bias_final = bias_conv + bias_scale + bias_identity
321
+ return kernel_final, bias_final
322
+
323
+ def _fuse_bn_tensor(
324
+ self, branch: Union[nn.Sequential, nn.BatchNorm2d]
325
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
326
+ """Method to fuse batchnorm layer with preceeding conv layer.
327
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
328
+
329
+ Args:
330
+ branch: Sequence of ops to be fused.
331
+
332
+ Returns:
333
+ Tuple of (kernel, bias) after fusing batchnorm.
334
+ """
335
+ if isinstance(branch, nn.Sequential):
336
+ kernel = branch.conv.weight
337
+ running_mean = branch.bn.running_mean
338
+ running_var = branch.bn.running_var
339
+ gamma = branch.bn.weight
340
+ beta = branch.bn.bias
341
+ eps = branch.bn.eps
342
+ else:
343
+ assert isinstance(branch, nn.BatchNorm2d)
344
+ if not hasattr(self, "id_tensor"):
345
+ input_dim = self.in_channels // self.groups
346
+
347
+ kernel_size = self.kernel_size
348
+ if isinstance(self.kernel_size, int):
349
+ kernel_size = (self.kernel_size, self.kernel_size)
350
+
351
+ kernel_value = torch.zeros(
352
+ (self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
353
+ dtype=branch.weight.dtype,
354
+ device=branch.weight.device,
355
+ )
356
+ for i in range(self.in_channels):
357
+ kernel_value[
358
+ i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
359
+ ] = 1
360
+ self.id_tensor = kernel_value
361
+ kernel = self.id_tensor
362
+ running_mean = branch.running_mean
363
+ running_var = branch.running_var
364
+ gamma = branch.weight
365
+ beta = branch.bias
366
+ eps = branch.eps
367
+ std = (running_var + eps).sqrt()
368
+ t = (gamma / std).reshape(-1, 1, 1, 1)
369
+ return kernel * t, beta - running_mean * gamma / std
370
+
371
+ def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
372
+ """Helper method to construct conv-batchnorm layers.
373
+
374
+ Args:
375
+ kernel_size: Size of the convolution kernel.
376
+ padding: Zero-padding size.
377
+
378
+ Returns:
379
+ Conv-BN module.
380
+ """
381
+ # Fallback, sometimes batchnorm tensors
382
+ # do not get instantiated correctly on some processes
383
+ # when using deepspeed + accelerate
384
+ norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
385
+ if norm_layer.weight.shape[0] == 0:
386
+ norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
387
+ if norm_layer.bias.shape[0] == 0:
388
+ norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
389
+
390
+ mod_list = nn.Sequential()
391
+ mod_list.add_module(
392
+ "conv",
393
+ nn.Conv2d(
394
+ in_channels=self.in_channels,
395
+ out_channels=self.out_channels,
396
+ kernel_size=kernel_size,
397
+ stride=self.stride,
398
+ padding=padding,
399
+ groups=self.groups,
400
+ bias=False,
401
+ ),
402
+ )
403
+ mod_list.add_module("bn", norm_layer)
404
+ return mod_list
405
+
406
+
407
+ class ReparamLargeKernelConv(nn.Module):
408
+ """Building Block of RepLKNet
409
+
410
+ This class defines overparameterized large kernel conv block
411
+ introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
412
+
413
+ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
414
+ """
415
+
416
+ def __init__(
417
+ self,
418
+ in_channels: int,
419
+ out_channels: int,
420
+ kernel_size: int,
421
+ stride: int,
422
+ groups: int,
423
+ small_kernel: int,
424
+ inference_mode: bool = False,
425
+ use_se: bool = False,
426
+ activation: nn.Module = nn.GELU(),
427
+ ) -> None:
428
+ """Construct a ReparamLargeKernelConv module.
429
+
430
+ Args:
431
+ in_channels: Number of input channels.
432
+ out_channels: Number of output channels.
433
+ kernel_size: Kernel size of the large kernel conv branch.
434
+ stride: Stride size. Default: 1
435
+ groups: Group number. Default: 1
436
+ small_kernel: Kernel size of small kernel conv branch.
437
+ inference_mode: If True, instantiates model in inference mode. Default: ``False``
438
+ activation: Activation module. Default: ``nn.GELU``
439
+ """
440
+ super(ReparamLargeKernelConv, self).__init__()
441
+
442
+ self.stride = stride
443
+ self.groups = groups
444
+ self.in_channels = in_channels
445
+ self.out_channels = out_channels
446
+ self.activation = activation
447
+
448
+ self.kernel_size = kernel_size
449
+ self.small_kernel = small_kernel
450
+ self.padding = kernel_size // 2
451
+
452
+ # Check if SE is requested
453
+ if use_se:
454
+ self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
455
+ else:
456
+ self.se = nn.Identity()
457
+
458
+ if inference_mode:
459
+ self.lkb_reparam = nn.Conv2d(
460
+ in_channels=in_channels,
461
+ out_channels=out_channels,
462
+ kernel_size=kernel_size,
463
+ stride=stride,
464
+ padding=self.padding,
465
+ dilation=1,
466
+ groups=groups,
467
+ bias=True,
468
+ )
469
+ else:
470
+ self.lkb_origin = self._conv_bn(
471
+ kernel_size=kernel_size, padding=self.padding
472
+ )
473
+ if small_kernel is not None:
474
+ assert (
475
+ small_kernel <= kernel_size
476
+ ), "The kernel size for re-param cannot be larger than the large kernel!"
477
+ self.small_conv = self._conv_bn(
478
+ kernel_size=small_kernel, padding=small_kernel // 2
479
+ )
480
+
481
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
482
+ """Apply forward pass."""
483
+ if hasattr(self, "lkb_reparam"):
484
+ out = self.lkb_reparam(x)
485
+ else:
486
+ out = self.lkb_origin(x)
487
+ if hasattr(self, "small_conv"):
488
+ out += self.small_conv(x)
489
+
490
+ return self.activation(self.se(out))
491
+
492
+ def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
493
+ """Method to obtain re-parameterized kernel and bias.
494
+ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
495
+
496
+ Returns:
497
+ Tuple of (kernel, bias) after fusing branches.
498
+ """
499
+ eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
500
+ if hasattr(self, "small_conv"):
501
+ small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
502
+ eq_b += small_b
503
+ eq_k += nn.functional.pad(
504
+ small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
505
+ )
506
+ return eq_k, eq_b
507
+
508
+ def reparameterize(self) -> None:
509
+ """
510
+ Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
511
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
512
+ architecture used at training time to obtain a plain CNN-like structure
513
+ for inference.
514
+ """
515
+ eq_k, eq_b = self.get_kernel_bias()
516
+ self.lkb_reparam = nn.Conv2d(
517
+ in_channels=self.in_channels,
518
+ out_channels=self.out_channels,
519
+ kernel_size=self.kernel_size,
520
+ stride=self.stride,
521
+ padding=self.padding,
522
+ dilation=self.lkb_origin.conv.dilation,
523
+ groups=self.groups,
524
+ bias=True,
525
+ )
526
+
527
+ self.lkb_reparam.weight.data = eq_k
528
+ self.lkb_reparam.bias.data = eq_b
529
+ self.__delattr__("lkb_origin")
530
+ if hasattr(self, "small_conv"):
531
+ self.__delattr__("small_conv")
532
+
533
+ @staticmethod
534
+ def _fuse_bn(
535
+ conv: torch.Tensor, bn: nn.BatchNorm2d
536
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
537
+ """Method to fuse batchnorm layer with conv layer.
538
+
539
+ Args:
540
+ conv: Convolutional kernel weights.
541
+ bn: Batchnorm 2d layer.
542
+
543
+ Returns:
544
+ Tuple of (kernel, bias) after fusing batchnorm.
545
+ """
546
+ kernel = conv.weight
547
+ running_mean = bn.running_mean
548
+ running_var = bn.running_var
549
+ gamma = bn.weight
550
+ beta = bn.bias
551
+ eps = bn.eps
552
+ std = (running_var + eps).sqrt()
553
+ t = (gamma / std).reshape(-1, 1, 1, 1)
554
+ return kernel * t, beta - running_mean * gamma / std
555
+
556
+ def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
557
+ """Helper method to construct conv-batchnorm layers.
558
+
559
+ Args:
560
+ kernel_size: Size of the convolution kernel.
561
+ padding: Zero-padding size.
562
+
563
+ Returns:
564
+ A nn.Sequential Conv-BN module.
565
+ """
566
+ # Fallback, sometimes batchnorm tensors
567
+ # do not get instantiated correctly on some processes
568
+ # when using deepspeed + accelerate
569
+ norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
570
+ if norm_layer.weight.shape[0] == 0:
571
+ norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
572
+ if norm_layer.bias.shape[0] == 0:
573
+ norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
574
+
575
+ mod_list = nn.Sequential()
576
+ mod_list.add_module(
577
+ "conv",
578
+ nn.Conv2d(
579
+ in_channels=self.in_channels,
580
+ out_channels=self.out_channels,
581
+ kernel_size=kernel_size,
582
+ stride=self.stride,
583
+ padding=padding,
584
+ groups=self.groups,
585
+ bias=False,
586
+ ),
587
+ )
588
+ mod_list.add_module("bn", norm_layer)
589
+ return mod_list
590
+
591
+
592
+ def convolutional_stem(
593
+ in_channels: int, out_channels: int, inference_mode: bool = False, use_scale_branch: bool = True,
594
+ ) -> nn.Sequential:
595
+ """Build convolutional stem with MobileOne blocks.
596
+
597
+ Args:
598
+ in_channels: Number of input channels.
599
+ out_channels: Number of output channels.
600
+ inference_mode: Flag to instantiate model in inference mode. Default: ``False``
601
+
602
+ Returns:
603
+ nn.Sequential object with stem elements.
604
+ """
605
+ return nn.Sequential(
606
+ MobileOneBlock(
607
+ in_channels=in_channels,
608
+ out_channels=out_channels,
609
+ kernel_size=3,
610
+ stride=2,
611
+ padding=1,
612
+ groups=1,
613
+ inference_mode=inference_mode,
614
+ use_se=False,
615
+ num_conv_branches=1,
616
+ use_scale_branch=use_scale_branch
617
+ ),
618
+ MobileOneBlock(
619
+ in_channels=out_channels,
620
+ out_channels=out_channels,
621
+ kernel_size=3,
622
+ stride=2,
623
+ padding=1,
624
+ groups=out_channels,
625
+ inference_mode=inference_mode,
626
+ use_se=False,
627
+ num_conv_branches=1,
628
+ use_scale_branch=use_scale_branch
629
+ ),
630
+ MobileOneBlock(
631
+ in_channels=out_channels,
632
+ out_channels=out_channels,
633
+ kernel_size=1,
634
+ stride=1,
635
+ padding=0,
636
+ groups=1,
637
+ inference_mode=inference_mode,
638
+ use_se=False,
639
+ num_conv_branches=1,
640
+ use_scale_branch=use_scale_branch
641
+ ),
642
+ )
643
+
644
+
645
+ class LayerNormChannel(nn.Module):
646
+ """
647
+ LayerNorm only for Channel Dimension.
648
+ Input: tensor in shape [B, C, H, W]
649
+ """
650
+ def __init__(self, num_features, eps=1e-05) -> None:
651
+ super().__init__()
652
+ self.weight = nn.Parameter(torch.ones(num_features))
653
+ self.bias = nn.Parameter(torch.zeros(num_features))
654
+ self.eps = eps
655
+
656
+ def forward(self, x) -> torch.Tensor:
657
+ u = x.mean(1, keepdim=True)
658
+ s = (x - u).pow(2).mean(1, keepdim=True)
659
+ x = (x - u) / torch.sqrt(s + self.eps)
660
+ x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
661
+ + self.bias.unsqueeze(-1).unsqueeze(-1)
662
+ return x
663
+
664
+
665
+ class MHSA(nn.Module):
666
+ """Multi-headed Self Attention module.
667
+
668
+ Source modified from:
669
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
670
+ """
671
+
672
+ def __init__(
673
+ self,
674
+ dim: int,
675
+ head_dim: int = 32,
676
+ qkv_bias: bool = False,
677
+ attn_drop: float = 0.0,
678
+ proj_drop: float = 0.0,
679
+ ) -> None:
680
+ """Build MHSA module that can handle 3D or 4D input tensors.
681
+
682
+ Args:
683
+ dim: Number of embedding dimensions.
684
+ head_dim: Number of hidden dimensions per head. Default: ``32``
685
+ qkv_bias: Use bias or not. Default: ``False``
686
+ attn_drop: Dropout rate for attention tensor.
687
+ proj_drop: Dropout rate for projection tensor.
688
+ """
689
+ super().__init__()
690
+ assert dim % head_dim == 0, "dim should be divisible by head_dim"
691
+ self.head_dim = head_dim
692
+ self.num_heads = dim // head_dim
693
+ self.scale = head_dim**-0.5
694
+
695
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
696
+ self.attn_drop = nn.Dropout(attn_drop)
697
+ self.proj = nn.Linear(dim, dim)
698
+ self.proj_drop = nn.Dropout(proj_drop)
699
+
700
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
701
+ shape = x.shape
702
+ B, C, H, W = shape
703
+ N = H * W
704
+ if len(shape) == 4:
705
+ x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
706
+ qkv = (
707
+ self.qkv(x)
708
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
709
+ .permute(2, 0, 3, 1, 4)
710
+ )
711
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
712
+
713
+ # trick here to make [email protected] more stable
714
+ attn = (q * self.scale) @ k.transpose(-2, -1)
715
+ attn = attn.softmax(dim=-1)
716
+ attn = self.attn_drop(attn)
717
+
718
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
719
+ x = self.proj(x)
720
+ x = self.proj_drop(x)
721
+ if len(shape) == 4:
722
+ x = x.transpose(-2, -1).reshape(B, C, H, W)
723
+
724
+ return x
725
+
726
+
727
+ class PatchEmbed(nn.Module):
728
+ """Convolutional patch embedding layer."""
729
+
730
+ def __init__(
731
+ self,
732
+ patch_size: int,
733
+ stride: int,
734
+ in_channels: int,
735
+ embed_dim: int,
736
+ inference_mode: bool = False,
737
+ use_se: bool = False,
738
+ ) -> None:
739
+ """Build patch embedding layer.
740
+
741
+ Args:
742
+ patch_size: Patch size for embedding computation.
743
+ stride: Stride for convolutional embedding layer.
744
+ in_channels: Number of channels of input tensor.
745
+ embed_dim: Number of embedding dimensions.
746
+ inference_mode: Flag to instantiate model in inference mode. Default: ``False``
747
+ use_se: If ``True`` SE block will be used.
748
+ """
749
+ super().__init__()
750
+ block = list()
751
+ block.append(
752
+ ReparamLargeKernelConv(
753
+ in_channels=in_channels,
754
+ out_channels=embed_dim,
755
+ kernel_size=patch_size,
756
+ stride=stride,
757
+ groups=in_channels,
758
+ small_kernel=3,
759
+ inference_mode=inference_mode,
760
+ use_se=use_se,
761
+ )
762
+ )
763
+ block.append(
764
+ MobileOneBlock(
765
+ in_channels=embed_dim,
766
+ out_channels=embed_dim,
767
+ kernel_size=1,
768
+ stride=1,
769
+ padding=0,
770
+ groups=1,
771
+ inference_mode=inference_mode,
772
+ use_se=False,
773
+ num_conv_branches=1,
774
+ )
775
+ )
776
+ self.proj = nn.Sequential(*block)
777
+
778
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
779
+ x = self.proj(x)
780
+ return x
781
+
782
+
783
+ class RepMixer(nn.Module):
784
+ """Reparameterizable token mixer.
785
+
786
+ For more details, please refer to our paper:
787
+ `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
788
+ """
789
+
790
+ def __init__(
791
+ self,
792
+ dim,
793
+ kernel_size=3,
794
+ use_layer_scale=True,
795
+ layer_scale_init_value=1e-5,
796
+ inference_mode: bool = False,
797
+ ):
798
+ """Build RepMixer Module.
799
+
800
+ Args:
801
+ dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
802
+ kernel_size: Kernel size for spatial mixing. Default: 3
803
+ use_layer_scale: If True, learnable layer scale is used. Default: ``True``
804
+ layer_scale_init_value: Initial value for layer scale. Default: 1e-5
805
+ inference_mode: If True, instantiates model in inference mode. Default: ``False``
806
+ """
807
+ super().__init__()
808
+ self.dim = dim
809
+ self.kernel_size = kernel_size
810
+ self.inference_mode = inference_mode
811
+
812
+ if inference_mode:
813
+ self.reparam_conv = nn.Conv2d(
814
+ in_channels=self.dim,
815
+ out_channels=self.dim,
816
+ kernel_size=self.kernel_size,
817
+ stride=1,
818
+ padding=self.kernel_size // 2,
819
+ groups=self.dim,
820
+ bias=True,
821
+ )
822
+ else:
823
+ self.norm = MobileOneBlock(
824
+ dim,
825
+ dim,
826
+ kernel_size,
827
+ padding=kernel_size // 2,
828
+ groups=dim,
829
+ use_act=False,
830
+ use_scale_branch=False,
831
+ num_conv_branches=0,
832
+ )
833
+ self.mixer = MobileOneBlock(
834
+ dim,
835
+ dim,
836
+ kernel_size,
837
+ padding=kernel_size // 2,
838
+ groups=dim,
839
+ use_act=False,
840
+ )
841
+ self.use_layer_scale = use_layer_scale
842
+ if use_layer_scale:
843
+ self.layer_scale = nn.Parameter(
844
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
845
+ )
846
+
847
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
848
+ if hasattr(self, "reparam_conv"):
849
+ x = self.reparam_conv(x)
850
+ return x
851
+ else:
852
+ if self.use_layer_scale:
853
+ x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
854
+ else:
855
+ x = x + self.mixer(x) - self.norm(x)
856
+ return x
857
+
858
+ def reparameterize(self) -> None:
859
+ """Reparameterize mixer and norm into a single
860
+ convolutional layer for efficient inference.
861
+ """
862
+ if self.inference_mode:
863
+ return
864
+
865
+ self.mixer.reparameterize()
866
+ self.norm.reparameterize()
867
+
868
+ if self.use_layer_scale:
869
+ w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
870
+ self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
871
+ )
872
+ b = torch.squeeze(self.layer_scale) * (
873
+ self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
874
+ )
875
+ else:
876
+ w = (
877
+ self.mixer.id_tensor
878
+ + self.mixer.reparam_conv.weight
879
+ - self.norm.reparam_conv.weight
880
+ )
881
+ b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
882
+
883
+ self.reparam_conv = nn.Conv2d(
884
+ in_channels=self.dim,
885
+ out_channels=self.dim,
886
+ kernel_size=self.kernel_size,
887
+ stride=1,
888
+ padding=self.kernel_size // 2,
889
+ groups=self.dim,
890
+ bias=True,
891
+ )
892
+ self.reparam_conv.weight.data = w
893
+ self.reparam_conv.bias.data = b
894
+
895
+ self.__delattr__("mixer")
896
+ self.__delattr__("norm")
897
+ if self.use_layer_scale:
898
+ self.__delattr__("layer_scale")
899
+
900
+
901
+ class ConvFFN(nn.Module):
902
+ """Convolutional FFN Module."""
903
+
904
+ def __init__(
905
+ self,
906
+ in_channels: int,
907
+ hidden_channels: Optional[int] = None,
908
+ out_channels: Optional[int] = None,
909
+ act_layer: nn.Module = nn.GELU,
910
+ drop: float = 0.0,
911
+ ) -> None:
912
+ """Build convolutional FFN module.
913
+
914
+ Args:
915
+ in_channels: Number of input channels.
916
+ hidden_channels: Number of channels after expansion. Default: None
917
+ out_channels: Number of output channels. Default: None
918
+ act_layer: Activation layer. Default: ``GELU``
919
+ drop: Dropout rate. Default: ``0.0``.
920
+ """
921
+ super().__init__()
922
+ out_channels = out_channels or in_channels
923
+ hidden_channels = hidden_channels or in_channels
924
+ self.conv = nn.Sequential()
925
+ self.conv.add_module(
926
+ "conv",
927
+ nn.Conv2d(
928
+ in_channels=in_channels,
929
+ out_channels=out_channels,
930
+ kernel_size=7,
931
+ padding=3,
932
+ groups=in_channels,
933
+ bias=False,
934
+ ),
935
+ )
936
+
937
+ # Fallback, sometimes batchnorm tensors
938
+ # do not get instantiated correctly on some processes
939
+ # when using deepspeed + accelerate
940
+ norm_layer = nn.BatchNorm2d(num_features=out_channels)
941
+ if norm_layer.weight.shape[0] == 0:
942
+ norm_layer.weight = nn.Parameter(torch.zeros(out_channels))
943
+ if norm_layer.bias.shape[0] == 0:
944
+ norm_layer.bias = nn.Parameter(torch.zeros(out_channels))
945
+
946
+ self.conv.add_module("bn", norm_layer)
947
+ self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
948
+ self.act = act_layer()
949
+ self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
950
+ self.drop = nn.Dropout(drop)
951
+ self.apply(self._init_weights)
952
+
953
+ def _init_weights(self, m: nn.Module) -> None:
954
+ if isinstance(m, nn.Conv2d):
955
+ normal_(m.weight, std=0.02)
956
+ if m.bias is not None:
957
+ nn.init.constant_(m.bias, 0)
958
+
959
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
960
+ x = self.conv(x)
961
+ x = self.fc1(x)
962
+ x = self.act(x)
963
+ x = self.drop(x)
964
+ x = self.fc2(x)
965
+ x = self.drop(x)
966
+ return x
967
+
968
+
969
+ class RepCPE(nn.Module):
970
+ """Implementation of conditional positional encoding.
971
+
972
+ For more details refer to paper:
973
+ `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
974
+
975
+ In our implementation, we can reparameterize this module to eliminate a skip connection.
976
+ """
977
+
978
+ def __init__(
979
+ self,
980
+ in_channels: int,
981
+ embed_dim: int = 768,
982
+ spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
983
+ inference_mode=False,
984
+ ) -> None:
985
+ """Build reparameterizable conditional positional encoding
986
+
987
+ Args:
988
+ in_channels: Number of input channels.
989
+ embed_dim: Number of embedding dimensions. Default: 768
990
+ spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
991
+ inference_mode: Flag to instantiate block in inference mode. Default: ``False``
992
+ """
993
+ super(RepCPE, self).__init__()
994
+ if isinstance(spatial_shape, int):
995
+ spatial_shape = tuple([spatial_shape] * 2)
996
+ assert isinstance(spatial_shape, Tuple), (
997
+ f'"spatial_shape" must by a sequence or int, '
998
+ f"get {type(spatial_shape)} instead."
999
+ )
1000
+ assert len(spatial_shape) == 2, (
1001
+ f'Length of "spatial_shape" should be 2, '
1002
+ f"got {len(spatial_shape)} instead."
1003
+ )
1004
+
1005
+ self.spatial_shape = spatial_shape
1006
+ self.embed_dim = embed_dim
1007
+ self.in_channels = in_channels
1008
+ self.groups = embed_dim
1009
+
1010
+ if inference_mode:
1011
+ self.reparam_conv = nn.Conv2d(
1012
+ in_channels=self.in_channels,
1013
+ out_channels=self.embed_dim,
1014
+ kernel_size=self.spatial_shape,
1015
+ stride=1,
1016
+ padding=int(self.spatial_shape[0] // 2),
1017
+ groups=self.embed_dim,
1018
+ bias=True,
1019
+ )
1020
+ else:
1021
+ self.pe = nn.Conv2d(
1022
+ in_channels,
1023
+ embed_dim,
1024
+ spatial_shape,
1025
+ 1,
1026
+ int(spatial_shape[0] // 2),
1027
+ bias=True,
1028
+ groups=embed_dim,
1029
+ )
1030
+
1031
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1032
+ if hasattr(self, "reparam_conv"):
1033
+ x = self.reparam_conv(x)
1034
+ return x
1035
+ else:
1036
+ x = self.pe(x) + x
1037
+ return x
1038
+
1039
+ def reparameterize(self) -> None:
1040
+ # Build equivalent Id tensor
1041
+ input_dim = self.in_channels // self.groups
1042
+ kernel_value = torch.zeros(
1043
+ (
1044
+ self.in_channels,
1045
+ input_dim,
1046
+ self.spatial_shape[0],
1047
+ self.spatial_shape[1],
1048
+ ),
1049
+ dtype=self.pe.weight.dtype,
1050
+ device=self.pe.weight.device,
1051
+ )
1052
+ for i in range(self.in_channels):
1053
+ kernel_value[
1054
+ i,
1055
+ i % input_dim,
1056
+ self.spatial_shape[0] // 2,
1057
+ self.spatial_shape[1] // 2,
1058
+ ] = 1
1059
+ id_tensor = kernel_value
1060
+
1061
+ # Reparameterize Id tensor and conv
1062
+ w_final = id_tensor + self.pe.weight
1063
+ b_final = self.pe.bias
1064
+
1065
+ # Introduce reparam conv
1066
+ self.reparam_conv = nn.Conv2d(
1067
+ in_channels=self.in_channels,
1068
+ out_channels=self.embed_dim,
1069
+ kernel_size=self.spatial_shape,
1070
+ stride=1,
1071
+ padding=int(self.spatial_shape[0] // 2),
1072
+ groups=self.embed_dim,
1073
+ bias=True,
1074
+ )
1075
+ self.reparam_conv.weight.data = w_final
1076
+ self.reparam_conv.bias.data = b_final
1077
+
1078
+ self.__delattr__("pe")
1079
+
1080
+
1081
+ class RepMixerBlock(nn.Module):
1082
+ """Implementation of Metaformer block with RepMixer as token mixer.
1083
+
1084
+ For more details on Metaformer structure, please refer to:
1085
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
1086
+ """
1087
+
1088
+ def __init__(
1089
+ self,
1090
+ dim: int,
1091
+ kernel_size: int = 3,
1092
+ mlp_ratio: float = 4.0,
1093
+ act_layer: nn.Module = nn.GELU,
1094
+ drop: float = 0.0,
1095
+ drop_path: float = 0.0,
1096
+ use_layer_scale: bool = True,
1097
+ layer_scale_init_value: float = 1e-5,
1098
+ inference_mode: bool = False,
1099
+ ):
1100
+ """Build RepMixer Block.
1101
+
1102
+ Args:
1103
+ dim: Number of embedding dimensions.
1104
+ kernel_size: Kernel size for repmixer. Default: 3
1105
+ mlp_ratio: MLP expansion ratio. Default: 4.0
1106
+ act_layer: Activation layer. Default: ``nn.GELU``
1107
+ drop: Dropout rate. Default: 0.0
1108
+ drop_path: Drop path rate. Default: 0.0
1109
+ use_layer_scale: Flag to turn on layer scale. Default: ``True``
1110
+ layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
1111
+ inference_mode: Flag to instantiate block in inference mode. Default: ``False``
1112
+ """
1113
+
1114
+ super().__init__()
1115
+
1116
+ self.token_mixer = RepMixer(
1117
+ dim,
1118
+ kernel_size=kernel_size,
1119
+ use_layer_scale=use_layer_scale,
1120
+ layer_scale_init_value=layer_scale_init_value,
1121
+ inference_mode=inference_mode,
1122
+ )
1123
+
1124
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
1125
+ mlp_ratio
1126
+ )
1127
+ mlp_hidden_dim = int(dim * mlp_ratio)
1128
+ self.convffn = ConvFFN(
1129
+ in_channels=dim,
1130
+ hidden_channels=mlp_hidden_dim,
1131
+ act_layer=act_layer,
1132
+ drop=drop,
1133
+ )
1134
+
1135
+ # Drop Path
1136
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1137
+
1138
+ # Layer Scale
1139
+ self.use_layer_scale = use_layer_scale
1140
+ if use_layer_scale:
1141
+ self.layer_scale = nn.Parameter(
1142
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1143
+ )
1144
+
1145
+ def forward(self, x):
1146
+ if self.use_layer_scale:
1147
+ x = self.token_mixer(x)
1148
+ x = x + self.drop_path(self.layer_scale * self.convffn(x))
1149
+ else:
1150
+ x = self.token_mixer(x)
1151
+ x = x + self.drop_path(self.convffn(x))
1152
+ return x
1153
+
1154
+
1155
+ class AttentionBlock(nn.Module):
1156
+ """Implementation of metaformer block with MHSA as token mixer.
1157
+
1158
+ For more details on Metaformer structure, please refer to:
1159
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
1160
+ """
1161
+
1162
+ def __init__(
1163
+ self,
1164
+ dim: int,
1165
+ mlp_ratio: float = 4.0,
1166
+ act_layer: nn.Module = nn.GELU,
1167
+ norm_layer: nn.Module = nn.BatchNorm2d,
1168
+ drop: float = 0.0,
1169
+ drop_path: float = 0.0,
1170
+ use_layer_scale: bool = True,
1171
+ layer_scale_init_value: float = 1e-5,
1172
+ ):
1173
+ """Build Attention Block.
1174
+
1175
+ Args:
1176
+ dim: Number of embedding dimensions.
1177
+ mlp_ratio: MLP expansion ratio. Default: 4.0
1178
+ act_layer: Activation layer. Default: ``nn.GELU``
1179
+ norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
1180
+ drop: Dropout rate. Default: 0.0
1181
+ drop_path: Drop path rate. Default: 0.0
1182
+ use_layer_scale: Flag to turn on layer scale. Default: ``True``
1183
+ layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
1184
+ """
1185
+
1186
+ super().__init__()
1187
+
1188
+ # Fallback, sometimes batchnorm tensors
1189
+ # do not get instantiated correctly on some processes
1190
+ # when using deepspeed + accelerate
1191
+ norm_layer_ = norm_layer(num_features=dim)
1192
+ if norm_layer_.weight.shape[0] == 0:
1193
+ norm_layer_.weight = nn.Parameter(torch.zeros(dim))
1194
+ if norm_layer_.bias.shape[0] == 0:
1195
+ norm_layer_.bias = nn.Parameter(torch.zeros(dim))
1196
+
1197
+ self.norm = norm_layer_
1198
+ self.token_mixer = MHSA(dim=dim)
1199
+
1200
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
1201
+ mlp_ratio
1202
+ )
1203
+ mlp_hidden_dim = int(dim * mlp_ratio)
1204
+ self.convffn = ConvFFN(
1205
+ in_channels=dim,
1206
+ hidden_channels=mlp_hidden_dim,
1207
+ act_layer=act_layer,
1208
+ drop=drop,
1209
+ )
1210
+
1211
+ # Drop path
1212
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1213
+
1214
+ # Layer Scale
1215
+ self.use_layer_scale = use_layer_scale
1216
+ if use_layer_scale:
1217
+ self.layer_scale_1 = nn.Parameter(
1218
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1219
+ )
1220
+ self.layer_scale_2 = nn.Parameter(
1221
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1222
+ )
1223
+
1224
+ def forward(self, x):
1225
+ if self.use_layer_scale:
1226
+ x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
1227
+ x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
1228
+ else:
1229
+ x = x + self.drop_path(self.token_mixer(self.norm(x)))
1230
+ x = x + self.drop_path(self.convffn(x))
1231
+ return x
1232
+
1233
+
1234
+ def basic_blocks(
1235
+ dim: int,
1236
+ block_index: int,
1237
+ num_blocks: List[int],
1238
+ token_mixer_type: str,
1239
+ kernel_size: int = 3,
1240
+ mlp_ratio: float = 4.0,
1241
+ act_layer: nn.Module = nn.GELU,
1242
+ norm_layer: nn.Module = nn.BatchNorm2d,
1243
+ drop_rate: float = 0.0,
1244
+ drop_path_rate: float = 0.0,
1245
+ use_layer_scale: bool = True,
1246
+ layer_scale_init_value: float = 1e-5,
1247
+ inference_mode=False,
1248
+ ) -> nn.Sequential:
1249
+ """Build FastViT blocks within a stage.
1250
+
1251
+ Args:
1252
+ dim: Number of embedding dimensions.
1253
+ block_index: block index.
1254
+ num_blocks: List containing number of blocks per stage.
1255
+ token_mixer_type: Token mixer type.
1256
+ kernel_size: Kernel size for repmixer.
1257
+ mlp_ratio: MLP expansion ratio.
1258
+ act_layer: Activation layer.
1259
+ norm_layer: Normalization layer.
1260
+ drop_rate: Dropout rate.
1261
+ drop_path_rate: Drop path rate.
1262
+ use_layer_scale: Flag to turn on layer scale regularization.
1263
+ layer_scale_init_value: Layer scale value at initialization.
1264
+ inference_mode: Flag to instantiate block in inference mode.
1265
+
1266
+ Returns:
1267
+ nn.Sequential object of all the blocks within the stage.
1268
+ """
1269
+ blocks = []
1270
+ for block_idx in range(num_blocks[block_index]):
1271
+ block_dpr = (
1272
+ drop_path_rate
1273
+ * (block_idx + sum(num_blocks[:block_index]))
1274
+ / (sum(num_blocks) - 1)
1275
+ )
1276
+ if token_mixer_type == "repmixer":
1277
+ blocks.append(
1278
+ RepMixerBlock(
1279
+ dim,
1280
+ kernel_size=kernel_size,
1281
+ mlp_ratio=mlp_ratio,
1282
+ act_layer=act_layer,
1283
+ drop=drop_rate,
1284
+ drop_path=block_dpr,
1285
+ use_layer_scale=use_layer_scale,
1286
+ layer_scale_init_value=layer_scale_init_value,
1287
+ inference_mode=inference_mode,
1288
+ )
1289
+ )
1290
+ elif token_mixer_type == "attention":
1291
+ blocks.append(
1292
+ AttentionBlock(
1293
+ dim,
1294
+ mlp_ratio=mlp_ratio,
1295
+ act_layer=act_layer,
1296
+ norm_layer=norm_layer,
1297
+ drop=drop_rate,
1298
+ drop_path=block_dpr,
1299
+ use_layer_scale=use_layer_scale,
1300
+ layer_scale_init_value=layer_scale_init_value,
1301
+ )
1302
+ )
1303
+ else:
1304
+ raise ValueError(
1305
+ "Token mixer type: {} not supported".format(token_mixer_type)
1306
+ )
1307
+ blocks = nn.Sequential(*blocks)
1308
+ return blocks
1309
+
1310
+
1311
+ class GlobalPool2D(nn.Module):
1312
+ """This class implements global pooling with linear projection."""
1313
+
1314
+ def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
1315
+ super().__init__()
1316
+ scale = in_dim**-0.5
1317
+ self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
1318
+ self.in_dim = in_dim
1319
+ self.out_dim = out_dim
1320
+
1321
+ def pool(self, x) -> Tensor:
1322
+ if x.dim() == 4:
1323
+ dims = [-2, -1]
1324
+ elif x.dim() == 5:
1325
+ dims = [-3, -2, -1]
1326
+ x = torch.mean(x, dim=dims, keepdim=False)
1327
+ return x
1328
+
1329
+ def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
1330
+ # x is of shape [batch, in_dim]
1331
+ assert (
1332
+ x.dim() == 4
1333
+ ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
1334
+ x.shape
1335
+ )
1336
+
1337
+ # [batch, in_dim, in_height, in_width] --> [batch, in_dim]
1338
+ x = self.pool(x)
1339
+ # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
1340
+ x = x @ self.proj
1341
+ return x
1342
+
1343
+
1344
+ class FastViT(nn.Module):
1345
+ """
1346
+ This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
1347
+ """
1348
+
1349
+ def __init__(
1350
+ self,
1351
+ layers,
1352
+ token_mixers: Tuple[str, ...],
1353
+ embed_dims=None,
1354
+ mlp_ratios=None,
1355
+ downsamples=None,
1356
+ se_downsamples=None,
1357
+ repmixer_kernel_size=3,
1358
+ norm_layer: nn.Module = nn.BatchNorm2d,
1359
+ act_layer: nn.Module = nn.GELU,
1360
+ num_classes=1000,
1361
+ pos_embs=None,
1362
+ down_patch_size=7,
1363
+ down_stride=2,
1364
+ drop_rate=0.0,
1365
+ drop_path_rate=0.0,
1366
+ use_layer_scale=True,
1367
+ layer_scale_init_value=1e-5,
1368
+ init_cfg=None,
1369
+ pretrained=None,
1370
+ cls_ratio=2.0,
1371
+ inference_mode=False,
1372
+ stem_scale_branch=True,
1373
+ **kwargs,
1374
+ ) -> None:
1375
+
1376
+ super().__init__()
1377
+
1378
+ self.num_classes = num_classes
1379
+ if len(layers) == 4:
1380
+ self.out_indices = [0, 2, 4, 7]
1381
+ elif len(layers) == 5:
1382
+ self.out_indices = [0, 2, 4, 7, 10]
1383
+ else:
1384
+ raise NotImplementedError("FPN is not implemented for more than 5 stages.")
1385
+
1386
+ if pos_embs is None:
1387
+ pos_embs = [None] * len(layers)
1388
+
1389
+ if se_downsamples is None:
1390
+ se_downsamples = [False] * len(layers)
1391
+
1392
+ # Convolutional stem
1393
+ self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode,
1394
+ use_scale_branch=stem_scale_branch)
1395
+
1396
+ # Build the main stages of the network architecture
1397
+ network = []
1398
+ for i in range(len(layers)):
1399
+ # Add position embeddings if requested
1400
+ if pos_embs[i] is not None:
1401
+ network.append(
1402
+ pos_embs[i](
1403
+ embed_dims[i], embed_dims[i], inference_mode=inference_mode
1404
+ )
1405
+ )
1406
+ stage = basic_blocks(
1407
+ embed_dims[i],
1408
+ i,
1409
+ layers,
1410
+ token_mixer_type=token_mixers[i],
1411
+ kernel_size=repmixer_kernel_size,
1412
+ mlp_ratio=mlp_ratios[i],
1413
+ act_layer=act_layer,
1414
+ norm_layer=norm_layer,
1415
+ drop_rate=drop_rate,
1416
+ drop_path_rate=drop_path_rate,
1417
+ use_layer_scale=use_layer_scale,
1418
+ layer_scale_init_value=layer_scale_init_value,
1419
+ inference_mode=inference_mode,
1420
+ )
1421
+ network.append(stage)
1422
+ if i >= len(layers) - 1:
1423
+ break
1424
+
1425
+ # Patch merging/downsampling between stages.
1426
+ if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
1427
+ network.append(
1428
+ PatchEmbed(
1429
+ patch_size=down_patch_size,
1430
+ stride=down_stride,
1431
+ in_channels=embed_dims[i],
1432
+ embed_dim=embed_dims[i + 1],
1433
+ inference_mode=inference_mode,
1434
+ use_se=se_downsamples[i + 1],
1435
+ )
1436
+ )
1437
+ self.network = nn.ModuleList(network)
1438
+
1439
+ # Classifier head
1440
+ self.conv_exp = MobileOneBlock(
1441
+ in_channels=embed_dims[-1],
1442
+ out_channels=int(embed_dims[-1] * cls_ratio),
1443
+ kernel_size=3,
1444
+ stride=1,
1445
+ padding=1,
1446
+ groups=embed_dims[-1],
1447
+ inference_mode=inference_mode,
1448
+ use_se=True,
1449
+ num_conv_branches=1,
1450
+ )
1451
+ self.head = (
1452
+ nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
1453
+ if num_classes > 0
1454
+ else nn.Identity()
1455
+ )
1456
+ self.apply(self.cls_init_weights)
1457
+ self.init_cfg = copy.deepcopy(init_cfg)
1458
+
1459
+ def cls_init_weights(self, m: nn.Module) -> None:
1460
+ """Init. for classification"""
1461
+ if isinstance(m, nn.Linear):
1462
+ normal_(m.weight, std=0.02)
1463
+ if isinstance(m, nn.Linear) and m.bias is not None:
1464
+ nn.init.constant_(m.bias, 0)
1465
+
1466
+ def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
1467
+ x = self.patch_embed(x)
1468
+ return x
1469
+
1470
+ def forward_tokens(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1471
+ for idx, block in enumerate(self.network):
1472
+ x = block(x)
1473
+ return x
1474
+
1475
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
1476
+ # input embedding
1477
+ x = self.forward_embeddings(x)
1478
+ # through backbone
1479
+ x = self.forward_tokens(x)
1480
+ # for image classification/embedding
1481
+ x = self.conv_exp(x)
1482
+ cls_out = self.head(x)
1483
+
1484
+ out_dict = dict()
1485
+ if kwargs.get("return_image_embeddings", False):
1486
+ out_dict.update({"logits": cls_out})
1487
+ out_dict.update({"image_embeddings": x})
1488
+ return out_dict
1489
+ else:
1490
+ return cls_out
1491
+
1492
+
1493
+ @register_model
1494
+ def fastvithd(pretrained=False, **kwargs):
1495
+ """Instantiate FastViTHD model variant."""
1496
+ layers = [2, 12, 24, 4, 2]
1497
+ embed_dims = [96, 192, 384, 768, 1536]
1498
+ mlp_ratios = [4, 4, 4, 4, 4]
1499
+ downsamples = [True, True, True, True, True]
1500
+ pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7)), partial(RepCPE, spatial_shape=(7, 7))]
1501
+ token_mixers = ("repmixer", "repmixer", "repmixer", "attention", "attention")
1502
+ model = FastViT(
1503
+ layers,
1504
+ token_mixers=token_mixers,
1505
+ embed_dims=embed_dims,
1506
+ pos_embs=pos_embs,
1507
+ mlp_ratios=mlp_ratios,
1508
+ downsamples=downsamples,
1509
+ norm_layer=LayerNormChannel,
1510
+ stem_scale_branch=False,
1511
+ inference_mode=True,
1512
+ **kwargs,
1513
+ )
1514
+ model.default_cfg = default_cfgs["fastvit_m"]
1515
+ if pretrained:
1516
+ raise ValueError("Functionality not implemented.")
1517
+ return model
1518
+
1519
+ def load_model_config(
1520
+ model_name: str,
1521
+ ) -> Any:
1522
+ model_cfg = {
1523
+ "embed_dim": 768,
1524
+ "image_cfg": {
1525
+ "image_size": 1024,
1526
+ "model_name": "fastvithd",
1527
+ "embed_dim": 3072,
1528
+ "patch_size": 64
1529
+ },
1530
+ "text_cfg": {
1531
+ "context_length": 77,
1532
+ "vocab_size": 49408,
1533
+ "dim": 768,
1534
+ "ffn_multiplier_per_layer": 4.0,
1535
+ "n_heads_per_layer": 12,
1536
+ "n_transformer_layers": 12,
1537
+ "norm_layer": "layer_norm_fp32",
1538
+ "causal_masking": False,
1539
+ "model_name": "base"
1540
+ }
1541
+ }
1542
+ return model_cfg
1543
+
1544
+
1545
+ class MCi(nn.Module):
1546
+ """
1547
+ This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_
1548
+ """
1549
+
1550
+ def __init__(self, model_name: str, *args, **kwargs) -> None:
1551
+ super().__init__()
1552
+ self.projection_dim = None
1553
+ if "projection_dim" in kwargs:
1554
+ self.projection_dim = kwargs.get("projection_dim")
1555
+
1556
+ # Create model
1557
+ self.model = create_model(model_name, projection_dim=self.projection_dim)
1558
+
1559
+ # Build out projection head.
1560
+ if self.projection_dim is not None:
1561
+ if hasattr(self.model, "head"):
1562
+ self.model.head = MCi._update_image_classifier(
1563
+ image_classifier=self.model.head, projection_dim=self.projection_dim
1564
+ )
1565
+
1566
+ def forward(self, x: Any, *args, **kwargs) -> Any:
1567
+ """A forward function of the model."""
1568
+ x = self.model(x, *args, **kwargs)
1569
+ return x
1570
+
1571
+ @staticmethod
1572
+ def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
1573
+ """Return the input feature dimension to the image classification head."""
1574
+ in_features = None
1575
+ if isinstance(image_classifier, nn.Sequential):
1576
+ # Classifier that uses nn.Sequential usually has global pooling and
1577
+ # multiple linear layers. Find the first linear layer and get its
1578
+ # in_features
1579
+ for layer in image_classifier:
1580
+ if isinstance(layer, nn.Linear):
1581
+ in_features = layer.in_features
1582
+ break
1583
+ elif isinstance(image_classifier, nn.Linear):
1584
+ in_features = image_classifier.in_features
1585
+
1586
+ if in_features is None:
1587
+ raise NotImplementedError(
1588
+ f"Cannot get input feature dimension of {image_classifier}."
1589
+ )
1590
+ return in_features
1591
+
1592
+ @staticmethod
1593
+ def _update_image_classifier(
1594
+ image_classifier: nn.Module, projection_dim: int, *args, **kwargs
1595
+ ) -> nn.Module:
1596
+ in_features = MCi._get_in_feature_dimension(image_classifier)
1597
+ new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
1598
+ return new_img_classifier
1599
+
1600
+
1601
+ class MobileCLIPVisionTower(nn.Module):
1602
+ def __init__(self, vision_tower, args, delay_load=False):
1603
+ super().__init__()
1604
+
1605
+ self.is_loaded = False
1606
+ self.vision_tower_name = vision_tower
1607
+ self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
1608
+ self.input_image_size = int(vision_tower.split("_")[-1])
1609
+
1610
+ # Delay load is disabled for now
1611
+ if not delay_load:
1612
+ self.load_model()
1613
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
1614
+ self.load_model()
1615
+ else:
1616
+ model_cfg = load_model_config(self.vision_tower_name)
1617
+ self.cfg_only = model_cfg
1618
+
1619
+ def load_model(self, device_map=None):
1620
+ if self.is_loaded:
1621
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
1622
+ return
1623
+
1624
+ # Load model config
1625
+ model_cfg = load_model_config(self.vision_tower_name)
1626
+
1627
+ # Override default image resolution
1628
+ model_cfg["image_cfg"]["image_size"] = self.input_image_size
1629
+
1630
+ self.cfg_only = model_cfg
1631
+
1632
+ # Build HF CLIPImageProcessor with MobileCLIP parameters
1633
+ self.image_processor = CLIPImageProcessor(crop_size={"height": model_cfg["image_cfg"]["image_size"],
1634
+ "width": model_cfg["image_cfg"]["image_size"]},
1635
+ image_mean=[0.0, 0.0, 0.0],
1636
+ image_std=[1.0, 1.0, 1.0],
1637
+ size={"shortest_edge": model_cfg["image_cfg"]["image_size"]})
1638
+
1639
+ # Instantiate the image encoder
1640
+ self.vision_tower = MCi(model_name=model_cfg["image_cfg"]["model_name"],
1641
+ projection_dim=model_cfg["embed_dim"])
1642
+
1643
+ if not self.tune_vision_tower:
1644
+ self.vision_tower.requires_grad_(False)
1645
+
1646
+ self.is_loaded = True
1647
+
1648
+ def feature_select(self, image_forward_outs):
1649
+ # Features from penultimate layer
1650
+ image_features = image_forward_outs["image_embeddings"]
1651
+
1652
+ # Reshape 4D tensor to 3D
1653
+ B, C, H, W = image_features.shape
1654
+ image_features = image_features.reshape(B, C, H*W)
1655
+ image_features = image_features.transpose(1, 2)
1656
+ return image_features
1657
+
1658
+ def forward(self, images):
1659
+ if self.tune_vision_tower:
1660
+ return self.forward_images(images)
1661
+ else:
1662
+ with torch.no_grad():
1663
+ return self.forward_images(images)
1664
+
1665
+ def forward_images(self, images):
1666
+ if type(images) is list:
1667
+ image_features = []
1668
+ for image in images:
1669
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True)
1670
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
1671
+ image_features.append(image_feature)
1672
+ else:
1673
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True)
1674
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
1675
+
1676
+ return image_features
1677
+
1678
+ @property
1679
+ def dummy_feature(self):
1680
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
1681
+
1682
+ @property
1683
+ def dtype(self):
1684
+ return next(self.vision_tower.parameters()).dtype
1685
+
1686
+ @property
1687
+ def device(self):
1688
+ return next(self.vision_tower.parameters()).device
1689
+
1690
+ @property
1691
+ def config(self):
1692
+ return self.cfg_only
1693
+
1694
+ @property
1695
+ def hidden_size(self):
1696
+ return self.config["image_cfg"]["embed_dim"]
1697
+
1698
+ @property
1699
+ def num_patches_per_side(self):
1700
+ return self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]
1701
+
1702
+ @property
1703
+ def num_patches(self):
1704
+ return (self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]) ** 2
1705
+
1706
+ class IdentityMap(nn.Module):
1707
+ def __init__(self):
1708
+ super().__init__()
1709
+
1710
+ def forward(self, x, *args, **kwargs):
1711
+ return x
1712
+
1713
+ @property
1714
+ def config(self):
1715
+ return {"mm_projector_type": 'identity'}
1716
+
1717
+ def build_vision_projector(config, delay_load=False, **kwargs):
1718
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
1719
+
1720
+ if projector_type == 'linear':
1721
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
1722
+
1723
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
1724
+ if mlp_gelu_match:
1725
+ mlp_depth = int(mlp_gelu_match.group(1))
1726
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
1727
+ for _ in range(1, mlp_depth):
1728
+ modules.append(nn.GELU())
1729
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
1730
+ return nn.Sequential(*modules)
1731
+
1732
+ if projector_type == 'identity':
1733
+ return IdentityMap()
1734
+
1735
+ raise ValueError(f'Unknown projector type: {projector_type}')
1736
+
1737
+ def build_vision_tower(vision_tower_cfg, **kwargs):
1738
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
1739
+ return MobileCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
1740
+
1741
+ class LlavaMetaModel:
1742
+
1743
+ def __init__(self, config):
1744
+ super(LlavaMetaModel, self).__init__(config)
1745
+
1746
+ if hasattr(config, "mm_vision_tower"):
1747
+ self.vision_tower = build_vision_tower(config, delay_load=True)
1748
+ self.mm_projector = build_vision_projector(config)
1749
+
1750
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
1751
+ self.image_newline = nn.Parameter(
1752
+ torch.empty(config.hidden_size, dtype=self.dtype)
1753
+ )
1754
+
1755
+ def get_vision_tower(self):
1756
+ vision_tower = getattr(self, 'vision_tower', None)
1757
+ if type(vision_tower) is list:
1758
+ vision_tower = vision_tower[0]
1759
+ return vision_tower
1760
+
1761
+ def initialize_vision_modules(self, model_args, fsdp=None):
1762
+ vision_tower = model_args.vision_tower
1763
+ mm_vision_select_layer = model_args.mm_vision_select_layer
1764
+ mm_vision_select_feature = model_args.mm_vision_select_feature
1765
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
1766
+ mm_patch_merge_type = model_args.mm_patch_merge_type
1767
+
1768
+ self.config.mm_vision_tower = vision_tower
1769
+
1770
+ if self.get_vision_tower() is None:
1771
+ vision_tower = build_vision_tower(model_args)
1772
+
1773
+ if fsdp is not None and len(fsdp) > 0:
1774
+ self.vision_tower = [vision_tower]
1775
+ else:
1776
+ self.vision_tower = vision_tower
1777
+ else:
1778
+ if fsdp is not None and len(fsdp) > 0:
1779
+ vision_tower = self.vision_tower[0]
1780
+ else:
1781
+ vision_tower = self.vision_tower
1782
+ vision_tower.load_model()
1783
+
1784
+ self.config.use_mm_proj = True
1785
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
1786
+ self.config.mm_hidden_size = vision_tower.hidden_size
1787
+ self.config.mm_vision_select_layer = mm_vision_select_layer
1788
+ self.config.mm_vision_select_feature = mm_vision_select_feature
1789
+ self.config.mm_patch_merge_type = mm_patch_merge_type
1790
+
1791
+ if getattr(self, 'mm_projector', None) is None:
1792
+ self.mm_projector = build_vision_projector(self.config)
1793
+
1794
+ if 'unpad' in mm_patch_merge_type:
1795
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
1796
+ self.image_newline = nn.Parameter(
1797
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
1798
+ )
1799
+ else:
1800
+ # In case it is frozen by LoRA
1801
+ for p in self.mm_projector.parameters():
1802
+ p.requires_grad = True
1803
+
1804
+ if pretrain_mm_mlp_adapter is not None:
1805
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
1806
+
1807
+ def get_w(weights, keyword):
1808
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
1809
+
1810
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
1811
+
1812
+ def select_best_resolution(original_size, possible_resolutions):
1813
+ """
1814
+ Selects the best resolution from a list of possible resolutions based on the original size.
1815
+
1816
+ Args:
1817
+ original_size (tuple): The original size of the image in the format (width, height).
1818
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
1819
+
1820
+ Returns:
1821
+ tuple: The best fit resolution in the format (width, height).
1822
+ """
1823
+ original_width, original_height = original_size
1824
+ best_fit = None
1825
+ max_effective_resolution = 0
1826
+ min_wasted_resolution = float('inf')
1827
+
1828
+ for width, height in possible_resolutions:
1829
+ scale = min(width / original_width, height / original_height)
1830
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
1831
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
1832
+ wasted_resolution = (width * height) - effective_resolution
1833
+
1834
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
1835
+ max_effective_resolution = effective_resolution
1836
+ min_wasted_resolution = wasted_resolution
1837
+ best_fit = (width, height)
1838
+
1839
+ return best_fit
1840
+
1841
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
1842
+ """
1843
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
1844
+
1845
+ Args:
1846
+ image_size (tuple): The size of the input image in the format (width, height).
1847
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
1848
+ patch_size (int): The size of each image patch.
1849
+
1850
+ Returns:
1851
+ tuple: The shape of the image patch grid in the format (width, height).
1852
+ """
1853
+ import ast
1854
+ if type(grid_pinpoints) is list:
1855
+ possible_resolutions = grid_pinpoints
1856
+ else:
1857
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
1858
+ width, height = select_best_resolution(image_size, possible_resolutions)
1859
+ return width // patch_size, height // patch_size
1860
+
1861
+ class LlavaMetaForCausalLM(ABC):
1862
+
1863
+ @abstractmethod
1864
+ def get_model(self):
1865
+ pass
1866
+
1867
+ def get_vision_tower(self):
1868
+ return self.get_model().get_vision_tower()
1869
+
1870
+ def encode_images(self, images):
1871
+ image_features = self.get_model().get_vision_tower()(images)
1872
+ image_features = self.get_model().mm_projector(image_features)
1873
+ return image_features
1874
+
1875
+ def prepare_inputs_labels_for_multimodal(
1876
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
1877
+ images, image_sizes=None
1878
+ ):
1879
+ vision_tower = self.get_vision_tower()
1880
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
1881
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
1882
+
1883
+ if type(images) is list or images.ndim == 5:
1884
+ if type(images) is list:
1885
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
1886
+ concat_images = torch.cat([image for image in images], dim=0)
1887
+ image_features = self.encode_images(concat_images)
1888
+ split_sizes = [image.shape[0] for image in images]
1889
+ image_features = torch.split(image_features, split_sizes, dim=0)
1890
+ mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
1891
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
1892
+ if mm_patch_merge_type == 'flat':
1893
+ image_features = [x.flatten(0, 1) for x in image_features]
1894
+ elif mm_patch_merge_type.startswith('spatial'):
1895
+ new_image_features = []
1896
+ for image_idx, image_feature in enumerate(image_features):
1897
+ if image_feature.shape[0] > 1:
1898
+ base_image_feature = image_feature[0]
1899
+ image_feature = image_feature[1:]
1900
+ height = width = self.get_vision_tower().num_patches_per_side
1901
+ assert height * width == base_image_feature.shape[0]
1902
+ if image_aspect_ratio == 'anyres':
1903
+ if hasattr(self.get_vision_tower(), 's2_image_size'):
1904
+ img_size = self.get_vision_tower().s2_image_size
1905
+ elif isinstance(self.get_vision_tower().config, dict):
1906
+ img_size = self.get_vision_tower().config["image_cfg"]["image_size"]
1907
+ else:
1908
+ img_size = self.get_vision_tower().config.image_size
1909
+
1910
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, img_size)
1911
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
1912
+ else:
1913
+ raise NotImplementedError
1914
+ if 'unpad' in mm_patch_merge_type:
1915
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
1916
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
1917
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
1918
+ image_feature = torch.cat((
1919
+ image_feature,
1920
+ self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
1921
+ ), dim=-1)
1922
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
1923
+ else:
1924
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
1925
+ image_feature = image_feature.flatten(0, 3)
1926
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
1927
+ else:
1928
+ image_feature = image_feature[0]
1929
+ if 'unpad' in mm_patch_merge_type:
1930
+ image_feature = torch.cat((
1931
+ image_feature,
1932
+ self.model.image_newline[None].to(image_feature.device)
1933
+ ), dim=0)
1934
+ new_image_features.append(image_feature)
1935
+ image_features = new_image_features
1936
+ else:
1937
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
1938
+ else:
1939
+ image_features = self.encode_images(images)
1940
+
1941
+ # TODO: image start / end is not implemented here to support pretraining.
1942
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
1943
+ raise NotImplementedError
1944
+
1945
+ # Let's just add dummy tensors if they do not exist,
1946
+ # it is a headache to deal with None all the time.
1947
+ # But it is not ideal, and if you have a better idea,
1948
+ # please open an issue / submit a PR, thanks.
1949
+ _labels = labels
1950
+ _position_ids = position_ids
1951
+ _attention_mask = attention_mask
1952
+ if attention_mask is None:
1953
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1954
+ else:
1955
+ attention_mask = attention_mask.bool()
1956
+ if position_ids is None:
1957
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
1958
+ if labels is None:
1959
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
1960
+
1961
+ # remove the padding using attention_mask -- FIXME
1962
+ _input_ids = input_ids
1963
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
1964
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
1965
+
1966
+ new_input_embeds = []
1967
+ new_labels = []
1968
+ cur_image_idx = 0
1969
+ for batch_idx, cur_input_ids in enumerate(input_ids):
1970
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
1971
+ if num_images == 0:
1972
+ cur_image_features = image_features[cur_image_idx]
1973
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
1974
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
1975
+ new_input_embeds.append(cur_input_embeds)
1976
+ new_labels.append(labels[batch_idx])
1977
+ cur_image_idx += 1
1978
+ continue
1979
+
1980
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
1981
+ cur_input_ids_noim = []
1982
+ cur_labels = labels[batch_idx]
1983
+ cur_labels_noim = []
1984
+ for i in range(len(image_token_indices) - 1):
1985
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
1986
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
1987
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
1988
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
1989
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
1990
+ cur_new_input_embeds = []
1991
+ cur_new_labels = []
1992
+
1993
+ for i in range(num_images + 1):
1994
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
1995
+ cur_new_labels.append(cur_labels_noim[i])
1996
+ if i < num_images:
1997
+ cur_image_features = image_features[cur_image_idx]
1998
+ cur_image_idx += 1
1999
+ cur_new_input_embeds.append(cur_image_features)
2000
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
2001
+
2002
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
2003
+
2004
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
2005
+ cur_new_labels = torch.cat(cur_new_labels)
2006
+
2007
+ new_input_embeds.append(cur_new_input_embeds)
2008
+ new_labels.append(cur_new_labels)
2009
+
2010
+ # Truncate sequences to max length as image embeddings can make the sequence longer
2011
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
2012
+ if tokenizer_model_max_length is not None:
2013
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
2014
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
2015
+
2016
+ # Combine them
2017
+ max_len = max(x.shape[0] for x in new_input_embeds)
2018
+ batch_size = len(new_input_embeds)
2019
+
2020
+ new_input_embeds_padded = []
2021
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
2022
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
2023
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
2024
+
2025
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
2026
+ cur_len = cur_new_embed.shape[0]
2027
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
2028
+ new_input_embeds_padded.append(torch.cat((
2029
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
2030
+ cur_new_embed
2031
+ ), dim=0))
2032
+ if cur_len > 0:
2033
+ new_labels_padded[i, -cur_len:] = cur_new_labels
2034
+ attention_mask[i, -cur_len:] = True
2035
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
2036
+ else:
2037
+ new_input_embeds_padded.append(torch.cat((
2038
+ cur_new_embed,
2039
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
2040
+ ), dim=0))
2041
+ if cur_len > 0:
2042
+ new_labels_padded[i, :cur_len] = cur_new_labels
2043
+ attention_mask[i, :cur_len] = True
2044
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
2045
+
2046
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
2047
+
2048
+ if _labels is None:
2049
+ new_labels = None
2050
+ else:
2051
+ new_labels = new_labels_padded
2052
+
2053
+ if _attention_mask is None:
2054
+ attention_mask = None
2055
+ else:
2056
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
2057
+
2058
+ if _position_ids is None:
2059
+ position_ids = None
2060
+
2061
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
2062
+
2063
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
2064
+ if model_args.mm_use_im_patch_token:
2065
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
2066
+ self.resize_token_embeddings(len(tokenizer))
2067
+
2068
+ if model_args.mm_use_im_start_end:
2069
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
2070
+ self.resize_token_embeddings(len(tokenizer))
2071
+
2072
+ if num_new_tokens > 0:
2073
+ input_embeddings = self.get_input_embeddings().weight.data
2074
+ output_embeddings = self.get_output_embeddings().weight.data
2075
+
2076
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
2077
+ dim=0, keepdim=True)
2078
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
2079
+ dim=0, keepdim=True)
2080
+
2081
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
2082
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
2083
+
2084
+ if model_args.tune_mm_mlp_adapter:
2085
+ for p in self.get_input_embeddings().parameters():
2086
+ p.requires_grad = True
2087
+ for p in self.get_output_embeddings().parameters():
2088
+ p.requires_grad = False
2089
+
2090
+ if model_args.pretrain_mm_mlp_adapter:
2091
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
2092
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
2093
+ assert num_new_tokens == 2
2094
+ if input_embeddings.shape == embed_tokens_weight.shape:
2095
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
2096
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
2097
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
2098
+ else:
2099
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
2100
+ elif model_args.mm_use_im_patch_token:
2101
+ if model_args.tune_mm_mlp_adapter:
2102
+ for p in self.get_input_embeddings().parameters():
2103
+ p.requires_grad = False
2104
+ for p in self.get_output_embeddings().parameters():
2105
+ p.requires_grad = False
2106
+
2107
+
2108
+ class LlavaQwen2Model(LlavaMetaModel, Qwen2Model):
2109
+ config_class = LlavaConfig
2110
+
2111
+ def __init__(self, config: Qwen2Config):
2112
+ super(LlavaQwen2Model, self).__init__(config)
2113
+
2114
+
2115
+ class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
2116
+ config_class = LlavaConfig
2117
+
2118
+ def __init__(self, config):
2119
+ super(Qwen2ForCausalLM, self).__init__(config)
2120
+ self.model = LlavaQwen2Model(config)
2121
+ # self.pretraining_tp = config.pretraining_tp
2122
+ self.vocab_size = config.vocab_size
2123
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
2124
+
2125
+ # Initialize weights and apply final processing
2126
+ self.post_init()
2127
+
2128
+ def get_model(self):
2129
+ return self.model
2130
+
2131
+ def forward(
2132
+ self,
2133
+ input_ids: torch.LongTensor = None,
2134
+ attention_mask: Optional[torch.Tensor] = None,
2135
+ position_ids: Optional[torch.LongTensor] = None,
2136
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
2137
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2138
+ labels: Optional[torch.LongTensor] = None,
2139
+ use_cache: Optional[bool] = None,
2140
+ output_attentions: Optional[bool] = None,
2141
+ output_hidden_states: Optional[bool] = None,
2142
+ images: Optional[torch.FloatTensor] = None,
2143
+ image_sizes: Optional[List[List[int]]] = None,
2144
+ return_dict: Optional[bool] = None,
2145
+ cache_position=None,
2146
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
2147
+
2148
+ if inputs_embeds is None:
2149
+ (
2150
+ input_ids,
2151
+ position_ids,
2152
+ attention_mask,
2153
+ past_key_values,
2154
+ inputs_embeds,
2155
+ labels
2156
+ ) = self.prepare_inputs_labels_for_multimodal(
2157
+ input_ids,
2158
+ position_ids,
2159
+ attention_mask,
2160
+ past_key_values,
2161
+ labels,
2162
+ images,
2163
+ image_sizes
2164
+ )
2165
+
2166
+ return super().forward(
2167
+ input_ids=input_ids,
2168
+ attention_mask=attention_mask,
2169
+ position_ids=position_ids,
2170
+ past_key_values=past_key_values,
2171
+ inputs_embeds=inputs_embeds,
2172
+ labels=labels,
2173
+ use_cache=use_cache,
2174
+ output_attentions=output_attentions,
2175
+ output_hidden_states=output_hidden_states,
2176
+ return_dict=return_dict
2177
+ )
2178
+
2179
+ @torch.no_grad()
2180
+ def generate(
2181
+ self,
2182
+ inputs: Optional[torch.Tensor] = None,
2183
+ images: Optional[torch.Tensor] = None,
2184
+ image_sizes: Optional[torch.Tensor] = None,
2185
+ **kwargs,
2186
+ ) -> Union[GenerateOutput, torch.LongTensor]:
2187
+ position_ids = kwargs.pop("position_ids", None)
2188
+ attention_mask = kwargs.pop("attention_mask", None)
2189
+ if "inputs_embeds" in kwargs:
2190
+ raise NotImplementedError("`inputs_embeds` is not supported")
2191
+
2192
+ if images is not None:
2193
+ (
2194
+ inputs,
2195
+ position_ids,
2196
+ attention_mask,
2197
+ _,
2198
+ inputs_embeds,
2199
+ _
2200
+ ) = self.prepare_inputs_labels_for_multimodal(
2201
+ inputs,
2202
+ position_ids,
2203
+ attention_mask,
2204
+ None,
2205
+ None,
2206
+ images,
2207
+ image_sizes=image_sizes
2208
+ )
2209
+ else:
2210
+ inputs_embeds = self.get_model().embed_tokens(inputs)
2211
+
2212
+ return super().generate(
2213
+ position_ids=position_ids,
2214
+ attention_mask=attention_mask,
2215
+ inputs_embeds=inputs_embeds,
2216
+ **kwargs
2217
+ )
2218
+
2219
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
2220
+ inputs_embeds=None, **kwargs):
2221
+ images = kwargs.pop("images", None)
2222
+ image_sizes = kwargs.pop("image_sizes", None)
2223
+ inputs = super().prepare_inputs_for_generation(
2224
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
2225
+ )
2226
+ if images is not None:
2227
+ inputs['images'] = images
2228
+ if image_sizes is not None:
2229
+ inputs['image_sizes'] = image_sizes
2230
+ return inputs
2231
+
2232
+
2233
+ AutoConfig.register("llava_qwen2", LlavaConfig)
2234
+ AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM)