diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-30 14:55:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-30 15:28:43 -0700 |
commit | d15f77048558a7af16648146faca1c5d13d8d6e1 (patch) | |
tree | 098fc91e752605870bc56b04251bd0198d991285 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | e469934f1274c7c498e5061995fec425a21c9be8 (diff) |
Move RemoveInvolution optimization to optimizer stage.
PiperOrigin-RevId: 198624394
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 130 |
1 files changed, 75 insertions, 55 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 64fdc8a83b..a908416e45 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -115,12 +115,17 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.dedup_computations = false; options.enable_try_simplify_and_replace = false; options.combine_add_to_addn = false; + options.convert_sqrt_div_to_rsqrt_mul = false; options.hoist_common_factor_out_of_aggregation = false; + options.hoist_cwise_unary_chains = false; options.minimize_broadcasts = false; options.remove_identity_transpose = false; + options.remove_involution = false; + options.remove_idempotent = false; options.remove_redundant_bitcast = false; options.remove_redundant_cast = false; options.remove_negation = false; + options.remove_logical_not = false; optimizer->options_ = options; } @@ -148,6 +153,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.remove_identity_transpose = true; } + void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_involution = true; + } + void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.remove_redundant_bitcast = true; @@ -338,100 +348,110 @@ TEST_F(ArithmeticOptimizerTest, MulToSquare) { test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) { +TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); - Output neg1 = ops::Neg(s.WithOpName("neg1"), c); - Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1); - Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2); - Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1); - Output id = ops::Identity(s.WithOpName("id"), recip2); + + auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); + auto neg1 = ops::Neg(s.WithOpName("neg1"), c); + auto neg2 = ops::Neg(s.WithOpName("neg2"), neg1); + auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2); + auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1); + auto id = ops::Identity(s.WithOpName("id"), recip2); + + std::vector<string> fetch = {"id"}; + GrapplerItem item; + item.fetch = fetch; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - std::vector<string> fetch = {"id"}; auto tensors_expected = EvaluateNodes(item.graph, fetch); EXPECT_EQ(1, tensors_expected.size()); - ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveInvolution(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); - EXPECT_EQ(6, output.node_size()); + // Negation and Reciprocal nodes cancelled each other. + EXPECT_EQ(2, output.node_size()); + EXPECT_EQ("id", output.node(1).name()); EXPECT_EQ("c", output.node(1).input(0)); - EXPECT_EQ("c", output.node(3).input(0)); - EXPECT_EQ("c", output.node(5).input(0)); auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) { +TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AroundValuePreservingChain) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); - Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); - Output id1 = ops::Identity(s.WithOpName("id1"), recip1); - Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); - Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze); - Output id2 = ops::Identity(s.WithOpName("id2"), recip2); + + auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); + auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); + auto id1 = ops::Identity(s.WithOpName("id1"), recip1); + auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); + auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze); + auto id2 = ops::Identity(s.WithOpName("id2"), recip2); + + std::vector<string> fetch = {"id2"}; + GrapplerItem item; + item.fetch = fetch; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - std::vector<string> fetch = {"id2"}; auto tensors_expected = EvaluateNodes(item.graph, fetch); EXPECT_EQ(1, tensors_expected.size()); - ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveInvolution(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); - EXPECT_EQ(6, output.node_size()); - EXPECT_EQ("squeeze", output.node(5).input(0)); - EXPECT_EQ("c", output.node(2).input(0)); + // Check that Reciprocal nodes were removed from the graph. + EXPECT_EQ(3, output.node_size()); + + // And const directly flows into squeeze. + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "squeeze") { + EXPECT_EQ("c", node.input(0)); + found++; + } else if (node.name() == "id2") { + EXPECT_EQ("squeeze", node.input(0)); + found++; + } + } + EXPECT_EQ(2, found); auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) { +TEST_F(ArithmeticOptimizerTest, RemoveInvolution_SkipControlDependencies) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); - Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); - Output id1 = ops::Identity(s.WithOpName("id1"), recip1); - Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); - Output recip2 = ops::Reciprocal( + + auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); + auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); + auto id1 = ops::Identity(s.WithOpName("id1"), recip1); + auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); + auto recip2 = ops::Reciprocal( s.WithOpName("recip2").WithControlDependencies(squeeze), c); - Output id2 = ops::Identity(s.WithOpName("id2"), recip2); + auto id2 = ops::Identity(s.WithOpName("id2"), recip2); + + std::vector<string> fetch = {"id2"}; + GrapplerItem item; + item.fetch = fetch; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - std::vector<string> fetch = {"id2"}; auto tensors_expected = EvaluateNodes(item.graph, fetch); EXPECT_EQ(1, tensors_expected.size()); - ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveInvolution(&optimizer); + OptimizeTwice(&optimizer, &item, &output); // do not prune in this test // The optimizer should be a noop. - EXPECT_EQ(item.graph.node_size(), output.node_size()); - for (int i = 0; i < item.graph.node_size(); ++i) { - const NodeDef& original = item.graph.node(i); - const NodeDef& optimized = output.node(i); - EXPECT_EQ(original.name(), optimized.name()); - EXPECT_EQ(original.op(), optimized.op()); - EXPECT_EQ(original.input_size(), optimized.input_size()); - for (int j = 0; j < original.input_size(); ++j) { - EXPECT_EQ(original.input(j), optimized.input(j)); - } - } + VerifyGraphsMatch(item.graph, output, __LINE__); auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(1, tensors.size()); @@ -2777,7 +2797,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) { ArithmeticOptimizer optimizer; EnableOnlyRemoveLogicalNot(&optimizer); OptimizeTwice(&optimizer, &item, &output); - LOG(INFO) << output.DebugString(); + int found = 0; for (const NodeDef& node : output.node()) { if (node.name() == "id_not_eq") { |