aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/op_types.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc12
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc26
3 files changed, 13 insertions, 28 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index b4ddd61c29..bdeb5c66fc 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -629,7 +629,8 @@ bool HasOpDef(const NodeDef& node) {
}
bool IsIdempotent(const NodeDef& node) {
- return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node);
+ return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
+ !ModifiesFrameInfo(node);
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d518685216..90be051764 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1722,19 +1722,15 @@ class RemoveIdempotentStage : public ArithmeticOptimizerStage {
~RemoveIdempotentStage() override = default;
bool IsSupported(const NodeDef* node) const override {
- return IsIdempotent(*node) && !IsInPreserveSet(*node);
+ return node->input_size() == 1 && IsIdempotent(*node) &&
+ !IsInPreserveSet(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
- auto root_scope_and_name = ParseNodeScopeAndName(node->name());
- const string new_name = OptimizedNodeName(root_scope_and_name);
- if (input->op() == node->op() && input->device() == node->device() &&
- IsIdempotent(*input) && !ctx().node_map->NodeExists(new_name)) {
- NodeDef* new_input_node = AddCopyNode(new_name, input);
- ForwardControlDependencies(new_input_node, {node});
- *simplified_node_name = new_input_node->name();
+ if (input->op() == node->op() && input->device() == node->device()) {
+ *simplified_node_name = node->input(0);
}
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index e1d55cdf5f..d0e6b04679 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -2976,12 +2976,8 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
- Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
- Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
- Output sn1 =
- ops::Snapshot(s.WithOpName("sn1").WithControlDependencies(ctrl1), a);
- Output sn2 =
- ops::Snapshot(s.WithOpName("sn2").WithControlDependencies(ctrl2), sn1);
+ Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
+ Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
Output id1 = ops::Identity(s.WithOpName("id1"), a);
Output id2 = ops::Identity(s.WithOpName("id2"), id1);
@@ -2997,32 +2993,24 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
EnableOnlyRemoveIdempotent(&optimizer);
OptimizeTwice(&optimizer, &item, &output);
- EXPECT_EQ(11, output.node_size());
+ EXPECT_EQ(7, output.node_size());
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "out1") {
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_sn2", node.input(0));
- found++;
- } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_sn2") {
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ("Snapshot", node.op());
- EXPECT_EQ("a", node.input(0));
- EXPECT_EQ("^ctrl1", node.input(1));
- EXPECT_EQ("^ctrl2", node.input(2));
+ EXPECT_EQ("sn1", node.input(0));
found++;
} else if (node.name() == "out2") {
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_id2", node.input(0));
+ EXPECT_EQ("id1", node.input(0));
found++;
- } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_id2") {
- EXPECT_EQ("Identity", node.op());
+ } else if (node.name() == "sn1") {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("a", node.input(0));
found++;
}
}
- EXPECT_EQ(4, found);
+ EXPECT_EQ(3, found);
auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(tensors.size(), tensors_expected.size());