aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/reduce.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/reduce.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc115
1 files changed, 115 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index 31c331a8c6..e99f67c725 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -78,6 +78,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) {
size_t num_axis = NumElements(op_context->axis);
const TfLiteIntArray* input_dims = op_context->input->dims;
int input_num_dims = NumDimensions(op_context->input);
+ if (input_num_dims == 0) {
+ return context->ResizeTensor(context, op_context->output,
+ TfLiteIntArrayCreate(0));
+ }
const int* axis = GetTensorData<int>(op_context->axis);
if (op_context->params->keep_dims) {
TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims);
@@ -315,6 +319,99 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+template <KernelType kernel_type>
+TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ int64_t num_axis = NumElements(op_context.axis);
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ }
+
+#define TF_LITE_PROD(kernel_type, data_type) \
+ kernel_type::ReduceProd<>( \
+ GetTensorData<data_type>(op_context.input), \
+ op_context.input->dims->data, op_context.input->dims->size, \
+ GetTensorData<data_type>(op_context.output), \
+ op_context.output->dims->data, op_context.output->dims->size, \
+ GetTensorData<int>(op_context.axis), num_axis, \
+ op_context.params->keep_dims, GetTensorData<int>(temp_index), \
+ GetTensorData<int>(resolved_axis))
+
+ if (kernel_type == kReference) {
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, float));
+ break;
+ case kTfLiteInt32:
+ TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int));
+ break;
+ case kTfLiteInt64:
+ TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int64_t));
+ break;
+ case kTfLiteUInt8:
+ // TODO(wangtz): uint8 reduce_prod is not yet supported.
+ default:
+ return kTfLiteError;
+ }
+ }
+#undef TF_LITE_PROD
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
+TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ int64_t num_axis = NumElements(op_context.axis);
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ }
+
+#define TF_LITE_MAX(kernel_type, data_type) \
+ kernel_type::ReduceMax<>( \
+ GetTensorData<data_type>(op_context.input), \
+ op_context.input->dims->data, op_context.input->dims->size, \
+ GetTensorData<data_type>(op_context.output), \
+ op_context.output->dims->data, op_context.output->dims->size, \
+ GetTensorData<int>(op_context.axis), num_axis, \
+ op_context.params->keep_dims, GetTensorData<int>(temp_index), \
+ GetTensorData<int>(resolved_axis))
+
+ if (kernel_type == kReference) {
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, float));
+ break;
+ case kTfLiteInt32:
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int));
+ break;
+ case kTfLiteInt64:
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int64_t));
+ break;
+ case kTfLiteUInt8:
+ TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
+ op_context.output->params.scale);
+ TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
+ op_context.output->params.zero_point);
+ TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, uint8_t));
+ break;
+ default:
+ return kTfLiteError;
+ }
+ }
+#undef TF_LITE_MAX
+ return kTfLiteOk;
+}
+
} // namespace reduce
TfLiteRegistration* Register_MEAN_REF() {
@@ -331,9 +428,27 @@ TfLiteRegistration* Register_SUM_REF() {
return &r;
}
+TfLiteRegistration* Register_REDUCE_PROD_REF() {
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareSimple,
+ reduce::EvalProd<reduce::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_REDUCE_MAX_REF() {
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareSimple,
+ reduce::EvalMax<reduce::kReference>};
+ return &r;
+}
+
// TODO(kanlig): add optimized implementation of Mean.
TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); }
TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); }
+TfLiteRegistration* Register_REDUCE_PROD() {
+ return Register_REDUCE_PROD_REF();
+}
+TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); }
} // namespace builtin
} // namespace ops