aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-08 04:12:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 04:14:40 -0700
commit1c241ba791f578a67c80e932cbbb06b5af5ca81a (patch)
treeaf99cf35186e0d49a5eb48a039f10b524405ea1d /tensorflow/tools/graph_transforms
parent16c1d25110e48b8cecbf61ea8e15a7c9da26dd83 (diff)
Fix RemoveUnusedNodes generating invalid graphs for PlaceholderWithDefault inputs
PiperOrigin-RevId: 199776409
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_lib.cc26
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_test.cc46
2 files changed, 26 insertions, 46 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
index 85660f94a8..f858411876 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
@@ -117,6 +117,31 @@ Status ReplaceSendRecvs(const GraphDef& original_graph_def,
return Status::OK();
}
+Status RewriteInputsAsPlaceholders(const TransformFuncContext& context,
+ GraphDef* graph_def) {
+ std::unordered_set<string> input_names;
+ for (const string& input_name : context.input_names) {
+ input_names.insert(ParseTensorName(input_name).first.ToString());
+ }
+
+ for (NodeDef& node : *graph_def->mutable_node()) {
+ if (input_names.find(node.name()) == input_names.end()) {
+ continue;
+ }
+ if (node.op() == "PlaceholderWithDefault") {
+ node.set_op("Placeholder");
+ node.clear_input();
+ } else if (node.op() != "Placeholder") {
+ return errors::InvalidArgument(
+ "Input '", node.name(),
+ "' was expected to be a Placeholder or PlaceholderWithDefault op, "
+ "but was ",
+ node.op());
+ }
+ }
+ return Status::OK();
+}
+
Status RemoveUnusedNodes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
@@ -165,6 +190,7 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def,
input_graph_def,
[&](const NodeDef& node) { return used_nodes.count(node.name()) > 0; },
output_graph_def);
+ TF_RETURN_IF_ERROR(RewriteInputsAsPlaceholders(context, output_graph_def));
return Status::OK();
}
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc
index a082399a87..dcdc3c2906 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc
@@ -330,48 +330,6 @@ class ConstantFoldingTest : public ::testing::Test {
EXPECT_EQ(0, node_map.count("unused"));
}
- void TestRemoveUnusedNodesMultipleOutputs() {
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
- auto root = tensorflow::Scope::NewRootScope();
-
- // a b
- // \ /
- // shape_n
- // \ /
- // c
- auto a = Placeholder(root.WithOpName("a"), DT_FLOAT);
- auto b = Placeholder(root.WithOpName("b"), DT_FLOAT);
- auto shape_n = ShapeN(root.WithOpName("shape_n"), {Output(a), Output(b)});
- auto c = Add(root.WithOpName("c"), shape_n[0], shape_n[1]);
-
- GraphDef graph_def;
- TF_ASSERT_OK(root.ToGraphDef(&graph_def));
- GraphDef result_graph_def;
- TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
- graph_def, {{shape_n[0].name()}, {"c"}}, &result_graph_def));
-
- // Only one output of shape_n node is fed input. Hence the graph search
- // should propagate to inputs of shape_n. Nothing to remove here.
- std::map<string, const NodeDef*> node_map;
- graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
- EXPECT_EQ(1, node_map.count("a"));
- EXPECT_EQ(1, node_map.count("b"));
- EXPECT_EQ(1, node_map.count("c"));
-
- result_graph_def.Clear();
- TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
- graph_def, {{shape_n[0].name(), shape_n[1].name()}, {"c"}},
- &result_graph_def));
-
- // Both outputs of shape_n node are fed inputs. shape_n does not function
- // and inputs to shape_n should be removed.
- node_map.clear();
- graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
- EXPECT_EQ(0, node_map.count("a"));
- EXPECT_EQ(0, node_map.count("b"));
- EXPECT_EQ(1, node_map.count("c"));
- }
-
void TestMaxConstantSizeInBytes() {
auto root = tensorflow::Scope::NewRootScope();
@@ -431,10 +389,6 @@ TEST_F(ConstantFoldingTest, TestReplaceSendRecvsPrefixNames) {
TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); }
-TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) {
- TestRemoveUnusedNodesMultipleOutputs();
-}
-
TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) {
TestMaxConstantSizeInBytes();
}