diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-23 17:35:53 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-23 17:40:21 -0800 |
commit | 6dfe00ca4114371ed47c93810219736e3deda2d2 (patch) | |
tree | 14262ad0f7b3e212c0e2d84a480c383a4074031f /tensorflow/contrib/lite/kernels/strided_slice.cc | |
parent | a3e81ec2892126056ad6c1feb9161bc16c2c2975 (diff) |
Support StridedSlice in TFLite for 1D-4D tensors.
PiperOrigin-RevId: 183020501
Diffstat (limited to 'tensorflow/contrib/lite/kernels/strided_slice.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/strided_slice.cc | 256 |
1 files changed, 256 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc new file mode 100644 index 0000000000..91ba4a9b78 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -0,0 +1,256 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <string.h> +#include <cmath> +#include <vector> +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace strided_slice { + +enum KernelType { + kReference, + // TODO(soroosh): add kGenericOptimized +}; + +constexpr int kInputTensor = 0; +constexpr int kBeginTensor = 1; +constexpr int kEndTensor = 2; +constexpr int kStridesTensor = 3; +constexpr int kOutputTensor = 0; + +struct StridedSliceContext { + StridedSliceContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data); + input = GetInput(context, node, kInputTensor); + begin = GetInput(context, node, kBeginTensor); + end = GetInput(context, node, kEndTensor); + strides = GetInput(context, node, kStridesTensor); + output = GetOutput(context, node, kOutputTensor); + dims = NumDimensions(input); + } + TfLiteStridedSliceParams* params; + TfLiteTensor* input; + TfLiteTensor* begin; + TfLiteTensor* end; + TfLiteTensor* strides; + TfLiteTensor* output; + int dims; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + StridedSliceContext op_context(context, node); + + // Ensure validity of input tensor and its dimension + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + // Only INT32 begin/end/strides are supported + // TODO(soroosh) add support for INT64 + TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); + TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, + "StridedSlice op only supports 1D-4D input arrays."); + + // TODO(soroosh): add the following missing functionalities + TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0, + "ellipsis_mask is not implemented yet."); + TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0, + "new_axis_mask is not implemented yet."); + TF_LITE_ENSURE_MSG(context, op_context.params->shrink_axis_mask == 0, + "shrink_axis_mask is not implemented yet."); + + // TODO(soroosh): optimize for constant tensors to do allocation in Prepare + op_context.output->allocation_type = kTfLiteDynamic; + return kTfLiteOk; +} // namespace strided_slice + +// TODO(soroosh): consolidate with BytesRequired in interpreter.h +TfLiteStatus BytesRequired(TfLiteContext* context, TfLiteType type, + const int* dims, int dims_size, size_t* bytes) { + // TODO(aselle): Check for overflow here using overflow.h in TensorFlow + // MultiplyWithoutOverflow. + TF_LITE_ENSURE(context, bytes != nullptr); + size_t count = 1; + for (int k = 0; k < dims_size; k++) count *= dims[k]; + switch (type) { + case kTfLiteFloat32: + *bytes = sizeof(float) * count; + break; + case kTfLiteInt32: + *bytes = sizeof(int32_t) * count; + break; + case kTfLiteUInt8: + *bytes = sizeof(uint8_t) * count; + break; + case kTfLiteInt64: + *bytes = sizeof(int64_t) * count; + break; + default: + return kTfLiteError; + } + return kTfLiteOk; +} + +// Reverse order of bits in the mask to match the expected order in kernel +inline int ReverseMaskBits(int mask, int num_dimensions) { + int out = 0; + for (int dim = 0; dim < num_dimensions; dim++) { + out <<= 1; + out += (mask & 1); + mask >>= 1; + } + return out; +} + +// This Op only supports 1-4D cases and since we use the reference 4D +// implementation, the 1-3D tensors are mapped to 4D. +const int kMaxDim = 4; + +inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) { + return (divisor + (dividend % divisor)) % divisor; +} + +inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) { + return pos_stride + ? (index >= dim ? dim + : PositiveRemainder( + std::min(std::max(index, -dim), dim), dim)) + : (index < -dim + ? -1 + : PositiveRemainder( + std::min(std::max(index, -dim), dim - 1), dim)); +} + +template <KernelType kernel_type> +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + StridedSliceContext op_context(context, node); + + std::vector<int> starts; + std::vector<int> stops; + std::vector<int> strides; + + // Determine size of output tensor and map indices + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(op_context.dims); + for (int idx = op_context.dims - 1; idx >= 0; --idx) { + int dim = op_context.input->dims->data[idx]; + int32_t stride = GetTensorData<int32_t>(op_context.strides)[idx]; + TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); + bool pos_stride = stride > 0; + + int32_t begin = + op_context.params->begin_mask & (1 << idx) + ? pos_stride ? 0 : dim - 1 + : ClampedIndex(GetTensorData<int32_t>(op_context.begin)[idx], dim, + pos_stride); + int32_t end = + op_context.params->end_mask & (1 << idx) + ? pos_stride ? dim : -1 + : ClampedIndex(GetTensorData<int32_t>(op_context.end)[idx], dim, + pos_stride); + + // This is valid for both positive and negative strides + output_shape->data[idx] = ceil((end - begin) / static_cast<float>(stride)); + output_shape->data[idx] = + output_shape->data[idx] < 0 ? 0 : output_shape->data[idx]; + starts.emplace_back(begin); + stops.emplace_back(end); + strides.emplace_back(stride); + } + + for (int i = op_context.dims; i < kMaxDim; i++) { + starts.emplace_back(0); + stops.emplace_back(1); + strides.emplace_back(1); + } + + TF_LITE_ENSURE_STATUS( + context->ResizeTensor(context, op_context.output, output_shape)); + + size_t required_bytes; + TF_LITE_ENSURE_OK( + context, + BytesRequired(context, op_context.output->type, output_shape->data, + output_shape->size, &required_bytes)); + TfLiteTensorRealloc(required_bytes, op_context.output); + + op_context.params->begin_mask = + ReverseMaskBits(op_context.params->begin_mask, op_context.dims); + op_context.params->end_mask = + ReverseMaskBits(op_context.params->end_mask, op_context.dims); + +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice( \ + GetTensorData<data_type>(op_context.input), \ + GetTensorDims(op_context.input), op_context.params->begin_mask, \ + op_context.params->end_mask, starts, stops, strides, \ + GetTensorData<data_type>(op_context.output), \ + GetTensorDims(op_context.output)) + + switch (op_context.input->type) { + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, float); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported " + "by StridedSlice."); + return kTfLiteError; + } +#undef TF_LITE_STRIDED_SLICE + return kTfLiteOk; +} + +} // namespace strided_slice + +TfLiteRegistration* Register_STRIDED_SLICE_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, strided_slice::Prepare, + strided_slice::Eval<strided_slice::kReference>}; + return &r; +} + +// TODO(soroosh): add optimized +TfLiteRegistration* Register_STRIDED_SLICE() { + return Register_STRIDED_SLICE_REF(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite |