aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_constructor_test.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-08-19 10:25:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-19 10:29:34 -0700
commite1030858725b485b0f848cc27597b8e2c2d8383f (patch)
tree1bdfcf53434398a44bbe6ab42b8f2feb24cdbbc1 /tensorflow/core/graph/graph_constructor_test.cc
parent181267c51f4827e351bb64a620b3dff83539bc11 (diff)
Don't create cond_input and body_input nodes when finishing while loop
These nodes are not needed, but they caused failures in functions with while loops because functions currently execute all ops in their bodies and these ops are placeholders without feeds. PiperOrigin-RevId: 165814802
Diffstat (limited to 'tensorflow/core/graph/graph_constructor_test.cc')
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc80
1 files changed, 80 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 6be8e36ab6..e448ce4927 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -1433,6 +1433,86 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) {
&refiner);
}
+TEST_F(GraphConstructorTest, ImportGraphDef_SkipMappedNodes_FullyMapped) {
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
+
+ // Populate graph with node we'll use in input map
+ ExpectOK("node { name: 'input' op: 'TestInput' }", ImportGraphDefOptions(),
+ &refiner);
+
+ // Create input_map and use it to import more nodes
+ ImportGraphDefOptions opts;
+ opts.skip_mapped_nodes = true;
+ opts.input_map[TensorId("new_input", 0)] = TensorId("input", 1);
+ opts.input_map[TensorId("new_input", 1)] = TensorId("input", 0);
+
+ ExpectOK(
+ R"EOF(
+ node { name: 'new_input' op: 'TestInput' }
+ node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
+ node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
+ )EOF",
+ opts, &refiner);
+
+ EXPECT_TRUE(HasNode("input"));
+ EXPECT_TRUE(HasNode("t1"));
+ EXPECT_TRUE(HasNode("t2"));
+ // `new_input` node is not imported because we set skip_mapped_nodes = true
+ // and all of its inputs are mapped
+ EXPECT_FALSE(HasNode("new_input"));
+
+ EXPECT_TRUE(HasEdge("input", 1, "t1", 0));
+ EXPECT_TRUE(HasEdge("input", 0, "t1", 1));
+ // Test that t2 is unaffected
+ EXPECT_TRUE(HasEdge("t1", 0, "t2", 0));
+
+ // Check that t1's NodeDef is consistent with graph
+ Node* t1 = FindNode("t1");
+ ASSERT_EQ(t1->requested_inputs().size(), 2);
+ ASSERT_EQ(t1->requested_inputs()[0], "input:1");
+ ASSERT_EQ(t1->requested_inputs()[1], "input:0");
+}
+
+TEST_F(GraphConstructorTest, ImportGraphDef_SkipMappedNodes_NotFullyMapped) {
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
+
+ // Populate graph with node we'll use in input map
+ ExpectOK("node { name: 'input' op: 'TestInput' }", ImportGraphDefOptions(),
+ &refiner);
+
+ // Create input_map and use it to import more nodes
+ ImportGraphDefOptions opts;
+ opts.skip_mapped_nodes = true;
+ opts.input_map[TensorId("new_input", 1)] = TensorId("input", 0);
+
+ ExpectOK(
+ R"EOF(
+ node { name: 'new_input' op: 'TestInput' }
+ node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
+ node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
+ )EOF",
+ opts, &refiner);
+
+ EXPECT_TRUE(HasNode("input"));
+ EXPECT_TRUE(HasNode("t1"));
+ EXPECT_TRUE(HasNode("t2"));
+ // `new_input` node is imported because not all of its inputs are mapped
+ EXPECT_TRUE(HasNode("new_input"));
+
+ EXPECT_FALSE(HasEdge("input", 1, "t1", 0));
+ EXPECT_TRUE(HasEdge("input", 0, "t1", 1));
+ EXPECT_TRUE(HasEdge("new_input", 0, "t1", 0));
+ EXPECT_FALSE(HasEdge("new_input", 1, "t1", 1));
+ // Test that t2 is unaffected
+ EXPECT_TRUE(HasEdge("t1", 0, "t2", 0));
+
+ // Check that t1's NodeDef is consistent with graph
+ Node* t1 = FindNode("t1");
+ ASSERT_EQ(t1->requested_inputs().size(), 2);
+ ASSERT_EQ(t1->requested_inputs()[0], "new_input:0");
+ ASSERT_EQ(t1->requested_inputs()[1], "input:0");
+}
+
TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensors) {
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());