aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2017-10-24 10:41:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-24 10:46:59 -0700
commit720efa37a4e93d5833e6e928993790f2523f0d85 (patch)
treec3c8709db2128acbdd096a884b29396baa5eff82
parent1bbec9e4e9c5d3fbbc2fa2b58841435e86dbf76a (diff)
Roll forward CL 171084886
171084886 had to be rolled back twice due to various open source build issues. I'm trying again, now that I think I've addressed all the pertinent issues. Original CL description: Don't use dlsym to resolve symbols in the CPU JIT Instead of resolving symbols via dlsym when JITting for the CPU backend, use a registry based mechanism. This lets us kill off the --export_dynamic hack that we used to need for CustomCall on the CPU backend. PiperOrigin-RevId: 173277862
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD12
-rw-r--r--tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc39
-rw-r--r--tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h74
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc198
-rw-r--r--tensorflow/compiler/xla/tests/BUILD3
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc14
-rw-r--r--tensorflow/compiler/xla/xla.bzl8
10 files changed, 259 insertions, 98 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 4ee7989824..2b43e313eb 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -5,7 +5,6 @@ package(
)
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
-load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts")
tf_kernel_library(
name = "xla_ops",
@@ -153,6 +152,7 @@ cc_library(
srcs = ["index_ops_kernel_argmax_float_1d.cc"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
@@ -164,6 +164,7 @@ cc_library(
srcs = ["index_ops_kernel_argmax_float_2d.cc"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc
index afbd64ca50..47cf8c6675 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc
@@ -16,6 +16,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/macros.h"
@@ -47,3 +48,5 @@ EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) {
extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) {
tensorflow::argmax_float_1d_xla_impl(out, data);
}
+
+REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc
index 841ff2f4df..9b83392d8f 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc
@@ -16,6 +16,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/macros.h"
@@ -49,3 +50,5 @@ EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) {
extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) {
tensorflow::argmax_float_2d_xla_impl(out, data);
}
+
+REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 136cbe7cb7..56bc1a6706 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -153,6 +153,7 @@ cc_library(
":cpu_runtime_avx",
":cpu_runtime_neon",
":cpu_runtime_sse4_1",
+ ":custom_call_target_registry",
":disassembler",
":external_constant_pool",
":runtime_conv2d",
@@ -719,6 +720,17 @@ cc_library(
],
)
+cc_library(
+ name = "custom_call_target_registry",
+ srcs = [
+ "custom_call_target_registry.cc",
+ ],
+ hdrs = [
+ "custom_call_target_registry.h",
+ ],
+ visibility = ["//visibility:public"],
+)
+
# -----------------------------------------------------------------------------
filegroup(
diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc
new file mode 100644
index 0000000000..5f5803874b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
+
+namespace xla {
+namespace cpu {
+
+CustomCallTargetRegistry* CustomCallTargetRegistry::Global() {
+ static auto* registry = new CustomCallTargetRegistry;
+ return registry;
+}
+
+void CustomCallTargetRegistry::Register(const std::string& symbol,
+ void* address) {
+ std::lock_guard<std::mutex> lock(mu_);
+ registered_symbols_[symbol] = address;
+}
+
+void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const {
+ std::lock_guard<std::mutex> lock(mu_);
+ auto it = registered_symbols_.find(symbol);
+ return it == registered_symbols_.end() ? nullptr : it->second;
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h
new file mode 100644
index 0000000000..2994642356
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
+
+// This file is depended on by kernels that have to build for mobile devices.
+// For this reason, we avoid relying on TensorFlow and instead only use the
+// standard C++ library.
+
+#include <mutex> // NOLINT
+#include <string>
+#include <unordered_map>
+
+namespace xla {
+namespace cpu {
+
+// The CPU JIT compiler uses this registry to resolve symbolic CustomCall
+// targets; so when using the CPU JIT, CustomCall targets need to be registered
+// here with the symbol name used in the CustomCall.
+//
+// The XLA AOT compiler links using a standard offline linker; so when compiling
+// in AOT mode, you *also* need to make sure the name of the callee (presumably
+// implemented in C++) matches up with the symbolic name used in the CustomCall.
+//
+// We maintain the registry in both the JIT and the AOT cases for simplicity,
+// but we only use it when running in JIT mode.
+class CustomCallTargetRegistry {
+ public:
+ static CustomCallTargetRegistry* Global();
+
+ void Register(const std::string& symbol, void* address);
+ void* Lookup(const std::string& symbol) const;
+
+ private:
+ std::unordered_map<std::string, void*> registered_symbols_;
+ mutable std::mutex mu_;
+};
+
+class RegisterCustomCallTarget {
+ public:
+ explicit RegisterCustomCallTarget(const std::string& name, void* address) {
+ CustomCallTargetRegistry::Global()->Register(name, address);
+ }
+};
+
+#define REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
+
+#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, counter) \
+ static ::xla::cpu::RegisterCustomCallTarget REGISTER_CUSTOM_CALL_CONCAT( \
+ custom_call_target_register, counter)(symbol, \
+ reinterpret_cast<void*>(address))
+
+#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
+ REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, __COUNTER__)
+
+#define REGISTER_CUSTOM_CALL_TARGET(function) \
+ REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)
+
+} // namespace cpu
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index cfffb3fbc3..fdf02e5b42 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
+#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
@@ -43,81 +44,6 @@ namespace xla {
namespace cpu {
namespace {
-// Converts a symbol 'name' into the form expected by dlsym().
-std::string CanonicalizeSymbol(const std::string& name) {
-#if defined(__APPLE__)
- // On Mac OS X, dlsym() expects names not to be prefixed with a leading
- // underscore.
- if (!name.empty() && name.front() == '_') {
- return name.substr(1);
- }
-#endif
- return name;
-}
-
-class JITSymbolTable {
- public:
- JITSymbolTable() { Populate(); }
-
- void* Lookup(llvm::StringRef jit_symbol_name) const {
- auto it = jit_symbol_table_.find(jit_symbol_name);
- return it == jit_symbol_table_.end() ? nullptr : it->getValue();
- }
-
- static bool MustBeInTable(llvm::StringRef name) {
- // In particular, names starting with
- // runtime::kXlaCpuRuntimeSymbolNamePrefix should not be dlsym'ed.
- return name.startswith(runtime::kXlaCpuRuntimeSymbolNamePrefix);
- }
-
- private:
- void AddJITSymbolToTable(llvm::StringRef jit_symbol_name,
- llvm::StringRef cpp_symbol_name,
- void* jit_symbol_value) {
- // The JIT symbol name and the C++ symbol name (with an extern "C" linkage)
- // need to match, otherwise AOT links will fail.
- CHECK(jit_symbol_name == cpp_symbol_name);
- CHECK(jit_symbol_table_.insert({jit_symbol_name, jit_symbol_value}).second);
- }
-
- void Populate() {
-#define ADD_JIT_SYMBOL_TO_TABLE(base_name) \
- do { \
- AddJITSymbolToTable( \
- xla::cpu::runtime::k##base_name##SymbolName, \
- "__xla_cpu_runtime_" #base_name, \
- reinterpret_cast<void*>(__xla_cpu_runtime_##base_name)); \
- } while (false)
-
- ADD_JIT_SYMBOL_TO_TABLE(AcquireInfeedBufferForDequeue);
- ADD_JIT_SYMBOL_TO_TABLE(ReleaseInfeedBufferAfterDequeue);
- ADD_JIT_SYMBOL_TO_TABLE(AcquireOutfeedBufferForPopulation);
- ADD_JIT_SYMBOL_TO_TABLE(ReleaseOutfeedBufferAfterPopulation);
- ADD_JIT_SYMBOL_TO_TABLE(ExpV8F32AVX);
- ADD_JIT_SYMBOL_TO_TABLE(LogV8F32AVX);
- ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32SSE);
- ADD_JIT_SYMBOL_TO_TABLE(LogV4F32SSE);
- ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32NEON);
- ADD_JIT_SYMBOL_TO_TABLE(LogV4F32NEON);
- ADD_JIT_SYMBOL_TO_TABLE(EigenConvF32);
- ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF32);
- ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF64);
- ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedConvF32);
- ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF32);
- ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF64);
- ADD_JIT_SYMBOL_TO_TABLE(ParallelForkJoin);
-
-#undef ADD_JIT_SYMBOL_TO_TABLE
- }
-
- llvm::StringMap<void*> jit_symbol_table_;
-};
-
-const JITSymbolTable& GetJITSymbolTable() {
- static JITSymbolTable* symbol_table = new JITSymbolTable;
- return *symbol_table;
-}
-
// A simple SymbolResolver that delegates to the host dynamic linker.
class SimpleResolver : public llvm::JITSymbolResolver {
public:
@@ -125,7 +51,6 @@ class SimpleResolver : public llvm::JITSymbolResolver {
: external_constant_pool_(external_constant_pool) {}
llvm::JITSymbol findSymbol(const std::string& name) override {
- string name_as_string(name);
if (const uint8* from_constant_pool =
external_constant_pool_->Find(string(name))) {
return llvm::JITEvaluatedSymbol(
@@ -133,13 +58,7 @@ class SimpleResolver : public llvm::JITSymbolResolver {
llvm::JITSymbolFlags::None);
}
- std::string canonical_name = CanonicalizeSymbol(name);
- const JITSymbolTable& jit_symbol_table = GetJITSymbolTable();
-
- void* func_addr = JITSymbolTable::MustBeInTable(canonical_name)
- ? jit_symbol_table.Lookup(canonical_name)
- : dlsym(RTLD_DEFAULT, canonical_name.c_str());
-
+ void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name);
if (func_addr == nullptr) {
return nullptr;
}
@@ -255,5 +174,118 @@ llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string& name) {
return nullptr;
}
+namespace {
+// Register some known symbols with the CustomCallTargetRegistry.
+bool RegisterKnownJITSymbols() {
+ CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global();
+
+#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
+ do { \
+ auto* function_address = \
+ reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
+ registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
+ function_address); \
+ CHECK_EQ( \
+ tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \
+ "__xla_cpu_runtime_" #base_name); \
+ } while (false)
+
+ REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
+ REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
+ REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
+ REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
+ REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32NEON);
+ REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE);
+ REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX);
+ REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON);
+ REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE);
+ REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX);
+ REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
+ REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
+ REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
+
+#undef REGISTER_CPU_RUNTIME_SYMBOL
+
+#define REGISTER_LIBM_SYMBOL(name) \
+ do { \
+ /* Register both the F32 and F64 variants of the libm symbol. */ \
+ registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \
+ registry->Register(#name, reinterpret_cast<void*>(name)); \
+ } while (false)
+
+ REGISTER_LIBM_SYMBOL(acos);
+ REGISTER_LIBM_SYMBOL(acosh);
+ REGISTER_LIBM_SYMBOL(asin);
+ REGISTER_LIBM_SYMBOL(asinh);
+ REGISTER_LIBM_SYMBOL(atan);
+ REGISTER_LIBM_SYMBOL(atan2);
+ REGISTER_LIBM_SYMBOL(atanh);
+ REGISTER_LIBM_SYMBOL(cbrt);
+ REGISTER_LIBM_SYMBOL(ceil);
+ REGISTER_LIBM_SYMBOL(copysign);
+ REGISTER_LIBM_SYMBOL(cos);
+ REGISTER_LIBM_SYMBOL(cosh);
+ REGISTER_LIBM_SYMBOL(erf);
+ REGISTER_LIBM_SYMBOL(erfc);
+ REGISTER_LIBM_SYMBOL(exp);
+ REGISTER_LIBM_SYMBOL(exp2);
+ REGISTER_LIBM_SYMBOL(expm1);
+ REGISTER_LIBM_SYMBOL(fabs);
+ REGISTER_LIBM_SYMBOL(fdim);
+ REGISTER_LIBM_SYMBOL(floor);
+ REGISTER_LIBM_SYMBOL(fma);
+ REGISTER_LIBM_SYMBOL(fmax);
+ REGISTER_LIBM_SYMBOL(fmin);
+ REGISTER_LIBM_SYMBOL(fmod);
+ REGISTER_LIBM_SYMBOL(frexp);
+ REGISTER_LIBM_SYMBOL(hypot);
+ REGISTER_LIBM_SYMBOL(ilogb);
+ REGISTER_LIBM_SYMBOL(ldexp);
+ REGISTER_LIBM_SYMBOL(lgamma);
+ REGISTER_LIBM_SYMBOL(llrint);
+ REGISTER_LIBM_SYMBOL(llround);
+ REGISTER_LIBM_SYMBOL(log);
+ REGISTER_LIBM_SYMBOL(log10);
+ REGISTER_LIBM_SYMBOL(log1p);
+ REGISTER_LIBM_SYMBOL(log2);
+ REGISTER_LIBM_SYMBOL(logb);
+ REGISTER_LIBM_SYMBOL(lrint);
+ REGISTER_LIBM_SYMBOL(lround);
+ REGISTER_LIBM_SYMBOL(modf);
+ REGISTER_LIBM_SYMBOL(nan);
+ REGISTER_LIBM_SYMBOL(nearbyint);
+ REGISTER_LIBM_SYMBOL(nextafter);
+ REGISTER_LIBM_SYMBOL(nexttoward);
+ REGISTER_LIBM_SYMBOL(pow);
+ REGISTER_LIBM_SYMBOL(remainder);
+ REGISTER_LIBM_SYMBOL(remquo);
+ REGISTER_LIBM_SYMBOL(rint);
+ REGISTER_LIBM_SYMBOL(round);
+ REGISTER_LIBM_SYMBOL(scalbln);
+ REGISTER_LIBM_SYMBOL(scalbn);
+ REGISTER_LIBM_SYMBOL(sin);
+ REGISTER_LIBM_SYMBOL(sincos);
+ REGISTER_LIBM_SYMBOL(sinh);
+ REGISTER_LIBM_SYMBOL(sqrt);
+ REGISTER_LIBM_SYMBOL(tan);
+ REGISTER_LIBM_SYMBOL(tanh);
+ REGISTER_LIBM_SYMBOL(tgamma);
+ REGISTER_LIBM_SYMBOL(trunc);
+
+#undef REGISTER_LIBM_SYMBOL
+
+ registry->Register("memcpy", reinterpret_cast<void*>(memcpy));
+ registry->Register("memmove", reinterpret_cast<void*>(memmove));
+ registry->Register("memset", reinterpret_cast<void*>(memset));
+ return true;
+}
+
+bool unused = RegisterKnownJITSymbols();
+} // namespace
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 43127925e6..2ea7b9bd8e 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -23,7 +23,6 @@ filegroup(
]),
)
-load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts")
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library")
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
@@ -988,13 +987,13 @@ xla_test(
xla_test(
name = "custom_call_test",
srcs = ["custom_call_test.cc"],
- linkopts = export_dynamic_linkopts,
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index 342478bc74..74f73a1ddc 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -31,19 +32,19 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"
-
-extern "C" void TF_EXPORT R0F32Add2(float* out, float** in) {
+namespace {
+void R0F32Add2(float* out, float** in) {
TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*));
*out = **in + 2.0f;
}
-extern "C" void TF_EXPORT R2F32ReduceSum(float* out, float** in) {
+void R2F32ReduceSum(float* out, float** in) {
TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4);
float* array = in[0];
*out = array[0] + array[1] + array[2] + array[3];
}
-extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) {
+void Add1ToValues(float* out, float** in) {
TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4);
float* array = in[0];
out[0] = array[0] + 1;
@@ -51,6 +52,11 @@ extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) {
out[2] = array[2] + 1;
out[3] = array[3] + 1;
}
+} // namespace
+
+REGISTER_CUSTOM_CALL_TARGET(R0F32Add2);
+REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum);
+REGISTER_CUSTOM_CALL_TARGET(Add1ToValues);
namespace xla {
namespace {
diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl
index 22e70ec97a..3fa5bcc1df 100644
--- a/tensorflow/compiler/xla/xla.bzl
+++ b/tensorflow/compiler/xla/xla.bzl
@@ -17,11 +17,3 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0):
protoc="@protobuf_archive//:protoc",
testonly=testonly,
visibility=visibility,)
-
-# Flags required for modules that export symbols that are to be called by the
-# XLA CustomCall operator. CustomCall must be able to find symbols with dlsym(),
-# which on Linux requires we link with --export-dynamic.
-export_dynamic_linkopts = select({
- "//tensorflow:darwin": [],
- "//conditions:default": ["-Wl,--export-dynamic"],
-})