Eagle2-2B / convnext_encoder.py
Zhiding's picture
update
1771875
import torch, os
import torch.nn as nn
from timm import create_model
from transformers import CLIPImageProcessor
from .convnext import convnext_xxlarge
from torch.utils.checkpoint import checkpoint
import torch
from torchvision import transforms as T
from PIL import Image
cfg={
"crop_size": 256,
"do_center_crop": True,
"do_normalize": True,
"do_resize": True,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"resample": 3,
"size": 256
}
MEAN_SLIP = [0.5, 0.5, 0.5]
STD_SLIP = [0.5, 0.5, 0.5]
MEAN_CLIP = [0.48145466, 0.4578275, 0.40821073]
STD_CLIP = [0.26862954, 0.26130258, 0.27577711]
a = [s_slip / s_clip for s_slip, s_clip in zip(STD_SLIP, STD_CLIP)]
b = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SLIP, MEAN_CLIP, STD_CLIP)]
class SlipToClipTransform:
def __init__(self, a, b):
self.a = torch.tensor(a).view(-1, 1, 1)
self.b = torch.tensor(b).view(-1, 1, 1)
def __call__(self, x_slip):
return x_slip * self.a.to(x_slip.device) + self.b.to(x_slip.device)
slip_to_clip = SlipToClipTransform(a, b)
class ConvNextVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False, normalize_type=None):
super().__init__()
self.is_loaded = False
self.freeze_vision=args.freeze_vision
self.input_image_size=args.input_image_size
self.vision_tower_name = vision_tower
self.name = 'convnext'
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.pre_norm = normalize_type
print('pre_norm: ', self.pre_norm)
self.delay_load = delay_load
self.load_model()
def load_model(self):
if 'xxlarge' in self.vision_tower_name:
if self.delay_load:
self.vision_tower = convnext_xxlarge(pretrained=False)
else:
self.vision_tower = convnext_xxlarge(self.vision_tower_name)
setattr(self.vision_tower, 'hidden_size', 3072)
elif os.path.exists(self.vision_tower_name):
self.vision_tower = torch.load(self.vision_tower_name)
else:
assert False, 'Not implemented'
self.vision_tower = self.vision_tower.to(torch.bfloat16)
if self.freeze_vision:
self.vision_tower.requires_grad_(False)
# if self.vision_tower.grad_checkpointing:
for s in self.vision_tower.stages:
s.grad_checkpointing = True
self.is_loaded = True
def feature_select(self, image_forward_outs):
if self.select_layer>100:
image_features = image_forward_outs[-4:]
else:
image_features = image_forward_outs[-1]
return image_features
def forward_features(self, x):
x = self.vision_tower.stem(x)
image_forward_out=[]
for blk in self.vision_tower.stages:
x = blk(x)
b,c,h,w=x.shape
image_forward_out.append(x.view(b,c,-1).transpose(1,2))
return image_forward_out
def forward(self, images):
if self.freeze_vision:
with torch.no_grad():
image_features = self._forward_images(images)
else:
image_features = self._forward_images(images)
return image_features
def _forward_images(self, images):
if type(images) is list:
image_features = []
for image in images:
if self.pre_norm == 'siglip':
dtype = image.dtype
image = slip_to_clip(image.to(torch.float32)).to(dtype)
image_forward_out = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
image_feature = self.feature_select(image_forward_out)
image_features.append(image_feature)
else:
if self.pre_norm == 'siglip':
dtype = images.dtype
images = slip_to_clip(images.to(torch.float32)).to(dtype)
image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype))
image_features = self.feature_select(image_forward_outs)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return next(self.vision_tower.parameters()).dtype
@property
def device(self):
return next(self.vision_tower.parameters()).device
@property
def config(self):
assert NotImplementedError
pass
@property
def num_attention_heads(self):
# as constant
return 16
@property
def num_layers(self):
# as constant
return 4
@property
def hidden_size(self):
return self.vision_tower.hidden_size
@property
def num_patches(self):
return (self.input_image_size // self.patch_embed.patch_size[0]) ** 2
class ConvNextFPNVisionTower(nn.Module):
def __init__(self,
vision_tower,
args,
fpn_target_level=1,
fpn_layer_idx=[1,2,3],
fpn_input_dim=[768,1536,3072],
delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower.replace('-fpn', 'fpn')
self.freeze_vision = getattr(args, "frozen_backbone", True)
# self.input_image_size = getattr(args, "vision_tower_input_size", 1024)
self.input_image_size = 1024 # hardcode
self.select_layer = args.mm_vision_select_layer # no effect
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.need_fpn = True
self.fpn_layer_idx = fpn_layer_idx # [1, 2, 3] # x8, x16, x32
self.fpn_input_dim = [768, 1536, 3072]
self.delay_load = delay_load
self.load_model()
def load_model(self):
if self.is_loaded:
return
self.image_processor = CLIPImageProcessor(**cfg)
if 'xxlarge' in self.vision_tower_name:
self.vision_tower = convnext_xxlarge(self.vision_tower_name)
setattr(self.vision_tower, 'hidden_size', self.fpn_input_dim)
# setattr(self.vision_tower, 'hidden_size', 3072)
else:
self.vision_tower = convnext_large_mlp(self.vision_tower_name)
setattr(self.vision_tower, 'hidden_size', 1536)
if self.freeze_vision:
self.vision_tower.requires_grad_(False)
# if self.vision_tower.grad_checkpointing:
for s in self.vision_tower.stages:
s.grad_checkpointing = True
if self.input_image_size is not None:
self.image_processor.size=self.input_image_size
self.image_processor.crop_size={
'height':self.input_image_size,
'width': self.input_image_size
}
self.is_loaded = True
@torch.no_grad()
def forward_features(self, x):
x = self.vision_tower.stem(x)
image_forward_out=[]
for blk in self.vision_tower.stages:
x = blk(x)
image_forward_out.append(x)
return image_forward_out
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
image_features.append(image_feature)
else:
image_features = self.forward_features(images.to(device=self.device, dtype=self.dtype))
image_features = [image_features[idx] for idx in self.fpn_layer_idx]
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return next(self.vision_tower.parameters()).dtype
@property
def device(self):
return next(self.vision_tower.parameters()).device
@property
def config(self):
assert NotImplementedError
pass
@property
def num_attention_heads(self):
# as constant
return 16
@property
def num_layers(self):
# as constant
return 4
@property
def hidden_size(self):
return self.vision_tower.hidden_size
@property
def num_patches(self):
return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2
if __name__ == '__main__':
COMBINED_STD = [s_slip / s_clip for s_slip, s_clip in zip(STD_SigLIP, STD_CLIP)]
COMBINED_MEAN = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SigLIP, MEAN_CLIP, STD_CLIP)]
# 定义合并的归一化变换
combined_normalize = T.Normalize(mean=COMBINED_MEAN, std=COMBINED_STD)
x = torch.randn(1, 3, 256, 256).cuda()
a = normalize_clip(x).to(torch.bfloat16)
b = normalize_siglip(x).to(torch.bfloat16)
c = denormalize_siglip(b.to(torch.float32))
c2 = normalize_clip(c).to(torch.bfloat16)
c3 = combined_normalize(b)
print((c-x).abs().max())
print((c2-a).abs().max())
print((c3-a).abs().max())
from IPython import embed
embed()
exit()