aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-28 09:36:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-28 09:41:22 -0700
commit17c5907a0f35cc2644737478137ed2b558998da9 (patch)
tree13de773feaabb8a6a10136e0956340124e5160e5
parent184a5fd8da87f79f46c75c716a802863aee28a02 (diff)
Only wait for one non-control input for Merge nodes if there is a loop. This is
to enable the propagation of shapes for conditionals, which also include Merge nodes. PiperOrigin-RevId: 160417770
-rw-r--r--tensorflow/core/graph/graph_constructor.cc39
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc18
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc365
3 files changed, 415 insertions, 7 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 38a780dfac..a2929d0210 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -45,6 +45,11 @@ inline bool IsMerge(const NodeDef& node_def) {
return node_def.op() == "Merge" || node_def.op() == "RefMerge";
}
+inline bool IsNextIteration(const NodeDef& node_def) {
+ return node_def.op() == "NextIteration" ||
+ node_def.op() == "RefNextIteration";
+}
+
bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
using ::tensorflow::strings::Scanner;
return Scanner(s)
@@ -365,24 +370,54 @@ Status GraphConstructor::BuildNodeIndex() {
return Status::OK();
}
+std::unordered_set<string> GetNextIterationNodes(
+ const GraphConstructor::NodeDefSlice& node_defs) {
+ std::unordered_set<string> next_iteration_nodes;
+
+ for (int n = 0; n < node_defs.size(); ++n) {
+ const NodeDef& node_def = *node_defs[n];
+ if (IsNextIteration(node_def)) {
+ next_iteration_nodes.insert(node_def.name());
+ }
+ }
+
+ return next_iteration_nodes;
+}
+
Status GraphConstructor::InitFromEdges() {
const int num_nodes = node_defs_.size();
pending_count_.reserve(num_nodes);
outputs_.resize(num_nodes);
+ std::unordered_set<string> next_iteration_nodes_ =
+ GetNextIterationNodes(node_defs_);
// Parse the inputs for each node.
for (int n = 0; n < num_nodes; ++n) {
const NodeDef& node_def = *node_defs_[n];
if (IsMerge(node_def)) {
- // for merge only wait for one non-control input.
+ // Cycles in the graph are only allowed for while loops. A while loop is
+ // identified by an edge from a NextIteration node to a Merge node. For
+ // such Merge nodes, only wait for one non-control input before
+ // considering the node ready to process in Convert().
int32 num_control_edges = 0;
+ bool has_loop_back_edge = false;
for (int i = 0; i < node_def.input_size(); ++i) {
StringPiece input_name(node_def.input(i));
if (input_name.starts_with("^")) {
num_control_edges++;
+ } else {
+ TensorId id(ParseTensorName(input_name));
+ if (next_iteration_nodes_.find(id.first.ToString()) !=
+ next_iteration_nodes_.end()) {
+ has_loop_back_edge = true;
+ }
}
}
- pending_count_.push_back(num_control_edges + 1);
+ if (has_loop_back_edge) {
+ pending_count_.push_back(num_control_edges + 1);
+ } else {
+ pending_count_.push_back(node_def.input_size());
+ }
} else {
pending_count_.push_back(node_def.input_size());
}
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 8abf21235e..b8d1879fa0 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -1871,25 +1871,29 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) {
// new_input
opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0);
- // ImportGraphDef only allows backedges into merge nodes (since backedges are
- // only expected in while loops)
+ // ImportGraphDef only allows backedges into merge nodes that are part of
+ // while loops (since backedges are only expected in while loops)
ExpectOK(
R"EOF(
node { name: 'new_input' op: 'TestInput' }
- node { name: 'merge' op: 'Merge' input: [ 'new_input:0', 't1:0' ]
+ node { name: 'merge' op: 'Merge' input: [ 'new_input:0', 'next:0' ]
attr { key: "N" value: { i: 2 } }
attr { key: "T" value: { type: DT_FLOAT } } }
node { name: 't1' op: 'TestMul' input: [ 'merge:0', 'merge:0' ] }
+ node { name: 'next' op: 'NextIteration' input: ['t1:0']
+ attr { key: "T" value: { type: DT_FLOAT } } }
)EOF",
opts, &refiner);
EXPECT_TRUE(HasNode("new_input"));
EXPECT_TRUE(HasNode("merge"));
EXPECT_TRUE(HasNode("t1"));
+ EXPECT_TRUE(HasNode("next"));
// Sanity check we created cycle
EXPECT_TRUE(HasEdge("merge", 0, "t1", 0));
- EXPECT_TRUE(HasEdge("t1", 0, "merge", 1));
+ EXPECT_TRUE(HasEdge("t1", 0, "next", 0));
+ EXPECT_TRUE(HasEdge("next", 0, "merge", 1));
// Test that control dep was added to exactly one node of cycle
EXPECT_TRUE(HasControlEdge("W1", "merge"));
@@ -1899,13 +1903,17 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) {
Node* merge = FindNode("merge");
ASSERT_EQ(merge->requested_inputs().size(), 3);
EXPECT_EQ(merge->requested_inputs()[0], "input:0");
- EXPECT_EQ(merge->requested_inputs()[1], "t1:0");
+ EXPECT_EQ(merge->requested_inputs()[1], "next:0");
EXPECT_EQ(merge->requested_inputs()[2], "^W1");
Node* t1 = FindNode("t1");
ASSERT_EQ(t1->requested_inputs().size(), 2);
EXPECT_EQ(t1->requested_inputs()[0], "merge:0");
EXPECT_EQ(t1->requested_inputs()[1], "merge:0");
+
+ Node* next = FindNode("next");
+ ASSERT_EQ(next->requested_inputs().size(), 1);
+ EXPECT_EQ(next->requested_inputs()[0], "t1:0");
}
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsErrors) {
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index cc6f097cd0..29b6adef5e 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -309,6 +309,371 @@ TEST_F(GraphPropertiesTest, Queues) {
EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
}
+TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
+ // Python code used to generate the graph is below.
+ const string gdef_ascii = R"EOF(
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 7
+ }
+ }
+ }
+}
+node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 5
+ }
+ }
+ }
+}
+node {
+ name: "ones"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "Less"
+ op: "Less"
+ input: "Const"
+ input: "Const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "cond/Switch"
+ op: "Switch"
+ input: "Less"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "cond/switch_t"
+ op: "Identity"
+ input: "cond/Switch:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "cond/switch_f"
+ op: "Identity"
+ input: "cond/Switch"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "cond/pred_id"
+ op: "Identity"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "cond/concat/axis"
+ op: "Const"
+ input: "^cond/switch_t"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "cond/concat/Switch"
+ op: "Switch"
+ input: "ones"
+ input: "cond/pred_id"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@ones"
+ }
+ }
+ }
+}
+node {
+ name: "cond/concat"
+ op: "ConcatV2"
+ input: "cond/concat/Switch:1"
+ input: "cond/concat/Switch:1"
+ input: "cond/concat/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "cond/concat_1/axis"
+ op: "Const"
+ input: "^cond/switch_f"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "cond/concat_1/Switch"
+ op: "Switch"
+ input: "ones"
+ input: "cond/pred_id"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@ones"
+ }
+ }
+ }
+}
+node {
+ name: "cond/concat_1"
+ op: "ConcatV2"
+ input: "cond/concat_1/Switch"
+ input: "cond/concat_1/Switch"
+ input: "cond/concat_1/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "cond/Merge"
+ op: "Merge"
+ input: "cond/concat"
+ input: "cond/concat_1"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "concat/axis"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 2
+ }
+ }
+ }
+}
+node {
+ name: "concat"
+ op: "ConcatV2"
+ input: "cond/Merge"
+ input: "cond/Merge"
+ input: "concat/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+versions {
+ producer: 21
+}
+ )EOF";
+
+ // Test graph produced in python using:
+ /*
+ with tf.Graph().as_default():
+ x = tf.constant(2)
+ y = tf.constant(5)
+ z = tf.ones([1,1,1])
+ def f1(): return tf.concat([z, z], axis=0)
+ def f2(): return tf.concat([z, z], axis=1)
+ r = tf.cond(tf.less(x, y), f1, f2)
+ tf.concat([r, r], axis=2)
+ with open('/tmp/graph.pbtxt', 'w') as f:
+ f.write(str(tf.get_default_graph().as_graph_def()))
+ */
+
+ GrapplerItem item;
+ CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically());
+
+ std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
+ std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
+ "float: [1,2,1]"};
+ for (int i = 0; i < nodes.size(); i++) {
+ const auto props = properties.GetOutputProperties(nodes[i]);
+ const OpInfo::TensorProperties& prop = props[0];
+ EXPECT_EQ(DT_FLOAT, prop.dtype());
+ EXPECT_EQ(expected_outputs[i], PropToString(prop));
+ }
+}
+
TEST_F(GraphPropertiesTest, WhileLoop) {
// Python code used to generate the graph is below.
const string gdef_ascii = R"EOF(