Liyang Chen
commited on
Commit
·
e1b47be
1
Parent(s):
88fbe87
full pipeline
Browse files
init_cross_attn.py
CHANGED
|
@@ -9,6 +9,17 @@ from lightning.pytorch import seed_everything
|
|
| 9 |
import random
|
| 10 |
from datetime import datetime
|
| 11 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
from ThinkSound.data.datamodule import DataModule
|
| 14 |
from ThinkSound.models import create_model_from_config
|
|
@@ -62,6 +73,7 @@ def main():
|
|
| 62 |
|
| 63 |
# Step 5: 初始化 cross-attn 模块(只初始化新增部分)
|
| 64 |
def init_cross_attn_weights(module):
|
|
|
|
| 65 |
if isinstance(module, nn.Linear):
|
| 66 |
nn.init.xavier_uniform_(module.weight)
|
| 67 |
if module.bias is not None:
|
|
@@ -69,7 +81,14 @@ def main():
|
|
| 69 |
elif isinstance(module, nn.LayerNorm):
|
| 70 |
nn.init.ones_(module.weight)
|
| 71 |
nn.init.zeros_(module.bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
|
|
|
|
|
|
| 73 |
# 只遍历 cross-attn 模块进行初始化
|
| 74 |
for name, module in model.named_modules():
|
| 75 |
if 'cross_attn' in name:
|
|
@@ -77,7 +96,7 @@ def main():
|
|
| 77 |
print(f"[INIT] Initialized {name}")
|
| 78 |
|
| 79 |
# Step 6: 保存新权重
|
| 80 |
-
torch.save(model.state_dict(), 'ckpts/
|
| 81 |
print("[DONE] New checkpoint saved with old weights + initialized cross-attn.")
|
| 82 |
|
| 83 |
if __name__ == '__main__':
|
|
|
|
| 9 |
import random
|
| 10 |
from datetime import datetime
|
| 11 |
import numpy as np
|
| 12 |
+
import sys
|
| 13 |
+
|
| 14 |
+
# 获取当前脚本所在目录(ckpts/)
|
| 15 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
|
| 17 |
+
# 项目根目录 = ckpts 的上级目录
|
| 18 |
+
project_root = os.path.abspath(os.path.join(current_dir, '..'))
|
| 19 |
+
|
| 20 |
+
# 添加项目根目录到 sys.path
|
| 21 |
+
if project_root not in sys.path:
|
| 22 |
+
sys.path.insert(0, project_root)
|
| 23 |
|
| 24 |
from ThinkSound.data.datamodule import DataModule
|
| 25 |
from ThinkSound.models import create_model_from_config
|
|
|
|
| 73 |
|
| 74 |
# Step 5: 初始化 cross-attn 模块(只初始化新增部分)
|
| 75 |
def init_cross_attn_weights(module):
|
| 76 |
+
from einops.layers.torch import Rearrange
|
| 77 |
if isinstance(module, nn.Linear):
|
| 78 |
nn.init.xavier_uniform_(module.weight)
|
| 79 |
if module.bias is not None:
|
|
|
|
| 81 |
elif isinstance(module, nn.LayerNorm):
|
| 82 |
nn.init.ones_(module.weight)
|
| 83 |
nn.init.zeros_(module.bias)
|
| 84 |
+
elif isinstance(module, nn.RMSNorm) or module.__class__.__name__ == "RMSNorm":
|
| 85 |
+
if hasattr(module, 'weight'):
|
| 86 |
+
nn.init.ones_(module.weight)
|
| 87 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 88 |
+
nn.init.zeros_(module.bias)
|
| 89 |
|
| 90 |
+
import pdb; pdb.set_trace()
|
| 91 |
+
pass
|
| 92 |
# 只遍历 cross-attn 模块进行初始化
|
| 93 |
for name, module in model.named_modules():
|
| 94 |
if 'cross_attn' in name:
|
|
|
|
| 96 |
print(f"[INIT] Initialized {name}")
|
| 97 |
|
| 98 |
# Step 6: 保存新权重
|
| 99 |
+
torch.save(model.state_dict(), 'ckpts/row_thinksound_light_cross_attn.ckpt')
|
| 100 |
print("[DONE] New checkpoint saved with old weights + initialized cross-attn.")
|
| 101 |
|
| 102 |
if __name__ == '__main__':
|
row_thinksound_light_cross_attn.ckpt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db9f641e234f91448d9ca1ec9254339ad64414cbdc2c637311ff06b985d8fb65
|
| 3 |
+
size 6026895638
|
thinksound_light_cross_attn.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:1a80397e1a4f44c9e698ce2b6cbacf5cc775ef598907e41b2241b285b9e7eb78
|
| 3 |
-
size 5909451670
|
|
|
|
|
|
|
|
|
|
|
|