wb-droid's picture
Initial commit.
493aa40
import gradio as gr
from einops import rearrange
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor, Pad
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
device = "cpu"
class Transformer_dummy(nn.Module):
def __init__(self, dim, mlp_hidden_dim=4098, attention_heads=8, depth=2 ):
super().__init__()
def forward(self, x):
return x
class MyViT(nn.Module):
def __init__(self, image_size, patch_size, dim, n_classes = len(labels_map), device = device, depth=5):
super().__init__()
self.image_size = image_size #height == width
self.patch_size = patch_size #height == width
self.dim = dim # dim of latent space for each patch
self.n_classes = n_classes
self.nh = self.nw = image_size // patch_size
self.n_patches = self.nh * self.nw # number or patches, i.e. NLP's seq len
self.layernorm1 = nn.LayerNorm(self.patch_size**2)
self.ln = nn.Linear(self.patch_size**2, dim)
self.layernorm2 = nn.LayerNorm(dim)
self.pos_encoding = nn.Embedding(self.n_patches, self.dim)
self.transformer = Transformer(dim=self.dim, depth=depth)
#self.proj = nn.Linear(self.dim * self.n_patches, self.n_classes)
self.proj = nn.Linear(self.dim, self.n_classes)
def forward(self, x):
# rearrange 'b c (nh ph) (nw pw) -> b nh nw (c ph pw)'
x = rearrange(x, 'b c (nh ph) (nw pw) -> b nh nw (c ph pw)', nh=self.nh, nw=self.nw)
# rearrange 'b nh nw d -> b (nh nw) d'
x = rearrange(x, 'b nh nw d -> b (nh nw) d')
x = self.layernorm1(x)
x = self.ln(x) #(b n_patches patch_size*patch_size) -> (b n_patches dim)
x = self.layernorm2(x)
pos = self.pos_encoding(torch.arange(0, self.n_patches).to(device))
x = x + pos
x = self.transformer(x)
#x = self.proj(x.view(x.shape[0],-1))
x = self.proj(x.mean(dim=1))
return x
class MLPBlock(nn.Module):
def __init__(self, dim, mlp_hidden_dim=4096, dropout=0.):
super().__init__()
self.layernorm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.proj1 = nn.Linear(dim, mlp_hidden_dim)
self.proj2 = nn.Linear(mlp_hidden_dim, dim)
self.activation = nn.GELU()
def forward(self, x):
x = self.layernorm(x)
x = self.proj1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.proj2(x)
x = self.dropout2(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, attention_heads = 8, depth=2, dropout=0.):
super().__init__()
self.dim = dim
self.attention_heads = attention_heads
self.layernorm = nn.LayerNorm(dim)
self.proj = nn.Linear(dim, 3*dim)
self.attention = nn.Softmax(dim = -1)
self.drop = nn.Dropout(dropout)
def forward(self, x):
x = self.layernorm(x)
q,k,v = self.proj(x).chunk(3, dim=-1)
# rearrange to b, num_heads, seq, head_size
q = rearrange(q, 'b s (nh hs) -> b nh s hs', nh = self.attention_heads)
k = rearrange(k, 'b s (nh hs) -> b nh hs s', nh = self.attention_heads)
v = rearrange(v, 'b s (nh hs) -> b nh s hs', nh = self.attention_heads)
# attention q@kT
x = q@k
# scale
x = x * (k.shape[-1] ** -0.5)
# attention mask not needed
#x = x.mask_fill(torch.ones((1,1, k.shape[-1], k.shape[-1])).tril())
# attention softmax
x = self.attention(x)
# drop out
x = self.drop(x)
# attention q@kT@v
x = x@v
# rearrange to b, seq, (num_heads, head_size)
x = rearrange(x, 'b nh s hs -> b s (nh hs)', nh = self.attention_heads)
return x
class Transformer(nn.Module):
def __init__(self, dim, mlp_hidden_dim=4098, attention_heads=8, depth=5 ):
super().__init__()
self.layernorm = nn.LayerNorm(dim)
self.net = nn.ModuleList([AttentionBlock(dim=dim), MLPBlock(dim=dim)] * depth)
def forward(self, x):
for m in self.net:
x = x + m(x)
x = self.layernorm(x)
return x
data_test = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=transforms.Compose([Pad([2,2,2,2]), ToTensor()]))
model = torch.load("vit01.pt", map_location=torch.device('cpu')).to("cpu")
model.eval()
@torch.no_grad()
def generate():
dl_test = torch.utils.data.DataLoader(data_test, batch_size=1, shuffle=True, num_workers=4)
image_eval, label_eval = next(iter(dl_test))
image_eval = image_eval - 0.5
logits = model(image_eval)
probability = torch.nn.functional.softmax(logits, dim=1)[-1]
n_topk = 3
topk = probability.topk(n_topk, dim=-1)
result = "Predictions (top 3):\n"
print(topk.indices)
for idx in range(n_topk):
print(topk.indices[idx].item())
label = labels_map[topk.indices[idx].item()]
prob = topk.values[idx].item()
print(prob)
label = label + ":"
label = f'{label: <12}'
result = result + label + " " + f'{prob*100:.2f}' + "%\n"
return (image_eval+0.5)[0].squeeze().detach().numpy(), result
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ViT (Vision Transformer) Model</h1>""")
gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""")
session_data = gr.State([])
sampling_button = gr.Button("Random image and zero-shot classification")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("""<h3 align="left">Random image</h1>""")
gr_image = gr.Image(height=250,width=200)
with gr.Column(scale=2):
gr.HTML("""<h3 align="left">Classification</h1>""")
gr_text = gr.Text(label="Classification")
sampling_button.click(
generate,
[],
[gr_image, gr_text],
)
demo.queue().launch(share=False, inbrowser=True)