Spaces:
Sleeping
Sleeping
| """ | |
| WaveGRU model: melspectrogram => mu-law encoded waveform | |
| """ | |
| from typing import Tuple | |
| import jax | |
| import jax.numpy as jnp | |
| import pax | |
| from pax import GRUState | |
| from tqdm.cli import tqdm | |
| class ReLU(pax.Module): | |
| def __call__(self, x): | |
| return jax.nn.relu(x) | |
| def dilated_residual_conv_block(dim, kernel, stride, dilation): | |
| """ | |
| Use dilated convs to enlarge the receptive field | |
| """ | |
| return pax.Sequential( | |
| pax.Conv1D(dim, dim, kernel, stride, dilation, "VALID", with_bias=False), | |
| pax.LayerNorm(dim, -1, True, True), | |
| ReLU(), | |
| pax.Conv1D(dim, dim, 1, 1, 1, "VALID", with_bias=False), | |
| pax.LayerNorm(dim, -1, True, True), | |
| ReLU(), | |
| ) | |
| def tile_1d(x, factor): | |
| """ | |
| Tile tensor of shape N, L, D into N, L*factor, D | |
| """ | |
| N, L, D = x.shape | |
| x = x[:, :, None, :] | |
| x = jnp.tile(x, (1, 1, factor, 1)) | |
| x = jnp.reshape(x, (N, L * factor, D)) | |
| return x | |
| def up_block(in_dim, out_dim, factor, relu=True): | |
| """ | |
| Tile >> Conv >> BatchNorm >> ReLU | |
| """ | |
| f = pax.Sequential( | |
| lambda x: tile_1d(x, factor), | |
| pax.Conv1D( | |
| in_dim, out_dim, 2 * factor, stride=1, padding="VALID", with_bias=False | |
| ), | |
| pax.LayerNorm(out_dim, -1, True, True), | |
| ) | |
| if relu: | |
| f >>= ReLU() | |
| return f | |
| class Upsample(pax.Module): | |
| """ | |
| Upsample melspectrogram to match raw audio sample rate. | |
| """ | |
| def __init__( | |
| self, input_dim, hidden_dim, rnn_dim, upsample_factors, has_linear_output=False | |
| ): | |
| super().__init__() | |
| self.input_conv = pax.Sequential( | |
| pax.Conv1D(input_dim, hidden_dim, 1, with_bias=False), | |
| pax.LayerNorm(hidden_dim, -1, True, True), | |
| ) | |
| self.upsample_factors = upsample_factors | |
| self.dilated_convs = [ | |
| dilated_residual_conv_block(hidden_dim, 3, 1, 2**i) for i in range(5) | |
| ] | |
| self.up_factors = upsample_factors[:-1] | |
| self.up_blocks = [ | |
| up_block(hidden_dim, hidden_dim, x) for x in self.up_factors[:-1] | |
| ] | |
| self.up_blocks.append( | |
| up_block( | |
| hidden_dim, | |
| hidden_dim if has_linear_output else 3 * rnn_dim, | |
| self.up_factors[-1], | |
| relu=False, | |
| ) | |
| ) | |
| if has_linear_output: | |
| self.x2zrh_fc = pax.Linear(hidden_dim, rnn_dim * 3) | |
| self.has_linear_output = has_linear_output | |
| self.final_tile = upsample_factors[-1] | |
| def __call__(self, x, no_repeat=False): | |
| x = self.input_conv(x) | |
| for residual in self.dilated_convs: | |
| y = residual(x) | |
| pad = (x.shape[1] - y.shape[1]) // 2 | |
| x = x[:, pad:-pad, :] + y | |
| for f in self.up_blocks: | |
| x = f(x) | |
| if self.has_linear_output: | |
| x = self.x2zrh_fc(x) | |
| if no_repeat: | |
| return x | |
| x = tile_1d(x, self.final_tile) | |
| return x | |
| class GRU(pax.Module): | |
| """ | |
| A customized GRU module. | |
| """ | |
| input_dim: int | |
| hidden_dim: int | |
| def __init__(self, hidden_dim: int): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.h_zrh_fc = pax.Linear( | |
| hidden_dim, | |
| hidden_dim * 3, | |
| w_init=jax.nn.initializers.variance_scaling( | |
| 1, "fan_out", "truncated_normal" | |
| ), | |
| ) | |
| def initial_state(self, batch_size: int) -> GRUState: | |
| """Create an all zeros initial state.""" | |
| return GRUState(jnp.zeros((batch_size, self.hidden_dim), dtype=jnp.float32)) | |
| def __call__(self, state: GRUState, x) -> Tuple[GRUState, jnp.ndarray]: | |
| hidden = state.hidden | |
| x_zrh = x | |
| h_zrh = self.h_zrh_fc(hidden) | |
| x_zr, x_h = jnp.split(x_zrh, [2 * self.hidden_dim], axis=-1) | |
| h_zr, h_h = jnp.split(h_zrh, [2 * self.hidden_dim], axis=-1) | |
| zr = x_zr + h_zr | |
| zr = jax.nn.sigmoid(zr) | |
| z, r = jnp.split(zr, 2, axis=-1) | |
| h_hat = x_h + r * h_h | |
| h_hat = jnp.tanh(h_hat) | |
| h = (1 - z) * hidden + z * h_hat | |
| return GRUState(h), h | |
| class Pruner(pax.Module): | |
| """ | |
| Base class for pruners | |
| """ | |
| def compute_sparsity(self, step): | |
| t = jnp.power(1 - (step * 1.0 - 1_000) / 200_000, 3) | |
| z = 0.95 * jnp.clip(1.0 - t, a_min=0, a_max=1) | |
| return z | |
| def prune(self, step, weights): | |
| """ | |
| Return a mask | |
| """ | |
| z = self.compute_sparsity(step) | |
| x = weights | |
| H, W = x.shape | |
| x = x.reshape(H // 4, 4, W // 4, 4) | |
| x = jnp.abs(x) | |
| x = jnp.sum(x, axis=(1, 3), keepdims=True) | |
| q = jnp.quantile(jnp.reshape(x, (-1,)), z) | |
| x = x >= q | |
| x = jnp.tile(x, (1, 4, 1, 4)) | |
| x = jnp.reshape(x, (H, W)) | |
| return x | |
| class GRUPruner(Pruner): | |
| def __init__(self, gru): | |
| super().__init__() | |
| self.h_zrh_fc_mask = jnp.ones_like(gru.h_zrh_fc.weight) == 1 | |
| def __call__(self, gru: pax.GRU): | |
| """ | |
| Apply mask after an optimization step | |
| """ | |
| zrh_masked_weights = jnp.where(self.h_zrh_fc_mask, gru.h_zrh_fc.weight, 0) | |
| gru = gru.replace_node(gru.h_zrh_fc.weight, zrh_masked_weights) | |
| return gru | |
| def update_mask(self, step, gru: pax.GRU): | |
| """ | |
| Update internal masks | |
| """ | |
| z_weight, r_weight, h_weight = jnp.split(gru.h_zrh_fc.weight, 3, axis=1) | |
| z_mask = self.prune(step, z_weight) | |
| r_mask = self.prune(step, r_weight) | |
| h_mask = self.prune(step, h_weight) | |
| self.h_zrh_fc_mask *= jnp.concatenate((z_mask, r_mask, h_mask), axis=1) | |
| class LinearPruner(Pruner): | |
| def __init__(self, linear): | |
| super().__init__() | |
| self.mask = jnp.ones_like(linear.weight) == 1 | |
| def __call__(self, linear: pax.Linear): | |
| """ | |
| Apply mask after an optimization step | |
| """ | |
| return linear.replace(weight=jnp.where(self.mask, linear.weight, 0)) | |
| def update_mask(self, step, linear: pax.Linear): | |
| """ | |
| Update internal masks | |
| """ | |
| self.mask *= self.prune(step, linear.weight) | |
| class WaveGRU(pax.Module): | |
| """ | |
| WaveGRU vocoder model. | |
| """ | |
| def __init__( | |
| self, | |
| mel_dim=80, | |
| rnn_dim=1024, | |
| upsample_factors=(5, 3, 20), | |
| has_linear_output=False, | |
| ): | |
| super().__init__() | |
| self.embed = pax.Embed(256, 3 * rnn_dim) | |
| self.upsample = Upsample( | |
| input_dim=mel_dim, | |
| hidden_dim=512, | |
| rnn_dim=rnn_dim, | |
| upsample_factors=upsample_factors, | |
| has_linear_output=has_linear_output, | |
| ) | |
| self.rnn = GRU(rnn_dim) | |
| self.o1 = pax.Linear(rnn_dim, rnn_dim) | |
| self.o2 = pax.Linear(rnn_dim, 256) | |
| self.gru_pruner = GRUPruner(self.rnn) | |
| self.o1_pruner = LinearPruner(self.o1) | |
| self.o2_pruner = LinearPruner(self.o2) | |
| def output(self, x): | |
| x = self.o1(x) | |
| x = jax.nn.relu(x) | |
| x = self.o2(x) | |
| return x | |
| def inference(self, mel, no_gru=False, seed=42): | |
| """ | |
| generate waveform form melspectrogram | |
| """ | |
| def step(rnn_state, mel, rng_key, x): | |
| x = self.embed(x) | |
| x = x + mel | |
| rnn_state, x = self.rnn(rnn_state, x) | |
| x = self.output(x) | |
| rng_key, next_rng_key = jax.random.split(rng_key, 2) | |
| x = jax.random.categorical(rng_key, x, axis=-1) | |
| return rnn_state, next_rng_key, x | |
| y = self.upsample(mel, no_repeat=no_gru) | |
| if no_gru: | |
| return y | |
| x = jnp.array([127], dtype=jnp.int32) | |
| rnn_state = self.rnn.initial_state(1) | |
| output = [] | |
| rng_key = jax.random.PRNGKey(seed) | |
| for i in tqdm(range(y.shape[1])): | |
| rnn_state, rng_key, x = step(rnn_state, y[:, i], rng_key, x) | |
| output.append(x) | |
| x = jnp.concatenate(output, axis=0) | |
| return x | |
| def __call__(self, mel, x): | |
| x = self.embed(x) | |
| y = self.upsample(mel) | |
| pad_left = (x.shape[1] - y.shape[1]) // 2 | |
| pad_right = x.shape[1] - y.shape[1] - pad_left | |
| x = x[:, pad_left:-pad_right] | |
| x = x + y | |
| _, x = pax.scan( | |
| self.rnn, | |
| self.rnn.initial_state(x.shape[0]), | |
| x, | |
| time_major=False, | |
| ) | |
| x = self.output(x) | |
| return x | |