diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers')
4 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 3f573cda10..ad2db685fc 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -243,6 +243,7 @@ cc_library( deps = [ ":graph_optimizer", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h index a80d46f416..4d7f8c98d0 100644 --- a/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { namespace grappler { @@ -26,7 +27,8 @@ namespace grappler { class CustomGraphOptimizer : public GraphOptimizer { public: virtual ~CustomGraphOptimizer() {} - virtual Status Init() = 0; + virtual Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* + config = nullptr) = 0; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc index 629f5e83c1..bdb1ae8532 100644 --- a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc @@ -32,7 +32,10 @@ static const char* kTestOptimizerName = "Test"; class TestGraphOptimizer : public CustomGraphOptimizer { public: - Status Init() override { return Status::OK(); } + Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config = + nullptr) override { + return Status::OK(); + } string name() const override { return kTestOptimizerName; } Status Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) override { diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index d9a386b9be..9fcf07651b 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -36,7 +36,10 @@ class TestOptimizer : public CustomGraphOptimizer { TestOptimizer() {} string name() const override { return "test_optimizer"; } - Status Init() override { return Status::OK(); } + Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config = + nullptr) override { + return Status::OK(); + } Status Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) override { |