Spaces:
Paused
Paused
import collections | |
import math | |
import types | |
import typing | |
import torch | |
import torch.nn as nn | |
from torchvision import models, transforms | |
# Try to import CLIP, but handle import errors gracefully | |
try: | |
import clip | |
CLIP_AVAILABLE = True | |
except ImportError: | |
print("Warning: CLIP not available, using FashionCLIP fallback") | |
CLIP_AVAILABLE = False | |
# code lifted from CLIPasso | |
# For ViT | |
class CLIPVisualEncoder(nn.Module): | |
def __init__(self, model_name, stride, device): | |
super().__init__() | |
self.load_model(model_name, device) | |
self.old_stride = self.model.conv1.stride[0] | |
self.new_stride = stride | |
self.patch_vit_resolution(stride) | |
for i in range(12): # 12 resblocks in VIT visual transformer | |
self.model.transformer.resblocks[i].register_forward_hook( | |
self.make_hook(i)) | |
def load_model(self, model_name, device): | |
if CLIP_AVAILABLE: | |
try: | |
model, preprocess = clip.load(model_name, device=device) | |
self.model = model.visual | |
self.mean = torch.tensor(preprocess.transforms[-1].mean, device=device) | |
self.std = torch.tensor(preprocess.transforms[-1].std, device=device) | |
except Exception as e: | |
print(f"Error loading CLIP model: {e}") | |
print("Falling back to FashionCLIP...") | |
self._load_fashion_clip_fallback(device) | |
else: | |
print("CLIP not available, using FashionCLIP fallback...") | |
self._load_fashion_clip_fallback(device) | |
def _load_fashion_clip_fallback(self, device): | |
"""Fallback method using FashionCLIP when regular CLIP is not available""" | |
try: | |
from packages.fashion_clip.fashion_clip.fashion_clip import FashionCLIP | |
fclip = FashionCLIP('fashion-clip') | |
self.model = fclip.model.vision_model | |
# Use standard CLIP mean and std values | |
self.mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device) | |
self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device) | |
print("Successfully loaded FashionCLIP fallback") | |
except Exception as e: | |
print(f"Error loading FashionCLIP fallback: {e}") | |
# Create a dummy model if all else fails | |
self._create_dummy_model(device) | |
def _create_dummy_model(self, device): | |
"""Create a dummy model when all CLIP options fail""" | |
print("Creating dummy CLIP model - functionality will be limited") | |
# Create a simple dummy model structure | |
class DummyModel: | |
def __init__(self): | |
self.conv1 = nn.Conv2d(3, 768, kernel_size=16, stride=16) | |
self.class_embedding = nn.Parameter(torch.randn(768)) | |
self.positional_embedding = nn.Parameter(torch.randn(197, 768)) | |
self.ln_pre = nn.LayerNorm(768) | |
self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(768, 12), 12) | |
self.model = DummyModel() | |
self.mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device) | |
self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device) | |
def _fix_pos_enc(patch_size: int, stride_hw: typing.Tuple[int, int]): | |
def interpolate_pos_encoding(self, x, w, h): | |
npatch = x.shape[1] - 1 | |
N = self.positional_embedding.shape[0] - 1 | |
if npatch == N and w == h: | |
return self.positional_embedding | |
class_pos_embed = self.positional_embedding[:1].type(x.dtype) | |
patch_pos_embed = self.positional_embedding[1:].type(x.dtype) | |
dim = x.shape[-1] | |
w0 = 1 + (w - patch_size) // stride_hw[1] | |
h0 = 1 + (h - patch_size) // stride_hw[0] | |
assert (w0 * h0 == npatch) | |
w0, h0 = w0 + 0.1, h0 + 0.1 | |
patch_pos_embed = torch.nn.functional.interpolate( | |
patch_pos_embed.reshape(int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(2, 0, 1).unsqueeze(0), | |
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), | |
mode='bicubic', | |
align_corners=False, recompute_scale_factor=False, | |
).squeeze() | |
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] | |
patch_pos_embed = patch_pos_embed.permute(1, 2, 0).view(1, -1, dim) | |
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) | |
return interpolate_pos_encoding | |
def patch_vit_resolution(self, stride): | |
patch_size = self.model.conv1.stride[0] | |
if stride == patch_size: | |
return | |
stride = (stride, stride) | |
assert all([(patch_size // s_) * s_ == patch_size for s_ in stride]) | |
self.model.conv1.stride = stride | |
self.model.interpolate_pos_encoding = types.MethodType(CLIPVisualEncoder._fix_pos_enc(patch_size, stride), self.model) | |
def dtype(self): | |
return self.model.conv1.weight.dtype | |
def make_hook(self, name): | |
def hook(module, input, output): | |
if len(output.shape) == 3: | |
self.featuremaps[name] = output.permute( | |
1, 0, 2) # LND -> NLD bs, smth, 768 | |
else: | |
self.featuremaps[name] = output | |
return hook | |
def forward(self, x, preprocess=False): | |
self.featuremaps = collections.OrderedDict() | |
if preprocess: | |
x = (x - self.mean[None, :, None, None]) / self.std[None, :, None, None] | |
B, C, W, H = x.shape | |
x = self.model.conv1(x.type(self.dtype)) # shape = [*, width, grid, grid] | |
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] | |
x = x + self.model.interpolate_pos_encoding(x, W, H) | |
x = self.model.ln_pre(x) | |
x = x.permute(1, 0, 2) | |
x = self.model.transformer(x) | |
# remove cls | |
featuremaps = [self.featuremaps[k].permute(0, 2, 1)[..., 1:] for k in range(12)] | |
return featuremaps | |