aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/remove_nodes_test.cc
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2017-05-24 19:05:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-24 19:09:20 -0700
commit70fc6abad77e276145e28ecda735c80c376b1036 (patch)
treee4f5da649006aa95ee25f61a92dc9081cfe047c8 /tensorflow/tools/graph_transforms/remove_nodes_test.cc
parent4a09e96797fdcf55b308fade4fd719ef77497d0d (diff)
Enable removal of nodes with multiple inputs
PiperOrigin-RevId: 157068219
Diffstat (limited to 'tensorflow/tools/graph_transforms/remove_nodes_test.cc')
-rw-r--r--tensorflow/tools/graph_transforms/remove_nodes_test.cc56
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/remove_nodes_test.cc b/tensorflow/tools/graph_transforms/remove_nodes_test.cc
index e87ea1daa6..d8d85a3b47 100644
--- a/tensorflow/tools/graph_transforms/remove_nodes_test.cc
+++ b/tensorflow/tools/graph_transforms/remove_nodes_test.cc
@@ -210,6 +210,58 @@ class RemoveNodesTest : public ::testing::Test {
EXPECT_EQ(0, node_lookup.count("identity_node2"));
EXPECT_EQ(0, node_lookup.count("identity_node3"));
}
+
+ void TestRemoveMultipleInputs() {
+ GraphDef graph_def;
+
+ NodeDef* const_node1 = graph_def.add_node();
+ const_node1->set_name("const_node1");
+ const_node1->set_op("Const");
+
+ NodeDef* const_node2 = graph_def.add_node();
+ const_node2->set_name("const_node2");
+ const_node2->set_op("Const");
+
+ NodeDef* const_node3 = graph_def.add_node();
+ const_node3->set_name("const_node3");
+ const_node3->set_op("Const");
+
+ NodeDef* const_node4 = graph_def.add_node();
+ const_node4->set_name("const_node4");
+ const_node4->set_op("Const");
+
+ NodeDef* fake_quant_node = graph_def.add_node();
+ fake_quant_node->set_name("fake_quant_node");
+ fake_quant_node->set_op("FakeQuantWithMinMaxVars");
+ fake_quant_node->add_input("const_node1");
+ fake_quant_node->add_input("const_node2");
+ fake_quant_node->add_input("const_node3");
+
+ NodeDef* add_node = graph_def.add_node();
+ add_node->set_name("add_node");
+ add_node->set_op("Add");
+ add_node->add_input("fake_quant_node");
+ add_node->add_input("const_node4");
+
+ GraphDef result;
+ TransformFuncContext context;
+ context.input_names = {};
+ context.output_names = {"add_node"};
+ context.params.insert(std::pair<string, std::vector<string>>(
+ {"op", {string("FakeQuantWithMinMaxVars")}}));
+ context.params.insert(
+ std::pair<string, std::vector<string>>({"max_inputs", {string("3")}}));
+ TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
+
+ std::map<string, const NodeDef*> node_lookup;
+ MapNamesToNodes(result, &node_lookup);
+ ASSERT_EQ(1, node_lookup.count("const_node1"));
+ ASSERT_EQ(1, node_lookup.count("const_node4"));
+ ASSERT_EQ(0, node_lookup.count("fake_quant_node"));
+ ASSERT_EQ(1, node_lookup.count("add_node"));
+ EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
+ EXPECT_EQ("const_node4", node_lookup.at("add_node")->input(1));
+ }
};
TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); }
@@ -218,5 +270,9 @@ TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); }
TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); }
+TEST_F(RemoveNodesTest, TestRemoveMultipleInputs) {
+ TestRemoveMultipleInputs();
+}
+
} // namespace graph_transforms
} // namespace tensorflow