ai-toolkit / toolkit /accelerator.py
jbilcke-hf's picture
jbilcke-hf HF Staff
Upload 430 files
3cc1e25 verified
raw
history blame
551 Bytes
from accelerate import Accelerator
from diffusers.utils.torch_utils import is_compiled_module
global_accelerator = None
def get_accelerator() -> Accelerator:
global global_accelerator
if global_accelerator is None:
global_accelerator = Accelerator()
return global_accelerator
def unwrap_model(model):
try:
accelerator = get_accelerator()
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
except Exception as e:
pass
return model