|
import numpy as np |
|
import torch |
|
from torch.nn import functional as F |
|
from TTS.tts.utils.generic_utils import sequence_mask |
|
|
|
try: |
|
|
|
from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c |
|
CYTHON = True |
|
except ModuleNotFoundError: |
|
CYTHON = False |
|
|
|
|
|
def convert_pad_shape(pad_shape): |
|
l = pad_shape[::-1] |
|
pad_shape = [item for sublist in l for item in sublist] |
|
return pad_shape |
|
|
|
|
|
def generate_path(duration, mask): |
|
""" |
|
duration: [b, t_x] |
|
mask: [b, t_x, t_y] |
|
""" |
|
device = duration.device |
|
|
|
b, t_x, t_y = mask.shape |
|
cum_duration = torch.cumsum(duration, 1) |
|
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) |
|
|
|
cum_duration_flat = cum_duration.view(b * t_x) |
|
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) |
|
path = path.view(b, t_x, t_y) |
|
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0] |
|
]))[:, :-1] |
|
path = path * mask |
|
return path |
|
|
|
|
|
def maximum_path(value, mask): |
|
if CYTHON: |
|
return maximum_path_cython(value, mask) |
|
return maximum_path_numpy(value, mask) |
|
|
|
|
|
def maximum_path_cython(value, mask): |
|
""" Cython optimised version. |
|
value: [b, t_x, t_y] |
|
mask: [b, t_x, t_y] |
|
""" |
|
value = value * mask |
|
device = value.device |
|
dtype = value.dtype |
|
value = value.data.cpu().numpy().astype(np.float32) |
|
path = np.zeros_like(value).astype(np.int32) |
|
mask = mask.data.cpu().numpy() |
|
|
|
t_x_max = mask.sum(1)[:, 0].astype(np.int32) |
|
t_y_max = mask.sum(2)[:, 0].astype(np.int32) |
|
maximum_path_c(path, value, t_x_max, t_y_max) |
|
return torch.from_numpy(path).to(device=device, dtype=dtype) |
|
|
|
|
|
def maximum_path_numpy(value, mask, max_neg_val=None): |
|
""" |
|
Monotonic alignment search algorithm |
|
Numpy-friendly version. It's about 4 times faster than torch version. |
|
value: [b, t_x, t_y] |
|
mask: [b, t_x, t_y] |
|
""" |
|
if max_neg_val is None: |
|
max_neg_val = -np.inf |
|
value = value * mask |
|
|
|
device = value.device |
|
dtype = value.dtype |
|
value = value.cpu().detach().numpy() |
|
mask = mask.cpu().detach().numpy().astype(np.bool) |
|
|
|
b, t_x, t_y = value.shape |
|
direction = np.zeros(value.shape, dtype=np.int64) |
|
v = np.zeros((b, t_x), dtype=np.float32) |
|
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1) |
|
for j in range(t_y): |
|
v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1] |
|
v1 = v |
|
max_mask = v1 >= v0 |
|
v_max = np.where(max_mask, v1, v0) |
|
direction[:, :, j] = max_mask |
|
|
|
index_mask = x_range <= j |
|
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val) |
|
direction = np.where(mask, direction, 1) |
|
|
|
path = np.zeros(value.shape, dtype=np.float32) |
|
index = mask[:, :, 0].sum(1).astype(np.int64) - 1 |
|
index_range = np.arange(b) |
|
for j in reversed(range(t_y)): |
|
path[index_range, index, j] = 1 |
|
index = index + direction[index_range, index, j] - 1 |
|
path = path * mask.astype(np.float32) |
|
path = torch.from_numpy(path).to(device=device, dtype=dtype) |
|
return path |
|
|