Spaces:
Build error
Build error
update
Browse files- inference/tts/ps_flow.py +2 -1
- modules/tts/portaspeech/portaspeech.py +12 -1
- tasks/tts/ps.py +1 -8
- tasks/tts/ps_flow.py +1 -9
inference/tts/ps_flow.py
CHANGED
|
@@ -11,7 +11,8 @@ class PortaSpeechFlowInfer(BaseTTSInfer):
|
|
| 11 |
word_dict_size = len(self.word_encoder)
|
| 12 |
model = PortaSpeechFlow(ph_dict_size, word_dict_size, self.hparams)
|
| 13 |
load_ckpt(model, hparams['work_dir'], 'model')
|
| 14 |
-
|
|
|
|
| 15 |
model.eval()
|
| 16 |
return model
|
| 17 |
|
|
|
|
| 11 |
word_dict_size = len(self.word_encoder)
|
| 12 |
model = PortaSpeechFlow(ph_dict_size, word_dict_size, self.hparams)
|
| 13 |
load_ckpt(model, hparams['work_dir'], 'model')
|
| 14 |
+
with torch.no_grad():
|
| 15 |
+
model.store_inverse_all()
|
| 16 |
model.eval()
|
| 17 |
return model
|
| 18 |
|
modules/tts/portaspeech/portaspeech.py
CHANGED
|
@@ -212,4 +212,15 @@ class PortaSpeech(FastSpeech):
|
|
| 212 |
x_pos = build_word_mask(word2word, x2word).float() # [B, T_word, T_ph]
|
| 213 |
x_pos = (x_pos.cumsum(-1) / x_pos.sum(-1).clamp(min=1)[..., None] * x_pos).sum(1)
|
| 214 |
x_pos = self.sin_pos(x_pos.float()) # [B, T_ph, H]
|
| 215 |
-
return x_pos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
x_pos = build_word_mask(word2word, x2word).float() # [B, T_word, T_ph]
|
| 213 |
x_pos = (x_pos.cumsum(-1) / x_pos.sum(-1).clamp(min=1)[..., None] * x_pos).sum(1)
|
| 214 |
x_pos = self.sin_pos(x_pos.float()) # [B, T_ph, H]
|
| 215 |
+
return x_pos
|
| 216 |
+
|
| 217 |
+
def store_inverse_all(self):
|
| 218 |
+
def remove_weight_norm(m):
|
| 219 |
+
try:
|
| 220 |
+
if hasattr(m, 'store_inverse'):
|
| 221 |
+
m.store_inverse()
|
| 222 |
+
nn.utils.remove_weight_norm(m)
|
| 223 |
+
except ValueError: # this module didn't have weight norm
|
| 224 |
+
return
|
| 225 |
+
|
| 226 |
+
self.apply(remove_weight_norm)
|
tasks/tts/ps.py
CHANGED
|
@@ -156,14 +156,7 @@ class PortaSpeechTask(FastSpeechTask):
|
|
| 156 |
super().test_start()
|
| 157 |
if hparams.get('save_attn', False):
|
| 158 |
os.makedirs(f'{self.gen_dir}/attn', exist_ok=True)
|
| 159 |
-
|
| 160 |
-
def remove_weight_norm(m):
|
| 161 |
-
try:
|
| 162 |
-
nn.utils.remove_weight_norm(m)
|
| 163 |
-
except ValueError:
|
| 164 |
-
return
|
| 165 |
-
|
| 166 |
-
self.apply(remove_weight_norm)
|
| 167 |
|
| 168 |
def test_step(self, sample, batch_idx):
|
| 169 |
assert sample['txt_tokens'].shape[0] == 1, 'only support batch_size=1 in inference'
|
|
|
|
| 156 |
super().test_start()
|
| 157 |
if hparams.get('save_attn', False):
|
| 158 |
os.makedirs(f'{self.gen_dir}/attn', exist_ok=True)
|
| 159 |
+
self.model.store_inverse_all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
def test_step(self, sample, batch_idx):
|
| 162 |
assert sample['txt_tokens'].shape[0] == 1, 'only support batch_size=1 in inference'
|
tasks/tts/ps_flow.py
CHANGED
|
@@ -131,12 +131,4 @@ class PortaSpeechFlowTask(PortaSpeechTask):
|
|
| 131 |
return [self.optimizer]
|
| 132 |
|
| 133 |
def build_scheduler(self, optimizer):
|
| 134 |
-
return FastSpeechTask.build_scheduler(self, optimizer[0])
|
| 135 |
-
|
| 136 |
-
############
|
| 137 |
-
# infer
|
| 138 |
-
############
|
| 139 |
-
def test_start(self):
|
| 140 |
-
super().test_start()
|
| 141 |
-
if hparams['use_post_flow']:
|
| 142 |
-
self.model.post_flow.store_inverse()
|
|
|
|
| 131 |
return [self.optimizer]
|
| 132 |
|
| 133 |
def build_scheduler(self, optimizer):
|
| 134 |
+
return FastSpeechTask.build_scheduler(self, optimizer[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|