aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/slice.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-08 22:49:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-09 10:41:25 -0700
commiteac758802e66934a6fde4e23fd92023780a5c075 (patch)
treebdef3d3c378bafae1f1908d8e3b89523046cd87e /tensorflow/contrib/lite/kernels/slice.cc
parent7bd992b02c0a19ce7aa9c085ab5caa0e00fe2516 (diff)
Implementation of Slice.
PiperOrigin-RevId: 195926057
Diffstat (limited to 'tensorflow/contrib/lite/kernels/slice.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/slice.cc197
1 files changed, 197 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc
new file mode 100644
index 0000000000..82baf53e1d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/slice.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string.h>
+#include <cmath>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace slice {
+
+constexpr int kInputTensor = 0;
+constexpr int kBeginTensor = 1;
+constexpr int kSizeTensor = 2;
+constexpr int kOutputTensor = 0;
+
+// This Op only supports 1-4D cases and since we use the optimized ops 4D
+// implementation, the 1-3D tensors are mapped to 4D.
+const int kMaxDim = 4;
+
+template <typename T>
+TfLiteStatus CalculateOutputShapeVector(
+ TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* begin,
+ TfLiteTensor* size, std::vector<int64_t>* output_shape_vector) {
+ for (int idx = 0; idx < NumDimensions(input); ++idx) {
+ T size_value = GetTensorData<T>(size)[idx];
+ if (size_value < 0) {
+ if (size_value != -1) {
+ context->ReportError(context, "Invalid size.");
+ return kTfLiteError;
+ }
+ size_value = SizeOfDimension(input, idx) - GetTensorData<T>(begin)[idx];
+ } else {
+ if (SizeOfDimension(input, idx) <
+ GetTensorData<T>(begin)[idx] + size_value) {
+ context->ReportError(context, "Invalid begin and size.");
+ return kTfLiteError;
+ }
+ }
+ output_shape_vector->push_back(size_value);
+ }
+ return kTfLiteOk;
+}
+
+template <typename T>
+void GetBeginAndSizeVectors(int dimensions, TfLiteTensor* begin,
+ TfLiteTensor* size, std::vector<int>* begins,
+ std::vector<int>* sizes) {
+ for (int idx = dimensions - 1; idx >= 0; --idx) {
+ begins->push_back(GetTensorData<T>(begin)[idx]);
+ sizes->push_back(GetTensorData<T>(size)[idx]);
+ }
+}
+
+TfLiteStatus ResizeOutputShape(TfLiteContext* context, TfLiteTensor* input,
+ TfLiteTensor* begin, TfLiteTensor* size,
+ TfLiteTensor* output) {
+ std::vector<int64_t> output_shape_vector;
+
+ if (begin->type == kTfLiteInt32) {
+ TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int32_t>(
+ context, input, begin, size, &output_shape_vector));
+ } else if (begin->type == kTfLiteInt64) {
+ TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int64_t>(
+ context, input, begin, size, &output_shape_vector));
+ } else {
+ context->ReportError(context, "Type is currently not supported by Slice.");
+ return kTfLiteError;
+ }
+
+ TfLiteIntArray* output_shape =
+ TfLiteIntArrayCreate(output_shape_vector.size());
+ std::copy(output_shape_vector.begin(), output_shape_vector.end(),
+ output_shape->data);
+ return context->ResizeTensor(context, output, output_shape);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* begin = GetInput(context, node, kBeginTensor);
+ TfLiteTensor* size = GetInput(context, node, kSizeTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // Ensure validity of input tensor and its dimension.
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+ TF_LITE_ENSURE(context,
+ begin->type == kTfLiteInt32 || begin->type == kTfLiteInt64);
+ TF_LITE_ENSURE(context,
+ size->type == kTfLiteInt32 || size->type == kTfLiteInt64);
+ TF_LITE_ENSURE(context, NumDimensions(begin) == NumDimensions(size) == 1);
+ TF_LITE_ENSURE_MSG(context, NumDimensions(input) <= kMaxDim,
+ "Slice op only supports 1D-4D input arrays.");
+
+ // Postpone allocation of output if any of the indexing tensors is not
+ // constant
+ if (!(IsConstantTensor(begin) && IsConstantTensor(size))) {
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+ }
+
+ return ResizeOutputShape(context, input, begin, size, output);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* begin = GetInput(context, node, kBeginTensor);
+ TfLiteTensor* size = GetInput(context, node, kSizeTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeOutputShape(context, input, begin, size, output));
+ }
+
+ std::vector<int> begins;
+ begins.reserve(kMaxDim);
+ std::vector<int> sizes;
+ sizes.reserve(kMaxDim);
+
+ if (begin->type == kTfLiteInt32) {
+ GetBeginAndSizeVectors<int32_t>(NumDimensions(input), begin, size, &begins,
+ &sizes);
+ } else if (begin->type == kTfLiteInt64) {
+ GetBeginAndSizeVectors<int64_t>(NumDimensions(input), begin, size, &begins,
+ &sizes);
+ } else {
+ context->ReportError(context, "Type is currently not supported by Slice.");
+ return kTfLiteError;
+ }
+
+ for (int i = NumDimensions(input); i < kMaxDim; ++i) {
+ begins.push_back(0);
+ sizes.push_back(1);
+ }
+
+#define TF_LITE_SLICE(data_type) \
+ optimized_ops::Slice<data_type>( \
+ GetTensorData<data_type>(input), GetTensorDims(input), begins, sizes, \
+ GetTensorData<data_type>(output), GetTensorDims(output))
+
+ switch (input->type) {
+ case kTfLiteFloat32:
+ TF_LITE_SLICE(float);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_SLICE(int32_t);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_SLICE(int64_t);
+ break;
+ case kTfLiteUInt8:
+ TF_LITE_SLICE(uint8_t);
+ break;
+ case kTfLiteBool:
+ TF_LITE_SLICE(bool);
+ break;
+ default:
+ context->ReportError(context,
+ "Type is currently not supported by Slice.");
+ return kTfLiteError;
+ }
+#undef TF_LITE_SLICE
+ return kTfLiteOk;
+}
+
+} // namespace slice
+
+TfLiteRegistration* Register_SLICE() {
+ static TfLiteRegistration r = {nullptr, nullptr, slice::Prepare, slice::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite