Spaces:
Sleeping
Sleeping
| /* | |
| * Copyright 2021 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| namespace csrblocksparse { | |
| constexpr int kAVX2SIMDWidth = 8; | |
| // Loads 8x fixed32 from |ptr0| and adds to |input|. | |
| // If |kTwoInputs|, also loads from |ptr1| and adds that as well. | |
| // Returns the 2 or 3-way sum. | |
| template <bool kTwoInputs> | |
| inline __m256i LoadAndAddFixed32(const int32_t* ptr0, const int32_t* ptr1, | |
| const __m256i& input) { | |
| __m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0)); | |
| if (kTwoInputs) { | |
| __m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1)); | |
| data0 = _mm256_add_epi32(data0, data1); | |
| } | |
| return _mm256_add_epi32(data0, input); | |
| } | |
| // Loads 8x fixed32 from ptr0. | |
| // If |kTwoInputs|, also loads from |ptr1| and adds. | |
| // Multiplies the loaded values by the factor and adds to |input|, which also | |
| // is converted to float. | |
| // Returns the sum. | |
| template <bool kTwoInputs> | |
| inline __m256 LoadMultiplyAddToFloat(const int32_t* ptr0, const int32_t* ptr1, | |
| const __m256& float_factor, | |
| const __m256& input) { | |
| __m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0)); | |
| if (kTwoInputs) { | |
| __m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1)); | |
| data0 = _mm256_add_epi32(data0, data1); | |
| } | |
| __m256 float_result = _mm256_cvtepi32_ps(data0); | |
| float_result = _mm256_mul_ps(float_result, float_factor); | |
| return _mm256_add_ps(float_result, input); | |
| } | |
| // Loads 16x float in 2x 8x registers from |ptr0_1| and multiplies by | |
| // |input_pairs|, likewise formatted as 8x floats, alternating between the two | |
| // AR inputs and sums each pair of results, making 8x float results. | |
| // If |kThreeInputs|, also loads 8x float from |ptr2| and multiplies by | |
| // |third_input|, which must be formatted as 8x float. The second product is | |
| // added to the previous result. | |
| // Returns the sum added to |accumulator|. | |
| template <bool kThreeInputs> | |
| inline __m256 MultiplyAddFloat(const __m256& input_pairs, | |
| const __m256& third_input, const float* ptr0_1, | |
| const float* ptr2, const __m256& accumulator) { | |
| __m256 data_pair0 = _mm256_load_ps(ptr0_1); | |
| __m256 data_pair1 = _mm256_load_ps(ptr0_1 + 8); | |
| data_pair0 = _mm256_mul_ps(data_pair0, input_pairs); | |
| data_pair1 = _mm256_mul_ps(data_pair1, input_pairs); | |
| data_pair0 = _mm256_hadd_ps(data_pair0, data_pair1); | |
| // Swap the middle 2 64 bit pairs to correct the hadd result. | |
| data_pair0 = _mm256_permute4x64_pd((__m256d)data_pair0, 0xd8); | |
| if (kThreeInputs) { | |
| // Load 256 bits (8 x float) of data, then multiply-accumulate. | |
| data_pair1 = _mm256_load_ps(ptr2); | |
| data_pair1 = _mm256_mul_ps(data_pair1, third_input); | |
| data_pair0 = _mm256_add_ps(data_pair0, data_pair1); | |
| } | |
| // Add conditioning. | |
| return _mm256_add_ps(data_pair0, accumulator); | |
| } | |
| // Processes the tanh and the final combination, returns the new GRU state. | |
| template <int kInputMantissaBits, int kStateMantissaBits, bool kSplitGates> | |
| inline __m256i GRUComputeState(const __m256& cell0, const __m256& cell1, | |
| const __m256& reset0, const __m256& reset1, | |
| const __m256& update0, const __m256& update1, | |
| const int32_t* gate_ptr, | |
| const int32_t* gate_other_ptr, | |
| const void* gru_h_ptr) { | |
| // Multiply the cell gru output and the reset. | |
| __m256 float_gru0 = LoadMultiplyAddToFloat<kSplitGates>( | |
| gate_ptr, gate_other_ptr, reset0, cell0); | |
| __m256 float_gru1 = LoadMultiplyAddToFloat<kSplitGates>( | |
| gate_ptr + kAVX2SIMDWidth, gate_other_ptr + kAVX2SIMDWidth, reset1, | |
| cell1); | |
| // Compute tanh on the result. | |
| __m256 hbar0, hbar1; | |
| float_tanh_float<kInputMantissaBits, TM_ORDER4_FLOAT>(float_gru0, float_gru1, | |
| hbar0, hbar1); | |
| // Load the 16-bit previous gru state and update. | |
| __m256i gru = _mm256_load_si256(reinterpret_cast<__m256i const*>(gru_h_ptr)); | |
| __m256 state_factor = | |
| _mm256_set1_ps(1.0f / (static_cast<float>(1 << kStateMantissaBits))); | |
| float_gru0 = | |
| _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(gru))); | |
| float_gru1 = _mm256_cvtepi32_ps( | |
| _mm256_cvtepi16_epi32(_mm256_extractf128_si256(gru, 1))); | |
| float_gru0 = _mm256_mul_ps(float_gru0, state_factor); | |
| float_gru1 = _mm256_mul_ps(float_gru1, state_factor); | |
| float_gru0 = _mm256_sub_ps(float_gru0, hbar0); | |
| float_gru1 = _mm256_sub_ps(float_gru1, hbar1); | |
| float_gru0 = _mm256_mul_ps(float_gru0, update0); | |
| float_gru1 = _mm256_mul_ps(float_gru1, update1); | |
| state_factor = _mm256_set1_ps(static_cast<float>(1 << kStateMantissaBits)); | |
| float_gru0 = _mm256_add_ps(float_gru0, hbar0); | |
| float_gru1 = _mm256_add_ps(float_gru1, hbar1); | |
| float_gru0 = _mm256_mul_ps(float_gru0, state_factor); | |
| float_gru1 = _mm256_mul_ps(float_gru1, state_factor); | |
| return PackFloatsToFixed16(float_gru0, float_gru1); | |
| } | |
| // According to |kInputsMode|, processes 0, 2 or 3 autoregressive inputs and | |
| // combines with |input| and |gates*|. | |
| // With 2 AR inputs, loads 8x pairs of float from |pair_weights| and multiplies | |
| // by |paired_ar|, likewise formatted as 8x float, but scaled such that the | |
| // product with pair_weights is on the same scale as |*input| and |*gates0|, | |
| // and sums each pair result, making 8x float results. | |
| // If 3 AR inputs, also loads 8x float from |third_weights| and multiplies by | |
| // |third_ar|, which must be formatted as 8x scaled floats. The second product | |
| // is added to the previous result. | |
| // Inputs, 8x fixed32 are loaded from |input|, and added to the total. | |
| // Finally 8x fixed32 from |gates0| (and |gates1| if |kTwoGates|) are added as | |
| // well. | |
| // Returns the total sum as a float, but on the scale of |*input|. | |
| template <bool kTwoGates, ARInputsMode kInputsMode> | |
| inline __m256 GruInput32ToFloat(const __m256& paired_ar, | |
| const __m256& third_ar, | |
| const float* pair_weights, | |
| const float* third_weights, | |
| const int32_t* gates0, const int32_t* gates1, | |
| const int32_t* input) { | |
| __m256i data32 = _mm256_load_si256(reinterpret_cast<__m256i const*>(input)); | |
| data32 = LoadAndAddFixed32<kTwoGates>(gates0, gates1, data32); | |
| __m256 float_data = _mm256_cvtepi32_ps(data32); | |
| if (kInputsMode != ARInputsMode::k0ARInputs) { | |
| float_data = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>( | |
| paired_ar, third_ar, pair_weights, third_weights, float_data); | |
| } | |
| return float_data; | |
| } | |
| // Generic GRU gates function controlled by template parameters thus: | |
| // - |kInputBits|: the mantissa bits in |*input_ptr|, |*gru_recurrent_ptr|. | |
| // - |kStateBits|: the mantissa_bits in |*gru_state_ptr|. | |
| // - |kInputsMode == |k0ARInputs|: There are no autoregressive inputs so | |
| // |ar_sample, |ar_sample1|, |ar_sample2|, |ar_01_weights|, |ar_2_weights| are | |
| // ignored. | |
| // - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied by | |
| // |ar_01_weights| and added to the (conditioning) input. | |
| // - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by |ar_2_weights| | |
| // and added to the other two AR inputs (and added to the conditioning input). | |
| // - |kReplicas| determines the number of duplicates of the output to be | |
| // written, separated by |replica_stride|. If zero, then the number of | |
| // replicas is variable and taken from the |replicas| argument. | |
| // - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary | |
| // recurrent input that must be added to |*gru_recurrent_ptr|. | |
| // - |start|, |end| are |rows| in [0, |state_size|] to be processed by this | |
| // thread. | |
| // | |
| // Previous state is read from |*gru_state_ptr| and the new state is written to | |
| // *(|gru_state_ptr| + i * |replica_stride| for i in [0, |kReplicas|]). | |
| template <int kInputBits, int kStateBits, | |
| ARInputsMode kInputsMode = ARInputsMode::k0ARInputs, | |
| int kReplicas = 1, bool kSplitGates = false> | |
| inline void GruGatesTemplate( | |
| int start, int end, int state_size, int replicas, int replica_stride, | |
| const int32_t* gru_recurrent_ptr, const int32_t* input_ptr, | |
| const std::pair<float, float>* ar_sample01, const float* ar_01_weights, | |
| const float* ar_sample2, const float* ar_2_weights, | |
| const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) { | |
| constexpr int kQRIncrement = kAVX2SIMDWidth; | |
| // Increment all the pointers to save on pointer arithmetic in the loop. | |
| input_ptr += start; | |
| gru_state_ptr += start; | |
| gru_recurrent_ptr += start; | |
| if (kSplitGates) gru_recurrent_other_ptr += start; | |
| __m256 ar_2_inputs, ar_3rd_input; | |
| if (kInputsMode != ARInputsMode::k0ARInputs) { | |
| ar_01_weights += 2 * start; | |
| ar_2_inputs = _mm256_castsi256_ps( | |
| _mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(ar_sample01))); | |
| if (kInputsMode == ARInputsMode::k3ARInputs) { | |
| ar_2_weights += start; | |
| ar_3rd_input = _mm256_set1_ps(*ar_sample2); | |
| } else { | |
| ar_3rd_input = {}; | |
| } | |
| } else { | |
| ar_2_inputs = {}; | |
| ar_3rd_input = {}; | |
| } | |
| // The transcendentals handle 2x registers of data at once, so we have to do | |
| // everything in duplicate. | |
| for (int i = start; i < end; i += kQRIncrement * 2) { | |
| // Load 8 pairs of fixed16s for each of reset, update and cell. | |
| __m256 reset0 = GruInput32ToFloat<kSplitGates, kInputsMode>( | |
| ar_2_inputs, ar_3rd_input, ar_01_weights, ar_2_weights, | |
| gru_recurrent_ptr, gru_recurrent_other_ptr, input_ptr); | |
| __m256 reset1 = GruInput32ToFloat<kSplitGates, kInputsMode>( | |
| ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * kQRIncrement, | |
| ar_2_weights + kQRIncrement, gru_recurrent_ptr + kAVX2SIMDWidth, | |
| gru_recurrent_other_ptr + kAVX2SIMDWidth, input_ptr + kAVX2SIMDWidth); | |
| float_sigmoid_float<kInputBits>(reset0, reset1); | |
| __m256 update0 = GruInput32ToFloat<kSplitGates, kInputsMode>( | |
| ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * state_size, | |
| ar_2_weights + state_size, gru_recurrent_ptr + state_size, | |
| gru_recurrent_other_ptr + state_size, input_ptr + state_size); | |
| __m256 update1 = GruInput32ToFloat<kSplitGates, kInputsMode>( | |
| ar_2_inputs, ar_3rd_input, | |
| ar_01_weights + 2 * state_size + 2 * kQRIncrement, | |
| ar_2_weights + state_size + kQRIncrement, | |
| gru_recurrent_ptr + state_size + kAVX2SIMDWidth, | |
| gru_recurrent_other_ptr + state_size + kAVX2SIMDWidth, | |
| input_ptr + state_size + kAVX2SIMDWidth); | |
| float_sigmoid_float<kInputBits>(update0, update1); | |
| __m256 cell0 = _mm256_cvtepi32_ps(_mm256_load_si256( | |
| reinterpret_cast<__m256i const*>(input_ptr + 2 * state_size))); | |
| __m256 cell1 = | |
| _mm256_cvtepi32_ps(_mm256_load_si256(reinterpret_cast<__m256i const*>( | |
| input_ptr + 2 * state_size + kAVX2SIMDWidth))); | |
| if (kInputsMode != ARInputsMode::k0ARInputs) { | |
| cell0 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>( | |
| ar_2_inputs, ar_3rd_input, ar_01_weights + 4 * state_size, | |
| ar_2_weights + 2 * state_size, cell0); | |
| cell1 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>( | |
| ar_2_inputs, ar_3rd_input, | |
| ar_01_weights + 4 * state_size + 2 * kQRIncrement, | |
| ar_2_weights + 2 * state_size + kQRIncrement, cell1); | |
| } | |
| __m256i gru_state = GRUComputeState<kInputBits, kStateBits, kSplitGates>( | |
| cell0, cell1, reset0, reset1, update0, update1, | |
| gru_recurrent_ptr + 2 * state_size, | |
| gru_recurrent_other_ptr + 2 * state_size, gru_state_ptr); | |
| if (kReplicas > 0) { | |
| // With |kReplicas| a template parameter, the compiler will unroll the | |
| // loop. | |
| for (int j = 0; j < kReplicas; ++j) { | |
| _mm256_store_si256( | |
| reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride), | |
| gru_state); | |
| } | |
| } else { | |
| // This loop will not unroll as replicas is variable. | |
| for (int j = 0; j < replicas; ++j) { | |
| _mm256_store_si256( | |
| reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride), | |
| gru_state); | |
| } | |
| } | |
| // Increment all the pointers. | |
| input_ptr += 2 * kAVX2SIMDWidth; | |
| gru_state_ptr += 2 * kAVX2SIMDWidth; | |
| gru_recurrent_ptr += 2 * kAVX2SIMDWidth; | |
| if (kSplitGates) gru_recurrent_other_ptr += 2 * kAVX2SIMDWidth; | |
| if (kInputsMode != ARInputsMode::k0ARInputs) { | |
| ar_01_weights += 4 * kQRIncrement; | |
| if (kInputsMode == ARInputsMode::k3ARInputs) | |
| ar_2_weights += 2 * kQRIncrement; | |
| } | |
| } | |
| } | |
| // Dispatches calls to the GruGatesTemplate function above converting the | |
| // replicas variable argument to a template parameter to allow the compiler to | |
| // unroll the write loop. | |
| // |ar_sample01| packs sample 0 and 1 into a pair because the QR weights are | |
| // formatted with the weights interleaved for sample 0 and 1. The two samples | |
| // represent coarse and fine for WaveRNN. | |
| template <int kInputBits, int kStateBits, | |
| ARInputsMode kInputsMode = ARInputsMode::k2ARInputs, | |
| bool kSplitGates = false> | |
| inline void GruGatesAVXFixed( | |
| int start, int end, int state_size, const int32_t* gru_recurrent_ptr, | |
| const int32_t* input_ptr, const std::pair<float, float>* ar_sample01, | |
| const float* ar_01_weights, int num_replicas, int replica_stride, | |
| const float* ar_sample2, const float* ar_2_weights, | |
| const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) { | |
| // Convert the number of replicas from a variable to a template parameter | |
| // with a switch. This enables the compiler to unroll the loop for | |
| // the write, making it faster for common numbers of threads. | |
| switch (num_replicas) { | |
| case 1: | |
| GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/1, | |
| kSplitGates>( | |
| start, end, state_size, num_replicas, replica_stride, | |
| gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, | |
| ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); | |
| break; | |
| case 2: | |
| GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/2, | |
| kSplitGates>( | |
| start, end, state_size, num_replicas, replica_stride, | |
| gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, | |
| ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); | |
| break; | |
| case 4: | |
| GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/4, | |
| kSplitGates>( | |
| start, end, state_size, num_replicas, replica_stride, | |
| gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, | |
| ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); | |
| break; | |
| case 6: | |
| GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/6, | |
| kSplitGates>( | |
| start, end, state_size, num_replicas, replica_stride, | |
| gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, | |
| ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); | |
| break; | |
| default: | |
| // Zero |kReplicas| tells the function to use the |num_replicas| variable. | |
| GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/0, | |
| kSplitGates>( | |
| start, end, state_size, num_replicas, replica_stride, | |
| gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, | |
| ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); | |
| } | |
| } | |
| } // namespace csrblocksparse | |