Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,295 Bytes
ab0b470 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import os, torch
import os.path as osp
import warnings
from collections import OrderedDict
from safetensors.torch import save_file
from scepter.modules.solver.hooks import CheckpointHook, BackwardHook
from scepter.modules.solver.hooks.registry import HOOKS
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.distribute import we
from scepter.modules.utils.file_system import FS
_DEFAULT_CHECKPOINT_PRIORITY = 300
def convert_to_comfyui_lora(ori_sd, prefix = "lora_unet"):
new_ckpt = OrderedDict()
for k,v in ori_sd.items():
new_k = k.replace(".lora_A.0_SwiftLoRA.", ".lora_down.").replace(".lora_B.0_SwiftLoRA.", ".lora_up.")
new_k = prefix + "_" + new_k.split(".lora")[0].replace("model.", "").replace(".", "_") + ".lora" + new_k.split(".lora")[1]
alpha_k = new_k.split(".lora")[0] + ".alpha"
new_ckpt[new_k] = v
if "lora_up" in new_k:
alpha = v.shape[-1]
elif "lora_down" in new_k:
alpha = v.shape[0]
new_ckpt[alpha_k] = torch.tensor(float(alpha)).to(v)
return new_ckpt
@HOOKS.register_class()
class ACECheckpointHook(CheckpointHook):
""" Checkpoint resume or save hook.
Args:
interval (int): Save interval, by epoch.
save_best (bool): Save the best checkpoint by a metric key, default is False.
save_best_by (str): How to get the best the checkpoint by the metric key, default is ''.
+ means the higher the best (default).
- means the lower the best.
E.g. +acc@1, -err@1, acc@5(same as +acc@5)
"""
def __init__(self, cfg, logger=None):
super(ACECheckpointHook, self).__init__(cfg, logger=logger)
def after_iter(self, solver):
super().after_iter(solver)
if solver.total_iter != 0 and (
(solver.total_iter + 1) % self.interval == 0
or solver.total_iter == solver.max_steps - 1):
from swift import SwiftModel
if isinstance(solver.model, SwiftModel) or (
hasattr(solver.model, 'module')
and isinstance(solver.model.module, SwiftModel)):
save_path = osp.join(
solver.work_dir,
'checkpoints/{}-{}'.format(self.save_name_prefix,
solver.total_iter + 1))
if we.rank == 0:
tuner_model = os.path.join(save_path, '0_SwiftLoRA', 'adapter_model.bin')
save_model = os.path.join(save_path, '0_SwiftLoRA', 'comfyui_model.safetensors')
if FS.exists(tuner_model):
with FS.get_from(tuner_model) as local_file:
swift_lora_sd = torch.load(local_file, weights_only=True)
safetensor_lora_sd = convert_to_comfyui_lora(swift_lora_sd)
with FS.put_to(save_model) as local_file:
save_file(safetensor_lora_sd, local_file)
@staticmethod
def get_config_template():
return dict_to_yaml('hook',
__class__.__name__,
ACECheckpointHook.para_dict,
set_name=True)
@HOOKS.register_class()
class ACEBackwardHook(BackwardHook):
def grad_clip(self, optimizer):
for params_group in optimizer.param_groups:
train_params = []
for param in params_group['params']:
if param.requires_grad:
train_params.append(param)
# print(len(train_params), self.gradient_clip)
torch.nn.utils.clip_grad_norm_(parameters=train_params,
max_norm=self.gradient_clip)
def after_iter(self, solver):
if solver.optimizer is not None and solver.is_train_mode:
if solver.loss is None:
warnings.warn(
'solver.loss should not be None in train mode, remember to call solver._reduce_scalar()!'
)
return
if solver.scaler is not None:
solver.scaler.scale(solver.loss /
self.accumulate_step).backward()
self.current_step += 1
# Suppose profiler run after backward, so we need to set backward_prev_step
# as the previous one step before the backward step
if self.current_step % self.accumulate_step == 0:
solver.scaler.unscale_(solver.optimizer)
if self.gradient_clip > 0:
self.grad_clip(solver.optimizer)
self.profile(solver)
solver.scaler.step(solver.optimizer)
solver.scaler.update()
solver.optimizer.zero_grad()
else:
(solver.loss / self.accumulate_step).backward()
self.current_step += 1
# Suppose profiler run after backward, so we need to set backward_prev_step
# as the previous one step before the backward step
if self.current_step % self.accumulate_step == 0:
if self.gradient_clip > 0:
self.grad_clip(solver.optimizer)
self.profile(solver)
solver.optimizer.step()
solver.optimizer.zero_grad()
if solver.lr_scheduler:
if self.current_step % self.accumulate_step == 0:
solver.lr_scheduler.step()
if self.current_step % self.accumulate_step == 0:
setattr(solver, 'backward_step', True)
self.current_step = 0
else:
setattr(solver, 'backward_step', False)
solver.loss = None
if self.empty_cache_step > 0 and solver.total_iter % self.empty_cache_step == 0:
torch.cuda.empty_cache()
@staticmethod
def get_config_template():
return dict_to_yaml('hook',
__class__.__name__,
ACEBackwardHook.para_dict,
set_name=True)
|