aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/reduce.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-16 19:09:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 19:13:36 -0700
commit1a2af489b1087eb22ec76863867e4e397e453e34 (patch)
treed2576dbbfc6f77897e59353e0844aa0c75a23033 /tensorflow/contrib/lite/kernels/reduce.cc
parentd8f3425e5b054dff01b5ece80e8c8a101c4ed816 (diff)
Support reduce_max and reduce_prod
PiperOrigin-RevId: 204846139
Diffstat (limited to 'tensorflow/contrib/lite/kernels/reduce.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc111
1 files changed, 111 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index 31c331a8c6..52e4084ff8 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -315,6 +315,99 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+template <KernelType kernel_type>
+TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ int64_t num_axis = NumElements(op_context.axis);
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ }
+
+#define TF_LITE_PROD(kernel_type, data_type) \
+ kernel_type::ReduceProd<>( \
+ GetTensorData<data_type>(op_context.input), \
+ op_context.input->dims->data, op_context.input->dims->size, \
+ GetTensorData<data_type>(op_context.output), \
+ op_context.output->dims->data, op_context.output->dims->size, \
+ GetTensorData<int>(op_context.axis), num_axis, \
+ op_context.params->keep_dims, GetTensorData<int>(temp_index), \
+ GetTensorData<int>(resolved_axis))
+
+ if (kernel_type == kReference) {
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, float));
+ break;
+ case kTfLiteInt32:
+ TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int));
+ break;
+ case kTfLiteInt64:
+ TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int64_t));
+ break;
+ case kTfLiteUInt8:
+ // TODO(wangtz): uint8 reduce_prod is not yet supported.
+ default:
+ return kTfLiteError;
+ }
+ }
+#undef TF_LITE_PROD
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
+TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ int64_t num_axis = NumElements(op_context.axis);
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ }
+
+#define TF_LITE_MAX(kernel_type, data_type) \
+ kernel_type::ReduceMax<>( \
+ GetTensorData<data_type>(op_context.input), \
+ op_context.input->dims->data, op_context.input->dims->size, \
+ GetTensorData<data_type>(op_context.output), \
+ op_context.output->dims->data, op_context.output->dims->size, \
+ GetTensorData<int>(op_context.axis), num_axis, \
+ op_context.params->keep_dims, GetTensorData<int>(temp_index), \
+ GetTensorData<int>(resolved_axis))
+
+ if (kernel_type == kReference) {
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, float));
+ break;
+ case kTfLiteInt32:
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int));
+ break;
+ case kTfLiteInt64:
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int64_t));
+ break;
+ case kTfLiteUInt8:
+ TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
+ op_context.output->params.scale);
+ TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
+ op_context.output->params.zero_point);
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, uint8_t));
+ break;
+ default:
+ return kTfLiteError;
+ }
+ }
+#undef TF_LITE_MAX
+ return kTfLiteOk;
+}
+
} // namespace reduce
TfLiteRegistration* Register_MEAN_REF() {
@@ -331,9 +424,27 @@ TfLiteRegistration* Register_SUM_REF() {
return &r;
}
+TfLiteRegistration* Register_REDUCE_PROD_REF() {
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareSimple,
+ reduce::EvalProd<reduce::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_REDUCE_MAX_REF() {
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareSimple,
+ reduce::EvalMax<reduce::kReference>};
+ return &r;
+}
+
// TODO(kanlig): add optimized implementation of Mean.
TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); }
TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); }
+TfLiteRegistration* Register_REDUCE_PROD() {
+ return Register_REDUCE_PROD_REF();
+}
+TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); }
} // namespace builtin
} // namespace ops