HienK64BKHN's picture
Update app.py
905f281 verified
import gradio as gr
import torch
from Unet import UNet
import torchvision
from torchvision.transforms import functional as f
import os
from timeit import default_timer as timer
device = 'cpu'
model = UNet(device=device, in_channels=3, num_classes=3)
model.load_state_dict(torch.load("./data/models/Unet_v1.pth", map_location=torch.device('cpu')))
image_transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(size=(128, 128)),
torchvision.transforms.ToTensor()
])
def predict(img):
start_time = timer()
img_transformed = image_transforms(img).to(device)
model.eval()
with torch.inference_mode():
y_logits = model(img_transformed.unsqueeze(dim=0)).squeeze(dim=0)
predicted_label = torch.argmax(y_logits, dim=0).to('cpu')
for i in range(3):
for j in range(128):
for z in range(128):
img_transformed[i][j][z] = predicted_label[j][z]
img_transformed = f.to_pil_image(img_transformed)
return img_transformed, round((timer() - start_time), 3)
title = "Animal Segmentation"
description = "An UNet* feature extractor computer vision model to segment animal in an image.\nModel works more precisely on an image that only contains just one animal."
article = "U-Net: Convolutional Networks for Biomedical Image Segmentation (https://arxiv.org/abs/1505.04597)"
example_list = [["examples/" + example] for example in os.listdir("examples")]
demo = gr.Interface(fn=predict, # mapping function from input to output
inputs=gr.Image(type="pil"), # what are the inputs?
outputs=[gr.Image(label="Segmentation"), # what are the outputs?
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
# Create examples list from "examples/" directory
examples=example_list,
title=title,
description=description,
article=article)
demo.launch()