File size: 6,295 Bytes
ab0b470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import os, torch
import os.path as osp
import warnings
from collections import OrderedDict
from safetensors.torch import save_file
from scepter.modules.solver.hooks import CheckpointHook, BackwardHook
from scepter.modules.solver.hooks.registry import HOOKS
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.distribute import we
from scepter.modules.utils.file_system import FS

_DEFAULT_CHECKPOINT_PRIORITY = 300

def convert_to_comfyui_lora(ori_sd, prefix = "lora_unet"):
    new_ckpt = OrderedDict()
    for k,v in ori_sd.items():
        new_k = k.replace(".lora_A.0_SwiftLoRA.", ".lora_down.").replace(".lora_B.0_SwiftLoRA.", ".lora_up.")
        new_k = prefix + "_" + new_k.split(".lora")[0].replace("model.", "").replace(".", "_") + ".lora" + new_k.split(".lora")[1]
        alpha_k = new_k.split(".lora")[0] + ".alpha"
        new_ckpt[new_k] = v
        if "lora_up" in new_k:
            alpha = v.shape[-1]
        elif "lora_down" in new_k:
            alpha = v.shape[0]
        new_ckpt[alpha_k] = torch.tensor(float(alpha)).to(v)
    return new_ckpt

@HOOKS.register_class()
class ACECheckpointHook(CheckpointHook):
    """ Checkpoint resume or save hook.
    Args:
        interval (int): Save interval, by epoch.
        save_best (bool): Save the best checkpoint by a metric key, default is False.
        save_best_by (str): How to get the best the checkpoint by the metric key, default is ''.
            + means the higher the best (default).
            - means the lower the best.
            E.g. +acc@1, -err@1, acc@5(same as +acc@5)
    """

    def __init__(self, cfg, logger=None):
        super(ACECheckpointHook, self).__init__(cfg, logger=logger)

    def after_iter(self, solver):
        super().after_iter(solver)
        if solver.total_iter != 0 and (
            (solver.total_iter + 1) % self.interval == 0
                or solver.total_iter == solver.max_steps - 1):
            from swift import SwiftModel
            if isinstance(solver.model, SwiftModel) or (
                    hasattr(solver.model, 'module')
                    and isinstance(solver.model.module, SwiftModel)):
                save_path = osp.join(
                    solver.work_dir,
                    'checkpoints/{}-{}'.format(self.save_name_prefix,
                                               solver.total_iter + 1))
                if we.rank == 0:
                    tuner_model = os.path.join(save_path, '0_SwiftLoRA', 'adapter_model.bin')
                    save_model = os.path.join(save_path, '0_SwiftLoRA', 'comfyui_model.safetensors')
                    if FS.exists(tuner_model):
                        with FS.get_from(tuner_model) as local_file:
                            swift_lora_sd = torch.load(local_file, weights_only=True)
                        safetensor_lora_sd = convert_to_comfyui_lora(swift_lora_sd)
                        with FS.put_to(save_model) as local_file:
                            save_file(safetensor_lora_sd, local_file)
    @staticmethod
    def get_config_template():
        return dict_to_yaml('hook',
                            __class__.__name__,
                            ACECheckpointHook.para_dict,
                            set_name=True)

@HOOKS.register_class()
class ACEBackwardHook(BackwardHook):
    def grad_clip(self, optimizer):
        for params_group in optimizer.param_groups:
            train_params = []
            for param in params_group['params']:
                if param.requires_grad:
                    train_params.append(param)
            # print(len(train_params), self.gradient_clip)
            torch.nn.utils.clip_grad_norm_(parameters=train_params,
                                       max_norm=self.gradient_clip)

    def after_iter(self, solver):
        if solver.optimizer is not None and solver.is_train_mode:
            if solver.loss is None:
                warnings.warn(
                    'solver.loss should not be None in train mode, remember to call solver._reduce_scalar()!'
                )
                return
            if solver.scaler is not None:
                solver.scaler.scale(solver.loss /
                                    self.accumulate_step).backward()
                self.current_step += 1
                # Suppose profiler run after backward, so we need to set backward_prev_step
                # as the previous one step before the backward step
                if self.current_step % self.accumulate_step == 0:
                    solver.scaler.unscale_(solver.optimizer)
                    if self.gradient_clip > 0:
                        self.grad_clip(solver.optimizer)
                    self.profile(solver)
                    solver.scaler.step(solver.optimizer)
                    solver.scaler.update()
                    solver.optimizer.zero_grad()
            else:
                (solver.loss / self.accumulate_step).backward()
                self.current_step += 1
                # Suppose profiler run after backward, so we need to set backward_prev_step
                # as the previous one step before the backward step
                if self.current_step % self.accumulate_step == 0:
                    if self.gradient_clip > 0:
                        self.grad_clip(solver.optimizer)
                    self.profile(solver)
                    solver.optimizer.step()
                    solver.optimizer.zero_grad()
            if solver.lr_scheduler:
                if self.current_step % self.accumulate_step == 0:
                    solver.lr_scheduler.step()
            if self.current_step % self.accumulate_step == 0:
                setattr(solver, 'backward_step', True)
                self.current_step = 0
            else:
                setattr(solver, 'backward_step', False)
            solver.loss = None
        if self.empty_cache_step > 0 and solver.total_iter % self.empty_cache_step == 0:
            torch.cuda.empty_cache()

    @staticmethod
    def get_config_template():
        return dict_to_yaml('hook',
                            __class__.__name__,
                            ACEBackwardHook.para_dict,
                            set_name=True)