diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-28 09:36:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-28 09:41:22 -0700 |
commit | 17c5907a0f35cc2644737478137ed2b558998da9 (patch) | |
tree | 13de773feaabb8a6a10136e0956340124e5160e5 /tensorflow/core/graph | |
parent | 184a5fd8da87f79f46c75c716a802863aee28a02 (diff) |
Only wait for one non-control input for Merge nodes if there is a loop. This is
to enable the propagation of shapes for conditionals, which also include Merge
nodes.
PiperOrigin-RevId: 160417770
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 39 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 18 |
2 files changed, 50 insertions, 7 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 38a780dfac..a2929d0210 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -45,6 +45,11 @@ inline bool IsMerge(const NodeDef& node_def) { return node_def.op() == "Merge" || node_def.op() == "RefMerge"; } +inline bool IsNextIteration(const NodeDef& node_def) { + return node_def.op() == "NextIteration" || + node_def.op() == "RefNextIteration"; +} + bool IsValidNodeName(StringPiece s, bool allow_internal_ops) { using ::tensorflow::strings::Scanner; return Scanner(s) @@ -365,24 +370,54 @@ Status GraphConstructor::BuildNodeIndex() { return Status::OK(); } +std::unordered_set<string> GetNextIterationNodes( + const GraphConstructor::NodeDefSlice& node_defs) { + std::unordered_set<string> next_iteration_nodes; + + for (int n = 0; n < node_defs.size(); ++n) { + const NodeDef& node_def = *node_defs[n]; + if (IsNextIteration(node_def)) { + next_iteration_nodes.insert(node_def.name()); + } + } + + return next_iteration_nodes; +} + Status GraphConstructor::InitFromEdges() { const int num_nodes = node_defs_.size(); pending_count_.reserve(num_nodes); outputs_.resize(num_nodes); + std::unordered_set<string> next_iteration_nodes_ = + GetNextIterationNodes(node_defs_); // Parse the inputs for each node. for (int n = 0; n < num_nodes; ++n) { const NodeDef& node_def = *node_defs_[n]; if (IsMerge(node_def)) { - // for merge only wait for one non-control input. + // Cycles in the graph are only allowed for while loops. A while loop is + // identified by an edge from a NextIteration node to a Merge node. For + // such Merge nodes, only wait for one non-control input before + // considering the node ready to process in Convert(). int32 num_control_edges = 0; + bool has_loop_back_edge = false; for (int i = 0; i < node_def.input_size(); ++i) { StringPiece input_name(node_def.input(i)); if (input_name.starts_with("^")) { num_control_edges++; + } else { + TensorId id(ParseTensorName(input_name)); + if (next_iteration_nodes_.find(id.first.ToString()) != + next_iteration_nodes_.end()) { + has_loop_back_edge = true; + } } } - pending_count_.push_back(num_control_edges + 1); + if (has_loop_back_edge) { + pending_count_.push_back(num_control_edges + 1); + } else { + pending_count_.push_back(node_def.input_size()); + } } else { pending_count_.push_back(node_def.input_size()); } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 8abf21235e..b8d1879fa0 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -1871,25 +1871,29 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) { // new_input opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0); - // ImportGraphDef only allows backedges into merge nodes (since backedges are - // only expected in while loops) + // ImportGraphDef only allows backedges into merge nodes that are part of + // while loops (since backedges are only expected in while loops) ExpectOK( R"EOF( node { name: 'new_input' op: 'TestInput' } - node { name: 'merge' op: 'Merge' input: [ 'new_input:0', 't1:0' ] + node { name: 'merge' op: 'Merge' input: [ 'new_input:0', 'next:0' ] attr { key: "N" value: { i: 2 } } attr { key: "T" value: { type: DT_FLOAT } } } node { name: 't1' op: 'TestMul' input: [ 'merge:0', 'merge:0' ] } + node { name: 'next' op: 'NextIteration' input: ['t1:0'] + attr { key: "T" value: { type: DT_FLOAT } } } )EOF", opts, &refiner); EXPECT_TRUE(HasNode("new_input")); EXPECT_TRUE(HasNode("merge")); EXPECT_TRUE(HasNode("t1")); + EXPECT_TRUE(HasNode("next")); // Sanity check we created cycle EXPECT_TRUE(HasEdge("merge", 0, "t1", 0)); - EXPECT_TRUE(HasEdge("t1", 0, "merge", 1)); + EXPECT_TRUE(HasEdge("t1", 0, "next", 0)); + EXPECT_TRUE(HasEdge("next", 0, "merge", 1)); // Test that control dep was added to exactly one node of cycle EXPECT_TRUE(HasControlEdge("W1", "merge")); @@ -1899,13 +1903,17 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) { Node* merge = FindNode("merge"); ASSERT_EQ(merge->requested_inputs().size(), 3); EXPECT_EQ(merge->requested_inputs()[0], "input:0"); - EXPECT_EQ(merge->requested_inputs()[1], "t1:0"); + EXPECT_EQ(merge->requested_inputs()[1], "next:0"); EXPECT_EQ(merge->requested_inputs()[2], "^W1"); Node* t1 = FindNode("t1"); ASSERT_EQ(t1->requested_inputs().size(), 2); EXPECT_EQ(t1->requested_inputs()[0], "merge:0"); EXPECT_EQ(t1->requested_inputs()[1], "merge:0"); + + Node* next = FindNode("next"); + ASSERT_EQ(next->requested_inputs().size(), 1); + EXPECT_EQ(next->requested_inputs()[0], "t1:0"); } TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsErrors) { |