import gradio as gr
import torch
from torch import nn
import torchvision
from torchvision.transforms import ToTensor

data_test = torchvision.datasets.FashionMNIST(root=".\data", train = False, transform=ToTensor(), download=True)

class MyVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            # (conv_in)
            nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), # 28, 28
            
            # (down_block_0)
            # (norm1)
            nn.GroupNorm(8, 32, eps=1e-06, affine=True),
            # (conv1)
            nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
            # (norm2): 
            nn.GroupNorm(8, 32, eps=1e-06, affine=True),
            # (dropout): 
            nn.Dropout(p=0.5, inplace=False),
            # (conv2): 
            nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
            # (nonlinearity): 
            nn.SiLU(),
            # (downsamplers)(conv): 
            nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #14, 14


            
            # (down_block_1)
            # (norm1)
            nn.GroupNorm(8, 32, eps=1e-06, affine=True),
            # (conv1)
            nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
            # (norm2): 
            nn.GroupNorm(8, 64, eps=1e-06, affine=True),
            # (dropout): 
            nn.Dropout(p=0.5, inplace=False),
            # (conv2): 
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
            # (nonlinearity): 
            nn.SiLU(),
            # (conv_shortcut): 
            #nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
            # (nonlinearity): 
            nn.SiLU(),            
            # (downsamplers)(conv): 
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #7, 7

            # (conv_norm_out): 
            nn.GroupNorm(16, 64, eps=1e-06, affine=True),
            # (conv_act): 
            nn.SiLU(),
            # (conv_out): 
            nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),

            #nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=3//2), # 14*14
            #nn.ReLU(),
            #nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=3//2), # 7*7    
            #nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            #(conv_in): 
            nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),

            #(norm1): 
            nn.GroupNorm(16, 64, eps=1e-06, affine=True),
            #(conv1): 
            nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            #(norm2): 
            nn.GroupNorm(8, 32, eps=1e-06, affine=True),
            #(dropout): 
            nn.Dropout(p=0.5, inplace=False),
            #(conv2): 
            nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            #(nonlinearity): 
            nn.SiLU(),      

            #(upsamplers): 
            nn.Upsample(scale_factor=2, mode='nearest'), # 14,14

            #(norm1): 
            nn.GroupNorm(8, 32, eps=1e-06, affine=True),
            #(conv1): 
            nn.Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            #(norm2): 
            nn.GroupNorm(8, 16, eps=1e-06, affine=True),
            #(dropout): 
            nn.Dropout(p=0.5, inplace=False),
            #(conv2): 
            nn.Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            #(nonlinearity): 
            nn.SiLU(),      

            #(upsamplers): 
            nn.Upsample(scale_factor=2, mode='nearest'),  # 16, 28, 28          

            #(norm1): 
            nn.GroupNorm(8, 16, eps=1e-06, affine=True),
            #(conv1): 
            nn.Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),  

            nn.Sigmoid()
        )
        
    def forward(self, xb, yb):
        x = self.encoder(xb)
        #print("current:",x.shape)
        x = self.decoder(x)
        #print("current decoder:",x.shape)
        #x = x.flatten(start_dim=1).mean(dim=1, keepdim=True)
        #print(x.shape, xb.shape)
        return x, F.mse_loss(x, xb)

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

myVAE2 = torch.load("checkpoint10.pt", map_location=torch.device('cpu')).to("cpu")
myVAE2.eval()
  
def sample():
    idx = torch.randint(0, len(data_test), (1,))
    print(idx.item())
    print(data_test[idx.item()][0].squeeze().shape)
    img = data_test[idx.item()][0].squeeze()
    img_original = torchvision.transforms.functional.to_pil_image(img)
    img_encoded = torchvision.transforms.functional.to_pil_image(myVAE2.encoder(img[None,None,...]).squeeze())
    img_decoded = torchvision.transforms.functional.to_pil_image(myVAE2.decoder(myVAE2.encoder(img[None,None,...])).squeeze())

    return(img_original,img_encoded,img_decoded, labels_map[data_test[idx.item()][1]])

with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">Variational Autoencoder</h1>""")
    gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""")
    session_data = gr.State([])

    sampling_button = gr.Button("Sample random FashionMNIST image")

    with gr.Row():
        with gr.Column(scale=2):
            #gr.Label("Original image")
            gr.HTML("""<h3 align="left">Original image</h1>""")
            sample_image = gr.Image(height=250,width=200)             
        with gr.Column(scale=2):
            #gr.Label("Encoded image")
            gr.HTML("""<h3 align="left">Encoded image</h1>""")
            encoded_image = gr.Image(height=250,width=200) 
        with gr.Column(scale=2):
            gr.HTML("""<h3 align="left">Decoded image</h1>""")
            #gr.Label("Decoded image")
            decoded_image = gr.Image(height=250,width=200) 
    
    image_label = gr.Label(label = "Image label")
   

    sampling_button.click(
        sample,
        [],
        [sample_image, encoded_image, decoded_image, image_label],
    )

demo.queue().launch(share=False, inbrowser=True)