ZennyKenny commited on
Commit
f376d1d
Β·
verified Β·
1 Parent(s): 74dabbb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +380 -0
app.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import io
5
+ import base64
6
+ from typing import Optional, Tuple
7
+ import plotly.express as px
8
+ import plotly.graph_objects as go
9
+ from plotly.subplots import make_subplots
10
+ import warnings
11
+ warnings.filterwarnings("ignore")
12
+
13
+ # Import Mostly AI SDK
14
+ try:
15
+ from mostlyai.sdk import MostlyAI
16
+ MOSTLY_AI_AVAILABLE = True
17
+ except ImportError:
18
+ MOSTLY_AI_AVAILABLE = False
19
+ print("Warning: Mostly AI SDK not available. Please install with: pip install mostlyai[local]")
20
+
21
+ class SyntheticDataGenerator:
22
+ def __init__(self):
23
+ self.mostly = None
24
+ self.generator = None
25
+ self.original_data = None
26
+
27
+ def initialize_mostly_ai(self):
28
+ """Initialize Mostly AI SDK"""
29
+ if not MOSTLY_AI_AVAILABLE:
30
+ return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]"
31
+
32
+ try:
33
+ self.mostly = MostlyAI(local=True, local_port=8080)
34
+ return True, "Mostly AI SDK initialized successfully!"
35
+ except Exception as e:
36
+ return False, f"Failed to initialize Mostly AI SDK: {str(e)}"
37
+
38
+
39
+ def train_generator(self, data: pd.DataFrame, name: str, epochs: int = 10, max_training_time: int = 60, batch_size: int = 32, value_protection: bool = True) -> Tuple[bool, str]:
40
+ """Train the synthetic data generator"""
41
+ if not self.mostly:
42
+ return False, "Mostly AI SDK not initialized"
43
+
44
+ try:
45
+ self.original_data = data
46
+ train_config = {'tables':
47
+ [
48
+ {
49
+ 'name': name,
50
+ 'data': data,
51
+ 'tabular_model_configuration':
52
+ {
53
+ 'max_epochs': epochs,
54
+ 'max_training_time': max_training_time,
55
+ 'value_protection': value_protection,
56
+ 'batch_size': batch_size
57
+ }
58
+ }
59
+ ]
60
+ }
61
+
62
+ self.generator = self.mostly.train(
63
+ config = train_config
64
+ )
65
+ return True, f"Generator trained successfully! Model: {name}"
66
+ except Exception as e:
67
+ return False, f"Training failed: {str(e)}"
68
+
69
+ def generate_synthetic_data(self, size: int) -> Tuple[pd.DataFrame, str]:
70
+ """Generate synthetic data"""
71
+ if not self.generator:
72
+ return None, "No trained generator available"
73
+
74
+ try:
75
+ synthetic_data = self.mostly.generate(self.generator, size=size)
76
+ df = synthetic_data.data()
77
+ return df, f"Generated {len(df)} synthetic records successfully!"
78
+ except Exception as e:
79
+ return None, f"Generation failed: {str(e)}"
80
+
81
+ def get_quality_report(self) -> str:
82
+ """Get quality assurance report"""
83
+ if not self.generator:
84
+ return "No trained generator available"
85
+
86
+ try:
87
+ report = self.generator.reports(display=False)
88
+ return str(report)
89
+ except Exception as e:
90
+ return f"Failed to generate report: {str(e)}"
91
+
92
+ def estimate_memory_usage(self, df: pd.DataFrame) -> str:
93
+ """Estimate memory usage for the dataset"""
94
+ if df is None or df.empty:
95
+ return "No data to analyze"
96
+
97
+ # Calculate approximate memory usage
98
+ memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
99
+ rows, cols = len(df), len(df.columns)
100
+
101
+ # Estimate training memory (roughly 3-5x the data size)
102
+ estimated_training_mb = memory_mb * 4
103
+
104
+ status = "βœ… Good" if memory_mb < 100 else "⚠️ Large" if memory_mb < 500 else "❌ Very Large"
105
+
106
+ return f"""
107
+ **Memory Usage Estimate:**
108
+ - Data size: {memory_mb:.1f} MB
109
+ - Estimated training memory: {estimated_training_mb:.1f} MB
110
+ - Status: {status}
111
+ - Rows: {rows:,} | Columns: {cols}
112
+ """.strip()
113
+
114
+ # Initialize the generator
115
+ generator = SyntheticDataGenerator()
116
+
117
+
118
+ def initialize_sdk() -> Tuple[str, str]:
119
+ """Initialize the Mostly AI SDK"""
120
+ success, message = generator.initialize_mostly_ai()
121
+ status = "βœ… Success" if success else "❌ Error"
122
+ return status, message
123
+
124
+ def train_model(data: pd.DataFrame, model_name: str, epochs: int, max_training_time: int, batch_size: int, value_protection: bool) -> Tuple[str, str]:
125
+ """Train the synthetic data generator"""
126
+ if data is None or data.empty:
127
+ return "❌ Error", "Please upload or create sample data first"
128
+
129
+ success, message = generator.train_generator(data, model_name, epochs, max_training_time, batch_size, value_protection)
130
+ status = "βœ… Success" if success else "❌ Error"
131
+ return status, message
132
+
133
+ def generate_data(size: int) -> Tuple[pd.DataFrame, str]:
134
+ """Generate synthetic data"""
135
+ if generator.generator is None:
136
+ return None, "❌ Please train a model first"
137
+
138
+ synthetic_df, message = generator.generate_synthetic_data(size)
139
+ if synthetic_df is not None:
140
+ status = "βœ… Success"
141
+ else:
142
+ status = "❌ Error"
143
+
144
+ return synthetic_df, f"{status} - {message}"
145
+
146
+ def get_quality_report() -> str:
147
+ """Get quality report"""
148
+ return generator.get_quality_report()
149
+
150
+ def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame) -> go.Figure:
151
+ """Create comparison plots between original and synthetic data"""
152
+ if original_df is None or synthetic_df is None:
153
+ return None
154
+
155
+ # Select numeric columns for comparison
156
+ numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist()
157
+
158
+ if not numeric_cols:
159
+ return None
160
+
161
+ # Create subplots
162
+ n_cols = min(3, len(numeric_cols))
163
+ n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
164
+
165
+ fig = make_subplots(
166
+ rows=n_rows,
167
+ cols=n_cols,
168
+ subplot_titles=numeric_cols[:n_rows*n_cols]
169
+ )
170
+
171
+ for i, col in enumerate(numeric_cols[:n_rows*n_cols]):
172
+ row = i // n_cols + 1
173
+ col_idx = i % n_cols + 1
174
+
175
+ # Add original data histogram
176
+ fig.add_trace(
177
+ go.Histogram(
178
+ x=original_df[col],
179
+ name=f'Original {col}',
180
+ opacity=0.7,
181
+ nbinsx=20
182
+ ),
183
+ row=row, col=col_idx
184
+ )
185
+
186
+ # Add synthetic data histogram
187
+ fig.add_trace(
188
+ go.Histogram(
189
+ x=synthetic_df[col],
190
+ name=f'Synthetic {col}',
191
+ opacity=0.7,
192
+ nbinsx=20
193
+ ),
194
+ row=row, col=col_idx
195
+ )
196
+
197
+ fig.update_layout(
198
+ title="Original vs Synthetic Data Comparison",
199
+ height=300 * n_rows,
200
+ showlegend=True
201
+ )
202
+
203
+ return fig
204
+
205
+ def download_csv(df: pd.DataFrame) -> str:
206
+ """Convert DataFrame to CSV for download"""
207
+ if df is None or df.empty:
208
+ return None
209
+
210
+ csv = df.to_csv(index=False)
211
+ return csv
212
+
213
+ # Create the Gradio interface
214
+ def create_interface():
215
+ with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo:
216
+ gr.Markdown("""
217
+ # 🎭 MOSTLY AI Synthetic Data Generator
218
+
219
+ Generate high-quality synthetic data using the Mostly AI SDK. Upload your own CSV files to generate synthetic data that preserves the statistical properties of your original dataset.
220
+ """)
221
+
222
+ with gr.Tab("πŸš€ Quick Start"):
223
+ gr.Markdown("### Initialize the SDK and upload your data")
224
+
225
+ with gr.Row():
226
+ with gr.Column():
227
+ init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary")
228
+ init_status = gr.Textbox(label="Initialization Status", interactive=False)
229
+
230
+ with gr.Column():
231
+ gr.Markdown("""
232
+ **Next Steps:**
233
+ 1. Initialize the SDK (click button above)
234
+ 2. Go to "Upload Data and Train Model" tab to upload your CSV file
235
+ 3. Train a model on your data
236
+ 4. Generate synthetic data
237
+ """)
238
+
239
+ with gr.Tab("πŸ“Š Upload Data and Train Model"):
240
+ gr.Markdown("### Upload your CSV file to generate synthetic data")
241
+
242
+ gr.Markdown("""
243
+ **πŸ“‹ File Requirements:**
244
+ - **Format:** CSV with header row
245
+ - **Size:** Optimized for Hugging Face Spaces (2 vCPU, 16GB RAM)
246
+ """)
247
+
248
+ file_upload = gr.File(
249
+ label="Upload CSV File",
250
+ file_types=[".csv"],
251
+ file_count="single"
252
+ )
253
+
254
+ uploaded_data = gr.Dataframe(label="Uploaded Data", interactive=False)
255
+
256
+ memory_info = gr.Markdown(label="Memory Usage Info", visible=False)
257
+
258
+ with gr.Row():
259
+ with gr.Column():
260
+ model_name = gr.Textbox(
261
+ value="My Synthetic Model",
262
+ label="Model Name",
263
+ placeholder="Enter a name for your model"
264
+ )
265
+ epochs = gr.Slider(1, 200, value=100, step=1, label="Training Epochs")
266
+ max_training_time = gr.Slider(1, 1000, value=60, step=1, label="Maximum Training Time")
267
+ batch_size = gr.Slider(8, 1024, value=32, step=8, label="Training Batch Size")
268
+ value_protection = gr.Checkbox(label="Value Protection", info="Enable Value Protection")
269
+ train_btn = gr.Button("Train Model", variant="primary")
270
+
271
+ with gr.Column():
272
+ train_status = gr.Textbox(label="Training Status", interactive=False)
273
+ quality_report = gr.Textbox(label="Quality Report", lines=10, interactive=False)
274
+
275
+ get_report_btn = gr.Button("Get Quality Report", variant="secondary")
276
+
277
+ with gr.Tab("🎲 Generate Data"):
278
+ gr.Markdown("### Generate synthetic data from your trained model")
279
+
280
+ with gr.Row():
281
+ with gr.Column():
282
+ gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate")
283
+ generate_btn = gr.Button("Generate Synthetic Data", variant="primary")
284
+
285
+ with gr.Column():
286
+ gen_status = gr.Textbox(label="Generation Status", interactive=False)
287
+
288
+ synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False)
289
+
290
+ with gr.Row():
291
+ download_btn = gr.DownloadButton("Download CSV", variant="secondary")
292
+ comparison_plot = gr.Plot(label="Data Comparison")
293
+
294
+ # Event handlers
295
+ init_btn.click(
296
+ initialize_sdk,
297
+ outputs=[init_status, init_status]
298
+ )
299
+
300
+ train_btn.click(
301
+ train_model,
302
+ inputs=[uploaded_data, model_name, epochs, max_training_time, batch_size, value_protection],
303
+ outputs=[train_status, train_status]
304
+ )
305
+
306
+ get_report_btn.click(
307
+ get_quality_report,
308
+ outputs=[quality_report]
309
+ )
310
+
311
+ generate_btn.click(
312
+ generate_data,
313
+ inputs=[gen_size],
314
+ outputs=[synthetic_data, gen_status]
315
+ )
316
+
317
+ # Update download button when synthetic data changes
318
+ synthetic_data.change(
319
+ download_csv,
320
+ inputs=[synthetic_data],
321
+ outputs=[download_btn]
322
+ )
323
+
324
+ # Create comparison plot when both datasets are available
325
+ synthetic_data.change(
326
+ create_comparison_plot,
327
+ inputs=[uploaded_data, synthetic_data],
328
+ outputs=[comparison_plot]
329
+ )
330
+
331
+ # Handle file upload with size and column limits
332
+ def process_uploaded_file(file):
333
+ if file is None:
334
+ return None, "No file uploaded", gr.update(visible=False)
335
+
336
+ try:
337
+ # Read the CSV file
338
+ df = pd.read_csv(file.name)
339
+
340
+ # # Check column limit (max 20 columns)
341
+ # if len(df.columns) > 20:
342
+ # return None, f"❌ Too many columns! Maximum allowed: 20, found: {len(df.columns)}. Please reduce the number of columns in your CSV file.", gr.update(visible=False)
343
+
344
+ # # Check row limit (max 10,000 records)
345
+ # if len(df) > 10000:
346
+ # return None, f"❌ Too many records! Maximum allowed: 10,000, found: {len(df)}. Please reduce the number of rows in your CSV file.", gr.update(visible=False)
347
+
348
+ # # Check minimum requirements
349
+ # if len(df) < 1000:
350
+ # return None, f"❌ Too few records! Minimum required: 1,000, found: {len(df)}. Please provide more data for training.", gr.update(visible=False)
351
+
352
+ # if len(df.columns) < 2:
353
+ # return None, f"❌ Too few columns! Minimum required: 2, found: {len(df.columns)}. Please provide more columns for training.", gr.update(visible=False)
354
+
355
+ # Success message with file info
356
+ success_msg = f"βœ… File uploaded successfully! {len(df)} rows Γ— {len(df.columns)} columns"
357
+
358
+ # Generate memory usage info
359
+ memory_info = generator.estimate_memory_usage(df)
360
+
361
+ return df, success_msg, gr.update(value=memory_info, visible=True)
362
+
363
+ except Exception as e:
364
+ return None, f"❌ Error reading file: {str(e)}", gr.update(visible=False)
365
+
366
+ file_upload.change(
367
+ process_uploaded_file,
368
+ inputs=[file_upload],
369
+ outputs=[uploaded_data, train_status, memory_info]
370
+ )
371
+
372
+ return demo
373
+
374
+ if __name__ == "__main__":
375
+ demo = create_interface()
376
+ demo.launch(
377
+ server_name="0.0.0.0",
378
+ server_port=7860,
379
+ share=True
380
+ )