Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Union | |
| import torch | |
| from modules import models | |
| from modules.utils.SeedContext import SeedContext | |
| import uuid | |
| def create_speaker_from_seed(seed): | |
| chat_tts = models.load_chat_tts() | |
| with SeedContext(seed): | |
| emb = chat_tts.sample_random_speaker() | |
| return emb | |
| class Speaker: | |
| def __init__(self, seed, name="", gender="", describe=""): | |
| self.id = uuid.uuid4() | |
| self.seed = seed | |
| self.name = name | |
| self.gender = gender | |
| self.describe = describe | |
| self.emb = None | |
| def to_json(self, with_emb=False): | |
| return { | |
| "id": str(self.id), | |
| "seed": self.seed, | |
| "name": self.name, | |
| "gender": self.gender, | |
| "describe": self.describe, | |
| "emb": self.emb.tolist() if with_emb else None, | |
| } | |
| def fix(self): | |
| is_update = False | |
| if "id" not in self.__dict__: | |
| setattr(self, "id", uuid.uuid4()) | |
| is_update = True | |
| if "seed" not in self.__dict__: | |
| setattr(self, "seed", -2) | |
| is_update = True | |
| if "name" not in self.__dict__: | |
| setattr(self, "name", "") | |
| is_update = True | |
| if "gender" not in self.__dict__: | |
| setattr(self, "gender", "*") | |
| is_update = True | |
| if "describe" not in self.__dict__: | |
| setattr(self, "describe", "") | |
| is_update = True | |
| return is_update | |
| def __hash__(self): | |
| return hash(str(self.id)) | |
| def __eq__(self, other): | |
| if not isinstance(other, Speaker): | |
| return False | |
| return str(self.id) == str(other.id) | |
| # 每个speaker就是一个 emb 文件 .pt | |
| # 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker | |
| # 可以 用 seed 创建一个 speaker | |
| # 可以 刷新列表 读取所有 speaker | |
| # 可以列出所有 speaker | |
| class SpeakerManager: | |
| def __init__(self): | |
| self.speakers = {} | |
| self.speaker_dir = "./data/speakers/" | |
| self.refresh_speakers() | |
| def refresh_speakers(self): | |
| self.speakers = {} | |
| for speaker_file in os.listdir(self.speaker_dir): | |
| if speaker_file.endswith(".pt"): | |
| speaker = torch.load( | |
| self.speaker_dir + speaker_file, map_location=torch.device("cpu") | |
| ) | |
| self.speakers[speaker_file] = speaker | |
| is_update = speaker.fix() | |
| if is_update: | |
| torch.save(speaker, self.speaker_dir + speaker_file) | |
| def list_speakers(self): | |
| return list(self.speakers.values()) | |
| def create_speaker_from_seed(self, seed, name="", gender="", describe=""): | |
| if name == "": | |
| name = seed | |
| filename = name + ".pt" | |
| speaker = Speaker(seed, name=name, gender=gender, describe=describe) | |
| speaker.emb = create_speaker_from_seed(seed) | |
| torch.save(speaker, self.speaker_dir + filename) | |
| self.refresh_speakers() | |
| return speaker | |
| def create_speaker_from_tensor( | |
| self, tensor, filename="", name="", gender="", describe="" | |
| ): | |
| if name == "": | |
| name = filename | |
| speaker = Speaker(seed=-2, name=name, gender=gender, describe=describe) | |
| if isinstance(tensor, torch.Tensor): | |
| speaker.emb = tensor | |
| if isinstance(tensor, list): | |
| speaker.emb = torch.tensor(tensor) | |
| torch.save(speaker, self.speaker_dir + filename + ".pt") | |
| self.refresh_speakers() | |
| return speaker | |
| def get_speaker(self, name) -> Union[Speaker, None]: | |
| for speaker in self.speakers.values(): | |
| if speaker.name == name: | |
| return speaker | |
| return None | |
| def get_speaker_by_id(self, id) -> Union[Speaker, None]: | |
| for speaker in self.speakers.values(): | |
| if str(speaker.id) == str(id): | |
| return speaker | |
| return None | |
| def get_speaker_filename(self, id: str): | |
| filename = None | |
| for fname, spk in self.speakers.items(): | |
| if str(spk.id) == str(id): | |
| filename = fname | |
| break | |
| return filename | |
| def update_speaker(self, speaker: Speaker): | |
| filename = None | |
| for fname, spk in self.speakers.items(): | |
| if str(spk.id) == str(speaker.id): | |
| filename = fname | |
| break | |
| if filename: | |
| torch.save(speaker, self.speaker_dir + filename) | |
| self.refresh_speakers() | |
| return speaker | |
| else: | |
| raise ValueError("Speaker not found for update") | |
| def save_all(self): | |
| for speaker in self.speakers.values(): | |
| filename = self.get_speaker_filename(speaker.id) | |
| torch.save(speaker, self.speaker_dir + filename) | |
| # self.refresh_speakers() | |
| def __len__(self): | |
| return len(self.speakers) | |
| speaker_mgr = SpeakerManager() | |