diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties_test.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties_test.cc | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index a53f6414c3..3e44b222fd 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -952,6 +952,39 @@ TEST_F(GraphPropertiesTest, Performance) { TF_CHECK_OK(properties.InferStatically(false)); } +TEST_F(GraphPropertiesTest, StridedSlicesOfShapes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = + ops::Placeholder(s.WithOpName("a"), DT_FLOAT, + ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); + auto shp = ops::Shape(s.WithOpName("shape"), {a}); + + Output index1 = ops::Const(s.WithOpName("index1"), 0, {1}); + Output index2 = ops::Const(s.WithOpName("index2"), 1, {1}); + Output index3 = ops::Const(s.WithOpName("index3"), 2, {1}); + + Output b = ops::StridedSlice(s.WithOpName("b"), shp, index1, index2, index2); + Output c = ops::StridedSlice(s.WithOpName("c"), shp, index2, index3, index2); + + Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {}); + Output o1 = ops::Fill(s.WithOpName("o1"), b, zero); + Output o2 = ops::Fill(s.WithOpName("o2"), c, zero); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto shape_a = properties.GetOutputProperties("a").at(0).shape(); + const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape(); + const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape(); + EXPECT_EQ(2, shape_a.dim_size()); + EXPECT_EQ(1, shape_o1.dim_size()); + EXPECT_EQ(1, shape_o2.dim_size()); + EXPECT_EQ(shape_a.dim(0).size(), shape_o1.dim(0).size()); + EXPECT_EQ(shape_a.dim(1).size(), shape_o2.dim(0).size()); +} + } // namespace } // namespace grappler } // namespace tensorflow |