aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-23 17:35:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-23 17:40:21 -0800
commit6dfe00ca4114371ed47c93810219736e3deda2d2 (patch)
tree14262ad0f7b3e212c0e2d84a480c383a4074031f /tensorflow/contrib
parenta3e81ec2892126056ad6c1feb9161bc16c2c2975 (diff)
Support StridedSlice in TFLite for 1D-4D tensors.
PiperOrigin-RevId: 183020501
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h20
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD13
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h51
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc256
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice_test.cc375
-rw-r--r--tensorflow/contrib/lite/model.cc12
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs10
-rw-r--r--[-rwxr-xr-x]tensorflow/contrib/lite/schema/schema_generated.h239
-rw-r--r--tensorflow/contrib/lite/testing/BUILD1
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py92
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc4
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc2
-rw-r--r--tensorflow/contrib/lite/toco/model.h3
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc26
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc22
18 files changed, 1107 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 3b43a1fd5d..ab07c58c92 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -88,7 +88,9 @@ typedef struct {
TfLiteFusedActivation activation;
} TfLiteSequenceRNNParams;
-typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams;
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteFullyConnectedParams;
typedef enum {
kTfLiteLshProjectionUnknown = 0,
@@ -96,9 +98,13 @@ typedef enum {
kTfLiteLshProjectionDense = 2,
} TfLiteLSHProjectionType;
-typedef struct { TfLiteLSHProjectionType type; } TfLiteLSHProjectionParams;
+typedef struct {
+ TfLiteLSHProjectionType type;
+} TfLiteLSHProjectionParams;
-typedef struct { float beta; } TfLiteSoftmaxParams;
+typedef struct {
+ float beta;
+} TfLiteSoftmaxParams;
typedef struct {
int axis;
@@ -226,6 +232,14 @@ typedef struct {
int num_squeeze_dims;
} TfLiteSqueezeParams;
+typedef struct {
+ int begin_mask;
+ int end_mask;
+ int ellipsis_mask;
+ int new_axis_mask;
+ int shrink_axis_mask;
+} TfLiteStridedSliceParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 2d51b8727f..4195e7553c 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -103,6 +103,7 @@ cc_library(
"space_to_batch_nd.cc",
"space_to_depth.cc",
"squeeze.cc",
+ "strided_slice.cc",
"sub.cc",
"svdf.cc",
"transpose.cc",
@@ -518,6 +519,18 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "strided_slice_test",
+ size = "small",
+ srcs = ["strided_slice_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 1d86183d94..03708eb32b 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -2330,6 +2330,18 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
}
}
+inline bool LoopCondition(int index, int stop, int stride) {
+ return stride > 0 ? index < stop : index > stop;
+}
+
+inline int StartIndex(int start, int stride, int dim, bool masked) {
+ return masked ? (stride > 0 ? 0 : dim - 1) : start;
+}
+
+inline int StopIndex(int stop, int stride, int dim, bool masked) {
+ return masked ? (stride > 0 ? dim : -1) : stop;
+}
+
template <typename T>
inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
int begin_mask, int end_mask,
@@ -2337,20 +2349,35 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& stops,
const std::vector<int>& strides, T* output_data,
const Dims<4>& output_dims) {
- const int start_b = (begin_mask & 8) ? 0 : starts[3];
- const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3];
- const int start_h = (begin_mask & 4) ? 0 : starts[2];
- const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2];
- const int start_w = (begin_mask & 2) ? 0 : starts[1];
- const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1];
- const int start_d = (begin_mask & 1) ? 0 : starts[0];
- const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0];
+ TFLITE_DCHECK_EQ(starts.size(), 4);
+ TFLITE_DCHECK_EQ(stops.size(), 4);
+ TFLITE_DCHECK_EQ(strides.size(), 4);
+ const int start_b =
+ StartIndex(starts[3], strides[3], input_dims.sizes[3], begin_mask & 8);
+ const int stop_b =
+ StopIndex(stops[3], strides[3], input_dims.sizes[3], end_mask & 8);
+ const int start_h =
+ StartIndex(starts[2], strides[2], input_dims.sizes[2], begin_mask & 4);
+ const int stop_h =
+ StopIndex(stops[2], strides[2], input_dims.sizes[2], end_mask & 4);
+ const int start_w =
+ StartIndex(starts[1], strides[1], input_dims.sizes[1], begin_mask & 2);
+ const int stop_w =
+ StopIndex(stops[1], strides[1], input_dims.sizes[1], end_mask & 2);
+ const int start_d =
+ StartIndex(starts[0], strides[0], input_dims.sizes[0], begin_mask & 1);
+ const int stop_d =
+ StopIndex(stops[0], strides[0], input_dims.sizes[0], end_mask & 1);
T* out_ptr = output_data;
- for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) {
- for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) {
- for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) {
- for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) {
+ for (int in_b = start_b; LoopCondition(in_b, stop_b, strides[3]);
+ in_b += strides[3]) {
+ for (int in_h = start_h; LoopCondition(in_h, stop_h, strides[2]);
+ in_h += strides[2]) {
+ for (int in_w = start_w; LoopCondition(in_w, stop_w, strides[1]);
+ in_w += strides[1]) {
+ for (int in_d = start_d; LoopCondition(in_d, stop_d, strides[0]);
+ in_d += strides[0]) {
*out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
}
}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index c9e74cb8d5..f605deaa5b 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -58,6 +58,7 @@ TfLiteRegistration* Register_GATHER();
TfLiteRegistration* Register_TRANSPOSE();
TfLiteRegistration* Register_MEAN();
TfLiteRegistration* Register_SQUEEZE();
+TfLiteRegistration* Register_STRIDED_SLICE();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -103,6 +104,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_DIV, Register_DIV());
AddBuiltin(BuiltinOperator_SUB, Register_SUB());
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
+ AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
}
TfLiteRegistration* BuiltinOpResolver::FindOp(
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
new file mode 100644
index 0000000000..91ba4a9b78
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -0,0 +1,256 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <cmath>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace strided_slice {
+
+enum KernelType {
+ kReference,
+ // TODO(soroosh): add kGenericOptimized
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kBeginTensor = 1;
+constexpr int kEndTensor = 2;
+constexpr int kStridesTensor = 3;
+constexpr int kOutputTensor = 0;
+
+struct StridedSliceContext {
+ StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
+ params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
+ input = GetInput(context, node, kInputTensor);
+ begin = GetInput(context, node, kBeginTensor);
+ end = GetInput(context, node, kEndTensor);
+ strides = GetInput(context, node, kStridesTensor);
+ output = GetOutput(context, node, kOutputTensor);
+ dims = NumDimensions(input);
+ }
+ TfLiteStridedSliceParams* params;
+ TfLiteTensor* input;
+ TfLiteTensor* begin;
+ TfLiteTensor* end;
+ TfLiteTensor* strides;
+ TfLiteTensor* output;
+ int dims;
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ StridedSliceContext op_context(context, node);
+
+ // Ensure validity of input tensor and its dimension
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
+ TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+ // Only INT32 begin/end/strides are supported
+ // TODO(soroosh) add support for INT64
+ TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
+ TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
+ "StridedSlice op only supports 1D-4D input arrays.");
+
+ // TODO(soroosh): add the following missing functionalities
+ TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0,
+ "ellipsis_mask is not implemented yet.");
+ TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
+ "new_axis_mask is not implemented yet.");
+ TF_LITE_ENSURE_MSG(context, op_context.params->shrink_axis_mask == 0,
+ "shrink_axis_mask is not implemented yet.");
+
+ // TODO(soroosh): optimize for constant tensors to do allocation in Prepare
+ op_context.output->allocation_type = kTfLiteDynamic;
+ return kTfLiteOk;
+} // namespace strided_slice
+
+// TODO(soroosh): consolidate with BytesRequired in interpreter.h
+TfLiteStatus BytesRequired(TfLiteContext* context, TfLiteType type,
+ const int* dims, int dims_size, size_t* bytes) {
+ // TODO(aselle): Check for overflow here using overflow.h in TensorFlow
+ // MultiplyWithoutOverflow.
+ TF_LITE_ENSURE(context, bytes != nullptr);
+ size_t count = 1;
+ for (int k = 0; k < dims_size; k++) count *= dims[k];
+ switch (type) {
+ case kTfLiteFloat32:
+ *bytes = sizeof(float) * count;
+ break;
+ case kTfLiteInt32:
+ *bytes = sizeof(int32_t) * count;
+ break;
+ case kTfLiteUInt8:
+ *bytes = sizeof(uint8_t) * count;
+ break;
+ case kTfLiteInt64:
+ *bytes = sizeof(int64_t) * count;
+ break;
+ default:
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// Reverse order of bits in the mask to match the expected order in kernel
+inline int ReverseMaskBits(int mask, int num_dimensions) {
+ int out = 0;
+ for (int dim = 0; dim < num_dimensions; dim++) {
+ out <<= 1;
+ out += (mask & 1);
+ mask >>= 1;
+ }
+ return out;
+}
+
+// This Op only supports 1-4D cases and since we use the reference 4D
+// implementation, the 1-3D tensors are mapped to 4D.
+const int kMaxDim = 4;
+
+inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) {
+ return (divisor + (dividend % divisor)) % divisor;
+}
+
+inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) {
+ return pos_stride
+ ? (index >= dim ? dim
+ : PositiveRemainder(
+ std::min(std::max(index, -dim), dim), dim))
+ : (index < -dim
+ ? -1
+ : PositiveRemainder(
+ std::min(std::max(index, -dim), dim - 1), dim));
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ StridedSliceContext op_context(context, node);
+
+ std::vector<int> starts;
+ std::vector<int> stops;
+ std::vector<int> strides;
+
+ // Determine size of output tensor and map indices
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(op_context.dims);
+ for (int idx = op_context.dims - 1; idx >= 0; --idx) {
+ int dim = op_context.input->dims->data[idx];
+ int32_t stride = GetTensorData<int32_t>(op_context.strides)[idx];
+ TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
+ bool pos_stride = stride > 0;
+
+ int32_t begin =
+ op_context.params->begin_mask & (1 << idx)
+ ? pos_stride ? 0 : dim - 1
+ : ClampedIndex(GetTensorData<int32_t>(op_context.begin)[idx], dim,
+ pos_stride);
+ int32_t end =
+ op_context.params->end_mask & (1 << idx)
+ ? pos_stride ? dim : -1
+ : ClampedIndex(GetTensorData<int32_t>(op_context.end)[idx], dim,
+ pos_stride);
+
+ // This is valid for both positive and negative strides
+ output_shape->data[idx] = ceil((end - begin) / static_cast<float>(stride));
+ output_shape->data[idx] =
+ output_shape->data[idx] < 0 ? 0 : output_shape->data[idx];
+ starts.emplace_back(begin);
+ stops.emplace_back(end);
+ strides.emplace_back(stride);
+ }
+
+ for (int i = op_context.dims; i < kMaxDim; i++) {
+ starts.emplace_back(0);
+ stops.emplace_back(1);
+ strides.emplace_back(1);
+ }
+
+ TF_LITE_ENSURE_STATUS(
+ context->ResizeTensor(context, op_context.output, output_shape));
+
+ size_t required_bytes;
+ TF_LITE_ENSURE_OK(
+ context,
+ BytesRequired(context, op_context.output->type, output_shape->data,
+ output_shape->size, &required_bytes));
+ TfLiteTensorRealloc(required_bytes, op_context.output);
+
+ op_context.params->begin_mask =
+ ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
+ op_context.params->end_mask =
+ ReverseMaskBits(op_context.params->end_mask, op_context.dims);
+
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
+ kernel_type::StridedSlice( \
+ GetTensorData<data_type>(op_context.input), \
+ GetTensorDims(op_context.input), op_context.params->begin_mask, \
+ op_context.params->end_mask, starts, stops, strides, \
+ GetTensorData<data_type>(op_context.output), \
+ GetTensorDims(op_context.output))
+
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ if (kernel_type == kReference) {
+ TF_LITE_STRIDED_SLICE(reference_ops, float);
+ }
+ break;
+ case kTfLiteInt32:
+ if (kernel_type == kReference) {
+ TF_LITE_STRIDED_SLICE(reference_ops, int32_t);
+ }
+ break;
+ case kTfLiteInt64:
+ if (kernel_type == kReference) {
+ TF_LITE_STRIDED_SLICE(reference_ops, int64_t);
+ }
+ break;
+ default:
+ context->ReportError(context,
+ "Type is currently not supported "
+ "by StridedSlice.");
+ return kTfLiteError;
+ }
+#undef TF_LITE_STRIDED_SLICE
+ return kTfLiteOk;
+}
+
+} // namespace strided_slice
+
+TfLiteRegistration* Register_STRIDED_SLICE_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, strided_slice::Prepare,
+ strided_slice::Eval<strided_slice::kReference>};
+ return &r;
+}
+
+// TODO(soroosh): add optimized
+TfLiteRegistration* Register_STRIDED_SLICE() {
+ return Register_STRIDED_SLICE_REF();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc
new file mode 100644
index 0000000000..cd4a364682
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc
@@ -0,0 +1,375 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class StridedSliceOpModel : public SingleOpModel {
+ public:
+ StridedSliceOpModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> begin_shape,
+ std::initializer_list<int> end_shape,
+ std::initializer_list<int> strides_shape, int begin_mask,
+ int end_mask, int ellipsis_mask, int new_axis_mask,
+ int shrink_axis_mask) {
+ input_ = AddInput(TensorType_FLOAT32);
+ begin_ = AddInput(TensorType_INT32);
+ end_ = AddInput(TensorType_INT32);
+ strides_ = AddInput(TensorType_INT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions,
+ CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
+ new_axis_mask, shrink_axis_mask)
+ .Union());
+ BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+ void SetBegin(std::initializer_list<int32> data) {
+ PopulateTensor<int32>(begin_, data);
+ }
+ void SetEnd(std::initializer_list<int32> data) {
+ PopulateTensor<int32>(end_, data);
+ }
+ void SetStrides(std::initializer_list<int32> data) {
+ PopulateTensor<int32>(strides_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int begin_;
+ int end_;
+ int strides_;
+ int output_;
+};
+
+TEST(StridedSliceOpTest, UnsupportedInputSize) {
+ EXPECT_DEATH(
+ StridedSliceOpModel({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0),
+ "StridedSlice op only supports 1D-4D input arrays.");
+}
+
+TEST(StridedSliceOpTest, UnssupportedArgs) {
+ EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0),
+ "ellipsis_mask is not implemented yet.");
+ EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0),
+ "new_axis_mask is not implemented yet.");
+ EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 0, 1),
+ "shrink_axis_mask is not implemented yet.");
+}
+
+TEST(StridedSliceOpTest, In1D) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
+}
+
+TEST(StridedSliceOpTest, In1D_EmptyOutput) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({10});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0}));
+}
+
+TEST(StridedSliceOpTest, In1D_NegativeBegin) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({-3});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
+}
+
+TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({-5});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
+}
+
+TEST(StridedSliceOpTest, In1D_NegativeEnd) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1});
+ m.SetEnd({-2});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
+}
+
+TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({-3});
+ m.SetEnd({5});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
+}
+
+TEST(StridedSliceOpTest, In1D_BeginMask) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
+}
+
+TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({-2});
+ m.SetEnd({-3});
+ m.SetStrides({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+}
+
+TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({5});
+ m.SetEnd({2});
+ m.SetStrides({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({4}));
+}
+
+TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({2});
+ m.SetEnd({-4});
+ m.SetStrides({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2}));
+}
+
+TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({-3});
+ m.SetEnd({-5});
+ m.SetStrides({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1}));
+}
+
+TEST(StridedSliceOpTest, In1D_EndMask) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
+}
+TEST(StridedSliceOpTest, In1D_NegStride) {
+ StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3});
+ m.SetBegin({-1});
+ m.SetEnd({-4});
+ m.SetStrides({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1}));
+}
+
+TEST(StridedSliceOpTest, In1D_EvenLenStride2) {
+ StridedSliceOpModel m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2});
+ m.SetBegin({0});
+ m.SetEnd({2});
+ m.SetStrides({2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
+}
+TEST(StridedSliceOpTest, In1D_OddLenStride2) {
+ StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3});
+ m.SetBegin({0});
+ m.SetEnd({3});
+ m.SetStrides({2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3}));
+}
+
+TEST(StridedSliceOpTest, In2D_Identity) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({0, 0});
+ m.SetEnd({2, 3});
+ m.SetStrides({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+TEST(StridedSliceOpTest, In2D) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, 0});
+ m.SetEnd({2, 2});
+ m.SetStrides({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5}));
+}
+
+TEST(StridedSliceOpTest, In2D_Stride2) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({0, 0});
+ m.SetEnd({2, 3});
+ m.SetStrides({2, 2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3}));
+}
+
+TEST(StridedSliceOpTest, In2D_NegStride) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, -1});
+ m.SetEnd({2, -4});
+ m.SetStrides({2, -1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4}));
+}
+
+TEST(StridedSliceOpTest, In2D_BeginMask) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, 0});
+ m.SetEnd({2, 2});
+ m.SetStrides({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5}));
+}
+
+TEST(StridedSliceOpTest, In2D_EndMask) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, 0});
+ m.SetEnd({2, 2});
+ m.SetStrides({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5, 6}));
+}
+
+TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, -2});
+ m.SetEnd({2, -4});
+ m.SetStrides({1, -1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4}));
+}
+TEST(StridedSliceOpTest, In2D_NegStrideEndMask) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, -2});
+ m.SetEnd({2, -3});
+ m.SetStrides({1, -1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 4}));
+}
+
+TEST(StridedSliceOpTest, In3D_Identity) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2}));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
+}
+
+TEST(StridedSliceOpTest, In3D_NegStride) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({-1, -1, -1});
+ m.SetEnd({-3, -4, -3});
+ m.SetStrides({-1, -1, -1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2}));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}));
+}
+TEST(StridedSliceOpTest, In3D_Strided2) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({2, 2, 2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 4d8a6d10c8..a01b74f9da 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -618,6 +618,18 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_STRIDED_SLICE: {
+ auto* params = MallocPOD<TfLiteStridedSliceParams>();
+ if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
+ params->begin_mask = schema_params->begin_mask();
+ params->end_mask = schema_params->end_mask();
+ params->ellipsis_mask = schema_params->ellipsis_mask();
+ params->new_axis_mask = schema_params->new_axis_mask();
+ params->shrink_axis_mask = schema_params->shrink_axis_mask();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
}
return builtin_data;
}
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 998a7b7614..d5b9319407 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -339,6 +339,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_DIV:
case tflite::BuiltinOperator_SUB:
case tflite::BuiltinOperator_SQUEEZE:
+ case tflite::BuiltinOperator_STRIDED_SLICE:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 6fcd3e51a4..8ddad4d251 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -118,6 +118,7 @@ enum BuiltinOperator : byte {
DIV = 42,
SQUEEZE = 43,
UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
+ STRIDED_SLICE = 45,
}
// Options for the builtin operators.
@@ -153,6 +154,7 @@ union BuiltinOptions {
DivOptions,
SqueezeOptions,
SequenceRNNOptions,
+ StridedSliceOptions,
}
enum Padding : byte { SAME, VALID }
@@ -340,6 +342,14 @@ table SqueezeOptions {
squeeze_dims:[int];
}
+table StridedSliceOptions {
+ begin_mask: int;
+ end_mask: int;
+ ellipsis_mask: int;
+ new_axis_mask: int;
+ shrink_axis_mask: int;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 6eb9ae2926..b756891f66 100755..100644
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -120,6 +120,9 @@ struct MeanOptionsT;
struct SqueezeOptions;
struct SqueezeOptionsT;
+struct StridedSliceOptions;
+struct StridedSliceOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -207,11 +210,12 @@ enum BuiltinOperator {
BuiltinOperator_DIV = 42,
BuiltinOperator_SQUEEZE = 43,
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
+ BuiltinOperator_STRIDED_SLICE = 45,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM
+ BuiltinOperator_MAX = BuiltinOperator_STRIDED_SLICE
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[42] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[43] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -254,7 +258,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[42] {
BuiltinOperator_SUB,
BuiltinOperator_DIV,
BuiltinOperator_SQUEEZE,
- BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM};
+ BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ BuiltinOperator_STRIDED_SLICE};
return values;
}
@@ -304,6 +309,7 @@ inline const char **EnumNamesBuiltinOperator() {
"DIV",
"SQUEEZE",
"UNIDIRECTIONAL_SEQUENCE_LSTM",
+ "STRIDED_SLICE",
nullptr};
return names;
}
@@ -346,11 +352,12 @@ enum BuiltinOptions {
BuiltinOptions_DivOptions = 29,
BuiltinOptions_SqueezeOptions = 30,
BuiltinOptions_SequenceRNNOptions = 31,
+ BuiltinOptions_StridedSliceOptions = 32,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_SequenceRNNOptions
+ BuiltinOptions_MAX = BuiltinOptions_StridedSliceOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[32] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[33] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -383,7 +390,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[32] {
BuiltinOptions_SubOptions,
BuiltinOptions_DivOptions,
BuiltinOptions_SqueezeOptions,
- BuiltinOptions_SequenceRNNOptions};
+ BuiltinOptions_SequenceRNNOptions,
+ BuiltinOptions_StridedSliceOptions};
return values;
}
@@ -420,6 +428,7 @@ inline const char **EnumNamesBuiltinOptions() {
"DivOptions",
"SqueezeOptions",
"SequenceRNNOptions",
+ "StridedSliceOptions",
nullptr};
return names;
}
@@ -593,6 +602,11 @@ struct BuiltinOptionsTraits<SequenceRNNOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions;
};
+template <>
+struct BuiltinOptionsTraits<StridedSliceOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -950,6 +964,16 @@ struct BuiltinOptionsUnion {
? reinterpret_cast<const SequenceRNNOptionsT *>(value)
: nullptr;
}
+ StridedSliceOptionsT *AsStridedSliceOptions() {
+ return type == BuiltinOptions_StridedSliceOptions
+ ? reinterpret_cast<StridedSliceOptionsT *>(value)
+ : nullptr;
+ }
+ const StridedSliceOptionsT *AsStridedSliceOptions() const {
+ return type == BuiltinOptions_StridedSliceOptions
+ ? reinterpret_cast<const StridedSliceOptionsT *>(value)
+ : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj,
@@ -3532,6 +3556,111 @@ flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions(
flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct StridedSliceOptionsT : public flatbuffers::NativeTable {
+ typedef StridedSliceOptions TableType;
+ int32_t begin_mask;
+ int32_t end_mask;
+ int32_t ellipsis_mask;
+ int32_t new_axis_mask;
+ int32_t shrink_axis_mask;
+ StridedSliceOptionsT()
+ : begin_mask(0),
+ end_mask(0),
+ ellipsis_mask(0),
+ new_axis_mask(0),
+ shrink_axis_mask(0) {}
+};
+
+struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS
+ : private flatbuffers::Table {
+ typedef StridedSliceOptionsT NativeTableType;
+ enum {
+ VT_BEGIN_MASK = 4,
+ VT_END_MASK = 6,
+ VT_ELLIPSIS_MASK = 8,
+ VT_NEW_AXIS_MASK = 10,
+ VT_SHRINK_AXIS_MASK = 12
+ };
+ int32_t begin_mask() const { return GetField<int32_t>(VT_BEGIN_MASK, 0); }
+ int32_t end_mask() const { return GetField<int32_t>(VT_END_MASK, 0); }
+ int32_t ellipsis_mask() const {
+ return GetField<int32_t>(VT_ELLIPSIS_MASK, 0);
+ }
+ int32_t new_axis_mask() const {
+ return GetField<int32_t>(VT_NEW_AXIS_MASK, 0);
+ }
+ int32_t shrink_axis_mask() const {
+ return GetField<int32_t>(VT_SHRINK_AXIS_MASK, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_BEGIN_MASK) &&
+ VerifyField<int32_t>(verifier, VT_END_MASK) &&
+ VerifyField<int32_t>(verifier, VT_ELLIPSIS_MASK) &&
+ VerifyField<int32_t>(verifier, VT_NEW_AXIS_MASK) &&
+ VerifyField<int32_t>(verifier, VT_SHRINK_AXIS_MASK) &&
+ verifier.EndTable();
+ }
+ StridedSliceOptionsT *UnPack(
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(
+ StridedSliceOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<StridedSliceOptions> Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct StridedSliceOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_begin_mask(int32_t begin_mask) {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_BEGIN_MASK, begin_mask, 0);
+ }
+ void add_end_mask(int32_t end_mask) {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_END_MASK, end_mask, 0);
+ }
+ void add_ellipsis_mask(int32_t ellipsis_mask) {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_ELLIPSIS_MASK,
+ ellipsis_mask, 0);
+ }
+ void add_new_axis_mask(int32_t new_axis_mask) {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_NEW_AXIS_MASK,
+ new_axis_mask, 0);
+ }
+ void add_shrink_axis_mask(int32_t shrink_axis_mask) {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_SHRINK_AXIS_MASK,
+ shrink_axis_mask, 0);
+ }
+ explicit StridedSliceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ StridedSliceOptionsBuilder &operator=(const StridedSliceOptionsBuilder &);
+ flatbuffers::Offset<StridedSliceOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<StridedSliceOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<StridedSliceOptions> CreateStridedSliceOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, int32_t begin_mask = 0,
+ int32_t end_mask = 0, int32_t ellipsis_mask = 0, int32_t new_axis_mask = 0,
+ int32_t shrink_axis_mask = 0) {
+ StridedSliceOptionsBuilder builder_(_fbb);
+ builder_.add_shrink_axis_mask(shrink_axis_mask);
+ builder_.add_new_axis_mask(new_axis_mask);
+ builder_.add_ellipsis_mask(ellipsis_mask);
+ builder_.add_end_mask(end_mask);
+ builder_.add_begin_mask(begin_mask);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<StridedSliceOptions> CreateStridedSliceOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -3816,6 +3945,11 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
? static_cast<const SequenceRNNOptions *>(builtin_options())
: nullptr;
}
+ const StridedSliceOptions *builtin_options_as_StridedSliceOptions() const {
+ return builtin_options_type() == BuiltinOptions_StridedSliceOptions
+ ? static_cast<const StridedSliceOptions *>(builtin_options())
+ : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -4023,6 +4157,12 @@ Operator::builtin_options_as<SequenceRNNOptions>() const {
return builtin_options_as_SequenceRNNOptions();
}
+template <>
+inline const StridedSliceOptions *
+Operator::builtin_options_as<StridedSliceOptions>() const {
+ return builtin_options_as_StridedSliceOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -4962,11 +5102,11 @@ inline void SequenceRNNOptions::UnPackTo(
{
auto _e = time_major();
_o->time_major = _e;
- }
+ };
{
auto _e = fused_activation_function();
_o->fused_activation_function = _e;
- }
+ };
}
inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack(
@@ -6040,6 +6180,67 @@ inline flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions(
return tflite::CreateSqueezeOptions(_fbb, _squeeze_dims);
}
+inline StridedSliceOptionsT *StridedSliceOptions::UnPack(
+ const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new StridedSliceOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void StridedSliceOptions::UnPackTo(
+ StridedSliceOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ {
+ auto _e = begin_mask();
+ _o->begin_mask = _e;
+ };
+ {
+ auto _e = end_mask();
+ _o->end_mask = _e;
+ };
+ {
+ auto _e = ellipsis_mask();
+ _o->ellipsis_mask = _e;
+ };
+ {
+ auto _e = new_axis_mask();
+ _o->new_axis_mask = _e;
+ };
+ {
+ auto _e = shrink_axis_mask();
+ _o->shrink_axis_mask = _e;
+ };
+}
+
+inline flatbuffers::Offset<StridedSliceOptions> StridedSliceOptions::Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateStridedSliceOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<StridedSliceOptions> CreateStridedSliceOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs {
+ flatbuffers::FlatBufferBuilder *__fbb;
+ const StridedSliceOptionsT *__o;
+ const flatbuffers::rehasher_function_t *__rehasher;
+ } _va = {&_fbb, _o, _rehasher};
+ (void)_va;
+ auto _begin_mask = _o->begin_mask;
+ auto _end_mask = _o->end_mask;
+ auto _ellipsis_mask = _o->ellipsis_mask;
+ auto _new_axis_mask = _o->new_axis_mask;
+ auto _shrink_axis_mask = _o->shrink_axis_mask;
+ return tflite::CreateStridedSliceOptions(_fbb, _begin_mask, _end_mask,
+ _ellipsis_mask, _new_axis_mask,
+ _shrink_axis_mask);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(
const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
@@ -6552,6 +6753,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier,
auto ptr = reinterpret_cast<const SequenceRNNOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_StridedSliceOptions: {
+ auto ptr = reinterpret_cast<const StridedSliceOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default:
return false;
}
@@ -6700,6 +6905,10 @@ inline void *BuiltinOptionsUnion::UnPack(
auto ptr = reinterpret_cast<const SequenceRNNOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_StridedSliceOptions: {
+ auto ptr = reinterpret_cast<const StridedSliceOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default:
return nullptr;
}
@@ -6835,6 +7044,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(
auto ptr = reinterpret_cast<const SequenceRNNOptionsT *>(value);
return CreateSequenceRNNOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_StridedSliceOptions: {
+ auto ptr = reinterpret_cast<const StridedSliceOptionsT *>(value);
+ return CreateStridedSliceOptions(_fbb, ptr, _rehasher).Union();
+ }
default:
return 0;
}
@@ -6985,6 +7198,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u)
*reinterpret_cast<SequenceRNNOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_StridedSliceOptions: {
+ value = new StridedSliceOptionsT(
+ *reinterpret_cast<StridedSliceOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -7147,6 +7365,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_StridedSliceOptions: {
+ auto ptr = reinterpret_cast<StridedSliceOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default:
break;
}
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 933da11353..50e8ca75f8 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -46,6 +46,7 @@ gen_zipped_test_files(
"space_to_batch_nd.zip",
"space_to_depth.zip",
"squeeze.zip",
+ "strided_slice.zip",
"sub.zip",
"transpose.zip",
],
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 56e4dfc7a2..6204471e52 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1447,6 +1447,97 @@ def make_squeeze_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_strided_slice_tests(zip_path):
+ """Make a set of tests to do strided_slice."""
+
+ # TODO(soroosh): add test/support for uint8.
+ test_parameters = [
+ # 4-D
+ {
+ "dtype": [tf.float32, tf.int32, tf.int64],
+ "index_type": [tf.int32],
+ "input_shape": [[12, 2, 2, 5]],
+ "begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
+ "end": [[8, 2, 2, 3], [12, 2, 2, 5]],
+ "strides": [None, [1, 1, 1, 1], [2, 1, 3, 1]],
+ "begin_mask": [None, 0, 1, 2, 8],
+ "end_mask": [None, 0, 1, 2, 8],
+ },
+ # 2-D
+ {
+ "dtype": [tf.float32, tf.int32, tf.int64],
+ "index_type": [tf.int32],
+ "input_shape": [[2, 3]],
+ "begin": [[0, 0], [1, 0]],
+ "end": [[2, 3], [2, 2]],
+ "strides": [None, [1, 1], [2, 2]],
+ "begin_mask": [None, 0, 1, 2],
+ "end_mask": [None, 0, 1, 2],
+ },
+ # Negative strides
+ {
+ "dtype": [tf.float32, tf.int32, tf.int64],
+ "index_type": [tf.int32],
+ "input_shape": [[2, 3]],
+ "begin": [[0, -1]],
+ "end": [[2, -3]],
+ "strides": [[1, -1]],
+ "begin_mask": [None, 0, 1, 2],
+ "end_mask": [None, 0, 1, 2],
+ },
+ ]
+
+ def build_graph(parameters):
+ """Build graph for stride_slice test."""
+ input_tensor = tf.placeholder(
+ dtype=parameters["dtype"],
+ name="input",
+ shape=parameters["input_shape"])
+ begin = tf.placeholder(
+ dtype=parameters["index_type"],
+ name="begin",
+ shape=[len(parameters["input_shape"])])
+ end = tf.placeholder(
+ dtype=parameters["index_type"],
+ name="end",
+ shape=[len(parameters["input_shape"])])
+ strides = (
+ tf.placeholder(
+ dtype=parameters["index_type"],
+ name="strides",
+ shape=[len(parameters["input_shape"])])
+ if parameters["strides"] is not None else None)
+ tensors = [input_tensor, begin, end]
+ if strides is not None:
+ tensors.append(strides)
+ out = tf.strided_slice(
+ input_tensor,
+ begin,
+ end,
+ strides,
+ begin_mask=parameters["begin_mask"],
+ end_mask=parameters["end_mask"])
+ return tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ """Build inputs for stride_slice test."""
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ index_type = _TF_TYPE_INFO[parameters["index_type"]][0]
+ begin_values = np.array(parameters["begin"]).astype(index_type)
+ end_values = np.array(parameters["end"]).astype(index_type)
+ stride_values = (
+ np.array(parameters["strides"]).astype(index_type)
+ if parameters["strides"] is not None else None)
+ values = [input_values, begin_values, end_values]
+ if stride_values is not None:
+ values.append(stride_values)
+
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
"""Given an input perform a sequence of TensorFlow ops to produce l2pool."""
return tf.sqrt(tf.nn.avg_pool(
@@ -1505,6 +1596,7 @@ def main(unused_args):
"transpose.zip": make_transpose_tests,
"mean.zip": make_mean_tests,
"squeeze.zip": make_squeeze_tests,
+ "strided_slice.zip": make_strided_slice_tests,
}
out = FLAGS.zip_to_output
bin_path = FLAGS.toco
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index c8a6e07abd..c29cd85c4d 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -263,6 +263,7 @@ INSTANTIATE_TESTS(div)
INSTANTIATE_TESTS(transpose)
INSTANTIATE_TESTS(mean)
INSTANTIATE_TESTS(squeeze)
+INSTANTIATE_TESTS(strided_slice)
} // namespace testing
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
index de4d06be2a..7e8b249b07 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
@@ -33,6 +33,10 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 4);
const auto& start_array = model->GetArray(op->inputs[1]);
if (!start_array.has_shape()) return false;
+ if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) {
+ // Only 1-4D arrays are supported for now.
+ return false;
+ }
const auto& stop_array = model->GetArray(op->inputs[2]);
if (!stop_array.has_shape()) return false;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 1947271f55..e8f318cd43 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1179,6 +1179,8 @@ void ConvertStridedSliceOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
CHECK_EQ(node.op(), "StridedSlice");
+ // TODO(soroosh): The 4th input (strides) should be e optional, to be
+ // consistent with TF.
CheckInputsCount(node, tf_import_flags, 4);
auto* op = new StridedSliceOperator;
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 54fbba7381..3cda63a8ce 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -738,6 +738,9 @@ struct PadOperator : Operator {
//
// Inputs:
// inputs[0]: required: the input array
+// inputs[1]: required: the begin array
+// inputs[2]: required: the end array
+// inputs[3]: optional: the strides array
//
// TensorFlow equivalent: StridedSlice
struct StridedSliceOperator : Operator {
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 0111e1ed92..0c2b570aad 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -617,6 +617,30 @@ class Split : public CustomOperator<TensorFlowSplitOperator> {
}
};
+class StridedSlice
+ : public BuiltinOperator<StridedSliceOperator,
+ ::tflite::StridedSliceOptions,
+ ::tflite::BuiltinOptions_StridedSliceOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateStridedSliceOptions(
+ *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
+ op.new_axis_mask, op.shrink_axis_mask);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->begin_mask = options.begin_mask();
+ op->end_mask = options.end_mask();
+ op->ellipsis_mask = options.ellipsis_mask();
+ op->new_axis_mask = options.new_axis_mask();
+ op->shrink_axis_mask = options.shrink_axis_mask();
+ }
+};
+
class TensorFlowUnsupported : public BaseOperator {
public:
using BaseOperator::BaseOperator;
@@ -777,6 +801,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
ops.emplace_back(
new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
+ ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
+ OperatorType::kStridedSlice));
// Custom Operators.
ops.emplace_back(new Cast("CAST", OperatorType::kCast));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 77c70847d1..de79c70e1b 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -398,6 +398,28 @@ TEST_F(OperatorTest, Squeeze) {
EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims);
}
+TEST_F(OperatorTest, StridedSlice) {
+ StridedSliceOperator op;
+
+ op.begin_mask = 1;
+ op.end_mask = 2;
+ op.ellipsis_mask = 1;
+ op.new_axis_mask = 1;
+ op.shrink_axis_mask = 2;
+
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op);
+ EXPECT_EQ(op.start_indices, output_toco_op->start_indices);
+ EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices);
+ EXPECT_EQ(op.strides, output_toco_op->strides);
+ EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask);
+ EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
+ EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
+ EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask);
+ EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask);
+ EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask);
+}
+
TEST_F(OperatorTest, TensorFlowUnsupported) {
TensorFlowUnsupportedOperator op;
op.tensorflow_op = "MyCustomUnsupportedOp";