aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-14 20:39:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-14 20:44:02 -0700
commit9037e241de1e64044ff55ab539ccc1fb013c178a (patch)
treef7b8bda19a5efdd57f99ce9cd7b0bf6fed211628 /tensorflow/core/grappler/utils_test.cc
parent357cd4b8b2f960520fc57b6cfbf41117a2a20fc7 (diff)
Enable Add/AddN tree rewrite for symbolically equal shapes.
1) Rewrite a tree of Add/AddN ops with a single AddN, if all shapes are symbolically equal 2) Lookup shape properties using GraphProperties instead of direct access to Node attributes PiperOrigin-RevId: 189131726
Diffstat (limited to 'tensorflow/core/grappler/utils_test.cc')
-rw-r--r--tensorflow/core/grappler/utils_test.cc41
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index eabce5b5ee..49a1996d25 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -292,6 +292,47 @@ TEST_F(UtilsTest, DedupControlInputs) {
EXPECT_EQ("gnu", foo.input(1));
}
+TEST_F(UtilsTest, NumNonControlOutputs) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ // *) Round node has control dependency edge from Add, which
+ // is not on this scheme (ASCII graphics limitation).
+ //
+ // *Round [Sqrt, Shape]
+ // | |
+ // | ctrl |
+ // Mul ------> Add
+ // / \ / \
+ // x y a b
+ auto x = ops::Variable(s.WithOpName("x"), {1, 2}, DT_FLOAT);
+ auto y = ops::Variable(s.WithOpName("y"), {1, 2}, DT_FLOAT);
+ auto a = ops::Variable(s.WithOpName("a"), {1, 2}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {1, 2}, DT_FLOAT);
+
+ auto mul = ops::Multiply(s.WithOpName("mul"), x, y);
+ auto add = ops::Add(s.WithOpName("add").WithControlDependencies(mul), a, b);
+
+ auto shape = ops::Shape(s.WithOpName("shape"), add);
+ auto sqrt = ops::Sqrt(s.WithOpName("sqrt"), add);
+
+ auto round =
+ ops::Round(s.WithOpName("round").WithControlDependencies(add), mul);
+
+ GraphDef graph;
+ TF_CHECK_OK(s.ToGraphDef(&graph));
+ NodeMap node_map(&graph);
+
+ const NodeDef* add_node = node_map.GetNode("add");
+ ASSERT_TRUE(add_node != nullptr);
+
+ // [a, b] are only non-control inputs
+ EXPECT_EQ(2, NumNonControlInputs(*add_node));
+ // [sqrt, shape] are non control outputs
+ EXPECT_EQ(2, NumNonControlOutputs(*add_node, node_map));
+ // sqrt is the only data output
+ EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map));
+}
+
TEST_F(UtilsTest, DeleteNodes) {}
} // namespace