antoniomae1234's picture
changes in flenema
2493d72 verified
import numpy as np
import torch
from torch.nn import functional as F
from TTS.tts.utils.generic_utils import sequence_mask
try:
# TODO: fix pypi cython installation problem.
from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c
CYTHON = True
except ModuleNotFoundError:
CYTHON = False
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def generate_path(duration, mask):
"""
duration: [b, t_x]
mask: [b, t_x, t_y]
"""
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]
]))[:, :-1]
path = path * mask
return path
def maximum_path(value, mask):
if CYTHON:
return maximum_path_cython(value, mask)
return maximum_path_numpy(value, mask)
def maximum_path_cython(value, mask):
""" Cython optimised version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
value = value * mask
device = value.device
dtype = value.dtype
value = value.data.cpu().numpy().astype(np.float32)
path = np.zeros_like(value).astype(np.int32)
mask = mask.data.cpu().numpy()
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
maximum_path_c(path, value, t_x_max, t_y_max)
return torch.from_numpy(path).to(device=device, dtype=dtype)
def maximum_path_numpy(value, mask, max_neg_val=None):
"""
Monotonic alignment search algorithm
Numpy-friendly version. It's about 4 times faster than torch version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
if max_neg_val is None:
max_neg_val = -np.inf # Patch for Sphinx complaint
value = value * mask
device = value.device
dtype = value.dtype
value = value.cpu().detach().numpy()
mask = mask.cpu().detach().numpy().astype(np.bool)
b, t_x, t_y = value.shape
direction = np.zeros(value.shape, dtype=np.int64)
v = np.zeros((b, t_x), dtype=np.float32)
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
for j in range(t_y):
v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1]
v1 = v
max_mask = v1 >= v0
v_max = np.where(max_mask, v1, v0)
direction[:, :, j] = max_mask
index_mask = x_range <= j
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
direction = np.where(mask, direction, 1)
path = np.zeros(value.shape, dtype=np.float32)
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
index_range = np.arange(b)
for j in reversed(range(t_y)):
path[index_range, index, j] = 1
index = index + direction[index_range, index, j] - 1
path = path * mask.astype(np.float32)
path = torch.from_numpy(path).to(device=device, dtype=dtype)
return path