File size: 1,823 Bytes
0daa129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn.functional as F

def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")
if is_torch2_available():
    from .attention_processor import HairAttnProcessor2_0 as HairAttnProcessor, AttnProcessor2_0 as AttnProcessor
else:
    from .attention_processor import HairAttnProcessor, AttnProcessor

def adapter_injection(unet, device="cuda", dtype=torch.float32, use_resampler=False, continue_learning_path=None):
    device = device
    dtype = dtype
    # load Hair attention layers
    attn_procs = {}
    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = HairAttnProcessor(hidden_size=hidden_size, cross_attention_dim=hidden_size, scale=1, use_resampler=use_resampler).to(device, dtype=dtype)
        else:
            attn_procs[name] = AttnProcessor()
    unet.set_attn_processor(attn_procs)
    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
    adapter_layers = adapter_modules
    adapter_layers.to(device, dtype=dtype)

    return adapter_layers

def set_scale(unet, scale):
    for attn_processor in unet.attn_processors.values():
        if isinstance(attn_processor, HairAttnProcessor):
            attn_processor.scale = scale