aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/reduce.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-30 16:25:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 16:30:13 -0700
commit82dced12ac82541d36c03dc273794e49956530a6 (patch)
treed67170f5138f7fcaa49aab11d5ff4c8f11c0eaff /tensorflow/contrib/lite/kernels/reduce.cc
parentde05ebe296d83607ca0ab1803bd8eed6afa7f74f (diff)
Refactor reduce_sum, reduce_prod, reduce_max, reduce_min, reduce_any
PiperOrigin-RevId: 211002207
Diffstat (limited to 'tensorflow/contrib/lite/kernels/reduce.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc324
1 files changed, 115 insertions, 209 deletions
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index 4001cf357f..ca83797936 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
+#include <limits>
#include <vector>
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
@@ -296,221 +297,125 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int num_axis = static_cast<int>(NumElements(op_context.axis));
+// The underlying logic for Reduce Sum/Prod/Max/Min/Any
+template <typename T>
+TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, T init_value,
+ T reducer(const T current, const T in)) {
+ int64_t num_axis = NumElements(op_context->axis);
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
// Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
+ if (IsDynamicTensor(op_context->output)) {
TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ ResizeTempAxis(context, op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context));
}
-
-#define TF_LITE_SUM(kernel_type, data_type) \
- kernel_type::Sum<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, uint8_t));
- break;
- default:
- return kTfLiteError;
- }
+ if (op_context->input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, op_context->input->params.scale,
+ op_context->output->params.scale);
+ TF_LITE_ENSURE_EQ(context, op_context->input->params.zero_point,
+ op_context->output->params.zero_point);
}
-#undef TF_LITE_SUM
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::ReduceGeneric<T>(
+ GetTensorData<T>(op_context->input), op_context->input->dims->data,
+ op_context->input->dims->size, GetTensorData<T>(op_context->output),
+ op_context->output->dims->data, op_context->output->dims->size,
+ GetTensorData<int>(op_context->axis), num_axis,
+ op_context->params->keep_dims, GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis), init_value, reducer));
return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- }
-
-#define TF_LITE_PROD(kernel_type, data_type) \
- kernel_type::ReduceProd<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
+enum ReduceType {
+ kSum,
+ kProd,
+ kMax,
+ kMin,
+ kAny,
+};
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- // TODO(wangtz): uint8 reduce_prod is not yet supported.
- default:
- return kTfLiteError;
- }
+// Eval for determined input type and reduce type.
+template <typename T>
+TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, ReduceType reduce_type) {
+ switch (reduce_type) {
+ case kSum:
+ return EvalLogic<T>(
+ context, node, op_context, static_cast<T>(0),
+ [](const T current, const T in) -> T { return in + current; });
+ break;
+ case kProd:
+ return EvalLogic<T>(
+ context, node, op_context, static_cast<T>(1),
+ [](const T current, const T in) -> T { return in * current; });
+ break;
+ case kMax:
+ return EvalLogic<T>(context, node, op_context,
+ std::numeric_limits<T>::lowest(),
+ [](const T current, const T in) -> T {
+ return (in > current) ? in : current;
+ });
+ break;
+ case kMin:
+ return EvalLogic<T>(context, node, op_context,
+ std::numeric_limits<T>::max(),
+ [](const T current, const T in) -> T {
+ return (in < current) ? in : current;
+ });
+ break;
+ default:
+ return kTfLiteError;
}
-#undef TF_LITE_PROD
- return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- }
-
-#define TF_LITE_MAX(kernel_type, data_type) \
- kernel_type::ReduceMax<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, uint8_t));
- break;
- default:
- return kTfLiteError;
- }
+// Template specialization for bool type
+template <>
+TfLiteStatus EvalType<bool>(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, ReduceType reduce_type) {
+ switch (reduce_type) {
+ case kAny:
+ return EvalLogic<bool>(context, node, op_context, false,
+ [](const bool current, const bool in) -> bool {
+ return in || current;
+ });
+ break;
+ default:
+ return kTfLiteError;
}
-#undef TF_LITE_MAX
- return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- }
-
-#define TF_LITE_MIN(kernel_type, data_type) \
- kernel_type::ReduceMin<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, uint8_t));
- break;
- default:
- return kTfLiteError;
- }
+// The entry point that handles input types and then calls template functions to
+// handle ReduceType.
+template <KernelType kernel_type, ReduceType reduce_type>
+TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
+ if (kernel_type != kReference) {
+ return kTfLiteOk;
}
-#undef TF_LITE_MIN
- return kTfLiteOk;
-}
-
-template <KernelType kernel_type>
-TfLiteStatus EvalAny(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- }
- if (kernel_type == kReference) {
- reference_ops::ReduceAny(
- GetTensorData<bool>(op_context.input), op_context.input->dims->data,
- op_context.input->dims->size, GetTensorData<bool>(op_context.output),
- op_context.output->dims->data, op_context.output->dims->size,
- GetTensorData<int>(op_context.axis), num_axis,
- op_context.params->keep_dims, GetTensorData<int>(temp_index),
- GetTensorData<int>(resolved_axis));
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ return EvalType<float>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteInt32:
+ return EvalType<int>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteInt64:
+ return EvalType<int64_t>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteUInt8:
+ return EvalType<uint8_t>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteBool:
+ return EvalType<bool>(context, node, &op_context, reduce_type);
+ break;
+ default:
+ return kTfLiteError;
}
-
- return kTfLiteOk;
}
+
} // namespace reduce
TfLiteRegistration* Register_MEAN_REF() {
@@ -521,36 +426,37 @@ TfLiteRegistration* Register_MEAN_REF() {
}
TfLiteRegistration* Register_SUM_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalSum<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kSum>};
return &r;
}
TfLiteRegistration* Register_REDUCE_PROD_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalProd<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kProd>};
return &r;
}
TfLiteRegistration* Register_REDUCE_MAX_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalMax<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kMax>};
return &r;
}
TfLiteRegistration* Register_REDUCE_MIN_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalMin<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kMin>};
return &r;
}
TfLiteRegistration* Register_REDUCE_ANY_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free, reduce::PrepareAny,
- reduce::EvalAny<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareAny,
+ reduce::EvalGeneric<reduce::kReference, reduce::kAny>};
return &r;
}