diff options
author | 2017-05-24 19:05:37 -0700 | |
---|---|---|
committer | 2017-05-24 19:09:20 -0700 | |
commit | 70fc6abad77e276145e28ecda735c80c376b1036 (patch) | |
tree | e4f5da649006aa95ee25f61a92dc9081cfe047c8 /tensorflow/tools/graph_transforms/remove_nodes_test.cc | |
parent | 4a09e96797fdcf55b308fade4fd719ef77497d0d (diff) |
Enable removal of nodes with multiple inputs
PiperOrigin-RevId: 157068219
Diffstat (limited to 'tensorflow/tools/graph_transforms/remove_nodes_test.cc')
-rw-r--r-- | tensorflow/tools/graph_transforms/remove_nodes_test.cc | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/remove_nodes_test.cc b/tensorflow/tools/graph_transforms/remove_nodes_test.cc index e87ea1daa6..d8d85a3b47 100644 --- a/tensorflow/tools/graph_transforms/remove_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/remove_nodes_test.cc @@ -210,6 +210,58 @@ class RemoveNodesTest : public ::testing::Test { EXPECT_EQ(0, node_lookup.count("identity_node2")); EXPECT_EQ(0, node_lookup.count("identity_node3")); } + + void TestRemoveMultipleInputs() { + GraphDef graph_def; + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* const_node4 = graph_def.add_node(); + const_node4->set_name("const_node4"); + const_node4->set_op("Const"); + + NodeDef* fake_quant_node = graph_def.add_node(); + fake_quant_node->set_name("fake_quant_node"); + fake_quant_node->set_op("FakeQuantWithMinMaxVars"); + fake_quant_node->add_input("const_node1"); + fake_quant_node->add_input("const_node2"); + fake_quant_node->add_input("const_node3"); + + NodeDef* add_node = graph_def.add_node(); + add_node->set_name("add_node"); + add_node->set_op("Add"); + add_node->add_input("fake_quant_node"); + add_node->add_input("const_node4"); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"add_node"}; + context.params.insert(std::pair<string, std::vector<string>>( + {"op", {string("FakeQuantWithMinMaxVars")}})); + context.params.insert( + std::pair<string, std::vector<string>>({"max_inputs", {string("3")}})); + TF_ASSERT_OK(RemoveNodes(graph_def, context, &result)); + + std::map<string, const NodeDef*> node_lookup; + MapNamesToNodes(result, &node_lookup); + ASSERT_EQ(1, node_lookup.count("const_node1")); + ASSERT_EQ(1, node_lookup.count("const_node4")); + ASSERT_EQ(0, node_lookup.count("fake_quant_node")); + ASSERT_EQ(1, node_lookup.count("add_node")); + EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0)); + EXPECT_EQ("const_node4", node_lookup.at("add_node")->input(1)); + } }; TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); } @@ -218,5 +270,9 @@ TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); } TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); } +TEST_F(RemoveNodesTest, TestRemoveMultipleInputs) { + TestRemoveMultipleInputs(); +} + } // namespace graph_transforms } // namespace tensorflow |