Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import vision_transformer as models | |
import cv2 | |
from torch import nn | |
from utils import load_pretrained_weights | |
class PatchEmbedding: | |
""" | |
该类加载了预训练的VIT_Base模型,可以对输入图像生成图像的patch token。 | |
Args: | |
pretrained_weights (str): 预训练权重文件的路径。 | |
arch (str, optional): 模型使用的体系结构。默认为“vit_base”。 | |
patch_size (int, optional): 图像中提取的patch的大小。默认值为16。 | |
Attributes: | |
model: 图像嵌入模型。 | |
embed_dim (int): 图像嵌入的维度。 | |
Methods: | |
load_pretrained_weights(pretrained_weights): 载入预训练的权重到模型中。 | |
get_representations(image_path, tfms, denormalize): 为输入图像生成patch token。 | |
""" | |
def __init__(self, pretrained_weights, arch='vit_base', patch_size=16): | |
self.model = models.__dict__[arch](patch_size=patch_size, num_classes=0) | |
self.embed_dim = self.model.embed_dim | |
self.model.eval().requires_grad_(False) | |
self.load_pretrained_weights(pretrained_weights) | |
from torchvision import transforms | |
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
self.tfms = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), | |
]) | |
def load_pretrained_weights(self, pretrained_weights): | |
load_pretrained_weights(self.model, pretrained_weights) | |
def get_representation(self, image): | |
""" | |
生成输入图像的patch token。 | |
Args: | |
image_path (str): 输入图像的路径。 | |
Returns: | |
patch_tokens (ndarray): 表示生成的patch token的数组: N, C。 | |
""" | |
img = self.tfms(image) | |
x = img[None,:] | |
tokens = self.model.forward_features(x)[0] # N - 1, C | |
tokens = nn.functional.normalize(tokens, dim=-1, p=2).numpy() | |
cls_token = tokens[0] # C | |
patch_tokens = tokens[1:] # N - 1, C | |
return cls_token, patch_tokens | |
def __call__(self, x): | |
return self.model.forward_features(x) | |
default_shape = (224,224) | |
embedding = PatchEmbedding('weights/mmc.pth') | |
def classify(query_image, support_image): | |
# Your classification code here | |
q_cls = embedding.get_representation(query_image)[0] | |
s_cls = embedding.get_representation(support_image)[0] | |
sim = (q_cls*s_cls).sum()*100 | |
return f"{sim:.2f}" | |
def segment(threshold, input): | |
# Your segmentation code here | |
image = input['image'] | |
mask = input['mask'] | |
patch_tokens = embedding.get_representation(image)[1] | |
select = (cv2.resize(mask[:,:,0],(14,14))>0).flatten() | |
q_pat = patch_tokens[select].mean(0) # C | |
sim = patch_tokens @ q_pat[:,None] # N,1 | |
mask = (sim.reshape(14,14) > threshold).astype('float') | |
mask = cv2.resize(mask,(224,224)) | |
ans = image*mask[:,:,None] | |
return ans.astype('uint8') | |
classification_tab = gr.Interface( | |
fn=classify, | |
inputs=[ | |
# gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8), | |
gr.inputs.Image(label="Query Image",shape=default_shape), | |
gr.inputs.Image(label="Support Image",shape=default_shape) | |
], | |
outputs=gr.outputs.Textbox(label="Prediction"), | |
title="Classification" | |
) | |
segmentation_tab = gr.Interface( | |
fn=segment, | |
inputs=[ | |
gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8), | |
gr.inputs.Image(label="Input Image",tool="sketch",shape=default_shape) | |
], | |
outputs=gr.outputs.Image('numpy',label='Segmentation'), | |
title="Segmentation" | |
) | |
interface = gr.TabbedInterface( | |
[classification_tab, segmentation_tab], | |
["Classification", "Segmentation"] | |
# layout="horizontal" | |
) | |
interface.launch() |