diff options
author | 2018-06-08 04:12:07 -0700 | |
---|---|---|
committer | 2018-06-08 04:14:40 -0700 | |
commit | 1c241ba791f578a67c80e932cbbb06b5af5ca81a (patch) | |
tree | af99cf35186e0d49a5eb48a039f10b524405ea1d /tensorflow/tools/graph_transforms | |
parent | 16c1d25110e48b8cecbf61ea8e15a7c9da26dd83 (diff) |
Fix RemoveUnusedNodes generating invalid graphs for PlaceholderWithDefault inputs
PiperOrigin-RevId: 199776409
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_constants_lib.cc | 26 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_constants_test.cc | 46 |
2 files changed, 26 insertions, 46 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc index 85660f94a8..f858411876 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc @@ -117,6 +117,31 @@ Status ReplaceSendRecvs(const GraphDef& original_graph_def, return Status::OK(); } +Status RewriteInputsAsPlaceholders(const TransformFuncContext& context, + GraphDef* graph_def) { + std::unordered_set<string> input_names; + for (const string& input_name : context.input_names) { + input_names.insert(ParseTensorName(input_name).first.ToString()); + } + + for (NodeDef& node : *graph_def->mutable_node()) { + if (input_names.find(node.name()) == input_names.end()) { + continue; + } + if (node.op() == "PlaceholderWithDefault") { + node.set_op("Placeholder"); + node.clear_input(); + } else if (node.op() != "Placeholder") { + return errors::InvalidArgument( + "Input '", node.name(), + "' was expected to be a Placeholder or PlaceholderWithDefault op, " + "but was ", + node.op()); + } + } + return Status::OK(); +} + Status RemoveUnusedNodes(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { @@ -165,6 +190,7 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def, input_graph_def, [&](const NodeDef& node) { return used_nodes.count(node.name()) > 0; }, output_graph_def); + TF_RETURN_IF_ERROR(RewriteInputsAsPlaceholders(context, output_graph_def)); return Status::OK(); } diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc index a082399a87..dcdc3c2906 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_test.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc @@ -330,48 +330,6 @@ class ConstantFoldingTest : public ::testing::Test { EXPECT_EQ(0, node_map.count("unused")); } - void TestRemoveUnusedNodesMultipleOutputs() { - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - auto root = tensorflow::Scope::NewRootScope(); - - // a b - // \ / - // shape_n - // \ / - // c - auto a = Placeholder(root.WithOpName("a"), DT_FLOAT); - auto b = Placeholder(root.WithOpName("b"), DT_FLOAT); - auto shape_n = ShapeN(root.WithOpName("shape_n"), {Output(a), Output(b)}); - auto c = Add(root.WithOpName("c"), shape_n[0], shape_n[1]); - - GraphDef graph_def; - TF_ASSERT_OK(root.ToGraphDef(&graph_def)); - GraphDef result_graph_def; - TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes( - graph_def, {{shape_n[0].name()}, {"c"}}, &result_graph_def)); - - // Only one output of shape_n node is fed input. Hence the graph search - // should propagate to inputs of shape_n. Nothing to remove here. - std::map<string, const NodeDef*> node_map; - graph_transforms::MapNamesToNodes(result_graph_def, &node_map); - EXPECT_EQ(1, node_map.count("a")); - EXPECT_EQ(1, node_map.count("b")); - EXPECT_EQ(1, node_map.count("c")); - - result_graph_def.Clear(); - TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes( - graph_def, {{shape_n[0].name(), shape_n[1].name()}, {"c"}}, - &result_graph_def)); - - // Both outputs of shape_n node are fed inputs. shape_n does not function - // and inputs to shape_n should be removed. - node_map.clear(); - graph_transforms::MapNamesToNodes(result_graph_def, &node_map); - EXPECT_EQ(0, node_map.count("a")); - EXPECT_EQ(0, node_map.count("b")); - EXPECT_EQ(1, node_map.count("c")); - } - void TestMaxConstantSizeInBytes() { auto root = tensorflow::Scope::NewRootScope(); @@ -431,10 +389,6 @@ TEST_F(ConstantFoldingTest, TestReplaceSendRecvsPrefixNames) { TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); } -TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) { - TestRemoveUnusedNodesMultipleOutputs(); -} - TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) { TestMaxConstantSizeInBytes(); } |