File size: 3,728 Bytes
74dd4dc
 
 
 
 
 
 
 
 
 
 
e1b47be
 
 
 
 
 
 
 
 
 
 
74dd4dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1b47be
74dd4dc
 
 
 
 
 
 
e1b47be
 
 
 
 
74dd4dc
e1b47be
 
74dd4dc
 
 
 
 
 
 
e1b47be
74dd4dc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
from prefigure.prefigure import get_all_args, push_wandb_config
import json
import os
import re
import torchaudio
from lightning.pytorch import seed_everything
import random
from datetime import datetime
import numpy as np
import sys

# 获取当前脚本所在目录(ckpts/)
current_dir = os.path.dirname(os.path.abspath(__file__))

# 项目根目录 = ckpts 的上级目录
project_root = os.path.abspath(os.path.join(current_dir, '..'))

# 添加项目根目录到 sys.path
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from ThinkSound.data.datamodule import DataModule
from ThinkSound.models import create_model_from_config
from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
from ThinkSound.inference.sampling import sample, sample_discrete_euler
from pathlib import Path
from tqdm import tqdm

def main():
    args = get_all_args()

    if args.save_dir == '':
        args.save_dir = args.results_dir


    seed = args.seed
    if os.environ.get("SLURM_PROCID") is not None:
        seed += int(os.environ.get("SLURM_PROCID"))
    seed_everything(seed, workers=True)

    # Load config
    if args.model_config == '':
        args.model_config = "ThinkSound/configs/model_configs/thinksound.json"
    with open(args.model_config) as f:
        model_config = json.load(f)

    duration = float(args.duration_sec)
    sample_rate = model_config["sample_rate"]
    latent_length = round(44100 / 64 / 32 * duration)

    model_config["sample_size"] = duration * sample_rate
    model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24 * int(duration)
    model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8 * int(duration)
    model_config["model"]["diffusion"]["config"]["latent_seq_len"] = latent_length

    model = create_model_from_config(model_config)

    # model.load_state_dict(torch.load(args.ckpt_dir))

    # Step 2: 加载旧的 checkpoint(不包含cross-attn)
    old_ckpt = torch.load(args.ckpt_dir, map_location='cpu')

    # Step 3: 仅提取匹配的权重(名字和尺寸都要匹配)
    model_state = model.state_dict()
    matched_ckpt = {k: v for k, v in old_ckpt.items() if k in model_state and v.shape == model_state[k].shape}
    print(f"[INFO] Loaded {len(matched_ckpt)} keys from old checkpoint")

    # Step 4: 加载已有权重
    model_state.update(matched_ckpt)
    model.load_state_dict(model_state)

    # Step 5: 初始化 cross-attn 模块(只初始化新增部分)
    def init_cross_attn_weights(module):
        from einops.layers.torch import Rearrange
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
        elif isinstance(module, nn.RMSNorm) or module.__class__.__name__ == "RMSNorm":
            if hasattr(module, 'weight'):
                nn.init.ones_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                nn.init.zeros_(module.bias)

    import pdb; pdb.set_trace()
    pass
    # 只遍历 cross-attn 模块进行初始化
    for name, module in model.named_modules():
        if 'cross_attn' in name:
            module.apply(init_cross_attn_weights)
            print(f"[INIT] Initialized {name}")

    # Step 6: 保存新权重
    torch.save(model.state_dict(), 'ckpts/row_thinksound_light_cross_attn.ckpt')
    print("[DONE] New checkpoint saved with old weights + initialized cross-attn.")

if __name__ == '__main__':
    main()