Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from sat.model import ViTModel, BaseModel | |
from sat.model import BaseMixin | |
from sat import AutoModel | |
from copy import deepcopy | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
class LNFinalyMixin(BaseMixin): | |
def __init__(self, hidden_size): | |
super().__init__() | |
self.ln_vision = nn.LayerNorm(hidden_size) | |
def final_forward(self, logits, **kw_args): | |
return self.ln_vision(logits) | |
class EVAViT(ViTModel): | |
def __init__(self, args, transformer=None, parallel_output=True, **kwargs): | |
super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs) | |
self.del_mixin("cls") | |
self.add_mixin("cls", LNFinalyMixin(args.hidden_size)) | |
def forward(self, image): | |
batch_size = image.size(0) | |
input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=image.device) | |
attention_mask = torch.tensor([[1.]], dtype=image.dtype, device=image.device) | |
return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, image=image) | |
class QFormer(BaseModel): | |
def __init__(self, args, transformer=None, parallel_output=True, **kwargs): | |
super().__init__(args, transformer=transformer, parallel_output=parallel_output, activation_func=nn.functional.gelu, **kwargs) | |
self.transformer.position_embeddings = None | |
def final_forward(self, logits, **kw_args): | |
return logits | |
def position_embedding_forward(self, position_ids, **kw_args): | |
return None | |
def forward(self, encoder_outputs): | |
batch_size = encoder_outputs.size(0) | |
input_ids = torch.arange(32, dtype=torch.long, device=encoder_outputs.device).unsqueeze(0).expand(batch_size, -1) | |
attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device) | |
cross_attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device) | |
return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask) | |
class BLIP2(torch.nn.Module): | |
def __init__(self, eva_args, qformer_args, vit=None, qformer=None, **kwargs): | |
super().__init__() | |
if vit is not None: | |
self.vit = vit | |
else: | |
self.vit = EVAViT(EVAViT.get_args(**eva_args)) | |
if qformer is not None: | |
self.qformer = qformer | |
else: | |
self.qformer = QFormer(QFormer.get_args(**qformer_args)) | |
self.glm_proj = nn.Linear(768, 4096).to(self.qformer.parameters().__next__().device).to(self.qformer.parameters().__next__().dtype) | |
def forward(self, image, **kwargs): | |
enc = self.vit(image)[0] | |
out = self.qformer(enc)[0] | |
return self.glm_proj(out) | |
class BlipImageBaseProcessor(): | |
def __init__(self, mean=None, std=None): | |
if mean is None: | |
mean = (0.48145466, 0.4578275, 0.40821073) | |
if std is None: | |
std = (0.26862954, 0.26130258, 0.27577711) | |
self.normalize = transforms.Normalize(mean, std) | |
class BlipImageEvalProcessor(BlipImageBaseProcessor): | |
def __init__(self, image_size=384, mean=None, std=None): | |
super().__init__(mean=mean, std=std) | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize( | |
(image_size, image_size), interpolation=InterpolationMode.BICUBIC | |
), | |
transforms.ToTensor(), | |
self.normalize, | |
] | |
) | |
def __call__(self, item): | |
return self.transform(item) | |