diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-21 16:28:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 16:31:56 -0700 |
commit | 514814057e03dcc9389f58e29187898ce7f3a44e (patch) | |
tree | b113aee910c613c9e2647bef522ff8838642b450 /tensorflow/contrib/lite/kernels/internal | |
parent | a444f6a29f4340fc673ce0fc70ceac58dbbf43b9 (diff) |
Make 8bit reduce sum op handler rescaling
PiperOrigin-RevId: 214062241
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h | 41 |
1 files changed, 27 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index bb1d30b216..5bfa3bd084 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -4661,12 +4661,15 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims, // It does so in two stages, first calculates the sum of elements along the axis // then divides it by the number of element in axis for quantized values. template <typename T, typename U> -inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale, - const int* input_dims, const int input_num_dims, - T* output_data, int32 output_zero_point, float output_scale, - const int* output_dims, const int output_num_dims, - const int* axis, const int num_axis_dimensions, bool keep_dims, - int* temp_index, int* resolved_axis, U* temp_sum) { +inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point, + float input_scale, const int* input_dims, + const int input_num_dims, T* output_data, + int32 output_zero_point, float output_scale, + const int* output_dims, + const int output_num_dims, const int* axis, + const int num_axis_dimensions, bool keep_dims, + int* temp_index, int* resolved_axis, U* temp_sum, + bool compute_sum) { // Reset output data. size_t num_outputs = 1; for (int idx = 0; idx < output_num_dims; ++idx) { @@ -4708,14 +4711,24 @@ inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale, if (num_elements_in_axis > 0) { const float scale = input_scale / output_scale; - const float bias = -input_zero_point * scale; - for (size_t idx = 0; idx < num_outputs; ++idx) { - float float_mean = static_cast<float>(temp_sum[idx]) / - static_cast<float>(num_elements_in_axis); - - // Convert to float value. - output_data[idx] = - static_cast<T>(round(float_mean * scale + bias)) + output_zero_point; + if (compute_sum) { + // TODO(b/116341117): Eliminate float and do this completely in 8bit. + const float bias = -input_zero_point * scale * num_elements_in_axis + 0.5; + for (size_t idx = 0; idx < num_outputs; ++idx) { + const U value = static_cast<U>(round(temp_sum[idx] * scale + bias)) + + output_zero_point; + output_data[idx] = static_cast<T>(value); + } + } else { + const float bias = -input_zero_point * scale + 0.5; + for (size_t idx = 0; idx < num_outputs; ++idx) { + float float_mean = static_cast<float>(temp_sum[idx]) / + static_cast<float>(num_elements_in_axis); + + // Convert to float value. + output_data[idx] = static_cast<T>(round(float_mean * scale + bias)) + + output_zero_point; + } } } return true; |