diff options
author | 2017-10-24 10:41:58 -0700 | |
---|---|---|
committer | 2017-10-24 10:46:59 -0700 | |
commit | 720efa37a4e93d5833e6e928993790f2523f0d85 (patch) | |
tree | c3c8709db2128acbdd096a884b29396baa5eff82 | |
parent | 1bbec9e4e9c5d3fbbc2fa2b58841435e86dbf76a (diff) |
Roll forward CL 171084886
171084886 had to be rolled back twice due to various open source build issues.
I'm trying again, now that I think I've addressed all the pertinent issues.
Original CL description:
Don't use dlsym to resolve symbols in the CPU JIT
Instead of resolving symbols via dlsym when JITting for the CPU backend, use a
registry based mechanism. This lets us kill off the --export_dynamic hack that
we used to need for CustomCall on the CPU backend.
PiperOrigin-RevId: 173277862
10 files changed, 259 insertions, 98 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 4ee7989824..2b43e313eb 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -5,7 +5,6 @@ package( ) load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts") tf_kernel_library( name = "xla_ops", @@ -153,6 +152,7 @@ cc_library( srcs = ["index_ops_kernel_argmax_float_1d.cc"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], @@ -164,6 +164,7 @@ cc_library( srcs = ["index_ops_kernel_argmax_float_2d.cc"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index afbd64ca50..47cf8c6675 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/macros.h" @@ -47,3 +48,5 @@ EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) { tensorflow::argmax_float_1d_xla_impl(out, data); } + +REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc index 841ff2f4df..9b83392d8f 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/macros.h" @@ -49,3 +50,5 @@ EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) { extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) { tensorflow::argmax_float_2d_xla_impl(out, data); } + +REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 136cbe7cb7..56bc1a6706 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -153,6 +153,7 @@ cc_library( ":cpu_runtime_avx", ":cpu_runtime_neon", ":cpu_runtime_sse4_1", + ":custom_call_target_registry", ":disassembler", ":external_constant_pool", ":runtime_conv2d", @@ -719,6 +720,17 @@ cc_library( ], ) +cc_library( + name = "custom_call_target_registry", + srcs = [ + "custom_call_target_registry.cc", + ], + hdrs = [ + "custom_call_target_registry.h", + ], + visibility = ["//visibility:public"], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc new file mode 100644 index 0000000000..5f5803874b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" + +namespace xla { +namespace cpu { + +CustomCallTargetRegistry* CustomCallTargetRegistry::Global() { + static auto* registry = new CustomCallTargetRegistry; + return registry; +} + +void CustomCallTargetRegistry::Register(const std::string& symbol, + void* address) { + std::lock_guard<std::mutex> lock(mu_); + registered_symbols_[symbol] = address; +} + +void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const { + std::lock_guard<std::mutex> lock(mu_); + auto it = registered_symbols_.find(symbol); + return it == registered_symbols_.end() ? nullptr : it->second; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h new file mode 100644 index 0000000000..2994642356 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ + +// This file is depended on by kernels that have to build for mobile devices. +// For this reason, we avoid relying on TensorFlow and instead only use the +// standard C++ library. + +#include <mutex> // NOLINT +#include <string> +#include <unordered_map> + +namespace xla { +namespace cpu { + +// The CPU JIT compiler uses this registry to resolve symbolic CustomCall +// targets; so when using the CPU JIT, CustomCall targets need to be registered +// here with the symbol name used in the CustomCall. +// +// The XLA AOT compiler links using a standard offline linker; so when compiling +// in AOT mode, you *also* need to make sure the name of the callee (presumably +// implemented in C++) matches up with the symbolic name used in the CustomCall. +// +// We maintain the registry in both the JIT and the AOT cases for simplicity, +// but we only use it when running in JIT mode. +class CustomCallTargetRegistry { + public: + static CustomCallTargetRegistry* Global(); + + void Register(const std::string& symbol, void* address); + void* Lookup(const std::string& symbol) const; + + private: + std::unordered_map<std::string, void*> registered_symbols_; + mutable std::mutex mu_; +}; + +class RegisterCustomCallTarget { + public: + explicit RegisterCustomCallTarget(const std::string& name, void* address) { + CustomCallTargetRegistry::Global()->Register(name, address); + } +}; + +#define REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b + +#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, counter) \ + static ::xla::cpu::RegisterCustomCallTarget REGISTER_CUSTOM_CALL_CONCAT( \ + custom_call_target_register, counter)(symbol, \ + reinterpret_cast<void*>(address)) + +#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \ + REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, __COUNTER__) + +#define REGISTER_CUSTOM_CALL_TARGET(function) \ + REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function) + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index cfffb3fbc3..fdf02e5b42 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" @@ -43,81 +44,6 @@ namespace xla { namespace cpu { namespace { -// Converts a symbol 'name' into the form expected by dlsym(). -std::string CanonicalizeSymbol(const std::string& name) { -#if defined(__APPLE__) - // On Mac OS X, dlsym() expects names not to be prefixed with a leading - // underscore. - if (!name.empty() && name.front() == '_') { - return name.substr(1); - } -#endif - return name; -} - -class JITSymbolTable { - public: - JITSymbolTable() { Populate(); } - - void* Lookup(llvm::StringRef jit_symbol_name) const { - auto it = jit_symbol_table_.find(jit_symbol_name); - return it == jit_symbol_table_.end() ? nullptr : it->getValue(); - } - - static bool MustBeInTable(llvm::StringRef name) { - // In particular, names starting with - // runtime::kXlaCpuRuntimeSymbolNamePrefix should not be dlsym'ed. - return name.startswith(runtime::kXlaCpuRuntimeSymbolNamePrefix); - } - - private: - void AddJITSymbolToTable(llvm::StringRef jit_symbol_name, - llvm::StringRef cpp_symbol_name, - void* jit_symbol_value) { - // The JIT symbol name and the C++ symbol name (with an extern "C" linkage) - // need to match, otherwise AOT links will fail. - CHECK(jit_symbol_name == cpp_symbol_name); - CHECK(jit_symbol_table_.insert({jit_symbol_name, jit_symbol_value}).second); - } - - void Populate() { -#define ADD_JIT_SYMBOL_TO_TABLE(base_name) \ - do { \ - AddJITSymbolToTable( \ - xla::cpu::runtime::k##base_name##SymbolName, \ - "__xla_cpu_runtime_" #base_name, \ - reinterpret_cast<void*>(__xla_cpu_runtime_##base_name)); \ - } while (false) - - ADD_JIT_SYMBOL_TO_TABLE(AcquireInfeedBufferForDequeue); - ADD_JIT_SYMBOL_TO_TABLE(ReleaseInfeedBufferAfterDequeue); - ADD_JIT_SYMBOL_TO_TABLE(AcquireOutfeedBufferForPopulation); - ADD_JIT_SYMBOL_TO_TABLE(ReleaseOutfeedBufferAfterPopulation); - ADD_JIT_SYMBOL_TO_TABLE(ExpV8F32AVX); - ADD_JIT_SYMBOL_TO_TABLE(LogV8F32AVX); - ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32SSE); - ADD_JIT_SYMBOL_TO_TABLE(LogV4F32SSE); - ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32NEON); - ADD_JIT_SYMBOL_TO_TABLE(LogV4F32NEON); - ADD_JIT_SYMBOL_TO_TABLE(EigenConvF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF64); - ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedConvF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF64); - ADD_JIT_SYMBOL_TO_TABLE(ParallelForkJoin); - -#undef ADD_JIT_SYMBOL_TO_TABLE - } - - llvm::StringMap<void*> jit_symbol_table_; -}; - -const JITSymbolTable& GetJITSymbolTable() { - static JITSymbolTable* symbol_table = new JITSymbolTable; - return *symbol_table; -} - // A simple SymbolResolver that delegates to the host dynamic linker. class SimpleResolver : public llvm::JITSymbolResolver { public: @@ -125,7 +51,6 @@ class SimpleResolver : public llvm::JITSymbolResolver { : external_constant_pool_(external_constant_pool) {} llvm::JITSymbol findSymbol(const std::string& name) override { - string name_as_string(name); if (const uint8* from_constant_pool = external_constant_pool_->Find(string(name))) { return llvm::JITEvaluatedSymbol( @@ -133,13 +58,7 @@ class SimpleResolver : public llvm::JITSymbolResolver { llvm::JITSymbolFlags::None); } - std::string canonical_name = CanonicalizeSymbol(name); - const JITSymbolTable& jit_symbol_table = GetJITSymbolTable(); - - void* func_addr = JITSymbolTable::MustBeInTable(canonical_name) - ? jit_symbol_table.Lookup(canonical_name) - : dlsym(RTLD_DEFAULT, canonical_name.c_str()); - + void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); if (func_addr == nullptr) { return nullptr; } @@ -255,5 +174,118 @@ llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string& name) { return nullptr; } +namespace { +// Register some known symbols with the CustomCallTargetRegistry. +bool RegisterKnownJITSymbols() { + CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global(); + +#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ + do { \ + auto* function_address = \ + reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \ + registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ + function_address); \ + CHECK_EQ( \ + tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \ + "__xla_cpu_runtime_" #base_name); \ + } while (false) + + REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); + REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32NEON); + REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE); + REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX); + REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON); + REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE); + REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX); + REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); + REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); + REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); + +#undef REGISTER_CPU_RUNTIME_SYMBOL + +#define REGISTER_LIBM_SYMBOL(name) \ + do { \ + /* Register both the F32 and F64 variants of the libm symbol. */ \ + registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \ + registry->Register(#name, reinterpret_cast<void*>(name)); \ + } while (false) + + REGISTER_LIBM_SYMBOL(acos); + REGISTER_LIBM_SYMBOL(acosh); + REGISTER_LIBM_SYMBOL(asin); + REGISTER_LIBM_SYMBOL(asinh); + REGISTER_LIBM_SYMBOL(atan); + REGISTER_LIBM_SYMBOL(atan2); + REGISTER_LIBM_SYMBOL(atanh); + REGISTER_LIBM_SYMBOL(cbrt); + REGISTER_LIBM_SYMBOL(ceil); + REGISTER_LIBM_SYMBOL(copysign); + REGISTER_LIBM_SYMBOL(cos); + REGISTER_LIBM_SYMBOL(cosh); + REGISTER_LIBM_SYMBOL(erf); + REGISTER_LIBM_SYMBOL(erfc); + REGISTER_LIBM_SYMBOL(exp); + REGISTER_LIBM_SYMBOL(exp2); + REGISTER_LIBM_SYMBOL(expm1); + REGISTER_LIBM_SYMBOL(fabs); + REGISTER_LIBM_SYMBOL(fdim); + REGISTER_LIBM_SYMBOL(floor); + REGISTER_LIBM_SYMBOL(fma); + REGISTER_LIBM_SYMBOL(fmax); + REGISTER_LIBM_SYMBOL(fmin); + REGISTER_LIBM_SYMBOL(fmod); + REGISTER_LIBM_SYMBOL(frexp); + REGISTER_LIBM_SYMBOL(hypot); + REGISTER_LIBM_SYMBOL(ilogb); + REGISTER_LIBM_SYMBOL(ldexp); + REGISTER_LIBM_SYMBOL(lgamma); + REGISTER_LIBM_SYMBOL(llrint); + REGISTER_LIBM_SYMBOL(llround); + REGISTER_LIBM_SYMBOL(log); + REGISTER_LIBM_SYMBOL(log10); + REGISTER_LIBM_SYMBOL(log1p); + REGISTER_LIBM_SYMBOL(log2); + REGISTER_LIBM_SYMBOL(logb); + REGISTER_LIBM_SYMBOL(lrint); + REGISTER_LIBM_SYMBOL(lround); + REGISTER_LIBM_SYMBOL(modf); + REGISTER_LIBM_SYMBOL(nan); + REGISTER_LIBM_SYMBOL(nearbyint); + REGISTER_LIBM_SYMBOL(nextafter); + REGISTER_LIBM_SYMBOL(nexttoward); + REGISTER_LIBM_SYMBOL(pow); + REGISTER_LIBM_SYMBOL(remainder); + REGISTER_LIBM_SYMBOL(remquo); + REGISTER_LIBM_SYMBOL(rint); + REGISTER_LIBM_SYMBOL(round); + REGISTER_LIBM_SYMBOL(scalbln); + REGISTER_LIBM_SYMBOL(scalbn); + REGISTER_LIBM_SYMBOL(sin); + REGISTER_LIBM_SYMBOL(sincos); + REGISTER_LIBM_SYMBOL(sinh); + REGISTER_LIBM_SYMBOL(sqrt); + REGISTER_LIBM_SYMBOL(tan); + REGISTER_LIBM_SYMBOL(tanh); + REGISTER_LIBM_SYMBOL(tgamma); + REGISTER_LIBM_SYMBOL(trunc); + +#undef REGISTER_LIBM_SYMBOL + + registry->Register("memcpy", reinterpret_cast<void*>(memcpy)); + registry->Register("memmove", reinterpret_cast<void*>(memmove)); + registry->Register("memset", reinterpret_cast<void*>(memset)); + return true; +} + +bool unused = RegisterKnownJITSymbols(); +} // namespace + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 43127925e6..2ea7b9bd8e 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -23,7 +23,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") @@ -988,13 +987,13 @@ xla_test( xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], - linkopts = export_dynamic_linkopts, deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 342478bc74..74f73a1ddc 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -31,19 +32,19 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" - -extern "C" void TF_EXPORT R0F32Add2(float* out, float** in) { +namespace { +void R0F32Add2(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*)); *out = **in + 2.0f; } -extern "C" void TF_EXPORT R2F32ReduceSum(float* out, float** in) { +void R2F32ReduceSum(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; *out = array[0] + array[1] + array[2] + array[3]; } -extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) { +void Add1ToValues(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; out[0] = array[0] + 1; @@ -51,6 +52,11 @@ extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) { out[2] = array[2] + 1; out[3] = array[3] + 1; } +} // namespace + +REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); +REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); +REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); namespace xla { namespace { diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 22e70ec97a..3fa5bcc1df 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -17,11 +17,3 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0): protoc="@protobuf_archive//:protoc", testonly=testonly, visibility=visibility,) - -# Flags required for modules that export symbols that are to be called by the -# XLA CustomCall operator. CustomCall must be able to find symbols with dlsym(), -# which on Linux requires we link with --export-dynamic. -export_dynamic_linkopts = select({ - "//tensorflow:darwin": [], - "//conditions:default": ["-Wl,--export-dynamic"], -}) |