diff options
author | 2017-11-14 15:16:45 -0800 | |
---|---|---|
committer | 2017-11-14 15:21:07 -0800 | |
commit | 6357bafeb80523c45bee21a19def146d221cd295 (patch) | |
tree | 7dbcca4b1a2b4e5a148485dd7131511152546cd8 | |
parent | f89cffd37c88e4d9fa0ee3ac191e6f5fd5c005c8 (diff) |
Use Toggle instead of bool to make the layout optimizer name and usage consistent with other optimizers.
PiperOrigin-RevId: 175743440
4 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc index ead44de1e2..e2db47b758 100644 --- a/tensorflow/core/grappler/clusters/cluster.cc +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -57,7 +57,7 @@ void Cluster::DisableOptimizer(bool disable) { // Disable Grappler optimizations. auto rewriter_config = options_.config.mutable_graph_options()->mutable_rewrite_options(); - rewriter_config->set_optimize_tensor_layout(false); + rewriter_config->set_layout_optimizer(RewriterConfig::OFF); rewriter_config->set_disable_model_pruning(true); rewriter_config->set_constant_folding(RewriterConfig::OFF); rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 6204a81f80..eb04bc6e9a 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -71,7 +71,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizers.push_back(std::unique_ptr<GraphOptimizer>( new ArithmeticOptimizer(cfg_.arithmetic_optimization()))); } - if (cfg_.optimize_tensor_layout()) { + if (cfg_.layout_optimizer() == RewriterConfig::ON) { optimizers.push_back( std::unique_ptr<GraphOptimizer>(new LayoutOptimizer())); } @@ -175,7 +175,8 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, } bool MetaOptimizerEnabled(const RewriterConfig& cfg) { - return !cfg.disable_model_pruning() || cfg.optimize_tensor_layout() || + return !cfg.disable_model_pruning() || + cfg.layout_optimizer() == RewriterConfig::ON || cfg.constant_folding() != RewriterConfig::OFF || cfg.arithmetic_optimization() != RewriterConfig::OFF || cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 || diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 8f3457e97c..eb74d4b1c5 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -30,7 +30,7 @@ message RewriterConfig { } // Optimize tensor layouts - bool optimize_tensor_layout = 1; + Toggle layout_optimizer = 1; // Fold constants (default is ON) Toggle constant_folding = 3; // Arithmetic optimizations (default is ON) diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index bc9d910447..9ac33fbb4a 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -88,8 +88,12 @@ def loop(): def get_config(layout_optimizer=True): - rewrite_options = rewriter_config_pb2.RewriterConfig( - optimize_tensor_layout=layout_optimizer) + if layout_optimizer: + rewrite_options = rewriter_config_pb2.RewriterConfig( + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) + else: + rewrite_options = rewriter_config_pb2.RewriterConfig( + layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions( rewrite_options=rewrite_options, build_cost_model=1) config = config_pb2.ConfigProto(graph_options=graph_options) @@ -194,7 +198,7 @@ class LayoutOptimizerTest(test.TestCase): meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def()) rewrite_options = rewriter_config_pb2.RewriterConfig( - optimize_tensor_layout=True) + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) optimized_graph = tf_optimizer.OptimizeGraph(rewrite_options, meta_graph) found = 0 |