aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/graph_properties_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties_test.cc')
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc33
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index a53f6414c3..3e44b222fd 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -952,6 +952,39 @@ TEST_F(GraphPropertiesTest, Performance) {
TF_CHECK_OK(properties.InferStatically(false));
}
+TEST_F(GraphPropertiesTest, StridedSlicesOfShapes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a =
+ ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
+ ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
+ auto shp = ops::Shape(s.WithOpName("shape"), {a});
+
+ Output index1 = ops::Const(s.WithOpName("index1"), 0, {1});
+ Output index2 = ops::Const(s.WithOpName("index2"), 1, {1});
+ Output index3 = ops::Const(s.WithOpName("index3"), 2, {1});
+
+ Output b = ops::StridedSlice(s.WithOpName("b"), shp, index1, index2, index2);
+ Output c = ops::StridedSlice(s.WithOpName("c"), shp, index2, index3, index2);
+
+ Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
+ Output o1 = ops::Fill(s.WithOpName("o1"), b, zero);
+ Output o2 = ops::Fill(s.WithOpName("o2"), c, zero);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
+ const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
+ const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
+ EXPECT_EQ(2, shape_a.dim_size());
+ EXPECT_EQ(1, shape_o1.dim_size());
+ EXPECT_EQ(1, shape_o2.dim_size());
+ EXPECT_EQ(shape_a.dim(0).size(), shape_o1.dim(0).size());
+ EXPECT_EQ(shape_a.dim(1).size(), shape_o2.dim(0).size());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow