Spaces:
Build error
Build error
import gradio as gr | |
import os | |
from .common_gui import ( | |
get_file_path, | |
get_folder_path, | |
set_pretrained_model_name_or_path_input, | |
scriptdir, | |
list_dirs, | |
list_files, | |
create_refresh_button, | |
) | |
from .class_gui_config import KohyaSSGUIConfig | |
folder_symbol = "\U0001f4c2" # π | |
refresh_symbol = "\U0001f504" # π | |
save_style_symbol = "\U0001f4be" # πΎ | |
document_symbol = "\U0001F4C4" # π | |
default_models = [ | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
"stabilityai/stable-diffusion-xl-refiner-1.0", | |
"stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned", | |
"stabilityai/stable-diffusion-2-1-base", | |
"stabilityai/stable-diffusion-2-base", | |
"stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned", | |
"stabilityai/stable-diffusion-2-1", | |
"stabilityai/stable-diffusion-2", | |
"runwayml/stable-diffusion-v1-5", | |
"CompVis/stable-diffusion-v1-4", | |
] | |
class SourceModel: | |
def __init__( | |
self, | |
save_model_as_choices=[ | |
"same as source model", | |
"ckpt", | |
"diffusers", | |
"diffusers_safetensors", | |
"safetensors", | |
], | |
save_precision_choices=[ | |
"float", | |
"fp16", | |
"bf16", | |
], | |
headless=False, | |
finetuning=False, | |
config: KohyaSSGUIConfig = {}, | |
): | |
self.headless = headless | |
self.save_model_as_choices = save_model_as_choices | |
self.finetuning = finetuning | |
self.config = config | |
# Set default directories if not provided | |
self.current_models_dir = self.config.get( | |
"model.models_dir", os.path.join(scriptdir, "models") | |
) | |
self.current_train_data_dir = self.config.get( | |
"model.train_data_dir", os.path.join(scriptdir, "data") | |
) | |
self.current_dataset_config_dir = self.config.get( | |
"model.dataset_config", os.path.join(scriptdir, "dataset_config") | |
) | |
model_checkpoints = list( | |
list_files( | |
self.current_models_dir, exts=[".ckpt", ".safetensors"], all=True | |
) | |
) | |
def list_models(path): | |
self.current_models_dir = ( | |
path if os.path.isdir(path) else os.path.dirname(path) | |
) | |
return default_models + list( | |
list_files(path, exts=[".ckpt", ".safetensors"], all=True) | |
) | |
def list_train_data_dirs(path): | |
self.current_train_data_dir = path if not path == "" else "." | |
return list(list_dirs(self.current_train_data_dir)) | |
def list_dataset_config_dirs(path: str) -> list: | |
""" | |
List directories and toml files in the dataset_config directory. | |
Parameters: | |
- path (str): The path to list directories and files from. | |
Returns: | |
- list: A list of directories and files. | |
""" | |
current_dataset_config_dir = path if not path == "" else "." | |
# Lists all .json files in the current configuration directory, used for populating dropdown choices. | |
return list( | |
list_files(current_dataset_config_dir, exts=[".toml"], all=True) | |
) | |
with gr.Accordion("Model", open=True): | |
with gr.Column(), gr.Group(): | |
model_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False) | |
model_ext_name = gr.Textbox(value="Model types", visible=False) | |
# Define the input elements | |
with gr.Row(): | |
with gr.Column(), gr.Row(): | |
self.model_list = gr.Textbox(visible=False, value="") | |
self.pretrained_model_name_or_path = gr.Dropdown( | |
label="Pretrained model name or path", | |
choices=default_models + model_checkpoints, | |
value=self.config.get("model.models_dir", "runwayml/stable-diffusion-v1-5"), | |
allow_custom_value=True, | |
visible=True, | |
min_width=100, | |
) | |
create_refresh_button( | |
self.pretrained_model_name_or_path, | |
lambda: None, | |
lambda: {"choices": list_models(self.current_models_dir)}, | |
"open_folder_small", | |
) | |
self.pretrained_model_name_or_path_file = gr.Button( | |
document_symbol, | |
elem_id="open_folder_small", | |
elem_classes=["tool"], | |
visible=(not headless), | |
) | |
self.pretrained_model_name_or_path_file.click( | |
get_file_path, | |
inputs=[self.pretrained_model_name_or_path, model_ext, model_ext_name], | |
outputs=self.pretrained_model_name_or_path, | |
show_progress=False, | |
) | |
self.pretrained_model_name_or_path_folder = gr.Button( | |
folder_symbol, | |
elem_id="open_folder_small", | |
elem_classes=["tool"], | |
visible=(not headless), | |
) | |
self.pretrained_model_name_or_path_folder.click( | |
get_folder_path, | |
inputs=self.pretrained_model_name_or_path, | |
outputs=self.pretrained_model_name_or_path, | |
show_progress=False, | |
) | |
with gr.Column(), gr.Row(): | |
self.output_name = gr.Textbox( | |
label="Trained Model output name", | |
placeholder="(Name of the model to output)", | |
value=self.config.get("model.output_name", "last"), | |
interactive=True, | |
) | |
with gr.Row(): | |
with gr.Column(), gr.Row(): | |
self.train_data_dir = gr.Dropdown( | |
label=( | |
"Image folder (containing training images subfolders)" | |
if not finetuning | |
else "Image folder (containing training images)" | |
), | |
choices=[""] | |
+ list_train_data_dirs(self.current_train_data_dir), | |
value=self.config.get("model.train_data_dir", ""), | |
interactive=True, | |
allow_custom_value=True, | |
) | |
create_refresh_button( | |
self.train_data_dir, | |
lambda: None, | |
lambda: { | |
"choices": [""] | |
+ list_train_data_dirs(self.current_train_data_dir) | |
}, | |
"open_folder_small", | |
) | |
self.train_data_dir_folder = gr.Button( | |
"π", | |
elem_id="open_folder_small", | |
elem_classes=["tool"], | |
visible=(not self.headless), | |
) | |
self.train_data_dir_folder.click( | |
get_folder_path, | |
outputs=self.train_data_dir, | |
show_progress=False, | |
) | |
with gr.Column(), gr.Row(): | |
# Toml directory dropdown | |
self.dataset_config = gr.Dropdown( | |
label="Dataset config file (Optional. Select the toml configuration file to use for the dataset)", | |
choices=[self.config.get("model.dataset_config", "")] | |
+ list_dataset_config_dirs(self.current_dataset_config_dir), | |
value=self.config.get("model.dataset_config", ""), | |
interactive=True, | |
allow_custom_value=True, | |
) | |
# Refresh button for dataset_config directory | |
create_refresh_button( | |
self.dataset_config, | |
lambda: None, | |
lambda: { | |
"choices": [""] | |
+ list_dataset_config_dirs( | |
self.current_dataset_config_dir | |
) | |
}, | |
"open_folder_small", | |
) | |
# Toml directory button | |
self.dataset_config_folder = gr.Button( | |
document_symbol, | |
elem_id="open_folder_small", | |
elem_classes=["tool"], | |
visible=(not self.headless), | |
) | |
# Toml directory button click event | |
self.dataset_config_folder.click( | |
get_file_path, | |
inputs=[ | |
self.dataset_config, | |
gr.Textbox(value="*.toml", visible=False), | |
gr.Textbox(value="Dataset config types", visible=False), | |
], | |
outputs=self.dataset_config, | |
show_progress=False, | |
) | |
# Change event for dataset_config directory dropdown | |
self.dataset_config.change( | |
fn=lambda path: gr.Dropdown( | |
choices=[""] + list_dataset_config_dirs(path) | |
), | |
inputs=self.dataset_config, | |
outputs=self.dataset_config, | |
show_progress=False, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
self.v2 = gr.Checkbox( | |
label="v2", value=False, visible=False, min_width=60 | |
) | |
self.v_parameterization = gr.Checkbox( | |
label="v_parameterization", | |
value=False, | |
visible=False, | |
min_width=130, | |
) | |
self.sdxl_checkbox = gr.Checkbox( | |
label="SDXL", | |
value=False, | |
visible=False, | |
min_width=60, | |
) | |
with gr.Column(): | |
gr.Group(visible=False) | |
with gr.Row(): | |
self.training_comment = gr.Textbox( | |
label="Training comment", | |
placeholder="(Optional) Add training comment to be included in metadata", | |
interactive=True, | |
value=self.config.get("model.training_comment", ""), | |
) | |
with gr.Row(): | |
self.save_model_as = gr.Radio( | |
save_model_as_choices, | |
label="Save trained model as", | |
value=self.config.get("model.save_model_as", "safetensors"), | |
) | |
self.save_precision = gr.Radio( | |
save_precision_choices, | |
label="Save precision", | |
value=self.config.get("model.save_precision", "fp16"), | |
) | |
self.pretrained_model_name_or_path.change( | |
fn=lambda path: set_pretrained_model_name_or_path_input( | |
path, refresh_method=list_models | |
), | |
inputs=[ | |
self.pretrained_model_name_or_path, | |
], | |
outputs=[ | |
self.pretrained_model_name_or_path, | |
self.v2, | |
self.v_parameterization, | |
self.sdxl_checkbox, | |
], | |
show_progress=False, | |
) | |
self.train_data_dir.change( | |
fn=lambda path: gr.Dropdown( | |
choices=[""] + list_train_data_dirs(path) | |
), | |
inputs=self.train_data_dir, | |
outputs=self.train_data_dir, | |
show_progress=False, | |
) | |