Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import io | |
import base64 | |
from typing import Optional, Tuple | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Import Mostly AI SDK | |
try: | |
from mostlyai.sdk import MostlyAI | |
MOSTLY_AI_AVAILABLE = True | |
except ImportError: | |
MOSTLY_AI_AVAILABLE = False | |
print("Warning: Mostly AI SDK not available. Please install with: pip install mostlyai[local]") | |
class SyntheticDataGenerator: | |
def __init__(self): | |
self.mostly = None | |
self.generator = None | |
self.original_data = None | |
def initialize_mostly_ai(self): | |
"""Initialize Mostly AI SDK""" | |
if not MOSTLY_AI_AVAILABLE: | |
return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]" | |
try: | |
self.mostly = MostlyAI(local=True, local_port=8080) | |
return True, "Mostly AI SDK initialized successfully!" | |
except Exception as e: | |
return False, f"Failed to initialize Mostly AI SDK: {str(e)}" | |
def train_generator(self, data: pd.DataFrame, name: str, epochs: int = 10, max_training_time: int = 60, batch_size: int = 32, value_protection: bool = True) -> Tuple[bool, str]: | |
"""Train the synthetic data generator""" | |
if not self.mostly: | |
return False, "Mostly AI SDK not initialized" | |
try: | |
self.original_data = data | |
train_config = {'tables': | |
[ | |
{ | |
'name': name, | |
'data': data, | |
'tabular_model_configuration': | |
{ | |
'max_epochs': epochs, | |
'max_training_time': max_training_time, | |
'value_protection': value_protection, | |
'batch_size': batch_size | |
} | |
} | |
] | |
} | |
self.generator = self.mostly.train( | |
config = train_config | |
) | |
return True, f"Generator trained successfully! Model: {name}" | |
except Exception as e: | |
return False, f"Training failed: {str(e)}" | |
def generate_synthetic_data(self, size: int) -> Tuple[pd.DataFrame, str]: | |
"""Generate synthetic data""" | |
if not self.generator: | |
return None, "No trained generator available" | |
try: | |
synthetic_data = self.mostly.generate(self.generator, size=size) | |
df = synthetic_data.data() | |
return df, f"Generated {len(df)} synthetic records successfully!" | |
except Exception as e: | |
return None, f"Generation failed: {str(e)}" | |
def get_quality_report(self) -> str: | |
"""Get quality assurance report""" | |
if not self.generator: | |
return "No trained generator available" | |
try: | |
report = self.generator.reports(display=False) | |
return str(report) | |
except Exception as e: | |
return f"Failed to generate report: {str(e)}" | |
def estimate_memory_usage(self, df: pd.DataFrame) -> str: | |
"""Estimate memory usage for the dataset""" | |
if df is None or df.empty: | |
return "No data to analyze" | |
# Calculate approximate memory usage | |
memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024) | |
rows, cols = len(df), len(df.columns) | |
# Estimate training memory (roughly 3-5x the data size) | |
estimated_training_mb = memory_mb * 4 | |
status = "β Good" if memory_mb < 100 else "β οΈ Large" if memory_mb < 500 else "β Very Large" | |
return f""" | |
**Memory Usage Estimate:** | |
- Data size: {memory_mb:.1f} MB | |
- Estimated training memory: {estimated_training_mb:.1f} MB | |
- Status: {status} | |
- Rows: {rows:,} | Columns: {cols} | |
""".strip() | |
# Initialize the generator | |
generator = SyntheticDataGenerator() | |
def initialize_sdk() -> Tuple[str, str]: | |
"""Initialize the Mostly AI SDK""" | |
success, message = generator.initialize_mostly_ai() | |
status = "β Success" if success else "β Error" | |
return status, message | |
def train_model(data: pd.DataFrame, model_name: str, epochs: int, max_training_time: int, batch_size: int, value_protection: bool) -> Tuple[str, str]: | |
"""Train the synthetic data generator""" | |
if data is None or data.empty: | |
return "β Error", "Please upload or create sample data first" | |
success, message = generator.train_generator(data, model_name, epochs, max_training_time, batch_size, value_protection) | |
status = "β Success" if success else "β Error" | |
return status, message | |
def generate_data(size: int) -> Tuple[pd.DataFrame, str]: | |
"""Generate synthetic data""" | |
if generator.generator is None: | |
return None, "β Please train a model first" | |
synthetic_df, message = generator.generate_synthetic_data(size) | |
if synthetic_df is not None: | |
status = "β Success" | |
else: | |
status = "β Error" | |
return synthetic_df, f"{status} - {message}" | |
def get_quality_report() -> str: | |
"""Get quality report""" | |
return generator.get_quality_report() | |
def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame) -> go.Figure: | |
"""Create comparison plots between original and synthetic data""" | |
if original_df is None or synthetic_df is None: | |
return None | |
# Select numeric columns for comparison | |
numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist() | |
if not numeric_cols: | |
return None | |
# Create subplots | |
n_cols = min(3, len(numeric_cols)) | |
n_rows = (len(numeric_cols) + n_cols - 1) // n_cols | |
fig = make_subplots( | |
rows=n_rows, | |
cols=n_cols, | |
subplot_titles=numeric_cols[:n_rows*n_cols] | |
) | |
for i, col in enumerate(numeric_cols[:n_rows*n_cols]): | |
row = i // n_cols + 1 | |
col_idx = i % n_cols + 1 | |
# Add original data histogram | |
fig.add_trace( | |
go.Histogram( | |
x=original_df[col], | |
name=f'Original {col}', | |
opacity=0.7, | |
nbinsx=20 | |
), | |
row=row, col=col_idx | |
) | |
# Add synthetic data histogram | |
fig.add_trace( | |
go.Histogram( | |
x=synthetic_df[col], | |
name=f'Synthetic {col}', | |
opacity=0.7, | |
nbinsx=20 | |
), | |
row=row, col=col_idx | |
) | |
fig.update_layout( | |
title="Original vs Synthetic Data Comparison", | |
height=300 * n_rows, | |
showlegend=True | |
) | |
return fig | |
def download_csv(df: pd.DataFrame) -> str: | |
"""Convert DataFrame to CSV for download""" | |
if df is None or df.empty: | |
return None | |
csv = df.to_csv(index=False) | |
return csv | |
# Create the Gradio interface | |
def create_interface(): | |
with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π MOSTLY AI Synthetic Data Generator | |
Generate high-quality synthetic data using the Mostly AI SDK. Upload your own CSV files to generate synthetic data that preserves the statistical properties of your original dataset. | |
""") | |
with gr.Tab("π Quick Start"): | |
gr.Markdown("### Initialize the SDK and upload your data") | |
with gr.Row(): | |
with gr.Column(): | |
init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary") | |
init_status = gr.Textbox(label="Initialization Status", interactive=False) | |
with gr.Column(): | |
gr.Markdown(""" | |
**Next Steps:** | |
1. Initialize the SDK (click button above) | |
2. Go to "Upload Data and Train Model" tab to upload your CSV file | |
3. Train a model on your data | |
4. Generate synthetic data | |
""") | |
with gr.Tab("π Upload Data and Train Model"): | |
gr.Markdown("### Upload your CSV file to generate synthetic data") | |
gr.Markdown(""" | |
**π File Requirements:** | |
- **Format:** CSV with header row | |
- **Size:** Optimized for Hugging Face Spaces (2 vCPU, 16GB RAM) | |
""") | |
file_upload = gr.File( | |
label="Upload CSV File", | |
file_types=[".csv"], | |
file_count="single" | |
) | |
uploaded_data = gr.Dataframe(label="Uploaded Data", interactive=False) | |
memory_info = gr.Markdown(label="Memory Usage Info", visible=False) | |
with gr.Row(): | |
with gr.Column(): | |
model_name = gr.Textbox( | |
value="My Synthetic Model", | |
label="Model Name", | |
placeholder="Enter a name for your model" | |
) | |
epochs = gr.Slider(1, 200, value=100, step=1, label="Training Epochs") | |
max_training_time = gr.Slider(1, 1000, value=60, step=1, label="Maximum Training Time") | |
batch_size = gr.Slider(8, 1024, value=32, step=8, label="Training Batch Size") | |
value_protection = gr.Checkbox(label="Value Protection", info="Enable Value Protection") | |
train_btn = gr.Button("Train Model", variant="primary") | |
with gr.Column(): | |
train_status = gr.Textbox(label="Training Status", interactive=False) | |
quality_report = gr.Textbox(label="Quality Report", lines=10, interactive=False) | |
get_report_btn = gr.Button("Get Quality Report", variant="secondary") | |
with gr.Tab("π² Generate Data"): | |
gr.Markdown("### Generate synthetic data from your trained model") | |
with gr.Row(): | |
with gr.Column(): | |
gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate") | |
generate_btn = gr.Button("Generate Synthetic Data", variant="primary") | |
with gr.Column(): | |
gen_status = gr.Textbox(label="Generation Status", interactive=False) | |
synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False) | |
with gr.Row(): | |
download_btn = gr.DownloadButton("Download CSV", variant="secondary") | |
comparison_plot = gr.Plot(label="Data Comparison") | |
# Event handlers | |
init_btn.click( | |
initialize_sdk, | |
outputs=[init_status, init_status] | |
) | |
train_btn.click( | |
train_model, | |
inputs=[uploaded_data, model_name, epochs, max_training_time, batch_size, value_protection], | |
outputs=[train_status, train_status] | |
) | |
get_report_btn.click( | |
get_quality_report, | |
outputs=[quality_report] | |
) | |
generate_btn.click( | |
generate_data, | |
inputs=[gen_size], | |
outputs=[synthetic_data, gen_status] | |
) | |
# Update download button when synthetic data changes | |
synthetic_data.change( | |
download_csv, | |
inputs=[synthetic_data], | |
outputs=[download_btn] | |
) | |
# Create comparison plot when both datasets are available | |
synthetic_data.change( | |
create_comparison_plot, | |
inputs=[uploaded_data, synthetic_data], | |
outputs=[comparison_plot] | |
) | |
# Handle file upload with size and column limits | |
def process_uploaded_file(file): | |
if file is None: | |
return None, "No file uploaded", gr.update(visible=False) | |
try: | |
# Read the CSV file | |
df = pd.read_csv(file.name) | |
# # Check column limit (max 20 columns) | |
# if len(df.columns) > 20: | |
# return None, f"β Too many columns! Maximum allowed: 20, found: {len(df.columns)}. Please reduce the number of columns in your CSV file.", gr.update(visible=False) | |
# # Check row limit (max 10,000 records) | |
# if len(df) > 10000: | |
# return None, f"β Too many records! Maximum allowed: 10,000, found: {len(df)}. Please reduce the number of rows in your CSV file.", gr.update(visible=False) | |
# # Check minimum requirements | |
# if len(df) < 1000: | |
# return None, f"β Too few records! Minimum required: 1,000, found: {len(df)}. Please provide more data for training.", gr.update(visible=False) | |
# if len(df.columns) < 2: | |
# return None, f"β Too few columns! Minimum required: 2, found: {len(df.columns)}. Please provide more columns for training.", gr.update(visible=False) | |
# Success message with file info | |
success_msg = f"β File uploaded successfully! {len(df)} rows Γ {len(df.columns)} columns" | |
# Generate memory usage info | |
memory_info = generator.estimate_memory_usage(df) | |
return df, success_msg, gr.update(value=memory_info, visible=True) | |
except Exception as e: | |
return None, f"β Error reading file: {str(e)}", gr.update(visible=False) | |
file_upload.change( | |
process_uploaded_file, | |
inputs=[file_upload], | |
outputs=[uploaded_data, train_status, memory_info] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |