mahesh1209 commited on
Commit
2358be3
·
verified ·
1 Parent(s): 21f583e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image, ImageOps
6
+
7
+ def load_model():
8
+ try:
9
+ return tf.keras.models.load_model("mnist_model.h5")
10
+ except:
11
+ model = tf.keras.models.Sequential([
12
+ tf.keras.layers.Flatten(input_shape=(28, 28)),
13
+ tf.keras.layers.Dense(128, activation='relu'),
14
+ tf.keras.layers.Dense(10, activation='softmax')
15
+ ])
16
+ model.compile(optimizer='adam',
17
+ loss='sparse_categorical_crossentropy',
18
+ metrics=['accuracy'])
19
+ (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
20
+ x_train = x_train / 255.0
21
+ model.fit(x_train, y_train, epochs=5, batch_size=32, verbose=0)
22
+ model.save("mnist_model.h5")
23
+ return model
24
+
25
+ model = load_model()
26
+
27
+ def segment_digits(image: Image.Image):
28
+ img = np.array(image.convert("L"))
29
+ img = ImageOps.invert(Image.fromarray(img))
30
+ img = np.array(img)
31
+ img = cv2.resize(img, (img.shape[1]*2, img.shape[0]*2))
32
+ _, thresh = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY)
33
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
34
+
35
+ digit_images = []
36
+ for cnt in sorted(contours, key=lambda x: cv2.boundingRect(x)[0]):
37
+ x, y, w, h = cv2.boundingRect(cnt)
38
+ digit = thresh[y:y+h, x:x+w]
39
+ digit = cv2.resize(digit, (28, 28))
40
+ digit = digit / 255.0
41
+ digit_images.append(digit.reshape(1, 28, 28))
42
+ return digit_images
43
+
44
+ def classify_multi_digit(image):
45
+ digits = segment_digits(image)
46
+ result = ""
47
+ confidences = {}
48
+ for i, digit in enumerate(digits):
49
+ pred = model.predict(digit)[0]
50
+ digit_class = np.argmax(pred)
51
+ result += str(digit_class)
52
+ confidences[f"Digit {i+1} ({digit_class})"] = round(np.max(pred), 2)
53
+ return f"Predicted Number: {result}", confidences
54
+
55
+ demo = gr.Interface(
56
+ fn=classify_multi_digit,
57
+ inputs=gr.Image(type="pil", label="Upload image with digits (e.g. 178)"),
58
+ outputs=[
59
+ gr.Text(label="Predicted Number"),
60
+ gr.Label(label="Confidence per Digit")
61
+ ],
62
+ title="🧠 Multi-digit MNIST Classifier",
63
+ description="Upload an image containing multiple handwritten digits (e.g. '178'). The app segments and classifies each digit using a simple MNIST-trained neural network."
64
+ )
65
+
66
+ if __name__ == "__main__":
67
+ demo.launch()