sanchit-gandhi commited on
Commit
72d1bae
·
1 Parent(s): 1e5a7f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from transformers import VitsModel, VitsTokenizer, set_seed
5
+
6
+
7
+ title = """
8
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
9
+ <div
10
+ style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
11
+ > <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
12
+ VITS TTS Demo
13
+ </h1> </div>
14
+ </div>
15
+ """
16
+
17
+ description = """
18
+ VITS is an end-to-end speech synthesis model that predicts a speech waveform conditional on an input text sequence. It
19
+ is a conditional variational autoencoder (VAE) comprised of a posterior encoder, decoder, and conditional prior.
20
+
21
+ This demo showcases the official VITS checkpoints, trained on the [LJ
22
+ Speech](https://huggingface.co/kakao-enterprise/vits-ljs) and [VCTK](https://huggingface.co/kakao-enterprise/vits-vctk)
23
+ datasets.
24
+ """
25
+
26
+ article = "Model by Jaehyeon Kim et al. from Kakao Enterprise. Code and demo by 🤗 Hugging Face."
27
+
28
+ ljs_model = VitsModel.from_pretrained("kakao-enterprise/vits-ljs")
29
+ ljs_tokenizer = VitsTokenizer.from_pretrained("kakao-enterprise/vits-ljs")
30
+
31
+ vctk_model = VitsModel.from_pretrained("kakao-enterprise/vits-vctk")
32
+ vctk_tokenizer = VitsTokenizer.from_pretrained("kakao-enterprise/vits-vctk")
33
+
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ ljs_model.to(device)
36
+ vctk_model.to(device)
37
+
38
+ def ljs_forward(text, speaking_rate=1.0):
39
+ inputs = ljs_tokenizer(text, return_tensors="pt")
40
+
41
+ ljs_model.speaking_rate = speaking_rate
42
+ set_seed(555)
43
+ with torch.no_grad():
44
+ outputs = ljs_model(**inputs)[0]
45
+
46
+ waveform = outputs[0].cpu().float().numpy()
47
+ return gr.make_waveform((22050, waveform))
48
+
49
+
50
+ def vctk_forward(text, speaking_rate=1.0, speaker_id=1):
51
+ inputs = vctk_tokenizer(text, return_tensors="pt")
52
+
53
+ vctk_model.speaking_rate = speaking_rate
54
+ set_seed(555)
55
+ with torch.no_grad():
56
+ outputs = vctk_model(**inputs, speaker_id=speaker_id - 1)[0]
57
+
58
+ waveform = outputs[0].cpu().float().numpy()
59
+ return gr.make_waveform((22050, waveform))
60
+
61
+
62
+ ljs_inference = gr.Interface(
63
+ fn=ljs_forward,
64
+ inputs=[
65
+ gr.Textbox(
66
+ value="Hey, it's Hugging Face on the phone",
67
+ max_lines=1,
68
+ label="Input text",
69
+ ),
70
+ gr.Slider(
71
+ 0.5,
72
+ 1.5,
73
+ value=1,
74
+ step=0.1,
75
+ label="Speaking rate",
76
+ ),
77
+ ],
78
+ outputs=gr.Audio(),
79
+ )
80
+
81
+ vctk_inference = gr.Interface(
82
+ fn=vctk_forward,
83
+ inputs=[
84
+ gr.Textbox(
85
+ value="Hey, it's Hugging Face on the phone",
86
+ max_lines=1,
87
+ label="Input text",
88
+ ),
89
+ gr.Slider(
90
+ 0.5,
91
+ 1.5,
92
+ value=1,
93
+ step=0.1,
94
+ label="Speaking rate",
95
+ ),
96
+ gr.Slider(
97
+ 1,
98
+ vctk_model.config.num_speakers,
99
+ value=1,
100
+ step=1,
101
+ label="Speaker id",
102
+ info=f"The VCTK model is trained on {vctk_model.config.num_speakers} speakers. You can prompt the model using one of these speaker ids.",
103
+ ),
104
+ ],
105
+ outputs=gr.Audio(),
106
+ )
107
+
108
+ demo = gr.Blocks()
109
+
110
+ with demo:
111
+ gr.Markdown(title)
112
+ gr.Markdown(description)
113
+ gr.TabbedInterface([ljs_inference, vctk_inference], ["LJ Speech", "VCTK"])
114
+ gr.Markdown(article)
115
+
116
+ demo.queue(max_size=10)
117
+ demo.launch()