diff options
author | 2018-08-20 12:23:46 -0700 | |
---|---|---|
committer | 2018-08-20 12:28:35 -0700 | |
commit | d0377209b9a39a3e00483609e16f3d8efe72b3f9 (patch) | |
tree | 816822c145db303e2d80ff5607f83cb40cc1b1be /tensorflow/contrib/lite/kernels/pack_test.cc | |
parent | 1a6d7f5acd50ed23c38f14e11e563f771a596656 (diff) |
Added uint8 support for Pack.
PiperOrigin-RevId: 209463575
Diffstat (limited to 'tensorflow/contrib/lite/kernels/pack_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/pack_test.cc | 40 |
1 files changed, 37 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/kernels/pack_test.cc b/tensorflow/contrib/lite/kernels/pack_test.cc index 485a50ad3a..c70dbd2764 100644 --- a/tensorflow/contrib/lite/kernels/pack_test.cc +++ b/tensorflow/contrib/lite/kernels/pack_test.cc @@ -51,6 +51,7 @@ class PackOpModel : public SingleOpModel { int output_; }; +// float32 tests. TEST(PackOpTest, FloatThreeInputs) { PackOpModel<float> model({TensorType_FLOAT32, {2}}, 0, 3); model.SetInput(0, {1, 4}); @@ -81,7 +82,8 @@ TEST(PackOpTest, FloatMultilDimensions) { ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); } -TEST(PackOpTest, IntThreeInputs) { +// int32 tests. +TEST(PackOpTest, Int32ThreeInputs) { PackOpModel<int32_t> model({TensorType_INT32, {2}}, 0, 3); model.SetInput(0, {1, 4}); model.SetInput(1, {2, 5}); @@ -91,7 +93,7 @@ TEST(PackOpTest, IntThreeInputs) { EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); } -TEST(PackOpTest, IntThreeInputsDifferentAxis) { +TEST(PackOpTest, Int32ThreeInputsDifferentAxis) { PackOpModel<int32_t> model({TensorType_INT32, {2}}, 1, 3); model.SetInput(0, {1, 4}); model.SetInput(1, {2, 5}); @@ -101,7 +103,7 @@ TEST(PackOpTest, IntThreeInputsDifferentAxis) { EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TEST(PackOpTest, IntMultilDimensions) { +TEST(PackOpTest, Int32MultilDimensions) { PackOpModel<int32_t> model({TensorType_INT32, {2, 3}}, 1, 2); model.SetInput(0, {1, 2, 3, 4, 5, 6}); model.SetInput(1, {7, 8, 9, 10, 11, 12}); @@ -110,6 +112,38 @@ TEST(PackOpTest, IntMultilDimensions) { EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); } + +// uint8 +TEST(PackOpTest, Uint8ThreeInputs) { + PackOpModel<uint8_t> model({TensorType_UINT8, {2}}, 0, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); +} + +TEST(PackOpTest, Uint8ThreeInputsDifferentAxis) { + PackOpModel<uint8_t> model({TensorType_UINT8, {2}}, 1, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(PackOpTest, Uint8MultilDimensions) { + PackOpModel<uint8_t> model({TensorType_UINT8, {2, 3}}, 1, 2); + model.SetInput(0, {1, 2, 3, 4, 5, 6}); + model.SetInput(1, {7, 8, 9, 10, 11, 12}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} + } // namespace } // namespace tflite |