File size: 5,482 Bytes
76ccb95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee7393
76ccb95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee7393
76ccb95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee7393
 
76ccb95
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from diffusers import AutoencoderKL, DDIMScheduler
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from animatediff.models.unet import UNet3DConditionModel
from omegaconf import OmegaConf
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import load_weights
from safetensors import safe_open
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
from faceadapter.face_adapter import FaceAdapterPlusForVideoLora

model_style_type2base_model_path = {
    "realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors",
    "anime": "models/aingdiffusion/aingdiffusion_v170_ar.safetensors",
    "photorealistic": "models/sar/sar.safetensors" # LDM format. Needs to be converted.
}

def load_model(model_style_type="realistic", device="cuda"):
    inference_config    = "inference-v2.yaml"
    sd_version          = "models/animatediff/sd"
    id_ckpt             = "models/animator.ckpt"
    image_encoder_path  = "models/image_encoder"

    base_model_path     = model_style_type2base_model_path[model_style_type]

    motion_module_path="models/v3_sd15_mm.ckpt" 
    motion_lora_path = "models/v3_sd15_adapter.ckpt"
    inference_config = OmegaConf.load(inference_config)    

    tokenizer    = CLIPTokenizer.from_pretrained(sd_version, subfolder="tokenizer",torch_dtype=torch.float16,
    )
    text_encoder = CLIPTextModel.from_pretrained(sd_version, subfolder="text_encoder",torch_dtype=torch.float16,
    ).to(device=device)
    vae          = AutoencoderKL.from_pretrained(sd_version, subfolder="vae",torch_dtype=torch.float16,
    ).to(device=device)
    unet = UNet3DConditionModel.from_pretrained_2d(sd_version, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)
    ).to(device=device)
    # unet.to(dtype=torch.float16) does not work on hf spaces.
    unet = unet.half()

    pipeline = AnimationPipeline(
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
            controlnet=None,
            #beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1
            scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
            # scheduler=DPMSolverMultistepScheduler(**OmegaConf.to_container(inference_config.DPMSolver_scheduler_kwargs)
            # scheduler=EulerAncestralDiscreteScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
            # scheduler=EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1
                                   ),
            torch_dtype=torch.float16,
            ).to(device=device)
    
    pipeline = load_weights(
            pipeline,
            # motion module
            motion_module_path         = motion_module_path,
            motion_module_lora_configs = [],
            # domain adapter
            adapter_lora_path          = motion_lora_path,
            adapter_lora_scale         = 1,
            # image layers
            dreambooth_model_path      = None,
            lora_model_path            = "",
            lora_alpha                 = 0.8
    ).to(device=device)
    
    if base_model_path != "":
        print(f"load dreambooth model from {base_model_path}")
        dreambooth_state_dict = {}
        with safe_open(base_model_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                dreambooth_state_dict[key] = f.get_tensor(key)
                        
            converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
            # print(vae)
            # vae ->to_q, to_k, to_v
            # print(converted_vae_checkpoint)
            convert_vae_keys = list(converted_vae_checkpoint.keys())
            for key in convert_vae_keys:
                if "encoder.mid_block.attentions" in key or "decoder.mid_block.attentions" in  key:
                    new_key = None
                    if "key" in key:
                        new_key = key.replace("key","to_k")
                    elif "query" in key:
                        new_key = key.replace("query","to_q")
                    elif "value" in key:
                        new_key = key.replace("value","to_v")
                    elif "proj_attn" in key:
                        new_key = key.replace("proj_attn","to_out.0")
                    if new_key:
                        converted_vae_checkpoint[new_key] = converted_vae_checkpoint.pop(key)

            pipeline.vae.load_state_dict(converted_vae_checkpoint)

            converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)
            m, u = pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
            print(f"### custom unet missing keys: {len(m)}; \n### unexpected keys: {len(u)};")

            pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict, dtype=torch.float16).to(device=device)
            
        del dreambooth_state_dict
        pipeline = pipeline.to(torch.float16)
        id_animator = FaceAdapterPlusForVideoLora(pipeline, image_encoder_path, id_ckpt, num_tokens=16,
                                                  device=torch.device(device), torch_type=torch.float16)

        return id_animator