diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-27 10:23:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 10:27:35 -0700 |
commit | 941b4e0f226de76f083401842e73bd9efd6db2d0 (patch) | |
tree | 52bb6ab9afb0d74fa1c5c437d0b7ca1978138926 /tensorflow/core/grappler | |
parent | 62e60166de65d6604b897f2575a5accc86160496 (diff) |
Fix support for custom optimizers in explicit schedule
PiperOrigin-RevId: 214794973
Diffstat (limited to 'tensorflow/core/grappler')
3 files changed, 56 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index c59645e5f2..e18a5f21d2 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -172,11 +172,12 @@ Status MetaOptimizer::InitializeOptimizers( optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>( cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); } - return InitializeCustomGraphOptimizers(optimizers); + return InitializeCustomGraphOptimizers(std::set<string>(), optimizers); } Status MetaOptimizer::InitializeOptimizersByName( std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { + std::set<string> initialized_custom_optimizers; for (const string& optimizer_name : cfg_.optimizers()) { auto optimizer = MakeNewOptimizer(optimizer_name); if (optimizer) { @@ -190,18 +191,26 @@ Status MetaOptimizer::InitializeOptimizersByName( if (custom_optimizer) { VLOG(2) << "Registered custom graph optimizer: " << optimizer_name; - TF_RETURN_IF_ERROR(custom_optimizer->Init()); + TF_RETURN_IF_ERROR(custom_optimizer->Init( + GetCustomGraphOptimizerConfig(optimizer_name))); optimizers->push_back(std::move(custom_optimizer)); + initialized_custom_optimizers.insert(optimizer_name); } else { VLOG(2) << "Can't register an optimizer by name: " << optimizer_name; } } - return InitializeCustomGraphOptimizers(optimizers); + return InitializeCustomGraphOptimizers(initialized_custom_optimizers, + optimizers); } Status MetaOptimizer::InitializeCustomGraphOptimizers( + const std::set<string>& pre_initialized_optimizers, std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { for (const auto& optimizer_config : cfg_.custom_optimizers()) { + if (pre_initialized_optimizers.find(optimizer_config.name()) != + pre_initialized_optimizers.end()) { + continue; + } // Initialize the ExperimentalImplementationSelector here instead of // CustomizeOptimizer registry, due the static link issue in TensorRT for // double registry. @@ -237,6 +246,16 @@ Status MetaOptimizer::InitializeCustomGraphOptimizers( return Status::OK(); } +const RewriterConfig::CustomGraphOptimizer* +MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const { + for (const auto& config : cfg_.custom_optimizers()) { + if (config.name() == name) { + return &config; + } + } + return nullptr; +} + Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index 831c5e37c0..99a0a33ffa 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -54,7 +54,11 @@ class MetaOptimizer : public GraphOptimizer { std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; // Initialize active optimizers from RewriterConfig.custom_optimizers. Status InitializeCustomGraphOptimizers( + const std::set<string>& pre_initialized_optimizers, std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; + // Returns the config for a custom graph optimizer. Null if none was found. + const RewriterConfig::CustomGraphOptimizer* GetCustomGraphOptimizerConfig( + const string& name) const; // Run optimization pass over a single GrapplerItem. Meta optimizer might run // multiple such passes: 1) for the main graph 2) for the function library diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index e74e0f7501..c477c4d4b1 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -71,6 +71,17 @@ class TestGraphOptimizer : public TestOptimizer { REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer); +class TestOptimizerWithParams : public TestOptimizer { + public: + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + CHECK(config != nullptr); + return Status::OK(); + } +}; + +REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams); + class MetaOptimizerTest : public GrapplerTest {}; TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { @@ -90,6 +101,25 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { EXPECT_TRUE(TestOptimizer::IsOptimized()); } +TEST_F(MetaOptimizerTest, RunsCustomOptimizerWithParams) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TestOptimizer::SetOptimized(false); + RewriterConfig rewriter_config; + rewriter_config.add_optimizers("TestOptimizerWithParams"); + auto* custom_config = rewriter_config.add_custom_optimizers(); + custom_config->set_name("TestOptimizerWithParams"); + (*custom_config->mutable_parameter_map())["foo"] = AttrValue(); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_TRUE(TestOptimizer::IsOptimized()); +} + TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) { TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); GrapplerItem item; |