Spaces:
Sleeping
Sleeping
""" | |
Mô-đun mở rộng embedding của mô hình bằng cách thêm token mới vào vocab. | |
Áp dụng khi fine-tuning mô hình F5-TTS. | |
""" | |
import os | |
import random | |
import torch | |
from cached_path import cached_path | |
from safetensors.torch import load_file | |
# Định nghĩa seed để đảm bảo tái lập kết quả | |
SEED = 666 | |
def set_random_seed(seed: int): | |
""" Đặt seed cho các thư viện ngẫu nhiên để đảm bảo reproducibility. """ | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def load_vocab(file_path: str) -> list: | |
""" Đọc danh sách token từ file vocab. """ | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Không tìm thấy file: {file_path}") | |
with open(file_path, "r", encoding="utf8") as file: | |
return [line.strip() for line in file.readlines()] | |
def expand_model_embeddings(ckpt_path: str, new_ckpt_path: str, num_new_tokens: int = 42): | |
""" | |
Mở rộng embedding của mô hình bằng cách thêm token mới. | |
Args: | |
ckpt_path (str): Đường dẫn đến file checkpoint gốc. | |
new_ckpt_path (str): Đường dẫn để lưu checkpoint đã mở rộng. | |
num_new_tokens (int): Số lượng token mới cần thêm vào. | |
""" | |
if ckpt_path.endswith(".safetensors"): | |
ckpt = load_file(ckpt_path, device="cpu") | |
ckpt = {"ema_model_state_dict": ckpt} | |
elif ckpt_path.endswith(".pt"): | |
ckpt = torch.load(ckpt_path, map_location="cpu") | |
else: | |
raise ValueError("Định dạng checkpoint không được hỗ trợ. Chỉ hỗ trợ .safetensors hoặc .pt") | |
ema_sd = ckpt.get("ema_model_state_dict", {}) | |
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight" | |
if embed_key_ema not in ema_sd: | |
raise KeyError(f"Không tìm thấy khóa {embed_key_ema} trong checkpoint.") | |
old_embed_ema = ema_sd[embed_key_ema] | |
vocab_old, embed_dim = old_embed_ema.shape | |
vocab_new = vocab_old + num_new_tokens | |
def expand_embeddings(old_embeddings: torch.Tensor) -> torch.Tensor: | |
""" Mở rộng embeddings bằng cách thêm vector mới. """ | |
new_embeddings = torch.zeros((vocab_new, embed_dim)) | |
new_embeddings[:vocab_old] = old_embeddings | |
new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim)) | |
return new_embeddings | |
ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema]) | |
torch.save(ckpt, new_ckpt_path) | |
if __name__ == "__main__": | |
# Thiết lập seed ngẫu nhiên | |
set_random_seed(SEED) | |
# Đường dẫn file vocab | |
TOKEN_PRETRAINED_PATH = "data/Emilia_ZH_EN_pinyin/vocab.txt" | |
TOKEN_NEW_PATH = "data/your_training_dataset/vocab.txt" | |
# Load vocab | |
tokens_pretrained = load_vocab(TOKEN_PRETRAINED_PATH) | |
tokens_new = load_vocab(TOKEN_NEW_PATH) | |
# Số lượng token mới cần thêm | |
vocab_size_new = len(tokens_new) - len(tokens_pretrained) | |
# Đường dẫn checkpoint | |
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) | |
new_ckpt_path = "ckpts/your_training_dataset/pretrained_model_1200000.pt" | |
# Mở rộng embedding | |
expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=vocab_size_new) | |
print(f"Checkpoint đã được mở rộng và lưu tại: {new_ckpt_path}") |