diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-28 09:36:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-28 09:41:22 -0700 |
commit | 17c5907a0f35cc2644737478137ed2b558998da9 (patch) | |
tree | 13de773feaabb8a6a10136e0956340124e5160e5 | |
parent | 184a5fd8da87f79f46c75c716a802863aee28a02 (diff) |
Only wait for one non-control input for Merge nodes if there is a loop. This is
to enable the propagation of shapes for conditionals, which also include Merge
nodes.
PiperOrigin-RevId: 160417770
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 39 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 18 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties_test.cc | 365 |
3 files changed, 415 insertions, 7 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 38a780dfac..a2929d0210 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -45,6 +45,11 @@ inline bool IsMerge(const NodeDef& node_def) { return node_def.op() == "Merge" || node_def.op() == "RefMerge"; } +inline bool IsNextIteration(const NodeDef& node_def) { + return node_def.op() == "NextIteration" || + node_def.op() == "RefNextIteration"; +} + bool IsValidNodeName(StringPiece s, bool allow_internal_ops) { using ::tensorflow::strings::Scanner; return Scanner(s) @@ -365,24 +370,54 @@ Status GraphConstructor::BuildNodeIndex() { return Status::OK(); } +std::unordered_set<string> GetNextIterationNodes( + const GraphConstructor::NodeDefSlice& node_defs) { + std::unordered_set<string> next_iteration_nodes; + + for (int n = 0; n < node_defs.size(); ++n) { + const NodeDef& node_def = *node_defs[n]; + if (IsNextIteration(node_def)) { + next_iteration_nodes.insert(node_def.name()); + } + } + + return next_iteration_nodes; +} + Status GraphConstructor::InitFromEdges() { const int num_nodes = node_defs_.size(); pending_count_.reserve(num_nodes); outputs_.resize(num_nodes); + std::unordered_set<string> next_iteration_nodes_ = + GetNextIterationNodes(node_defs_); // Parse the inputs for each node. for (int n = 0; n < num_nodes; ++n) { const NodeDef& node_def = *node_defs_[n]; if (IsMerge(node_def)) { - // for merge only wait for one non-control input. + // Cycles in the graph are only allowed for while loops. A while loop is + // identified by an edge from a NextIteration node to a Merge node. For + // such Merge nodes, only wait for one non-control input before + // considering the node ready to process in Convert(). int32 num_control_edges = 0; + bool has_loop_back_edge = false; for (int i = 0; i < node_def.input_size(); ++i) { StringPiece input_name(node_def.input(i)); if (input_name.starts_with("^")) { num_control_edges++; + } else { + TensorId id(ParseTensorName(input_name)); + if (next_iteration_nodes_.find(id.first.ToString()) != + next_iteration_nodes_.end()) { + has_loop_back_edge = true; + } } } - pending_count_.push_back(num_control_edges + 1); + if (has_loop_back_edge) { + pending_count_.push_back(num_control_edges + 1); + } else { + pending_count_.push_back(node_def.input_size()); + } } else { pending_count_.push_back(node_def.input_size()); } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 8abf21235e..b8d1879fa0 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -1871,25 +1871,29 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) { // new_input opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0); - // ImportGraphDef only allows backedges into merge nodes (since backedges are - // only expected in while loops) + // ImportGraphDef only allows backedges into merge nodes that are part of + // while loops (since backedges are only expected in while loops) ExpectOK( R"EOF( node { name: 'new_input' op: 'TestInput' } - node { name: 'merge' op: 'Merge' input: [ 'new_input:0', 't1:0' ] + node { name: 'merge' op: 'Merge' input: [ 'new_input:0', 'next:0' ] attr { key: "N" value: { i: 2 } } attr { key: "T" value: { type: DT_FLOAT } } } node { name: 't1' op: 'TestMul' input: [ 'merge:0', 'merge:0' ] } + node { name: 'next' op: 'NextIteration' input: ['t1:0'] + attr { key: "T" value: { type: DT_FLOAT } } } )EOF", opts, &refiner); EXPECT_TRUE(HasNode("new_input")); EXPECT_TRUE(HasNode("merge")); EXPECT_TRUE(HasNode("t1")); + EXPECT_TRUE(HasNode("next")); // Sanity check we created cycle EXPECT_TRUE(HasEdge("merge", 0, "t1", 0)); - EXPECT_TRUE(HasEdge("t1", 0, "merge", 1)); + EXPECT_TRUE(HasEdge("t1", 0, "next", 0)); + EXPECT_TRUE(HasEdge("next", 0, "merge", 1)); // Test that control dep was added to exactly one node of cycle EXPECT_TRUE(HasControlEdge("W1", "merge")); @@ -1899,13 +1903,17 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) { Node* merge = FindNode("merge"); ASSERT_EQ(merge->requested_inputs().size(), 3); EXPECT_EQ(merge->requested_inputs()[0], "input:0"); - EXPECT_EQ(merge->requested_inputs()[1], "t1:0"); + EXPECT_EQ(merge->requested_inputs()[1], "next:0"); EXPECT_EQ(merge->requested_inputs()[2], "^W1"); Node* t1 = FindNode("t1"); ASSERT_EQ(t1->requested_inputs().size(), 2); EXPECT_EQ(t1->requested_inputs()[0], "merge:0"); EXPECT_EQ(t1->requested_inputs()[1], "merge:0"); + + Node* next = FindNode("next"); + ASSERT_EQ(next->requested_inputs().size(), 1); + EXPECT_EQ(next->requested_inputs()[0], "t1:0"); } TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsErrors) { diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index cc6f097cd0..29b6adef5e 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -309,6 +309,371 @@ TEST_F(GraphPropertiesTest, Queues) { EXPECT_EQ("float: [1,2,3]", PropToString(props5[2])); } +TEST_F(GraphPropertiesTest, MergeWithoutLoops) { + // Python code used to generate the graph is below. + const string gdef_ascii = R"EOF( +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 7 + } + } + } +} +node { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 5 + } + } + } +} +node { + name: "ones" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 1 + } + } + float_val: 1.0 + } + } + } +} +node { + name: "Less" + op: "Less" + input: "Const" + input: "Const_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "cond/Switch" + op: "Switch" + input: "Less" + input: "Less" + attr { + key: "T" + value { + type: DT_BOOL + } + } +} +node { + name: "cond/switch_t" + op: "Identity" + input: "cond/Switch:1" + attr { + key: "T" + value { + type: DT_BOOL + } + } +} +node { + name: "cond/switch_f" + op: "Identity" + input: "cond/Switch" + attr { + key: "T" + value { + type: DT_BOOL + } + } +} +node { + name: "cond/pred_id" + op: "Identity" + input: "Less" + attr { + key: "T" + value { + type: DT_BOOL + } + } +} +node { + name: "cond/concat/axis" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "cond/concat/Switch" + op: "Switch" + input: "ones" + input: "cond/pred_id" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@ones" + } + } + } +} +node { + name: "cond/concat" + op: "ConcatV2" + input: "cond/concat/Switch:1" + input: "cond/concat/Switch:1" + input: "cond/concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "cond/concat_1/axis" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "cond/concat_1/Switch" + op: "Switch" + input: "ones" + input: "cond/pred_id" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@ones" + } + } + } +} +node { + name: "cond/concat_1" + op: "ConcatV2" + input: "cond/concat_1/Switch" + input: "cond/concat_1/Switch" + input: "cond/concat_1/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "cond/Merge" + op: "Merge" + input: "cond/concat" + input: "cond/concat_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "concat/axis" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "concat" + op: "ConcatV2" + input: "cond/Merge" + input: "cond/Merge" + input: "concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +versions { + producer: 21 +} + )EOF"; + + // Test graph produced in python using: + /* + with tf.Graph().as_default(): + x = tf.constant(2) + y = tf.constant(5) + z = tf.ones([1,1,1]) + def f1(): return tf.concat([z, z], axis=0) + def f2(): return tf.concat([z, z], axis=1) + r = tf.cond(tf.less(x, y), f1, f2) + tf.concat([r, r], axis=2) + with open('/tmp/graph.pbtxt', 'w') as f: + f.write(str(tf.get_default_graph().as_graph_def())) + */ + + GrapplerItem item; + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically()); + + std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"}; + std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]", + "float: [1,2,1]"}; + for (int i = 0; i < nodes.size(); i++) { + const auto props = properties.GetOutputProperties(nodes[i]); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ(expected_outputs[i], PropToString(prop)); + } +} + TEST_F(GraphPropertiesTest, WhileLoop) { // Python code used to generate the graph is below. const string gdef_ascii = R"EOF( |