diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-06-05 12:19:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-05 12:22:39 -0700 |
commit | 2b5f598fbd822f911ad305ae1e57325aefd50826 (patch) | |
tree | 30ced01eceaa62a99ea7908688df5f79bf4c46d6 | |
parent | 920df27282b3f5d03d79f54ef05cea305c2a30d7 (diff) |
Move ReplaceMulWithSquare to a separate optimizer stage.
PiperOrigin-RevId: 199338297
3 files changed, 73 insertions, 43 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 400af82627..561930f858 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2079,6 +2079,49 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage { } }; +// Replace Mul node with identical inputs with a Square. +class ReplaceMulWithSquare : public ArithmeticOptimizerStage { + public: + explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {} + ~ReplaceMulWithSquare() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsMul(*node) && node->input(0) == node->input(1); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const NodeScopeAndName mul = ParseNodeScopeAndName(node->name()); + const string optimized_node_name = OptimizedNodeName(mul); + if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK(); + + const DataType type = GetDataTypeFromAttr(*node, "T"); + bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); + + string task; + string device; + bool is_on_cpu = + DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && + str_util::StrContains(device, DEVICE_CPU); + + if (!is_complex || is_on_cpu) { + NodeDef* new_square_node = AddCopyNode(optimized_node_name, node); + new_square_node->set_op("Square"); + for (int i = 1; i < new_square_node->input_size(); ++i) { + new_square_node->set_input(i - 1, new_square_node->input(i)); + } + new_square_node->mutable_input()->RemoveLast(); + for (const string& input : new_square_node->input()) { + ctx().node_map->AddOutput(NodeName(input), new_square_node->name()); + } + *simplified_node_name = new_square_node->name(); + } + + return Status::OK(); + } +}; + } // namespace class UniqueNodes { @@ -2331,29 +2374,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( // ArithmeticOptimizerStage string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) { - if (node->op() == "Mul" && node->input(0) == node->input(1) && - !OptimizedNodeExists(*node, "square")) { - const DataType type = GetDataTypeFromAttr(*node, "T"); - bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); - string dontcare; - string device; - bool is_on_cpu = - DeviceNameUtils::SplitDeviceName(node->device(), &dontcare, &device) && - str_util::StrContains(device, DEVICE_CPU); - if (!is_complex || is_on_cpu) { - NodeDef* new_square_node = AddNode(*node, "square", /*copy_node=*/true); - new_square_node->set_op("Square"); - for (int i = 1; i < new_square_node->input_size(); ++i) { - new_square_node->set_input(i - 1, new_square_node->input(i)); - } - new_square_node->mutable_input()->RemoveLast(); - for (const string& input : new_square_node->input()) { - node_map_->AddOutput(NodeName(input), new_square_node->name()); - } - return new_square_node->name(); - } - } - if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) { // Discard aggregate nodes with a single input and no control dependencies. if (node->input_size() == 1) { @@ -2528,6 +2548,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext); if (options_.remove_negation) pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext); + if (options_.replace_mul_with_square) + pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext); if (options_.remove_logical_not) pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext); if (options_.reorder_cast_and_transpose) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index e6fc311929..8e00b83a70 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -74,6 +74,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_redundant_cast = true; bool remove_redundant_reshape = true; bool reorder_cast_and_transpose = true; + bool replace_mul_with_square = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index b9fec0f860..f15cbfe407 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -139,6 +139,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_negation = false; options.remove_logical_not = false; options.reorder_cast_and_transpose = false; + options.replace_mul_with_square = false; optimizer->options_ = options; } @@ -201,6 +202,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.reorder_cast_and_transpose = true; } + void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.replace_mul_with_square = true; + } + void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.hoist_cwise_unary_chains = true; @@ -345,33 +351,36 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, MulToSquare) { +TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2}); Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c); Output id = ops::Identity(s.WithOpName("id"), mul); + GrapplerItem item; + item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - std::vector<string> fetch = {"id"}; - auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); EXPECT_EQ(1, tensors_expected.size()); - ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + ArithmeticOptimizer optimizer; + EnableOnlyReplaceMulWithSquare(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); - EXPECT_EQ(5, output.node_size()); - EXPECT_EQ("id", output.node(3).name()); - EXPECT_EQ(OptimizedName("mul_square"), output.node(3).input(0)); - EXPECT_EQ("Square", output.node(4).op()); - EXPECT_EQ(OptimizedName("mul_square"), output.node(4).name()); - EXPECT_EQ(2, output.node(4).input_size()); - EXPECT_EQ("c", output.node(4).input(0)); - EXPECT_EQ("^d", output.node(4).input(1)); + EXPECT_EQ(4, output.node_size()); - auto tensors = EvaluateNodes(output, fetch); + NodeMap node_map(&output); + const string p = "ArithmeticOptimizer/ReplaceMulWithSquare"; + const NodeDef* square_node = node_map.GetNode(strings::StrCat(p, "_", "mul")); + + ASSERT_NE(square_node, nullptr); + EXPECT_EQ("Square", square_node->op()); + EXPECT_EQ("c", square_node->input(0)); + EXPECT_EQ("^d", square_node->input(1)); + + auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } @@ -386,12 +395,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) { auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1); auto id = ops::Identity(s.WithOpName("id"), recip2); - std::vector<string> fetch = {"id"}; - GrapplerItem item; - item.fetch = fetch; + item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); EXPECT_EQ(1, tensors_expected.size()); GraphDef output; @@ -404,7 +411,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) { EXPECT_EQ("id", output.node(1).name()); EXPECT_EQ("c", output.node(1).input(0)); - auto tensors = EvaluateNodes(output, fetch); + auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } |