# coding=utf-8 # Copyright 2022 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil import sys from pathlib import Path from typing import Dict, Optional, Union from uuid import uuid4 from huggingface_hub import HfFolder, Repository, whoami from . import __version__ from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging from .utils.import_utils import ( _flax_version, _jax_version, _onnxruntime_version, _torch_version, is_flax_available, is_modelcards_available, is_onnx_available, is_torch_available, ) if is_modelcards_available(): from modelcards import CardData, ModelCard logger = logging.get_logger(__name__) MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" SESSION_ID = uuid4().hex DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: """ Formats a user-agent string with basic info about a request. """ ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" if DISABLE_TELEMETRY: return ua + "; telemetry/off" if is_torch_available(): ua += f"; torch/{_torch_version}" if is_flax_available(): ua += f"; jax/{_jax_version}" ua += f"; flax/{_flax_version}" if is_onnx_available(): ua += f"; onnxruntime/{_onnxruntime_version}" # CI will set this value to True if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: ua += "; is_ci/true" if isinstance(user_agent, dict): ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) elif isinstance(user_agent, str): ua += "; " + user_agent return ua def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: token = HfFolder.get_token() if organization is None: username = whoami(token)["name"] return f"{username}/{model_id}" else: return f"{organization}/{model_id}" def init_git_repo(args, at_init: bool = False): """ Args: Initializes a git repo in `args.hub_model_id`. at_init (`bool`, *optional*, defaults to `False`): Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. """ deprecation_message = ( "Please use `huggingface_hub.Repository`. " "See `examples/unconditional_image_generation/train_unconditional.py` for an example." ) deprecate("init_git_repo()", "0.10.0", deprecation_message) if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: return hub_token = args.hub_token if hasattr(args, "hub_token") else None use_auth_token = True if hub_token is None else hub_token if not hasattr(args, "hub_model_id") or args.hub_model_id is None: repo_name = Path(args.output_dir).absolute().name else: repo_name = args.hub_model_id if "/" not in repo_name: repo_name = get_full_repo_name(repo_name, token=hub_token) try: repo = Repository( args.output_dir, clone_from=repo_name, use_auth_token=use_auth_token, private=args.hub_private_repo, ) except EnvironmentError: if args.overwrite_output_dir and at_init: # Try again after wiping output_dir shutil.rmtree(args.output_dir) repo = Repository( args.output_dir, clone_from=repo_name, use_auth_token=use_auth_token, ) else: raise repo.git_pull() # By default, ignore the checkpoint folders if not os.path.exists(os.path.join(args.output_dir, ".gitignore")): with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: writer.writelines(["checkpoint-*/"]) return repo def push_to_hub( args, pipeline, repo: Repository, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs, ) -> str: """ Parameters: Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. commit_message (`str`, *optional*, defaults to `"End of training"`): Message to commit while pushing. blocking (`bool`, *optional*, defaults to `True`): Whether the function should return only when the `git push` has finished. kwargs: Additional keyword arguments passed along to [`create_model_card`]. Returns: The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the commit and an object to track the progress of the commit if `blocking=True` """ deprecation_message = ( "Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. " "See `examples/unconditional_image_generation/train_unconditional.py` for an example." ) deprecate("push_to_hub()", "0.10.0", deprecation_message) if not hasattr(args, "hub_model_id") or args.hub_model_id is None: model_name = Path(args.output_dir).name else: model_name = args.hub_model_id.split("/")[-1] output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving pipeline checkpoint to {output_dir}") pipeline.save_pretrained(output_dir) # Only push from one node. if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: return # Cancel any async push in progress if blocking=True. The commits will all be pushed together. if ( blocking and len(repo.command_queue) > 0 and repo.command_queue[-1] is not None and not repo.command_queue[-1].is_done ): repo.command_queue[-1]._process.kill() git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True) # push separately the model card to be independent from the rest of the model create_model_card(args, model_name=model_name) try: repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True) except EnvironmentError as exc: logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") return git_head_commit_url def create_model_card(args, model_name): if not is_modelcards_available: raise ValueError( "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can" " install the package with `pip install modelcards`." ) if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: return hub_token = args.hub_token if hasattr(args, "hub_token") else None repo_name = get_full_repo_name(model_name, token=hub_token) model_card = ModelCard.from_template( card_data=CardData( # Card metadata object that will be converted to YAML block language="en", license="apache-2.0", library_name="diffusers", tags=[], datasets=args.dataset_name, metrics=[], ), template_path=MODEL_CARD_TEMPLATE_PATH, model_name=model_name, repo_name=repo_name, dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, learning_rate=args.learning_rate, train_batch_size=args.train_batch_size, eval_batch_size=args.eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None, adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, ema_power=args.ema_power if hasattr(args, "ema_power") else None, ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, mixed_precision=args.mixed_precision, ) card_path = os.path.join(args.output_dir, "README.md") model_card.save(card_path)