aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-06-05 12:19:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 12:22:39 -0700
commit2b5f598fbd822f911ad305ae1e57325aefd50826 (patch)
tree30ced01eceaa62a99ea7908688df5f79bf4c46d6
parent920df27282b3f5d03d79f54ef05cea305c2a30d7 (diff)
Move ReplaceMulWithSquare to a separate optimizer stage.
PiperOrigin-RevId: 199338297
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc68
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc47
3 files changed, 73 insertions, 43 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 400af82627..561930f858 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2079,6 +2079,49 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
}
};
+// Replace Mul node with identical inputs with a Square.
+class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
+ public:
+ explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
+ ~ReplaceMulWithSquare() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMul(*node) && node->input(0) == node->input(1);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
+ const string optimized_node_name = OptimizedNodeName(mul);
+ if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
+
+ const DataType type = GetDataTypeFromAttr(*node, "T");
+ bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
+
+ string task;
+ string device;
+ bool is_on_cpu =
+ DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
+ str_util::StrContains(device, DEVICE_CPU);
+
+ if (!is_complex || is_on_cpu) {
+ NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
+ new_square_node->set_op("Square");
+ for (int i = 1; i < new_square_node->input_size(); ++i) {
+ new_square_node->set_input(i - 1, new_square_node->input(i));
+ }
+ new_square_node->mutable_input()->RemoveLast();
+ for (const string& input : new_square_node->input()) {
+ ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
+ }
+ *simplified_node_name = new_square_node->name();
+ }
+
+ return Status::OK();
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2331,29 +2374,6 @@ void ArithmeticOptimizer::ForwardControlDependencies(
// ArithmeticOptimizerStage
string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) {
- if (node->op() == "Mul" && node->input(0) == node->input(1) &&
- !OptimizedNodeExists(*node, "square")) {
- const DataType type = GetDataTypeFromAttr(*node, "T");
- bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
- string dontcare;
- string device;
- bool is_on_cpu =
- DeviceNameUtils::SplitDeviceName(node->device(), &dontcare, &device) &&
- str_util::StrContains(device, DEVICE_CPU);
- if (!is_complex || is_on_cpu) {
- NodeDef* new_square_node = AddNode(*node, "square", /*copy_node=*/true);
- new_square_node->set_op("Square");
- for (int i = 1; i < new_square_node->input_size(); ++i) {
- new_square_node->set_input(i - 1, new_square_node->input(i));
- }
- new_square_node->mutable_input()->RemoveLast();
- for (const string& input : new_square_node->input()) {
- node_map_->AddOutput(NodeName(input), new_square_node->name());
- }
- return new_square_node->name();
- }
- }
-
if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) {
// Discard aggregate nodes with a single input and no control dependencies.
if (node->input_size() == 1) {
@@ -2528,6 +2548,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
+ if (options_.replace_mul_with_square)
+ pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
if (options_.remove_logical_not)
pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
if (options_.reorder_cast_and_transpose)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index e6fc311929..8e00b83a70 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -74,6 +74,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool remove_redundant_cast = true;
bool remove_redundant_reshape = true;
bool reorder_cast_and_transpose = true;
+ bool replace_mul_with_square = true;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index b9fec0f860..f15cbfe407 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -139,6 +139,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.remove_negation = false;
options.remove_logical_not = false;
options.reorder_cast_and_transpose = false;
+ options.replace_mul_with_square = false;
optimizer->options_ = options;
}
@@ -201,6 +202,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.reorder_cast_and_transpose = true;
}
+ void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.replace_mul_with_square = true;
+ }
+
void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_cwise_unary_chains = true;
@@ -345,33 +351,36 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, MulToSquare) {
+TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
Output id = ops::Identity(s.WithOpName("id"), mul);
+
GrapplerItem item;
+ item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"id"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
- ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ ArithmeticOptimizer optimizer;
+ EnableOnlyReplaceMulWithSquare(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
- EXPECT_EQ(5, output.node_size());
- EXPECT_EQ("id", output.node(3).name());
- EXPECT_EQ(OptimizedName("mul_square"), output.node(3).input(0));
- EXPECT_EQ("Square", output.node(4).op());
- EXPECT_EQ(OptimizedName("mul_square"), output.node(4).name());
- EXPECT_EQ(2, output.node(4).input_size());
- EXPECT_EQ("c", output.node(4).input(0));
- EXPECT_EQ("^d", output.node(4).input(1));
+ EXPECT_EQ(4, output.node_size());
- auto tensors = EvaluateNodes(output, fetch);
+ NodeMap node_map(&output);
+ const string p = "ArithmeticOptimizer/ReplaceMulWithSquare";
+ const NodeDef* square_node = node_map.GetNode(strings::StrCat(p, "_", "mul"));
+
+ ASSERT_NE(square_node, nullptr);
+ EXPECT_EQ("Square", square_node->op());
+ EXPECT_EQ("c", square_node->input(0));
+ EXPECT_EQ("^d", square_node->input(1));
+
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
@@ -386,12 +395,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) {
auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
auto id = ops::Identity(s.WithOpName("id"), recip2);
- std::vector<string> fetch = {"id"};
-
GrapplerItem item;
- item.fetch = fetch;
+ item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
GraphDef output;
@@ -404,7 +411,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) {
EXPECT_EQ("id", output.node(1).name());
EXPECT_EQ("c", output.node(1).input(0));
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}