aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers')
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/custom_graph_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc5
4 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 3f573cda10..ad2db685fc 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -243,6 +243,7 @@ cc_library(
deps = [
":graph_optimizer",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
],
)
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h
index a80d46f416..4d7f8c98d0 100644
--- a/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
@@ -26,7 +27,8 @@ namespace grappler {
class CustomGraphOptimizer : public GraphOptimizer {
public:
virtual ~CustomGraphOptimizer() {}
- virtual Status Init() = 0;
+ virtual Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer*
+ config = nullptr) = 0;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc
index 629f5e83c1..bdb1ae8532 100644
--- a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc
@@ -32,7 +32,10 @@ static const char* kTestOptimizerName = "Test";
class TestGraphOptimizer : public CustomGraphOptimizer {
public:
- Status Init() override { return Status::OK(); }
+ Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config =
+ nullptr) override {
+ return Status::OK();
+ }
string name() const override { return kTestOptimizerName; }
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) override {
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index d9a386b9be..9fcf07651b 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -36,7 +36,10 @@ class TestOptimizer : public CustomGraphOptimizer {
TestOptimizer() {}
string name() const override { return "test_optimizer"; }
- Status Init() override { return Status::OK(); }
+ Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config =
+ nullptr) override {
+ return Status::OK();
+ }
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) override {