import torch import torch.nn as nn import torch.nn.functional as F import torchaudio import wget import json import os TTS_FOLDER = "./TTSModel" TTS_MODEL_NAME = "vits" TTS_MODEL_CONFIG = "config.json" TTS_MODEL_WEIGHTS = "pytorch_model.bin" TTS_VOCAB = "vocab.json" TTS_CONFIG_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/config.json" TTS_MODEL_WEIGHTS_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/pytorch_model.bin" TTS_VOCAB_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/vocab.json" TTS_FILES_URLS = [ (TTS_CONFIG_URL, TTS_MODEL_CONFIG), (TTS_MODEL_WEIGHTS_URL, TTS_MODEL_WEIGHTS), (TTS_VOCAB_URL, TTS_VOCAB), ] def ensure_tts_files_exist(): os.makedirs(TTS_FOLDER, exist_ok=True) for url, filename in TTS_FILES_URLS: filepath = os.path.join(TTS_FOLDER, filename) if not os.path.exists(filepath): wget.download(url, out=filepath) class VITS(nn.Module): def __init__(self, spec_channels, segment_size, num_speakers, num_languages, num_symbols): super().__init__() self.spec_channels = spec_channels self.segment_size = segment_size self.num_speakers = num_speakers self.num_languages = num_languages self.num_symbols = num_symbols self.embedding = nn.Embedding(num_symbols, 192) self.decoder = Generator(spec_channels) def forward(self, text): x = self.embedding(text) audio = self.decoder(x) return audio class Generator(nn.Module): def __init__(self, spec_channels): super().__init__() self.spec_channels = spec_channels self.initial_conv = nn.ConvTranspose2d(192, spec_channels, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) self.final_conv = nn.Conv2d(spec_channels, 1, kernel_size=(7, 7), padding=(3, 3)) def forward(self, encoder_outputs): x = encoder_outputs.unsqueeze(2) x = self.initial_conv(x) x = self.final_conv(x) return x.squeeze(1)