diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-31 14:01:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-31 14:03:45 -0700 |
commit | 395428bcaf02c9a9e8067083993d7e6b5afdc0a6 (patch) | |
tree | 028e3a7c9922edab67e89cddeaec37f90ac1bec7 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | d3b5b07e7810782c3760468312f9cace10b89073 (diff) |
Move RemodeRedundantReshape optimization to a separate stage.
PiperOrigin-RevId: 198775276
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 90 |
1 files changed, 49 insertions, 41 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index f678ea7227..43355ef945 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -124,6 +124,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_idempotent = false; options.remove_redundant_bitcast = false; options.remove_redundant_cast = false; + options.remove_redundant_reshape = false; options.remove_negation = false; options.remove_logical_not = false; optimizer->options_ = options; @@ -168,6 +169,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.remove_redundant_cast = true; } + void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_redundant_reshape = true; + } + void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.remove_negation = true; @@ -955,7 +961,7 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) { test::ExpectTensorNear<complex64>(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, IdentityReshape) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_IdentityReshape) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28})); @@ -977,11 +983,11 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}}); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(0, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}}); @@ -989,7 +995,8 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) { +TEST_F(ArithmeticOptimizerTest, + RemoveRedundantReshape_IdentityReshapeBetweenSymbolicShapes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1})); @@ -1009,27 +1016,28 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) { Output reshape = ops::Reshape(s, inputs, target_shape); Output outputs = ops::Identity(s.WithOpName("outputs"), reshape); + auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28})); GrapplerItem item; item.fetch = {"outputs"}; + item.feed = {{"Placeholder", x_t}}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28})); - auto tensors_expected = - EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}}); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) - .Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + // Assume valid feed shape in aggressive mode. + ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE); + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(0, CountOpNodes(output, "Reshape")); - auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}}); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotAssumeValidFeeds) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28})); @@ -1047,10 +1055,9 @@ TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { EXPECT_EQ(1, tensors_expected.size()); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); // The reshape is preserved because the shape of the placeholder can be // different from the shape of the actual feed. @@ -1061,7 +1068,8 @@ TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) { +TEST_F(ArithmeticOptimizerTest, + RemoveRedundantReshape_AssumeValidFeedsInAggressiveMode) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28})); @@ -1077,12 +1085,11 @@ TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) { auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) - .Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE); + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(0, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, item.feed); @@ -1090,7 +1097,7 @@ TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) { test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotIdentityReshape) { // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can // be from [4,3,28,28] to [8,6,28,28]. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -1106,11 +1113,11 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { item.feed = {{"Placeholder", x_t}}; auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(1, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, item.feed); @@ -1118,7 +1125,8 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { +TEST_F(ArithmeticOptimizerTest, + RemoveRedundantReshape_NotIdentityReshapeTooManyUnknownDimSizes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3})); @@ -1128,16 +1136,16 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { GrapplerItem item; item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(1, CountOpNodes(output, "Reshape")); } -TEST_F(ArithmeticOptimizerTest, CombineReshapes) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_CombineReshapes) { // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two // reshapes should be combined. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -1162,11 +1170,11 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) { item.feed = {{"nchw_vect_c", x_t}}; auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(1, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, item.feed); |