""" Mô-đun mở rộng embedding của mô hình bằng cách thêm token mới vào vocab. Áp dụng khi fine-tuning mô hình F5-TTS. """ import os import random import torch from cached_path import cached_path from safetensors.torch import load_file # Định nghĩa seed để đảm bảo tái lập kết quả SEED = 666 def set_random_seed(seed: int): """ Đặt seed cho các thư viện ngẫu nhiên để đảm bảo reproducibility. """ random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_vocab(file_path: str) -> list: """ Đọc danh sách token từ file vocab. """ if not os.path.exists(file_path): raise FileNotFoundError(f"Không tìm thấy file: {file_path}") with open(file_path, "r", encoding="utf8") as file: return [line.strip() for line in file.readlines()] def expand_model_embeddings(ckpt_path: str, new_ckpt_path: str, num_new_tokens: int = 42): """ Mở rộng embedding của mô hình bằng cách thêm token mới. Args: ckpt_path (str): Đường dẫn đến file checkpoint gốc. new_ckpt_path (str): Đường dẫn để lưu checkpoint đã mở rộng. num_new_tokens (int): Số lượng token mới cần thêm vào. """ if ckpt_path.endswith(".safetensors"): ckpt = load_file(ckpt_path, device="cpu") ckpt = {"ema_model_state_dict": ckpt} elif ckpt_path.endswith(".pt"): ckpt = torch.load(ckpt_path, map_location="cpu") else: raise ValueError("Định dạng checkpoint không được hỗ trợ. Chỉ hỗ trợ .safetensors hoặc .pt") ema_sd = ckpt.get("ema_model_state_dict", {}) embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight" if embed_key_ema not in ema_sd: raise KeyError(f"Không tìm thấy khóa {embed_key_ema} trong checkpoint.") old_embed_ema = ema_sd[embed_key_ema] vocab_old, embed_dim = old_embed_ema.shape vocab_new = vocab_old + num_new_tokens def expand_embeddings(old_embeddings: torch.Tensor) -> torch.Tensor: """ Mở rộng embeddings bằng cách thêm vector mới. """ new_embeddings = torch.zeros((vocab_new, embed_dim)) new_embeddings[:vocab_old] = old_embeddings new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim)) return new_embeddings ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema]) torch.save(ckpt, new_ckpt_path) if __name__ == "__main__": # Thiết lập seed ngẫu nhiên set_random_seed(SEED) # Đường dẫn file vocab TOKEN_PRETRAINED_PATH = "data/Emilia_ZH_EN_pinyin/vocab.txt" TOKEN_NEW_PATH = "data/your_training_dataset/vocab.txt" # Load vocab tokens_pretrained = load_vocab(TOKEN_PRETRAINED_PATH) tokens_new = load_vocab(TOKEN_NEW_PATH) # Số lượng token mới cần thêm vocab_size_new = len(tokens_new) - len(tokens_pretrained) # Đường dẫn checkpoint ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) new_ckpt_path = "ckpts/your_training_dataset/pretrained_model_1200000.pt" # Mở rộng embedding expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=vocab_size_new) print(f"Checkpoint đã được mở rộng và lưu tại: {new_ckpt_path}")