aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/compiler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/compiler.cc')
-rw-r--r--tensorflow/compiler/xla/service/compiler.cc96
1 files changed, 96 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc
new file mode 100644
index 0000000000..f71b2b6b9c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/compiler.cc
@@ -0,0 +1,96 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/compiler.h"
+
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+/* static */ tensorflow::mutex* Compiler::platform_compiler_mutex_;
+
+/* static */ void Compiler::LazyInitMutex() {
+ static std::once_flag mutex_init_flag;
+ std::call_once(mutex_init_flag, []() {
+ Compiler::platform_compiler_mutex_ = new tensorflow::mutex;
+ });
+}
+
+/* static */ std::map<perftools::gputools::Platform::Id,
+ Compiler::CompilerFactory>*
+Compiler::GetPlatformCompilerFactories() {
+ static auto* r =
+ new std::map<perftools::gputools::Platform::Id, CompilerFactory>;
+ return r;
+}
+
+/* static */
+std::map<perftools::gputools::Platform::Id, std::unique_ptr<Compiler>>*
+Compiler::GetPlatformCompilers() {
+ static auto* r = new std::map<perftools::gputools::Platform::Id,
+ std::unique_ptr<Compiler>>;
+ return r;
+}
+
+/* static */ void Compiler::RegisterCompilerFactory(
+ se::Platform::Id platform_id,
+ std::function<std::unique_ptr<Compiler>()> compiler_factory) {
+ LazyInitMutex();
+ tensorflow::mutex_lock lock(*platform_compiler_mutex_);
+ auto* factories = GetPlatformCompilerFactories();
+ CHECK(factories->find(platform_id) == factories->end());
+ (*factories)[platform_id] = std::move(compiler_factory);
+}
+
+/* static */ StatusOr<Compiler*> Compiler::GetForPlatform(
+ const se::Platform* platform) {
+ LazyInitMutex();
+ tensorflow::mutex_lock lock(*platform_compiler_mutex_);
+
+ auto* compilers = GetPlatformCompilers();
+ // See if we already instantiated a compiler for this platform.
+ {
+ auto it = compilers->find(platform->id());
+ if (it != compilers->end()) {
+ return it->second.get();
+ }
+
+ // If not, we just fall through to try to create one with a registered
+ // factory.
+ }
+
+ auto* factories = GetPlatformCompilerFactories();
+ auto it = factories->find(platform->id());
+ if (it == factories->end()) {
+ return NotFound(
+ "could not find registered compiler for platform %s -- check "
+ "target linkage",
+ platform->Name().c_str());
+ }
+
+ // And then we invoke the factory, placing the result into the mapping.
+ compilers->insert(std::make_pair(platform->id(), it->second()));
+ return compilers->at(platform->id()).get();
+}
+
+} // namespace xla