ameenmarashi commited on
Commit
a08e47f
·
verified ·
1 Parent(s): 29183c0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import zipfile, os, uuid, shutil, subprocess, sys, time
3
+ import tensorflow as tf
4
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
5
+ import numpy as np
6
+
7
+ # Directory constants
8
+ UPLOAD_DIR = "uploads"
9
+ MODEL_DIR = "models"
10
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
11
+ os.makedirs(MODEL_DIR, exist_ok=True)
12
+
13
+ def train_and_export(dataset_zip, model_name, num_classes, epochs, batch_size, image_size):
14
+ # Save upload
15
+ uid = str(uuid.uuid4())
16
+ zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip")
17
+ with open(zip_path, "wb") as f:
18
+ f.write(dataset_zip.read())
19
+ extract_path = os.path.join(UPLOAD_DIR, uid)
20
+ with zipfile.ZipFile(zip_path, "r") as z:
21
+ z.extractall(extract_path)
22
+ train_dir = os.path.join(extract_path, "train")
23
+ val_dir = os.path.join(extract_path, "validation")
24
+ # Data gens
25
+ train_gen = ImageDataGenerator(rescale=1/255,
26
+ rotation_range=20,
27
+ width_shift_range=0.2,
28
+ height_shift_range=0.2,
29
+ horizontal_flip=True,
30
+ zoom_range=0.2) \
31
+ .flow_from_directory(train_dir, target_size=(image_size, image_size),
32
+ batch_size=batch_size, class_mode="categorical")
33
+ val_gen = ImageDataGenerator(rescale=1/255) \
34
+ .flow_from_directory(val_dir, target_size=(image_size, image_size),
35
+ batch_size=batch_size, class_mode="categorical")
36
+ # Build model (e.g., simplified ResNet-inspired CNN)
37
+ model = tf.keras.Sequential([
38
+ tf.keras.layers.Conv2D(32,3,activation="relu",
39
+ input_shape=(image_size,image_size,3)),
40
+ tf.keras.layers.BatchNormalization(),
41
+ tf.keras.layers.MaxPooling2D(),
42
+ tf.keras.layers.Conv2D(64,3,activation="relu"),
43
+ tf.keras.layers.BatchNormalization(),
44
+ tf.keras.layers.MaxPooling2D(),
45
+ tf.keras.layers.Flatten(),
46
+ tf.keras.layers.Dense(128,activation="relu"),
47
+ tf.keras.layers.Dropout(0.5),
48
+ tf.keras.layers.Dense(num_classes,activation="softmax")
49
+ ])
50
+ model.compile(optimizer="adam", loss="categorical_crossentropy",
51
+ metrics=["accuracy"])
52
+ # Train
53
+ start = time.time()
54
+ history = model.fit(train_gen,
55
+ validation_data=val_gen,
56
+ epochs=epochs,
57
+ verbose=0)
58
+ elapsed = time.time() - start
59
+ # Save outputs
60
+ space_model_dir = os.path.join(MODEL_DIR, uid)
61
+ os.makedirs(space_model_dir, exist_ok=True)
62
+ h5_path = os.path.join(space_model_dir, f"{model_name}.h5")
63
+ model.save(h5_path)
64
+ savedmodel_path = os.path.join(space_model_dir, "savedmodel")
65
+ model.save(savedmodel_path)
66
+ tfjs_dir = os.path.join(space_model_dir, "tfjs")
67
+ subprocess.run([
68
+ sys.executable, "-m", "tensorflowjs_converter",
69
+ "--input_format=tf_saved_model",
70
+ savedmodel_path,
71
+ tfjs_dir
72
+ ], check=True)
73
+ size_mb = sum(os.path.getsize(os.path.join(dp,fn))
74
+ for dp,_,fns in os.walk(space_model_dir)
75
+ for fn in fns) / (1024*1024)
76
+ return (f"✅ Trained in {elapsed:.1f}s\n"
77
+ f"Final val acc: {max(history.history['val_accuracy']):.4f}\n"
78
+ f"Model size: {size_mb:.2f} MB\n"
79
+ f"Download H5: {h5_path}\n"
80
+ f"Download SavedModel: {savedmodel_path}\n"
81
+ f"Download TF.js: {tfjs_dir}")
82
+
83
+ # Gradio interface
84
+ demo = gr.Interface(
85
+ fn=train_and_export,
86
+ inputs=[
87
+ gr.File(label="Dataset ZIP"),
88
+ gr.Textbox(label="Model Name", value="my_model"),
89
+ gr.Slider(2, 100, value=5, label="Number of Classes"),
90
+ gr.Slider(1, 200, value=50, step=1, label="Epochs"),
91
+ gr.Radio([16,32,64,128], value=32, label="Batch Size"),
92
+ gr.Radio([128,224,256], value=224, label="Image Size")
93
+ ],
94
+ outputs="text",
95
+ title="AI Image Classifier Trainer",
96
+ description="Upload train/validation dataset, train a CNN, and download H5, SavedModel, or TF.js."
97
+ )
98
+
99
+ if __name__ == "__main__":
100
+ demo.launch()