aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/fold_constants_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-18 12:53:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-18 13:00:35 -0700
commit38bcb3c02fbc5185d6c1fb7e8327a070284b66e4 (patch)
tree39befc19c046875ef9ba9ee675a878e048ebad4f /tensorflow/tools/graph_transforms/fold_constants_test.cc
parent09ff3f7296a66c39535e097ecb6b82e3fc42ba30 (diff)
Bug fixes for fold_constants_lib.
1. Tensor names in TF may be in the form of "a:0", "a:1", or "a" as a shorthand notation of "a:0". FoldConstant library always expected the shorthand notation, and did not handle the cases where explicit notation was passed to input or output list. This means that this library could not handle the case when input or output were not the first output of a node. 2. To match the input nodes in the original graph and the added Recv nodes in rewritten graph, FoldConstant library used prefix matching. Unfortunately, this means that when a input name is a prefix of another input name, there is possibility that wrong Recv node gets matched. For example, if input names were "placeholder" and "placeholder_1", then it did not handle the case very well. 3. RemoveUnusedNodes() in FoldConstants lib could remove nodes which output depended on. This happened when an input name points to a node with multiple outputs and not all outputs of that node were included in the input names. 4. ReplaceSendRecvs() in FoldConstants lib assumed that all input nodes are removed during rewriting the graph. This assumption is not necessarily true, and it could add a duplicate node in the graph. PiperOrigin-RevId: 172641947
Diffstat (limited to 'tensorflow/tools/graph_transforms/fold_constants_test.cc')
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_test.cc85
1 files changed, 84 insertions, 1 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc
index fd4188a6a4..41106de008 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc
@@ -74,6 +74,9 @@ class ConstantFoldingTest : public ::testing::Test {
TestConstantFolding(graph_def,
{{"placeholder_expect_remains", placeholder_tensor}},
{}, {"output_expect_remains"}, {});
+ TestConstantFolding(graph_def,
+ {{"placeholder_expect_remains:0", placeholder_tensor}},
+ {}, {"output_expect_remains:0"}, {});
}
void TestOpExclusionAdd() {
@@ -256,10 +259,40 @@ class ConstantFoldingTest : public ::testing::Test {
EXPECT_EQ(0, node_map.count("new_send"));
}
+ void TestReplaceSendRecvsPrefixNames() {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ auto o_root = tensorflow::Scope::NewRootScope();
+ auto a = Placeholder(o_root.WithOpName("placeholder"), DT_FLOAT);
+ auto b = Placeholder(o_root.WithOpName("placeholder_1"), DT_FLOAT);
+ auto add_o = Add(o_root.WithOpName("add"), a, b);
+ GraphDef o_graph_def;
+ TF_ASSERT_OK(o_root.ToGraphDef(&o_graph_def));
+
+ auto n_root = tensorflow::Scope::NewRootScope();
+ auto c = _Recv(n_root.WithOpName("_recv_placeholder_0"), DT_FLOAT, "", "",
+ 0, "");
+ auto d = _Recv(n_root.WithOpName("_recv_placeholder_1_0"), DT_FLOAT, "", "",
+ 0, "");
+ auto add_n = Add(n_root.WithOpName("add"), c, d);
+ GraphDef n_graph_def;
+ TF_ASSERT_OK(n_root.ToGraphDef(&n_graph_def));
+
+ GraphDef result_graph_def;
+ TF_ASSERT_OK(graph_transforms::ReplaceSendRecvs(
+ o_graph_def, n_graph_def, {"placeholder", "placeholder_1"}, {"add"},
+ &result_graph_def));
+
+ std::map<string, const NodeDef*> node_map;
+ graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
+ EXPECT_EQ(1, node_map.count("placeholder"));
+ EXPECT_EQ(1, node_map.count("placeholder_1"));
+ EXPECT_EQ(1, node_map.count("add"));
+ }
+
void TestRemoveUnusedNodes() {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
auto root = tensorflow::Scope::NewRootScope();
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 100;
@@ -295,6 +328,48 @@ class ConstantFoldingTest : public ::testing::Test {
EXPECT_EQ(1, node_map.count("output"));
EXPECT_EQ(0, node_map.count("unused"));
}
+
+ void TestRemoveUnusedNodesMultipleOutputs() {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ auto root = tensorflow::Scope::NewRootScope();
+
+ // a b
+ // \ /
+ // shape_n
+ // \ /
+ // c
+ auto a = Placeholder(root.WithOpName("a"), DT_FLOAT);
+ auto b = Placeholder(root.WithOpName("b"), DT_FLOAT);
+ auto shape_n = ShapeN(root.WithOpName("shape_n"), {Output(a), Output(b)});
+ auto c = Add(root.WithOpName("c"), shape_n[0], shape_n[1]);
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+ GraphDef result_graph_def;
+ TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
+ graph_def, {{shape_n[0].name()}, {"c"}}, &result_graph_def));
+
+ // Only one output of shape_n node is fed input. Hence the graph search
+ // should propagate to inputs of shape_n. Nothing to remove here.
+ std::map<string, const NodeDef*> node_map;
+ graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
+ EXPECT_EQ(1, node_map.count("a"));
+ EXPECT_EQ(1, node_map.count("b"));
+ EXPECT_EQ(1, node_map.count("c"));
+
+ result_graph_def.Clear();
+ TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
+ graph_def, {{shape_n[0].name(), shape_n[1].name()}, {"c"}},
+ &result_graph_def));
+
+ // Both outputs of shape_n node are fed inputs. shape_n does not function
+ // and inputs to shape_n should be removed.
+ node_map.clear();
+ graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
+ EXPECT_EQ(0, node_map.count("a"));
+ EXPECT_EQ(0, node_map.count("b"));
+ EXPECT_EQ(1, node_map.count("c"));
+ }
};
TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); }
@@ -309,7 +384,15 @@ TEST_F(ConstantFoldingTest, TestPreserveOutputShapes) {
TEST_F(ConstantFoldingTest, TestReplaceSendRecvs) { TestReplaceSendRecvs(); }
+TEST_F(ConstantFoldingTest, TestReplaceSendRecvsPrefixNames) {
+ TestReplaceSendRecvsPrefixNames();
+}
+
TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); }
+TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) {
+ TestRemoveUnusedNodesMultipleOutputs();
+}
+
} // namespace graph_transforms
} // namespace tensorflow