aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD2
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc61
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h1
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto4
5 files changed, 59 insertions, 12 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 4e41c2bb12..bd96e2b33c 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -17,6 +17,7 @@ filegroup(
srcs = glob(
[
"*_optimizer.*",
+ "constant_folding.*",
"model_pruner.*",
"graph_rewriter.*",
],
@@ -175,6 +176,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":constant_folding",
":graph_optimizer",
":layout_optimizer",
":model_pruner",
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 7cddedef2e..8f79c55810 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -72,8 +72,7 @@ class DeviceSimple : public DeviceBase {
Tensor* tensor) override {
Tensor parsed(tensor_proto.dtype());
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- tensor_proto.DebugString());
+ return errors::InvalidArgument("Cannot parse tensor from tensor_proto.");
}
*tensor = parsed;
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 44a1f5bab9..d82d5a469d 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/constant_folding.h"
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/lib/core/status.h"
@@ -21,25 +23,64 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
+ const string& optimizer) {
+ VLOG(1) << "Adding graph optimization pass: " << optimizer;
+ std::unique_ptr<GraphOptimizer> graph_optimizer;
+ if (optimizer == "pruning") {
+ graph_optimizer.reset(new ModelPruner());
+ }
+ if (optimizer == "constfold") {
+ graph_optimizer.reset(new ConstantFolding());
+ }
+ if (optimizer == "layout") {
+ graph_optimizer.reset(new LayoutOptimizer());
+ }
+ return graph_optimizer;
+}
+
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
- bool already_optimized = false;
- if (!cfg_.disable_model_pruning()) {
- already_optimized = true;
- ModelPruner pruner;
- TF_RETURN_IF_ERROR(pruner.Optimize(nullptr, item, optimized_graph));
+ std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
+ if (cfg_.optimizers().empty()) {
+ if (!cfg_.disable_model_pruning()) {
+ optimizers.push_back(std::unique_ptr<GraphOptimizer>(new ModelPruner()));
+ }
+ if (cfg_.constant_folding()) {
+ optimizers.push_back(
+ std::unique_ptr<GraphOptimizer>(new ConstantFolding()));
+ }
+ if (cfg_.optimize_tensor_layout()) {
+ optimizers.push_back(
+ std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
+ }
+ } else {
+ std::set<string> avaliable_optimizers = {"pruning", "constfold", "layout"};
+ for (const auto& optimizer : cfg_.optimizers()) {
+ if (avaliable_optimizers.find(optimizer) != avaliable_optimizers.end()) {
+ optimizers.push_back(NewOptimizer(optimizer));
+ }
+ }
+ }
+
+ if (optimizers.empty()) {
+ *optimized_graph = item.graph;
+ return Status::OK();
}
- if (cfg_.optimize_tensor_layout()) {
- LayoutOptimizer layout_optimizer;
+
+ bool already_optimized = false;
+ for (const auto& optimizer : optimizers) {
if (!already_optimized) {
- return layout_optimizer.Optimize(nullptr, item, optimized_graph);
+ TF_RETURN_IF_ERROR(optimizer->Optimize(nullptr, item, optimized_graph));
+ already_optimized = true;
} else {
GrapplerItem optimized_item = item;
optimized_item.graph = *optimized_graph;
- return layout_optimizer.Optimize(nullptr, optimized_item,
- optimized_graph);
+ TF_RETURN_IF_ERROR(
+ optimizer->Optimize(nullptr, optimized_item, optimized_graph));
}
}
+
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index d7ff03f590..9def2cd711 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -39,6 +39,7 @@ class MetaOptimizer : public GraphOptimizer {
const GraphDef& optimized_graph, double result) override;
private:
+ std::unique_ptr<GraphOptimizer> NewOptimizer(const string& optimizer);
RewriterConfig cfg_;
};
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index aef69461d8..6e9eff6225 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -9,4 +9,8 @@ option java_package = "org.tensorflow.framework";
message RewriterConfig {
bool optimize_tensor_layout = 1;
bool disable_model_pruning = 2;
+ bool constant_folding = 3;
+ // If non-empty, will use this as an alternative way to specify a list of
+ // optimizations to turn on and the order of the optimizations.
+ repeated string optimizers = 100;
}