aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-11 10:20:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-11 10:26:08 -0800
commit4ba3147461f2cd1b73029f986cf806b33d0ce290 (patch)
tree547e5f1567a12ca1afa194b3410ca0a77e8abedd
parent7eba57baec4442640f11059caecfc10898966e00 (diff)
Enable identity reshape and common factor hoisting optimizations.
PiperOrigin-RevId: 181625889
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc23
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc13
2 files changed, 14 insertions, 22 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d6bc8614f9..fe0af3434a 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -632,12 +632,11 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
// If the reshape is a no-op, forward its input to its consumers. This is
- // considered aggressive and turned off by default, because users may state
- // that the placeholder outputs tensors of shape [M, N] while feeding it
- // with tensors of shape [M*N] (or worse). The reshape nodes are then
- // necessary to update the tensor metadata to the required shape.
- if (opt_level_ == RewriterConfig::AGGRESSIVE &&
- ReshapeIsIdentity(*reshape, *input, output_pos)) {
+ // considered aggressive, because users may state that the placeholder
+ // outputs tensors of shape [M, N] while feeding it with tensors of shape
+ // [M*N] (or worse). The reshape nodes are then necessary to update the
+ // tensor metadata to the required shape.
+ if (ReshapeIsIdentity(*reshape, *input, output_pos)) {
return reshape->input(0);
}
}
@@ -896,8 +895,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
// to the following:
// Mul(x, AddN(y1, y2, y3, ... yn))
- if (opt_level_ == RewriterConfig::AGGRESSIVE && IsAggregate(*node) &&
- NumNonControlInputs(*node) > 1 &&
+ if (IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
!OptimizedNodeExists(StrCat(node->name(), "_hoist_add"))) {
// Determine the set of common factors if the input nodes are all Mul nodes.
std::set<string> common_factors;
@@ -1110,12 +1108,9 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
&frame_map_, &num_frames));
// Shapes are only needed in aggressive mode.
- if (opt_level_ == RewriterConfig::AGGRESSIVE) {
- graph_properties_.reset(new GraphProperties(item));
- TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
- TF_RETURN_IF_ERROR(
- graph_properties_->AnnotateOutputShapes(optimized_graph_));
- }
+ graph_properties_.reset(new GraphProperties(item));
+ TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
+ TF_RETURN_IF_ERROR(graph_properties_->AnnotateOutputShapes(optimized_graph_));
// Perform the optimizations.
DedupComputations();
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index da4263ff42..b5b1ec7021 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -350,7 +350,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device(devices[i]);
}
- ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ ArithmeticOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -423,7 +423,7 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) {
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ ArithmeticOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -625,8 +625,7 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
- .Optimize(nullptr, item, &output));
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph = output;
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
@@ -650,8 +649,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
- .Optimize(nullptr, item, &output));
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph = output;
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
@@ -673,8 +671,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
- .Optimize(nullptr, item, &output));
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph = output;
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));