aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/expand_dims.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-01 12:53:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 12:57:00 -0700
commitb812f37e26889bb168fa0279a536b907c3fb5fdd (patch)
tree0b5804e7749d0b83a44748ca848917ad1554ceae /tensorflow/contrib/lite/kernels/expand_dims.cc
parent10b2b3b44a6f93f4fd414e8ac450587ece2207ae (diff)
TFLite: adding tile and expand_dims ops.
PiperOrigin-RevId: 198913026
Diffstat (limited to 'tensorflow/contrib/lite/kernels/expand_dims.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims.cc113
1 files changed, 113 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc
new file mode 100644
index 0000000000..ed33012864
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/expand_dims.cc
@@ -0,0 +1,113 @@
+
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace expand_dims {
+constexpr int kInput = 0;
+constexpr int kAxis = 1;
+constexpr int kOutput = 0;
+
+namespace {
+TfLiteStatus ExpandTensorDim(TfLiteContext* context, const TfLiteTensor& input,
+ int axis, TfLiteTensor* output) {
+ const TfLiteIntArray& input_dims = *input.dims;
+ if (axis < 0) {
+ axis = input_dims.size + 1 + axis;
+ }
+ TF_LITE_ENSURE(context, axis <= input_dims.size);
+
+ TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims.size + 1);
+ for (int i = 0; i < output_dims->size; ++i) {
+ if (i < axis) {
+ output_dims->data[i] = input_dims.data[i];
+ } else if (i == axis) {
+ output_dims->data[i] = 1;
+ } else {
+ output_dims->data[i] = input_dims.data[i - 1];
+ }
+ }
+
+ return context->ResizeTensor(context, output, output_dims);
+}
+
+TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
+ const TfLiteTensor& axis, int* axis_value) {
+ TF_LITE_ENSURE_EQ(context, NumElements(&axis), 1);
+ switch (axis.type) {
+ case kTfLiteInt32:
+ *axis_value = *GetTensorData<int32_t>(&axis);
+ return kTfLiteOk;
+ case kTfLiteInt64:
+ *axis_value = *GetTensorData<int64_t>(&axis);
+ return kTfLiteOk;
+ default:
+ return kTfLiteError;
+ }
+}
+
+} // namespace
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, kInput);
+ const TfLiteTensor* axis = GetInput(context, node, kAxis);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ output->type = input->type;
+ if (IsConstantTensor(axis)) {
+ int axis_value;
+ TF_LITE_ENSURE_OK(context,
+ GetAxisValueFromTensor(context, *axis, &axis_value));
+ return ExpandTensorDim(context, *input, axis_value, output);
+ }
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ // Just copy input to output.
+ const TfLiteTensor* input = GetInput(context, node, kInput);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ const TfLiteTensor* axis = GetInput(context, node, kAxis);
+ if (IsDynamicTensor(output)) {
+ int axis_value;
+ TF_LITE_ENSURE_OK(context,
+ GetAxisValueFromTensor(context, *axis, &axis_value));
+ TF_LITE_ENSURE_OK(context,
+ ExpandTensorDim(context, *input, axis_value, output));
+ }
+ memcpy(output->data.raw, input->data.raw, input->bytes);
+ return kTfLiteOk;
+}
+
+} // namespace expand_dims
+TfLiteRegistration* Register_EXPAND_DIMS() {
+ static TfLiteRegistration r = {nullptr, nullptr, expand_dims::Prepare,
+ expand_dims::Eval};
+ return &r;
+}
+} // namespace builtin
+} // namespace ops
+} // namespace tflite