diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-01-29 10:31:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-29 10:37:43 -0800 |
commit | 730071d0dca35a9e08f3bdc49661ae34d109da74 (patch) | |
tree | a8e25ab416d0501ed5cf3236832388b11287a267 /tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc | |
parent | 1f26c65254268730b7409f517d1ed1b554d01e50 (diff) |
Make TFLite BatchToSpaceND op have parity with TF BatchToSpaceND op.
PiperOrigin-RevId: 183687487
Diffstat (limited to 'tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc | 91 |
1 files changed, 71 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc index 3ec4efbebc..c9152bf967 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc @@ -26,37 +26,88 @@ using ::testing::ElementsAreArray; class BatchToSpaceNDOpModel : public SingleOpModel { public: - BatchToSpaceNDOpModel(std::initializer_list<int> input_shape, - std::initializer_list<int> block_shape, - std::initializer_list<int> before_crops, - std::initializer_list<int> after_crops) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, - BuiltinOptions_BatchToSpaceNDOptions, - CreateBatchToSpaceNDOptions( - builder_, builder_.CreateVector<int>(block_shape), - builder_.CreateVector<int>(before_crops), - builder_.CreateVector<int>(after_crops)) - .Union()); - BuildInterpreter({input_shape}); - } - void SetInput(std::initializer_list<float> data) { PopulateTensor<float>(input_, data); } + void SetBlockShape(std::initializer_list<int> data) { + PopulateTensor<int>(block_shape_, data); + } + + void SetCrops(std::initializer_list<int> data) { + PopulateTensor<int>(crops_, data); + } + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } std::vector<int> GetOutputShape() { return GetTensorShape(output_); } - private: + protected: int input_; + int block_shape_; + int crops_; int output_; }; -TEST(BatchToSpaceNDOpTest, SimpleTest) { - BatchToSpaceNDOpModel m({4, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}); +// Tests case where block_shape and crops are const tensors. +// +// Example usage is as follows: +// BatchToSpaceNDOpConstModel m(input_shape, block_shape, crops); +// m.SetInput(input_data); +// m.Invoke(); +class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel { + public: + BatchToSpaceNDOpConstModel(std::initializer_list<int> input_shape, + std::initializer_list<int> block_shape, + std::initializer_list<int> crops) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); + crops_ = AddConstInput(TensorType_INT32, crops, {2, 2}); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Tests case where block_shape and crops are non-const tensors. +// +// Example usage is as follows: +// BatchToSpaceNDOpDynamicModel m(input_shape); +// m.SetInput(input_data); +// m.SetBlockShape(block_shape); +// m.SetPaddings(crops); +// m.Invoke(); +class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel { + public: + BatchToSpaceNDOpDynamicModel(std::initializer_list<int> input_shape) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddInput(TensorType_INT32); + crops_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions(builder_).Union()); + BuildInterpreter({input_shape, {2}, {2, 2}}); + } +}; + +TEST(BatchToSpaceNDOpTest, SimpleConstTest) { + BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, + 4, 8, 11, 15, 12, 16})); +} + +TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) { + BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetCrops({0, 0, 0, 0}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, @@ -64,7 +115,7 @@ TEST(BatchToSpaceNDOpTest, SimpleTest) { } TEST(BatchToSpaceNDOpTest, InvalidShapeTest) { - EXPECT_DEATH(BatchToSpaceNDOpModel({3, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}), + EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}), "Cannot allocate tensors"); } |