aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-01-29 10:31:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 10:37:43 -0800
commit730071d0dca35a9e08f3bdc49661ae34d109da74 (patch)
treea8e25ab416d0501ed5cf3236832388b11287a267 /tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
parent1f26c65254268730b7409f517d1ed1b554d01e50 (diff)
Make TFLite BatchToSpaceND op have parity with TF BatchToSpaceND op.
PiperOrigin-RevId: 183687487
Diffstat (limited to 'tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc91
1 files changed, 71 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
index 3ec4efbebc..c9152bf967 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
@@ -26,37 +26,88 @@ using ::testing::ElementsAreArray;
class BatchToSpaceNDOpModel : public SingleOpModel {
public:
- BatchToSpaceNDOpModel(std::initializer_list<int> input_shape,
- std::initializer_list<int> block_shape,
- std::initializer_list<int> before_crops,
- std::initializer_list<int> after_crops) {
- input_ = AddInput(TensorType_FLOAT32);
- output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
- BuiltinOptions_BatchToSpaceNDOptions,
- CreateBatchToSpaceNDOptions(
- builder_, builder_.CreateVector<int>(block_shape),
- builder_.CreateVector<int>(before_crops),
- builder_.CreateVector<int>(after_crops))
- .Union());
- BuildInterpreter({input_shape});
- }
-
void SetInput(std::initializer_list<float> data) {
PopulateTensor<float>(input_, data);
}
+ void SetBlockShape(std::initializer_list<int> data) {
+ PopulateTensor<int>(block_shape_, data);
+ }
+
+ void SetCrops(std::initializer_list<int> data) {
+ PopulateTensor<int>(crops_, data);
+ }
+
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
- private:
+ protected:
int input_;
+ int block_shape_;
+ int crops_;
int output_;
};
-TEST(BatchToSpaceNDOpTest, SimpleTest) {
- BatchToSpaceNDOpModel m({4, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0});
+// Tests case where block_shape and crops are const tensors.
+//
+// Example usage is as follows:
+// BatchToSpaceNDOpConstModel m(input_shape, block_shape, crops);
+// m.SetInput(input_data);
+// m.Invoke();
+class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel {
+ public:
+ BatchToSpaceNDOpConstModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> block_shape,
+ std::initializer_list<int> crops) {
+ input_ = AddInput(TensorType_FLOAT32);
+ block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
+ crops_ = AddConstInput(TensorType_INT32, crops, {2, 2});
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
+ BuiltinOptions_BatchToSpaceNDOptions,
+ CreateBatchToSpaceNDOptions(builder_).Union());
+ BuildInterpreter({input_shape});
+ }
+};
+
+// Tests case where block_shape and crops are non-const tensors.
+//
+// Example usage is as follows:
+// BatchToSpaceNDOpDynamicModel m(input_shape);
+// m.SetInput(input_data);
+// m.SetBlockShape(block_shape);
+// m.SetPaddings(crops);
+// m.Invoke();
+class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel {
+ public:
+ BatchToSpaceNDOpDynamicModel(std::initializer_list<int> input_shape) {
+ input_ = AddInput(TensorType_FLOAT32);
+ block_shape_ = AddInput(TensorType_INT32);
+ crops_ = AddInput(TensorType_INT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
+ BuiltinOptions_BatchToSpaceNDOptions,
+ CreateBatchToSpaceNDOptions(builder_).Union());
+ BuildInterpreter({input_shape, {2}, {2, 2}});
+ }
+};
+
+TEST(BatchToSpaceNDOpTest, SimpleConstTest) {
+ BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0});
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7,
+ 4, 8, 11, 15, 12, 16}));
+}
+
+TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) {
+ BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ m.SetBlockShape({2, 2});
+ m.SetCrops({0, 0, 0, 0});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7,
@@ -64,7 +115,7 @@ TEST(BatchToSpaceNDOpTest, SimpleTest) {
}
TEST(BatchToSpaceNDOpTest, InvalidShapeTest) {
- EXPECT_DEATH(BatchToSpaceNDOpModel({3, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}),
+ EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}),
"Cannot allocate tensors");
}