From 6357bafeb80523c45bee21a19def146d221cd295 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Tue, 14 Nov 2017 15:16:45 -0800 Subject: Use Toggle instead of bool to make the layout optimizer name and usage consistent with other optimizers. PiperOrigin-RevId: 175743440 --- tensorflow/core/grappler/clusters/cluster.cc | 2 +- tensorflow/core/grappler/optimizers/meta_optimizer.cc | 5 +++-- tensorflow/core/protobuf/rewriter_config.proto | 2 +- tensorflow/python/grappler/layout_optimizer_test.py | 10 +++++++--- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc index ead44de1e2..e2db47b758 100644 --- a/tensorflow/core/grappler/clusters/cluster.cc +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -57,7 +57,7 @@ void Cluster::DisableOptimizer(bool disable) { // Disable Grappler optimizations. auto rewriter_config = options_.config.mutable_graph_options()->mutable_rewrite_options(); - rewriter_config->set_optimize_tensor_layout(false); + rewriter_config->set_layout_optimizer(RewriterConfig::OFF); rewriter_config->set_disable_model_pruning(true); rewriter_config->set_constant_folding(RewriterConfig::OFF); rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 6204a81f80..eb04bc6e9a 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -71,7 +71,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizers.push_back(std::unique_ptr( new ArithmeticOptimizer(cfg_.arithmetic_optimization()))); } - if (cfg_.optimize_tensor_layout()) { + if (cfg_.layout_optimizer() == RewriterConfig::ON) { optimizers.push_back( std::unique_ptr(new LayoutOptimizer())); } @@ -175,7 +175,8 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, } bool MetaOptimizerEnabled(const RewriterConfig& cfg) { - return !cfg.disable_model_pruning() || cfg.optimize_tensor_layout() || + return !cfg.disable_model_pruning() || + cfg.layout_optimizer() == RewriterConfig::ON || cfg.constant_folding() != RewriterConfig::OFF || cfg.arithmetic_optimization() != RewriterConfig::OFF || cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 || diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 8f3457e97c..eb74d4b1c5 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -30,7 +30,7 @@ message RewriterConfig { } // Optimize tensor layouts - bool optimize_tensor_layout = 1; + Toggle layout_optimizer = 1; // Fold constants (default is ON) Toggle constant_folding = 3; // Arithmetic optimizations (default is ON) diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index bc9d910447..9ac33fbb4a 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -88,8 +88,12 @@ def loop(): def get_config(layout_optimizer=True): - rewrite_options = rewriter_config_pb2.RewriterConfig( - optimize_tensor_layout=layout_optimizer) + if layout_optimizer: + rewrite_options = rewriter_config_pb2.RewriterConfig( + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) + else: + rewrite_options = rewriter_config_pb2.RewriterConfig( + layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions( rewrite_options=rewrite_options, build_cost_model=1) config = config_pb2.ConfigProto(graph_options=graph_options) @@ -194,7 +198,7 @@ class LayoutOptimizerTest(test.TestCase): meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def()) rewrite_options = rewriter_config_pb2.RewriterConfig( - optimize_tensor_layout=True) + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) optimized_graph = tf_optimizer.OptimizeGraph(rewrite_options, meta_graph) found = 0 -- cgit v1.2.3