From 82dced12ac82541d36c03dc273794e49956530a6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Aug 2018 16:25:50 -0700 Subject: Refactor reduce_sum, reduce_prod, reduce_max, reduce_min, reduce_any PiperOrigin-RevId: 211002207 --- tensorflow/contrib/lite/kernels/reduce.cc | 324 +++++++++++------------------- 1 file changed, 115 insertions(+), 209 deletions(-) (limited to 'tensorflow/contrib/lite/kernels/reduce.cc') diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc index 4001cf357f..ca83797936 100644 --- a/tensorflow/contrib/lite/kernels/reduce.cc +++ b/tensorflow/contrib/lite/kernels/reduce.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" @@ -296,221 +297,125 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -template -TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { - OpContext op_context(context, node); - int num_axis = static_cast(NumElements(op_context.axis)); +// The underlying logic for Reduce Sum/Prod/Max/Min/Any +template +TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node, + OpContext* op_context, T init_value, + T reducer(const T current, const T in)) { + int64_t num_axis = NumElements(op_context->axis); TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); // Resize the output tensor if the output tensor is dynamic. - if (IsDynamicTensor(op_context.output)) { + if (IsDynamicTensor(op_context->output)) { TF_LITE_ENSURE_OK(context, - ResizeTempAxis(context, &op_context, resolved_axis)); - TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + ResizeTempAxis(context, op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context)); } - -#define TF_LITE_SUM(kernel_type, data_type) \ - kernel_type::Sum<>( \ - GetTensorData(op_context.input), \ - op_context.input->dims->data, op_context.input->dims->size, \ - GetTensorData(op_context.output), \ - op_context.output->dims->data, op_context.output->dims->size, \ - GetTensorData(op_context.axis), num_axis, \ - op_context.params->keep_dims, GetTensorData(temp_index), \ - GetTensorData(resolved_axis)) - - if (kernel_type == kReference) { - switch (op_context.input->type) { - case kTfLiteFloat32: - TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, float)); - break; - case kTfLiteInt32: - TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int)); - break; - case kTfLiteInt64: - TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int64_t)); - break; - case kTfLiteUInt8: - TF_LITE_ENSURE_EQ(context, op_context.input->params.scale, - op_context.output->params.scale); - TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, - op_context.output->params.zero_point); - TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, uint8_t)); - break; - default: - return kTfLiteError; - } + if (op_context->input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, op_context->input->params.scale, + op_context->output->params.scale); + TF_LITE_ENSURE_EQ(context, op_context->input->params.zero_point, + op_context->output->params.zero_point); } -#undef TF_LITE_SUM + TF_LITE_ENSURE( + context, + reference_ops::ReduceGeneric( + GetTensorData(op_context->input), op_context->input->dims->data, + op_context->input->dims->size, GetTensorData(op_context->output), + op_context->output->dims->data, op_context->output->dims->size, + GetTensorData(op_context->axis), num_axis, + op_context->params->keep_dims, GetTensorData(temp_index), + GetTensorData(resolved_axis), init_value, reducer)); return kTfLiteOk; } -template -TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) { - OpContext op_context(context, node); - int64_t num_axis = NumElements(op_context.axis); - TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); - TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); - // Resize the output tensor if the output tensor is dynamic. - if (IsDynamicTensor(op_context.output)) { - TF_LITE_ENSURE_OK(context, - ResizeTempAxis(context, &op_context, resolved_axis)); - TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); - } - -#define TF_LITE_PROD(kernel_type, data_type) \ - kernel_type::ReduceProd<>( \ - GetTensorData(op_context.input), \ - op_context.input->dims->data, op_context.input->dims->size, \ - GetTensorData(op_context.output), \ - op_context.output->dims->data, op_context.output->dims->size, \ - GetTensorData(op_context.axis), num_axis, \ - op_context.params->keep_dims, GetTensorData(temp_index), \ - GetTensorData(resolved_axis)) +enum ReduceType { + kSum, + kProd, + kMax, + kMin, + kAny, +}; - if (kernel_type == kReference) { - switch (op_context.input->type) { - case kTfLiteFloat32: - TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, float)); - break; - case kTfLiteInt32: - TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int)); - break; - case kTfLiteInt64: - TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int64_t)); - break; - case kTfLiteUInt8: - // TODO(wangtz): uint8 reduce_prod is not yet supported. - default: - return kTfLiteError; - } +// Eval for determined input type and reduce type. +template +TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node, + OpContext* op_context, ReduceType reduce_type) { + switch (reduce_type) { + case kSum: + return EvalLogic( + context, node, op_context, static_cast(0), + [](const T current, const T in) -> T { return in + current; }); + break; + case kProd: + return EvalLogic( + context, node, op_context, static_cast(1), + [](const T current, const T in) -> T { return in * current; }); + break; + case kMax: + return EvalLogic(context, node, op_context, + std::numeric_limits::lowest(), + [](const T current, const T in) -> T { + return (in > current) ? in : current; + }); + break; + case kMin: + return EvalLogic(context, node, op_context, + std::numeric_limits::max(), + [](const T current, const T in) -> T { + return (in < current) ? in : current; + }); + break; + default: + return kTfLiteError; } -#undef TF_LITE_PROD - return kTfLiteOk; } -template -TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) { - OpContext op_context(context, node); - int64_t num_axis = NumElements(op_context.axis); - TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); - TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); - // Resize the output tensor if the output tensor is dynamic. - if (IsDynamicTensor(op_context.output)) { - TF_LITE_ENSURE_OK(context, - ResizeTempAxis(context, &op_context, resolved_axis)); - TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); - } - -#define TF_LITE_MAX(kernel_type, data_type) \ - kernel_type::ReduceMax<>( \ - GetTensorData(op_context.input), \ - op_context.input->dims->data, op_context.input->dims->size, \ - GetTensorData(op_context.output), \ - op_context.output->dims->data, op_context.output->dims->size, \ - GetTensorData(op_context.axis), num_axis, \ - op_context.params->keep_dims, GetTensorData(temp_index), \ - GetTensorData(resolved_axis)) - - if (kernel_type == kReference) { - switch (op_context.input->type) { - case kTfLiteFloat32: - TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, float)); - break; - case kTfLiteInt32: - TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int)); - break; - case kTfLiteInt64: - TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int64_t)); - break; - case kTfLiteUInt8: - TF_LITE_ENSURE_EQ(context, op_context.input->params.scale, - op_context.output->params.scale); - TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, - op_context.output->params.zero_point); - TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, uint8_t)); - break; - default: - return kTfLiteError; - } +// Template specialization for bool type +template <> +TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node, + OpContext* op_context, ReduceType reduce_type) { + switch (reduce_type) { + case kAny: + return EvalLogic(context, node, op_context, false, + [](const bool current, const bool in) -> bool { + return in || current; + }); + break; + default: + return kTfLiteError; } -#undef TF_LITE_MAX - return kTfLiteOk; } -template -TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) { - OpContext op_context(context, node); - int64_t num_axis = NumElements(op_context.axis); - TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); - TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); - // Resize the output tensor if the output tensor is dynamic. - if (IsDynamicTensor(op_context.output)) { - TF_LITE_ENSURE_OK(context, - ResizeTempAxis(context, &op_context, resolved_axis)); - TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); - } - -#define TF_LITE_MIN(kernel_type, data_type) \ - kernel_type::ReduceMin<>( \ - GetTensorData(op_context.input), \ - op_context.input->dims->data, op_context.input->dims->size, \ - GetTensorData(op_context.output), \ - op_context.output->dims->data, op_context.output->dims->size, \ - GetTensorData(op_context.axis), num_axis, \ - op_context.params->keep_dims, GetTensorData(temp_index), \ - GetTensorData(resolved_axis)) - - if (kernel_type == kReference) { - switch (op_context.input->type) { - case kTfLiteFloat32: - TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, float)); - break; - case kTfLiteInt32: - TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, int)); - break; - case kTfLiteInt64: - TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, int64_t)); - break; - case kTfLiteUInt8: - TF_LITE_ENSURE_EQ(context, op_context.input->params.scale, - op_context.output->params.scale); - TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, - op_context.output->params.zero_point); - TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, uint8_t)); - break; - default: - return kTfLiteError; - } +// The entry point that handles input types and then calls template functions to +// handle ReduceType. +template +TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) { + if (kernel_type != kReference) { + return kTfLiteOk; } -#undef TF_LITE_MIN - return kTfLiteOk; -} - -template -TfLiteStatus EvalAny(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); - int64_t num_axis = NumElements(op_context.axis); - TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); - TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); - // Resize the output tensor if the output tensor is dynamic. - if (IsDynamicTensor(op_context.output)) { - TF_LITE_ENSURE_OK(context, - ResizeTempAxis(context, &op_context, resolved_axis)); - TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); - } - if (kernel_type == kReference) { - reference_ops::ReduceAny( - GetTensorData(op_context.input), op_context.input->dims->data, - op_context.input->dims->size, GetTensorData(op_context.output), - op_context.output->dims->data, op_context.output->dims->size, - GetTensorData(op_context.axis), num_axis, - op_context.params->keep_dims, GetTensorData(temp_index), - GetTensorData(resolved_axis)); + switch (op_context.input->type) { + case kTfLiteFloat32: + return EvalType(context, node, &op_context, reduce_type); + break; + case kTfLiteInt32: + return EvalType(context, node, &op_context, reduce_type); + break; + case kTfLiteInt64: + return EvalType(context, node, &op_context, reduce_type); + break; + case kTfLiteUInt8: + return EvalType(context, node, &op_context, reduce_type); + break; + case kTfLiteBool: + return EvalType(context, node, &op_context, reduce_type); + break; + default: + return kTfLiteError; } - - return kTfLiteOk; } + } // namespace reduce TfLiteRegistration* Register_MEAN_REF() { @@ -521,36 +426,37 @@ TfLiteRegistration* Register_MEAN_REF() { } TfLiteRegistration* Register_SUM_REF() { - static TfLiteRegistration r = {reduce::Init, reduce::Free, - reduce::PrepareSimple, - reduce::EvalSum}; + static TfLiteRegistration r = { + reduce::Init, reduce::Free, reduce::PrepareSimple, + reduce::EvalGeneric}; return &r; } TfLiteRegistration* Register_REDUCE_PROD_REF() { - static TfLiteRegistration r = {reduce::Init, reduce::Free, - reduce::PrepareSimple, - reduce::EvalProd}; + static TfLiteRegistration r = { + reduce::Init, reduce::Free, reduce::PrepareSimple, + reduce::EvalGeneric}; return &r; } TfLiteRegistration* Register_REDUCE_MAX_REF() { - static TfLiteRegistration r = {reduce::Init, reduce::Free, - reduce::PrepareSimple, - reduce::EvalMax}; + static TfLiteRegistration r = { + reduce::Init, reduce::Free, reduce::PrepareSimple, + reduce::EvalGeneric}; return &r; } TfLiteRegistration* Register_REDUCE_MIN_REF() { - static TfLiteRegistration r = {reduce::Init, reduce::Free, - reduce::PrepareSimple, - reduce::EvalMin}; + static TfLiteRegistration r = { + reduce::Init, reduce::Free, reduce::PrepareSimple, + reduce::EvalGeneric}; return &r; } TfLiteRegistration* Register_REDUCE_ANY_REF() { - static TfLiteRegistration r = {reduce::Init, reduce::Free, reduce::PrepareAny, - reduce::EvalAny}; + static TfLiteRegistration r = { + reduce::Init, reduce::Free, reduce::PrepareAny, + reduce::EvalGeneric}; return &r; } -- cgit v1.2.3