Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import os
import torch
import logging
import open_clip
import numpy as np
import torch.nn as nn
import torchvision.transforms as T
from utils.registry_class import EMBEDDER
@EMBEDDER.register_class()
class FrozenOpenCLIPEmbedder(nn.Module):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
#"pooled",
"last",
"penultimate"
]
def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77,
freeze=True, layer="last"):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
@EMBEDDER.register_class()
class FrozenOpenCLIPVisualEmbedder(nn.Module):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
#"pooled",
"last",
"penultimate"
]
def __init__(self, pretrained, vit_resolution=(224, 224), arch="ViT-H-14", device="cuda", max_length=77,
freeze=True, layer="last"):
super().__init__()
assert layer in self.LAYERS
model, _, preprocess = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=pretrained)
del model.transformer
self.model = model
data_white = np.ones((vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8)*255
self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
self.device = device
self.max_length = max_length # 77
if freeze:
self.freeze()
self.layer = layer # 'penultimate'
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, image):
# tokens = open_clip.tokenize(text)
z = self.model.encode_image(image.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
@EMBEDDER.register_class()
class FrozenOpenCLIPTextVisualEmbedder(nn.Module):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
#"pooled",
"last",
"penultimate"
]
def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77,
freeze=True, layer="last", **kwargs):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, image=None, text=None):
xi = self.model.encode_image(image.to(self.device)) if image is not None else None
tokens = open_clip.tokenize(text)
xt, x = self.encode_with_transformer(tokens.to(self.device))
return xi, xt, x
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
xt = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection
return xt, x
def encode_image(self, image):
return self.model.visual(image)
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)