Spaces:
Paused
Paused
| ''' | |
| ostris/ai-toolkit on https://modal.com | |
| Run training with the following command: | |
| modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml | |
| ''' | |
| import os | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| import sys | |
| import modal | |
| from dotenv import load_dotenv | |
| # Load the .env file if it exists | |
| load_dotenv() | |
| sys.path.insert(0, "/root/ai-toolkit") | |
| # must come before ANY torch or fastai imports | |
| # import toolkit.cuda_malloc | |
| # turn off diffusers telemetry until I can figure out how to make it opt-in | |
| os.environ['DISABLE_TELEMETRY'] = 'YES' | |
| # define the volume for storing model outputs, using "creating volumes lazily": https://modal.com/docs/guide/volumes | |
| # you will find your model, samples and optimizer stored in: https://modal.com/storage/your-username/main/flux-lora-models | |
| model_volume = modal.Volume.from_name("flux-lora-models", create_if_missing=True) | |
| # modal_output, due to "cannot mount volume on non-empty path" requirement | |
| MOUNT_DIR = "/root/ai-toolkit/modal_output" # modal_output, due to "cannot mount volume on non-empty path" requirement | |
| # define modal app | |
| image = ( | |
| modal.Image.debian_slim(python_version="3.11") | |
| # install required system and pip packages, more about this modal approach: https://modal.com/docs/examples/dreambooth_app | |
| .apt_install("libgl1", "libglib2.0-0") | |
| .pip_install( | |
| "python-dotenv", | |
| "torch", | |
| "diffusers[torch]", | |
| "transformers", | |
| "ftfy", | |
| "torchvision", | |
| "oyaml", | |
| "opencv-python", | |
| "albumentations", | |
| "safetensors", | |
| "lycoris-lora==1.8.3", | |
| "flatten_json", | |
| "pyyaml", | |
| "tensorboard", | |
| "kornia", | |
| "invisible-watermark", | |
| "einops", | |
| "accelerate", | |
| "toml", | |
| "pydantic", | |
| "omegaconf", | |
| "k-diffusion", | |
| "open_clip_torch", | |
| "timm", | |
| "prodigyopt", | |
| "controlnet_aux==0.0.7", | |
| "bitsandbytes", | |
| "hf_transfer", | |
| "lpips", | |
| "pytorch_fid", | |
| "optimum-quanto", | |
| "sentencepiece", | |
| "huggingface_hub", | |
| "peft" | |
| ) | |
| ) | |
| # mount for the entire ai-toolkit directory | |
| # example: "/Users/username/ai-toolkit" is the local directory, "/root/ai-toolkit" is the remote directory | |
| code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit") | |
| # create the Modal app with the necessary mounts and volumes | |
| app = modal.App(name="flux-lora-training", image=image, mounts=[code_mount], volumes={MOUNT_DIR: model_volume}) | |
| # Check if we have DEBUG_TOOLKIT in env | |
| if os.environ.get("DEBUG_TOOLKIT", "0") == "1": | |
| # Set torch to trace mode | |
| import torch | |
| torch.autograd.set_detect_anomaly(True) | |
| import argparse | |
| from toolkit.job import get_job | |
| def print_end_message(jobs_completed, jobs_failed): | |
| failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" | |
| completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" | |
| print("") | |
| print("========================================") | |
| print("Result:") | |
| if len(completed_string) > 0: | |
| print(f" - {completed_string}") | |
| if len(failure_string) > 0: | |
| print(f" - {failure_string}") | |
| print("========================================") | |
| def main(config_file_list_str: str, recover: bool = False, name: str = None): | |
| # convert the config file list from a string to a list | |
| config_file_list = config_file_list_str.split(",") | |
| jobs_completed = 0 | |
| jobs_failed = 0 | |
| print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") | |
| for config_file in config_file_list: | |
| try: | |
| job = get_job(config_file, name) | |
| job.config['process'][0]['training_folder'] = MOUNT_DIR | |
| os.makedirs(MOUNT_DIR, exist_ok=True) | |
| print(f"Training outputs will be saved to: {MOUNT_DIR}") | |
| # run the job | |
| job.run() | |
| # commit the volume after training | |
| model_volume.commit() | |
| job.cleanup() | |
| jobs_completed += 1 | |
| except Exception as e: | |
| print(f"Error running job: {e}") | |
| jobs_failed += 1 | |
| if not recover: | |
| print_end_message(jobs_completed, jobs_failed) | |
| raise e | |
| print_end_message(jobs_completed, jobs_failed) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| # require at least one config file | |
| parser.add_argument( | |
| 'config_file_list', | |
| nargs='+', | |
| type=str, | |
| help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially' | |
| ) | |
| # flag to continue if a job fails | |
| parser.add_argument( | |
| '-r', '--recover', | |
| action='store_true', | |
| help='Continue running additional jobs even if a job fails' | |
| ) | |
| # optional name replacement for config file | |
| parser.add_argument( | |
| '-n', '--name', | |
| type=str, | |
| default=None, | |
| help='Name to replace [name] tag in config file, useful for shared config file' | |
| ) | |
| args = parser.parse_args() | |
| # convert list of config files to a comma-separated string for Modal compatibility | |
| config_file_list_str = ",".join(args.config_file_list) | |
| main.call(config_file_list_str=config_file_list_str, recover=args.recover, name=args.name) | |