diff options
author | 2018-08-09 17:01:57 -0700 | |
---|---|---|
committer | 2018-08-09 17:06:31 -0700 | |
commit | 03e10b0b485fca75ebd476201fa49c7f3b86bfa3 (patch) | |
tree | 00e9ad88e0ea5c6cdc1c980bd7ef3864f1834db5 /tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h | |
parent | 37bfe7a9f290c267ff7a804038fb6c8979975dee (diff) |
Make FullyConnected() op work with real_mutiplier > 1.
#20451
#19607
PiperOrigin-RevId: 208135233
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 6adb879c71..b870789772 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -893,6 +893,7 @@ inline void FullyConnectedAsGEMV( const int input_size = FlatSizeSkipDim(input_dims, 3); const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); static constexpr int kPeel = 4; + const bool shift_left = (output_shift <= 0); for (int k = 0; k < input_size; k += 64) { optimized_ops_preload_l1_stream(input_data + k); } @@ -1004,11 +1005,17 @@ inline void FullyConnectedAsGEMV( int32x4_t bias_vec = vld1q_s32(bias_ptr); bias_ptr += 4; reduced = vaddq_s32(reduced, bias_vec); - // Multiply by the fixed-point multiplier. - reduced = vqrdmulhq_n_s32(reduced, output_multiplier); - // Rounding-shift-right. - using gemmlowp::RoundingDivideByPOT; - reduced = RoundingDivideByPOT(reduced, output_shift); + if (shift_left) { + const int32 multiplier_power_of_two = 1 << -output_shift; + reduced = vmulq_n_s32(reduced, multiplier_power_of_two); + reduced = vqrdmulhq_n_s32(reduced, output_multiplier); + } else { + // Multiply by the fixed-point multiplier. + reduced = vqrdmulhq_n_s32(reduced, output_multiplier); + // Rounding-shift-right. + using gemmlowp::RoundingDivideByPOT; + reduced = RoundingDivideByPOT(reduced, output_shift); + } // Add the output offset. const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); reduced = vaddq_s32(reduced, output_offset_vec); |