ameenmarashi commited on
Commit
b842669
·
verified ·
1 Parent(s): a08e47f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -86
app.py CHANGED
@@ -1,100 +1,218 @@
1
  import gradio as gr
2
- import zipfile, os, uuid, shutil, subprocess, sys, time
 
 
 
 
 
 
3
  import tensorflow as tf
4
  from tensorflow.keras.preprocessing.image import ImageDataGenerator
5
  import numpy as np
6
 
7
- # Directory constants
8
  UPLOAD_DIR = "uploads"
9
  MODEL_DIR = "models"
10
  os.makedirs(UPLOAD_DIR, exist_ok=True)
11
  os.makedirs(MODEL_DIR, exist_ok=True)
12
 
13
- def train_and_export(dataset_zip, model_name, num_classes, epochs, batch_size, image_size):
14
- # Save upload
15
- uid = str(uuid.uuid4())
16
- zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip")
17
- with open(zip_path, "wb") as f:
18
- f.write(dataset_zip.read())
19
- extract_path = os.path.join(UPLOAD_DIR, uid)
20
- with zipfile.ZipFile(zip_path, "r") as z:
21
- z.extractall(extract_path)
22
- train_dir = os.path.join(extract_path, "train")
23
- val_dir = os.path.join(extract_path, "validation")
24
- # Data gens
25
- train_gen = ImageDataGenerator(rescale=1/255,
26
- rotation_range=20,
27
- width_shift_range=0.2,
28
- height_shift_range=0.2,
29
- horizontal_flip=True,
30
- zoom_range=0.2) \
31
- .flow_from_directory(train_dir, target_size=(image_size, image_size),
32
- batch_size=batch_size, class_mode="categorical")
33
- val_gen = ImageDataGenerator(rescale=1/255) \
34
- .flow_from_directory(val_dir, target_size=(image_size, image_size),
35
- batch_size=batch_size, class_mode="categorical")
36
- # Build model (e.g., simplified ResNet-inspired CNN)
37
- model = tf.keras.Sequential([
38
- tf.keras.layers.Conv2D(32,3,activation="relu",
39
- input_shape=(image_size,image_size,3)),
40
- tf.keras.layers.BatchNormalization(),
41
- tf.keras.layers.MaxPooling2D(),
42
- tf.keras.layers.Conv2D(64,3,activation="relu"),
43
- tf.keras.layers.BatchNormalization(),
44
- tf.keras.layers.MaxPooling2D(),
45
- tf.keras.layers.Flatten(),
46
- tf.keras.layers.Dense(128,activation="relu"),
47
- tf.keras.layers.Dropout(0.5),
48
- tf.keras.layers.Dense(num_classes,activation="softmax")
49
- ])
50
- model.compile(optimizer="adam", loss="categorical_crossentropy",
51
- metrics=["accuracy"])
52
- # Train
53
- start = time.time()
54
- history = model.fit(train_gen,
55
- validation_data=val_gen,
56
- epochs=epochs,
57
- verbose=0)
58
- elapsed = time.time() - start
59
- # Save outputs
60
- space_model_dir = os.path.join(MODEL_DIR, uid)
61
- os.makedirs(space_model_dir, exist_ok=True)
62
- h5_path = os.path.join(space_model_dir, f"{model_name}.h5")
63
- model.save(h5_path)
64
- savedmodel_path = os.path.join(space_model_dir, "savedmodel")
65
- model.save(savedmodel_path)
66
- tfjs_dir = os.path.join(space_model_dir, "tfjs")
67
- subprocess.run([
68
- sys.executable, "-m", "tensorflowjs_converter",
69
- "--input_format=tf_saved_model",
70
- savedmodel_path,
71
- tfjs_dir
72
- ], check=True)
73
- size_mb = sum(os.path.getsize(os.path.join(dp,fn))
74
- for dp,_,fns in os.walk(space_model_dir)
75
- for fn in fns) / (1024*1024)
76
- return (f"✅ Trained in {elapsed:.1f}s\n"
77
- f"Final val acc: {max(history.history['val_accuracy']):.4f}\n"
78
- f"Model size: {size_mb:.2f} MB\n"
79
- f"Download H5: {h5_path}\n"
80
- f"Download SavedModel: {savedmodel_path}\n"
81
- f"Download TF.js: {tfjs_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # Gradio interface
84
- demo = gr.Interface(
85
- fn=train_and_export,
86
- inputs=[
87
- gr.File(label="Dataset ZIP"),
88
- gr.Textbox(label="Model Name", value="my_model"),
89
- gr.Slider(2, 100, value=5, label="Number of Classes"),
90
- gr.Slider(1, 200, value=50, step=1, label="Epochs"),
91
- gr.Radio([16,32,64,128], value=32, label="Batch Size"),
92
- gr.Radio([128,224,256], value=224, label="Image Size")
93
- ],
94
- outputs="text",
95
- title="AI Image Classifier Trainer",
96
- description="Upload train/validation dataset, train a CNN, and download H5, SavedModel, or TF.js."
97
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
 
99
  if __name__ == "__main__":
100
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
+ import zipfile
3
+ import os
4
+ import uuid
5
+ import shutil
6
+ import subprocess
7
+ import sys
8
+ import time
9
  import tensorflow as tf
10
  from tensorflow.keras.preprocessing.image import ImageDataGenerator
11
  import numpy as np
12
 
13
+ # Directory setup
14
  UPLOAD_DIR = "uploads"
15
  MODEL_DIR = "models"
16
  os.makedirs(UPLOAD_DIR, exist_ok=True)
17
  os.makedirs(MODEL_DIR, exist_ok=True)
18
 
19
+ def train_and_export(dataset_file, model_name, num_classes, epochs, batch_size, image_size):
20
+ try:
21
+ # Generate unique ID for this training session
22
+ uid = str(uuid.uuid4())
23
+ zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip")
24
+
25
+ # Copy uploaded file to our storage
26
+ shutil.copyfile(dataset_file, zip_path)
27
+
28
+ # Extract dataset
29
+ extract_path = os.path.join(UPLOAD_DIR, uid)
30
+ os.makedirs(extract_path, exist_ok=True)
31
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
32
+ zip_ref.extractall(extract_path)
33
+
34
+ # Locate train and validation directories
35
+ train_dir = os.path.join(extract_path, "train")
36
+ val_dir = os.path.join(extract_path, "validation")
37
+
38
+ # Verify dataset structure
39
+ if not os.path.exists(train_dir) or not os.path.exists(val_dir):
40
+ return "Error: Dataset must contain 'train' and 'validation' folders"
41
+
42
+ # Create data generators
43
+ train_datagen = ImageDataGenerator(
44
+ rescale=1./255,
45
+ rotation_range=20,
46
+ width_shift_range=0.2,
47
+ height_shift_range=0.2,
48
+ horizontal_flip=True,
49
+ zoom_range=0.2
50
+ )
51
+
52
+ val_datagen = ImageDataGenerator(rescale=1./255)
53
+
54
+ train_generator = train_datagen.flow_from_directory(
55
+ train_dir,
56
+ target_size=(image_size, image_size),
57
+ batch_size=batch_size,
58
+ class_mode='categorical'
59
+ )
60
+
61
+ val_generator = val_datagen.flow_from_directory(
62
+ val_dir,
63
+ target_size=(image_size, image_size),
64
+ batch_size=batch_size,
65
+ class_mode='categorical'
66
+ )
67
+
68
+ # Update num_classes based on actual data
69
+ actual_classes = train_generator.num_classes
70
+ if actual_classes != num_classes:
71
+ num_classes = actual_classes
72
+
73
+ # Build model
74
+ model = tf.keras.Sequential([
75
+ tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(image_size, image_size, 3)),
76
+ tf.keras.layers.BatchNormalization(),
77
+ tf.keras.layers.MaxPooling2D(),
78
+ tf.keras.layers.Dropout(0.25),
79
+
80
+ tf.keras.layers.Conv2D(64, 3, activation='relu'),
81
+ tf.keras.layers.BatchNormalization(),
82
+ tf.keras.layers.MaxPooling2D(),
83
+ tf.keras.layers.Dropout(0.25),
84
+
85
+ tf.keras.layers.Conv2D(128, 3, activation='relu'),
86
+ tf.keras.layers.BatchNormalization(),
87
+ tf.keras.layers.MaxPooling2D(),
88
+ tf.keras.layers.Dropout(0.25),
89
+
90
+ tf.keras.layers.Flatten(),
91
+ tf.keras.layers.Dense(256, activation='relu'),
92
+ tf.keras.layers.BatchNormalization(),
93
+ tf.keras.layers.Dropout(0.5),
94
+ tf.keras.layers.Dense(num_classes, activation='softmax')
95
+ ])
96
+
97
+ model.compile(
98
+ optimizer='adam',
99
+ loss='categorical_crossentropy',
100
+ metrics=['accuracy']
101
+ )
102
+
103
+ # Train model
104
+ start_time = time.time()
105
+ history = model.fit(
106
+ train_generator,
107
+ steps_per_epoch=train_generator.samples // train_generator.batch_size,
108
+ epochs=epochs,
109
+ validation_data=val_generator,
110
+ validation_steps=val_generator.samples // val_generator.batch_size,
111
+ verbose=0
112
+ )
113
+ training_time = time.time() - start_time
114
+
115
+ # Save models
116
+ model_dir = os.path.join(MODEL_DIR, uid)
117
+ os.makedirs(model_dir, exist_ok=True)
118
+
119
+ # Save H5 model
120
+ h5_path = os.path.join(model_dir, f"{model_name}.h5")
121
+ model.save(h5_path)
122
+
123
+ # Save SavedModel
124
+ savedmodel_path = os.path.join(model_dir, "savedmodel")
125
+ model.save(savedmodel_path)
126
+
127
+ # Convert to TensorFlow.js
128
+ tfjs_path = os.path.join(model_dir, "tfjs")
129
+ try:
130
+ subprocess.run([
131
+ "tensorflowjs_converter",
132
+ "--input_format=tf_saved_model",
133
+ savedmodel_path,
134
+ tfjs_path
135
+ ], check=True)
136
+ except Exception:
137
+ # Install tensorflowjs if not available
138
+ subprocess.run([sys.executable, "-m", "pip", "install", "tensorflowjs"], check=True)
139
+ subprocess.run([
140
+ "tensorflowjs_converter",
141
+ "--input_format=tf_saved_model",
142
+ savedmodel_path,
143
+ tfjs_path
144
+ ], check=True)
145
+
146
+ # Calculate model size
147
+ model_size = 0
148
+ for dirpath, _, filenames in os.walk(model_dir):
149
+ for f in filenames:
150
+ fp = os.path.join(dirpath, f)
151
+ model_size += os.path.getsize(fp)
152
+ model_size_mb = model_size / (1024 * 1024)
153
+
154
+ # Get class names
155
+ class_names = list(train_generator.class_indices.keys())
156
+
157
+ # Prepare download links
158
+ download_info = f"""
159
+ ✅ Training completed successfully!
160
+ ⏱️ Training time: {training_time:.2f} seconds
161
+ 📊 Validation accuracy: {max(history.history['val_accuracy']):.4f}
162
+ 📦 Model size: {model_size_mb:.2f} MB
163
+ 🗂️ Number of classes: {num_classes}
164
+ """
165
+
166
+ # Return paths for download
167
+ return download_info, h5_path, savedmodel_path, tfjs_path
168
+
169
+ except Exception as e:
170
+ return f"❌ Training failed: {str(e)}", None, None, None
171
 
172
  # Gradio interface
173
+ with gr.Blocks(title="AI Image Classifier Trainer") as demo:
174
+ gr.Markdown("# 🖼️ AI Image Classifier Trainer")
175
+ gr.Markdown("Upload your dataset (ZIP with train/validation folders), configure training, and download models in multiple formats.")
176
+
177
+ with gr.Row():
178
+ with gr.Column():
179
+ dataset = gr.File(label="Dataset ZIP File", file_types=[".zip"])
180
+ model_name = gr.Textbox(label="Model Name", value="my_classifier")
181
+ num_classes = gr.Slider(2, 100, value=5, step=1, label="Number of Classes")
182
+ epochs = gr.Slider(5, 200, value=30, step=1, label="Training Epochs")
183
+ batch_size = gr.Radio([16, 32, 64], value=32, label="Batch Size")
184
+ image_size = gr.Radio([128, 224, 256], value=224, label="Image Size (px)")
185
+ train_btn = gr.Button("🚀 Train Model", variant="primary")
186
+
187
+ with gr.Column():
188
+ output = gr.Textbox(label="Training Results", interactive=False)
189
+ h5_download = gr.File(label="H5 Model Download", visible=False)
190
+ savedmodel_download = gr.File(label="SavedModel Download", visible=False)
191
+ tfjs_download = gr.File(label="TensorFlow.js Download", visible=False)
192
+
193
+ def toggle_downloads(results, h5_path, saved_path, tfjs_path):
194
+ downloads_visible = h5_path is not None
195
+ return (
196
+ gr.File(visible=downloads_visible, value=h5_path),
197
+ gr.File(visible=downloads_visible, value=saved_path),
198
+ gr.File(visible=downloads_visible, value=tfjs_path)
199
+ )
200
+
201
+ train_btn.click(
202
+ fn=train_and_export,
203
+ inputs=[dataset, model_name, num_classes, epochs, batch_size, image_size],
204
+ outputs=[output, h5_download, savedmodel_download, tfjs_download]
205
+ ).then(
206
+ fn=toggle_downloads,
207
+ inputs=[output, h5_download, savedmodel_download, tfjs_download],
208
+ outputs=[h5_download, savedmodel_download, tfjs_download]
209
+ )
210
 
211
+ # Launch settings for Hugging Face Spaces
212
  if __name__ == "__main__":
213
+ demo.launch(
214
+ server_name="0.0.0.0",
215
+ server_port=7860,
216
+ share=False,
217
+ max_file_size=100 # 100MB file size limit
218
+ )