aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar Jingyue Wu <jingyue@google.com>2018-09-25 09:32:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 09:39:44 -0700
commitaee2ab023837adbfc61253ffec07f8d2dcd6c2a8 (patch)
tree835e00c9055acc6ba56310221351490867126c4b /tensorflow/core/grappler
parentc0b63bef59bd2a94de2d1925259d1499d3ad04ea (diff)
Fix a bug in debug_stripper.
AsControlDependency accepts a node name not a tensor name. PiperOrigin-RevId: 214451885
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper_test.cc29
2 files changed, 31 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc
index 9701a038d0..800160e649 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc
@@ -38,7 +38,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
// be optimized away by dependency optimizer.
for (string& inp : *node.mutable_input()) {
if (!IsControlInput(inp)) {
- inp = AsControlDependency(inp);
+ inp = AsControlDependency(NodeName(inp));
}
}
} else if (IsCheckNumerics(node) || IsPrint(node)) {
@@ -54,7 +54,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
// input.
for (size_t i = 1; i < node.input_size(); ++i) {
if (!IsControlInput(node.input(i))) {
- *node.mutable_input(i) = AsControlDependency(node.input(i));
+ *node.mutable_input(i) = AsControlDependency(NodeName(node.input(i)));
}
}
}
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
index 96ceee791f..affd2d51c2 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
@@ -43,6 +43,35 @@ TEST_F(DebugStripperTest, OutputEqualToInput) {
CompareGraphs(item.graph, output);
}
+TEST_F(DebugStripperTest, StripAssertOnTwoOutputs) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape({6}));
+ auto split =
+ ops::Split(s.WithOpName("split"), /*axis=*/0, input, /*num_split=*/2);
+ Output x = split[0];
+ Output y = split[1];
+ Output ge = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
+ auto assert = ops::Assert(s.WithOpName("Assert"), ge, {x, y});
+ Output add = ops::Add(
+ s.WithOpName("add").WithControlDependencies({assert.operation}), x, y);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ for (const NodeDef& node : output.node()) {
+ for (const string& input : node.input()) {
+ if (IsControlInput(input)) {
+ EXPECT_EQ(input.find(':'), -1);
+ }
+ }
+ }
+}
+
TEST_F(DebugStripperTest, StripAssertFromGraph) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,