Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update MAE-ViT image reconstruction descriptions and add links to model card and GitHub repository
		f1a7938
		
		| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| import random | |
| from einops import rearrange | |
| import matplotlib.pyplot as plt | |
| from torchvision.transforms import v2 | |
| from model import MAE_ViT, MAE_Encoder, MAE_Decoder, MAE_Encoder_FeatureExtractor | |
| path_1 = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']] | |
| path_2 = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']] | |
| path_3 = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']] | |
| device = torch.device("cpu") | |
| model_name = "model/no_mode/vit-t-mae-pretrain.pt" | |
| model_no_mode = torch.load(model_name, map_location='cpu') | |
| model_no_mode.eval() | |
| model_no_mode.to(device) | |
| model_name = "model/bottom_25/vit-t-mae-pretrain.pt" | |
| model_pca_mode_bottom = torch.load(model_name, map_location='cpu') | |
| model_pca_mode_bottom.eval() | |
| model_pca_mode_bottom.to(device) | |
| model_name = "model/top_75/vit-t-mae-pretrain.pt" | |
| model_pca_mode_top = torch.load(model_name, map_location='cpu') | |
| model_pca_mode_top.eval() | |
| model_pca_mode_top.to(device) | |
| transform = v2.Compose([ | |
| v2.Resize((96, 96)), | |
| v2.ToTensor(), | |
| v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| # Load and Preprocess the Image | |
| def load_image(image_path, transform): | |
| img = Image.open(image_path).convert('RGB') | |
| img = transform(img).unsqueeze(0) # Add batch dimension | |
| return img | |
| def show_image(img, title): | |
| img = rearrange(img, "c h w -> h w c") | |
| img = (img.cpu().detach().numpy() + 1) / 2 # Normalize to [0, 1] | |
| plt.imshow(img) | |
| plt.axis('off') | |
| plt.title(title) | |
| # Visualize a Single Image | |
| def visualize_single_image_no_mode(image_path): | |
| img = load_image(image_path, transform).to(device) | |
| # Run inference | |
| with torch.no_grad(): | |
| predicted_img, mask = model_no_mode(img) | |
| # Convert the tensor back to a displayable image | |
| # masked image | |
| im_masked = img * (1 - mask) | |
| # MAE reconstruction pasted with visible patches | |
| im_paste = img * (1 - mask) + predicted_img * mask | |
| # remove the batch dimension | |
| img = img[0] | |
| im_masked = im_masked[0] | |
| predicted_img = predicted_img[0] | |
| im_paste = im_paste[0] | |
| # make the plt figure larger | |
| plt.figure(figsize=(18, 8)) | |
| plt.subplot(1, 3, 1) | |
| show_image(img, "original") | |
| plt.subplot(1, 3, 2) | |
| show_image(im_masked, "masked") | |
| # plt.subplot(1, 4, 3) | |
| # show_image(predicted_img, "reconstruction") | |
| plt.subplot(1, 3, 3) | |
| show_image(im_paste, "reconstruction") | |
| plt.tight_layout() | |
| # convert the plt figure to a numpy array | |
| plt.savefig("output.png") | |
| return np.array(plt.imread("output.png")) | |
| def visualize_single_image_pca_mode_bottom(image_path): | |
| img = load_image(image_path, transform).to(device) | |
| # Run inference | |
| with torch.no_grad(): | |
| predicted_img, mask = model_pca_mode_bottom(img) | |
| # Convert the tensor back to a displayable image | |
| # masked image | |
| im_masked = img * (1 - mask) | |
| # MAE reconstruction pasted with visible patches | |
| im_paste = img * (1 - mask) + predicted_img * mask | |
| # remove the batch dimension | |
| img = img[0] | |
| im_masked = im_masked[0] | |
| predicted_img = predicted_img[0] | |
| im_paste = im_paste[0] | |
| # make the plt figure larger | |
| plt.figure(figsize=(18, 8)) | |
| plt.subplot(1, 3, 1) | |
| show_image(img, "original") | |
| plt.subplot(1, 3, 2) | |
| show_image(im_masked, "masked") | |
| plt.subplot(1, 3, 3) | |
| show_image(predicted_img, "reconstruction") | |
| # plt.subplot(1, 4, 4) | |
| # show_image(im_paste, "reconstruction + visible") | |
| plt.tight_layout() | |
| # convert the plt figure to a numpy array | |
| plt.savefig("output.png") | |
| return np.array(plt.imread("output.png")) | |
| def visualize_single_image_pca_mode_top(image_path): | |
| img = load_image(image_path, transform).to(device) | |
| # Run inference | |
| with torch.no_grad(): | |
| predicted_img, mask = model_pca_mode_top(img) | |
| # Convert the tensor back to a displayable image | |
| # masked image | |
| im_masked = img * (1 - mask) | |
| # MAE reconstruction pasted with visible patches | |
| im_paste = img * (1 - mask) + predicted_img * mask | |
| # remove the batch dimension | |
| img = img[0] | |
| im_masked = im_masked[0] | |
| predicted_img = predicted_img[0] | |
| im_paste = im_paste[0] | |
| # make the plt figure larger | |
| plt.figure(figsize=(18, 8)) | |
| plt.subplot(1, 3, 1) | |
| show_image(img, "original") | |
| plt.subplot(1, 3, 2) | |
| show_image(im_masked, "masked") | |
| plt.subplot(1, 3, 3) | |
| show_image(predicted_img, "reconstruction") | |
| # plt.subplot(1, 4, 4) | |
| # show_image(im_paste, "reconstruction + visible") | |
| plt.tight_layout() | |
| # convert the plt figure to a numpy array | |
| plt.savefig("output.png") | |
| return np.array(plt.imread("output.png")) | |
| inputs_image_1 = [ | |
| gr.components.Image(type="filepath", label="Input Image"), | |
| ] | |
| outputs_image_1 = [ | |
| gr.components.Image(type="numpy", label="Output Image"), | |
| ] | |
| inputs_image_2 = [ | |
| gr.components.Image(type="filepath", label="Input Image"), | |
| ] | |
| outputs_image_2 = [ | |
| gr.components.Image(type="numpy", label="Output Image"), | |
| ] | |
| inputs_image_3 = [ | |
| gr.components.Image(type="filepath", label="Input Image"), | |
| ] | |
| outputs_image_3 = [ | |
| gr.components.Image(type="numpy", label="Output Image"), | |
| ] | |
| inference_no_mode = gr.Interface( | |
| fn=visualize_single_image_no_mode, | |
| inputs=inputs_image_1, | |
| outputs=outputs_image_1, | |
| examples=path_1, | |
| cache_examples = False, | |
| title="MAE-ViT Image Reconstruction", | |
| description="This is a demo of the MAE-ViT model for image reconstruction. The model is trained without PCA mode. It was trained on the STL-10 dataset. Check out the huggingface model card and the github repository for more information. https://huggingface.co/turhancan97/MAE-Models and https://github.com/turhancan97/Learning-by-Reconstruction-with-MAE", | |
| ) | |
| inference_pca_mode_bottom = gr.Interface( | |
| fn=visualize_single_image_pca_mode_bottom, | |
| inputs=inputs_image_2, | |
| outputs=outputs_image_2, | |
| examples=path_2, | |
| title="MAE-ViT Image Reconstruction", | |
| description="This is a demo of the MAE-ViT model for image reconstruction. The model is trained with PCA mode (bottom 25%). It was trained on the STL-10 dataset.", | |
| ) | |
| inference_pca_mode_top = gr.Interface( | |
| fn=visualize_single_image_pca_mode_top, | |
| inputs=inputs_image_3, | |
| outputs=outputs_image_3, | |
| examples=path_3, | |
| title="MAE-ViT Image Reconstruction", | |
| description="This is a demo of the MAE-ViT model for image reconstruction. The model is trained with PCA mode (top 75%). It was trained on the STL-10 dataset.", | |
| ) | |
| gr.TabbedInterface( | |
| [inference_no_mode, inference_pca_mode_bottom, inference_pca_mode_top], | |
| tab_names=['Normal Mode', 'PCA Mode (Bottom 25%)', 'PCA Mode (Top 75%)'] | |
| ).queue().launch() |