ddgdgd / tts_vits.py
Kfjjdjdjdhdhd's picture
Upload 13 files
f5790af verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import wget
import json
import os
TTS_FOLDER = "./TTSModel"
TTS_MODEL_NAME = "vits"
TTS_MODEL_CONFIG = "config.json"
TTS_MODEL_WEIGHTS = "pytorch_model.bin"
TTS_VOCAB = "vocab.json"
TTS_CONFIG_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/config.json"
TTS_MODEL_WEIGHTS_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/pytorch_model.bin"
TTS_VOCAB_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/vocab.json"
TTS_FILES_URLS = [
(TTS_CONFIG_URL, TTS_MODEL_CONFIG),
(TTS_MODEL_WEIGHTS_URL, TTS_MODEL_WEIGHTS),
(TTS_VOCAB_URL, TTS_VOCAB),
]
def ensure_tts_files_exist():
os.makedirs(TTS_FOLDER, exist_ok=True)
for url, filename in TTS_FILES_URLS:
filepath = os.path.join(TTS_FOLDER, filename)
if not os.path.exists(filepath):
wget.download(url, out=filepath)
class VITS(nn.Module):
def __init__(self, spec_channels, segment_size, num_speakers, num_languages, num_symbols):
super().__init__()
self.spec_channels = spec_channels
self.segment_size = segment_size
self.num_speakers = num_speakers
self.num_languages = num_languages
self.num_symbols = num_symbols
self.embedding = nn.Embedding(num_symbols, 192)
self.decoder = Generator(spec_channels)
def forward(self, text):
x = self.embedding(text)
audio = self.decoder(x)
return audio
class Generator(nn.Module):
def __init__(self, spec_channels):
super().__init__()
self.spec_channels = spec_channels
self.initial_conv = nn.ConvTranspose2d(192, spec_channels, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
self.final_conv = nn.Conv2d(spec_channels, 1, kernel_size=(7, 7), padding=(3, 3))
def forward(self, encoder_outputs):
x = encoder_outputs.unsqueeze(2)
x = self.initial_conv(x)
x = self.final_conv(x)
return x.squeeze(1)