PraneshJs's picture
Added files to hf space
fff6782 verified
raw
history blame
1.39 kB
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import tensorflow as tf
# Load the trained model
model = tf.keras.models.load_model('mnist_model.h5')
def cnn_predict_digit(image):
# Handle Gradio Sketchpad dictionary input
if isinstance(image, dict) and 'composite' in image:
image = image['composite']
# Convert to grayscale if RGB
if image.ndim == 3 and image.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Invert colors (white background β†’ black background)
image = 255 - image
# Resize to 28x28
image = cv2.resize(image, (28, 28))
# Normalize and reshape
image = image.astype('float32') / 255.0
image = image.reshape(1, 28, 28, 1)
# Predict
prediction = model.predict(image)
pred_label = np.argmax(prediction, axis=1)[0]
return str(pred_label)
with gr.Blocks() as interface:
gr.Markdown(
"""
## ✍️ Digit Classification using Convolutional Neural Network
Draw a digit in the sketchpad below (0 to 9), then click **Submit** to see the prediction.
"""
)
with gr.Row():
sketchpad = gr.Sketchpad(image_mode='L')
output = gr.Label()
gr.Button("Submit").click(cnn_predict_digit, inputs=sketchpad, outputs=output)
gr.ClearButton([sketchpad, output])
interface.launch()