diff options
author | 2018-09-25 09:32:30 -0700 | |
---|---|---|
committer | 2018-09-25 09:39:44 -0700 | |
commit | aee2ab023837adbfc61253ffec07f8d2dcd6c2a8 (patch) | |
tree | 835e00c9055acc6ba56310221351490867126c4b /tensorflow/core/grappler | |
parent | c0b63bef59bd2a94de2d1925259d1499d3ad04ea (diff) |
Fix a bug in debug_stripper.
AsControlDependency accepts a node name not a tensor name.
PiperOrigin-RevId: 214451885
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r-- | tensorflow/core/grappler/optimizers/debug_stripper.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/debug_stripper_test.cc | 29 |
2 files changed, 31 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc index 9701a038d0..800160e649 100644 --- a/tensorflow/core/grappler/optimizers/debug_stripper.cc +++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc @@ -38,7 +38,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item, // be optimized away by dependency optimizer. for (string& inp : *node.mutable_input()) { if (!IsControlInput(inp)) { - inp = AsControlDependency(inp); + inp = AsControlDependency(NodeName(inp)); } } } else if (IsCheckNumerics(node) || IsPrint(node)) { @@ -54,7 +54,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item, // input. for (size_t i = 1; i < node.input_size(); ++i) { if (!IsControlInput(node.input(i))) { - *node.mutable_input(i) = AsControlDependency(node.input(i)); + *node.mutable_input(i) = AsControlDependency(NodeName(node.input(i))); } } } diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc index 96ceee791f..affd2d51c2 100644 --- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc +++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc @@ -43,6 +43,35 @@ TEST_F(DebugStripperTest, OutputEqualToInput) { CompareGraphs(item.graph, output); } +TEST_F(DebugStripperTest, StripAssertOnTwoOutputs) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape({6})); + auto split = + ops::Split(s.WithOpName("split"), /*axis=*/0, input, /*num_split=*/2); + Output x = split[0]; + Output y = split[1]; + Output ge = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y); + auto assert = ops::Assert(s.WithOpName("Assert"), ge, {x, y}); + Output add = ops::Add( + s.WithOpName("add").WithControlDependencies({assert.operation}), x, y); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + DebugStripper optimizer; + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + for (const NodeDef& node : output.node()) { + for (const string& input : node.input()) { + if (IsControlInput(input)) { + EXPECT_EQ(input.find(':'), -1); + } + } + } +} + TEST_F(DebugStripperTest, StripAssertFromGraph) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, |