aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc72
1 files changed, 72 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 7271a29319..9e579098ef 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -1126,5 +1126,77 @@ TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) {
EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
}
}
+
+TEST_F(OpLevelCostEstimatorTest, MaybeGetMinimumShape) {
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(true);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ ExpectTensorShape({1, 1, 1, 1}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 1, &unknown_shapes);
+ EXPECT_FALSE(unknown_shapes);
+ ExpectTensorShape({1}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
+ EXPECT_FALSE(unknown_shapes);
+ ExpectTensorShape({1, 1}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ x.add_dim()->set_size(10);
+ x.add_dim()->set_size(20);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
+ EXPECT_FALSE(unknown_shapes);
+ ExpectTensorShape({10, 20}, y);
+
+ unknown_shapes = false;
+ TensorShapeProto z = MaybeGetMinimumShape(x, 4, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ EXPECT_EQ(4, z.dim_size());
+ ExpectTensorShape({10, 20, 1, 1}, z);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ x.add_dim()->set_size(10);
+ x.add_dim()->set_size(20);
+ x.add_dim()->set_size(-1);
+ x.add_dim()->set_size(20);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ ExpectTensorShape({10, 20, 1, 20}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ x.add_dim()->set_size(10);
+ x.add_dim()->set_size(20);
+ x.add_dim()->set_size(30);
+ x.add_dim()->set_size(20);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ ExpectTensorShape({10, 20}, y);
+ }
+}
} // end namespace grappler
} // end namespace tensorflow