aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-11-14 15:16:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-14 15:21:07 -0800
commit6357bafeb80523c45bee21a19def146d221cd295 (patch)
tree7dbcca4b1a2b4e5a148485dd7131511152546cd8
parentf89cffd37c88e4d9fa0ee3ac191e6f5fd5c005c8 (diff)
Use Toggle instead of bool to make the layout optimizer name and usage consistent with other optimizers.
PiperOrigin-RevId: 175743440
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc5
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto2
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py10
4 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
index ead44de1e2..e2db47b758 100644
--- a/tensorflow/core/grappler/clusters/cluster.cc
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -57,7 +57,7 @@ void Cluster::DisableOptimizer(bool disable) {
// Disable Grappler optimizations.
auto rewriter_config =
options_.config.mutable_graph_options()->mutable_rewrite_options();
- rewriter_config->set_optimize_tensor_layout(false);
+ rewriter_config->set_layout_optimizer(RewriterConfig::OFF);
rewriter_config->set_disable_model_pruning(true);
rewriter_config->set_constant_folding(RewriterConfig::OFF);
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 6204a81f80..eb04bc6e9a 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -71,7 +71,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
new ArithmeticOptimizer(cfg_.arithmetic_optimization())));
}
- if (cfg_.optimize_tensor_layout()) {
+ if (cfg_.layout_optimizer() == RewriterConfig::ON) {
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
}
@@ -175,7 +175,8 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
}
bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
- return !cfg.disable_model_pruning() || cfg.optimize_tensor_layout() ||
+ return !cfg.disable_model_pruning() ||
+ cfg.layout_optimizer() == RewriterConfig::ON ||
cfg.constant_folding() != RewriterConfig::OFF ||
cfg.arithmetic_optimization() != RewriterConfig::OFF ||
cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 ||
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 8f3457e97c..eb74d4b1c5 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -30,7 +30,7 @@ message RewriterConfig {
}
// Optimize tensor layouts
- bool optimize_tensor_layout = 1;
+ Toggle layout_optimizer = 1;
// Fold constants (default is ON)
Toggle constant_folding = 3;
// Arithmetic optimizations (default is ON)
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index bc9d910447..9ac33fbb4a 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -88,8 +88,12 @@ def loop():
def get_config(layout_optimizer=True):
- rewrite_options = rewriter_config_pb2.RewriterConfig(
- optimize_tensor_layout=layout_optimizer)
+ if layout_optimizer:
+ rewrite_options = rewriter_config_pb2.RewriterConfig(
+ layout_optimizer=rewriter_config_pb2.RewriterConfig.ON)
+ else:
+ rewrite_options = rewriter_config_pb2.RewriterConfig(
+ layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(
rewrite_options=rewrite_options, build_cost_model=1)
config = config_pb2.ConfigProto(graph_options=graph_options)
@@ -194,7 +198,7 @@ class LayoutOptimizerTest(test.TestCase):
meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def())
rewrite_options = rewriter_config_pb2.RewriterConfig(
- optimize_tensor_layout=True)
+ layout_optimizer=rewriter_config_pb2.RewriterConfig.ON)
optimized_graph = tf_optimizer.OptimizeGraph(rewrite_options, meta_graph)
found = 0