Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import subprocess | |
from huggingface_hub import snapshot_download | |
hf_token = os.environ.get("HF_TOKEN") | |
def set_accelerate_default_config(): | |
try: | |
subprocess.run(["accelerate", "config", "default"], check=True) | |
print("Accelerate default config set successfully!") | |
except subprocess.CalledProcessError as e: | |
print(f"An error occurred: {e}") | |
def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps): | |
script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder | |
command = [ | |
"accelerate", | |
"launch", | |
script_filename, # Use the local script | |
"--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0", | |
"--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix", | |
f"--instance_data_dir={instance_data_dir}", | |
f"--output_dir={lora_trained_xl_folder}", | |
"--mixed_precision=fp16", | |
f"--instance_prompt={instance_prompt}", | |
"--resolution=1024", | |
"--train_batch_size=2", | |
"--gradient_accumulation_steps=2", | |
"--gradient_checkpointing", | |
"--learning_rate=1e-4", | |
"--lr_scheduler=constant", | |
"--lr_warmup_steps=0", | |
"--enable_xformers_memory_efficient_attention", | |
"--mixed_precision=fp16", | |
"--use_8bit_adam", | |
f"--max_train_steps={max_train_steps}", | |
f"--checkpointing_steps={checkpoint_steps}", | |
"--seed=0", | |
"--push_to_hub", | |
f"--hub_token={hf_token}" | |
] | |
try: | |
subprocess.run(command, check=True) | |
print("Training is finished!") | |
except subprocess.CalledProcessError as e: | |
print(f"An error occurred: {e}") | |
def main(dataset_id, | |
lora_trained_xl_folder, | |
instance_prompt, | |
max_train_steps, | |
checkpoint_steps): | |
dataset_repo = dataset_id | |
# Automatically set local_dir based on the last part of dataset_repo | |
repo_parts = dataset_repo.split("/") | |
local_dir = f"./{repo_parts[-1]}" # Use the last part of the split | |
# Check if the directory exists and create it if necessary | |
if not os.path.exists(local_dir): | |
os.makedirs(local_dir) | |
gr.Info("Downloading dataset ...") | |
snapshot_download( | |
dataset_repo, | |
local_dir=local_dir, | |
repo_type="dataset", | |
ignore_patterns=".gitattributes", | |
token=hf_token | |
) | |
set_accelerate_default_config() | |
gr.Info("Training begins ...") | |
instance_data_dir = repo_parts[-1] | |
train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps) | |
return f"Done, your trained model has been stored in your models library: your_user_name/{lora-trained-xl-folder}" | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
dataset_id = gr.Textbox(label="Dataset ID", placeholder="diffusers/dog-example") | |
instance_prompt = gr.Textbox(label="Concept prompt", info="concept prompt - use a unique, made up word to avoid collisions") | |
model_output_folder = gr.Textbox(label="Output model folder name", placeholder="lora-trained-xl-folder") | |
with gr.Row(): | |
max_train_steps = gr.Number(value=500) | |
checkpoint_steps = gr.Number(value=100) | |
train_button = gr.Button("Train !") | |
status = gr.Textbox(labe="Training status") | |
train_button.click( | |
fn = main, | |
inputs = [ | |
dataset_id, | |
model_output_folder, | |
instance_prompt, | |
max_train_steps, | |
checkpoint_steps | |
], | |
outputs = [status] | |
) | |
demo.queue().launch() |