diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index 7271a29319..9e579098ef 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -1126,5 +1126,77 @@ TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) { EXPECT_EQ(0, costs.num_ops_with_unknown_shapes); } } + +TEST_F(OpLevelCostEstimatorTest, MaybeGetMinimumShape) { + { + TensorShapeProto x; + x.set_unknown_rank(true); + bool unknown_shapes = false; + TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes); + EXPECT_TRUE(unknown_shapes); + ExpectTensorShape({1, 1, 1, 1}, y); + } + + { + TensorShapeProto x; + x.set_unknown_rank(false); + bool unknown_shapes = false; + TensorShapeProto y = MaybeGetMinimumShape(x, 1, &unknown_shapes); + EXPECT_FALSE(unknown_shapes); + ExpectTensorShape({1}, y); + } + + { + TensorShapeProto x; + x.set_unknown_rank(false); + bool unknown_shapes = false; + TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes); + EXPECT_FALSE(unknown_shapes); + ExpectTensorShape({1, 1}, y); + } + + { + TensorShapeProto x; + x.set_unknown_rank(false); + x.add_dim()->set_size(10); + x.add_dim()->set_size(20); + bool unknown_shapes = false; + TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes); + EXPECT_FALSE(unknown_shapes); + ExpectTensorShape({10, 20}, y); + + unknown_shapes = false; + TensorShapeProto z = MaybeGetMinimumShape(x, 4, &unknown_shapes); + EXPECT_TRUE(unknown_shapes); + EXPECT_EQ(4, z.dim_size()); + ExpectTensorShape({10, 20, 1, 1}, z); + } + + { + TensorShapeProto x; + x.set_unknown_rank(false); + x.add_dim()->set_size(10); + x.add_dim()->set_size(20); + x.add_dim()->set_size(-1); + x.add_dim()->set_size(20); + bool unknown_shapes = false; + TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes); + EXPECT_TRUE(unknown_shapes); + ExpectTensorShape({10, 20, 1, 20}, y); + } + + { + TensorShapeProto x; + x.set_unknown_rank(false); + x.add_dim()->set_size(10); + x.add_dim()->set_size(20); + x.add_dim()->set_size(30); + x.add_dim()->set_size(20); + bool unknown_shapes = false; + TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes); + EXPECT_TRUE(unknown_shapes); + ExpectTensorShape({10, 20}, y); + } +} } // end namespace grappler } // end namespace tensorflow |