diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-07 15:41:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-07 17:27:20 -0700 |
commit | fc7f0b296dd53d1b72af21d36d36b6bcc5291ea7 (patch) | |
tree | 46e76ead2391a3fb1232459189ad0b8d0d8066ac /tensorflow/contrib/lite/kernels/select.cc | |
parent | 3a2f1cfb73fa6a21eba077485bdc08aa05646ad1 (diff) |
Add support for select (via tf.where) support to tflite.
PiperOrigin-RevId: 195734246
Diffstat (limited to 'tensorflow/contrib/lite/kernels/select.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/select.cc | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc new file mode 100644 index 0000000000..029ad9a709 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/select.cc @@ -0,0 +1,125 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace select { + +constexpr int kInputTensorCondition = 0; +constexpr int kInputTensorX = 1; +constexpr int kInputTensorY = 2; +constexpr int kOutputTensor = 0; + +TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input_condition = + GetInput(context, node, kInputTensorCondition); + TfLiteTensor* input_x = GetInput(context, node, kInputTensorX); + TfLiteTensor* input_y = GetInput(context, node, kInputTensorY); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Input must be bool. + TF_LITE_ENSURE(context, input_condition->type == kTfLiteBool); + + // Input tensors must have the same type and size + TF_LITE_ENSURE_EQ(context, input_x->type, input_y->type); + TF_LITE_ENSURE(context, HaveSameShapes(input_x, input_y)); + output->type = input_x->type; + + // Either the same shape, or input_condition must be Rank 1 and match over the + // first dimension. + bool same_shape = HaveSameShapes(input_condition, input_x); + if (!same_shape && NumDimensions(input_condition) == 1) { + same_shape = + SizeOfDimension(input_condition, 0) == SizeOfDimension(input_x, 0); + } + + TF_LITE_ENSURE(context, same_shape); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_x->dims); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input_condition = + GetInput(context, node, kInputTensorCondition); + TfLiteTensor* input_x = GetInput(context, node, kInputTensorX); + TfLiteTensor* input_y = GetInput(context, node, kInputTensorY); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + bool is_rank_one = !HaveSameShapes(input_condition, input_x); + +#define TF_LITE_SELECT(type, op) \ + reference_ops::op(GetTensorData<bool>(input_condition), \ + GetTensorDims(input_condition), \ + GetTensorData<type>(input_x), GetTensorDims(input_x), \ + GetTensorData<type>(input_y), GetTensorDims(input_y), \ + GetTensorData<type>(output), GetTensorDims(output)); + +#define TF_LITE_SWITCH(type, op) \ + switch (type) { \ + break; \ + case kTfLiteBool: \ + TF_LITE_SELECT(bool, op); \ + break; \ + case kTfLiteFloat32: \ + TF_LITE_SELECT(float, op); \ + break; \ + case kTfLiteUInt8: \ + TF_LITE_SELECT(uint8_t, op); \ + break; \ + case kTfLiteInt32: \ + TF_LITE_SELECT(int32_t, op); \ + break; \ + case kTfLiteInt64: \ + TF_LITE_SELECT(int64_t, op); \ + break; \ + default: \ + context->ReportError(context, \ + "Does not support type other than bool|float|int"); \ + return kTfLiteError; \ + } + + if (is_rank_one) { + TF_LITE_SWITCH(input_x->type, RankOneSelect); + } else { + TF_LITE_SWITCH(input_x->type, Select); + } + +#undef TF_LITE_SELECT +#undef TF_LITE_SWITCH + return kTfLiteOk; +} + +} // namespace select + +TfLiteRegistration* Register_SELECT() { + static TfLiteRegistration r = {nullptr, nullptr, select::SelectPrepare, + select::SelectEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite |