MNIST / app.py
GbrlOl's picture
Update app.py
faa9b86 verified
import gradio as gr
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import numpy as np
from PIL import Image
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# Definimos las capas convolucionales
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
# Definimos capas fully connected
self.fc1 = nn.Linear(128 * 3 * 3, 256)
self.fc2 = nn.Linear(256, 10)
# Definimos un max pooling y dropout
self.pool = nn.MaxPool2d(2, 2)
self.dropout = nn.Dropout(0.25)
def forward(self, x):
# Pasamos las entradas por las capas convolucionales y el max pooling
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
# Aplanamos la salida de las capas convolucionales para pasar a fully connected
x = x.view(-1, 128 * 3 * 3)
# Pasamos por las capas fully connected
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
model = CNN().to(device)
# Cargar el modelo en la CPU
model = CNN().to(device)
model.load_state_dict(torch.load("model_mnist_cnn_data_augmentation.pth", map_location=torch.device('cpu')))
def predict(im):
imagen = np.array(im["composite"])
imagen = imagen[:,:,3] #accedo al canal que tiene la iamgen
print(imagen.shape)
print(imagen.dtype)
# parte 2
# Crea la imagen en escala de grises
imagen_pil = Image.fromarray(imagen, mode='L')
img_resize = imagen_pil.resize((28, 28))
# parte 3: A la imagen resize se convierte numpy
img_np = np.array(img_resize)
# parte 4: Normalizar
img_np = img_np.astype(np.float32) / 255.0
# Parte 4: Tensor
img_tensor = torch.from_numpy(img_np)
# Parte 5: A帽adimos el canal al tensor
# img_tensor = img_tensor.unsqueeze(-1)
img_tensor = img_tensor.unsqueeze(0)
print(img_tensor.shape)
print(img_tensor.dtype)
classes = [
"Cero",
"Uno",
"Dos",
"Tres",
"Cuatro",
"Cinco",
"Seis",
"Siete",
"Ocho",
"Nueve",
]
model.eval()
with torch.no_grad():
# Mover la imagen al dispositivo (GPU o CPU)
img_tensor = img_tensor.unsqueeze(0).to(device) # Agregar una dimensi贸n para el batch
print("Dentro del grad")
print("Forma de x: ", img_tensor.shape)
print("Tipo de datos de x: ", img_tensor.dtype)
print("\n")
pred = model(img_tensor)
print("Dentro del model")
print(pred)
print("Forma de pred: ", pred.shape)
print("Tipo de datos de pred: ", pred.dtype)
print("\n")
# Obtener las clases predicha y real
# predicted, actual = classes[pred[0].argmax(0)], classes[y]
# print(f'Predicted: "{predicted}", Actual: "{actual}"')
predicted = classes[pred[0].argmax(0)]
print(f'Predicci贸n: "{predicted}"')
# Mover la imagen de vuelta a la CPU si est谩 en la GPU y permutar los ejes
img_tensor = img_tensor.squeeze(0).cpu() # Eliminar la dimensi贸n del batch
img_tensor = img_tensor.permute(1, 2, 0) # Cambiar el orden de los canales para plt.imshow()
print("Dentro del squeeze")
print("Forma de x: ", img_tensor.shape)
print("Tipo de datos de x: ", img_tensor.dtype)
print("\n")
# return im["composite"]
return im["composite"], predicted
theme = gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.red)
with gr.Blocks(theme=theme) as demo:
descripcion = """
# MNIST
El siguiente sistema permite predecir d铆gitos del 0 al 9, utilizando el Sketchpad de gradio. Se entren贸 una CNN con el dataset MNIST.
""".strip()
gr.Markdown(descripcion)
with gr.Row():
with gr.Column():
im = gr.Sketchpad(type="pil", image_mode='RGBA',)
with gr.Column():
# gr.CheckboxGroup()
prediction_text = gr.Textbox(label="Predicci贸n")
im_preview = gr.Image()
# gr
# im.change(predict, outputs=im_preview, inputs=im, show_progress="full", )
# im.change(predict, outputs=prediction_text, inputs=im, show_progress="full", )
im.change(predict, outputs=[im_preview, prediction_text], inputs=im, show_progress="full", )
gr.Markdown("> Gabriel Olmos Leiva")
demo.launch(share=True, debug=False)