import gradio as gr from einops import rearrange import torch from torch import nn import torchvision from torchvision import transforms from torchvision.transforms import ToTensor, Pad labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot", } device = "cpu" class Transformer_dummy(nn.Module): def __init__(self, dim, mlp_hidden_dim=4098, attention_heads=8, depth=2 ): super().__init__() def forward(self, x): return x class MyViT(nn.Module): def __init__(self, image_size, patch_size, dim, n_classes = len(labels_map), device = device, depth=5): super().__init__() self.image_size = image_size #height == width self.patch_size = patch_size #height == width self.dim = dim # dim of latent space for each patch self.n_classes = n_classes self.nh = self.nw = image_size // patch_size self.n_patches = self.nh * self.nw # number or patches, i.e. NLP's seq len self.layernorm1 = nn.LayerNorm(self.patch_size**2) self.ln = nn.Linear(self.patch_size**2, dim) self.layernorm2 = nn.LayerNorm(dim) self.pos_encoding = nn.Embedding(self.n_patches, self.dim) self.transformer = Transformer(dim=self.dim, depth=depth) #self.proj = nn.Linear(self.dim * self.n_patches, self.n_classes) self.proj = nn.Linear(self.dim, self.n_classes) def forward(self, x): # rearrange 'b c (nh ph) (nw pw) -> b nh nw (c ph pw)' x = rearrange(x, 'b c (nh ph) (nw pw) -> b nh nw (c ph pw)', nh=self.nh, nw=self.nw) # rearrange 'b nh nw d -> b (nh nw) d' x = rearrange(x, 'b nh nw d -> b (nh nw) d') x = self.layernorm1(x) x = self.ln(x) #(b n_patches patch_size*patch_size) -> (b n_patches dim) x = self.layernorm2(x) pos = self.pos_encoding(torch.arange(0, self.n_patches).to(device)) x = x + pos x = self.transformer(x) #x = self.proj(x.view(x.shape[0],-1)) x = self.proj(x.mean(dim=1)) return x class MLPBlock(nn.Module): def __init__(self, dim, mlp_hidden_dim=4096, dropout=0.): super().__init__() self.layernorm = nn.LayerNorm(dim) self.dropout = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.proj1 = nn.Linear(dim, mlp_hidden_dim) self.proj2 = nn.Linear(mlp_hidden_dim, dim) self.activation = nn.GELU() def forward(self, x): x = self.layernorm(x) x = self.proj1(x) x = self.activation(x) x = self.dropout(x) x = self.proj2(x) x = self.dropout2(x) return x class AttentionBlock(nn.Module): def __init__(self, dim, attention_heads = 8, depth=2, dropout=0.): super().__init__() self.dim = dim self.attention_heads = attention_heads self.layernorm = nn.LayerNorm(dim) self.proj = nn.Linear(dim, 3*dim) self.attention = nn.Softmax(dim = -1) self.drop = nn.Dropout(dropout) def forward(self, x): x = self.layernorm(x) q,k,v = self.proj(x).chunk(3, dim=-1) # rearrange to b, num_heads, seq, head_size q = rearrange(q, 'b s (nh hs) -> b nh s hs', nh = self.attention_heads) k = rearrange(k, 'b s (nh hs) -> b nh hs s', nh = self.attention_heads) v = rearrange(v, 'b s (nh hs) -> b nh s hs', nh = self.attention_heads) # attention q@kT x = q@k # scale x = x * (k.shape[-1] ** -0.5) # attention mask not needed #x = x.mask_fill(torch.ones((1,1, k.shape[-1], k.shape[-1])).tril()) # attention softmax x = self.attention(x) # drop out x = self.drop(x) # attention q@kT@v x = x@v # rearrange to b, seq, (num_heads, head_size) x = rearrange(x, 'b nh s hs -> b s (nh hs)', nh = self.attention_heads) return x class Transformer(nn.Module): def __init__(self, dim, mlp_hidden_dim=4098, attention_heads=8, depth=5 ): super().__init__() self.layernorm = nn.LayerNorm(dim) self.net = nn.ModuleList([AttentionBlock(dim=dim), MLPBlock(dim=dim)] * depth) def forward(self, x): for m in self.net: x = x + m(x) x = self.layernorm(x) return x data_test = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=transforms.Compose([Pad([2,2,2,2]), ToTensor()])) model = torch.load("vit01.pt", map_location=torch.device('cpu')).to("cpu") model.eval() @torch.no_grad() def generate(): dl_test = torch.utils.data.DataLoader(data_test, batch_size=1, shuffle=True, num_workers=4) image_eval, label_eval = next(iter(dl_test)) image_eval = image_eval - 0.5 logits = model(image_eval) probability = torch.nn.functional.softmax(logits, dim=1)[-1] n_topk = 3 topk = probability.topk(n_topk, dim=-1) result = "Predictions (top 3):\n" print(topk.indices) for idx in range(n_topk): print(topk.indices[idx].item()) label = labels_map[topk.indices[idx].item()] prob = topk.values[idx].item() print(prob) label = label + ":" label = f'{label: <12}' result = result + label + " " + f'{prob*100:.2f}' + "%\n" return (image_eval+0.5)[0].squeeze().detach().numpy(), result with gr.Blocks() as demo: gr.HTML("""

ViT (Vision Transformer) Model

""") gr.HTML("""

trained with FashionMNIST

""") session_data = gr.State([]) sampling_button = gr.Button("Random image and zero-shot classification") with gr.Row(): with gr.Column(scale=1): gr.HTML("""

Random image

""") gr_image = gr.Image(height=250,width=200) with gr.Column(scale=2): gr.HTML("""

Classification

""") gr_text = gr.Text(label="Classification") sampling_button.click( generate, [], [gr_image, gr_text], ) demo.queue().launch(share=False, inbrowser=True)