diff options
author | 2018-01-23 17:35:53 -0800 | |
---|---|---|
committer | 2018-01-23 17:40:21 -0800 | |
commit | 6dfe00ca4114371ed47c93810219736e3deda2d2 (patch) | |
tree | 14262ad0f7b3e212c0e2d84a480c383a4074031f /tensorflow/contrib/lite | |
parent | a3e81ec2892126056ad6c1feb9161bc16c2c2975 (diff) |
Support StridedSlice in TFLite for 1D-4D tensors.
PiperOrigin-RevId: 183020501
Diffstat (limited to 'tensorflow/contrib/lite')
18 files changed, 1107 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 3b43a1fd5d..ab07c58c92 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -88,7 +88,9 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteSequenceRNNParams; -typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams; +typedef struct { + TfLiteFusedActivation activation; +} TfLiteFullyConnectedParams; typedef enum { kTfLiteLshProjectionUnknown = 0, @@ -96,9 +98,13 @@ typedef enum { kTfLiteLshProjectionDense = 2, } TfLiteLSHProjectionType; -typedef struct { TfLiteLSHProjectionType type; } TfLiteLSHProjectionParams; +typedef struct { + TfLiteLSHProjectionType type; +} TfLiteLSHProjectionParams; -typedef struct { float beta; } TfLiteSoftmaxParams; +typedef struct { + float beta; +} TfLiteSoftmaxParams; typedef struct { int axis; @@ -226,6 +232,14 @@ typedef struct { int num_squeeze_dims; } TfLiteSqueezeParams; +typedef struct { + int begin_mask; + int end_mask; + int ellipsis_mask; + int new_axis_mask; + int shrink_axis_mask; +} TfLiteStridedSliceParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 2d51b8727f..4195e7553c 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -103,6 +103,7 @@ cc_library( "space_to_batch_nd.cc", "space_to_depth.cc", "squeeze.cc", + "strided_slice.cc", "sub.cc", "svdf.cc", "transpose.cc", @@ -518,6 +519,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "strided_slice_test", + size = "small", + srcs = ["strided_slice_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 1d86183d94..03708eb32b 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -2330,6 +2330,18 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, } } +inline bool LoopCondition(int index, int stop, int stride) { + return stride > 0 ? index < stop : index > stop; +} + +inline int StartIndex(int start, int stride, int dim, bool masked) { + return masked ? (stride > 0 ? 0 : dim - 1) : start; +} + +inline int StopIndex(int stop, int stride, int dim, bool masked) { + return masked ? (stride > 0 ? dim : -1) : stop; +} + template <typename T> inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, int begin_mask, int end_mask, @@ -2337,20 +2349,35 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, const std::vector<int>& stops, const std::vector<int>& strides, T* output_data, const Dims<4>& output_dims) { - const int start_b = (begin_mask & 8) ? 0 : starts[3]; - const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3]; - const int start_h = (begin_mask & 4) ? 0 : starts[2]; - const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2]; - const int start_w = (begin_mask & 2) ? 0 : starts[1]; - const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1]; - const int start_d = (begin_mask & 1) ? 0 : starts[0]; - const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0]; + TFLITE_DCHECK_EQ(starts.size(), 4); + TFLITE_DCHECK_EQ(stops.size(), 4); + TFLITE_DCHECK_EQ(strides.size(), 4); + const int start_b = + StartIndex(starts[3], strides[3], input_dims.sizes[3], begin_mask & 8); + const int stop_b = + StopIndex(stops[3], strides[3], input_dims.sizes[3], end_mask & 8); + const int start_h = + StartIndex(starts[2], strides[2], input_dims.sizes[2], begin_mask & 4); + const int stop_h = + StopIndex(stops[2], strides[2], input_dims.sizes[2], end_mask & 4); + const int start_w = + StartIndex(starts[1], strides[1], input_dims.sizes[1], begin_mask & 2); + const int stop_w = + StopIndex(stops[1], strides[1], input_dims.sizes[1], end_mask & 2); + const int start_d = + StartIndex(starts[0], strides[0], input_dims.sizes[0], begin_mask & 1); + const int stop_d = + StopIndex(stops[0], strides[0], input_dims.sizes[0], end_mask & 1); T* out_ptr = output_data; - for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { - for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { - for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { - for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) { + for (int in_b = start_b; LoopCondition(in_b, stop_b, strides[3]); + in_b += strides[3]) { + for (int in_h = start_h; LoopCondition(in_h, stop_h, strides[2]); + in_h += strides[2]) { + for (int in_w = start_w; LoopCondition(in_w, stop_w, strides[1]); + in_w += strides[1]) { + for (int in_d = start_d; LoopCondition(in_d, stop_d, strides[0]); + in_d += strides[0]) { *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; } } diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index c9e74cb8d5..f605deaa5b 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -58,6 +58,7 @@ TfLiteRegistration* Register_GATHER(); TfLiteRegistration* Register_TRANSPOSE(); TfLiteRegistration* Register_MEAN(); TfLiteRegistration* Register_SQUEEZE(); +TfLiteRegistration* Register_STRIDED_SLICE(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -103,6 +104,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_DIV, Register_DIV()); AddBuiltin(BuiltinOperator_SUB, Register_SUB()); AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); + AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); } TfLiteRegistration* BuiltinOpResolver::FindOp( diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc new file mode 100644 index 0000000000..91ba4a9b78 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -0,0 +1,256 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <string.h> +#include <cmath> +#include <vector> +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace strided_slice { + +enum KernelType { + kReference, + // TODO(soroosh): add kGenericOptimized +}; + +constexpr int kInputTensor = 0; +constexpr int kBeginTensor = 1; +constexpr int kEndTensor = 2; +constexpr int kStridesTensor = 3; +constexpr int kOutputTensor = 0; + +struct StridedSliceContext { + StridedSliceContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data); + input = GetInput(context, node, kInputTensor); + begin = GetInput(context, node, kBeginTensor); + end = GetInput(context, node, kEndTensor); + strides = GetInput(context, node, kStridesTensor); + output = GetOutput(context, node, kOutputTensor); + dims = NumDimensions(input); + } + TfLiteStridedSliceParams* params; + TfLiteTensor* input; + TfLiteTensor* begin; + TfLiteTensor* end; + TfLiteTensor* strides; + TfLiteTensor* output; + int dims; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + StridedSliceContext op_context(context, node); + + // Ensure validity of input tensor and its dimension + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + // Only INT32 begin/end/strides are supported + // TODO(soroosh) add support for INT64 + TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); + TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, + "StridedSlice op only supports 1D-4D input arrays."); + + // TODO(soroosh): add the following missing functionalities + TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0, + "ellipsis_mask is not implemented yet."); + TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0, + "new_axis_mask is not implemented yet."); + TF_LITE_ENSURE_MSG(context, op_context.params->shrink_axis_mask == 0, + "shrink_axis_mask is not implemented yet."); + + // TODO(soroosh): optimize for constant tensors to do allocation in Prepare + op_context.output->allocation_type = kTfLiteDynamic; + return kTfLiteOk; +} // namespace strided_slice + +// TODO(soroosh): consolidate with BytesRequired in interpreter.h +TfLiteStatus BytesRequired(TfLiteContext* context, TfLiteType type, + const int* dims, int dims_size, size_t* bytes) { + // TODO(aselle): Check for overflow here using overflow.h in TensorFlow + // MultiplyWithoutOverflow. + TF_LITE_ENSURE(context, bytes != nullptr); + size_t count = 1; + for (int k = 0; k < dims_size; k++) count *= dims[k]; + switch (type) { + case kTfLiteFloat32: + *bytes = sizeof(float) * count; + break; + case kTfLiteInt32: + *bytes = sizeof(int32_t) * count; + break; + case kTfLiteUInt8: + *bytes = sizeof(uint8_t) * count; + break; + case kTfLiteInt64: + *bytes = sizeof(int64_t) * count; + break; + default: + return kTfLiteError; + } + return kTfLiteOk; +} + +// Reverse order of bits in the mask to match the expected order in kernel +inline int ReverseMaskBits(int mask, int num_dimensions) { + int out = 0; + for (int dim = 0; dim < num_dimensions; dim++) { + out <<= 1; + out += (mask & 1); + mask >>= 1; + } + return out; +} + +// This Op only supports 1-4D cases and since we use the reference 4D +// implementation, the 1-3D tensors are mapped to 4D. +const int kMaxDim = 4; + +inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) { + return (divisor + (dividend % divisor)) % divisor; +} + +inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) { + return pos_stride + ? (index >= dim ? dim + : PositiveRemainder( + std::min(std::max(index, -dim), dim), dim)) + : (index < -dim + ? -1 + : PositiveRemainder( + std::min(std::max(index, -dim), dim - 1), dim)); +} + +template <KernelType kernel_type> +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + StridedSliceContext op_context(context, node); + + std::vector<int> starts; + std::vector<int> stops; + std::vector<int> strides; + + // Determine size of output tensor and map indices + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(op_context.dims); + for (int idx = op_context.dims - 1; idx >= 0; --idx) { + int dim = op_context.input->dims->data[idx]; + int32_t stride = GetTensorData<int32_t>(op_context.strides)[idx]; + TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); + bool pos_stride = stride > 0; + + int32_t begin = + op_context.params->begin_mask & (1 << idx) + ? pos_stride ? 0 : dim - 1 + : ClampedIndex(GetTensorData<int32_t>(op_context.begin)[idx], dim, + pos_stride); + int32_t end = + op_context.params->end_mask & (1 << idx) + ? pos_stride ? dim : -1 + : ClampedIndex(GetTensorData<int32_t>(op_context.end)[idx], dim, + pos_stride); + + // This is valid for both positive and negative strides + output_shape->data[idx] = ceil((end - begin) / static_cast<float>(stride)); + output_shape->data[idx] = + output_shape->data[idx] < 0 ? 0 : output_shape->data[idx]; + starts.emplace_back(begin); + stops.emplace_back(end); + strides.emplace_back(stride); + } + + for (int i = op_context.dims; i < kMaxDim; i++) { + starts.emplace_back(0); + stops.emplace_back(1); + strides.emplace_back(1); + } + + TF_LITE_ENSURE_STATUS( + context->ResizeTensor(context, op_context.output, output_shape)); + + size_t required_bytes; + TF_LITE_ENSURE_OK( + context, + BytesRequired(context, op_context.output->type, output_shape->data, + output_shape->size, &required_bytes)); + TfLiteTensorRealloc(required_bytes, op_context.output); + + op_context.params->begin_mask = + ReverseMaskBits(op_context.params->begin_mask, op_context.dims); + op_context.params->end_mask = + ReverseMaskBits(op_context.params->end_mask, op_context.dims); + +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice( \ + GetTensorData<data_type>(op_context.input), \ + GetTensorDims(op_context.input), op_context.params->begin_mask, \ + op_context.params->end_mask, starts, stops, strides, \ + GetTensorData<data_type>(op_context.output), \ + GetTensorDims(op_context.output)) + + switch (op_context.input->type) { + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, float); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported " + "by StridedSlice."); + return kTfLiteError; + } +#undef TF_LITE_STRIDED_SLICE + return kTfLiteOk; +} + +} // namespace strided_slice + +TfLiteRegistration* Register_STRIDED_SLICE_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, strided_slice::Prepare, + strided_slice::Eval<strided_slice::kReference>}; + return &r; +} + +// TODO(soroosh): add optimized +TfLiteRegistration* Register_STRIDED_SLICE() { + return Register_STRIDED_SLICE_REF(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc new file mode 100644 index 0000000000..cd4a364682 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc @@ -0,0 +1,375 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class StridedSliceOpModel : public SingleOpModel { + public: + StridedSliceOpModel(std::initializer_list<int> input_shape, + std::initializer_list<int> begin_shape, + std::initializer_list<int> end_shape, + std::initializer_list<int> strides_shape, int begin_mask, + int end_mask, int ellipsis_mask, int new_axis_mask, + int shrink_axis_mask) { + input_ = AddInput(TensorType_FLOAT32); + begin_ = AddInput(TensorType_INT32); + end_ = AddInput(TensorType_INT32); + strides_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions, + CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask, + new_axis_mask, shrink_axis_mask) + .Union()); + BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape}); + } + + void SetInput(std::initializer_list<float> data) { + PopulateTensor<float>(input_, data); + } + void SetBegin(std::initializer_list<int32> data) { + PopulateTensor<int32>(begin_, data); + } + void SetEnd(std::initializer_list<int32> data) { + PopulateTensor<int32>(end_, data); + } + void SetStrides(std::initializer_list<int32> data) { + PopulateTensor<int32>(strides_, data); + } + + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int begin_; + int end_; + int strides_; + int output_; +}; + +TEST(StridedSliceOpTest, UnsupportedInputSize) { + EXPECT_DEATH( + StridedSliceOpModel({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0), + "StridedSlice op only supports 1D-4D input arrays."); +} + +TEST(StridedSliceOpTest, UnssupportedArgs) { + EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0), + "ellipsis_mask is not implemented yet."); + EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0), + "new_axis_mask is not implemented yet."); + EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 0, 1), + "shrink_axis_mask is not implemented yet."); +} + +TEST(StridedSliceOpTest, In1D) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); +} + +TEST(StridedSliceOpTest, In1D_EmptyOutput) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({10}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0})); +} + +TEST(StridedSliceOpTest, In1D_NegativeBegin) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-3}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-5}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(StridedSliceOpTest, In1D_NegativeEnd) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({-2}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-3}); + m.SetEnd({5}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); +} + +TEST(StridedSliceOpTest, In1D_BeginMask) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-2}); + m.SetEnd({-3}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({5}); + m.SetEnd({2}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4})); +} + +TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({2}); + m.SetEnd({-4}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-3}); + m.SetEnd({-5}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1})); +} + +TEST(StridedSliceOpTest, In1D_EndMask) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); +} +TEST(StridedSliceOpTest, In1D_NegStride) { + StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3}); + m.SetBegin({-1}); + m.SetEnd({-4}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1})); +} + +TEST(StridedSliceOpTest, In1D_EvenLenStride2) { + StridedSliceOpModel m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2}); + m.SetBegin({0}); + m.SetEnd({2}); + m.SetStrides({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} +TEST(StridedSliceOpTest, In1D_OddLenStride2) { + StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3}); + m.SetBegin({0}); + m.SetEnd({3}); + m.SetStrides({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3})); +} + +TEST(StridedSliceOpTest, In2D_Identity) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} +TEST(StridedSliceOpTest, In2D) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5})); +} + +TEST(StridedSliceOpTest, In2D_Stride2) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3})); +} + +TEST(StridedSliceOpTest, In2D_NegStride) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, -1}); + m.SetEnd({2, -4}); + m.SetStrides({2, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); +} + +TEST(StridedSliceOpTest, In2D_BeginMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); +} + +TEST(StridedSliceOpTest, In2D_EndMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5, 6})); +} + +TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, -2}); + m.SetEnd({2, -4}); + m.SetStrides({1, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); +} +TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, -2}); + m.SetEnd({2, -3}); + m.SetStrides({1, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 4})); +} + +TEST(StridedSliceOpTest, In3D_Identity) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); +} + +TEST(StridedSliceOpTest, In3D_NegStride) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({-1, -1, -1}); + m.SetEnd({-3, -4, -3}); + m.SetStrides({-1, -1, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})); +} +TEST(StridedSliceOpTest, In3D_Strided2) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({2, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 4d8a6d10c8..a01b74f9da 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -618,6 +618,18 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast<void*>(params); break; } + case BuiltinOperator_STRIDED_SLICE: { + auto* params = MallocPOD<TfLiteStridedSliceParams>(); + if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) { + params->begin_mask = schema_params->begin_mask(); + params->end_mask = schema_params->end_mask(); + params->ellipsis_mask = schema_params->ellipsis_mask(); + params->new_axis_mask = schema_params->new_axis_mask(); + params->shrink_axis_mask = schema_params->shrink_axis_mask(); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } } return builtin_data; } diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 998a7b7614..d5b9319407 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -339,6 +339,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_DIV: case tflite::BuiltinOperator_SUB: case tflite::BuiltinOperator_SQUEEZE: + case tflite::BuiltinOperator_STRIDED_SLICE: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 6fcd3e51a4..8ddad4d251 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -118,6 +118,7 @@ enum BuiltinOperator : byte { DIV = 42, SQUEEZE = 43, UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + STRIDED_SLICE = 45, } // Options for the builtin operators. @@ -153,6 +154,7 @@ union BuiltinOptions { DivOptions, SqueezeOptions, SequenceRNNOptions, + StridedSliceOptions, } enum Padding : byte { SAME, VALID } @@ -340,6 +342,14 @@ table SqueezeOptions { squeeze_dims:[int]; } +table StridedSliceOptions { + begin_mask: int; + end_mask: int; + ellipsis_mask: int; + new_axis_mask: int; + shrink_axis_mask: int; +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 6eb9ae2926..b756891f66 100755..100644 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -120,6 +120,9 @@ struct MeanOptionsT; struct SqueezeOptions; struct SqueezeOptionsT; +struct StridedSliceOptions; +struct StridedSliceOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -207,11 +210,12 @@ enum BuiltinOperator { BuiltinOperator_DIV = 42, BuiltinOperator_SQUEEZE = 43, BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + BuiltinOperator_STRIDED_SLICE = 45, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM + BuiltinOperator_MAX = BuiltinOperator_STRIDED_SLICE }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[42] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[43] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -254,7 +258,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[42] { BuiltinOperator_SUB, BuiltinOperator_DIV, BuiltinOperator_SQUEEZE, - BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM}; + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOperator_STRIDED_SLICE}; return values; } @@ -304,6 +309,7 @@ inline const char **EnumNamesBuiltinOperator() { "DIV", "SQUEEZE", "UNIDIRECTIONAL_SEQUENCE_LSTM", + "STRIDED_SLICE", nullptr}; return names; } @@ -346,11 +352,12 @@ enum BuiltinOptions { BuiltinOptions_DivOptions = 29, BuiltinOptions_SqueezeOptions = 30, BuiltinOptions_SequenceRNNOptions = 31, + BuiltinOptions_StridedSliceOptions = 32, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_SequenceRNNOptions + BuiltinOptions_MAX = BuiltinOptions_StridedSliceOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[32] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[33] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -383,7 +390,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[32] { BuiltinOptions_SubOptions, BuiltinOptions_DivOptions, BuiltinOptions_SqueezeOptions, - BuiltinOptions_SequenceRNNOptions}; + BuiltinOptions_SequenceRNNOptions, + BuiltinOptions_StridedSliceOptions}; return values; } @@ -420,6 +428,7 @@ inline const char **EnumNamesBuiltinOptions() { "DivOptions", "SqueezeOptions", "SequenceRNNOptions", + "StridedSliceOptions", nullptr}; return names; } @@ -593,6 +602,11 @@ struct BuiltinOptionsTraits<SequenceRNNOptions> { static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions; }; +template <> +struct BuiltinOptionsTraits<StridedSliceOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -950,6 +964,16 @@ struct BuiltinOptionsUnion { ? reinterpret_cast<const SequenceRNNOptionsT *>(value) : nullptr; } + StridedSliceOptionsT *AsStridedSliceOptions() { + return type == BuiltinOptions_StridedSliceOptions + ? reinterpret_cast<StridedSliceOptionsT *>(value) + : nullptr; + } + const StridedSliceOptionsT *AsStridedSliceOptions() const { + return type == BuiltinOptions_StridedSliceOptions + ? reinterpret_cast<const StridedSliceOptionsT *>(value) + : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, @@ -3532,6 +3556,111 @@ flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions( flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct StridedSliceOptionsT : public flatbuffers::NativeTable { + typedef StridedSliceOptions TableType; + int32_t begin_mask; + int32_t end_mask; + int32_t ellipsis_mask; + int32_t new_axis_mask; + int32_t shrink_axis_mask; + StridedSliceOptionsT() + : begin_mask(0), + end_mask(0), + ellipsis_mask(0), + new_axis_mask(0), + shrink_axis_mask(0) {} +}; + +struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef StridedSliceOptionsT NativeTableType; + enum { + VT_BEGIN_MASK = 4, + VT_END_MASK = 6, + VT_ELLIPSIS_MASK = 8, + VT_NEW_AXIS_MASK = 10, + VT_SHRINK_AXIS_MASK = 12 + }; + int32_t begin_mask() const { return GetField<int32_t>(VT_BEGIN_MASK, 0); } + int32_t end_mask() const { return GetField<int32_t>(VT_END_MASK, 0); } + int32_t ellipsis_mask() const { + return GetField<int32_t>(VT_ELLIPSIS_MASK, 0); + } + int32_t new_axis_mask() const { + return GetField<int32_t>(VT_NEW_AXIS_MASK, 0); + } + int32_t shrink_axis_mask() const { + return GetField<int32_t>(VT_SHRINK_AXIS_MASK, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_BEGIN_MASK) && + VerifyField<int32_t>(verifier, VT_END_MASK) && + VerifyField<int32_t>(verifier, VT_ELLIPSIS_MASK) && + VerifyField<int32_t>(verifier, VT_NEW_AXIS_MASK) && + VerifyField<int32_t>(verifier, VT_SHRINK_AXIS_MASK) && + verifier.EndTable(); + } + StridedSliceOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + StridedSliceOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<StridedSliceOptions> Pack( + flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StridedSliceOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_begin_mask(int32_t begin_mask) { + fbb_.AddElement<int32_t>(StridedSliceOptions::VT_BEGIN_MASK, begin_mask, 0); + } + void add_end_mask(int32_t end_mask) { + fbb_.AddElement<int32_t>(StridedSliceOptions::VT_END_MASK, end_mask, 0); + } + void add_ellipsis_mask(int32_t ellipsis_mask) { + fbb_.AddElement<int32_t>(StridedSliceOptions::VT_ELLIPSIS_MASK, + ellipsis_mask, 0); + } + void add_new_axis_mask(int32_t new_axis_mask) { + fbb_.AddElement<int32_t>(StridedSliceOptions::VT_NEW_AXIS_MASK, + new_axis_mask, 0); + } + void add_shrink_axis_mask(int32_t shrink_axis_mask) { + fbb_.AddElement<int32_t>(StridedSliceOptions::VT_SHRINK_AXIS_MASK, + shrink_axis_mask, 0); + } + explicit StridedSliceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + StridedSliceOptionsBuilder &operator=(const StridedSliceOptionsBuilder &); + flatbuffers::Offset<StridedSliceOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<StridedSliceOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<StridedSliceOptions> CreateStridedSliceOptions( + flatbuffers::FlatBufferBuilder &_fbb, int32_t begin_mask = 0, + int32_t end_mask = 0, int32_t ellipsis_mask = 0, int32_t new_axis_mask = 0, + int32_t shrink_axis_mask = 0) { + StridedSliceOptionsBuilder builder_(_fbb); + builder_.add_shrink_axis_mask(shrink_axis_mask); + builder_.add_new_axis_mask(new_axis_mask); + builder_.add_ellipsis_mask(ellipsis_mask); + builder_.add_end_mask(end_mask); + builder_.add_begin_mask(begin_mask); + return builder_.Finish(); +} + +flatbuffers::Offset<StridedSliceOptions> CreateStridedSliceOptions( + flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -3816,6 +3945,11 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { ? static_cast<const SequenceRNNOptions *>(builtin_options()) : nullptr; } + const StridedSliceOptions *builtin_options_as_StridedSliceOptions() const { + return builtin_options_type() == BuiltinOptions_StridedSliceOptions + ? static_cast<const StridedSliceOptions *>(builtin_options()) + : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -4023,6 +4157,12 @@ Operator::builtin_options_as<SequenceRNNOptions>() const { return builtin_options_as_SequenceRNNOptions(); } +template <> +inline const StridedSliceOptions * +Operator::builtin_options_as<StridedSliceOptions>() const { + return builtin_options_as_StridedSliceOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -4962,11 +5102,11 @@ inline void SequenceRNNOptions::UnPackTo( { auto _e = time_major(); _o->time_major = _e; - } + }; { auto _e = fused_activation_function(); _o->fused_activation_function = _e; - } + }; } inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack( @@ -6040,6 +6180,67 @@ inline flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions( return tflite::CreateSqueezeOptions(_fbb, _squeeze_dims); } +inline StridedSliceOptionsT *StridedSliceOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new StridedSliceOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void StridedSliceOptions::UnPackTo( + StridedSliceOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = begin_mask(); + _o->begin_mask = _e; + }; + { + auto _e = end_mask(); + _o->end_mask = _e; + }; + { + auto _e = ellipsis_mask(); + _o->ellipsis_mask = _e; + }; + { + auto _e = new_axis_mask(); + _o->new_axis_mask = _e; + }; + { + auto _e = shrink_axis_mask(); + _o->shrink_axis_mask = _e; + }; +} + +inline flatbuffers::Offset<StridedSliceOptions> StridedSliceOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateStridedSliceOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<StridedSliceOptions> CreateStridedSliceOptions( + flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const StridedSliceOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _begin_mask = _o->begin_mask; + auto _end_mask = _o->end_mask; + auto _ellipsis_mask = _o->ellipsis_mask; + auto _new_axis_mask = _o->new_axis_mask; + auto _shrink_axis_mask = _o->shrink_axis_mask; + return tflite::CreateStridedSliceOptions(_fbb, _begin_mask, _end_mask, + _ellipsis_mask, _new_axis_mask, + _shrink_axis_mask); +} + inline OperatorCodeT *OperatorCode::UnPack( const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); @@ -6552,6 +6753,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, auto ptr = reinterpret_cast<const SequenceRNNOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast<const StridedSliceOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return false; } @@ -6700,6 +6905,10 @@ inline void *BuiltinOptionsUnion::UnPack( auto ptr = reinterpret_cast<const SequenceRNNOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast<const StridedSliceOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } @@ -6835,6 +7044,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack( auto ptr = reinterpret_cast<const SequenceRNNOptionsT *>(value); return CreateSequenceRNNOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast<const StridedSliceOptionsT *>(value); + return CreateStridedSliceOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } @@ -6985,6 +7198,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) *reinterpret_cast<SequenceRNNOptionsT *>(u.value)); break; } + case BuiltinOptions_StridedSliceOptions: { + value = new StridedSliceOptionsT( + *reinterpret_cast<StridedSliceOptionsT *>(u.value)); + break; + } default: break; } @@ -7147,6 +7365,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast<StridedSliceOptionsT *>(value); + delete ptr; + break; + } default: break; } diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 933da11353..50e8ca75f8 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -46,6 +46,7 @@ gen_zipped_test_files( "space_to_batch_nd.zip", "space_to_depth.zip", "squeeze.zip", + "strided_slice.zip", "sub.zip", "transpose.zip", ], diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 56e4dfc7a2..6204471e52 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -1447,6 +1447,97 @@ def make_squeeze_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_strided_slice_tests(zip_path): + """Make a set of tests to do strided_slice.""" + + # TODO(soroosh): add test/support for uint8. + test_parameters = [ + # 4-D + { + "dtype": [tf.float32, tf.int32, tf.int64], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], + "end": [[8, 2, 2, 3], [12, 2, 2, 5]], + "strides": [None, [1, 1, 1, 1], [2, 1, 3, 1]], + "begin_mask": [None, 0, 1, 2, 8], + "end_mask": [None, 0, 1, 2, 8], + }, + # 2-D + { + "dtype": [tf.float32, tf.int32, tf.int64], + "index_type": [tf.int32], + "input_shape": [[2, 3]], + "begin": [[0, 0], [1, 0]], + "end": [[2, 3], [2, 2]], + "strides": [None, [1, 1], [2, 2]], + "begin_mask": [None, 0, 1, 2], + "end_mask": [None, 0, 1, 2], + }, + # Negative strides + { + "dtype": [tf.float32, tf.int32, tf.int64], + "index_type": [tf.int32], + "input_shape": [[2, 3]], + "begin": [[0, -1]], + "end": [[2, -3]], + "strides": [[1, -1]], + "begin_mask": [None, 0, 1, 2], + "end_mask": [None, 0, 1, 2], + }, + ] + + def build_graph(parameters): + """Build graph for stride_slice test.""" + input_tensor = tf.placeholder( + dtype=parameters["dtype"], + name="input", + shape=parameters["input_shape"]) + begin = tf.placeholder( + dtype=parameters["index_type"], + name="begin", + shape=[len(parameters["input_shape"])]) + end = tf.placeholder( + dtype=parameters["index_type"], + name="end", + shape=[len(parameters["input_shape"])]) + strides = ( + tf.placeholder( + dtype=parameters["index_type"], + name="strides", + shape=[len(parameters["input_shape"])]) + if parameters["strides"] is not None else None) + tensors = [input_tensor, begin, end] + if strides is not None: + tensors.append(strides) + out = tf.strided_slice( + input_tensor, + begin, + end, + strides, + begin_mask=parameters["begin_mask"], + end_mask=parameters["end_mask"]) + return tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Build inputs for stride_slice test.""" + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + index_type = _TF_TYPE_INFO[parameters["index_type"]][0] + begin_values = np.array(parameters["begin"]).astype(index_type) + end_values = np.array(parameters["end"]).astype(index_type) + stride_values = ( + np.array(parameters["strides"]).astype(index_type) + if parameters["strides"] is not None else None) + values = [input_values, begin_values, end_values] + if stride_values is not None: + values.append(stride_values) + + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_l2_pool(input_tensor, ksize, strides, padding, data_format): """Given an input perform a sequence of TensorFlow ops to produce l2pool.""" return tf.sqrt(tf.nn.avg_pool( @@ -1505,6 +1596,7 @@ def main(unused_args): "transpose.zip": make_transpose_tests, "mean.zip": make_mean_tests, "squeeze.zip": make_squeeze_tests, + "strided_slice.zip": make_strided_slice_tests, } out = FLAGS.zip_to_output bin_path = FLAGS.toco diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index c8a6e07abd..c29cd85c4d 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -263,6 +263,7 @@ INSTANTIATE_TESTS(div) INSTANTIATE_TESTS(transpose) INSTANTIATE_TESTS(mean) INSTANTIATE_TESTS(squeeze) +INSTANTIATE_TESTS(strided_slice) } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc index de4d06be2a..7e8b249b07 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -33,6 +33,10 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 4); const auto& start_array = model->GetArray(op->inputs[1]); if (!start_array.has_shape()) return false; + if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) { + // Only 1-4D arrays are supported for now. + return false; + } const auto& stop_array = model->GetArray(op->inputs[2]); if (!stop_array.has_shape()) return false; diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 1947271f55..e8f318cd43 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1179,6 +1179,8 @@ void ConvertStridedSliceOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "StridedSlice"); + // TODO(soroosh): The 4th input (strides) should be e optional, to be + // consistent with TF. CheckInputsCount(node, tf_import_flags, 4); auto* op = new StridedSliceOperator; diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 54fbba7381..3cda63a8ce 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -738,6 +738,9 @@ struct PadOperator : Operator { // // Inputs: // inputs[0]: required: the input array +// inputs[1]: required: the begin array +// inputs[2]: required: the end array +// inputs[3]: optional: the strides array // // TensorFlow equivalent: StridedSlice struct StridedSliceOperator : Operator { diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 0111e1ed92..0c2b570aad 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -617,6 +617,30 @@ class Split : public CustomOperator<TensorFlowSplitOperator> { } }; +class StridedSlice + : public BuiltinOperator<StridedSliceOperator, + ::tflite::StridedSliceOptions, + ::tflite::BuiltinOptions_StridedSliceOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateStridedSliceOptions( + *builder, op.begin_mask, op.end_mask, op.ellipsis_mask, + op.new_axis_mask, op.shrink_axis_mask); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->begin_mask = options.begin_mask(); + op->end_mask = options.end_mask(); + op->ellipsis_mask = options.ellipsis_mask(); + op->new_axis_mask = options.new_axis_mask(); + op->shrink_axis_mask = options.shrink_axis_mask(); + } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -777,6 +801,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); ops.emplace_back( new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); + ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, + OperatorType::kStridedSlice)); // Custom Operators. ops.emplace_back(new Cast("CAST", OperatorType::kCast)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 77c70847d1..de79c70e1b 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -398,6 +398,28 @@ TEST_F(OperatorTest, Squeeze) { EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims); } +TEST_F(OperatorTest, StridedSlice) { + StridedSliceOperator op; + + op.begin_mask = 1; + op.end_mask = 2; + op.ellipsis_mask = 1; + op.new_axis_mask = 1; + op.shrink_axis_mask = 2; + + auto output_toco_op = SerializeAndDeserialize( + GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op); + EXPECT_EQ(op.start_indices, output_toco_op->start_indices); + EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices); + EXPECT_EQ(op.strides, output_toco_op->strides); + EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask); + EXPECT_EQ(op.end_mask, output_toco_op->end_mask); + EXPECT_EQ(op.end_mask, output_toco_op->end_mask); + EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask); + EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask); + EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; |