import gradio as gr import torch import vision_transformer as models import cv2 from torch import nn from utils import load_pretrained_weights class PatchEmbedding: """ 该类加载了预训练的VIT_Base模型,可以对输入图像生成图像的patch token。 Args: pretrained_weights (str): 预训练权重文件的路径。 arch (str, optional): 模型使用的体系结构。默认为“vit_base”。 patch_size (int, optional): 图像中提取的patch的大小。默认值为16。 Attributes: model: 图像嵌入模型。 embed_dim (int): 图像嵌入的维度。 Methods: load_pretrained_weights(pretrained_weights): 载入预训练的权重到模型中。 get_representations(image_path, tfms, denormalize): 为输入图像生成patch token。 """ def __init__(self, pretrained_weights, arch='vit_base', patch_size=16): self.model = models.__dict__[arch](patch_size=patch_size, num_classes=0) self.embed_dim = self.model.embed_dim self.model.eval().requires_grad_(False) self.load_pretrained_weights(pretrained_weights) from torchvision import transforms IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) self.tfms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), ]) def load_pretrained_weights(self, pretrained_weights): load_pretrained_weights(self.model, pretrained_weights) def get_representation(self, image): """ 生成输入图像的patch token。 Args: image_path (str): 输入图像的路径。 Returns: patch_tokens (ndarray): 表示生成的patch token的数组: N, C。 """ img = self.tfms(image) x = img[None,:] tokens = self.model.forward_features(x)[0] # N - 1, C tokens = nn.functional.normalize(tokens, dim=-1, p=2).numpy() cls_token = tokens[0] # C patch_tokens = tokens[1:] # N - 1, C return cls_token, patch_tokens def __call__(self, x): return self.model.forward_features(x) default_shape = (224,224) embedding = PatchEmbedding('weights/mmc.pth') def classify(query_image, support_image): # Your classification code here q_cls = embedding.get_representation(query_image)[0] s_cls = embedding.get_representation(support_image)[0] sim = (q_cls*s_cls).sum()*100 return f"{sim:.2f}" def segment(threshold, input): # Your segmentation code here image = input['image'] mask = input['mask'] patch_tokens = embedding.get_representation(image)[1] select = (cv2.resize(mask[:,:,0],(14,14))>0).flatten() q_pat = patch_tokens[select].mean(0) # C sim = patch_tokens @ q_pat[:,None] # N,1 mask = (sim.reshape(14,14) > threshold).astype('float') mask = cv2.resize(mask,(224,224)) ans = image*mask[:,:,None] return ans.astype('uint8') classification_tab = gr.Interface( fn=classify, inputs=[ # gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8), gr.inputs.Image(label="Query Image",shape=default_shape), gr.inputs.Image(label="Support Image",shape=default_shape) ], outputs=gr.outputs.Textbox(label="Prediction"), title="Classification" ) segmentation_tab = gr.Interface( fn=segment, inputs=[ gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8), gr.inputs.Image(label="Input Image",tool="sketch",shape=default_shape) ], outputs=gr.outputs.Image('numpy',label='Segmentation'), title="Segmentation" ) with gr.Blocks() as app: gr.Markdown(""" @misc{wu2023masked, title={Masked Momentum Contrastive Learning for Zero-shot Semantic Understanding}, author={Jiantao Wu and Shentong Mo and Muhammad Awais and Sara Atito and Zhenhua Feng and Josef Kittler}, year={2023}, eprint={2308.11448}, archivePrefix={arXiv}, primaryClass={cs.CV} }""") interface = gr.TabbedInterface( [classification_tab, segmentation_tab], ["Classification", "Segmentation"] # layout="horizontal" ) app.launch()