aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 16:28:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 16:31:56 -0700
commit514814057e03dcc9389f58e29187898ce7f3a44e (patch)
treeb113aee910c613c9e2647bef522ff8838642b450 /tensorflow/contrib/lite/kernels
parenta444f6a29f4340fc673ce0fc70ceac58dbbf43b9 (diff)
Make 8bit reduce sum op handler rescaling
PiperOrigin-RevId: 214062241
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h41
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc52
-rw-r--r--tensorflow/contrib/lite/kernels/reduce_test.cc12
3 files changed, 84 insertions, 21 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index bb1d30b216..5bfa3bd084 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -4661,12 +4661,15 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
// It does so in two stages, first calculates the sum of elements along the axis
// then divides it by the number of element in axis for quantized values.
template <typename T, typename U>
-inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
- const int* input_dims, const int input_num_dims,
- T* output_data, int32 output_zero_point, float output_scale,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int num_axis_dimensions, bool keep_dims,
- int* temp_index, int* resolved_axis, U* temp_sum) {
+inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point,
+ float input_scale, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ int32 output_zero_point, float output_scale,
+ const int* output_dims,
+ const int output_num_dims, const int* axis,
+ const int num_axis_dimensions, bool keep_dims,
+ int* temp_index, int* resolved_axis, U* temp_sum,
+ bool compute_sum) {
// Reset output data.
size_t num_outputs = 1;
for (int idx = 0; idx < output_num_dims; ++idx) {
@@ -4708,14 +4711,24 @@ inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
if (num_elements_in_axis > 0) {
const float scale = input_scale / output_scale;
- const float bias = -input_zero_point * scale;
- for (size_t idx = 0; idx < num_outputs; ++idx) {
- float float_mean = static_cast<float>(temp_sum[idx]) /
- static_cast<float>(num_elements_in_axis);
-
- // Convert to float value.
- output_data[idx] =
- static_cast<T>(round(float_mean * scale + bias)) + output_zero_point;
+ if (compute_sum) {
+ // TODO(b/116341117): Eliminate float and do this completely in 8bit.
+ const float bias = -input_zero_point * scale * num_elements_in_axis + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ const U value = static_cast<U>(round(temp_sum[idx] * scale + bias)) +
+ output_zero_point;
+ output_data[idx] = static_cast<T>(value);
+ }
+ } else {
+ const float bias = -input_zero_point * scale + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ float float_mean = static_cast<float>(temp_sum[idx]) /
+ static_cast<float>(num_elements_in_axis);
+
+ // Convert to float value.
+ output_data[idx] = static_cast<T>(round(float_mean * scale + bias)) +
+ output_zero_point;
+ }
}
}
return true;
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index d94d821e87..4732a37a65 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -215,7 +215,7 @@ TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
return PrepareSimple(context, node);
}
-TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
// reduce_mean requires a buffer to store intermediate sum result.
@@ -274,7 +274,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
} else {
TF_LITE_ENSURE(
context,
- reference_ops::Mean<>(
+ reference_ops::QuantizedMeanOrSum<>(
GetTensorData<uint8_t>(op_context.input),
op_context.input->params.zero_point,
op_context.input->params.scale, op_context.input->dims->data,
@@ -286,7 +286,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<int>(op_context.axis), num_axis,
op_context.params->keep_dims, GetTensorData<int>(temp_index),
GetTensorData<int>(resolved_axis),
- GetTensorData<int>(temp_sum)));
+ GetTensorData<int>(temp_sum), /*compute_sum=*/false));
}
break;
default:
@@ -416,19 +416,57 @@ TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
}
}
+TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ const auto& input = op_context.input;
+ const auto& output = op_context.output;
+ if (input->type != kTfLiteUInt8 ||
+ (input->params.scale == output->params.scale &&
+ input->params.zero_point == output->params.zero_point)) {
+ return EvalGeneric<kReference, kSum>(context, node);
+ } else {
+ // Rescaling 8bit reduce sum.
+ int num_axis = static_cast<int>(NumElements(op_context.axis));
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
+ }
+
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::QuantizedMeanOrSum<>(
+ GetTensorData<uint8_t>(op_context.input),
+ op_context.input->params.zero_point, op_context.input->params.scale,
+ op_context.input->dims->data, op_context.input->dims->size,
+ GetTensorData<uint8_t>(op_context.output),
+ op_context.output->params.zero_point,
+ op_context.output->params.scale, op_context.output->dims->data,
+ op_context.output->dims->size, GetTensorData<int>(op_context.axis),
+ num_axis, op_context.params->keep_dims,
+ GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
+ GetTensorData<int32>(temp_sum), /*compute_sum=*/true));
+ }
+
+ return kTfLiteOk;
+}
} // namespace reduce
TfLiteRegistration* Register_MEAN_REF() {
static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareMean,
+ reduce::PrepareMeanOrSum,
reduce::EvalMean<reduce::kReference>};
return &r;
}
TfLiteRegistration* Register_SUM_REF() {
- static TfLiteRegistration r = {
- reduce::Init, reduce::Free, reduce::PrepareSimple,
- reduce::EvalGeneric<reduce::kReference, reduce::kSum>};
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareMeanOrSum, reduce::EvalSum};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc
index 6d289b14d8..fb2ec58ab2 100644
--- a/tensorflow/contrib/lite/kernels/reduce_test.cc
+++ b/tensorflow/contrib/lite/kernels/reduce_test.cc
@@ -488,6 +488,18 @@ TEST(ConstUint8SumOpTest, NotKeepDims) {
ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance)));
}
+TEST(ConstUint8SumOpTest, NotKeepDimsRescaling) {
+ float kQuantizedTolerance = GetTolerance(0.0, 2.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, 0.0, 1.0},
+ {TensorType_UINT8, {2}, 0.0, 2.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {1.2, 1.2}, kQuantizedTolerance)));
+}
+
TEST(ConstUint8SumOpTest, KeepDims) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};