aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 16:28:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 16:31:56 -0700
commit514814057e03dcc9389f58e29187898ce7f3a44e (patch)
treeb113aee910c613c9e2647bef522ff8838642b450 /tensorflow/contrib/lite/kernels/internal
parenta444f6a29f4340fc673ce0fc70ceac58dbbf43b9 (diff)
Make 8bit reduce sum op handler rescaling
PiperOrigin-RevId: 214062241
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h41
1 files changed, 27 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index bb1d30b216..5bfa3bd084 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -4661,12 +4661,15 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
// It does so in two stages, first calculates the sum of elements along the axis
// then divides it by the number of element in axis for quantized values.
template <typename T, typename U>
-inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
- const int* input_dims, const int input_num_dims,
- T* output_data, int32 output_zero_point, float output_scale,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int num_axis_dimensions, bool keep_dims,
- int* temp_index, int* resolved_axis, U* temp_sum) {
+inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point,
+ float input_scale, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ int32 output_zero_point, float output_scale,
+ const int* output_dims,
+ const int output_num_dims, const int* axis,
+ const int num_axis_dimensions, bool keep_dims,
+ int* temp_index, int* resolved_axis, U* temp_sum,
+ bool compute_sum) {
// Reset output data.
size_t num_outputs = 1;
for (int idx = 0; idx < output_num_dims; ++idx) {
@@ -4708,14 +4711,24 @@ inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
if (num_elements_in_axis > 0) {
const float scale = input_scale / output_scale;
- const float bias = -input_zero_point * scale;
- for (size_t idx = 0; idx < num_outputs; ++idx) {
- float float_mean = static_cast<float>(temp_sum[idx]) /
- static_cast<float>(num_elements_in_axis);
-
- // Convert to float value.
- output_data[idx] =
- static_cast<T>(round(float_mean * scale + bias)) + output_zero_point;
+ if (compute_sum) {
+ // TODO(b/116341117): Eliminate float and do this completely in 8bit.
+ const float bias = -input_zero_point * scale * num_elements_in_axis + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ const U value = static_cast<U>(round(temp_sum[idx] * scale + bias)) +
+ output_zero_point;
+ output_data[idx] = static_cast<T>(value);
+ }
+ } else {
+ const float bias = -input_zero_point * scale + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ float float_mean = static_cast<float>(temp_sum[idx]) /
+ static_cast<float>(num_elements_in_axis);
+
+ // Convert to float value.
+ output_data[idx] = static_cast<T>(round(float_mean * scale + bias)) +
+ output_zero_point;
+ }
}
}
return true;