import collections import math import types import typing import torch import torch.nn as nn from torchvision import models, transforms # Try to import CLIP, but handle import errors gracefully try: import clip CLIP_AVAILABLE = True except ImportError: print("Warning: CLIP not available, using FashionCLIP fallback") CLIP_AVAILABLE = False # code lifted from CLIPasso # For ViT class CLIPVisualEncoder(nn.Module): def __init__(self, model_name, stride, device): super().__init__() self.load_model(model_name, device) self.old_stride = self.model.conv1.stride[0] self.new_stride = stride self.patch_vit_resolution(stride) for i in range(12): # 12 resblocks in VIT visual transformer self.model.transformer.resblocks[i].register_forward_hook( self.make_hook(i)) def load_model(self, model_name, device): if CLIP_AVAILABLE: try: model, preprocess = clip.load(model_name, device=device) self.model = model.visual self.mean = torch.tensor(preprocess.transforms[-1].mean, device=device) self.std = torch.tensor(preprocess.transforms[-1].std, device=device) except Exception as e: print(f"Error loading CLIP model: {e}") print("Falling back to FashionCLIP...") self._load_fashion_clip_fallback(device) else: print("CLIP not available, using FashionCLIP fallback...") self._load_fashion_clip_fallback(device) def _load_fashion_clip_fallback(self, device): """Fallback method using FashionCLIP when regular CLIP is not available""" try: from packages.fashion_clip.fashion_clip.fashion_clip import FashionCLIP fclip = FashionCLIP('fashion-clip') self.model = fclip.model.vision_model # Use standard CLIP mean and std values self.mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device) self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device) print("Successfully loaded FashionCLIP fallback") except Exception as e: print(f"Error loading FashionCLIP fallback: {e}") # Create a dummy model if all else fails self._create_dummy_model(device) def _create_dummy_model(self, device): """Create a dummy model when all CLIP options fail""" print("Creating dummy CLIP model - functionality will be limited") # Create a simple dummy model structure class DummyModel: def __init__(self): self.conv1 = nn.Conv2d(3, 768, kernel_size=16, stride=16) self.class_embedding = nn.Parameter(torch.randn(768)) self.positional_embedding = nn.Parameter(torch.randn(197, 768)) self.ln_pre = nn.LayerNorm(768) self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(768, 12), 12) self.model = DummyModel() self.mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device) self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device) @staticmethod def _fix_pos_enc(patch_size: int, stride_hw: typing.Tuple[int, int]): def interpolate_pos_encoding(self, x, w, h): npatch = x.shape[1] - 1 N = self.positional_embedding.shape[0] - 1 if npatch == N and w == h: return self.positional_embedding class_pos_embed = self.positional_embedding[:1].type(x.dtype) patch_pos_embed = self.positional_embedding[1:].type(x.dtype) dim = x.shape[-1] w0 = 1 + (w - patch_size) // stride_hw[1] h0 = 1 + (h - patch_size) // stride_hw[0] assert (w0 * h0 == npatch) w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = torch.nn.functional.interpolate( patch_pos_embed.reshape(int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(2, 0, 1).unsqueeze(0), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode='bicubic', align_corners=False, recompute_scale_factor=False, ).squeeze() assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(1, 2, 0).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) return interpolate_pos_encoding def patch_vit_resolution(self, stride): patch_size = self.model.conv1.stride[0] if stride == patch_size: return stride = (stride, stride) assert all([(patch_size // s_) * s_ == patch_size for s_ in stride]) self.model.conv1.stride = stride self.model.interpolate_pos_encoding = types.MethodType(CLIPVisualEncoder._fix_pos_enc(patch_size, stride), self.model) @property def dtype(self): return self.model.conv1.weight.dtype def make_hook(self, name): def hook(module, input, output): if len(output.shape) == 3: self.featuremaps[name] = output.permute( 1, 0, 2) # LND -> NLD bs, smth, 768 else: self.featuremaps[name] = output return hook def forward(self, x, preprocess=False): self.featuremaps = collections.OrderedDict() if preprocess: x = (x - self.mean[None, :, None, None]) / self.std[None, :, None, None] B, C, W, H = x.shape x = self.model.conv1(x.type(self.dtype)) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.model.interpolate_pos_encoding(x, W, H) x = self.model.ln_pre(x) x = x.permute(1, 0, 2) x = self.model.transformer(x) # remove cls featuremaps = [self.featuremaps[k].permute(0, 2, 1)[..., 1:] for k in range(12)] return featuremaps