diff options
author | 2018-05-08 22:49:20 -0700 | |
---|---|---|
committer | 2018-05-09 10:41:25 -0700 | |
commit | eac758802e66934a6fde4e23fd92023780a5c075 (patch) | |
tree | bdef3d3c378bafae1f1908d8e3b89523046cd87e /tensorflow/contrib/lite/kernels/slice.cc | |
parent | 7bd992b02c0a19ce7aa9c085ab5caa0e00fe2516 (diff) |
Implementation of Slice.
PiperOrigin-RevId: 195926057
Diffstat (limited to 'tensorflow/contrib/lite/kernels/slice.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/slice.cc | 197 |
1 files changed, 197 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc new file mode 100644 index 0000000000..82baf53e1d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/slice.cc @@ -0,0 +1,197 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string.h> +#include <cmath> +#include <vector> +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace slice { + +constexpr int kInputTensor = 0; +constexpr int kBeginTensor = 1; +constexpr int kSizeTensor = 2; +constexpr int kOutputTensor = 0; + +// This Op only supports 1-4D cases and since we use the optimized ops 4D +// implementation, the 1-3D tensors are mapped to 4D. +const int kMaxDim = 4; + +template <typename T> +TfLiteStatus CalculateOutputShapeVector( + TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* begin, + TfLiteTensor* size, std::vector<int64_t>* output_shape_vector) { + for (int idx = 0; idx < NumDimensions(input); ++idx) { + T size_value = GetTensorData<T>(size)[idx]; + if (size_value < 0) { + if (size_value != -1) { + context->ReportError(context, "Invalid size."); + return kTfLiteError; + } + size_value = SizeOfDimension(input, idx) - GetTensorData<T>(begin)[idx]; + } else { + if (SizeOfDimension(input, idx) < + GetTensorData<T>(begin)[idx] + size_value) { + context->ReportError(context, "Invalid begin and size."); + return kTfLiteError; + } + } + output_shape_vector->push_back(size_value); + } + return kTfLiteOk; +} + +template <typename T> +void GetBeginAndSizeVectors(int dimensions, TfLiteTensor* begin, + TfLiteTensor* size, std::vector<int>* begins, + std::vector<int>* sizes) { + for (int idx = dimensions - 1; idx >= 0; --idx) { + begins->push_back(GetTensorData<T>(begin)[idx]); + sizes->push_back(GetTensorData<T>(size)[idx]); + } +} + +TfLiteStatus ResizeOutputShape(TfLiteContext* context, TfLiteTensor* input, + TfLiteTensor* begin, TfLiteTensor* size, + TfLiteTensor* output) { + std::vector<int64_t> output_shape_vector; + + if (begin->type == kTfLiteInt32) { + TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int32_t>( + context, input, begin, size, &output_shape_vector)); + } else if (begin->type == kTfLiteInt64) { + TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int64_t>( + context, input, begin, size, &output_shape_vector)); + } else { + context->ReportError(context, "Type is currently not supported by Slice."); + return kTfLiteError; + } + + TfLiteIntArray* output_shape = + TfLiteIntArrayCreate(output_shape_vector.size()); + std::copy(output_shape_vector.begin(), output_shape_vector.end(), + output_shape->data); + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* begin = GetInput(context, node, kBeginTensor); + TfLiteTensor* size = GetInput(context, node, kSizeTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Ensure validity of input tensor and its dimension. + TF_LITE_ENSURE_EQ(context, input->type, output->type); + TF_LITE_ENSURE(context, + begin->type == kTfLiteInt32 || begin->type == kTfLiteInt64); + TF_LITE_ENSURE(context, + size->type == kTfLiteInt32 || size->type == kTfLiteInt64); + TF_LITE_ENSURE(context, NumDimensions(begin) == NumDimensions(size) == 1); + TF_LITE_ENSURE_MSG(context, NumDimensions(input) <= kMaxDim, + "Slice op only supports 1D-4D input arrays."); + + // Postpone allocation of output if any of the indexing tensors is not + // constant + if (!(IsConstantTensor(begin) && IsConstantTensor(size))) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + + return ResizeOutputShape(context, input, begin, size, output); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* begin = GetInput(context, node, kBeginTensor); + TfLiteTensor* size = GetInput(context, node, kSizeTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputShape(context, input, begin, size, output)); + } + + std::vector<int> begins; + begins.reserve(kMaxDim); + std::vector<int> sizes; + sizes.reserve(kMaxDim); + + if (begin->type == kTfLiteInt32) { + GetBeginAndSizeVectors<int32_t>(NumDimensions(input), begin, size, &begins, + &sizes); + } else if (begin->type == kTfLiteInt64) { + GetBeginAndSizeVectors<int64_t>(NumDimensions(input), begin, size, &begins, + &sizes); + } else { + context->ReportError(context, "Type is currently not supported by Slice."); + return kTfLiteError; + } + + for (int i = NumDimensions(input); i < kMaxDim; ++i) { + begins.push_back(0); + sizes.push_back(1); + } + +#define TF_LITE_SLICE(data_type) \ + optimized_ops::Slice<data_type>( \ + GetTensorData<data_type>(input), GetTensorDims(input), begins, sizes, \ + GetTensorData<data_type>(output), GetTensorDims(output)) + + switch (input->type) { + case kTfLiteFloat32: + TF_LITE_SLICE(float); + break; + case kTfLiteInt32: + TF_LITE_SLICE(int32_t); + break; + case kTfLiteInt64: + TF_LITE_SLICE(int64_t); + break; + case kTfLiteUInt8: + TF_LITE_SLICE(uint8_t); + break; + case kTfLiteBool: + TF_LITE_SLICE(bool); + break; + default: + context->ReportError(context, + "Type is currently not supported by Slice."); + return kTfLiteError; + } +#undef TF_LITE_SLICE + return kTfLiteOk; +} + +} // namespace slice + +TfLiteRegistration* Register_SLICE() { + static TfLiteRegistration r = {nullptr, nullptr, slice::Prepare, slice::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite |