diff options
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h | 384 |
1 files changed, 192 insertions, 192 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index d7a0005f27..f08d9d6d57 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -324,6 +324,198 @@ void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs, } } +#ifdef GEMMLOWP_NEON +// In the common case of batch size 1, a fully-connected node degenerates +// to a matrix*vector product. LSTM cells contain a fully-connected node; +// when quantized, this becomes a special type of GEMV operation where +// the output is 16bit-quantized, thus needs its own special path. +inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims, + const uint8* weights_data, + const Dims<4>& weights_dims, + uint8 weights_zero_point, const int32* bias_data, + const Dims<4>& bias_dims, int32 accum_multiplier, + int accum_shift, int16* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3), + 1); + const int input_size = input_dims.strides[3]; + const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0); + // This special fast path for quantized LSTM cells does not try to support + // odd sizes that we haven't encountered in any LSTM cell, that would + // require special code (that would go untested until any LSTM cell + // exercises it). We just guard our assumptions about size evenness with + // the following assertions. + TFLITE_DCHECK(!(output_size % 4)); + TFLITE_DCHECK(!(input_size % 8)); + const int32* bias_ptr = bias_data; + int16* output_ptr = output_data; + for (int out = 0; out < output_size; out += 4) { + int32x4_t acc_0 = vdupq_n_s32(0); + int32x4_t acc_1 = vdupq_n_s32(0); + int32x4_t acc_2 = vdupq_n_s32(0); + int32x4_t acc_3 = vdupq_n_s32(0); + const int16x8_t input_offset_vec = vdupq_n_s16(-128); + const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point); + int in = 0; + // Handle 16 levels of depth at a time. + for (; in <= input_size - 16; in += 16) { + const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); + const uint8* weights_ptr = weights_data + in + out * input_size; + uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size); + uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size); + uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size); + uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size); + int16x8_t input_val_0, input_val_1; + const uint8x8_t low = vget_low_u8(input_val_u8); + const uint8x8_t high = vget_high_u8(input_val_u8); + input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low)); + input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high)); + input_val_0 = vaddq_s16(input_val_0, input_offset_vec); + input_val_1 = vaddq_s16(input_val_1, input_offset_vec); + int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0, + weights_val_3_0; + int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1, + weights_val_3_1; + weights_val_0_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))), + weights_offset_vec); + weights_val_0_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))), + weights_offset_vec); + weights_val_1_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))), + weights_offset_vec); + weights_val_1_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))), + weights_offset_vec); + weights_val_2_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))), + weights_offset_vec); + weights_val_2_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))), + weights_offset_vec); + weights_val_3_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))), + weights_offset_vec); + weights_val_3_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))), + weights_offset_vec); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0), + vget_low_s16(input_val_0)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0), + vget_low_s16(input_val_0)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0), + vget_low_s16(input_val_0)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0), + vget_low_s16(input_val_0)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0), + vget_high_s16(input_val_0)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0), + vget_high_s16(input_val_0)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0), + vget_high_s16(input_val_0)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0), + vget_high_s16(input_val_0)); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1), + vget_low_s16(input_val_1)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1), + vget_low_s16(input_val_1)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1), + vget_low_s16(input_val_1)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1), + vget_low_s16(input_val_1)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1), + vget_high_s16(input_val_1)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1), + vget_high_s16(input_val_1)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1), + vget_high_s16(input_val_1)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1), + vget_high_s16(input_val_1)); + } + // Handle 8 levels of depth at a time. + for (; in < input_size; in += 8) { + const uint8x8_t input_val_u8 = vld1_u8(input_data + in); + const uint8* weights_ptr = weights_data + in + out * input_size; + uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size); + uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size); + uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size); + uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size); + int16x8_t input_val; + input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); + input_val = vaddq_s16(input_val, input_offset_vec); + int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3; + weights_val_0 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)), + weights_offset_vec); + weights_val_1 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)), + weights_offset_vec); + weights_val_2 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)), + weights_offset_vec); + weights_val_3 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)), + weights_offset_vec); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0), + vget_low_s16(input_val)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1), + vget_low_s16(input_val)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2), + vget_low_s16(input_val)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3), + vget_low_s16(input_val)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0), + vget_high_s16(input_val)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1), + vget_high_s16(input_val)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2), + vget_high_s16(input_val)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3), + vget_high_s16(input_val)); + } + // Horizontally reduce accumulators + int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, + pairwise_reduced_acc_2, pairwise_reduced_acc_3; + pairwise_reduced_acc_0 = + vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0)); + pairwise_reduced_acc_1 = + vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1)); + pairwise_reduced_acc_2 = + vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2)); + pairwise_reduced_acc_3 = + vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3)); + const int32x2_t reduced_lo = + vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); + const int32x2_t reduced_hi = + vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); + int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); + // Add bias values. + int32x4_t bias_vec = vld1q_s32(bias_ptr); + bias_ptr += 4; + reduced = vaddq_s32(reduced, bias_vec); + int left_shift = accum_shift > 0 ? accum_shift : 0; + int right_shift = accum_shift > 0 ? 0 : -accum_shift; + reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); + // Multiply by the fixed-point multiplier. + reduced = vqrdmulhq_n_s32(reduced, accum_multiplier); + // Rounding-shift-right. + using gemmlowp::RoundingDivideByPOT; + reduced = RoundingDivideByPOT(reduced, right_shift); + // Narrow values down to 16 bit signed. + const int16x4_t res16 = vqmovn_s32(reduced); + vst1_s16(output_ptr, res16); + output_ptr += 4; + } +} +#endif + inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, const float* weights_data, const Dims<4>& weights_dims, const float* bias_data, @@ -2478,198 +2670,6 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, output_state_map.tanh(); } -#ifdef GEMMLOWP_NEON -// In the common case of batch size 1, a fully-connected node degenerates -// to a matrix*vector product. LSTM cells contain a fully-connected node; -// when quantized, this becomes a special type of GEMV operation where -// the output is 16bit-quantized, thus needs its own special path. -inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims, - const uint8* weights_data, - const Dims<4>& weights_dims, - uint8 weights_zero_point, const int32* bias_data, - const Dims<4>& bias_dims, int32 accum_multiplier, - int accum_shift, int16* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell"); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3), - 1); - const int input_size = input_dims.strides[3]; - const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0); - // This special fast path for quantized LSTM cells does not try to support - // odd sizes that we haven't encountered in any LSTM cell, that would - // require special code (that would go untested until any LSTM cell - // exercises it). We just guard our assumptions about size evenness with - // the following assertions. - TFLITE_DCHECK(!(output_size % 4)); - TFLITE_DCHECK(!(input_size % 8)); - const int32* bias_ptr = bias_data; - int16* output_ptr = output_data; - for (int out = 0; out < output_size; out += 4) { - int32x4_t acc_0 = vdupq_n_s32(0); - int32x4_t acc_1 = vdupq_n_s32(0); - int32x4_t acc_2 = vdupq_n_s32(0); - int32x4_t acc_3 = vdupq_n_s32(0); - const int16x8_t input_offset_vec = vdupq_n_s16(-128); - const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point); - int in = 0; - // Handle 16 levels of depth at a time. - for (; in <= input_size - 16; in += 16) { - const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); - const uint8* weights_ptr = weights_data + in + out * input_size; - uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size); - uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size); - uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size); - uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size); - int16x8_t input_val_0, input_val_1; - const uint8x8_t low = vget_low_u8(input_val_u8); - const uint8x8_t high = vget_high_u8(input_val_u8); - input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low)); - input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high)); - input_val_0 = vaddq_s16(input_val_0, input_offset_vec); - input_val_1 = vaddq_s16(input_val_1, input_offset_vec); - int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0, - weights_val_3_0; - int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1, - weights_val_3_1; - weights_val_0_0 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))), - weights_offset_vec); - weights_val_0_1 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))), - weights_offset_vec); - weights_val_1_0 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))), - weights_offset_vec); - weights_val_1_1 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))), - weights_offset_vec); - weights_val_2_0 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))), - weights_offset_vec); - weights_val_2_1 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))), - weights_offset_vec); - weights_val_3_0 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))), - weights_offset_vec); - weights_val_3_1 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))), - weights_offset_vec); - acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0), - vget_low_s16(input_val_0)); - acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0), - vget_low_s16(input_val_0)); - acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0), - vget_low_s16(input_val_0)); - acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0), - vget_low_s16(input_val_0)); - acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0), - vget_high_s16(input_val_0)); - acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0), - vget_high_s16(input_val_0)); - acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0), - vget_high_s16(input_val_0)); - acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0), - vget_high_s16(input_val_0)); - acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1), - vget_low_s16(input_val_1)); - acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1), - vget_low_s16(input_val_1)); - acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1), - vget_low_s16(input_val_1)); - acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1), - vget_low_s16(input_val_1)); - acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1), - vget_high_s16(input_val_1)); - acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1), - vget_high_s16(input_val_1)); - acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1), - vget_high_s16(input_val_1)); - acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1), - vget_high_s16(input_val_1)); - } - // Handle 8 levels of depth at a time. - for (; in < input_size; in += 8) { - const uint8x8_t input_val_u8 = vld1_u8(input_data + in); - const uint8* weights_ptr = weights_data + in + out * input_size; - uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size); - uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size); - uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size); - uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size); - int16x8_t input_val; - input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); - input_val = vaddq_s16(input_val, input_offset_vec); - int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3; - weights_val_0 = - vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)), - weights_offset_vec); - weights_val_1 = - vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)), - weights_offset_vec); - weights_val_2 = - vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)), - weights_offset_vec); - weights_val_3 = - vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)), - weights_offset_vec); - acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0), - vget_low_s16(input_val)); - acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1), - vget_low_s16(input_val)); - acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2), - vget_low_s16(input_val)); - acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3), - vget_low_s16(input_val)); - acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0), - vget_high_s16(input_val)); - acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1), - vget_high_s16(input_val)); - acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2), - vget_high_s16(input_val)); - acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3), - vget_high_s16(input_val)); - } - // Horizontally reduce accumulators - int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, - pairwise_reduced_acc_2, pairwise_reduced_acc_3; - pairwise_reduced_acc_0 = - vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0)); - pairwise_reduced_acc_1 = - vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1)); - pairwise_reduced_acc_2 = - vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2)); - pairwise_reduced_acc_3 = - vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3)); - const int32x2_t reduced_lo = - vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); - const int32x2_t reduced_hi = - vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); - int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); - // Add bias values. - int32x4_t bias_vec = vld1q_s32(bias_ptr); - bias_ptr += 4; - reduced = vaddq_s32(reduced, bias_vec); - int left_shift = accum_shift > 0 ? accum_shift : 0; - int right_shift = accum_shift > 0 ? 0 : -accum_shift; - reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); - // Multiply by the fixed-point multiplier. - reduced = vqrdmulhq_n_s32(reduced, accum_multiplier); - // Rounding-shift-right. - using gemmlowp::RoundingDivideByPOT; - reduced = RoundingDivideByPOT(reduced, right_shift); - // Narrow values down to 16 bit signed. - const int16x4_t res16 = vqmovn_s32(reduced); - vst1_s16(output_ptr, res16); - output_ptr += 4; - } -} -#endif - // Quantized LSTM cell. Currently just a copy of the reference impl in // reference_ops.h. See the big function comment there, not replicating it // here. |