diff options
author | 2018-04-03 10:32:57 -0700 | |
---|---|---|
committer | 2018-04-03 10:35:29 -0700 | |
commit | f005999d571536b859229229e4487cf749d9a786 (patch) | |
tree | af583f2229b7a4c143969a25ac035a888514b7ea /tensorflow/contrib/lite/kernels/strided_slice_test.cc | |
parent | 63a3be25482121e34ad3ac399e1ece4b24656318 (diff) |
Add 8bit strided slice op to tflite
PiperOrigin-RevId: 191461696
Diffstat (limited to 'tensorflow/contrib/lite/kernels/strided_slice_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/strided_slice_test.cc | 113 |
1 files changed, 64 insertions, 49 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc index 5c98c5f431..22d7b097cb 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc @@ -24,6 +24,8 @@ namespace { using ::int32; using ::testing::ElementsAreArray; +template <typename input_type = float, + TensorType tensor_input_type = TensorType_FLOAT32> class StridedSliceOpModel : public SingleOpModel { public: StridedSliceOpModel(std::initializer_list<int> input_shape, @@ -32,11 +34,11 @@ class StridedSliceOpModel : public SingleOpModel { std::initializer_list<int> strides_shape, int begin_mask, int end_mask, int ellipsis_mask, int new_axis_mask, int shrink_axis_mask) { - input_ = AddInput(TensorType_FLOAT32); + input_ = AddInput(tensor_input_type); begin_ = AddInput(TensorType_INT32); end_ = AddInput(TensorType_INT32); strides_ = AddInput(TensorType_INT32); - output_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(tensor_input_type); SetBuiltinOp( BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions, CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask, @@ -45,8 +47,8 @@ class StridedSliceOpModel : public SingleOpModel { BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape}); } - void SetInput(std::initializer_list<float> data) { - PopulateTensor<float>(input_, data); + void SetInput(std::initializer_list<input_type> data) { + PopulateTensor<input_type>(input_, data); } void SetBegin(std::initializer_list<int32> data) { PopulateTensor<int32>(begin_, data); @@ -58,7 +60,9 @@ class StridedSliceOpModel : public SingleOpModel { PopulateTensor<int32>(strides_, data); } - std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + std::vector<input_type> GetOutput() { + return ExtractVector<input_type>(output_); + } std::vector<int> GetOutputShape() { return GetTensorShape(output_); } private: @@ -71,19 +75,19 @@ class StridedSliceOpModel : public SingleOpModel { TEST(StridedSliceOpTest, UnsupportedInputSize) { EXPECT_DEATH( - StridedSliceOpModel({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0), + StridedSliceOpModel<>({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0), "StridedSlice op only supports 1D-4D input arrays."); } TEST(StridedSliceOpTest, UnssupportedArgs) { - EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0), + EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0), "ellipsis_mask is not implemented yet."); - EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0), + EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0), "new_axis_mask is not implemented yet."); } TEST(StridedSliceOpTest, In1D) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -94,7 +98,7 @@ TEST(StridedSliceOpTest, In1D) { } TEST(StridedSliceOpTest, In1D_EmptyOutput) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({10}); m.SetEnd({3}); @@ -104,7 +108,7 @@ TEST(StridedSliceOpTest, In1D_EmptyOutput) { } TEST(StridedSliceOpTest, In1D_NegativeBegin) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-3}); m.SetEnd({3}); @@ -115,7 +119,7 @@ TEST(StridedSliceOpTest, In1D_NegativeBegin) { } TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-5}); m.SetEnd({3}); @@ -126,7 +130,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) { } TEST(StridedSliceOpTest, In1D_NegativeEnd) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({-2}); @@ -137,7 +141,7 @@ TEST(StridedSliceOpTest, In1D_NegativeEnd) { } TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-3}); m.SetEnd({5}); @@ -148,7 +152,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) { } TEST(StridedSliceOpTest, In1D_BeginMask) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -159,7 +163,7 @@ TEST(StridedSliceOpTest, In1D_BeginMask) { } TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-2}); m.SetEnd({-3}); @@ -170,7 +174,7 @@ TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) { } TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({5}); m.SetEnd({2}); @@ -181,7 +185,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) { } TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({2}); m.SetEnd({-4}); @@ -192,7 +196,7 @@ TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) { } TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-3}); m.SetEnd({-5}); @@ -203,7 +207,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) { } TEST(StridedSliceOpTest, In1D_EndMask) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -214,7 +218,7 @@ TEST(StridedSliceOpTest, In1D_EndMask) { } TEST(StridedSliceOpTest, In1D_NegStride) { - StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3}); m.SetBegin({-1}); m.SetEnd({-4}); @@ -225,7 +229,7 @@ TEST(StridedSliceOpTest, In1D_NegStride) { } TEST(StridedSliceOpTest, In1D_EvenLenStride2) { - StridedSliceOpModel m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2}); m.SetBegin({0}); m.SetEnd({2}); @@ -236,7 +240,7 @@ TEST(StridedSliceOpTest, In1D_EvenLenStride2) { } TEST(StridedSliceOpTest, In1D_OddLenStride2) { - StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3}); m.SetBegin({0}); m.SetEnd({3}); @@ -247,7 +251,7 @@ TEST(StridedSliceOpTest, In1D_OddLenStride2) { } TEST(StridedSliceOpTest, In2D_Identity) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -258,7 +262,7 @@ TEST(StridedSliceOpTest, In2D_Identity) { } TEST(StridedSliceOpTest, In2D) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, 0}); m.SetEnd({2, 2}); @@ -269,7 +273,7 @@ TEST(StridedSliceOpTest, In2D) { } TEST(StridedSliceOpTest, In2D_Stride2) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -280,7 +284,7 @@ TEST(StridedSliceOpTest, In2D_Stride2) { } TEST(StridedSliceOpTest, In2D_NegStride) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, -1}); m.SetEnd({2, -4}); @@ -291,7 +295,7 @@ TEST(StridedSliceOpTest, In2D_NegStride) { } TEST(StridedSliceOpTest, In2D_BeginMask) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, 0}); m.SetEnd({2, 2}); @@ -302,7 +306,7 @@ TEST(StridedSliceOpTest, In2D_BeginMask) { } TEST(StridedSliceOpTest, In2D_EndMask) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, 0}); m.SetEnd({2, 2}); @@ -313,7 +317,7 @@ TEST(StridedSliceOpTest, In2D_EndMask) { } TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, -2}); m.SetEnd({2, -4}); @@ -324,7 +328,7 @@ TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { } TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, -2}); m.SetEnd({2, -3}); @@ -335,7 +339,7 @@ TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { } TEST(StridedSliceOpTest, In3D_Identity) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -347,7 +351,7 @@ TEST(StridedSliceOpTest, In3D_Identity) { } TEST(StridedSliceOpTest, In3D_NegStride) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({-1, -1, -1}); m.SetEnd({-3, -4, -3}); @@ -359,7 +363,7 @@ TEST(StridedSliceOpTest, In3D_NegStride) { } TEST(StridedSliceOpTest, In3D_Strided2) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -370,7 +374,7 @@ TEST(StridedSliceOpTest, In3D_Strided2) { } TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -381,7 +385,7 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { } TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); m.SetBegin({2}); m.SetEnd({1}); @@ -392,7 +396,7 @@ TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) { } TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -403,7 +407,7 @@ TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { } TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) { - StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); m.SetBegin({-2}); m.SetEnd({-3}); @@ -414,7 +418,7 @@ TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) { } TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -425,7 +429,7 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { } TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -436,7 +440,7 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { } TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -447,7 +451,7 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -458,7 +462,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -469,7 +473,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -480,7 +484,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -491,7 +495,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -502,7 +506,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -513,7 +517,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7); + StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -525,7 +529,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { // This tests catches a very subtle bug that was fixed by cl/188403234. TEST(StridedSliceOpTest, RunTwice) { - StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); auto setup_inputs = [&m]() { m.SetInput({1, 2, 3, 4, 5, 6}); @@ -544,6 +548,17 @@ TEST(StridedSliceOpTest, RunTwice) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); } +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { + StridedSliceOpModel<uint8, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, + 0, 0, 1); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} } // namespace } // namespace tflite |