Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import tensorflow as tf
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from PIL import Image, ImageOps
|
6 |
+
|
7 |
+
def load_model():
|
8 |
+
try:
|
9 |
+
return tf.keras.models.load_model("mnist_model.h5")
|
10 |
+
except:
|
11 |
+
model = tf.keras.models.Sequential([
|
12 |
+
tf.keras.layers.Flatten(input_shape=(28, 28)),
|
13 |
+
tf.keras.layers.Dense(128, activation='relu'),
|
14 |
+
tf.keras.layers.Dense(10, activation='softmax')
|
15 |
+
])
|
16 |
+
model.compile(optimizer='adam',
|
17 |
+
loss='sparse_categorical_crossentropy',
|
18 |
+
metrics=['accuracy'])
|
19 |
+
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
|
20 |
+
x_train = x_train / 255.0
|
21 |
+
model.fit(x_train, y_train, epochs=5, batch_size=32, verbose=0)
|
22 |
+
model.save("mnist_model.h5")
|
23 |
+
return model
|
24 |
+
|
25 |
+
model = load_model()
|
26 |
+
|
27 |
+
def segment_digits(image: Image.Image):
|
28 |
+
img = np.array(image.convert("L"))
|
29 |
+
img = ImageOps.invert(Image.fromarray(img))
|
30 |
+
img = np.array(img)
|
31 |
+
img = cv2.resize(img, (img.shape[1]*2, img.shape[0]*2))
|
32 |
+
_, thresh = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY)
|
33 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
34 |
+
|
35 |
+
digit_images = []
|
36 |
+
for cnt in sorted(contours, key=lambda x: cv2.boundingRect(x)[0]):
|
37 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
38 |
+
digit = thresh[y:y+h, x:x+w]
|
39 |
+
digit = cv2.resize(digit, (28, 28))
|
40 |
+
digit = digit / 255.0
|
41 |
+
digit_images.append(digit.reshape(1, 28, 28))
|
42 |
+
return digit_images
|
43 |
+
|
44 |
+
def classify_multi_digit(image):
|
45 |
+
digits = segment_digits(image)
|
46 |
+
result = ""
|
47 |
+
confidences = {}
|
48 |
+
for i, digit in enumerate(digits):
|
49 |
+
pred = model.predict(digit)[0]
|
50 |
+
digit_class = np.argmax(pred)
|
51 |
+
result += str(digit_class)
|
52 |
+
confidences[f"Digit {i+1} ({digit_class})"] = round(np.max(pred), 2)
|
53 |
+
return f"Predicted Number: {result}", confidences
|
54 |
+
|
55 |
+
demo = gr.Interface(
|
56 |
+
fn=classify_multi_digit,
|
57 |
+
inputs=gr.Image(type="pil", label="Upload image with digits (e.g. 178)"),
|
58 |
+
outputs=[
|
59 |
+
gr.Text(label="Predicted Number"),
|
60 |
+
gr.Label(label="Confidence per Digit")
|
61 |
+
],
|
62 |
+
title="🧠 Multi-digit MNIST Classifier",
|
63 |
+
description="Upload an image containing multiple handwritten digits (e.g. '178'). The app segments and classifies each digit using a simple MNIST-trained neural network."
|
64 |
+
)
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
demo.launch()
|