aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-09 17:01:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 17:06:31 -0700
commit03e10b0b485fca75ebd476201fa49c7f3b86bfa3 (patch)
tree00e9ad88e0ea5c6cdc1c980bd7ef3864f1834db5 /tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
parent37bfe7a9f290c267ff7a804038fb6c8979975dee (diff)
Make FullyConnected() op work with real_mutiplier > 1.
#20451 #19607 PiperOrigin-RevId: 208135233
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h17
1 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 6adb879c71..b870789772 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -893,6 +893,7 @@ inline void FullyConnectedAsGEMV(
const int input_size = FlatSizeSkipDim(input_dims, 3);
const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
static constexpr int kPeel = 4;
+ const bool shift_left = (output_shift <= 0);
for (int k = 0; k < input_size; k += 64) {
optimized_ops_preload_l1_stream(input_data + k);
}
@@ -1004,11 +1005,17 @@ inline void FullyConnectedAsGEMV(
int32x4_t bias_vec = vld1q_s32(bias_ptr);
bias_ptr += 4;
reduced = vaddq_s32(reduced, bias_vec);
- // Multiply by the fixed-point multiplier.
- reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
- // Rounding-shift-right.
- using gemmlowp::RoundingDivideByPOT;
- reduced = RoundingDivideByPOT(reduced, output_shift);
+ if (shift_left) {
+ const int32 multiplier_power_of_two = 1 << -output_shift;
+ reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
+ reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
+ } else {
+ // Multiply by the fixed-point multiplier.
+ reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
+ // Rounding-shift-right.
+ using gemmlowp::RoundingDivideByPOT;
+ reduced = RoundingDivideByPOT(reduced, output_shift);
+ }
// Add the output offset.
const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
reduced = vaddq_s32(reduced, output_offset_vec);