aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2018-02-23 16:04:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-23 16:08:55 -0800
commitbeed05217cf8c3d90784a66cec7c97e042ff5258 (patch)
tree05205522cf14fb20515b63b63a40f8de034c9a1a /tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc
parentbd946a5bd7b59be8bb276fdd93e0a97653dedbfd (diff)
Add custom registered graph optimizers run by MetaOptimizer.
PiperOrigin-RevId: 186837828
Diffstat (limited to 'tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc61
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc
new file mode 100644
index 0000000000..6eed43c2b1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace grappler {
+
+namespace {
+typedef std::unordered_map<string, CustomGraphOptimizerRegistry::Creator>
+ RegistrationMap;
+RegistrationMap* registered_optimizers = nullptr;
+RegistrationMap* GetRegistrationMap() {
+ if (registered_optimizers == nullptr)
+ registered_optimizers = new RegistrationMap;
+ return registered_optimizers;
+}
+} // namespace
+
+std::unique_ptr<CustomGraphOptimizer>
+CustomGraphOptimizerRegistry::CreateByNameOrNull(const string& name) {
+ const auto it = GetRegistrationMap()->find(name);
+ if (it == GetRegistrationMap()->end()) return nullptr;
+ return std::unique_ptr<CustomGraphOptimizer>(it->second());
+}
+
+std::vector<string> CustomGraphOptimizerRegistry::GetRegisteredOptimizers() {
+ std::vector<string> optimizer_names;
+ optimizer_names.reserve(GetRegistrationMap()->size());
+ for (const auto& opt : *GetRegistrationMap())
+ optimizer_names.emplace_back(opt.first);
+ return optimizer_names;
+}
+
+void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
+ const Creator& optimizer_creator, const string& name) {
+ const auto it = GetRegistrationMap()->find(name);
+ if (it != GetRegistrationMap()->end()) {
+ LOG(FATAL) << "CustomGraphOptimizer is registered twice: " << name;
+ }
+ GetRegistrationMap()->insert({name, optimizer_creator});
+}
+
+} // end namespace grappler
+} // end namespace tensorflow