diff options
author | 2018-03-14 20:39:10 -0700 | |
---|---|---|
committer | 2018-03-14 20:44:02 -0700 | |
commit | 9037e241de1e64044ff55ab539ccc1fb013c178a (patch) | |
tree | f7b8bda19a5efdd57f99ce9cd7b0bf6fed211628 /tensorflow/core/grappler/utils_test.cc | |
parent | 357cd4b8b2f960520fc57b6cfbf41117a2a20fc7 (diff) |
Enable Add/AddN tree rewrite for symbolically equal shapes.
1) Rewrite a tree of Add/AddN ops with a single AddN,
if all shapes are symbolically equal
2) Lookup shape properties using GraphProperties instead
of direct access to Node attributes
PiperOrigin-RevId: 189131726
Diffstat (limited to 'tensorflow/core/grappler/utils_test.cc')
-rw-r--r-- | tensorflow/core/grappler/utils_test.cc | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index eabce5b5ee..49a1996d25 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -292,6 +292,47 @@ TEST_F(UtilsTest, DedupControlInputs) { EXPECT_EQ("gnu", foo.input(1)); } +TEST_F(UtilsTest, NumNonControlOutputs) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + // *) Round node has control dependency edge from Add, which + // is not on this scheme (ASCII graphics limitation). + // + // *Round [Sqrt, Shape] + // | | + // | ctrl | + // Mul ------> Add + // / \ / \ + // x y a b + auto x = ops::Variable(s.WithOpName("x"), {1, 2}, DT_FLOAT); + auto y = ops::Variable(s.WithOpName("y"), {1, 2}, DT_FLOAT); + auto a = ops::Variable(s.WithOpName("a"), {1, 2}, DT_FLOAT); + auto b = ops::Variable(s.WithOpName("b"), {1, 2}, DT_FLOAT); + + auto mul = ops::Multiply(s.WithOpName("mul"), x, y); + auto add = ops::Add(s.WithOpName("add").WithControlDependencies(mul), a, b); + + auto shape = ops::Shape(s.WithOpName("shape"), add); + auto sqrt = ops::Sqrt(s.WithOpName("sqrt"), add); + + auto round = + ops::Round(s.WithOpName("round").WithControlDependencies(add), mul); + + GraphDef graph; + TF_CHECK_OK(s.ToGraphDef(&graph)); + NodeMap node_map(&graph); + + const NodeDef* add_node = node_map.GetNode("add"); + ASSERT_TRUE(add_node != nullptr); + + // [a, b] are only non-control inputs + EXPECT_EQ(2, NumNonControlInputs(*add_node)); + // [sqrt, shape] are non control outputs + EXPECT_EQ(2, NumNonControlOutputs(*add_node, node_map)); + // sqrt is the only data output + EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map)); +} + TEST_F(UtilsTest, DeleteNodes) {} } // namespace |