Gent (PG/R - Comp Sci & Elec Eng)
Add application file
460258c
raw
history blame
3.96 kB
import gradio as gr
import torch
import vision_transformer as models
import cv2
from torch import nn
from utils import load_pretrained_weights
class PatchEmbedding:
"""
该类加载了预训练的VIT_Base模型,可以对输入图像生成图像的patch token。
Args:
pretrained_weights (str): 预训练权重文件的路径。
arch (str, optional): 模型使用的体系结构。默认为“vit_base”。
patch_size (int, optional): 图像中提取的patch的大小。默认值为16。
Attributes:
model: 图像嵌入模型。
embed_dim (int): 图像嵌入的维度。
Methods:
load_pretrained_weights(pretrained_weights): 载入预训练的权重到模型中。
get_representations(image_path, tfms, denormalize): 为输入图像生成patch token。
"""
def __init__(self, pretrained_weights, arch='vit_base', patch_size=16):
self.model = models.__dict__[arch](patch_size=patch_size, num_classes=0)
self.embed_dim = self.model.embed_dim
self.model.eval().requires_grad_(False)
self.load_pretrained_weights(pretrained_weights)
from torchvision import transforms
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
self.tfms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
def load_pretrained_weights(self, pretrained_weights):
load_pretrained_weights(self.model, pretrained_weights)
def get_representation(self, image):
"""
生成输入图像的patch token。
Args:
image_path (str): 输入图像的路径。
Returns:
patch_tokens (ndarray): 表示生成的patch token的数组: N, C。
"""
img = self.tfms(image)
x = img[None,:]
tokens = self.model.forward_features(x)[0] # N - 1, C
tokens = nn.functional.normalize(tokens, dim=-1, p=2).numpy()
cls_token = tokens[0] # C
patch_tokens = tokens[1:] # N - 1, C
return cls_token, patch_tokens
def __call__(self, x):
return self.model.forward_features(x)
default_shape = (224,224)
embedding = PatchEmbedding('weights/mmc.pth')
def classify(query_image, support_image):
# Your classification code here
q_cls = embedding.get_representation(query_image)[0]
s_cls = embedding.get_representation(support_image)[0]
sim = (q_cls*s_cls).sum()*100
return f"{sim:.2f}"
def segment(threshold, input):
# Your segmentation code here
image = input['image']
mask = input['mask']
patch_tokens = embedding.get_representation(image)[1]
select = (cv2.resize(mask[:,:,0],(14,14))>0).flatten()
q_pat = patch_tokens[select].mean(0) # C
sim = patch_tokens @ q_pat[:,None] # N,1
mask = (sim.reshape(14,14) > threshold).astype('float')
mask = cv2.resize(mask,(224,224))
ans = image*mask[:,:,None]
return ans.astype('uint8')
classification_tab = gr.Interface(
fn=classify,
inputs=[
# gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8),
gr.inputs.Image(label="Query Image",shape=default_shape),
gr.inputs.Image(label="Support Image",shape=default_shape)
],
outputs=gr.outputs.Textbox(label="Prediction"),
title="Classification"
)
segmentation_tab = gr.Interface(
fn=segment,
inputs=[
gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8),
gr.inputs.Image(label="Input Image",tool="sketch",shape=default_shape)
],
outputs=gr.outputs.Image('numpy',label='Segmentation'),
title="Segmentation"
)
interface = gr.TabbedInterface(
[classification_tab, segmentation_tab],
["Classification", "Segmentation"]
# layout="horizontal"
)
interface.launch()