aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/strided_slice.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-23 17:35:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-23 17:40:21 -0800
commit6dfe00ca4114371ed47c93810219736e3deda2d2 (patch)
tree14262ad0f7b3e212c0e2d84a480c383a4074031f /tensorflow/contrib/lite/kernels/strided_slice.cc
parenta3e81ec2892126056ad6c1feb9161bc16c2c2975 (diff)
Support StridedSlice in TFLite for 1D-4D tensors.
PiperOrigin-RevId: 183020501
Diffstat (limited to 'tensorflow/contrib/lite/kernels/strided_slice.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc256
1 files changed, 256 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
new file mode 100644
index 0000000000..91ba4a9b78
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -0,0 +1,256 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <cmath>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace strided_slice {
+
+enum KernelType {
+ kReference,
+ // TODO(soroosh): add kGenericOptimized
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kBeginTensor = 1;
+constexpr int kEndTensor = 2;
+constexpr int kStridesTensor = 3;
+constexpr int kOutputTensor = 0;
+
+struct StridedSliceContext {
+ StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
+ params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
+ input = GetInput(context, node, kInputTensor);
+ begin = GetInput(context, node, kBeginTensor);
+ end = GetInput(context, node, kEndTensor);
+ strides = GetInput(context, node, kStridesTensor);
+ output = GetOutput(context, node, kOutputTensor);
+ dims = NumDimensions(input);
+ }
+ TfLiteStridedSliceParams* params;
+ TfLiteTensor* input;
+ TfLiteTensor* begin;
+ TfLiteTensor* end;
+ TfLiteTensor* strides;
+ TfLiteTensor* output;
+ int dims;
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ StridedSliceContext op_context(context, node);
+
+ // Ensure validity of input tensor and its dimension
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
+ TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+ // Only INT32 begin/end/strides are supported
+ // TODO(soroosh) add support for INT64
+ TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
+ TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
+ "StridedSlice op only supports 1D-4D input arrays.");
+
+ // TODO(soroosh): add the following missing functionalities
+ TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0,
+ "ellipsis_mask is not implemented yet.");
+ TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
+ "new_axis_mask is not implemented yet.");
+ TF_LITE_ENSURE_MSG(context, op_context.params->shrink_axis_mask == 0,
+ "shrink_axis_mask is not implemented yet.");
+
+ // TODO(soroosh): optimize for constant tensors to do allocation in Prepare
+ op_context.output->allocation_type = kTfLiteDynamic;
+ return kTfLiteOk;
+} // namespace strided_slice
+
+// TODO(soroosh): consolidate with BytesRequired in interpreter.h
+TfLiteStatus BytesRequired(TfLiteContext* context, TfLiteType type,
+ const int* dims, int dims_size, size_t* bytes) {
+ // TODO(aselle): Check for overflow here using overflow.h in TensorFlow
+ // MultiplyWithoutOverflow.
+ TF_LITE_ENSURE(context, bytes != nullptr);
+ size_t count = 1;
+ for (int k = 0; k < dims_size; k++) count *= dims[k];
+ switch (type) {
+ case kTfLiteFloat32:
+ *bytes = sizeof(float) * count;
+ break;
+ case kTfLiteInt32:
+ *bytes = sizeof(int32_t) * count;
+ break;
+ case kTfLiteUInt8:
+ *bytes = sizeof(uint8_t) * count;
+ break;
+ case kTfLiteInt64:
+ *bytes = sizeof(int64_t) * count;
+ break;
+ default:
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// Reverse order of bits in the mask to match the expected order in kernel
+inline int ReverseMaskBits(int mask, int num_dimensions) {
+ int out = 0;
+ for (int dim = 0; dim < num_dimensions; dim++) {
+ out <<= 1;
+ out += (mask & 1);
+ mask >>= 1;
+ }
+ return out;
+}
+
+// This Op only supports 1-4D cases and since we use the reference 4D
+// implementation, the 1-3D tensors are mapped to 4D.
+const int kMaxDim = 4;
+
+inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) {
+ return (divisor + (dividend % divisor)) % divisor;
+}
+
+inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) {
+ return pos_stride
+ ? (index >= dim ? dim
+ : PositiveRemainder(
+ std::min(std::max(index, -dim), dim), dim))
+ : (index < -dim
+ ? -1
+ : PositiveRemainder(
+ std::min(std::max(index, -dim), dim - 1), dim));
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ StridedSliceContext op_context(context, node);
+
+ std::vector<int> starts;
+ std::vector<int> stops;
+ std::vector<int> strides;
+
+ // Determine size of output tensor and map indices
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(op_context.dims);
+ for (int idx = op_context.dims - 1; idx >= 0; --idx) {
+ int dim = op_context.input->dims->data[idx];
+ int32_t stride = GetTensorData<int32_t>(op_context.strides)[idx];
+ TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
+ bool pos_stride = stride > 0;
+
+ int32_t begin =
+ op_context.params->begin_mask & (1 << idx)
+ ? pos_stride ? 0 : dim - 1
+ : ClampedIndex(GetTensorData<int32_t>(op_context.begin)[idx], dim,
+ pos_stride);
+ int32_t end =
+ op_context.params->end_mask & (1 << idx)
+ ? pos_stride ? dim : -1
+ : ClampedIndex(GetTensorData<int32_t>(op_context.end)[idx], dim,
+ pos_stride);
+
+ // This is valid for both positive and negative strides
+ output_shape->data[idx] = ceil((end - begin) / static_cast<float>(stride));
+ output_shape->data[idx] =
+ output_shape->data[idx] < 0 ? 0 : output_shape->data[idx];
+ starts.emplace_back(begin);
+ stops.emplace_back(end);
+ strides.emplace_back(stride);
+ }
+
+ for (int i = op_context.dims; i < kMaxDim; i++) {
+ starts.emplace_back(0);
+ stops.emplace_back(1);
+ strides.emplace_back(1);
+ }
+
+ TF_LITE_ENSURE_STATUS(
+ context->ResizeTensor(context, op_context.output, output_shape));
+
+ size_t required_bytes;
+ TF_LITE_ENSURE_OK(
+ context,
+ BytesRequired(context, op_context.output->type, output_shape->data,
+ output_shape->size, &required_bytes));
+ TfLiteTensorRealloc(required_bytes, op_context.output);
+
+ op_context.params->begin_mask =
+ ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
+ op_context.params->end_mask =
+ ReverseMaskBits(op_context.params->end_mask, op_context.dims);
+
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
+ kernel_type::StridedSlice( \
+ GetTensorData<data_type>(op_context.input), \
+ GetTensorDims(op_context.input), op_context.params->begin_mask, \
+ op_context.params->end_mask, starts, stops, strides, \
+ GetTensorData<data_type>(op_context.output), \
+ GetTensorDims(op_context.output))
+
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ if (kernel_type == kReference) {
+ TF_LITE_STRIDED_SLICE(reference_ops, float);
+ }
+ break;
+ case kTfLiteInt32:
+ if (kernel_type == kReference) {
+ TF_LITE_STRIDED_SLICE(reference_ops, int32_t);
+ }
+ break;
+ case kTfLiteInt64:
+ if (kernel_type == kReference) {
+ TF_LITE_STRIDED_SLICE(reference_ops, int64_t);
+ }
+ break;
+ default:
+ context->ReportError(context,
+ "Type is currently not supported "
+ "by StridedSlice.");
+ return kTfLiteError;
+ }
+#undef TF_LITE_STRIDED_SLICE
+ return kTfLiteOk;
+}
+
+} // namespace strided_slice
+
+TfLiteRegistration* Register_STRIDED_SLICE_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, strided_slice::Prepare,
+ strided_slice::Eval<strided_slice::kReference>};
+ return &r;
+}
+
+// TODO(soroosh): add optimized
+TfLiteRegistration* Register_STRIDED_SLICE() {
+ return Register_STRIDED_SLICE_REF();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite