diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-18 12:53:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-18 13:00:35 -0700 |
commit | 38bcb3c02fbc5185d6c1fb7e8327a070284b66e4 (patch) | |
tree | 39befc19c046875ef9ba9ee675a878e048ebad4f /tensorflow/tools/graph_transforms/fold_constants_test.cc | |
parent | 09ff3f7296a66c39535e097ecb6b82e3fc42ba30 (diff) |
Bug fixes for fold_constants_lib.
1. Tensor names in TF may be in the form of "a:0", "a:1", or "a" as a shorthand
notation of "a:0". FoldConstant library always expected the shorthand notation,
and did not handle the cases where explicit notation was passed to input or
output list. This means that this library could not handle the case when input
or output were not the first output of a node.
2. To match the input nodes in the original graph and the added Recv nodes in
rewritten graph, FoldConstant library used prefix matching. Unfortunately, this
means that when a input name is a prefix of another input name, there is
possibility that wrong Recv node gets matched. For example, if input names were
"placeholder" and "placeholder_1", then it did not handle the case very well.
3. RemoveUnusedNodes() in FoldConstants lib could remove nodes which output
depended on. This happened when an input name points to a node with multiple
outputs and not all outputs of that node were included in the input names.
4. ReplaceSendRecvs() in FoldConstants lib assumed that all input nodes are
removed during rewriting the graph. This assumption is not necessarily true,
and it could add a duplicate node in the graph.
PiperOrigin-RevId: 172641947
Diffstat (limited to 'tensorflow/tools/graph_transforms/fold_constants_test.cc')
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_constants_test.cc | 85 |
1 files changed, 84 insertions, 1 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc index fd4188a6a4..41106de008 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_test.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc @@ -74,6 +74,9 @@ class ConstantFoldingTest : public ::testing::Test { TestConstantFolding(graph_def, {{"placeholder_expect_remains", placeholder_tensor}}, {}, {"output_expect_remains"}, {}); + TestConstantFolding(graph_def, + {{"placeholder_expect_remains:0", placeholder_tensor}}, + {}, {"output_expect_remains:0"}, {}); } void TestOpExclusionAdd() { @@ -256,10 +259,40 @@ class ConstantFoldingTest : public ::testing::Test { EXPECT_EQ(0, node_map.count("new_send")); } + void TestReplaceSendRecvsPrefixNames() { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + auto o_root = tensorflow::Scope::NewRootScope(); + auto a = Placeholder(o_root.WithOpName("placeholder"), DT_FLOAT); + auto b = Placeholder(o_root.WithOpName("placeholder_1"), DT_FLOAT); + auto add_o = Add(o_root.WithOpName("add"), a, b); + GraphDef o_graph_def; + TF_ASSERT_OK(o_root.ToGraphDef(&o_graph_def)); + + auto n_root = tensorflow::Scope::NewRootScope(); + auto c = _Recv(n_root.WithOpName("_recv_placeholder_0"), DT_FLOAT, "", "", + 0, ""); + auto d = _Recv(n_root.WithOpName("_recv_placeholder_1_0"), DT_FLOAT, "", "", + 0, ""); + auto add_n = Add(n_root.WithOpName("add"), c, d); + GraphDef n_graph_def; + TF_ASSERT_OK(n_root.ToGraphDef(&n_graph_def)); + + GraphDef result_graph_def; + TF_ASSERT_OK(graph_transforms::ReplaceSendRecvs( + o_graph_def, n_graph_def, {"placeholder", "placeholder_1"}, {"add"}, + &result_graph_def)); + + std::map<string, const NodeDef*> node_map; + graph_transforms::MapNamesToNodes(result_graph_def, &node_map); + EXPECT_EQ(1, node_map.count("placeholder")); + EXPECT_EQ(1, node_map.count("placeholder_1")); + EXPECT_EQ(1, node_map.count("add")); + } + void TestRemoveUnusedNodes() { using namespace ::tensorflow::ops; // NOLINT(build/namespaces) auto root = tensorflow::Scope::NewRootScope(); - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) const int width = 100; @@ -295,6 +328,48 @@ class ConstantFoldingTest : public ::testing::Test { EXPECT_EQ(1, node_map.count("output")); EXPECT_EQ(0, node_map.count("unused")); } + + void TestRemoveUnusedNodesMultipleOutputs() { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + auto root = tensorflow::Scope::NewRootScope(); + + // a b + // \ / + // shape_n + // \ / + // c + auto a = Placeholder(root.WithOpName("a"), DT_FLOAT); + auto b = Placeholder(root.WithOpName("b"), DT_FLOAT); + auto shape_n = ShapeN(root.WithOpName("shape_n"), {Output(a), Output(b)}); + auto c = Add(root.WithOpName("c"), shape_n[0], shape_n[1]); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + GraphDef result_graph_def; + TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes( + graph_def, {{shape_n[0].name()}, {"c"}}, &result_graph_def)); + + // Only one output of shape_n node is fed input. Hence the graph search + // should propagate to inputs of shape_n. Nothing to remove here. + std::map<string, const NodeDef*> node_map; + graph_transforms::MapNamesToNodes(result_graph_def, &node_map); + EXPECT_EQ(1, node_map.count("a")); + EXPECT_EQ(1, node_map.count("b")); + EXPECT_EQ(1, node_map.count("c")); + + result_graph_def.Clear(); + TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes( + graph_def, {{shape_n[0].name(), shape_n[1].name()}, {"c"}}, + &result_graph_def)); + + // Both outputs of shape_n node are fed inputs. shape_n does not function + // and inputs to shape_n should be removed. + node_map.clear(); + graph_transforms::MapNamesToNodes(result_graph_def, &node_map); + EXPECT_EQ(0, node_map.count("a")); + EXPECT_EQ(0, node_map.count("b")); + EXPECT_EQ(1, node_map.count("c")); + } }; TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); } @@ -309,7 +384,15 @@ TEST_F(ConstantFoldingTest, TestPreserveOutputShapes) { TEST_F(ConstantFoldingTest, TestReplaceSendRecvs) { TestReplaceSendRecvs(); } +TEST_F(ConstantFoldingTest, TestReplaceSendRecvsPrefixNames) { + TestReplaceSendRecvsPrefixNames(); +} + TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); } +TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) { + TestRemoveUnusedNodesMultipleOutputs(); +} + } // namespace graph_transforms } // namespace tensorflow |