hiwei commited on
Commit
0763043
·
verified ·
1 Parent(s): e10b15b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ TITLE = "Handwritten Digit Recognition Demo"
7
+
8
+ DESCRIPTION = "This demo employs a basic CNN architecture inspired by [MIT 6.S191’s Lab2 Part1](https://github.com/aamini/introtodeeplearning/blob/master/lab2/Part1_MNIST.ipynb). "\
9
+ "It achieves about 98% accuracy on the MNIST test dataset but may perform poorly, particularly with digits 8 and 9, likely due to suboptimal image preprocessing."
10
+
11
+ model = tf.keras.saving.load_model("tf_model_mnist")
12
+
13
+
14
+ def preprocess(image):
15
+ """ Normalize Gradio image to MNIST format """
16
+ image = image.resize((28, 28), Image.Resampling.HAMMING)
17
+ img_array = np.asarray(image, dtype=np.float32)
18
+ for i in range(img_array.shape[0]):
19
+ for j in range(img_array.shape[1]):
20
+ alpha = img_array[i, j, 3]
21
+ if alpha == 0.:
22
+ img_array[i, j] = [0., 0., 0., 255.]
23
+ else:
24
+ img_array[i, j] = [255., 255., 255., 255.]
25
+
26
+ new_image = Image.fromarray(img_array.astype(np.uint8), "RGBA")
27
+ new_image = new_image.convert("L")
28
+ image_array = tf.keras.utils.img_to_array(new_image)
29
+ image_array = (np.expand_dims(image_array, axis=0)/255.).astype(np.float32)
30
+ return image_array, new_image
31
+
32
+
33
+ def predict(img):
34
+ img = img["composite"]
35
+ input_arr, new_image = preprocess(img)
36
+ print("input:", input_arr.shape)
37
+ predictions = model.predict(input_arr)
38
+ return {str(i): predictions[0][i] for i in range(10)}, new_image
39
+
40
+
41
+ input_image = gr.Sketchpad(
42
+ layers=False,
43
+ type="pil",
44
+ )
45
+ demo = gr.Interface(
46
+ title=TITLE,
47
+ description=DESCRIPTION,
48
+ predict,
49
+ inputs=input_image,
50
+ outputs=['label', 'image']
51
+ )
52
+
53
+
54
+ demo.launch()