diff options
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 26 |
3 files changed, 13 insertions, 28 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index b4ddd61c29..bdeb5c66fc 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -629,7 +629,8 @@ bool HasOpDef(const NodeDef& node) { } bool IsIdempotent(const NodeDef& node) { - return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node); + return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) && + !ModifiesFrameInfo(node); } } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d518685216..90be051764 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1722,19 +1722,15 @@ class RemoveIdempotentStage : public ArithmeticOptimizerStage { ~RemoveIdempotentStage() override = default; bool IsSupported(const NodeDef* node) const override { - return IsIdempotent(*node) && !IsInPreserveSet(*node); + return node->input_size() == 1 && IsIdempotent(*node) && + !IsInPreserveSet(*node); } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { NodeDef* input; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); - auto root_scope_and_name = ParseNodeScopeAndName(node->name()); - const string new_name = OptimizedNodeName(root_scope_and_name); - if (input->op() == node->op() && input->device() == node->device() && - IsIdempotent(*input) && !ctx().node_map->NodeExists(new_name)) { - NodeDef* new_input_node = AddCopyNode(new_name, input); - ForwardControlDependencies(new_input_node, {node}); - *simplified_node_name = new_input_node->name(); + if (input->op() == node->op() && input->device() == node->device()) { + *simplified_node_name = node->input(0); } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index e1d55cdf5f..d0e6b04679 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -2976,12 +2976,8 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) { TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 3.14f, {32}); - Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {}); - Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {}); - Output sn1 = - ops::Snapshot(s.WithOpName("sn1").WithControlDependencies(ctrl1), a); - Output sn2 = - ops::Snapshot(s.WithOpName("sn2").WithControlDependencies(ctrl2), sn1); + Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a); + Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1); Output out1 = ops::Identity(s.WithOpName("out1"), sn2); Output id1 = ops::Identity(s.WithOpName("id1"), a); Output id2 = ops::Identity(s.WithOpName("id2"), id1); @@ -2997,32 +2993,24 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) { EnableOnlyRemoveIdempotent(&optimizer); OptimizeTwice(&optimizer, &item, &output); - EXPECT_EQ(11, output.node_size()); + EXPECT_EQ(7, output.node_size()); int found = 0; for (const NodeDef& node : output.node()) { if (node.name() == "out1") { EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_sn2", node.input(0)); - found++; - } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_sn2") { - EXPECT_EQ(3, node.input_size()); - EXPECT_EQ("Snapshot", node.op()); - EXPECT_EQ("a", node.input(0)); - EXPECT_EQ("^ctrl1", node.input(1)); - EXPECT_EQ("^ctrl2", node.input(2)); + EXPECT_EQ("sn1", node.input(0)); found++; } else if (node.name() == "out2") { EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_id2", node.input(0)); + EXPECT_EQ("id1", node.input(0)); found++; - } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_id2") { - EXPECT_EQ("Identity", node.op()); + } else if (node.name() == "sn1") { EXPECT_EQ(1, node.input_size()); EXPECT_EQ("a", node.input(0)); found++; } } - EXPECT_EQ(4, found); + EXPECT_EQ(3, found); auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(tensors.size(), tensors_expected.size()); |