diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-21 16:28:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 16:31:56 -0700 |
commit | 514814057e03dcc9389f58e29187898ce7f3a44e (patch) | |
tree | b113aee910c613c9e2647bef522ff8838642b450 /tensorflow/contrib/lite/kernels | |
parent | a444f6a29f4340fc673ce0fc70ceac58dbbf43b9 (diff) |
Make 8bit reduce sum op handler rescaling
PiperOrigin-RevId: 214062241
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h | 41 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/reduce.cc | 52 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/reduce_test.cc | 12 |
3 files changed, 84 insertions, 21 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index bb1d30b216..5bfa3bd084 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -4661,12 +4661,15 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims, // It does so in two stages, first calculates the sum of elements along the axis // then divides it by the number of element in axis for quantized values. template <typename T, typename U> -inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale, - const int* input_dims, const int input_num_dims, - T* output_data, int32 output_zero_point, float output_scale, - const int* output_dims, const int output_num_dims, - const int* axis, const int num_axis_dimensions, bool keep_dims, - int* temp_index, int* resolved_axis, U* temp_sum) { +inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point, + float input_scale, const int* input_dims, + const int input_num_dims, T* output_data, + int32 output_zero_point, float output_scale, + const int* output_dims, + const int output_num_dims, const int* axis, + const int num_axis_dimensions, bool keep_dims, + int* temp_index, int* resolved_axis, U* temp_sum, + bool compute_sum) { // Reset output data. size_t num_outputs = 1; for (int idx = 0; idx < output_num_dims; ++idx) { @@ -4708,14 +4711,24 @@ inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale, if (num_elements_in_axis > 0) { const float scale = input_scale / output_scale; - const float bias = -input_zero_point * scale; - for (size_t idx = 0; idx < num_outputs; ++idx) { - float float_mean = static_cast<float>(temp_sum[idx]) / - static_cast<float>(num_elements_in_axis); - - // Convert to float value. - output_data[idx] = - static_cast<T>(round(float_mean * scale + bias)) + output_zero_point; + if (compute_sum) { + // TODO(b/116341117): Eliminate float and do this completely in 8bit. + const float bias = -input_zero_point * scale * num_elements_in_axis + 0.5; + for (size_t idx = 0; idx < num_outputs; ++idx) { + const U value = static_cast<U>(round(temp_sum[idx] * scale + bias)) + + output_zero_point; + output_data[idx] = static_cast<T>(value); + } + } else { + const float bias = -input_zero_point * scale + 0.5; + for (size_t idx = 0; idx < num_outputs; ++idx) { + float float_mean = static_cast<float>(temp_sum[idx]) / + static_cast<float>(num_elements_in_axis); + + // Convert to float value. + output_data[idx] = static_cast<T>(round(float_mean * scale + bias)) + + output_zero_point; + } } } return true; diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc index d94d821e87..4732a37a65 100644 --- a/tensorflow/contrib/lite/kernels/reduce.cc +++ b/tensorflow/contrib/lite/kernels/reduce.cc @@ -215,7 +215,7 @@ TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) { return PrepareSimple(context, node); } -TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, PrepareSimple(context, node)); // reduce_mean requires a buffer to store intermediate sum result. @@ -274,7 +274,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { } else { TF_LITE_ENSURE( context, - reference_ops::Mean<>( + reference_ops::QuantizedMeanOrSum<>( GetTensorData<uint8_t>(op_context.input), op_context.input->params.zero_point, op_context.input->params.scale, op_context.input->dims->data, @@ -286,7 +286,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { GetTensorData<int>(op_context.axis), num_axis, op_context.params->keep_dims, GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis), - GetTensorData<int>(temp_sum))); + GetTensorData<int>(temp_sum), /*compute_sum=*/false)); } break; default: @@ -416,19 +416,57 @@ TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) { } } +TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + const auto& input = op_context.input; + const auto& output = op_context.output; + if (input->type != kTfLiteUInt8 || + (input->params.scale == output->params.scale && + input->params.zero_point == output->params.zero_point)) { + return EvalGeneric<kReference, kSum>(context, node); + } else { + // Rescaling 8bit reduce sum. + int num_axis = static_cast<int>(NumElements(op_context.axis)); + TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum)); + } + + TF_LITE_ENSURE( + context, + reference_ops::QuantizedMeanOrSum<>( + GetTensorData<uint8_t>(op_context.input), + op_context.input->params.zero_point, op_context.input->params.scale, + op_context.input->dims->data, op_context.input->dims->size, + GetTensorData<uint8_t>(op_context.output), + op_context.output->params.zero_point, + op_context.output->params.scale, op_context.output->dims->data, + op_context.output->dims->size, GetTensorData<int>(op_context.axis), + num_axis, op_context.params->keep_dims, + GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis), + GetTensorData<int32>(temp_sum), /*compute_sum=*/true)); + } + + return kTfLiteOk; +} } // namespace reduce TfLiteRegistration* Register_MEAN_REF() { static TfLiteRegistration r = {reduce::Init, reduce::Free, - reduce::PrepareMean, + reduce::PrepareMeanOrSum, reduce::EvalMean<reduce::kReference>}; return &r; } TfLiteRegistration* Register_SUM_REF() { - static TfLiteRegistration r = { - reduce::Init, reduce::Free, reduce::PrepareSimple, - reduce::EvalGeneric<reduce::kReference, reduce::kSum>}; + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareMeanOrSum, reduce::EvalSum}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc index 6d289b14d8..fb2ec58ab2 100644 --- a/tensorflow/contrib/lite/kernels/reduce_test.cc +++ b/tensorflow/contrib/lite/kernels/reduce_test.cc @@ -488,6 +488,18 @@ TEST(ConstUint8SumOpTest, NotKeepDims) { ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance))); } +TEST(ConstUint8SumOpTest, NotKeepDimsRescaling) { + float kQuantizedTolerance = GetTolerance(0.0, 2.0); + std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, 0.0, 1.0}, + {TensorType_UINT8, {2}, 0.0, 2.0}, {1}, {1}, false); + m.QuantizeAndPopulate<uint8_t>(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + {1.2, 1.2}, kQuantizedTolerance))); +} + TEST(ConstUint8SumOpTest, KeepDims) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; |