diff options
author | 2018-02-23 16:04:38 -0800 | |
---|---|---|
committer | 2018-02-23 16:08:55 -0800 | |
commit | beed05217cf8c3d90784a66cec7c97e042ff5258 (patch) | |
tree | 05205522cf14fb20515b63b63a40f8de034c9a1a /tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc | |
parent | bd946a5bd7b59be8bb276fdd93e0a97653dedbfd (diff) |
Add custom registered graph optimizers run by MetaOptimizer.
PiperOrigin-RevId: 186837828
Diffstat (limited to 'tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc new file mode 100644 index 0000000000..6eed43c2b1 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" + +#include <string> +#include <unordered_map> + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace grappler { + +namespace { +typedef std::unordered_map<string, CustomGraphOptimizerRegistry::Creator> + RegistrationMap; +RegistrationMap* registered_optimizers = nullptr; +RegistrationMap* GetRegistrationMap() { + if (registered_optimizers == nullptr) + registered_optimizers = new RegistrationMap; + return registered_optimizers; +} +} // namespace + +std::unique_ptr<CustomGraphOptimizer> +CustomGraphOptimizerRegistry::CreateByNameOrNull(const string& name) { + const auto it = GetRegistrationMap()->find(name); + if (it == GetRegistrationMap()->end()) return nullptr; + return std::unique_ptr<CustomGraphOptimizer>(it->second()); +} + +std::vector<string> CustomGraphOptimizerRegistry::GetRegisteredOptimizers() { + std::vector<string> optimizer_names; + optimizer_names.reserve(GetRegistrationMap()->size()); + for (const auto& opt : *GetRegistrationMap()) + optimizer_names.emplace_back(opt.first); + return optimizer_names; +} + +void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie( + const Creator& optimizer_creator, const string& name) { + const auto it = GetRegistrationMap()->find(name); + if (it != GetRegistrationMap()->end()) { + LOG(FATAL) << "CustomGraphOptimizer is registered twice: " << name; + } + GetRegistrationMap()->insert({name, optimizer_creator}); +} + +} // end namespace grappler +} // end namespace tensorflow |