diff options
author | 2018-01-11 10:20:51 -0800 | |
---|---|---|
committer | 2018-01-11 10:26:08 -0800 | |
commit | 4ba3147461f2cd1b73029f986cf806b33d0ce290 (patch) | |
tree | 547e5f1567a12ca1afa194b3410ca0a77e8abedd | |
parent | 7eba57baec4442640f11059caecfc10898966e00 (diff) |
Enable identity reshape and common factor hoisting optimizations.
PiperOrigin-RevId: 181625889
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 13 |
2 files changed, 14 insertions, 22 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d6bc8614f9..fe0af3434a 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -632,12 +632,11 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } // If the reshape is a no-op, forward its input to its consumers. This is - // considered aggressive and turned off by default, because users may state - // that the placeholder outputs tensors of shape [M, N] while feeding it - // with tensors of shape [M*N] (or worse). The reshape nodes are then - // necessary to update the tensor metadata to the required shape. - if (opt_level_ == RewriterConfig::AGGRESSIVE && - ReshapeIsIdentity(*reshape, *input, output_pos)) { + // considered aggressive, because users may state that the placeholder + // outputs tensors of shape [M, N] while feeding it with tensors of shape + // [M*N] (or worse). The reshape nodes are then necessary to update the + // tensor metadata to the required shape. + if (ReshapeIsIdentity(*reshape, *input, output_pos)) { return reshape->input(0); } } @@ -896,8 +895,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn)) // to the following: // Mul(x, AddN(y1, y2, y3, ... yn)) - if (opt_level_ == RewriterConfig::AGGRESSIVE && IsAggregate(*node) && - NumNonControlInputs(*node) > 1 && + if (IsAggregate(*node) && NumNonControlInputs(*node) > 1 && !OptimizedNodeExists(StrCat(node->name(), "_hoist_add"))) { // Determine the set of common factors if the input nodes are all Mul nodes. std::set<string> common_factors; @@ -1110,12 +1108,9 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_, &frame_map_, &num_frames)); // Shapes are only needed in aggressive mode. - if (opt_level_ == RewriterConfig::AGGRESSIVE) { - graph_properties_.reset(new GraphProperties(item)); - TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false)); - TF_RETURN_IF_ERROR( - graph_properties_->AnnotateOutputShapes(optimized_graph_)); - } + graph_properties_.reset(new GraphProperties(item)); + TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false)); + TF_RETURN_IF_ERROR(graph_properties_->AnnotateOutputShapes(optimized_graph_)); // Perform the optimizations. DedupComputations(); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index da4263ff42..b5b1ec7021 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -350,7 +350,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { for (int i = 0; i < item.graph.node_size(); ++i) { item.graph.mutable_node(i)->set_device(devices[i]); } - ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE); + ArithmeticOptimizer optimizer; GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -423,7 +423,7 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE); + ArithmeticOptimizer optimizer; GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -625,8 +625,7 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) - .Optimize(nullptr, item, &output)); + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); item.graph = output; TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); @@ -650,8 +649,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) - .Optimize(nullptr, item, &output)); + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); item.graph = output; TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); @@ -673,8 +671,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) - .Optimize(nullptr, item, &output)); + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); item.graph = output; TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); |