|
import torch, os |
|
import torch.nn as nn |
|
from timm import create_model |
|
from transformers import CLIPImageProcessor |
|
from .convnext import convnext_xxlarge |
|
from torch.utils.checkpoint import checkpoint |
|
import torch |
|
from torchvision import transforms as T |
|
from PIL import Image |
|
|
|
|
|
|
|
cfg={ |
|
"crop_size": 256, |
|
"do_center_crop": True, |
|
"do_normalize": True, |
|
"do_resize": True, |
|
"feature_extractor_type": "CLIPFeatureExtractor", |
|
"image_mean": [ |
|
0.48145466, |
|
0.4578275, |
|
0.40821073 |
|
], |
|
"image_std": [ |
|
0.26862954, |
|
0.26130258, |
|
0.27577711 |
|
], |
|
"resample": 3, |
|
"size": 256 |
|
} |
|
|
|
|
|
|
|
MEAN_SLIP = [0.5, 0.5, 0.5] |
|
STD_SLIP = [0.5, 0.5, 0.5] |
|
|
|
MEAN_CLIP = [0.48145466, 0.4578275, 0.40821073] |
|
STD_CLIP = [0.26862954, 0.26130258, 0.27577711] |
|
|
|
|
|
a = [s_slip / s_clip for s_slip, s_clip in zip(STD_SLIP, STD_CLIP)] |
|
b = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SLIP, MEAN_CLIP, STD_CLIP)] |
|
|
|
|
|
class SlipToClipTransform: |
|
def __init__(self, a, b): |
|
self.a = torch.tensor(a).view(-1, 1, 1) |
|
self.b = torch.tensor(b).view(-1, 1, 1) |
|
|
|
def __call__(self, x_slip): |
|
return x_slip * self.a.to(x_slip.device) + self.b.to(x_slip.device) |
|
slip_to_clip = SlipToClipTransform(a, b) |
|
|
|
class ConvNextVisionTower(nn.Module): |
|
def __init__(self, vision_tower, args, delay_load=False, normalize_type=None): |
|
super().__init__() |
|
|
|
self.is_loaded = False |
|
self.freeze_vision=args.freeze_vision |
|
self.input_image_size=args.input_image_size |
|
self.vision_tower_name = vision_tower |
|
self.name = 'convnext' |
|
self.select_layer = args.mm_vision_select_layer |
|
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') |
|
self.pre_norm = normalize_type |
|
|
|
print('pre_norm: ', self.pre_norm) |
|
self.delay_load = delay_load |
|
self.load_model() |
|
|
|
def load_model(self): |
|
if 'xxlarge' in self.vision_tower_name: |
|
if self.delay_load: |
|
self.vision_tower = convnext_xxlarge(pretrained=False) |
|
else: |
|
self.vision_tower = convnext_xxlarge(self.vision_tower_name) |
|
setattr(self.vision_tower, 'hidden_size', 3072) |
|
elif os.path.exists(self.vision_tower_name): |
|
self.vision_tower = torch.load(self.vision_tower_name) |
|
else: |
|
assert False, 'Not implemented' |
|
|
|
|
|
self.vision_tower = self.vision_tower.to(torch.bfloat16) |
|
|
|
if self.freeze_vision: |
|
self.vision_tower.requires_grad_(False) |
|
|
|
|
|
for s in self.vision_tower.stages: |
|
s.grad_checkpointing = True |
|
|
|
self.is_loaded = True |
|
|
|
def feature_select(self, image_forward_outs): |
|
|
|
if self.select_layer>100: |
|
image_features = image_forward_outs[-4:] |
|
else: |
|
image_features = image_forward_outs[-1] |
|
return image_features |
|
|
|
def forward_features(self, x): |
|
x = self.vision_tower.stem(x) |
|
image_forward_out=[] |
|
for blk in self.vision_tower.stages: |
|
x = blk(x) |
|
b,c,h,w=x.shape |
|
image_forward_out.append(x.view(b,c,-1).transpose(1,2)) |
|
return image_forward_out |
|
|
|
def forward(self, images): |
|
if self.freeze_vision: |
|
with torch.no_grad(): |
|
image_features = self._forward_images(images) |
|
else: |
|
image_features = self._forward_images(images) |
|
|
|
return image_features |
|
|
|
def _forward_images(self, images): |
|
|
|
if type(images) is list: |
|
image_features = [] |
|
for image in images: |
|
if self.pre_norm == 'siglip': |
|
dtype = image.dtype |
|
image = slip_to_clip(image.to(torch.float32)).to(dtype) |
|
image_forward_out = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) |
|
image_feature = self.feature_select(image_forward_out) |
|
image_features.append(image_feature) |
|
else: |
|
if self.pre_norm == 'siglip': |
|
dtype = images.dtype |
|
images = slip_to_clip(images.to(torch.float32)).to(dtype) |
|
image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype)) |
|
image_features = self.feature_select(image_forward_outs) |
|
|
|
return image_features |
|
|
|
@property |
|
def dummy_feature(self): |
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.vision_tower.parameters()).dtype |
|
|
|
@property |
|
def device(self): |
|
return next(self.vision_tower.parameters()).device |
|
|
|
@property |
|
def config(self): |
|
assert NotImplementedError |
|
pass |
|
|
|
@property |
|
def num_attention_heads(self): |
|
|
|
return 16 |
|
@property |
|
def num_layers(self): |
|
|
|
return 4 |
|
@property |
|
def hidden_size(self): |
|
return self.vision_tower.hidden_size |
|
|
|
@property |
|
def num_patches(self): |
|
return (self.input_image_size // self.patch_embed.patch_size[0]) ** 2 |
|
|
|
|
|
class ConvNextFPNVisionTower(nn.Module): |
|
def __init__(self, |
|
vision_tower, |
|
args, |
|
fpn_target_level=1, |
|
fpn_layer_idx=[1,2,3], |
|
fpn_input_dim=[768,1536,3072], |
|
delay_load=False): |
|
|
|
super().__init__() |
|
|
|
self.is_loaded = False |
|
self.vision_tower_name = vision_tower.replace('-fpn', 'fpn') |
|
self.freeze_vision = getattr(args, "frozen_backbone", True) |
|
|
|
self.input_image_size = 1024 |
|
self.select_layer = args.mm_vision_select_layer |
|
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') |
|
|
|
self.need_fpn = True |
|
self.fpn_layer_idx = fpn_layer_idx |
|
self.fpn_input_dim = [768, 1536, 3072] |
|
self.delay_load = delay_load |
|
self.load_model() |
|
|
|
def load_model(self): |
|
if self.is_loaded: |
|
return |
|
|
|
self.image_processor = CLIPImageProcessor(**cfg) |
|
if 'xxlarge' in self.vision_tower_name: |
|
self.vision_tower = convnext_xxlarge(self.vision_tower_name) |
|
setattr(self.vision_tower, 'hidden_size', self.fpn_input_dim) |
|
|
|
else: |
|
self.vision_tower = convnext_large_mlp(self.vision_tower_name) |
|
setattr(self.vision_tower, 'hidden_size', 1536) |
|
if self.freeze_vision: |
|
self.vision_tower.requires_grad_(False) |
|
|
|
|
|
for s in self.vision_tower.stages: |
|
s.grad_checkpointing = True |
|
|
|
if self.input_image_size is not None: |
|
self.image_processor.size=self.input_image_size |
|
self.image_processor.crop_size={ |
|
'height':self.input_image_size, |
|
'width': self.input_image_size |
|
} |
|
|
|
self.is_loaded = True |
|
|
|
@torch.no_grad() |
|
def forward_features(self, x): |
|
x = self.vision_tower.stem(x) |
|
image_forward_out=[] |
|
for blk in self.vision_tower.stages: |
|
x = blk(x) |
|
image_forward_out.append(x) |
|
return image_forward_out |
|
|
|
@torch.no_grad() |
|
def forward(self, images): |
|
if type(images) is list: |
|
image_features = [] |
|
for image in images: |
|
image_feature = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) |
|
image_features.append(image_feature) |
|
else: |
|
image_features = self.forward_features(images.to(device=self.device, dtype=self.dtype)) |
|
image_features = [image_features[idx] for idx in self.fpn_layer_idx] |
|
|
|
return image_features |
|
|
|
@property |
|
def dummy_feature(self): |
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.vision_tower.parameters()).dtype |
|
|
|
@property |
|
def device(self): |
|
return next(self.vision_tower.parameters()).device |
|
|
|
@property |
|
def config(self): |
|
assert NotImplementedError |
|
pass |
|
|
|
@property |
|
def num_attention_heads(self): |
|
|
|
return 16 |
|
@property |
|
def num_layers(self): |
|
|
|
return 4 |
|
@property |
|
def hidden_size(self): |
|
return self.vision_tower.hidden_size |
|
|
|
@property |
|
def num_patches(self): |
|
return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2 |
|
|
|
if __name__ == '__main__': |
|
COMBINED_STD = [s_slip / s_clip for s_slip, s_clip in zip(STD_SigLIP, STD_CLIP)] |
|
COMBINED_MEAN = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SigLIP, MEAN_CLIP, STD_CLIP)] |
|
|
|
|
|
combined_normalize = T.Normalize(mean=COMBINED_MEAN, std=COMBINED_STD) |
|
x = torch.randn(1, 3, 256, 256).cuda() |
|
a = normalize_clip(x).to(torch.bfloat16) |
|
b = normalize_siglip(x).to(torch.bfloat16) |
|
c = denormalize_siglip(b.to(torch.float32)) |
|
c2 = normalize_clip(c).to(torch.bfloat16) |
|
c3 = combined_normalize(b) |
|
print((c-x).abs().max()) |
|
print((c2-a).abs().max()) |
|
print((c3-a).abs().max()) |
|
from IPython import embed |
|
embed() |
|
exit() |