Spaces:
Sleeping
Sleeping
| /* | |
| * Copyright 2021 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| namespace csrblocksparse { | |
| static constexpr int kNeonSIMDWidth = 4; | |
| // ------ Scalar calculation -------- | |
| // See "Efficient Neural Audio Synthesis" for a description of the calculation. | |
| // https://arxiv.org/abs/1802.08435 | |
| // | |
| // NOTE: | |
| // |sample| = (|coarse_at_sminus1|, |fine_at_sminus1|, | |
| // |coarse_at_sminus1|, |fine_at_sminus1|) | |
| // |w_sample| = (|coarse_at_s|, |coarse_at_s|, |coarse_at_s|, |coarse_at_s|) | |
| // | |
| // CHEATSHEET: | |
| // vld1q_f32 = load 4 32-bit floats | |
| // vmulq_f32(a, b) : return a * b; | |
| // vaddq_f32(a, b) : return a + b; | |
| // vmlaq_f32(c, a, b) : return c + a * b; | |
| // vpaddq_f32(a, b) : return (a0 + a1, a2 + a3, b0 + b1, b2 + b3) | |
| // vsubq_f32(a, b) : return a - b; | |
| // vst1q_f32 = store 4 32-bit floats | |
| // Backport of vpaddq_f32 to ARM32. | |
| inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) { | |
| float32x2_t a10 = vget_low_f32(a); | |
| float32x2_t a32 = vget_high_f32(a); | |
| float32x2_t b10 = vget_low_f32(b); | |
| float32x2_t b32 = vget_high_f32(b); | |
| return vcombine_f32(vpadd_f32(a10, a32), vpadd_f32(b10, b32)); | |
| } | |
| template <ARInputsMode kInputsMode, bool SplitGates> | |
| void GoThroughGatesFloat(int start, int end, const float* qr_ptr, | |
| const float* gru_gates_ptr, | |
| const float* gru_gates_other_ptr, | |
| const float* conditioning_ptr, float* gru_h_ptr, | |
| const float* w_hat, int proj_size, | |
| const float* coarse_at_sminus1, | |
| const float* fine_at_sminus1, | |
| const float* coarse_at_s) { | |
| // Increment all the pointers to save on pointer arithmetic in the loop. | |
| conditioning_ptr += start; | |
| gru_h_ptr += start; | |
| gru_gates_ptr += start; | |
| if (SplitGates) { | |
| DCHECK_NE(gru_gates_other_ptr, nullptr); | |
| gru_gates_other_ptr += start; | |
| } | |
| if (kInputsMode != ARInputsMode::k0ARInputs) { | |
| DCHECK_NE(qr_ptr, nullptr); | |
| qr_ptr += 2 * start; | |
| DCHECK_NE(coarse_at_sminus1, nullptr); | |
| DCHECK_NE(fine_at_sminus1, nullptr); | |
| if (kInputsMode == ARInputsMode::k3ARInputs) { | |
| DCHECK_NE(w_hat, nullptr); | |
| DCHECK_NE(coarse_at_s, nullptr); | |
| w_hat += start; | |
| } | |
| } | |
| for (int i = start; i < end; i += kNeonSIMDWidth) { | |
| float32x4_t reset = vld1q_f32(gru_gates_ptr); | |
| float32x4_t update = vld1q_f32(gru_gates_ptr + proj_size); | |
| float32x4_t cell = vld1q_f32(gru_gates_ptr + 2 * proj_size); | |
| float32x4_t qr_cell; | |
| if (SplitGates) { | |
| reset = vaddq_f32(reset, vld1q_f32(gru_gates_other_ptr)); | |
| update = vaddq_f32(update, vld1q_f32(gru_gates_other_ptr + proj_size)); | |
| cell = vaddq_f32(cell, vld1q_f32(gru_gates_other_ptr + 2 * proj_size)); | |
| } | |
| if (kInputsMode != ARInputsMode::k0ARInputs) { | |
| // Setup the sample vector. | |
| float32x4_t sample = vdupq_n_f32(*coarse_at_sminus1); | |
| sample = vsetq_lane_f32(*fine_at_sminus1, sample, 1); | |
| sample = vsetq_lane_f32(*fine_at_sminus1, sample, 3); | |
| // All auto types are float32x4_t, auto used to fit statements on one line | |
| // for readability. Do two rows of QR at once. | |
| auto qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample); | |
| auto qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample); | |
| auto qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1); | |
| auto qr_update_0 = vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample); | |
| auto qr_update_1 = | |
| vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample); | |
| auto qr_update = vpaddq_f32(qr_update_0, qr_update_1); | |
| auto qr_cell_0 = vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample); | |
| auto qr_cell_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample); | |
| qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1); | |
| if (kInputsMode == ARInputsMode::k3ARInputs) { | |
| float32x4_t w_sample = vdupq_n_f32(*coarse_at_s); | |
| qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample); | |
| qr_update = | |
| vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample); | |
| qr_cell = | |
| vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample); | |
| } | |
| reset = vaddq_f32(reset, qr_reset); | |
| update = vaddq_f32(update, qr_update); | |
| } | |
| auto reset_conditioning = vld1q_f32(conditioning_ptr); | |
| auto update_conditioning = vld1q_f32(conditioning_ptr + proj_size); | |
| auto cell_conditioning = vld1q_f32(conditioning_ptr + 2 * proj_size); | |
| reset = fast_sigmoid(vaddq_f32(reset, reset_conditioning)); | |
| update = fast_sigmoid(vaddq_f32(update, update_conditioning)); | |
| if (kInputsMode == ARInputsMode::k0ARInputs) { | |
| cell = vmulq_f32(reset, cell); | |
| } else { | |
| cell = vmlaq_f32(qr_cell, reset, cell); | |
| } | |
| auto hbar = fast_tanh(vaddq_f32(cell, cell_conditioning)); | |
| auto prev_h = vld1q_f32(gru_h_ptr); | |
| auto diff = vsubq_f32(prev_h, hbar); | |
| auto new_h = vmlaq_f32(hbar, diff, update); | |
| vst1q_f32(gru_h_ptr, new_h); | |
| // Increment all the pointers. | |
| conditioning_ptr += kNeonSIMDWidth; | |
| gru_h_ptr += kNeonSIMDWidth; | |
| gru_gates_ptr += kNeonSIMDWidth; | |
| if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth; | |
| if (kInputsMode != ARInputsMode::k0ARInputs) { | |
| qr_ptr += 2 * kNeonSIMDWidth; | |
| if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth; | |
| } | |
| } | |
| } | |
| // This version should only be used if all of the 32-bit fixed point | |
| // representations have the same number of mantissa bits. | |
| // |ar_at_sminus1| packs sample 0 and 1 into a pair because the QR weights are | |
| // formatted with the weights interleaved for sample 0 and 1. The two samples | |
| // represent coarse and fine for WaveRNN. | |
| template <typename GRUStateType, typename GRUMatMulOutType, | |
| ARInputsMode kInputsMode, bool SplitGates> | |
| void GoThroughGatesFixed(int start, int end, const float* qr_ptr, | |
| const int32_t* gru_gates_ptr, | |
| const int32_t* gru_gates_other_ptr, | |
| const int32_t* conditioning_ptr, int16_t* gru_h_ptr, | |
| const float* w_hat, int proj_size, | |
| const std::pair<float, float>* ar_at_sminus1, | |
| const float* coarse_at_s) { | |
| // Increment all the pointers to save on pointer arithmetic in the loop. | |
| conditioning_ptr += start; | |
| gru_h_ptr += start; | |
| gru_gates_ptr += start; | |
| if (SplitGates) { | |
| DCHECK_NE(gru_gates_other_ptr, nullptr); | |
| gru_gates_other_ptr += start; | |
| } | |
| float32x4_t sample01; | |
| float32x4_t w_sample; | |
| if (kInputsMode != ARInputsMode::k0ARInputs) { | |
| DCHECK_NE(qr_ptr, nullptr); | |
| qr_ptr += 2 * start; | |
| DCHECK_NE(ar_at_sminus1, nullptr); | |
| sample01 = vdupq_n_f32(ar_at_sminus1->first); | |
| sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 1); | |
| sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 3); | |
| if (kInputsMode == ARInputsMode::k3ARInputs) { | |
| DCHECK_NE(w_hat, nullptr); | |
| DCHECK_NE(coarse_at_s, nullptr); | |
| w_hat += start; | |
| w_sample = vdupq_n_f32(*coarse_at_s); | |
| } | |
| } | |
| for (int i = start; i < end; i += kNeonSIMDWidth) { | |
| auto reset = vld1q_s32(gru_gates_ptr); | |
| auto update = vld1q_s32(gru_gates_ptr + proj_size); | |
| // vcvtq_n_f32_s32 = convert 32-bit fixed point to fp32 | |
| auto cell_int = vld1q_s32(gru_gates_ptr + 2 * proj_size); | |
| if (SplitGates) { | |
| reset = vaddq_s32(reset, vld1q_s32(gru_gates_other_ptr)); | |
| update = vaddq_s32(update, vld1q_s32(gru_gates_other_ptr + proj_size)); | |
| cell_int = | |
| vaddq_s32(cell_int, vld1q_s32(gru_gates_other_ptr + 2 * proj_size)); | |
| } | |
| float32x4_t cell = | |
| vcvtq_n_f32_s32(cell_int, GRUMatMulOutType::kMantissaBits); | |
| float32x4_t qr_cell; | |
| if (kInputsMode != ARInputsMode::k0ARInputs) { | |
| // Do two rows of QR at once. | |
| float32x4_t qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample01); | |
| float32x4_t qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample01); | |
| float32x4_t qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1); | |
| float32x4_t qr_update_0 = | |
| vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample01); | |
| float32x4_t qr_update_1 = | |
| vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample01); | |
| float32x4_t qr_update = vpaddq_f32(qr_update_0, qr_update_1); | |
| float32x4_t qr_cell_0 = | |
| vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample01); | |
| float32x4_t qr_cell_1 = | |
| vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample01); | |
| qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1); | |
| if (kInputsMode == ARInputsMode::k3ARInputs) { | |
| float32x4_t w_sample = vdupq_n_f32(*coarse_at_s); | |
| qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample); | |
| qr_update = | |
| vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample); | |
| qr_cell = | |
| vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample); | |
| } | |
| reset = vaddq_s32( | |
| reset, vcvtq_n_s32_f32(qr_reset, GRUMatMulOutType::kMantissaBits)); | |
| update = vaddq_s32( | |
| update, vcvtq_n_s32_f32(qr_update, GRUMatMulOutType::kMantissaBits)); | |
| } | |
| auto reset_conditioning = vld1q_s32(conditioning_ptr); | |
| auto update_conditioning = vld1q_s32(conditioning_ptr + proj_size); | |
| float32x4_t cell_conditioning = | |
| vcvtq_n_f32_s32(vld1q_s32(conditioning_ptr + 2 * proj_size), | |
| GRUMatMulOutType::kMantissaBits); | |
| float32x4_t reset_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>( | |
| vaddq_s32(reset, reset_conditioning)); | |
| float32x4_t update_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>( | |
| vaddq_s32(update, update_conditioning)); | |
| if (kInputsMode == ARInputsMode::k0ARInputs) { | |
| cell = vmulq_f32(reset_f32, cell); | |
| } else { | |
| cell = vmlaq_f32(qr_cell, reset_f32, cell); | |
| } | |
| float32x4_t hbar = fast_tanh(vaddq_f32(cell, cell_conditioning)); | |
| float32x4_t prev_h = vcvtq_n_f32_s32(vmovl_s16(vld1_s16(gru_h_ptr)), | |
| GRUStateType::kMantissaBits); | |
| float32x4_t diff = vsubq_f32(prev_h, hbar); | |
| float32x4_t new_h = vmlaq_f32(hbar, diff, update_f32); | |
| // vcvtq_n_s32_f32 = convert fp32 to signed 32-bit fixed point | |
| // vqrshrn_n_s32 = saturating, rounding, narrowing right shift - used to | |
| // convert a 32-bit fixed point value to a 16-bit fixed point value | |
| vst1_s16(gru_h_ptr, | |
| vqrshrn_n_s32( | |
| vcvtq_n_s32_f32(new_h, GRUStateType::kMantissaBits + 16), 16)); | |
| // Increment all the pointers. | |
| conditioning_ptr += kNeonSIMDWidth; | |
| gru_h_ptr += kNeonSIMDWidth; | |
| gru_gates_ptr += kNeonSIMDWidth; | |
| if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth; | |
| if (kInputsMode != ARInputsMode::k0ARInputs) { | |
| qr_ptr += 2 * kNeonSIMDWidth; | |
| if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth; | |
| } | |
| } | |
| } | |
| } // namespace csrblocksparse | |