Spaces:
Runtime error
Runtime error
File size: 3,742 Bytes
05fd390 4efa9a6 05fd390 62b04c4 463536c 4efa9a6 05fd390 118c8fd 05fd390 118c8fd 05fd390 463536c 05fd390 463536c 05fd390 62b04c4 05fd390 463536c 383a495 463536c 05fd390 463536c 05fd390 8f30316 05fd390 4efa9a6 05fd390 f559d19 118c8fd 05fd390 463536c 05fd390 463536c 05fd390 86cbf7f 463536c 118c8fd 463536c 86cbf7f 05fd390 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
import os
import subprocess
from huggingface_hub import snapshot_download
hf_token = os.environ.get("HF_TOKEN")
def set_accelerate_default_config():
try:
subprocess.run(["accelerate", "config", "default"], check=True)
print("Accelerate default config set successfully!")
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")
def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps):
script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
command = [
"accelerate",
"launch",
script_filename, # Use the local script
"--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
"--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
f"--instance_data_dir={instance_data_dir}",
f"--output_dir={lora_trained_xl_folder}",
"--mixed_precision=fp16",
f"--instance_prompt={instance_prompt}",
"--resolution=1024",
"--train_batch_size=2",
"--gradient_accumulation_steps=2",
"--gradient_checkpointing",
"--learning_rate=1e-4",
"--lr_scheduler=constant",
"--lr_warmup_steps=0",
"--enable_xformers_memory_efficient_attention",
"--mixed_precision=fp16",
"--use_8bit_adam",
f"--max_train_steps={max_train_steps}",
f"--checkpointing_steps={checkpoint_steps}",
"--seed=0",
"--push_to_hub",
f"--hub_token={hf_token}"
]
try:
subprocess.run(command, check=True)
print("Training is finished!")
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")
def main(dataset_id,
lora_trained_xl_folder,
instance_prompt,
max_train_steps,
checkpoint_steps):
dataset_repo = dataset_id
# Automatically set local_dir based on the last part of dataset_repo
repo_parts = dataset_repo.split("/")
local_dir = f"./{repo_parts[-1]}" # Use the last part of the split
# Check if the directory exists and create it if necessary
if not os.path.exists(local_dir):
os.makedirs(local_dir)
gr.Info("Downloading dataset ...")
snapshot_download(
dataset_repo,
local_dir=local_dir,
repo_type="dataset",
ignore_patterns=".gitattributes",
token=hf_token
)
set_accelerate_default_config()
gr.Info("Training begins ...")
instance_data_dir = repo_parts[-1]
train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps)
return f"Done, your trained model has been stored in your models library: your_user_name/{lora-trained-xl-folder}"
with gr.Blocks() as demo:
with gr.Column():
dataset_id = gr.Textbox(label="Dataset ID", placeholder="diffusers/dog-example")
instance_prompt = gr.Textbox(label="Concept prompt", info="concept prompt - use a unique, made up word to avoid collisions")
model_output_folder = gr.Textbox(label="Output model folder name", placeholder="lora-trained-xl-folder")
with gr.Row():
max_train_steps = gr.Number(value=500)
checkpoint_steps = gr.Number(value=100)
train_button = gr.Button("Train !")
status = gr.Textbox(labe="Training status")
train_button.click(
fn = main,
inputs = [
dataset_id,
model_output_folder,
instance_prompt,
max_train_steps,
checkpoint_steps
],
outputs = [status]
)
demo.queue().launch() |